mindspore.train.RunContext

View Source On Gitee
class mindspore.train.RunContext(original_args)[source]

Hold and manage information about the model.

RunContext is mainly used to collect context-related information about the model during training or eval and pass it into the Callback object as an input parameter to share information.

Callback objects not only can obtain the Model context information by calling by RunContext.original_args() and add extra attributes to the information, but also can stop the training process by calling request_stop method. For details of custom Callback, please check Callback Mechanism.

RunContext.original_args() holds the model context information as a dictionary variable, and different attributes of the dictionary are stored in training or eval process. Details are as follows:

Attributes supported in train

Attributes supported in eval

meaning

train_network

train network with optimizer and loss

epoch_num

Number of train epochs

train_dataset

the train dataset

loss_fn

the loss function

optimizer

the optimizer

parallel_mode

the parallel mode

device_number

the device number

train_dataset_element

the train data element of current step

last_save_ckpt_step

the last step num of save ckpt

latest_ckpt_file

the ckpt file

cur_epoch_num

number of current epoch

eval_network

the evaluate network

valid_dataset

the valid dataset

metrics

the evaluate metrics

mode

mode

“train” or “eval”

batch_num

batch_num

the train/eval batch number

list_callback

list_callback

callback list

network

network

basic network

cur_step_num

cur_step_num

the train/eval step number

dataset_sink_mode

dataset_sink_mode

the train/eval sink mode

net_outputs

net_outputs

network output results

Parameters

original_args (dict) – Holding the related information of model.

Examples

>>> from mindspore import Tensor
>>> from mindspore.train import RunContext
>>> cb_params = {}
>>> cb_params["cur_epoch_num"] = 4
>>> cb_params["epoch_num"] = 4
>>> cb_params["cur_step_num"] = 2
>>> cb_params["batch_num"] = 2
>>> cb_params["net_outputs"] = Tensor(2.0)
>>> run_context = RunContext(cb_params)
>>> whether_stop = run_context.get_stop_requested()
get_stop_requested()[source]

Return whether a stop is requested or not.

Returns

bool, if true, model.train() stops iterations.

original_args()[source]

Get the _original_args object.

Returns

Dict, an object that holds the original arguments of model.

Tutorial Examples:
request_stop()[source]

Set stop requirement during training or eval.

Callbacks can use this function to request stop of iterations. model.train() checks whether this is called or not.

Tutorial Examples: