当前位置: 首页 > news >正文

DAY 37 早停策略和模型权重的保存

  1. 早停策略

import torch.nn as nn
import torch.optim as optim
import time
import matplotlib.pyplot as plt
from tqdm import tqdm# Define the MLP model
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(X_train.shape[1], 10)self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 2)  # Binary classificationdef forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# Instantiate the model
model = MLP().to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# Training settings
num_epochs = 20000
early_stop_patience = 50  # Epochs to wait for improvement
best_loss = float('inf')
patience_counter = 0
best_epoch = 0
early_stopped = False# Track losses
train_losses = []
test_losses = []
epochs = []# Start training
start_time = time.time()
with tqdm(total=num_epochs, desc="Training Progress", unit="epoch") as pbar:for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(X_train)train_loss = criterion(outputs, y_train)train_loss.backward()optimizer.step()# Evaluate on the test setmodel.eval()with torch.no_grad():outputs_test = model(X_test)test_loss = criterion(outputs_test, y_test)if (epoch + 1) % 200 == 0:train_losses.append(train_loss.item())test_losses.append(test_loss.item())epochs.append(epoch + 1)# Early stopping checkif test_loss.item() < best_loss:  # If current test loss is better than the bestbest_loss = test_loss.item()  # Update best lossbest_epoch = epoch + 1  # Update best epochpatience_counter = 0  # Reset counter# Save the best modeltorch.save(model.state_dict(), 'best_model.pth')else:patience_counter += 1if patience_counter >= early_stop_patience:print(f"Early stopping triggered! No improvement for {early_stop_patience} epochs.")print(f"Best test loss was at epoch {best_epoch} with a loss of {best_loss:.4f}")early_stopped = Truebreak  # Stop the training loop# Update the progress barpbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'})# Update progress bar every 1000 epochsif (epoch + 1) % 1000 == 0:pbar.update(1000)# Ensure progress bar reaches 100%
if pbar.n < num_epochs:pbar.update(num_epochs - pbar.n)time_all = time.time() - start_time  # Calculate total training time
print(f'Training time: {time_all:.2f} seconds')# If early stopping occurred, load the best model
if early_stopped:print(f"Loading best model from epoch {best_epoch} for final evaluation...")model.load_state_dict(torch.load('best_model.pth'))# Continue training for 50 more epochs after loading the best model
num_extra_epochs = 50
for epoch in range(num_extra_epochs):model.train()optimizer.zero_grad()outputs = model(X_train)train_loss = criterion(outputs, y_train)train_loss.backward()optimizer.step()# Evaluate on the test setmodel.eval()with torch.no_grad():outputs_test = model(X_test)test_loss = criterion(outputs_test, y_test)train_losses.append(train_loss.item())test_losses.append(test_loss.item())epochs.append(num_epochs + epoch + 1)# Print progress for the extra epochsprint(f"Epoch {num_epochs + epoch + 1}: Train Loss = {train_loss.item():.4f}, Test Loss = {test_loss.item():.4f}")# Plot the loss curves
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss over Epochs')
plt.legend()
plt.grid(True)
plt.show()# Evaluate final accuracy on the test set
model.eval()
with torch.no_grad():outputs = model(X_test)_, predicted = torch.max(outputs, 1)correct = (predicted == y_test).sum().item()accuracy = correct / y_test.size(0)print(f'Test Accuracy: {accuracy * 100:.2f}%')

@浙大疏锦行

http://www.lqws.cn/news/452251.html

相关文章:

  • RPGMZ游戏引擎 如何手动控制文字显示速度
  • 机器翻译与跨语言学习数据集综述
  • 情感大模型
  • “地标界爱马仕”再拓疆域:世酒中菜联袂赤水金钗石斛定义中国GI
  • vue3 reactive重新赋值
  • QEMU学习之路(10)— RISCV64 virt 使用Ubuntu启动
  • Linux故障排查与性能优化实战经验
  • AI浪潮下的自媒体革命:智能体崛起与人类价值的重构
  • Qi无线充电:车载充电的便捷与安全之选
  • servlet前后端交互
  • C++设计模式
  • 在VTK中捕捉体绘制图像并实时图像处理
  • uniapp开发小程序,导出文件打开并保存,实现过程downloadFile下载,openDocument打开
  • 【Python】Excel表格操作:ISBN转条形码
  • React Native【实战范例】弹跳动画菜单导航
  • 学习threejs,三维汽车模拟器,场景有树、云、山等
  • Nginx-Ingress-Controller自定义端口实现TCP/UDP转发
  • 大数据系统架构实践(一):Zookeeper集群部署
  • 局域网投屏工具(将任何设备转换为计算机的辅助屏幕)Deskreen
  • LVS负载均衡群集:Nginx+Tomcat负载均衡群集
  • Lora训练
  • 项目管理利器:甘特图的全面解析与应用指南
  • 计算机网络八股第二期
  • net程序-Serilog 集成 SQL Server LocalDB 日志记录指南
  • 有方 N58 LTE Cat.1 模块联合 SD NAND 贴片式 TF 卡 MKDV1GIL-AST,打造 T-BOX 高性能解决方案
  • 如何在WordPress中添加导航菜单?
  • 基于 CNN-LSTM-GRU 架构的超音速导弹轨迹高级预测
  • Redis如何解决缓存击穿,缓存雪崩,缓存穿透
  • 技术革新赋能楼宇自控:物联网云计算推动应用前景深度拓展
  • 饼图:数据可视化的“切蛋糕”艺术