# Copyright 2022-2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Tensor API.
"""
from __future__ import absolute_import
from enum import Enum
import numpy
from mindspore_lite.lib import _c_lite_wrapper
__all__ = ['DataType', 'Format', 'Tensor']
[文档]class DataType(Enum):
"""
The `DataType` class defines the data type of the Tensor in MindSpore Lite.
Currently, the following 'DataType' are supported:
=========================== ==================================================================
Definition Description
=========================== ==================================================================
`DataType.UNKNOWN` No matching any of the following known types.
`DataType.BOOL` Boolean `True` or `False` .
`DataType.INT8` 8-bit integer.
`DataType.INT16` 16-bit integer.
`DataType.INT32` 32-bit integer.
`DataType.INT64` 64-bit integer.
`DataType.UINT8` unsigned 8-bit integer.
`DataType.UINT16` unsigned 16-bit integer.
`DataType.UINT32` unsigned 32-bit integer.
`DataType.UINT64` unsigned 64-bit integer.
`DataType.FLOAT16` 16-bit floating-point number.
`DataType.FLOAT32` 32-bit floating-point number.
`DataType.FLOAT64` 64-bit floating-point number.
`DataType.INVALID` The maximum threshold value of DataType to prevent invalid types.
=========================== ==================================================================
Examples:
>>> # Method 1: Import mindspore_lite package
>>> import mindspore_lite as mslite
>>> print(mslite.DataType.FLOAT32)
DataType.FLOAT32
>>> # Method 2: from mindspore_lite package import DataType
>>> from mindspore_lite import DataType
>>> print(DataType.FLOAT32)
DataType.FLOAT32
"""
UNKNOWN = 0
BOOL = 30
INT8 = 32
INT16 = 33
INT32 = 34
INT64 = 35
UINT8 = 37
UINT16 = 38
UINT32 = 39
UINT64 = 40
FLOAT16 = 42
FLOAT32 = 43
FLOAT64 = 44
INVALID = 2147483647 # INT32_MAX
data_type_py_cxx_map = {
DataType.UNKNOWN: _c_lite_wrapper.DataType.kTypeUnknown,
DataType.BOOL: _c_lite_wrapper.DataType.kNumberTypeBool,
DataType.INT8: _c_lite_wrapper.DataType.kNumberTypeInt8,
DataType.INT16: _c_lite_wrapper.DataType.kNumberTypeInt16,
DataType.INT32: _c_lite_wrapper.DataType.kNumberTypeInt32,
DataType.INT64: _c_lite_wrapper.DataType.kNumberTypeInt64,
DataType.UINT8: _c_lite_wrapper.DataType.kNumberTypeUInt8,
DataType.UINT16: _c_lite_wrapper.DataType.kNumberTypeUInt16,
DataType.UINT32: _c_lite_wrapper.DataType.kNumberTypeUInt32,
DataType.UINT64: _c_lite_wrapper.DataType.kNumberTypeUInt64,
DataType.FLOAT16: _c_lite_wrapper.DataType.kNumberTypeFloat16,
DataType.FLOAT32: _c_lite_wrapper.DataType.kNumberTypeFloat32,
DataType.FLOAT64: _c_lite_wrapper.DataType.kNumberTypeFloat64,
DataType.INVALID: _c_lite_wrapper.DataType.kInvalidType,
}
data_type_cxx_py_map = {
_c_lite_wrapper.DataType.kTypeUnknown: DataType.UNKNOWN,
_c_lite_wrapper.DataType.kNumberTypeBool: DataType.BOOL,
_c_lite_wrapper.DataType.kNumberTypeInt8: DataType.INT8,
_c_lite_wrapper.DataType.kNumberTypeInt16: DataType.INT16,
_c_lite_wrapper.DataType.kNumberTypeInt32: DataType.INT32,
_c_lite_wrapper.DataType.kNumberTypeInt64: DataType.INT64,
_c_lite_wrapper.DataType.kNumberTypeUInt8: DataType.UINT8,
_c_lite_wrapper.DataType.kNumberTypeUInt16: DataType.UINT16,
_c_lite_wrapper.DataType.kNumberTypeUInt32: DataType.UINT32,
_c_lite_wrapper.DataType.kNumberTypeUInt64: DataType.UINT64,
_c_lite_wrapper.DataType.kNumberTypeFloat16: DataType.FLOAT16,
_c_lite_wrapper.DataType.kNumberTypeFloat32: DataType.FLOAT32,
_c_lite_wrapper.DataType.kNumberTypeFloat64: DataType.FLOAT64,
_c_lite_wrapper.DataType.kInvalidType: DataType.INVALID,
}
format_py_cxx_map = {
Format.DEFAULT: _c_lite_wrapper.Format.DEFAULT_FORMAT,
Format.NCHW: _c_lite_wrapper.Format.NCHW,
Format.NHWC: _c_lite_wrapper.Format.NHWC,
Format.NHWC4: _c_lite_wrapper.Format.NHWC4,
Format.HWKC: _c_lite_wrapper.Format.HWKC,
Format.HWCK: _c_lite_wrapper.Format.HWCK,
Format.KCHW: _c_lite_wrapper.Format.KCHW,
Format.CKHW: _c_lite_wrapper.Format.CKHW,
Format.KHWC: _c_lite_wrapper.Format.KHWC,
Format.CHWK: _c_lite_wrapper.Format.CHWK,
Format.HW: _c_lite_wrapper.Format.HW,
Format.HW4: _c_lite_wrapper.Format.HW4,
Format.NC: _c_lite_wrapper.Format.NC,
Format.NC4: _c_lite_wrapper.Format.NC4,
Format.NC4HW4: _c_lite_wrapper.Format.NC4HW4,
Format.NCDHW: _c_lite_wrapper.Format.NCDHW,
Format.NWC: _c_lite_wrapper.Format.NWC,
Format.NCW: _c_lite_wrapper.Format.NCW,
Format.NDHWC: _c_lite_wrapper.Format.NDHWC,
Format.NC8HW8: _c_lite_wrapper.Format.NC8HW8,
}
format_cxx_py_map = {
_c_lite_wrapper.Format.DEFAULT_FORMAT: Format.DEFAULT,
_c_lite_wrapper.Format.NCHW: Format.NCHW,
_c_lite_wrapper.Format.NHWC: Format.NHWC,
_c_lite_wrapper.Format.NHWC4: Format.NHWC4,
_c_lite_wrapper.Format.HWKC: Format.HWKC,
_c_lite_wrapper.Format.HWCK: Format.HWCK,
_c_lite_wrapper.Format.KCHW: Format.KCHW,
_c_lite_wrapper.Format.CKHW: Format.CKHW,
_c_lite_wrapper.Format.KHWC: Format.KHWC,
_c_lite_wrapper.Format.CHWK: Format.CHWK,
_c_lite_wrapper.Format.HW: Format.HW,
_c_lite_wrapper.Format.HW4: Format.HW4,
_c_lite_wrapper.Format.NC: Format.NC,
_c_lite_wrapper.Format.NC4: Format.NC4,
_c_lite_wrapper.Format.NC4HW4: Format.NC4HW4,
_c_lite_wrapper.Format.NCDHW: Format.NCDHW,
_c_lite_wrapper.Format.NWC: Format.NWC,
_c_lite_wrapper.Format.NCW: Format.NCW,
_c_lite_wrapper.Format.NDHWC: Format.NDHWC,
_c_lite_wrapper.Format.NC8HW8: Format.NC8HW8,
}
[文档]class Tensor:
"""
The `Tensor` class defines a Tensor in MindSpore Lite.
Args:
tensor(Tensor, optional): The data to be stored in a new Tensor. It can be from another Tensor. Default: None.
Raises:
TypeError: `tensor` is neither a Tensor nor None.
Examples:
>>> import mindspore_lite as mslite
>>> tensor = mslite.Tensor()
>>> tensor.name = "tensor1"
>>> print(tensor.name)
tensor1
>>> tensor.dtype = mslite.DataType.FLOAT32
>>> print(tensor.dtype)
DataType.FLOAT32
>>> tensor.shape = [1, 3, 2, 2]
>>> print(tensor.shape)
[1, 3, 2, 2]
>>> tensor.format = mslite.Format.NCHW
>>> print(tensor.format)
Format.NCHW
>>> print(tensor.element_num)
12
>>> print(tensor.data_size)
48
>>> print(tensor)
name: tensor1,
dtype: DataType.FLOAT32,
shape: [1, 3, 2, 2],
format: Format.NCHW,
element_num: 12,
data_size: 48.
"""
def __init__(self, tensor=None):
if tensor is not None:
if not isinstance(tensor, _c_lite_wrapper.TensorBind):
raise TypeError(f"tensor must be MindSpore Lite's Tensor._tensor, but got {type(tensor)}.")
self._tensor = tensor
else:
self._tensor = _c_lite_wrapper.create_tensor()
def __str__(self):
res = f"name: {self.name},\n" \
f"dtype: {self.dtype},\n" \
f"shape: {self.shape},\n" \
f"format: {self.format},\n" \
f"element_num: {self.element_num},\n" \
f"data_size: {self.data_size}."
return res
@property
def data_size(self):
"""
Get the data size of the Tensor.
Data size of the Tensor = the element num of the Tensor * size of unit data type of the Tensor.
Returns:
int, the data size of the Tensor data.
"""
return self._tensor.get_data_size()
@property
def dtype(self):
"""
Get the data type of the Tensor.
Returns:
DataType, the data type of the Tensor.
"""
return data_type_cxx_py_map.get(self._tensor.get_data_type())
@dtype.setter
def dtype(self, dtype):
"""
Set data type for the Tensor.
Args:
dtype (DataType): The data type of the Tensor. For details, see
`DataType <https://mindspore.cn/lite/api/en/r2.0/mindspore_lite/mindspore_lite.DataType.html>`_ .
Raises:
TypeError: `dtype` is not a DataType.
"""
if not isinstance(dtype, DataType):
raise TypeError(f"dtype must be DataType, but got {type(dtype)}.")
self._tensor.set_data_type(data_type_py_cxx_map.get(dtype))
@property
def element_num(self):
"""
Get the element num of the Tensor.
Returns:
int, the element num of the Tensor data.
"""
return self._tensor.get_element_num()
@property
def format(self):
"""
Get the format of the Tensor.
Returns:
Format, the format of the Tensor.
"""
return format_cxx_py_map.get(self._tensor.get_format())
@format.setter
def format(self, tensor_format):
"""
Set format of the Tensor.
Args:
tensor_format (Format): The format of the Tensor. For details, see
`Format <https://mindspore.cn/lite/api/en/r2.0/mindspore_lite/mindspore_lite.Format.html>`_ .
Raises:
TypeError: `tensor_format` is not a Format.
"""
if not isinstance(tensor_format, Format):
raise TypeError(f"format must be Format, but got {type(tensor_format)}.")
self._tensor.set_format(format_py_cxx_map.get(tensor_format))
@property
def name(self):
"""
Get the name of the Tensor.
Returns:
str, the name of the Tensor.
"""
return self._tensor.get_tensor_name()
@name.setter
def name(self, name):
"""
Set the name of the Tensor.
Args:
name (str): The name of the Tensor.
Raises:
TypeError: `name` is not a str.
"""
if not isinstance(name, str):
raise TypeError(f"name must be str, but got {type(name)}.")
self._tensor.set_tensor_name(name)
@property
def shape(self):
"""
Get the shape of the Tensor.
Returns:
list[int], the shape of the Tensor.
"""
return self._tensor.get_shape()
@shape.setter
def shape(self, shape):
"""
Set shape for the Tensor.
Args:
shape (list[int]): The shape of the Tensor.
Raises:
TypeError: `shape` is not a list.
TypeError: `shape` is a list, but the elements is not int.
"""
if not isinstance(shape, list):
raise TypeError(f"shape must be list, but got {type(shape)}.")
for i, element in enumerate(shape):
if not isinstance(element, int):
raise TypeError(f"shape element must be int, but got {type(element)} at index {i}.")
self._tensor.set_shape(shape)
[文档] def get_data_to_numpy(self):
"""
Get the data from the Tensor to the numpy object.
Returns:
numpy.ndarray, the numpy object from Tensor data.
Examples:
>>> import mindspore_lite as mslite
>>> import numpy as np
>>> tensor = mslite.Tensor()
>>> tensor.shape = [1, 3, 2, 2]
>>> tensor.dtype = mslite.DataType.FLOAT32
>>> in_data = np.arange(1 * 3 * 2 * 2, dtype=np.float32)
>>> tensor.set_data_from_numpy(in_data)
>>> data = tensor.get_data_to_numpy()
>>> print(data)
[[[[ 0. 1.]
[ 2. 3.]]
[[ 4. 5.]
[ 6. 7.]]
[[ 8. 9.]
[ 10. 11.]]]]
"""
return self._tensor.get_data_to_numpy()
[文档] def set_data_from_numpy(self, numpy_obj):
"""
Set the data for the Tensor from the numpy object.
Args:
numpy_obj(numpy.ndarray): the numpy object.
Raises:
TypeError: `numpy_obj` is not a numpy.ndarray.
RuntimeError: The data type of `numpy_obj` is not equivalent to the data type of the Tensor.
RuntimeError: The data size of `numpy_obj` is not equal to the data size of the Tensor.
Examples:
>>> # 1. set Tensor data which is from file
>>> import mindspore_lite as mslite
>>> import numpy as np
>>> tensor = mslite.Tensor()
>>> tensor.shape = [1, 3, 224, 224]
>>> tensor.dtype = mslite.DataType.FLOAT32
>>> in_data = np.fromfile("input.bin", dtype=np.float32)
>>> tensor.set_data_from_numpy(in_data)
>>> print(tensor)
name: ,
dtype: DataType.FLOAT32,
shape: [1, 3, 224, 224],
format: Format.NCHW,
element_num: 150528,
data_size: 602112.
>>> # 2. set Tensor data which is numpy arange
>>> import mindspore_lite as mslite
>>> import numpy as np
>>> tensor = mslite.Tensor()
>>> tensor.shape = [1, 3, 2, 2]
>>> tensor.dtype = mslite.DataType.FLOAT32
>>> in_data = np.arange(1 * 3 * 2 * 2, dtype=np.float32)
>>> tensor.set_data_from_numpy(in_data)
>>> print(tensor)
name: ,
dtype: DataType.FLOAT32,
shape: [1, 3, 2, 2],
format: Format.NCHW,
element_num: 12,
data_size: 48.
"""
if not isinstance(numpy_obj, numpy.ndarray):
raise TypeError(f"numpy_obj must be numpy.ndarray, but got {type(numpy_obj)}.")
if not numpy_obj.flags['FORC']:
numpy_obj = numpy.ascontiguousarray(numpy_obj)
data_type_map = {
numpy.bool_: DataType.BOOL,
numpy.int8: DataType.INT8,
numpy.int16: DataType.INT16,
numpy.int32: DataType.INT32,
numpy.int64: DataType.INT64,
numpy.uint8: DataType.UINT8,
numpy.uint16: DataType.UINT16,
numpy.uint32: DataType.UINT32,
numpy.uint64: DataType.UINT64,
numpy.float16: DataType.FLOAT16,
numpy.float32: DataType.FLOAT32,
numpy.float64: DataType.FLOAT64,
}
if data_type_map.get(numpy_obj.dtype.type) != self.dtype:
raise RuntimeError(
f"data type not equal! Numpy type: {numpy_obj.dtype.type}, Tensor type: {self.dtype}")
if numpy_obj.nbytes != self.data_size:
raise RuntimeError(
f"data size not equal! Numpy size: {numpy_obj.nbytes}, Tensor size: {self.data_size}")
self._tensor.set_data_from_numpy(numpy_obj)