pytorch fine-tune 预训练的模型操作
之一:
torchvision 中包含了很多预训练好的模型,这样就使得 fine-tune 非常容易。本文主要介绍如何 fine-tune torchvision 中预训练好的模型。
安装
pip install torchvision
如何 fine-tune
以 resnet18 为例:
from torchvision import models from torch import nn from torch import optim resnet_model = models.resnet18(pretrained=True) # pretrained 设置为 True,会自动下载模型 所对应权重,并加载到模型中 # 也可以自己下载 权重,然后 load 到 模型中,源码中有 权重的地址。 # 假设 我们的 分类任务只需要 分 100 类,那么我们应该做的是 # 1. 查看 resnet 的源码 # 2. 看最后一层的 名字是啥 (在 resnet 里是 self.fc = nn.Linear(512 * block.expansion, num_classes)) # 3. 在外面替换掉这个层 resnet_model.fc= nn.Linear(in_features=..., out_features=100) # 这样就 哦了,修改后的模型除了输出层的参数是 随机初始化的,其他层都是用预训练的参数初始化的。 # 如果只想训练 最后一层的话,应该做的是: # 1. 将其它层的参数 requires_grad 设置为 False # 2. 构建一个 optimizer, optimizer 管理的参数只有最后一层的参数 # 3. 然后 backward, step 就可以了 # 这一步可以节省大量的时间,因为多数的参数不需要计算梯度 for para in list(resnet_model.parameters())[:-2]: para.requires_grad=False optimizer = optim.SGD(params=[resnet_model.fc.weight, resnet_model.fc.bias], lr=1e-3) ...
为什么
这里介绍下 运行resnet_model.fc= nn.Linear(in_features=..., out_features=100)时 框架内发生了什么
这时应该看 nn.Module 源码的 __setattr__ 部分,因为 setattr 时都会调用这个方法:
def __setattr__(self, name, value): def remove_from(*dicts): for d in dicts: if name in d: del d[name]
首先映入眼帘就是 remove_from 这个函数,这个函数的目的就是,如果出现了 同名的属性,就将旧的属性移除。 用刚才举的例子就是:
预训练的模型中 有个 名字叫fc 的 Module。
在类定义外,我们 将另一个 Module 重新 赋值给了 fc。
类定义内的 fc 对应的 Module 就会从 模型中 删除。
之二:
前言
这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。
参数初始化
参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。
所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:
def weight_init(m): # 使用isinstance来判断m属于什么类型 if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): # m中的weight,bias其实都是Variable,为了能学习参数以及后向传播 m.weight.data.fill_(1) m.bias.data.zero_()
Finetune
往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。
局部微调
有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。
model = torchvision.models.resnet18(pretrained=True) for param in model.parameters(): param.requires_grad = False # 替换最后的全连接层, 改为训练100类 # 新构造的模块的参数默认requires_grad为True model.fc = nn.Linear(512, 100) # 只优化最后的分类层 optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
全局微调
有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:
ignored_params = list(map(id, model.fc.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, model.parameters()) optimizer = torch.optim.SGD([ {'params': base_params}, {'params': model.fc.parameters(), 'lr': 1e-3} ], lr=1e-2, momentum=0.9)
其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。
之三:
pytorch finetune模型
文章主要讲述如何在pytorch上读取以往训练的模型参数,在模型的名字已经变更的情况下又如何读取模型的部分参数等。
pytorch 模型的存储与读取
其中在模型的保存过程有存储模型和参数一起的也有单独存储模型参数的
单独存储模型参数
存储时使用:
torch.save(the_model.state_dict(), PATH)
读取时:
the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH))
存储模型与参数
存储:
torch.save(the_model, PATH)
读取:
the_model = torch.load(PATH)
模型的参数
fine-tune的过程是读取原有模型的参数,但是由于模型的所要处理的数据集不同,最后的一层class的总数不同,所以需要修改模型的最后一层,这样模型读取的参数,和在大数据集上训练好下载的模型参数在形式上不一样。需要我们自己去写函数读取参数。
pytorch模型参数的形式
模型的参数是以字典的形式存储的。
model_dict = the_model.state_dict(), for k,v in model_dict.items(): print(k)
即可看到所有的键值
如果想修改模型的参数,给相应的键值赋值即可
model_dict[k] = new_value
最后更新模型的参数
the_model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是一样的
我们可以通过下列算法进行读取模型
model_dict = model.state_dict() pretrained_dict = torch.load(model_path) # 1. filter out unnecessary keys diff = {k: v for k, v in model_dict.items() if \ k in pretrained_dict and pretrained_dict[k].size() == v.size()} pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()} pretrained_dict.update(diff) # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的
model_dict = model.state_dict() pretrained_dict = torch.load(model_path) keys = [] for k,v in pretrained_dict.items(): keys.append(k) i = 0 for k,v in model_dict.items(): if v.size() == pretrained_dict[keys[i]].size(): print(k, ',', keys[i]) model_dict[k]=pretrained_dict[keys[i]] i = i + 1 model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的
自己找对应关系,一个key对应一个key的赋值
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
用ReactJS和Python的Flask框架编写留言板的代码示例
这篇文章主要介绍了用ReactJS和Python的Flask框架编写留言板的代码示例,其他的话用到了MongoDB这个方便使用JavaScript来操作的数据库,需要的朋友可以参考下2015-12-12Django生成PDF文档显示在网页上以及解决PDF中文显示乱码的问题
这篇文章主要介绍了Django生成PDF文档显示在网页上以及解决PDF中文显示乱码的问题,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧2019-07-07
最新评论