mindspore.ops.silent_check.ASDBase
- class mindspore.ops.silent_check.ASDBase(cls, *args, **kwargs)[源代码]
ASDBase 是 Python 中具有精度敏感检测特性的算子的基类。
- 参数:
cls (Primitive) - 需要增加精度敏感检测特性的原始算子。
args (tuple) - 传递给原始运算符的可变参数元组。
kwargs (dict) - 传递给原始运算符的可变参数字典。
- 支持平台:
Ascend
样例:
>>> from mindspore.ops.silent_check import ASDBase >>> from mindspore.ops import LayerNorm as OriginLayerNorm >>> class LayerNormASD(ASDBase): ... def __init__(self, *args, **kwargs): ... super().__init__(OriginLayerNorm, *args, **kwargs) ... # init parameters for feature value detection by calling the base class method generate_params() ... self.pre_val, self.min_val, self.max_val, self.cnt = self.generate_params() ... ... def __call__(self, input_x, gamma, beta): ... if self.enable_check: ... # execute feature value detection by calling the check_op of base class ... input_x = self.check_op( ... input_x, self.pre_val, self.min_val, self.max_val, self.cnt, None) ... self.cnt += 1 ... # return the result of original operator ... return self.op(input_x, gamma, beta)
- generate_params()[源代码]
生成支持精度敏感检测的参数。
- 返回:
包含4个元素的元组。 派生类通过调用此函数初始化精度敏感检测所需的参数。
样例:
>>> from mindspore.ops.silent_check import ASDBase >>> from mindspore.ops import LayerNorm as OriginLayerNorm >>> class LayerNormASD(ASDBase): ... def __init__(self, *args, **kwargs): ... super().__init__(OriginLayerNorm, *args, **kwargs) ... # init parameters for feature value detection by calling the base class function ... self.pre_val, self.min_val, self.max_val, self.cnt = self.generate_params()