mindspore.RunContext ================================ .. py:class:: mindspore.RunContext(original_args) 保存和管理模型的相关信息。 `RunContext` 主要用于收集训练或推理过程中模型的上下文相关信息并作为入参传入callback对象中来实现信息的共享。 Callback的类方法中,调用 `RunContext.original_args()` 可以获取模型当前的上下文信息,用户也可以为此信息添加额外的自定义属性,同时 `request_stop()` 方法可以控制训练过程的停止。具体用法请查看 `Callback `_。 `RunContext.original_args()` 存储的模型信息为一个字典型变量,在训练和推理过程会存储不同的属性,详情如下: +--------------------------+------------------------+---------------------------------------+ | 训练过程支持的属性 | 推理过程支持的属性 | 说明 | +==========================+========================+=======================================+ | train_network | | 包含了优化器和损失的训练网络 | +--------------------------+------------------------+---------------------------------------+ | epoch_num | | 训练的epoch数 | +--------------------------+------------------------+---------------------------------------+ | train_dataset | | 训练集 | +--------------------------+------------------------+---------------------------------------+ | loss_fn | | 损失函数 | +--------------------------+------------------------+---------------------------------------+ | optimizer | | 优化器 | +--------------------------+------------------------+---------------------------------------+ | parallel_mode | | 并行模式 | +--------------------------+------------------------+---------------------------------------+ | device_number | | 设备编号 | +--------------------------+------------------------+---------------------------------------+ | train_dataset_element | | 当前step的训练数据 | +--------------------------+------------------------+---------------------------------------+ | last_save_ckpt_step | | 最后一次存储ckpt的step | +--------------------------+------------------------+---------------------------------------+ | latest_ckpt_file | | ckpt文件名 | +--------------------------+------------------------+---------------------------------------+ | cur_epoch_num | | 当前的epoch | +--------------------------+------------------------+---------------------------------------+ | | eval_network | 评估网络 | +--------------------------+------------------------+---------------------------------------+ | | valid_dataset | 验证集 | +--------------------------+------------------------+---------------------------------------+ | | metrics | 评估指标 | +--------------------------+------------------------+---------------------------------------+ | mode | mode | "train"或"eval"模式 | +--------------------------+------------------------+---------------------------------------+ | batch_num | batch_num | 训练或推理的batch数 | +--------------------------+------------------------+---------------------------------------+ | list_callback | list_callback | 回调列表 | +--------------------------+------------------------+---------------------------------------+ | network | network | 基础的网络结构 | +--------------------------+------------------------+---------------------------------------+ | cur_step_num | cur_step_num | 当前的训练或推理的step | +--------------------------+------------------------+---------------------------------------+ | dataset_sink_mode | dataset_sink_mode | 训练或推理的数据是否下沉 | +--------------------------+------------------------+---------------------------------------+ | net_outputs | net_outputs | 训练或推理的网络输出 | +--------------------------+------------------------+---------------------------------------+ **参数:** - **original_args** (dict) - 模型的相关信息。 .. py:method:: get_stop_requested() 获取是否停止训练的标志。 **返回:** bool,如果为True,则 `Model.train()` 停止迭代。 .. py:method:: original_args() 获取模型相关信息的对象。 **返回:** dict,含有模型的相关信息的对象。 .. py:method:: request_stop() 在训练期间设置停止请求。 可以使用此函数请求停止训练。 `Model.train()` 会检查是否调用此函数。