mindspore.ops.IndexAdd
- class mindspore.ops.IndexAdd(*args, **kwargs)[source]
Adds tensor y to specified axis and indices of tensor x. The axis should be in the range from 0 to len(x.dim) - 1, and indices should be in the range from 0 to the size of x at the axis dimension.
- Parameters
axis (int) – The dimension along which to index.
- Inputs:
input_x (Parameter) - The input tensor to add to, with data type float64, float32, float16, int32, int16, int8, uint8.
indices (Tensor) - The index of input_x on the axis th dimension to add to, with data type int32. The indices must be 1D with the same size as the size of the axis th dimension of input_y. The values of indices should be in the range of 0 to the size of the axis th dimension of input_x.
input_y (Tensor) - The input tensor with the value to add. Must have same data type as input_x. The shape must be the same as input_x except the axis th dimension.
- Outputs:
Tensor, has the same shape and dtype as input_x.
- Raises
TypeError – If dtype of input_x is not one of: float64, float32, float16, int32, int16, int8, uint8.
TypeError – If neither indices nor input_y is a Tensor.
TypeError – If shape of input_y is not same as the input_x.
ValueError – If axis is out of input_x rank’s range.
ValueError – If input_x rank is not the same as input_y rank.
ValueError – If size of indices is not equal to dimension of y[axis].
ValueError – If input_y’s shape is not the same as input_x except the axis th dimension.
- Supported Platforms:
GPU
Examples
>>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.index_add = ops.IndexAdd(axis=1) ... self.input_x = Parameter(Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)) ... self.indices = Tensor(np.array([0, 2]), mindspore.int32) ... ... def construct(self, input_y): ... return self.index_add(self.input_x, self.indices, input_y) ... >>> input_y = Tensor(np.array([[0.5, 1.0], [1.0, 1.5], [2.0, 2.5]]), mindspore.float32) >>> net = Net() >>> output = net(input_y) >>> print(output) [[ 1.5 2. 4. ] [ 5. 5. 7.5] [ 9. 8. 11.5]]