## Packed Poseidon2 in MLIR
Following [the previous post](https://hackmd.io/@Cb5ED-vYSb-4QslX0tXOSw/By5kJkDAeg), we benchmarked our MLIR implementation on Packed PrimeField. The results are as follows (reference PR: https://github.com/fractalyze/zkir/pull/86):
| | Width | Plonky3 | ZKIR | Ratio (Plonky3 / ZKIR) |
| ----------- | ----- | ------- | ------- | ---------------------- |
| AMD 9950X3d | 16 | 7.38 ms | 7.38 ms | 1.00 |
| Mac M4 Pro | 4 | 5.51 ms | 5.58 ms | 0.98 |
*(The table shows the total time for `Width` * 10,000 hash operations. The benchmark used **BabyBear**, but similar results are expected for **KoalaBear**.)*
We added the following math tricks to the ZKIR pass:
- Algebraic rewrite for $a^x$ when $x$ is a constant (e.g., $a^3 \to a^2 \times a$)
- Algebraic rewrite for $k * a$ when $k$ is a constant (e.g., $3 \times a \to 2\times a + a$)
- Algebraic rewrite for $a/2^x$
- Algebraic rewrite for $(a - b)^2$
- Vector constant folding for `mod_arith` dialect.
- Algebraic rewrite for `add`, `sub`, and `mul` when operands match specific AVX-512 patterns
- Custom `mod_arith.mul` for arm neon (it's not merged yet, but you can see code [here](https://github.com/fractalyze/zkir/tree/feat/benchmark-arm-neon-poseidon2))
Separately, we implemented additional math tricks directly in MLIR that should also be integrated into the ZKIR pass:
- When $A$ is a constant matrix and $b$ is a vector, an algebraic rewrite for $A \times b$
---
## JAX
For the first time, we implemented code generation to **ZKX** using **JAX**. It only supports boolean and integer types right now, we'll soon add support for ZK types in 2 weeks.
```python
import jax
fast_f = jax.jit(lambda x: x + 1)
print(fast_f(1))
```
The Python code above produces the following **StableHLO**:
```mlir
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32>) -> (tensor<i32> {jax.result_info = "result"}) {
%c = stablehlo.constant dense<1> : tensor<i32>
%0 = stablehlo.add %arg0, %c : tensor<i32>
return %0 : tensor<i32>
}
}
```
which is then lowered to **HLO**:
```
ENTRY %main.4 (Arg_0.1: s32[]) -> s32[] {
%Arg_0.1 = s32[] parameter(0), metadata={op_name="x"}
%constant.2 = s32[] constant(1)
ROOT %add.3 = s32[] add(s32[] %Arg_0.1, s32[] %constant.2), metadata={op_name="jit(f)/add" source_file="/private/var/tmp/_bazel_chokobole/bb0d5db6656f6e67a16815740113965e/execroot/jax/bazel-out/darwin_arm64-opt/bin/tests/my_test.runfiles/jax/tests/my_test.py" source_line=8}
}
```
and then converted to **MLIR**.
```mlir
module {
func.func @add.3(%arg0: memref<i32> {bufferization.writable = false}, %arg1: memref<i32> {bufferization.writable = false}) -> memref<i32> attributes {llvm.emit_c_interface} {
%0 = memref.load %arg0[] : memref<i32>
%1 = memref.load %arg1[] : memref<i32>
%2 = arith.addi %0, %1 : i32
%alloc = memref.alloc() : memref<i32>
memref.store %2, %alloc[] : memref<i32>
return %alloc : memref<i32>
}
}
```
---
## ZKX
To support Groth16, we implemented the following opcodes (see opcode semantics in https://openxla.org/xla/operation_semantics):
- `add`
- `bitcast`
- `const`
- `convert`
- `dot`
- `fft`
- `msm`
- `multiply`
- `parameter`
- `slice`
- `subtract`
- `tuple`
To support **WHIR-p3**, we additionally added the following opcodes to ZKX, enabling us to cover most single-machine provers:
- `abs`
- `and`
- `bitcast-convert`
- `call`
- `clamp`
- `concatenate`
- `conditional`
- `count-leading-zeros`
- `compare`
- `dynamic-slice`
- `dynamic-update-slice`
- `iota`
- `map`
- `maximum`
- `minimum`
- `not`
- `pad`
- `popluation-count`
- `reduce`
- `remainder`
- `reshape`
- `reverse`
- `select`
- `shift-left`
- `shift-right-arithmetic`
- `shift-right-logical`
- `sign`
- `sort`
- `transpose`
- `while`
- `xor`
See the related PRs
- [feat: enable more opcodes](https://github.com/fractalyze/zkx/pull/97)
- [feat: enable more shape inference](https://github.com/fractalyze/zkx/pull/98)
- [feat: add more instructions](https://github.com/fractalyze/zkx/pull/99)
- [feat: add more cpu code generation part1](https://github.com/fractalyze/zkx/pull/101)
- [feat: add more cpu code generation part2](https://github.com/fractalyze/zkx/pull/102)
- [feat: add more cpu code generation part3](https://github.com/fractalyze/zkx/pull/103)
- [feat: add more cpu code generation part4](https://github.com/fractalyze/zkx/pull/104)
- [feat: add more cpu code generation part5](https://github.com/fractalyze/zkx/pull/105)
- [feat: add more cpu code generation part6](https://github.com/fractalyze/zkx/pull/106)
---
## Poseidon2 in JAX
Since we haven’t added a prime-field type to JAX yet, the following is an **integer-based** early version of Poseidon2:
```python
import jax
from jax import jit
import jax.numpy as jnp
import jax.lax as lax
import jax.tree_util as tree_util
class Poseidon2BabyBear16:
"""
A Poseidon2BabyBear16 hash function.
"""
# INITIAL ROUND CONSTANTS (4 rounds x 16 constants)
INITIAL_RC = jnp.array([
[0x69cbb6af, 0x46ad93f9, 0x60a00f4e, 0x6b1297cd, 0x23189afe, 0x732e7bef, 0x72c246de, 0x2c941900, 0x0557eede, 0x1580496f, 0x3a3ea77b, 0x54f3f271, 0x0f49b029, 0x47872fe1, 0x221e2e36, 0x1ab7202e],
[0x487779a6, 0x3851c9d8, 0x38dc17c0, 0x209f8849, 0x268dcee8, 0x350c48da, 0x5b9ad32e, 0x0523272b, 0x3f89055b, 0x01e894b2, 0x13ddedde, 0x1b2ef334, 0x7507d8b4, 0x6ceeb94e, 0x52eb6ba2, 0x50642905],
[0x05453f3f, 0x06349efc, 0x6922787c, 0x04bfff9c, 0x768c714a, 0x3e9ff21a, 0x15737c9c, 0x2229c807, 0x0d47f88c, 0x097e0ecc, 0x27eadba0, 0x2d7d29e4, 0x3502aaa0, 0x0f475fd7, 0x29fbda49, 0x018afffd],
[0x0315b618, 0x6d4497d1, 0x1b171d9e, 0x52861abd, 0x2e5d0501, 0x3ec8646c, 0x6e5f250a, 0x148ae8e6, 0x17f5fa4a, 0x3e66d284, 0x0051aa3b, 0x483f7913, 0x2cfe5f15, 0x023427ca, 0x2cc78315, 0x1e36ea47]
])
# INTERNAL ROUND CONSTANTS (13 scalar constants)
INTERNAL_RC = jnp.array([
0x5a8053c0, 0x693be639, 0x3858867d, 0x19334f6b, 0x128f0fd8, 0x4e2b1ccb, 0x61210ce0, 0x3c318939, 0x0b5b2f22, 0x2edb11d5, 0x213effdf, 0x0cac4606, 0x241af16d
])
# TERMINAL ROUND CONSTANTS (4 rounds x 16 constants)
TERMINAL_RC = jnp.array([
[0x7290a80d, 0x6f7e5329, 0x598ec8a8, 0x76a859a0, 0x6559e868, 0x657b83af, 0x13271d3f, 0x1f876063, 0x0aeeae37, 0x706e9ca6, 0x46400cee, 0x72a05c26, 0x2c589c9e, 0x20bd37a7, 0x6a2d3d10, 0x20523767],
[0x5b8fe9c4, 0x2aa501d6, 0x1e01ac3e, 0x1448bc54, 0x5ce5ad1c, 0x4918a14d, 0x2c46a83f, 0x4fcf6876, 0x61d8d5c8, 0x6ddf4ff9, 0x11fda4d3, 0x02933a8f, 0x170eaf81, 0x5a9c314f, 0x49a12590, 0x35ec52a1],
[0x58eb1611, 0x5e481e65, 0x367125c9, 0x0eba33ba, 0x1fc28ded, 0x066399ad, 0x0cbec0ea, 0x75fd1af0, 0x50f5bf4e, 0x643d5f41, 0x6f4fe718, 0x5b3cbbde, 0x1e3afb3e, 0x296fb027, 0x45e1547b, 0x4a8db2ab],
[0x59986d19, 0x30bcdfa3, 0x1db63932, 0x1d7c2824, 0x53b33681, 0x0673b747, 0x038a98a3, 0x2c5bce60, 0x351979cd, 0x5008fb73, 0x547bca78, 0x711af481, 0x3f93bf64, 0x644d987b, 0x3c8bcd87, 0x608758b8]
])
def __init__(self):
pass
@jit
def add_rc_and_sbox(self, elem, rc):
"""
Adds the round constants and applies the S-box to the element.
"""
return (elem + rc) ** 7
@jit
def mds_light_permutation(self, state):
"""
Applies the MDS light permutation to the state.
The permutation consists of two phases:
1. Apply M_4 matrix to each consecutive four elements
2. Apply outer circulant matrix transformation
Args:
state: A 16-element array representing the state
Returns:
The transformed state array
"""
# First phase: Apply M_4 to each consecutive four elements
# [ 2 3 1 1 ]
# [ 1 2 3 1 ]
# [ 1 1 2 3 ]
# [ 3 1 1 2 ]
def process_chunk(chunk_idx, state):
offset = chunk_idx * 4
x0 = state[offset]
x1 = state[offset + 1]
x2 = state[offset + 2]
x3 = state[offset + 3]
# Compute intermediate sums
x01 = x0 + x1
x23 = x2 + x3
x0123 = x01 + x23
x01123 = x0123 + x1
x01233 = x0123 + x3
# Compute doubles
x00 = x0 * 2
x22 = x2 * 2
# Compute new values
x0_new = x01123 + x01
x1_new = x01123 + x22
x2_new = x01233 + x23
x3_new = x01233 + x00
# Update state
state = state.at[offset].set(x0_new)
state = state.at[offset + 1].set(x1_new)
state = state.at[offset + 2].set(x2_new)
state = state.at[offset + 3].set(x3_new)
return state
# Apply M_4 to each chunk of 4 elements
state = lax.fori_loop(0, 4, process_chunk, state)
# Second phase: Apply outer circulant matrix
# Compute sums: sums[k] = sum of state[j + k] for j = 0, 4, 8, 12
# This is equivalent to: sums = state[0:4] + state[4:8] + state[8:12] + state[12:16]
sums = state[0:4] + state[4:8] + state[8:12] + state[12:16]
# Apply the formula: y_i = x_i' + sums[i % 4]
# Reshape state to (4, 4) and transpose so columns correspond to [i, i+4, i+8, i+12]
# Then add sums[i] to each element in column i
state_reshaped = state.reshape(4, 4).T # Transpose to get columns [i, i+4, i+8, i+12]
sums_expanded = sums[:, jnp.newaxis] # Shape (4, 1) for broadcasting
state_reshaped = state_reshaped + sums_expanded
state = state_reshaped.T.reshape(16) # Transpose back and flatten
return state
@jit
def permute_state_terminal(self, state, rc):
"""
External layer: terminal permutation (4 rounds: add RC, S-box, MDS).
Args:
state: A 16-element array representing the state
rc: A (4, 16) array of round constants for the 4 rounds
Returns:
The permuted state array
"""
# Loop through 4 rounds of external terminal permutation
for round_idx in range(4):
# Get round constants for this round
rc_round = rc[round_idx]
state = (state + rc_round) ** 7
# Apply MDS light permutation
state = self.mds_light_permutation(state)
return state
@jit
def permute_state_initial(self, state):
"""
External layer: initial permutation (MDS light + terminal permutation).
Args:
state: A 16-element array representing the state
Returns:
The permuted state array
"""
# First apply MDS light permutation
state = self.mds_light_permutation(state)
return self.permute_state_terminal(state, self.INITIAL_RC)
@jit
def internal_layer_mat_mul(self, state, sum_val):
"""
Internal layer matrix multiplication.
Applies (1 + diagonal_mat) multiplication to state[1] through state[15].
Args:
state: A 16-element array representing the state
sum_val: The sum value to use in the matrix multiplication
Returns:
The transformed state array
"""
# Precompute powers of 2 inverses
# NOTE(batzor): if both are field, it would be inverse of 2 lol
inv_two = 1 / 2 # Inverse of 2
inv_four = inv_two ** 2 # Inverse of 4
inv_eight = inv_two ** 3 # Inverse of 8
inv_sixteen = inv_two ** 4 # Inverse of 16
inv_256 = inv_sixteen ** 2 # Inverse of 256 (16^2)
inv_2_27 = inv_two ** 27 # Inverse of 2^27
# Extract state[1:16] for vectorized operations
state_slice = state[1:16]
# Define multipliers for each element: state[i] = state[i] * multiplier[i] + sum_val
multipliers = jnp.array([
1, # state[1]: state[1] + sum
2, # state[2]: state[2] * 2 + sum
inv_two, # state[3]: state[3] * inv_two + sum
3, # state[4]: state[4] * 3 + sum (2 + 1)
4, # state[5]: state[5] * 4 + sum
-inv_two, # state[6]: sum - state[6] * inv_two
-3, # state[7]: sum - state[7] * 3 (2 + 1)
-4, # state[8]: sum - state[8] * 4
inv_256, # state[9]: state[9] * inv_256 + sum
inv_four, # state[10]: state[10] * inv_four + sum
inv_eight, # state[11]: state[11] * inv_eight + sum
inv_2_27, # state[12]: state[12] * inv_2_27 + sum
-inv_256, # state[13]: sum - state[13] * inv_256
-inv_sixteen, # state[14]: sum - state[14] * inv_sixteen
-inv_2_27, # state[15]: sum - state[15] * inv_2_27
])
# Vectorized operation: state[1:16] = state[1:16] * multipliers + sum_val
state_slice = state_slice * multipliers + sum_val
# Update state with the transformed slice
state = state.at[1:16].set(state_slice)
return state
@jit
def permute_state(self, state):
"""
Internal layer: permutation (add RC to first element, S-box first, internal diffusion).
Args:
state: A 16-element array representing the state
Returns:
The permuted state array
"""
# For each internal constant: add RC and S-box to first element, then apply matrix multiplication
for round_idx in range(13):
# Get current round constant
rc_val = self.INTERNAL_RC[round_idx]
# Save old state[0] before S-box
old_elem0 = state[0]
# Add RC and apply S-box to first element
new_elem0 = self.add_rc_and_sbox(state[0], rc_val)
# Use the new S-box result in sum calculation.
elem0 = new_elem0
# Calculate sum of elements from index 1 to 15.
partial_sum = jnp.sum(state[1:])
total_sum = partial_sum + elem0
# Update state[0] = partial_sum - old_elem0 (use old value before S-box)
state = state.at[0].set(partial_sum - old_elem0)
# Apply internal layer matrix multiplication
state = self.internal_layer_mat_mul(state, total_sum)
return state
@jit
def permute(self, state):
"""
Permutes the state using the Poseidon2BabyBear16 hash function.
Args:
state: A 16-element array representing the state
Returns:
The permuted state array
"""
state = self.permute_state_initial(state)
state = self.permute_state(state)
state = self.permute_state_terminal(state, self.TERMINAL_RC)
return state
def _tree_flatten(self):
children = () # arrays / dynamic values
aux_data = {} # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)
tree_util.register_pytree_node(Poseidon2BabyBear16,
Poseidon2BabyBear16._tree_flatten,
Poseidon2BabyBear16._tree_unflatten)
```