mindspore.train.RunContext
- class mindspore.train.RunContext(original_args)[源代码]
保存和管理模型的相关信息。
RunContext 主要用于收集训练或推理过程中模型的上下文相关信息并作为入参传入callback对象中来实现信息的共享。
Callback的类方法中,调用 RunContext.original_args() 可以获取模型当前的上下文信息,用户也可以为此信息添加额外的自定义属性,同时 request_stop() 方法可以控制训练过程的停止。具体用法请查看 回调机制Callback。
RunContext.original_args() 存储的模型信息为一个字典型变量,在训练和推理过程会存储不同的属性。详情如下:
训练过程支持的属性
推理过程支持的属性
说明
train_network
包含了优化器和损失的训练网络
epoch_num
训练的epoch数
train_dataset
训练集
loss_fn
损失函数
optimizer
优化器
parallel_mode
并行模式
device_number
设备编号
train_dataset_element
当前step的训练数据
last_save_ckpt_step
最后一次存储ckpt的step
latest_ckpt_file
ckpt文件名
cur_epoch_num
当前的epoch
eval_network
评估网络
valid_dataset
验证集
metrics
评估指标
mode
mode
“train”或”eval”模式
batch_num
batch_num
训练或推理的batch数
list_callback
list_callback
回调列表
network
network
基础的网络结构
cur_step_num
cur_step_num
当前的训练或推理的step
dataset_sink_mode
dataset_sink_mode
训练或推理的数据是否下沉
net_outputs
net_outputs
训练或推理的网络输出
- 参数:
original_args (dict) - 模型的相关信息。
样例:
>>> from mindspore import Tensor >>> from mindspore.train import RunContext >>> cb_params = {} >>> cb_params["cur_epoch_num"] = 4 >>> cb_params["epoch_num"] = 4 >>> cb_params["cur_step_num"] = 2 >>> cb_params["batch_num"] = 2 >>> cb_params["net_outputs"] = Tensor(2.0) >>> run_context = RunContext(cb_params) >>> whether_stop = run_context.get_stop_requested()