mindspore.ops.matrix_band_part

查看源文件
mindspore.ops.matrix_band_part(x, lower, upper)[源代码]

将矩阵的每个中心带外的所有位置设置为0。中心带为对角线加上 lowerupper 对应保留的部分。

警告

这是一个实验性API,后续可能修改或删除。

参数:
  • x (Tensor) - x 的shape为 \((*, m, n)\) ,其中 \(*\) 表示任意batch维度。

  • lower (Union[int, Tensor]) - 要保留的下部子对角线数。其数据类型必须是int32或int64。如果为负数,则保留对角线下方所有元素。

  • upper (Union[int, Tensor]) - 要保留的上部子对角线数。其数据类型必须是int32或int64。如果为负数,则保留对角线上方所有元素。

返回:

Tensor,其数据类型和维度必须和输入中的 x 保持一致。

异常:
  • TypeError - x 不是一个Tensor。

  • TypeError - x 的类型无效。

  • TypeError - lower 不是一个数值或者Tensor。

  • TypeError - upper 不是一个数值或者Tensor。

  • TypeError - lower 的数据类型不是int32或int64。

  • TypeError - upper 的数据类型不是int32或int64。

  • ValueError - x 的shape不是大于或等于二维。

  • ValueError - lower 的shape不等于零维。

  • ValueError - upper 的shape不等于零维。

支持平台:

Ascend GPU CPU

样例:

>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> x = Tensor(np.ones([2, 4, 4]).astype(np.float32))
>>> output = ops.matrix_band_part(x, 2, 1)
>>> print(output)
[[[1. 1. 0. 0.]
  [1. 1. 1. 0.]
  [1. 1. 1. 1.]
  [0. 1. 1. 1.]]
 [[1. 1. 0. 0.]
  [1. 1. 1. 0.]
  [1. 1. 1. 1.]
  [0. 1. 1. 1.]]]