Program Listing for File value.h
↰ Return to documentation for file (include/converter/include/mindapi/ir/value.h
)
#ifndef MINDSPORE_CORE_MINDAPI_IR_VALUE_H_
#define MINDSPORE_CORE_MINDAPI_IR_VALUE_H_
#include <vector>
#include <string>
#include <type_traits>
#include "mindapi/base/base.h"
#include "mindapi/ir/common.h"
namespace mindspore::api {
template <typename T>
struct ImmTrait {};
#define MIND_API_IMM_TRAIT(typeimm, prototype) \
template <> \
struct ImmTrait<prototype> { \
using type = SharedPtr<typeimm>; \
}
class MIND_API Value : public Base {
public:
MIND_API_BASE_MEMBER(Value);
TypePtr type() const;
AbstractBasePtr ToAbstract() const;
};
class MIND_API ValueSequence : public Value {
public:
MIND_API_BASE_MEMBER(ValueSequence);
std::size_t size() const;
std::vector<ValuePtr> value() const;
};
using ValueSequencePtr = SharedPtr<ValueSequence>;
class MIND_API ValueTuple : public ValueSequence {
public:
MIND_API_BASE_MEMBER(ValueTuple);
explicit ValueTuple(const std::vector<ValuePtr> &elements);
};
using ValueTuplePtr = SharedPtr<ValueTuple>;
class MIND_API StringImm : public Value {
public:
MIND_API_BASE_MEMBER(StringImm);
explicit StringImm(const std::string &str);
const std::string &value() const;
};
using StringImmPtr = SharedPtr<StringImm>;
MIND_API_IMM_TRAIT(StringImm, std::string);
class MIND_API Scalar : public Value {
public:
MIND_API_BASE_MEMBER(Scalar);
};
class MIND_API BoolImm : public Scalar {
public:
MIND_API_BASE_MEMBER(BoolImm);
explicit BoolImm(bool b);
bool value() const;
};
using BoolImmPtr = SharedPtr<BoolImm>;
MIND_API_IMM_TRAIT(BoolImm, bool);
class MIND_API IntegerImm : public Scalar {
public:
MIND_API_BASE_MEMBER(IntegerImm);
};
class MIND_API Int8Imm : public IntegerImm {
public:
MIND_API_BASE_MEMBER(Int8Imm);
explicit Int8Imm(int8_t value);
int8_t value() const;
};
using Int8ImmPtr = SharedPtr<Int8Imm>;
MIND_API_IMM_TRAIT(Int8Imm, int8_t);
class MIND_API Int16Imm : public IntegerImm {
public:
MIND_API_BASE_MEMBER(Int16Imm);
explicit Int16Imm(int16_t value);
int16_t value() const;
};
using Int16ImmPtr = SharedPtr<Int16Imm>;
MIND_API_IMM_TRAIT(Int16Imm, int16_t);
class MIND_API Int32Imm : public IntegerImm {
public:
MIND_API_BASE_MEMBER(Int32Imm);
explicit Int32Imm(int32_t value);
int32_t value() const;
};
using Int32ImmPtr = SharedPtr<Int32Imm>;
MIND_API_IMM_TRAIT(Int32Imm, int32_t);
class MIND_API Int64Imm : public IntegerImm {
public:
MIND_API_BASE_MEMBER(Int64Imm);
explicit Int64Imm(int64_t value);
int64_t value() const;
};
using Int64ImmPtr = SharedPtr<Int64Imm>;
MIND_API_IMM_TRAIT(Int64Imm, int64_t);
class MIND_API UInt8Imm : public IntegerImm {
public:
MIND_API_BASE_MEMBER(UInt8Imm);
explicit UInt8Imm(uint8_t value);
uint8_t value() const;
};
using UInt8ImmPtr = SharedPtr<UInt8Imm>;
MIND_API_IMM_TRAIT(UInt8Imm, uint8_t);
class MIND_API FloatImm : public Scalar {
public:
MIND_API_BASE_MEMBER(FloatImm);
};
class MIND_API FP32Imm : public FloatImm {
public:
MIND_API_BASE_MEMBER(FP32Imm);
explicit FP32Imm(float value);
float value() const;
};
using FP32ImmPtr = SharedPtr<FP32Imm>;
MIND_API_IMM_TRAIT(FP32Imm, float);
class MIND_API FP64Imm : public FloatImm {
public:
MIND_API_BASE_MEMBER(FP64Imm);
explicit FP64Imm(double value);
double value() const;
};
using FP64ImmPtr = SharedPtr<FP64Imm>;
MIND_API_IMM_TRAIT(FP64Imm, double);
// === Utility functions for Value === //
template <typename T, typename U = typename ImmTrait<T>::type::element_type>
inline ValuePtr MakeValue(T v) {
return MakeShared<U>(v);
}
inline ValuePtr MakeValue(const char *s) { return MakeShared<StringImm>(std::string(s)); }
inline ValuePtr MakeValue(int i) { return MakeShared<Int64Imm>(static_cast<int64_t>(i)); }
inline ValuePtr MakeValue(const std::vector<ValuePtr> &values) { return MakeShared<ValueTuple>(values); }
template <typename T, typename = typename std::enable_if_t<is_vector<T>::value, T>>
inline ValuePtr MakeValue(const T &values) {
std::vector<ValuePtr> value_vector;
value_vector.reserve(values.size());
for (auto value : values) {
value_vector.emplace_back(MakeValue(value));
}
return MakeShared<ValueTuple>(value_vector);
}
template <typename T, typename U = typename ImmTrait<T>::type>
inline T GetValue(const ValuePtr &value) {
if (value == nullptr) {
return T();
}
U imm = value->cast<U>();
if (imm == nullptr) {
return T();
}
return imm->value();
}
template <typename T, typename S = typename std::decay_t<T>,
typename U = typename std::enable_if_t<is_vector<S>::value, typename S::value_type>>
std::vector<U> GetValue(const ValuePtr &value) {
if (value == nullptr) {
return {};
}
auto seq = value->cast<ValueSequencePtr>();
if (seq == nullptr) {
return {};
}
if constexpr (std::is_same_v<ValuePtr, U>) {
return seq->value();
} else {
auto elements = seq->value();
std::vector<U> result;
result.reserve(elements.size());
for (auto &e : elements) {
result.emplace_back(GetValue<U>(e));
}
return result;
}
}
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_VALUE_H_