Back Original

I rebuilt FlashAttention in Triton to understand the performance archaeology

⏲️ Estimated reading time ~45min.

Flash Attention has become one of the most impactful optimizations in modern deep learning. Since the original paper was published in 2022, we’ve seen four major versions—each squeezing more performance out of increasingly powerful hardware. But here’s the thing: reading papers is one thing, understanding why these optimizations were made is another entirely.

My goal here is simple: start from first principles, implement FlashAttention v1 exactly as described in the paper, profile it, find the bottlenecks, and see how far we can push it. We’ll build intuition by iterating on the algorithm, discovering through profiling exactly why v2, v3, and v4 were necessary. Think of this as archaeology—digging through the performance layers to understand what each version was really solving.

So let’s put ourselves in the shoes of that mythical Stanford grad student. You’ve finally finished configuring your neovim and archlinux setup (a multi-year endeavor, naturally). You open up a fresh LLaMA model to peek under the hood. Text goes in, gets tokenized, embedded, then flows through a stack of transformer blocks. Standard stuff. But then you look closer at the attention mechanism—three projections and then… there it is:

scores = torch.matmul(q, k.transpose(3, 2)) / math.sqrt(self.head_dim)
scores = F.softmax(scores, dim=-1)
output = torch.matmul(scores, v)  # (B, N_h, T, D_h)

This monstrosity of a code is staring you right in the face. For those who don’t immediately see the problem with these 4 lines, let me add some annotations on how this would normally execute in PyTorch (without compilation).

# 1. We load q `(B, N_h, S, D_h)` and k `(B, N_h, S, D_h)`
# 2. We compute Q.Kt and write it back to HBM. Note score is `(B, N_h, S, S)`.
scores = torch.matmul(q, k.transpose(3, 2)) / math.sqrt(self.head_dim)
# 3. Reload the scores tensor to compute the softmax and write it back to HBM
scores = F.softmax(scores, dim=-1)
# 4. Load v `(B, N_h, S, D_h)`, load the scores from HBM
# 5. Compute scores@v and write it back to HBM.
output = torch.matmul(scores, v)  # `(B, N_h, T, D_h)`

Do you see it now? Well, we have three tensors q, k, v each of dimension (B, N_h, T, D_h)1. The output tensor of the attention mechanism is (B, N_h, T, D_h), and somehow in the middle we had to materialize a (B, N_h, S, S) tensor for funsies. The attention mechanism has a critical bottleneck: quadratic memory complexity. Let’s take a standard training sequence length S=8192. Computing attention naively requires O(S²) memory to store the full attention matrix, which means consuming several gigabytes of GPU memory. Crucially, in modern transformers, S >> D_h (sequence length is much larger than head dimension) - we typically have S=8192 or more while D_h=64 or 128. This massive asymmetry is what makes Flash Attention algorithm possible. The second big issue here is the back and forth to HBM. Modern GPUs have compute throughput vastly exceeding memory bandwidth. Repeatedly reading Q, K, V, and scores from slow High Bandwidth Memory (HBM) is going to greatly impact performance (more on this later).

The whole idea of Flash Attention is to bypass these intermediate steps—i.e., go from tensors q, k, v to the output tensor directly and compute the attention in one go, with minimal memory footprint (materializing only the tensors we need) and minimal back and forth to HBM, and hopefully getting to O(S) memory complexity.

The Plan: Start with FlashAttention v1 from the original paper. Implement it faithfully in Triton, profile with NVIDIA’s tooling, identify the bottlenecks, then iterate. Each optimization will teach us something about the GPU memory hierarchy and why the subsequent versions (v2, v3, v4) introduced the changes they did. Four versions represents an enormous amount of engineering work—let’s see how much of that journey we can reconstruct by just following the profiler’s breadcrumbs.

Setup and Hardware

First things first, I’ll use triton to implement the attention kernels. I wanted to implement them in CUDA but I thought this was a good opportunity to learn triton. I’m always skeptical about DSLs and abstractions that promise “write once, run fast everywhere,” but like my very wise friend Chris Fleetwood said, you can’t go wrong learning something shipped inside PyTorch. Running triton is also extremely straightforward and reduces the mess of boilerplate and C++ code that you have to write when implementing CUDA kernels, especially if you want to call them from Python. I actually went back and reimplemented Flash Attention in CUDA to have a bit more control, but that will come later (a little teasing).

I’ll compare these kernels directly with a reference PyTorch implementation, first to make sure the kernel is doing what it’s supposed to, and second to profile the kernel against the baseline.

Why Triton? Block-level Programming Without Thread Hell

Triton is a Python-based DSL for writing GPU kernels. The pitch is simple: write Python-like code at the block level, and the compiler handles individual thread management, memory coalescing, and other low-level optimizations. What makes Triton fundamentally different from CUDA is the abstraction level - in CUDA, you explicitly program individual threads of a block. In Triton, you write your kernel thinking about blocks/tiles of data a little bit how you’d write torch code, and the compiler figures out how to map that to threads. It handles threads data access pattern, synchronization barriers and work splitting. My skepticism comes from the fact that I don’t really trust these “magic” compilers for running parallel code. Usually, the promised performance falls apart quickly. But Triton has a few things going for it:

  1. It’s maintained by OpenAI and ships with PyTorch 2.0+
  2. The generated PTX is easily inspectable, so you can see what it’s actually doing
  3. It handles the tedious bits (pointer arithmetic, bounds checking) while still giving you control over the algorithm. You write your kernel at block level and the compilers issues the correct code to load your tensors in.

Setup

I’m running this on my personal machine:

Here are the GPU specs from cudaGetDeviceProperties. These numbers will be useful once we start profiling our kernel to detect performance issues:

Name NVIDIA GeForce RTX 2070
Compute Capability 7.5
Total Memory 8 GB
MultiProcessor Count (SM_count) 36
Max Threads per MultiProcessor 1024
Warp Size 32
L2 Cache Size 4 MB
Shared Memory per Block 48 KB
Shared Memory per MultiProcessor 64 KB
Registers per MultiProcessor 65536

Profiling Tools

For profiling, I use three tools at different granularities:

  1. torch.profiler: Quick and dirty. Good for seeing wall-clock time and basic GPU utilization. I use this for initial sanity checks.

  2. NVIDIA Nsight Systems (nsys): System-wide profiler. Shows CPU/GPU timeline, kernel launches, memory transfers. Great for spotting gaps where the GPU is idle.

  3. NVIDIA Nsight Compute (ncu): The heavy hitter. Gives you everything: occupancy, memory throughput, warp stalls, instruction mix, bank conflicts. This is what I’ll use for deep-diving into kernel performance. You can run it with:

sudo ncu --set full --kernel-name "attn_kernel" -o profile_output -f python script.py

Then open the .ncu-rep file with ncu-ui for the full analysis.

The Flash Attention Algorithm

For modern GPUs, compute throughput vastly exceeds memory bandwidth - an A100 can do ~300 TFLOPs but only has ~2 TB/s of memory bandwidth. This massive compute-to-memory ratio means that naive algorithms spending most of their time waiting for memory, not computing. Flash Attention restructures the computation to maximize arithmetic intensity (FLOPs per byte transferred).

Let’s build the intuition behind the core ideas from walking back from the solution. The goal is to one shot (single kernel) build the output tensor.

We are in GPU programming land. If you want to minimize transfers to main memory—like any good GPU engineer who has read Simon’s matmul blog post (stop reading this post and go read it. I am serious. It’s the best matmul writeup on the internet.)—you’d immediately think about computing blocks of output in parallel and writing each block to memory.

