Program Listing for File value.h

Return to documentation for file (include/converter/include/mindapi/ir/value.h)

#ifndef MINDSPORE_CORE_MINDAPI_IR_VALUE_H_
#define MINDSPORE_CORE_MINDAPI_IR_VALUE_H_

#include <vector>
#include <string>
#include <type_traits>
#include "mindapi/base/base.h"
#include "mindapi/ir/common.h"

namespace mindspore::api {
template <typename T>
struct ImmTrait {};

#define MIND_API_IMM_TRAIT(typeimm, prototype) \
  template <>                                  \
  struct ImmTrait<prototype> {                 \
    using type = SharedPtr<typeimm>;           \
  }

class MIND_API Value : public Base {
 public:
  MIND_API_BASE_MEMBER(Value);

  TypePtr type() const;

  AbstractBasePtr ToAbstract() const;
};

class MIND_API ValueSequence : public Value {
 public:
  MIND_API_BASE_MEMBER(ValueSequence);

  std::size_t size() const;

  std::vector<ValuePtr> value() const;
};

using ValueSequencePtr = SharedPtr<ValueSequence>;

class MIND_API ValueTuple : public ValueSequence {
 public:
  MIND_API_BASE_MEMBER(ValueTuple);

  explicit ValueTuple(const std::vector<ValuePtr> &elements);
};

using ValueTuplePtr = SharedPtr<ValueTuple>;

class MIND_API StringImm : public Value {
 public:
  MIND_API_BASE_MEMBER(StringImm);

  explicit StringImm(const std::string &str);

  const std::string &value() const;
};

using StringImmPtr = SharedPtr<StringImm>;

MIND_API_IMM_TRAIT(StringImm, std::string);

class MIND_API Scalar : public Value {
 public:
  MIND_API_BASE_MEMBER(Scalar);
};

class MIND_API BoolImm : public Scalar {
 public:
  MIND_API_BASE_MEMBER(BoolImm);

  explicit BoolImm(bool b);

  bool value() const;
};

using BoolImmPtr = SharedPtr<BoolImm>;

MIND_API_IMM_TRAIT(BoolImm, bool);

class MIND_API IntegerImm : public Scalar {
 public:
  MIND_API_BASE_MEMBER(IntegerImm);
};

class MIND_API Int8Imm : public IntegerImm {
 public:
  MIND_API_BASE_MEMBER(Int8Imm);

  explicit Int8Imm(int8_t value);

  int8_t value() const;
};

using Int8ImmPtr = SharedPtr<Int8Imm>;

MIND_API_IMM_TRAIT(Int8Imm, int8_t);

class MIND_API Int16Imm : public IntegerImm {
 public:
  MIND_API_BASE_MEMBER(Int16Imm);

  explicit Int16Imm(int16_t value);

  int16_t value() const;
};

using Int16ImmPtr = SharedPtr<Int16Imm>;

MIND_API_IMM_TRAIT(Int16Imm, int16_t);

class MIND_API Int32Imm : public IntegerImm {
 public:
  MIND_API_BASE_MEMBER(Int32Imm);

  explicit Int32Imm(int32_t value);

  int32_t value() const;
};

using Int32ImmPtr = SharedPtr<Int32Imm>;

MIND_API_IMM_TRAIT(Int32Imm, int32_t);

class MIND_API Int64Imm : public IntegerImm {
 public:
  MIND_API_BASE_MEMBER(Int64Imm);

  explicit Int64Imm(int64_t value);

  int64_t value() const;
};

using Int64ImmPtr = SharedPtr<Int64Imm>;

MIND_API_IMM_TRAIT(Int64Imm, int64_t);

class MIND_API UInt8Imm : public IntegerImm {
 public:
  MIND_API_BASE_MEMBER(UInt8Imm);

  explicit UInt8Imm(uint8_t value);

  uint8_t value() const;
};

using UInt8ImmPtr = SharedPtr<UInt8Imm>;

MIND_API_IMM_TRAIT(UInt8Imm, uint8_t);

class MIND_API FloatImm : public Scalar {
 public:
  MIND_API_BASE_MEMBER(FloatImm);
};

class MIND_API FP32Imm : public FloatImm {
 public:
  MIND_API_BASE_MEMBER(FP32Imm);

  explicit FP32Imm(float value);

  float value() const;
};

using FP32ImmPtr = SharedPtr<FP32Imm>;

MIND_API_IMM_TRAIT(FP32Imm, float);

class MIND_API FP64Imm : public FloatImm {
 public:
  MIND_API_BASE_MEMBER(FP64Imm);

  explicit FP64Imm(double value);

  double value() const;
};

using FP64ImmPtr = SharedPtr<FP64Imm>;

MIND_API_IMM_TRAIT(FP64Imm, double);

// === Utility functions for Value === //

template <typename T, typename U = typename ImmTrait<T>::type::element_type>
inline ValuePtr MakeValue(T v) {
  return MakeShared<U>(v);
}

inline ValuePtr MakeValue(const char *s) { return MakeShared<StringImm>(std::string(s)); }

inline ValuePtr MakeValue(int i) { return MakeShared<Int64Imm>(static_cast<int64_t>(i)); }

inline ValuePtr MakeValue(const std::vector<ValuePtr> &values) { return MakeShared<ValueTuple>(values); }

template <typename T, typename = typename std::enable_if_t<is_vector<T>::value, T>>
inline ValuePtr MakeValue(const T &values) {
  std::vector<ValuePtr> value_vector;
  value_vector.reserve(values.size());
  for (auto value : values) {
    value_vector.emplace_back(MakeValue(value));
  }
  return MakeShared<ValueTuple>(value_vector);
}

template <typename T, typename U = typename ImmTrait<T>::type>
inline T GetValue(const ValuePtr &value) {
  if (value == nullptr) {
    return T();
  }
  U imm = value->cast<U>();
  if (imm == nullptr) {
    return T();
  }
  return imm->value();
}

template <typename T, typename S = typename std::decay_t<T>,
          typename U = typename std::enable_if_t<is_vector<S>::value, typename S::value_type>>
std::vector<U> GetValue(const ValuePtr &value) {
  if (value == nullptr) {
    return {};
  }
  auto seq = value->cast<ValueSequencePtr>();
  if (seq == nullptr) {
    return {};
  }
  if constexpr (std::is_same_v<ValuePtr, U>) {
    return seq->value();
  } else {
    auto elements = seq->value();
    std::vector<U> result;
    result.reserve(elements.size());
    for (auto &e : elements) {
      result.emplace_back(GetValue<U>(e));
    }
    return result;
  }
}
}  // namespace mindspore::api
#endif  // MINDSPORE_CORE_MINDAPI_IR_VALUE_H_