基于BCEWithLogitsLoss样本不均衡的处理方案

 更新时间:2021年05月13日 10:52:49   作者:ucas_fhx  
这篇文章主要介绍了BCEWithLogitsLoss样本不均衡的处理方案,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

最近在做deepfake检测任务(可以将其视为二分类问题,label为1和0),遇到了正负样本不均衡的问题,正样本数目是负样本的5倍,这样会导致FP率较高。

尝试将正样本的loss权重增高,看BCEWithLogitsLoss的源码

Examples::
 
    >>> target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
    >>> output = torch.full([10, 64], 0.999)  # A prediction (logit)
    >>> pos_weight = torch.ones([64])  # All weights are equal to 1
    >>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    >>> criterion(output, target)  # -log(sigmoid(0.999))
    tensor(0.3135)
 
Args:
    weight (Tensor, optional): a manual rescaling weight given to the loss
        of each batch element. If given, has to be a Tensor of size `nbatch`.
    size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
        the losses are averaged over each loss element in the batch. Note that for
        some losses, there are multiple elements per sample. If the field :attr:`size_average`
        is set to ``False``, the losses are instead summed for each minibatch. Ignored
        when reduce is ``False``. Default: ``True``
    reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
        losses are averaged or summed over observations for each minibatch depending
        on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
        batch element instead and ignores :attr:`size_average`. Default: ``True``
    reduction (string, optional): Specifies the reduction to apply to the output:
        ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
        ``'mean'``: the sum of the output will be divided by the number of
        elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
        and :attr:`reduce` are in the process of being deprecated, and in the meantime,
        specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
    pos_weight (Tensor, optional): a weight of positive examples.
            Must be a vector with length equal to the number of classes.

对其中的参数pos_weight的使用存在疑惑,BCEloss里的例子pos_weight = torch.ones([64]) # All weights are equal to 1,不懂为什么会有64个class,因为BCEloss是针对二分类问题的loss,后经过检索,得知还有多标签分类

多标签分类就是多个标签,每个标签有两个label(0和1),这类任务同样可以使用BCEloss。

现在讲一下BCEWithLogitsLoss里的pos_weight使用方法

比如我们有正负两类样本,正样本数量为100个,负样本为400个,我们想要对正负样本的loss进行加权处理,将正样本的loss权重放大4倍,通过这样的方式缓解样本不均衡问题。

criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([4]))
 
# pos_weight (Tensor, optional): a weight of positive examples.
#            Must be a vector with length equal to the number of classes.

pos_weight里是一个tensor列表,需要和标签个数相同,比如我们现在是二分类,只需要将正样本loss的权重写上即可。

如果是多标签分类,有64个标签,则

Examples::
 
    >>> target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
    >>> output = torch.full([10, 64], 0.999)  # A prediction (logit)
    >>> pos_weight = torch.ones([64])  # All weights are equal to 1
    >>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    >>> criterion(output, target)  # -log(sigmoid(0.999))
    tensor(0.3135)

补充:Pytorch —— BCEWithLogitsLoss()的一些问题

一、等价表达

1、pytorch:

torch.sigmoid() + torch.nn.BCELoss()

2、自己编写

def ce_loss(y_pred, y_train, alpha=1):
    
    p = torch.sigmoid(y_pred)
    # p = torch.clamp(p, min=1e-9, max=0.99)  
    loss = torch.sum(- alpha * torch.log(p) * y_train \
           - torch.log(1 - p) * (1 - y_train))/len(y_train)
    return loss~

3、验证

import torch
import torch.nn as nn
torch.cuda.manual_seed(300)       # 为当前GPU设置随机种子
torch.manual_seed(300)            # 为CPU设置随机种子
def ce_loss(y_pred, y_train, alpha=1):
   # 计算loss
   p = torch.sigmoid(y_pred)
   # p = torch.clamp(p, min=1e-9, max=0.99)
   loss = torch.sum(- alpha * torch.log(p) * y_train \
          - torch.log(1 - p) * (1 - y_train))/len(y_train)
   return loss
