← Writing

Scaling Sparse Attention for Long-Context Reasoning

Modern transformer architectures face a fundamental tension: the self-attention mechanism that makes them powerful scales quadratically with sequence length. For a sequence of length nn, standard attention computes an n×nn \times n matrix of pairwise interactions — a cost of O(n2)\mathcal{O}(n^2) in both time and memory. This post explores how sparse attention patterns can break this bottleneck without sacrificing the reasoning capabilities that make dense attention so effective.

The Attention Bottleneck

Recall that in standard multi-head attention, for queries QQ, keys KK, and values VV, each head computes:

Attention(Q,K,V)=softmax ⁣(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V

where dkd_k is the dimension of the key vectors. The QKQK^\top product is the expensive part — it produces a dense n×nn \times n matrix even when most of its entries contribute negligibly to the output.

Empirically, we observe that trained attention heads exhibit highly structured sparsity. Most tokens attend strongly to only a small subset of other tokens. The figure below shows a typical attention map from a 7B parameter model:

Attention heatmap showing sparse diagonal and vertical stripe patterns

The pattern is clear: attention is concentrated along the diagonal (local context), a few vertical stripes (globally important tokens), and scattered high-value entries. The vast majority of the matrix is near-zero.

Sparse Attention Patterns

We can exploit this structure by restricting which token pairs are allowed to interact. A sparse attention pattern is defined by a mask M{0,1}n×nM \in \{0, 1\}^{n \times n}:

SparseAttn(Q,K,V)=softmax ⁣(QKM(1M)dk)V\text{SparseAttn}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top \odot M - (1 - M) \cdot \infty}{\sqrt{d_k}}\right) V

The key design question is: which entries should MM include?

Common Patterns

Pattern Complexity Description
Sliding window O(nw)\mathcal{O}(nw) Each token attends to ww neighbors
Dilated O(nw)\mathcal{O}(nw) Window with gaps, increasing receptive field
Global + local O(n(w+g))\mathcal{O}(n(w + g)) Local window plus gg global tokens
Hash-based (LSH) O(nlogn)\mathcal{O}(n \log n) Tokens hashed into buckets; attend within buckets
Learned O(nk)\mathcal{O}(nk) Top-kk attention targets predicted per token

Our approach combines sliding window attention for local context with a small set of learned anchor tokens that provide global information flow. The anchor selection is itself a lightweight attention operation over compressed representations.

Implementation

The core kernel is implemented in Triton. The key insight is that sparse block patterns map naturally to GPU thread blocks, avoiding the need for scatter/gather operations:

import triton
import triton.language as tl
 
@triton.jit
def sparse_attention_kernel(
    Q, K, V, Out,
    stride_qb, stride_qh, stride_qn, stride_qd,
    stride_kb, stride_kh, stride_kn, stride_kd,
    stride_vb, stride_vh, stride_vn, stride_vd,
    stride_ob, stride_oh, stride_on, stride_od,
    n_ctx, n_heads, block_size: tl.constexpr,
    window_size: tl.constexpr,
):
    pid_b = tl.program_id(0)  # batch
    pid_h = tl.program_id(1)  # head
    pid_q = tl.program_id(2)  # query block
 
    # Compute block boundaries
    q_start = pid_q * block_size
    q_offs = q_start + tl.arange(0, block_size)
 
    # Initialize accumulator and running max
    acc = tl.zeros([block_size, block_size], dtype=tl.float32)
    m_i = tl.full([block_size], float("-inf"), dtype=tl.float32)
    l_i = tl.zeros([block_size], dtype=tl.float32)
 
    # Only iterate over blocks within the sparse window
    k_start = tl.maximum(0, q_start - window_size)
    k_end = tl.minimum(n_ctx, q_start + window_size + block_size)
 
    for k_block in range(k_start, k_end, block_size):
        # Load K, V blocks and compute local attention
        k_offs = k_block + tl.arange(0, block_size)
        # ... (flash-attention-style online softmax)

The window restriction means each query block only loads O(w/B)\mathcal{O}(w / B) key-value blocks instead of O(n/B)\mathcal{O}(n / B), where BB is the block size. For typical values (n=32768n = 32768, w=4096w = 4096, B=128B = 128), this is an 8× reduction in memory traffic.

Benchmark Results

We benchmark against FlashAttention-2 on an A100 GPU with varying sequence lengths:

Sequence Length    FlashAttn-2     Sparse (ours)    Speedup
─────────────────────────────────────────────────────────
     4,096          1.2 ms          1.1 ms          1.1×
     8,192          4.1 ms          2.3 ms          1.8×
    16,384         15.8 ms          4.9 ms          3.2×
    32,768         62.1 ms          9.7 ms          6.4×
    65,536            OOM          19.2 ms           —
   131,072            OOM          38.1 ms           —

The crossover point is around 8K tokens — below that, the overhead of the sparse pattern management is not worth it. Above 16K, the quadratic scaling of dense attention becomes dominant and sparse attention pulls ahead dramatically.

Preserving Reasoning Quality

Speed is meaningless if the model loses its ability to reason over long contexts. We evaluate on three benchmarks that specifically test long-range reasoning:

Key finding: On multi-hop reasoning tasks requiring information synthesis across 10K+ token spans, our sparse attention achieves 97.3% of dense attention's accuracy while using 6× less compute. The remaining 2.7% gap is concentrated in tasks requiring simultaneous attention to 4+ distant passages.

The critical design choice is the anchor token mechanism. Without it, pure sliding window attention drops to 81% accuracy on cross-document QA — the model simply cannot propagate information across distant windows. With 64 learned anchors per layer, accuracy recovers to near-dense levels.

Mathematically, the anchor mechanism can be understood as a low-rank correction to the sparse attention matrix. If AsparseA_{\text{sparse}} is the windowed attention and AanchorA_{\text{anchor}} captures the global component:

Aeffective=(1α)Asparse+αUΣVA_{\text{effective}} = (1 - \alpha) \, A_{\text{sparse}} + \alpha \, U \Sigma V^\top

where URn×kU \in \mathbb{R}^{n \times k} and VRn×kV \in \mathbb{R}^{n \times k} are the learned anchor projections and knk \ll n. This adds only O(nk)\mathcal{O}(nk) cost while capturing the most important long-range interactions.

What's Next

Three directions we're actively exploring:

  1. Dynamic sparsity — adapting the attention pattern per-input rather than using a fixed mask. Early experiments with a lightweight router network show promising results.

  2. Hierarchical anchors — instead of a flat set of global tokens, organizing them in a tree structure to support O(logn)\mathcal{O}(\log n) information propagation.

  3. Training-time sparsity — applying sparse attention during pretraining, not just inference. The challenge is maintaining gradient flow through the sparse mask.

The code is available at github.com/arturbcarneiro/sparse-attn. We welcome contributions and benchmarks on new hardware.