当前位置: 首页 > news >正文

【机器学习第一期(Python)】梯度提升决策树 GBDT

目录

  • 基础知识
    • 决策树(Decision Tree)
    • 回归树
  • 📌 一、GBDT 原理概述
    • 1.1 Boosting 思想
    • 1.2 GBDT 的关键思想
  • 🧠 二、GBDT 算法流程
  • 🐍 三、Python 实现步骤(以回归为例)
    • ✨ 可调参数总结(调参重点)
    • 3.1 数据准备
    • 3.2 模型训练
    • 3.3 模型评估
  • 四、完整案例(Python)
  • 参考

梯度提升决策树(GBDT:Gradient Boosting Decision Tree)是一种集成学习方法,广泛用于分类和回归任务。它结合了多个 弱学习器(通常是决策树) 的预测结果,通过梯度下降策略不断优化模型残差,从而提升整体预测性能。

基础知识

决策树(Decision Tree)

决策树(Decision Tree) 是一种基于树形结构进行决策分析的方法。它利用树形结构来表示各种决策结果之间的关系,并且可以用于 分类和回归分析等任务。其关键思想是通过递归地划分数据集,找到最优的特征及其划分点,使得子数据集尽可能纯净(即目标变量具有较小的方差或熵)。

决策树的构建过程可以分为两个步骤:

  • 第一步是选择一个合适的划分属性,将数据集划分成多个子集;
  • 第二步是针对每个子集递归地重复进行第一步,直到所有的子集都属于同一类别或者达到了预定义的停止条件。

在决策树的构建过程中,需要选择一个合适的划分属性。常见的划分属性选择方法有信息增益、信息增益比、基尼指数等。在选择划分属性时,通常选择对分类结果的影响最大的属性,以达到最优的划分效果。

决策树的优点包括易于理解和解释、计算复杂度低等。同时,决策树也有一些缺点,比如容易过拟合、对数据的噪声和异常值敏感等。
在这里插入图片描述

回归树

回归树(Regression Tree) 是决策树的回归版本,主要用于预测连续变量。它的核心思想是:

  • 选择一个特征及其划分点,使得划分后子数据集的均方误差(MSE)最小。
  • 递归执行上述过程,直到达到停止条件(如叶子节点样本数小于某个阈值)。
  • 预测时,输入样本沿着树结构从根节点到叶子节点,并返回叶子节点的均值作为最终预测值。

📌 一、GBDT 原理概述

1.1 Boosting 思想

Boosting 是一种序列学习方法,它通过加法模型将多个弱学习器组合成一个强学习器。
每一次迭代,模型都会试图拟合上一次模型的预测误差(残差)。

1.2 GBDT 的关键思想

使用CART回归树作为弱学习器。
每一步都在现有模型的基础上,通过最小化损失函数的负梯度方向来构建新的决策树。

GBDT 特性总结

特性描述
基学习器回归树(CART)
适用任务回归、二分类、多分类
优点高准确率、无需特征归一化、可处理非线性关系
缺点训练时间较长、难以并行

🧠 二、GBDT 算法流程

设训练数据为 ( x i , y i ) i = 1 n {(x_i, y_i)}_{i=1}^{n} (xi,yi)i=1n,我们的目标是学习一个函数 F ( x ) F(x) F(x) 来预测 y y y

2.1 初始化模型

选择一个常数值 F 0 ( x ) F_0(x) F0(x) 作为初始模型,通常是使损失函数最小的常数:

F_0(x) = \arg\min_\gamma \sum_{i=1}^n L(y_i, \gamma)

2.2 每一轮迭代步骤(共 M M M 轮)

m = 1 m = 1 m=1 M M M

1、计算残差(负梯度):

r_{im} = -\left[\frac{\partial L(y_i, F(x_i))}{\partial F(x_i)}\right]_{F=F_{m-1}}

实际上,这就是损失函数在当前模型处的梯度。

2、拟合一棵回归树 h m ( x ) h_m(x) hm(x) 来预测残差 r i m r_{im} rim

3、计算每个叶子节点的输出值 γ j m \gamma_{jm} γjm,使得损失最小:

\gamma_{jm} = \arg\min_\gamma \sum_{x_i \in R_{jm}} L(y_i, F_{m-1}(x_i) + \gamma)

4、更新模型:

F_m(x) = F_{m-1}(x) + \nu \cdot h_m(x)

其中 ν \nu ν 是学习率,控制每次迭代的步长。

🐍 三、Python 实现步骤(以回归为例)

我们使用 sklearn 中的 GradientBoostingRegressor 实现一个简单的 GBDT 模型。

✨ 可调参数总结(调参重点)

参数含义建议
n_estimators树的数量越大模型越复杂,需配合早停
learning_rate学习率小学习率需较多树
max_depth树的最大深度控制模型复杂度
subsample子样本比例小于1可防止过拟合
min_samples_split内部节点再划分所需的最小样本数控制过拟合

3.1 数据准备

from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_splitX, y = make_regression(n_samples=1000, n_features=10, noise=0.1)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

3.2 模型训练

from sklearn.ensemble import GradientBoostingRegressormodel = GradientBoostingRegressor(n_estimators=100,    # 弱学习器数量learning_rate=0.1,   # 学习率max_depth=3,         # 每棵树的深度random_state=42
)model.fit(X_train, y_train)