This is commonly known as tiling, and every GPU kernel out there uses this trick. The idea is simple but powerful: instead of streaming data directly from slow global memory for each operation, you load a tile (a small block of data) into fast shared memory once, reuse it across multiple computations within a thread block, and only write the final results back. This dramatically reduces memory bandwidth bottlenecks by exploiting data locality and the GPU’s memory hierarchy—exactly what you need when global memory access is orders of magnitude slower than compute.

Let’s first look at how the output tensor is computed based on the attention formula. Let’s focus on a single batch and head (B, N_h) to clearly see the shapes of the matmuls and the elements. Also, we’ll omit the causal masking here (although it’s not really complicated to add it to the attention score before softmax).

output tensor output tensor

If you have been following along, you’re probably scratching your head wondering why we are sharding the V tensor row-wise. Sorry for pulling a quick one here, but this formulation of the output (using the rows of V instead of the columns) will actually help us build the flash attention algorithm. For each row OiO_i of the output:

Oi=jeSijVjjeSijwhereSij=QiKj O_i = \frac{\sum_j e^{S_{ij}} V_j}{\sum_j e^{S_{ij}}} \quad\text{where}\quad S_{ij} = Q_i \cdot K_j^{\top}

Still, why split VjV_j row-wise and not column-wise? First of all, sharding on the D dim makes less. Like I pointed out earlier, D is usually 64 or 128 whereas S sequence length could be 8192 or more depending on the context window. If we tried to split VV column-wise (along D), the math would look like this:

[Pij]×[VleftVright]=[PVleftPVright] \begin{bmatrix} \dots & P_{ij} & \dots \end{bmatrix} \times \begin{bmatrix} V_{\text{left}} & V_{\text{right}} \end{bmatrix} = \begin{bmatrix} P \cdot V_{\text{left}} & P \cdot V_{\text{right}} \end{bmatrix}

This would force us to load the entire score matrix P=softmax(Q.Kt)P=softmax(Q.K^t) just to compute the left half of the output. But we can’t store the entire PP matrix—that’s the exact problem FlashAttention solves! By splitting row-wise:

[Pi0Pi1]×[V0V1]=(Pi0V0)+(Pi1V1) \left[ \begin{array}{c|c} P_{i0} & P_{i1} \end{array} \right] \times \begin{bmatrix} V_0 \\ \hline V_1 \end{bmatrix} = (P_{i0} \cdot V_0) + (P_{i1} \cdot V_1)

Now, the genius piece of flash attention is this: if we found a way to build OiO_i iteratively for a given ii, we could calculate a small chunk PijP_{ij}, multiply by VjV_j, add it to the sum, and discard PijP_{ij}. This means we are only allocating a small block of memory!

But why does this matter anyway? Isn’t it still mathematically equivalent if we compute the whole P_i row? Well, if you live in some mythical mathematical idea world, yes. But this code runs on hardware, and when and how we access memory does matter a LOT!

Quick Detour into GPU Memory Land

Understanding GPU memory is crucial, especially because GPUs are SIMT (Single Instruction, Multiple Thread) machines. If you’re familiar with SIMD programming on the CPU side, you usually need to use explicit vector instructions to pack and align vectors before executing instructions. SIMT, on the other hand, lets you write scalar thread code that the GPU transparently executes as a lockstep vector, coalescing loads automatically when memory is well-aligned.

GPUs execute thousands of lightweight threads in lockstep (called warps). Multiple warps form a block, and blocks are scheduled into SMs. Looking back at the Setup section, we can see that my RTX 2070 has 36 SMs (streaming multiprocessors), and each SM can handle 1024 threads (32 warps) concurrently.

GPU cores are simple and rely on this massive parallelism to hide latency—whenever one warp stalls waiting on memory, the scheduler swaps in another warp that already has its data “in flight.” This only works if kernels expose enough parallelism and if memory accesses are structured to take advantage of locality2. A core part of the CUDA programming model is that it exposes the memory hierarchy directly, forcing the programmer (or compiler) to reason explicitly about where data lives and how far it is from the compute units. Memory proximity matters enormously:

Look at this diagram of memory hierarchy with bandwidth and memory size taken from Flash Attention v1 very closely, and try to internalize it. It is extremely important to take I/O into account when designing highly performant algorithms.

Memory hierarchy and bandwidth

If you’re still following along with me, the main insight is to restructure attention so that all intermediate activations fit within these fast memories—registers and shared memory—minimizing slow HBM traffic and maximizing warp-level parallelism. This is why building the output block by block using the same size PijP_{ij} is a core pillar of fast attention. But it is still unclear at this stage how we could build the output incrementally.

Online Softmax Mathematics

Before we go deeper into how we could build these blocks, we need one last quick detour into something I omitted until now for simplification reasons. The last lego block we need in our toolbox before building flash attention is softmax numerical stability.

Numerically stable softmax

Until now I wrote the softmax function for a vector z=(z1,,zn)z = (z_1, \dots, z_n) as:

softmax(zi)=ezij=1nezj \operatorname{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{n} e^{z_j}}

But what happens if input values zz are large (e.g., zi=1000z_i = 1000)? Computing e1000e^{1000} will exceed the maximum value a floating-point number can hold, resulting in inf (infinity) or NaN (Not a Number) errors. Conversely, underflow can happen if values are very negative—the denominator might vanish to zero, leading to division by zero errors. The trick is to numerically stabilize by subtracting the maximum logit.

The key property is that Softmax is invariant to adding or subtracting a constant ( c ):

softmax(zi+c)=softmax(zi) \operatorname{softmax}(z_i + c) = \operatorname{softmax}(z_i)

For numerical stability, we usually subtract the maximum value (m=maxjzjm = \max_j z_j). Setting ( c = -m ) ensures that the exponentials are at most 1, preventing overflow :

softmax(zi)=ezimj=1nezjm \operatorname{softmax}(z_i) = \frac{e^{z_i - m}}{\sum_{j=1}^{n} e^{z_j - m}}

So let’s go back to incrementally computing a row of the output OiO_i. Let’s start slow to build the intuition. We take row i=0 and j=0. This means we are focusing on the first block of P00=Q0@K0tP_{00}=Q_0@K_0^t and the first row of the value tensor V0V_0.

output tensor output tensor

If you already see an issue with the current softmax, bear with me a minute. Remember that each element of the score matrix is computed as Pij=eQiKjTmijeQiKjTmi P_{ij} = \frac{e^{Q_i \cdot K_j^T - m_i}}{\sum_j e^{Q_i \cdot K_j^T- m_i}} . So the denominator is clearly broken (we need to compute the whole row to compute this sum).

For now we only have Q0K0tQ_0\cdot K_0^t in the row, so the first value we computed is clearly broken. Similarly, mi=m0m_i=m_0 here represents the maximum value of the row. For now we’ve only computed one value, so the max is also broken. Let’s continue to the second value of the output row and see how we can fix what is broken:

output tensor output tensor

When we see a new block, we can’t just throw away our previous work. Instead, we need to rescale what we’ve already computed. How though? If we simplify the problem to only having two values in this row (D=2) and we have computed the first one, the correct output value should be :

O1=eQ1K1Tm1V1+eQ1K2Tm1V2eQ1K1Tm1+eQ1K2Tm1wherem1=max(Q1K1T,Q1K2T) O_1 = \frac{e^{Q_1 \cdot K_1^T - m_1} \cdot V_1 + e^{Q_1 \cdot K_2^T - m_1} \cdot V_2}{e^{Q_1 \cdot K_1^T - m_1} + e^{Q_1 \cdot K_2^T - m_1}} \quad \text{where} \quad m_1 = \max(Q_1 \cdot K_1^T, Q_1 \cdot K_2^T)

Thanks to the exponential multiplicative property we can easily build these from the previous work and currently computed block. This is done using running correction factors:

