文档反馈

问题文档片段

问题文档片段包含公式时,显示为空格。

提交类型
issue

有点复杂...

找人问问吧。

PR

小问题,全程线上修改...

一键搞定!

请选择提交类型

问题类型
规范和低错类

- 规范和低错类:

- 错别字或拼写错误,标点符号使用错误、公式错误或显示异常。

- 链接错误、空单元格、格式错误。

- 英文中包含中文字符。

- 界面和描述不一致,但不影响操作。

- 表述不通顺,但不影响理解。

- 版本号不匹配:如软件包名称、界面版本号。

易用性

- 易用性:

- 关键步骤错误或缺失,无法指导用户完成任务。

- 缺少主要功能描述、关键词解释、必要前提条件、注意事项等。

- 描述内容存在歧义指代不明、上下文矛盾。

- 逻辑不清晰,该分类、分项、分步骤的没有给出。

正确性

- 正确性:

- 技术原理、功能、支持平台、参数类型、异常报错等描述和软件实现不一致。

- 原理图、架构图等存在错误。

- 命令、命令参数等错误。

- 代码片段错误。

- 命令无法完成对应功能。

- 界面错误,无法指导操作。

- 代码样例运行报错、运行结果不符。

风险提示

- 风险提示:

- 对重要数据或系统存在风险的操作,缺少安全提示。

内容合规

- 内容合规:

- 违反法律法规,涉及政治、领土主权等敏感词。

- 内容侵权。

请选择问题类型

问题描述

点击输入详细问题描述,以帮助我们快速定位问题。

优化器的编译优化

下载Notebook下载样例代码查看源文件

概述

mindspore.ops.composite中提供了一些涉及图变换的组合类算子,例如MultitypeFuncGraphHyperMap等。

MultitypeFuncGraph

MultitypeFuncGraph用于生成重载函数,支持不同类型的输入。用户可以使用MultitypeFuncGraph定义一组重载的函数,根据不同类型,采用不同实现。首先初始化一个MultitypeFuncGraph 对象,使用带有输入类型的 register 作为待注册函数的装饰器,使得该对象支持多种类型的输入。更多使用方法见:MultitypeFuncGraph

代码样例如下:

[1]:
import numpy as np
from mindspore.ops import MultitypeFuncGraph
import mindspore as ms
import mindspore.ops as ops

add = MultitypeFuncGraph('add')
@add.register("Number", "Number")
def add_scalar(x, y):
    return ops.scalar_add(x, y)

@add.register("Tensor", "Tensor")
def add_tensor(x, y):
    return ops.tensor_add(x, y)

tensor1 = ms.Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
tensor2 = ms.Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
print('tensor', add(tensor1, tensor2))
print('scalar', add(1, 2))
tensor [[2.4 4.2]
 [4.4 6.4]]
scalar 3

HyperMap

HyperMap可以对一组或多组输入做指定的运算,可以配合MultitypeFuncGraph一起使用。例如定义一组重载的add函数后,对多组不同类型的输入进行add运算。不同于MapHyperMap 能够用于嵌套结构,对序列或嵌套序列中的输入做指定运算。更多使用方法见:HyperMap

代码样例如下:

[2]:
import mindspore as ms
from mindspore.ops import MultitypeFuncGraph, HyperMap
import mindspore.ops as ops

add = MultitypeFuncGraph('add')
@add.register("Number", "Number")
def add_scalar(x, y):
    return ops.scalar_add(x, y)

@add.register("Tensor", "Tensor")
def add_tensor(x, y):
    return ops.tensor_add(x, y)

add_map = HyperMap(add)
output = add_map((ms.Tensor(1, ms.float32), ms.Tensor(2, ms.float32), 1), (ms.Tensor(3, ms.float32), ms.Tensor(4, ms.float32), 2))
print("output =", output)
output = (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 6), 3)

此例子中传入add_map的输入包含了两个序列,HyperMap会以operation(args[0][i], args[1][i])的形式分别从两个序列中取相应的元素作为add函数的输入xy,例如add(Tensor(1, mstype.float32), Tensor(3, mstype.float32))