当前位置: 首页 > news >正文

《Pytorch深度学习实践》ch3-反向传播

                                                           ------B站《刘二大人》

1.Introduction

  • 在神经网络中,可以看到权重非常多,计算 loss 对 w 的偏导非常困难,于是引入了反向传播方法;

2.Backward

  • 这里模型为  y = x * w,所以要计算的偏导数为 loss 对 w;

  • 这里模型为  y = x * w + b,所以要计算的偏导数为 loss 对 w 和 loss 对 b;
  • 模型有几个初始变量,就要求几个偏导;

3.Tensor

  • Pytorch 里常用的一种数据类型为 Tensor,包含两种值;
  • item():提取数值,不会保留 Tensor 结构,也不能用于更新权重。
  • data:直接修改 Tensor 数据(可以更新权重等),但不会影响梯度计算或反向传播。它允许修改 Tensor,而不会触发梯度计算。
  • 直接用原梯度才会触发梯度计算。

4.Implementation

import torch
import matplotlib.pyplot as plt# 数据集
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]# 权重(Tensor)
w = torch.Tensor([1.0])
w.requires_grad = True# 模型(自动变为 Tensor 间的运算)
def forward(x): return x * w# 损失函数
def loss(x, y):y_pred = forward(x)return (y_pred - y) ** 2# 训练轮数 epoch 为横坐标,损失 loss 为纵坐标
epoch_list = []
loss_list = []# 计算 loss - epoch
print('Predict (before training)', 4, forward(4).item())for epoch in range(100):for x, y in zip(x_data, y_data):loss_val = loss(x , y)loss_val.backward() # 反馈完计算图会被释放print('\tgrad:', x, y, w.grad.item())w.data -= 0.01 * w.grad.data # data 不会加入计算图w.grad.data.zero_() # 清除当前梯度,以防止累积epoch_list.append(epoch)loss_list.append(loss_val.item())print('progress:', epoch, loss_val.item())print('Predict (after training)', 4, forward(4).item())# 绘图
plt.plot(epoch_list, loss_list)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.grid()
plt.show()
  • 绘图如下:

http://www.lqws.cn/news/79615.html

相关文章:

  • 数字化转型全场景安全解析:从产品到管理的防线构建与实施要点
  • 自适应流量调度用于遥操作:面向时间敏感网络的通信与控制协同优化框架
  • 用wireshark抓包分析学习USB协议
  • 04powerbi-度量值-筛选引擎CALCULATE()
  • 吴恩达MCP课程(5):research_server_prompt_resource.py
  • 光伏功率预测 | BiLSTM多变量单步光伏功率预测(Matlab完整源码和数据)
  • HTML 等价字符引用:系统化记忆指南
  • 网络攻防技术五:网络扫描技术
  • Linux中的mysql逻辑备份与恢复
  • 二叉树的层序遍历与完全二叉树判断
  • HarmonyOS鸿蒙Taro跨端框架
  • 已有的前端项目打包到tauri运行(windows)
  • AI智能体|扣子(Coze)搭建【合同/文档审查】工作流
  • SpringBoot手动实现流式输出方案整理以及SSE规范输出详解
  • 从 LeetCode 到日志匹配:一行 Swift 实现规则识别
  • 【Godot】如何导出 Release 版本的安卓项目
  • Linux服务器安装GUI界面工具
  • Grafana对接Prometheus数据源
  • LlamaIndex的IngestionPipeline添加本地存储(本地文档存储)
  • 【深度学习】实验四 卷积神经网络CNN
  • 记录一次由打扑克牌测试国内各家大模型的经历
  • 2025年5月24日系统架构设计师考试题目回顾
  • 使用 OpenCV (C++) 进行人脸边缘提取
  • 大数据-275 Spark MLib - 基础介绍 机器学习算法 集成学习 随机森林 Bagging Boosting
  • shiro使用详解
  • Java后端优化:对象池模式解决高频ObjectMapper实例化问题及性能影响
  • 链式前向星图解
  • 【C++高级主题】转换与多个基类
  • InlineHook的原理与做法
  • 【TMS570LC4357】之相关驱动开发学习记录1