Pytorch搭建简单的卷积神经网络(CNN)实现MNIST数据集分类任务

 更新时间:2023年03月23日 10:04:59   作者:无知的吱屋  
这篇文章主要介绍了Pytorch搭建简单的卷积神经网络(CNN)实现MNIST数据集分类任务,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

关于一些代码里的解释,可以看我上一篇发布的文章,里面有很详细的介绍!!!

可以依次把下面的代码段合在一起运行,也可以通过jupyter notebook分次运行

第一步:基本库的导入

import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time
np.random.seed(1234)

第二步:引用MNIST数据集,这里采用的是torchvision自带的MNIST数据集

#这里用的是torchvision已经封装好的MINST数据集
trainset=torchvision.datasets.MNIST(
    root='MNIST',  #root是下载MNIST数据集保存的路径,可以自行修改
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
 
testset=torchvision.datasets.MNIST(
    root='MNIST',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
 
trainloader = DataLoader(dataset=trainset, batch_size=100, shuffle=True)   #DataLoader是一个很好地能够帮助整理数据集的类,可以用来分批次,打乱以及多线程等操作
testloader = DataLoader(dataset=testset, batch_size=100, shuffle=True)

下载之后利用DataLoader实例化为适合遍历的训练集和测试集,我们把其中的某一批数据进行可视化,下面是可视化的代码,其实就是利用subplot画了子图。

#可视化某一批数据
train_img,train_label=next(iter(trainloader))   #iter迭代器,可以用来便利trainloader里面每一个数据,这里只迭代一次来进行可视化
fig, axes = plt.subplots(10, 10, figsize=(10, 10))
axes_list = []
#输入到网络的图像
for i in range(axes.shape[0]):
    for j in range(axes.shape[1]):
        axes[i, j].imshow(train_img[i*10+j,0,:,:],cmap="gray")    #这里画出来的就是我们想输入到网络里训练的图像,与之对应的标签用来进行最后分类结果损失函数的计算
        axes[i, j].axis("off")
#对应的标签
print(train_label)

 第三步:用pytorch搭建简单的卷积神经网络(CNN)

 这里把卷积模块单独拿出来作为一个类,看上去会舒服一点。

#卷积模块,由卷积核和激活函数组成
class conv_block(nn.Module):
    def __init__(self,ks,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=ks,stride=1,padding=1,bias=True),  #二维卷积核,用于提取局部的图像信息
            nn.ReLU(inplace=True), #这里用ReLU作为激活函数
            nn.Conv2d(ch_out, ch_out, kernel_size=ks,stride=1,padding=1,bias=True),
            nn.ReLU(inplace=True),
        )
    def forward(self,x):
        return self.conv(x)

下面是CNN主体部分,由上面的卷积模块和全连接分类器组合而成。这里只用了简单的几个卷积块进行堆叠,没有采用池化以及dropout的操作。主要目的是给大家简单搭建一下以便学习。

#常规CNN模块(由几个卷积模块堆叠而成)
class CNN(nn.Module):
    def __init__(self,kernel_size,in_ch,out_ch):
        super(CNN, self).__init__()
        feature_list = [16,32,64,128,256]   #代表每一层网络的特征数,扩大特征空间有助于挖掘更多的局部信息
        self.conv1 = conv_block(kernel_size,in_ch,feature_list[0])
        self.conv2 = conv_block(kernel_size,feature_list[0],feature_list[1])
        self.conv3 = conv_block(kernel_size,feature_list[1],feature_list[2])
        self.conv4 = conv_block(kernel_size,feature_list[2],feature_list[3])
        self.conv5 = conv_block(kernel_size,feature_list[3],feature_list[4])
        self.fc =  nn.Sequential(           #全连接层主要用来进行分类,整合采集的局部信息以及全局信息
            nn.Linear(feature_list[4] * 28 * 28, 1024),  #此处28为MINST一张图片的维度
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )
 
    def forward(self,x):
        device = x.device
        x1 = self.conv1(x )
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        x4 = self.conv4(x3)
        x5 = self.conv5(x4)
        x5 = x5.view(x5.size()[0], -1)  #全连接层相当于做了矩阵乘法,所以这里需要将维度降维来实现矩阵的运算
        out = self.fc(x5)
        return out

第四步:训练以及模型保存

先是一些网络参数的定义,包括优化器,迭代轮数,学习率,运行硬件等等的确定。

#网络参数定义
device = torch.device("cuda:4")  #此处根据电脑配置进行选择,如果没有cuda就用cpu
#device = torch.device("cpu")
net = CNN(3,1,1).to(device = device,dtype = torch.float32)
epochs = 50  #训练轮次
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-8)  #使用Adam优化器
criterion = nn.CrossEntropyLoss()  #分类任务常用的交叉熵损失函数
train_loss = []

然后是每一轮训练的主体:

# Begin training
MinTrainLoss = 999
for epoch in range(1,epochs+1):
    total_train_loss = []      
    net.train()
    start = time.time()
    for input_img,label in trainloader:
        input_img = input_img.to(device = device,dtype=torch.float32)  #我们同样地,需要将我们取出来的训练集数据进行torch能够运算的格式转换
        label = label.to(device = device,dtype=torch.float32)          #输入和输出的格式都保持一致才能进行运算
        optimizer.zero_grad()  #每一次算loss前需要将之前的梯度清零,这样才不会影响后面的更新
        pred_img = net(input_img) 
        loss = criterion(pred_img,label.long())
        loss.backward()
        optimizer.step()
        total_train_loss.append(loss.item())
    train_loss.append(np.mean(total_train_loss))    #将一个minibatch里面的损失取平均作为这一轮的loss
    end = time.time()
    #打印当前的loss
    print("epochs[%3d/%3d] current loss: %.5f, time: %.3f"%(epoch,epochs,train_loss[-1],(end-start)))   #打印每一轮训练的结果
    
    if train_loss[-1]<MinTrainLoss:
        torch.save(net.state_dict(), "./model_min_train.pth")  #保存loss最小的模型
        MinTrainLoss = train_loss[-1]

