pytorch实现图像识别(实战)

 更新时间:2022年02月18日 08:41:51   作者:AI AX AT  
这篇文章主要介绍了pytorch实现图像识别(实战),文章主要分享实现代码,但也具有一定的参考价值,需要的小伙伴可以才可以一下,希望对你有所帮助

1. 代码讲解

1.1 导库

import os.path
from os import listdir
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import AdaptiveAvgPool2d
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split

1.2 标准化、transform、设置GPU

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
normalize = transforms.Normalize(
   mean=[0.485, 0.456, 0.406],
   std=[0.229, 0.224, 0.225]
)
transform = transforms.Compose([transforms.ToTensor(), normalize])  # 转换

1.3 预处理数据

class DogDataset(Dataset):
# 定义变量
    def __init__(self, img_paths, img_labels, size_of_images):  
        self.img_paths = img_paths
        self.img_labels = img_labels
        self.size_of_images = size_of_images

# 多少长图片
    def __len__(self):
        return len(self.img_paths)

# 打开每组图片并处理每张图片
    def __getitem__(self, index):
        PIL_IMAGE = Image.open(self.img_paths[index]).resize(self.size_of_images)
        TENSOR_IMAGE = transform(PIL_IMAGE)
        label = self.img_labels[index]
        return TENSOR_IMAGE, label


print(len(listdir(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\train')))
print(len(pd.read_csv(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\labels.csv')))
print(len(listdir(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\test')))

train_paths = []
test_paths = []
labels = []
# 训练集图片路径
train_paths_lir = r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\train'
for path in listdir(train_paths_lir):
    train_paths.append(os.path.join(train_paths_lir, path))  
# 测试集图片路径
labels_data = pd.read_csv(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\labels.csv')
labels_data = pd.DataFrame(labels_data)  
# 把字符标签离散化,因为数据有120种狗,不离散化后面把数据给模型时会报错:字符标签过多。把字符标签从0-119编号
size_mapping = {}
value = 0
size_mapping = dict(labels_data['breed'].value_counts())
for kay in size_mapping:
    size_mapping[kay] = value
    value += 1
# print(size_mapping)
labels = labels_data['breed'].map(size_mapping)
labels = list(labels)
# print(labels)
print(len(labels))
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(train_paths, labels, test_size=0.2)

train_set = DogDataset(X_train, y_train, (32, 32))
test_set = DogDataset(X_test, y_test, (32, 32))

train_loader = torch.utils.data.DataLoader(train_set, batch_size=64)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64)

1.4 建立模型

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),  
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 120)
        )

    def forward(self, x):
        batch_size = x.shape[0]
        x = self.features(x)
        x = x.view(batch_size, -1)
        x = self.classifier(x)
        return x


model = LeNet().to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters())
TRAIN_LOSS = []  # 损失
TRAIN_ACCURACY = []  # 准确率

1.5 训练模型

def train(epoch):
    model.train()
    epoch_loss = 0.0 # 损失
    correct = 0  # 精确率
    for batch_index, (Data, Label) in enumerate(train_loader):
    # 扔到GPU中
        Data = Data.to(device)
        Label = Label.to(device)
        output_train = model(Data)
    # 计算损失
        loss_train = criterion(output_train, Label)
        epoch_loss = epoch_loss + loss_train.item()
    # 计算精确率
        pred = torch.max(output_train, 1)[1]
        train_correct = (pred == Label).sum()
        correct = correct + train_correct.item()
    # 梯度归零、反向传播、更新参数
        optimizer.zero_grad()
        loss_train.backward()
        optimizer.step()
    print('Epoch: ', epoch, 'Train_loss: ', epoch_loss / len(train_set), 'Train correct: ', correct / len(train_set))

1.6 测试模型

和训练集差不多。

def test():
    model.eval()
    correct = 0.0
    test_loss = 0.0
    with torch.no_grad():
        for Data, Label in test_loader:
            Data = Data.to(device)
            Label = Label.to(device)
            test_output = model(Data)
            loss = criterion(test_output, Label)
            pred = torch.max(test_output, 1)[1]
            test_correct = (pred == Label).sum()
            correct = correct + test_correct.item()
            test_loss = test_loss + loss.item()
    print('Test_loss: ', test_loss / len(test_set), 'Test correct: ', correct / len(test_set))

1.7结果

epoch = 10
for n_epoch in range(epoch):
    train(n_epoch)
test()

到此这篇关于pytorch实现图像识别(实战)的文章就介绍到这了,更多相关pytorch实现图像识别内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python列表排序用 sort()和sorted()的区别

    python列表排序用 sort()和sorted()的区别

    这篇文章主要介绍了python列表排序用 sort()和sorted()的区别,主要比较 Python 中用于列表排序的两种函数 sort() 和 sorted(),选择合适的排序函数,下文详细内容需要的小伙伴可以参考一下
    2022-03-03
  • Python csv文件记录流程代码解析

    Python csv文件记录流程代码解析

    这篇文章主要介绍了Python csv文件记录流程代码解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-07-07
  • 200个Python 标准库总结

    200个Python 标准库总结

    这篇文章主要给大家分享了200个Python 标准库总结,主要对文本、数据类型、数学等多个类型总结,既有一定的参考价值,需要的小伙伴可以参考一下
    2022-01-01
  • Python多层嵌套list的递归处理方法(推荐)

    Python多层嵌套list的递归处理方法(推荐)

    下面小编就为大家带来一篇Python多层嵌套list的递归处理方法(推荐)。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2016-06-06
  • Python获取指定字符前面的所有字符方法

    Python获取指定字符前面的所有字符方法

    下面小编就为大家分享一篇Python获取指定字符前面的所有字符方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-05-05
  • Python字典操作简明总结

    Python字典操作简明总结

    这篇文章主要介绍了Python字典操作简明总结,本文总结了创建字典 、创建一个"默认"字典、遍历字典、获得value值、成员操作符:in或not in 、更新字典、删除字典等常用操作,需要的朋友可以参考下
    2015-04-04
  • Python搭建代理IP池实现存储IP的方法

    Python搭建代理IP池实现存储IP的方法

    这篇文章主要介绍了Python搭建代理IP池实现存储IP的方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-10-10
  • python3正则提取字符串里的中文实例

    python3正则提取字符串里的中文实例

    今天小编就为大家分享一篇python3正则提取字符串里的中文实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-01-01
  • 老生常谈python的私有公有属性(必看篇)

    老生常谈python的私有公有属性(必看篇)

    下面小编就为大家带来一篇老生常谈python的私有公有属性(必看篇)。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-06-06
  • python中关于py文件之间相互import的问题及解决方法

    python中关于py文件之间相互import的问题及解决方法

    这篇文章主要介绍了python中关于py文件之间相互import的问题,本文用一个例子演示下如何解决python中循环引用的问题,需要的朋友可以参考下
    2022-02-02

最新评论