PyTorch如何搭建一个简单的网络

 更新时间:2020年08月24日 10:27:45   作者:机器学习炼丹术  
这篇文章主要介绍了PyTorch如何搭建一个简单的网络,帮助大家更好的理解和学习PyTorch,感兴趣的朋友可以了解下

1 任务

首先说下我们要搭建的网络要完成的学习任务: 让我们的神经网络学会逻辑异或运算,异或运算也就是俗称的“相同取0,不同取1” 。再把我们的需求说的简单一点,也就是我们需要搭建这样一个神经网络,让我们在输入(1,1)时输出0,输入(1,0)时输出1(相同取0,不同取1),以此类推。

2 实现思路

因为我们的需求需要有两个输入,一个输出,所以我们需要在输入层设置两个输入节点,输出层设置一个输出节点。因为问题比较简单,所以隐含层我们只需要设置10个节点就可以达到不错的效果了,隐含层的激活函数我们采用ReLU函数,输出层我们用Sigmoid函数,让输出保持在0到1的一个范围,如果输出大于0.5,即可让输出结果为1,小于0.5,让输出结果为0.

3 实现过程

我们使用的简单的快速搭建法。

3.1 引入必要库

import torch
import torch.nn as nn
import numpy as np

用pytorch当然要引入torch包,然后为了写代码方便将torch包里的nn用nn来代替,nn这个包就是neural network的缩写,专门用来搭神经网络的一个包。引入numpy是为了创建矩阵作为输入。

3.2 创建训练集

# 构建输入集
x = np.mat('0 0;'
      '0 1;'
      '1 0;'
      '1 1')
x = torch.tensor(x).float()
y = np.mat('1;'
      '0;'
      '0;'
      '1')
y = torch.tensor(y).float()

我个人比较喜欢用np.mat这种方式构建矩阵,感觉写法比较简单,当然你也可以用其他的方法。但是构建完矩阵一定要有这一步 torch.tensor(x).float() ,必须要把你所创建的输入转换成tensor变量。

什么是tensor呢?你可以简单地理解他就是pytorch中用的一种变量,你想用pytorch这个框架就必须先把你的变量转换成tensor变量。而我们这个神经网络会要求你的输入和输出必须是float浮点型的,指的是tensor变量中的浮点型,而你用np.mat创建的输入是int型的,转换成tensor也会自动地转换成tensor的int型,所以要在后面加个.float()转换成浮点型。

这样我们就构建完成了输入和输出(分别是x矩阵和y矩阵),x是四行二列的一个矩阵,他的每一行是一个输入,一次输入两个值,这里我们把所有的输入情况都列了出来。输出y是一个四行一列的矩阵,每一行都是一个输出,对应x矩阵每一行的输入。

3.3 搭建网络

myNet = nn.Sequential( 
  nn.Linear(2,10),
  nn.ReLU(),
  nn.Linear(10,1),
  nn.Sigmoid()
  )
print(myNet)

输出结果:

我们使用nn包中的Sequential搭建网络,这个函数就是那个可以让我们像搭积木一样搭神经网络的一个东西。

nn.Linear(2,10)的意思搭建输入层,里面的2代表输入节点个数,10代表输出节点个数。Linear也就是英文的线性,意思也就是这层不包括任何其它的激活函数,你输入了啥他就给你输出了啥。nn.ReLU()这个就代表把一个激活函数层,把你刚才的输入扔到了ReLU函数中去。 接着又来了一个Linear,最后再扔到Sigmoid函数中去。 2,10,1就分别代表了三个层的个数,简单明了。

3.4 设置优化器

optimzer = torch.optim.SGD(myNet.parameters(),lr=0.05)
loss_func = nn.MSELoss()

对这一步的理解就是,你需要有一个优化的方法来训练你的网络,所以这步设置了我们所要采用的优化方法。

torch.optim.SGD的意思就是采用SGD(随机梯度下降)方法训练,你只需要把你网络的参数和学习率传进去就可以了,分别是 myNet.paramets 和 lr 。 loss_func 这句设置了代价函数,因为我们的这个问题比较简单,所以采用了MSE,也就是均方误差代价函数。

