Program Listing for File model.h

Return to documentation for file (include/include/api/model.h)

#ifndef MINDSPORE_INCLUDE_API_MODEL_H
#define MINDSPORE_INCLUDE_API_MODEL_H

#include <string>
#include <vector>
#include <map>
#include <memory>
#include <utility>
#include "include/api/status.h"
#include "include/api/types.h"
#include "include/api/graph.h"
#include "include/api/context.h"
#include "include/api/callback/callback.h"
#include "include/api/cell.h"
#include "include/api/cfg.h"
#include "include/api/dual_abi_helper.h"

namespace mindspore {
class ModelImpl;
class Metrics;

namespace dataset {
class Dataset;
}  // namespace dataset
class  Model {
 public:
  Model();
  ~Model();
  Model(const Model &) = delete;
  void operator=(const Model &) = delete;

  Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr,
               const std::shared_ptr<TrainCfg> &train_cfg = nullptr);

  Status BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr<Context> &context,
                               const std::shared_ptr<TrainCfg> &train_cfg = nullptr);

  Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);

  Status UpdateWeights(const std::vector<MSTensor> &new_weights);

  Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
                 const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);

  Status RunStep(const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);

  Status PredictWithPreprocess(const std::vector<std::vector<MSTensor>> &inputs, std::vector<MSTensor> *outputs,
                               const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);

  Status Preprocess(const std::vector<std::vector<MSTensor>> &inputs, std::vector<MSTensor> *outputs);

  bool HasPreprocess();

  inline Status LoadConfig(const std::string &config_path);

  inline Status UpdateConfig(const std::string &section, const std::pair<std::string, std::string> &config);

  std::vector<MSTensor> GetInputs();

  inline MSTensor GetInputByTensorName(const std::string &tensor_name);

  std::vector<MSTensor> GetGradients() const;

  Status ApplyGradients(const std::vector<MSTensor> &gradients);

  std::vector<MSTensor> GetFeatureMaps() const;

  Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);

  std::vector<MSTensor> GetOptimizerParams() const;

  Status SetOptimizerParams(const std::vector<MSTensor> &params);

  Status SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f);

  Status SetLearningRate(float learning_rate);

  float GetLearningRate();

  Status InitMetrics(std::vector<Metrics *> metrics);
  std::vector<Metrics *> GetMetrics();

  std::vector<MSTensor> GetOutputs();

  inline std::vector<std::string> GetOutputTensorNames();

  inline MSTensor GetOutputByTensorName(const std::string &tensor_name);

  inline std::vector<MSTensor> GetOutputsByNodeName(const std::string &node_name);


  Status BindGLTexture2DMemory(const std::map<std::string, unsigned int> &inputGLTexture,
                               std::map<std::string, unsigned int> *outputGLTexture);

  static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type);

  Status SetTrainMode(bool train);
  bool GetTrainMode() const;
  Status Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);
  Status Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);

  Status Build(const void *model_data, size_t data_size, ModelType model_type,
               const std::shared_ptr<Context> &model_context = nullptr);

  inline Status Build(const std::string &model_path, ModelType model_type,
                      const std::shared_ptr<Context> &model_context = nullptr);

  Status Build(const void *model_data, size_t data_size, ModelType model_type,
               const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode,
               const std::string &cropto_lib_path);

  inline Status Build(const std::string &model_path, ModelType model_type,
                      const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode,
                      const std::string &cropto_lib_path);

 private:
  friend class Serialization;
  // api without std::string
  MSTensor GetInputByTensorName(const std::vector<char> &tensor_name);
  std::vector<std::vector<char>> GetOutputTensorNamesChar();
  MSTensor GetOutputByTensorName(const std::vector<char> &tensor_name);
  std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name);
  Status LoadConfig(const std::vector<char> &config_path);
  Status UpdateConfig(const std::vector<char> &section, const std::pair<std::vector<char>, std::vector<char>> &config);
  Status Build(const std::vector<char> &model_path, ModelType model_type,
               const std::shared_ptr<Context> &model_context);
  Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
               const Key &dec_key, const std::string &dec_mode, const std::vector<char> &cropto_lib_path);
  std::shared_ptr<ModelImpl> impl_;
};

MSTensor Model::GetInputByTensorName(const std::string &tensor_name) {
  return GetInputByTensorName(StringToChar(tensor_name));
}

std::vector<std::string> Model::GetOutputTensorNames() { return VectorCharToString(GetOutputTensorNamesChar()); }

MSTensor Model::GetOutputByTensorName(const std::string &tensor_name) {
  return GetOutputByTensorName(StringToChar(tensor_name));
}

std::vector<MSTensor> Model::GetOutputsByNodeName(const std::string &node_name) {
  return GetOutputsByNodeName(StringToChar(node_name));
}

Status Model::LoadConfig(const std::string &config_path) { return LoadConfig(StringToChar(config_path)); }

Status Model::UpdateConfig(const std::string &section, const std::pair<std::string, std::string> &config) {
  std::pair<std::vector<char>, std::vector<char>> config_pair = {StringToChar(config.first),
                                                                 StringToChar(config.second)};
  return UpdateConfig(StringToChar(section), config_pair);
}

Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
                    const Key &dec_key, const std::string &dec_mode, const std::string &cropto_lib_path) {
  return Build(StringToChar(model_path), model_type, model_context, dec_key, dec_mode, StringToChar(cropto_lib_path));
}

Status Model::Build(const std::string &model_path, ModelType model_type,
                    const std::shared_ptr<Context> &model_context) {
  return Build(StringToChar(model_path), model_type, model_context);
}
}  // namespace mindspore
#endif  // MINDSPORE_INCLUDE_API_MODEL_H