Pytorch如何加载部分权重
更新时间:2023年09月15日 10:14:16 作者:Mr_寒路
这篇文章主要介绍了Pytorch如何加载部分权重问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
1.修改网络层输出
比如在人脸检测项目中,已经训练好人脸框的回归,但是此时需要再加入人脸关键点。
为了节约大量时间,我们可以加载部分权重。加载的网络权重
if os.path.exists(self.load_params): pretext_model = torch.load(self.load_params)
打印出来,会看到网络权重存储在一个字典中,需要修改哪一层,用字典的键索引值进行修改。
比如原本输出层为4,我将网络输出层修改为14,又由于输出的都是坐标值,属于同一分布,所以我将原参4复制扩充为了14,效果非常好。
w = pretext_model["fc2.weight"] b = pretext_model["fc2.bias"] pretext_model["fc2.weight"] = torch.cat((w,w,w,w[:2]),dim=0) pretext_model["fc2.bias"] = torch.cat((b,b,b,b[:2]),dim=0)
最后加载修改后的参数
self.net.load_state_dict(pretext_model)
2.删除或增加了网络层
查看模型的参数,也是存放在一个字典中
if os.path.exists(self.load_params): pretext_model = torch.load(self.load_params) #加载的参数 model_dict = net.state_dict() #模型参数 print(model_dict) print(pretext_model)
#如果模型有k层,就加载 state_dict = {k: v for k, v in pretext_model.items() if k in model_dict.keys()} model_dict.update(state_dict) net.load_state_dict(model_dict)
3.迁移学习
有时我们也会用别人的模型,加载与训练参数,但是需要对输出层做一些修改,一般有两种方法,直接修改输出层个数或增加网络层
修改输出层个数
net = models.vgg19(pretrained=True) #下载与训练参数 print(net) #查看网络结构 net.classifier[6] = torch.nn.Linear(4096,10) #将输出层修改为10分类
增加输出网络层
num_fc_ftr = net.classifier[6] net.fc = nn.Linear(num_fc_ftr, 128) net.out = nn.Linear(128, 10)
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
示例详解Python3 or Python2 两者之间的差异
这篇文章主要介绍了Python3 or Python2?示例详解两者之间的差异,在本文中给大家介绍的非常详细,需要的朋友可以参考下2018-08-08itchat和matplotlib的结合使用爬取微信信息的实例
下面小编就为大家带来一篇itchat和matplotlib的结合使用爬取微信信息的实例。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧2017-08-08
最新评论