Tutorial

Designing Hardware-Aware Algorithms: FlashAttention

Published on October 15, 2024
Designing Hardware-Aware Algorithms: FlashAttention

Attention is All You Need

If you weren’t already aware, the T in chatGPT stands for Transformer, the linchpin architecture for developing state-of-the-art AI. Initially developed for machine translation, the Transformer is a neural network architecture that introduced self-attention in the paper: Attention is all you need. Through layers of interconnected nodes that map out an internal mathematical represesentation identifying relationships and relevance, an input sequence is transformed into an output sequence.

If Attention is All You Need, Let’s Make it Better…

The advent of the Transformer architecture ushered in a new era of AI research, focused on increasing the efficiency of its core mechanism, attention. Attention’s scalability is compromised by its time and memory complexity that scales quadratically or O(n^2) with sequence length, n. This is rather troublesome as efficiently modeling long sequences is incredibly important for capturing the long-range dependencies required to model lengthy texts, codebases, high-resolution images, etc. To handle this, many researchers have been working on hardware-aware and memory efficient algorithms, such as FlashAttention.

Introduction

The goal of this article is to highlight the concepts that made FlashAttention (2022) successful in achieving wall-clock speedup over the standard attention mechanism. The techniques leveraged in its second (2023) and third (2024) iterations will be covered in subsequent blog posts.

Prerequisites

Familiarity with the following will help with understanding the topics presented in this article:

Designing Hardware-Aware And Memory-Efficient Algorithms

Modern accelerators, such as Hopper and Ampere GPUs, have an abundance of floating-point operations per second, or FLOPS, a metric telling of a device’s theoretical computational power. However, these same accelerators are limited by memory bandwidth, the rate at which data can be transferred between the GPU’s memory and its processing units. With this in mind, designing hardware-aware and memory-efficient algorithms for GPUs would require strategic consideration of how to best leverage the memory hierarchy and use as much of the theoretical maximum FLOPS possible.

FlashAttention is an excellent example of a hardware-aware and memory-efficient algorithm that enables longer context in Transformers by optimizing the attention mechanism for the hardware it’s computed on.

FlashAttention (2022)

FlashAttention is introduced as an “IO-aware exact attention algorithm that uses tiling to reduce the number of reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM.”

Let’s break that down.

GPU Memory: HBM & SRAM

The terminology surrounding GPU memory types can be confusing, with numerous terms often describing identical or overlapping concepts. FlashAttention involves two memory types HBM and SRAM.

Memory AKA Key characteristics
HBM (High Bandwidth Memory) GPU memory, global memory Slow, larger memory capacity
SRAM (Static Random-Access Memory) L1 cache, shared memory Fast, smaller memory capacity, on-chip

GPU Compute Model

GPU Compute Model

Diagram from Aleksa Gordić’s YouTube video featuring FlashAttention author Tri Dao: Streaming multiprocessors (2) are in blue and contain compute units and SRAM. Global memory accesses to and from HBM are slow and should be minimized if possible.

It is worthwhile to build an understanding of how data is transferred in the GPU.

  1. Input starts out in HBM (GPU memory)
  2. Data moves into compute units & SRAM for computation
  3. Output written back to HBM

Computing Attention

Self-Attention Calculation The self-attention calculation in matrix form, figure from The Illustrated Transformer by Jay Alammar.

The Attention Line-up

Here’s a refresher of the variables involved in calculating the self-attention layer of the transformer.

Query (Q): The query vector is the current input or element for which attention will be computed. The vector is part of a query matrix of size Nxd where N is the sequence length on the order of 1K-8K and d is the head dimension of length 64-128.

Key (K): The key matrix is of the same dimensions as the query matrix. The key vectors are multiplied by the query vectors to calculate the similarity score.

Similarity Score (S): The similarity score is a measure of how similar the query is to each element in the sequence. By multiplying the query matrix with the transposed key matrix, a NxN matrix of similarity scores is produced.

Attention Probability (P in algorithm, A in diagram): The Attention Probability is a probability distribution computed by applying the softmax operation to the similarity scores, S. The softmax function normalizes the similarity scores, ensuring they are positive and sum up to 1.

