当前位置: 首页 > news >正文

【机器学习】非参数贝叶斯回归方法 GPR

目录

  • 🎯 一、GPR 方法原理
    • 1.1 核心思想
    • 1.2 回归模型设定
    • 1.3 后验推断
    • 1.4 常见核函数
  • 🛠️ 二、Python 实现 GPR(含数据生成)
    • 2.1 使用 scikit-learn 实现 GPR(基础案例)
  • 参考

Gaussian Process Regression (GPR) 是一种强大的 非参数贝叶斯回归方法,适用于拟合非线性关系,并能提供预测的不确定性。

与传统的线性回归模型不同,GPR 能够通过指定的核函数捕捉复杂的非线性关系,并提供不确定性的估计。

下面将详细介绍 GPR 的原理、实现步骤,并附上完整的 Python 实现示例,包括数据生成过程。

🎯 一、GPR 方法原理

1.1 核心思想

GPR 假设目标函数是一个 高斯过程(Gaussian Process) 的样本。高斯过程是定义在输入空间上的 随机函数集合,任意有限个点上的函数值服从联合高斯分布:
在这里插入图片描述

1.2 回归模型设定

假设我们有训练数据:
在这里插入图片描述

1.3 后验推断

构建联合高斯分布:
在这里插入图片描述

1.4 常见核函数

核函数是 GPR 的核心,它决定了模型的平滑度、周期性等特性。选择合适的核函数可以显著提高模型的性能。常见的核函数包括:

  • RBF (Gaussian) Kernel:适用于平滑且连续的函数建模。
    在这里插入图片描述
  • Matern Kernel
  • Polynomial Kernel
  • Dot Product Kernel
  • 线性核:适用于线性关系建模。

核函数的形式和参数需要根据具体问题进行选择和调整。

🛠️ 二、Python 实现 GPR(含数据生成)

2.1 使用 scikit-learn 实现 GPR(基础案例)

核函数设计:RBF 核 + 白噪声 WhiteKernel 是常见组合。
参数优化:n_restarts_optimizer 用于多次尝试最优化超参数。
不确定性估计:GPR 不仅给出预测值,还提供置信区间(方差)。

在这里插入图片描述

完整Python代码如下:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel, ConstantKernel as C
import matplotlib as mpl# 设置字体
mpl.rcParams['font.family'] = 'Times New Roman'# 1. 生成训练数据
np.random.seed(42)
X_train = np.linspace(0, 5, 20).reshape(-1, 1)
y_train = np.sin(X_train).ravel() + np.random.normal(0, 0.1, X_train.shape[0])# 2. 构造核函数
kernel = C(1.0) * RBF(length_scale=1.0) + WhiteKernel(noise_level=0.1)# 3. 初始化并训练 GPR 模型
gpr = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=10)
gpr.fit(X_train, y_train)# 4. 测试点预测
X_test = np.linspace(0, 6, 100).reshape(-1, 1)
y_pred, y_std = gpr.predict(X_test, return_std=True)# 5. 可视化结果
plt.figure(figsize=(10, 6))
plt.plot(X_train, y_train, 'ro', label="Training Data")
plt.plot(X_test, y_pred, 'b-', label="Mean Prediction")
plt.fill_between(X_test.ravel(),y_pred - 1.96 * y_std,y_pred + 1.96 * y_std,alpha=0.3, color='blue', label="95% Confidence Interval")plt.title("Gaussian Process Regression", fontsize=16,fontweight='bold')
plt.xlabel("x", fontsize=16,fontweight='bold')
plt.ylabel("f(x)", fontsize=16,fontweight='bold')plt.legend()
plt.grid(True)
plt.show()

参考

http://www.lqws.cn/news/523819.html

相关文章:

  • ipfs在windows下载和安装
  • JSON框架转化isSuccess()为sucess字段
  • C++(智能指针)
  • Liunx操作系统笔记2
  • linux-修改文件命令(补充)
  • IT运维效率提升: 当IT监控遇上3D可视化
  • 三步实现B站缓存视频转MP4格式
  • 记一次AWS 中RDS优化费用使用的案例
  • Postman鉴权动态传参?对比脚本变量vs环境变量!
  • 理论加案例,一文读懂数据分析中的分类建模
  • 通过pyqt5学习MVC
  • 代理型 AI 重塑营销格局:国产 R²AIN SUITE 如何破解数据与技术瓶颈,实现 AI 赋能全链路提效
  • VScode常用快捷键【个人总结】
  • 2024年AEI SCI1区TOP,强化学习人工兔优化算法RLTARO+山地森林地形无人机编队路径规划,深度解析+性能实测
  • Dify、n8n、Coze、FastGPT 和 Ragflow 对比分析:如何选择最适合你的智能体平台?
  • Wpf的Binding
  • 数据库1.0
  • Python 爬虫入门:从数据爬取到转存 MySQL 数据库
  • 【Ansible】Ansible入门
  • Git常用操作详解
  • Python核心可视化库:Matplotlib与Seaborn深度解析
  • React 第六十四节Router中HashRouter的使用详细介绍及案例分析
  • 重置 MySQL root 密码
  • 基于STM32的智能节能风扇的设计
  • KNN算法(K近邻算法)
  • K8s在centos7安装及kubectl
  • 50天50个小项目 (Vue3 + Tailwindcss V4) ✨ | BackgroundSlider(背景滑块)
  • 设备维修全流程记录,提升设备运维效率
  • 前端面试专栏-主流框架:13.vue3组件通信与生命周期
  • 【MPC】实战:基于MPC的车辆自适应巡航控制 (ACC) 系统设计