python深度学习之多标签分类器及pytorch实现源码

 更新时间:2022年01月30日 09:14:08   作者:鬼道2022  
这篇文章主要为大家介绍了python深度学习之多标签分类器的使用说明及pytorch的实现源码,有需要的朋友可以借鉴参考下,希望能够有所帮助

多标签分类器

多标签分类任务与多分类任务有所不同,多分类任务是将一个实例分到某个类别中,多标签分类任务是将某个实例分到多个类别中。多标签分类任务有有两大特点:

  • 类标数量不确定,有些样本可能只有一个类标,有些样本的类标可能高达几十甚至上百个
  • 类标之间相互依赖,例如包含蓝天类标的样本很大概率上包含白云

如下图所示,即为一个多标签分类学习的一个例子,一张图片里有多个类别,房子,树,云等,深度学习模型需要将其一一分类识别出来。

多标签分类器损失函数

代码实现

针对图像的多标签分类器pytorch的简化代码实现如下所示。因为图像的多标签分类器的数据集比较难获取,所以可以通过对mnist数据集中的每个图片打上特定的多标签,例如类别1的多标签可以为[1,1,0,1,0,1,0,0,1],然后再利用重新打标后的数据集训练出一个mnist的多标签分类器。

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.Sq1 = nn.Sequential(         
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),   # (16, 28, 28)                           #  output: (16, 28, 28)
            nn.ReLU(),                    
            nn.MaxPool2d(kernel_size=2),    # (16, 14, 14)
        )
        self.Sq2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),  # (32, 14, 14)
            nn.ReLU(),                      
            nn.MaxPool2d(2),                # (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 100)  
    def forward(self, x):
        x = self.Sq1(x)
        x = self.Sq2(x)
        x = x.view(x.size(0), -1)    
        x = self.out(x)
        ## Sigmoid activation   
        output = F.sigmoid(x)  # 1/(1+e**(-x))
        return output
def loss_fn(pred, target):
    return -(target * torch.log(pred) + (1 - target) * torch.log(1 - pred)).sum()
def multilabel_generate(label):
    Y1 = F.one_hot(label, num_classes = 100)
    Y2 = F.one_hot(label+10, num_classes = 100)
    Y3 = F.one_hot(label+50, num_classes = 100) 	
    multilabel = Y1+Y2+Y3
    return multilabel
        
# def multilabel_generate(label):
# 	multilabel_dict = {}
# 	multi_list = []
# 	for i in range(label.shape[0]):
# 		multi_list.append(multilabel_dict[label[i].item()])
# 	multilabel_tensor = torch.tensor(multi_list)
#     return multilabel
def train():
    epoches = 10
    mnist_net = CNN()
    mnist_net.train()
    opitimizer = optim.SGD(mnist_net.parameters(), lr=0.002)
    mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= 128, shuffle=True)
    for epoch in range(epoches):
    	loss = 0 
    	for batch_X, batch_Y in train_loader:
    		opitimizer.zero_grad()
    		outputs = mnist_net(batch_X)
    		loss = loss_fn(outputs, multilabel_generate(batch_Y)) / batch_X.shape[0]
    		loss.backward()
    		opitimizer.step()
    		print(loss)
if __name__ == '__main__':
	train()

以上就是python深度学习之多标签分类器及pytorch源码的详细内容,更多关于多标签分类器pytorch源码的资料请关注脚本之家其它相关文章!

相关文章

  • 一篇文章带你了解谷歌这些大厂是怎么写 python 代码的

    一篇文章带你了解谷歌这些大厂是怎么写 python 代码的

    这篇文章主要介绍了谷歌这些大厂怎么写python代码,我们写代码,往往还是按照其它语言的思维习惯来写,那样的写法不仅运行速度慢,代码读起来也费尽,给人一种拖泥带水的感觉,需要的朋友可以参考下
    2021-09-09
  • 解析Python扩展模块的加速方案

    解析Python扩展模块的加速方案

    这章我们来介绍Python的扩展名之ctypes,教大家认识ctypes,有需要的朋友可以借鉴参考下,希望可以有所帮助,祝大家多多进步,早日升职加薪
    2021-09-09
  • Python 3.6 -win64环境安装PIL模块的教程

    Python 3.6 -win64环境安装PIL模块的教程

    PIL功能非常强大,但API却非常简单易用。这篇文章主要介绍了Python 3.6 -win64环境安装PIL模块的教程,需要的朋友可以参考下
    2019-06-06
  • Python实现贪吃蛇小游戏(单人模式)

    Python实现贪吃蛇小游戏(单人模式)

    这篇文章主要为大家详细介绍了Python实现单人模式的贪吃蛇小游戏,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-09-09
  • Flask 数据库集成的介绍

    Flask 数据库集成的介绍

    这篇文章主要给大家分享了Flask 数据库集成的介绍,数据库是大多数 Web 应用的基础设施,只要想把数据存储下来,就离不开数据库,下面将一起学习一下如何给 Flask 应用添加数据库支持。下面详细内容,需要的朋友可以参考一下
    2021-11-11
  • Python开多次方根的案例

    Python开多次方根的案例

    这篇文章主要介绍了Python开多次方根的案例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-03-03
  • 怎么用Python识别手势数字

    怎么用Python识别手势数字

    今天给大家带来的文章是怎么用Python识别手势数字,文中有非常详细的图文示例,对正在学习python的小伙伴们很有帮助,需要的朋友可以参考下
    2021-06-06
  • Python sklearn 中的 make_blobs() 函数示例详解

    Python sklearn 中的 make_blobs() 函数示例详解

    make_blobs() 是 sklearn.datasets中的一个函数,这篇文章主要介绍了Python sklearn 中的 make_blobs() 函数,本文结合实例代码给大家介绍的非常详细,需要的朋友可以参考下
    2023-02-02
  • Pandas拼接concat使用方法

    Pandas拼接concat使用方法

    当我们需要将两个Pandas DataFrame对象合并为一个时,就需要使用Pandas拼接函数,本文主要介绍了Pandas拼接concat使用方法,感兴趣的可以了解一下
    2023-12-12
  • Python基本知识之datetime模块详解

    Python基本知识之datetime模块详解

    这篇文章主要给大家介绍了关于Python基本知识之datetime模块的相关资料,Python内置的时间模块datetime包含下面的模块包含六个类和两个常数,提供了用于处理日期和时间的类和对应的方法,一般用于处理年、月、日、时、分、秒的统计和计算等需求,需要的朋友可以参考下
    2023-08-08

最新评论