Pytorch统计参数网络参数数量方式

 更新时间:2023年02月20日 10:09:39   作者:qq_34535410  
这篇文章主要介绍了Pytorch统计参数网络参数数量方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

Pytorch统计参数网络参数数量

def get_parameter_number(net):
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

Pytorch如何计算网络的参数量

本文以 Dense Block 为例,Pytorch 为 DL 框架,最终计算模块参数量方法如下:

import torch
import torch.nn as nn

class Norm_Conv(nn.Module):

    def __init__(self,in_channel):
        super(Norm_Conv,self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channel,in_channel,3,1,1),
            nn.ReLU(True),
            nn.BatchNorm2d(in_channel),
            nn.Conv2d(in_channel,in_channel,3,1,1),
            nn.ReLU(True),
            nn.BatchNorm2d(in_channel),
            nn.Conv2d(in_channel,in_channel,3,1,1),
            nn.ReLU(True),
            nn.BatchNorm2d(in_channel))
    def forward(self,input):
        out = self.layers(input)
        return out


class DenseBlock_Norm(nn.Module):
    def __init__(self,in_channel):
        super(DenseBlock_Norm,self).__init__()

        self.first_layer = nn.Sequential(nn.Conv2d(in_channel,in_channel,3,1,1),
                                        nn.ReLU(True),
                                        nn.BatchNorm2d(in_channel))
        self.second_layer = nn.Sequential(nn.Conv2d(in_channel*2,in_channel,3,1,1),
                                          nn.ReLU(True),
                                          nn.BatchNorm2d(in_channel))
        self.third_layer = nn.Sequential(
            nn.Conv2d(in_channel*3,in_channel,3,1,1),
            nn.ReLU(True),
            nn.BatchNorm2d(in_channel))

    def forward(self,input):

        output1 = self.first_layer(input)
        output2 = self.second_layer(torch.cat((output1,input),dim=1))
        output3 = self.third_layer(torch.cat((input,output1,output2),dim=1))

        return output3

def count_param(model):
    param_count = 0
    for param in model.parameters():
        param_count += param.view(-1).size()[0]
    return param_count

# Get Parameter number of Network
in_channel = 128
net1 = Norm_Conv(in_channel)
print('Norm Conv parameter count is {}'.format(count_param(net1)))
net2 = DenseBlock_Norm(in_channel)
print('DenseBlock Norm parameter count is {}'.format(count_param(net2)))

最终结果如下

Norm Conv parameter count is 443520
DenseBlock Norm parameter count is 885888

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 使用pandas对两个dataframe进行join的实例

    使用pandas对两个dataframe进行join的实例

    今天小编就为大家分享一篇使用pandas对两个dataframe进行join的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • python selenium 获取接口数据的实现

    python selenium 获取接口数据的实现

    这篇文章主要介绍了python selenium 获取接口数据的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-12-12
  • Python中os模块功能与用法详解

    Python中os模块功能与用法详解

    这篇文章主要介绍了Python中os模块功能与用法,总结分析了Python os模块基本功能、内置函数、使用方法及相关操作注意事项,需要的朋友可以参考下
    2020-02-02
  • Python实现Windows上气泡提醒效果的方法

    Python实现Windows上气泡提醒效果的方法

    这篇文章主要介绍了Python实现Windows上气泡提醒效果的方法,涉及Python针对windows窗口操作的相关技巧,需要的朋友可以参考下
    2015-06-06
  • Python GAE、Django导出Excel的方法

    Python GAE、Django导出Excel的方法

    在Python中操作Excel的方法可以通过COM,最常用的跨平台的方法是使用pyExcelerator,pyExcelerator的使用方法可以参考limodou的《使用pyExcelerator来读写Excel文件》。
    2008-11-11
  • python循环某一特定列的所有行数据(方法示例)

    python循环某一特定列的所有行数据(方法示例)

    在Python中,处理表格数据(比如CSV文件、Excel文件等)时,我们通常会使用pandas库,因为它提供了丰富的数据结构和数据分析工具,下面,我将以处理CSV文件中的某一特定列的所有行数据为例,给出详细、完整的代码示例,感兴趣的朋友跟随小编一起看看吧
    2024-08-08
  • Python实现进程同步和通信的方法

    Python实现进程同步和通信的方法

    本篇文章主要介绍了Python实现进程同步和通信的方法,详细的介绍了Process、Queue、Pipe、Lock等组件,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-01-01
  • python根据文章标题内容自动生成摘要的实例

    python根据文章标题内容自动生成摘要的实例

    今天小编就为大家分享一篇python根据文章标题内容自动生成摘要的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-02-02
  • Python使用matplotlib.pyplot画热图和损失图的代码详解

    Python使用matplotlib.pyplot画热图和损失图的代码详解

    众所周知,在完成论文相关工作时画图必不可少,如损失函数图、热力图等是非常常见的图,在本文中,总结了这两个图的画法,下面给出了完整的代码,开箱即用,感兴趣的同学可以自己动手尝试一下
    2023-09-09
  • python统计中文字符数量的两种方法

    python统计中文字符数量的两种方法

    今天小编就为大家分享一篇python统计中文字符数量的两种方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-01-01

最新评论