Program Listing for File base_operator.h

Return to documentation for file (include/converter/include/ops/base_operator.h)

#ifndef MINDSPORE_CORE_OPS_BASE_OPERATOR_
#define MINDSPORE_CORE_OPS_BASE_OPERATOR_

#include <map>
#include <memory>
#include <string>
#include <vector>

#include "mindapi/ir/primitive.h"

namespace mindspore {
namespace abstract {
class AnalysisEngine;
using AnalysisEnginePtr = std::shared_ptr<AnalysisEngine>;

class AbstractBase;
using AbstractBasePtr = std::shared_ptr<AbstractBase>;
}  // namespace abstract
}  // namespace mindspore

namespace mindspore {
class Primitive;
using PrimitivePtr = std::shared_ptr<Primitive>;
}  // namespace mindspore

namespace mindspore {
namespace ops {
using PrimitiveCPtr = PrimitivePtr;
class MIND_API BaseOperator : public api::Primitive {
 public:
  MIND_API_BASE_MEMBER(BaseOperator);
  explicit BaseOperator(const std::string &name);
  PrimitiveCPtr GetPrim();

  void set_batch_rank(int64_t batch_rank);
  int64_t get_batch_rank() const;

 protected:
  void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name);
};

using OperatorDefineFunc = std::function<std::shared_ptr<BaseOperator>(const std::shared_ptr<mindspore::Base> &)>;
class MIND_API OperatorRegister {
 public:
  ~OperatorRegister() {}

  static OperatorRegister &GetInstance();

  const std::map<std::string, OperatorDefineFunc> &GetOperatorMap() const;

  void SetOperatorMap(const std::string &kname, const OperatorDefineFunc &fn);

 private:
  OperatorRegister() {}
  std::map<std::string, OperatorDefineFunc> operator_fns_;
};

class MIND_API OperatorRegisterHelper {
 public:
  OperatorRegisterHelper(const std::string &kname, const OperatorDefineFunc &fn) {
    OperatorRegister::GetInstance().SetOperatorMap(kname, fn);
    (void)id_;  // make compiler happy on macos
  }

  ~OperatorRegisterHelper() = default;

 private:
  int id_{0};
};

#define OPERATOR_CREATOR_REG(K_NAME, OP_CLASS)                                                                   \
  std::shared_ptr<BaseOperator> GetDefaultBaseOperator##OP_CLASS(const std::shared_ptr<mindspore::Base> &impl) { \
    return std::make_shared<OP_CLASS>(impl);                                                                     \
  }                                                                                                              \
  OperatorRegisterHelper operator_gen_##OP_CLASS(K_NAME, GetDefaultBaseOperator##OP_CLASS)

#define MIND_API_OPERATOR_IMPL(ClassName, ParentClassName)    \
  MIND_API_BASE_IMPL(ClassName, PrimitiveC, ParentClassName); \
  OPERATOR_CREATOR_REG(#ClassName, ClassName)

// This macro is for operator whose name is not same as its class name.
#define MIND_API_OPERATOR_NAME_IMPL(ClassName, OpName, ParentClassName) \
  MIND_API_BASE_IMPL(ClassName, PrimitiveC, ParentClassName);           \
  OPERATOR_CREATOR_REG(OpName, ClassName)
}  // namespace ops
}  // namespace mindspore

#endif  // MINDSPORE_CORE_OPS_BASE_OPERATOR_