mindspore.nn.Tril

class mindspore.nn.Tril[source]

Returns a tensor with elements above the kth diagonal zeroed.

Inputs:
  • x (Tensor) - The input tensor. The data type is Number. \((N,*)\) where \(*\) means, any number of additional dimensions.

  • k (Int) - The index of diagonal. Default: 0

Outputs:

Tensor, has the same shape and type as input x.

Raises
Supported Platforms:

Ascend GPU CPU

Examples

>>> x = Tensor(np.array([[ 1,  2,  3,  4],
...                      [ 5,  6,  7,  8],
...                      [10, 11, 12, 13],
...                      [14, 15, 16, 17]]))
>>> tril = nn.Tril()
>>> result = tril(x)
>>> print(result)
[[ 1,  0,  0,  0],
 [ 5,  6,  0,  0],
 [10, 11, 12,  0],
 [14, 15, 16, 17]]))
>>> x = Tensor(np.array([[ 1,  2,  3,  4],
...                      [ 5,  6,  7,  8],
...                      [10, 11, 12, 13],
...                      [14, 15, 16, 17]]))
>>> tril = nn.Tril()
>>> result = tril(x, 1)
>>> print(result)
[[ 1,  2,  0,  0],
 [ 5,  6,  7,  0],
 [10, 11, 12,  13],
 [14, 15, 16, 17]]))
>>> x = Tensor(np.array([[ 1,  2,  3,  4],
...                      [ 5,  6,  7,  8],
...                      [10, 11, 12, 13],
...                      [14, 15, 16, 17]]))
>>> tril = nn.Tril()
>>> result = tril(x, 2)
>>> print(result)
[[ 1,  2,  3,  0],
 [ 5,  6,  7,  8],
 [10, 11, 12,  13],
 [14, 15, 16, 17]]))
>>> x = Tensor(np.array([[ 1,  2,  3,  4],
...                      [ 5,  6,  7,  8],
...                      [10, 11, 12, 13],
...                      [14, 15, 16, 17]]))
>>> tril = nn.Tril()
>>> result = tril(x, -1)
>>> print(result)
[[ 0,  0,  0,  0],
 [ 5,  0,  0,  0],
 [10, 11,  0,  0],
 [14, 15, 16,  0]]))