PyTorch 如何将CIFAR100数据按类标归类保存
few-shot learning的采样
Few-shot learning 基于任务对模型进行训练,在N-way-K-shot中,一个任务中的meta-training中含有N类,每一类抽取K个样本构成support set, query set则是在刚才抽取的N类剩余的样本中sample一定数量的样本(可以是均匀采样,也可以是不均匀采样)。
对数据按类标归类
针对上述情况,我们需要使用不同类别放置在不同文件夹的数据集。但有时,数据并没有按类放置,这时就需要对数据进行处理。
下面以CIFAR100为列(不含N-way-k-shot的采样):
import os from skimage import io import torchvision as tv import numpy as np import torch def Cifar100(root): character = [[] for i in range(100)] train_set = tv.datasets.CIFAR100(root, train=True, download=True) test_set = tv.datasets.CIFAR100(root, train=False, download=True) dataset = [] for (X, Y) in zip(train_set.train_data, train_set.train_labels): # 将train_set的数据和label读入列表 dataset.append(list((X, Y))) for (X, Y) in zip(test_set.test_data, test_set.test_labels): # 将test_set的数据和label读入列表 dataset.append(list((X, Y))) for X, Y in dataset: character[Y].append(X) # 32*32*3 character = np.array(character) character = torch.from_numpy(character) # 按类打乱 np.random.seed(6) shuffle_class = np.arange(len(character)) np.random.shuffle(shuffle_class) character = character[shuffle_class] # shape = self.character.shape # self.character = self.character.view(shape[0], shape[1], shape[4], shape[2], shape[3]) # 将数据转成channel在前 meta_training, meta_validation, meta_testing = \ character[:64], character[64:80], character[80:] # meta_training : meta_validation : Meta_testing = 64类:16类:20类 dataset = [] # 释放内存 character = [] os.mkdir(os.path.join(root, 'meta_training')) for i, per_class in enumerate(meta_training): character_path = os.path.join(root, 'meta_training', 'character_' + str(i)) os.mkdir(character_path) for j, img in enumerate(per_class): img_path = character_path + '/' + str(j) + ".jpg" io.imsave(img_path, img) os.mkdir(os.path.join(root, 'meta_validation')) for i, per_class in enumerate(meta_validation): character_path = os.path.join(root, 'meta_validation', 'character_' + str(i)) os.mkdir(character_path) for j, img in enumerate(per_class): img_path = character_path + '/' + str(j) + ".jpg" io.imsave(img_path, img) os.mkdir(os.path.join(root, 'meta_testing')) for i, per_class in enumerate(meta_testing): character_path = os.path.join(root, 'meta_testing', 'character_' + str(i)) os.mkdir(character_path) for j, img in enumerate(per_class): img_path = character_path + '/' + str(j) + ".jpg" io.imsave(img_path, img) if __name__ == '__main__': root = '/home/xie/文档/datasets/cifar_100' Cifar100(root) print("-----------------")
补充:使用Pytorch对数据集CIFAR-10进行分类
主要是以下几个步骤:
1、下载并预处理数据集
2、定义网络结构
3、定义损失函数和优化器
4、训练网络并更新参数
5、测试网络效果
#数据加载和预处理 #使用CIFAR-10数据进行分类实验 import torch as t import torchvision as tv import torchvision.transforms as transforms from torchvision.transforms import ToPILImage show = ToPILImage() # 可以把Tensor转成Image,方便可视化 #定义对数据的预处理 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)), #归一化 ]) #训练集 trainset = tv.datasets.CIFAR10( root = './data/', train = True, download = True, transform = transform ) trainloader = t.utils.data.DataLoader( trainset, batch_size = 4, shuffle = True, num_workers = 2, ) #测试集 testset = tv.datasets.CIFAR10( root = './data/', train = False, download = True, transform = transform, ) testloader = t.utils.data.DataLoader( testset, batch_size = 4, shuffle = False, num_workers = 2, ) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
初次下载需要一些时间,运行结束后,显示如下:
import torch.nn as nn import torch.nn.functional as F import time start = time.time()#计时 #定义网络结构 class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.conv1 = nn.Conv2d(3,6,5) self.conv2 = nn.Conv2d(6,16,5) self.fc1 = nn.Linear(16*5*5,120) self.fc2 = nn.Linear(120,84) self.fc3 = nn.Linear(84,10) def forward(self,x): x = F.max_pool2d(F.relu(self.conv1(x)),2) x = F.max_pool2d(F.relu(self.conv2(x)),2) x = x.view(x.size()[0],-1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net = Net() print(net)
显示net结构如下:
#定义优化和损失 loss_func = nn.CrossEntropyLoss() #交叉熵损失函数 optimizer = t.optim.SGD(net.parameters(),lr = 0.001,momentum = 0.9) #训练网络 for epoch in range(2): running_loss = 0 for i,data in enumerate(trainloader,0): inputs,labels = data outputs = net(inputs) loss = loss_func(outputs,labels) optimizer.zero_grad() loss.backward() optimizer.step() running_loss +=loss.item() if i%2000 ==1999: print('epoch:',epoch+1,'|i:',i+1,'|loss:%.3f'%(running_loss/2000)) running_loss = 0.0 end = time.time() time_using = end - start print('finish training') print('time:',time_using)
结果如下:
下一步进行使用测试集进行网络测试:
#测试网络 correct = 0 #定义的预测正确的图片数 total = 0#总共图片个数 with t.no_grad(): for data in testloader: images,labels = data outputs = net(images) _,predict = t.max(outputs,1) total += labels.size(0) correct += (predict == labels).sum() print('测试集中的准确率为:%d%%'%(100*correct/total))
结果如下:
简单的网络训练确实要比10%的比例高一点:)
在GPU中训练:
#在GPU中训练 device = t.device('cuda:0' if t.cuda.is_available() else 'cpu') net.to(device) images = images.to(device) labels = labels.to(device) output = net(images) loss = loss_func(output,labels) loss
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
相关文章
django2用iframe标签完成网页内嵌播放b站视频功能
这篇文章主要介绍了django2 用iframe标签完成 网页内嵌播放b站视频功能,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧2018-06-06如何解决Pycharm运行报错No Python interpreter selected
这篇文章主要介绍了如何解决Pycharm运行时No Python interpreter selected问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教2024-05-05
最新评论