If you weren’t already aware, the T in chatGPT stands for Transformer, the linchpin architecture for developing state-of-the-art AI. Initially developed for machine translation, the Transformer is a neural network architecture that introduced self-attention in the paper: Attention is all you need. Through layers of interconnected nodes that map out an internal mathematical represesentation identifying relationships and relevance, an input sequence is transformed into an output sequence.
The advent of the Transformer architecture ushered in a new era of AI research, focused on increasing the efficiency of its core mechanism, attention. Attention’s scalability is compromised by its time and memory complexity that scales quadratically or O(n^2) with sequence length, n. This is rather troublesome as efficiently modeling long sequences is incredibly important for capturing the long-range dependencies required to model lengthy texts, codebases, high-resolution images, etc. To handle this, many researchers have been working on hardware-aware and memory efficient algorithms, such as FlashAttention.
The goal of this article is to highlight the concepts that made FlashAttention (2022) successful in achieving wall-clock speedup over the standard attention mechanism. The techniques leveraged in its second (2023) and third (2024) iterations will be covered in subsequent blog posts.
Familiarity with the following will help with understanding the topics presented in this article:
Modern accelerators, such as Hopper and Ampere GPUs, have an abundance of floating-point operations per second, or FLOPS, a metric telling of a device’s theoretical computational power. However, these same accelerators are limited by memory bandwidth, the rate at which data can be transferred between the GPU’s memory and its processing units. With this in mind, designing hardware-aware and memory-efficient algorithms for GPUs would require strategic consideration of how to best leverage the memory hierarchy and use as much of the theoretical maximum FLOPS possible.
FlashAttention is an excellent example of a hardware-aware and memory-efficient algorithm that enables longer context in Transformers by optimizing the attention mechanism for the hardware it’s computed on.
FlashAttention is introduced as an “IO-aware exact attention algorithm that uses tiling to reduce the number of reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM.”
Let’s break that down.
The terminology surrounding GPU memory types can be confusing, with numerous terms often describing identical or overlapping concepts. FlashAttention involves two memory types HBM and SRAM.
Memory | AKA | Key characteristics |
---|---|---|
HBM (High Bandwidth Memory) | GPU memory, global memory | Slow, larger memory capacity |
SRAM (Static Random-Access Memory) | L1 cache, shared memory | Fast, smaller memory capacity, on-chip |
Diagram from Aleksa Gordić’s YouTube video featuring FlashAttention author Tri Dao: Streaming multiprocessors (2) are in blue and contain compute units and SRAM. Global memory accesses to and from HBM are slow and should be minimized if possible.
It is worthwhile to build an understanding of how data is transferred in the GPU.
The self-attention calculation in matrix form, figure from The Illustrated Transformer by Jay Alammar.
Here’s a refresher of the variables involved in calculating the self-attention layer of the transformer.
Query (Q): The query vector is the current input or element for which attention will be computed. The vector is part of a query matrix of size Nxd where N is the sequence length on the order of 1K-8K and d is the head dimension of length 64-128.
Key (K): The key matrix is of the same dimensions as the query matrix. The key vectors are multiplied by the query vectors to calculate the similarity score.
Similarity Score (S): The similarity score is a measure of how similar the query is to each element in the sequence. By multiplying the query matrix with the transposed key matrix, a NxN matrix of similarity scores is produced.
Attention Probability (P in algorithm, A in diagram): The Attention Probability is a probability distribution computed by applying the softmax operation to the similarity scores, S. The softmax function normalizes the similarity scores, ensuring they are positive and sum up to 1.
Note that S and P/A matrices are both intermediate matrices and therefore not depicted in the formula
Value (V): The value vectors of the Nxd value matrix contains information about each element in a sequence and is multiplied by the attention probabilities to produce an Nxd output.
Attention algorithm as depicted in FlashAttention paper. In step 1, Q and K matrices are loaded into HBM to compute S. In step 2, S is read from HBM to have softmax applied to it, which is then written as P to HBM. This step takes the longest.
From Aleksa Gordić’s YouTube video featuring FlashAttention author Tri Dao: The diagram explains how reading and writing the intermediate matrices (S and A) is the main bottleneck when computing attention. Note that A in this diagram is the same thing as P in the algorithm above.
Now that we’ve established that the standard attention implementation lacks IO-awareness with its redundant reads and writes from slow GPU memory (HBM), let’s discuss the hurdles FlashAttention had to overcome to achieve IO-awareness.
FlashAttention boosts performance by fusing the attention computation into a single CUDA kernel. While kernel fusion may seem straightforward, the FlashAttention algorithm had to be carefully designed to ensure that the on-chip memory does not exceed hardware limits.
Tiling is a technique that involves partitioning data into smaller blocks, or “tiles”, that can fit into on-chip memory. Memory bandwidth requirements are reduced with tiling-assisted kernel fusion since data is transferred from global memory to the streaming multiprocessors only once per tile.
Tiling is particularly effective for associative operations like matrix multiplication. This property allows the computation to be reordered without affecting the final result, enabling efficient processing of smaller tiles. The softmax operation in self-attention, however, is not associative, meaning the order of the computations do matter.
Leveraging the online softmax trick to make softmax associative is arguably the key innovation of FlashAttention. FlashAttention forward pass diagram from FlashAttention-2 paper: To incrementally perform softmax reduction, the attention computation is restructured as indicated by the figure. The inputs Q, K, V are split into blocks. Instead of materializing the intermediate matrices (S,A/P) in HBM, they are computed in SRAM. The output is rescaled to the correct denominator (normalization factor) before adding them up at the end to give us the same result as the standard attention implementation.
Redundant read/write operations are omitted by not storing the intermediate S and A/P matrices and instead recomputing them in the backward pass. This is done by storing the output O and softmax normalization statistics (m, l) to recompute the intermediate S and A/P matrices in the backward pass from the Q, K, V blocks in SRAM.
FlashAttention overview, diagram from HuggingFace
By cleverly reordering the attention computation with classical techniques like tiling and recomputation to exploit the asymmetric GPU memory hierarchy, FlashAttention sped up the attention mechanism and reduced memory usage from quadratic to linear in sequence length. This algorithm does an excellent job of demonstrating both the art and effectiveness of designing hardware-aware algorithms.
Thanks for learning with the DigitalOcean Community. Check out our offerings for compute, storage, networking, and managed databases.
This textbox defaults to using Markdown to format your answer.
You can type !ref in this text area to quickly search our full set of tutorials, documentation & marketplace offerings and insert the link!