[图像算法]-(yolov5.train)-Pytorch保存和加载模型完全指南: 关于使用Pytorch读写模型的一切方法

      本文是一篇关于如何用Pytorch保存和加载模型的指南


文章目录

  • 1.读写tensor
    • 1.1 单个张量
    • 1.2 张量列表和张量词典
  • 2.保存和加载模型
    • 2.1 state_dict
    • 2.2 保存和加载
      • 2.2.1 保存和加载state_dict(推荐方式)
      • 2.2.2 保存和读写整个模型
      • 2.2.3 保存和加载checkpiont
      • 2.2.4 在一个文件中保存多个模型
    • 2.3 使用来自不同模型的参数进行模型热启动
  • 3.跨设备保存和加载模型
    • 3.1 在GPU中保存,在CPU中加载
    • 3.2 在GPU中保存,在GPU中加载
    • 3.3 在CPU中保存,在GPU中加载
  • 4.保存torch.nn.DataParallel的模型

本文主要涉及到3个函数:

  • 1.torch.save: 使用Python的pickle实用程序将对象进行序列化,然后将序列化的对象保存到disk,可以保存各种对象,包括模型、张量和字典等。
  • 2.torch.load: 使用pickle unpickle工具将pickle的对象文件反序列化为内存。
  • 3.torch.nn.Module.load_state_dict: 用反序列化的state_dict来加载模型参数。

1.读写tensor

1.1单个张量

import torch

x = torch.tensor([3.,4.])
torch.save(x, 'x.pt')
x1 = torch.load('x.pt')
print(x1)

输出:

tensor([3., 4.])

1.2张量列表和张量词典

y = torch.ones((4,2))
torch.save([x,y],'xy.pt')
torch.save({'x':x, 'y':y}, 'xy_dict.pt')
xy = torch.load('xy.pt')
xy_dict = torch.load('xy_dict.pt')
print(xy)
print(xy_dict)

输出:

[tensor([3., 4.]), tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])]
{'x': tensor([3., 4.]), 'y': tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])}

2.保存和加载模型

2.1state_dict

state_dict是一个从每一个层的名称映射到这个层的参数Tesnor的字典对象。

注意,只有具有可学习参数的层(卷积层、线性层等)和注册缓存(batchnorm’s running_mean)才有state_dict中的条目。优化器(torch.optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。

from torch import nn
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(3, 2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2, 1)

    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)

net = MLP()
print(net.state_dict())
print('\n',net.state_dict()['output.weight'])

optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
print(optimizer.state_dict())

输出:

OrderedDict([('hidden.weight', tensor([[ 0.0405, -0.0659, -0.5540],
        [ 0.2954,  0.0676, -0.1933]])), ('hidden.bias', tensor([-0.1628,  0.0768])), ('output.weight', tensor([[-0.4635,  0.4958]])), ('output.bias', tensor([-0.5440]))])

 tensor([[-0.4635,  0.4958]])
{'state': {}, 'param_groups': [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3]}]}

2.2保存和加载

PyTorch中保存和加载训练模型有两种常见的方法:

    1. 仅保存和加载模型参数(state_dict);
    1. 保存和加载整个模型。
2.2.1保存和加载state_dict(推荐方式)
torch.save(net.state_dict(), 'net_state_dict.pt')## 后缀名一般写为: .pt或.pth
net1 = MLP()
net1.load_state_dict(torch.load('net_state_dict.pt'))
print(net1)

输出:

MLP(
  (hidden): Linear(in_features=3, out_features=2, bias=True)
  (act): ReLU()
  (output): Linear(in_features=2, out_features=1, bias=True)
)

注意: load_state_dict() 接受一个词典对象,而不是一个指向对象的路径。因此你需要先使用torch.load()来反序列化。比如,你不能直接这么用model.load_state_dict(PATH)。

2.2.2保存和读写整个模型

这个就相对来说比较简单了。

torch.save(net, 'net.pt')
net2 = torch.load('net.pt')
print(net2)

输出:

MLP(
  (hidden): Linear(in_features=3, out_features=2, bias=True)
  (act): ReLU()
  (output): Linear(in_features=2, out_features=1, bias=True)
)

注意:以这种方式保存模型将使用Python的pickle模块保存整个模块。 这种方法的缺点是序列化的数据被绑定到特定的类,并且在保存模型时使用了确切的词典结构。 这样做的原因是因为pickle不会保存模型类本身。 而是将其保存到包含这个类的文件的路径,该路径在加载时使用。 因此,在其他项目中使用或重构后,您的代码可能会以各种方式中断。

2.2.3保存和加载checkpiont
## Save
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

###########################
## Load
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

保存用于检查或继续训练的checkpiont时,你必须保存的不只是模型的state_dict。 保存优化器的state_dict也很重要,因为它包含随着模型训练而更新的缓冲区和参数。 你可能要保存的其他项目包括你未设置的时间段,最新记录的训练损失,外部torch.nn.Embedding层等。

2.2.4在一个文件中保存多个模型
#Save
torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)


#Load
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

常见的PyTorch约定是使用.tar文件扩展名保存这些检查点。

2.3使用来自不同模型的参数进行模型热启动

这种方法一般用于迁移学习。利用经过训练的参数,即使只有少数几个可用的参数,也将有助于热启动训练过程,并希望与从头开始训练相比,可以更快地收敛模型。

torch.save(modelA.state_dict(), PATH)

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)

无论是从缺少某些键的部分state_dict加载,还是加载比要加载的模型更多的keystate_dict,都可以在load_state_dict()函数中将strict参数设置为False,以忽略不匹配键。

如果你想要将一个层的参数加载到另一个层,但是一些keys不匹配,你只需改变你所加载的state_dict中的名称即可。

3.跨设备保存和加载模型

3.1在GPU中保存,在CPU中加载

torch.save(model.state_dict(), PATH)

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

3.2在GPU中保存,在GPU中加载

torch.save(model.state_dict(), PATH)

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

3.3在CPU中保存,在GPU中加载

torch.save(model.state_dict(), PATH)

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

4.保存torch.nn.DataParallel的模型

torch.save(model.module.state_dict(), PATH)

# Load to whatever device you want,加载方法使用常规方式即可。

参考链接:

  1. 官方文档
  2. Dive-into-DL-PyTorch
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,179评论 5 476
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,229评论 2 380
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,032评论 0 336
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,533评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,531评论 5 365
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,539评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,916评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,574评论 0 256
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,813评论 1 296
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,568评论 2 320
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,654评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,354评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,937评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,918评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,152评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,852评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,378评论 2 342

推荐阅读更多精彩内容