Training Neural Nets on Larger Batches: Practical Tips for 1-GPU, Multi-GPU & Distributed setups

Training Neural Nets on Larger Batches: Practical Tips for 1-GPU, Multi-GPU & Distributed setups

Medium uses browser cookies to give you the best possible experience. To make Medium work, we log user data and share it with processors. To use Medium, you must agree to our Privacy Policy, including cookie policy.
Homepage
HuggingFace
HOMEARTIFICIAL INTELLIGENCENATURAL LANGUAGE PROCESSINGIOS APPLICATIONSGET THE APP
Go to the profile of Thomas Wolf
Thomas Wolf
Natural Language Processing, Deep learning and Computational Linguistics – Science Lead @ Huggingface
Oct 15

By David Marcu
💥 Training Neural Nets on Larger Batches: Practical Tips for 1-GPU, Multi-GPU & Distributed setups
I’ve spent most of 2018 training neural networks that tackle the limits of my GPUs. Whether it was a 150 millions parameters language model like OpenAI’s huge Generative Pre-trained Transformer (or the recent and similar BERT model) or a meta-learning neural net fed with 30 million element inputs like the one of our ICLR ‘18 paper, I could barely fit more than a few training samples on a GPU.

But most of the time stochastic gradient descent algorithms require larger batches than just a handful of examples to get decent results.

How can you train your model on large batches when your GPU can’t hold more than a few samples?
There are several tools, tips and tricks you can use to do that and I thought it would be nice to gather all the things I use and learned in a post.

In this post I will mainly talk about the PyTorch framework. Some of these tools are not in PyTorch yet (as of 1.0) so I include some custom code as well.

In particular, we’ll talk about:

How you can train a model on a single or multi GPU server with batches larger than the GPUs memory or when even a single training sample won’t fit (!),
How you can make the most efficient use of a multi-GPU machine, and
The simplest way to train using several machines in a distributed setting.
Let’s start by the simplest trick: gradient accumulation.

⌛️Large batches on one or several GPU(s)
So, you’ve build a nice model that might be the new SOTA on this neat task but every time you try to stack more than a few samples in a batch you get a CUDA RuntimeError: out of memory.

Adam confirms your predicament! 😱Oh no!
But you’re pretty sure that doubling the batch size will improve the results.

How can you do that?
There is an easy solution to this problem: accumulating gradients. Here is a quick reminder on how stochastic gradient descent works from my earlier post on meta-learning:

The 5-steps of a gradient descent optimization algorithm
The PyTorch code equivalent of these 5 steps can also be written in 5 lines:

During the loss.backward() operation, gradients are computed for each parameter (in green on our animation) and stored in a tensor associated to each parameter: parameter.grad (the middle graph on our animation).

Accumulating gradients just means that, before calling optimizer.step() to perform a step of gradient descent, we will sum the gradients of several backward operations in the parameter.grad tensors. This is straightforward to do in PyTorch as the gradient tensors are not reset unless we call model.zero_grad() or optimizer.zero_grad(). We’ll also need to divide by the number of accumulation steps if our loss is averaged over the training samples.

Here is a simple gist for training a model using gradient accumulation. In this example we can train with a batch size that is accumulation_steps-larger than the maximum size that fits on our GPU(s):

😱 Pushing that to the extreme
Can you train a model for which not even a single sample can fit on a GPU?

Well if your architecture doesn’t have too-much skip connections, yes, it’s possible! The solution is to trade compute for memory using gradient-checkpointing.

Basically, the idea is to back-propagate the gradients in small chunks along the model, trading the memory needed to store a full back propagation graph with the additional compute of a partial forward pass associated to each chunk. This is a rather slow method as we add additional compute to reduce the memory requirements but it can be interesting in some settings, e.g. to train RNN models over very long sequences (see for example my previous introduction to meta-learning).

I won’t go into more details here and will just refer you to the relevant links:

TensorFlow: https://github.com/openai/gradient-checkpointing
PyTorch doc: https://pytorch.org/docs/stable/checkpoint.html

