Pytorch写数字识别LeNet模型

 更新时间:2022年01月26日 17:54:21   作者:Jokic_Rn   
这篇文章主要介绍了Pytorch写数字识别LeNet模型,LeNet-5是一个较简单的卷积神经网络,  LeNet-5 这个网络虽然很小,但是它包含了深度学习的基本模块:卷积层,池化层,全连接层。是其他深度学习模型的基础, 这里我们对LeNet-5进行深入分析,需要的朋友可以参考下

LeNet网络

LeNet网络过卷积层时候保持分辨率不变,过池化层时候分辨率变小。实现如下

from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import numpy as np
import tqdm as tqdm

class LeNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.sequential = nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
                                        nn.AvgPool2d(kernel_size=2,stride=2),
                                        nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
                                        nn.AvgPool2d(kernel_size=2,stride=2),
                                        nn.Flatten(),
                                        nn.Linear(16*25,120),nn.Sigmoid(),
                                        nn.Linear(120,84),nn.Sigmoid(),
                                        nn.Linear(84,10))
        
    
    def forward(self,x):
        return self.sequential(x)

class MLP(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.sequential = nn.Sequential(nn.Flatten(),
                          nn.Linear(28*28,120),nn.Sigmoid(),
                          nn.Linear(120,84),nn.Sigmoid(),
                          nn.Linear(84,10))
        
    
    def forward(self,x):
        return self.sequential(x)

epochs = 15
batch = 32
lr=0.9
loss = nn.CrossEntropyLoss()
model = LeNet()
optimizer = torch.optim.SGD(model.parameters(),lr)
device = torch.device('cuda')
root = r"./"
trans_compose  = transforms.Compose([transforms.ToTensor(),
                    ])
train_data = torchvision.datasets.MNIST(root,train=True,transform=trans_compose,download=True)
test_data = torchvision.datasets.MNIST(root,train=False,transform=trans_compose,download=True)
train_loader = DataLoader(train_data,batch_size=batch,shuffle=True)
test_loader = DataLoader(test_data,batch_size=batch,shuffle=False)

model.to(device)
loss.to(device)
# model.apply(init_weights)
for epoch in range(epochs):
  train_loss = 0
  test_loss = 0
  correct_train = 0
  correct_test = 0
  for index,(x,y) in enumerate(train_loader):
    x = x.to(device)
    y = y.to(device)
    predict = model(x)
    L = loss(predict,y)
    optimizer.zero_grad()
    L.backward()
    optimizer.step()
    train_loss = train_loss + L
    correct_train += (predict.argmax(dim=1)==y).sum()
  acc_train = correct_train/(batch*len(train_loader))
  with torch.no_grad():
    for index,(x,y) in enumerate(test_loader):
      [x,y] = [x.to(device),y.to(device)]
      predict = model(x)
      L1 = loss(predict,y)
      test_loss = test_loss + L1
      correct_test += (predict.argmax(dim=1)==y).sum()
    acc_test = correct_test/(batch*len(test_loader))
  print(f'epoch:{epoch},train_loss:{train_loss/batch},test_loss:{test_loss/batch},acc_train:{acc_train},acc_test:{acc_test}')

训练结果

epoch:12,train_loss:2.235553741455078,test_loss:0.3947642743587494,acc_train:0.9879833459854126,acc_test:0.9851238131523132
epoch:13,train_loss:2.028963804244995,test_loss:0.3220392167568207,acc_train:0.9891499876976013,acc_test:0.9875199794769287
epoch:14,train_loss:1.8020273447036743,test_loss:0.34837451577186584,acc_train:0.9901833534240723,acc_test:0.98702073097229

泛化能力测试

找了一张图片,将其分割成只含一个数字的图片进行测试

images_np = cv2.imread("/content/R-C.png",cv2.IMREAD_GRAYSCALE)
h,w = images_np.shape
images_np = np.array(255*torch.ones(h,w))-images_np#图片反色
images = Image.fromarray(images_np)
plt.figure(1)
plt.imshow(images)
test_images = []
for i in range(10):
  for j in range(16):
    test_images.append(images_np[h//10*i:h//10+h//10*i,w//16*j:w//16*j+w//16])
sample = test_images[77]
sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device)
sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28))
predict = model(sample_tensor)
output = predict.argmax()
print(output)
plt.figure(2)
plt.imshow(np.array(sample_tensor.squeeze().to('cpu')))

