MoonMath AI team has released a bf16 forward attention kernel for AMD’s MI300X GPU. It is written in HIP, not hand-written assembly. The code is open-source under the MIT license. The MoonMath.ai team reports it beats AITER v3, AMD’s own optimized kernel, on every tested shape. Bare-metal access came from HotAisle, an AMD cloud provider.
Attention is the fused softmax(QKᵀ/√d)·V operation inside every transformer. The MI300X is AMD’s CDNA3 data-center GPU, with the ISA target (gfx942). This kernel runs on that hardware only.
TL;DR
- MoonMath.ai open-sources a bf16 forward attention kernel for AMD MI300X, written in HIP, not assembly (MIT).
- It beats AMD’s AITER v3 on every shape and rounding mode — geomean 1.18×/1.15×/1.08×, up to 1.26×.
- The core trick: one-instruction asm wrappers let you pick the opcode while the compiler allocates registers.
- Most of the speedup is memory placement — K in LDS, V hot in L1, Q and accumulators in registers.
- A real SGLang PR used it to speed up Wan2.1 video diffusion by 1.23×, with no quality regression.
Understanding Kernel
A kernel is a small program that runs directly on the GPU’s many cores to perform one specific computation—here, the attention math—as fast as the hardware allows. The kernel computes forward attention in bf16 on MI300X only. It takes inputs in either BSHD or BHSD layout, with no transpose. Head dimension is fixed at 128. It supports any sequence length, including cross-attention.
There are real limits. There is no causal mask, no GQA, and no varlen batching. Outputs are bf16, and it runs on gfx942 hardware exclusively.
Numerics are tightly controlled. All three rounding modes match AITER’s per-mode rounding rule. Every finite output sits within 1 bf16 ULP of AITER. NaN and Inf handling is bit-identical, and results are deterministic.
The Core Trick: One-Instruction asm Wrappers
The core technique avoids a familiar dilemma. Compiler intrinsics keep code tidy but let the compiler reorder or rename operands. Raw inline assembly gives control but forces manual register and address management.
MoonMath wraps exactly one instruction in a __device__ __forceinline__ function. Extended asm constraints describe the operands. The research team picks the opcode. The compiler still allocates registers and tracks data flow.
// in/out tied to the SAME VGPR → no accumulator rename, no v_mov copy.
__device__ __forceinline__ void asm_mfma(bf16x4_t a, bf16x4_t b, fp32x4_t& c) {
asm volatile("v_mfma_f32_16x16x16_bf16 %0, %1, %2, %0"
: "+v"(c) : "v"(a), "v"(b));
}
The "+v"(c) constraint ties the accumulator input and output to the same VGPR. No copy instruction is emitted. This keeps the kernel close to ordinary HIP. It still steers the machine one instruction at a time.
The Architecture: Eight Waves, Two Groups, Two Barriers
A CDNA3 compute unit has four SIMD units. The textbook block is four waves. MoonMath instead runs eight waves per block, in two groups of four.
The two groups run the same Q*K, softmax, O += P*V sequence. They are offset by a phase. While one group saturates the matrix core, the other runs softmax and issues loads. Then they swap, so the matrix core never idles.
There are two s_barriers per iteration. One sits at the phase handoff. One sits at the iteration boundary. Per-counter waits handle the rest of the synchronization.
This echoes FlashAttention-3’s matmul and softmax alternation. It does not copy FA3’s producer and consumer warp split. On CDNA3, every memory move is already asynchronous, so a dedicated producer wave is unnecessary.
Where Data Lives, and Why 16×16×16
Most of the speedup comes from memory placement. K streams from HBM into LDS, double-buffered, shared by all eight waves. V stays hot in L1, read on every PV matmul. Q and accumulators live in registers.
The research team picked the 16×16×16 MFMA over 32×32×8. Both shapes have identical throughput. The smaller tile accumulates into 4 fp32 elements per lane, against 16. Lower accumulator pressure leaves room for deeper prefetch and a third Q tile.
| Decision | Choice | Reason |
|---|---|---|
| Waves per block | 8 (two groups of 4) | Plan the pipeline directly; share one K copy |
| MFMA shape | 16×16×16 bf16 | Same throughput, lower VGPR pressure, better power efficiency |
| K placement | LDS, double-buffered, 32 KiB | Shared by all 8 waves, swapped per iteration |
| V placement | L1, resident, prefetched | Reread across PV, kept hot deliberately |
| Q + accumulators | VGPRs | Read every iteration, never reloaded |
Two later wins close the gap. A third Q tile (3Q) raises data reuse per loaded K and V tile. A Flash-Decoding-style tail KV split rescues the stranded fractional round across MI300X’s 304 CUs. These wins cascade. Moving V to L1 freed the LDS that the third Q tile then fills.
Benchmark
Tests ran on MI300X in bf16, head dimension 128. Each shape was measured at three rounding modes. RTNE rounds to nearest even. RTNA rounds to nearest, ties away from zero. RTZ truncates toward zero.
| Shape (B, H, S, D) | Round | Ours (ms) | AITER v3 (ms) | vs AITER | vs MAX |
|---|---|---|---|---|---|
| (2, 24, 8192, 128) | RTNE | 3.083 | 3.792 | 1.23× | 1.37× |
| (2, 24, 16384, 128) | RTNE | 11.670 | 14.691 | 1.26× | 1.54× |
| (4, 16, 16384, 128) | RTZ | 15.055 | 16.183 | 1.07× | 1.47× |
| (2, 24, 32768, 128) | RTNA | 44.440 | 52.363 | 1.18× | 1.57× |
| (1, 16, 131072, 128) | RTNE | 232.517 | 269.278 | 1.16× | 1.46× |
Geomeans across the sweep favor MoonMath. Versus AITER, it scores 1.18× (RTNE), 1.15× (RTNA), and 1.08× (RTZ). Versus Modular MAX, geomeans run 1.44× to 1.49×, and per-shape speedups reach 1.59×.
RTZ is AITER’s own fastest mode and the tightest race. The (4, 16, 16384) RTZ shape moved from 0.95× to 1.07×. The tail KV split is what closed that final gap.
Interactive Explainer
<div class="chart" id="mm-chart">
<div class="row">
<div class="name">MoonMath<small>HIP kernel</small></div>
<div class="track"><div class="fill ours" id="f-ours"><span class="ms" id="m-ours"></span></div></div>
<div class="spd">baseline<small>this kernel</small></div>
</div>
<div class="row">
<div class="name">AITER v3<small>AMD, ASM</small></div>
<div class="track"><div class="fill aiter" id="f-aiter"><span class="ms" id="m-aiter"></span></div></div>
<div class="spd" id="s-aiter"></div>
</div>
<div class="row">
<div class="name">Modular MAX<small>RTNE internally</small></div>
<div class="track"><div class="fill max" id="f-max"><span class="ms" id="m-max"></span></div></div>
<div class="spd" id="s-max"></div>
</div>
<div class="legend">
<span><i class="li-ours"></i>MoonMath (ours)</span>
<span><i class="li-aiter"></i>AITER v3</span>
<span><i class="li-max"></i>Modular MAX</span>
</div>
</div>
<p class="note" id="mm-bnote"></p>
</div>
<div class="panel" data-panel="pipe">
<p class="panel-lead">Eight waves run as two groups of four. Each K-block iteration has two phases. While one group saturates the matrix core, the other does softmax and issues the next loads. Two <code>s_barrier</code>s bound each iteration.</p>
<div class="pipewrap">
<div class="barbar">
<div></div>
<div class="tl"><div class="ph" id="ph1">Phase 1</div><div class="ph" id="ph2">Phase 2</div></div>
</div>
<div class="grouprow">
<div class="glabel">Group A<small>4 waves</small></div>
<div class="lane"><div class="slot" id="a0"></div><div class="slot" id="a1"></div></div>
</div>
<div class="grouprow">
<div class="glabel">Group B<small>4 waves</small></div>
<div class="lane"><div class="slot" id="b0"></div><div class="slot" id="b1"></div></div>
</div>
<div class="transport">
<div class="glabel" style="align-self:center">Memory<small>async</small></div>
<div class="pipe">
<div class="seg2" id="t0">K: HBM→LDS</div>
<div class="seg2" id="t1">V prefetch→L1</div>
</div>
</div>
<div class="barrier" id="brk">— s_barrier at phase handoff · s_barrier + lds_barrier at iteration boundary —</div>
</div>
<div class="pctl">
<button class="btn primary" id="mm-play">► Play</button>
<button class="btn" id="mm-step">Step ›</button>
<span class="iter" id="mm-iter">iteration N · phase 1</span>
<span class="speed">slow<input type="range" id="mm-spd" min="350" max="1600" value="950" step="50">fast</span>
</div>
<p class="note">The matrix core (black slots) never idles: as Group A finishes its PV/QK MFMAs, Group B has already taken over the matrix core while A moves to softmax and prefetch. This mirrors FlashAttention-3’s matmul/softmax alternation, but without a dedicated producer warp — every load here is already asynchronous on CDNA3.</p>
</div>
<div class="panel" data-panel="mem">
<p class="panel-lead">The kernel is one long argument about which level of the MI300X memory hierarchy holds what. Tap any item to see why it lives there. Per-CU capacities are shown under each level.</p>
<div class="mem-grid">
<div class="cell">
<h4>VGPRs</h4><div class="cap">512 KiB register file / CU</div>
<button class="chip" data-k="q" aria-pressed="true"><b>Q tile</b><span>persistent</span></button>
<button class="chip" data-k="acc"><b>scores · O</b><span>fp32 accumulators</span></button>
</div>
<div class="cell">
<h4>LDS</h4><div class="cap">64 KiB shared / CU</div>
<button class="chip" data-k="k"><b>K tile</b><span>double-buffered, 32 KiB</span></button>
<button class="chip" data-k="q3"><b>3rd Q tile</b><span>32 KiB, streamed</span></button>
</div>
<div class="cell">
<h4>L1</h4><div class="cap">32 KiB vector cache / CU</div>
<button class="chip" data-k="v"><b>V_t tile</b><span>resident, prefetched</span></button>
</div>
<div class="cell">
<h4>HBM / L2</h4><div class="cap">32 MiB L2 · 256 MiB Infinity</div>
<button class="chip" data-k="src"><b>K / V source</b><span>streamed in</span></button>
</div>
</div>
<div class="detail" id="mm-detail">
<div class="dt"></div><div class="db"></div>
</div>
</div>
<div class="ft">
<span>Source: <a href="https://moonmath.ai/cdna3attention/" target="_blank" rel="noopener">moonmath.ai/cdna3attention</a> · MIT-licensed kernel</span>
<span><b>Marktechpost</b> · interactive explainer</span>
</div>
<script>
(function(){
var root=document.getElementById(‘mm-cdna3-demo’);
function $(s){return root.querySelector(s);}
function $all(s){return Array.prototype.slice.call(root.querySelectorAll(s));}
/* —- data: verbatim from MoonMath published table —- */
var DATA=[
{s:"(2, 24, 8192, 128)", rtne:[3.083,3.792,4.237], rtna:[3.022,3.605,4.237], rtz:[2.983,3.303,4.237]},
{s:"(2, 24, 16384, 128)", rtne:[11.670,14.691,17.923],rtna:[11.479,13.801,17.923],rtz:[11.385,12.629,17.923]},
{s:"(1, 32, 16384, 128)", rtne:[8.013,9.031,11.030], rtna:[7.828,8.656,11.030], rtz:[7.731,7.989,11.030]},
{s:"(4, 16, 16384, 128)", rtne:[15.591,18.337,22.061],rtna:[15.331,17.567,22.061],rtz:[15.055,16.183,22.061]},
{s:"(1, 64, 16384, 128)", rtne:[15.528,18.333,22.763],rtna:[15.239,17.535,22.763],rtz:[15.040,16.161,22.763]},
{s:"(2, 24, 32768, 128)", rtne:[46.002,54.794,69.947],rtna:[44.440,52.363,69.947],rtz:[44.075,48.549,69.947]},
{s:"(2, 16, 65536, 128)", rtne:[117.612,136.301,171.273],rtna:[115.550,130.278,171.273],rtz:[114.665,121.668,171.273]},
{s:"(2, 8, 86016, 128)", rtne:[101.071,118.939,141.319],rtna:[100.165,114.515,141.319],rtz:[99.397,106.513,141.319]},
{s:"(1, 16, 131072, 128)",rtne:[232.517,269.278,339.322],rtna:[228.475,258.092,339.322],rtz:[226.152,239.587,339.322]}
];
var round=’rtne’, shapeIdx=0;
var sel=$(‘#mm-shape’);
DATA.forEach(function(d,i){var o=document.createElement(‘option’);o.value=i;o.textContent=d.s;sel.appendChild(o);});
function fmt(x){return x.toFixed(x<10?3:(x<100?2:1));}
function drawBench(){
var d=DATA[shapeIdx], row=d[round];
var ours=row[0],aiter=row[1],mx=row[2];
var maxv=Math.max(ours,aiter,mx);
$(‘#f-ours’).style.width=(ours/maxv*100)+’%’;
$(‘#f-aiter’).style.width=(aiter/maxv*100)+’%’;
$(‘#f-max’).style.width=(mx/maxv*100)+’%’;
$(‘#m-ours’).textContent=fmt(ours)+’ ms’;
$(‘#m-aiter’).textContent=fmt(aiter)+’ ms’;
$(‘#m-max’).textContent=fmt(mx)+’ ms’;
$(‘#s-aiter’).innerHTML=(aiter/ours).toFixed(2)+’×<small>vs AITER</small>’;
$(‘#s-max’).innerHTML=(mx/ours).toFixed(2)+’×<small>vs MAX</small>’;
var rn={rtne:’rounds to nearest even’,rtna:"rounds to nearest, ties away (AITER’s default)",rtz:"truncates toward zero (AITER’s fastest mode)"};
$(‘#mm-bnote’).textContent=’Shape ‘+d.s+’ · ‘+round.toUpperCase()+’ ‘+rn[round]+’. MoonMath is ‘+(aiter/ours).toFixed(2)+’× faster than AITER v3 and ‘+(mx/ours).toFixed(2)+’× faster than Modular MAX on this configuration.’;
}
sel.addEventListener(‘change’,function(){shapeIdx=+this.value;drawBench();});
$all(‘#mm-round button’).forEach(function(b){
b.addEventListener(‘click’,function(){
$all(‘#mm-round button’).forEach(function(x){x.setAttribute(‘aria-pressed’,’false’);});
b.setAttribute(‘aria-pressed’,’true’);round=b.getAttribute(‘data-r’);drawBench();
});
});
/* —- tabs —- */
$all(‘.tab’).forEach(function(t){
t.addEventListener(‘click’,function(){
$all(‘.tab’).forEach(function(x){x.setAttribute(‘aria-selected’,’false’);});
t.setAttribute(‘aria-selected’,’true’);
var id=t.getAttribute(‘data-tab’);
$all(‘.panel’).forEach(function(p){p.classList.toggle(‘on’,p.getAttribute(‘data-panel’)===id);});
reportHeight();
});
});
/* —- pipeline animation —- */
// Fixed roles per phase; the highlight (current phase column) moves.
// Phase 1: A on matrix core (PV/QK); B on memory+softmax (K load).
// Phase 2: roles swap — B on matrix core; A on memory+softmax (V prefetch).
// So the matrix-core "baton" passes A -> B across the two phase columns.
var phase=0, iter=’N’, playing=false, timer=null, spd=950;
// [phase1 cell, phase2 cell] for each lane — content is fixed
var CELLS={
a:[{c:’mc’,t:’PV · QK’,tag:’matrix core’},{c:’mem’,t:’softmax · V→L1′,tag:’memory’}],
b:[{c:’mem’,t:’K→LDS · softmax’,tag:’memory’},{c:’mc’,t:’PV · QK’,tag:’matrix core’}],
t:[{t:’K: HBM→LDS’},{t:’V prefetch→L1′}]
};
function setSlot(id,cell){var el=$(‘#’+id);el.className=’slot role ‘+cell.c;el.innerHTML='<b>’+cell.t+'</b><span class="tag">’+cell.tag+'</span>’;}
function paintStatic(){
setSlot(‘a0’,CELLS.a[0]);setSlot(‘a1’,CELLS.a[1]);
setSlot(‘b0’,CELLS.b[0]);setSlot(‘b1’,CELLS.b[1]);
$(‘#t0’).innerHTML=CELLS.t[0].t;$(‘#t1’).innerHTML=CELLS.t[1].t;
}
function paintPipe(){
// phase 0 -> col0 active (a0,b0,t0); phase 1 -> col1 active (a1,b1,t1)
var act=phase, dim=phase^1;
[[‘a0′,’a1’],[‘b0′,’b1’]].forEach(function(pair){
$(‘#’+pair[act]).classList.remove(‘dim’);$(‘#’+pair[act]).classList.add(‘active’);
$(‘#’+pair[dim]).classList.add(‘dim’);$(‘#’+pair[dim]).classList.remove(‘active’);
});
[‘t0′,’t1’].forEach(function(id,i){$(‘#’+id).className=’seg2’+(i===act?’ flow’:’ dim’);});
$(‘#ph1′).className=’ph’+(phase===0?’ cur’:”);
$(‘#ph2′).className=’ph’+(phase===1?’ cur’:”);
$(‘#mm-iter’).textContent=’iteration ‘+iter+’ · phase ‘+(phase+1)+(phase===0?’ — Group A on matrix core’:’ — Group B on matrix core’);
}
function adv(){
phase++;
if(phase>1){phase=0;iter=(iter===’N’)?’N+1′:(iter===’N+1′?’N+2′:’N’);}
paintPipe();
}
$(‘#mm-step’).addEventListener(‘click’,function(){stop();adv();});
function play(){playing=true;$(‘#mm-play’).innerHTML=’❚❚ Pause’;$(‘#mm-play’).classList.remove(‘primary’);
timer=setInterval(adv,spd);}
function stop(){playing=false;$(‘#mm-play’).innerHTML=’► Play’;$(‘#mm-play’).classList.add(‘primary’);
if(timer){clearInterval(timer);timer=null;}}
$(‘#mm-play’).addEventListener(‘click’,function(){playing?stop():play();});
$(‘#mm-spd’).addEventListener(‘input’,function(){spd=1950-(+this.value);if(playing){stop();play();}});
paintStatic();paintPipe();
/* —- memory map —- */
var MEM={
q:{t:’Q tile — VGPRs, persistent’,b:’The Q tile is read every iteration and never reloaded, so it stays resident in the vector register file. Two of three Q tiles per wave stay register-resident and hot.’},
acc:{t:’scores · O — fp32 accumulators in VGPRs’,b:’Matrix-core outputs (the score matrix and the running output) never leave registers until the final store. The 16×16×16 MFMA accumulates into just 4 fp32 elements per lane, keeping accumulator pressure low.’},
k:{t:’K tile — LDS, double-buffered, 32 KiB’,b:’One copy of K is shared by all eight waves and swapped per iteration via a double buffer. K streams from HBM straight into LDS by direct DMA, never passing through a VGPR. An XOR swizzle breaks bank conflicts with zero padding.’},
q3:{t:’3rd Q tile — LDS, 32 KiB, streamed’,b:’Moving V to L1 freed 32 KiB of LDS. The kernel spends it on a third Q tile (48 q-rows per wave). It is parked in LDS and streamed through a ping-pong buffer during the QK matmul, raising K/V reuse.’},
v:{t:’V_t tile — L1, resident’,b:’The pre-transposed V tile is kept hot in L1 and reread on every PV matmul. L1 is not addressable, so residency is engineered by prefetching the next iterationu2019s lines into a throwaway register — the data lands in L1 as a side effect.’},
src:{t:’K / V source — HBM, staged via L2′,b:’A head-first chiplet swizzle maps all of a (batch, head)u2019s Q blocks onto a single XCD, so its K and V stay resident in that XCDu2019s slice of L2 instead of thrashing across all eight.’}
};
function showMem(k){
$all(‘.chip’).forEach(function(c){c.setAttribute(‘aria-pressed’, c.getAttribute(‘data-k’)===k ? ‘true’:’false’);});
$(‘#mm-detail .dt’).textContent=MEM[k].t;
$(‘#mm-detail .db’).textContent=MEM[k].b;
reportHeight();
}
$all(‘.chip’).forEach(function(c){c.addEventListener(‘click’,function(){showMem(c.getAttribute(‘data-k’));});});
/* —- init —- */
drawBench();showMem(‘q’);
/* —- auto-resize for WordPress embed —- */
function reportHeight(){
var h=root.offsetHeight+40;
if(window.parent){window.parent.postMessage({type:’mm-cdna3-height’,height:h},’*’);}
}
window.addEventListener(‘load’,reportHeight);
window.addEventListener(‘resize’,reportHeight);
setTimeout(reportHeight,300);setTimeout(reportHeight,900);
})();
</script>
</div></body></html>”>
Use Cases
The kernel installs with pip and exposes a small API. It launches on the caller’s stream, so it overlaps inside larger pipelines.
import torch
import moonmath_attention as ma
# PyTorch's ROCm build uses the "cuda" device string on AMD GPUs
q = torch.randn(2, 8192, 24, 128, dtype=torch.bfloat16, device="cuda")
k = torch.randn(2, 8192, 24, 128, dtype=torch.bfloat16, device="cuda")
v = torch.randn(2, 8192, 24, 128, dtype=torch.bfloat16, device="cuda")
out = ma.forward(q, k, v, layout="bshd")
out_rtz = ma.forward(q, k, v, layout="bshd", round_mode="rtz")
One concrete use case is video diffusion. The team added LiteAttention support and sent a PR to SGLang diffusion. On Wan2.1-T2V-1.3B-Diffusers, they switched attention from AITER to liteattention_rocm. End-to-end generation improved by 1.23× on MI300X, with no visible quality regression.
The BSHD layout suits diffusion tensors directly. Cross-attention works with any KV length and no padding.
Key Takeaways
- The kernel is bf16 forward attention for MI300X, written in HIP under MIT.
- It beats AITER v3 on every shape and rounding mode, geomean 1.18×/1.15×/1.08×.
- One-instruction asm wrappers give opcode control while the compiler allocates registers.
- Memory placement drove most of the gain: K in LDS, V hot in L1, Q in registers.
- A real SGLang PR sped up Wan2.1 video diffusion by 1.23× with no quality regression.
Check out the Technical details. Also, feel free to follow us on Twitter and don’t forget to join our 150k+ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.
Need to partner with us for promoting your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar etc.? Connect with us
The post MoonMath AI Open-Sources a HIP Attention Kernel for AMD MI300X That Beats AITER v3 on Every Shape and Rounding Mode appeared first on MarkTechPost.