mindspore.ops.Primitive

class mindspore.ops.Primitive(name)[源代码]

Primitive是Python中算子原语的基类。

参数:

  • name (str) - 当前Primitive的名称。

样例:

>>> from mindspore.ops.primitive import prim_attr_register, Primitive
>>> add = Primitive('add')
>>>
>>> # or work with prim_attr_register:
>>> # init a Primitive class with attr1 and attr2
>>> class Add(Primitive):
...     @prim_attr_register
...     def __init__(self, attr1, attr2):
...         '''init for add'''
...     # check attr1 and attr2 or do some initializations
...     # init a Primitive obj with attr1=1 and attr2=2
>>> add = Add(attr1=1, attr2=2)
add_prim_attr(name, value)[源代码]

添加Primitive的属性。

参数:

  • name (str) - 属性名称。

  • value (Any) - 属性值。

样例:

>>> import mindspore.ops as ops
>>> a = ops.Add()
>>> a = a.add_prim_attr("attr",1)
>>> out = a.attrs["attr"]
>>> print(out)
1
check_elim(*args)[源代码]

检查是否可以消除此Primitive。有需要的子类可以重写该方法。

参数:

  • args (Primitive参数的类型) - 与当前Primitive的参数相同。

返回:

由两个元素组成的元组。第一个元素是指是否能在编译阶段计算Primitive,第二个元素是计算结果。

样例:

>>> from mindspore.ops.primitive import prim_attr_register, Primitive
>>> from mindspore import Tensor
>>> import numpy as np
>>> class AddN(Primitive):
...     @prim_attr_register
...     def __init__(self):
...         self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
...     def check_elim(self, inputs):
...         if len(inputs) != 1:
...             return (False, None)
...         if isinstance(inputs[0], Tensor):
...             return (True, inputs[0])
...
>>> addn = AddN()
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> output = addn.check_elim((input_x,))
>>> print(output)
(True, Tensor(shape=[3], dtype=Float32, value= [ 1.00000000e+00,  2.00000000e+00,  3.00000000e+00]))
del_prim_attr(name)[源代码]

删除Primitive的属性。

参数:

  • name (str) - 属性名称。

样例:

>>> import mindspore.ops as ops
>>> a = ops.Add()
>>> a = a.add_prim_attr("attr",1)
>>> a = a.del_prim_attr("attr")
>>> print(a.attrs)
{'input_names': ['x', 'y'], 'output_names' : ['output']}
init_prim_io_names(inputs, outputs)[源代码]

初始化Tensor或属性的输入输出的名称。

参数:

  • inputs (list[str]) - 输入名称的列表。

  • outputs (list[str]) - 输出名称的列表。

样例:

>>> import mindspore.ops as ops
>>> a = ops.Add()
>>> a.init_prim_io_names(["x","y"],["sum"])
>>> print(a.input_names)
['x','y']
>>> print(a.output_names)
['sum']
recompute(mode=True)[源代码]

设置Primitive的重计算属性。

如果有一个被设置了重计算属性的Primitive,并且其结果在计算导数的时候被使用,那么不会保存该Primitive在前向网络中的中间计算结果,而是在自动微分的时候重新进行计算。

Note

  • 如果计算涉及随机化或全局变量,则暂无法保证等效性。

  • 在PyNative模式下不支持。

参数:

  • mode (bool) - Primitive是否设置了重计算。默认值:True。

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Tensor, ops, nn
>>> class NetRecompute(nn.Cell):
...     def __init__(self):
...         super(NetRecompute,self).__init__()
...         self.relu = ops.ReLU().recompute()
...         self.sqrt = ops.Sqrt()
...     def construct(self, x):
...         out = self.relu(x)
...         return self.sqrt(out)
...
>>> class GradNet(nn.Cell):
...     def __init__(self, network):
...         super(GradNet,self).__init__()
...         self.network = network
...         self.grad = ops.GradOperation()
...     def construct(self, x):
...         g_out = self.grad(self.network)(x)
...         return g_out
...
>>> x = Tensor(np.array([-1,1]).astype(np.float32))
>>> net = NetRecompute()
>>> grad = GradNet(net)
>>> a = grad(x)
>>> print(a)
[0. 0.5]
set_prim_instance_name(instance_name)[源代码]

设置Primitive算子的实例的名称。

Note

当用户定义Primitive算子时,默认调用它。

参数:

  • instance_name (str) - 用户设置的Primitive算子的实例的名称。

样例:

>>> import mindspore.ops as ops
>>> a = ops.Add()
>>> a = a.set_prim_instance_name("add")
>>> print(a.instance_name)
add
set_stage(stage)[源代码]

将stage的ID添加到Primitive属性中。

Note

仅在半自动并行模式下有效。在其他并行模式下,请将其设置为0。

参数:

  • stage (int) - 当前stage的ID。

样例:

>>> from mindspore import ops
>>> add = ops.Add()
>>> print(add.set_stage(0))
Prim[Add]<stage=0>
shard(in_strategy=None, out_strategy=None)[源代码]

将切分策略添加到Primitive属性中。

Note

仅在半自动并行或自动并行模式下有效。在其他并行模式中,将忽略此处设置的策略。

参数:

  • in_strategy (tuple) - 描述算子输入的切分策略。默认值:None。

  • out_strategy (tuple) - 描述算子输出的切分策略,仅针对某些算子,如MatMul。默认值:None。

样例:

>>> from mindspore import ops
>>> add = ops.Add()
>>> print(add.shard(((1, 1), (1, 1))))
Prim[Add]<in_strategy=((1, 1), (1, 1)), out_strategy=None>
property update_parameter

判断此Primitive是否会更新参数的值。