mindspore.Model

class mindspore.Model(network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None, amp_level="O0", boost_level="O0", **kwargs)[source]

High-Level API for training or inference.

Model groups layers into an object with training and inference features.

Parameters
  • network (Cell) – A training or testing network.

  • loss_fn (Cell) – Objective function, if loss_fn is None, the network should contain the logic of loss and grads calculation, and parallel if needed. Default: None.

  • optimizer (Cell) – Optimizer for updating the weights. Default: None.

  • metrics (Union[dict, set]) – A Dictionary or a set of metrics to be evaluated by the model during training and inference. eg: {‘accuracy’, ‘recall’}. Default: None.

  • eval_network (Cell) – Network for evaluation. If not defined, network and loss_fn would be wrapped as eval_network . Default: None.

  • eval_indexes (list) – When defining the eval_network, if eval_indexes is None, all outputs of the eval_network would be passed to metrics, otherwise eval_indexes must contain three elements, including the positions of loss value, predicted value and label. The loss value would be passed to the Loss metric, the predicted value and label would be passed to other metric. Default: None.

  • amp_level (str) –

    Option for argument level in mindspore.amp.build_train_network , level for mixed precision training. Supports [“O0”, “O2”, “O3”, “auto”]. Default: “O0”.

    • O0: Do not change.

    • O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.

    • O3: Cast network to float16, with additional property keep_batchnorm_fp32=False .

    • auto: Set to level to recommended level in different devices. Set level to O2 on GPU, Set level to O3 Ascend. The recommended level is choose by the export experience, cannot always general. User should specify the level for special network.

    O2 is recommended on GPU, O3 is recommended on Ascend.The more detailed explanation of amp_level setting can be found at mindspore.amp.build_train_network .

  • boost_level (str) –

    Option for argument level in mindspore.boost , level for boost mode training. Supports [“O0”, “O1”, “O2”]. Default: “O0”.

    • O0: Do not change.

    • O1: Enable the boost mode, the performance is improved by about 20%, and the accuracy is the same as the original accuracy.

    • O2: Enable the boost mode, the performance is improved by about 30%, and the accuracy is reduced by less than 3%.

Examples

>>> from mindspore import Model, nn
>>>
>>> class Net(nn.Cell):
...     def __init__(self, num_class=10, num_channel=1):
...         super(Net, self).__init__()
...         self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
...         self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
...         self.fc1 = nn.Dense(16*5*5, 120, weight_init='ones')
...         self.fc2 = nn.Dense(120, 84, weight_init='ones')
...         self.fc3 = nn.Dense(84, num_class, weight_init='ones')
...         self.relu = nn.ReLU()
...         self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
...         self.flatten = nn.Flatten()
...
...     def construct(self, x):
...         x = self.max_pool2d(self.relu(self.conv1(x)))
...         x = self.max_pool2d(self.relu(self.conv2(x)))
...         x = self.flatten(x)
...         x = self.relu(self.fc1(x))
...         x = self.relu(self.fc2(x))
...         x = self.fc3(x)
...         return x
>>>
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
>>> dataset = create_custom_dataset()
>>> model.train(2, dataset)
build(train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1)[source]

Build computational graphs and data graphs with the sink mode.

Warning

This is an experimental prototype that is subject to change and/or deletion.

Note

Pre-build process only supports GRAPH_MODE and Ascend target currently. The interface builds the computational graphs, when the interface is executed first, ‘model.train’ only performs the graphs execution. It only support dataset sink mode.

Parameters
  • train_dataset (Dataset) – A training dataset iterator. If train_dataset is defined, training graphs will be initialized. Default: None.

  • valid_dataset (Dataset) – An evaluating dataset iterator. If valid_dataset is defined, evaluation graphs will be initialized, and metrics in Model can not be None. Default: None.

  • sink_size (int) – Control the amount of data in each sink. Default: -1.

  • epoch (int) – Control the training epochs. Default: 1.

Examples

>>> from mindspore import Model, nn, FixedLossScaleManager
>>>
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
>>> dataset = create_custom_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> loss_scale_manager = FixedLossScaleManager()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.build(dataset, epoch=2)
>>> model.train(2, dataset)
eval(valid_dataset, callbacks=None, dataset_sink_mode=True)[source]

Evaluation API where the iteration is controlled by python front-end.

Configure to pynative mode or CPU, the evaluating process will be performed with dataset non-sink mode.

Note

If dataset_sink_mode is True, data will be sent to device. If the device is Ascend, features of data will be transferred one by one. The limitation of data transmission per time is 256M. When dataset_sink_mode is True, the step_end method of the Callback class will be executed when the epoch_end method is called.

Parameters
  • valid_dataset (Dataset) – Dataset to evaluate the model.

  • callbacks (Optional[list(Callback)]) – List of callback objects which should be executed while training. Default: None.

  • dataset_sink_mode (bool) – Determines whether to pass the data through dataset channel. Default: True.

Returns

Dict, the key is the metric name defined by users and the value is the metrics value for the model in the test mode.

Examples

>>> from mindspore import Model, nn
>>>
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
>>> dataset = create_custom_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
>>> acc = model.eval(dataset, dataset_sink_mode=False)
property eval_network

Get the model’s eval_network.

infer_predict_layout(*predict_data)[source]

Generate parameter layout for the predict network in auto or semi auto parallel mode.

Data could be a single tensor or multiple tensors.

Note

Batch data should be put together in one tensor.

Parameters

predict_data (Tensor) – One tensor or multiple tensors of predict data.

Returns

Dict, Parameter layout dictionary used for load distributed checkpoint.

Raises

RuntimeError – If get_context is not GRAPH_MODE.

Examples

>>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on
>>> # mindspore.cn.
>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Model, context, Tensor
>>> from mindspore.context import ParallelMode
>>> from mindspore.communication import init
>>>
>>> context.set_context(mode=context.GRAPH_MODE)
>>> init()
>>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
>>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
>>> model = Model(Net())
>>> predict_map = model.infer_predict_layout(input_data)
infer_train_layout(train_dataset, dataset_sink_mode=True, sink_size=-1)[source]

Generate parameter layout for the train network in auto or semi auto parallel mode. Only dataset sink mode is supported for now.

Warning

This is an experimental prototype that is subject to change and/or deletion.

Note

This is a pre-compile function. The arguments should be the same with model.train() function.

Parameters
  • train_dataset (Dataset) – A training dataset iterator. If there is no loss_fn, a tuple with multiple data (data1, data2, data3, …) should be returned and passed to the network. Otherwise, a tuple (data, label) should be returned. The data and label would be passed to the network and loss function respectively.

  • dataset_sink_mode (bool) – Determines whether to pass the data through dataset channel. Configure pynative mode or CPU, the training process will be performed with dataset not sink. Default: True.

  • sink_size (int) – Control the amount of data in each sink. If sink_size = -1, sink the complete dataset for each epoch. If sink_size > 0, sink sink_size data for each epoch. If dataset_sink_mode is False, set sink_size as invalid. Default: -1.

Returns

Dict, Parameter layout dictionary used for load distributed checkpoint

Examples

>>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on
>>> # mindspore.cn.
>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Model, context, Tensor, nn, FixedLossScaleManager
>>> from mindspore.context import ParallelMode
>>> from mindspore.communication import init
>>>
>>> context.set_context(mode=context.GRAPH_MODE)
>>> init()
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
>>>
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
>>> dataset = create_custom_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> loss_scale_manager = FixedLossScaleManager()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> layout_dict = model.infer_train_layout(dataset)
predict(*predict_data)[source]

Generate output predictions for the input samples.

Data could be a single tensor, a list of tensor, or a tuple of tensor.

Note

This is a pre-compile function. The arguments should be the same with model.predict() function.

Parameters

predict_data (Optional[Tensor, list[Tensor], tuple[Tensor]]) – The predict data, can be a single tensor, a list of tensor, or a tuple of tensor.

Returns

Tensor, array(s) of predictions.

Examples

>>> import mindspore as ms
>>> from mindspore import Model, Tensor
>>>
>>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
>>> model = Model(Net())
>>> result = model.predict(input_data)
property predict_network

Get the model’s predict_network.

train(epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1)[source]

Training API where the iteration is controlled by python front-end.

When setting pynative mode or CPU, the training process will be performed with dataset not sink.

Note

If dataset_sink_mode is True, data will be sent to device. If the device is Ascend, features of data will be transferred one by one. The limitation of data transmission per time is 256M. When dataset_sink_mode is True, the step_end method of the Callback class will be executed when the epoch_end method is called. If sink_size > 0, each epoch of the dataset can be traversed unlimited times until you get sink_size elements of the dataset. The next epoch continues to traverse from the end position of the previous traversal. The interface builds the computational graphs and then executes the computational graphs. However, when the ‘model.build’ is executed first, it only performs the graphs execution.

Parameters
  • epoch (int) – Generally, total number of iterations on the data per epoch. When dataset_sink_mode is set to true and sink_size>0, each epoch sink sink_size steps on the data instead of total number of iterations.

  • train_dataset (Dataset) – A training dataset iterator. If there is no loss_fn, a tuple with multiple data (data1, data2, data3, …) should be returned and passed to the network. Otherwise, a tuple (data, label) should be returned. The data and label would be passed to the network and loss function respectively.

  • callbacks (Optional[list[Callback], Callback]) – List of callback objects or callback object, which should be executed while training. Default: None.

  • dataset_sink_mode (bool) – Determines whether to pass the data through dataset channel. Configure pynative mode or CPU, the training process will be performed with dataset not sink. Default: True.

  • sink_size (int) – Control the amount of data in each sink. If sink_size = -1, sink the complete dataset for each epoch. If sink_size > 0, sink sink_size data for each epoch. If dataset_sink_mode is False, set sink_size as invalid. Default: -1.

Examples

>>> from mindspore import Model, nn, FixedLossScaleManager
>>>
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
>>> dataset = create_custom_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> loss_scale_manager = FixedLossScaleManager()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.train(2, dataset)
property train_network

Get the model’s train_network.