Implementing a Sentiment Classification Application (Android)
In privacy compliance scenarios, the federated learning modeling mode based on device-cloud synergy can make full use of the advantages of device data and prevent sensitive user data from being directly reported to the cloud. When exploring the application scenarios of federated learning, we notice the input method scenario. Users attach great importance to their text privacy and intelligent functions on the input method. Therefore, federated learning is naturally applicable to the input method scenario. MindSpore Federated applies the federated language model to the emoji prediction function of the input method. The federated language model recommends emojis suitable for the current context based on the chat text data. During federated learning modeling, each emoji is defined as a sentiment label category, and each chat phrase corresponds to an emoji. MindSpore Federated defines the emoji prediction task as a federated sentiment classification task.
Preparations
Environment
For details, see Server Environment Configuration and Client Environment Configuration.
Data
The training data contains 20 user chat files. The directory structure is as follows:
datasets/supervise/train/
├── 0.txt # Training data of user 0
├── 1.txt # Training data of user 1
│
│ ......
│
└── 19.txt # Training data of user 19
The validation data contains one chat file. The directory structure is as follows:
datasets/supervise/eval/
└── eval.txt # Validation data
The labels in the training data and validation data correspond to four types of emojis: good
, leimu
, xiaoku
, xin
.
Defining the Network
The ALBERT language model[1] is used in federated learning. The ALBERT model on the client includes the embedding layer, encoder layer, and classifier layer.
For details about the network definition, see source code.
Generating a Device Model File
Exporting a Model as a MindIR File
The sample code is as follows:
import numpy as np
from mindspore import export, Tensor
from src.config import train_cfg, client_net_cfg
from src.cell_wrapper import NetworkTrainCell
# Build a model.
client_network_train_cell = NetworkTrainCell(client_net_cfg)
# Build input data.
input_ids = Tensor(np.zeros((train_cfg.batch_size, client_net_cfg.seq_length), dtype=np.int32))
attention_mask = Tensor(np.zeros((train_cfg.batch_size, client_net_cfg.seq_length), dtype=np.int32))
token_type_ids = Tensor(np.zeros((train_cfg.batch_size, client_net_cfg.seq_length), dtype=np.int32))
label_ids = Tensor(np.zeros((train_cfg.batch_size, client_net_cfg.num_labels), dtype=np.int32))
# Export the model.
export(client_network_train_cell, input_ids, attention_mask, token_type_ids, label_ids, file_name='albert_train.mindir', file_format='MINDIR')
Converting the MindIR File into an MS File that Can Be Used by the Federated Learning Framework on the Device
For details about how to generate a model file on the device, see Implementing an Image Classification Application.
Starting the Federated Learning Process
Start the script on the server. For details, see Cloud-based Deployment
Based on the training and inference tasks of the ALBERT model, the overall process is as follows:
Create an Android project.
Build the MindSpore Lite AAR package.
Describe the Android instance program structure.
Write code.
Configure Android project dependencies.
Build and run on Android.
Creating an Android Project
Create a project in Android Studio and install the corresponding SDK. (After the SDK version is specified, Android Studio automatically installs the SDK.)
Building the MindSpore Lite AAR Package
For details, see Federated Learning Deployment.
Name of the generated Android AAR package:
mindspore-lite-full-{version}.aar
Place the AAR package in the app/libs/ directory of the Android project.
Android Instance Program Structure
app
│ ├── libs # Binary archive file of the Android library project
| | └── mindspore-lite-full-{version}.aar # MindSpore Lite archive file of the Android version
├── src/main
│ ├── assets # Resource directory
| | └── model # Model directory
| | └── albert_ad_train.mindir.ms # Pre-trained model file
│ | └── albert_ad_infer.mindir.ms # Inference model file
│ | └── data # Data directory
| | └── 0.txt # training data file
| | └── vocab.txt # Dictionary file
| | └── vocab_map_ids.txt # Dictionary ID mapping file
| | └── eval.txt # Training result evaluation file
| | └── eval_no_label.txt # Inference data file
│ |
│ ├── java # Application code at the Java layer
│ │ └── ... Storing Android code files. Related directories can be customized.
│ │
│ ├── res # Resource files related to Android
│ └── AndroidManifest.xml # Android configuration file
│
│
├── build.gradle # Android project build file
├── download.gradle # Downloading the project dependency files
└── ...
Writing Code
AssetCopyer.java: This code file is used to store the resource files in the app/src/main/assets directory of the Android project to the disk of the Android system. In this way, the federated learning framework API can read the resource files based on the absolute path during model training and inference.
import android.content.Context; import java.io.File; import java.io.FileOutputStream; import java.io.InputStream; import java.util.logging.Logger; public class AssetCopyer { private static final Logger LOGGER = Logger.getLogger(AssetCopyer.class.toString()); public static void copyAllAssets(Context context,String destination) { LOGGER.info("destination: " + destination); copyAssetsToDst(context,"",destination); } // Copy the resource files in the assets directory to the disk of the Android system. You can view the specific path by printing destination. private static void copyAssetsToDst(Context context,String srcPath, String dstPath) { try { // Recursively obtain all file names in the assets directory. String[] fileNames =context.getAssets().list(srcPath); if (fileNames.length > 0) { // Build the destination file object. File file = new File(dstPath); // Create a destination directory. file.mkdirs(); for (String fileName : fileNames) { // Copy the file to the specified disk. if(!srcPath.equals("")) { copyAssetsToDst(context,srcPath + "/" + fileName,dstPath+"/"+fileName); }else{ copyAssetsToDst(context, fileName,dstPath+"/"+fileName); } } } else { // Build the input stream of the source file. InputStream is = context.getAssets().open(srcPath); // Build the output stream of the destination file. FileOutputStream fos = new FileOutputStream(new File(dstPath)); // Define a 1024-byte buffer array. byte[] buffer = new byte[1024]; int byteCount=0; // Write the source file to the destination file. while((byteCount=is.read(buffer))!=-1) { fos.write(buffer, 0, byteCount); } // Refresh the output stream. fos.flush(); // Close the input stream. is.close(); // Close the output stream. fos.close(); } } catch (Exception e) { e.printStackTrace(); } } }
FlJob.java: This code file is used to define training and inference tasks. For details about federated learning APIs, see Federal Learning APIs.
import android.annotation.SuppressLint; import android.os.Build; import androidx.annotation.RequiresApi; import com.mindspore.flAndroid.utils.AssetCopyer; import com.mindspore.flclient.FLParameter; import com.mindspore.flclient.SyncFLJob; import java.util.Arrays; import java.util.UUID; import java.util.logging.Logger; public class FlJob { private static final Logger LOGGER = Logger.getLogger(AssetCopyer.class.toString()); private final String parentPath; public FlJob(String parentPath) { this.parentPath = parentPath; } // Android federated learning training task @SuppressLint("NewApi") @RequiresApi(api = Build.VERSION_CODES.M) public void syncJobTrain() { // create dataMap String trainTxtPath = "data/albert/supervise/client/1.txt"; String evalTxtPath = "data/albert/supervise/eval/eval.txt"; // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter String vocabFile = "data/albert/supervise/vocab.txt"; // Path of the dictionary file for data preprocessing. String idsFile = "data/albert/supervise/vocab_map_ids.txt" // Path of the mapping ID file of a dictionary. Map<RunType, List<String>> dataMap = new HashMap<>(); List<String> trainPath = new ArrayList<>(); trainPath.add(trainTxtPath); trainPath.add(vocabFile); trainPath.add(idsFile); List<String> evalPath = new ArrayList<>(); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter evalPath.add(evalTxtPath); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter evalPath.add(vocabFile); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter evalPath.add(idsFile); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter dataMap.put(RunType.TRAINMODE, trainPath); dataMap.put(RunType.EVALMODE, evalPath); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter String flName = "com.mindspore.flclient.demo.albert.AlbertClient"; // The package path of AlBertClient.java String trainModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; // Absolute path String inferModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; // Absolute path, consistent with trainModelPath String sslProtocol = "TLSv1.2"; String deployEnv = "android"; // The url for device-cloud communication. Ensure that the Android device can access the server. Otherwise, the message "connection failed" is displayed. String domainName = "http://10.*.*.*:6668"; boolean ifUseElb = true; int serverNum = 4; int threadNum = 4; BindMode cpuBindMode = BindMode.NOT_BINDING_CORE; int batchSize = 32; FLParameter flParameter = FLParameter.getInstance(); flParameter.setFlName(flName); flParameter.setDataMap(dataMap); flParameter.setTrainModelPath(trainModelPath); flParameter.setInferModelPath(inferModelPath); flParameter.setSslProtocol(sslProtocol); flParameter.setDeployEnv(deployEnv); flParameter.setDomainName(domainName); flParameter.setUseElb(ifUseElb); flParameter.setServerNum(serverNum); flParameter.setThreadNum(threadNum); flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode)); // start FLJob SyncFLJob syncFLJob = new SyncFLJob(); syncFLJob.flJobRun(); } // Android federated learning inference task public void syncJobPredict() { // create dataMap String inferTxtPath = "data/albert/supervise/eval/eval.txt"; String vocabFile = "data/albert/supervise/vocab.txt"; String idsFile = "data/albert/supervise/vocab_map_ids.txt"; Map<RunType, List<String>> dataMap = new HashMap<>(); List<String> inferPath = new ArrayList<>(); inferPath.add(inferTxtPath); inferPath.add(vocabFile); inferPath.add(idsFile); dataMap.put(RunType.INFERMODE, inferPath); String flName = "com.mindspore.flclient.demo.albert.AlbertClient"; // The package path of AlBertClient.java String inferModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; // Absolute path, consistent with trainModelPath int threadNum = 4; BindMode cpuBindMode = BindMode.NOT_BINDING_CORE; int batchSize = 32; FLParameter flParameter = FLParameter.getInstance(); flParameter.setFlName(flName); flParameter.setDataMap(dataMap); flParameter.setInferModelPath(inferModelPath); flParameter.setThreadNum(threadNum); flParameter.setCpuBindMode(cpuBindMode); flParameter.setBatchSize(batchSize); // inference SyncFLJob syncFLJob = new SyncFLJob(); int[] labels = syncFLJob.modelInference(); LOGGER.info("labels = " + Arrays.toString(labels)); } }
MainActivity.java: This code file is used to start federated learning training and inference tasks.
import android.os.Build; import android.os.Bundle; import androidx.annotation.RequiresApi; import androidx.appcompat.app.AppCompatActivity; import com.huawei.flAndroid.job.FlJob; import com.huawei.flAndroid.utils.AssetCopyer; @RequiresApi(api = Build.VERSION_CODES.P) public class MainActivity extends AppCompatActivity { private String parentPath; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); // Obtain the disk path of the application in the Android system. this.parentPath = this.getExternalFilesDir(null).getAbsolutePath(); // Copy the resource files in the assets directory to the disk of the Android system. AssetCopyer.copyAllAssets(this.getApplicationContext(), parentPath); // Create a thread and start the federated learning training and inference tasks. new Thread(() -> { FlJob flJob = new FlJob(parentPath); flJob.syncJobTrain(); flJob.syncJobPredict(); }).start(); } }
Configuring Android Project Dependencies
AndroidManifest.xml
<?xml version="1.0" encoding="utf-8"?> <manifest xmlns:android="http://schemas.android.com/apk/res/android" package="com.huawei.flAndroid"> <!-- Allow network access.--> <uses-permission android:name="android.permission.INTERNET" /> <application android:allowBackup="true" android:supportsRtl="true" android:usesCleartextTraffic="true" android:theme="@style/Theme.Flclient"> <!--Customize the location of the MainActivity file.--> <activity android:name="com.huawei.flAndroid.activity.MainActivity"> <intent-filter> <action android:name="android.intent.action.MAIN" /> <category android:name="android.intent.category.LAUNCHER" /> </intent-filter> </activity> </application> </manifest>
app/build.gradle
plugins { id 'com.android.application' } android { // Android SDK build version. It is recommended that the version be later than 27. compileSdkVersion 30 buildToolsVersion "30.0.3" defaultConfig { applicationId "com.huawei.flAndroid" minSdkVersion 27 targetSdkVersion 30 versionCode 1 versionName "1.0" multiDexEnabled true testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" ndk { // Different mobile phone models correspond to different NDKs. Mate 20 corresponds to 'armeabi-v7a'. abiFilters 'armeabi-v7a' } } // Specified NDK version ndkVersion '21.3.6528147' sourceSets{ main { // Specified JNI directory jniLibs.srcDirs = ['libs'] jni.srcDirs = [] } } compileOptions { sourceCompatibility JavaVersion.VERSION_1_8 targetCompatibility JavaVersion.VERSION_1_8 } } dependencies { // AAR package to be scanned in the libs directory implementation fileTree(dir:'libs',include:['*.aar']) implementation 'androidx.appcompat:appcompat:1.1.0' implementation 'com.google.android.material:material:1.1.0' implementation 'androidx.constraintlayout:constraintlayout:1.1.3' androidTestImplementation 'androidx.test.ext:junit:1.1.1' androidTestImplementation 'androidx.test.espresso:espresso-core:3.2.0' implementation 'com.android.support:multidex:1.0.3' // Add third-party open source software that federated learning relies on implementation group: 'com.squareup.okhttp3', name: 'okhttp', version: '3.14.9' implementation group: 'com.google.flatbuffers', name: 'flatbuffers-java', version: '2.0.0' implementation(group: 'org.bouncycastle',name: 'bcprov-jdk15on', version: '1.68') }
Building and Running on Android
Connect to the Android device and run federated learning training and inference applications. Connect to the Android device through a USB cable for debugging. Click
Run 'app'
to run the federated learning task on your device.For details about how to connect the Android Studio to a device for debugging, see https://developer.android.com/studio/run/device. Android Studio can identify the mobile phone only when USB debugging mode is enabled on the mobile phone. For Huawei phones, enable USB debugging mode by choosing Settings > System & updates > Developer options > USB debugging.
Continue the installation on the Android device. After the installation is complete, you can start the app to train and infer the ALBERT model for federated learning.
The program running result is as follows:
I/SyncFLJob: <FLClient> [model inference] inference finish I/SyncFLJob: labels = [2, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 4, 4, 4, 4]
Experiment Result
The total number of federated learning iterations is 5, the number of epochs for local training on the client is 10, and the value of batchSize is 16.
Top 1 Accuracy |
Top 5 Accuracy |
|
---|---|---|
ALBERT |
24% |
70% |
References
[1] Lan Z , Chen M , Goodman S , et al. ALBERT: A Lite BERT for Self-supervised Learning of Language Representations[J]. 2019.