mindspore_gl.translate

查看源文件
mindspore_gl.translate(obj, method_name: str, translate_path: None or str = None)[源代码]

将顶点中心代码转换为MindSpore可理解代码。

翻译后,将在 /.mindspore_gl 中生成一个新函数。原方法将被此函数替换。

参数:
  • obj (Object) - 翻译对象。

  • method_name (str) - 要转换的方法的名称。

  • translate_path (str) - 构造文件的保存路径。

支持平台:

Ascend GPU

样例:

>>> import mindspore as ms
>>> import mindspore.ops as ops
>>> from mindspore_gl.nn import GNNCell
>>> from mindspore_gl import BatchedGraph
>>> from mindspore_gl.parser.vcg import translate
...
>>> class Net(GNNCell):
...     def __init__(self):
...         super().__init__()
...         translate(self, "loss")
...
...     def construct(self, pred, label, g: BatchedGraph):
...         loss = self.loss(pred, label, g)
...         loss = loss * g.graph_mask
...         return loss
...
...     def loss(self, pred, label, g: BatchedGraph):
...         criterion = ms.nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='none')
...         loss = criterion(pred, label)
...         loss = ops.ReduceMean()(loss * g.graph_mask)
...         return loss