mindspore.ops.MatrixBandPart

class mindspore.ops.MatrixBandPart[源代码]

提取一个Tensor中每个矩阵的中心带,中心带之外的所有值都设置为零。

更多参考详见 mindspore.ops.matrix_band_part()

警告

这是一个实验性API,后续可能修改或删除。

输入:
  • x (Tensor) - x 的shape为 \((*, m, n)\) ,其中 \(*\) 表示任意batch维度。

  • lower (Union[int, Tensor]) - 要保留的下部子对角线数。其数据类型必须是int32或int64。如果为负数,则保留整个下三角形。

  • upper (Union[int, Tensor]) - 要保留的上部子对角线数。其数据类型必须是int32或int64。如果为负数,则保留整个上三角形。

输出:

Tensor,其数据类型和维度必须和输入中的 x 保持一致。

支持平台:

Ascend GPU CPU

样例:

>>> matrix_band_part = ops.MatrixBandPart()
>>> x = np.ones([2, 4, 4]).astype(np.float32)
>>> output = matrix_band_part(Tensor(x), 2, 1)
>>> print(output)
[[[1. 1. 0. 0.]
  [1. 1. 1. 0.]
  [1. 1. 1. 1.]
  [0. 1. 1. 1.]]
 [[1. 1. 0. 0.]
  [1. 1. 1. 0.]
  [1. 1. 1. 1.]
  [0. 1. 1. 1.]]]