Note that S and P/A matrices are both intermediate matrices and therefore not depicted in the formula

Value (V): The value vectors of the Nxd value matrix contains information about each element in a sequence and is multiplied by the attention probabilities to produce an Nxd output.

Standard Attention Algorithm

Attention algorithm as depicted in FlashAttention paper. In step 1, Q and K matrices are loaded into HBM to compute S. In step 2, S is read from HBM to have softmax applied to it, which is then written as P to HBM. This step takes the longest.

Attention Bottleneck Aleksa Gordić

From Aleksa Gordić’s YouTube video featuring FlashAttention author Tri Dao: The diagram explains how reading and writing the intermediate matrices (S and A) is the main bottleneck when computing attention. Note that A in this diagram is the same thing as P in the algorithm above.

FlashAttention is IO-aware

Now that we’ve established that the standard attention implementation lacks IO-awareness with its redundant reads and writes from slow GPU memory (HBM), let’s discuss the hurdles FlashAttention had to overcome to achieve IO-awareness.

Kernel Fusion

FlashAttention boosts performance by fusing the attention computation into a single CUDA kernel. While kernel fusion may seem straightforward, the FlashAttention algorithm had to be carefully designed to ensure that the on-chip memory does not exceed hardware limits.

Tiling

Tiling is a technique that involves partitioning data into smaller blocks, or “tiles”, that can fit into on-chip memory. Memory bandwidth requirements are reduced with tiling-assisted kernel fusion since data is transferred from global memory to the streaming multiprocessors only once per tile.

Tiling is particularly effective for associative operations like matrix multiplication. This property allows the computation to be reordered without affecting the final result, enabling efficient processing of smaller tiles. The softmax operation in self-attention, however, is not associative, meaning the order of the computations do matter.

Making Softmax Associative

Leveraging the online softmax trick to make softmax associative is arguably the key innovation of FlashAttention. FlashAttention forward pass FlashAttention forward pass diagram from FlashAttention-2 paper: To incrementally perform softmax reduction, the attention computation is restructured as indicated by the figure. The inputs Q, K, V are split into blocks. Instead of materializing the intermediate matrices (S,A/P) in HBM, they are computed in SRAM. The output is rescaled to the correct denominator (normalization factor) before adding them up at the end to give us the same result as the standard attention implementation.

Recomputation in the Backward Pass

Redundant read/write operations are omitted by not storing the intermediate S and A/P matrices and instead recomputing them in the backward pass. This is done by storing the output O and softmax normalization statistics (m, l) to recompute the intermediate S and A/P matrices in the backward pass from the Q, K, V blocks in SRAM.

HuggingFace FlashAttention FlashAttention overview, diagram from HuggingFace

Conclusion

By cleverly reordering the attention computation with classical techniques like tiling and recomputation to exploit the asymmetric GPU memory hierarchy, FlashAttention sped up the attention mechanism and reduced memory usage from quadratic to linear in sequence length. This algorithm does an excellent job of demonstrating both the art and effectiveness of designing hardware-aware algorithms.

Thanks for learning with the DigitalOcean Community. Check out our offerings for compute, storage, networking, and managed databases.

Learn more about our products

About the authors

Still looking for an answer?

Ask a questionSearch for more help

Was this helpful?
 
Leave a comment


This textbox defaults to using Markdown to format your answer.

You can type !ref in this text area to quickly search our full set of tutorials, documentation & marketplace offerings and insert the link!

Try DigitalOcean for free

Click below to sign up and get $200 of credit to try our products over 60 days!

Sign up

Join the Tech Talk
Success! Thank you! Please check your email for further details.

Please complete your information!

Featured on Community

Get our biweekly newsletter

Sign up for Infrastructure as a Newsletter.

Hollie's Hub for Good

Working on improving health and education, reducing inequality, and spurring economic growth? We'd like to help.

Become a contributor

Get paid to write technical tutorials and select a tech-focused charity to receive a matching donation.

Welcome to the developer cloud

DigitalOcean makes it simple to launch in the cloud and scale up as you grow — whether you're running one virtual machine or ten thousand.

Learn more