Client

View Source On Gitee

import com.mindspore.flclient.model.Client

Client defines the execution process object of the end-side federated learning algorithm.

Public Member Functions

function

abstract List<Callback> initCallbacks(RunType runType, DataSet dataSet)

abstract Map<RunType, Integer> initDataSets(Map<RunType, List<String>> files)

abstract float getEvalAccuracy(List<Callback> evalCallbacks)

abstract List<Integer> getInferResult(List<Callback> inferCallbacks)

Status initSessionAndInputs(String modelPath, MSConfig config)

Status trainModel(int epochs)

evalModel()

List<Integer> inferModel()

Status saveModel(String modelPath)

List<MSTensor> getFeatures()

Status updateFeatures(String modelName, List<FeatureMap> featureMaps)

void free()

Status setLearningRate(float lr)

void setBatchSize(int batchSize)

initCallbacks

public abstract List<Callback> initCallbacks(RunType runType, DataSet dataSet)

Initialize the callback list.

  • Parameters

    • runType: Define run phase.

    • dataSet: DataSet.

  • Returns

    The initialized callback list.

initDataSets

public abstract Map<RunType, Integer> initDataSets(Map<RunType, List<String>> files)

Initialize dataset list.

  • Parameters

    • files: Data files.

  • Returns

    Data counts in different run type.

getEvalAccuracy

public abstract float getEvalAccuracy(List<Callback> evalCallbacks)

Get eval model accuracy.

  • Parameters

    • evalCallbacks: Callback used in eval phase.

  • Returns

    The accuracy in eval phase.

getInferResult

public abstract List<Integer> getInferResult(List<Callback> inferCallbacks)

Get infer phase result.

  • Parameters

    • inferCallbacks: Callback used in infer phase.

  • Returns

    predict results.

initSessionAndInputs

public Status initSessionAndInputs(String modelPath, MSConfig config)

Initialize client runtime session and input buffer.

  • Parameters

    • modelPath: Model file path.

    • config: session config.

  • Returns

    Whether the Initialization is successful.

trainModel

public Status trainModel(int epochs)

Execute train model process.

  • Parameters

    • epochs: Epoch num used in train process.

  • Returns

    Whether the train model is successful.

evalModel

public float evalModel()

Execute eval model process.

  • Returns

    The accuracy in eval process.

inferModel

public List<Integer> inferModel()

Execute infer model process.

  • Returns

    The infer result in infer process.

saveModel

public Status saveModel(String modelPath)

Save model.

  • Returns

    Whether the inference is successful.

getFeatures

public List<MSTensor> getFeatures()

Get feature weights.

  • Returns

    The feature weights of model.

updateFeatures

public Status updateFeatures(String modelName, List<FeatureMap> featureMaps)

Update model feature weights.

  • Parameters

    • modelName: Model file name.

    • featureMaps: New model weights.

  • Returns

    Whether the update is successful.

free

public void free()

free model.

setLearningRate

public Status setLearningRate(float lr)

Set learning rate.

  • Parameters

    • lr: Learning rate.

  • Returns

    Whether the set is successful.

setBatchSize

public void setBatchSize(int batchSize)

Set batch size.

  • Parameters

    • batchSize: batch size.