Client
import com.mindspore.flclient.model.Client
Client defines the execution process object of the end-side federated learning algorithm.
Public Member Functions
initCallbacks
public abstract List<Callback> initCallbacks(RunType runType, DataSet dataSet)
Initialize the callback list.
Parameters
runType
: RunType class, identify whether the training, evaluation or prediction phase.dataSet
: DataSet class, identify whether the training, evaluation or prediction phase datasets.
Returns
The initialized callback list.
initDataSets
public abstract Map<RunType, Integer> initDataSets(Map<RunType, List<String>> files)
Initialize dataset list.
Parameters
files
: Data files used in the training, evaluation or prediction phase.
Returns
Data counts in different run type.
getEvalAccuracy
public abstract float getEvalAccuracy(List<Callback> evalCallbacks)
Get eval model accuracy.
Parameters
evalCallbacks
: Callback used in eval phase.
Returns
The accuracy in eval phase.
getInferResult
public abstract List<Object> getInferResult(List<Callback> inferCallbacks)
Get infer phase result.
Parameters
inferCallbacks
: Callback used in prediction phase.
Returns
predict results.
trainModel
public Status trainModel(int epochs)
Execute train model process.
Parameters
epochs
: Epoch num used in train process.
Returns
Whether the train model is successful.
evalModel
public float evalModel()
Execute eval model process.
Returns
The accuracy in eval process.
genUnsupervisedEvalData
public Map<String, float[]> genUnsupervisedEvalData(List<Callback> evalCallbacks)
Generate unsupervised training evaluation data, and the subclass needs to rewrite this function.
Parameters
evalCallbacks
: the eval Callback that generates data.
Returns
unsupervised training evaluation data
inferModel
public List<Object> inferModel()
Execute model prediction process.
Returns
The prediction result.
setLearningRate
public Status setLearningRate(float lr)
Set learning rate.
Parameters
lr
: Learning rate.
Returns
Whether the set is successful.
setBatchSize
public void setBatchSize(int batchSize)
Set batch size.
Parameters
batchSize
: batch size.