💾 保存与加载模型
训练好的模型需要保存下来,以便后续使用。本节介绍PyTorch中保存和加载模型的方法。
📝 两种保存方式
PyTorch有两种保存模型的方式:
| 方式 | 保存内容 | 优点 | 缺点 |
|---|---|---|---|
| state_dict(推荐) | 只保存参数 | 灵活、文件小 | 需要先定义模型结构 |
| 整个模型 | 模型结构+参数 | 方便 | 不够灵活、可能有兼容问题 |
💡 推荐方式:保存state_dict
保存模型
import torch
import torch.nn as nn
# 定义模型
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)
model = SimpleNet()
# 假设已经训练好了...
# 保存模型参数
torch.save(model.state_dict(), 'model.pth')
print("模型已保存到 model.pth")
加载模型
import torch
# 必须先创建相同结构的模型
model = SimpleNet()
# 加载参数
model.load_state_dict(torch.load('model.pth'))
# 设置为评估模式
model.eval()
print("模型已加载!")
# 使用模型进行预测
x = torch.randn(5, 10)
with torch.no_grad():
output = model(x)
print(output)
⚠️ 注意
加载前必须先创建模型实例,且模型结构要和保存时一致!
📦 保存整个模型
# 保存整个模型(包括结构)
torch.save(model, 'model_full.pth')
# 加载整个模型
model = torch.load('model_full.pth')
model.eval()
🔴 不推荐
这种方式在某些情况下可能出问题:
- 代码结构改变后可能无法加载
- 不同PyTorch版本可能不兼容
- 文件更大
🎯 保存训练检查点(Checkpoint)
在长时间训练中,我们需要定期保存检查点,以便中断后继续训练:
import torch
import torch.nn as nn
import torch.optim as optim
# 保存检查点
def save_checkpoint(model, optimizer, epoch, loss, path):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, path)
print(f"检查点已保存: epoch {epoch}")
# 加载检查点
def load_checkpoint(path, model, optimizer):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(f"检查点已加载: epoch {epoch}, loss {loss:.4f}")
return epoch, loss
# 使用示例
model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练过程中保存
for epoch in range(100):
# ... 训练代码 ...
# 每10个epoch保存一次
if (epoch + 1) % 10 == 0:
save_checkpoint(model, optimizer, epoch, loss, f'checkpoint_epoch{epoch+1}.pth')
# 从检查点恢复训练
start_epoch, _ = load_checkpoint('checkpoint_epoch50.pth', model, optimizer)
for epoch in range(start_epoch, 100):
# 继续训练...
pass
🔧 保存最佳模型
best_val_loss = float('inf')
best_model_path = 'best_model.pth'
for epoch in range(num_epochs):
# 训练...
train_loss = train_epoch(model, train_loader, criterion, optimizer)
# 验证...
val_loss = validate(model, val_loader, criterion)
# 如果是最佳模型,保存
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), best_model_path)
print(f"Epoch {epoch}: 保存最佳模型 (val_loss: {val_loss:.4f})")
🖥️ 跨设备加载
GPU训练的模型加载到CPU
# 在GPU上训练并保存
device = torch.device('cuda')
model = SimpleNet().to(device)
# ... 训练 ...
torch.save(model.state_dict(), 'model_gpu.pth')
# 在CPU上加载
model = SimpleNet()
model.load_state_dict(torch.load('model_gpu.pth', map_location='cpu'))
CPU训练的模型加载到GPU
# 在CPU上训练并保存
model = SimpleNet()
# ... 训练 ...
torch.save(model.state_dict(), 'model_cpu.pth')
# 在GPU上加载
device = torch.device('cuda')
model = SimpleNet()
model.load_state_dict(torch.load('model_cpu.pth', map_location=device))
model.to(device)
通用写法
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载时自动适配设备
model = SimpleNet()
model.load_state_dict(torch.load('model.pth', map_location=device))
model.to(device)
model.eval()
📋 保存额外信息
# 保存模型和其他信息
save_dict = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'loss': loss,
'accuracy': accuracy,
'hyperparameters': {
'learning_rate': 0.001,
'batch_size': 32,
'hidden_size': 256,
},
'class_names': ['cat', 'dog', 'bird'],
}
torch.save(save_dict, 'model_with_info.pth')
# 加载
checkpoint = torch.load('model_with_info.pth')
print(f"训练epoch: {checkpoint['epoch']}")
print(f"准确率: {checkpoint['accuracy']}")
print(f"类别: {checkpoint['class_names']}")
🎯 完整示例
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
# 模型
class Classifier(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)
# 训练函数
def train(model, train_loader, criterion, optimizer, device):
model.train()
for data, labels in train_loader:
data, labels = data.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(data)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 主程序
def main():
# 设置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 模拟数据
X = torch.randn(500, 20)
y = torch.randint(0, 3, (500,))
dataset = TensorDataset(X, y)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
# 模型
model = Classifier(20, 64, 3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练
for epoch in range(10):
train(model, train_loader, criterion, optimizer, device)
print(f"Epoch {epoch+1} 完成")
# 保存
torch.save({
'model_state_dict': model.state_dict(),
'input_size': 20,
'hidden_size': 64,
'num_classes': 3,
}, 'classifier.pth')
print("模型已保存!")
# 加载并使用
def predict():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载
checkpoint = torch.load('classifier.pth', map_location=device)
# 重建模型
model = Classifier(
checkpoint['input_size'],
checkpoint['hidden_size'],
checkpoint['num_classes']
)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
# 预测
x = torch.randn(5, 20).to(device)
with torch.no_grad():
outputs = model(x)
predictions = outputs.argmax(dim=1)
print(f"预测结果: {predictions.tolist()}")
if __name__ == '__main__':
main()
predict()
📝 最佳实践总结
- 优先使用state_dict:更灵活、更安全
- 保存检查点:长时间训练时定期保存
- 保存超参数:方便复现结果
- 注意设备:使用
map_location处理跨设备加载 - 设置eval模式:推理时别忘了
model.eval()
下一步
模型保存好了,让我们学习如何使用GPU加速训练!