mindspore.nn.Tril
- class mindspore.nn.Tril[source]
Returns a tensor, the elements above the specified main diagonal are set to zero.
Divide the matrix elements into upper and lower triangles along the main diagonal (including diagonals).
The parameter k controls the choice of diagonal. If k = 0, split along the main diagonal and keep all the elements of the lower triangle. If k > 0, select the diagonal k along the main diagonal upwards, and keep all the elements of the lower triangle. If k < 0, select the diagonal k along the main diagonal down, and keep all the elements of the lower triangle.
- Inputs:
x (Tensor) - The input tensor. The data type is number.
k (Int) - The index of diagonal. Default: 0. If the dimensions of the input matrix are d1 and d2, the range of k should be in [-min(d1, d2)+1, min(d1, d2)-1], and the output value will be the same as the input x when k is out of range.
- Outputs:
Tensor, has the same shape and type as input x.
- Raises
TypeError – If k is not an int.
ValueError – If length of shape of x is less than 1.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> # case1: k = 0 >>> x = Tensor(np.array([[ 1, 2, 3, 4], ... [ 5, 6, 7, 8], ... [10, 11, 12, 13], ... [14, 15, 16, 17]])) >>> tril = nn.Tril() >>> result = tril(x) >>> print(result) [[ 1 0 0 0] [ 5 6 0 0] [10 11 12 0] [14 15 16 17]] >>> # case2: k = 1 >>> x = Tensor(np.array([[ 1, 2, 3, 4], ... [ 5, 6, 7, 8], ... [10, 11, 12, 13], ... [14, 15, 16, 17]])) >>> tril = nn.Tril() >>> result = tril(x, 1) >>> print(result) [[ 1 2 0 0] [ 5 6 7 0] [10 11 12 13] [14 15 16 17]] >>> # case3: k = 2 >>> x = Tensor(np.array([[ 1, 2, 3, 4], ... [ 5, 6, 7, 8], ... [10, 11, 12, 13], ... [14, 15, 16, 17]])) >>> tril = nn.Tril() >>> result = tril(x, 2) >>> print(result) [[ 1 2 3 0] [ 5 6 7 8] [10 11 12 13] [14 15 16 17]] >>> # case4: k = -1 >>> x = Tensor(np.array([[ 1, 2, 3, 4], ... [ 5, 6, 7, 8], ... [10, 11, 12, 13], ... [14, 15, 16, 17]])) >>> tril = nn.Tril() >>> result = tril(x, -1) >>> print(result) [[ 0 0 0 0] [ 5 0 0 0] [10 11 0 0] [14 15 16 0]]