mindspore.ops.PrimitiveWithInfer
- class mindspore.ops.PrimitiveWithInfer(name)[source]
PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference in python.
There are four method can be override to define the infer logic of the primitive: __infer__(), infer_shape(), infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer logic of the shape and type. The infer_value() is used for constant propagation.
- Parameters
name (str) – Name of the current Primitive.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> # init a Primitive class with infer >>> class Add(PrimitiveWithInfer): ... @prim_attr_register ... def __init__(self): ... pass ... ... def infer_shape(self, x, y): ... return x # output shape same as first input 'x' ... ... def infer_dtype(self, x, y): ... return x # output type same as first input 'x' ... >>> # init a Primitive obj >>> add = Add()
- infer_dtype(*args)[source]
Infer output dtype based on input dtype.
- Parameters
args (
mindspore.dtype
) – data type of inputs.- Returns
mindspore.dtype
, data type of outputs.