## Packed Poseidon2 in MLIR Following [the previous post](https://hackmd.io/@Cb5ED-vYSb-4QslX0tXOSw/By5kJkDAeg), we benchmarked our MLIR implementation on Packed PrimeField. The results are as follows (reference PR: https://github.com/fractalyze/zkir/pull/86): | | Width | Plonky3 | ZKIR | Ratio (Plonky3 / ZKIR) | | ----------- | ----- | ------- | ------- | ---------------------- | | AMD 9950X3d | 16 | 7.38 ms | 7.38 ms | 1.00 | | Mac M4 Pro | 4 | 5.51 ms | 5.58 ms | 0.98 | *(The table shows the total time for `Width` * 10,000 hash operations. The benchmark used **BabyBear**, but similar results are expected for **KoalaBear**.)* We added the following math tricks to the ZKIR pass: - Algebraic rewrite for $a^x$ when $x$ is a constant (e.g., $a^3 \to a^2 \times a$) - Algebraic rewrite for $k * a$ when $k$ is a constant (e.g., $3 \times a \to 2\times a + a$) - Algebraic rewrite for $a/2^x$ - Algebraic rewrite for $(a - b)^2$ - Vector constant folding for `mod_arith` dialect. - Algebraic rewrite for `add`, `sub`, and `mul` when operands match specific AVX-512 patterns - Custom `mod_arith.mul` for arm neon (it's not merged yet, but you can see code [here](https://github.com/fractalyze/zkir/tree/feat/benchmark-arm-neon-poseidon2)) Separately, we implemented additional math tricks directly in MLIR that should also be integrated into the ZKIR pass: - When $A$ is a constant matrix and $b$ is a vector, an algebraic rewrite for $A \times b$ --- ## JAX For the first time, we implemented code generation to **ZKX** using **JAX**. It only supports boolean and integer types right now, we'll soon add support for ZK types in 2 weeks. ```python import jax fast_f = jax.jit(lambda x: x + 1) print(fast_f(1)) ``` The Python code above produces the following **StableHLO**: ```mlir module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32>) -> (tensor<i32> {jax.result_info = "result"}) { %c = stablehlo.constant dense<1> : tensor<i32> %0 = stablehlo.add %arg0, %c : tensor<i32> return %0 : tensor<i32> } } ``` which is then lowered to **HLO**: ``` ENTRY %main.4 (Arg_0.1: s32[]) -> s32[] { %Arg_0.1 = s32[] parameter(0), metadata={op_name="x"} %constant.2 = s32[] constant(1) ROOT %add.3 = s32[] add(s32[] %Arg_0.1, s32[] %constant.2), metadata={op_name="jit(f)/add" source_file="/private/var/tmp/_bazel_chokobole/bb0d5db6656f6e67a16815740113965e/execroot/jax/bazel-out/darwin_arm64-opt/bin/tests/my_test.runfiles/jax/tests/my_test.py" source_line=8} } ``` and then converted to **MLIR**. ```mlir module { func.func @add.3(%arg0: memref<i32> {bufferization.writable = false}, %arg1: memref<i32> {bufferization.writable = false}) -> memref<i32> attributes {llvm.emit_c_interface} { %0 = memref.load %arg0[] : memref<i32> %1 = memref.load %arg1[] : memref<i32> %2 = arith.addi %0, %1 : i32 %alloc = memref.alloc() : memref<i32> memref.store %2, %alloc[] : memref<i32> return %alloc : memref<i32> } } ``` --- ## ZKX To support Groth16, we implemented the following opcodes (see opcode semantics in https://openxla.org/xla/operation_semantics): - `add` - `bitcast` - `const` - `convert` - `dot` - `fft` - `msm` - `multiply` - `parameter` - `slice` - `subtract` - `tuple` To support **WHIR-p3**, we additionally added the following opcodes to ZKX, enabling us to cover most single-machine provers: - `abs` - `and` - `bitcast-convert` - `call` - `clamp` - `concatenate` - `conditional` - `count-leading-zeros` - `compare` - `dynamic-slice` - `dynamic-update-slice` - `iota` - `map` - `maximum` - `minimum` - `not` - `pad` - `popluation-count` - `reduce` - `remainder` - `reshape` - `reverse` - `select` - `shift-left` - `shift-right-arithmetic` - `shift-right-logical` - `sign` - `sort` - `transpose` - `while` - `xor` See the related PRs - [feat: enable more opcodes](https://github.com/fractalyze/zkx/pull/97) - [feat: enable more shape inference](https://github.com/fractalyze/zkx/pull/98) - [feat: add more instructions](https://github.com/fractalyze/zkx/pull/99) - [feat: add more cpu code generation part1](https://github.com/fractalyze/zkx/pull/101) - [feat: add more cpu code generation part2](https://github.com/fractalyze/zkx/pull/102) - [feat: add more cpu code generation part3](https://github.com/fractalyze/zkx/pull/103) - [feat: add more cpu code generation part4](https://github.com/fractalyze/zkx/pull/104) - [feat: add more cpu code generation part5](https://github.com/fractalyze/zkx/pull/105) - [feat: add more cpu code generation part6](https://github.com/fractalyze/zkx/pull/106) --- ## Poseidon2 in JAX Since we haven’t added a prime-field type to JAX yet, the following is an **integer-based** early version of Poseidon2: ```python import jax from jax import jit import jax.numpy as jnp import jax.lax as lax import jax.tree_util as tree_util class Poseidon2BabyBear16: """ A Poseidon2BabyBear16 hash function. """ # INITIAL ROUND CONSTANTS (4 rounds x 16 constants) INITIAL_RC = jnp.array([ [0x69cbb6af, 0x46ad93f9, 0x60a00f4e, 0x6b1297cd, 0x23189afe, 0x732e7bef, 0x72c246de, 0x2c941900, 0x0557eede, 0x1580496f, 0x3a3ea77b, 0x54f3f271, 0x0f49b029, 0x47872fe1, 0x221e2e36, 0x1ab7202e], [0x487779a6, 0x3851c9d8, 0x38dc17c0, 0x209f8849, 0x268dcee8, 0x350c48da, 0x5b9ad32e, 0x0523272b, 0x3f89055b, 0x01e894b2, 0x13ddedde, 0x1b2ef334, 0x7507d8b4, 0x6ceeb94e, 0x52eb6ba2, 0x50642905], [0x05453f3f, 0x06349efc, 0x6922787c, 0x04bfff9c, 0x768c714a, 0x3e9ff21a, 0x15737c9c, 0x2229c807, 0x0d47f88c, 0x097e0ecc, 0x27eadba0, 0x2d7d29e4, 0x3502aaa0, 0x0f475fd7, 0x29fbda49, 0x018afffd], [0x0315b618, 0x6d4497d1, 0x1b171d9e, 0x52861abd, 0x2e5d0501, 0x3ec8646c, 0x6e5f250a, 0x148ae8e6, 0x17f5fa4a, 0x3e66d284, 0x0051aa3b, 0x483f7913, 0x2cfe5f15, 0x023427ca, 0x2cc78315, 0x1e36ea47] ]) # INTERNAL ROUND CONSTANTS (13 scalar constants) INTERNAL_RC = jnp.array([ 0x5a8053c0, 0x693be639, 0x3858867d, 0x19334f6b, 0x128f0fd8, 0x4e2b1ccb, 0x61210ce0, 0x3c318939, 0x0b5b2f22, 0x2edb11d5, 0x213effdf, 0x0cac4606, 0x241af16d ]) # TERMINAL ROUND CONSTANTS (4 rounds x 16 constants) TERMINAL_RC = jnp.array([ [0x7290a80d, 0x6f7e5329, 0x598ec8a8, 0x76a859a0, 0x6559e868, 0x657b83af, 0x13271d3f, 0x1f876063, 0x0aeeae37, 0x706e9ca6, 0x46400cee, 0x72a05c26, 0x2c589c9e, 0x20bd37a7, 0x6a2d3d10, 0x20523767], [0x5b8fe9c4, 0x2aa501d6, 0x1e01ac3e, 0x1448bc54, 0x5ce5ad1c, 0x4918a14d, 0x2c46a83f, 0x4fcf6876, 0x61d8d5c8, 0x6ddf4ff9, 0x11fda4d3, 0x02933a8f, 0x170eaf81, 0x5a9c314f, 0x49a12590, 0x35ec52a1], [0x58eb1611, 0x5e481e65, 0x367125c9, 0x0eba33ba, 0x1fc28ded, 0x066399ad, 0x0cbec0ea, 0x75fd1af0, 0x50f5bf4e, 0x643d5f41, 0x6f4fe718, 0x5b3cbbde, 0x1e3afb3e, 0x296fb027, 0x45e1547b, 0x4a8db2ab], [0x59986d19, 0x30bcdfa3, 0x1db63932, 0x1d7c2824, 0x53b33681, 0x0673b747, 0x038a98a3, 0x2c5bce60, 0x351979cd, 0x5008fb73, 0x547bca78, 0x711af481, 0x3f93bf64, 0x644d987b, 0x3c8bcd87, 0x608758b8] ]) def __init__(self): pass @jit def add_rc_and_sbox(self, elem, rc): """ Adds the round constants and applies the S-box to the element. """ return (elem + rc) ** 7 @jit def mds_light_permutation(self, state): """ Applies the MDS light permutation to the state. The permutation consists of two phases: 1. Apply M_4 matrix to each consecutive four elements 2. Apply outer circulant matrix transformation Args: state: A 16-element array representing the state Returns: The transformed state array """ # First phase: Apply M_4 to each consecutive four elements # [ 2 3 1 1 ] # [ 1 2 3 1 ] # [ 1 1 2 3 ] # [ 3 1 1 2 ] def process_chunk(chunk_idx, state): offset = chunk_idx * 4 x0 = state[offset] x1 = state[offset + 1] x2 = state[offset + 2] x3 = state[offset + 3] # Compute intermediate sums x01 = x0 + x1 x23 = x2 + x3 x0123 = x01 + x23 x01123 = x0123 + x1 x01233 = x0123 + x3 # Compute doubles x00 = x0 * 2 x22 = x2 * 2 # Compute new values x0_new = x01123 + x01 x1_new = x01123 + x22 x2_new = x01233 + x23 x3_new = x01233 + x00 # Update state state = state.at[offset].set(x0_new) state = state.at[offset + 1].set(x1_new) state = state.at[offset + 2].set(x2_new) state = state.at[offset + 3].set(x3_new) return state # Apply M_4 to each chunk of 4 elements state = lax.fori_loop(0, 4, process_chunk, state) # Second phase: Apply outer circulant matrix # Compute sums: sums[k] = sum of state[j + k] for j = 0, 4, 8, 12 # This is equivalent to: sums = state[0:4] + state[4:8] + state[8:12] + state[12:16] sums = state[0:4] + state[4:8] + state[8:12] + state[12:16] # Apply the formula: y_i = x_i' + sums[i % 4] # Reshape state to (4, 4) and transpose so columns correspond to [i, i+4, i+8, i+12] # Then add sums[i] to each element in column i state_reshaped = state.reshape(4, 4).T # Transpose to get columns [i, i+4, i+8, i+12] sums_expanded = sums[:, jnp.newaxis] # Shape (4, 1) for broadcasting state_reshaped = state_reshaped + sums_expanded state = state_reshaped.T.reshape(16) # Transpose back and flatten return state @jit def permute_state_terminal(self, state, rc): """ External layer: terminal permutation (4 rounds: add RC, S-box, MDS). Args: state: A 16-element array representing the state rc: A (4, 16) array of round constants for the 4 rounds Returns: The permuted state array """ # Loop through 4 rounds of external terminal permutation for round_idx in range(4): # Get round constants for this round rc_round = rc[round_idx] state = (state + rc_round) ** 7 # Apply MDS light permutation state = self.mds_light_permutation(state) return state @jit def permute_state_initial(self, state): """ External layer: initial permutation (MDS light + terminal permutation). Args: state: A 16-element array representing the state Returns: The permuted state array """ # First apply MDS light permutation state = self.mds_light_permutation(state) return self.permute_state_terminal(state, self.INITIAL_RC) @jit def internal_layer_mat_mul(self, state, sum_val): """ Internal layer matrix multiplication. Applies (1 + diagonal_mat) multiplication to state[1] through state[15]. Args: state: A 16-element array representing the state sum_val: The sum value to use in the matrix multiplication Returns: The transformed state array """ # Precompute powers of 2 inverses # NOTE(batzor): if both are field, it would be inverse of 2 lol inv_two = 1 / 2 # Inverse of 2 inv_four = inv_two ** 2 # Inverse of 4 inv_eight = inv_two ** 3 # Inverse of 8 inv_sixteen = inv_two ** 4 # Inverse of 16 inv_256 = inv_sixteen ** 2 # Inverse of 256 (16^2) inv_2_27 = inv_two ** 27 # Inverse of 2^27 # Extract state[1:16] for vectorized operations state_slice = state[1:16] # Define multipliers for each element: state[i] = state[i] * multiplier[i] + sum_val multipliers = jnp.array([ 1, # state[1]: state[1] + sum 2, # state[2]: state[2] * 2 + sum inv_two, # state[3]: state[3] * inv_two + sum 3, # state[4]: state[4] * 3 + sum (2 + 1) 4, # state[5]: state[5] * 4 + sum -inv_two, # state[6]: sum - state[6] * inv_two -3, # state[7]: sum - state[7] * 3 (2 + 1) -4, # state[8]: sum - state[8] * 4 inv_256, # state[9]: state[9] * inv_256 + sum inv_four, # state[10]: state[10] * inv_four + sum inv_eight, # state[11]: state[11] * inv_eight + sum inv_2_27, # state[12]: state[12] * inv_2_27 + sum -inv_256, # state[13]: sum - state[13] * inv_256 -inv_sixteen, # state[14]: sum - state[14] * inv_sixteen -inv_2_27, # state[15]: sum - state[15] * inv_2_27 ]) # Vectorized operation: state[1:16] = state[1:16] * multipliers + sum_val state_slice = state_slice * multipliers + sum_val # Update state with the transformed slice state = state.at[1:16].set(state_slice) return state @jit def permute_state(self, state): """ Internal layer: permutation (add RC to first element, S-box first, internal diffusion). Args: state: A 16-element array representing the state Returns: The permuted state array """ # For each internal constant: add RC and S-box to first element, then apply matrix multiplication for round_idx in range(13): # Get current round constant rc_val = self.INTERNAL_RC[round_idx] # Save old state[0] before S-box old_elem0 = state[0] # Add RC and apply S-box to first element new_elem0 = self.add_rc_and_sbox(state[0], rc_val) # Use the new S-box result in sum calculation. elem0 = new_elem0 # Calculate sum of elements from index 1 to 15. partial_sum = jnp.sum(state[1:]) total_sum = partial_sum + elem0 # Update state[0] = partial_sum - old_elem0 (use old value before S-box) state = state.at[0].set(partial_sum - old_elem0) # Apply internal layer matrix multiplication state = self.internal_layer_mat_mul(state, total_sum) return state @jit def permute(self, state): """ Permutes the state using the Poseidon2BabyBear16 hash function. Args: state: A 16-element array representing the state Returns: The permuted state array """ state = self.permute_state_initial(state) state = self.permute_state(state) state = self.permute_state_terminal(state, self.TERMINAL_RC) return state def _tree_flatten(self): children = () # arrays / dynamic values aux_data = {} # static values return (children, aux_data) @classmethod def _tree_unflatten(cls, aux_data, children): return cls(*children, **aux_data) tree_util.register_pytree_node(Poseidon2BabyBear16, Poseidon2BabyBear16._tree_flatten, Poseidon2BabyBear16._tree_unflatten) ```