Program Listing for File context.h

Return to documentation for file (include/context.h)

#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
#define MINDSPORE_INCLUDE_API_CONTEXT_H

#include <string>
#include <memory>
#include <vector>
#include <map>
#include "include/api/types.h"
#include "include/api/dual_abi_helper.h"

namespace mindspore {
enum DeviceType {
  kCPU = 0,
  kGPU,
  kKirinNPU,
  kAscend910,
  kAscend310,
  kHexagonDSP = 6,
  // add new type here
  kInvalidDeviceType = 100,
};

class Allocator;
class Delegate;
class DeviceInfoContext;

class  Context {
 public:
  struct Data;
  Context();
  ~Context() = default;

  void SetThreadNum(int32_t thread_num);

  int32_t GetThreadNum() const;

  void SetThreadAffinity(int mode);

  int GetThreadAffinityMode() const;

  void SetThreadAffinity(const std::vector<int> &core_list);

  std::vector<int32_t> GetThreadAffinityCoreList() const;

  void SetEnableParallel(bool is_parallel);

  bool GetEnableParallel() const;

  void SetDelegate(const std::shared_ptr<Delegate> &delegate);

  std::shared_ptr<Delegate> GetDelegate() const;

  std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();

 private:
  std::shared_ptr<Data> data_;
};

class  DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> {
 public:
  struct Data;

  DeviceInfoContext();
  virtual ~DeviceInfoContext() = default;

  virtual enum DeviceType GetDeviceType() const = 0;

  template <class T>
  std::shared_ptr<T> Cast() {
    static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
    if (GetDeviceType() != T().GetDeviceType()) {
      return nullptr;
    }

    return std::static_pointer_cast<T>(shared_from_this());
  }
  inline std::string GetProvider() const;

  inline void SetProvider(const std::string &provider);

  inline std::string GetProviderDevice() const;
  inline void SetProviderDevice(const std::string &device);

  void SetAllocator(const std::shared_ptr<Allocator> &allocator);

  std::shared_ptr<Allocator> GetAllocator() const;

 protected:
  std::vector<char> GetProviderChar() const;
  void SetProvider(const std::vector<char> &provider);
  std::vector<char> GetProviderDeviceChar() const;
  void SetProviderDevice(const std::vector<char> &device);

  std::shared_ptr<Data> data_;
};

std::string DeviceInfoContext::GetProvider() const { return CharToString(GetProviderChar()); }
void DeviceInfoContext::SetProvider(const std::string &provider) { SetProvider(StringToChar(provider)); }
std::string DeviceInfoContext::GetProviderDevice() const { return CharToString(GetProviderDeviceChar()); }
void DeviceInfoContext::SetProviderDevice(const std::string &device) { SetProviderDevice(StringToChar(device)); }

class  CPUDeviceInfo : public DeviceInfoContext {
 public:
  enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };

  void SetEnableFP16(bool is_fp16);

  bool GetEnableFP16() const;
};

class  KirinNPUDeviceInfo : public DeviceInfoContext {
 public:
  enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };

  void SetFrequency(int frequency);

  int GetFrequency() const;
};

class  GPUDeviceInfo : public DeviceInfoContext {
 public:
  enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; };

  void SetDeviceID(uint32_t device_id);

  uint32_t GetDeviceID() const;

  inline void SetPrecisionMode(const std::string &precision_mode);

  inline std::string GetPrecisionMode() const;

  void SetEnableFP16(bool is_fp16);

  bool GetEnableFP16() const;

 private:
  void SetPrecisionMode(const std::vector<char> &precision_mode);
  std::vector<char> GetPrecisionModeChar() const;
};

void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
  SetPrecisionMode(StringToChar(precision_mode));
}
std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }

class  Ascend910DeviceInfo : public DeviceInfoContext {
 public:
  enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };

  void SetDeviceID(uint32_t device_id);

  uint32_t GetDeviceID() const;
};

class  Ascend310DeviceInfo : public DeviceInfoContext {
 public:
  enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; };

  void SetDeviceID(uint32_t device_id);

  uint32_t GetDeviceID() const;

  inline void SetInsertOpConfigPath(const std::string &cfg_path);

  inline std::string GetInsertOpConfigPath() const;

  inline void SetInputFormat(const std::string &format);

  inline std::string GetInputFormat() const;

  inline void SetInputShape(const std::string &shape);

  inline std::string GetInputShape() const;

  void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);

  std::map<int, std::vector<int>> GetInputShapeMap() const;

  void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
  inline std::string GetDynamicBatchSize() const;

  void SetOutputType(enum DataType output_type);

  enum DataType GetOutputType() const;

  inline void SetPrecisionMode(const std::string &precision_mode);

  inline std::string GetPrecisionMode() const;

  inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);

  inline std::string GetOpSelectImplMode() const;

  inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
  inline std::string GetFusionSwitchConfigPath() const;

  // Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize"
  inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode);
  inline std::string GetBufferOptimizeMode() const;

 private:
  void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
  std::vector<char> GetInsertOpConfigPathChar() const;

  void SetInputFormat(const std::vector<char> &format);
  std::vector<char> GetInputFormatChar() const;

  void SetInputShape(const std::vector<char> &shape);
  std::vector<char> GetInputShapeChar() const;

  std::vector<char> GetDynamicBatchSizeChar() const;

  void SetPrecisionMode(const std::vector<char> &precision_mode);
  std::vector<char> GetPrecisionModeChar() const;

  void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode);
  std::vector<char> GetOpSelectImplModeChar() const;

  void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path);
  std::vector<char> GetFusionSwitchConfigPathChar() const;

  void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode);
  std::vector<char> GetBufferOptimizeModeChar() const;
};

void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
  SetInsertOpConfigPath(StringToChar(cfg_path));
}
std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }

void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }

void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }

std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }

void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
  SetPrecisionMode(StringToChar(precision_mode));
}
std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }

void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
  SetOpSelectImplMode(StringToChar(op_select_impl_mode));
}
std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }

void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
  SetFusionSwitchConfigPath(StringToChar(cfg_path));
}
std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const {
  return CharToString(GetFusionSwitchConfigPathChar());
}

void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
  SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
}
std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }

class  HexagonDspDeviceInfo : public DeviceInfoContext {
 public:
  enum DeviceType GetDeviceType() const override { return DeviceType::kHexagonDSP; };
};
}  // namespace mindspore
#endif  // MINDSPORE_INCLUDE_API_CONTEXT_H