The history of computing innovation is marked by examples of optimizations that seem obvious, but have been hiding in plain sight for years.
FlashAttention (2022) is an example of such a breakthrough. While many researchers were focusing on FLOPs reduction through approximation techniques, Tri Dao and the FlashAttention team realized that the bottleneck was redundant memory accesses between GPU HBM and SRAM. FlashAttention combined classical techniques (kernel fusion and tiling) to achieve wall-clock speed up when computing attention, without compromising on accuracy (as in approximation methods).
FlashAttention-2 (2023) took its hardware-aware and IO-aware approach up a notch, achieving 2x speedup over its predecessor.
This article will expand on how FlashAttention-2 improved upon FlashAttention. The following modifications to the algorithm will be discussed in greater detail:
It is suggested to read and understand the previous article on FlashAttention before proceeding. An understanding of GPU performance optimization, the GPU memory hierarchy, and warps may also be helpful.
Maintaining a high throughput, the rate at which the system (GPU) can process data or perform operations, is crucial for handling increased workloads. To achieve a high throughput, programs must be designed to efficiently utilize computational resources.
For example, NVIDIA GPUs boast highly efficient processing units called Tensor Cores, which accelerate matrix multiplication. Floating point operations that are not matrix multiplication (non-matmul FLOPs), however, are not accelerated by these specialized units and therefore require more time to perform. By eliminating non-matmul operations that can’t leverage the computational power of Tensor Cores, throughput can be maintained at a high rate.
This figure shows how most of the time in the attention computation is spent on non-matmul operations.
FlashAttention-2 aimed to minimize non-matmul FLOPs by strategically identifying areas that can be modified without affecting the final output. To do this, FlashAttention-2 adjusted how online softmax was computed.
To put these modifications into better context, this blog post by Aleksa Gordić and article by Zihao Ye does an excellent job of breaking down the complexity of the algorithm.
In FlashAttention, softmax was computed one block at a time by keeping track of extra statistics (m,l) to rescale the output. The outputs of all the individual blocks were added up at the end to get the correct result. FlashAttention-2, instead, has an unscaled version of the output maintained until every end of the loop. Only at the end of the loop does FlashAttention-2 scale to get the correct output.
Additionally, instead of storing both the max 𝑚 ^( 𝑗 ) and sum of exponentials ℓ^( 𝑗 ) for the backward pass, FlashAttention-2 only stores logsumexp, which is comprised of the omitted variables.
Recall that a thread is comprised of the program’s code, the current execution point in the code, as well as the values of its variables and data structures. These threads are organized into thread blocks and executed by a streaming multiprocessor, which runs hundreds of these threads simultaneously.
Warp-level thread management is possible thanks to the single instruction multiple thread (SIMT) execution model of NVIDIA’s GPUs where one instruction is executed by multiple threads in the form of a 32-thread warp. Threads within a warp can collaborate to perform operations like matrix multiplication. Warps can also communicate with each other by reading from/writing to shared memory.
Assigning work to warps involves dividing large computations into smaller tasks and distributing them across these thread groups to be simultaneously executed. Poor allocation of tasks can result in redundant accesses to shared memory. FlashAttention-2 seeks to reduce shared memory access through strategic partitioning of the attention computation across warps.
FlashAttention (Split-K) | FlashAttention-2 (Split-Q) | |
---|---|---|
Which matrix/matrices are split among 4 warps? | K and V | Q |
Which matrix/matrices are accessible by all 4 warps? | Q | K and V |
How is QK^T computed? | Each of the 4 warps multiply with each other to get a partial sum of QK^T. | Each warp computes its slice of QK^T. |
Is synchronization and communication between warps necessary? | Yes, all 4 warps need to write their intermediate results to shared memory, synchronize, and add up the intermediate results. | There is no communication or synchronization between warps necessary in the forward pass. Each warp can multiply its result directly with V to get the output. However, in the backward pass there is some synchronization necessary to accommodate for the complex dependency between all the inputs and gradients. |
What does this mean for speed? | Forward pass is slowed down by the several shared memory reads/writes. | Sharing K^T and V among warps and splitting Q allows for shared memory reads/writes between warps to be eliminated, yielding speedup over FlashAttention in both the forward and backward pass. |
Occupancy is the ratio of the number of warps assigned to a streaming multiprocessor to the maximum number of warps it supports. Memory-bound operations, like the softmax step of the attention computation, typically require higher occupancy.
The A100 GPU’s 108 streaming multiprocessors operate efficiently with at least 80 thread blocks. Fewer thread blocks may result in idle streaming multiprocessors, resulting in underutilization of the GPU’s computing capacity.
To increase occupancy, FlashAttention-2 incorporated parallelization over sequence length. Here, parallelization refers to simultaneous execution of independent tasks across thread blocks that require no synchronization between them.
FlashAttention (and Standard Multi-Head Attention) parallelizes over: | FlashAttention-2 parallelizes over: | |
---|---|---|
Batch size: The number of input sequences in a batch | ✔️ | ✔️ |
Head dimension: The number of attention heads | ✔️ | ✔️ |
Sequence length: The number of elements in an input sequence | ✔️ |
FlashAttention parallelizes over the batch size and head dimension. This means the number of thread blocks is batch size * head dimension.
FlashAttention-2, however, parallelizes over batch size, head dimension, and sequence length. The number of thread blocks would be batch size * head dimension * sequence length.
A long sequence length translates to a smaller batch size as less input sequences can fit in a batch.
For FlashAttention, this would mean less thread blocks would be operating since the number of thread blocks would be batch size * head dimension. As a result, the incorporation of parallelization over sequence length for FlashAttention-2 allows for better utilization of the streaming multiprocessors of the GPU as the number of thread blocks would be batch size * head dimension * sequence length.
This figure shows how FlashAttention loops through blocks of the K and V matrices and loads them to SRAM in the outer loop (red arrows). In each block, FlashAttention loops over blocks of Q matrix (blue arrows), loading them to SRAM, and writing the attention computation output back to HBM. FlashAttention-2 inverts this order.
Loop | FlashAttention | FlashAttention-2 |
---|---|---|
Outer | Over K,V blocks | Over Q blocks |
Inner | Over Q blocks | Over K,V blocks |
Phil Tillet first introduced and implemented the optimizations of reversing the loop order and parallelizing along the sequence length dimension in Triton.
Figure depicts forward and backward pass of the parallelization scheme. Workers, here, correspond to thread blocks.
Forward pass: Each thread block executes a block of rows (outer loop) of the attention matrix.
Backward pass: Each thread block executes a block of columns (inner loop) of the attention matrix.
To recap, FlashAttention-2 achieved speedup over FlashAttention by reducing the number of non-matmul FLOPs to maintain high throughput, adding sequence-length parallelization to increase occupancy, and partitioning work between different warps of a thread block to reduce communication and shared memory reads/writes.
The success of FlashAttention and its successor FlashAttention-2 demonstrated that working with the hardware rather than against it yields superior results. By really understanding the systems we build upon as opposed to treating them as abstract black boxes, we can achieve amazing leaps in technology.
Thanks for learning with the DigitalOcean Community. Check out our offerings for compute, storage, networking, and managed databases.
This textbox defaults to using Markdown to format your answer.
You can type !ref in this text area to quickly search our full set of tutorials, documentation & marketplace offerings and insert the link!
Sign up for Infrastructure as a Newsletter.
Working on improving health and education, reducing inequality, and spurring economic growth? We'd like to help.
Get paid to write technical tutorials and select a tech-focused charity to receive a matching donation.