pytorch加载自己的图片数据集的2种方法详解

 更新时间:2022年06月11日 11:19:30   作者:_-周-_  
数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力,下面这篇文章主要给大家介绍了关于pytorch加载自己的图片数据集的2种方法,文中通过示例代码介绍的非常详细,需要的朋友可以参考下

pytorch加载图片数据集有两种方法。

1.ImageFolder 适合于分类数据集,并且每一个类别的图片在同一个文件夹, ImageFolder加载的数据集, 训练数据为文件件下的图片, 训练标签是对应的文件夹, 每个文件夹为一个类别

导入ImageFolder()包
from torchvision.datasets import ImageFolder

在Flower_Orig_dataset文件夹下有flower_orig 和 sunflower这两个文件夹, 这两个文件夹下放着同一个类别的图片。 使用 ImageFolder 加载的图片, 就会返回图片信息和对应的label信息, 但是label信息是根据文件夹给出的, 如flower_orig就是标签0, sunflower就是标签1。

ImageFolder 加载数据集

1. 导入包和设置transform

import torch
from torchvision import transforms, datasets
import torch.nn as nn
from torch.utils.data import DataLoader
 
transforms = transforms.Compose([
    transforms.Resize(256),    # 将图片短边缩放至256,长宽比保持不变:
    transforms.CenterCrop(224),   #将图片从中心切剪成3*224*224大小的图片
    transforms.ToTensor()          #把图片进行归一化,并把数据转换成Tensor类型
]) 

2.加载数据集: 将分类图片的父目录作为路径传递给ImageFolder(), 并传入transform。这样就有了要加载的数据集, 之后就可以使用DataLoader加载数据, 并构建网络训练。

path = r'D:\数据集\Flower_Orig_dataset'
data_train = datasets.ImageFolder(path, transform=transforms)
data_loader = DataLoader(data_train, batch_size=64, shuffle=True)
for i, data in enumerate(data_loader):
    images, labels = data
    print(images.shape)
    print(labels.shape)
    break

使用pytorch提供的Dataset类创建自己的数据集。

具体步骤:

1.  首先要有一个txt文件, 这个文件格式是: 图片路径  标签.  这样的格式, 所以使用os库, 遍历自己的图片名, 并把标签和图片路径写入txt文件。

2. 有了这个txt文件, 我们就可以在类里面构造我们的数据集.

2.1    把图片路径和图片标签分割开, 有两个列表, 一个列表是图片路径名, 一个列表是标签号, 有一点就是第 i 个图片列表和 第 i 个标签是对应的

3. 重写__len__方法  和  __getitem__方法

3.1 getitem方法中, 获得对应的图片路径,并用PIL库读取文件把图片transfrom后, 在getitem函数中返回读取的图片和标签即可

4.就可以构建数据集实例和加载数据集.

 定义一个用来生成[ 图片路径 标签] 这样的txt文件函数

def make_txt(root, file_name, label):
    path = os.path.join(root, file_name)
    data = os.listdir(path)
    f = open(path+'\\'+'f.txt', 'w')
    for line in data:
        f.write(line+' '+str(label)+'\n')
    f.close()
#调用函数生成两个文件夹下的txt文件
make_txt(path, file_name='flower_orig', label=0)
make_txt(path, file_name='sunflower', label=1)

将连个txt文件合并成一个txt文件,表示数据集所有的图片和标签

def link_txt(file1, file2):
    txt_list = []
    path = r'D:\数据集\Flower_Orig_dataset\data.txt'
 
    f = open(path, 'a')
 
    f1 = open(file1, 'r')
    data1 = f1.readlines()
    for line in data1:
        txt_list.append(line)
 
    f2 = open(file2, 'r')
    data2 = f2.readlines()
    for line in data2:
        txt_list.append(line)
 
    for line in txt_list:
        f.write(line)
 
    f.close()
    f1.close()
    f2.close()
 
#调用函数, 将两个文件夹下的txt文件合并
file1 = r'D:\数据集\Flower_Orig_dataset\flower_orig\f.txt'
file2 = r'D:\数据集\Flower_Orig_dataset\sunflower\f.txt'
link_txt(file1=file1, file2=file2)

现在我们已经有了我们制作数据集所需要的txt文件, 接下来要做的即使继承Dataset类, 来构建自己的数据集 , 别忘了前面说的 构建数据集步骤, 在__getitem__函数中, 需要拿到图片路径和标签, 并且用PIL库方法读取图片,对图片进行transform转换后,返回图片信息和标签信息

Dataset加载数据集

