Try   HackMD

RVVM JIT compiler

Tracing JIT 基於以下假設:

  1. 程式的執行時間大部份都花在迴圈
  2. 迴圈每次都有相似的執行路徑

Tracing JIT 有下面幾個步驟:

  1. profile
    找出程式中執行頻繁的熱點,紀錄 jump 發生次數超過一定的閥值就進入追蹤階段
  2. tracing
    當程式執行到熱點的時候紀錄執行的指令,這裡要考慮的情況很多,像是如下 case
    • 巢狀迴圈、分支
    • break / continue
    • 結合上述兩點的情況
      以下論文提到藉由 Trace tree 跟 Control flow graph (CFG) 來紀錄指令,兩者相同的地方都是透過走訪來得到 Hot path。
  3. compile and optimize
    簡單的方式是逐一指令以 codegen 產生機械碼
    RVVM
    複雜一點的則是生成 IR 來優化再編譯成機械碼
    MIR
  4. execute code
    當程式執行到同樣的地方時,換成執行 JIT 編譯後的機械碼

RVVM 為 RISC-V 系統模擬器,具備 Tracing JIT 的功能又號稱比 QEMU 快 10 倍 又是怎麼一回事?
RVVM 的程式碼中也沒看到 CFG 跟 Trace tree,那又是怎麼處理分支跟巢狀迴圈?

真相就是 RVVM 直接跳過 profile,從運行開始就在做 tracing 跟 codegen,直到遇到 jump 或是 branch 再檢查 tracing 是否超過最大長度,如果結束編譯,那麼等到下次執行到同樣指令時,就開始呼叫 JIT 編譯後的機械碼,利用犧牲空間巧妙的迴避掉複雜的 case。

RVVM/cpu 的目錄裡包含所有 RISC-V 的指令處理,每道指令都有它的解碼跟實作,而且每道指令的實作都以 rvjit_{inst}(rds, rs1, rs2, 4); 來紀錄要執行的指令。

接下來分別以一般指令跟跳躍/分支指令為例來說明 RVVM 怎麼 tracing

一般指令

以下說明 addi 這道指令,其他指令有相似處,差在 load/store 跟 jump/branch 是不同的寫法

// risc_i.c
static void riscv_i_addi(rvvm_hart_t *vm, const uint32_t instruction)
{
    // Add signed immediate to rs1, store to rds
    regid_t rds = bit_cut(instruction, 7, 5);
    regid_t rs1 = bit_cut(instruction, 15, 5);
    sxlen_t imm = sign_extend(bit_cut(instruction, 20, 12), 12);
    xlen_t src_reg = riscv_read_register(vm, rs1);

    rvjit_addi(rds, rs1, imm, 4);

    riscv_write_register(vm, rds, src_reg + imm);
}

追蹤 rvjit_addi 可以看到 RVVM_RVJIT_TRACE,並把指令跟 size 當作參數傳入,如果 jit 正在編譯,就更新 jit 的 pc offset 也標記目前這道指令尚未結束

// riscv_cpu.h
#define rvjit_addi(rds, rs1, imm, size) \
RVVM_RVJIT_TRACE(rvjit32_addi(&vm->jit, rds, rs1, imm), size) \
// riscv_cpu.h
#define RVVM_RVJIT_TRACE(intrinsic, inst_size) \
do { \
    if (!vm->jit_compiling && riscv_jit_tlb_lookup(vm)) { \
        vm->registers[REGISTER_PC] -= inst_size; \
        return; \
    } \
    if (vm->jit_compiling) { \
        intrinsic; \
        vm->jit.pc_off += inst_size; \
        vm->block_ends = false; \
    } \
} while (0)

來看 rvjit32_addi 如何實作: RVVM 支援 32 跟 64 位元的指令,因此要分別實作 rvjit32_addirvjit64_addi,其中 RVJIT32_IMM_INC 便能產生 32 位元的指令及其實作。

// rvjit_emit.c
RVJIT_IMM_INC(addi)
    
#define RVJIT_IMM_INC(instr) \
RVJIT32_IMM_INC(instr) \
RVJIT64_IMM_INC(instr)

