mindspore_gs.ptq.RoundToNearest

查看源文件
class mindspore_gs.ptq.RoundToNearest(config=None)[源代码]

后量化算法的基本实现,通过统计最大最小值实现模型量化。

参数:
异常:
  • TypeError - config 在输入不为 None 时,元素类型不为 PTQConfig。

样例:

>>> import mindspore_gs
>>> from mindspore_gs import ptq
>>> from mindspore_gs.ptq import RoundToNearest as rtn
>>> from mindspore_gs.ptq import PTQConfig
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
>>> ptq = rtn()
>>> network = LeNet5()
>>> fake_quant_net = ptq.apply(net_work)
>>> quant_net = ptq.convert(fake_quant_net)
apply(network: Cell)[源代码]

network 中添加伪量化节点,转换成一个伪量化网络。

参数:
  • network (Cell) - 待伪量化的网络。

返回:

伪量化后的网络。

convert(net_opt: Cell, ckpt_path='')[源代码]

将量化网络 net_opt 转换为真实量化网络,后续导出用于部署。

参数:
  • net_opt (Cell) - 经过量化算法apply之后的网络。

  • ckpt_path (str) - 网络的checkpoint file文件路径,默认值为 "",表示不加载。注意,该参数会在后续版本中被遗弃。

异常:
  • TypeError - net_opt 数据类型不是Cell。

  • TypeError - ckpt_path 数据类型不是str。

  • ValueError - ckpt_path 非空但不是有效路径。

返回:

转换后的网络。