A “Memory-poor” strategy that needs O(1) memory (but requires O(n²) computation steps) — From Yaroslav Bulatov’s nice post: https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9
🕰 Making the best of a multi-GPU machine
Now let’s talk more specifically about training model on multi-GPUs.

The go-to strategy to train a PyTorch model on a multi-GPU server is to use torch.nn.DataParallel. It’s a container which parallelizes the application of a module by splitting the input across the specified devices, chunking along the batch dimension.

DataParallel is very easy to use, we just add one line to encapsulate the model:

However one issue can arise with DataParallel: unbalanced GPU usage.

Under some settings GPU-1 will be used a lot more than the other GPUs.
Where does this come from? I made an illustration to better explain what DataParallel does under the hood:

Forward and Backward passes with torch.nn.DataParallel
During step 4 of the Forward pass (top-right), the results of all the parallel computations are gathered on GPU-1. This is fine for a lot of classification problems but it can become problematic when you train a language model on large batch for example.

Let’s quickly compute the size of the output for a language model:

Number of elements in the output of a language model
If we assume a 40k vocabulary, 250 tokens in our sequences, 32 samples per batch and 4 bytes to store each element in the memory, the output of our model takes about 1,2 GB. We need to double that to store the associated gradient tensors, our model output thus requires 2,4 GB of memory!

That’s a significant portion of a typical 10 GB GPU memory and means that GPU-1 will be over-used with regards to the other GPUs, limiting the effect of the parallelization.

We cannot easily reduce the number of elements in this output without tweaking the model and/or optimization scheme. But we can make sure the memory load is more evenly distributed among the GPUs.

⚖️ Balanced load on a multi-GPU machine
The solution is to keep each partial output on its GPU instead of gathering all of them to GPU-1. We well need to distribute our loss criterion computation as well to be able to compute and back propagate our loss.

Thankfully for us, Hang Zhang (张航) has open-sourced a nice PyTorch package called PyTorch-Encoding which comprises these custom parallelization functions.

I’ve extracted and slightly adapted this module and you can download here a gist (parallel.py) to include and call from your code. It mainly comprises two modules: DataParallelModel and DataParallelCriterion which are made to be used as follows:

The difference between DataParallelModel and torch.nn.DataParallel is just that the output of the forward pass (predictions) is not gathered on GPU-1 and is thus a tuple of n_gpu tensors, each tensor being located on a respective GPU.

The DataParallelCriterion container encapsulate the loss function and takes as input the tuple of n_gpu tensors and the target labels tensor. It computes the loss function in parallel on each GPU, splitting the target label tensor the same way the model input was chunked by DataParallel.

I made an illustration of DataParallelModel/DataParallelCriterion internals:

Using DataParallelModel and DataParallelCriterion
Here is how to handle two particular cases you may encounter:

Your model outputs several tensors: you likely want to disentangle them: output_1, output_2 = zip(*predictions)
Sometimes you don’t want to use a parallel loss function: gather all the tensors on the cpu: gathered_predictions = parallel.gather(predictions)
⏰ Distributed training: training on several machines
Now how can we harness the power of several servers to train on even larger batches?

The simplest option is to use PyTorch DistributedDataParallel which is meant to be almost a drop-in replacement for DataParallel discussed above.

But be careful: while the code looks similar, training your model in a distributed setting will change your workflow because you will actually have to start an independent python training script on each node (these scripts are all identical). As we will see, once started, these training scripts will be synchronized together by PyTorch distributed backend.

In practice, this means that each training script will have:

its own optimizer and performs a complete optimization step with each iteration, no parameter broadcast (step 2 in DataParallel) is needed,
an independent Python interpreter: this will also avoid the GIL-freeze that can come from driving several parallel execution threads in a single Python interpreter.
Models that make heavy use of Python loops/call in their forward passes can be slowed down by the python interpreter’s GIL when several parallel forward calls are driven by a single interpreter. In these settings, DistributedDataParallel can advantageously replace DataParallel even on a single-machine setup.
Now let’s just dive straight in the code and usage.

