mindspore.ops.matrix_band_part
- mindspore.ops.matrix_band_part(x, lower, upper)[源代码]
将矩阵的每个中心带外的所有位置设置为0。
参数:
x (Tensor) - x 的shape为 \((*, m, n)\) ,其中 \(*\) 表示任意batch维度。x 的数据类型必须为float16、float32、float64、int32或int64。
lower (Union[int, Tensor]) - 要保留的下部子对角线数。其数据类型必须是int32或int64。如果为负数,则保留整个下三角形。
upper (Union[int, Tensor]) - 要保留的上部子对角线数。其数据类型必须是int32或int64。如果为负数,则保留整个上三角形。
返回:
Tensor,其数据类型和维度必须和输入中的 x 保持一致。
异常:
TypeError - x 不是一个Tensor。
TypeError - x 的数据类型不是float16、float32、float64、int32或int64。
TypeError - lower 不是一个数值或者Tensor。
TypeError - upper 不是一个数值或者Tensor。
TypeError - lower 的数据类型不是int32或int64。
TypeError - upper 的数据类型不是int32或int64。
ValueError - x 的shape不是大于或等于2维。
ValueError - lower 的shape不等于0维。
ValueError - upper 的shape不等于0维。
- 支持平台:
GPU
CPU
样例:
>>> from mindspore.ops import functional as F >>> x = Tensor(np.ones([2, 4, 4]).astype(np.float32)) >>> output = F.matrix_band_part(x, 2, 1) >>> print(output) [[[1. 1. 0. 0.] [1. 1. 1. 0.] [1. 1. 1. 1.] [0. 1. 1. 1.]] [[1. 1. 0. 0.] [1. 1. 1. 0.] [1. 1. 1. 1.] [0. 1. 1. 1.]]]