mindspore.ops.IndexAdd

class mindspore.ops.IndexAdd(axis, use_lock=True, check_index_bound=True)[源代码]

Adds tensor y to specified axis and indices of tensor x. The axis should be in [0, len(x.dim) - 1], and indices should be in [0, the size of x] at the axis dimension.

Parameters
  • axis (int) – The dimension along which to index.

  • use_lock (bool) – If true, use lock mode. If false, don’t use lock mode. Default: True.

  • check_index_bound (bool) – If true, check index boundary. If false, don’t check index boundary. Default: True.

Inputs:
  • x (Parameter) - The input Parameter to add to.

  • indices (Tensor) - Add the value of x and y along the dimension of the axis according to the specified index value, with data type int32. The indices must be 1D with the same size as the size of y in the axis dimension. The values of indices should be in [0, b), where the b is the size of x in the axis dimension.

  • y (Tensor) - The input tensor with the value to add. Must have same data type as x. The shape must be the same as x except the axis th dimension.

Outputs:

Tensor, has the same shape and dtype as x.

Raises
  • TypeError – If x is not a Parameter.

  • TypeError – If neither indices nor y is a Tensor.

  • ValueError – If axis is out of x rank’s range.

  • ValueError – If x rank is not the same as y rank.

  • ValueError – If size of indices is not equal to dimension of y[axis].

  • ValueError – If y’s shape is not the same as x except the axis th dimension.

Supported Platforms:

Ascend GPU

Examples

>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.index_add = ops.IndexAdd(axis=1)
...         self.x = Parameter(Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32),
...                 name="name_x")
...         self.indices = Tensor(np.array([0, 2]), mindspore.int32)
...
...     def construct(self, y):
...         return self.index_add(self.x, self.indices, y)
...
>>> y = Tensor(np.array([[0.5, 1.0], [1.0, 1.5], [2.0, 2.5]]), mindspore.float32)
>>> net = Net()
>>> output = net(y)
>>> print(output)
[[ 1.5  2.   4. ]
 [ 5.   5.   7.5]
 [ 9.   8.  11.5]]