DistributedDataParallel is build on top of torch.distributed package which provide low-level primitives for synchronizing distributed operations and can make use of several backends (tcp, gloo, mpi, nccl) with different capabilities.

In this post I will select one simple way to use it out-of-the-box but you should read the doc and this nice tutorial by Séb Arnold to dive deeper in this module.

We will consider a simple but general setup with two 4-GPU servers (nodes):

The main server (server 1) has an accessible IP and an open port for communication.
🏃 Adapting our Python training script for distributed training
First we need to adapt our script so that it can be run separately on each machine (node). We are actually going to go fully distributed and run a separate process for each GPU of each node, so 8 process in total.

Our training script is a bit longer as we need to initialize the distributed backend for synchronization, encapsulate the model and prepare the data to train each process on a separate subset of the data (every process is independent so we have to care of that ourselves). Here is the updated code:

✨ Launching multiple instances of our Python training script
We are almost done now. We just have to start an instance of our training script on each server.

To run our script, we’ll use the torch.distributed.launch utility of PyTorch. It will take care of setting the environment variables and call each script with the right local_rank argument.
The first machine will be our master, it need to be accessible from all the other machine and thus have an accessible IP address (192.168.1.1 in our example) and an open port (1234 in our case). On this first machine, we run our training script using torch.distributed.launch:

python -m torch.distributed.launch –nproc_per_node=4 –nnodes=2 –node_rank=0 –master_addr=”192.168.1.1″ –master_port=1234 OUR_TRAINING_SCRIPT.py (–arg1 –arg2 –arg3 and all other arguments of our training script)
On the second machine we similarly start our script:

python -m torch.distributed.launch –nproc_per_node=4 –nnodes=2 –node_rank=1 –master_addr=”192.168.1.1″ –master_port=1234 OUR_TRAINING_SCRIPT.py (–arg1 –arg2 –arg3 and all other arguments of our training script)
These two commands are identical excepted for the –node_rank argument which is set to 0 on the first machine and 1 on the second (and would be 2 on an additional server etc…)

The process of running a bunch of almost identical commands on a cluster of machine might looks a bit tedious. So now is probably a good time to learn about the magic of… GNU parallel:

One exciting improvement of PyTorch v1.0 is the release of the c10d backend for the distributed module. I will update this simple introduction when v1.0 is released with more details on the new backend 🔥

This conclude our quick post on a few tips, tricks and tools to train your model on larger batches in a variety of settings.

I hope you enjoyed this more technical post!

Clap 👏 a couple of times if you liked it and want us to post more of these!

Machine LearningNLPPytorchAITutorial
Go to the profile of Thomas Wolf
Thomas Wolf
Natural Language Processing, Deep learning and Computational Linguistics – Science Lead @ Huggingface

HuggingFace
HuggingFace
Stories @ Hugging Face

More from HuggingFace
🚀 100 Times Faster Natural Language Processing in Python
Go to the profile of Thomas Wolf
Thomas Wolf
More from HuggingFace
⛵ Learning Meaning in Natural Language Processing – The Semantics Mega-Thread
Go to the profile of Thomas Wolf
Thomas Wolf
Also tagged AI
Deep Learning Performance Cheat Sheet
Go to the profile of Christopher Dossman
Christopher Dossman
Responses
Applause from Thomas Wolf (author)
Go to the profile of Brian
Brian
Oct 15
Thank you for the great explanation of how to train large models.

Conversation between Avnish Kumar and Thomas Wolf.
Go to the profile of Avnish Kumar
Avnish Kumar
Oct 15
This was a really helpful article Thomas. Thank you.
One question for you – what is “2,4 Go of memory”?

Go to the profile of Thomas Wolf
Thomas Wolf
Oct 15
Oh yeah that’s GB (GigaBytes). Corrected that, thanks!

HuggingFace
Never miss a story from HuggingFace, when you sign up for Medium. Learn more

Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

This site uses Akismet to reduce spam. Learn how your comment data is processed.