Try   HackMD

Halo2 Peak Memory Usage

This note shows a pseudocode of create_proof of halo2_proofs and counts the memory usage for each moment.

fn create_proof(params: &KZGParams, pk: &ProvingKey, circuit: &Circuit, instances: &[&[F]]) {
    // Let:
    // 
    // - k: log 2 of number of rows
    // - n: `1 << k`
    // - d: Degree of circuit
    // - e: Extension magnitude, equal to `(d - 1).next_power_of_two()`
    // - c_f: number of fixed columns
    // - c_a: number of advice columns
    // - c_i: number of instance columns
    // - c_p: number of columns enabled with copy constraint
    // - c_pg: number of grand product in permutation argument, equal to `div_ceil(c_p, d - 2)`
    // - c_l: number of lookup argument
    // 
    // The memory usage M.C and M.S stands for:
    // 
    // - M.C: number of "n elliptic curve points" (with symbol ◯)
    // - M.S: number of "n field elements"        (with symbol △)
    // - M.E: number of "e * n field elements"    (with symbol ⬡)
    // 
    // So the actual memory usage in terms of bytes will be:
    //
    //   M = 32 * n * (2 * M.C + M.S + e * M.E)
    //
    // We'll ignore other values with sublinear amount to n.

    // 0. In the beginning:
    // 
    // `params` has:
    //  ◯ powers_of_tau
    //  ◯ ifft(powers_of_tau)
    // 
    // M.C = 2 (+= 2)
    // M.S = 0
    // M.E = 0
    // 
    // `pk` has:
    // ⬡ l0
    // ⬡ l_last
    // ⬡ l_active_row
    // △ fixed_lagranges (c_f)
    // △ fixed_monomials (c_f)
    // ⬡ fixed_extended_lagranges (c_f)
    // △ permutation_lagranges (c_p)
    // △ permutation_monomials (c_p)
    // ⬡ permutation_extended_lagranges (c_p)
    // 
    // M.C = 2
    // M.S = 2 * c_f + 2 * c_p (+= 2 * c_f + 2 * c_p)
    // M.E = 3 + c_f + c_p     (+= 3 + c_f + c_p)
    // 
    // And let's ignore `circuit`

    // 1. Pad instances as lagrange form and compute its monomial form.
    // 
    // M.C = 2
    // M.S = 2 * c_f + 2 * c_p + 2 * c_i (+= 2 * c_i)
    // M.E = 3 + c_f + c_p
    let instance_lagranges = instances.to_lagranges();
    let instance_monomials = instance_lagranges.to_monomials();

    // 2. Synthesize circuit and collect advice column values.
    // 
    // M.C = 2
    // M.S = 2 * c_f + 2 * c_p + 2 * c_i + c_a (+= c_a)
    // M.E = 3 + c_f + c_p
    let advice_lagranges = circuit.synthesize_all_phases();

    // 3. Generate permuted input and table of lookup argument.
    // For each lookup argument, we have:
    //
    // △ compressed_input_lagranges - cached for later computation
    // △ permuted_input_lagranges
    // △ permuted_input_monomials
    // △ compressed_table_lagranges - cached for later computation
    // △ permuted_table_lagranges
    // △ permuted_table_monomials
    //
    // M.C = 2
    // M.S = 2 * c_f + 2 * c_p + 2 * c_i + c_a + 6 * c_l (+= 6 * c_l)
    // M.E = 3 + c_f + c_p
    let (
        compressed_input_lagranges,
        permuted_input_lagranges,
        permuted_input_monomials,
        compressed_table_lagranges,
        permuted_table_lagranges,
        permuted_table_monomials,
    ) = lookup_permuted()

    // 4. Generate grand products of permutation argument.
    //
    // M.C = 2
    // M.S = 2 * c_f + 2 * c_p + 2 * c_i + c_a + 6 * c_l + c_pg (+= c_pg)
    // M.E = 3 + c_f + c_p + c_pg                               (+= c_pg)
    let (
        perm_grand_product_monomials,
        perm_grand_product_extended_lagranges,
    ) = permutation_grand_products();

    // 5. Generate grand products of lookup argument.
    // And then drops unnecessary lagranges values.
    //
    // M.C = 2
    // M.S = 2 * c_f + 2 * c_p + 2 * c_i + c_a + 3 * c_l + c_pg (-= 3 * c_l)
    // M.E = 3 + c_f + c_p + c_pg
    let lookup_product_monomials = lookup_grand_products();
    drop(compressed_input_lagranges);
    drop(permuted_input_lagranges);
    drop(compressed_table_lagranges);
    drop(permuted_table_lagranges);

    // 6. Generate random polynomial.
    //
    // M.C = 2
    // M.S = 1 + 2 * c_f + 2 * c_p + 2 * c_i + c_a + 3 * c_l + c_pg (+= 1)
    // M.E = 3 + c_f + c_p + c_pg
    let random_monomial = random();

    // 7. Turn advice_lagranges into advice_monomials.
    let advice_monomials = advice_lagranges.to_monomials();
    drop(advice_lagranges);

    // 8. Generate necessary extended lagranges.
    //
    // M.C = 2
    // M.S = 1 + 2 * c_f + 2 * c_p + 2 * c_i + c_a + 3 * c_l + c_pg
    // M.E = 3 + c_f + c_p + c_pg + c_i + c_a (+= c_i + c_a)
    let instances_extended_lagrnages = instances_monomials.to_extended_lagranges();
    let advice_extended_lagrnages = advice_monomials.to_extended_lagranges();

    // 9. While computing the quotient, these extended lagranges:
    // 
    // ⬡ permuted_input_extended_lagranges
    // ⬡ permuted_table_extended_lagranges
    // ⬡ lookup_product_extended_lagranges
    // 
    // of each lookup argument are generated on the fly and drop before next.
    // 
    // And 1 extra quotient_extended_lagrange is created. So the peak memory:
    //
    // M.C = 2
    // M.S = 1 + 2 * c_f + 2 * c_p + 2 * c_i + c_a + 3 * c_l + c_pg
    // M.E = 4 + c_f + c_p + c_pg + c_i + c_a + 3 * (c_l > 0) (+= 3 * (c_l > 0) + 1)
    let quotient_extended_lagrange = quotient_extended_lagrange();

    // 10. After quotient is comuputed, drop all the other extended lagranges.
    //
    // M.C = 2
    // M.S = 1 + 2 * c_f + 2 * c_p + 2 * c_i + c_a + 3 * c_l + c_pg
    // M.E = 4 + c_f + c_p (-= c_pg + c_i + c_a + 3 * (c_l > 0))
    drop(instances_extended_lagrnages)
    drop(advice_extended_lagrnages)
    drop(perm_grand_product_extended_lagranges)

    // 11. Turn quotient_extended_lagrange into monomial form.
    // And then cut int `d - 1` pieces.
    //
    // M.C = 2
    // M.S = 2 * c_f + 2 * c_p + 2 * c_i + c_a + 3 * c_l + c_pg + d (+= d - 1)
    // M.E = 3 + c_f + c_p (-= 1)
    let quotient_monomials = quotient_monomials()
    drop(quotient_extended_lagrange)

    // 12. Evaluate and open all polynomial except instance ones.
}

