Pytorch中torch.flatten()和torch.nn.Flatten()实例详解
torch.flatten(x)等于torch.flatten(x,0)默认将张量拉成一维的向量,也就是说从第一维开始平坦化,torch.flatten(x,1)代表从第二维开始平坦化。
import torch x=torch.randn(2,4,2) print(x) z=torch.flatten(x) print(z) w=torch.flatten(x,1) print(w) 输出为: tensor([[[-0.9814, 0.8251], [ 0.8197, -1.0426], [-0.8185, -1.3367], [-0.6293, 0.6714]], [[-0.5973, -0.0944], [ 0.3720, 0.0672], [ 0.2681, 1.8025], [-0.0606, 0.4855]]]) tensor([-0.9814, 0.8251, 0.8197, -1.0426, -0.8185, -1.3367, -0.6293, 0.6714, -0.5973, -0.0944, 0.3720, 0.0672, 0.2681, 1.8025, -0.0606, 0.4855]) tensor([[-0.9814, 0.8251, 0.8197, -1.0426, -0.8185, -1.3367, -0.6293, 0.6714] , [-0.5973, -0.0944, 0.3720, 0.0672, 0.2681, 1.8025, -0.0606, 0.4855] ])
torch.flatten(x,0,1)代表在第一维和第二维之间平坦化
import torch x=torch.randn(2,4,2) print(x) w=torch.flatten(x,0,1) #第一维长度2,第二维长度为4,平坦化后长度为2*4 print(w.shape) print(w) 输出为: tensor([[[-0.5523, -0.1132], [-2.2659, -0.0316], [ 0.1372, -0.8486], [-0.3593, -0.2622]], [[-0.9130, 1.0038], [-0.3996, 0.4934], [ 1.7269, 0.8215], [ 0.1207, -0.9590]]]) torch.Size([8, 2]) tensor([[-0.5523, -0.1132], [-2.2659, -0.0316], [ 0.1372, -0.8486], [-0.3593, -0.2622], [-0.9130, 1.0038], [-0.3996, 0.4934], [ 1.7269, 0.8215], [ 0.1207, -0.9590]])
对于torch.nn.Flatten(),因为其被用在神经网络中,输入为一批数据,第一维为batch,通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第二维开始平坦化。
import torch #随机32个通道为1的5*5的图 x=torch.randn(32,1,5,5) model=torch.nn.Sequential( #输入通道为1,输出通道为6,3*3的卷积核,步长为1,padding=1 torch.nn.Conv2d(1,6,3,1,1), torch.nn.Flatten() ) output=model(x) print(output.shape) # 6*(7-3+1)*(7-3+1) 输出为: torch.Size([32, 150])
总结
到此这篇关于Pytorch中torch.flatten()和torch.nn.Flatten()的文章就介绍到这了,更多相关Pytorch torch.flatten()和torch.nn.Flatten()内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
相关文章
python opencv实现直线检测并测出倾斜角度(附源码+注释)
这篇文章主要介绍了python opencv实现直线检测并测出倾斜角度(附源码+注释),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧2020-12-12
最新评论