m1=Q1K1T,m2=Q1K2T,mnew=max(m1,m2)α=em1mnew,β=em2mnew \begin{gathered} m_1 = Q_1 \cdot K_1^T, \quad m_2 = Q_1 \cdot K_2^T, \quad m_{\text{new}} = \max(m_1, m_2) \quad \alpha = e^{m_1 - m_{\text{new}}}, \quad \beta = e^{m_2 - m_{\text{new}}} \end{gathered}

These α\alpha and β\beta factors let us rescale our previous exponentials to the new maximum. Think of it as “adjusting the baseline” - when we find a new maximum or a new sum, we scale down the previous values accordingly.

lnew=αl1+βl2=em1mnewl1+em2mnewl2 l_{\text{new}} = \alpha l_1 + \beta l_2 = e^{m_1 - m_{\text{new}}} l_1 + e^{m_2 - m_{\text{new}}} l_2

lnew=eQ1K1Tmnew+eQ1K2Tmnew \boxed{l_{\text{new}} = e^{Q_1 \cdot K_1^T - m_{\text{new}}} + e^{Q_1 \cdot K_2^T - m_{\text{new}}}}

Now here’s where the magic happens. We also need to update our output O1O_1. Remember, we had computed O1prevO_1^{\text{prev}} using only the first block, but now we need to incorporate the second block. The update formula rescales the old output and adds the contribution from the new block:

O1new=αl1×O1prev+βeQ1K2Tm2V2lnew O_1^{\text{new}} = \frac{\alpha \cdot l_1 \times O_1^{\text{prev}} + \beta \cdot e^{Q_1 \cdot K_2^T - m_2} \cdot V_2}{l_{\text{new}}}

Let’s expand this to see what’s really happening. The old output was O1prev=eQ1K1Tm1l1V1O_1^{\text{prev}} = \frac{e^{Q_1 K_1^T - m_1}}{l_1} \cdot V_1, so:

=1lnew(em1mnewl1eQ1K1Tm1l1V1+em2mneweQ1K2Tm2V2) = \frac{1}{l_{\text{new}}} \left( e^{m_1 - m_{\text{new}}} \cdot \cancel{l_1} \cdot \frac{e^{Q_1 K_1^T - m_1}}{\cancel{l_1}} \cdot V_1 + e^{m_2 - m_{\text{new}}} \cdot e^{Q_1 \cdot K_2^T - m_2} \cdot V_2 \right)

Simplifying by combining the exponentials:

=1lnew(eQ1K1TmnewV1+eQ1K2TmnewV2) = \frac{1}{l_{\text{new}}} \left( e^{Q_1 \cdot K_1^T - m_{\text{new}}} \cdot V_1 + e^{Q_1 K_2^T - m_{\text{new}}} \cdot V_2 \right)

And voilà! We get exactly what we’d expect - the proper softmax formula:

=j=12eQ1KjTmnewVjj=12eQ1KjTmnew = \frac{\sum_{j=1}^{2} e^{Q_1 \cdot K_j^T - m_{\text{new}}} V_j}{\sum_{j=1}^{2} e^{Q_1 \cdot K_j^T - m_{\text{new}}}}

The beauty of this approach is that we never had to store the full attention matrix. We just kept updating our running statistics (mm and ll) and our output incrementally!

This is the core insight of Flash Attention’s online softmax - we can compute exact attention without materializing the entire attention matrix in memory.

Core Idea: Tiling and Online Softmax

TLDR, Flash Attention solves the memory problem through two key insights:

  1. Block-wise Computation: Instead of computing the full attention matrix, process Q, K, V in small blocks that fit in fast on-chip SRAM. We have been working with a single row of the output matrix, but we could easily generalize to computing a chunk of n rows of the matrix. For a single chunk n of the output, we would need n rows of the query matrix Q and all j rows of K and V. More concretely, the algorithm divides the sequence into blocks of size Bc and processes them iteratively.
  2. Online Softmax: Compute softmax incrementally without materializing the full attention matrix. This is usually the unintuitive part of the flash attention algorithm. Hopefully, if you have been following along, you can see the core idea behind the online softmax. The key is keeping track of running statistics (max value m and sum l) to compute the correct softmax normalization.

Let’s roll up our sleeves and implement the algorithm!


Implementing Flash Attention v1

Here is the triton implementation of the FA1 algorithm3 straight out of the paper. I tried to follow the algorithm as closely as possible to have a good baseline that we can work from:

Algorithm 1 FlashAttentionRequire: Matrices Q,K,VRN×d in HBM, on-chip SRAM of size M.1.Set block sizes Bc=M4d, Br=min(M4d,d).2.Initialize O=0N×d, =0N, m=()N.3.Divide Q into Tr blocks Qi, and K,V into Tc blocks Kj,Vj.4.Divide O,,m into blocks Oi,i,mi.5.for 1jTc do6.Load Kj,Vj from HBM to SRAM.7.for 1iTr do8.Load Qi,Oi,i,mi from HBM to SRAM.9.Compute Sij=QiKjRBr×Bc.10.m~ij=rowmax(Sij), P~ij=exp(Sijm~ij), ~ij=rowsum(P~ij).11.minew=max(mi,m~ij), inew=emiminewi+em~ijminew~ij.12.Oidiag(inew)1(diag(i)emiminewOi+em~ijminewP~ijVj).13.iinew, miminew.14.end for15.end for16.return O. \small \begin{array}{l} \hline \textbf{Algorithm 1 } \text{FlashAttention} \\ \hline \\ \textbf{Require: } \text{Matrices } Q, K, V \in \mathbb{R}^{N \times d} \text{ in HBM, on-chip SRAM of size } M. \\ \\ 1. \quad \text{Set block sizes } B_c = \lceil \frac{M}{4d} \rceil, \ B_r = \min(\lceil \frac{M}{4d} \rceil, d). \\ 2. \quad \text{Initialize } O = \mathbf{0}_{N \times d}, \ \ell = \mathbf{0}_N, \ m = (-\infty)_N. \\ 3. \quad \text{Divide } Q \text{ into } T_r \text{ blocks } Q_i, \text{ and } K, V \text{ into } T_c \text{ blocks } K_j, V_j. \\ 4. \quad \text{Divide } O, \ell, m \text{ into blocks } O_i, \ell_i, m_i. \\ 5. \quad \textbf{for } 1 \le j \le T_c \textbf{ do} \\ 6. \qquad \text{Load } K_j, V_j \text{ from HBM to SRAM.} \\ 7. \qquad \textbf{for } 1 \le i \le T_r \textbf{ do} \\ 8. \qquad \quad \text{Load } Q_i, O_i, \ell_i, m_i \text{ from HBM to SRAM.} \\ 9. \qquad \quad \text{Compute } S_{ij} = Q_i K_j^\top \in \mathbb{R}^{B_r \times B_c}. \\ 10.\qquad \quad \tilde{m}_{ij} = \mathrm{rowmax}(S_{ij}), \ \tilde{P}_{ij} = \exp(S_{ij} - \tilde{m}_{ij}), \ \tilde{\ell}_{ij} = \mathrm{rowsum}(\tilde{P}_{ij}). \\ 11.\qquad \quad m_i^{\text{new}} = \max(m_i, \tilde{m}_{ij}), \ \ell_i^{\text{new}} = e^{m_i - m_i^{\text{new}}} \ell_i + e^{\tilde{m}_{ij} - m_i^{\text{new}}} \tilde{\ell}_{ij}. \\ 12.\qquad \quad O_i \leftarrow \mathrm{diag}(\ell_i^{\text{new}})^{-1} \left( \mathrm{diag}(\ell_i) e^{m_i - m_i^{\text{new}}} O_i + e^{\tilde{m}_{ij} - m_i^{\text{new}}} \tilde{P}_{ij} V_j \right). \\ 13.\qquad \quad \ell_i \leftarrow \ell_i^{\text{new}}, \ m_i \leftarrow m_i^{\text{new}}. \\ 14.\qquad \textbf{end for} \\ 15.\quad \textbf{end for} \\ 16.\quad \textbf{return } O. \\ \hline \end{array}

