mindspore_lite.GPUDeviceInfo

class mindspore_lite.GPUDeviceInfo(device_id=0, enable_fp16=False)[源代码]

用于描述GPU设备硬件信息的辅助类,继承 mindspore_lite.DeviceInfo 基类。

参数:
  • device_id (int,可选) - 设备id。默认值:0。

  • enable_fp16 (bool,可选) - 启用以执行Float16推理。默认值:False。

异常:
  • TypeError - device_id 不是int类型。

  • TypeError - enable_fp16 不是bool类型。

  • ValueError - device_id 小于0。

样例:

>>> # Use case: inference on GPU device.
>>> # precondition 1: Building MindSpore Lite GPU package by export MSLITE_GPU_BACKEND=tensorrt.
>>> # precondition 2: install wheel package of MindSpore Lite built by precondition 1.
>>> import mindspore_lite as mslite
>>> gpu_device_info = mslite.GPUDeviceInfo(device_id=1, enable_fp16=False)
>>> print(gpu_device_info)
device_type: DeviceType.kGPU,
device_id: 1,
enable_fp16: False.
>>> cpu_device_info = mslite.CPUDeviceInfo(enable_fp16=False)
>>> context = mslite.Context()
>>> context.append_device_info(gpu_device_info)
>>> context.append_device_info(cpu_device_info)
>>> print(context)
thread_num: 0,
inter_op_parallel_num: 0,
thread_affinity_mode: 0,
thread_affinity_core_list: [],
enable_parallel: False,
device_list: 1, 0, .
get_group_size()[源代码]

从上下文获取集群数量。

返回:

int,集群数量。

样例:

>>> # Use case: inference on GPU device.
>>> # precondition 1: Building MindSpore Lite GPU package by export MSLITE_GPU_BACKEND=tensorrt.
>>> # precondition 2: install wheel package of MindSpore Lite built by precondition 1.
>>> import mindspore_lite as mslite
>>> device_info = mslite.GPUDeviceInfo(device_id=1, enable_fp16=True)
>>> group_size = device_info.get_group_size()
>>> print(group_size)
1
get_rank_id()[源代码]

从上下文获取当前设备在集群中的ID。

返回:

int,当前设备在集群中的ID,固定从0开始编号。

样例:

>>> # Use case: inference on GPU device.
>>> # precondition 1: Building MindSpore Lite GPU package by export MSLITE_GPU_BACKEND=tensorrt.
>>> # precondition 2: install wheel package of MindSpore Lite built by precondition 1.
>>> import mindspore_lite as mslite
>>> device_info = mslite.GPUDeviceInfo(device_id=1, enable_fp16=True)
>>> rank_id = device_info.get_rank_id()
>>> print(rank_id)
0