mindspore.amp.get_black_list

查看源文件
mindspore.amp.get_black_list()[源代码]

提供用于自动混合精度 amp_levelO2 等级的内置黑名单的拷贝。

当前的内置黑名单内容为:

[mindspore.nn.BatchNorm1d, mindspore.nn.BatchNorm2d, mindspore.nn.BatchNorm3d, mindspore.nn.LayerNorm]

返回:

list:内置黑名单的拷贝。

样例:

>>> from mindspore import amp
>>> black_list = amp.get_black_list()
>>> print(black_list)
[<class 'mindspore.nn.layer.normalization.BatchNorm1d'>, <class 'mindspore.nn.layer.normalization.BatchNorm2d'>,
 <class 'mindspore.nn.layer.normalization.BatchNorm3d'>, <class 'mindspore.nn.layer.normalization.LayerNorm'>]