Pytorch中的torch.nn.Linear()方法用法解读
Pytorch torch.nn.Linear()方法
torch.nn.Linear()作为深度学习中最简单的线性变换方法,其主要作用是对输入数据应用线性转换
看一下官方的解释及介绍
class Linear(Module): r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` This module supports :ref:`TensorFloat32<tf32_on_ampere>`. Args: in_features: size of each input sample out_features: size of each output sample bias: If set to ``False``, the layer will not learn an additive bias. Default: ``True`` Shape: - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of additional dimensions and :math:`H_{in} = \text{in\_features}` - Output: :math:`(N, *, H_{out})` where all but the last dimension are the same shape as the input and :math:`H_{out} = \text{out\_features}`. Attributes: weight: the learnable weights of the module of shape :math:`(\text{out\_features}, \text{in\_features})`. The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where :math:`k = \frac{1}{\text{in\_features}}` bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. If :attr:`bias` is ``True``, the values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{in\_features}}` Examples:: >>> m = nn.Linear(20, 30) >>> input = torch.randn(128, 20) >>> output = m(input) >>> print(output.size()) torch.Size([128, 30]) """ __constants__ = ['in_features', 'out_features'] in_features: int out_features: int weight: Tensor def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: super(Linear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self) -> None: init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def forward(self, input: Tensor) -> Tensor: return F.linear(input, self.weight, self.bias) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None ) # This class exists solely for Transformer; it has an annotation stating # that bias is never None, which appeases TorchScript
这里我们主要看__init__()方法,很容易知道,当我们使用这个方法时一般需要传入2~3个参数,分别是in_features: int, out_features: int, bias: bool = True,第三个参数是说是否加偏置(bias),简单来讲,这个函数其实就是一个'一次函数':y = xA^T + b,(T表示张量A的转置),首先super(Linear, self).__init__()就是老生常谈的方法,之后初始化in_features和out_features,接下来就是比较重要的weight的设置,我们可以很清晰的看到weight的shape是(out_features,in_features)的,而我们在做xA^T时,并不是x和A^T相乘的,而是x和A.weight^T相乘的,这里需要大大留意,也就是说先对A做转置得到A.weight,然后在丢入y = xA^T + b中,得出结果。
接下来奉上一个小例子来实践一下:
import torch # 随机初始化一个shape为(128,20)的Tensor x = torch.randn(128,20) # 构造线性变换函数y = xA^T + b,且参数(20,30)指的是A的shape,则A.weight的shape就是(30,20)了 y= torch.nn.Linear(20,30) output = y(x) # 按照以上逻辑使用torch中的简单乘法函数进行检验,结果很显然与上述符合 # 下面的y.weight可以理解为一个shape为(30,20)的一个可学习的矩阵,.t()表示转置 # y.bias若为TRUE,则bias是一个Tensor,且其shape为out_features,在该程序中应为30 # 更加细致的表达一下y = (128 * 20) * (30 * 20)^T + (if bias (1,30) ,else: 0) ans = torch.mm(x,y.weight.t())+y.bias print('ans.shape:\n',ans.shape) print(torch.equal(ans,output))
对torch.nn.Linear的理解
torch.nn.Linear是pytorch的线性变换层
定义如下:
Linear(in_features: int, out_features: int, bias: bool = True, device: Any | None = None, dtype: Any | None = None)
全连接层 Fully Connect 一般就就用这个函数来实现。
因此在潜意识里,变换的输入张量的 shape 为 (batchsize, in_features),输出张量的 shape 为 (batchsize, out_features)。
当然这是常用的方式,但是 Linear 的输入张量的维度其实并不需要必须为上述的二维,多维也是完全可以的,Linear 仅是对输入的最后一维做线性变换,不影响其他维。
可以看下官网的解释
Linear — PyTorch 1.11.0 documentation
一个例子
如下:
import torch input = torch.randn(30, 20, 10) # [30, 20, 10] linear = torch.nn.Linear(10, 15) # (*, 10) --> (*, 15) output = linear(input) print(output.size()) # 输出 [30, 20, 15]
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
Windows系统下使用flup搭建Nginx和Python环境的方法
这篇文章主要介绍了Windows系统下使用flup搭建Nginx和Python环境的方法,文中使用到了flup这个Python的FastCGI工具,需要的朋友可以参考下2015-12-12django2+uwsgi+nginx上线部署到服务器Ubuntu16.04
这篇文章主要介绍了django2+uwsgi+nginx上线部署到服务器Ubuntu16.04,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧2018-06-06
最新评论