mindspore_lite.ModelGroupFlag

查看源文件
class mindspore_lite.ModelGroupFlag[源代码]

ModelGroupFlag 类用于构造 ModelGroup 的标签。目前支持以下场景:

  1. ModelGroupFlag.SHARE_WEIGHT ,共享工作空间内存,ModelGroup 的默认构造标签。

  2. ModelGroupFlag.SHARE_WORKSPACE ,共享权重内存,多个模型共享权重(包括常量和变量)内存。

  3. ModelGroupFlag.SHARE_WEIGHT_WORKSPACE ,共享权重内存和工作空间内存。

样例:

>>> import mindspore_lite as mslite
>>> context = mslite.Context()
>>> context.target = ["Ascend"]
>>> context.ascend.device_id = 0
>>> context.ascend.rank_id = 0
>>> context.ascend.provider = "ge"
>>> model_group = mslite.ModelGroup(mslite.ModelGroupFlag.SHARE_WEIGHT)
>>> model0 = mslite.Model()
>>> model1 = mslite.Model()
>>> model_group.add_model([model0, model1])
>>> model0.build_from_file("seq_1024.mindir", mslite.ModelType.MINDIR, context, "config0.ini")
>>> model1.build_from_file("seq_1.mindir", mslite.ModelType.MINDIR, context, "config.ini")