pytorch 状态字典:state_dict使用详解

 更新时间:2020年01月17日 17:12:28   作者:wzg2016  
今天小编就为大家分享一篇pytorch 状态字典:state_dict使用详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

 torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

 model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

2.1)保存整个model的状态,用以下命令

torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

   # Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上这篇pytorch 状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 教你python制作自己的模块的基本步骤

    教你python制作自己的模块的基本步骤

    这篇文章主要介绍了python如何制作自己的模块,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2023-08-08
  • Python实现的人工神经网络算法示例【基于反向传播算法】

    Python实现的人工神经网络算法示例【基于反向传播算法】

    这篇文章主要介绍了Python实现的人工神经网络算法,结合实例形式分析了Python基于反向传播算法实现的人工神经网络相关操作技巧,需要的朋友可以参考下
    2017-11-11
  • Python利用treap实现双索引的方法

    Python利用treap实现双索引的方法

    所遍历的元素一定是递增(小堆)或是递减(大堆)关系,但是我们无法得知左子树与右子树两部分节点的排序关系。本文就来讲讲算法和数据结构共同满足一组特性,感兴趣的小伙伴请参考下面文章的内容
    2021-09-09
  • python基础入门之列表(一)

    python基础入门之列表(一)

    在Python中,列表(list)是常用的数据类型。列表由一系列按照特定顺序排列的项(item)组成。
    2021-06-06
  • python使用redis模块来跟redis实现交互

    python使用redis模块来跟redis实现交互

    这篇文章主要介绍了python使用redis模块来跟redis实现交互,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下
    2022-06-06
  • python判断、获取一张图片主色调的2个实例

    python判断、获取一张图片主色调的2个实例

    一幅图片,想通过程序判断获得其主要色调,应该怎么样处理?本文通过python实现判断、获取一张图片的主色调方法,需要的朋友可以参考下
    2014-04-04
  • 详解Python Flask API 示例演示(附cookies和session)

    详解Python Flask API 示例演示(附cookies和session)

    这篇文章主要为大家介绍了Python Flask API 示例演示(附cookies和session)详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-03-03
  • 利用Python将时间或时间间隔转为ISO 8601格式方法示例

    利用Python将时间或时间间隔转为ISO 8601格式方法示例

    国际标准化组织的国际标准ISO8601是日期和时间的表示方法,全称为《数据存储和交换形式·信息交换·日期和时间的表示方法》,下面这篇文章主要给大家介绍了关于利用Python将时间或时间间隔转为ISO 8601格式的相关资料,需要的朋友可以参考下。
    2017-09-09
  • python计算阶乘和的方法(1!+2!+3!+...+n!)

    python计算阶乘和的方法(1!+2!+3!+...+n!)

    今天小编就为大家分享一篇python计算阶乘和的方法(1!+2!+3!+...+n!),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-02-02
  • Python使用内置函数setattr设置对象的属性值

    Python使用内置函数setattr设置对象的属性值

    这篇文章主要介绍了Python使用内置函数setattr设置对象的属性值,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-10-10

最新评论