Program Listing for File cfg.h

Return to documentation for file (include/cfg.h)

#ifndef MINDSPORE_INCLUDE_API_CFG_H
#define MINDSPORE_INCLUDE_API_CFG_H

#include <cstddef>
#include <string>
#include <vector>
#include <memory>
#include "include/api/data_type.h"
#include "include/api/dual_abi_helper.h"

namespace mindspore {

class MixPrecisionCfg {
 public:
  MixPrecisionCfg() {
    this->dynamic_loss_scale_ = false;
    this->loss_scale_ = 128.0f;
    this->num_of_not_nan_iter_th_ = 1000;
  }

  bool dynamic_loss_scale_ = false;
  float loss_scale_;
  uint32_t num_of_not_nan_iter_th_;
  bool is_raw_mix_precision_ = false;
};

class TrainCfg {
 public:
  TrainCfg() { this->loss_name_ = "_loss_fn"; }

  OptimizationLevel optimization_level_ = kO0;
  std::string loss_name_;
  MixPrecisionCfg mix_precision_cfg_;
  bool accumulate_gradients_ = false;
};

}  // namespace mindspore
#endif  // MINDSPORE_INCLUDE_API_CFG_H