Back Original

A tail-call interpreter in (nightly) Rust

Last week, I wrote a tail-call interpreter using the become keyword, which was recently added to nightly Rust (seven months ago is recent, right?).

It was a surprisingly pleasant experience, and the resulting VM outperforms both my previous Rust implementation and my hand-coded ARM64 assembly. Tailcall-based techniques have been all the rage recently (see this overview); consider this my trip report implementing a simple but non-trivial system.


For those keeping track at home, this is the latest in my exploration of high-performance emulation of the Uxn CPU, which runs a bunch of applications in the Hundred Rabbits ecosystem.

If you want to read the whole saga, here's the list:

Experimenting with LLMs proved controversial, which wasn't a surprise; I'm pleased to declare that all of the tail-call code is human-written, and the new backend can be used as a substitute for the x86 assembly backend at a minor performance penalty.

(This blog post is also entirely human-written, per my personal standards)

The next few sections summarize previous work, so feel free to skim them if you've done the reading and jump straight to tailcalls in Rust.


Basics of Uxn emulation

Uxn is a simple stack machine with 256 instructions. The whole CPU has just over 64K of space, split between a few memories:

The simplest emulator reads a byte from RAM at the program counter, then calls into an instruction (which may update the program counter):

fn run(core: &mut Uxn, dev: &mut Device, mut pc: u16) -> u16 {
    loop {
        let op = core.next(&mut pc);
        let Some(next) = core.op(op, dev, pc) else {
            break pc;
        };
        pc = next;
    }
}

impl Uxn {
    fn op(
        &mut self, op: u8, dev: &mut Device, pc: u16
    ) -> Option<u16> {
        match op {
            op::BRK => self.brk(pc),
            op::INC => self.inc::<0b000>(pc),
            op::POP => self.pop::<0b000>(pc),
            op::NIP => self.nip::<0b000>(pc),
            op::SWP => self.swp::<0b000>(pc),
            // ... etc
            op::ORA2kr => self.ora::<0b111>(pc),
            op::EOR2kr => self.eor::<0b111>(pc),
            op::SFT2kr => self.sft::<0b111>(pc),
        }
    }
}

There are 256 instructions, many of which are parameterized with flags. Here's the INC instruction, which increments the top byte on the stack:

impl Uxn {
    pub fn inc<const FLAGS: u8>(&mut self, pc: u16) -> Option<u16> {
        let mut s = self.stack_view::<FLAGS>();
        let v = s.pop();
        s.push(v.wrapping_add(1));
        Some(pc)
    }
}

All of the opcode implementations are inlined into the main op function, but there's room for improvement: some values are stored in memory rather than registers, and the main op selection branch is unpredictable.

Threaded code in assembly

In our assembly implementation, we can instead use threaded code (specifically token threading). We store all of the CPU state in registers, then end each instruction with a jump to the subsequent instruction:

; x0 | stack pointer
; x1 | stack index
; x4 | ram pointer
; x5 | program counter
; x8 | opcode table
_INC:
    ldrb w9, [x0, x1]           ; read the byte from the top of the stack
    add w9, w9, #1              ; increment it
    strb w9, [x0, x1]           ; write it back
    ldrb w9, [x4, x5]           ; load the next opcode from RAM
    add x5, x5, #1              ; increment the program counter
    and x5, x5, #0xffff         ; wrap the program counter
    ldr x10, [x8, x9, lsl #3]   ; load the opcode implementation address
    br x10                      ; jump to the opcode's implementation

This distributes the dispatch operation across every opcode, making it easier for the branch predictor to learn sequences of opcodes in the program. Overall speedups were significant: 40-50% faster on ARM64, and about 2× faster on x86-64.

Unfortunately, it requires maintaining about 2000 lines of code, and is incredibly unsafe. In my x86 port, I introduced an out-of-bounds write, which stomped on a few bytes outside of device RAM; the only symptom was that the fuzzer would segfault when exiting after running a very particular program.

So, what's to be done?

Tail calls in Rust

We'd like to get the same behavior as our assembly implementation – VM state stored in registers, dispatch at the end of each opcode – without hand-writing every instruction in assembly. Fortunately, there is hope!

The core idea has almost certainly been reinvented a bunch of times, but I first encountered the idea of tail-call interpreters in the Massey Meta Machine writeup, which was a mind-expanding read.

There are two pieces:

We could write this today in Rust; here's our inc function:

const TABLE: FunctionTable = FunctionTable([
    brk,
    inc::<0b000>,
    pop::<0b000>,
    nip::<0b000>,
    swp::<0b000>,
    rot::<0b000>,
    dup::<0b000>,
    // ...etc
]);

fn inc<'a, const FLAGS: u8>(
    stack_data: &'a mut [u8; 256],
    stack_index: u8,
    rstack_data: &'a mut [u8; 256],
    rstack_index: u8,
    dev: &'a mut [u8; 256],
    ram: &'a mut [u8; 65536],
    mut pc: u16,
    vdev: &mut dyn Device,
) -> (Uxn<'a>, u16) {
    let mut core = Uxn {
        stack: Stack {
            data: stack_data,
            index: stack_index,
        },
        ret: Stack {
            data: rstack_data,
            index: rstack_index,
        },
        dev,
        ram,
    };
    match core.inc::<FLAGS>(pc) {
        Some(pc) => {
            let op = core.next(&mut pc);
            TABLE.0[op as usize](
                core.stack.data,
                core.stack.index,
                core.ret.data,
                core.ret.index,
                core.dev,
                core.ram,
                pc,
                vdev,
            )
        }
        None => (core, pc)
    }
}

