Abstract
We implement Flash Attention (Dao et al., 2022) in the Needle
deep learning framework — replacing the standard O(N²) attention matrix with an online softmax
recurrence that reduces working memory to O(N). Our contribution includes a tiled CUDA forward and
backward pass wired into Needle's autograd engine, a portable C reference implementation for
CPU testing, and a drop-in use_flash_attention flag on MultiHeadAttention.
Correctness is verified against pre-computed reference outputs across batch sizes, sequence lengths,
head configurations, and causal masking modes.
Background: The Quadratic Memory Problem
Standard attention computes O = softmax(QKT/√d)·V, which requires storing the full N×N score matrix — O(N²) memory per head. Flash Attention avoids this entirely by streaming over blocks and maintaining only two running scalars per query row. The output is bit-identical to standard attention — it is not an approximation.
| Sequence Length N | Standard (N×N matrix) | Flash (stats only) | Reduction |
|---|---|---|---|
| 128 | 256 KB | 4 KB | 64× |
| 512 | 4 MB | 16 KB | 256× |
| 1 024 | 16 MB | 32 KB | 512× |
| 4 096 | 256 MB | 128 KB | 2 048× |
Per head (B=1, H=1), float32. Standard: N²×4 B. Flash: N×2×4 B (running max m and normalizer l).
Algorithm
Instead of computing all scores first then normalizing, Flash Attention maintains a running maximum mi and normalizer li per query row. When a larger score arrives, the accumulated output is rescaled by exp(mprev−mi) — keeping results numerically stable without an N×N buffer.
Backward pass reuses the saved (m, l) to recompute attention weights without
storing the N×N probability matrix. Two CUDA kernels run sequentially:
(1) ComputeDelta — row-wise δi = Σj Pij · (∂Oi·V[j]);
(2) AccumulateGradients — ∂Q, ∂K, ∂V via recomputed Pij,
using atomicAdd for ∂K and ∂V.
Implementation
Four layers from user API to hardware, with the CUDA path for training and a standalone C path for correctness testing:
nn.TransformerLayer · nn.AttentionLayer · nn.MultiHeadAttention
ops.flash_attention(Q, K, V, causal=True) —
caches m, l on output NDArray for backward
ndarray_backend_cuda.cu
flash_attention_cpu_forward()c_api/ · optional OpenMP
The forward CUDA kernel launches on grid dim3(batch, heads) with block width
head_dim — each thread owns one element of the head dimension and accumulates
a register scalar while streaming over all key positions.
FlashAttention extends Needle's TensorOp, attaching (m, l) to the
output so the backward pass can recompute attention weights without re-running the forward.
API & Module Map
hidden_size, dropout=0., causal=True
use_flash_attention=False, causal=True
use_flash_attention=False
FlashAttentionGrad(TensorTupleOp) · retrieves cached m/l, calls CUDA backward
FlashAttentionBackwardKernel
ComputeDeltaKernel + gradient accumulation
returns FLASH_ATTENTION_OK or error code
max head_dim 512 · optional -DUSE_OPENMP
Full documentation: docs/USER.md · docs/DEVELOPER.md · man flash_attention(3)
Results
CPU Reference Benchmark
Single-threaded, causal masking, B=1 H=4 D=64. Run with make example-c.
| N | Time | N² | Ratio vs N=128 |
|---|---|---|---|
| 128 | 0.0013 s | 16 384 | 1× |
| 512 | 0.0205 s | 262 144 | 15.8× |
| 1 024 | 0.0554 s | 1 048 576 | 42.6× |
Scales O(N²·D) as expected (theoretical 16× and 64×; observed slightly lower due to cache effects at small N).
Correctness Verification
The Python test suite validates MultiHeadAttention, AttentionLayer,
TransformerLayer, and Transformer against pre-computed float32 reference
outputs (atol=1e-5), parametrized over batch ∈ {4,8},
seq_len ∈ {5,11,31}, heads ∈ {4,5,8}, head_dim ∈ {8,32,64},
causal ∈ {True,False}, device ∈ {cpu,cuda}.
The C reference test (make test-c) additionally verifies online softmax vs. naive
attention at atol=1e-4.
Quick Start
# Build Needle backends (requires CMake, pybind11)
make lib
# C reference — no Python/CUDA required
make c-api && make test-c && make example-c
import needle as ndl, needle.nn as nn
# Drop-in: add use_flash_attention=True to any AttentionLayer
layer = nn.AttentionLayer(
q_features=256, num_head=8, dim_head=32,
causal=True, use_flash_attention=True,
device=ndl.cuda()
)
output = layer(x) # x: (batch, seq_len, 256)
# Or call the op directly (Q,K,V shape: batch, heads, seq_len, head_dim)
from needle.ops import flash_attention
output = flash_attention(Q, K, V, causal=True)
// C API
#include "flash_attention.h"
FlashAttentionParams p = {.batch=1,.heads=4,.seq_len=1024,.head_dim=64,.causal=1};
flash_attention_cpu_forward(Q, K, V, O, &p);
Platform-specific build instructions and OpenMP setup: docs/INSTALL.md
References
- Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS, 2022.
- Zico Kolter and David Guestrin. Deep Learning Systems. Carnegie Mellon University, 2023. (Needle framework open-source materials)