py_lossFun = nn.BCEWithLogitsLoss()
input = torch.randn((10000,1), requires_grad=True)
target = torch.ones((10000,1))
target.requires_grad_(True)
py_loss = py_lossFun(input, target)
py_loss.backward()
print("*********BCEWithLogitsLoss***********")
print("loss: ")
print(py_loss.item())
print("梯度: ")
print(input.grad)
input = input.detach()
input.requires_grad_(True)
self_loss = ce_loss(input, target)
self_loss.backward()
print("*********SelfCELoss***********")
print("loss: ")
print(self_loss.item())
print("梯度: ")
print(input.grad)

测试结果:

在这里插入图片描述

– 由上结果可知,我编写的loss和pytorch中提供的j基本一致。

– 但是仅仅这样就可以了吗?NO! 下面介绍BCEWithLogitsLoss()的强大之处:

– BCEWithLogitsLoss()具有很好的对nan的处理能力,对于我写的代码(四层神经网络,层之间的激活函数采用的是ReLU,输出层激活函数采用sigmoid(),由于数据处理的问题,所以会导致我们编写的CE的loss出现nan:原因如下:

–首先神经网络输出的pre_target较大,就会导致sigmoid之后的p为1,则torch.log(1 - p)为nan;

– 使用clamp(函数虽然会解除这个nan,但是由于在迭代过程中,网络输出可能越来越大(层之间使用的是ReLU),则导致我们写的loss陷入到某一个数值而无法进行优化。但是BCEWithLogitsLoss()对这种情况下出现的nan有很好的处理,从而得到更好的结果。

– 我此实验的目的是为了比较CE和FL的区别,自己编写FL,则必须也要自己编写CE,不能使用BCEWithLogitsLoss()。

二、使用场景

二分类 + sigmoid()

使用sigmoid作为输出层非线性表达的分类问题(虽然可以处理多分类问题,但是一般用于二分类,并且最后一层只放一个节点)

三、注意事项

输入格式

要求输入的input和target均为float类型

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python 探针的实现原理

    Python 探针的实现原理

    本文将简单讲述一下 Python 探针的实现原理。 同时为了验证这个原理,我们也会一起来实现一个简单的统计指定函数执行时间的探针程序。
    2016-04-04
  • python scipy卷积运算的实现方法

    python scipy卷积运算的实现方法

    这篇文章主要介绍了python scipy卷积运算的实现方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-09-09
  • python爬虫超时的处理的实例

    python爬虫超时的处理的实例

    今天小编就为大家分享一篇python爬虫超时的处理的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • 把大数据数字口语化(python与js)两种实现

    把大数据数字口语化(python与js)两种实现

    当出现万以上的整型数字时,经常要把它们口语化比较直观。下面分享两段代码,python与js的
    2013-02-02
  • Pycharm创建python文件自动添加日期作者等信息(步骤详解)

    Pycharm创建python文件自动添加日期作者等信息(步骤详解)

    这篇文章主要介绍了Pycharm创建python文件自动添加日期作者等信息(步骤详解),本文分步骤给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-02-02
  • python爬虫lxml库解析xpath网页过程示例

    python爬虫lxml库解析xpath网页过程示例

    这篇文章主要为大家介绍了python爬虫lxml库解析xpath网页的过程示例,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-05-05
  • Python实现修改文件内容的方法分析

    Python实现修改文件内容的方法分析

    这篇文章主要介绍了Python实现修改文件内容的方法,结合实例形式分析了Python文件读写、字符串替换及shell方法调用等相关操作技巧,需要的朋友可以参考下
    2018-03-03
  • python中join()方法介绍

    python中join()方法介绍

    Python join() 方法用于将序列中的元素以指定的字符连接生成一个新的字符串。这篇文章主要介绍了python中join()方法,需要的朋友可以参考下
    2018-10-10
  • python添加菜单图文讲解

    python添加菜单图文讲解

    在本篇文章中小编给大家整理的是关于python添加菜单图文讲解以及步骤分析,需要的朋友们学习下吧。
    2019-06-06
  • flask框架视图函数用法示例

    flask框架视图函数用法示例

    这篇文章主要介绍了flask框架视图函数用法,结合实例形式分析了flask框架视图函数常见配置与使用技巧,需要的朋友可以参考下
    2018-07-07

最新评论