mindspore.ops.uniform_candidate_sampler
- mindspore.ops.uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False)[source]
Uniform candidate sampler.
This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution. If unique=True, candidates are drawn without replacement, else unique=False with replacement.
- Parameters
true_classes (Tensor) – A Tensor. The target classes with a Tensor shape of \((batch\_size, num\_true)\) .
num_true (int) – The number of target classes in each training example.
num_sampled (int) – The number of classes to randomly sample. The sampled_candidates will have a shape of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
unique (bool) – Whether all sampled classes in a batch are unique.
range_max (int) – The number of possible classes, must be positive.
seed (int) – Used for random number generation, must be non-negative. If seed has a value of 0, the seed will be replaced with a randomly generated value. Default:
0
.remove_accidental_hits (bool) – Whether accidental hit is removed. Accidental hit is when one of the true classes matches one of the sample classes. Set
True
to remove which accidentally sampling the true class as sample class. Default:False
.
- Returns
sampled_candidates (Tensor) - The sampled_candidates is independent of the true classes. shape: \((num\_sampled, )\) .
true_expected_count (Tensor) - The expected counts under the sampling distribution of each of true_classes. shape: \((batch\_size, num\_true)\) .
sampled_expected_count (Tensor) - The expected counts under the sampling distribution of each of sampled_candidates. shape: \((num\_sampled, )\) .
- Raises
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import numpy as np >>> from mindspore import Tensor, ops >>> data = Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int64)) >>> output1, output2, output3 = ops.uniform_candidate_sampler(data, 1, 3, False, 4, 1) >>> print(output1.shape) (3,) >>> print(output2.shape) (5, 1) >>> print(output3.shape) (3,)