mindspore.ops.MatrixBandPart
- class mindspore.ops.MatrixBandPart[source]
Extracts the central diagonal band of each matrix in a tensor, with all values outside the central band set to zero.
Refer to
mindspore.ops.matrix_band_part()
for more details.Warning
This is an experimental API that is subject to change or deletion.
- Inputs:
x (Tensor) - Input tensor. \((*, m, n)\) where \(*\) means, any number of additional dimensions.
lower (Union[int, Tensor]) - Number of subdiagonals to keep. The data type must be int32 or int64. If negative, keep entire lower triangle.
upper (Union[int, Tensor]) - Number of superdiagonals to keep. The data type must be int32 or int64. If negative, keep entire upper triangle.
- Outputs:
Tensor, has the same type and shape as x.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> matrix_band_part = ops.MatrixBandPart() >>> x = np.ones([2, 4, 4]).astype(np.float32) >>> output = matrix_band_part(Tensor(x), 2, 1) >>> print(output) [[[1. 1. 0. 0.] [1. 1. 1. 0.] [1. 1. 1. 1.] [0. 1. 1. 1.]] [[1. 1. 0. 0.] [1. 1. 1. 0.] [1. 1. 1. 1.] [0. 1. 1. 1.]]]