我们读取图片的根目录, 在根目录下有所有图片的txt文件, 拿到txt文件后, 先读取txt文件, 之后遍历txt文件中的每一行, 首先去除掉尾部的换行符, 在以空格切分,前半部分是图片名称, 后半部分是图片标签, 当图片名称和根目录结合,就得到了我们的图片路径   
class MyDataset(Dataset):
    def __init__(self, img_path, transform=None):
        super(MyDataset, self).__init__()
        self.root = img_path
 
        self.txt_root = self.root + 'data.txt'
        f = open(self.txt_root, 'r')
        data = f.readlines()
 
        imgs = []
        labels = []
        for line in data:
            line = line.rstrip()
            word = line.split()
            imgs.append(os.path.join(self.root, word[1], word[0]))
 
            labels.append(word[1])
        self.img = imgs
        self.label = labels
        self.transform = transform
 
    def __len__(self):
        return len(self.label)
 
    def __getitem__(self, item):
        img = self.img[item]
        label = self.label[item]
 
        img = Image.open(img).convert('RGB')
 
        #此时img是PIL.Image类型   label是str类型
 
        if transforms is not None:
            img = self.transform(img)
 
        label = np.array(label).astype(np.int64)
        label = torch.from_numpy(label)
        
        return img, label

 加载我们的数据集:

path = r'D:\数据集\Flower_Orig_dataset'
dataset = MyDataset(path, transform=transform)
 
data_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)

接下来我们就可以构建我们的网络架构:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,16,3)
        self.maxpool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(16,5,3)
 
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(55*55*5, 1200)
        self.fc2 = nn.Linear(1200,64)
        self.fc3 = nn.Linear(64,2)
 
    def forward(self,x):
        x = self.maxpool(self.relu(self.conv1(x)))    #113
        x = self.maxpool(self.relu(self.conv2(x)))    #55
        x = x.view(-1, self.num_flat_features(x))
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    
    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
 
        return num_features
 

 训练我们的网络:

model = Net()
 
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
 
 
epochs = 10
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(data_loader):
        images, label = data
 
        out = model(images)
 
        loss = criterion(out, label)
 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
        running_loss += loss.item()
        if(i+1)%10 == 0:
            print('[%d  %5d]   loss: %.3f'%(epoch+1, i+1, running_loss/100))
            running_loss = 0.0
 
print('finished train')

 保存网络模型(这里不止是保存参数,还保存了网络结构)

#保存模型
torch.save(net, 'model_name.pth')   #保存的是模型, 不止是w和b权重值
 
# 读取模型
model = torch.load('model_name.pth')

总结

到此这篇关于pytorch加载自己的图片数据集的2种方法的文章就介绍到这了,更多相关pytorch加载图片数据集内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python-redis-lock实现锁自动续期的源码逻辑

    python-redis-lock实现锁自动续期的源码逻辑

    这篇文章主要介绍了python-redis-lock实现锁自动续期的源码逻辑,其中用到了多线程threading、弱引用weakref和Lua脚本等相关知识,需要的朋友可以参考下
    2024-07-07
  • python读写json文件的简单实现

    python读写json文件的简单实现

    这篇文章主要介绍了python读写json文件的简单实现,实例分析了各种读写json的方法和技巧,有兴趣的可以了解一下
    2017-04-04
  • 详解python 破解网站反爬虫的两种简单方法

    详解python 破解网站反爬虫的两种简单方法

    这篇文章主要介绍了详解python 破解网站反爬虫的两种简单方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-02-02
  • Python+Selenium随机生成手机验证码并检查页面上是否弹出重复手机号码提示框

    Python+Selenium随机生成手机验证码并检查页面上是否弹出重复手机号码提示框

    这篇文章主要介绍了Python+Selenium随机生成手机验证码并检查页面上是否弹出重复手机号码提示框,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-09-09
  • 了解一下python内建模块collections

    了解一下python内建模块collections

    这篇文章主要介绍了Python内建模块——collections的相关资料,帮助大家更好的理解和使用python,感兴趣的朋友可以了解下
    2020-09-09
  • Ubuntu下升级 python3.7.1流程备忘(推荐)

    Ubuntu下升级 python3.7.1流程备忘(推荐)

    这篇文章主要介绍了Ubuntu下升级 python3.7.1流程备忘,本文通过实例代码给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2018-12-12
  • python中绕过反爬虫的方法总结

    python中绕过反爬虫的方法总结

    在本篇文章里小编给大家整理的是一篇关于python中绕过反爬虫的方法总结内容,需要的朋友们可以参考下。
    2020-11-11
  • Python读取csv、Excel文件生成图表的方法

    Python读取csv、Excel文件生成图表的方法

    这篇文章主要介绍了Python读取csv、Excel文件生成图表,本文通过示例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2023-07-07
  • 详解tensorflow之过拟合问题实战

    详解tensorflow之过拟合问题实战

    这篇文章主要介绍了详解tensorflow之过拟合问题实战,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-11-11
  • Python qrcode 生成一个二维码的实例详解

    Python qrcode 生成一个二维码的实例详解

    在本篇文章里小编给大家整理的是关于Python qrcode 生成一个二维码的实例内容,需要的朋友们可以学习参考下。
    2020-02-02

最新评论