mindspore.ops.UniformCandidateSampler

查看源文件
class mindspore.ops.UniformCandidateSampler(num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False)[源代码]

使用均匀分布对一组类别进行采样。

此函数使用均匀分布从[0, range_max-1]中采样一组类(sampled_candidates)。如果 uniqueTrue ,则候选采样没有重复;如果 uniqueFalse ,则有重复。

更多参考详见 mindspore.ops.uniform_candidate_sampler()

警告

  • Ascend后端不支持随机数重现功能, seed 参数不起作用。

  • Ascend后端暂不支持动态shape场景。

参数:
  • num_true (int) - 每个训练样本的目标类数。

  • num_sampled (int) - 随机采样的类数。 sampled_candidates 的shape将为 num_sampled 。如果 uniqueTrue ,则 num_sampled 必须小于或等于 range_max

  • unique (bool) - 表示一个batch中的所有采样类是否唯一。

  • range_max (int) - 可能的类数,该值必须是非负的。

  • seed (int,可选) - 随机种子,该值必须是非负的。如果 seed 的值为 0 ,则 seed 的值将被随机生成的值替换。默认值: 0

  • remove_accidental_hits (bool,可选) - 表示是否移除accidental hit。accidental hit表示其中一个 true_classes 目标类匹配 sampled_candidates 采样类之一,设置为 True 表示移除等于目标类的采样类。默认值: False

输入:
  • true_classes (Tensor) - 输入Tensor,目标类,其shape为 \((batch\_size, num\_true)\)。 其元素值范围需要在 \([0, range\_max)\)

输出:
  • sampled_candidates (Tensor) - 候选采样与目标类之间不存在联系,其shape为 \((num\_sampled, )\)

  • true_expected_count (Tensor) - 在每组目标类的采样分布下的预期计数。Shape为 \((batch\_size, num\_true)\)

  • sampled_expected_count (Tensor) - 每个候选采样分布下的预期计数。Shape为 \((num\_sampled, )\)

支持平台:

Ascend GPU CPU

样例:

>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> sampler = ops.UniformCandidateSampler(1, 3, False, 4, 1)
>>> output1, output2, output3 = sampler(Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int64)))
>>> print(output1.shape)
(3,)
>>> print(output2.shape)
(5, 1)
>>> print(output3.shape)
(3,)