【机器学习深度学习】交互式线性回归 demo
目录
一、环境准备
二、Demo 功能
三、完整交互 demo 代码
3.1 执行代码
3.2 示例交互演示
3.3 运行结果
3.4 运行线性图
使用 PyTorch 构建交互式线性回归模型:输入数据、拟合直线、图像可视化并实现实时预测,助你深入理解机器学习从数据到模型的全过程。
一、环境准备
需要你本地能跑 Python + matplotlib
+ PyTorch
,无需其他安装。
二、Demo 功能
-
你输入一些样本点(比如面积和价格)
-
模型会学出一条最合适的线
-
自动训练并画图展示
-
你输入新数据,它来预测输出!
三、完整交互 demo 代码
PyTorch 的基本训练流程:
输入 → 线性模型建模 → 拟合 → 可视化 → 预测
3.1 执行代码
import torch
import torch.nn as nn
import matplotlib.pyplot as plt# 1. 让你输入数据(面积 → 房价)
print("请输入一些训练数据(输入特征和输出标签),格式如:50 100")
print("输入完后输入空行结束")X_list = []
y_list = []while True:line = input("请输入一组 (x y): ")if not line.strip():breaktry:x, y = map(float, line.strip().split())X_list.append([x])y_list.append([y])except:print("格式错误,请输入两个数字")X = torch.tensor(X_list, dtype=torch.float32)
y = torch.tensor(y_list, dtype=torch.float32)# 2. 定义模型
model = nn.Linear(1, 1)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 3. 训练模型
epochs = 1000
for epoch in range(epochs):optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()# 4. 显示结果
print("\n训练完成!")
w = model.weight.item()
b = model.bias.item()
print(f"模型学到的函数为: y = {w:.2f} * x + {b:.2f}")# 5. 可视化
with torch.no_grad():predicted = model(X).numpy()plt.scatter(X.numpy(), y.numpy(), label='原始数据')
plt.plot(X.numpy(), predicted, 'r-', label='拟合直线')
plt.title(f"y = {w:.2f}x + {b:.2f}")
plt.legend()
plt.grid(True)
plt.show()# 6. 输入新数据做预测
while True:new_x = input("\n输入新的 x(或输入 'q' 退出预测):")if new_x.lower() == 'q':breaktry:x_value = torch.tensor([[float(new_x)]])y_pred = model(x_value).item()print(f"预测值为:y = {y_pred:.2f}")except:print("请输入有效数字")
3.2 示例交互演示
请输入一组 (x y): 50 100
请输入一组 (x y): 60 120
请输入一组 (x y): 70 140
请输入一组 (x y): 80 160
请输入一组 (x y):
你会看到:
-
模型学到函数
y = 2.00 * x + 0.00
-
会弹出一张图,直线穿过这些点
-
然后你可以输入
90
,它预测180
3.3 运行结果
请输入一些训练数据(输入特征和输出标签),格式如:50 100
输入完后输入空行结束
请输入一组 (x y): 1 20
请输入一组 (x y): 2 40
请输入一组 (x y): 3 60
请输入一组 (x y): 4 80
请输入一组 (x y): 训练完成!
模型学到的函数为: y = 19.88 * x + 0.36输入新的 x(或输入 'q' 退出预测):80
预测值为:y = 1590.59输入新的 x(或输入 'q' 退出预测):5
预测值为:y = 99.75输入新的 x(或输入 'q' 退出预测):q