mindspore.train.RunContext

class mindspore.train.RunContext(original_args)[源代码]

保存和管理模型的相关信息。

RunContext 主要用于收集训练或推理过程中模型的上下文相关信息并作为入参传入callback对象中来实现信息的共享。

Callback的类方法中,调用 RunContext.original_args() 可以获取模型当前的上下文信息,用户也可以为此信息添加额外的自定义属性,同时 request_stop() 方法可以控制训练过程的停止。具体用法请查看 回调机制Callback

RunContext.original_args() 存储的模型信息为一个字典型变量,在训练和推理过程会存储不同的属性。详情如下:

训练过程支持的属性

推理过程支持的属性

说明

train_network

包含了优化器和损失的训练网络

epoch_num

训练的epoch数

train_dataset

训练集

loss_fn

损失函数

optimizer

优化器

parallel_mode

并行模式

device_number

设备编号

train_dataset_element

当前step的训练数据

last_save_ckpt_step

最后一次存储ckpt的step

latest_ckpt_file

ckpt文件名

cur_epoch_num

当前的epoch

eval_network

评估网络

valid_dataset

验证集

metrics

评估指标

mode

mode

“train”或”eval”模式

batch_num

batch_num

训练或推理的batch数

list_callback

list_callback

回调列表

network

network

基础的网络结构

cur_step_num

cur_step_num

当前的训练或推理的step

dataset_sink_mode

dataset_sink_mode

训练或推理的数据是否下沉

net_outputs

net_outputs

训练或推理的网络输出

参数:
  • original_args (dict) - 模型的相关信息。

样例:

>>> from mindspore import Tensor
>>> from mindspore.train import RunContext
>>> cb_params = {}
>>> cb_params["cur_epoch_num"] = 4
>>> cb_params["epoch_num"] = 4
>>> cb_params["cur_step_num"] = 2
>>> cb_params["batch_num"] = 2
>>> cb_params["net_outputs"] = Tensor(2.0)
>>> run_context = RunContext(cb_params)
>>> whether_stop = run_context.get_stop_requested()
get_stop_requested()[源代码]

获取是否停止训练的标志。

返回:

bool,如果为True,则 Model.train() 停止迭代。

original_args()[源代码]

获取模型相关信息的对象。

返回:

dict,含有模型的相关信息的对象。

教程样例:
request_stop()[源代码]

在训练期间设置停止请求。

可以使用此函数请求停止训练。 Model.train() 会检查是否调用此函数。

教程样例: