mindspore.ops.UniformCandidateSampler
- class mindspore.ops.UniformCandidateSampler(num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False)[source]
Uniform candidate sampler.
This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution.
Refer to
mindspore.ops.uniform_candidate_sampler()
for more details.- Parameters
num_true (int) – The number of target classes in each training example.
num_sampled (int) – The number of classes to randomly sample. The sampled_candidates will have a shape of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
unique (bool) – Whether all sampled classes in a batch are unique.
range_max (int) – The number of possible classes, must be non-negative.
seed (int, optional) – Used for random number generation, must be non-negative. If seed has a value of 0, the seed will be replaced with a randomly generated value. Default:
0
.remove_accidental_hits (bool, optional) – Whether accidental hit is removed. Accidental hit is when one of the true classes matches one of the sample classes. Set
True
to remove which accidentally sampling the true class as sample class. Default:False
.
- Inputs:
true_classes (Tensor) - A Tensor. The target classes with a Tensor shape of \((batch\_size, num\_true)\).
- Outputs:
sampled_candidates (Tensor) - The sampled_candidates is independent of the true classes. Shape: \((num\_sampled, )\).
true_expected_count (Tensor) - The expected counts under the sampling distribution of each of true_classes. Shape: \((batch\_size, num\_true)\).
sampled_expected_count (Tensor) - The expected counts under the sampling distribution of each of sampled_candidates. Shape: \((num\_sampled, )\).
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import numpy as np >>> from mindspore import Tensor, ops >>> sampler = ops.UniformCandidateSampler(1, 3, False, 4, 1) >>> output1, output2, output3 = sampler(Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int64))) >>> print(output1.shape) (3,) >>> print(output2.shape) (5, 1) >>> print(output3.shape) (3,)