pytorch 利用lstm做mnist手写数字识别分类的实例

 更新时间:2020年01月10日 10:43:23   作者:xckkcxxck  
今天小编就为大家分享一篇pytorch 利用lstm做mnist手写数字识别分类的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)    

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • el-table 多表格弹窗嵌套数据显示异常错乱问题解决方案

    el-table 多表格弹窗嵌套数据显示异常错乱问题解决方案

    使用vue+element开发报表功能时,需要列表上某列的超链接按钮弹窗展示,在弹窗的el-table列表某列中再次使用超链接按钮点开弹窗,以此类推多表格弹窗嵌套,本文以弹窗两次为例,需要的朋友可以参考下
    2023-11-11
  • 利用Python将原始边列表转换为邻接矩阵的过程

    利用Python将原始边列表转换为邻接矩阵的过程

    有时候,我们会从外部数据源中得到原始的边列表,而需要将其转换为邻接矩阵以便进行后续的分析和处理,本文将介绍如何使用Python来实现这一转换过程,需要的朋友可以参考下
    2024-04-04
  • JPype实现在python中调用JAVA的实例

    JPype实现在python中调用JAVA的实例

    本篇文章主要介绍了JPype实现在python中调用JAVA的实例,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-07-07
  • pandas使用get_dummies进行one-hot编码的方法

    pandas使用get_dummies进行one-hot编码的方法

    今天小编就为大家分享一篇pandas使用get_dummies进行one-hot编码的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07
  • python从入门到精通(DAY 1)

    python从入门到精通(DAY 1)

    本文是此次python从入门到精通系列文章的第一篇,给大家汇总一下常用的Python的基础知识,非常的简单,但是很全面,有需要的小伙伴可以参考下
    2015-12-12
  • Win10环境中如何实现python2和python3并存

    Win10环境中如何实现python2和python3并存

    这篇文章主要介绍了Win10环境中如何实现python2和python3并存,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-07-07
  • pytorch permute维度转换方法

    pytorch permute维度转换方法

    今天小编就为大家分享一篇pytorch permute维度转换方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • 简单了解Python字典copy与赋值的区别

    简单了解Python字典copy与赋值的区别

    这篇文章主要介绍了简单了解Python字典copy与赋值区别,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-09-09
  • python字符串判断密码强弱

    python字符串判断密码强弱

    这篇文章主要为大家详细介绍了python字符串判断密码强弱,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2020-03-03
  • Python CSV模块使用实例

    Python CSV模块使用实例

    这篇文章主要介绍了Python CSV模块使用实例,本文将举几个例子来介绍一下Python的CSV模块的使用方法,包括reader、writer、DictReader、DictWriter.register_dialect等,需要的朋友可以参考下
    2015-04-04

最新评论