Client

查看源文件

import com.mindspore.flclient.model.Client

Client定义了端侧联邦学习算法执行流程对象。

公有成员函数

function

abstract List<Callback> initCallbacks(RunType runType, DataSet dataSet)

abstract Map<RunType, Integer> initDataSets(Map<RunType, List<String>> files)

abstract float getEvalAccuracy(List<Callback> evalCallbacks)

abstract List<Integer> getInferResult(List<Callback> inferCallbacks)

Status initSessionAndInputs(String modelPath, MSConfig config)

Status trainModel(int epochs)

evalModel()

List<Integer> inferModel()

Status saveModel(String modelPath)

List<MSTensor> getFeatures()

Status updateFeatures(String modelName, List<FeatureMap> featureMaps)

void free()

Status setLearningRate(float lr)

void setBatchSize(int batchSize)

initCallbacks

public abstract List<Callback> initCallbacks(RunType runType, DataSet dataSet)

初始化callback列表。

  • 参数

    • runType: RunType类,标识训练、评估还是预测阶段。

    • dataSet: DataSet类,训练、评估还是预测阶段数据集。

  • 返回值

    初始化的callback列表。

initDataSets

public abstract Map<RunType, Integer> initDataSets(Map<RunType, List<String>> files)

初始化dataset列表。

  • 参数

    • files: 训练、评估和预测阶段使用的数据文件。

  • 返回值

    训练、评估和预测阶段数据集样本量。

getEvalAccuracy

public abstract float getEvalAccuracy(List<Callback> evalCallbacks)

获取评估阶段的精度。

  • 参数

    • evalCallbacks: 评估阶段使用的callback列表。

  • 返回值

    评估阶段精度。

getInferResult

public abstract List<Integer> getInferResult(List<Callback> inferCallbacks)

获取预测结果。

  • 参数

    • inferCallbacks: 预测阶段使用的callback列表。

  • 返回值

    预测结果。

initSessionAndInputs

public Status initSessionAndInputs(String modelPath, MSConfig config)

初始化client底层会话和输入。

  • 参数

    • modelPath: 模型文件。

    • config: 会话配置。

  • 返回值

    初始化状态结果。

trainModel

public Status trainModel(int epochs)

开启模型训练。

  • 参数

    • epochs: 训练的epoch数。

  • 返回值

    模型训练结果。

evalModel

public float evalModel()

执行模型评估过程。

  • 返回值

    模型评估精度。

inferModel

public List<Integer> inferModel()

执行模型预测过程。

  • 返回值

    模型预测结果。

saveModel

public Status saveModel(String modelPath)

保存模型。

  • 返回值

    模型保存结果。

getFeatures

public List<MSTensor> getFeatures()

获取端侧权重。

  • 返回值

    模型权重。

updateFeatures

public Status updateFeatures(String modelName, List<FeatureMap> featureMaps)

更新端侧权重。

  • 参数

    • modelName: 待更新的模型文件。

    • featureMaps: 待更新的模型权重。

  • 返回值

    模型权重。

free

public void free()

释放模型。

setLearningRate

public Status setLearningRate(float lr)

设置学习率。

  • 参数

    • lr: 学习率。

  • 返回值

    设置结果。

setBatchSize

public void setBatchSize(int batchSize)

设置执行批次数。

  • 参数

    • batchSize: 批次数。