From the computation, currently peak memory usage could be found at step 9:

M.C = 2
M.S = 1 + 2 * c_f + 2 * c_p + 2 * c_i + c_a + 3 * c_l + c_pg
M.E = 4 + c_f + c_p + c_pg + c_i + c_a + 3 * (c_l > 0)

Here is a script to reproduce the actual peak memory usage https://gist.github.com/han0110/676ba81904a5e99d5ff7d8961cebaed5.

Improvements

Low-hanging fruits

  • After step 5, we can drop instance_lagranges already
  • Generate perm_grand_product_extended_lagranges also on the fly so it doesn't overlap with lookup's auxiliary extended lagranges

With these, the peak memory usage reduces to:

M.C = 2
M.S = 1 + 2 * c_f + 2 * c_p + c_i + c_a + 3 * c_l + c_pg
M.E = 4 + c_f + c_p + c_i + c_a + max(3 * (c_l > 0), c_pg)

Compute quotient chunk by chunk

In fact:

best_fft(monomial, omega_nd, log2_nd).into_iter().skip(i).step_by(d)
    == best_fft(coset(monomial, omega_nd_to_i), omega_n, log2_n)

So we can compute quotient chunk by chunk without needing extended lagrange of other monomials in memory all at once, so degree won't have affect on peak memory usage anymore.

But with a performance drawback, is the need to compute d - 1 extra times of coset(monomial, omega_nd_to_i), which seems fine because it's not that expensive.

With this, the peak memory usage reduces to:

M.C = 2
M.S = 1 + 2 * c_f + 2 * c_p + 2 * c_i + 2 * c_a + 3 * c_l + c_pg
        + max(3 * (c_l > 0), c_pg)
M.E = 4 + c_f + c_p

To reduce even more, we can also discard the extended lagranges in proving key, and just compute them on the fly.

M.C = 2
M.S = 7 + 3 * c_f + 3 * c_p + 2 * c_i + 2 * c_a + 3 * c_l + c_pg
        + max(3 * (c_l > 0), c_pg)
M.E = 1 

So the peak memory usage is roughly estimated to 2x sum of all polynomials size.

Reference