mindspore.train.RunContext

查看源文件
class mindspore.train.RunContext(original_args)[源代码]

保存和管理模型的相关信息。

RunContext 主要用于收集训练或推理过程中模型的上下文相关信息,并作为入参传入callback对象中来实现信息的共享。

callback的类方法中,调用 RunContext.original_args() 可以获取模型当前的上下文信息,用户也可以为此信息添加额外的自定义属性,同时 request_stop() 方法可以控制训练过程的停止。

RunContext.original_args() 存储的模型信息为一个字典型变量,在训练和推理过程中会存储不同的属性。详情如下:

参数:
  • original_args (dict) - 模型的相关信息。

样例:

>>> from mindspore import Tensor
>>> from mindspore.train import RunContext
>>> cb_params = {}
>>> cb_params["cur_epoch_num"] = 4
>>> cb_params["epoch_num"] = 4
>>> cb_params["cur_step_num"] = 2
>>> cb_params["batch_num"] = 2
>>> cb_params["net_outputs"] = Tensor(2.0)
>>> run_context = RunContext(cb_params)
>>> whether_stop = run_context.get_stop_requested()
get_stop_requested()[源代码]

获取是否停止训练的标志。

返回:

bool,如果为True,则 Model.train() 停止迭代。

original_args()[源代码]

获取模型相关信息的对象。

返回:

dict,含有模型的相关信息的对象。

request_stop()[源代码]

在训练期间设置停止请求。

可以使用此函数请求停止训练。 Model.train() 会检查是否调用此函数。