// rvjit_emit.c
#define RVJIT32_IMM_INC(instr) \
void rvjit32_##instr(rvjit_block_t* block, regid_t rds, regid_t rs1, int32_t imm) \
{ \
    RVJIT32_IMM_INC_OPTIMIZE(rds, rs1, imm); \
    RVJIT_2REG_IMM_OP(rvjit32_native_##instr, rds, rs1, imm); \
}

addi 帶入到巨集會產生 rvjit32_native_addi,這就對應到處理器架構對應產生的指令集,像是 arm, x86, risc-v 的 addi 實作。以 x86 為例:

// rvjit_x86.h
static inline void rvjit32_native_addi(rvjit_block_t* block, regid_t hrds, regid_t hrs1, int32_t imm)
{
    rvjit_x86_2reg_imm_op(block, X86_ADD_IMM, hrds, hrs1, imm, false);
}

// generate machine code
static inline void rvjit_x86_2reg_imm_op(rvjit_block_t* block, uint8_t opcode, regid_t hrds, regid_t hrs1, int32_t imm, bool bits_64)
{
    if (opcode == X86_AND_IMM) {
        if (imm == 0) {
            // Optimize andi r1, r2, 0 -> xor r1, r1
            rvjit_x86_2reg_op(block, X86_XOR, hrds, hrds, false);
            return;
        } else if (imm == 0xFF && x86_byte_reg_usable(hrs1)) {
            // Optimize andi r1, r2, 0xFF -> movzxb r1, r2
            rvjit_x86_movzxb(block, hrds, hrs1);
            return;
        } else if (imm > 0) {
            // Remove REX.W prefix for unsigned andi imm
            bits_64 = false;
        }
    } else if (opcode == X86_ADD_IMM && imm && hrds != hrs1) {
        // addi r1, r2, imm -> lea r1, [r2 + imm]
        rvjit_x86_lea_addi(block, hrds, hrs1, imm, bits_64);
        return;
    }
    if (hrds != hrs1) rvjit_x86_mov(block, hrds, hrs1, bits_64);
    if (imm) rvjit_x86_r_imm_op(block, opcode, hrds, imm, bits_64);
}

假設經過 if-else 最後是執行 rvjit_x86_lea_addi 便會開始產生機械碼,最後再透過 rvjit_put_code 把指令附加 (append) 到 block

static inline void rvjit_x86_lea_addi(rvjit_block_t* block, regid_t dest, regid_t src, int32_t imm, bool bits_64)
{
    uint8_t code[2];
    code[0] = bits_64 ? X64_REX_W : 0;
    code[1] = 0x8d;
    if (src >= X64_R8) {
        code[0] |= X64_REX_B;
    }
    if (dest >= X64_R8) {
        code[0] |= X64_REX_R;
    }
    rvjit_put_code(block, code + (code[0] ? 0 : 1), code[0] ? 2 : 1);
    rvjit_x86_memory_ref(block, dest, src, imm);
}

static inline void rvjit_put_code(rvjit_block_t* block, const void* inst, size_t size)
{
    if (unlikely(block->space < block->size + size)) {
        block->space += 1024;
        block->code = safe_realloc(block->code, block->space);
    }
    memcpy(block->code + block->size, inst, size);
    block->size += size;
}

跳躍/分支指令

遇到跳躍或分支的時候會檢查是否超過 BRANCH_MAX_BLOCK_SIZE,PC 則要更新成 offset,即要跳躍過去的地址

跳躍

#define BRANCH_MAX_BLOCK_SIZE 250

// JAL instruction applies jump offset to pc_off
// We already check page cross in riscv_emulate()
#define RVVM_RVJIT_TRACE_JAL(intrinsic, offset, inst_size) \
do { \
    if (!vm->jit_compiling && riscv_jit_tlb_lookup(vm)) { \
        vm->registers[REGISTER_PC] -= inst_size; \
        return; \
    } \
    if (vm->jit_compiling) { \
        intrinsic; \
        vm->jit.pc_off += offset; \
        vm->block_ends = vm->jit.size > BRANCH_MAX_BLOCK_SIZE; \
    } \
} while (0)

// Blocks immediately ends upon indirect jump (thus no need to trace it)
#define RVVM_RVJIT_COMPILE_JALR(intrinsic) \
do { \
    if (vm->jit_compiling) { \
        intrinsic; \
    } \
} while (0)

分支

// Branches taken in interpreter are treated as likely branches and inlined
#define RVVM_RVJIT_BRANCH(intrinsic, target_off, falthrough_off, inst_size) \
do { \
    if (!vm->jit_compiling && riscv_jit_tlb_lookup(vm)) { \
        vm->registers[REGISTER_PC] -= inst_size; \
        return; \
    } \
    if (vm->jit_compiling) { \
        vm->jit.pc_off += falthrough_off; \
        intrinsic; \
        vm->jit.pc_off += (target_off - falthrough_off); \
        vm->block_ends = vm->jit.size > BRANCH_MAX_BLOCK_SIZE; \
    } \
} while (0)

執行 JIT

RVVM 執行指令的函式為 riscv_emulate,目前還不清楚檢查 virt_pc 跟 register 的 pc 目的為何,但只要不執行到 riscv_jit_finalize 那麼就會繼續編譯,如果進行 finalize 的話就會把 tracing 到的 block 放到 vm->jtlb 當作快取

// riscv_cpu.c
static inline void riscv_emulate(rvvm_hart_t *vm, uint32_t instruction)
{
#ifdef USE_JIT
    if (unlikely(vm->jit_compiling)) {
        /*
         * If we hit non-compilable instruction or cross page boundaries,
         * the block is finalized.
         */
        if (vm->block_ends
        || (vm->jit.virt_pc >> PAGE_SHIFT) != (vm->registers[REGISTER_PC] >> PAGE_SHIFT)) {
            riscv_jit_finalize(vm);
        }
        vm->block_ends = true;
    }
#endif
    if ((instruction & RV_OPCODE_MASK) != RV_OPCODE_MASK) {
        vm->decoder.opcodes_c[riscv_c_funcid(instruction)](vm, instruction);
        // FYI: Any jump instruction implementation should take care of PC increment
        vm->registers[REGISTER_PC] += 2;
    } else {
        vm->decoder.opcodes[riscv_funcid(instruction)](vm, instruction);
        vm->registers[REGISTER_PC] += 4;
    }
}

如果 block 被 finalize,則 block 透過 riscv_jit_tlb_put 將位置跟 block(機械碼) 放到 jtlb 的當作快取,如果快取滿了就清掉,並結束編譯。

static void riscv_jit_finalize(rvvm_hart_t* vm)
{
    if (rvjit_block_nonempty(&vm->jit)) {
        rvjit_func_t block = rvjit_block_finalize(&vm->jit);

        if (block) {
            riscv_jit_tlb_put(vm, vm->jit.virt_pc, block);
        } else {
            // Our cache is full, flush it
            riscv_jit_tlb_flush(vm);
            rvjit_flush_cache(&vm->jit);
        }
    }

    vm->jit_compiling = false;
}

static inline void riscv_jit_tlb_put(rvvm_hart_t* vm, vaddr_t vaddr, rvjit_func_t block)
{
    vaddr_t entry = (vaddr >> 1) & TLB_MASK;
    vm->jtlb[entry].pc = vaddr;
    vm->jtlb[entry].block = block;
}

假設已經紀錄好 addi,接著下一道指令是 and,執行 and 的時候會呼叫 rvjit_and,如果還沒執行 finalize,就會紀錄指令到 tracing,如果已經 finalize,vm->jit_compiling 會變 false,這時候會呼叫 riscv_jit_tlb_lookup

#define RVVM_RVJIT_TRACE(intrinsic, inst_size) \
do { \
    // 編譯結束就從 riscv_jit_tlb_lookup 執行 jit 編譯的機械碼 
    if (!vm->jit_compiling && riscv_jit_tlb_lookup(vm)) { \
        vm->registers[REGISTER_PC] -= inst_size; \
        return; \
    } \
    // 編譯中就繼續紀錄指令
    if (vm->jit_compiling) { \
        intrinsic; \
        vm->jit.pc_off += inst_size; \
        vm->block_ends = false; \
    } \
} while (0)

當呼叫 riscv_jit_tlb_lookup 時會去找 jtlb 的快取有沒有 jit 編譯的機械碼,有的話就把 vm 傳入 block 當作參數來執行機械碼。

static inline bool riscv_jit_tlb_lookup(rvvm_hart_t* vm)
{
    vaddr_t pc, tpc, entry;
    size_t tries = 0;

    if (unlikely(!vm->jit_enabled)) return false;

    // Try to find & execute a block
    trace:
    pc = vm->registers[REGISTER_PC];
    entry = (pc >> 1) & (TLB_SIZE - 1);
    tpc = vm->jtlb[entry].pc;
    if (likely(pc == tpc)) {
        vm->jtlb[entry].block(vm);
        if (likely(tries++ < 10)) goto trace;
        return true;
    } else if (tries == 0) {
        return riscv_jit_lookup(vm);
    } else return true;
}

riscv_jit_lookup 則是 pc 不等於 tpc 時就從 virtual addr 找看看,如果這樣也找不到就初始化一個新的 block 繼續紀錄執行的指令,如此便能夠逐步進行 JIT 編譯

NOINLINE bool riscv_jit_lookup(rvvm_hart_t* vm)
{
    /*
     * Translate virtual address into physical.
     * We are tracing address already fetched from,
     * thus a pagefault isn't possible
     */
    vaddr_t virt_pc = vm->registers[REGISTER_PC];
    vmptr_t ptr = riscv_vma_translate_e(vm, virt_pc);
    // Lookup in the hashmap, cache in JTLB
    if (ptr) {
        paddr_t phys_pc = (size_t)(ptr - vm->mem.data) + vm->mem.begin;
        rvjit_func_t block = rvjit_block_lookup(&vm->jit, phys_pc);
        if (block) {
            riscv_jit_tlb_put(vm, virt_pc, block);
            block(vm);
            return true;
        }

        /*
         * No valid block compiled for this location,
         * make a new one and enable compiler
         */
        rvjit_block_init(&vm->jit);
        vm->jit.pc_off = 0;
        vm->jit.virt_pc = virt_pc;
        vm->jit.phys_pc = phys_pc;

        vm->jit_compiling = true;
        vm->block_ends = false;
    }
    return false;
}