Client
import com.mindspore.flclient.model.Client
Client定义了端侧联邦学习算法执行流程对象。
公有成员函数
initCallbacks
public abstract List<Callback> initCallbacks(RunType runType, DataSet dataSet)
初始化callback列表。
参数
runType
: RunType类,标识训练、评估还是预测阶段。dataSet
: DataSet类,训练、评估还是预测阶段数据集。
返回值
初始化的callback列表。
initDataSets
public abstract Map<RunType, Integer> initDataSets(Map<RunType, List<String>> files)
初始化dataset列表。
参数
files
: 训练、评估和预测阶段使用的数据文件。
返回值
训练、评估和预测阶段数据集样本量。
getEvalAccuracy
public abstract float getEvalAccuracy(List<Callback> evalCallbacks)
获取评估阶段的精度。
参数
evalCallbacks
: 评估阶段使用的callback列表。
返回值
评估阶段精度。
getInferResult
public abstract List<Object> getInferResult(List<Callback> inferCallbacks)
获取预测结果。
参数
inferCallbacks
: 预测阶段使用的callback列表。
返回值
预测结果。
trainModel
public Status trainModel(int epochs)
开启模型训练。
参数
epochs
: 训练的epoch数。
返回值
模型训练结果。
evalModel
public float evalModel()
执行模型评估过程。
返回值
模型评估精度。
genUnsupervisedEvalData
public Map<String, float[]> genUnsupervisedEvalData(List<Callback> evalCallbacks)
生成无监督训练评估数据,子类需要覆写该函数。
参数
evalCallbacks
: 推理回调类,该类生成数据。
返回值
无监督训练评估数据。
inferModel
public List<Object> inferModel()
执行模型预测过程。
返回值
模型预测结果。
setLearningRate
public Status setLearningRate(float lr)
设置学习率。
参数
lr
: 学习率。
返回值
设置结果。
setBatchSize
public void setBatchSize(int batchSize)
设置执行批次数。
参数
batchSize
: 批次数。