Program Listing for File batch_norm.h

Return to documentation for file (include/converter/include/ops/batch_norm.h)

#ifndef MINDSPORE_CORE_OPS_BATCH_NORMAL_H_
#define MINDSPORE_CORE_OPS_BATCH_NORMAL_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "mindapi/base/format.h"
#include "ops/base_operator.h"

namespace mindspore {
namespace ops {
constexpr auto kNameBatchNorm = "BatchNorm";
constexpr auto kNameBatchNormWithActivation = "BatchNormWithActivation";
constexpr auto kNameBatchNormWithAddAndActivation = "BatchNormWithAddAndActivation";
class MIND_API BatchNorm : public BaseOperator {
 public:
  MIND_API_BASE_MEMBER(BatchNorm);
  BatchNorm() : BaseOperator(kNameBatchNorm) {
    InitIOName({"x", "scale", "offset", "mean", "variance"},
               {"y", "batch_mean", "batch_variance", "reserve_space_1", "reserve_space_2"});
  }
  explicit BatchNorm(const std::string kernel_name) : BaseOperator(kernel_name) {
    InitIOName({"x", "scale", "offset", "mean", "variance"},
               {"y", "batch_mean", "batch_variance", "reserve_space_1", "reserve_space_2"});
  }
  void Init(const bool is_training = false, const float epsilon = 1e-5, const float momentun = 0.1,
            const Format &format = NCHW);
  void set_is_training(const bool is_training);
  void set_epsilon(const float epsilon);
  void set_format(const Format &format);
  void set_momentum(const float momentum);
  bool get_is_training() const;
  float get_epsilon() const;
  Format get_format() const;
  float get_momentum() const;
};

class MIND_API BatchNormWithActivation : public BatchNorm {
 public:
  MIND_API_BASE_MEMBER(BatchNormWithActivation);
  BatchNormWithActivation() : BatchNorm(kNameBatchNormWithActivation) {
    InitIOName({"x", "scale", "offset", "mean", "variance"},
               {"y", "batch_mean", "batch_variance", "reserve_space_1", "reserve_space_2"});
  }
};

class MIND_API BatchNormWithAddAndActivation : public BatchNorm {
 public:
  MIND_API_BASE_MEMBER(BatchNormWithAddAndActivation);
  BatchNormWithAddAndActivation() : BatchNorm(kNameBatchNormWithAddAndActivation) {
    InitIOName({"x", "scale", "offset", "mean", "variance", "z"},
               {"y", "batch_mean", "batch_variance", "reserve_space_1", "reserve_space_2"});
  }
};

MIND_API abstract::AbstractBasePtr BatchNormInferFunc(const abstract::AnalysisEnginePtr &,
                                                      const PrimitivePtr &primitive,
                                                      const std::vector<abstract::AbstractBasePtr> &input_args);
}  // namespace ops
}  // namespace mindspore

#endif  // MINDSPORE_CORE_OPS_BatchNorm_H_