调用自定义类
概述
通过ms_class,用户可以在静态图模式下调用自定义类的属性和方法。
在静态图模式下,用户需要获取自定义类的属性/方法时,可以对该类使用@ms_class装饰器,从而调用其属性/方法。在动态图模式即PyNative模式下,ms_class的使用也是支持的,但用户不需要@ms_class装饰器也能调用自定义类的属性和方法。
本文档主要介绍ms_class的使用场景和使用须知,以便您可以更有效地使用ms_class功能。
使用场景
1、调用自定义类的属性
调用自定义类的属性时,可以通过@ms_class装饰器,对自定义类进行修饰。
[1]:
import numpy as np
import mindspore.nn as nn
from mindspore import context, Tensor, ms_class
@ms_class
class InnerNet:
def __init__(self):
self.value = Tensor(np.array([1, 2, 3]))
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self):
out = self.inner_net.value
return out
context.set_context(mode=context.GRAPH_MODE)
net = Net()
out = net()
print(out)
[1 2 3]
2、调用自定义类的方法
调用自定义类的方法时,可以通过@ms_class装饰器,对自定义类进行修饰。
[2]:
import numpy as np
import mindspore.nn as nn
from mindspore import context, Tensor, ms_class
@ms_class
class InnerNet:
def act(self, x, y):
return x + y
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self, x, y):
out = self.inner_net.act(x, y)
return out
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array([1, 2, 3]).astype(np.int32))
y = Tensor(np.array([4, 5, 6]).astype(np.int32))
net = Net()
out = net(x, y)
print(out)
[5 7 9]
3、调用嵌套的自定义类的属性和方法
多个自定义类嵌套时,如果都使用了@ms_class装饰器,则可以获取嵌套类的属性和方法。
[3]:
import numpy as np
import mindspore.nn as nn
from mindspore import context, Tensor, ms_class
@ms_class
class Inner:
def __init__(self):
self.value = Tensor(np.array([1, 2, 3]))
@ms_class
class InnerNet:
def __init__(self):
self.inner = Inner()
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self):
out = self.inner_net.inner.value
return out
context.set_context(mode=context.GRAPH_MODE)
net = Net()
out = net()
print(out)
[1 2 3]
4、自定义类和nn.Cell嵌套使用
当自定义类和nn.Cell嵌套使用时,调用自定义类的属性和方法。关于nn.Cell的介绍,请参考mindspore.nn.Cell。
[4]:
import numpy as np
import mindspore.nn as nn
from mindspore import dtype as mstype
from mindspore import context, Tensor, ms_class
class Net(nn.Cell):
def __init__(self, val):
super().__init__()
self.val = val
def construct(self, x):
return x + self.val
@ms_class
class TrainNet():
class Loss(nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
def construct(self, x):
out = self.net(x)
return out * 2
def __init__(self, net):
self.net = net
loss_net = self.Loss(self.net)
self.number = loss_net(10)
global_net = Net(1)
class LearnNet(nn.Cell):
def __init__(self):
super().__init__()
self.value = TrainNet(global_net).number
def construct(self, x):
return x + self.value
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(3, mstype.int32)
leanrn_net = LearnNet()
out = leanrn_net(x)
print(out)
25
使用须知
使用ms_class时,需要考虑以下条件:
1、ms_class不支持非class类型
from mindspore import ms_class
@ms_class
def func(x, y):
return x + y
func(1, 2)
执行代码后,将会提示以下报错信息:
TypeError: Decorator ms_class can only be used for class type, but got <function func at 0x7fee33c005f0>.
2、ms_class支持调用类实例的属性和方法,不支持直接从类定义获取其属性和方法,不支持在construct/ms_function函数中创建自定义类的实例。
import mindspore.nn as nn
from mindspore import context, ms_class
@ms_class
class InnerNet:
def __init__(self):
self.number = 1
class Net(nn.Cell):
def construct(self):
out = InnerNet().number
return out
context.set_context(mode=context.GRAPH_MODE)
net = Net()
net()
执行代码后,将会提示以下报错信息:
ValueError: This may be not defined, or it can’t be a operator. Please check code.
3、不支持对nn.Cell使用@ms_class装饰器。
import mindspore.nn as nn
from mindspore import context, Tensor, ms_class
@ms_class
class Net(nn.Cell):
def construct(self, x):
return x
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(1)
net = Net()
net(x)
执行代码后,将会提示以下报错信息:
TypeError: ms_class is used for user-defined classes and cannot be used for nn.Cell: Net<>.
4、不支持调用自定义类的私有属性或魔术方法。
import numpy as np
import mindspore.nn as nn
from mindspore import context, Tensor, ms_class
@ms_class
class InnerNet:
def __init__(self):
self.value = Tensor(np.array([1, 2, 3]))
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self):
out = self.inner_net.__str__()
return out
context.set_context(mode=context.GRAPH_MODE)
net = Net()
out = net()
执行代码后,将会提示以下报错信息:
AttributeError: __str__
is a private variable or magic method, which is not supported.