以下是迭代过程:

 第五步:导入网络模型,输入某一批测试数据,查看结果

我们先来看某一批测试数据

#测试机某一批数据
test_img,test_label=next(iter(testloader))
fig, axes = plt.subplots(10, 10, figsize=(10, 10))
axes_list = []
#输入到网络的图像
for i in range(axes.shape[0]):
    for j in range(axes.shape[1]):
        axes[i, j].imshow(test_img[i*10+j,0,:,:],cmap="gray")
        axes[i, j].axis("off")

然后将其输入到训练好的模型进行预测

#预测我拿出来的那一批数据进行展示
cnn = CNN(3,1,1).to(device = device,dtype = torch.float32)
cnn.load_state_dict(torch.load("./model_min_train.pth", map_location=device)) #导入我们之前已经训练好的模型
cnn.eval()   #评估模式
 
test_img = test_img.to(device = device,dtype = torch.float32)
test_label = test_label.to(device = device,dtype = torch.float32)
 
pred_test = cnn(test_img)  #记住,输出的结果是一个长度为10的tensor
test_pred = np.argmax(pred_test.cpu().data.numpy(), axis=1)  #所以我们需要对其进行最大值对应索引的处理,从而得到我们想要的预测结果
 
#预测结果以及标签
print("预测结果")
print(test_pred)
print("标签")
print(test_label.cpu().data.numpy())

从预测的结果我们可以看到,整体上这么一个简单的CNN搭配全连接分类器对MNIST这一批数据分类的效果还不错。当然,我这里只用了交叉熵损失函数,并且没有计算准确率,仅供大家对于CNN学习和参考。

到此这篇关于Pytorch搭建简单的卷积神经网络(CNN)实现MNIST数据集分类任务的文章就介绍到这了,更多相关Pytorch卷积神经网络内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python3+PyQt5实现柱状图

    python3+PyQt5实现柱状图

    这篇文章主要为大家详细介绍了python3+PyQt5实现柱状图的方法,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-04-04
  • python入门while循环语句理解学习

    python入门while循环语句理解学习

    这篇文章主要介绍了python入门while循环语句理解学习,文中附含详细图文示例教程,有需要的朋友可以借鉴参考下,希望能够有所帮助
    2021-09-09
  • Python基础 括号()[]{}的详解

    Python基础 括号()[]{}的详解

    这篇文章主要介绍了Python基础 括号()、[]、{},下面文章将围绕这三个括号的相关解析展开内容,需要的朋友可以参考一下,洗碗粉对你有所帮助
    2021-11-11
  • Python抓取网页图片难点分析

    Python抓取网页图片难点分析

    没想到python是如此强大,令人着迷,以前看见图片总是一张一张复制粘贴,现在好了,学会python就可以用程序将一张张图片,保存下来。今天网上冲浪看到很多美图,可是图片有点多,不想一张一张地复制粘贴,怎么办呢?办法总是有的,即便没有我们也可以创造一个办法
    2023-01-01
  • Python视频剪辑Moviepy库使用教程

    Python视频剪辑Moviepy库使用教程

    这篇文章主要为大家介绍了Python视频剪辑Moviepy库使用教程详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-06-06
  • Python实现Window路径格式转换为Linux路径格式的代码

    Python实现Window路径格式转换为Linux路径格式的代码

    这篇文章主要介绍了Python实现Window路径格式转换为Linux路径格式的方法,文中通过代码示例讲解的非常详细,对大家的学习或工作有一定的帮助,需要的朋友可以参考下
    2024-07-07
  • 深入浅析Python的类

    深入浅析Python的类

    这篇文章是一篇关于python基础知识内容,主要讲述了关于类的相关知识点,有兴趣的朋友参考下。
    2018-06-06
  • Python入门之三角函数tan()函数实例详解

    Python入门之三角函数tan()函数实例详解

    这篇文章主要介绍了Python入门之三角函数tan()的相关内容,介绍了tan()函数的描述,语法以及简单实例,具有一定参考价值,需要的朋友可以了解下。
    2017-11-11
  • python GUI库图形界面开发之PyQt5布局控件QGridLayout详细使用方法与实例

    python GUI库图形界面开发之PyQt5布局控件QGridLayout详细使用方法与实例

    这篇文章主要介绍了python GUI库图形界面开发之PyQt5布局控件QGridLayout详细使用方法与实例,需要的朋友可以参考下
    2020-03-03
  • 5分钟教会你用Docker部署一个Python应用

    5分钟教会你用Docker部署一个Python应用

    Docker是一个开源项目,为开发人员和系统管理员提供了一个开放平台,可以将应用程序构建、打包为一个轻量级容器,并在任何地方运行,下面这篇文章主要给大家介绍了关于如何通过5分钟教会你用Docker部署一个Python应用,需要的朋友可以参考下
    2022-06-06

最新评论