PyTorch 如何将CIFAR100数据按类标归类保存

 更新时间:2021年05月10日 09:12:51   作者:Xie_learning  
这篇文章主要介绍了PyTorch 将CIFAR100数据按类标归类保存的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

few-shot learning的采样

Few-shot learning 基于任务对模型进行训练,在N-way-K-shot中,一个任务中的meta-training中含有N类,每一类抽取K个样本构成support set, query set则是在刚才抽取的N类剩余的样本中sample一定数量的样本(可以是均匀采样,也可以是不均匀采样)。

对数据按类标归类

针对上述情况,我们需要使用不同类别放置在不同文件夹的数据集。但有时,数据并没有按类放置,这时就需要对数据进行处理。

下面以CIFAR100为列(不含N-way-k-shot的采样):

import os
from skimage import io
import torchvision as tv
import numpy as np
import torch
def Cifar100(root):
    character = [[] for i in range(100)]
    train_set = tv.datasets.CIFAR100(root, train=True, download=True)
    test_set = tv.datasets.CIFAR100(root, train=False, download=True)
    dataset = []
    for (X, Y) in zip(train_set.train_data, train_set.train_labels):  # 将train_set的数据和label读入列表
        dataset.append(list((X, Y)))
    for (X, Y) in zip(test_set.test_data, test_set.test_labels):  # 将test_set的数据和label读入列表
        dataset.append(list((X, Y)))
    for X, Y in dataset:
        character[Y].append(X)  # 32*32*3
    character = np.array(character)
    character = torch.from_numpy(character)
    # 按类打乱
    np.random.seed(6)
    shuffle_class = np.arange(len(character))
    np.random.shuffle(shuffle_class)
    character = character[shuffle_class]
    # shape = self.character.shape
    # self.character = self.character.view(shape[0], shape[1], shape[4], shape[2], shape[3])  # 将数据转成channel在前
    meta_training, meta_validation, meta_testing = \
    character[:64], character[64:80], character[80:]  # meta_training : meta_validation : Meta_testing = 64类:16类:20类
    dataset = []  # 释放内存
    character = []
    os.mkdir(os.path.join(root, 'meta_training'))
    for i, per_class in enumerate(meta_training):
        character_path = os.path.join(root, 'meta_training', 'character_' + str(i))
        os.mkdir(character_path)
        for j, img in enumerate(per_class):
            img_path = character_path + '/' + str(j) + ".jpg"
            io.imsave(img_path, img)
    os.mkdir(os.path.join(root, 'meta_validation'))
    for i, per_class in enumerate(meta_validation):
        character_path = os.path.join(root, 'meta_validation', 'character_' + str(i))
        os.mkdir(character_path)
        for j, img in enumerate(per_class):
            img_path = character_path + '/' + str(j) + ".jpg"
            io.imsave(img_path, img)
    os.mkdir(os.path.join(root, 'meta_testing'))
    for i, per_class in enumerate(meta_testing):
        character_path = os.path.join(root, 'meta_testing', 'character_' + str(i))
        os.mkdir(character_path)
        for j, img in enumerate(per_class):
            img_path = character_path + '/' + str(j) + ".jpg"
            io.imsave(img_path, img)
if __name__ == '__main__':
    root = '/home/xie/文档/datasets/cifar_100'
    Cifar100(root)
    print("-----------------")

补充:使用Pytorch对数据集CIFAR-10进行分类

主要是以下几个步骤:

1、下载并预处理数据集

2、定义网络结构

3、定义损失函数和优化器

4、训练网络并更新参数

5、测试网络效果

#数据加载和预处理
#使用CIFAR-10数据进行分类实验
import torch as t
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage() # 可以把Tensor转成Image,方便可视化
 
#定义对数据的预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),  #归一化
])
 
#训练集
trainset = tv.datasets.CIFAR10(
    root = './data/',
    train = True,
    download = True,
    transform = transform
)
 
trainloader = t.utils.data.DataLoader(
    trainset,
    batch_size = 4,
    shuffle = True,
    num_workers = 2,
)
 
#测试集
testset = tv.datasets.CIFAR10(
    root = './data/',
    train = False,
    download = True,
    transform = transform,
)
testloader = t.utils.data.DataLoader(
    testset,
    batch_size = 4,
    shuffle = False,
    num_workers = 2,
)
 
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

初次下载需要一些时间,运行结束后,显示如下:

