mindspore.ops.MaskedScatter

查看源文件
class mindspore.ops.MaskedScatter[源代码]

返回一个Tensor。根据 maskupdates 更新输入Tensor的值。

警告

这是一个实验性API,后续可能修改或删除。

输入:
  • x (Tensor) - 被更新输入Tensor。

  • mask (Tensor[bool]) - 指示应修改或替换哪些元素的掩码Tensor, maskx 的shape必须相等或者两者的shape可以广播。

  • updates (Tensor) - 要散播到目标张量或数组中的值。其数据类型与 x 相同。 updates 中的元素数量必须大于等于 mask 中的True元素的数量。

输出:

Tensor,其数据类型和shape与 x 相同。

异常:
  • TypeError - 如果 xmask 或者 updates 不是Tensor。

  • TypeError - 如果 x 的数据类型不被支持。

  • TypeError - 如果 mask 的dtype不是bool。

  • TypeError - 如果 x 的维度数小于 mask 的维度数。

  • ValueError - 如果 mask 不能广播到 x

  • ValueError - 如果 updates 中的元素数目小于 mask 中的True元素的数量。

支持平台:

Ascend CPU

样例:

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> input_x = Tensor(np.array([1., 2., 3., 4.]), mindspore.float32)
>>> mask = Tensor(np.array([True, True, False, True]), mindspore.bool_)
>>> updates = Tensor(np.array([5., 6., 7.]), mindspore.float32)
>>> output = ops.MaskedScatter()(input_x, mask, updates)
>>> print(output)
[5. 6. 3. 7.]