mindspore.ops.MaskedScatter

View Source On Gitee
class mindspore.ops.MaskedScatter[source]

Updates the value in the input with value in updates according to the mask.

Warning

This is an experimental API that is subject to change or deletion.

Inputs:
  • x (Tensor): The input Tensor to be updated.

  • mask (Tensor[bool]): The mask Tensor indicating which elements should be modified or replaced. The shapes of mask and x must be the same or broadcastable.

  • updates (Tensor): The values to scatter into the target tensor x. It has the same data type as x. The number of elements must be greater than or equal to the number of True’s in mask.

Outputs:

Tensor, with the same type and shape as x.

Raises
  • TypeError – If x, mask or updates is not a Tensor.

  • TypeError – If data type of x is not be supported.

  • TypeError – If dtype of mask is not bool.

  • TypeError – If the dim of x less than the dim of mask.

  • ValueError – If mask can not be broadcastable to x.

  • ValueError – If the number of elements in updates is less than number of True’s in mask.

Supported Platforms:

Ascend CPU

Examples

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> input_x = Tensor(np.array([1., 2., 3., 4.]), mindspore.float32)
>>> mask = Tensor(np.array([True, True, False, True]), mindspore.bool_)
>>> updates = Tensor(np.array([5., 6., 7.]), mindspore.float32)
>>> output = ops.MaskedScatter()(input_x, mask, updates)
>>> print(output)
[5. 6. 3. 7.]