Scaling Sparse Attention for Long-Context Reasoning
Modern transformer architectures face a fundamental tension: the self-attention mechanism that makes them powerful scales quadratically with sequence length. For a sequence of length , standard attention computes an matrix of pairwise interactions — a cost of in both time and memory. This post explores how sparse attention patterns can break this bottleneck without sacrificing the reasoning capabilities that make dense attention so effective.
The Attention Bottleneck
Recall that in standard multi-head attention, for queries , keys , and values , each head computes:
where is the dimension of the key vectors. The product is the expensive part — it produces a dense matrix even when most of its entries contribute negligibly to the output.
Empirically, we observe that trained attention heads exhibit highly structured sparsity. Most tokens attend strongly to only a small subset of other tokens. The figure below shows a typical attention map from a 7B parameter model:

The pattern is clear: attention is concentrated along the diagonal (local context), a few vertical stripes (globally important tokens), and scattered high-value entries. The vast majority of the matrix is near-zero.
Sparse Attention Patterns
We can exploit this structure by restricting which token pairs are allowed to interact. A sparse attention pattern is defined by a mask :
The key design question is: which entries should include?
Common Patterns
| Pattern | Complexity | Description |
|---|---|---|
| Sliding window | Each token attends to neighbors | |
| Dilated | Window with gaps, increasing receptive field | |
| Global + local | Local window plus global tokens | |
| Hash-based (LSH) | Tokens hashed into buckets; attend within buckets | |
| Learned | Top- attention targets predicted per token |
Our approach combines sliding window attention for local context with a small set of learned anchor tokens that provide global information flow. The anchor selection is itself a lightweight attention operation over compressed representations.
Implementation
The core kernel is implemented in Triton. The key insight is that sparse block patterns map naturally to GPU thread blocks, avoiding the need for scatter/gather operations:
import triton
import triton.language as tl
@triton.jit
def sparse_attention_kernel(
Q, K, V, Out,
stride_qb, stride_qh, stride_qn, stride_qd,
stride_kb, stride_kh, stride_kn, stride_kd,
stride_vb, stride_vh, stride_vn, stride_vd,
stride_ob, stride_oh, stride_on, stride_od,
n_ctx, n_heads, block_size: tl.constexpr,
window_size: tl.constexpr,
):
pid_b = tl.program_id(0) # batch
pid_h = tl.program_id(1) # head
pid_q = tl.program_id(2) # query block
# Compute block boundaries
q_start = pid_q * block_size
q_offs = q_start + tl.arange(0, block_size)
# Initialize accumulator and running max
acc = tl.zeros([block_size, block_size], dtype=tl.float32)
m_i = tl.full([block_size], float("-inf"), dtype=tl.float32)
l_i = tl.zeros([block_size], dtype=tl.float32)
# Only iterate over blocks within the sparse window
k_start = tl.maximum(0, q_start - window_size)
k_end = tl.minimum(n_ctx, q_start + window_size + block_size)
for k_block in range(k_start, k_end, block_size):
# Load K, V blocks and compute local attention
k_offs = k_block + tl.arange(0, block_size)
# ... (flash-attention-style online softmax)The window restriction means each query block only loads key-value blocks instead of , where is the block size. For typical values (, , ), this is an 8× reduction in memory traffic.
Benchmark Results
We benchmark against FlashAttention-2 on an A100 GPU with varying sequence lengths:
Sequence Length FlashAttn-2 Sparse (ours) Speedup
─────────────────────────────────────────────────────────
4,096 1.2 ms 1.1 ms 1.1×
8,192 4.1 ms 2.3 ms 1.8×
16,384 15.8 ms 4.9 ms 3.2×
32,768 62.1 ms 9.7 ms 6.4×
65,536 OOM 19.2 ms —
131,072 OOM 38.1 ms —
The crossover point is around 8K tokens — below that, the overhead of the sparse pattern management is not worth it. Above 16K, the quadratic scaling of dense attention becomes dominant and sparse attention pulls ahead dramatically.
Preserving Reasoning Quality
Speed is meaningless if the model loses its ability to reason over long contexts. We evaluate on three benchmarks that specifically test long-range reasoning:
Key finding: On multi-hop reasoning tasks requiring information synthesis across 10K+ token spans, our sparse attention achieves 97.3% of dense attention's accuracy while using 6× less compute. The remaining 2.7% gap is concentrated in tasks requiring simultaneous attention to 4+ distant passages.
The critical design choice is the anchor token mechanism. Without it, pure sliding window attention drops to 81% accuracy on cross-document QA — the model simply cannot propagate information across distant windows. With 64 learned anchors per layer, accuracy recovers to near-dense levels.
Mathematically, the anchor mechanism can be understood as a low-rank correction to the sparse attention matrix. If is the windowed attention and captures the global component:
where and are the learned anchor projections and . This adds only cost while capturing the most important long-range interactions.
What's Next
Three directions we're actively exploring:
-
Dynamic sparsity — adapting the attention pattern per-input rather than using a fixed mask. Early experiments with a lightweight router network show promising results.
-
Hierarchical anchors — instead of a flat set of global tokens, organizing them in a tree structure to support information propagation.
-
Training-time sparsity — applying sparse attention during pretraining, not just inference. The challenge is maintaining gradient flow through the sparse mask.
The code is available at github.com/arturbcarneiro/sparse-attn. We welcome contributions and benchmarks on new hardware.