Program Listing for File model.h
↰ Return to documentation for file (include/model.h
)
#ifndef MINDSPORE_INCLUDE_API_MODEL_H
#define MINDSPORE_INCLUDE_API_MODEL_H
#include <string>
#include <vector>
#include <map>
#include <memory>
#include <utility>
#include "include/api/status.h"
#include "include/api/types.h"
#include "include/api/graph.h"
#include "include/api/context.h"
#include "include/api/callback/callback.h"
#include "include/api/cell.h"
#include "include/api/cfg.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore {
class ModelImpl;
class Metrics;
namespace dataset {
class Dataset;
} // namespace dataset
class Model {
public:
Model();
~Model();
Model(const Model &) = delete;
void operator=(const Model &) = delete;
Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr,
const std::shared_ptr<TrainCfg> &train_cfg = nullptr);
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
Status PredictWithPreprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
Status Preprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
bool HasPreprocess();
inline Status LoadConfig(const std::string &config_path);
std::vector<MSTensor> GetInputs();
inline MSTensor GetInputByTensorName(const std::string &tensor_name);
std::vector<MSTensor> GetGradients() const;
Status ApplyGradients(const std::vector<MSTensor> &gradients);
std::vector<MSTensor> GetOptimizerParams() const;
Status SetOptimizerParams(const std::vector<MSTensor> ¶ms);
Status InitMetrics(std::vector<Metrics *> metrics);
std::vector<Metrics *> GetMetrics();
std::vector<MSTensor> GetOutputs();
inline std::vector<std::string> GetOutputTensorNames();
inline MSTensor GetOutputByTensorName(const std::string &tensor_name);
inline std::vector<MSTensor> GetOutputsByNodeName(const std::string &node_name);
static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type);
Status SetTrainMode(bool train);
bool GetTrainMode() const;
Status Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);
Status Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);
inline Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
inline Status Build(const std::string &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
private:
friend class Serialization;
// api without std::string
MSTensor GetInputByTensorName(const std::vector<char> &tensor_name);
std::vector<std::vector<char>> GetOutputTensorNamesChar();
MSTensor GetOutputByTensorName(const std::vector<char> &tensor_name);
std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name);
Status LoadConfig(const std::vector<char> &config_path);
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::vector<char> &dec_mode);
Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::vector<char> &dec_mode);
std::shared_ptr<ModelImpl> impl_;
};
MSTensor Model::GetInputByTensorName(const std::string &tensor_name) {
return GetInputByTensorName(StringToChar(tensor_name));
}
std::vector<std::string> Model::GetOutputTensorNames() { return VectorCharToString(GetOutputTensorNamesChar()); }
MSTensor Model::GetOutputByTensorName(const std::string &tensor_name) {
return GetOutputByTensorName(StringToChar(tensor_name));
}
std::vector<MSTensor> Model::GetOutputsByNodeName(const std::string &node_name) {
return GetOutputsByNodeName(StringToChar(node_name));
}
Status Model::LoadConfig(const std::string &config_path) {
return LoadConfig(StringToChar(config_path));
}
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
return Build(model_data, data_size, model_type, model_context, dec_key, StringToChar(dec_mode));
}
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::string &dec_mode) {
return Build(StringToChar(model_path), model_type, model_context, dec_key, StringToChar(dec_mode));
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_H