Program Listing for File anf.h

Return to documentation for file (include/converter/include/mindapi/ir/anf.h)

#ifndef MINDSPORE_CORE_MINDAPI_IR_ANF_H_
#define MINDSPORE_CORE_MINDAPI_IR_ANF_H_

#include <vector>
#include <string>
#include "mindapi/base/base.h"
#include "mindapi/ir/common.h"
#include "mindapi/ir/abstract.h"
#include "mindapi/ir/primitive.h"
#include "mindapi/ir/value.h"

namespace mindspore::api {
class MIND_API AnfNode : public Base {
 public:
  MIND_API_BASE_MEMBER(AnfNode);

  std::string fullname_with_scope() const;

  AbstractBasePtr abstract() const;

  void set_abstract(const AbstractBasePtr &abs);
};

class MIND_API CNode : public AnfNode {
 public:
  MIND_API_BASE_MEMBER(CNode);

  size_t size() const;

  AnfNodePtr input(size_t i) const;

  std::vector<AnfNodePtr> inputs() const;

  void set_inputs(const std::vector<AnfNodePtr> &inputs);

  void add_input(const AnfNodePtr &input);

  void set_fullname_with_scope(const std::string &full_name);

  void AddAttr(const std::string &name, const ValuePtr &attr);

  void EraseAttr(const std::string &name);

  ValuePtr GetAttr(const std::string &name) const;
};

using CNodePtr = SharedPtr<CNode>;

class MIND_API Parameter : public AnfNode {
 public:
  MIND_API_BASE_MEMBER(Parameter);

  std::string name() const;

  void set_name(const std::string &name);

  bool has_default() const;

  void set_default_param(const ValuePtr &param);

  ValuePtr default_param() const;
};

using ParameterPtr = SharedPtr<Parameter>;

class MIND_API ValueNode : public AnfNode {
 public:
  MIND_API_BASE_MEMBER(ValueNode);

  explicit ValueNode(const ValuePtr &value);

  ValuePtr value() const;
};

using ValueNodePtr = SharedPtr<ValueNode>;

// === ANF utility functions === //

template <typename T, typename = typename std::enable_if_t<std::is_base_of_v<Value, T>, T>>
inline ValueNodePtr NewValueNode(const SharedPtr<T> &value) {
  return MakeShared<ValueNode>(value);
}

template <typename T>
inline ValueNodePtr NewValueNode(T value) {
  return NewValueNode(MakeValue(value));
}

inline ValuePtr GetValueNode(const AnfNodePtr &node) {
  if (node == nullptr) {
    return nullptr;
  }
  auto value_node = node->cast<ValueNodePtr>();
  if (value_node == nullptr) {
    return nullptr;
  }
  return value_node->value();
}

template <typename T, typename = typename std::enable_if_t<
                        is_wrapper_ptr<T>::value && std::is_base_of_v<Value, typename T::element_type>, T>>
inline T GetValueNode(const AnfNodePtr &node) {
  auto value = GetValueNode(node);
  if (value == nullptr) {
    return nullptr;
  }
  return value->cast<T>();
}

MIND_API bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &prim = nullptr);

MIND_API bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &prim);

MIND_API bool IsDataNode(const AnfNodePtr &node);
}  // namespace mindspore::api
#endif  // MINDSPORE_CORE_MINDAPI_IR_ANF_H_