⏲️ Estimated reading time ~45min.
Flash Attention has become one of the most impactful optimizations in modern deep learning. Since the original paper was published in 2022, we’ve seen four major versions—each squeezing more performance out of increasingly powerful hardware. But here’s the thing: reading papers is one thing, understanding why these optimizations were made is another entirely.
My goal here is simple: start from first principles, implement FlashAttention v1 exactly as described in the paper, profile it, find the bottlenecks, and see how far we can push it. We’ll build intuition by iterating on the algorithm, discovering through profiling exactly why v2, v3, and v4 were necessary. Think of this as archaeology—digging through the performance layers to understand what each version was really solving.
So let’s put ourselves in the shoes of that mythical Stanford grad student. You’ve finally finished configuring your neovim and archlinux setup (a multi-year endeavor, naturally). You open up a fresh LLaMA model to peek under the hood. Text goes in, gets tokenized, embedded, then flows through a stack of transformer blocks. Standard stuff. But then you look closer at the attention mechanism—three projections and then… there it is:
scores = torch.matmul(q, k.transpose(3, 2)) / math.sqrt(self.head_dim)
scores = F.softmax(scores, dim=-1)
output = torch.matmul(scores, v) # (B, N_h, T, D_h)
This monstrosity of a code is staring you right in the face. For those who don’t immediately see the problem with these 4 lines, let me add some annotations on how this would normally execute in PyTorch (without compilation).
# 1. We load q `(B, N_h, S, D_h)` and k `(B, N_h, S, D_h)`
# 2. We compute Q.Kt and write it back to HBM. Note score is `(B, N_h, S, S)`.
scores = torch.matmul(q, k.transpose(3, 2)) / math.sqrt(self.head_dim)
# 3. Reload the scores tensor to compute the softmax and write it back to HBM
scores = F.softmax(scores, dim=-1)
# 4. Load v `(B, N_h, S, D_h)`, load the scores from HBM
# 5. Compute scores@v and write it back to HBM.
output = torch.matmul(scores, v) # `(B, N_h, T, D_h)`
Do you see it now? Well, we have three tensors q, k, v each of dimension (B, N_h, T, D_h)1. The output tensor of the attention mechanism is (B, N_h, T, D_h), and somehow in the middle we had to materialize a (B, N_h, S, S) tensor for funsies. The attention mechanism has a critical bottleneck: quadratic memory complexity. Let’s take a standard training sequence length S=8192. Computing attention naively requires O(S²) memory to store the full attention matrix, which means consuming several gigabytes of GPU memory. Crucially, in modern transformers, S >> D_h (sequence length is much larger than head dimension) - we typically have S=8192 or more while D_h=64 or 128. This massive asymmetry is what makes Flash Attention algorithm possible. The second big issue here is the back and forth to HBM. Modern GPUs have compute throughput vastly exceeding memory bandwidth. Repeatedly reading Q, K, V, and scores from slow High Bandwidth Memory (HBM) is going to greatly impact performance (more on this later).
The whole idea of Flash Attention is to bypass these intermediate steps—i.e., go from tensors q, k, v to the output tensor directly and compute the attention in one go, with minimal memory footprint (materializing only the tensors we need) and minimal back and forth to HBM, and hopefully getting to O(S) memory complexity.
The Plan: Start with FlashAttention v1 from the original paper. Implement it faithfully in Triton, profile with NVIDIA’s tooling, identify the bottlenecks, then iterate. Each optimization will teach us something about the GPU memory hierarchy and why the subsequent versions (v2, v3, v4) introduced the changes they did. Four versions represents an enormous amount of engineering work—let’s see how much of that journey we can reconstruct by just following the profiler’s breadcrumbs.
First things first, I’ll use triton to implement the attention kernels. I wanted to implement them in CUDA but I thought this was a good opportunity to learn triton. I’m always skeptical about DSLs and abstractions that promise “write once, run fast everywhere,” but like my very wise friend Chris Fleetwood said, you can’t go wrong learning something shipped inside PyTorch. Running triton is also extremely straightforward and reduces the mess of boilerplate and C++ code that you have to write when implementing CUDA kernels, especially if you want to call them from Python. I actually went back and reimplemented Flash Attention in CUDA to have a bit more control, but that will come later (a little teasing).
I’ll compare these kernels directly with a reference PyTorch implementation, first to make sure the kernel is doing what it’s supposed to, and second to profile the kernel against the baseline.
Triton is a Python-based DSL for writing GPU kernels. The pitch is simple: write Python-like code at the block level, and the compiler handles individual thread management, memory coalescing, and other low-level optimizations. What makes Triton fundamentally different from CUDA is the abstraction level - in CUDA, you explicitly program individual threads of a block. In Triton, you write your kernel thinking about blocks/tiles of data a little bit how you’d write torch code, and the compiler figures out how to map that to threads. It handles threads data access pattern, synchronization barriers and work splitting. My skepticism comes from the fact that I don’t really trust these “magic” compilers for running parallel code. Usually, the promised performance falls apart quickly. But Triton has a few things going for it:
I’m running this on my personal machine:
Here are the GPU specs from cudaGetDeviceProperties. These numbers will be useful once we start profiling our kernel to detect performance issues:
| Name | NVIDIA GeForce RTX 2070 |
| Compute Capability | 7.5 |
| Total Memory | 8 GB |
| MultiProcessor Count (SM_count) | 36 |
| Max Threads per MultiProcessor | 1024 |
| Warp Size | 32 |
| L2 Cache Size | 4 MB |
| Shared Memory per Block | 48 KB |
| Shared Memory per MultiProcessor | 64 KB |
| Registers per MultiProcessor | 65536 |
For profiling, I use three tools at different granularities:
torch.profiler: Quick and dirty. Good for seeing wall-clock time and basic GPU utilization. I use this for initial sanity checks.
NVIDIA Nsight Systems (nsys): System-wide profiler. Shows CPU/GPU timeline, kernel launches, memory transfers. Great for spotting gaps where the GPU is idle.
NVIDIA Nsight Compute (ncu): The heavy hitter. Gives you everything: occupancy, memory throughput, warp stalls, instruction mix, bank conflicts. This is what I’ll use for deep-diving into kernel performance. You can run it with:
sudo ncu --set full --kernel-name "attn_kernel" -o profile_output -f python script.py
Then open the .ncu-rep file with ncu-ui for the full analysis.
For modern GPUs, compute throughput vastly exceeds memory bandwidth - an A100 can do ~300 TFLOPs but only has ~2 TB/s of memory bandwidth. This massive compute-to-memory ratio means that naive algorithms spending most of their time waiting for memory, not computing. Flash Attention restructures the computation to maximize arithmetic intensity (FLOPs per byte transferred).
Let’s build the intuition behind the core ideas from walking back from the solution. The goal is to one shot (single kernel) build the output tensor.
We are in GPU programming land. If you want to minimize transfers to main memory—like any good GPU engineer who has read Simon’s matmul blog post (stop reading this post and go read it. I am serious. It’s the best matmul writeup on the internet.)—you’d immediately think about computing blocks of output in parallel and writing each block to memory.
This is commonly known as tiling, and every GPU kernel out there uses this trick. The idea is simple but powerful: instead of streaming data directly from slow global memory for each operation, you load a tile (a small block of data) into fast shared memory once, reuse it across multiple computations within a thread block, and only write the final results back. This dramatically reduces memory bandwidth bottlenecks by exploiting data locality and the GPU’s memory hierarchy—exactly what you need when global memory access is orders of magnitude slower than compute.
Let’s first look at how the output tensor is computed based on the attention formula. Let’s focus on a single batch and head (B, N_h) to clearly see the shapes of the matmuls and the elements. Also, we’ll omit the causal masking here (although it’s not really complicated to add it to the attention score before softmax).
If you have been following along, you’re probably scratching your head wondering why we are sharding the V tensor row-wise. Sorry for pulling a quick one here, but this formulation of the output (using the rows of V instead of the columns) will actually help us build the flash attention algorithm. For each row of the output:
Still, why split row-wise and not column-wise? First of all, sharding on the D dim makes less. Like I pointed out earlier, D is usually 64 or 128 whereas S sequence length could be 8192 or more depending on the context window. If we tried to split column-wise (along D), the math would look like this:
This would force us to load the entire score matrix just to compute the left half of the output. But we can’t store the entire matrix—that’s the exact problem FlashAttention solves! By splitting row-wise:
Now, the genius piece of flash attention is this: if we found a way to build iteratively for a given , we could calculate a small chunk , multiply by , add it to the sum, and discard . This means we are only allocating a small block of memory!
But why does this matter anyway? Isn’t it still mathematically equivalent if we compute the whole P_i row? Well, if you live in some mythical mathematical idea world, yes. But this code runs on hardware, and when and how we access memory does matter a LOT!
Understanding GPU memory is crucial, especially because GPUs are SIMT (Single Instruction, Multiple Thread) machines. If you’re familiar with SIMD programming on the CPU side, you usually need to use explicit vector instructions to pack and align vectors before executing instructions. SIMT, on the other hand, lets you write scalar thread code that the GPU transparently executes as a lockstep vector, coalescing loads automatically when memory is well-aligned.
GPUs execute thousands of lightweight threads in lockstep (called warps). Multiple warps form a block, and blocks are scheduled into SMs. Looking back at the Setup section, we can see that my RTX 2070 has 36 SMs (streaming multiprocessors), and each SM can handle 1024 threads (32 warps) concurrently.
GPU cores are simple and rely on this massive parallelism to hide latency—whenever one warp stalls waiting on memory, the scheduler swaps in another warp that already has its data “in flight.” This only works if kernels expose enough parallelism and if memory accesses are structured to take advantage of locality2. A core part of the CUDA programming model is that it exposes the memory hierarchy directly, forcing the programmer (or compiler) to reason explicitly about where data lives and how far it is from the compute units. Memory proximity matters enormously:
DRAM/HBM (High Bandwidth Memory): I don’t have HBM on my RTX 2070, but even on high-end cards like the A100 with large capacity (~80GB), this memory is relatively far from the SMs. Despite impressive raw bandwidth (~1–2 TB/s), latency is high, so naive repeated reads kill performance.
L2 Cache: A chip-wide cache that helps buffer global memory traffic. Still far slower than on-SM memories. L2 cache behavior is typically leveraged implicitly through access patterns rather than explicit optimization.
L1 / Shared Memory (SRAM): On-SM, extremely fast, and explicitly managed in CUDA. This is where SRAM truly shines compared to HBM. First, the physics: SRAM sits directly on the SM, mere micrometers from the compute units (vs millimeters for HBM - that’s 1000x closer!). It uses 6-transistor flip-flop circuits that hold state without refresh cycles, unlike HBM’s capacitor-based cells. This means SRAM access takes only ~20-30 cycles vs ~200-600 for HBM.
The bandwidth story is even more compelling: each SM gets its own ~164KB of SRAM (48 Kb for my poor RTX 2070) with ~1-2 TB/s bandwidth. With 108 SMs on an A100, that’s theoretically ~100+ TB/s aggregate SRAM bandwidth across the chip, compared to “only” ~2 TB/s to HBM. But here’s the real beauty - SRAM is explicitly programmer-controlled! Unlike HBM which goes through L2 cache, L1 cache, and complex replacement policies you can’t control, with SRAM you orchestrate the exact choreography of data movement with computation. No cache thrashing, no surprise evictions, just deterministic high-speed access exactly when you need it.
Registers: The closest memory to the compute units. Tiny (~256KB per SM) but insanely fast (100+ TB/s effective). The compiler allocates these, and register pressure directly impacts occupancy (how many warps can run concurrently, which in turn affects latency hiding).
Look at this diagram of memory hierarchy with bandwidth and memory size taken from Flash Attention v1 very closely, and try to internalize it. It is extremely important to take I/O into account when designing highly performant algorithms.

