---
title: "MSM - Multi scalar multiplication"
date: 2023-07-08T10:00:00-07:00
tags:
- tech
- math
- cryptography
---
Problem: calculate $\sum_{i=0}^{n-1}k_{i}P_{i}$, where $k_{i}$ is a scalar and $P_{i}$ is point on an EC.
Scalar size: $b$ bits
## Naive
Complexity: $N(b-1)$ squaring and $N(b-1)$ point additions

MSM can be divided into two main parts:
1. Modular multiplication
2. Point additions
## Pippenger bucketed approach
> [!info] Squaring: double add, i.e. $2*P$, where P is an EC point.
1. Divide scalar into windows of m, each with $w$ bits.
2. calculate: $P=\sum_{i} k_iP_i=\sum_{j}2^{cj}\left(\sum_{i} k_{ij}P_i\right)=\sum_j 2^{cj} B_j$
3. $B_{j}=\sum_{i}k_{ij}P_{i}=\sum_{2^{w}-1}\lambda\sum_{u(\lambda)}P_{u}$
4. Example: c=3, j=15, $B_{1}=4P_{1}+3P_{2}+5P_{3}+1P_{4}+4P_{5}+6P_{7}+6P_{8}+\ldots+3P_{14}+5P_{15}$
5. Now, group points with same coefficients together, i.e. create buckets of $2^c-1$, and $B_{j}=\sum_{\lambda}\lambda S_{j\lambda}$
6. Take partial sums: $$\begin{align}T_{j1}&=S_{j7}\\ T_{j2}&=S_{j6}+T_{j1}\\T_{j3}&=T_{j2}+S_{j5}\\&\vdots\\T_{j7}&=T_{j6}+S_{j1}\end{align}$$
7. Sum each window.
Complexity: $\frac{b}{c}(2^c+N+1)$ additions +. $b$ squarings
- Part 1: window result calculation: $(2^c+N)b/c$
- Part 2: sum over all windows: each iteration require $1$ addition and $c$ squarings. $b/c$ windows -> $\frac{b}{c} + \frac{b}{c}*c$ squarings
## Point addition optimisations
Majority of complexity is due to point additions. Point addition complexity for affine coordinates is 1 division, 2 multiplications, 6 additions on $\mathbb{F}$. Division is very costly.
Use projective coordinates. Division can be defered to when there is need for switch back to affine coordinates.
Complexity: 7 mults, 4 squarings, 9 additions, 3 mults by 2, 1 by 1. But no division is required.
### Batch affine
Can be used in cases where aim is to find: $G_{i}=P_i+Q_i$ where P, Q are points on EC.
Denote: $a_{i}=x_{i,2}-x_{i,1}$
- $s$: $\prod_{i=1}^{n}a_i$
- $l_i$: $\prod_{j=1}^{i-1}a_{j}$
- $r_i$: $\prod_{j=i+1}^{n}a_{j}$
- $G_{i}=\frac{1}{x_{i,2}-x_{i,1}}=s*l_{i}*r_{i}$
### GLV with endomorphism
Goal: accelerate single point scalar mult $k*P$
#### Naive approach
1. Divide into windows of c bits
2. pre-compute $i*P \forall i\in[0,\ldots,2^c-1]$
3. compute $k_i$ for each window $i$
4. $d=b/c$
5. $R=0$
6. For i from d-1 to 0:
1. $R = 2^c*R$ (require c squaring)
2. $R = R+(k_{i}*P)$ (require 1 point addition)
7. return $R$
Complexity: $2^c(precompute)+d$ point additions and $c*d=b$ squarings
#### Endomorphism
For some curves like BN254, we can use cube roots of unity to find new points on the curve for same y.
:::info
Equation of BN254: $y^2=x^3+b$.
:::
So, $k*P$ can be divided into $k_{1}*P_{1}+k_{2}*P_{2}$, where
- $k=k_1+(s*k_{2})$
- $P_{2}=s*P_{1}$
This reduces no of point doublings by half.
Complexity: b/2 squaring and $d/2+2^{c+1}$ additions
### WNAF (windwoed non-adjacent form)
Most of the time in MSM is spent in additions, and number of additions depend on hamming weight of the scalar. Less number of additions will have to be done if the number of 1's in the scalar bits is less.
That's exactly what NAF is used for in terms of MSM. Instead of representing number in bits 0, 1. It is represented in {-1, 0, 1}. This reduces the number of number of non-zero bits in the number by 1/3rd.
Also, when the b-bit scalars are divided into c-bit slices, value of each slice is used as bucket index. In total, we need $2^{c}-1$ buckets. Using NAF, we map $0,1,\ldots,2^{c}-1$ to ${-2^{c-1},\ldots,-1,0,1,\ldots,2^{c-1}}$. For example, for slice $s_{i}\in{1,2,3,4,5,6,7}$ that needs $2^3−1=7$ buckets in total, if we map $s_i$ to ${−4,−3,−2,−1,1,2,3}$, only $2^(3−1)=4$ buckets are needed.
Pseudocode to find NAF:
Input: I = $(b_{n-1},b_{n-1},\ldots,b_{0})$
Output O = $(\omega_{n-1},\omega_{n-2},\ldots,\omega_{0})$
```
i=0
while I > 0:
if I is odd:
w(i) = I mod 2^2
I = I - w(i)
else:
w(i) = 0
I = I/2
i = i+1
return O
```
Let's take a scalar $s$, with c-bit window size, $L = 2^c$. Split scalar into K slices, where $K = ceil(b/c)$.
$$
s = s_{0}+s_1L+\ldots+s_{K-1}L^{K-1} = \sum_{i}L^{i}K_{i}
$$
Now, to take advantage of wnaf, use following relation:
$$
s_{i}L^{i}+s_{i+1}L^{i+1}=(s_{i}-L)L^{i}+(s_{i+1}+1)L^{i+1}
$$
i.e. subtract $2^c$ from current slice and add 1 to next slice, whenever $s_i >= L/2$. After this, slice comes in range $[-L/2,L/2)$. To convert it into array form, whenever sign of $s_i$ is negative, subtract it from ith index in bucket array.
## Barretenberg pseudocode
- `pippenger_internal()`
- `compute_wnaf_states()`: take scalars, divide among threads, split into endomorphism points, compute `skew_map`: map of scalar vs even/odd, create wnaf entries serial wise in `point_schedule`.
- `process_buckets()`: sort wnaf entries in `point_schedule`
- `evaluate_pippenger_rounds()`: receive sorted wnaf entries in `point_schedule`.
- Divide the rounds into threads and divide `num_points` into threads.
- `reduce_buckets()`: points will be packed together according to their wnaf entries. outputs `point_pairs_1` from `affine_product_state`. all the buckets are packed together at the end in reverse order, i.e. first bucket is at position `affine_product_state.num_points`. **all buckets** in the subarray assigned to thread **summed together**.
- runs recursively until only **one bucket point per point** reached.
- `construct_addition_chains()`:
- outputs: `max_bucket_bits` which is the bucket with maximum points. so, if `max_bucket_bits=3`, then there is a bucket with atleast 8 points, can be more.
- compute `bucket_counts` and `bucket_empty_status` in first round.
- `count_bits()`: calculate `bit_offsets`
- `bit_offsets` tells how points are grouped according to their bits. so, for example if a bucket has 14 points, then it will have 8 points in 3rd offset, 4 in 2nd, and 2 in 1st offset.
- places wnaf entries from `point_schedule` bucket wise in `point_pairs_1`. we have `bucket_count`, `bit_offsets`, run for every bucket,
- take count, and from offset,
- find the place in the `point_pairs_1` array to put the point bucket-wise (bit-wise).
- bucket counts which have 2nd bit true will be put together.
- `evaluate_addition_chains()`:
- add affine points pairs together, reduces number of points by half.
- example: `bucket_counts = [0, 4, 0, 10]` => `max_bucket_bits = 3`. `bit_offsets = [0, 0, 2, 4, 8]` => `[0, 0, 2, 6, 14]`.
- addition chains of 2, 4, and 8 would have been formed already
- loop runs for 4 iterations.
- 1st iteration: add all points except only 1 point pair
- 2nd iteration: add all points except till 2 point pair, because they have been converted to 1 points from
- previous iteration 3rd iteration: add all points except till 4 point pair, and so on.
- `add_affine_points()`
- calculate new `bit_offsets`
- calculate `point_schedules`, `bucket_counts`
- accumulate summed bucket points together. $T_1=S_1,T_2=T_1+S_2,\ldots,T_n=T_n-1+S_n$. and calculate `running_sum` as sum of all bucket points together.
- scale by first bucket, as threads will have buckets assigned from middle of the point_schedule
- aggregate all `thread_accumulator` together to form `result`.
- return `result`.
## Resources
- [zkStudyClub: Multi-scalar multiplication](https://www.youtube.com/watch?v=Bl5mQA7UL2I)
- [Known optimisation for MSM](https://www.notion.so/Known-Optimizations-for-MSM-13042f29196544938d98e83a19934305#9d8b79321f584477ac945a738042c396)
- [Aztec's wnaf](https://hackmd.io/@aztec-network/rJ3VZcyZ9)
- [Optimizing Multi-Scalar Multiplication (MSM): Learning from ZPRIZE](https://hackmd.io/@drouyang/msm)
- [EC arithmetic](https://cryptographyinrustforhackers.com/chapter_4/elliptic_curves.html)
- [Hardware Review: GPUs , FPGAs and Zero Knowledge Proofs](https://www.ingonyama.com/blog/hardware-review-gpus-fpgas-and-zero-knowledge-proofs#section-6)
- [Optimisation of MSM](https://hackernoon.com/optimization-of-multi-scalar-multiplication-algorithm-sin7y-tech-review-21)
- [Accelerating the PlonK zkSNARK Proving System using GPU Architectures](https://bpb-us-w2.wpmucdn.com/wordpress.lehigh.edu/dist/0/2548/files/2023/05/Master_Thesis_Tal_Derei.pdf)
- [FPGA Acceleration of Multi-Scalar Multiplication: CycloneMSM](https://eprint.iacr.org/2022/1396)
- [EdMSM: Multi-Scalar-Multiplication for SNARKs and Faster Montgomery multiplication](https://eprint.iacr.org/2022/1400)
- [cuZK: Accelerating Zero-Knowledge Proof with A Faster Parallel Multi-Scalar Multiplication Algorithm on GPUs](https://eprint.iacr.org/2022/1321.pdf)
- [PipeMSM: Hardware Acceleration for Multi-Scalar Multiplication](https://eprint.iacr.org/2022/999)
Pat's Questions:
- What is the proper definition of a pippenger round???? as mentioned here https://github.com/AztecProtocol/barretenberg/blob/master/cpp/src/barretenberg/ecc/scalar_multiplication/scalar_multiplication.cpp#L137
- We should check how this would work in rust? I think it isn't an issue since arrays/vecs etc follow similar layout and in memory its just reversing assignment but we shoud know before we implement. Also low priority given endomorphisms are curve dependent. https://github.com/AztecProtocol/barretenberg/blob/master/cpp/src/barretenberg/ecc/scalar_multiplication/scalar_multiplication.cpp#L106
- We should strongly consider using proptest instead of the random generator they use. That or hook into lib fuzzer. Random also is dead simple so I'm probs hallucinating. IDEA: derive macro for struct headers that generates proptest think this may be a thing but we can make it sick or sicker. Also use ffi to have implementation equivalency with barrentenburg. Would also be sick to have a criterion server so to speak echoing tensorboard that allows you to see performance data in real time/past data in browser. That is out of scope.
- What the fuck is this??? https://github.com/AztecProtocol/barretenberg/blob/master/cpp/src/barretenberg/ecc/scalar_multiplication/scalar_multiplication.cpp#L17
- Remembering Jonathan's words. Pippenger allows you to traverse both the bit space and the number space at the same time to speed things up instead of doing one and then the other. Make this a diagram showing each bucket and random set of bits.
- How ????https://github.com/AztecProtocol/barretenberg/blob/master/cpp/src/barretenberg/ecc/pippenger.md?plain=1#L64C9-L64C9
- How is this number derived? https://github.com/AztecProtocol/barretenberg/blob/master/cpp/src/barretenberg/ecc/scalar_multiplication/scalar_multiplication.hpp#L51
- Why? https://github.com/AztecProtocol/barretenberg/blob/master/cpp/src/barretenberg/ecc/fields/field.hpp#L292
- Note: https://github.com/AztecProtocol/barretenberg/blob/master/cpp/src/barretenberg/ecc/fields/field.hpp#L328
- Believe this is referencing batch inversion? Commonly known as the montgomery trick mentioned here? https://zcash.github.io/halo2/background/fields.html
- Gotta figure what the equivalent to this macro is in rust could probs do some whacky inlining for the compiler: https://zcash.github.io/halo2/background/fields.html
- Check Jane street opensource repos???
- It could be tricky to port over the runtime states to different curves may need others expertise: https://en.cppreference.com/w/cpp/language/static_cast
- This is a good idea: https://github.com/arkworks-rs/algebra/blob/master/ec/src/scalar_mul/fixed_base.rs#L83
- Got confused by move constructors all should be good.
TODO:
[ ] Function headers:
[ ] Tests:
[ ] scalar multiply reduce buckets
- Heck reduce buckets with 16 buckets each with 1 point
[ ] Reduce Buckets
- points == 1 << 16
[ ] Disabled reduce buckets basic
- points == 1 << 20
- reduce buckets with smaller workload???
[ ] AddAffinePoints
- Verify affine addition used in addition chains for bucket reduction
[ ] ConstructAdditionChains
- clock/verify computation of addition chains for bucket reduction
[ ] EndomorphismSplit
- checks in memory endomorphism split
[ ] RadixSort
- Verify Radix sort is correct
[ ] OversizedInputs
- Checks for large inputs we chunk and perform multiple smaller pippengers
[ ] UndersizedInputs
- Test smallet size input that runs traditional MSM
[ ] PippengerSmall
- Pippenger on small input
[ ] PippengerEdgeCaseDbl
- All points are the same random element
[ ] PippengerShortInputs
- Test pippenger with truncated scalar values
[ ] PippengerUnsafe
- Test unsafe pippenger op
[ ] PippengerUnsafeShortInputs
- Test unsafe pippenger op with elements truncated with 0's
- Manually set underlying memory of scalars to be truncated (end with zeros) test edge case???
[ ] PippengerOne
- Run pippenger with 1 set of MSM
[ ] PippengerZeroPoints
- Run Pippenger with no points loaded state(0)
[ ] PippengerMulByZero
- Set scalar 0 to Fr::zero()
- set points to G1::One
- generate pippenger point table
- create runtime state
- run pippenger
- check result is point at infinity