Last week, I wrote a tail-call interpreter using the become
keyword, which was
recently added to nightly Rust
(seven months ago is recent, right?).
It was a surprisingly pleasant experience, and the resulting VM outperforms both my previous Rust implementation and my hand-coded ARM64 assembly. Tailcall-based techniques have been all the rage recently (see this overview); consider this my trip report implementing a simple but non-trivial system.
For those keeping track at home, this is the latest in my exploration of high-performance emulation of the Uxn CPU, which runs a bunch of applications in the Hundred Rabbits ecosystem.
If you want to read the whole saga, here's the list:
- Project writeup for Raven, my original Rust implementation
- Beating the compiler, in which I write an ARM64 assembly implementation which outperforms my Rust code
- Guided by the beauty of our test suite, in which I revisit the code a year later and improve its testing and CI
- An x86-64 backend for
raven-uxn, in which I port the assembly implementation from ARM64 to x86 (with the help of Claude Code)
Experimenting with LLMs proved controversial, which wasn't a surprise; I'm pleased to declare that all of the tail-call code is human-written, and the new backend can be used as a substitute for the x86 assembly backend at a minor performance penalty.
(This blog post is also entirely human-written, per my personal standards)
The next few sections summarize previous work, so feel free to skim them if you've done the reading and jump straight to tailcalls in Rust.
Basics of Uxn emulation
Uxn is a simple stack machine with 256 instructions. The whole CPU has just over 64K of space, split between a few memories:
- Two 256-byte stacks, each with an index byte
- 65536 bytes of RAM, which is used for both data and program text
- A 2-byte program counter
- 256 bytes of "device memory", used for peripherals
The simplest emulator reads a byte from RAM at the program counter, then calls into an instruction (which may update the program counter):
fn run(core: &mut Uxn, dev: &mut Device, mut pc: u16) -> u16 {
loop {
let op = core.next(&mut pc);
let Some(next) = core.op(op, dev, pc) else {
break pc;
};
pc = next;
}
}
impl Uxn {
fn op(
&mut self, op: u8, dev: &mut Device, pc: u16
) -> Option<u16> {
match op {
op::BRK => self.brk(pc),
op::INC => self.inc::<0b000>(pc),
op::POP => self.pop::<0b000>(pc),
op::NIP => self.nip::<0b000>(pc),
op::SWP => self.swp::<0b000>(pc),
// ... etc
op::ORA2kr => self.ora::<0b111>(pc),
op::EOR2kr => self.eor::<0b111>(pc),
op::SFT2kr => self.sft::<0b111>(pc),
}
}
}
There are 256 instructions, many of which are parameterized with flags. Here's
the INC instruction, which increments the top byte on the stack:
impl Uxn {
pub fn inc<const FLAGS: u8>(&mut self, pc: u16) -> Option<u16> {
let mut s = self.stack_view::<FLAGS>();
let v = s.pop();
s.push(v.wrapping_add(1));
Some(pc)
}
}
All of the opcode implementations are inlined into the main op function, but
there's room for improvement:
some values are stored in memory rather than registers, and the main op
selection branch is unpredictable.
Threaded code in assembly
In our assembly implementation, we can instead use threaded code (specifically token threading). We store all of the CPU state in registers, then end each instruction with a jump to the subsequent instruction:
; x0 | stack pointer
; x1 | stack index
; x4 | ram pointer
; x5 | program counter
; x8 | opcode table
_INC:
ldrb w9, [x0, x1] ; read the byte from the top of the stack
add w9, w9, #1 ; increment it
strb w9, [x0, x1] ; write it back
ldrb w9, [x4, x5] ; load the next opcode from RAM
add x5, x5, #1 ; increment the program counter
and x5, x5, #0xffff ; wrap the program counter
ldr x10, [x8, x9, lsl #3] ; load the opcode implementation address
br x10 ; jump to the opcode's implementation
This distributes the dispatch operation across every opcode, making it easier for the branch predictor to learn sequences of opcodes in the program. Overall speedups were significant: 40-50% faster on ARM64, and about 2× faster on x86-64.
Unfortunately, it requires maintaining about 2000 lines of code, and is incredibly unsafe. In my x86 port, I introduced an out-of-bounds write, which stomped on a few bytes outside of device RAM; the only symptom was that the fuzzer would segfault when exiting after running a very particular program.
So, what's to be done?
Tail calls in Rust
We'd like to get the same behavior as our assembly implementation – VM state stored in registers, dispatch at the end of each opcode – without hand-writing every instruction in assembly. Fortunately, there is hope!
The core idea has almost certainly been reinvented a bunch of times, but I first encountered the idea of tail-call interpreters in the Massey Meta Machine writeup, which was a mind-expanding read.
There are two pieces:
- Store program state in function arguments, which are mapped to registers based on your system's calling convention
- End each function by calling the next function
We could write this today in Rust; here's our inc function:
const TABLE: FunctionTable = FunctionTable([
brk,
inc::<0b000>,
pop::<0b000>,
nip::<0b000>,
swp::<0b000>,
rot::<0b000>,
dup::<0b000>,
// ...etc
]);
fn inc<'a, const FLAGS: u8>(
stack_data: &'a mut [u8; 256],
stack_index: u8,
rstack_data: &'a mut [u8; 256],
rstack_index: u8,
dev: &'a mut [u8; 256],
ram: &'a mut [u8; 65536],
mut pc: u16,
vdev: &mut dyn Device,
) -> (Uxn<'a>, u16) {
let mut core = Uxn {
stack: Stack {
data: stack_data,
index: stack_index,
},
ret: Stack {
data: rstack_data,
index: rstack_index,
},
dev,
ram,
};
match core.inc::<FLAGS>(pc) {
Some(pc) => {
let op = core.next(&mut pc);
TABLE.0[op as usize](
core.stack.data,
core.stack.index,
core.ret.data,
core.ret.index,
core.dev,
core.ram,
pc,
vdev,
)
}
None => (core, pc)
}
}
We want to reuse our existing Uxn opcode implementations, so we reconstruct
the core: Uxn object at the beginning of the function, call its inc
function, then deconstruct it again when calling the next operation. There's a
lot of boilerplate, and it's tempting to just pass a &mut Uxn argument, but
that removes the "state is stored in registers" benefit; we'll remove
boilerplate with a macro later on.
Unfortunately, there's a problem with this implementation:
thread 'snapshots::tailcall::mandelbrot' (67889685) has overflowed its stack
fatal runtime error: stack overflow, aborting
error: test failed, to rerun pass `-p raven-varvara --test snapshots`
Even in a release build, the compiler has not optimized out the stack. As we execute more and more operations, the stack gets deeper and deeper until it inevitably overflows.
We need tell the compiler to generate a br (branch to register) instead of a
bl (branch-and-link) instruction, and – more importantly – not to allocate any
persistent space on the stack. In other words, we need a
tail call.
In nightly Rust, this is a one-word fix:
match core.inc::<FLAGS>(pc) {
Some(pc) => {
let op = core.next(&mut pc);
- TABLE.0[op as usize](
+ become TABLE.0[op as usize](
core.stack.data,
core.stack.index,
core.ret.data,
core.ret.index,
core.dev,
core.ram,
pc,
vdev,
)
With this change in place, the Rust compiler makes a guarantee:
When tail calling a function, instead of its stack frame being added to the stack, the stack frame of the caller is directly replaced with the callee’s.
That's it, everything works! End of writeup!
Implementation details
Okay, okay, I've got a little more to say.
First, I promised a macro to eliminate the boilerplate. As always, it's a horrifying thing to behold:
macro_rules! tail_fn {
($name:ident $(::<$flags:ident>)?) => {
tail_fn!($name $(::<$flags>)?[][vdev: &mut dyn Device]);
};
($name:ident $(::<$flags:ident>)?($($arg:ident: $ty:ty),*)) => {
tail_fn!($name $(::<$flags>)?[$($arg: $ty),*][]);
};
($name:ident $(::<$flags:ident>)?[$($arg0:ident: $ty0:ty),*][$($arg1:ident: $ty1:ty),*]) => {
extern "rust-preserve-none" fn $name<'a, $(const $flags: u8)?>(
stack_data: &'a mut [u8; 256],
stack_index: u8,
rstack_data: &'a mut [u8; 256],
rstack_index: u8,
dev: &'a mut [u8; 256],
ram: &'a mut [u8; 65536],
pc: u16,
$($arg0: $ty0),*
$($arg1: $ty1),*
) -> (UxnCore<'a>, u16) {
let mut core = UxnCore {
stack: Stack {
data: stack_data,
index: stack_index,
},
ret: Stack {
data: rstack_data,
index: rstack_index,
},
dev,
ram,
};
match core.$name::<$($flags)?>(pc, $($arg0),*) {
Some(mut pc) => {
let op = core.next(&mut pc);
become TABLE.0[op as usize](
core.stack.data,
core.stack.index,
core.ret.data,
core.ret.index,
core.dev,
core.ram,
pc,
$($arg0),*
$($arg1),*
)
}
None => (core, pc),
}
}
};
}
(This is now from the actual implementation, so some types are slightly different than the simplified code earlier in this writeup)
The macro is very awkward, but it lets us declare all three kinds of functions:
- Fabled ones
- Innumerable ones
- Those drawn with a very fine camel hair brush
Wait, no, wrong list; the three kinds of functions are
- Bare functions
- Functions with
FLAGS - Functions with
FLAGSand additional arguments
Here's what all three look like:
tail_fn!(brk); // bare function
tail_fn!(inc::<FLAGS>); // function with flags
tail_fn!(dei::<FLAGS>(dev: &mut dyn Device)); // flags and arguments
You don't need to spend much time puzzling over the macro; we're firmly in "if
it compiles, it works" territory here. It's also worth noting that this is
still 100% safe Rust: our #![forbid(unsafe_code)] attribute remains
untriggered.
Codegen notes
The compiler does a good job of inlining and stripping functions down to their
essential operations; the boilerplate of constructing and deconstructing the
UxnCore is fully optimized out.
0000000100039c50 <raven_uxn::tailcall::inc::<0>>:
and x8, x0, #0xff ; mask stack pointer
ldrb w9, [x21, x8] ; read byte
add w9, w9, #1 ; increment value
strb w9, [x21, x8] ; write back value
and x8, x2, #0xffff ; mask pc
ldrb w8, [x24, x8] ; look up next opcode
adrp x9, 0x100279000 ; load table base
add x9, x9, #3776 ; offset??
ldr x3, [x9, x8, lsl #3] ; get jump target
add w2, w2, #1 ; increment program counter
br x3 ; jump!
I see two main differences from our hand-written implementation:
- It does register masking before use, instead of after modification. For
example, the program counter (in
x2) is masked when used for a lookup, but not masked after it is incremented - The table base is loaded from an offset, instead of stored in a register
We could fix the latter by also threading the table through the argument list, which improves performance on x86 (but doesn't seem to matter on ARM64).
Performance results
Speaking of performance, how does it do?
I've got two main benchmarks:
- A microbenchmark program which naively computes a Fibonacci number
- A larger program which renders a Mandlebrot fractal
On my laptop (M1 Macbook), I'm pleased to report that I'm no longer beating the compiler: the tail-call interpreter handily beats my hand-written assembly on both benchmarks.
| Fibonacci | Mandelbrot | |
|---|---|---|
| VM | 2.41 | 125 |
| Assembly | 1.32 | 87 |
| Tailcall | 1.19 | 76 |
(all times in milliseconds, smaller is faster)
Now, let's take a big sip of seasonally-inappropriate tea and test on x86—
| Fibonacci | Mandelbrot | |
|---|---|---|
| VM | 4.70 | 264 |
| Assembly | 1.84 | 168 |
| Tailcall | 3.23 | 175 |
—oh no. It's outperforming the VM, but is still losing to the assembly backend by a noticeable amount (especially in the Fibonacci microbenchmark). What's going on here?
Let's start by looking at the generated code for INC, our simplest opcode:
_RINvNtCs7kxjQDHw3ed_9raven_uxn8tailcall3incKh0_EB4_:
movzx eax,dil ; mask stack index
inc BYTE PTR [r13+rax] ; increment byte on stack
movzx eax,cx ; mask pc register
movzx eax,BYTE PTR [rdx+rax] ; read next opcode from RAM
inc ecx ; increment pc reigster
mov rax,QWORD PTR [r8+rax*8] ; read jump target from table
jmp rax ; jump to target
; followed by `int3` operations to pad to 32-byte alignment
This implementation looks broadly fine: it's doing the minimal number of reads and writes, and is basically what I'd expect. Indeed, incrementing the byte by address may be more efficient than my assembly, which loads and stores that byte.
(Also, I didn't harp on this before, but declaring these functions as extern "rust-preserve-none" is very important for the x86 implementation. The default
calling convention doesn't use enough registers for all of our arguments, which
adds tremendous amounts of overhead)
So, INC is inoffensive. Let's now look at ADD2, which adds the top two
16-bit values on the stack. I'll start by showing you the hand-tuned
implementation from the assembly backend:
_ADD2:
movzx eax,BYTE PTR [rbx+r12*1] ; read byte from data stack
dec r12b ; decrement data pointer
movzx ecx,BYTE PTR [rbx+r12*1] ; read byte from data stack
dec r12b ; decrement data pointer
shl ecx,0x8 ; build 16-bit value
or eax,ecx
movzx ecx,BYTE PTR [rbx+r12*1] ; read byte from data stack
lea r11,[r12-0x1] ; get next byte address
and r11,0xff ; mask index byte
movzx edx,BYTE PTR [rbx+r11*1] ; read byte from data stack
shl edx,0x8 ; build 16-bit value
or ecx,edx
add ecx,eax ; do addition
mov BYTE PTR [rbx+r12*1],cl ; write byte to data stack
shr ecx,0x8 ; shift 16-bit value
mov BYTE PTR [rbx+r11*1],cl ; write byte to data stack
movzx eax,BYTE PTR [r15+rbp*1] ; read next opcode
inc bp ; increment pc
lea rcx,[rip+0x22e5c7] ; read jump address
jmp QWORD PTR [rcx+rax*8] ; jump
This implementation is 79 bytes, and does the bare minimum number of reads and
writes: 4 byte reads + 2 byte writes to the data stack, one byte read from RAM
(to get the next opcode), and one qword read from the jump table (to get the
jump target).
In contrast, here's the compiled tailcall implementation:
_RINvNtCs7kxjQDHw3ed_9raven_uxn8tailcall3addKh1_EB4_:
push rbp ; spill rbp to the stack
mov QWORD PTR [rsp-0x8],r11 ; spill r11 to the stack too?
mov r11,r9 ; do the Register Shuffle (??)
mov r9,r8
mov r8d,ecx
mov rax,r13
movzx r10d,dil ; get index as byte
movzx edi,BYTE PTR [r13+r10] ; read byte from stack
lea ebx,[r10-0x1] ; get next index
movzx ebx,bl ; mask index to byte
lea r13d,[r10-0x2] ; precompute another index
movzx ebp,BYTE PTR [rax+rbx*1] ; read byte from stack
shl ebp,0x8 ; build 16-bit value
or ebp,edi
movzx edi,r13b ; mask index byte
movzx r13d,BYTE PTR [rax+rdi] ; read byte from stack
add r10b,0xfd ; compute another address
movzx ecx,r10b ; mask index byte
movzx ebx,BYTE PTR [rax+rcx] ; read byte from stack
shl ebx,0x8 ; build 16-bit value
or ebx,r13d
add ebx,ebp ; do addition
mov BYTE PTR [rax+rcx*1],bh ; write byte to stack
mov BYTE PTR [rax+rdi*1],bl ; write byte to stack
movzx ecx,r8w ; mask pc word
movzx ecx,BYTE PTR [rdx+rcx*1] ; read opcode
inc r8d ; increment pc
mov r10,QWORD PTR [r9+rcx*8] ; get next jump address
mov r13,rax ; do the Reverse Register Shuffle (??)
mov ecx,r8d
mov r8,r9
mov r9,r11
mov r11,QWORD PTR [rsp-0x8] ; restore spilled r11
mov rax,r10
pop rbp ; restore pushed rbp
jmp rax ; jump
; followed by `int3` padding to a 32-byte boundary
This is 121 bytes, and – more concerningly – spills and restores two full
64-bit registers to the stack. This looks like (to use a technical term) real
bad codegen. On one hand, this is an unfinished nightly feature in rustc, so
it's understandable; on the other hand, I'm surprised this isn't well optimized
by the LLVM backend, even if the rustc side is immature.
This blog post is getting unwieldy, so I won't speculate too much further, but I will make a few observations:
- Plain C code is compiled to something more optimal by Clang
- Similarly plain Rust code isn't quite as
bad as what I was seeing, but still spills
rbpto the stack unnecessarily.
One more thing
WebAssembly also supports tail calls, and Raven supports compilation to WebAssembly! I wonder how the tail-call interpreter will fare, compared to both native performance and the simple VM interpreter?
I can only benchmark the Mandelbrot example; the Fibonacci program is too fast (given limitations on web timer resolution). Here are the numbers:
| Engine | Backend | Time (ms) |
|---|---|---|
| Native | VM | 125 |
| Tailcall | 76 | |
| Firefox | VM | 264 |
| Tailcall | 311 | |
| Chrome | VM | 244 |
| Tailcall | 905 | |
wasmtime
| VM | 128 |
| Tailcall | 595 |
Surprise, it's terrible: 1.2× slower on Firefox, 3.7× slower on Chrome, and
4.6× slower in wasmtime. I guess patterns which generate good assembly don't
map well to the WASM stack machine, and the JITs aren't smart enough to lower it
to optimal machine code.
wasmtime did manage impressive performance on the traditional VM
implementation though — within a few percentage points of the native Rust build!
(All of these tests were on my M1 Max laptop;
wasmtime was built from e9e1665c5,
Firefox 149.0, and Chrome 146.0.7680.178)
Conclusion
The tailcall interpreter PR is merged,
and has been deployed in the 0.3.0 release. When enabled, it's the default on
ARM64 systems, and the second choice on x86-64 systems (if the native feature
is not enabled).
I'd be very curious to get tips on improving x86 and WASM performance;