{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 调用自定义类\n", "\n", "[![在线运行](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9vYnMuZHVhbHN0YWNrLmNuLW5vcnRoLTQubXlodWF3ZWljbG91ZC5jb20vbWluZHNwb3JlLXdlYnNpdGUvbm90ZWJvb2svcjIuMC4wLWFscGhhL3R1dG9yaWFscy9leHBlcnRzL3poX2NuL25ldHdvcmsvbWluZHNwb3JlX2ppdF9jbGFzcy5pcHluYg==&imageid=77ef960a-bd26-4de4-9695-5b85a786fb90) [![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/resource/_static/logo_notebook.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.0.0-alpha/tutorials/experts/zh_cn/network/mindspore_jit_class.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/resource/_static/logo_download_code.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.0.0-alpha/tutorials/experts/zh_cn/network/mindspore_jit_class.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r2.0.0-alpha/tutorials/experts/source_zh_cn/network/jit_class.ipynb)\n", "\n", "## 概述\n", "\n", "在静态图模式下,通过使用`jit_class`修饰自定义类,用户可以创建、调用该自定义类的实例,并且可以获取其属性和方法。\n", "\n", "`jit_class`应用于静态图模式,扩充完善静态图编译语法的支持范围。在动态图模式即PyNative模式下,`jit_class`的使用不影响PyNative模式的执行逻辑。\n", "\n", "本文档主要介绍`jit_class`装饰器的使用方法和使用须知,以便您可以更有效地使用`jit_class`装饰器功能。\n", "\n", "## jit_class修饰自定义类\n", "\n", "使用@jit_class修饰自定义类后,支持创建、调用该自定义类的实例,获取其属性和方法。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T11:36:31.406170Z", "start_time": "2022-01-04T11:36:29.874130Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1 2 3]\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.nn as nn\n", "import mindspore as ms\n", "\n", "@ms.jit_class\n", "class InnerNet:\n", " value = ms.Tensor(np.array([1, 2, 3]))\n", "\n", "class Net(nn.Cell):\n", " def construct(self):\n", " return InnerNet().value\n", "\n", "ms.set_context(mode=ms.GRAPH_MODE)\n", "net = Net()\n", "out = net()\n", "print(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "jit_class支持自定义类嵌套使用、自定义类与nn.Cell嵌套使用的场景。需要注意的是,类继承时,如果父类使用了jit_class,子类也会具有jit_class的能力。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T11:36:31.517439Z", "start_time": "2022-01-04T11:36:31.499697Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1 2 3]\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.nn as nn\n", "import mindspore as ms\n", "\n", "@ms.jit_class\n", "class Inner:\n", " def __init__(self):\n", " self.value = ms.Tensor(np.array([1, 2, 3]))\n", "\n", "@ms.jit_class\n", "class InnerNet:\n", " def __init__(self):\n", " self.inner = Inner()\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.inner_net = InnerNet()\n", "\n", " def construct(self):\n", " out = self.inner_net.inner.value\n", " return out\n", "\n", "ms.set_context(mode=ms.GRAPH_MODE)\n", "net = Net()\n", "out = net()\n", "print(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "jit_class仅支持修饰自定义类,不支持nn.Cell和非class类型。执行下面用例,将会出现报错。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "import mindspore.nn as nn\n", "import mindspore as ms\n", "\n", "@ms.jit_class\n", "class Net(nn.Cell):\n", " def construct(self, x):\n", " return x\n", "\n", "ms.set_context(mode=ms.GRAPH_MODE)\n", "x = ms.Tensor(1)\n", "net = Net()\n", "net(x)\n", "\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "报错信息如下:\n", "\n", "```text\n", "TypeError: Decorator jit_class is used for user-defined classes and cannot be used for nn.Cell: Net<>.\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "import mindspore as ms\n", "\n", "@ms.jit_class\n", "def func(x, y):\n", " return x + y\n", "\n", "func(1, 2)\n", "\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "报错信息如下:\n", "\n", "```text\n", "TypeError: Decorator jit_class can only be used for class type, but got .\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 获取自定义类的属性和方法\n", "\n", "支持通过类名调用类的属性,不支持通过类名调用类的方法。对于类的实例,支持调用其属性和方法。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T11:36:31.406170Z", "start_time": "2022-01-04T11:36:29.874130Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "12\n" ] } ], "source": [ "import mindspore.nn as nn\n", "import mindspore as ms\n", "\n", "@ms.jit_class\n", "class InnerNet:\n", " def __init__(self, val):\n", " self.number = val\n", "\n", " def act(self, x, y):\n", " return self.number * (x + y)\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.inner_net = InnerNet(2)\n", "\n", " def construct(self, x, y):\n", " return self.inner_net.number + self.inner_net.act(x, y)\n", "\n", "ms.set_context(mode=ms.GRAPH_MODE)\n", "x = ms.Tensor(2, dtype=ms.int32)\n", "y = ms.Tensor(3, dtype=ms.int32)\n", "net = Net()\n", "out = net(x, y)\n", "print(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "不支持调用私有属性和魔术方法,所调用的方法函数须在静态图编译支持的语法范围内。执行下面用例,将会出现报错。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "import numpy as np\n", "import mindspore.nn as nn\n", "import mindspore as ms\n", "\n", "@ms.jit_class\n", "class InnerNet:\n", " def __init__(self):\n", " self.value = ms.Tensor(np.array([1, 2, 3]))\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.inner_net = InnerNet()\n", "\n", " def construct(self):\n", " out = self.inner_net.__str__()\n", " return out\n", "\n", "ms.set_context(mode=ms.GRAPH_MODE)\n", "net = Net()\n", "out = net()\n", "\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "报错信息如下:\n", "\n", "RuntimeError: `__str__` is a private variable or magic method, which is not supported." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 创建自定义类的实例\n", "\n", "静态图模式下,创建自定义类的实例时,参数要求为常量。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T11:36:31.406170Z", "start_time": "2022-01-04T11:36:29.874130Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "5\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.nn as nn\n", "import mindspore as ms\n", "\n", "@ms.jit_class\n", "class InnerNet:\n", " def __init__(self, val):\n", " self.number = val + 3\n", "\n", "class Net(nn.Cell):\n", " def construct(self):\n", " net = InnerNet(2)\n", " return net.number\n", "\n", "ms.set_context(mode=ms.GRAPH_MODE)\n", "net = Net()\n", "out = net()\n", "print(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "对于其它场景,创建自定义类的实例时,没有参数必须是常量的限制。例如下面的用例:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T11:36:31.406170Z", "start_time": "2022-01-04T11:36:29.874130Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "5\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.nn as nn\n", "import mindspore as ms\n", "\n", "@ms.jit_class\n", "class InnerNet:\n", " def __init__(self, val):\n", " self.number = val + 3\n", "\n", "class Net(nn.Cell):\n", " def __init__(self, val):\n", " super(Net, self).__init__()\n", " self.inner = InnerNet(val)\n", "\n", " def construct(self):\n", " return self.inner.number\n", "\n", "ms.set_context(mode=ms.GRAPH_MODE)\n", "x = ms.Tensor(2, dtype=ms.int32)\n", "net = Net(x)\n", "out = net()\n", "print(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 调用自定义类的实例\n", "\n", "调用自定义类的实例时,将会调用该类的`__call__`函数方法。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T11:36:31.517439Z", "start_time": "2022-01-04T11:36:31.499697Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.nn as nn\n", "import mindspore as ms\n", "\n", "@ms.jit_class\n", "class InnerNet:\n", " def __init__(self, number):\n", " self.number = number\n", "\n", " def __call__(self, x, y):\n", " return self.number * (x + y)\n", "\n", "class Net(nn.Cell):\n", " def construct(self, x, y):\n", " net = InnerNet(2)\n", " out = net(x, y)\n", " return out\n", "\n", "ms.set_context(mode=ms.GRAPH_MODE)\n", "x = ms.Tensor(2, dtype=ms.int32)\n", "y = ms.Tensor(3, dtype=ms.int32)\n", "net = Net()\n", "out = net(x, y)\n", "print(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "如果该类没有定义`__call__`函数,将会报错提示。执行下面用例,将会出现报错。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "import numpy as np\n", "import mindspore.nn as nn\n", "import mindspore as ms\n", "\n", "@ms.jit_class\n", "class InnerNet:\n", " def __init__(self, number):\n", " self.number = number\n", "\n", "class Net(nn.Cell):\n", " def construct(self, x, y):\n", " net = InnerNet(2)\n", " out = net(x, y)\n", " return out\n", "\n", "ms.set_context(mode=ms.GRAPH_MODE)\n", "x = ms.Tensor(2, dtype=ms.int32)\n", "y = ms.Tensor(3, dtype=ms.int32)\n", "net = Net()\n", "out = net(x, y)\n", "print(out)\n", "\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "报错信息如下:\n", "\n", "RumtimeError: MsClassObject: 'InnerNet' has no `__call__` function, please check the code.\n" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }