Estimating memory requirements of transformer networks
Fine-tuning GPT/Bert based models for custom tasks I often found myself in the unfortunate situation of "Cuda out of memory". Turns out that the transformer models are memory intensive. In addition to this, the memory reuirements increase with the sequence length.
It will be of interest to get some idea of how much memory is needed to finetune/train the model. A rough estimate will help in estimating the resources needed for the task.
If you are short at time or don't want to go into details. You can skip to TL;DR and insights.
All neural network are trained with back propagation. Keeping this in mind following simple relation appears to give us memory usage
total_memory = memory_modal + memory_activations + memory_gradients
Here memory_modal
means the memory required to store all parameters of the model.
Activations are calculated and stored in forward pass. Gradients are calculated using activations. Also number of gradients are generally equal to number of parameters, resulting in memory_gradients = memory_modal
. Hence, we can write:
total_memory = memory_modal + 2 * memory_activations
In essence we need to find the values for memory_modal
and memory_activations
to estimate the total memory required.
Estimating model's memory requirements
Lets take GPT as an example. GPT consists of a number of transformer blocks (let's call it n_tr_blocks
from now on). Each transformer block consists of following structure:
multi_headed_attention --> layer_normalization --> MLP -->layer_normalization
Each multi_headed_attention
element consists of value nets
, key
and query
. Let's say that each of these have n_head
attention heads and dim
dimensions. MLP also has a dimension of n_head * dim
. The memory needed to store these will be
total_memory = memory of multi_headed_attention + memory of MLP
= memory of value nets + memory of key + memory of query + memory of MLP
= square_of(n_head * dim) + square_of(n_head * dim) + square_of(n_head * dim) + square_of(n_head * dim)
= 4*square_of(n_head * dim)
Since our modal contains n_tr_blocks
units of these blocks. Total memory required by the modal becomes.
memory_modal = 4*n_tr_blocks*square_of(n_head * dim)
Above estimation does not take into account the memory required for biases, since that is mostly static and does not depend on things like batch size, input sequence etc.
Estimating model activation's memory requirements
Multi headed attention is generally a softmax. More specifically it can written as:
multi_headed_attention = softmax(query * key * sequence_length) * value_net
query
key
and value_net
all have a tensor shape of
[batch_size, n_head, sequence_length, dim]
query * key * sequence_length
operation gives following resultant shape:
[batch_size, n_head, sequence_length, sequence_length]
This finally gives the memory cost of activation function as
memory_softmax = batch_size * n_head * square_of(sequence_length)
query * key * sequence_length
operation multiplied by value_net
has the shape of [batch_size, n_head, sequence_length, dim]
. MLP also has the same shape. So memory cost of these operations become:
memory of MLP = batch_size * n_head * sequence_length * dim
memory of value_net = batch_size * n_head * sequence_length * dim
This gives us the memory of model activation per block:
mem_act = memory_softmax + memory_value_net + memory_MLP
= batch_size * n_head * square_of(sequence_length)
+ batch_size * n_head * sequence_length * dim
+ batch_size * n_head * sequence_length * dim
= batch_size * n_head * sequence_length * (sequence_length + 2*dim)
Memory of model activation across the model will be:
n_tr_blocks * (batch_size * n_head * sequence_length * (sequence_length + 2*dim))
Summing it all up
To sum up total memory needed for fine-tuning/training transformer models is:
total_memory = memory_modal + 2 * memory_activations
Memory for modal is:
memory_modal = 4*n_tr_blocks*square_of(n_head * dim)
And memory for model activations is:
n_tr_blocks * (batch_size * n_head * sequence_length * (sequence_length + 2*dim))
These rough formulas can be written more succintly using following notation.
R = n_tr_blocks = number of transformer blocks in the model
N = n_head = number of attention heads
D = dim = dimension of each attention head
B = batch_size = batch size
S = sequence_length = input sequence length
memory modal = 4 * R * N^2 * D^2
memory activations = RBNS(S + 2D)
Total memory consumption if modal training is
M = (4 * R * N^2 * D^2) + RBNS(S + 2D)
If we have a very long sequence lengths S >> D
S + 2D <--> S
hence M
in this case becomes:
M = (4 * R * N^2 * D^2) + RBNS(S) = 4*R*N^2*D^2 + RBNS^2
M is directly proportional to square of length of input sequence for large sequences
M is lineraly proportional to the batch size.
TLDR
These rough formula for estimating the memory requirements of fine tuning transformer models
R = n_tr_blocks = number of transformer blocks in the model
N = n_head = number of attention heads
D = dim = dimension of each attention head
B = batch_size = batch size
S = sequence_length = input sequence length
memory modal = 4 * R * N^2 * D^2
memory activations = RBNS(S + 2D)
total memory required = ((4 * R * N^2 * D^2) + RBNS(S + 2D)) * float64 memory in bytes
Insights
-
Memory consumption is directly proportional to square of length of input sequence for large sequences
-
Memory consumption is lineraly proportional to the batch size.