Program Listing for File transforms.h

Return to documentation for file (include/transforms.h)

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_TRANSFORMS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_TRANSFORMS_H_

#include <map>
#include <memory>
#include <string>
#include <vector>

#include "include/api/dual_abi_helper.h"
#include "include/api/status.h"
#include "include/api/types.h"
#include "include/dataset/constants.h"

namespace mindspore {
namespace dataset {

class TensorOperation;

// We need the following two groups of forward declaration to friend the class in class TensorTransform.
namespace transforms {
class Compose;
class RandomApply;
class RandomChoice;
}  // namespace transforms

namespace vision {
class BoundingBoxAugment;
class RandomSelectSubpolicy;
class UniformAugment;
}  // namespace vision

// Abstract class to represent a tensor transform operation in the data pipeline.
class TensorTransform : public std::enable_shared_from_this<TensorTransform> {
  friend class Dataset;
  friend class Execute;
  friend class transforms::Compose;
  friend class transforms::RandomApply;
  friend class transforms::RandomChoice;
  friend class vision::BoundingBoxAugment;
  friend class vision::RandomSelectSubpolicy;
  friend class vision::UniformAugment;

 public:
  TensorTransform() {}

  ~TensorTransform() = default;

 protected:
  virtual std::shared_ptr<TensorOperation> Parse() = 0;

  virtual std::shared_ptr<TensorOperation> Parse(const MapTargetDevice &env) { return nullptr; }
};

class Slice {
 public:
  Slice() : start_(0), stop_(0), step_(0) {}
  Slice(dsize_t start, dsize_t stop, dsize_t step) : start_(start), stop_(stop), step_(step) {}
  Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {}
  explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {}
  Slice(Slice const &slice) = default;

  ~Slice() = default;

  bool valid() const { return step_ != 0; }
  dsize_t start_;
  dsize_t stop_;
  dsize_t step_;
};

class SliceOption {
 public:
  explicit SliceOption(bool all) : all_(all) {}
  explicit SliceOption(std::vector<dsize_t> indices) : indices_(indices) {}
  explicit SliceOption(Slice slice) : slice_(slice) {}
  SliceOption(SliceOption const &slice) = default;

  ~SliceOption() = default;

  // only one of the following will be valid
  // given indices to slice the Tensor.
  std::vector<dsize_t> indices_ = {};
  // Slice object. All start, stop and step are 0 if invalid.
  Slice slice_;
  bool all_ = false;
};

// Transform operations for performing data transformation.
namespace transforms {

class Compose final : public TensorTransform {
 public:
  explicit Compose(const std::vector<TensorTransform *> &transforms);
  explicit Compose(const std::vector<std::shared_ptr<TensorTransform>> &transforms);
  explicit Compose(const std::vector<std::reference_wrapper<TensorTransform>> &transforms);

  ~Compose() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class Concatenate final : public TensorTransform {
 public:
  explicit Concatenate(int8_t axis = 0, const MSTensor &prepend = {}, const MSTensor &append = {});

  ~Concatenate() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class Duplicate final : public TensorTransform {
 public:
  Duplicate();

  ~Duplicate() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;
};

class Fill final : public TensorTransform {
 public:
  explicit Fill(const MSTensor &fill_value);

  ~Fill() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class Mask final : public TensorTransform {
 public:
  explicit Mask(RelationalOp op, const MSTensor &constant,
                mindspore::DataType ms_type = mindspore::DataType(mindspore::DataType::kNumberTypeBool));

  ~Mask() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class OneHot final : public TensorTransform {
 public:
  explicit OneHot(int32_t num_classes);

  ~OneHot() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class PadEnd final : public TensorTransform {
 public:
  explicit PadEnd(const std::vector<dsize_t> &pad_shape, const MSTensor &pad_value = {});

  ~PadEnd() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class RandomApply final : public TensorTransform {
 public:
  explicit RandomApply(const std::vector<TensorTransform *> &transforms, double prob = 0.5);
  explicit RandomApply(const std::vector<std::shared_ptr<TensorTransform>> &transforms, double prob = 0.5);
  explicit RandomApply(const std::vector<std::reference_wrapper<TensorTransform>> &transforms, double prob = 0.5);

  ~RandomApply() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class RandomChoice final : public TensorTransform {
 public:
  explicit RandomChoice(const std::vector<TensorTransform *> &transforms);
  explicit RandomChoice(const std::vector<std::shared_ptr<TensorTransform>> &transforms);
  explicit RandomChoice(const std::vector<std::reference_wrapper<TensorTransform>> &transforms);

  ~RandomChoice() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class Slice final : public TensorTransform {
 public:
  explicit Slice(const std::vector<SliceOption> &slice_input);

  ~Slice() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class TypeCast final : public TensorTransform {
 public:
  explicit TypeCast(mindspore::DataType data_type);

  ~TypeCast() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class Unique final : public TensorTransform {
 public:
  Unique();

  ~Unique() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;
};
}  // namespace transforms
}  // namespace dataset
}  // namespace mindspore
#endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_TRANSFORMS_H_