谈谈对Pytorch中的forward的理解

 更新时间:2023年04月06日 16:07:57   作者:使者大牙  
这篇文章主要介绍了谈谈对Pytorch中的forward的理解,在Pytorch中,forward方法是一个特殊的方法,被专门用来进行前向传播,本文给大家详细讲解,需要的朋友可以参考下

写在前面

以下是本人根据Pytorch学习过程中总结出的经验,如果有错误,请指正。

正文

为什么都用def forward,而不改个名字?

在Pytorch建立神经元网络模型的时候,经常用到forward方法,表示在建立模型后,进行神经元网络的前向传播。说的直白点,forward就是专门用来计算给定输入,得到神经元网络输出的方法。

在代码实现中,也是用def forward来写forward前向传播的方法,我原来以为这是一种约定熟成的名字,也可以换成任意一个自己喜欢的名字。

但是看的多了之后发现并非如此:Pytorch对于forward方法赋予了一些特殊“功能”

(这里不禁再吐槽,一些看起来挺厉害的Pytorch“大神”,居然不知道这个。。。只能草草解释一下:“就是这样的。。。”)

forward有什么特殊功能?

第一条:.forward()可以不写

我最开始发现forward()的与众不同之处就是在此,首先举个例子:

import torch.nn as nn
class test(nn.Module):
    def __init__(self, input):
        super(test,self).__init__()
        self.input = input

    def forward(self,x):
        return self.input * x

T = test(8)
print(T(6))

# print(T.forward(6))
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48

Process finished with exit code 0

可以发现,T(6)是可以输出的!而且不用指定,默认了调用forward方法。当然如果非要写上.forward()这也是可以正常运行的,和不写是一样的。

如果不调用Pytorch(正常的Python语法规则),这样肯定会报错的

# import torch.nn as nn  #不再调用torch
class test():
    def __init__(self, input):
        self.input = input

    def forward(self,x):
        return self.input * x

T = test(8)
print(T.forward(6))
print("************************")
print(T(6))
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48
************************
Traceback (most recent call last):
  File "C:\Users\Lenovo\Desktop\DL\pythonProject\tt.py", line 77, in <module>
    print(T(6))
TypeError: 'test' object is not callable

Process finished with exit code 1

这里会报:‘test’ object is not callable
因为class不能被直接调用,不知道你想调用哪个方法。

第二条:优先运行forward方法

如果在class中再增加一个方法:

import torch.nn as nn
class test(nn.Module):
    def __init__(self, input):
        super(test,self).__init__()
        self.input = input

    def byten(self):
        return self.input * 10

    def forward(self,x):
        return self.input * x

T = test(8)
print(T(6))
print(T.byten())
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48
80

Process finished with exit code 0

可以见到,在class中有多个method的时候,如果不指定method,forward是会被优先执行的。

总结

在Pytorch中,forward方法是一个特殊的方法,被专门用来进行前向传播。

到此这篇关于谈谈对Pytorch中的forward的理解的文章就介绍到这了,更多相关Pytorch中的forward内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python基于回溯法子集树模板解决选排问题示例

    Python基于回溯法子集树模板解决选排问题示例

    这篇文章主要介绍了Python基于回溯法子集树模板解决选排问题,简单描述了选排问题并结合实例形式分析了Python使用回溯法子集树模板解决选排问题的具体实现步骤与相关操作注意事项,需要的朋友可以参考下
    2017-09-09
  • Scrapy爬虫文件批量运行的实现

    Scrapy爬虫文件批量运行的实现

    这篇文章主要介绍了Scrapy爬虫文件批量运行的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-09-09
  • PyCharm常用配置和常用插件(小结)

    PyCharm常用配置和常用插件(小结)

    这篇文章主要介绍了PyCharm常用配置和常用插件(小结),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-02-02
  • 浅谈python爬虫使用Selenium模拟浏览器行为

    浅谈python爬虫使用Selenium模拟浏览器行为

    这篇文章主要介绍了浅谈python爬虫使用Selenium模拟浏览器行为,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-02-02
  • Python获取字典中的值的八种方法

    Python获取字典中的值的八种方法

    Python 字典(dictionary)是一种可变容器模型,可以存储任意数量的任意类型的数据,字典通常用于存储键值对,每个元素由一个键(key)和一个值(value)组成,键和值之间用冒号分隔,本文给大家介绍了Python 字典取值的几种方法及其代码演示,需要的朋友可以参考下
    2024-07-07
  • 解决django后台管理界面添加中文内容乱码问题

    解决django后台管理界面添加中文内容乱码问题

    今天小编就为大家分享一篇解决django后台管理界面添加中文内容乱码问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-11-11
  • Python中的各种装饰器详解

    Python中的各种装饰器详解

    这篇文章主要介绍了Python中的各种装饰器详解,Python装饰器分两部分,一是装饰器本身的定义,一是被装饰器对象的定义,本文分别讲解了各种情况下的装饰器,需要的朋友可以参考下
    2015-04-04
  • 让python同时兼容python2和python3的8个技巧分享

    让python同时兼容python2和python3的8个技巧分享

    这篇文章主要介绍了让python同时兼容python2和python3的8个技巧分享,对代码稍微做些修改就可以很好的同时支持python2和python3的,需要的朋友可以参考下
    2014-07-07
  • python常量折叠基础知识点讲解

    python常量折叠基础知识点讲解

    在本篇文章里小编给大家整理的是一篇关于python常量折叠基础知识点讲解,对此有兴趣的朋友可以跟着学习下。
    2021-02-02
  • python中list列表复制的几种方法(赋值、切片、copy(),deepcopy())

    python中list列表复制的几种方法(赋值、切片、copy(),deepcopy())

    本文主要介绍了python中list列表复制的几种方法(赋值、切片、copy(),deepcopy()),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2022-08-08

最新评论