mindspore.nn.GraphCell
- class mindspore.nn.GraphCell(graph, params_init=None)[源代码]
运行从MindIR加载的计算图。
此功能仍在开发中。目前 GraphCell 不支持修改图结构,在导出MindIR时只能使用shape和类型与输入相同的数据。
参数:
graph (object) - 从MindIR加载的编译图。
params_init (dict) - 需要在图中初始化的参数。key为参数名称,类型为字符串,value为 Tensor 或 Parameter。如果参数名在图中已经存在,则更新其值;如果不存在,则忽略。默认值:None。
异常:
TypeError – 如果图不是FuncGraph类型。
TypeError – 如果 params_init 不是字典。
TypeError – 如果 params_init 的key不是字符串。
TypeError – 如果 params_init 的value既不是 Tensor也不是Parameter。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import numpy as np >>> import mindspore.nn as nn >>> from mindspore import Tensor, export, load >>> >>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones") >>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) >>> export(net, input, file_name="net", file_format="MINDIR") >>> graph = load("net.mindir") >>> net = nn.GraphCell(graph) >>> output = net(input) >>> print(output) [[[[4. 6. 4.] [6. 9. 6.] [4. 6. 4.]]]]