Program Listing for File model.h
↰ Return to documentation for file (include/include/api/model.h
)
#ifndef MINDSPORE_INCLUDE_API_MODEL_H
#define MINDSPORE_INCLUDE_API_MODEL_H
#include <string>
#include <vector>
#include <map>
#include <memory>
#include <utility>
#include "include/api/status.h"
#include "include/api/types.h"
#include "include/api/graph.h"
#include "include/api/context.h"
#include "include/api/callback/callback.h"
#include "include/api/cell.h"
#include "include/api/cfg.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore {
class ModelImpl;
class Metrics;
namespace dataset {
class Dataset;
} // namespace dataset
class Model {
public:
Model();
~Model();
Model(const Model &) = delete;
void operator=(const Model &) = delete;
Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr,
const std::shared_ptr<TrainCfg> &train_cfg = nullptr);
Status BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr<Context> &context,
const std::shared_ptr<TrainCfg> &train_cfg = nullptr);
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
Status UpdateWeights(const std::vector<MSTensor> &new_weights);
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
Status RunStep(const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
Status PredictWithPreprocess(const std::vector<std::vector<MSTensor>> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
Status Preprocess(const std::vector<std::vector<MSTensor>> &inputs, std::vector<MSTensor> *outputs);
bool HasPreprocess();
inline Status LoadConfig(const std::string &config_path);
inline Status UpdateConfig(const std::string §ion, const std::pair<std::string, std::string> &config);
std::vector<MSTensor> GetInputs();
inline MSTensor GetInputByTensorName(const std::string &tensor_name);
std::vector<MSTensor> GetGradients() const;
Status ApplyGradients(const std::vector<MSTensor> &gradients);
std::vector<MSTensor> GetFeatureMaps() const;
Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);
std::vector<MSTensor> GetOptimizerParams() const;
Status SetOptimizerParams(const std::vector<MSTensor> ¶ms);
Status SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f);
Status SetLearningRate(float learning_rate);
float GetLearningRate();
Status InitMetrics(std::vector<Metrics *> metrics);
std::vector<Metrics *> GetMetrics();
std::vector<MSTensor> GetOutputs();
inline std::vector<std::string> GetOutputTensorNames();
inline MSTensor GetOutputByTensorName(const std::string &tensor_name);
inline std::vector<MSTensor> GetOutputsByNodeName(const std::string &node_name);
Status BindGLTexture2DMemory(const std::map<std::string, unsigned int> &inputGLTexture,
std::map<std::string, unsigned int> *outputGLTexture);
static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type);
Status SetTrainMode(bool train);
bool GetTrainMode() const;
Status Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);
Status Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);
inline Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
inline Status Build(const std::string &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
private:
friend class Serialization;
// api without std::string
MSTensor GetInputByTensorName(const std::vector<char> &tensor_name);
std::vector<std::vector<char>> GetOutputTensorNamesChar();
MSTensor GetOutputByTensorName(const std::vector<char> &tensor_name);
std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name);
Status LoadConfig(const std::vector<char> &config_path);
Status UpdateConfig(const std::vector<char> §ion, const std::pair<std::vector<char>, std::vector<char>> &config);
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::vector<char> &dec_mode);
Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::vector<char> &dec_mode);
std::shared_ptr<ModelImpl> impl_;
};
MSTensor Model::GetInputByTensorName(const std::string &tensor_name) {
return GetInputByTensorName(StringToChar(tensor_name));
}
std::vector<std::string> Model::GetOutputTensorNames() { return VectorCharToString(GetOutputTensorNamesChar()); }
MSTensor Model::GetOutputByTensorName(const std::string &tensor_name) {
return GetOutputByTensorName(StringToChar(tensor_name));
}
std::vector<MSTensor> Model::GetOutputsByNodeName(const std::string &node_name) {
return GetOutputsByNodeName(StringToChar(node_name));
}
Status Model::LoadConfig(const std::string &config_path) { return LoadConfig(StringToChar(config_path)); }
Status Model::UpdateConfig(const std::string §ion, const std::pair<std::string, std::string> &config) {
std::pair<std::vector<char>, std::vector<char>> config_pair = {StringToChar(config.first),
StringToChar(config.second)};
return UpdateConfig(StringToChar(section), config_pair);
}
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
return Build(model_data, data_size, model_type, model_context, dec_key, StringToChar(dec_mode));
}
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::string &dec_mode) {
return Build(StringToChar(model_path), model_type, model_context, dec_key, StringToChar(dec_mode));
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_H