pytorch固定BN层参数的操作

 更新时间:2021年05月27日 08:58:15   作者:grllery  
这篇文章主要介绍了pytorch固定BN层参数的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

背景:

基于PyTorch的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch相同的测试数据经过主分支输出的结果不同。

原因:

未固定主分支BN层中的running_mean和running_var。

解决方法:

将需要固定的BN层状态设置为eval。

问题示例:

环境:torch:1.7.0

# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.bn1 = nn.BatchNorm2d(6)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.bn2 = nn.BatchNorm2d(16)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 5)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

def print_parameter_grad_info(net):
    print('-------parameters requires grad info--------')
    for name, p in net.named_parameters():
        print(f'{name}:\t{p.requires_grad}')

def print_net_state_dict(net):
    for key, v in net.state_dict().items():
        print(f'{key}')

if __name__ == "__main__":
    net = Net()

    print_parameter_grad_info(net)
    net.requires_grad_(False)
    print_parameter_grad_info(net)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假设每个epoch只迭代一次
        net.train()
        pre = net(train_data)
        # 计算损失和参数更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

运行结果:

-------parameters requires grad info--------
conv1.weight: True
conv1.bias: True
bn1.weight: True
bn1.bias: True
conv2.weight: True
conv2.bias: True
bn2.weight: True
bn2.bias: True
fc1.weight: True
fc1.bias: True
fc2.weight: True
fc2.bias: True
fc3.weight: True
fc3.bias: True
-------parameters requires grad info--------
conv1.weight: False
conv1.bias: False
bn1.weight: False
bn1.bias: False
conv2.weight: False
conv2.bias: False
bn2.weight: False
bn2.bias: False
fc1.weight: False
fc1.bias: False
fc2.weight: False
fc2.bias: False
fc3.weight: False
fc3.bias: False
epoch:0 tensor([[-0.0755, 0.1138, 0.0966, 0.0564, -0.0224]])
epoch:1 tensor([[-0.0763, 0.1113, 0.0970, 0.0574, -0.0235]])

可以看到:

net.requires_grad_(False)已经将网络中的各参数设置成了不需要梯度更新的状态,但是同样的测试数据test_data在不同epoch中前向之后出现了不同的结果。

调用print_net_state_dict可以看到BN层中的参数running_mean和running_var并没在可优化参数net.parameters中

bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked

但在training pahse的前向过程中,这两个参数被更新了。导致整个网络在freeze的情况下,同样的测试数据出现了不同的结果

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a defaultmomentumof 0.1. source

因此在training phase时对BN层显式设置eval状态:

if __name__ == "__main__":
    net = Net()
    net.requires_grad_(False)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假设每个epoch只迭代一次
        net.train()
        net.bn1.eval()
        net.bn2.eval()
        pre = net(train_data)
        # 计算损失和参数更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

可以看到结果正常了:

epoch:0 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])

补充:pytorch---之BN层参数详解及应用(1,2,3)(1,2)?

BN层参数详解(1,2)

一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层(对于BN层测试的均值和方差是通过统计训练的时候所有的batch的均值和方差的平均值)或者Dropout层(对于Dropout层在测试的时候所有神经元都是激活的)。通常用model.train()指定当前模型model为训练状态,model.eval()指定当前模型为测试状态。

同时,BN的API中有几个参数需要比较关心的,一个是affine指定是否需要仿射,还有个是track_running_stats指定是否跟踪当前batch的统计特性。容易出现问题也正好是这三个参数:trainning,affine,track_running_stats。

其中的affine指定是否需要仿射,也就是是否需要上面算式的第四个,如果affine=False则γ=1,β=0 \gamma=1,\beta=0γ=1,β=0,并且不能学习被更新。一般都会设置成affine=True。(这里是一个可学习参数)

trainning和track_running_stats,track_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性(意思就是说新的batch依赖于之前的batch的均值和方差这里使用momentum参数,参考了指数移动平均的算法EMA)。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。

