## Poseidon2 in MLIR
In hash-based STARK systems, the hash function dominates performance — and Poseidon2 is the most commonly used hash. Therefore, improving Poseidon2 performance directly translates to faster hash-based STARKs.
Over the past two weeks, one of our team members implemented [Poseidon2 in MLIR](#MLIR-code), and the result outperformed **Plonky3** on both AMD and Apple Silicon environments:
| | Plonky3 | ZKIR | Ratio (Plonky3 / ZKIR) |
| ----------- | ------- | ------- | ---------------------- |
| AMD 9950X3d | 6.12 ms | 5.70 ms | 1.07 |
| Mac M4 Pro | 7.76 ms | 5.80 ms | 1.34 |
*(The table shows the total time for 10,000 hash operations. The benchmark used **BabyBear**, but similar results are expected for **KoalaBear**.)*
We also benchmarked with **Packed PrimeField**, where performance is currently slower:
| | Plonky3 | ZKIR | Ratio (Plonky3 / ZKIR) |
| ----------- | ------- | ------- | ---------------------- |
| AMD 9950X3d | 7.44 ms | 9.30 ms | 0.8 |
This slowdown is due to **engineering optimizations** in Plonky3 that we haven’t yet implemented.
For example, consider computing $(x - y)^2$ for $0 \le x, y < P$:
1. $v = P + x - y$
2. $\mathsf{reduce}(v)$
3. $v^2$
4. $\mathsf{reduce}(v^2)$
$$
\mathsf{reduce}(v) =
\begin{cases}
v - P & \text{if } v \ge P \\
v & \text{otherwise}
\end{cases}
$$
Plonky3 skips step (2) and instead performs:
1. $v = x - y$
3. $v^2$
4. $\mathsf{reduce}(v^2)$
Since $v$ after step (1) lies within $-P < v < P$, omitting the first reduction does not affect correctness — $v^2$ still satisfies $0 \le v^2 < P^2$.
We’re currently implementing similar optimizations in ZKIR and will share the results soon.
---
## WHIR-p3 in Python
[XLA](https://openxla.org/xla) is a runtime compiler used in deep learning. It defines **[HLO (High-Level Operations)](https://openxla.org/xla/operation_semantics)** and allows compilation from **Python → [Jaxpr](https://docs.jax.dev/en/latest/jaxpr.html) → HLO → [MLIR](https://mlir.llvm.org/) → LLVM IR → binary** via the [JAX](https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html) library.
Our team has been **developing a ZK-specific runtime compiler called ZKX** by forking XLA. Using ZKX, we already achieved **[SOTA Groth16 performance on CPU](https://narrow-cello-dab.notion.site/EN-The-Fastest-Groth16-CPU-Prover-in-the-World-22cc62052b8e80fab91cd531a9ca70a0?pvs=73)**.
The next goal is to **express [WHIR-p3](https://github.com/tcoratger/whir-p3) in Python**. To that end, we are **forking and adapting JAX** for ZKX compatibility.
In parallel, we’re porting **WHIR-p3 (including Poseidon2)** to Python using JAX.
For example, the Rust implementation of [poly/dense.rs](https://github.com/tcoratger/whir-p3/blob/main/src/poly/dense.rs) can be expressed in JAX as [follows](#Python-code).
---
## MLIR code
```mlir
// Poseidon2 utility functions for BabyBear field
// Based on Plonky3 implementation: https://github.com/Plonky3/Plonky3
!pf = !field.pf<2013265921 : i32, true>
!pf_std = !field.pf<2013265921 : i32>
!state = memref<16x!pf>
!state_std = memref<16x!pf_std>
func.func @add_rc_and_sbox(%var: !pf, %c: !pf) -> !pf {
%c7 = arith.constant 7 : i32
%sum = field.add %var, %c : !pf
%sum_sq = field.square %sum : !pf
%sum_sq_sq = field.square %sum_sq : !pf
%sum_cu = field.mul %sum, %sum_sq : !pf
%sum_exp7 = field.mul %sum_sq_sq, %sum_cu : !pf
return %sum_exp7 : !pf
}
// In-place version of apply_mat4 using memref
// Optimally, we just want to do matmul which then lowers to the following
// sequence but at this moment, it seems hard to achieve. Therefore, we just use field addition instead of matrix multiplication.
func.func @apply_mat4(%state: memref<4x!pf, strided<[1], offset: ?>>) {
// Load the 4x4 MDS matrix (no changes here)
%matrix = arith.constant dense<[
[2, 3, 1, 1],
[1, 2, 3, 1],
[1, 1, 2, 3],
[3, 1, 1, 2]
]> : tensor<4x4xi32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
// Allocate a temporary output buffer
%output = memref.alloca() : memref<4x!pf>
// Compute the sum of all 4 elements
%x0 = memref.load %state[%c0] : memref<4x!pf, strided<[1], offset: ?>>
%x1 = memref.load %state[%c1] : memref<4x!pf, strided<[1], offset: ?>>
%x2 = memref.load %state[%c2] : memref<4x!pf, strided<[1], offset: ?>>
%x3 = memref.load %state[%c3] : memref<4x!pf, strided<[1], offset: ?>>
%x01 = field.add %x0, %x1 : !pf
%x23 = field.add %x2, %x3 : !pf
%x0123 = field.add %x01, %x23 : !pf
%x01123 = field.add %x0123, %x1 : !pf
%x01233 = field.add %x0123, %x3 : !pf
%x00 = field.double %x0 : !pf
%x22 = field.double %x2 : !pf
// x[0] = x01123 + x01
%x0_new = field.add %x01123, %x01 : !pf
// x[1] = x01123 + 2*x[2]
%x1_new = field.add %x01123, %x22 : !pf
// x[2] = x01233 + x23
%x2_new = field.add %x01233, %x23 : !pf
// x[3] = x01233 + 2*x[0]
%x3_new = field.add %x01233, %x00 : !pf
// Store the sum in all output positions
memref.store %x0_new, %state[%c0] : memref<4x!pf, strided<[1], offset: ?>>
memref.store %x1_new, %state[%c1] : memref<4x!pf, strided<[1], offset: ?>>
memref.store %x2_new, %state[%c2] : memref<4x!pf, strided<[1], offset: ?>>
memref.store %x3_new, %state[%c3] : memref<4x!pf, strided<[1], offset: ?>>
return
}
func.func @mds_light_permutation(%state: !state) {
// First, apply M_4 to each consecutive four elements of the state
// This replaces each x_i with x_i'
affine.for %chunk_idx = 0 to 4 {
// Calculate offset for this chunk
%x0 = affine.load %state[%chunk_idx * 4] : !state
%x1 = affine.load %state[%chunk_idx * 4 + 1] : !state
%x01 = field.add %x0, %x1 : !pf
%x2 = affine.load %state[%chunk_idx * 4 + 2] : !state
%x3 = affine.load %state[%chunk_idx * 4 + 3] : !state
%x23 = field.add %x2, %x3 : !pf
%x0123 = field.add %x01, %x23 : !pf
%x01123 = field.add %x0123, %x1 : !pf
%x01233 = field.add %x0123, %x3 : !pf
%x00 = field.double %x0 : !pf
%x22 = field.double %x2 : !pf
// x[0] = x01123 + x01
%x0_new = field.add %x01123, %x01 : !pf
// x[1] = x01123 + 2*x[2]
%x1_new = field.add %x01123, %x22 : !pf
// x[2] = x01233 + x23
%x2_new = field.add %x01233, %x23 : !pf
// x[3] = x01233 + 2*x[0]
%x3_new = field.add %x01233, %x00 : !pf
// Store the sum in all output positions
affine.store %x0_new, %state[%chunk_idx * 4] : !state
affine.store %x1_new, %state[%chunk_idx * 4 + 1] : !state
affine.store %x2_new, %state[%chunk_idx * 4 + 2] : !state
affine.store %x3_new, %state[%chunk_idx * 4 + 3] : !state
}
// Now apply the outer circulant matrix
// Precompute the four sums of every four elements
// Compute sums: sums[k] = sum of state[j + k] for j = 0, 4, 8, 12
%sums = memref.alloca() : memref<4x!pf>
affine.for %k = 0 to 4 {
%val0 = affine.load %state[%k] : !state
%val1 = affine.load %state[%k + 4] : !state
%val2 = affine.load %state[%k + 8] : !state
%val3 = affine.load %state[%k + 12] : !state
%sum01 = field.add %val0, %val1 : !pf
%sum23 = field.add %val2, %val3 : !pf
%new_sum = field.add %sum01, %sum23 : !pf
affine.store %new_sum, %sums[%k] : memref<4x!pf>
}
// Apply the formula: y_i = x_i' + sums[i % 4]
affine.for %i = 0 to 4 {
%val0 = affine.load %state[%i] : !state
%val1 = affine.load %state[%i + 4] : !state
%val2 = affine.load %state[%i + 8] : !state
%val3 = affine.load %state[%i + 12] : !state
%sum = affine.load %sums[%i] : memref<4x!pf>
%sum0 = field.add %val0, %sum : !pf
%sum1 = field.add %val1, %sum : !pf
%sum2 = field.add %val2, %sum : !pf
%sum3 = field.add %val3, %sum : !pf
affine.store %sum0, %state[%i] : !state
affine.store %sum1, %state[%i + 4] : !state
affine.store %sum2, %state[%i + 8] : !state
affine.store %sum3, %state[%i + 12] : !state
}
return
}
// Internal layer matrix multiplication
func.func @internal_layer_mat_mul(%state: !state, %sum: !pf) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c5 = arith.constant 5 : index
%c6 = arith.constant 6 : index
%c7 = arith.constant 7 : index
%c8 = arith.constant 8 : index
%c9 = arith.constant 9 : index
%c10 = arith.constant 10 : index
%c11 = arith.constant 11 : index
%c12 = arith.constant 12 : index
%c13 = arith.constant 13 : index
%c14 = arith.constant 14 : index
%c15 = arith.constant 15 : index
// Precompute powers of 2 inverses using powui
// [-2, 1, 2, 1/2, 3, 4, -1/2, -3, -4, 1/2^8, 1/4, 1/8, 1/2^27, -1/2^8, -1/16, -1/2^27]
%inv_two = field.constant 134217727 : !pf
%inv_four = field.constant 1073741824 : !pf
%inv_eight = field.constant 536870912 : !pf
%inv_sixteen = field.constant 268435456 : !pf
%inv_256 = field.constant 16777216 : !pf
%inv_2_27 = field.constant 32 : !pf
// state[1] += sum
%s1 = memref.load %state[%c1] : !state
%new_s1 = field.add %s1, %sum : !pf
memref.store %new_s1, %state[%c1] : !state
// state[2] = state[2].double() + sum
%s2 = memref.load %state[%c2] : !state
%s2_double = field.double %s2 : !pf
%new_s2 = field.add %s2_double, %sum : !pf
memref.store %new_s2, %state[%c2] : !state
// state[3] = state[3].halve() + sum
%s3 = memref.load %state[%c3] : !state
%s3_halve = field.mul %s3, %inv_two : !pf
%new_s3 = field.add %s3_halve, %sum : !pf
memref.store %new_s3, %state[%c3] : !state
// state[4] = sum + state[4].double() + state[4]
%s4 = memref.load %state[%c4] : !state
%s4_double = field.double %s4 : !pf
%s4_sum = field.add %s4_double, %s4 : !pf
%new_s4 = field.add %sum, %s4_sum : !pf
memref.store %new_s4, %state[%c4] : !state
// state[5] = sum + state[5].double().double()
%s5 = memref.load %state[%c5] : !state
%s5_double = field.double %s5 : !pf
%s5_double_double = field.double %s5_double : !pf
%new_s5 = field.add %sum, %s5_double_double : !pf
memref.store %new_s5, %state[%c5] : !state
// state[6] = sum - state[6].halve()
%s6 = memref.load %state[%c6] : !state
%s6_halve = field.mul %s6, %inv_two : !pf
%new_s6 = field.sub %sum, %s6_halve : !pf
memref.store %new_s6, %state[%c6] : !state
// state[7] = sum - (state[7].double() + state[7])
%s7 = memref.load %state[%c7] : !state
%s7_double = field.double %s7 : !pf
%s7_sum = field.add %s7_double, %s7 : !pf
%new_s7 = field.sub %sum, %s7_sum : !pf
memref.store %new_s7, %state[%c7] : !state
// state[8] = sum - state[8].double().double()
%s8 = memref.load %state[%c8] : !state
%s8_double = field.double %s8 : !pf
%s8_double_double = field.double %s8_double : !pf
%new_s8 = field.sub %sum, %s8_double_double : !pf
memref.store %new_s8, %state[%c8] : !state
// state[9] = state[9] * inv_256 + sum
%s9 = memref.load %state[%c9] : !state
%s9_div_256 = field.mul %s9, %inv_256 : !pf
%new_s9 = field.add %s9_div_256, %sum : !pf
memref.store %new_s9, %state[%c9] : !state
// state[10] = state[10] * inv_four + sum
%s10 = memref.load %state[%c10] : !state
%s10_div_4 = field.mul %s10, %inv_four : !pf
%new_s10 = field.add %s10_div_4, %sum : !pf
memref.store %new_s10, %state[%c10] : !state
// state[11] = state[11] * inv_eight + sum
%s11 = memref.load %state[%c11] : !state
%s11_div_8 = field.mul %s11, %inv_eight : !pf
%new_s11 = field.add %s11_div_8, %sum : !pf
memref.store %new_s11, %state[%c11] : !state
// state[12] = state[12] * inv_2_27 + sum
%s12 = memref.load %state[%c12] : !state
%s12_div_27 = field.mul %s12, %inv_2_27 : !pf
%new_s12 = field.add %s12_div_27, %sum : !pf
memref.store %new_s12, %state[%c12] : !state
// state[13] = sum - state[13] * inv_256
%s13 = memref.load %state[%c13] : !state
%s13_div_256 = field.mul %s13, %inv_256 : !pf
%new_s13 = field.sub %sum, %s13_div_256 : !pf
memref.store %new_s13, %state[%c13] : !state
// state[14] = sum - state[14] * inv_sixteen
%s14 = memref.load %state[%c14] : !state
%s14_div_16 = field.mul %s14, %inv_sixteen : !pf
%new_s14 = field.sub %sum, %s14_div_16 : !pf
memref.store %new_s14, %state[%c14] : !state
// state[15] = sum - state[15] * inv_2_27
%s15 = memref.load %state[%c15] : !state
%s15_div_27 = field.mul %s15, %inv_2_27 : !pf
%new_s15 = field.sub %sum, %s15_div_27 : !pf
memref.store %new_s15, %state[%c15] : !state
return
}
// Internal layer: permutation (add RC to first element, S-box first, internal diffusion)
func.func @permute_state(%state: !state) {
// Convert to memref for in-place operations
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c5 = arith.constant 5 : index
%c6 = arith.constant 6 : index
%c7 = arith.constant 7 : index
%c8 = arith.constant 8 : index
%c9 = arith.constant 9 : index
%c10 = arith.constant 10 : index
%c11 = arith.constant 11 : index
%c12 = arith.constant 12 : index
%c13 = arith.constant 13 : index
%c14 = arith.constant 14 : index
%c15 = arith.constant 15 : index
// BABYBEAR_RC16_INTERNAL (13 scalar constants)
%rc_internal = arith.constant dense<[250494022, 528496384, 1472966118, 977089650, 1885890237, 1094557811, 147492661, 664163003, 398852570, 336233633, 1628648315, 888594966, 586791090]> : tensor<13xi32>
%rc_internal_mont = field.bitcast %rc_internal : tensor<13xi32> -> tensor<13x!pf>
// For each internal constant: add RC and S-box to first element, then apply matrix multiplication
affine.for %round = 0 to 13 {
// Get current round constant via tensor.extract
%rc = tensor.extract %rc_internal_mont[%round] : tensor<13x!pf>
// Add RC and apply S-box to first element
%s0 = memref.load %state[%c0] : !state
%elem0 = func.call @add_rc_and_sbox(%s0, %rc) : (!pf, !pf) -> !pf
// Compute sum of all elements using affine.for
// NOTE: this is extremely slow, so we manually add them.
// %zero = field.constant 0 : !pf
// %sum = affine.for %i = 0 to 16 iter_args(%acc = %zero) -> (!pf) {
// %elem = tensor.extract %t[%i] : tensor<16x!pf>
// %new_acc = field.add %acc, %elem : !pf
// affine.yield %new_acc : !pf
// }
%elem1 = memref.load %state[%c1] : memref<16x!pf>
%elem2 = memref.load %state[%c2] : memref<16x!pf>
%elem3 = memref.load %state[%c3] : memref<16x!pf>
%elem4 = memref.load %state[%c4] : memref<16x!pf>
%elem5 = memref.load %state[%c5] : memref<16x!pf>
%elem6 = memref.load %state[%c6] : memref<16x!pf>
%elem7 = memref.load %state[%c7] : memref<16x!pf>
%elem8 = memref.load %state[%c8] : memref<16x!pf>
%elem9 = memref.load %state[%c9] : memref<16x!pf>
%elem10 = memref.load %state[%c10] : memref<16x!pf>
%elem11 = memref.load %state[%c11] : memref<16x!pf>
%elem12 = memref.load %state[%c12] : memref<16x!pf>
%elem13 = memref.load %state[%c13] : memref<16x!pf>
%elem14 = memref.load %state[%c14] : memref<16x!pf>
%elem15 = memref.load %state[%c15] : memref<16x!pf>
// --- Step 2: Sum the elements using a reduction tree ---
// This structure allows for maximum parallel execution by the CPU.
// Level 1 (8 parallel additions)
%sum2_3 = field.add %elem2, %elem3 : !pf
%sum4_5 = field.add %elem4, %elem5 : !pf
%sum6_7 = field.add %elem6, %elem7 : !pf
%sum8_9 = field.add %elem8, %elem9 : !pf
%sum10_11 = field.add %elem10, %elem11 : !pf
%sum12_13 = field.add %elem12, %elem13 : !pf
%sum14_15 = field.add %elem14, %elem15 : !pf
// Level 2 (4 parallel additions)
%sum1_3 = field.add %elem1, %sum2_3 : !pf
%sum4_7 = field.add %sum4_5, %sum6_7 : !pf
%sum8_11 = field.add %sum8_9, %sum10_11 : !pf
%sum12_15 = field.add %sum12_13, %sum14_15 : !pf
// Level 3 (2 parallel additions)
%sum1_7 = field.add %sum1_3, %sum4_7 : !pf
%sum8_15 = field.add %sum8_11, %sum12_15 : !pf
// Level 4 (Partial sum)
%partial_sum = field.add %sum1_7, %sum8_15 : !pf
%total_sum = field.add %partial_sum, %elem0 : !pf
%new_s0 = field.sub %partial_sum, %elem0 : !pf
memref.store %new_s0, %state[%c0] : !state
// Apply internal layer matrix multiplication
func.call @internal_layer_mat_mul(%state, %total_sum) : (!state, !pf) -> ()
}
return
}
// External layer: terminal permutation (4 rounds: add RC, S-box, MDS)
func.func @permute_state_terminal(%state: !state) {
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
// BABYBEAR_RC16_EXTERNAL_FINAL (4 rounds x 16 constants)
%rc_external_const = arith.constant dense<[
[999830298, 304461056, 552699684, 450698925, 667466464, 1736509752, 1327760865, 1153241151, 816675655, 1076172858, 1914832527, 1668723429, 1365579850, 975704528, 1031625628, 1393317533],
[1554700828, 1023828605, 1610378860, 347744760, 1909572073, 739227895, 428565985, 633143046, 121797685, 94048546, 1369350241, 1250010422, 114268841, 515033604, 49052844, 1962329907],
[1380892638, 1860017417, 64711457, 9758460, 1681838395, 710850601, 1020228997, 1414164790, 1531515535, 36158805, 713604525, 89935127, 1870801994, 395985906, 1122769045, 1760811055],
[819787042, 134654834, 1755145179, 18433016, 1701878989, 1782339297, 1483861396, 962480061, 1857590724, 222440409, 63223417, 515206622, 1348364213, 973414686, 1591066884, 705852913]
]> : tensor<4x16xi32>
%rc_external_final = field.bitcast %rc_external_const : tensor<4x16xi32> -> tensor<4x16x!pf>
%state_tensor = bufferization.to_tensor %state restrict : memref<16x!pf> to tensor<16x!pf>
// Loop through 4 rounds of external terminal permutation
affine.for %round = 0 to 4 {
affine.for %i = 0 to 16 {
%s = tensor.extract %state_tensor[%i] : tensor<16x!pf>
%c = tensor.extract %rc_external_final[%round, %i] : tensor<4x16x!pf>
%sbox = func.call @add_rc_and_sbox(%s, %c) : (!pf, !pf) -> !pf
affine.store %sbox, %state[%i] : !state
}
// Apply MDS light permutation (in-place)
func.call @mds_light_permutation(%state) : (!state) -> ()
}
return
}
// External layer: initial permutation (MDS light + terminal permutation)
func.func @permute_state_initial(%state: !state) {
// First apply MDS light permutation
func.call @mds_light_permutation(%state) : (!state) -> ()
// Round constants for 16-width Poseidon2 on BabyBear
// BABYBEAR_RC16_EXTERNAL_INITIAL (4 rounds x 16 constants)
%rc_external_const = arith.constant dense<[
[1582131512, 1899519471, 1641921850, 462688640, 1293997949, 1380417575, 1932416963, 283521298, 1016708647, 35751290, 1270782647, 851730739, 795004022, 929571430, 523703523, 1593957757],
[895976710, 1742343460, 917700746, 1516725708, 1170237629, 785693164, 613651155, 352999196, 678775274, 1005433272, 1704854670, 1174551920, 508930349, 530338447, 1327158816, 1417652352],
[1153538870, 583201050, 397833841, 1440603828, 454600685, 174490638, 171758601, 1998476616, 1403697810, 1807736944, 450348306, 1458895865, 787037868, 1063762964, 1987002214, 481645916],
[1231767638, 1323639433, 238360103, 2012412459, 1024945356, 1108359895, 1284135849, 606928406, 1021455954, 719347978, 659671051, 769588663, 805534062, 592213995, 1752728055, 663410947]
]> : tensor<4x16xi32>
%rc_external_final = field.bitcast %rc_external_const : tensor<4x16xi32> -> tensor<4x16x!pf>
%state_tensor = bufferization.to_tensor %state restrict : memref<16x!pf> to tensor<16x!pf>
// Then apply terminal permutation with initial external constants
// Loop through 4 rounds of external terminal permutation
affine.for %round = 0 to 4 {
affine.for %i = 0 to 16 {
%s = tensor.extract %state_tensor[%i] : tensor<16x!pf>
%c = tensor.extract %rc_external_final[%round, %i] : tensor<4x16x!pf>
%sbox = func.call @add_rc_and_sbox(%s, %c) : (!pf, !pf) -> !pf
affine.store %sbox, %state[%i] : !state
}
// Apply MDS light permutation (in-place)
func.call @mds_light_permutation(%state) : (!state) -> ()
}
return
}
// Complete Poseidon2 permutation
func.func @poseidon2_permute(%state: !state) {
func.call @permute_state_initial(%state) : (!state) -> ()
func.call @permute_state(%state) : (!state) -> ()
func.call @permute_state_terminal(%state) : (!state) -> ()
return
}
func.func @permute_10000(%state : !state) attributes { llvm.emit_c_interface } {
affine.for %i = 0 to 10000 {
func.call @poseidon2_permute(%state) : (!state) -> ()
}
return
}
```
---
## Python code
```python
# This is a python code for https://github.com/tcoratger/whir-p3/blob/main/src/poly/dense.rs
import jax
import jax.random as rnd
import jax.numpy as jnp
import jax.lax as lax
import numpy.random as nprnd
### is_zero
@jax.jit
def is_zero(poly):
# A zero polynomial is all zeros, or an empty array
poly = jnp.array(poly)
return (poly.size == 0) | jnp.all(poly == 0)
### evaluate - use as is
@jax.jit
def evaluate(poly, x):
poly = jnp.array(poly)
return jnp.polyval(poly, x)
### random
@jax.jit
def random(key, poly):
# Generate random real numbers from normal, then round to nearest integer for "discrete real numbers"
# (Could also use randint but that produces ints, not real/float dtype)
poly = jnp.array(poly)
rand = rnd.uniform(key, poly.shape, minval=0., maxval=999)
return jnp.round(rand)
### lagrange_interpolation
@jax.jit
def has_dup_x_jit(xs):
xs_sorted = jnp.sort(xs)
return jnp.any(xs_sorted[1:] == xs_sorted[:-1])
@jax.jit
def lagrange_interpolation(values):
size = len(values)
if size == 0:
return jnp.zeros(0)
# Unzip the list of (x, y) pairs to separate arrays
xs, ys = zip(*values)
xs = jnp.array(xs)
ys = jnp.array(ys)
# Check for duplicate x-coordinates using JAX ops and return None if found
def dup_case(_):
return jnp.zeros(size) # or whatever you want to represent None
def no_dup_case(_):
# Normal interpolation computation here
# (return polynomial coefficients or whatever you compute)
monomial_init = jnp.zeros(size)
def body(i, acc):
result_poly, basis_poly = acc
current_y = jnp.polyval(result_poly, xs[i])
delta = ys[i] - current_y
basis_eval = jnp.polyval(basis_poly, xs[i])
c_i = delta / basis_eval
# element-wise multiplication
term = basis_poly * c_i
result_poly = result_poly + term
monomial = monomial_init.at[size - 1].set(-xs[i])
monomial_all = monomial.at[size - 2].set(1)
# After i steps, B(x) = (x - x_0)(x - x_1)...(x - x_{i-1})
basis_poly = jnp.polymul(basis_poly, monomial_all)
basis_poly = basis_poly[-size:]
return (result_poly, basis_poly)
# The result polynomial P(x) starts at zero and is updated iteratively.
zero_poly = jnp.zeros(size)
# The basis polynomial B(x) starts at 1.
basis_poly = zero_poly.at[size - 1].set(1)
result_poly, basis_poly = jax.lax.fori_loop(0, size, body, (zero_poly, basis_poly))
return result_poly
return lax.cond(has_dup_x_jit(xs), dup_case, no_dup_case, operand=None)
### add - use as is
@jax.jit
def add(a, b):
jax_array_a = jnp.array(a)
jax_array_b = jnp.array(b)
return jnp.polyadd(jax_array_a, jax_array_b)
### mul - use as is
@jax.jit
def mul(a, b):
jax_array_a = jnp.array(a)
jax_array_b = jnp.array(b)
return jnp.polymul(jax_array_a, jax_array_b)
```