Client

View Source On Gitee

import com.mindspore.flclient.model.Client

Client defines the execution process object of the end-side federated learning algorithm.

Public Member Functions

function

abstract List<Callback> initCallbacks(RunType runType, DataSet dataSet)

abstract Map<RunType, Integer> initDataSets(Map<RunType, List<String>> files)

abstract float getEvalAccuracy(List<Callback> evalCallbacks)

abstract List<Object> getInferResult(List<Callback> inferCallbacks)

Status trainModel(int epochs)

float evalModel()

Map<String, float[]> genUnsupervisedEvalData(List<Callback> evalCallbacks)

List<Object> inferModel()

Status setLearningRate(float lr)

void setBatchSize(int batchSize)

initCallbacks

public abstract List<Callback> initCallbacks(RunType runType, DataSet dataSet)

Initialize the callback list.

  • Parameters

    • runType: RunType class, identify whether the training, evaluation or prediction phase.

    • dataSet: DataSet class, identify whether the training, evaluation or prediction phase datasets.

  • Returns

    The initialized callback list.

initDataSets

public abstract Map<RunType, Integer> initDataSets(Map<RunType, List<String>> files)

Initialize dataset list.

  • Parameters

    • files: Data files used in the training, evaluation or prediction phase.

  • Returns

    Data counts in different run type.

getEvalAccuracy

public abstract float getEvalAccuracy(List<Callback> evalCallbacks)

Get eval model accuracy.

  • Parameters

    • evalCallbacks: Callback used in eval phase.

  • Returns

    The accuracy in eval phase.

getInferResult

public abstract List<Object> getInferResult(List<Callback> inferCallbacks)

Get infer phase result.

  • Parameters

    • inferCallbacks: Callback used in prediction phase.

  • Returns

    predict results.

trainModel

public Status trainModel(int epochs)

Execute train model process.

  • Parameters

    • epochs: Epoch num used in train process.

  • Returns

    Whether the train model is successful.

evalModel

public float evalModel()

Execute eval model process.

  • Returns

    The accuracy in eval process.

genUnsupervisedEvalData

public Map<String, float[]> genUnsupervisedEvalData(List<Callback> evalCallbacks)

Generate unsupervised training evaluation data, and the subclass needs to rewrite this function.

  • Parameters

    • evalCallbacks: the eval Callback that generates data.

  • Returns

    unsupervised training evaluation data

inferModel

public List<Object> inferModel()

Execute model prediction process.

  • Returns

    The prediction result.

setLearningRate

public Status setLearningRate(float lr)

Set learning rate.

  • Parameters

    • lr: Learning rate.

  • Returns

    Whether the set is successful.

setBatchSize

public void setBatchSize(int batchSize)

Set batch size.

  • Parameters

    • batchSize: batch size.