mindspore.ops.uniform_candidate_sampler

mindspore.ops.uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False)[source]

Uniform candidate sampler.

This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution. If unique=True, candidates are drawn without replacement, else unique=False with replacement.

Parameters
  • true_classes (Tensor) – A Tensor. The target classes with a Tensor shape of (batch_size, num_true).

  • num_true (int) – The number of target classes in each training example.

  • num_sampled (int) – The number of classes to randomly sample. The sampled_candidates will have a shape of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.

  • unique (bool) – Whether all sampled classes in a batch are unique.

  • range_max (int) – The number of possible classes, must be positive.

  • seed (int) – Used for random number generation, must be non-negative. If seed has a value of 0, the seed will be replaced with a randomly generated value. Default: 0.

  • remove_accidental_hits (bool) – Whether accidental hit is removed. Default: False.

Returns

  • sampled_candidates (Tensor) - The sampled_candidates is independent of the true classes. Shape: (num_sampled, ).

  • true_expected_count (Tensor) - The expected counts under the sampling distribution of each of true_classes. Shape: (batch_size, num_true).

  • sampled_expected_count (Tensor) - The expected counts under the sampling distribution of each of sampled_candidates. Shape: (num_sampled, ).

Raises
  • TypeError – If neither num_true nor num_sampled is an int.

  • TypeError – If neither unique nor remove_accidental_hits is a bool.

  • TypeError – If neither range_max nor seed is an int.

  • TypeError – If true_classes is not a Tensor.

Supported Platforms:

Ascend GPU CPU

Examples

>>> data = Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int32))
>>> output1, output2, output3 = ops.uniform_candidate_sampler(data, 1, 3, False, 4, 1)
>>> print(output1)
[0 0 3]
>>> print(output2)
[[0.75] [0.75] [0.75] [0.75] [0.75]]
>>> print(output3)
[0.75 0.75 0.75]