mindspore.ops.PrimitiveWithInfer

View Source On Gitee
class mindspore.ops.PrimitiveWithInfer(name)[source]

PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference in python.

There are four method can be overridden to define the infer logic of the primitive: __infer__(), infer_shape(), infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has the highest priority to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer logic of the shape and type. The infer_value() is used for constant propagation.

Parameters

name (str) – Name of the current Primitive.

Supported Platforms:

Ascend GPU CPU

Examples

>>> from mindspore.ops import prim_attr_register, PrimitiveWithInfer
>>> # init a Primitive class with infer
>>> class Add(PrimitiveWithInfer):
...     @prim_attr_register
...     def __init__(self):
...         pass
...
...     def infer_shape(self, x, y):
...         return x # output shape same as first input 'x'
...
...     def infer_dtype(self, x, y):
...         return x # output type same as first input 'x'
...
>>> # init a Primitive obj
>>> add = Add()