mindspore.ops.sparse_segment_mean

查看源文件
mindspore.ops.sparse_segment_mean(x, indices, segment_ids)[源代码]

计算输入张量稀疏段的平均值。

outputi=jxindices[j]N

其中 N 为满足 segment_ids[j]==i 的元素个数,如果 segment_ids 中没有某个值 i ,则 output[i]=0

说明

  • 在CPU平台, segment_ids 必须是升序, indices 必须在[0, x.shape[0])范围内。

  • 在GPU平台,对于 segment_ids 无需排序,但可能导致未定义的行为。超出范围的 indices 将被忽略。

参数:
  • x (Tensor) - 输入tensor,至少为一维。

  • indices (Tensor) - 指定索引,一维tensor。

  • segment_ids (Tensor) - 一维tensor,必须按升序排列,可包含重复值。

返回:

Tensor

支持平台:

GPU CPU

样例:

>>> import mindspore
>>> x = mindspore.tensor([[0, 1, 2], [1, 2, 3], [3, 6, 7]], dtype=mindspore.float32)
>>> indices = mindspore.tensor([0, 1, 2], dtype=mindspore.int32)
>>> segment_ids = mindspore.tensor([1,2,2], dtype=mindspore.int32)
>>> out = mindspore.ops.sparse_segment_mean(x, indices, segment_ids)
>>> print(out)
[[0. 0. 0.]
 [0. 1. 2.]
 [2. 4. 5.]]