We recognize the usual suspects from the math we derived earlier. I will only focus on the forward pass without causal masking to keep focus on the core algorithm. Let’s start with the simple reference implementation in pytorch:

# q,k,v are of shape: `(B, N_h, S, D_h)`
# B: batch_size
# N_h: num heads
# S: sequence length
# D_h: head dim
def simple_attn(q, k, v):
    att = q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1)))
    att = F.softmax(att, dim=-1)
    y = att @ v
    return y

My first implementation kernels/triton_flash_att.py follows the original Flash Attention paper very closely:

@triton.jit
def attn_kernel(Q, K, V, O, S, D, Tc, Tr, Bc, Br, softmax_scale, l, m):
    # ... [SETUP offsets here]
    # Outer loop over K, V blocks
    for j in range(0, Tc):
        # Load K_j, V_j from HBM to SRAM
        kj = tl.load(k_ptr + offset_j)  # (Bc, D)
        vj = tl.load(v_ptr + offset_j)  # (Bc, D)

        # Inner loop over Q blocks
        for i in range(0, Tr):
            # Load Q_i and previous O_i, l_i, m_i
            qi = tl.load(q_ptr + offset_i)  # `(Bc, D)`
            prev_oi = tl.load(o_ptr + offset_i)
            prev_li = tl.load(l_ptr + S_i_offset)
            prev_mi = tl.load(m_ptr + S_i_offset)

            # Compute attention scores for this block
            Sij = tl.dot(qi, tl.trans(kj)) * softmax_scale  # (Bc, Bc)

            # Online softmax update
            mij = tl.max(Sij, 1)  # Row-wise max
            pij = tl.exp(Sij - mij[:, None])
            lij = tl.sum(pij, 1)  # Row-wise sum

            # Update running statistics
            mi_new = tl.maximum(prev_mi, mij)
            alpha = tl.exp(prev_mi - mi_new)
            beta = tl.exp(mij - mi_new)
            li_new = prev_li * alpha + lij * beta

            # Update output
            oi_new = (alpha[:, None] * prev_li[:, None] * prev_oi
                      + beta[:, None] * tl.dot(pij, vj)) / li_new[:, None]

            # Write back to HBM
            tl.store(o_ptr + offset_i, oi_new)
            tl.store(m_ptr + S_i_offset, mi_new)
            tl.store(l_ptr + S_i_offset, li_new)

Some notes about this implementation :

Before looking at the profiling of this v1 naive implementation let’s first look at another requirement I brushed over quickly but that is really important: SRAM limit. Because we are living in the real world we can’t have infinite SRAM on GPUs (sigh! ..), we need to have a ballpark estimation of how much SRAM we need to allocate and if this fits into my device SRAM.

Quick back-of-the-enveloppe calculation, the kernel at minimum needs to store in SRAM simultaneously:

For B=10; Bc = 32; N_h = 64; S = 64; D_h = 32, Total SRAM: 2 × Bc + 3 × Bc × D + Bc² floats = 2 × 32 + 3 × 32 × 64 + 32² = 7,232 floats = 7,232 × 4 bytes ≈ 28KB. This fits comfortably in the ~48KB of shared memory available per thread block. It does mean, though, that we are limited to ~2 blocks per SM. If you’re confused about this limit, don’t worry, I’ll explain this promptly once we get into profiling.

Profiling the v1 Implementation

As I mentioned in the Setup and hardware section, I’ll be using the Nsight Compute ncu tool from the start. It’s one of those tools that comes with a LOT of information, and it can be pretty overwhelming. I recommend watching this lecture to get a global overview of all the ncu-ui sections. I also highly recommend reading Modal’s GPU glossary performance section front to back to at least familiarize yourself with the key terms you need to understand GPU program performance, as well as gain a general understanding of how GPUs execute code. If you don’t know what a warp scheduler is, please go read that resource before continuing to the next section.

I ran the ncu profiler using this command to get the full set of timer and metrics:

#!/bin/bash

sudo CUDA_HOME=/opt/cuda /usr/local/NVIDIA-Nsight-Compute/ncu \
    --set full \
    --kernel-name "attn_kernel" \
    --import-source yes \
    -o profile_flash_attn_v1 \
    -f .venv/bin/python3 kernels/triton_flash_att.py

Once the profiling is done, we can open it using ncu-ui, here is an annotated view of the profile once we open it:

output tensor output tensor

Great, let’s drill down on the key metrics. First, we can see that our kernel took 166.47ms to execute. Our kernel used a grid size of (10, 64, 1) matchs our triton grid of (B, N_h). Each block has a size of (128, 1, 1). Because we are using triton to write our kernel, there is no way to specify the block size directly (although there is a num_warps param that we could use). This means that triton chose to use 4 warps per block to run the kernel.

Next, the profiler actually picked up of a very interesting problem that we identified earlier:

The 2.00 theoretical warps per scheduler this kernel can issue according to its occupancy are below the hardware maximum of 8. This kernel’s theoretical occupancy (25.0%) is limited by the required amount of shared memory.

To understand this occupancy problem a little bit more, ncu provides a quite handy occupancy calculator tool. Let’s look at what the tool says :

occupancy v1

Because our shared memory requirement per block is quite high, we can only have 2 active blocks at a time per SM! Low occupancy really hurts performance when there aren’t enough warps to hide the latency. But once occupancy is sufficient for latency hiding, increasing it further can degrade performance. Higher occupancy reduces resources per thread, potentially bottlenecking the kernel on registers or reducing the arithmetic intensity.

To fix this we need to change our SRAM requirement: SRAM only depends on Bc and D_h. So we can either have smaller chunks (although this could impact memory bandwidth), or have more attention heads to reduce each D (if we control model architecture). Either way, we can go back to tweaking this later. Let’s look at the detail section of the profile report to see what other surprises await us!

memory_v1

The whole idea behind Flash Attention is to avoid going to main memory (HBM) for intermediate steps. The 11.58 GB reads and 5.54 GB writes confirm a bottleneck. Quick math: In this naive implementation, we iterate over columns (K, V) in the outer loop and rows (Q, O) in the inner loop. This loop order forces us to reload the Accumulator (OO) and Query (QQ) from HBM for every single chunk of keys. With S/Bc=64S/Bc = 64 chunks, we are reading and writing the entire output matrix 64 times! Math: Size of O83 MBO \approx 83 \text{ MB}. Reads 64×(Q+O)10.6 GB\approx 64 \times (Q + O) \approx 10.6 \text{ GB}. Writes 64×O5.3 GB\approx 64 \times O \approx 5.3 \text{ GB}.


  for j in range(0, Tc):
      # Load K_j, V_j from HBM
      # ....
      for i in range(0, Tr):
          # ....
          # -> Reading previous O_i (Bc,D)
          prev_oi = tl.load(o_ptr + offset_i)
          # -> Writing previous O_i (Bc,D)
          tl.store(o_ptr + offset_i, oi_new)

We are treating HBM like a register, which explains the massive bandwidth usage!

The last thing to look at before starting to look at the potential solutions we can implement is the source from the profiler.

source_v1