3.5 训练网络

for epoch in range(5000):
  out = myNet(x)
  loss = loss_func(out,y)
  optimzer.zero_grad()
  loss.backward()
  optimzer.step()

我这里设置了一个5000次的循环(可能不需要这么多次),让这个训练的动作迭代5000次。每一次的输出直接用myNet(x),把输入扔进你的网络就得到了输出out(就是这么简单粗暴!),然后用代价函数和你的标准输出y求误差。 清除梯度的那一步是为了每一次重新迭代时清除上一次所求出的梯度,你就把这一步记住就行,初学不用理解太深。 loss.backward() 当然就是让误差反向传播,接着 optimzer.step() 也就是让我们刚刚设置的优化器开始工作。

3.6 测试

print(myNet(x).data)

运行结果:

可以看到这个结果已经非常接近我们期待的结果了,当然你也可以换个数据测试,结果也会是相似的。这里简单解释下为什么我们的代码末尾加上了一个.data,因为我们的tensor变量其实是包含两个部分的,一部分是tensor数据,另一部分是tensor的自动求导参数,我们加上.data意思是输出取tensor中的数据,如果不加的话会输出下面这样:

以上就是PyTorch如何搭建一个简单的网络的详细内容,更多关于PyTorch搭建网络的资料请关注脚本之家其它相关文章!

相关文章

  • 深入了解Python枚举类型的相关知识

    深入了解Python枚举类型的相关知识

    这篇文章主要介绍了深入了解Python枚举类型的相关知识,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07
  • pytorch 更改预训练模型网络结构的方法

    pytorch 更改预训练模型网络结构的方法

    今天小编就为大家分享一篇pytorch 更改预训练模型网络结构的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • python中进程间通信详细介绍

    python中进程间通信详细介绍

    大家好,本篇文章主要讲的是python中进程间通信详细介绍,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下,方便下次浏览
    2021-12-12
  • 基于np.arange与np.linspace细微区别(数据溢出问题)

    基于np.arange与np.linspace细微区别(数据溢出问题)

    这篇文章主要介绍了基于np.arange与np.linspace细微区别(数据溢出问题),具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-05-05
  • Python 继承,重写,super()调用父类方法操作示例

    Python 继承,重写,super()调用父类方法操作示例

    这篇文章主要介绍了Python 继承,重写,super()调用父类方法,结合完整实例形式详细分析了Python面向对象程序设计中子类继承与重写父类方法的相关操作技巧,需要的朋友可以参考下
    2019-09-09
  • Python中set与frozenset方法和区别详解

    Python中set与frozenset方法和区别详解

    这篇文章主要介绍了Python中set与frozenset方法和区别详解的相关资料,需要的朋友可以参考下
    2016-05-05
  • Python第三方库xlrd/xlwt的安装与读写Excel表格

    Python第三方库xlrd/xlwt的安装与读写Excel表格

    最近开始学习python,想做做简单的自动化测试,需要读写excel,于是就接触到了Python的第三方库xlrd和xlwt,下面这篇文章就给大家主要介绍了Python中第三方库xlrd/xlwt的安装与读写Excel表格的方法,需要的朋友可以参考借鉴。
    2017-01-01
  • 使用Python对IP进行转换的一些操作技巧小结

    使用Python对IP进行转换的一些操作技巧小结

    这篇文章主要介绍了使用Python对IP进行转换的一些操作技巧小结,包括使用socket模块里的相关函数和匿名函数实现,需要的朋友可以参考下
    2015-11-11
  • 我用Python抓取了7000 多本电子书案例详解

    我用Python抓取了7000 多本电子书案例详解

    这篇文章主要介绍了我用Python抓取了7000 多本电子书案例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-03-03
  • 详解python3类型注释annotations实用案例

    详解python3类型注释annotations实用案例

    这篇文章主要介绍了详解python3类型注释annotations实用案例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-01-01

最新评论