当前位置: 首页 > news >正文

《Pytorch深度学习实践》ch5-Logistic回归

                                                        ------B站《刘二大人》

1.Classification

  • 经典的分类数据集:MNIST(0 - 9)

  • 导入数据集:(路径,训练集/测试集,是否下载)
import torchvision
train_set = torchvision.datasets.MINIST(root='../dataset/mnist', train=True,  download=True)
test_set  = torchvision.datasets.MINIST(root='../dataset/mnist', train=False, download=True)

2.Sigmoid functions

  • 由于分类问题就是求概率的最大值,所以利用 S 函数将数值全部映射到 [0,1] 区间;
  • 最著名的就是这个 Logistic 函数:

  • 其它的 S 函数:

3.Logistic Regression Model

  • 就是在原函数基础上加一个 Sigmoid:

4.Loss and BCE

  • BCE:Binary Cross Entropy,二元交叉熵损失:

5.Implemetation

  • 导包:
import torch
import torch.nn.functional as F
  • 数据集:y 变为 {0,1},二分类
# 数据集
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
  • 模型:F.sigmoid()函数
# 模型
class LogisticRegressionModel(torch.nn.Module): # Module 构建计算图def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1) def forward(self, x): # 前馈y_pred = F.sigmoid(self.linear(x))return y_predmodel = LogisticRegressionModel() # 实例化
  •  损失和优化器:BCELoss
# 损失函数和优化器
criterion = torch.nn.BCELoss(reduction = 'sum') # 计算损失,参数为(y_pred, y)optimizer = torch.optim.SGD(model.parameters(), lr = 0.01) # 进行更新
  • 训练:
# 训练
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data) # 1.前馈print(epoch, loss)optimizer.zero_grad() # 梯度清零loss.backward() # 2.反馈optimizer.step() # 3.更新

6.Result

import numpy as np
import matplotlib.pyplot as pltx = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200,1)) # 将x数组转换为PyTorch张量,并将其形状调整为列向量(200x1)
y_t = model(x_t)
y = y_t.data.numpy() # 将输出张量y_t转换为NumPy数组yplt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], c='r') # 绘制一条从x=0到x=10的红色水平线,y值为0.5
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()
  • 绘图如下:

http://www.lqws.cn/news/111907.html

相关文章:

  • ollama的安装及加速下载技巧
  • VBA模拟进度条
  • 缩量和放量指的是什么?
  • 二叉树(二)
  • Windows应用-音视频捕获
  • 嵌入式SDK技术EasyRTC音视频实时通话助力即时通信社交/教育等多场景创新应用
  • Win11系统不推送24H2/西数SSD无法安装24H2 - 解决方案
  • 6.4 note
  • 【请关注】VC内存泄露的排除及处理
  • 数据加密标准(DES)解析及代码实现(java)
  • 解决Vditor加载Markdown网页很慢的问题(Vite+JS+Vditor)
  • 后台管理系统八股
  • VRRP虚拟路由器协议的基本概述
  • 【Bluedroid】蓝牙启动之sdp_init 源码解析
  • Win11/Win10 打不开 gpedit.msc 之 组策略编辑器安装
  • 文件IO流
  • 生成JavaDoc文档
  • 安科电动机保护器通过ModbusRTU转profinet网关与PLC通讯
  • PowerShell脚本编程基础指南
  • Python爬虫解析动态网页:从渲染到数据提取
  • MAU算法流程理解
  • OpenEMMA: 打破Waymo闭源,首个开源端到端多模态模型
  • MPLS-EVPN笔记详述
  • 内存 DC(双缓冲)是个什么东西?
  • RM-R1:基于推理任务构建奖励模型
  • 飞腾D2000,麒麟系统V10,docker,ubuntu1804,小白入门喂饭级教程
  • JavaWeb是什么?总结一下JavaWeb的体系
  • 68道Hbase高频题整理(附答案背诵版)
  • RAG架构中用到的模型学习思考
  • 互联网三高架构 一