mindspore.ops.scatter_max
- mindspore.ops.scatter_max(input_x, indices, updates)[source]
Using given values to update tensor value through the max operation, along with the input indices. This operation outputs the input_x after the update is done, which makes it convenient to use the updated value.
- Parameters
input_x (Parameter) – The target tensor, with data type of Parameter. The shape is \((N,*)\) where \(*\) means,any number of additional dimensions.
indices (Tensor) – The index to do max operation whose data type must be mindspore.int32.
updates (Tensor) – The tensor doing the max operation with input_x, the data type is same as input_x, the shape is indices.shape + x.shape[1:].
- Returns
Tensor, the updated input_x, the type and shape same as input_x.
- Raises
TypeError – If indices is not an int32 or int64.
ValueError – If the shape of updates is not equal to indices.shape + input_x.shape[1:].
RuntimeError – If the data type of input_x and updates conversion of Parameter is required when data type conversion of Parameter is not supported.
RuntimeError – On the Ascend platform, the input data dimension of input_x , indices and updates is greater than 8 dimensions.
- Supported Platforms:
Ascend
CPU
GPU
Examples
>>> input_x = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32), name="input_x") >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) >>> updates = Tensor(np.ones([2, 2, 3]) * 88, mindspore.float32) >>> output = ops.scatter_max(input_x, indices, updates) >>> print(output) [[88. 88. 88.] [88. 88. 88.]]