比较与tf.keras.Model的功能差异

查看源文件

tf.keras.Model

tf.keras.Model(*args, **kwargs)

更多内容详见tf.keras.Model

mindspore.train.Model

mindspore.train.Model(network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None, amp_level="O0", boost_level="O0", **kwargs)

更多内容详见mindspore.train.Model

使用方式

框架提供的模型训练和推理的高阶API,实例化一个Model的常见场景可参考代码示例。

代码示例

TensorFlow:

  1. 实例化Model的两种方法:

创建一个前向传递,根据输入输出创建一个Model实例:

import tensorflow as tf

inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

继承Model类,在__init__中定义模型层,在call中明确执行逻辑。

import tensorflow as tf

class MyModel(tf.keras.Model):

  def __init__(self):
    super(MyModel, self).__init__()
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

  def call(self, inputs):
    x = self.dense1(inputs)
    return self.dense2(x)

model = MyModel()
  1. 使用compile方法进行模型配置:

model.compile(loss='mae', optimizer='adam')

MindSpore:

import mindspore as ms
from mindspore.train import Model
from mindspore import nn
from mindspore.common.initializer import Normal

class LinearNet(nn.Cell):
    def __init__(self):
        super().__init__()
        self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))

    def construct(self, x):
        return self.fc(x)

net = LinearNet()
crit = nn.MSELoss()
opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)

model = Model(network=net, loss_fn=crit, optimizer=opt, metrics={"mae"})