PyTorch学习指南
首页
基础篇
进阶篇
高级篇
实战项目
🚀 编程指南
首页
基础篇
进阶篇
高级篇
实战项目
🚀 编程指南
  • 📈 进阶篇

    • 📈 进阶篇概述
    • 🧠 构建神经网络
    • 🎓 模型训练
    • 💾 保存与加载模型
    • 🚀 GPU加速训练

💾 保存与加载模型

训练好的模型需要保存下来,以便后续使用。本节介绍PyTorch中保存和加载模型的方法。

📝 两种保存方式

PyTorch有两种保存模型的方式:

方式保存内容优点缺点
state_dict(推荐)只保存参数灵活、文件小需要先定义模型结构
整个模型模型结构+参数方便不够灵活、可能有兼容问题

💡 推荐方式:保存state_dict

保存模型

import torch
import torch.nn as nn

# 定义模型
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 1)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

model = SimpleNet()

# 假设已经训练好了...

# 保存模型参数
torch.save(model.state_dict(), 'model.pth')
print("模型已保存到 model.pth")

加载模型

import torch

# 必须先创建相同结构的模型
model = SimpleNet()

# 加载参数
model.load_state_dict(torch.load('model.pth'))

# 设置为评估模式
model.eval()

print("模型已加载!")

# 使用模型进行预测
x = torch.randn(5, 10)
with torch.no_grad():
    output = model(x)
    print(output)

⚠️ 注意

加载前必须先创建模型实例,且模型结构要和保存时一致!

📦 保存整个模型

# 保存整个模型(包括结构)
torch.save(model, 'model_full.pth')

# 加载整个模型
model = torch.load('model_full.pth')
model.eval()

🔴 不推荐

这种方式在某些情况下可能出问题:

  • 代码结构改变后可能无法加载
  • 不同PyTorch版本可能不兼容
  • 文件更大

🎯 保存训练检查点(Checkpoint)

在长时间训练中,我们需要定期保存检查点,以便中断后继续训练:

import torch
import torch.nn as nn
import torch.optim as optim

# 保存检查点
def save_checkpoint(model, optimizer, epoch, loss, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, path)
    print(f"检查点已保存: epoch {epoch}")

# 加载检查点
def load_checkpoint(path, model, optimizer):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print(f"检查点已加载: epoch {epoch}, loss {loss:.4f}")
    return epoch, loss

# 使用示例
model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练过程中保存
for epoch in range(100):
    # ... 训练代码 ...
    
    # 每10个epoch保存一次
    if (epoch + 1) % 10 == 0:
        save_checkpoint(model, optimizer, epoch, loss, f'checkpoint_epoch{epoch+1}.pth')

# 从检查点恢复训练
start_epoch, _ = load_checkpoint('checkpoint_epoch50.pth', model, optimizer)
for epoch in range(start_epoch, 100):
    # 继续训练...
    pass

🔧 保存最佳模型

best_val_loss = float('inf')
best_model_path = 'best_model.pth'

for epoch in range(num_epochs):
    # 训练...
    train_loss = train_epoch(model, train_loader, criterion, optimizer)
    
    # 验证...
    val_loss = validate(model, val_loader, criterion)
    
    # 如果是最佳模型,保存
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_path)
        print(f"Epoch {epoch}: 保存最佳模型 (val_loss: {val_loss:.4f})")

🖥️ 跨设备加载

GPU训练的模型加载到CPU

# 在GPU上训练并保存
device = torch.device('cuda')
model = SimpleNet().to(device)
# ... 训练 ...
torch.save(model.state_dict(), 'model_gpu.pth')

# 在CPU上加载
model = SimpleNet()
model.load_state_dict(torch.load('model_gpu.pth', map_location='cpu'))

CPU训练的模型加载到GPU

# 在CPU上训练并保存
model = SimpleNet()
# ... 训练 ...
torch.save(model.state_dict(), 'model_cpu.pth')

# 在GPU上加载
device = torch.device('cuda')
model = SimpleNet()
model.load_state_dict(torch.load('model_cpu.pth', map_location=device))
model.to(device)

通用写法

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载时自动适配设备
model = SimpleNet()
model.load_state_dict(torch.load('model.pth', map_location=device))
model.to(device)
model.eval()

📋 保存额外信息

# 保存模型和其他信息
save_dict = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epoch,
    'loss': loss,
    'accuracy': accuracy,
    'hyperparameters': {
        'learning_rate': 0.001,
        'batch_size': 32,
        'hidden_size': 256,
    },
    'class_names': ['cat', 'dog', 'bird'],
}

torch.save(save_dict, 'model_with_info.pth')

# 加载
checkpoint = torch.load('model_with_info.pth')
print(f"训练epoch: {checkpoint['epoch']}")
print(f"准确率: {checkpoint['accuracy']}")
print(f"类别: {checkpoint['class_names']}")

🎯 完整示例

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 模型
class Classifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

# 训练函数
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    for data, labels in train_loader:
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# 主程序
def main():
    # 设置
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 模拟数据
    X = torch.randn(500, 20)
    y = torch.randint(0, 3, (500,))
    dataset = TensorDataset(X, y)
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
    
    # 模型
    model = Classifier(20, 64, 3).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # 训练
    for epoch in range(10):
        train(model, train_loader, criterion, optimizer, device)
        print(f"Epoch {epoch+1} 完成")
    
    # 保存
    torch.save({
        'model_state_dict': model.state_dict(),
        'input_size': 20,
        'hidden_size': 64,
        'num_classes': 3,
    }, 'classifier.pth')
    print("模型已保存!")

# 加载并使用
def predict():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 加载
    checkpoint = torch.load('classifier.pth', map_location=device)
    
    # 重建模型
    model = Classifier(
        checkpoint['input_size'],
        checkpoint['hidden_size'],
        checkpoint['num_classes']
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # 预测
    x = torch.randn(5, 20).to(device)
    with torch.no_grad():
        outputs = model(x)
        predictions = outputs.argmax(dim=1)
        print(f"预测结果: {predictions.tolist()}")

if __name__ == '__main__':
    main()
    predict()

📝 最佳实践总结

  1. 优先使用state_dict:更灵活、更安全
  2. 保存检查点:长时间训练时定期保存
  3. 保存超参数:方便复现结果
  4. 注意设备:使用map_location处理跨设备加载
  5. 设置eval模式:推理时别忘了model.eval()

下一步

模型保存好了,让我们学习如何使用GPU加速训练!

上次更新: 2025/11/25 18:38
Prev
🎓 模型训练
Next
🚀 GPU加速训练