Class WeightedRandomSampler

Inheritance Relationships

Base Type

Class Documentation

class WeightedRandomSampler : public mindspore::dataset::Sampler

A class to represent a Weighted Random Sampler in the data pipeline.

说明

Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).

Public Functions

explicit WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples = 0, bool replacement = true)

Constructor.

参数
  • weights[in] A vector sequence of weights, not necessarily summing up to 1.

  • num_samples[in] The number of samples to draw (default=0, return all samples).

  • replacement[in] If true, put the sample ID back for the next draw (default=true).

样例
/* creates a WeightedRandomSampler that will sample 4 elements without replacement */
std::vector<double> weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1};
sampler = std::make_shared<WeightedRandomSampler>(weights, 4);
std::string folder_path = "/path/to/image/folder";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, sampler);
~WeightedRandomSampler() = default

Destructor.