Program Listing for File context.h
↰ Return to documentation for file (include/context.h
)
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
#define MINDSPORE_INCLUDE_API_CONTEXT_H
#include <string>
#include <memory>
#include <vector>
#include <map>
#include "include/api/types.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore {
enum DeviceType {
kCPU = 0,
kGPU,
kKirinNPU,
kAscend910,
kAscend310,
kHexagonDSP = 6,
// add new type here
kInvalidDeviceType = 100,
};
class Allocator;
class Delegate;
class DeviceInfoContext;
class Context {
public:
struct Data;
Context();
~Context() = default;
void SetThreadNum(int32_t thread_num);
int32_t GetThreadNum() const;
void SetThreadAffinity(int mode);
int GetThreadAffinityMode() const;
void SetThreadAffinity(const std::vector<int> &core_list);
std::vector<int32_t> GetThreadAffinityCoreList() const;
void SetEnableParallel(bool is_parallel);
bool GetEnableParallel() const;
void SetDelegate(const std::shared_ptr<Delegate> &delegate);
std::shared_ptr<Delegate> GetDelegate() const;
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
private:
std::shared_ptr<Data> data_;
};
class DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> {
public:
struct Data;
DeviceInfoContext();
virtual ~DeviceInfoContext() = default;
virtual enum DeviceType GetDeviceType() const = 0;
template <class T>
std::shared_ptr<T> Cast() {
static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
if (GetDeviceType() != T().GetDeviceType()) {
return nullptr;
}
return std::static_pointer_cast<T>(shared_from_this());
}
inline std::string GetProvider() const;
inline void SetProvider(const std::string &provider);
inline std::string GetProviderDevice() const;
inline void SetProviderDevice(const std::string &device);
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
std::shared_ptr<Allocator> GetAllocator() const;
protected:
std::vector<char> GetProviderChar() const;
void SetProvider(const std::vector<char> &provider);
std::vector<char> GetProviderDeviceChar() const;
void SetProviderDevice(const std::vector<char> &device);
std::shared_ptr<Data> data_;
};
std::string DeviceInfoContext::GetProvider() const { return CharToString(GetProviderChar()); }
void DeviceInfoContext::SetProvider(const std::string &provider) { SetProvider(StringToChar(provider)); }
std::string DeviceInfoContext::GetProviderDevice() const { return CharToString(GetProviderDeviceChar()); }
void DeviceInfoContext::SetProviderDevice(const std::string &device) { SetProviderDevice(StringToChar(device)); }
class CPUDeviceInfo : public DeviceInfoContext {
public:
enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
void SetEnableFP16(bool is_fp16);
bool GetEnableFP16() const;
};
class KirinNPUDeviceInfo : public DeviceInfoContext {
public:
enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };
void SetFrequency(int frequency);
int GetFrequency() const;
};
class GPUDeviceInfo : public DeviceInfoContext {
public:
enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; };
void SetDeviceID(uint32_t device_id);
uint32_t GetDeviceID() const;
inline void SetPrecisionMode(const std::string &precision_mode);
inline std::string GetPrecisionMode() const;
void SetEnableFP16(bool is_fp16);
bool GetEnableFP16() const;
private:
void SetPrecisionMode(const std::vector<char> &precision_mode);
std::vector<char> GetPrecisionModeChar() const;
};
void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
SetPrecisionMode(StringToChar(precision_mode));
}
std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
class Ascend910DeviceInfo : public DeviceInfoContext {
public:
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
void SetDeviceID(uint32_t device_id);
uint32_t GetDeviceID() const;
};
class Ascend310DeviceInfo : public DeviceInfoContext {
public:
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; };
void SetDeviceID(uint32_t device_id);
uint32_t GetDeviceID() const;
inline void SetInsertOpConfigPath(const std::string &cfg_path);
inline std::string GetInsertOpConfigPath() const;
inline void SetInputFormat(const std::string &format);
inline std::string GetInputFormat() const;
inline void SetInputShape(const std::string &shape);
inline std::string GetInputShape() const;
void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
std::map<int, std::vector<int>> GetInputShapeMap() const;
void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
inline std::string GetDynamicBatchSize() const;
void SetOutputType(enum DataType output_type);
enum DataType GetOutputType() const;
inline void SetPrecisionMode(const std::string &precision_mode);
inline std::string GetPrecisionMode() const;
inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
inline std::string GetOpSelectImplMode() const;
inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
inline std::string GetFusionSwitchConfigPath() const;
// Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize"
inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode);
inline std::string GetBufferOptimizeMode() const;
private:
void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
std::vector<char> GetInsertOpConfigPathChar() const;
void SetInputFormat(const std::vector<char> &format);
std::vector<char> GetInputFormatChar() const;
void SetInputShape(const std::vector<char> &shape);
std::vector<char> GetInputShapeChar() const;
std::vector<char> GetDynamicBatchSizeChar() const;
void SetPrecisionMode(const std::vector<char> &precision_mode);
std::vector<char> GetPrecisionModeChar() const;
void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode);
std::vector<char> GetOpSelectImplModeChar() const;
void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path);
std::vector<char> GetFusionSwitchConfigPathChar() const;
void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode);
std::vector<char> GetBufferOptimizeModeChar() const;
};
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
SetInsertOpConfigPath(StringToChar(cfg_path));
}
std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
SetPrecisionMode(StringToChar(precision_mode));
}
std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
SetOpSelectImplMode(StringToChar(op_select_impl_mode));
}
std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
SetFusionSwitchConfigPath(StringToChar(cfg_path));
}
std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const {
return CharToString(GetFusionSwitchConfigPathChar());
}
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
}
std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
class HexagonDspDeviceInfo : public DeviceInfoContext {
public:
enum DeviceType GetDeviceType() const override { return DeviceType::kHexagonDSP; };
};
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H