关于pytorch处理类别不平衡的问题

 更新时间:2019年12月31日 09:09:22   作者:NAAE  
今天小编就为大家分享一篇关于pytorch处理类别不平衡的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

当训练样本不均匀时,我们可以采用过采样、欠采样、数据增强等手段来避免过拟合。今天遇到一个3d点云数据集合,样本分布极不均匀,正例与负例相差4-5个数量级。数据增强效果就不会太好了,另外过采样也不太合适,因为是空间数据,新增的点有可能会对真实分布产生未知影响。所以采用欠采样来缓解类别不平衡的问题。

下面的代码展示了如何使用WeightedRandomSampler来完成抽样。

numDataPoints = 1000
data_dim = 5
bs = 100

# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
     np.ones(int(numDataPoints * 0.1), dtype=np.int32)))

print 'target train 0/1: {}/{}'.format(
 len(np.where(target == 0)[0]), len(np.where(target == 1)[0]))

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = DataLoader(
 train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
 print "batch index {}, 0/1: {}/{}".format(
  i,
  len(np.where(target.numpy() == 0)[0]),
  len(np.where(target.numpy() == 1)[0]))

核心部分为实际使用时替换下变量把sampler传递给DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定采样点个数:

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

参考:https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2

以上这篇关于pytorch处理类别不平衡的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python之os操作方法(详解)

    Python之os操作方法(详解)

    下面小编就为大家带来一篇Python之os操作方法(详解)。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-06-06
  • python 实现将txt文件多行合并为一行并将中间的空格去掉方法

    python 实现将txt文件多行合并为一行并将中间的空格去掉方法

    今天小编就为大家分享一篇python 实现将txt文件多行合并为一行并将中间的空格去掉方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • python tkinterEntry组件设置默认值方式

    python tkinterEntry组件设置默认值方式

    使用Tkinter库中的Entry组件创建文本输入框时,可以通过insert方法在指定位置插入默认文本作为提示,结合使用focus和focusin事件,可以实现用户点击时清除默认文本,以便输入自定义内容
    2024-09-09
  • Python内置的HTTP协议服务器SimpleHTTPServer使用指南

    Python内置的HTTP协议服务器SimpleHTTPServer使用指南

    这篇文章主要介绍了Python内置的HTTP协议服务器SimpleHTTPServer使用指南,SimpleHTTPServer本身的功能十分简单,文中介绍了需要的朋友可以参考下
    2016-03-03
  • Python 实例方法、类方法、静态方法的区别与作用

    Python 实例方法、类方法、静态方法的区别与作用

    Python中至少有三种比较常见的方法类型,即实例方法,类方法、静态方法。它们是如何定义的呢?如何调用的呢?它们又有何区别和作用呢?感兴趣的朋友跟随小编一起看看吧
    2019-08-08
  • pytest中文文档之编写断言

    pytest中文文档之编写断言

    这篇文章主要给大家介绍了关于pytest中文文档之编写断言的相关资料,文中通过示例代码介绍的非常详细,对大家学习或者使用pytest具有一定的参考学习价值,需要的朋友们下面来一起学习学习吧
    2019-09-09
  • python去掉字符串中重复字符的方法

    python去掉字符串中重复字符的方法

    这篇文章主要介绍了python去掉字符串中重复字符的方法,需要的朋友可以参考下
    2014-02-02
  • python之数字图像处理方式

    python之数字图像处理方式

    这篇文章主要介绍了python之数字图像处理方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-05-05
  • Python模拟登录微博并爬取表情包

    Python模拟登录微博并爬取表情包

    前段时间爬取的知乎表情包用完了吗?今天再带大家去微博爬一波表情包吧.文中有非常详细的代码示例,废话不多说,让我们愉快地开始吧,需要的朋友可以参考下
    2021-06-06
  • PYTHON正则表达式 re模块使用说明

    PYTHON正则表达式 re模块使用说明

    正则表达式是一个复杂的主题。本文能否有助于你理解呢?那些部分是否不清晰,或在这儿没有找到你所遇到的问题?如果是那样的话,请将建议发给作者以便改进
    2011-05-05

最新评论