Program Listing for File transforms.h
↰ Return to documentation for file (include/transforms.h
)
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_TRANSFORMS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_TRANSFORMS_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "include/api/dual_abi_helper.h"
#include "include/api/status.h"
#include "include/api/types.h"
#include "include/dataset/constants.h"
namespace mindspore {
namespace dataset {
class TensorOperation;
// We need the following two groups of forward declaration to friend the class in class TensorTransform.
namespace transforms {
class Compose;
class RandomApply;
class RandomChoice;
} // namespace transforms
namespace vision {
class BoundingBoxAugment;
class RandomSelectSubpolicy;
class UniformAugment;
} // namespace vision
// Abstract class to represent a tensor transform operation in the data pipeline.
class TensorTransform : public std::enable_shared_from_this<TensorTransform> {
friend class Dataset;
friend class Execute;
friend class transforms::Compose;
friend class transforms::RandomApply;
friend class transforms::RandomChoice;
friend class vision::BoundingBoxAugment;
friend class vision::RandomSelectSubpolicy;
friend class vision::UniformAugment;
public:
TensorTransform() = default;
virtual ~TensorTransform() = default;
protected:
virtual std::shared_ptr<TensorOperation> Parse() = 0;
virtual std::shared_ptr<TensorOperation> Parse(const MapTargetDevice &env) { return nullptr; }
};
class Slice {
public:
Slice() : start_(0), stop_(0), step_(0) {}
Slice(dsize_t start, dsize_t stop, dsize_t step) : start_(start), stop_(stop), step_(step) {}
Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {}
explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {}
Slice(Slice const &slice) = default;
~Slice() = default;
bool valid() const { return step_ != 0; }
dsize_t start_;
dsize_t stop_;
dsize_t step_;
};
class SliceOption {
public:
explicit SliceOption(bool all) : all_(all) {}
explicit SliceOption(const std::vector<dsize_t> &indices) : indices_(indices) {}
explicit SliceOption(const Slice &slice) : slice_(slice) {}
SliceOption(SliceOption const &slice) = default;
~SliceOption() = default;
// only one of the following will be valid
// given indices to slice the Tensor.
std::vector<dsize_t> indices_ = {};
// Slice object. All start, stop and step are 0 if invalid.
Slice slice_;
bool all_ = false;
};
// Transform operations for performing data transformation.
namespace transforms {
class Compose final : public TensorTransform {
public:
explicit Compose(const std::vector<TensorTransform *> &transforms);
explicit Compose(const std::vector<std::shared_ptr<TensorTransform>> &transforms);
explicit Compose(const std::vector<std::reference_wrapper<TensorTransform>> &transforms);
~Compose() = default;
protected:
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
class Concatenate final : public TensorTransform {
public:
explicit Concatenate(int8_t axis = 0, const MSTensor &prepend = {}, const MSTensor &append = {});
~Concatenate() = default;
protected:
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
class Duplicate final : public TensorTransform {
public:
Duplicate();
~Duplicate() = default;
protected:
std::shared_ptr<TensorOperation> Parse() override;
};
class Fill final : public TensorTransform {
public:
explicit Fill(const MSTensor &fill_value);
~Fill() = default;
protected:
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
class Mask final : public TensorTransform {
public:
explicit Mask(RelationalOp op, const MSTensor &constant,
mindspore::DataType ms_type = mindspore::DataType(mindspore::DataType::kNumberTypeBool));
~Mask() = default;
protected:
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
class OneHot final : public TensorTransform {
public:
explicit OneHot(int32_t num_classes);
~OneHot() = default;
protected:
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
class PadEnd final : public TensorTransform {
public:
explicit PadEnd(const std::vector<dsize_t> &pad_shape, const MSTensor &pad_value = {});
~PadEnd() = default;
protected:
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
class RandomApply final : public TensorTransform {
public:
explicit RandomApply(const std::vector<TensorTransform *> &transforms, double prob = 0.5);
explicit RandomApply(const std::vector<std::shared_ptr<TensorTransform>> &transforms, double prob = 0.5);
explicit RandomApply(const std::vector<std::reference_wrapper<TensorTransform>> &transforms, double prob = 0.5);
~RandomApply() = default;
protected:
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
class RandomChoice final : public TensorTransform {
public:
explicit RandomChoice(const std::vector<TensorTransform *> &transforms);
explicit RandomChoice(const std::vector<std::shared_ptr<TensorTransform>> &transforms);
explicit RandomChoice(const std::vector<std::reference_wrapper<TensorTransform>> &transforms);
~RandomChoice() = default;
protected:
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
class Slice final : public TensorTransform {
public:
explicit Slice(const std::vector<SliceOption> &slice_input);
~Slice() = default;
protected:
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
class TypeCast final : public TensorTransform {
public:
explicit TypeCast(mindspore::DataType data_type);
~TypeCast() = default;
protected:
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
class Unique final : public TensorTransform {
public:
Unique();
~Unique() = default;
protected:
std::shared_ptr<TensorOperation> Parse() override;
};
} // namespace transforms
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_TRANSFORMS_H_