Python实战小项目之Mnist手写数字识别

 更新时间:2021年10月20日 15:01:10   作者:GSAU-深蓝工作室  
MNIST 数据集已经是一个被”嚼烂”了的数据集, 很多教程都会对它”下手”, 几乎成为一个 “典范”. 不过有些人可能对它还不是很了解, 下面通过一个小实例来带你了解它

程序流程分析图:

传播过程:

代码展示:

创建环境

使用<pip install+包名>来下载torch,torchvision包

准备数据集

设置一次训练所选取的样本数Batch_Sized的值为512,训练此时Epochs的值为8

BATCH_SIZE = 512
EPOCHS = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

下载数据集

Normalize()数字归一化,转换使用的值0.1307和0.3081是MNIST数据集的全局平均值和标准偏差,这里我们将它们作为给定值。model

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.Compose([.
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=BATCH_SIZE, shuffle=True)

下载测试集

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=BATCH_SIZE, shuffle=True)

绘制图像

我们可以使用matplotlib来绘制其中的一些图像

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_targets)
print(example_data.shape)
print(example_data)
 
import matplotlib.pyplot as plt
fig = plt.figure()
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.tight_layout()
  plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
  plt.title("Ground Truth: {}".format(example_targets[i]))
  plt.xticks([])
  plt.yticks([])
plt.show()

搭建神经网络

这里我们构建全连接神经网络,我们使用三个全连接(或线性)层进行前向传播。

class linearNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.log_softmax(x, dim=1)
        return x

训练模型

首先,我们需要使用optimizer.zero_grad()手动将梯度设置为零,因为PyTorch在默认情况下会累积梯度。然后,我们生成网络的输出(前向传递),并计算输出与真值标签之间的负对数概率损失。现在,我们收集一组新的梯度,并使用optimizer.step()将其传播回每个网络参数。

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
 
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if (batch_idx) % 30 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))

测试模型

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加
            pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
            correct += pred.eq(target.view_as(pred)).sum().item()
 
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

将训练次数进行循环

if __name__ == '__main__':
    model = linearNet()
    optimizer = optim.Adam(model.parameters())
 
    for epoch in range(1, EPOCHS + 1):
        train(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)

保存训练模型

torch.save(model, 'MNIST.pth')

运行结果展示:

分享人:苏云云

到此这篇关于Python实战小项目之Mnist手写数字识别的文章就介绍到这了,更多相关Python Mnist手写数字识别内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python:socket传输大文件示例

    python:socket传输大文件示例

    本篇文章主要介绍了python:socket传输大文件示例,具有一定的参考价值,有兴趣的可以了解一下,
    2017-01-01
  • Python实现文件只读属性的设置与取消

    Python实现文件只读属性的设置与取消

    这篇文章主要为大家详细介绍了Python如何实现设置文件只读与取消文件只读的功能,文中的示例代码讲解详细,感兴趣的小伙伴可以了解一下
    2023-07-07
  • python实现向微信用户发送每日一句 python实现微信聊天机器人

    python实现向微信用户发送每日一句 python实现微信聊天机器人

    这篇文章主要为大家详细介绍了python实现向微信用户发送每日一句,python调实现微信聊天机器人,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-03-03
  • 一篇文章搞懂python混乱的切换操作与优雅的推导式

    一篇文章搞懂python混乱的切换操作与优雅的推导式

    这篇文章主要给大家介绍了如何通过一篇文章搞懂python混乱的切换操作与优雅的推导式的相关资料,文中通过示例代码介绍的非常详细,对大家的学习具有一定的参考学习价值,需要的朋友可以参考下
    2021-08-08
  • python去除拼音声调字母,替换为字母的方法

    python去除拼音声调字母,替换为字母的方法

    今天小编就为大家分享一篇python去除拼音声调字母,替换为字母的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-11-11
  • 在python中pandas读文件,有中文字符的方法

    在python中pandas读文件,有中文字符的方法

    今天小编就为大家分享一篇在python中pandas读文件,有中文字符的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • 浅谈Python数学建模之整数规划

    浅谈Python数学建模之整数规划

    整数规划并不一定是线性规划问题的变量取整限制,对于二次规划、非线性规划问题也有变量取整限制而引出的整数规划。但在数学建模问题中所说的整数规划,通常是指整数线性规划。整数规划与线性规划的差别只是变量的整数约束。选择简单通用的编程方案,让求解器去处理吧
    2021-06-06
  • Python装饰器简单用法实例小结

    Python装饰器简单用法实例小结

    这篇文章主要介绍了Python装饰器简单用法,结合实例形式总结分析了Python装饰器的基本功能、简单用法及相关操作注意事项,需要的朋友可以参考下
    2018-12-12
  • Pycharm中出现ImportError:DLL load failed:找不到指定模块的解决方法

    Pycharm中出现ImportError:DLL load failed:找不到指定模块的解决方法

    这篇文章主要介绍了Pycharm中出现ImportError:DLL load failed:找不到指定模块的解决方法,需要的朋友可以参考下
    2019-09-09
  • 一篇文章教你用python画动态爱心表白

    一篇文章教你用python画动态爱心表白

    这篇文章主要给大家介绍了关于如何用python画动态爱心表白的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-11-11

最新评论