Decentralized ML training with Federated Learning
There are two inevitable truths in machine learning: more data and more computational power to train on this data.
In a recent article I covered a basic architecture covering compute and storage for ML. One of the main assumptions in the architecture I presented was in having all training data consolidated in a single location (storage system) and more critically, that this data is readily available for ML training. What happens then, if we need some data for training, but for a variety of reasons, that data cannot be stored in our environment? Enter Federated Learning.
The approach that I outlined in my article assumed that both data and compute are centralized. Under a centralized approach, all the data that is needed to build and train models is readily available and within the same environment as the compute. Federated Learning flips this model on its head. Rather than centralize data and compute, Federated Learning runs under a decentralized model without the need to share data. The latter is key.
The diagram below illustrates how Federated Learning works. The first step is to bootstrap a model, which will happen within your own environment. You build some initial version of a model using data and compute resources at your disposal. Then you push this model out to remote sites. Each site has some proprietary data that you want to train your model on. This data could be useful to help your model generalize and perform better. The model that you push down to these local sites is then trained on the per site local data. The data doesn’t leave the site. Once training completes, the new model from each site is pushed back to your environment. This will in turn produce a new model that is the result of the initial version and the models generated from each remote site.
In practice what happens, is the per site model weights are sent back and ensembled (or averaged) amongst all the weights from the other models, including the bootstrapped one. This results in a new version of your model which is derived from the combined models across all the remote sites you trained on. All done, without having to share data!
Federated Learning comes into play in several situations, perhaps the most prevalent and useful are massively distributed learning and to address data privacy concerns.
Consider the case whereby you have a wildly popular mobile application. It’s used by hundreds of millions of people globally. You might want to leverage the wild adoption of your mobile application in collaborative and highly distributed training. Google is a great example of leveraging Federated Learning in Gboard on Android, the Google Keyboard. When Gboard shows a suggested query, the user’s phone stores this information locally along with relevant metadata. Federated Learning processes that history on-device to suggest improvements to the next iteration of Gboard’s query suggestion model. The data that was stored on the users device, remains there and never reaches Google’s servers. However, the learnings that were locally implied from the user are applied to the global model.
Yet another example is one from a domain I am familiar with: healthcare. Imagine that you are building a model for cancer detection. You’ll need a lot of medical data spanning images, like x-rays, mammogram, PET/CTs and so on. You might also need other medical data like pathology reports and patient medical records. All of which are highly sensitive and private data. You’re not going to be able to move or copy this data to your own environment. Federated Learning can be of huge help here. Rather than send the data back to your mothership, you train a model locally at the remote site and use Federated Learnings to apply the learnings to your global model.