在pytorch中如何查看模型model参数parameters

 更新时间:2022年11月28日 14:52:59   作者:xiaoju233  
这篇文章主要介绍了在pytorch中如何查看模型model参数parameters,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

pytorch查看模型model参数parameters

示例1:pytorch自带的faster r-cnn模型

import torch
import torchvision

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

for name, p in model.named_parameters():
    print(name)
    print(p.requires_grad)
    print(...)

#或者

for p in model.parameters():
    print(p)
    print(...)

示例2:自定义网络模型

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
        self.features = self._vgg_layers(cfg)

    def _vgg_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x ,kernel_size=3, padding=1),
                        nn.BatchNorm2d(x),
                        nn.ReLU(inplace=True)
                        ]
                in_channels = x
            
        return nn.Sequential(*layers)

    def forward(self, data):
        out_map = self.features(data)
        return out_map
    
Model = Net()

for name, p in model.named_parameters():
    print(name)
    print(p.requires_grad)
    print(...)

#或者

for p in model.parameters():
    print(p)
    print(...)

在自定义网络中,model.parameters()方法继承自nn.Module

pytorch查看模型参数总结

1:DNN_printer

其中(3, 32, 32)是输入的大小,其他方法中的参数同理

from DNN_printer import DNN_printer

batch_size = 512
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    // put the code here and you can get the result
    DNN_printer(net, (3, 32, 32),batch_size)

结果

2:parameters

def cnn_paras_count(net):
    """cnn参数量统计, 使用方式cnn_paras_count(net)"""
    # Find total parameters and trainable parameters
    total_params = sum(p.numel() for p in net.parameters())
    print(f'{total_params:,} total parameters.')
    total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print(f'{total_trainable_params:,} training parameters.')
    return total_params, total_trainable_params

cnn_paras_count(net)

直接输出参数量,然后自己计算

需要注意的是,一般模型中参数是以float32保存的,也就是一个参数由4个bytes表示,那么就可以将参数量转化为存储大小。

例如:

  • 44426个参数*4 / 1024 ≈ 174KB

3:get_model_complexity_info()

from ptflops import get_model_complexity_info
from torchvision import models

net = models.mobilenet_v2()
ops, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, 
										print_per_layer_stat=True, verbose=True)

4:torchstat

from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))

输出

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python中Selenium模拟JQuery滑动解锁实例

    Python中Selenium模拟JQuery滑动解锁实例

    这篇文章主要介绍了Python中Selenium模拟JQuery滑动解锁实例,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2017-07-07
  • python使用箱型图剔除异常值的实现方法

    python使用箱型图剔除异常值的实现方法

    python中的箱线图可用于分析数据中的异常值,下面这篇文章主要给大家介绍了关于python使用箱型图剔除异常值的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2023-05-05
  • python实现的登录与提交表单数据功能示例

    python实现的登录与提交表单数据功能示例

    这篇文章主要介绍了python实现的登录与提交表单数据功能,结合实例形式分析了Python表单登录相关的请求与响应操作实现技巧,需要的朋友可以参考下
    2019-09-09
  • Python实现读取并写入Excel文件过程解析

    Python实现读取并写入Excel文件过程解析

    这篇文章主要介绍了Python实现读取并写入Excel文件过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-05-05
  • Pytorch的mean和std调查实例

    Pytorch的mean和std调查实例

    今天小编就为大家分享一篇Pytorch的mean和std调查实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • Python中各类Excel表格批量合并问题的实现思路与案例

    Python中各类Excel表格批量合并问题的实现思路与案例

    在日常工作中,可能会遇到各类表格合并的需求。本文主要介绍了Python中各类Excel表格批量合并问题的实现思路与案例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-01-01
  • python爬虫使用requests发送post请求示例详解

    python爬虫使用requests发送post请求示例详解

    这篇文章主要介绍了python爬虫使用requests发送post请求示例详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-08-08
  • Tensorflow与RNN、双向LSTM等的踩坑记录及解决

    Tensorflow与RNN、双向LSTM等的踩坑记录及解决

    这篇文章主要介绍了Tensorflow与RNN、双向LSTM等的踩坑记录及解决方案,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-05-05
  • python文件写入write()的操作

    python文件写入write()的操作

    这篇文章主要介绍了python文件写入write()的操作,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-05-05
  • 对网站内嵌gradio应用的输入输出做审核实现详解

    对网站内嵌gradio应用的输入输出做审核实现详解

    这篇文章主要为大家介绍了对网站内嵌gradio应用的输入输出做审核实现详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-04-04

最新评论