mindspore.nn.Tril

class mindspore.nn.Tril[source]

Returns a tensor, the elements above the specified main diagonal are set to zero.

Divide the matrix elements into upper and lower triangles along the main diagonal (including diagonals).

The parameter k controls the choice of diagonal. If k = 0, split along the main diagonal and keep all the elements of the lower triangle. If k > 0, select the diagonal k along the main diagonal upwards, and keep all the elements of the lower triangle. If k < 0, select the diagonal k along the main diagonal down, and keep all the elements of the lower triangle.

Inputs:
  • x (Tensor) - The input tensor. The data type is number.

  • k (Int) - The index of diagonal. Default: 0. If the dimensions of the input matrix are d1 and d2, the range of k should be in [-min(d1, d2)+1, min(d1, d2)-1], and the output value will be the same as the input x when k is out of range.

Outputs:

Tensor, has the same shape and type as input x.

Raises
Supported Platforms:

Ascend GPU CPU

Examples

>>> # case1: k = 0
>>> x = Tensor(np.array([[ 1,  2,  3,  4],
...                      [ 5,  6,  7,  8],
...                      [10, 11, 12, 13],
...                      [14, 15, 16, 17]]))
>>> tril = nn.Tril()
>>> result = tril(x)
>>> print(result)
[[ 1  0  0  0]
 [ 5  6  0  0]
 [10 11 12  0]
 [14 15 16 17]]
>>> # case2: k = 1
>>> x = Tensor(np.array([[ 1,  2,  3,  4],
...                      [ 5,  6,  7,  8],
...                      [10, 11, 12, 13],
...                      [14, 15, 16, 17]]))
>>> tril = nn.Tril()
>>> result = tril(x, 1)
>>> print(result)
[[ 1  2  0  0]
 [ 5  6  7  0]
 [10 11 12 13]
 [14 15 16 17]]
>>> # case3: k = 2
>>> x = Tensor(np.array([[ 1,  2,  3,  4],
...                      [ 5,  6,  7,  8],
...                      [10, 11, 12, 13],
...                      [14, 15, 16, 17]]))
>>> tril = nn.Tril()
>>> result = tril(x, 2)
>>> print(result)
[[ 1  2  3  0]
 [ 5  6  7  8]
 [10 11 12 13]
 [14 15 16 17]]
>>> # case4: k = -1
>>> x = Tensor(np.array([[ 1,  2,  3,  4],
...                      [ 5,  6,  7,  8],
...                      [10, 11, 12, 13],
...                      [14, 15, 16, 17]]))
>>> tril = nn.Tril()
>>> result = tril(x, -1)
>>> print(result)
[[ 0  0  0  0]
 [ 5  0  0  0]
 [10 11  0  0]
 [14 15 16  0]]