3.3 模型评估

from sklearn.metrics import mean_squared_errory_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)print(f"Test MSE: {mse:.4f}")

四、完整案例(Python)

任务:使用 GBDT 拟合一个非线性函数
数据:人工生成的非线性回归数据
模型:Gradient Boosting Regressor

可视化:

  • 模型拟合曲线
  • 残差分析图
  • 训练集 vs 测试集预测效果

在这里插入图片描述

左图:拟合效果:拟合曲线很好地捕捉了数据的非线性趋势。

  • 蓝点:训练数据
  • 红点:测试数据
  • 黑线:GBDT 拟合曲线

右图:残差图:残差应随机分布在 y=0 附近,没有明显模式,表明模型拟合良好。

输出结果为:

Train MSE: 0.0110
Test MSE: 0.0465

完整Python实现代码如下:(可以尝试调整参数如 n_estimators, learning_rate, max_depth 来观察拟合效果的变化)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error# 1. 生成非线性数据
np.random.seed(42)
X = np.linspace(0, 10, 200).reshape(-1, 1)
y = np.sin(X).ravel() + np.random.normal(0, 0.2, X.shape[0])  # 添加噪声# 2. 拆分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 3. 拟合 GBDT 模型
model = GradientBoostingRegressor(n_estimators=100,learning_rate=0.1,max_depth=3,random_state=42
)
model.fit(X_train, y_train)# 4. 预测
y_train_pred = model.predict(X_train)
y_test_pred = model.predict(X_test)# 5. 模型评估
train_mse = mean_squared_error(y_train, y_train_pred)
test_mse = mean_squared_error(y_test, y_test_pred)print(f"Train MSE: {train_mse:.4f}")
print(f"Test MSE: {test_mse:.4f}")# 6. 可视化:模型拟合效果
plt.figure(figsize=(12, 6))# 原始数据 + 拟合线
plt.subplot(1, 2, 1)
plt.scatter(X_train, y_train, color='lightblue', label='Train Data', alpha=0.6)
plt.scatter(X_test, y_test, color='lightcoral', label='Test Data', alpha=0.6)# 绘制平滑预测曲线
X_all = np.linspace(0, 10, 1000).reshape(-1, 1)
y_all_pred = model.predict(X_all)
plt.plot(X_all, y_all_pred, color='black', label='GBDT Prediction', linewidth=2)plt.title("GBDT Model Fit")
plt.xlabel("X")
plt.ylabel("y")
plt.legend()
plt.grid(True)# 7. 可视化:残差图
plt.subplot(1, 2, 2)
train_residuals = y_train - y_train_pred
test_residuals = y_test - y_test_predplt.scatter(y_train_pred, train_residuals, color='blue', alpha=0.6, label='Train Residuals')
plt.scatter(y_test_pred, test_residuals, color='red', alpha=0.6, label='Test Residuals')
plt.axhline(y=0, color='black', linestyle='--')
plt.xlabel("Predicted y")
plt.ylabel("Residuals")
plt.title("Residual Plot")
plt.legend()
plt.grid(True)plt.tight_layout()
plt.show()

参考

http://www.lqws.cn/news/529939.html

相关文章:

  • Pycharm无法运行Vue项目的解决办法
  • Java 泛型详解:从入门到实战
  • jdbc实现跨库分页查询demo
  • 人力资源管理系统
  • Spring Cloud Config动态刷新实战指南
  • 用户统计-01.需求分析和设计
  • GNSS位移监测站在大坝安全中的用处
  • 渗透实战:使用隐式转换覆盖toString的反射型xss
  • Day43 复习日 图像数据集——CNN
  • 【PX4-AutoPilot教程-TIPS】PX4系统命令行控制台ConsolesShells常用命令(持续更新)
  • ES文件管理器v4.4.3(ES文件浏览器)
  • 鸿蒙 FoldSplitContainer 解析:折叠屏布局适配与状态管理
  • MySQL之存储函数与触发器详解
  • 多相机人脸扫描设备如何助力高效打造数字教育孪生体?
  • ethers.js express vue2 定时任务每天凌晨2点监听合约地址数据同步到Mysql整理
  • ASIO 避坑指南:高效、安全与稳健的异步网络编程
  • 微服务架构下面临的安全、合规审计挑战
  • Python打卡:Day37
  • 使用 Python 自动化文件获取:从 FTP 到 API 的全面指南
  • 【Bluedroid】蓝牙启动之 btm_acl_device_down 流程源码解析
  • 稳定币技术全解:从货币锚定机制到区块链金融基础设施
  • Java底层原理:深入理解线程与并发机制
  • GEO生成式引擎优化发展迅猛:热点数智化传播是GEO最佳路径
  • 人大金仓Kingbase数据库KSQL 常用命令指南
  • 【论文】云原生事件驱动架构在智能风控系统中的实践与思考
  • 小孙学变频学习笔记(八)变频器的输入电流(下)
  • RPC(Remote Procedure Call)技术解析
  • 计算机网络 网络层:控制平面(二)
  • WPF中Converter基础用法
  • 正则表达式,`[]`(字符类)和`|`(或操作符)