If you’re still following along with me, the main insight is to restructure attention so that all intermediate activations fit within these fast memories—registers and shared memory—minimizing slow HBM traffic and maximizing warp-level parallelism. This is why building the output block by block using the same size is a core pillar of fast attention. But it is still unclear at this stage how we could build the output incrementally.
Before we go deeper into how we could build these blocks, we need one last quick detour into something I omitted until now for simplification reasons. The last lego block we need in our toolbox before building flash attention is softmax numerical stability.
Numerically stable softmax
Until now I wrote the softmax function for a vector as:
But what happens if input values are large (e.g., )? Computing will exceed the maximum value a floating-point number can hold, resulting in inf (infinity) or NaN (Not a Number) errors. Conversely, underflow can happen if values are very negative—the denominator might vanish to zero, leading to division by zero errors. The trick is to numerically stabilize by subtracting the maximum logit.
The key property is that Softmax is invariant to adding or subtracting a constant ( c ):
For numerical stability, we usually subtract the maximum value (). Setting ( c = -m ) ensures that the exponentials are at most 1, preventing overflow :
So let’s go back to incrementally computing a row of the output . Let’s start slow to build the intuition. We take row i=0 and j=0. This means we are focusing on the first block of and the first row of the value tensor .
If you already see an issue with the current softmax, bear with me a minute. Remember that each element of the score matrix is computed as . So the denominator is clearly broken (we need to compute the whole row to compute this sum).
For now we only have in the row, so the first value we computed is clearly broken. Similarly, here represents the maximum value of the row. For now we’ve only computed one value, so the max is also broken. Let’s continue to the second value of the output row and see how we can fix what is broken:
When we see a new block, we can’t just throw away our previous work. Instead, we need to rescale what we’ve already computed. How though? If we simplify the problem to only having two values in this row (D=2) and we have computed the first one, the correct output value should be :
Thanks to the exponential multiplicative property we can easily build these from the previous work and currently computed block. This is done using running correction factors:
These and factors let us rescale our previous exponentials to the new maximum. Think of it as “adjusting the baseline” - when we find a new maximum or a new sum, we scale down the previous values accordingly.
Now here’s where the magic happens. We also need to update our output . Remember, we had computed using only the first block, but now we need to incorporate the second block. The update formula rescales the old output and adds the contribution from the new block:
Let’s expand this to see what’s really happening. The old output was , so:
Simplifying by combining the exponentials:
And voilà! We get exactly what we’d expect - the proper softmax formula:
The beauty of this approach is that we never had to store the full attention matrix. We just kept updating our running statistics ( and ) and our output incrementally!
This is the core insight of Flash Attention’s online softmax - we can compute exact attention without materializing the entire attention matrix in memory.
TLDR, Flash Attention solves the memory problem through two key insights:
n rows of the matrix. For a single chunk n of the output, we would need n rows of the query matrix Q and all j rows of K and V. More concretely, the algorithm divides the sequence into blocks of size Bc and processes them iteratively.m and sum l) to compute the correct softmax normalization.Let’s roll up our sleeves and implement the algorithm!
Here is the triton implementation of the FA1 algorithm3 straight out of the paper. I tried to follow the algorithm as closely as possible to have a good baseline that we can work from:
We recognize the usual suspects from the math we derived earlier. I will only focus on the forward pass without causal masking to keep focus on the core algorithm. Let’s start with the simple reference implementation in pytorch:
# q,k,v are of shape: `(B, N_h, S, D_h)`
# B: batch_size
# N_h: num heads
# S: sequence length
# D_h: head dim
def simple_attn(q, k, v):
att = q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1)))
att = F.softmax(att, dim=-1)
y = att @ v
return y
My first implementation kernels/triton_flash_att.py follows the original Flash Attention paper very closely:
@triton.jit
def attn_kernel(Q, K, V, O, S, D, Tc, Tr, Bc, Br, softmax_scale, l, m):
# ... [SETUP offsets here]
# Outer loop over K, V blocks
for j in range(0, Tc):
# Load K_j, V_j from HBM to SRAM
kj = tl.load(k_ptr + offset_j) # (Bc, D)
vj = tl.load(v_ptr + offset_j) # (Bc, D)
# Inner loop over Q blocks
for i in range(0, Tr):
# Load Q_i and previous O_i, l_i, m_i
qi = tl.load(q_ptr + offset_i) # `(Bc, D)`
prev_oi = tl.load(o_ptr + offset_i)
prev_li = tl.load(l_ptr + S_i_offset)
prev_mi = tl.load(m_ptr + S_i_offset)
# Compute attention scores for this block
Sij = tl.dot(qi, tl.trans(kj)) * softmax_scale # (Bc, Bc)
# Online softmax update
mij = tl.max(Sij, 1) # Row-wise max
pij = tl.exp(Sij - mij[:, None])
lij = tl.sum(pij, 1) # Row-wise sum
# Update running statistics
mi_new = tl.maximum(prev_mi, mij)
alpha = tl.exp(prev_mi - mi_new)
beta = tl.exp(mij - mi_new)
li_new = prev_li * alpha + lij * beta
# Update output
oi_new = (alpha[:, None] * prev_li[:, None] * prev_oi
+ beta[:, None] * tl.dot(pij, vj)) / li_new[:, None]
# Write back to HBM
tl.store(o_ptr + offset_i, oi_new)
tl.store(m_ptr + S_i_offset, mi_new)
tl.store(l_ptr + S_i_offset, li_new)
Some notes about this implementation :
qi = tl.load(q_ptr + offset_i) # (Bc, D) in the inner loop which is pretty inefficient.B x N_h blocks. This means that each block has to find the correct matrix offset in Q,K,V and compute a full of the output O for the batch and attention head. I immediately thought that we are wasting compute here as each row O can be computed separately from each other. But don’t worry I’ll come back to this in our V2.l and m as l = torch.zeros(B, N_h, S).cuda() ; m = torch.full((B, N_h, S), float("-inf")).cuda(). If you recall the memory access bit earlier, we are effectively loading rows using S_i_offset from HBM which isn’t great. What we would want is for each block to allocate these accumulators locally in SRAM and avoid allocating them entirely on HBM.Br = Bc = 32. This means that we are splitting Q and K and V into same size chunks. This isn’t the greatest idea because we might need to have different values for optimization reason. It is just another simplification I threw in for the implementation to be straightforward.Before looking at the profiling of this v1 naive implementation let’s first look at another requirement I brushed over quickly but that is really important: SRAM limit. Because we are living in the real world we can’t have infinite SRAM on GPUs (sigh! ..), we need to have a ballpark estimation of how much SRAM we need to allocate and if this fits into my device SRAM.
Quick back-of-the-enveloppe calculation, the kernel at minimum needs to store in SRAM simultaneously:
Bc × D floatsBc × D floatsBc × D floatsBc × Bc floats2 × Bc floatsFor B=10; Bc = 32; N_h = 64; S = 64; D_h = 32, Total SRAM: 2 × Bc + 3 × Bc × D + Bc² floats = 2 × 32 + 3 × 32 × 64 + 32² = 7,232 floats = 7,232 × 4 bytes ≈ 28KB. This fits comfortably in the ~48KB of shared memory available per thread block. It does mean, though, that we are limited to ~2 blocks per SM. If you’re confused about this limit, don’t worry, I’ll explain this promptly once we get into profiling.
As I mentioned in the Setup and hardware section, I’ll be using the Nsight Compute ncu tool from the start. It’s one of those tools that comes with a LOT of information, and it can be pretty overwhelming. I recommend watching this lecture to get a global overview of all the ncu-ui sections. I also highly recommend reading Modal’s GPU glossary performance section front to back to at least familiarize yourself with the key terms you need to understand GPU program performance, as well as gain a general understanding of how GPUs execute code. If you don’t know what a warp scheduler is, please go read that resource before continuing to the next section.
I ran the ncu profiler using this command to get the full set of timer and metrics:
#!/bin/bash
sudo CUDA_HOME=/opt/cuda /usr/local/NVIDIA-Nsight-Compute/ncu \
--set full \
--kernel-name "attn_kernel" \
--import-source yes \
-o profile_flash_attn_v1 \
-f .venv/bin/python3 kernels/triton_flash_att.py
Once the profiling is done, we can open it using ncu-ui, here is an annotated view of the profile once we open it:
Great, let’s drill down on the key metrics. First, we can see that our kernel took 166.47ms to execute. Our kernel used a grid size of (10, 64, 1) matchs our triton grid of (B, N_h). Each block has a size of (128, 1, 1). Because we are using triton to write our kernel, there is no way to specify the block size directly (although there is a num_warps param that we could use). This means that triton chose to use 4 warps per block to run the kernel.
Next, the profiler actually picked up of a very interesting problem that we identified earlier:
The 2.00 theoretical warps per scheduler this kernel can issue according to its occupancy are below the hardware maximum of 8. This kernel’s theoretical occupancy (25.0%) is limited by the required amount of shared memory.
To understand this occupancy problem a little bit more, ncu provides a quite handy occupancy calculator tool. Let’s look at what the tool says :

Because our shared memory requirement per block is quite high, we can only have 2 active blocks at a time per SM! Low occupancy really hurts performance when there aren’t enough warps to hide the latency. But once occupancy is sufficient for latency hiding, increasing it further can degrade performance. Higher occupancy reduces resources per thread, potentially bottlenecking the kernel on registers or reducing the arithmetic intensity.
To fix this we need to change our SRAM requirement: SRAM only depends on Bc and D_h. So we can either have smaller chunks (although this could impact memory bandwidth), or have more attention heads to reduce each D (if we control model architecture). Either way, we can go back to tweaking this later. Let’s look at the detail section of the profile report to see what other surprises await us!

The whole idea behind Flash Attention is to avoid going to main memory (HBM) for intermediate steps. The 11.58 GB reads and 5.54 GB writes confirm a bottleneck. Quick math: In this naive implementation, we iterate over columns (K, V) in the outer loop and rows (Q, O) in the inner loop. This loop order forces us to reload the Accumulator () and Query () from HBM for every single chunk of keys. With chunks, we are reading and writing the entire output matrix 64 times! Math: Size of . Reads . Writes .
for j in range(0, Tc):
# Load K_j, V_j from HBM
# ....
for i in range(0, Tr):
# ....
# -> Reading previous O_i (Bc,D)
prev_oi = tl.load(o_ptr + offset_i)
# -> Writing previous O_i (Bc,D)
tl.store(o_ptr + offset_i, oi_new)
We are treating HBM like a register, which explains the massive bandwidth usage!
The last thing to look at before starting to look at the potential solutions we can implement is the source from the profiler.

ncu marks problematic kernel code lines with a warning emoji ⚠️, which is quite helpful! The div operation is usually costly on CPUs, and it’s basically the same story for GPUs. CUDA implements floating-point division on GPUs via reciprocal + multiply ops: using MUFU (Multi-Function Unit) followed by one or two FFMA or FMUL instructions to refine precision. If we look at the SASS disassembly for this line:
MUFU.RCP R8, R8
...
@P5 FMUL R29, R29, 0.25
@P5 FMUL R15, R15, 0.25
@!P4 FMUL R10, R10, 16777216
@!P6 FMUL R29, R29, 16777216
@!P4 FMUL R23, R23, 16777216
@!P6 FMUL R15, R15, 16777216
FMUL R10, R9, R10
FMUL R29, R8, R29
FMUL R9, R9, R23
FMUL R8, R8, R15
Because we are doing the division inside the hot loop we are doing a lot more work here and we can start thinking about how we can avoid this division all together.
oi_new = (
alpha[:, None] * prev_li[:, None] * prev_oi
+ beta[:, None] * tl.dot(pij, vj)
) / li_new[:, None]
Based on the ncu profiling data and our analysis of the memory traffic, we have identified three critical bottlenecks to address in the next iteration (V2).
Invert the Loop Order: The massive 11.58 GB of main memory reads and 5.54 GB of writes is our biggest performance killer. It stems from treating HBM as a temporary register for our output accumulator. Current loop forces us to repeatedly read/write the accumulator and reload for every block of . We must invert the loops: Parallelize the kernel over Queries (rows) so that each thread block handles a tile of :
Defer Normalization: Before describing the fix, let’s recall the true attention output: . Inside the kernel, we currently divide during every iteration of the loop: . This normalization step does not need to happen each time. We will store the unnormalized accumulated numerator and denominator in registers throughout the loop. We will perform the normalization once, at the end of the kernel, just before writing the result to HBM.
Tuning Block Sizes: Theoretical occupancy is capped at 25% due to shared-memory pressure. Smaller B_c may reduce register/SRAM pressure and increase the number of active warps. But too small a tile can fragment memory access and reduce bandwidth efficiency. I’ll leave tweaking this value for later.
The v2 implementation (kernels/triton_flash_att_v2.py) makes critical changes for better performance:
@triton.jit
def attn_kernel(
Q, K, V, O,
S, stride_H, softmax_scale,
D: tl.constexpr, Tc: tl.constexpr, Bc: tl.constexpr
):
# ... offset setup
# Load query block ONCE at the start !!
qi = tl.load(
q_ptr + offset_i, mask=mask[:, None], other=0.0
) # shape (Bc,D)
# Block accumulator and running max in SRAM !!
prev_li = tl.zeros([Bc], dtype=tl.float32)
prev_mi = tl.zeros([Bc], dtype=tl.float32) - float("inf")
acc = tl.zeros([Bc, D], dtype=tl.float32)
for j in range(0, Tc):
# .. setup offset for K and V
# Load K_j, V_j from HBM to SRAM
kj = tl.load(k_ptr + offset_j) # shape(Bc,D)
vj = tl.load(v_ptr + offset_j) # shape(Bc,D)
# Compute Sij on SRAM : Q_i * K_j.T / sqrt(D)
Sij = tl.dot(qi, tl.trans(kj)) * softmax_scale # (Bc,Bc)
# accumulators
mij = tl.max(Sij, 1) # Rowmax(Sij): (Bc,)
pij = tl.exp(Sij - mij[:, None]) # (Bc,Bc)
lij = tl.sum(pij, 1) # (Bc,)
# Running maximum
mi_new = tl.maximum(prev_mi, mij)
# Compute scaling factors using previous_max
alpha = tl.exp(prev_mi - mi_new)
beta = tl.exp(mij - mi_new)
# Update running sum
li_new = prev_li * alpha + lij * beta
# Update the output block
acc = alpha[:, None] * acc + beta[:, None] * tl.dot(pij, vj)
prev_li = li_new
prev_mi = mi_new
# Divide by the last accumulated sum !
acc = acc / prev_li[:, None]
# Update in HBM
tl.store(o_ptr + offset_i, acc)
I also adjusted the grid configuration:
grid = lambda META: (triton.cdiv(S, META["Bc"]), B * N_h)
Where Dimension 0: Number of Q blocks = S / Bc and Dimension 1: Batch × Heads = B × N_h. To reiterate the main changes compared to v1:
B=10, N_h=64, S=1024, D_h=32 Bc=32, we launch 32 × 640 = 20,480 independent thread blocks.acc stays in fast registers, no main memory writes until the endGreat, let’s profile our newly crafted kernel and see what ncu tells us!

Well, it is faster that the v1 but only 6% faster. Let’s first look at the memory chart that was the biggest

Great! Rewriting reduced main memory reads to 412.18 MB (-92.98%). We are also writing 80MB corresponding exactly to the matrix size. But there is still a clear problem somewhere—we would expect a lot more speedup as we are now parallelizing 32x more. Let’s look at the occupancy; maybe it is somehow worse than v1 (sanity check here)?

We are at 63% occupancy and are limited by shared memory with ~12KB per block. So that’s clearly not the culprit. Let’s go back to the summary page of the profile; ncu picks up on these issues:
- Shared Load Bank Conflicts Est. Speedup: 63.57%: The memory access pattern for shared loads might not be optimal and causes on average a 6.3 - way bank conflict across all 293601280 shared load requests. This results in 1174579308 bank conflicts, which represent 63.64% of the overall 1845667948 wavefronts for shared loads. Check the Source Counters section for uncoalesced shared loads.
- Uncoalesced Shared Accesses Est. Speedup: 61.46%: This kernel has uncoalesced shared accesses resulting in a total of 1174405120 excessive wavefronts (62% of the total 1909063680 wavefronts). Check the L1 Wavefronts Shared Excessive table for the primary source locations.
- MIO Throttle Stalls Est. Speedup: 50.06%: On average, each warp of this workload spends 18.4 cycles being stalled waiting for the MIO (memory input/output) instruction queue to be not full. This stall reason is high in cases of extreme utilization of the MIO pipelines, which include special math instructions, dynamic branches, as well as shared memory instructions. When caused by shared memory accesses, trying to use fewer but wider loads can reduce pipeline pressure. This stall type represents about 50.7% of the total average of 36.2 cycles between issuing two instructions.
It seems to be a problem with how we access shared memory in our kernel. A lot of terms are thrown around here, so let’s first understand what bank conflicts, uncoalesced shared access, and L1 wavefronts shared load mean.
The profiler is screaming at us about bank conflicts and wavefronts. Before we can fix anything, we need to understand what’s actually happening at the hardware level. Let’s build intuition from the ground up.
Remember from the memory hierarchy section that shared memory (SRAM) lives on-chip, right next to the compute units. It’s blazingly fast (~10-20 TB/s) but there’s a catch: shared memory isn’t a single monolithic block. It’s physically divided into 32 memory banks that can be accessed simultaneously (remember how each warp has 32 threads).
Think of these banks like checkout lanes at a grocery store. If 32 customers each go to a different lane, everyone gets served in one “cycle.” But if multiple customers try to use the same lane, they have to wait in line - that’s a bank conflict.
The mapping from memory address to bank is straightforward:
For a float32 array in shared memory, consecutive elements land in consecutive banks:
data[0] → byte 0 → bank 0
data[1] → byte 4 → bank 1
...
data[31] → byte 124 → bank 31
data[32] → byte 128 → bank 0 ← wraps around!
This is intentional! It means that when threads in a warp access consecutive array elements (a very common pattern), each thread hits a different bank. Perfect parallelism !
Now here’s where things get spicy. Let’s look at which lines in our kernels this access get’s spicy. Thankfully ncu again show us the exact problematic line :

So it turns out, that this line is the culprit :
Sij = tl.dot(qi, tl.trans(kj)) * softmax_scale
But why though ? To understand why we are hitting shared memory bank conflict we’ll need to go a little bit deeper. How deep you ask? Well to the end of the world… But first, let’s continue understanding the remaining ncu reported issues.
The profiler keeps mentioning “wavefronts” - what are those? A wavefront (sometimes called a memory transaction) is a single memory operation that the hardware executes atomically. In the ideal case: 32 threads request shared memory, all hit different banks → 1 wavefront → 1 cycle
Now the profiler numbers start making sense:
| Metric | Value | Meaning |
|---|---|---|
| Shared load requests | 293,601,280 | Times our kernel asked for shared memory |
| Bank conflicts | 1,174,579,308 | Extra transactions due to serialization |
| Total wavefronts | 1,845,667,948 | Actual memory operations executed |
| Conflict rate | 63.64% | Almost 2/3 of bandwidth wasted! |
The 6.3-way average conflict means that on average, 6.3 threads are fighting for the same bank.
The profiler is telling us we’re doing 6x more memory transactions than necessary. No wonder we only got only 6% speedup despite reducing main memory traffic by 93%!
The profiler also complains about **uncoalesced shared accesses" with **62% excessive wavefronts**. This is essentially the same problem viewed from a different angle: Threads access scattered addresses. Our column reads are stride-D accesses, which the hardware cannot combine efficiently. The “excessive wavefronts” metric is counting how many extra transactions we’re issuing beyond the theoretical minimum.
The third warning - MIO Throttle Stalls (50.06%) - is a consequence of the first two. MIO (Memory Input/Output) is the pipeline that handles:
exp, log) (remember this for later)When we’re issuing 6x more shared memory transactions than necessary, the MIO pipeline gets clogged. Warps have to stall waiting for the pipeline to clear, which is what the 50% stall rate is measuring.
OK, let’s go back and drill down on this single line:
Sij = tl.dot(qi, tl.trans(kj)) * softmax_scale
Well, we can’t just look at it, so let’s see what the Triton compiler generated for it. Triton compiles kernels down to PTX, and you can very easily map back the PTX code to Python line code:
TRITON_CACHE_DIR="./triton_dump" python kernels/triton_flash_att_v2.py
awk '/\.loc.* 83 /{p=1; print; next} /\.loc/{p=0} p' triton_dump/[ID]/attn_kernel.ptx > attn_kernel_v2.ptx
Great! Let’s look at the assembly for the line triton_flash_att_v2.py:83:
.loc 1 83 34 // triton_flash_att_v2.py:83:34
bar.sync 0;
// We are storing stuff in shared memory `st.`
st.shared.v4.b32 [%r4], {%r100, %r101, %r102, %r103};
st.shared.v4.b32 [%r4+2048], {%r104, %r105, %r106, %r107};
st.shared.v4.b32 [%r4+4096], {%r108, %r109, %r110, %r111};
st.shared.v4.b32 [%r4+6144], {%r112, %r113, %r114, %r115};
// Here we are loading something from shared memory `ld.`
.loc 1 83 34 // triton_flash_att_v2.py:83:34
ld.shared.v4.b32 {%r388, %r389, %r390, %r391}, [%r8];
ld.shared.v4.b32 {%r392, %r393, %r394, %r395}, [%r8+256];
ld.shared.v4.b32 {%r396, %r397, %r398, %r399}, [%r8+16];
ld.shared.v4.b32 {%r400, %r401, %r402, %r403}, [%r8+272];
ld.shared.v4.b32 {%r404, %r405, %r406, %r407}, [%r8+32];
// ... A lot more ld.shared.v4.b32 instruction
// Then finally we do the acual matmul `tl.dot`
.loc 1 83 25 // triton_flash_att_v2.py:83:25
fma.rn.f32 %r516, %r132, %r388, 0f00000000;
fma.rn.f32 %r517, %r133, %r389, %r516;
fma.rn.f32 %r518, %r134, %r390, %r517;
// And this mul is the `.softmax_scale` we can ignore it
.loc 1 83 41 // triton_flash_att_v2.py:83:41
mul.f32 %r1028, %r25, %r579;
mul.f32 %r1029, %r25, %r643;
mul.f32 %r1030, %r25, %r707;
mul.f32 %r1031, %r25, %r771;
mul.f32 %r1032, %r25, %r835;
The storing PTX part is pretty straightforward: we are storing K from registers to shared memory in parallel with each of the 128 threads in the warp executing these instructions in parallel. st.shared.v4.b32 is a vectorized instruction, meaning we store 4 * 4bytes = 64bytes in one shot. Remember that we have of size (Bc, D), meaning each chunk is 32 x 64 = 8192 bytes. We are issuing 4 stores of 64 bytes per thread, and with 128 threads this amounts to exactly 8192 bytes. Important detail: we are storing K in row-major order in shared memory, with a row stride of 256 bytes. Here is how K is stored in shared memory with (a base offset of 0 to simplify):
Now, the next set of instructions are very interesting. ld.shared.v4.b32 is a vectorized load of 4x 4bytes wide from shared memory. Remember again, these instructions are executed in parallel by all threads in the warp, so to understand where the conflict happens, we need to look at what adresses each thread is loading from shared memory.
ld.shared.v4.b32 {%r388, %r389, %r390, %r391}, [%r8];
ld.shared.v4.b32 {%r392, %r393, %r394, %r395}, [%r8+256];
ld.shared.v4.b32 {%r396, %r397, %r398, %r399}, [%r8+16];
ld.shared.v4.b32 {%r400, %r401, %r402, %r403}, [%r8+272];
ld.shared.v4.b32 {%r404, %r405, %r406, %r407}, [%r8+32];
Looking at the pattern, it seems that we are doing multiple loads with an offset from a base register %r8, so to know from which memory address we are effectively loading, we need to figure out what this register holds for each thread. A quick grep across the PTX codebase and we find this:
mov.u32 %r1, %tid.x; // Thread ID (0-127)
and.b32 %r77, %r1, 15; // tid & 0xF (0-15)
shl.b32 %r93, %r77, 9; // shift left by 9
add.s32 %r8, %r91, %r93; // %r8 = smem_base + lane_group * 512
Nice! I thought it would be way more tedious. But let’s break down what these instructions are doing. Thankfully, we are in 2025 and LLMs exist to explain PTX; one prompt to Gemini later:
%r1 = %tid.x: Set %r1 register to thread ID within the block (0 to 127)%r77 = tid & 15: Extract low 4 bits (values 0-15, repeats for lanes 16-31)%r93 = %r77 << 9: Multiply by 512 (= 2^9)%r8 = smem_base + lane_group*512 = base + %r93: Final shared memory addressThis means that for each thread in the warp we have:
Key insight: Due to masking
tid & 15, only 16 unique base addresses exist. Lanes 16-31 duplicate lanes 0-15.
So within a single warp of 32 threads, threads 0-15 get unique base addresses (0×512, 1×512, … 15×512) and threads 16-31 get duplicate addresses (same as lanes 0-15)! But bear in mind, we are not doing 2x the loading. When two lanes access the same address, the hardware broadcasts—it, fetches once and delivers to both. No conflict here. Let’s dig deeper to see where the conflict occurs.
Now that we have the base pointer address, let’s see what each thread loads. The PTX shows that we executed 32 loads per thread to get data for the matrix multiply:
ld.shared.v4.b32 [%r8]; // 4 floats at base + 0
ld.shared.v4.b32 [%r8+256]; // 4 floats at base + 256 (next row)
ld.shared.v4.b32 [%r8+16]; // 4 floats at base + 16 (next 4 cols)
ld.shared.v4.b32 [%r8+272]; // 4 floats at base + 272 (next row, next 4 cols)
ld.shared.v4.b32 [%r8+32]; // 4 floats at base + 32
ld.shared.v4.b32 [%r8+288]; // 4 floats at base + 288
ld.shared.v4.b32 [%r8+48]; // ...
ld.shared.v4.b32 [%r8+304];
ld.shared.v4.b32 [%r8+64];
ld.shared.v4.b32 [%r8+320];
ld.shared.v4.b32 [%r8+80];
ld.shared.v4.b32 [%r8+336];
ld.shared.v4.b32 [%r8+96];
ld.shared.v4.b32 [%r8+352];
ld.shared.v4.b32 [%r8+112];
ld.shared.v4.b32 [%r8+368];
... (32 total loads)
The offsets follow an interleaved pattern:
Let’s draw over our previous shared memory schema to see which addresses of K are loaded by which threads and to clearly visualize this interleaved access pattern:
And there is the bank conflict right there!
Result: 16 requests to banks 0,1,2,3. Hardware serializes these into 16 phases. This means we have a theorical efficiency in this section of: 1/16 = 6.25% Waste: 15/16 = 93.75% !!
One thing we could do is set D = D+1 and see what would happen. Theoretically, the off-by-one stride would spread accesses across all banks (This is the idea behind padding - adding a dummy column to break the stride alignment).
Unfortunately, you can’t set D to non power of 2. So we end up doubling the dim. The padding breaks bank conflict but it ends up doing a lot more loads and being way slower than the version.
The implementation is kernels/triton_flash_att_v2_padded.py is very close to v2 with the addition to a function that computes the padded head_dim
def compute_padded_headdim(D_h):
"""
Compute padded head dimension to avoid bank conflicts.
The doubling provides extra padding that naturally breaks stride-based conflicts.
With stride=64 instead of 32, threads access different bank groups.
"""
# Find next power of 2
if D_h <= 0:
return 1
# Check if already power of 2
if (D_h & (D_h - 1)) == 0:
# Already power of 2, double it
return D_h * 2
else:
# Round up to next power of 2
return 1 << (D_h - 1).bit_length()
We can see that bank conflicts got down to 3.4 - way bank conflict across all 21135360 shared store requests. So I guess this worked 😂..
Back to the drawing board, Now that we understand the problem, we have several options:
(Bc, D+1) in shared memory but only use D columns - breaks the stride alignment -> works, but additional work makes it slower.In the next section, I’ll implement the transpose option and see how much performance we can claw back.
The v2 transpose kernel kernels/triton_flash_att_v2_transpose.py seems a little bit like cheating—we transpose the matrix beforehand. A small but important detail is to make sure that the transpose isn’t just a view and force it to be contiguous in memory:
k_trans = k.transpose(-1, -2).contiguous() # IMPORTANT!
attn_kernel[grid](
q, k_trans, v, o,
S,
q.stride(1), stride_k_d, stride_k_s,
1 / math.sqrt(D_h),
D_h, Tc, Bc,
)
Now the only thing that changes is making sure we load the correct columns in the kernel. Note that I hoisted the softmax_scaling from the inside loop to save mul ops
@triton.jit
def attn_kernel(
Q, K, V, O,
S,
stride_H,
stride_k_d,
stride_k_s,
softmax_scale,
D: tl.constexpr,
Tc: tl.constexpr,
Bc: tl.constexpr,
):
# ...
# Pre-scale Q to save muls inside loop
qi = qi * softmax_scale
for j in range(0, Tc):
# Pointer Math: (Row_Idx * Stride_Row) + (Col_Idx * Stride_Col)
# Row_Idx is offs_d (0..D)
# Col_Idx is current_cols (S dimension)
offset_j_k = (offs_d[:, None] * stride_k_d) + (current_cols[None, :] * stride_k_s)
# Load K (D, Bc) directly!
kj = tl.load(k_ptr + offset_j_k)
# ...
Sij = tl.dot(qi, kj) # no transpose needed for kj !
# ...
Let’s profile it and look at the result in comparison to the previous two kernels:

Another key metric to track (we haven’t until now because we needed to fix obvious stuff), is to look at the GPU speed of light throughput. This is also known as roofline analysis.

Our v2 transpose kernel has way higher arithmetic intensity thanks to removing shared memory access conflicts. This can also be corroborated by looking at the warp scheduler statistics. We can see +153% more eligible warps per cycle as they aren’t stalled by serial shared memory access.
Great! Kernel ran in 34ms which is +145% improvement compared to the v1! Looking at ncu performance opportunities we can see that the next big blocker is :
Mio Throttle Stalls Est. Speedup: 43.97%: On average, each warp of this workload spends 6.7 cycles being stalled waiting for the MIO (memory input/output) instruction queue to be not full. This stall reason is high in cases of extreme utilization of the MIO pipelines, which include special math instructions, dynamic branches, as well as shared memory instructions. When caused by shared memory accesses, trying to use fewer but wider loads can reduce pipeline pressure. This stall type represents about 44.0% of the total average of 15.3 cycles between issuing two instructions.
So we fixed bank conflicts and got a nice 145% speedup. What’s left? The MIO (Memory Input/Output) pipeline is now our bottleneck. Despite the name, this isn’t about main memory - MIO handles two things:
qi, kj, vj tilestl.exp, tl.max, tl.logEvery iteration of our inner loop calls tl.exp for the softmax and tl.max for numerical stability. These operations go through the SFU (Special Function Unit), which is much slower than the main FMA units. With Bc=32, we’re doing these expensive operations very frequently relative to the actual matmul work.
I tried a few things to reduce MIO pressure:
FP16: Wrote an FP16 kernel hoping for a speedup by lowering shared memory requirements and increasing D_h and Bc. Ended up slower than FP32 because I kept fighting the compiler—it kept inserting extra shared memory accesses for type conversions, and I didn’t get any SRAM savings because FP16 precision sucks for accumulators and I needed to keep all the on-chip accumulators in FP32 anyway.
Larger block sizes: Increasing Bc means fewer loop iterations, so fewer exp/max calls per output element. But my RTX 2070 doesn’t have enough shared memory to go much higher without killing occupancy.
Anyway, the main take here is to try to reduce MIO stalls by having larger block reads and reducing all special function math instructions.
Looking at the instruction stats, something else jumped out:

The kernel is dominated by FFMA (Fused Floating-point Multiply-Add) instructions. These are regular CUDA core operations, not tensor core operations. Tensor cores can do 4x4 matrix multiplies in a single cycle—they’re the whole reason modern GPUs are so fast at deep learning. But my tl.dot calls aren’t using them for some reason!
I fought with this for a while. Tried different memory layouts, alignment hints, different block sizes… Nothing worked. Turns out, on SM 7.5 (Turing architecture), Triton struggles to generate tensor core code. The compiler just falls back to regular FMA instructions which run on regular old CUDA cores, clogging up the pipe.
I guess I need to stop being poor and buy a newer GPU. More realistically, I’ll rent an H100 and see if the same code magically starts using tensor cores on SM 9.0…
At this point I’ve hit the limits of what I can optimize on my hardware.
Let’s open up the actual Flash Attention v2 paper4 and see what we got right and what we missed:
1. Reducing Non-Matmul FLOPs ✅ (partially)
We did defer the final division to the end of the kernel. But FA2 goes further - it restructures the entire online softmax to minimize exp and max operations. GPUs have separate units for matmul (Tensor Cores, ~300 TFLOPs on A100) vs generic math (CUDA cores, ~20 TFLOPs). Every non-matmul operation is 15x slower, so minimizing them matters a lot.
2. Parallelism Over Sequence Length ✅
We nailed this one. Our v2 kernel parallelizes over (S/Bc, B*N_h) instead of just (B, N_h). This is exactly what FA2 does - split the query sequence into chunks and assign them to different thread blocks.
3. Warp-Level Work Partitioning ⛔
This is where Triton abstracts too much away. FA2 carefully controls how warps within a block divide work: instead of “Split-K” (warps split the K/V dimension and sync to combine results), they use “Split-Q” (warps split the Q dimension and work independently). This removes expensive synchronization barriers.
In Triton, we don’t control warp-level scheduling - the compiler decides. We could inspect the generated PTX to see what it’s doing, but we can’t easily change it.
4. Larger Head Dimensions 🤷
FA2 supports D_h up to 256 efficiently. Our kernel works with any head dimension, but we haven’t optimized the tiling specifically for larger values. With D_h=128 or 256, you’d want different block sizes and potentially different memory layouts. For implementation simplicity reasons, I kept Bc==Br, so we haven’t gotten to play with tweaking these parameters.
Don’t get it twisted—Flash Attention is a brilliant example of algorithm-hardware co-design. My goal here wasn’t to diminish the work of Tri Dao, but to walk through the reasoning myself. Every brilliant solution seems kind of obvious once you have it in front of you, but having the insight and technical chops to come up with it in the first place is a whole other thing.
I tried my best to demystify the core ideas behind Flash Attention, and show how understanding the GPU memory hierarchy lets you optimize iteratively: profile, rewrite, tweak, rinse and repeat.
If you’ve fallen asleep three times reading this and just woke up, here are the key takeaways:
There’s more to explore - FA3 brings asynchronous memory copies and FP8 support, FA4 pushes things even further with Blackwell-specific optimizations. But that’s for another post (and another GPU).
The full implementation is available in this repository, including the Triton kernels and profiling scripts.