应用技巧:(1,2)

通常pytorch都会用到optimizer.zero_grad() 来清空以前的batch所累加的梯度,因为pytorch中Variable计算的梯度会进行累计,所以每一个batch都要重新清空一次梯度,原始的做法是下面这样的:

问题:参数non_blocking,以及pytorch的整体框架??

代码(1)

for index,data,target in enumerate(dataloader):
    data = data.cuda(non_blocking=True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking = Trye)
    output = model(data)
    loss = criterion(output,target)
    
    #清空梯度
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

而这里为了模仿minibacth,我们每次batch不清0,累积到一定次数再清0,再更新权重:

for index, data, target in enumerate(dataloader):
    #如果不是Tensor,一般要用到torch.from_numpy()
    data = data.cuda(non_blocking = True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking = True)
    output = model(data)
    loss = criterion(data, target)
    loss.backward()
    if index%accumulation == 0:
        #用累积的梯度更新权重
        optimizer.step()
        #清空梯度
        optimizer.zero_grad()

虽然这里的梯度是相当于原来的accumulation倍,但是实际在前向传播的过程中,对于BN几乎没有影响,因为前向的BN还是只是一个batch的均值和方差,这个时候可以用pytorch中BN的momentum参数,默认是0.1,BN参数如下,就是指数移动平均

x_new_running = (1 - momentum) * x_running + momentum * x_new_observed. momentum

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python微信好友数据分析详解

    python微信好友数据分析详解

    这篇文章主要为大家详细介绍了python微信好友数据分析,实现对微信好友的获取,并对省份、性别等数据分析,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-11-11
  • 通过字符串导入 Python 模块的方法详解

    通过字符串导入 Python 模块的方法详解

    这篇文章主要介绍了通过字符串导入 Python 模块的方法详解,本文通过实例结合,给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-10-10
  • python实现向微信用户发送每日一句 python实现微信聊天机器人

    python实现向微信用户发送每日一句 python实现微信聊天机器人

    这篇文章主要为大家详细介绍了python实现向微信用户发送每日一句,python调实现微信聊天机器人,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-03-03
  • 基于python的docx模块处理word和WPS的docx格式文件方式

    基于python的docx模块处理word和WPS的docx格式文件方式

    今天小编就为大家分享一篇基于python的docx模块处理word和WPS的docx格式文件方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-02-02
  • 使用Python生成XML的方法实例

    使用Python生成XML的方法实例

    这篇文章主要介绍了使用Python生成XML的方法,结合具体实例形式详细分析了Python生成xml文件的具体流畅与相关注意事项,需要的朋友可以参考下
    2017-03-03
  • 如何在vscode中安装python库的方法步骤

    如何在vscode中安装python库的方法步骤

    这篇文章主要介绍了如何在vscode中安装python库的方法步骤,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-01-01
  • python中BackgroundScheduler和BlockingScheduler的区别

    python中BackgroundScheduler和BlockingScheduler的区别

    这篇文章主要介绍了python中BackgroundScheduler和BlockingScheduler的区别,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-07-07
  • Python扫描IP段查看指定端口是否开放的方法

    Python扫描IP段查看指定端口是否开放的方法

    这篇文章主要介绍了Python扫描IP段查看指定端口是否开放的方法,涉及Python使用socket模块实现端口扫描功能的相关技巧,需要的朋友可以参考下
    2015-06-06
  • Python判断文件和字符串编码类型的实例

    Python判断文件和字符串编码类型的实例

    下面小编就为大家分享一篇Python判断文件和字符串编码类型的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2017-12-12
  • PyTorch深度学习LSTM从input输入到Linear输出

    PyTorch深度学习LSTM从input输入到Linear输出

    这篇文章主要为大家介绍了PyTorch深度学习LSTM从input输入到Linear输出深入理解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-05-05

最新评论