## Poseidon2 in MLIR In hash-based STARK systems, the hash function dominates performance — and Poseidon2 is the most commonly used hash. Therefore, improving Poseidon2 performance directly translates to faster hash-based STARKs. Over the past two weeks, one of our team members implemented [Poseidon2 in MLIR](#MLIR-code), and the result outperformed **Plonky3** on both AMD and Apple Silicon environments: | | Plonky3 | ZKIR | Ratio (Plonky3 / ZKIR) | | ----------- | ------- | ------- | ---------------------- | | AMD 9950X3d | 6.12 ms | 5.70 ms | 1.07 | | Mac M4 Pro | 7.76 ms | 5.80 ms | 1.34 | *(The table shows the total time for 10,000 hash operations. The benchmark used **BabyBear**, but similar results are expected for **KoalaBear**.)* We also benchmarked with **Packed PrimeField**, where performance is currently slower: | | Plonky3 | ZKIR | Ratio (Plonky3 / ZKIR) | | ----------- | ------- | ------- | ---------------------- | | AMD 9950X3d | 7.44 ms | 9.30 ms | 0.8 | This slowdown is due to **engineering optimizations** in Plonky3 that we haven’t yet implemented. For example, consider computing $(x - y)^2$ for $0 \le x, y < P$: 1. $v = P + x - y$ 2. $\mathsf{reduce}(v)$ 3. $v^2$ 4. $\mathsf{reduce}(v^2)$ $$ \mathsf{reduce}(v) = \begin{cases} v - P & \text{if } v \ge P \\ v & \text{otherwise} \end{cases} $$ Plonky3 skips step (2) and instead performs: 1. $v = x - y$ 3. $v^2$ 4. $\mathsf{reduce}(v^2)$ Since $v$ after step (1) lies within $-P < v < P$, omitting the first reduction does not affect correctness — $v^2$ still satisfies $0 \le v^2 < P^2$. We’re currently implementing similar optimizations in ZKIR and will share the results soon. --- ## WHIR-p3 in Python [XLA](https://openxla.org/xla) is a runtime compiler used in deep learning. It defines **[HLO (High-Level Operations)](https://openxla.org/xla/operation_semantics)** and allows compilation from **Python → [Jaxpr](https://docs.jax.dev/en/latest/jaxpr.html) → HLO → [MLIR](https://mlir.llvm.org/) → LLVM IR → binary** via the [JAX](https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html) library. Our team has been **developing a ZK-specific runtime compiler called ZKX** by forking XLA. Using ZKX, we already achieved **[SOTA Groth16 performance on CPU](https://narrow-cello-dab.notion.site/EN-The-Fastest-Groth16-CPU-Prover-in-the-World-22cc62052b8e80fab91cd531a9ca70a0?pvs=73)**. The next goal is to **express [WHIR-p3](https://github.com/tcoratger/whir-p3) in Python**. To that end, we are **forking and adapting JAX** for ZKX compatibility. In parallel, we’re porting **WHIR-p3 (including Poseidon2)** to Python using JAX. For example, the Rust implementation of [poly/dense.rs](https://github.com/tcoratger/whir-p3/blob/main/src/poly/dense.rs) can be expressed in JAX as [follows](#Python-code). --- ## MLIR code ```mlir // Poseidon2 utility functions for BabyBear field // Based on Plonky3 implementation: https://github.com/Plonky3/Plonky3 !pf = !field.pf<2013265921 : i32, true> !pf_std = !field.pf<2013265921 : i32> !state = memref<16x!pf> !state_std = memref<16x!pf_std> func.func @add_rc_and_sbox(%var: !pf, %c: !pf) -> !pf { %c7 = arith.constant 7 : i32 %sum = field.add %var, %c : !pf %sum_sq = field.square %sum : !pf %sum_sq_sq = field.square %sum_sq : !pf %sum_cu = field.mul %sum, %sum_sq : !pf %sum_exp7 = field.mul %sum_sq_sq, %sum_cu : !pf return %sum_exp7 : !pf } // In-place version of apply_mat4 using memref // Optimally, we just want to do matmul which then lowers to the following // sequence but at this moment, it seems hard to achieve. Therefore, we just use field addition instead of matrix multiplication. func.func @apply_mat4(%state: memref<4x!pf, strided<[1], offset: ?>>) { // Load the 4x4 MDS matrix (no changes here) %matrix = arith.constant dense<[ [2, 3, 1, 1], [1, 2, 3, 1], [1, 1, 2, 3], [3, 1, 1, 2] ]> : tensor<4x4xi32> %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index // Allocate a temporary output buffer %output = memref.alloca() : memref<4x!pf> // Compute the sum of all 4 elements %x0 = memref.load %state[%c0] : memref<4x!pf, strided<[1], offset: ?>> %x1 = memref.load %state[%c1] : memref<4x!pf, strided<[1], offset: ?>> %x2 = memref.load %state[%c2] : memref<4x!pf, strided<[1], offset: ?>> %x3 = memref.load %state[%c3] : memref<4x!pf, strided<[1], offset: ?>> %x01 = field.add %x0, %x1 : !pf %x23 = field.add %x2, %x3 : !pf %x0123 = field.add %x01, %x23 : !pf %x01123 = field.add %x0123, %x1 : !pf %x01233 = field.add %x0123, %x3 : !pf %x00 = field.double %x0 : !pf %x22 = field.double %x2 : !pf // x[0] = x01123 + x01 %x0_new = field.add %x01123, %x01 : !pf // x[1] = x01123 + 2*x[2] %x1_new = field.add %x01123, %x22 : !pf // x[2] = x01233 + x23 %x2_new = field.add %x01233, %x23 : !pf // x[3] = x01233 + 2*x[0] %x3_new = field.add %x01233, %x00 : !pf // Store the sum in all output positions memref.store %x0_new, %state[%c0] : memref<4x!pf, strided<[1], offset: ?>> memref.store %x1_new, %state[%c1] : memref<4x!pf, strided<[1], offset: ?>> memref.store %x2_new, %state[%c2] : memref<4x!pf, strided<[1], offset: ?>> memref.store %x3_new, %state[%c3] : memref<4x!pf, strided<[1], offset: ?>> return } func.func @mds_light_permutation(%state: !state) { // First, apply M_4 to each consecutive four elements of the state // This replaces each x_i with x_i' affine.for %chunk_idx = 0 to 4 { // Calculate offset for this chunk %x0 = affine.load %state[%chunk_idx * 4] : !state %x1 = affine.load %state[%chunk_idx * 4 + 1] : !state %x01 = field.add %x0, %x1 : !pf %x2 = affine.load %state[%chunk_idx * 4 + 2] : !state %x3 = affine.load %state[%chunk_idx * 4 + 3] : !state %x23 = field.add %x2, %x3 : !pf %x0123 = field.add %x01, %x23 : !pf %x01123 = field.add %x0123, %x1 : !pf %x01233 = field.add %x0123, %x3 : !pf %x00 = field.double %x0 : !pf %x22 = field.double %x2 : !pf // x[0] = x01123 + x01 %x0_new = field.add %x01123, %x01 : !pf // x[1] = x01123 + 2*x[2] %x1_new = field.add %x01123, %x22 : !pf // x[2] = x01233 + x23 %x2_new = field.add %x01233, %x23 : !pf // x[3] = x01233 + 2*x[0] %x3_new = field.add %x01233, %x00 : !pf // Store the sum in all output positions affine.store %x0_new, %state[%chunk_idx * 4] : !state affine.store %x1_new, %state[%chunk_idx * 4 + 1] : !state affine.store %x2_new, %state[%chunk_idx * 4 + 2] : !state affine.store %x3_new, %state[%chunk_idx * 4 + 3] : !state } // Now apply the outer circulant matrix // Precompute the four sums of every four elements // Compute sums: sums[k] = sum of state[j + k] for j = 0, 4, 8, 12 %sums = memref.alloca() : memref<4x!pf> affine.for %k = 0 to 4 { %val0 = affine.load %state[%k] : !state %val1 = affine.load %state[%k + 4] : !state %val2 = affine.load %state[%k + 8] : !state %val3 = affine.load %state[%k + 12] : !state %sum01 = field.add %val0, %val1 : !pf %sum23 = field.add %val2, %val3 : !pf %new_sum = field.add %sum01, %sum23 : !pf affine.store %new_sum, %sums[%k] : memref<4x!pf> } // Apply the formula: y_i = x_i' + sums[i % 4] affine.for %i = 0 to 4 { %val0 = affine.load %state[%i] : !state %val1 = affine.load %state[%i + 4] : !state %val2 = affine.load %state[%i + 8] : !state %val3 = affine.load %state[%i + 12] : !state %sum = affine.load %sums[%i] : memref<4x!pf> %sum0 = field.add %val0, %sum : !pf %sum1 = field.add %val1, %sum : !pf %sum2 = field.add %val2, %sum : !pf %sum3 = field.add %val3, %sum : !pf affine.store %sum0, %state[%i] : !state affine.store %sum1, %state[%i + 4] : !state affine.store %sum2, %state[%i + 8] : !state affine.store %sum3, %state[%i + 12] : !state } return } // Internal layer matrix multiplication func.func @internal_layer_mat_mul(%state: !state, %sum: !pf) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index %c4 = arith.constant 4 : index %c5 = arith.constant 5 : index %c6 = arith.constant 6 : index %c7 = arith.constant 7 : index %c8 = arith.constant 8 : index %c9 = arith.constant 9 : index %c10 = arith.constant 10 : index %c11 = arith.constant 11 : index %c12 = arith.constant 12 : index %c13 = arith.constant 13 : index %c14 = arith.constant 14 : index %c15 = arith.constant 15 : index // Precompute powers of 2 inverses using powui // [-2, 1, 2, 1/2, 3, 4, -1/2, -3, -4, 1/2^8, 1/4, 1/8, 1/2^27, -1/2^8, -1/16, -1/2^27] %inv_two = field.constant 134217727 : !pf %inv_four = field.constant 1073741824 : !pf %inv_eight = field.constant 536870912 : !pf %inv_sixteen = field.constant 268435456 : !pf %inv_256 = field.constant 16777216 : !pf %inv_2_27 = field.constant 32 : !pf // state[1] += sum %s1 = memref.load %state[%c1] : !state %new_s1 = field.add %s1, %sum : !pf memref.store %new_s1, %state[%c1] : !state // state[2] = state[2].double() + sum %s2 = memref.load %state[%c2] : !state %s2_double = field.double %s2 : !pf %new_s2 = field.add %s2_double, %sum : !pf memref.store %new_s2, %state[%c2] : !state // state[3] = state[3].halve() + sum %s3 = memref.load %state[%c3] : !state %s3_halve = field.mul %s3, %inv_two : !pf %new_s3 = field.add %s3_halve, %sum : !pf memref.store %new_s3, %state[%c3] : !state // state[4] = sum + state[4].double() + state[4] %s4 = memref.load %state[%c4] : !state %s4_double = field.double %s4 : !pf %s4_sum = field.add %s4_double, %s4 : !pf %new_s4 = field.add %sum, %s4_sum : !pf memref.store %new_s4, %state[%c4] : !state // state[5] = sum + state[5].double().double() %s5 = memref.load %state[%c5] : !state %s5_double = field.double %s5 : !pf %s5_double_double = field.double %s5_double : !pf %new_s5 = field.add %sum, %s5_double_double : !pf memref.store %new_s5, %state[%c5] : !state // state[6] = sum - state[6].halve() %s6 = memref.load %state[%c6] : !state %s6_halve = field.mul %s6, %inv_two : !pf %new_s6 = field.sub %sum, %s6_halve : !pf memref.store %new_s6, %state[%c6] : !state // state[7] = sum - (state[7].double() + state[7]) %s7 = memref.load %state[%c7] : !state %s7_double = field.double %s7 : !pf %s7_sum = field.add %s7_double, %s7 : !pf %new_s7 = field.sub %sum, %s7_sum : !pf memref.store %new_s7, %state[%c7] : !state // state[8] = sum - state[8].double().double() %s8 = memref.load %state[%c8] : !state %s8_double = field.double %s8 : !pf %s8_double_double = field.double %s8_double : !pf %new_s8 = field.sub %sum, %s8_double_double : !pf memref.store %new_s8, %state[%c8] : !state // state[9] = state[9] * inv_256 + sum %s9 = memref.load %state[%c9] : !state %s9_div_256 = field.mul %s9, %inv_256 : !pf %new_s9 = field.add %s9_div_256, %sum : !pf memref.store %new_s9, %state[%c9] : !state // state[10] = state[10] * inv_four + sum %s10 = memref.load %state[%c10] : !state %s10_div_4 = field.mul %s10, %inv_four : !pf %new_s10 = field.add %s10_div_4, %sum : !pf memref.store %new_s10, %state[%c10] : !state // state[11] = state[11] * inv_eight + sum %s11 = memref.load %state[%c11] : !state %s11_div_8 = field.mul %s11, %inv_eight : !pf %new_s11 = field.add %s11_div_8, %sum : !pf memref.store %new_s11, %state[%c11] : !state // state[12] = state[12] * inv_2_27 + sum %s12 = memref.load %state[%c12] : !state %s12_div_27 = field.mul %s12, %inv_2_27 : !pf %new_s12 = field.add %s12_div_27, %sum : !pf memref.store %new_s12, %state[%c12] : !state // state[13] = sum - state[13] * inv_256 %s13 = memref.load %state[%c13] : !state %s13_div_256 = field.mul %s13, %inv_256 : !pf %new_s13 = field.sub %sum, %s13_div_256 : !pf memref.store %new_s13, %state[%c13] : !state // state[14] = sum - state[14] * inv_sixteen %s14 = memref.load %state[%c14] : !state %s14_div_16 = field.mul %s14, %inv_sixteen : !pf %new_s14 = field.sub %sum, %s14_div_16 : !pf memref.store %new_s14, %state[%c14] : !state // state[15] = sum - state[15] * inv_2_27 %s15 = memref.load %state[%c15] : !state %s15_div_27 = field.mul %s15, %inv_2_27 : !pf %new_s15 = field.sub %sum, %s15_div_27 : !pf memref.store %new_s15, %state[%c15] : !state return } // Internal layer: permutation (add RC to first element, S-box first, internal diffusion) func.func @permute_state(%state: !state) { // Convert to memref for in-place operations %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index %c4 = arith.constant 4 : index %c5 = arith.constant 5 : index %c6 = arith.constant 6 : index %c7 = arith.constant 7 : index %c8 = arith.constant 8 : index %c9 = arith.constant 9 : index %c10 = arith.constant 10 : index %c11 = arith.constant 11 : index %c12 = arith.constant 12 : index %c13 = arith.constant 13 : index %c14 = arith.constant 14 : index %c15 = arith.constant 15 : index // BABYBEAR_RC16_INTERNAL (13 scalar constants) %rc_internal = arith.constant dense<[250494022, 528496384, 1472966118, 977089650, 1885890237, 1094557811, 147492661, 664163003, 398852570, 336233633, 1628648315, 888594966, 586791090]> : tensor<13xi32> %rc_internal_mont = field.bitcast %rc_internal : tensor<13xi32> -> tensor<13x!pf> // For each internal constant: add RC and S-box to first element, then apply matrix multiplication affine.for %round = 0 to 13 { // Get current round constant via tensor.extract %rc = tensor.extract %rc_internal_mont[%round] : tensor<13x!pf> // Add RC and apply S-box to first element %s0 = memref.load %state[%c0] : !state %elem0 = func.call @add_rc_and_sbox(%s0, %rc) : (!pf, !pf) -> !pf // Compute sum of all elements using affine.for // NOTE: this is extremely slow, so we manually add them. // %zero = field.constant 0 : !pf // %sum = affine.for %i = 0 to 16 iter_args(%acc = %zero) -> (!pf) { // %elem = tensor.extract %t[%i] : tensor<16x!pf> // %new_acc = field.add %acc, %elem : !pf // affine.yield %new_acc : !pf // } %elem1 = memref.load %state[%c1] : memref<16x!pf> %elem2 = memref.load %state[%c2] : memref<16x!pf> %elem3 = memref.load %state[%c3] : memref<16x!pf> %elem4 = memref.load %state[%c4] : memref<16x!pf> %elem5 = memref.load %state[%c5] : memref<16x!pf> %elem6 = memref.load %state[%c6] : memref<16x!pf> %elem7 = memref.load %state[%c7] : memref<16x!pf> %elem8 = memref.load %state[%c8] : memref<16x!pf> %elem9 = memref.load %state[%c9] : memref<16x!pf> %elem10 = memref.load %state[%c10] : memref<16x!pf> %elem11 = memref.load %state[%c11] : memref<16x!pf> %elem12 = memref.load %state[%c12] : memref<16x!pf> %elem13 = memref.load %state[%c13] : memref<16x!pf> %elem14 = memref.load %state[%c14] : memref<16x!pf> %elem15 = memref.load %state[%c15] : memref<16x!pf> // --- Step 2: Sum the elements using a reduction tree --- // This structure allows for maximum parallel execution by the CPU. // Level 1 (8 parallel additions) %sum2_3 = field.add %elem2, %elem3 : !pf %sum4_5 = field.add %elem4, %elem5 : !pf %sum6_7 = field.add %elem6, %elem7 : !pf %sum8_9 = field.add %elem8, %elem9 : !pf %sum10_11 = field.add %elem10, %elem11 : !pf %sum12_13 = field.add %elem12, %elem13 : !pf %sum14_15 = field.add %elem14, %elem15 : !pf // Level 2 (4 parallel additions) %sum1_3 = field.add %elem1, %sum2_3 : !pf %sum4_7 = field.add %sum4_5, %sum6_7 : !pf %sum8_11 = field.add %sum8_9, %sum10_11 : !pf %sum12_15 = field.add %sum12_13, %sum14_15 : !pf // Level 3 (2 parallel additions) %sum1_7 = field.add %sum1_3, %sum4_7 : !pf %sum8_15 = field.add %sum8_11, %sum12_15 : !pf // Level 4 (Partial sum) %partial_sum = field.add %sum1_7, %sum8_15 : !pf %total_sum = field.add %partial_sum, %elem0 : !pf %new_s0 = field.sub %partial_sum, %elem0 : !pf memref.store %new_s0, %state[%c0] : !state // Apply internal layer matrix multiplication func.call @internal_layer_mat_mul(%state, %total_sum) : (!state, !pf) -> () } return } // External layer: terminal permutation (4 rounds: add RC, S-box, MDS) func.func @permute_state_terminal(%state: !state) { %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %c1 = arith.constant 1 : index // BABYBEAR_RC16_EXTERNAL_FINAL (4 rounds x 16 constants) %rc_external_const = arith.constant dense<[ [999830298, 304461056, 552699684, 450698925, 667466464, 1736509752, 1327760865, 1153241151, 816675655, 1076172858, 1914832527, 1668723429, 1365579850, 975704528, 1031625628, 1393317533], [1554700828, 1023828605, 1610378860, 347744760, 1909572073, 739227895, 428565985, 633143046, 121797685, 94048546, 1369350241, 1250010422, 114268841, 515033604, 49052844, 1962329907], [1380892638, 1860017417, 64711457, 9758460, 1681838395, 710850601, 1020228997, 1414164790, 1531515535, 36158805, 713604525, 89935127, 1870801994, 395985906, 1122769045, 1760811055], [819787042, 134654834, 1755145179, 18433016, 1701878989, 1782339297, 1483861396, 962480061, 1857590724, 222440409, 63223417, 515206622, 1348364213, 973414686, 1591066884, 705852913] ]> : tensor<4x16xi32> %rc_external_final = field.bitcast %rc_external_const : tensor<4x16xi32> -> tensor<4x16x!pf> %state_tensor = bufferization.to_tensor %state restrict : memref<16x!pf> to tensor<16x!pf> // Loop through 4 rounds of external terminal permutation affine.for %round = 0 to 4 { affine.for %i = 0 to 16 { %s = tensor.extract %state_tensor[%i] : tensor<16x!pf> %c = tensor.extract %rc_external_final[%round, %i] : tensor<4x16x!pf> %sbox = func.call @add_rc_and_sbox(%s, %c) : (!pf, !pf) -> !pf affine.store %sbox, %state[%i] : !state } // Apply MDS light permutation (in-place) func.call @mds_light_permutation(%state) : (!state) -> () } return } // External layer: initial permutation (MDS light + terminal permutation) func.func @permute_state_initial(%state: !state) { // First apply MDS light permutation func.call @mds_light_permutation(%state) : (!state) -> () // Round constants for 16-width Poseidon2 on BabyBear // BABYBEAR_RC16_EXTERNAL_INITIAL (4 rounds x 16 constants) %rc_external_const = arith.constant dense<[ [1582131512, 1899519471, 1641921850, 462688640, 1293997949, 1380417575, 1932416963, 283521298, 1016708647, 35751290, 1270782647, 851730739, 795004022, 929571430, 523703523, 1593957757], [895976710, 1742343460, 917700746, 1516725708, 1170237629, 785693164, 613651155, 352999196, 678775274, 1005433272, 1704854670, 1174551920, 508930349, 530338447, 1327158816, 1417652352], [1153538870, 583201050, 397833841, 1440603828, 454600685, 174490638, 171758601, 1998476616, 1403697810, 1807736944, 450348306, 1458895865, 787037868, 1063762964, 1987002214, 481645916], [1231767638, 1323639433, 238360103, 2012412459, 1024945356, 1108359895, 1284135849, 606928406, 1021455954, 719347978, 659671051, 769588663, 805534062, 592213995, 1752728055, 663410947] ]> : tensor<4x16xi32> %rc_external_final = field.bitcast %rc_external_const : tensor<4x16xi32> -> tensor<4x16x!pf> %state_tensor = bufferization.to_tensor %state restrict : memref<16x!pf> to tensor<16x!pf> // Then apply terminal permutation with initial external constants // Loop through 4 rounds of external terminal permutation affine.for %round = 0 to 4 { affine.for %i = 0 to 16 { %s = tensor.extract %state_tensor[%i] : tensor<16x!pf> %c = tensor.extract %rc_external_final[%round, %i] : tensor<4x16x!pf> %sbox = func.call @add_rc_and_sbox(%s, %c) : (!pf, !pf) -> !pf affine.store %sbox, %state[%i] : !state } // Apply MDS light permutation (in-place) func.call @mds_light_permutation(%state) : (!state) -> () } return } // Complete Poseidon2 permutation func.func @poseidon2_permute(%state: !state) { func.call @permute_state_initial(%state) : (!state) -> () func.call @permute_state(%state) : (!state) -> () func.call @permute_state_terminal(%state) : (!state) -> () return } func.func @permute_10000(%state : !state) attributes { llvm.emit_c_interface } { affine.for %i = 0 to 10000 { func.call @poseidon2_permute(%state) : (!state) -> () } return } ``` --- ## Python code ```python # This is a python code for https://github.com/tcoratger/whir-p3/blob/main/src/poly/dense.rs import jax import jax.random as rnd import jax.numpy as jnp import jax.lax as lax import numpy.random as nprnd ### is_zero @jax.jit def is_zero(poly): # A zero polynomial is all zeros, or an empty array poly = jnp.array(poly) return (poly.size == 0) | jnp.all(poly == 0) ### evaluate - use as is @jax.jit def evaluate(poly, x): poly = jnp.array(poly) return jnp.polyval(poly, x) ### random @jax.jit def random(key, poly): # Generate random real numbers from normal, then round to nearest integer for "discrete real numbers" # (Could also use randint but that produces ints, not real/float dtype) poly = jnp.array(poly) rand = rnd.uniform(key, poly.shape, minval=0., maxval=999) return jnp.round(rand) ### lagrange_interpolation @jax.jit def has_dup_x_jit(xs): xs_sorted = jnp.sort(xs) return jnp.any(xs_sorted[1:] == xs_sorted[:-1]) @jax.jit def lagrange_interpolation(values): size = len(values) if size == 0: return jnp.zeros(0) # Unzip the list of (x, y) pairs to separate arrays xs, ys = zip(*values) xs = jnp.array(xs) ys = jnp.array(ys) # Check for duplicate x-coordinates using JAX ops and return None if found def dup_case(_): return jnp.zeros(size) # or whatever you want to represent None def no_dup_case(_): # Normal interpolation computation here # (return polynomial coefficients or whatever you compute) monomial_init = jnp.zeros(size) def body(i, acc): result_poly, basis_poly = acc current_y = jnp.polyval(result_poly, xs[i]) delta = ys[i] - current_y basis_eval = jnp.polyval(basis_poly, xs[i]) c_i = delta / basis_eval # element-wise multiplication term = basis_poly * c_i result_poly = result_poly + term monomial = monomial_init.at[size - 1].set(-xs[i]) monomial_all = monomial.at[size - 2].set(1) # After i steps, B(x) = (x - x_0)(x - x_1)...(x - x_{i-1}) basis_poly = jnp.polymul(basis_poly, monomial_all) basis_poly = basis_poly[-size:] return (result_poly, basis_poly) # The result polynomial P(x) starts at zero and is updated iteratively. zero_poly = jnp.zeros(size) # The basis polynomial B(x) starts at 1. basis_poly = zero_poly.at[size - 1].set(1) result_poly, basis_poly = jax.lax.fori_loop(0, size, body, (zero_poly, basis_poly)) return result_poly return lax.cond(has_dup_x_jit(xs), dup_case, no_dup_case, operand=None) ### add - use as is @jax.jit def add(a, b): jax_array_a = jnp.array(a) jax_array_b = jnp.array(b) return jnp.polyadd(jax_array_a, jax_array_b) ### mul - use as is @jax.jit def mul(a, b): jax_array_a = jnp.array(a) jax_array_b = jnp.array(b) return jnp.polymul(jax_array_a, jax_array_b) ```