mindspore.ops.silent_check.ASDBase

查看源文件
class mindspore.ops.silent_check.ASDBase(cls, *args, **kwargs)[源代码]

ASDBase 是 Python 中具有精度敏感检测特性的算子的基类。

参数:
  • cls (Primitive) - 需要增加精度敏感检测特性的原始算子。

  • args (tuple) - 传递给原始运算符的可变参数元组。

  • kwargs (dict) - 传递给原始运算符的可变参数字典。

支持平台:

Ascend

样例:

>>> from mindspore.ops.silent_check import ASDBase
>>> from mindspore.ops import LayerNorm as OriginLayerNorm
>>> class LayerNormASD(ASDBase):
...     def __init__(self, *args, **kwargs):
...         super().__init__(OriginLayerNorm, *args, **kwargs)
...         # init parameters for accuracy-sensitive detection by calling the base class method generate_params()
...         self.pre_val, self.min_val, self.max_val, self.cnt = self.generate_params()
...
...     def __call__(self, input_x, gamma, beta):
...         if self.enable_check:
...             # execute accuracy-sensitive detection by calling the check_op of base class
...             input_x = self.check_op(
...                 input_x, self.pre_val, self.min_val, self.max_val, self.cnt, None)
...             self.cnt += 1
...         # return the result of original operator
...         return self.op(input_x, gamma, beta)
generate_params()[源代码]

生成支持精度敏感检测的参数。

返回:

包含4个元素的元组。 派生类通过调用此函数初始化精度敏感检测所需的参数。

样例:

>>> from mindspore.ops.silent_check import ASDBase
>>> from mindspore.ops import LayerNorm as OriginLayerNorm
>>> class LayerNormASD(ASDBase):
...     def __init__(self, *args, **kwargs):
...         super().__init__(OriginLayerNorm, *args, **kwargs)
...         # init parameters for accuracy-sensitive detection by calling the base class function
...         self.pre_val, self.min_val, self.max_val, self.cnt = self.generate_params()