Pytorch框架之one_hot编码函数解读
Pytorch one_hot编码函数解读
one_hot编码定义
在一个给定的向量中,按照设定的最值–可以是向量中包含的最大值(作为最高分类数),有也可以是自定义的最大值,设计one_hot编码的长度:最大值+1【详见举的例子吧】。
然后按照最大值创建一个1*(最大值+1)的维度大小的全零零向量:[0, 0, 0, …] => 共最大值+1对应的个数
接着按照向量中的值,从第0位开始索引,将向量中值对应的位置设置为1,其他保持为0.
eg:
假设设定one_hot长度为4(最大值) –
且当前向量中值为1对应的one_hot编码:
[0, 1, 0, 0]
当前向量中值为2对应的one_hot编码:
[0, 0, 1, 0]
eg:
假设设定one_hot长度为6(等价最大值+1) –
且当前向量中值为4对应的one_hot编码:
[0, 0, 0, 0, 1, 0]
当前向量中值为2对应的one_hot编码:
[0, 0, 1, 0, 0, 0]
eg:
targets = [4, 1, 0, 3] => max_value=4=>one_hot的长度为(4+1)
假设设定one_hot长度为5(最大值) –
且当前向量中值为4对应的one_hot编码:
[0, 0, 0, 0, 1]
当前向量中值为1对应的one_hot编码:
[0, 1, 0, 0, 0]
Pytorch中one_hot转换
import torch targets = torch.tensor([5, 3, 2, 1]) targets_to_one_hot = torch.nn.functional.one_hot(targets) # 默认按照targets其中的最大值+1作为one_hot编码的长度 # result: # tensor( # [0, 0, 0, 0, 0, 1], # [0, 0, 0, 1, 0, 0], # [0, 0, 1, 0, 0, 0], # [0, 1, 0, 0, 0, 0] #) targets_to_one_hot = torch.nn.functional.one_hot(targets, num_classes=7) 3# 指定one_hot编码长度为7 # result: # tensor( # [0, 0, 0, 0, 0, 1, 0], # [0, 0, 0, 1, 0, 0, 0], # [0, 0, 1, 0, 0, 0, 0], # [0, 1, 0, 0, 0, 0, 0] #)
总结:one_hot编码主要用于分类时,作为一个类别的编码–方便判别与相关计算;
1. 如同类别数统计,只需要将one_hot编码相加得到一个一维向量就知道了一批数据中所有类别的预测或真实的分布情况;
2. 相比于预测出具体的类别数–43等,用向量可以使用向量相关的算法进行时间上的优化等等
Pytorch变量类型转换及one_hot编码表示
生成张量
y = torch.empty(3, dtype=torch.long).random_(5) y = torch.Tensor(2,3).random_(10) y = torch.randn(3,4).random_(10)
查看类型
y.type y.dtype
类型转化
tensor.long()/int()/float() long(),int(),float() 实现类型的转化
One_hot编码表示
def one_hot(y): ''' y: (N)的一维tensor,值为每个样本的类别 out: y_onehot: 转换为one_hot 编码格式 ''' y = y.view(-1, 1) # y_onehot = torch.FloatTensor(3, 5) # y_onehot.zero_() y_onehot = torch.zeros(3,5) # 等价于上面 y_onehot.scatter_(1, y, 1) return y_onehot y = torch.empty(3, dtype=torch.long).random_(5) #标签 res = one_hot(y) # 转化为One_hot类型 # One_hot类型标签转化为整数型列表的两种方法 h = torch.argmax(res,dim=1) _,h1 = res.max(dim=1)
expand()函数
这个函数的作用就是对指定的维度进行数值大小的改变。只能改变维大小为1的维,否则就会报错。不改变的维可以传入-1或者原来的数值。
a=torch.randn(1,1,3,768) print(a.shape) #torch.Size([1, 1, 3, 768]) b=a.expand(2,-1,-1,-1) print(b.shape) #torch.Size([2, 1, 3, 768]) c=a.expand(2,1,3,768) print(c.shape) #torch.Size([2, 1, 3, 768])
repeat()函数
沿着指定的维度,对原来的tensor进行数据复制。这个函数和expand()还是有点区别的。expand()只能对维度为1的维进行扩大,而repeat()对所有的维度可以随意操作。
a=torch.randn(2,1,768) print(a) print(a.shape) #torch.Size([2, 1, 768]) b=a.repeat(1,2,1) print(b) print(b.shape) #torch.Size([2, 2, 768]) c=a.repeat(3,3,3) print(c) print(c.shape) #torch.Size([6, 3, 2304])
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
Python开发的单词频率统计工具wordsworth使用方法
wordsworth是字母,单词和n元组频率分析,用来分析文件中的单词出现频率的工具。2014-06-06使用Python编写一个在Linux下实现截图分享的脚本的教程
这篇文章主要介绍了使用Python编写一个在Linux下实现截图分享的脚本的教程,利用到了scrot和urllib2库,需要的朋友可以参考下2015-04-04
最新评论