pytorch如何保存训练模型参数并实现继续训练

 更新时间:2023年09月11日 14:56:36   作者:回炉重造P  
这篇文章主要介绍了pytorch如何保存训练模型参数并实现继续训练问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

最近的想法是在推荐模型中考虑根据用户对推荐结果的后续选择,利用已训练的offline预训练模型参数来更新新的结果。

简单记录一下中途保存参数和后续使用不同数据训练的方法。

简单模型和训练数据

先准备一个简单模型,简单两层linear出个分类结果。

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(64, 32)
        self.linear1 = nn.Linear(32, 10)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        x = self.linear1(x)
        return x

准备训练用数据,这里直接随机两份,同时给出配套的十个类的分类label结果。

要注意的是 crossEntropy 交叉熵只认 long 以上的tensor,label记得转一下类型。

    rand1 = torch.rand((100, 64)).to(torch.float)
    label1 = np.random.randint(0, 10, size=100)
    label1 = torch.from_numpy(label1).to(torch.long)
    rand2 = torch.rand((100, 64)).to(torch.float)
    label2 = np.random.randint(0, 10, size=100)
    label2 = torch.from_numpy(label2).to(torch.long)

训练简单使用交叉熵,优化器Adam。

    model = MyModel()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    loss = nn.CrossEntropyLoss()
    iteration = 100
    for i in range(iteration):
        output = model(rand1)
        my_loss = loss(output, label1)
        optimizer.zero_grad()
        my_loss.backward()
        optimizer.step()
        print("iteration:{} loss:{}".format(i, my_loss))

反正能跑起来:

在这里插入图片描述

保存与读取训练参数结果的方法

关键的保存方法,可以分为两种,一种是直接把模型整体保存:

torch.save(model, save_path)

两个参数,模型和保存目录。不过这种不常用,如果模型变化或者只需要其中一部分参数就不太灵活。

常用方法的是将需要的模型或优化器参数取出以字典形式存储,这样可以在使用时初始化相关模型,读入对应参数即可。

def save_model(save_path, iteration, optimizer, model):
    torch.save({'iteration': iteration,
                'optimizer_dict': optimizer.state_dict(),
                'model_dict': model.state_dict()},
                save_path)
    print("model save success")

分别存储训练循环次数,优化器设置和模型参数结果。

初始化模型,读取参数并设置:

def load_model(save_name, optimizer, model):
    model_data = torch.load(save_name)
    model.load_state_dict(model_data['model_dict'])
    optimizer.load_state_dict(model_data['optimizer_dict'])
    print("model load success")

初始化新模型:

    path = "net.dict"
    save_model(path, iteration, optimizer, model)
    print(model.state_dict()['linear.weight'])
    new_model = MyModel()
    new_optimizer = torch.optim.Adam(new_model.parameters(), lr=0.01)
    load_model(path, new_optimizer, new_model)
    print(new_model.state_dict()['linear.weight'])

输出第一个linear层的参数看看,确实相同,参数成功读取上了。注意optimizer的初始化对应模型别写错了。

在这里插入图片描述

之后用新模型继续训练试试:

    for i in range(iteration):
        output = new_model(rand2)
        my_loss = loss(output, label2)
        new_optimizer.zero_grad()
        my_loss.backward()
        new_optimizer.step()
        print("iteration:{} loss:{}".format(i, my_loss))

能成功训练。

在这里插入图片描述

变化学习率的保存

上面的demo只用了固定的学习率来做实验。

如果使用了 scheduler 来变化步长,只要保存 scheduler state_dict ,之后对新初始化的 scheduler 设置对应的当前循环次数即可。

# 存储时
'scheduler': scheduler.state_dict()
# 读取时
scheduler.load_state_dict(checkpoint['lr_schedule'])

scheduler的使用可以看看我之前整理的文章:利用scheduler实现learning-rate学习率动态变化

总结

这次主要是整理了一下pytorch模型参数的整体保存方法,来实现新数据的后续训练或直接作为offline预训练模型来使用。

不过后续数据分布不同的话感觉效果会很差啊…

也不知道能不能用什么算法修改下权重来贴合新的数据,找点多次训练优化论文看看好了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python pyenv多版本管理工具的使用

    python pyenv多版本管理工具的使用

    这篇文章主要介绍了python pyenv多版本管理工具的使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-12-12
  • Python使用Redis实现作业调度系统(超简单)

    Python使用Redis实现作业调度系统(超简单)

    Redis作为内存数据库的一个典型代表,已经在很多应用场景中被使用,这里仅就Redis的pub/sub功能来说说怎样通过此功能来实现一个简单的作业调度系统。这里只是想展现一个简单的想法,所以还是有很多需要考虑的东西没有包括在这个例子中,比如错误处理,持久化等
    2016-03-03
  • 一文了解conda虚拟环境的使用及常见问题

    一文了解conda虚拟环境的使用及常见问题

    管理不同项目的依赖关系是一个常见而棘手的问题,本文主要介绍了一文了解conda虚拟环境的使用及常见问题,具有一定的参考价值,感兴趣的可以了解一下
    2024-03-03
  • 使用python为mysql实现restful接口

    使用python为mysql实现restful接口

    这篇文章主要介绍了使用python为mysql实现restful接口的相关资料,需要的朋友可以参考下
    2018-01-01
  • python web.py开发httpserver解决跨域问题实例解析

    python web.py开发httpserver解决跨域问题实例解析

    这篇文章主要介绍了python web.py开发httpserver解决跨域问题实例解析,分享了相关代码示例,小编觉得还是挺不错的,具有一定借鉴价值,需要的朋友可以参考下
    2018-02-02
  • python的paramiko模块基本用法详解

    python的paramiko模块基本用法详解

    paramiko 是一个用于在Python中执行远程操作的模块,支持SSH协议,它可以用于连接到远程服务器,执行命令、上传和下载文件,以及在远程服务器上执行各种操作,这篇文章主要介绍了python的paramiko模块基本用法,需要的朋友可以参考下
    2023-08-08
  • Python爬虫工程师面试问题总结

    Python爬虫工程师面试问题总结

    本篇文章给大家总结了关于Python爬虫工程师面试问题总结,希望我们整理的内容能够帮助到大家。
    2018-03-03
  • Python利用pandas和matplotlib实现绘制柱状折线图

    Python利用pandas和matplotlib实现绘制柱状折线图

    这篇文章主要为大家详细介绍了如何使用 Python 中的 Pandas 和 Matplotlib 库创建一个柱状图与折线图结合的数据可视化图表,感兴趣的可以了解一下
    2023-11-11
  • Scrapy框架使用的基本知识

    Scrapy框架使用的基本知识

    今天小编就为大家分享一篇关于Scrapy框架使用的基本知识,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2018-10-10
  • Python利用scapy实现ARP欺骗的方法

    Python利用scapy实现ARP欺骗的方法

    今天小编就为大家分享一篇Python利用scapy实现ARP欺骗的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07

最新评论