pytorch __init__、forward与__call__的用法小结

 更新时间:2021年02月27日 09:11:32   作者:时光碎了天  
这篇文章主要介绍了pytorch __init__、forward与__call__的用法小结,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

1.介绍

当我们使用pytorch来构建网络框架的时候,也会遇到和tensorflow(tensorflow __init__、build 和call小结)类似的情况,即经常会遇到__init__、forward和call这三个互相搭配着使用,那么它们的主要区别又在哪里呢?

1)__init__主要用来做参数初始化用,比如我们要初始化卷积的一些参数,就可以放到这里面,这点和tf里面的用法是一样的

2)forward是表示一个前向传播,构建网络层的先后运算步骤

3)__call__的功能其实和forward类似,所以很多时候,我们构建网络的时候,可以用__call__替代forward函数,但它们两个的区别又在哪里呢?

当网络构建完之后,调__call__的时候,会去先调forward,即__call__其实是包了一层forward,所以会导致两者的功能类似。

在pytorch在nn.Module中,实现了__call__方法,而在__call__方法中调用了forward函数:

https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py

2.代码

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class Net(nn.Module):
 def __init__(self, in_channels, mid_channels, out_channels):
 super(Net, self).__init__()
 self.conv0 = torch.nn.Sequential(
 torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 torch.nn.LeakyReLU())
 self.conv1 = torch.nn.Sequential(
 torch.nn.Conv2d(mid_channels, out_channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
 
 def forward(self, x):
 x = self.conv0(x)
 x = self.conv1(x)
 return x
 
class Net(nn.Module):
 def __init__(self, in_channels, mid_channels, out_channels):
 super(Net, self).__init__()
 self.conv0 = torch.nn.Sequential(
 torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 torch.nn.LeakyReLU())
 self.conv1 = torch.nn.Sequential(
 torch.nn.Conv2d(mid_channels, out_channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
 
 def __call__(self, x):
 x = self.conv0(x)
 x = self.conv1(x)
 return x

补充:torch/nn目录结构以及__init__.py

torch/nn目录结构以及init.py

torch/nn目录结构

__init__.py:

from .modules import *
#nn.modules  导入modules目录下内容 定义容器modules
from .parameter import Parameter
#nn.Parameter 导入parameter.py  定义parameter
from .parallel import DataParallel
#导入parallel目录下data_parallel.py中的DataParallel类
from . import init
#nn.init   导入init.py   参数初始化
from . import utils
#nn.utils  导入utils目录下内容 官网api下nn.utils下api

对于backends, functional.py, _functions 需要在代码前重新Import

例如我们常用的

import torch.nn.functional as F 就是导入了functional.py

backends和_functions是functional.py实现各种函数时所用到的。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。

相关文章

  • Python实现控制台输入密码的方法

    Python实现控制台输入密码的方法

    这篇文章主要介绍了Python实现控制台输入密码的方法,实例对比分析了几种输入密码的方法,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-05-05
  • Python基于Django实现验证码登录功能

    Python基于Django实现验证码登录功能

    验证码登录是一种常见的身份验证方式,它可以有效防止恶意攻击和机器人登录,本文将介绍如何基于Python Django实现验证码登录功能,需要的可以参考一下
    2023-05-05
  • Python如何用字典完成匹配任务

    Python如何用字典完成匹配任务

    在生物信息学领域,经常需要根据基因名称匹配其对应的编号,本文介绍了一种通过字典进行基因名称与编号匹配的方法,首先定义一个空列表存储对应编号,对于字典中不存在的基因名称,其编号默认为0
    2024-09-09
  • python实现批量改文件名称的方法

    python实现批量改文件名称的方法

    这篇文章主要介绍了python实现批量改文件名称的方法,涉及Python中os模块rename方法的相关使用技巧,需要的朋友可以参考下
    2015-05-05
  • python中BackgroundScheduler和BlockingScheduler的区别

    python中BackgroundScheduler和BlockingScheduler的区别

    这篇文章主要介绍了python中BackgroundScheduler和BlockingScheduler的区别,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-07-07
  • 10分钟教你用python动画演示深度优先算法搜寻逃出迷宫的路径

    10分钟教你用python动画演示深度优先算法搜寻逃出迷宫的路径

    这篇文章主要介绍了10分钟教你用python动画演示深度优先算法搜寻逃出迷宫的路径,非常不错,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-08-08
  • Python Pandas处理CSV文件的常用技巧分享

    Python Pandas处理CSV文件的常用技巧分享

    这篇文章主要和大家分享几个Python Pandas中处理CSV文件的常用技巧,如:统计列值出现的次数、筛选特定列值、遍历数据行等,需要的可以参考一下
    2022-06-06
  • Python中无限元素列表的实现方法

    Python中无限元素列表的实现方法

    这篇文章主要介绍了Python中无限元素列表的实现方法,很实用的功能,需要的朋友可以参考下
    2014-08-08
  • 基于opencv实现简单画板功能

    基于opencv实现简单画板功能

    这篇文章主要为大家详细介绍了基于opencv实现简单画板功能,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2020-08-08
  • 解决Pyinstaller 打包exe文件 取消dos窗口(黑框框)的问题

    解决Pyinstaller 打包exe文件 取消dos窗口(黑框框)的问题

    今天小编就为大家分享一篇解决Pyinstaller 打包exe文件 取消dos窗口(黑框框)的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-06-06

最新评论