【机器学习】非参数贝叶斯回归方法 GPR
目录
- 🎯 一、GPR 方法原理
- 1.1 核心思想
- 1.2 回归模型设定
- 1.3 后验推断
- 1.4 常见核函数
- 🛠️ 二、Python 实现 GPR(含数据生成)
- 2.1 使用 scikit-learn 实现 GPR(基础案例)
- 参考
Gaussian Process Regression (GPR) 是一种强大的 非参数贝叶斯回归方法,适用于拟合非线性关系,并能提供预测的不确定性。
与传统的线性回归模型不同,GPR 能够通过指定的核函数捕捉复杂的非线性关系,并提供不确定性的估计。
下面将详细介绍 GPR 的原理、实现步骤,并附上完整的 Python 实现示例,包括数据生成过程。
🎯 一、GPR 方法原理
1.1 核心思想
GPR 假设目标函数是一个 高斯过程(Gaussian Process) 的样本。高斯过程是定义在输入空间上的 随机函数集合,任意有限个点上的函数值服从联合高斯分布:
1.2 回归模型设定
假设我们有训练数据:
1.3 后验推断
构建联合高斯分布:
1.4 常见核函数
核函数是 GPR 的核心,它决定了模型的平滑度、周期性等特性。选择合适的核函数可以显著提高模型的性能。常见的核函数包括:
- RBF (Gaussian) Kernel:适用于平滑且连续的函数建模。
- Matern Kernel
- Polynomial Kernel
- Dot Product Kernel
- 线性核:适用于线性关系建模。
核函数的形式和参数需要根据具体问题进行选择和调整。
🛠️ 二、Python 实现 GPR(含数据生成)
2.1 使用 scikit-learn 实现 GPR(基础案例)
核函数设计:RBF 核 + 白噪声 WhiteKernel 是常见组合。
参数优化:n_restarts_optimizer 用于多次尝试最优化超参数。
不确定性估计:GPR 不仅给出预测值,还提供置信区间(方差)。
完整Python代码如下:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel, ConstantKernel as C
import matplotlib as mpl# 设置字体
mpl.rcParams['font.family'] = 'Times New Roman'# 1. 生成训练数据
np.random.seed(42)
X_train = np.linspace(0, 5, 20).reshape(-1, 1)
y_train = np.sin(X_train).ravel() + np.random.normal(0, 0.1, X_train.shape[0])# 2. 构造核函数
kernel = C(1.0) * RBF(length_scale=1.0) + WhiteKernel(noise_level=0.1)# 3. 初始化并训练 GPR 模型
gpr = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=10)
gpr.fit(X_train, y_train)# 4. 测试点预测
X_test = np.linspace(0, 6, 100).reshape(-1, 1)
y_pred, y_std = gpr.predict(X_test, return_std=True)# 5. 可视化结果
plt.figure(figsize=(10, 6))
plt.plot(X_train, y_train, 'ro', label="Training Data")
plt.plot(X_test, y_pred, 'b-', label="Mean Prediction")
plt.fill_between(X_test.ravel(),y_pred - 1.96 * y_std,y_pred + 1.96 * y_std,alpha=0.3, color='blue', label="95% Confidence Interval")plt.title("Gaussian Process Regression", fontsize=16,fontweight='bold')
plt.xlabel("x", fontsize=16,fontweight='bold')
plt.ylabel("f(x)", fontsize=16,fontweight='bold')plt.legend()
plt.grid(True)
plt.show()