ncu marks problematic kernel code lines with a warning emoji ⚠️, which is quite helpful! The div operation is usually costly on CPUs, and it’s basically the same story for GPUs. CUDA implements floating-point division on GPUs via reciprocal + multiply ops: using MUFU (Multi-Function Unit) followed by one or two FFMA or FMUL instructions to refine precision. If we look at the SASS disassembly for this line:

     MUFU.RCP R8, R8
     ...
@P5  FMUL R29, R29, 0.25
@P5  FMUL R15, R15, 0.25
@!P4 FMUL R10, R10, 16777216
@!P6 FMUL R29, R29, 16777216
@!P4 FMUL R23, R23, 16777216
@!P6 FMUL R15, R15, 16777216
     FMUL R10, R9, R10
     FMUL R29, R8, R29
     FMUL R9, R9, R23
     FMUL R8, R8, R15

Because we are doing the division inside the hot loop we are doing a lot more work here and we can start thinking about how we can avoid this division all together.

  oi_new = (
      alpha[:, None] * prev_li[:, None] * prev_oi
      + beta[:, None] * tl.dot(pij, vj)
  ) / li_new[:, None]

Next Implementation Plan

Based on the ncu profiling data and our analysis of the memory traffic, we have identified three critical bottlenecks to address in the next iteration (V2).

  1. Invert the Loop Order: The massive 11.58 GB of main memory reads and 5.54 GB of writes is our biggest performance killer. It stems from treating HBM as a temporary register for our output accumulator. Current loop forces us to repeatedly read/write the accumulator OO and reload QQ for every block of KK. We must invert the loops: Parallelize the kernel over Queries (rows) so that each thread block handles a tile of QQ:

    • Stationary Data: Rows of OO and QQ remain in SRAM/registers for the entire kernel.
    • Streaming Data: KK and VV stream from HBM. Since all blocks share the same K/VK/V, the L2 cache will absorb most of the traffic (coalesced access).
  2. Defer Normalization: Before describing the fix, let’s recall the true attention output: Oi=jeSijVjjeSijO_i = \frac{\sum_j e^{S_{ij}} V_j}{\sum_j e^{S_{ij}}}. Inside the kernel, we currently divide during every iteration of the loop: Onew=lnewO_{\text{new}} = \frac{\dots}{l_{\text{new}}}. This normalization step does not need to happen each time. We will store the unnormalized accumulated numerator and denominator in registers throughout the loop. We will perform the normalization once, at the end of the kernel, just before writing the result to HBM.

  3. Tuning Block Sizes: Theoretical occupancy is capped at 25% due to shared-memory pressure. Smaller B_c may reduce register/SRAM pressure and increase the number of active warps. But too small a tile can fragment memory access and reduce bandwidth efficiency. I’ll leave tweaking this value for later.

(My) Flash Attention v2

The v2 implementation (kernels/triton_flash_att_v2.py) makes critical changes for better performance:

Reorganized Loop Structure

@triton.jit
def attn_kernel(
    Q, K, V, O,
    S, stride_H, softmax_scale,
    D: tl.constexpr, Tc: tl.constexpr, Bc: tl.constexpr
  ):
    # ... offset setup
    # Load query block ONCE at the start !!
    qi = tl.load(
        q_ptr + offset_i, mask=mask[:, None], other=0.0
    )  # shape (Bc,D)

    # Block accumulator and running max in SRAM  !!
    prev_li = tl.zeros([Bc], dtype=tl.float32)
    prev_mi = tl.zeros([Bc], dtype=tl.float32) - float("inf")
    acc = tl.zeros([Bc, D], dtype=tl.float32)

    for j in range(0, Tc):
        # .. setup offset for K and V
        # Load K_j, V_j from HBM to SRAM
        kj = tl.load(k_ptr + offset_j)  # shape(Bc,D)
        vj = tl.load(v_ptr + offset_j)  # shape(Bc,D)

        # Compute Sij on SRAM : Q_i * K_j.T / sqrt(D)
        Sij = tl.dot(qi, tl.trans(kj)) * softmax_scale  # (Bc,Bc)


        # accumulators
        mij = tl.max(Sij, 1) # Rowmax(Sij): (Bc,)
        pij = tl.exp(Sij - mij[:, None])  # (Bc,Bc)
        lij = tl.sum(pij, 1)  # (Bc,)

        # Running maximum
        mi_new = tl.maximum(prev_mi, mij)

        # Compute scaling factors using previous_max
        alpha = tl.exp(prev_mi - mi_new)
        beta = tl.exp(mij - mi_new)

        # Update running sum
        li_new = prev_li * alpha + lij * beta

        # Update the output block
        acc = alpha[:, None] * acc + beta[:, None] * tl.dot(pij, vj)

        prev_li = li_new
        prev_mi = mi_new

    # Divide by the  last accumulated sum !
    acc = acc / prev_li[:, None]
    # Update in HBM
    tl.store(o_ptr + offset_i, acc)

I also adjusted the grid configuration:

grid = lambda META: (triton.cdiv(S, META["Bc"]), B * N_h)

Where Dimension 0: Number of Q blocks = S / Bc and Dimension 1: Batch × Heads = B × N_h. To reiterate the main changes compared to v1:

  1. Single loop: Each thread block processes one Q block, iterating over all K/V blocks. Each thread block is independent, allowing parallelism. With B=10, N_h=64, S=1024, D_h=32 Bc=32, we launch 32 × 640 = 20,480 independent thread blocks.
  2. Load Q once: Query block is loaded once and reused across all iterations
  3. Register accumulation: Output accumulator acc stays in fast registers, no main memory writes until the end
  4. No intermediate main memory traffic: Notice that ll and mm are kept in registers, only final output written to main memory.

Profiling the v2 Implementation

Great, let’s profile our newly crafted kernel and see what ncu tells us!

compare_v2_v1

Well, it is faster that the v1 but only 6% faster. Let’s first look at the memory chart that was the biggest

memory_v2_v1

Great! Rewriting reduced main memory reads to 412.18 MB (-92.98%). We are also writing 80MB corresponding exactly to the OO matrix size. But there is still a clear problem somewhere—we would expect a lot more speedup as we are now parallelizing 32x more. Let’s look at the occupancy; maybe it is somehow worse than v1 (sanity check here)?

occupancy_v2

We are at 63% occupancy and are limited by shared memory with ~12KB per block. So that’s clearly not the culprit. Let’s go back to the summary page of the profile; ncu picks up on these issues:

  • Shared Load Bank Conflicts Est. Speedup: 63.57%: The memory access pattern for shared loads might not be optimal and causes on average a 6.3 - way bank conflict across all 293601280 shared load requests. This results in 1174579308 bank conflicts, which represent 63.64% of the overall 1845667948 wavefronts for shared loads. Check the  Source Counters section for uncoalesced shared loads.
  • Uncoalesced Shared Accesses Est. Speedup: 61.46%: This kernel has uncoalesced shared accesses resulting in a total of 1174405120 excessive wavefronts (62% of the total 1909063680 wavefronts). Check the L1 Wavefronts Shared Excessive table for the primary source locations.
  • MIO Throttle Stalls Est. Speedup: 50.06%: On average, each warp of this workload spends 18.4 cycles being stalled waiting for the MIO (memory input/output) instruction queue to be not full. This stall reason is high in cases of extreme utilization of the MIO pipelines, which include special math instructions, dynamic branches, as well as shared memory instructions. When caused by shared memory accesses, trying to use fewer but wider loads can reduce pipeline pressure. This stall type represents about 50.7% of the total average of 36.2 cycles between issuing two instructions.

It seems to be a problem with how we access shared memory in our kernel. A lot of terms are thrown around here, so let’s first understand what bank conflicts, uncoalesced shared access, and L1 wavefronts shared load mean.

