mindspore.nn.Tril
- class mindspore.nn.Tril[源代码]
返回一个Tensor,指定主对角线以上的元素被置为零。
将矩阵元素沿主对角线分为上三角和下三角(包含对角线)。
参数 k 控制对角线的选择。若 k 为0,则沿主对角线分割并保留下三角所有元素。若 k 为正值,则沿主对角线向上选择对角线 k ,并保留下三角所有元素。若 k 为负值,则沿主对角线向下选择对角线 k ,并保留下三角所有元素。
- 输入:
x (Tensor):输入Tensor。数据类型为 number 。
k (int):对角线的索引。默认值:0。假设输入的矩阵的维度分别为d1,d2,则k的范围应在[-min(d1, d2)+1, min(d1, d2)-1],超出该范围时输出值与输入 x 一致。
- 输出:
Tensor,数据类型和shape与 x 相同。
- 异常:
TypeError: k 不是int。
ValueError: x 的维度小于1。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> # case1: k = 0 >>> x = Tensor(np.array([[ 1, 2, 3, 4], ... [ 5, 6, 7, 8], ... [10, 11, 12, 13], ... [14, 15, 16, 17]])) >>> tril = nn.Tril() >>> result = tril(x) >>> print(result) [[ 1 0 0 0] [ 5 6 0 0] [10 11 12 0] [14 15 16 17]] >>> # case2: k = 1 >>> x = Tensor(np.array([[ 1, 2, 3, 4], ... [ 5, 6, 7, 8], ... [10, 11, 12, 13], ... [14, 15, 16, 17]])) >>> tril = nn.Tril() >>> result = tril(x, 1) >>> print(result) [[ 1 2 0 0] [ 5 6 7 0] [10 11 12 13] [14 15 16 17]] >>> # case3: k = 2 >>> x = Tensor(np.array([[ 1, 2, 3, 4], ... [ 5, 6, 7, 8], ... [10, 11, 12, 13], ... [14, 15, 16, 17]])) >>> tril = nn.Tril() >>> result = tril(x, 2) >>> print(result) [[ 1 2 3 0] [ 5 6 7 8] [10 11 12 13] [14 15 16 17]] >>> # case4: k = -1 >>> x = Tensor(np.array([[ 1, 2, 3, 4], ... [ 5, 6, 7, 8], ... [10, 11, 12, 13], ... [14, 15, 16, 17]])) >>> tril = nn.Tril() >>> result = tril(x, -1) >>> print(result) [[ 0 0 0 0] [ 5 0 0 0] [10 11 0 0] [14 15 16 0]]