# Machine Learning zk-VM In this document, I want to brainstorm some ideas on how to build a machine-learning optimized zk-VM as part of [ZK Hack Istanbul](https://www.zkistanbul.com/). ## Project description (as posted into to the ZK Hack `#project-ideas` channel) My idea is to build a zk-VM that is specialized for inference of deep neural networks. I prototyped a tiny version of this a few months ago [here](https://github.com/georgwiese/ml-zkvm). The current version just performs a 2x2 matrix multiplication and a ReLU activation. For this, it uses a special dot-product and ReLU instructions, so the execution of the program uses fewer steps that it would on a general-purpose instruction set (like RISC-V, EVM, WebAssembly). The idea is to push this is far as possible in one weekend, e.g.: - Define an instruction set that is general enough to compute many common neural network architectures. - Implement these instructions using highly optimized circuits. We can probably find some inspiration at projects like EZKL, who have Halo2 implementations of common operations. - Write some kind of compiler that converts real-world models into an assembly program that uses our ML instruction set. The ideal outcome of this would be a zk-VM who's performance comes close to that of circuit compilers like EZKL, but has all the flexibility of being a zk-VM (e.g. one validator for all models, possibly being able to interpret models directly without any compilation process in future). ## Background: EZKL To get some inspiration on how to effeciently prove neural network inference in a zk-SNARK, I dove into how [EZKL](https://github.com/zkonduit/ezkl) (a ONNX-to-Halo2 compiler) does it. Specifically, I stepped through various stages of their pipeline following the [`2l_relu_fc`](https://github.com/zkonduit/ezkl/tree/main/examples/onnx/2l_relu_fc) example. [These docs](https://docs.ezkl.xyz/about_ezkl/commands/) describe the stages and how to run them. The `2l_relu_fc` example is a very simple model which works as follows: - It receives a 3-dimensional input vector $x$ - Then, it applies a linear layer, i.e. computes $\langle w, x \rangle + b$ (where $w$ is a learned 3-dimensional vector and $b$ is a learned scalar) - Then, it applies a ReLU activation ### The computational graph The `compile-circuit` stage generates a [`GraphCircuit`](https://github.com/zkonduit/ezkl/blob/ee4e64faee0972ac29dece132df59aecab865185/src/graph/mod.rs#L461-L468) from an ONNX model. The computational graph is represented in the [`Model`](https://github.com/zkonduit/ezkl/blob/ee4e64faee0972ac29dece132df59aecab865185/src/graph/model.rs#L79-L84) type, which eventually consists of a set of [`Node`s](https://github.com/zkonduit/ezkl/blob/ee4e64faee0972ac29dece132df59aecab865185/src/graph/node.rs#L502-L517) containing a [`SupportedOp`](https://github.com/zkonduit/ezkl/blob/ee4e64faee0972ac29dece132df59aecab865185/src/graph/node.rs#L261-L278). The main point is that this part of the pipeline deals with quantization and converts a generic ONNX file to a graph of a relatively small number of tensor operations. In the example above, [this](https://gist.github.com/georgwiese/9718624b18ac7befd49c5805b71b3fe4) is the generated graph. I think it could be benificial to at least take this description of the model as input, because then we don't have to deal with the large number of ONNX operations and floating point values! ### Compiling to a circuit The `GraphCircuit` type implements Halo2's `Circuit` trait, which generates the constraints and fills the columns with data. In this simple example, EZKL generates: - 3 Advice columns: These can mostly be interpreted as 2 inputs and 1 output - 1 Instance column: Used for public inputs / outputs - 12 Selector columns: - One for each of the 10 [`BaseOp`s](https://github.com/zkonduit/ezkl/blob/ee4e64faee0972ac29dece132df59aecab865185/src/circuit/ops/base.rs#L10-L22) - 1 is for activating the ReLU lookup - 1 is for activating a lookup for rescaling (e.g. dividing by $2^7$ after multiplication) - 4 Fixed columns: They come from the 2 lookup tables (for ReLU and division by $2^7$) So, one can think of each row performing one of 12 elementary operations on two scalars. However, a operation might pass state to the next. For example, for a dot-product of input vectors $a$ and $b$, each row multiplies $a_i$ and $b_i$ and sums up the results using an accumulator. This is illustrated [in this post](https://hackmd.io/mGwARMgvSeq2nGvQWLL2Ww#Supporting-arbitrary-model-sizes). Inputs and outputs of different operations are connected via copy constraints. ## Approach 1: Mimic EZKL execution with memory EZKL's approach can be easily viewed as a simple program which has a small number of instructions (e.g. 12 in the example) on up to 2 scalars. Each row performs one instruction. This is very close to a VM execution. The main challenge here is that EZKL connects inputs and outputs via copy constraints, so they are hard-wired into the circuit. This is not possible in a VM. To work around that, we could use use memory. Our VM would have 3 registers (`in1`, `in2`, `out`). Then, each row in the EZKL circuit would roughly translate to 4 instructions: 1. Load operand 1 from memory into `in1` 2. Load operand 2 from memory into `in2` 3. Run the actual operation (this can be one of 12 operations, and modifies the `out` register) 4. Write the content of `out` to memory The memory addresses would be part of the program being executed. They can be simply be the ID of the equivalence class when looking at the EZKL circuit with copy constraints. This ensures that values that are supposed to be the same actually are the same. Note that: - For operations only requiring 1 operand, we could skip step 2 - For values of `out` that don't appear in any copy constraint (think of intermediate accumulator values in a dot-product computation), we can skip step 4 The memory would be a write-once memory, so the read and write operation is actually the same. Here's a simple implementation in PIL: ``` col fixed ADDR(i) i; col witness content; instr_mem_in1 {addr, in1} in {ADDR, content}; instr_mem_in2 {addr, in2} in {ADDR, content}; instr_mem_out {addr, out} in {ADDR, content}; ``` ## Approach 2: Data-parallel instructions The main downside of the first approach is a up to 4x blow-up in the number of rows compared to EZKL, because of the extra memory operations. An alternative approach would be to have a machine that works on larger words and can perform computation in a data-parallel fashion. For example, let's say we have a word size of 8 field elements and registers `in1[8]`, `in2[8]` and `out`. Then, we could have an instruction that computes a dot product of a two input vectors. A matrix multiplication of 2 $16 \times 16$ matrices would involve $16 \cdot 16 \cdot 2 = 512$ such dot products (2 for each entry in the output matrix, which also has to be summed). This compares to $16 \cdot 16 \cdot 16 = 4096$ multiplications to implement the matrix multiplication naively (using Approach 1). Still, this approach comes with many challenges for which I don't have a solution yet: - It is not clear how to load the operands into memory. Storing single words in memory would require many memory operations. Storing larger words would require the operator to be "in one piece" (e.g. we wouldn't be able to read the a column of a matrix stored in row-major order). - While with approach 1 it is fairly clear how to generate the ROM from the ONNX file (letting EZKL handle breaking it down into low-level operations), this approach would need a more complicated approach.