The profiler is screaming at us about bank conflicts and wavefronts. Before we can fix anything, we need to understand what’s actually happening at the hardware level. Let’s build intuition from the ground up.

Remember from the memory hierarchy section that shared memory (SRAM) lives on-chip, right next to the compute units. It’s blazingly fast (~10-20 TB/s) but there’s a catch: shared memory isn’t a single monolithic block. It’s physically divided into 32 memory banks that can be accessed simultaneously (remember how each warp has 32 threads).

Think of these banks like checkout lanes at a grocery store. If 32 customers each go to a different lane, everyone gets served in one “cycle.” But if multiple customers try to use the same lane, they have to wait in line - that’s a bank conflict.

The mapping from memory address to bank is straightforward:

Bank number=byte address4mod  32 \text{Bank number} = \left\lfloor \frac{\text{byte address}}{4} \right\rfloor \mod 32

For a float32 array in shared memory, consecutive elements land in consecutive banks:

data[0]  → byte 0   → bank 0
data[1]  → byte 4   → bank 1
...
data[31] → byte 124 → bank 31
data[32] → byte 128 → bank 0   ← wraps around!

This is intentional! It means that when threads in a warp access consecutive array elements (a very common pattern), each thread hits a different bank. Perfect parallelism !

What Goes Wrong: The Strided Access Pattern

Now here’s where things get spicy. Let’s look at which lines in our kernels this access get’s spicy. Thankfully ncu again show us the exact problematic line :

tranpose_v2

So it turns out, that this line is the culprit :

Sij = tl.dot(qi, tl.trans(kj)) * softmax_scale

But why though ? To understand why we are hitting shared memory bank conflict we’ll need to go a little bit deeper. How deep you ask? Well to the end of the world… But first, let’s continue understanding the remaining ncu reported issues.

Wavefronts: The Unit of Memory Work

The profiler keeps mentioning “wavefronts” - what are those? A wavefront (sometimes called a memory transaction) is a single memory operation that the hardware executes atomically. In the ideal case: 32 threads request shared memory, all hit different banks → 1 wavefront → 1 cycle

Now the profiler numbers start making sense:

Metric Value Meaning
Shared load requests 293,601,280 Times our kernel asked for shared memory
Bank conflicts 1,174,579,308 Extra transactions due to serialization
Total wavefronts 1,845,667,948 Actual memory operations executed
Conflict rate 63.64% Almost 2/3 of bandwidth wasted!

The 6.3-way average conflict means that on average, 6.3 threads are fighting for the same bank.

The profiler is telling us we’re doing 6x more memory transactions than necessary. No wonder we only got only 6% speedup despite reducing main memory traffic by 93%!

Uncoalesced Accesses: Same but different, but same

The profiler also complains about **uncoalesced shared accesses" with **62% excessive wavefronts**. This is essentially the same problem viewed from a different angle: Threads access scattered addresses. Our column reads are stride-D accesses, which the hardware cannot combine efficiently. The “excessive wavefronts” metric is counting how many extra transactions we’re issuing beyond the theoretical minimum.

MIO Throttle: The Pipeline Backs Up

The third warning - MIO Throttle Stalls (50.06%) - is a consequence of the first two. MIO (Memory Input/Output) is the pipeline that handles:

When we’re issuing 6x more shared memory transactions than necessary, the MIO pipeline gets clogged. Warps have to stall waiting for the pipeline to clear, which is what the 50% stall rate is measuring.

PTX and fun

OK, let’s go back and drill down on this single line:

Sij = tl.dot(qi, tl.trans(kj)) * softmax_scale

Well, we can’t just look at it, so let’s see what the Triton compiler generated for it. Triton compiles kernels down to PTX, and you can very easily map back the PTX code to Python line code:

TRITON_CACHE_DIR="./triton_dump" python kernels/triton_flash_att_v2.py

awk '/\.loc.* 83 /{p=1; print; next} /\.loc/{p=0} p' triton_dump/[ID]/attn_kernel.ptx > attn_kernel_v2.ptx

Great! Let’s look at the assembly for the line triton_flash_att_v2.py:83:

.loc	1 83 34                         // triton_flash_att_v2.py:83:34
bar.sync 	0;
// We are storing stuff in shared memory `st.`
st.shared.v4.b32 	[%r4], {%r100, %r101, %r102, %r103};
st.shared.v4.b32 	[%r4+2048], {%r104, %r105, %r106, %r107};
st.shared.v4.b32 	[%r4+4096], {%r108, %r109, %r110, %r111};
st.shared.v4.b32 	[%r4+6144], {%r112, %r113, %r114, %r115};


// Here we are loading something from shared memory `ld.`
.loc	1 83 34                         // triton_flash_att_v2.py:83:34
ld.shared.v4.b32 	{%r388, %r389, %r390, %r391}, [%r8];
ld.shared.v4.b32 	{%r392, %r393, %r394, %r395}, [%r8+256];
ld.shared.v4.b32 	{%r396, %r397, %r398, %r399}, [%r8+16];
ld.shared.v4.b32 	{%r400, %r401, %r402, %r403}, [%r8+272];
ld.shared.v4.b32 	{%r404, %r405, %r406, %r407}, [%r8+32];
// ... A lot more ld.shared.v4.b32 instruction

// Then finally we do the acual matmul `tl.dot`
.loc	1 83 25                         // triton_flash_att_v2.py:83:25
fma.rn.f32 	%r516, %r132, %r388, 0f00000000;
fma.rn.f32 	%r517, %r133, %r389, %r516;
fma.rn.f32 	%r518, %r134, %r390, %r517;

// And this mul is the `.softmax_scale` we can ignore it
.loc	1 83 41                         // triton_flash_att_v2.py:83:41
mul.f32 	%r1028, %r25, %r579;
mul.f32 	%r1029, %r25, %r643;
mul.f32 	%r1030, %r25, %r707;
mul.f32 	%r1031, %r25, %r771;
mul.f32 	%r1032, %r25, %r835;

The storing PTX part is pretty straightforward: we are storing K from registers to shared memory in parallel with each of the 128 threads in the warp executing these instructions in parallel. st.shared.v4.b32 is a vectorized instruction, meaning we store 4 * 4bytes = 64bytes in one shot. Remember that we have KjK_j of size (Bc, D), meaning each chunk is 32 x 64 = 8192 bytes. We are issuing 4 stores of 64 bytes per thread, and with 128 threads this amounts to exactly 8192 bytes. Important detail: we are storing K in row-major order in shared memory, with a row stride of 256 bytes. Here is how K is stored in shared memory with (a base offset of 0 to simplify):

output tensor output tensor

Now, the next set of instructions are very interesting. ld.shared.v4.b32 is a vectorized load of 4x 4bytes wide from shared memory. Remember again, these instructions are executed in parallel by all threads in the warp, so to understand where the conflict happens, we need to look at what adresses each thread is loading from shared memory.

ld.shared.v4.b32 	{%r388, %r389, %r390, %r391}, [%r8];
ld.shared.v4.b32 	{%r392, %r393, %r394, %r395}, [%r8+256];
ld.shared.v4.b32 	{%r396, %r397, %r398, %r399}, [%r8+16];
ld.shared.v4.b32 	{%r400, %r401, %r402, %r403}, [%r8+272];
ld.shared.v4.b32 	{%r404, %r405, %r406, %r407}, [%r8+32];

Looking at the pattern, it seems that we are doing multiple loads with an offset from a base register %r8, so to know from which memory address we are effectively loading, we need to figure out what this register holds for each thread. A quick grep across the PTX codebase and we find this:

mov.u32   %r1, %tid.x;              // Thread ID (0-127)
and.b32   %r77, %r1, 15;            // tid & 0xF (0-15)
shl.b32   %r93, %r77, 9;            // shift left by 9
add.s32   %r8, %r91, %r93;          // %r8 = smem_base + lane_group * 512

Nice! I thought it would be way more tedious. But let’s break down what these instructions are doing. Thankfully, we are in 2025 and LLMs exist to explain PTX; one prompt to Gemini later:

This means that for each thread in the warp we have:

Lanetid15r8 = base + offset00base+011base+51222base+102433base+15361414base+71681515base+7680160base + 0duplicate!171base+5123115base+7680 \begin{array}{c|c|l} \text{Lane} & \text{tid} \land 15 & \text{r8 = base + offset} \\ \hline 0 & 0 & \text{base} + 0 \\ 1 & 1 & \text{base} + 512 \\ 2 & 2 & \text{base} + 1024 \\ 3 & 3 & \text{base} + 1536 \\ \cdots & \cdots & \cdots \\ 14 & 14 & \text{base} + 7168 \\ 15 & 15 & \text{base} + 7680 \\ \hline \textbf{16} & \textbf{0} & \textbf{base + 0} \leftarrow \text{duplicate!} \\ 17 & 1 & \text{base} + 512 \\ \cdots & \cdots & \cdots \\ 31 & 15 & \text{base} + 7680 \\ \end{array}

Key insight: Due to masking tid & 15, only 16 unique base addresses exist. Lanes 16-31 duplicate lanes 0-15.

So within a single warp of 32 threads, threads 0-15 get unique base addresses (0×512, 1×512, … 15×512) and threads 16-31 get duplicate addresses (same as lanes 0-15)! But bear in mind, we are not doing 2x the loading. When two lanes access the same address, the hardware broadcasts—it, fetches once and delivers to both. No conflict here. Let’s dig deeper to see where the conflict occurs.

Now that we have the base pointer address, let’s see what each thread loads. The PTX shows that we executed 32 loads per thread to get data for the matrix multiply:

ld.shared.v4.b32  [%r8];        // 4 floats at base + 0
ld.shared.v4.b32  [%r8+256];    // 4 floats at base + 256 (next row)
ld.shared.v4.b32  [%r8+16];     // 4 floats at base + 16 (next 4 cols)
ld.shared.v4.b32  [%r8+272];    // 4 floats at base + 272 (next row, next 4 cols)
ld.shared.v4.b32  [%r8+32];     // 4 floats at base + 32
ld.shared.v4.b32  [%r8+288];    // 4 floats at base + 288
ld.shared.v4.b32  [%r8+48];     // ...
ld.shared.v4.b32  [%r8+304];
ld.shared.v4.b32  [%r8+64];
ld.shared.v4.b32  [%r8+320];
ld.shared.v4.b32  [%r8+80];
ld.shared.v4.b32  [%r8+336];
ld.shared.v4.b32  [%r8+96];
ld.shared.v4.b32  [%r8+352];
ld.shared.v4.b32  [%r8+112];
ld.shared.v4.b32  [%r8+368];
... (32 total loads)

The offsets follow an interleaved pattern:

Let’s draw over our previous shared memory schema to see which addresses of K are loaded by which threads and to clearly visualize this interleaved access pattern:

output tensor output tensor

And there is the bank conflict right there!

LaneAddressBank 0Bank 1Bank 2Bank 30base+01base+512conflict2base+1024conflict3base+1536conflict4base+2048conflict5base+2560conflict6base+3072conflict7base+3584conflict8base+4096conflict9base+4608conflict10base+5120conflict11base+5632conflict12base+6144conflict13base+6656conflict14base+7168conflict15base+7680conflict1631broadcast from lanes 0-15 \begin{array}{c|l|cccc|l} \text{Lane} & \text{Address} & \text{Bank 0} & \text{Bank 1} & \text{Bank 2} & \text{Bank 3} & \\ \hline 0 & \text{base} + 0 & \bullet & \bullet & \bullet & \bullet & \\ 1 & \text{base} + 512 & \bullet & \bullet & \bullet & \bullet & \leftarrow \text{conflict} \\ 2 & \text{base} + 1024 & \bullet & \bullet & \bullet & \bullet & \leftarrow \text{conflict} \\ 3 & \text{base} + 1536 & \bullet & \bullet & \bullet & \bullet & \leftarrow \text{conflict} \\ 4 & \text{base} + 2048 & \bullet & \bullet & \bullet & \bullet & \leftarrow \text{conflict} \\ 5 & \text{base} + 2560 & \bullet & \bullet & \bullet & \bullet & \leftarrow \text{conflict} \\ 6 & \text{base} + 3072 & \bullet & \bullet & \bullet & \bullet & \leftarrow \text{conflict} \\ 7 & \text{base} + 3584 & \bullet & \bullet & \bullet & \bullet & \leftarrow \text{conflict} \\ 8 & \text{base} + 4096 & \bullet & \bullet & \bullet & \bullet & \leftarrow \text{conflict} \\ 9 & \text{base} + 4608 & \bullet & \bullet & \bullet & \bullet & \leftarrow \text{conflict} \\ 10 & \text{base} + 5120 & \bullet & \bullet & \bullet & \bullet & \leftarrow \text{conflict} \\ 11 & \text{base} + 5632 & \bullet & \bullet & \bullet & \bullet & \leftarrow \text{conflict} \\ 12 & \text{base} + 6144 & \bullet & \bullet & \bullet & \bullet & \leftarrow \text{conflict} \\ 13 & \text{base} + 6656 & \bullet & \bullet & \bullet & \bullet & \leftarrow \text{conflict} \\ 14 & \text{base} + 7168 & \bullet & \bullet & \bullet & \bullet & \leftarrow \text{conflict} \\ 15 & \text{base} + 7680 & \bullet & \bullet & \bullet & \bullet & \leftarrow \text{conflict} \\ \hline 16-31 & \text{broadcast from lanes 0-15} & & & & & \\ \end{array}

Result: 16 requests to banks 0,1,2,3. Hardware serializes these into 16 phases. This means we have a theorical efficiency in this section of: 1/16 = 6.25% Waste: 15/16 = 93.75% !!

One thing we could do is set D = D+1 and see what would happen. Theoretically, the off-by-one stride would spread accesses across all banks (This is the idea behind padding - adding a dummy column to break the stride alignment).

Unfortunately, you can’t set D to non power of 2. So we end up doubling the dim. The padding breaks bank conflict but it ends up doing a lot more loads and being way slower than the version. The implementation is kernels/triton_flash_att_v2_padded.py is very close to v2 with the addition to a function that computes the padded head_dim


def compute_padded_headdim(D_h):
    """
    Compute padded head dimension to avoid bank conflicts.

    The doubling provides extra padding that naturally breaks stride-based conflicts.
    With stride=64 instead of 32, threads access different bank groups.
    """
    # Find next power of 2
    if D_h <= 0:
        return 1
    # Check if already power of 2
    if (D_h & (D_h - 1)) == 0:
        # Already power of 2, double it
        return D_h * 2
    else:
        # Round up to next power of 2
        return 1 << (D_h - 1).bit_length()

We can see that bank conflicts got down to 3.4 - way bank conflict across all 21135360 shared store requests. So I guess this worked 😂..

Back to the drawing board, Now that we understand the problem, we have several options:

  1. Padding: Allocate K as (Bc, D+1) in shared memory but only use D columns - breaks the stride alignment -> works, but additional work makes it slower.
  2. Pre-transpose K: Store K in column-major order before the kernel runs, so “column” reads become row reads
  3. Larger D: If we increase head dimension to something not divisible by 32, conflicts naturally reduce (but this changes the model architecture)
  4. Swizzling: Use a permuted memory layout that distributes bank accesses (this is what CUTLASS and newer Flash Attention versions do)