import torch.nn as nn
import torch.nn.functional as F
import time
start = time.time()#计时
#定义网络结构
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(3,6,5)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16*5*5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
        
    def forward(self,x):
        x = F.max_pool2d(F.relu(self.conv1(x)),2)
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        
        x = x.view(x.size()[0],-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
net = Net()
print(net)

显示net结构如下:

#定义优化和损失
loss_func = nn.CrossEntropyLoss()  #交叉熵损失函数
optimizer = t.optim.SGD(net.parameters(),lr = 0.001,momentum = 0.9)
 
#训练网络
for epoch in range(2):
    running_loss = 0
    for i,data in enumerate(trainloader,0):
        inputs,labels = data
       
        outputs = net(inputs)
        loss = loss_func(outputs,labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss +=loss.item()
        if i%2000 ==1999:
            print('epoch:',epoch+1,'|i:',i+1,'|loss:%.3f'%(running_loss/2000))
            running_loss = 0.0
end = time.time()
time_using = end - start
print('finish training')
print('time:',time_using)

结果如下:

下一步进行使用测试集进行网络测试:

#测试网络
correct = 0 #定义的预测正确的图片数
total = 0#总共图片个数
with t.no_grad():
    for data in testloader:
        images,labels = data
        outputs = net(images)
        _,predict = t.max(outputs,1)
        total += labels.size(0)
        correct += (predict == labels).sum()
print('测试集中的准确率为:%d%%'%(100*correct/total))

结果如下:

简单的网络训练确实要比10%的比例高一点:)

在GPU中训练:

#在GPU中训练
device = t.device('cuda:0' if t.cuda.is_available() else 'cpu')
 
net.to(device)
images = images.to(device)
labels = labels.to(device)
 
output = net(images)
loss = loss_func(output,labels)
 
loss

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。

相关文章

  • django2用iframe标签完成网页内嵌播放b站视频功能

    django2用iframe标签完成网页内嵌播放b站视频功能

    这篇文章主要介绍了django2 用iframe标签完成 网页内嵌播放b站视频功能,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-06-06
  • python 多线程爬取壁纸网站的示例

    python 多线程爬取壁纸网站的示例

    这篇文章主要介绍了python 多线程爬取壁纸网站的示例,帮助大家更好的理解和学习使用python,感兴趣的朋友可以了解下
    2021-02-02
  • Python实现数据库编程方法详解

    Python实现数据库编程方法详解

    这篇文章主要介绍了Python实现数据库编程方法,较为详细的总结了Python数据库编程涉及的各种常用技巧与相关组件,需要的朋友可以参考下
    2015-06-06
  • Python如何读写字节数据

    Python如何读写字节数据

    这篇文章主要介绍了Python如何读写字节数据,文中讲解非常细致,代码帮助大家更好的理解和学习,感兴趣的朋友可以了解下
    2020-08-08
  • 如何解决Pycharm运行报错No Python interpreter selected问题

    如何解决Pycharm运行报错No Python interpreter selected

    这篇文章主要介绍了如何解决Pycharm运行时No Python interpreter selected问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-05-05
  • 使用Python脚本备份华为交换机的配置信息

    使用Python脚本备份华为交换机的配置信息

    在现代网络管理中,备份交换机的配置信息是一项至关重要的任务,备份可以确保在交换机发生故障或配置错误时,能够迅速恢复到之前的工作状态,本文将详细介绍如何使用Python脚本备份华为交换机的配置信息,需要的朋友可以参考下
    2024-06-06
  • python类继承与子类实例初始化用法分析

    python类继承与子类实例初始化用法分析

    这篇文章主要介绍了python类继承与子类实例初始化用法,实例分析了Python类的使用技巧,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-04-04
  • Python标准库之os模块详解

    Python标准库之os模块详解

    Python的os模块是用于与操作系统进行交互的模块,它提供了许多函数和方法来执行文件和目录操作、进程管理、环境变量访问等,本文详细介绍了Python标准库中os模块,感兴趣的同学跟着小编一起来看看吧
    2023-08-08
  • python调用函数、类和文件操作简单实例总结

    python调用函数、类和文件操作简单实例总结

    这篇文章主要介绍了python调用函数、类和文件操作,结合简单实例形式总结分析了Python调用函数、类和文件操作的各种常见操作技巧,需要的朋友可以参考下
    2019-11-11
  • python高级搜索实现高效搜索GitHub资源

    python高级搜索实现高效搜索GitHub资源

    这篇文章主要为大家介绍了python高级搜索来高效搜索GitHub,从而高效获取所需资源,有需要的朋友可以借鉴参考下,希望能够有所帮助
    2021-11-11

最新评论