Implement Device Training Based On Java Interface
Overview
This tutorial demonstrates how to use the Java API on MindSpore Lite by building and deploying LeNet of the Java version. The following describes how to train a LeNet model locally and then explains the sample code.
Preparation
Environment Requirements
Downloading MindSpore and Building the Java Package for On-device Training
Clone the source code and build the Java package for MindSpore Lite training. The Linux
command is as follows:
git clone https://gitee.com/mindspore/mindspore.git -b r2.0.0-alpha
cd mindspore
bash build.sh -I x86_64 -j8
Environment requirements and settings about the build, see Building MindSpore Lite.
The sample source code used in this tutorial is in the mindspore/lite/examples/train_lenet_java
directory.
Downloading the Dataset
The MNIST
dataset used in this example consists of 10 classes of 28 x 28 pixels grayscale images, where the training set contains 60,000 images, and the test set contains 10,000 images.
You can download the MNIST dataset from the official website http://yann.lecun.com/exdb/mnist/. The four links are used to download training data, training labels, test data, and test labels.
Download and decompress the package to the local host. The decompressed training and test sets are stored in the /PATH/MNIST_Data/train
and /PATH/MNIST_Data/test
directories, respectively.
The directory structure is as follows:
MNIST_Data/
├── test
│ ├── t10k-images-idx3-ubyte
│ └── t10k-labels-idx1-ubyte
└── train
├── train-images-idx3-ubyte
└── train-labels-idx1-ubyte
Deploying an Application
Building and Running
Go to the directory where the sample project is located and execute the sample project. The commands are as follows:
cd /codes/mindspore/mindspore/lite/examples/train_lenet_java ./prepare_and_run.sh -D /PATH/MNIST_Data/ -r ../../../../output/mindspore-lite-${version}-linux-x64.tar.gz
../resources/model/lenet_tod.ms is a LeNet training model preconfigured in the sample project. You can also convert it into a LeNet model by referring to Creating MindSpore Lite Models.
/PATH/MNIST_Data/ is the path of MNIST dataset.
The command output is as follows:
...... ==========Loading Model, Create Train Session============= Model path is ../model/lenet_tod.ms batch_size: 4 virtual batch multiplier: 1 ==========Initing DataSet================ train data cnt: 60000 test data cnt: 10000 ==========Training Model=================== step_500: Loss is 0.05553353 [min=0.010149269] max_acc=0.9543269 step_1000: Loss is 0.15295759 [min=0.0018140086] max_acc=0.96594554 step_1500: Loss is 0.018035552 [min=0.0018140086] max_acc=0.9704527 step_2000: Loss is 0.029250022 [min=0.0010245014] max_acc=0.9765625 ...... ==========Evaluating The Trained Model============ accuracy = 0.9770633 Trained model successfully saved: ../model/lenet_tod_trained.ms
Detailed Demo Description
Demo Structure
train_lenet_java
├── lib
├── build.sh
├── model
│ ├── lenet_export.py
│ ├── prepare_model.sh
│ └── train_utils.sh
├── pom.xml
├── prepare_and_run.sh
├── resources
│ └── model
│ └── lenet_tod.ms # LeNet training model
├── src
│ └── main
│ └── java
│ └── com
│ └── mindspore
│ └── lite
│ ├── train_lenet
│ │ ├── DataSet.java # MNIST dataset processing
│ │ ├── Main.java # Main function
│ │ └── NetRunner.java # Overall training process
Writing On-Device Inference Code
For details about how to use Java APIs, visit https://www.mindspore.cn/lite/api/en/r2.0.0-alpha/index.html.
Load the MindSpore Lite model file and build a session.
MSContext context = new MSContext(); // use default param init context context.init(); boolean isSuccess = context.addDeviceInfo(DeviceType.DT_CPU, false, 0); TrainCfg trainCfg = new TrainCfg(); trainCfg.init(); model = new Model(); Graph graph = new Graph(); graph.load(modelPath); model.build(graph, context, trainCfg); model.setupVirtualBatch(virtualBatch, 0.01f, 1.00f);
Switch to training mode, perform cyclic iteration, and train the model.
model.setTrainMode(true) float min_loss = 1000; float max_acc = 0; for (int i = 0; i < cycles; i++) { for (int b = 0; b < virtualBatch; b++) { fillInputData(ds.getTrainData(), false); model.runStep(); float loss = getLoss(); if (min_loss > loss) { min_loss = loss; } if ((b == 0) && ((i + 1) % 500 == 0)) { float acc = calculateAccuracy(10); // only test 10 batch size if (max_acc < acc) { max_acc = acc; } System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_acc=" + max_acc); } } }
Switch to inference mode, perform inference, and evaluate the model accuracy.
model.setTrainMode(false); for (long i = 0; i < tests; i++) { Vector<Integer> labels = fillInputData(test_set, (maxTests == -1)); if (labels.size() != batchSize) { System.err.println("unexpected labels size: " + labels.size() + " batch_size size: " + batchSize); System.exit(1); } model.predict(); MSTensor outputsv = searchOutputsForSize((int) (batchSize * numOfClasses)); if (outputsv == null) { System.err.println("can not find output tensor with size: " + batchSize * numOfClasses); System.exit(1); } float[] scores = outputsv.getFloatData(); for (int b = 0; b < batchSize; b++) { int max_idx = 0; float max_score = scores[(int) (numOfClasses * b)]; for (int c = 0; c < numOfClasses; c++) { if (scores[(int) (numOfClasses * b + c)] > max_score) { max_score = scores[(int) (numOfClasses * b + c)]; max_idx = c; } } if (labels.get(b) == max_idx) { accuracy += 1.0; } } }
After the inference is complete, if you need to continue the training, switch to training mode.
Save the trained model.
// arg 0: FileName // arg 1: quantization type QT_DEFAULT -> 0 // arg 2: model type MT_TRAIN -> 0 // arg 3: use default output tensor names model.export(trainedFilePath, 0, false, null);
After the model training is complete, save the model to the specified path. Then, you can continue to load and run the model.