Quantized LoRA

QLoRA is an attempt to further optimize finetuning of LLMs by doing it directly on quantized LLMs and letting the gradient updates happen to the LoRA adapters.

Quick overview of LoRA

LoRA is a technique to perform finetuning of LLMs without having to perform gradient updates on the entire parameter set (rather a much, much smaller subset).

It makes use of the assumption that the change of weights matrix ($\Delta W$) for each layer in an LLM has a low rank. Formally, if a weight matrix in an LLM is $W_{a \times b}$,

\[rank(\Delta W_{a \times b}) \ll a, b\]

This can be exploited by creating smaller matrices $A_{a \times h}$ and $B_{h \times b}$ (where $h \ll a, b$) that act as adapters for gradients to flow through them. The output of the layer then can be calculated as -

\[Y = XW + XAB\]

Where the weights in W are kept frozen, while A and B are updated.

Quick overview of Quantization

With Quantization, we attempt to reduce the precision required per parameter value to save memory/speed up calculations, without sacrificing the information per bit too much. This is generally done through scaling floats within a discretized range, while taking care of outliers by doing it in a block-wise fashion. Not going into the details here, link attached below.

Quantization significantly speeds up inference of models and makes them more accessible to a wide range of machines (even raspberry pis somehow). However finetuning a quantized LLM results in a significantly decreased performance compared to finetuning pre-quantization (which is obviously more expensive and inaccessible).


Quantized LLMs can be combined with LoRA adapters 🥳

Essentially, QLoRA uses 2 precision types - 4 bit for storage (general transformer weights) and 16 bit for computation (through dequantization). The gradients are only computed for the LoRA adapters which use 16 bit always. Computationally -

\[Y^{16} = X^{16} dequant'(W^4) + X^{16} A^{16} B^{16}\]

Where the superscripts represent the representation. (The $dequant'$ function is not the standard dequantization, and is explained further on here.)

Tricks used

The paper highlights some tricks and optimizations used for preserving memory and improving precision.