mindspore.ops.PrimitiveWithInfer

class mindspore.ops.PrimitiveWithInfer(name)[source]

PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference in python.

There are four method can be override to define the infer logic of the primitive: __infer__(), infer_shape(), infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer logic of the shape and type. The infer_value() is used for constant propagation.

Parameters

name (str) – Name of the current Primitive.

Examples

>>> # init a Primitive class with infer
>>> class Add(PrimitiveWithInfer):
>>>     @prim_attr_register
>>>     def __init__(self):
>>>         pass
>>>
>>>     def infer_shape(self, x, y):
>>>         return x # output shape same as first input 'x'
>>>
>>>     def infer_dtype(self, x, y):
>>>         return x # output type same as first input 'x'
>>>
>>> # init a Primitive obj
>>> add = Add()
infer_dtype(*args)[source]

Infer output dtype based on input dtype.

Parameters

args (mindspore.dtype) – data type of inputs.

Returns

mindspore.dtype, data type of outputs.

infer_shape(*args)[source]

Infer output shape based on input shape.

Note

The shape of scalar is an empty tuple.

Parameters

args (tuple(int)) – shapes of input tensors.

Returns

tuple(int), shapes of output tensors.

infer_value(*args)[source]

Infer output value based on input value at compile time.

Parameters

args (Any) – value of inputs.

Returns

Value of outputs. Return None, the value can not be inferred at compile time in this case.