We want to reuse our existing Uxn opcode implementations, so we reconstruct the core: Uxn object at the beginning of the function, call its inc function, then deconstruct it again when calling the next operation. There's a lot of boilerplate, and it's tempting to just pass a &mut Uxn argument, but that removes the "state is stored in registers" benefit; we'll remove boilerplate with a macro later on.

Unfortunately, there's a problem with this implementation:

thread 'snapshots::tailcall::mandelbrot' (67889685) has overflowed its stack
fatal runtime error: stack overflow, aborting
error: test failed, to rerun pass `-p raven-varvara --test snapshots`

Even in a release build, the compiler has not optimized out the stack. As we execute more and more operations, the stack gets deeper and deeper until it inevitably overflows.

We need tell the compiler to generate a br (branch to register) instead of a bl (branch-and-link) instruction, and – more importantly – not to allocate any persistent space on the stack. In other words, we need a tail call.

In nightly Rust, this is a one-word fix:

     match core.inc::<FLAGS>(pc) {
         Some(pc) => {
             let op = core.next(&mut pc);
-            TABLE.0[op as usize](
+            become TABLE.0[op as usize](
                 core.stack.data,
                 core.stack.index,
                 core.ret.data,
                 core.ret.index,
                 core.dev,
                 core.ram,
                 pc,
                 vdev,
             )

With this change in place, the Rust compiler makes a guarantee:

When tail calling a function, instead of its stack frame being added to the stack, the stack frame of the caller is directly replaced with the callee’s.

That's it, everything works! End of writeup!

Implementation details

Okay, okay, I've got a little more to say.

First, I promised a macro to eliminate the boilerplate. As always, it's a horrifying thing to behold:

macro_rules! tail_fn {
    ($name:ident $(::<$flags:ident>)?) => {
        tail_fn!($name $(::<$flags>)?[][vdev: &mut dyn Device]);
    };
    ($name:ident $(::<$flags:ident>)?($($arg:ident: $ty:ty),*)) => {
        tail_fn!($name $(::<$flags>)?[$($arg: $ty),*][]);
    };
    ($name:ident $(::<$flags:ident>)?[$($arg0:ident: $ty0:ty),*][$($arg1:ident: $ty1:ty),*]) => {
        extern "rust-preserve-none" fn $name<'a, $(const $flags: u8)?>(
            stack_data: &'a mut [u8; 256],
            stack_index: u8,
            rstack_data: &'a mut [u8; 256],
            rstack_index: u8,
            dev: &'a mut [u8; 256],
            ram: &'a mut [u8; 65536],
            pc: u16,
            $($arg0: $ty0),*
            $($arg1: $ty1),*
        ) -> (UxnCore<'a>, u16) {
            let mut core = UxnCore {
                stack: Stack {
                    data: stack_data,
                    index: stack_index,
                },
                ret: Stack {
                    data: rstack_data,
                    index: rstack_index,
                },
                dev,
                ram,
            };
            match core.$name::<$($flags)?>(pc, $($arg0),*) {
                Some(mut pc) => {
                    let op = core.next(&mut pc);
                    become TABLE.0[op as usize](
                        core.stack.data,
                        core.stack.index,
                        core.ret.data,
                        core.ret.index,
                        core.dev,
                        core.ram,
                        pc,
                        $($arg0),*
                        $($arg1),*
                    )
                }
                None => (core, pc),
            }
        }
    };
}

(This is now from the actual implementation, so some types are slightly different than the simplified code earlier in this writeup)

The macro is very awkward, but it lets us declare all three kinds of functions:

Wait, no, wrong list; the three kinds of functions are

Here's what all three look like:

tail_fn!(brk); // bare function
tail_fn!(inc::<FLAGS>); // function with flags
tail_fn!(dei::<FLAGS>(dev: &mut dyn Device)); // flags and arguments

You don't need to spend much time puzzling over the macro; we're firmly in "if it compiles, it works" territory here. It's also worth noting that this is still 100% safe Rust: our #![forbid(unsafe_code)] attribute remains untriggered.

Codegen notes

The compiler does a good job of inlining and stripping functions down to their essential operations; the boilerplate of constructing and deconstructing the UxnCore is fully optimized out.

0000000100039c50 <raven_uxn::tailcall::inc::<0>>:
    and     x8, x0, #0xff           ; mask stack pointer
    ldrb    w9, [x21, x8]           ; read byte
    add     w9, w9, #1              ; increment value
    strb    w9, [x21, x8]           ; write back value
    and     x8, x2, #0xffff         ; mask pc
    ldrb    w8, [x24, x8]           ; look up next opcode
    adrp    x9, 0x100279000         ; load table base
    add     x9, x9, #3776           ; offset??
    ldr     x3, [x9, x8, lsl  #3]   ; get jump target
    add     w2, w2, #1              ; increment program counter
    br      x3                      ; jump!

I see two main differences from our hand-written implementation:

We could fix the latter by also threading the table through the argument list, which improves performance on x86 (but doesn't seem to matter on ARM64).

Performance results

Speaking of performance, how does it do?

I've got two main benchmarks:

On my laptop (M1 Macbook), I'm pleased to report that I'm no longer beating the compiler: the tail-call interpreter handily beats my hand-written assembly on both benchmarks.

FibonacciMandelbrot
VM 2.41 125
Assembly 1.32 87
Tailcall 1.19 76

(all times in milliseconds, smaller is faster)

Now, let's take a big sip of seasonally-inappropriate tea and test on x86—

FibonacciMandelbrot
VM 4.70 264
Assembly 1.84 168
Tailcall 3.23 175

oh no. It's outperforming the VM, but is still losing to the assembly backend by a noticeable amount (especially in the Fibonacci microbenchmark). What's going on here?

Let's start by looking at the generated code for INC, our simplest opcode:

_RINvNtCs7kxjQDHw3ed_9raven_uxn8tailcall3incKh0_EB4_:
    movzx  eax,dil                  ; mask stack index
    inc    BYTE PTR [r13+rax]       ; increment byte on stack
    movzx  eax,cx                   ; mask pc register
    movzx  eax,BYTE PTR [rdx+rax]   ; read next opcode from RAM
    inc    ecx                      ; increment pc reigster
    mov    rax,QWORD PTR [r8+rax*8] ; read jump target from table
    jmp    rax                      ; jump to target
    ; followed by `int3` operations to pad to 32-byte alignment

This implementation looks broadly fine: it's doing the minimal number of reads and writes, and is basically what I'd expect. Indeed, incrementing the byte by address may be more efficient than my assembly, which loads and stores that byte.

(Also, I didn't harp on this before, but declaring these functions as extern "rust-preserve-none" is very important for the x86 implementation. The default calling convention doesn't use enough registers for all of our arguments, which adds tremendous amounts of overhead)

So, INC is inoffensive. Let's now look at ADD2, which adds the top two 16-bit values on the stack. I'll start by showing you the hand-tuned implementation from the assembly backend:

_ADD2:
    movzx  eax,BYTE PTR [rbx+r12*1]     ; read byte from data stack
    dec    r12b                         ; decrement data pointer
    movzx  ecx,BYTE PTR [rbx+r12*1]     ; read byte from data stack
    dec    r12b                         ; decrement data pointer
    shl    ecx,0x8                      ; build 16-bit value
    or     eax,ecx
    movzx  ecx,BYTE PTR [rbx+r12*1]     ; read byte from data stack
    lea    r11,[r12-0x1]                ; get next byte address
    and    r11,0xff                     ; mask index byte
    movzx  edx,BYTE PTR [rbx+r11*1]     ; read byte from data stack
    shl    edx,0x8                      ; build 16-bit value
    or     ecx,edx
    add    ecx,eax                      ; do addition
    mov    BYTE PTR [rbx+r12*1],cl      ; write byte to data stack
    shr    ecx,0x8                      ; shift 16-bit value
    mov    BYTE PTR [rbx+r11*1],cl      ; write byte to data stack
    movzx  eax,BYTE PTR [r15+rbp*1]     ; read next opcode
    inc    bp                           ; increment pc
    lea    rcx,[rip+0x22e5c7]           ; read jump address
    jmp    QWORD PTR [rcx+rax*8]        ; jump

This implementation is 79 bytes, and does the bare minimum number of reads and writes: 4 byte reads + 2 byte writes to the data stack, one byte read from RAM (to get the next opcode), and one qword read from the jump table (to get the jump target).

In contrast, here's the compiled tailcall implementation:

_RINvNtCs7kxjQDHw3ed_9raven_uxn8tailcall3addKh1_EB4_:
    push   rbp                      ; spill rbp to the stack
    mov    QWORD PTR [rsp-0x8],r11  ; spill r11 to the stack too?
    mov    r11,r9                   ; do the Register Shuffle (??)
    mov    r9,r8
    mov    r8d,ecx
    mov    rax,r13
    movzx  r10d,dil                 ; get index as byte
    movzx  edi,BYTE PTR [r13+r10]   ; read byte from stack
    lea    ebx,[r10-0x1]            ; get next index
    movzx  ebx,bl                   ; mask index to byte
    lea    r13d,[r10-0x2]           ; precompute another index
    movzx  ebp,BYTE PTR [rax+rbx*1] ; read byte from stack
    shl    ebp,0x8                  ; build 16-bit value
    or     ebp,edi
    movzx  edi,r13b                 ; mask index byte
    movzx  r13d,BYTE PTR [rax+rdi]  ; read byte from stack
    add    r10b,0xfd                ; compute another address
    movzx  ecx,r10b                 ; mask index byte
    movzx  ebx,BYTE PTR [rax+rcx]   ; read byte from stack
    shl    ebx,0x8                  ; build 16-bit value
    or     ebx,r13d
    add    ebx,ebp                  ; do addition
    mov    BYTE PTR [rax+rcx*1],bh  ; write byte to stack
    mov    BYTE PTR [rax+rdi*1],bl  ; write byte to stack
    movzx  ecx,r8w                  ; mask pc word
    movzx  ecx,BYTE PTR [rdx+rcx*1] ; read opcode
    inc    r8d                      ; increment pc
    mov    r10,QWORD PTR [r9+rcx*8] ; get next jump address
    mov    r13,rax                  ; do the Reverse Register Shuffle (??)
    mov    ecx,r8d
    mov    r8,r9
    mov    r9,r11
    mov    r11,QWORD PTR [rsp-0x8]  ; restore spilled r11
    mov    rax,r10
    pop    rbp                      ; restore pushed rbp
    jmp    rax                      ; jump
    ; followed by `int3` padding to a 32-byte boundary

This is 121 bytes, and – more concerningly – spills and restores two full 64-bit registers to the stack. This looks like (to use a technical term) real bad codegen. On one hand, this is an unfinished nightly feature in rustc, so it's understandable; on the other hand, I'm surprised this isn't well optimized by the LLVM backend, even if the rustc side is immature.

This blog post is getting unwieldy, so I won't speculate too much further, but I will make a few observations:

One more thing

WebAssembly also supports tail calls, and Raven supports compilation to WebAssembly! I wonder how the tail-call interpreter will fare, compared to both native performance and the simple VM interpreter?

I can only benchmark the Mandelbrot example; the Fibonacci program is too fast (given limitations on web timer resolution). Here are the numbers:

EngineBackendTime (ms)
Native VM 125
Tailcall 76
Firefox VM 264
Tailcall 311
Chrome VM 244
Tailcall 905
wasmtime VM 128
Tailcall 595

Surprise, it's terrible: 1.2× slower on Firefox, 3.7× slower on Chrome, and 4.6× slower in wasmtime. I guess patterns which generate good assembly don't map well to the WASM stack machine, and the JITs aren't smart enough to lower it to optimal machine code.

wasmtime did manage impressive performance on the traditional VM implementation though — within a few percentage points of the native Rust build!

(All of these tests were on my M1 Max laptop; wasmtime was built from e9e1665c5, Firefox 149.0, and Chrome 146.0.7680.178)

Conclusion

The tailcall interpreter PR is merged, and has been deployed in the 0.3.0 release. When enabled, it's the default on ARM64 systems, and the second choice on x86-64 systems (if the native feature is not enabled).

I'd be very curious to get tips on improving x86 and WASM performance;