深度学习模型部署与加速汇总
深度学习模型的部署与加速方法归纳汇总
1. 模型训练
- 使用 PyTorch、TensorFlow、Keras、MXNet 等框架进行训练
- 输出:.pt, .pth, .h5, .pb 等格式的模型文件
2. 模型转换
将原始模型转换为适合部署的格式:
- ONNX
- TensorRT 格式(.engine)
- OpenVINO IR 格式(.xml, .bin)
- TFLite(用于移动端)
使用 ONNX + ONNX Runtime 部署
安装依赖:
sudo apt-get update
sudo apt-get install python3-pip
pip3 install torch onnx onnxruntime
# 导出ONNX模型
import torch
import torchvisionmodel = torchvision.models.resnet18(pretrained=True)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "resnet18.onnx")
# 使用 ONNX Runtime 推理
import onnxruntime as ort
import numpy as nport_session = ort.InferenceSession("resnet18.onnx")
outputs = ort_session.run(None,{'input': dummy_input.numpy()}
)
print(outputs)
3. 模型优化
- 量化(Quantization):FP32 → FP16 / INT8
- 剪枝(Pruning)
- 蒸馏(Distillation)
常用模型加速方法
方法 | 描述 | 工具 |
---|---|---|
FP16 推理 | 使用半精度浮点数降低计算量 | TensorRT, OpenVINO |
INT8 量化 | 使用整型代替浮点运算 | TensorRT, OpenVINO |
模型剪枝 | 减少冗余参数 | PyTorch/TensorFlow 自带工具 |
知识蒸馏 | 小模型模仿大模型输出 | 自定义损失函数 |
模型压缩(如 MobileNet, EfficientNet) | 使用轻量级网络结构 | TensorFlow Lite, MMDetection |
异构计算(GPU/CPU/NPU) | 利用不同硬件加速 | CUDA, OpenCL, VPU |
PyTorch 中剪枝 ResNet
import torch.nn.utils.prune as prunemodel = torchvision.models.resnet18(pretrained=True)
# 对 conv1 层做 L1 无结构剪枝
prune.l1_unstructured(model.conv1, name='weight', amount=0.3) # 剪掉 30% 权重
prune.remove(model.conv1, 'weight') # 固定剪枝后的权重
PyTorch 中简单蒸馏训练
teacher_model.eval()
student_model.train()criterion_ce = torch.nn.CrossEntropyLoss()
criterion_kl = torch.nn.KLDivLoss(reduction="batchmean")for images, labels in dataloader:with torch.no_grad():teacher_logits = teacher_model(images)student_logits = student_model(images)loss_ce = criterion_ce(student_logits, labels)loss_kd = criterion_kl(F.log_softmax(student_logits / T, dim=1),F.softmax(teacher_logits / T, dim=1)) * (alpha * T * T)total_loss = loss_ce + loss_kdoptimizer.zero_grad()total_loss.backward()optimizer.step()
4. 模型部署
- CPU 推理(OpenVINO, ONNX Runtime)
- GPU 推理(TensorRT, CUDA)
- 边缘设备(Jetson, RK3399, Coral TPU)
使用 Flask提供 REST API
from fastapi import FastAPI
import uvicornapp = FastAPI()@app.post("/predict")
def predict(data: dict):# 模型推理逻辑return {"result": ...}uvicorn.run(app, host="0.0.0.0", port=5000)
使用 Triton Inference Server
适用于多模型、多框架部署,支持动态批处理、模型热加载等高级功能。
# 启动 Triton
docker run --gpus all --rm -p8000:8000 -p8001:8001 -p8002:8002 nvcr.io/nvidia/tritonserver:23.09-py3
5. 接口封装
- REST API(Flask, FastAPI)
- gRPC
- 模型服务化(Triton Inference Server, TorchServe)