Program Listing for File model_parallel_runner.h

Return to documentation for file (include/runtime/include/api/model_parallel_runner.h)

#ifndef MINDSPORE_INCLUDE_API_MODEL_PARALLEL_RUNNER_H
#define MINDSPORE_INCLUDE_API_MODEL_PARALLEL_RUNNER_H
#include <vector>
#include <memory>
#include <utility>
#include <map>
#include <string>
#include "include/api/status.h"
#include "include/api/context.h"
namespace mindspore {
class MS_API RunnerConfig {
 public:
  struct Data;
  RunnerConfig();
  ~RunnerConfig();

  void SetWorkersNum(int32_t workers_num);

  int32_t GetWorkersNum() const;

  void SetContext(const std::shared_ptr<Context> &context);

  std::shared_ptr<Context> GetContext() const;

  inline void SetConfigInfo(const std::string &section, const std::map<std::string, std::string> &config);

  inline std::map<std::string, std::map<std::string, std::string>> GetConfigInfo() const;

  inline void SetConfigPath(const std::string &config_path);

  inline std::string GetConfigPath() const;

 private:
  void SetConfigInfo(const std::vector<char> &section, const std::map<std::vector<char>, std::vector<char>> &config);
  std::map<std::vector<char>, std::map<std::vector<char>, std::vector<char>>> GetConfigInfoChar() const;
  void SetConfigPath(const std::vector<char> &config_path);
  std::vector<char> GetConfigPathChar() const;
  std::shared_ptr<Data> data_ = nullptr;
};

void RunnerConfig::SetConfigInfo(const std::string &section, const std::map<std::string, std::string> &config) {
  SetConfigInfo(StringToChar(section), MapStringToVectorChar(config));
}

std::map<std::string, std::map<std::string, std::string>> RunnerConfig::GetConfigInfo() const {
  return MapMapCharToString(GetConfigInfoChar());
}

void RunnerConfig::SetConfigPath(const std::string &config_path) { SetConfigPath(StringToChar(config_path)); }

std::string RunnerConfig::GetConfigPath() const { return CharToString(GetConfigPathChar()); }

class ModelParallelRunnerImpl;

class MS_API ModelParallelRunner {
 public:
  ModelParallelRunner();
  ~ModelParallelRunner();

  inline Status Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config = nullptr);

  Status Init(const void *model_data, const size_t data_size,
              const std::shared_ptr<RunnerConfig> &runner_config = nullptr);

  std::vector<MSTensor> GetInputs();

  std::vector<MSTensor> GetOutputs();

  Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
                 const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);

 private:
  Status Init(const std::vector<char> &model_path, const std::shared_ptr<RunnerConfig> &runner_config);
  std::shared_ptr<ModelParallelRunnerImpl> model_parallel_runner_impl_ = nullptr;
};

Status ModelParallelRunner::Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config) {
  return Init(StringToChar(model_path), runner_config);
}
}  // namespace mindspore
#endif  // MINDSPORE_INCLUDE_API_MODEL_PARALLEL_RUNNER_H