In the next section, I’ll implement the transpose option and see how much performance we can claw back.

(My) Flash Attention v2 Transpose

The v2 transpose kernel kernels/triton_flash_att_v2_transpose.py seems a little bit like cheating—we transpose the KK matrix beforehand. A small but important detail is to make sure that the transpose isn’t just a view and force it to be contiguous in memory:

  k_trans = k.transpose(-1, -2).contiguous() # IMPORTANT!
  attn_kernel[grid](
    q, k_trans, v, o,
    S,
    q.stride(1), stride_k_d, stride_k_s,
    1 / math.sqrt(D_h),
    D_h, Tc, Bc,
  )

Now the only thing that changes is making sure we load the correct columns in the kernel. Note that I hoisted the softmax_scaling from the inside loop to save mul ops


@triton.jit
def attn_kernel(
    Q, K, V, O,
    S,
    stride_H,
    stride_k_d,
    stride_k_s,
    softmax_scale,
    D: tl.constexpr,
    Tc: tl.constexpr,
    Bc: tl.constexpr,
):
    #  ...

    # Pre-scale Q to save muls inside loop
    qi = qi * softmax_scale

    for j in range(0, Tc):
        # Pointer Math: (Row_Idx * Stride_Row) + (Col_Idx * Stride_Col)
        # Row_Idx is offs_d (0..D)
        # Col_Idx is current_cols (S dimension)
        offset_j_k = (offs_d[:, None] * stride_k_d) + (current_cols[None, :] * stride_k_s)

        # Load K (D, Bc) directly!
        kj = tl.load(k_ptr + offset_j_k)

        # ...
        Sij = tl.dot(qi, kj) # no transpose needed for kj !

    # ...

Let’s profile it and look at the result in comparison to the previous two kernels:

compare_v2_transpose

Another key metric to track (we haven’t until now because we needed to fix obvious stuff), is to look at the GPU speed of light throughput. This is also known as roofline analysis.

roofline_v2_trans

Our v2 transpose kernel has way higher arithmetic intensity thanks to removing shared memory access conflicts. This can also be corroborated by looking at the warp scheduler statistics. We can see +153% more eligible warps per cycle as they aren’t stalled by serial shared memory access.

Great! Kernel ran in 34ms which is +145% improvement compared to the v1! Looking at ncu performance opportunities we can see that the next big blocker is :

Mio Throttle Stalls Est. Speedup: 43.97%: On average, each warp of this workload spends 6.7 cycles being stalled waiting for the MIO (memory input/output) instruction queue to be not full. This stall reason is high in cases of extreme utilization of the MIO pipelines, which include special math instructions, dynamic branches, as well as shared memory instructions. When caused by shared memory accesses, trying to use fewer but wider loads can reduce pipeline pressure. This stall type represents about 44.0% of the total average of 15.3 cycles between issuing two instructions.

Still Having an MIO Bottleneck

So we fixed bank conflicts and got a nice 145% speedup. What’s left? The MIO (Memory Input/Output) pipeline is now our bottleneck. Despite the name, this isn’t about main memory - MIO handles two things:

Every iteration of our inner loop calls tl.exp for the softmax and tl.max for numerical stability. These operations go through the SFU (Special Function Unit), which is much slower than the main FMA units. With Bc=32, we’re doing these expensive operations very frequently relative to the actual matmul work.

I tried a few things to reduce MIO pressure:

  1. FP16: Wrote an FP16 kernel hoping for a speedup by lowering shared memory requirements and increasing D_h and Bc. Ended up slower than FP32 because I kept fighting the compiler—it kept inserting extra shared memory accesses for type conversions, and I didn’t get any SRAM savings because FP16 precision sucks for accumulators and I needed to keep all the on-chip accumulators in FP32 anyway.

  2. Larger block sizes: Increasing Bc means fewer loop iterations, so fewer exp/max calls per output element. But my RTX 2070 doesn’t have enough shared memory to go much higher without killing occupancy.

Anyway, the main take here is to try to reduce MIO stalls by having larger block reads and reducing all special function math instructions.

The core Tensor Core problem

Looking at the instruction stats, something else jumped out:

inst_stats

The kernel is dominated by FFMA (Fused Floating-point Multiply-Add) instructions. These are regular CUDA core operations, not tensor core operations. Tensor cores can do 4x4 matrix multiplies in a single cycle—they’re the whole reason modern GPUs are so fast at deep learning. But my tl.dot calls aren’t using them for some reason!

I fought with this for a while. Tried different memory layouts, alignment hints, different block sizes… Nothing worked. Turns out, on SM 7.5 (Turing architecture), Triton struggles to generate tensor core code. The compiler just falls back to regular FMA instructions which run on regular old CUDA cores, clogging up the pipe.

I guess I need to stop being poor and buy a newer GPU. More realistically, I’ll rent an H100 and see if the same code magically starts using tensor cores on SM 9.0…

Comparison with the Real Flash Attention v2

At this point I’ve hit the limits of what I can optimize on my hardware.

Let’s open up the actual Flash Attention v2 paper4 and see what we got right and what we missed:

1. Reducing Non-Matmul FLOPs ✅ (partially)

We did defer the final division to the end of the kernel. But FA2 goes further - it restructures the entire online softmax to minimize exp and max operations. GPUs have separate units for matmul (Tensor Cores, ~300 TFLOPs on A100) vs generic math (CUDA cores, ~20 TFLOPs). Every non-matmul operation is 15x slower, so minimizing them matters a lot.

2. Parallelism Over Sequence Length

We nailed this one. Our v2 kernel parallelizes over (S/Bc, B*N_h) instead of just (B, N_h). This is exactly what FA2 does - split the query sequence into chunks and assign them to different thread blocks.

3. Warp-Level Work Partitioning

This is where Triton abstracts too much away. FA2 carefully controls how warps within a block divide work: instead of “Split-K” (warps split the K/V dimension and sync to combine results), they use “Split-Q” (warps split the Q dimension and work independently). This removes expensive synchronization barriers.

In Triton, we don’t control warp-level scheduling - the compiler decides. We could inspect the generated PTX to see what it’s doing, but we can’t easily change it.

4. Larger Head Dimensions 🤷

FA2 supports D_h up to 256 efficiently. Our kernel works with any head dimension, but we haven’t optimized the tiling specifically for larger values. With D_h=128 or 256, you’d want different block sizes and potentially different memory layouts. For implementation simplicity reasons, I kept Bc==Br, so we haven’t gotten to play with tweaking these parameters.

Conclusion

Don’t get it twisted—Flash Attention is a brilliant example of algorithm-hardware co-design. My goal here wasn’t to diminish the work of Tri Dao, but to walk through the reasoning myself. Every brilliant solution seems kind of obvious once you have it in front of you, but having the insight and technical chops to come up with it in the first place is a whole other thing.

I tried my best to demystify the core ideas behind Flash Attention, and show how understanding the GPU memory hierarchy lets you optimize iteratively: profile, rewrite, tweak, rinse and repeat.

If you’ve fallen asleep three times reading this and just woke up, here are the key takeaways:

  1. Tiling: Process attention in blocks that fit in fast SRAM
  2. Online softmax: Compute softmax incrementally without materializing the full attention matrix
  3. Minimize HBM traffic: Load data once, keep accumulators in registers/SRAM, write only the final result

There’s more to explore - FA3 brings asynchronous memory copies and FP8 support, FA4 pushes things even further with Blackwell-specific optimizations. But that’s for another post (and another GPU).

The full implementation is available in this repository, including the Triton kernels and profiling scripts.

References