方法一:仅保存和加载模型参数(推荐)
这种方法只保存模型的参数(weights 和 biases),而不保存模型的结构。需要在加载模型时重新定义模型结构。
- 保存模型参数
import torch
# 假设 model 是已经训练好的模型实例
torch.save(model.state_dict(), 'model_parameters.pth')
- 加载模型参数
import torch
import torch.nn as nn
# 重新定义模型结构
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# 实例化模型
model = SimpleModel()
# 加载模型参数
model.load_state_dict(torch.load('model_parameters.pth'))
方法二:保存和加载整个模型
这种方法将模型的结构和参数一起保存和加载。它更简单,但不太灵活,因为你依赖于保存时的模型结构。
- 保存整个模型
# 假设 model 是已经训练好的模型实例
torch.save(model, 'full_model.pth')
- 加载整个模型
model = torch.load('full_model.pth')
两种方法对比
兼容性问题
方法一仅保存模型参数,只需要在加载时定义相同的模型结构,减少了环境和版本依赖性;而方法二会导致模型文件依赖于创建它的Python环境和PyTorch版本。如果在不同的环境或PyTorch版本中加载,可能会遇到兼容性问题。文件大小:
方法一仅保存参数,文件较小,更加高效。方法二由于保存了整个模型的结构和状态,文件通常会比仅保存参数的文件大。