mindspore.ops.UniformCandidateSampler

class mindspore.ops.UniformCandidateSampler(*args, **kwargs)[source]

Uniform candidate sampler.

This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution. If unique=True, candidates are drawn without replacement, else unique=False with replacement.

Parameters
  • num_true (int) – The number of target classes in each training example.

  • num_sampled (int) – The number of classes to randomly sample. The sampled_candidates will have a shape of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.

  • unique (bool) – Whether all sampled classes in a batch are unique.

  • range_max (int) – The number of possible classes, must be non-negative.

  • seed (int) – Used for random number generation, must be non-negative. If seed has a value of 0, seed will be replaced with a randomly generated value. Default: 0.

  • remove_accidental_hits (bool) – Whether accidental hit is removed. Default: False.

Inputs:
  • true_classes (Tensor) - A Tensor. The target classes with a Tensor shape of (batch_size, num_true).

Outputs:
  • sampled_candidates (Tensor) - The sampled_candidates is independent of the true classes. Shape: (num_sampled, ).

  • true_expected_count (Tensor) - The expected counts under the sampling distribution of each of true_classes. Shape: (batch_size, num_true).

  • sampled_expected_count (Tensor) - The expected counts under the sampling distribution of each of sampled_candidates. Shape: (num_sampled, ).

Raises
  • TypeError – If neither num_true nor num_sampled is an int.

  • TypeError – If neither unique nor remove_accidental_hits is a bool.

  • TypeError – If neither range_max nor seed is a int.

  • TypeError – If true_classes is not a Tensor.

Supported Platforms:

GPU

Examples

>>> sampler = ops.UniformCandidateSampler(1, 3, False, 4)
>>> output1, output2, output3 = sampler(Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int32)))
>>> print(output1, output2, output3)
[1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75]