PyTorch实现MNIST数据集手写数字识别详情

 更新时间:2022年09月06日 14:17:45   作者:长浔  
这篇文章主要介绍了PyTorch实现MNIST数据集手写数字识别详情,文章围绕主题展开详细的内容戒杀,具有一定的参考价值,需要的朋友可以参考一下

前言:

本篇文章基于卷积神经网络CNN,使用PyTorch实现MNIST数据集手写数字识别。

一、PyTorch是什么?

PyTorch 是一个 Torch7 团队开源的 Python 优先的深度学习框架,提供两个高级功能:

  • 强大的 GPU 加速 Tensor 计算(类似 numpy)
  • 构建基于 tape 的自动升级系统上的深度神经网络

你可以重用你喜欢的 python 包,如 numpy、scipy 和 Cython ,在需要时扩展 PyTorch。

二、程序示例

下面案例可供运行参考

1.引入必要库

import torchvision
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

2.下载数据集

这里设置download=True,将会自动下载数据集,并存储在./data文件夹。

train_data = torchvision.datasets.MNIST(root="./data",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.MNIST(root="./data",train=False,transform=torchvision.transforms.ToTensor(),download=True)

3.加载数据集

batch_size=32表示每一个batch中包含32张手写数字图片,shuffle=True表示打乱测试集(data和target仍一一对应)

train_loader = DataLoader(train_data,batch_size=32,shuffle=True)
test_loader = DataLoader(test_data,batch_size=32,shuffle=False)

4.搭建CNN模型并实例化

class Net(torch.nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.con1 = torch.nn.Conv2d(1,10,kernel_size=5)
        self.con2 = torch.nn.Conv2d(10,20,kernel_size=5)
        self.pooling = torch.nn.MaxPool2d(2)
        self.fc = torch.nn.Linear(320,10)
    def forward(self,x):
        batch_size = x.size(0)
        x = F.relu(self.pooling(self.con1(x)))
        x = F.relu(self.pooling(self.con2(x)))
        x = x.view(batch_size,-1)
        x = self.fc(x)
        return x
#模型实例化        
model = Net()

5.交叉熵损失函数损失函数及SGD算法优化器

lossfun = torch.nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)

6.训练函数

def train(epoch):
    running_loss = 0.0
    for i,(inputs,targets) in enumerate(train_loader,0):
        # inputs,targets = inputs.to(device),targets.to(device)
        opt.zero_grad()
        outputs = model(inputs)
        loss = lossfun(outputs,targets)
        loss.backward()
        opt.step()

        running_loss += loss.item()
        if i % 300 == 299:
            print('[%d,%d] loss:%.3f' % (epoch+1,i+1,running_loss/300))
            running_loss = 0.0

7.测试函数

def test():
    total = 0
    correct = 0
    with torch.no_grad():
        for (inputs,targets) in test_loader:
            # inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _,predicted = torch.max(outputs.data,dim=1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    print(100*correct/total)

8.运行

if __name__ == '__main__':
    for epoch in range(20):
        train(epoch)
        test()

三、总结

到此这篇关于PyTorch实现MNIST数据集手写数字识别详情的文章就介绍到这了,更多相关PyTorch MNIST 内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python实现Linux中的du命令

    Python实现Linux中的du命令

    这篇文章主要介绍了Python实现Linux中简单du命令,需要的朋友可以参考下
    2017-06-06
  • Python 对输入的数字进行排序的方法

    Python 对输入的数字进行排序的方法

    今天小编就为大家分享一篇Python 对输入的数字进行排序的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • Django中使用Json返回数据的实现方法

    Django中使用Json返回数据的实现方法

    这篇文章主要介绍了Django中使用Json返回数据的实现方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-06-06
  • Pytorch固定随机数种子的方法小结

    Pytorch固定随机数种子的方法小结

    在对神经网络模型进行训练时,有时候会存在对训练过程进行复现的需求,然而,每次运行时 Pytorch、Numpy 中的随机性将使得该目的变得困难重重,基于此,本文记录了 Pytorch 中的固定随机数种子的方法,需要的朋友可以参考下
    2023-12-12
  • Python分析彩票记录并预测中奖号码过程详解

    Python分析彩票记录并预测中奖号码过程详解

    这篇文章主要介绍了Python分析彩票记录并预测中奖号码过程详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07
  • Pycharm运行时总是跳出Python Console问题

    Pycharm运行时总是跳出Python Console问题

    这篇文章主要介绍了Pycharm运行时总是跳出Python Console问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-04-04
  • Django中的AutoField字段使用

    Django中的AutoField字段使用

    这篇文章主要介绍了Django中的AutoField字段使用,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • Python对字符串实现去重操作的方法示例

    Python对字符串实现去重操作的方法示例

    字符串去重是python中字符串操作常见的一个需求,最近在工作中就又遇到了,所以下面这篇文章主要给大家介绍了关于Python对字符串实现去重操作的相关资料,文中给出了详细的介绍,需要的朋友可以参考借鉴,下面来一起看看吧。
    2017-08-08
  • 详解pytest分布式执行插件 pytest-xdist 的高级用法

    详解pytest分布式执行插件 pytest-xdist 的高级用法

    这篇文章主要介绍了pytest分布式执行插件 pytest-xdist 的高级用法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-08-08
  • Python Pandas中rolling方法的使用指南

    Python Pandas中rolling方法的使用指南

    在数据分析和时间序列数据处理中,经常需要执行滚动计算或滑动窗口操作,Pandas库提供了rolling方法,用于执行这些操作,下面我们就来学习一下rolling方法的具体使用吧
    2023-11-11

最新评论