PyTorch深度学习模型的保存和加载流程详解

 更新时间:2021年10月21日 09:32:00   作者:软耳朵DONG  
PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch,这篇文章主要介绍了PyTorch模型的保存和加载流程

一、模型参数的保存和加载

  •  torch.save(module.state_dict(), path):使用module.state_dict()函数获取各层已经训练好的参数和缓冲区,然后将参数和缓冲区保存到path所指定的文件存放路径(常用文件格式为.pt.pth.pkl)。
  • torch.nn.Module.load_state_dict(state_dict):从state_dict中加载参数和缓冲区到Module及其子类中 。
  • torch.nn.Module.state_dict()函数返回python中的一个OrderedDict类型字典对象,该对象将每一层与它的对应参数和缓冲区建立映射关系,字典的键值是参数或缓冲区的名称。只有那些参数可以训练的层才会被保存到OrderedDict中,例如:卷积层、线性层等。
  • Python中的字典类以“键:值”方式存取数据,OrderedDict是它的一个子类,实现了对字典对象中元素的排序(OrderedDict根据放入元素的先后顺序进行排序)。由于进行了排序,所以顺序不同的两个OrderedDict字典对象会被当做是两个不同的对象。
  • 示例:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.pool1 = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x

# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 获取state_dict
state_dict = net.state_dict()
# 字典的遍历默认是遍历key,所以param_tensor实际上是键值
for param_tensor in state_dict: 
    print(param_tensor,':\n',state_dict[param_tensor])
# 保存模型参数
torch.save(state_dict,"net_params.pth")
# 通过加载state_dict获取模型参数
net.load_state_dict(state_dict)

输出:

在这里插入图片描述

二、完整模型的保存和加载

  •  torch.save(module, path):将训练完的整个网络模型module保存到path所指定的文件存放路径(常用文件格式为.pt.pth)。
  • torch.load(path):加载保存到path中的整个神经网络模型。
  • 示例:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.pool1 = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x

# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 保存整个网络
torch.save(net,"net.pth")
# 加载网络
net = torch.load("net.pth")

到此这篇关于PyTorch深度学习模型的保存和加载流程详解的文章就介绍到这了,更多相关PyTorch 模型的保存 内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python邮箱API发送邮件的方法和步骤

    Python邮箱API发送邮件的方法和步骤

    Python是一种功能强大的编程语言,可以用来发送电子邮件,使用Python发送邮件可以通过邮箱API来实现,aoksend将介绍使用Python邮箱API发送邮件的方法和步骤,需要的朋友可以参考下
    2024-04-04
  • Python web框架fastapi中间件的使用及CORS跨域问题

    Python web框架fastapi中间件的使用及CORS跨域问题

    fastapi "中间件"是一个函数,它在每个请求被特定的路径操作处理之前,以及在每个响应之后工作,它接收你的应用程序的每一个请求,下面通过本文给大家介绍Python web框架fastapi中间件的使用及CORS跨域问题,感兴趣的朋友一起看看吧
    2024-03-03
  • 用Python写一段用户登录的程序代码

    用Python写一段用户登录的程序代码

    下面小编就为大家分享一篇用Python写一段用户登录的程序代码,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • pycharm 实现光标快速移动到括号外或行尾的操作

    pycharm 实现光标快速移动到括号外或行尾的操作

    这篇文章主要介绍了pycharm 实现光标快速移动到括号外或行尾的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-02-02
  • python pandas loc 布尔索引示例说明

    python pandas loc 布尔索引示例说明

    loc跟iloc的区别,首先loc是location的意思,和iloc中i的意思是指integer,所以它只接受整数作为参数,详情见下面
    2022-03-03
  • Python爬虫:url中带字典列表参数的编码转换方法

    Python爬虫:url中带字典列表参数的编码转换方法

    今天小编就为大家分享一篇Python爬虫:url中带字典列表参数的编码转换方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • 关于python多进程中的常用方法详解

    关于python多进程中的常用方法详解

    这篇文章主要介绍了关于python多进程中的常用方法详解,python中的多线程其实并不是真正的多线程,如果想要充分地使用多核CPU资源,在python中大部分情况需要使用多进程,需要的朋友可以参考下
    2023-07-07
  • Python尾递归优化实现代码及原理详解

    Python尾递归优化实现代码及原理详解

    这篇文章主要介绍了Python尾递归优化实现代码及原理详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-10-10
  • 常见Python AutoEDA工具库及功能使用探究

    常见Python AutoEDA工具库及功能使用探究

    AutoEDA(自动探索性数据分析)工具库是数据科学中至关重要的一部分,它们能够自动生成数据摘要、探查数据的基本特征、检测异常值和提供可视化,为数据科学家和分析师们提供了解数据的便捷方式,本文为大家介绍常见的AutoEDA工具库及其功能和示例代码
    2024-01-01
  • 使用Python VTK 完成图像切割

    使用Python VTK 完成图像切割

    这篇文章主要介绍了使用Python VTK 完成图像切割,文章内容基于python的相关资料展开对主题的详细介绍,具有一定的参考价值,需要的小伙伴可以参考一下
    2022-04-04

最新评论