此时预测结果为4,预测正确。从这段代码中可以看到有一个反色的步骤,若不反色,结果会受到影响,如下图所示,预测为0,错误。
模型用于输入的图片是单通道的黑白图片,这里由于可视化出现了黄色,但实际上是黑白色,反色操作说明了数据的预处理十分的重要,很多数据如果是不清理过是无法直接用于推理的。

将所有用来泛化性测试的图片进行准确率测试:

correct = 0
i = 0
cnt = 1
for sample in test_images:
  sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device)
  sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28))
  predict = model(sample_tensor)
  output = predict.argmax()
  if(output==i):
    correct+=1
  if(cnt%16==0):
    i+=1
  cnt+=1
acc_g = correct/len(test_images)
print(f'acc_g:{acc_g}')

如果不反色,acc_g=0.15

acc_g:0.50625

到此这篇关于Pytorch写数字识别LeNet模型的文章就介绍到这了,更多相关Pytorch写数字识别LeNet模型内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • 详解Python map函数及Python map()函数的用法

    详解Python map函数及Python map()函数的用法

    map() 会根据提供的函数对指定序列做映射。下面通过本文给大家介绍Python map函数及Python map()函数的用法,需要的朋友参考下吧
    2017-11-11
  • python多进程中的内存复制(实例讲解)

    python多进程中的内存复制(实例讲解)

    下面小编就为大家分享一篇python多进程中的内存复制(实例讲解),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-01-01
  • python简单程序读取串口信息的方法

    python简单程序读取串口信息的方法

    这篇文章主要介绍了python简单程序读取串口信息的方法,涉及Python操作serial模块的技巧,需要的朋友可以参考下
    2015-03-03
  • Python使用装饰器进行django开发实例代码

    Python使用装饰器进行django开发实例代码

    这篇文章主要介绍了Python使用装饰器进行django开发实例代码,分享了相关代码示例,小编觉得还是挺不错的,具有一定借鉴价值,需要的朋友可以参考下
    2018-02-02
  • 神经网络(BP)算法Python实现及应用

    神经网络(BP)算法Python实现及应用

    这篇文章主要为大家详细介绍了Python实现神经网络(BP)算法及简单应用,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-04-04
  • 如何查看Python安装了哪些包

    如何查看Python安装了哪些包

    这篇文章主要给大家介绍了关于如何查看Python安装了哪些包的相关资料, Conda是另一种广泛使用的Python包管理工具,它用于安装、管理和升级软件包和其依赖项,需要的朋友可以参考下
    2023-07-07
  • Python中的进程分支fork和exec详解

    Python中的进程分支fork和exec详解

    这篇文章主要介绍了Python中的进程分支fork和exec详解,本文用实例讲解fork()的使用,并讲解了exec相关的8个方法等内容,需要的朋友可以参考下
    2015-04-04
  • AI人工智能 Python实现人机对话

    AI人工智能 Python实现人机对话

    这篇文章主要为大家详细介绍了AI人工智能应用,本文拟使用Python开发语言实现类似于WIndows平台的“小娜”,,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2017-11-11
  • python 查找轮廓的实现示例

    python 查找轮廓的实现示例

    边缘检测是一种从图像中提取轮廓和特征的技术,本文主要介绍了python查找轮廓的实现示例,具有一定的参考价值,感兴趣的可以了解一下
    2024-07-07
  • python数据抓取分析的示例代码(python + mongodb)

    python数据抓取分析的示例代码(python + mongodb)

    本篇文章主要介绍了python数据抓取分析的示例代码(python + mongodb),小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-12-12

最新评论