Python中torch.load()加载模型以及其map_location参数详解

 更新时间:2022年09月23日 09:55:41   作者:eecspan  
torch.load()作用用来加载torch.save()保存的模型文件,下面这篇文章主要给大家介绍了关于Python中torch.load()加载模型以及其map_location参数的相关资料,需要的朋友可以参考下

参考

TORCH.LOAD

torch.load()

函数格式为:torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args),一般我们使用的时候,基本只使用前两个参数。

模型的保存

模型保存有两种形式,一种是保存模型的state_dict(),只是保存模型的参数。那么加载时需要先创建一个模型的实例model,之后通过torch.load()将保存的模型参数加载进来,得到dict,再通过model.load_state_dict(dict)将模型的参数更新。

另一种是将整个模型保存下来,之后加载的时候只需要通过torch.load()将模型加载,即可返回一个加载好的模型。

具体可参考:PyTorch模型的保存与加载

模型加载中的map_location参数

具体来说,map_location参数是用于重定向,比如此前模型的参数是在cpu中的,我们希望将其加载到cuda:0中。或者我们有多张卡,那么我们就可以将卡1中训练好的模型加载到卡2中,这在数据并行的分布式深度学习中可能会用到。

首先定义一个AlexNet,并使用cuda:0将其训练了一个猫狗分类,之后把模型存储起来。

map_location=None

我们先把state_dict加载进来。

model_path = "./cuda_model.pth"
model = torch.load(model_path)
print(next(model.parameters()).device)

结果为:

cuda:0

因为保存的时候就是模型就是cuda:0的,所以加载进来也是。

map_location=torch.device()

model_path = "./cuda_model.pth"
model = torch.load(model_path, map_location=torch.device('cpu'))
print(next(model.parameters()).device)

结果为:

cpu

模型从cuda:0变成了cpu

map_location={xx:xx}

model_path = "./cuda_model.pth"
model = torch.load(model_path, map_location={'cuda:0':'cuda:1'})
print(next(model.parameters()).device)

结果为:

cuda:1

模型从cuda:0变成了cuda:1

model_path = "./cuda_model.pth"
model = torch.load(model_path, map_location={'cuda:2':'cpu'})
print(next(model.parameters()).device)

结果为:

cuda:0

模型还是cuda:0,并没有变成cpu。因为这个map_location的映射是不对的,原始的模型就是cuda:0,而映射是cuda:2cpu,是不对的。这种情况下,map_location返回None,也就是和不加map_location相同。

总结

到此这篇关于Python中torch.load()加载模型以及其map_location参数详解的文章就介绍到这了,更多相关torch.load()加载模型map_location参数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Django使用原生SQL查询数据库详解

    Django使用原生SQL查询数据库详解

    本文介绍了Django ORM的优缺点,然后介绍了使用原生SQL进行查询的优点,包括更灵活、更高效等。接着介绍了如何在Django中使用原生SQL进行查询,包括利用Django的connection对象进行查询以及使用Django的CursorWrapper类进行封装。最后提醒了使用原生SQL查询的注意事项。
    2023-04-04
  • Pandas实现groupby分组统计方法实例

    Pandas实现groupby分组统计方法实例

    在数据处理的过程,有可能需要对一堆数据分组处理,例如对不同的列进行agg聚合操作(mean,min,max等等),下面这篇文章主要给大家介绍了关于Pandas实现groupby分组统计方法的相关资料,需要的朋友可以参考下
    2023-06-06
  • python Matplotlib数据可视化(1):简单入门

    python Matplotlib数据可视化(1):简单入门

    这篇文章主要介绍了python Matplotlib的相关资料,帮助大家入门matplotlib,绘制各种图表,感兴趣的朋友可以了解下
    2020-09-09
  • Python黑魔法之metaclass详情

    Python黑魔法之metaclass详情

    Python 有很多黑魔法,为了不分你的心,今天只讲 metaclass。对于 metaclass 这种特性,有两种极端的观点:下面小编将为大家详细的介绍,刚兴趣的小伙伴可以参考一下
    2021-09-09
  • Matplotlib使用字符串代替变量绘制散点图的方法

    Matplotlib使用字符串代替变量绘制散点图的方法

    这篇文章主要介绍了Matplotlib使用字符串代替变量绘制散点图的方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-02-02
  • Python django搭建layui提交表单,表格,图标的实例

    Python django搭建layui提交表单,表格,图标的实例

    今天小编就为大家分享一篇Python django搭建layui提交表单,表格,图标的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-11-11
  • Python的网络编程库Gevent的安装及使用技巧

    Python的网络编程库Gevent的安装及使用技巧

    Gevent库的奥义在于并发式的高性能网络程序设计支持,这里我们将来讲解Python的网络编程库Gevent的安装及使用技巧,来看一下Gevent支持的多进程程序编写:
    2016-06-06
  • Python之——生成动态路由轨迹图的实例

    Python之——生成动态路由轨迹图的实例

    今天小编就为大家分享一篇Python之——生成动态路由轨迹图的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-11-11
  • Python 多线程爬取案例

    Python 多线程爬取案例

    这篇文章主要介绍了Python 多线程爬取案例,爬虫属于I/O密集型的程序,所以使用多线程可以大大提高爬取效率,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下
    2022-08-08
  • python多线程爬取西刺代理的示例代码

    python多线程爬取西刺代理的示例代码

    这篇文章主要介绍了python多线程爬取西刺代理的示例代码,帮助大家更好的理解和学习python的爬虫,感兴趣的朋友可以了解下
    2021-01-01

最新评论