mindspore.mint.nn.functional.linear
- mindspore.mint.nn.functional.linear(input, weight, bias=None)
对输入 input 应用全连接操作。全连接定义为:
\[output = input * weight^{T} + bias\]警告
这是一个实验性API,后续可能修改或删除。
在PYNATIVE模式下,如果 bias 不是1D, input 不可以大于6D。
- 参数:
input (Tensor) - 输入Tensor,shape是 \((*, in\_channels)\),其中 \(*\) 表示任意的附加维度。
weight (Tensor) - 输入Tensor的权重,shape是 \((out\_channels, in\_channels)\) 或 \((in\_channels)\)。
bias (Tensor,可选) - 添加在输出结果的偏差,shape是 \((out\_channels)\) 或 \(()\)。默认值:
None
,偏差为0。
- 返回:
输出结果,shape由 input 和 weight 的shape决定。
- 异常:
TypeError - input 不是Tensor。
TypeError - weight 不是Tensor。
TypeError - bias 不是Tensor。
RuntimeError - 在PYNATIVE模式下, bias 不是1D且 input 大于6D。
- 支持平台:
Ascend
样例:
>>> import numpy as np >>> import mindspore >>> from mindspore import Tensor, mint >>> input = Tensor([[-1., 1., 2.], [-3., -3., 1.]], mindspore.float32) >>> weight = Tensor([[-2., -2., -2.], [0., -1., 0.]], mindspore.float32) >>> bias = Tensor([0., 1.], mindspore.float32) >>> output = mint.nn.functional.linear(input, weight, bias) >>> print(output) [[-4. 0.] [10. 4.]]