Fine-tuning GPT/Bert based models for custom tasks I often found myself in the unfortunate situation of "Cuda out of memory". Turns out that the transformer models are memory intensive. In addition to this, the memory reuirements increase with the sequence length.

It will be of interest to get some idea of how much memory is needed to finetune/train the model. A rough estimate will help in estimating the resources needed for the task.

If you are short at time or don't want to go into details. You can skip to TL;DR and insights.

All neural network are trained with back propagation. Keeping this in mind following simple relation appears to give us memory usage

total_memory = memory_modal + memory_activations + memory_gradients

Here memory_modal means the memory required to store all parameters of the model. Activations are calculated and stored in forward pass. Gradients are calculated using activations. Also number of gradients are generally equal to number of parameters, resulting in memory_gradients = memory_modal. Hence, we can write:

total_memory = memory_modal + 2 * memory_activations

In essence we need to find the values for memory_modal and memory_activations to estimate the total memory required.

Estimating model's memory requirements

Lets take GPT as an example. GPT consists of a number of transformer blocks (let's call it n_tr_blocks from now on). Each transformer block consists of following structure:

multi_headed_attention --> layer_normalization --> MLP -->layer_normalization

Each multi_headed_attention element consists of value nets, key and query. Let's say that each of these have n_head attention heads and dim dimensions. MLP also has a dimension of n_head * dim. The memory needed to store these will be

total_memory = memory of multi_headed_attention + memory of MLP
			 = memory of value nets + memory of key + memory of query + memory of MLP
			 = square_of(n_head * dim) + square_of(n_head * dim) + square_of(n_head * dim) + square_of(n_head * dim)
			 = 4*square_of(n_head * dim)

Since our modal contains n_tr_blocks units of these blocks. Total memory required by the modal becomes.

memory_modal = 4*n_tr_blocks*square_of(n_head * dim)

Above estimation does not take into account the memory required for biases, since that is mostly static and does not depend on things like batch size, input sequence etc.

Estimating model activation's memory requirements

Multi headed attention is generally a softmax. More specifically it can written as:

multi_headed_attention = softmax(query * key * sequence_length) * value_net

query key and value_net all have a tensor shape of

[batch_size, n_head, sequence_length, dim]

query * key * sequence_length operation gives following resultant shape:

[batch_size, n_head, sequence_length, sequence_length]

This finally gives the memory cost of activation function as

memory_softmax  = batch_size * n_head * square_of(sequence_length)

query * key * sequence_length operation multiplied by value_net has the shape of [batch_size, n_head, sequence_length, dim]. MLP also has the same shape. So memory cost of these operations become:

memory of MLP  = batch_size * n_head * sequence_length * dim
memory of value_net  = batch_size * n_head * sequence_length * dim

This gives us the memory of model activation per block:

mem_act = memory_softmax + memory_value_net + memory_MLP
		= batch_size * n_head * square_of(sequence_length)
		  + batch_size * n_head * sequence_length * dim
		  + batch_size * n_head * sequence_length * dim
		= batch_size * n_head * sequence_length * (sequence_length + 2*dim)

Memory of model activation across the model will be:

n_tr_blocks * (batch_size * n_head * sequence_length * (sequence_length + 2*dim))

Summing it all up

To sum up total memory needed for fine-tuning/training transformer models is:

total_memory = memory_modal + 2 * memory_activations

Memory for modal is:

memory_modal = 4*n_tr_blocks*square_of(n_head * dim)

And memory for model activations is:

n_tr_blocks * (batch_size * n_head * sequence_length * (sequence_length + 2*dim))

These rough formulas can be written more succintly using following notation.

R = n_tr_blocks = number of transformer blocks in the model
N = n_head = number of attention heads
D = dim = dimension of each attention head
B = batch_size = batch size
S = sequence_length = input sequence length



memory modal = 4 * R * N^2 * D^2

memory activations = RBNS(S + 2D)

Total memory consumption if modal training is

M = (4 * R * N^2 * D^2) + RBNS(S + 2D)

If we have a very long sequence lengths S >> D S + 2D <--> S hence M in this case becomes:

M = (4 * R * N^2 * D^2) + RBNS(S) = 4*R*N^2*D^2 + RBNS^2

M is directly proportional to square of length of input sequence for large sequences
M is lineraly proportional to the batch size.

TLDR

These rough formula for estimating the memory requirements of fine tuning transformer models

R = n_tr_blocks = number of transformer blocks in the model
N = n_head = number of attention heads
D = dim = dimension of each attention head
B = batch_size = batch size
S = sequence_length = input sequence length



memory modal = 4 * R * N^2 * D^2

memory activations = RBNS(S + 2D)

total memory required = ((4 * R * N^2 * D^2) + RBNS(S + 2D)) * float64 memory in bytes

Insights

  1. Memory consumption is directly proportional to square of length of input sequence for large sequences

  2. Memory consumption is lineraly proportional to the batch size.