mindspore.ops.scatter_max

mindspore.ops.scatter_max(input_x, indices, updates)[source]

Using given values to update tensor value through the max operation, along with the input indices. This operation outputs the input_x after the update is done, which makes it convenient to use the updated value.

Parameters
  • input_x (Parameter) – The target tensor, with data type of Parameter. The shape is \((N,*)\) where \(*\) means,any number of additional dimensions.

  • indices (Tensor) – The index to do max operation whose data type must be mindspore.int32.

  • updates (Tensor) – The tensor doing the max operation with input_x, the data type is same as input_x, the shape is indices.shape + x.shape[1:].

Returns

Tensor, the updated input_x, the type and shape same as input_x.

Raises
  • TypeError – If indices is not an int32 or int64.

  • ValueError – If the shape of updates is not equal to indices.shape + input_x.shape[1:].

  • RuntimeError – If the data type of input_x and updates conversion of Parameter is required when data type conversion of Parameter is not supported.

  • RuntimeError – On the Ascend platform, the input data dimension of input_x , indices and updates is greater than 8 dimensions.

Supported Platforms:

Ascend CPU GPU

Examples

>>> input_x = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32), name="input_x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> updates = Tensor(np.ones([2, 2, 3]) * 88, mindspore.float32)
>>> output = ops.scatter_max(input_x, indices, updates)
>>> print(output)
[[88. 88. 88.]
 [88. 88. 88.]]