Program Listing for File batch_norm.h
↰ Return to documentation for file (include/converter/include/ops/batch_norm.h
)
#ifndef MINDSPORE_CORE_OPS_BATCH_NORMAL_H_
#define MINDSPORE_CORE_OPS_BATCH_NORMAL_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "mindapi/base/format.h"
#include "ops/base_operator.h"
namespace mindspore {
namespace ops {
constexpr auto kNameBatchNorm = "BatchNorm";
constexpr auto kNameBatchNormWithActivation = "BatchNormWithActivation";
constexpr auto kNameBatchNormWithAddAndActivation = "BatchNormWithAddAndActivation";
class MIND_API BatchNorm : public BaseOperator {
public:
MIND_API_BASE_MEMBER(BatchNorm);
BatchNorm() : BaseOperator(kNameBatchNorm) {
InitIOName({"x", "scale", "offset", "mean", "variance"},
{"y", "batch_mean", "batch_variance", "reserve_space_1", "reserve_space_2"});
}
explicit BatchNorm(const std::string kernel_name) : BaseOperator(kernel_name) {
InitIOName({"x", "scale", "offset", "mean", "variance"},
{"y", "batch_mean", "batch_variance", "reserve_space_1", "reserve_space_2"});
}
void Init(const bool is_training = false, const float epsilon = 1e-5, const float momentun = 0.1,
const Format &format = NCHW);
void set_is_training(const bool is_training);
void set_epsilon(const float epsilon);
void set_format(const Format &format);
void set_momentum(const float momentum);
bool get_is_training() const;
float get_epsilon() const;
Format get_format() const;
float get_momentum() const;
};
class MIND_API BatchNormWithActivation : public BatchNorm {
public:
MIND_API_BASE_MEMBER(BatchNormWithActivation);
BatchNormWithActivation() : BatchNorm(kNameBatchNormWithActivation) {
InitIOName({"x", "scale", "offset", "mean", "variance"},
{"y", "batch_mean", "batch_variance", "reserve_space_1", "reserve_space_2"});
}
};
class MIND_API BatchNormWithAddAndActivation : public BatchNorm {
public:
MIND_API_BASE_MEMBER(BatchNormWithAddAndActivation);
BatchNormWithAddAndActivation() : BatchNorm(kNameBatchNormWithAddAndActivation) {
InitIOName({"x", "scale", "offset", "mean", "variance", "z"},
{"y", "batch_mean", "batch_variance", "reserve_space_1", "reserve_space_2"});
}
};
MIND_API abstract::AbstractBasePtr BatchNormInferFunc(const abstract::AnalysisEnginePtr &,
const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_BatchNorm_H_