Calling the Custom Class
Overview
In static graph mode, using ms_class to decorate a custom class, users can create and call the instance of this custom class, and obtain attributes and methods for that custom class.
ms_class applied to static graph mode, expanding the scope of support for improving static graph compilation syntax. In dynamic graph mode, that is, PyNative mode, the use of ms_class does not affect the execution logic of the PyNative mode.
This document describes how to use ms_class so that you can use ms_class functions more effectively.
ms_class Decorates Custom Class
After decorating a custom class with @ms_class, you can create and call the instance of the custom class and obtain the attributes and methods.
import numpy as np
import mindspore.nn as nn
import mindspore as ms
from mindspore import ms_class
@ms_class
class InnerNet:
value = ms.Tensor(np.array([1, 2, 3]))
class Net(nn.Cell):
def construct(self):
return InnerNet().value
ms.set_context(mode=ms.GRAPH_MODE)
net = Net()
out = net()
print(out)
ms_class support custom class nesting use, custom classes and nn. Cell nesting uses scenes. It should be noted that when a class inherits, if the parent class uses ms_class, the subclass will also have the ability to ms_class.
import numpy as np
import mindspore.nn as nn
import mindspore as ms
from mindspore import ms_class
@ms_class
class Inner:
def __init__(self):
self.value = ms.Tensor(np.array([1, 2, 3]))
@ms_class
class InnerNet:
def __init__(self):
self.inner = Inner()
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self):
out = self.inner_net.inner.value
return out
ms.set_context(mode=ms.GRAPH_MODE)
net = Net()
out = net()
print(out)
ms_class only support decorating custom classes, not nn. Cell and nonclass types. If you execute the following use case, an error will appear.
import mindspore.nn as nn
import mindspore as ms
from mindspore import ms_class
@ms_class
class Net(nn.Cell):
def construct(self, x):
return x
ms.set_context(mode=ms.GRAPH_MODE)
x = ms.Tensor(1)
net = Net()
net(x)
The error information is as follows:
TypeError: ms_class is used for user-defined classes and cannot be used for nn.Cell: Net<>.
from mindspore import ms_class
@ms_class
def func(x, y):
return x + y
func(1, 2)
The error information is as follows:
TypeError: Decorator ms_class can only be used for class type, but got <function func at 0x7fee33c005f0>.
Obtaining the Attributes and Methods of the Custom Class
Support call a class’s attributes by class name, and calling a class’s methods by class name is not supported. For instances of a class, calling its attributes and methods is supported.
import mindspore.nn as nn
import mindspore as ms
from mindspore import ms_class
@ms_class
class InnerNet:
def __init__(self, val):
self.number = val
def act(self, x, y):
return self.number * (x + y)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet(2)
def construct(self, x, y):
return self.inner_net.number + self.inner_net.act(x, y)
ms.set_context(mode=ms.GRAPH_MODE)
x = ms.Tensor(2, dtype=ms.int32)
y = ms.Tensor(3, dtype=ms.int32)
net = Net()
out = net(x, y)
print(out)
Calling private attributes and magic methods is not supported, and the method functions that are called must be within the syntax supported by static graph compilation. If you execute the following use case, an error will appear.
import numpy as np
import mindspore.nn as nn
import mindspore as ms
from mindspore import ms_class
@ms_class
class InnerNet:
def __init__(self):
self.value = ms.Tensor(np.array([1, 2, 3]))
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self):
out = self.inner_net.__str__()
return out
ms.set_context(mode=ms.GRAPH_MODE)
net = Net()
out = net()
The error information is as follows:
RuntimeError: __str__
is a private variable or magic method, which is not supported.
Creating Instance of the Custom Class
In the static graph mode, when you create the instance of the custom class in a configuration/ms_function, the parameter requirement is a constant.
import numpy as np
import mindspore.nn as nn
import mindspore as ms
from mindspore import ms_class
@ms_class
class InnerNet:
def __init__(self, val):
self.number = val + 3
class Net(nn.Cell):
def construct(self):
net = InnerNet(2)
return net.number
ms.set_context(mode=ms.GRAPH_MODE)
net = Net()
out = net()
print(out)
For other scenarios, when creating an instance of a custom class, there is a restriction that no parameters must be constants. For example, the following use case:
import numpy as np
import mindspore.nn as nn
import mindspore as ms
from mindspore import ms_class
@ms_class
class InnerNet:
def __init__(self, val):
self.number = val + 3
class Net(nn.Cell):
def __init__(self, val):
super(Net, self).__init__()
self.inner = InnerNet(val)
def construct(self):
return self.inner.number
ms.set_context(mode=ms.GRAPH_MODE)
x = ms.Tensor(2, dtype=ms.int32)
net = Net(x)
out = net()
print(out)
Calling the Instance of the Custom Class
When you call an instance of a custom class, the __call__
function method of that class is called.
import numpy as np
import mindspore.nn as nn
import mindspore as ms
from mindspore import ms_class
@ms_class
class InnerNet:
def __init__(self, number):
self.number = number
def __call__(self, x, y):
return self.number * (x + y)
class Net(nn.Cell):
def construct(self, x, y):
net = InnerNet(2)
out = net(x, y)
return out
ms.set_context(mode=ms.GRAPH_MODE)
x = ms.Tensor(2, dtype=ms.int32)
y = ms.Tensor(3, dtype=ms.int32)
net = Net()
out = net(x, y)
print(out)
If the class does not define the __call__
function, an error message will be reported. If you execute the following use case, an error will appear.
import numpy as np
import mindspore.nn as nn
import mindspore as ms
from mindspore import ms_class
@ms_class
class InnerNet:
def __init__(self, number):
self.number = number
class Net(nn.Cell):
def construct(self, x, y):
net = InnerNet(2)
out = net(x, y)
return out
ms.set_context(mode=ms.GRAPH_MODE)
x = ms.Tensor(2, dtype=ms.int32)
y = ms.Tensor(3, dtype=ms.int32)
net = Net()
out = net(x, y)
print(out)
The error information is as follows:
RumtimeError: MsClassObject: ‘InnerNet’ has no __call__
function, please check the code.