In partnership with Paperspace
One of the key challenges of machine learning is the need for large amounts of data. Gathering training datasets for machine learning models poses privacy, security, and processing risks that organizations would rather avoid.
One technique that can help address some of these challenges is “federated learning.” By distributing the training of models across user devices, federated learning makes it possible to take advantage of machine learning while minimizing the need to collect user data.
Cloud-based machine learning
The traditional process for developing machine learning applications is to gather a large dataset, train a model on the data, and run the trained model on a cloud server that users can reach through different applications such as web search, translation, text generation, and image processing.
Every time the application wants to use the machine learning model, it has to send the user’s data to the server where the model resides.
In many cases sending data to the server is inevitable. For example, this paradigm is inevitable for content recommendation systems because part of the data and content needed for machine learning inference resides on the cloud server.
But in applications such as text autocompletion or facial recognition, the data is local to the user and the device. In these cases, it would be preferable for the data to stay on the user’s device instead of being sent to the cloud.
Fortunately, advances in edge AI have made it possible to avoid sending sensitive user data to application servers. Also known as TinyML, this is an active area of research and tries to create machine learning models that fit on smartphones and other user devices. These models make it possible to perform on-device inference. Large tech companies are trying to bring some of their machine learning applications to users’ devices to improve privacy.
On-device machine learning has several added benefits. These applications can continue to work even when the device is not connected to the internet. They also provide the benefit of saving bandwidth when users are on metered connections. And in many applications, on-device inference is more energy-efficient than sending data to the cloud.
Training on-device machine learning models
On-device inference is an important privacy upgrade for machine learning applications. But one challenge remains: Developers still need data to train the models they will push on users’ devices. This doesn’t pose a problem when the organization developing the models already owns the data (e.g., a bank owns its transactions) or the data is public knowledge (e.g., Wikipedia or news articles).
But if a company wants to train machine learning models that involve confidential user information such as emails, chat logs, or personal photos, then collecting training data entails many challenges. The company will have to make sure its collection and storage policy is conformant with the various data protection regulations and is anonymized to remove personally identifiable information (PII).
Once the machine learning model is trained, the developer team must make decisions on whether it will preserve or discard the training data. They will also have to have a policy and procedure to continue collecting data from users to retrain and update their models regularly.
This is the problem federated learning addresses.
The main idea behind federated learning is to train a machine learning model on user data without the need to transfer that data to cloud servers.
Federated learning starts with a base machine learning model in the cloud server. This model is either trained on public data (e.g., Wikipedia articles or the ImageNet dataset) or has not been trained at all.
In the next stage, several user devices volunteer to train the model. These devices hold user data that is relevant to the model’s application, such as chat logs and keystrokes.
These devices download the base model at a suitable time, for instance when they are on a wi-fi network and are connected to a power outlet (training is a compute-intensive operation and will drain the device’s battery if done at an improper time). Then they train the model on the device’s local data.
After training, they return the trained model to the server. Popular machine learning algorithms such as deep neural networks and support vector machines is that they are parametric. Once trained, they encode the statistical patterns of their data in numerical parameters and they no longer need the training data for inference. Therefore, when the device sends the trained model back to the server, it doesn’t contain raw user data.
Once the server receives the data from user devices, it updates the base model with the aggregate parameter values of user-trained models.
The federated learning cycle must be repeated several times before the model reaches the optimal level of accuracy that the developers desire. Once the final model is ready, it can be distributed to all users for on-device inference.
Limits of federated learning
Federated learning does not apply to all machine learning applications. If the model is too large to run on user devices, then the developer will need to find other workarounds to preserve user privacy.
On the other hand, the developers must make sure that the data on user devices are relevant to the application. The traditional machine learning development cycle involves intensive data cleaning practices in which data engineers remove misleading data points and fill the gaps where data is missing. Training machine learning models on irrelevant data can do more harm than good.
When the training data is on the user’s device, the data engineers have no way of evaluating the data and making sure it will be beneficial to the application. For this reason, federated learning must be limited to applications where the user data does not need preprocessing.
Another limit of federated machine learning is data labeling. Most machine learning models are supervised, which means they require training examples that are manually labeled by human annotators. For example, the ImageNet dataset is a crowdsourced repository that contains millions of images and their corresponding classes.
In federated learning, unless outcomes can be inferred from user interactions (e.g., predicting the next word the user is typing), the developers can’t expect users to go out of their way to label training data for the machine learning model. Federated learning is better suited for unsupervised learning applications such as language modeling.
Privacy implications of federated learning
While sending trained model parameters to the server is less privacy-sensitive than sending user data, it doesn’t mean that the model parameters are completely clean of private data.
In fact, many experiments have shown that trained machine learning models might memorize user data and membership inference attacks can recreate training data in some models through trial and error.
One important remedy to the privacy concerns of federated learning is to discard the user-trained models after they are integrated into the central model. The cloud server doesn’t need to store individual models once it updates its base model.
Another measure that can help is to increase the pool of model trainers. For example, if a model needs to be trained on the data of 100 users, the engineers can increase their pool of trainers to 250 or 500 users. For each training iteration, the system will send the base model to 100 random users from the training pool. This way, the system doesn’t collect trained parameters from any single user constantly.
Finally, by adding a bit of noise to the trained parameters and using normalization techniques, developers can considerably reduce the model’s ability to memorize users’ data.
Federated learning is gaining popularity as it addresses some of the fundamental problems of modern artificial intelligence. Researchers are constanly looking for new ways to apply federated learning to new AI applications and overcome its limits. It will be interesting to see how the field evolves in the future.