[特殊字符]【联邦学习实战】用 PyTorch 从 0 搭建一个最简单的联邦学习系统(含完整代码)
💡 本文是联邦学习系列的第二篇,适合有一定 PyTorch 基础的读者。
✅ 适合练手
✅ 模拟真实场景
✅ 附完整可运行代码
📖 目录
- 一、项目简介
- 二、环境准备
- 三、模型定义
- 四、数据划分:模拟客户端
- 五、本地训练函数
- 六、参数聚合函数(FedAvg)
- 七、联邦主流程
- 八、测试准确率
- 九、总结与思考
- 🔚 下一步建议
一、项目简介
本项目通过 PyTorch 模拟联邦学习流程:
- 两个客户端(Client)
- 每个客户端本地训练模型
- 不共享数据,仅共享模型参数
- 服务端聚合模型,更新全局参数
📌 任务:手写数字识别(MNIST)
二、环境准备
pip install torch torchvision matplotlib
三、模型定义
我们使用一个简单的 2 层全连接神经网络(MLP):
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(28*28, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28*28) # flattenx = F.relu(self.fc1(x))x = self.fc2(x)return x
四、数据划分:模拟客户端
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_splittransform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)# 模拟两个客户端,每人一半数据
client1_data, client2_data = random_split(train_dataset, [30000, 30000])
client1_loader = DataLoader(client1_data, batch_size=64, shuffle=True)
client2_loader = DataLoader(client2_data, batch_size=64, shuffle=True)
五、本地训练函数
def train_local(model, dataloader, epochs=1):model.train()optimizer = torch.optim.SGD(model.parameters(), lr=0.01)loss_fn = nn.CrossEntropyLoss()for _ in range(epochs):for X, y in dataloader:optimizer.zero_grad()output = model(X)loss = loss_fn(output, y)loss.backward()optimizer.step()return model.state_dict()
六、参数聚合函数(FedAvg)
def average_weights(w1, w2):avg_weights = {}for key in w1.keys():avg_weights[key] = (w1[key] + w2[key]) / 2return avg_weights
七、联邦主流程
global_model = MLP()for round in range(5):print(f"\n联邦训练第 {round+1} 轮")# 拷贝模型到客户端client1_model = MLP()client2_model = MLP()client1_model.load_state_dict(global_model.state_dict())client2_model.load_state_dict(global_model.state_dict())# 本地训练w1 = train_local(client1_model, client1_loader)w2 = train_local(client2_model, client2_loader)# 聚合参数new_weights = average_weights(w1, w2)global_model.load_state_dict(new_weights)
八、测试准确率
test_loader = DataLoader(datasets.MNIST('./data', train=False, download=True, transform=transform), batch_size=1000)def test(model):model.eval()correct = 0total = 0with torch.no_grad():for X, y in test_loader:outputs = model(X)_, predicted = torch.max(outputs, 1)correct += (predicted == y).sum().item()total += y.size(0)print(f"测试准确率: {correct / total * 100:.2f}%")test(global_model)
九、总结与思考
- 本例展示了 联邦学习核心思想:数据不动,模型移动
- 在本地训练的同时保护了“数据隐私”
- 虽然只有两个客户端,但结构已经接近真实场景
🔚 下一步建议
想继续深入?可以尝试以下方向:
- 💡 模拟多个客户端(5 个以上)
- 💡 客户端数据不均衡(Non-IID)
- 💡 聚合方式改进(如加权平均)
- 💡 加入差分隐私(DP-SGD)
- 💡 使用 CNN 模型替代 MLP
🗣 互动交流
💬 你是否还想看 “联邦学习 + 差分隐私”、“联邦学习在医疗中的模拟” 等内容?欢迎留言!
👍 点赞 + ⭐ 收藏 = 给我更多创作动力!