• # Keccak Halo2 implementation spec``

Acknowledgement

We port Konstantin Panarin's idea of sparse representation and the code examples to Halo2.

Terminology

Keccak

Sparse representation

chunk
a chunk is a bit
slice
a slice is some chunks

The lane size is 64 bits. We split it into slices. In the binary to 13 base mapping, we work with 16 bits or 16 chunks per slice. A lane has 4 slices.

In the 13 base to 9 base mapping, we work with 4 bits or 4 chunks per slice.

first base
the number 13
second base
the number 9

Introduction

Overview

The original keccak

# 1word = 64 bits
# 200 bytes, 25 words
state = [0 for i in range(200)]
for _input in split_every_17_word(raw_input):
    # take rate=17 words of input
    state ^= _input
    # keccak_f(state)
    for i in range(24):
        # round_b
        state = theta(state)
        state = rho(state, rotation[i])
        state = pi(state)
        state = xi(state)
        # add round constant to state[0][0]
        state = iota(state, round_constant[i])
    return state

How we do keccak in the circuit

first_input, rest_inputs = split_every_17_word(raw_input)
# take 
state = convert_base(first_input, _from=2, _to=13)
for _input in rest_inputs:
    next_mixing = convert_base(_input, _from=2, _to=9)
    state = keccak_f(state, next_mixing)
state = keccak_f(state, None)
# TODO: padding
    
def keccak_f(state, next_mixing):
    for i in range(23):
        # enter state of 25 lanes, each lane is 13 base 64 chunks
        state = theta(state)
        # enter with 13 base 65 chunks
        state = rho(state, rotation[i])
        # leave with 9 base 64 chunks
        state = pi(state)
        # first 23 rounds we add round_constant but no absorb
        state = xi(state)
        state = iota_b9(state, round_constant[i])
        state = convert_base(state, _from=9, _to=13)
    state = theta(state)
    state = rho(state, rotation[23])
    state = pi(state)
    state = xi(state)
    # last round:
    if next_mixing is None:
        # if we don't have new input to mix, just add the round_constant
        state = iota_b9(
            state,
            convert_base(round_constant[23], _from=2, _to=9),
        )
    else:
        # if we have new input to mix, add the round_constant after the base conversion.
        # This is the case the coefficient of `state[0][0]` in the next theta could be 2.
        state = absorb(state, next_mixing)
        state = convert_base(state, _from=9, _to=13)
        state = iota_b13(
            state,
            convert_base(round_constant[23], _from=2, _to=13),
        )

The first base and the second base

Why the first sparse base is 13? because in theta step we have

  • at most 12 0~1 variables adding together.
  • a rotate_left(, 1)
    Note that we don't do full rotate_left(, 1), we do shift_left(, 1) instead and save the wrap over in the later step. After shift_left the sparse representation is at most 13^65 and is about 241 bits.

Why the second sparse base is 9? because in the xi and iota steps combined we have at most 8 0~1 variables adding together.

Break down of the theta step

new_state[x][y] = state[x][y] + c[(x+4)% 5] + 13 * c[(x+1)%5]
        state[x][y]             s63    s62    s61        s2    s1    s0
        c[(x+4)% 5]           c4_63  c4_62  c4_61      c4_2  c4_1  c4_0
+) 13 * c[(x+1)% 5]    c1_63  c1_62  c1_61  c1_60      c1_1  c1_0
---------------------------------------------------------------------
   new_state[x][y]      ns64   ns63   ns62   ns61 ...   ns2   ns1   ns0

new_state[x][y] is a base 13 number with 65 chunks.

The full roation is supposed to give us a base 13 value t with 64 chunks. Let's denote it t0~t63.

The t0 is supposed to be s0 + c4_0 + c1_63. But in new_state[x][y], s0 + c4_0 is at the ns0 and c1_63 is at the ns64.

We call the ns0 "last chunk high value" and ns64 "last chunk low value."

Break down the rotate_and_convert

In the rotate_and_convert, we do the rotation and the conversion from base 13 to base 9 simultaniously.

For example, we want to rotate 5. Let's first observe the ground truth t

we assume first base converter has been applied on t, so that each chunk ti is either 0 or 1.

   t          t63  t62  t61  t60  t59  t58 ... t5  t4  t3  t2  t1  t0   
   rho(t, 5)  t58  t57  t56  t55  t54  t53 ... t0 t63 t62 t61 t60 t59
   t          t63  t62  t61  t60  t59   t58 ...   t5   t4   t3   t2   t1   t0
   base 9     9^4  9^3  9^2    9    1  9^63     9^10  9^9  9^8  9^7  9^6  9^5

The sum-product of t and the base gives us a 64 chunks base 9 number.

But the new_state[x][y] from the theta step has the special chunks "chunk 0" and "chunk 64"

We can handle the rest of the chunks as usual, but the ns0 and ns64 needs special care.

nsxy      ns64 ns63 ns62 ns61 ns60 ns59 ns58 ... ns5 ns4 ns3 ns2 ns1 ns0
base 9          9^4  9^3  9^2    9    1 9^63    9^10 9^9 9^8 9^7 9^6 

add the sum-product with keccak_u64_first_converter(ns0 + ns64)*9**5 and we get the correctly rotated base 9 number.

Example

nsxy      ns64 ns63 ns62 ns61 ns60 ns59 ns58 ...  ns5  ns4  ns3  ns2  ns1 ns0
example      1    7    6    0    5    8   10        9    2    1   12    3  11

# apply first base converter for the middle chunks
converted         1    0    0    1    0    0        1    0    1    0    1
base 9          9^4  9^3  9^2    9    1 9^63     9^10  9^9  9^8  9^7  9^6

Note that ns64 + ns0 must be less than or equal to 12

keccak_u64_first_converter(ns0 + ns64) evalutes to 0

Break down the rho step validation

How do we implement Keccak steps in the circuit?

Theta

Original

# state in base 2
def theta(state: List[List[int]]):
    c = [state[x][0] ^ state[x][1] ^ state[x][2] ^ state[x][3] ^ state[x][4] for x in range(5)]
    d = [c[x - 1] ^ rotate_left(c[(x + 1) % 5], 1) for x in range(5)]
    for x in range(5):
        for y in range(5):
            state[x][y] ^= d[x]
    return state

In circuit

# state in base 13 # input state: coefficients of state[0][0] in 0~2, other state[x][y] in 0~1 # output state: coefficients in 0~12 def theta(state: List[List[int]]): c = [state[x][0] + state[x][1] + state[x][2] + state[x][3] + state[x][4] for x in range(5)] new_state = [[0 for x in range(5)] for y in range(5)] for x in range(5): for y in range(5): new_state[x][y] = state[x][y] + c[(x+4)% 5] + 13 * c[(x+1)%5] return new_state

The argument that each output lane's coefficients is at most 12

Note that in the middle of the rounds we might enter the theta step from previous iota step, which is a step that xors a round constant to state[0][0]. So that state[0][0] might start with 2.

Since state[0][0] might start with 2 and c[0] contains a state[0][0], c[0] could be at most 6.

Taking a closer look to the number new_state[x][y] in base 13 system at line 9, we observe the largest possible value for coefficient at any given position.

For x not in (0, 1, 4), we have state[x][y] at most 1, c[(x+4)% 5] at most 5, and 13 * c[(x+1)%5] from the previous position that contributs at most 5. The total is at most 1+5+5 = 11.

For x == 0, we have state[x][y] at most 2, which happens at state[0][0], c[(x+4)% 5] at most 5 since state[0][0] is not in it, and 13 * c[(x+1)%5] at most 5 for the same reason. The total is at most 2 + 5 + 5 = 12.

For x in (1, 4), state[x][y] is at most 1. One of c[(x+4)% 5] or c[(x+1)%5], but not both, could be at most 6 since state[0][0] is in exact one of them. The total is at most 1 + 5 + 6 = 12.

Cost analysis

5 linear combination for c
25 linear combination for new_state

Pi

Nothing too interesting here

# state in base 9
def pi(state: List[List[int]]):
    new_state = [[0 for x in range(5)] for y in range(5)]
    for x in range(5):
        for y in range(5):
            new_state[x][y] = state[(x + 3* y) % 5][x]

Cost analysis

No cost in circuit???

Rho

Original

def rho(state: List[List[int]]):
    for x in range(5):
        for y in range(5):
            state[x][y] = rotate_left(state[x][y], ROT[x][y])
    return state[x][y]

In circuit

# input state[x][y] is in base 13, has shifted left 1, and thus has 65 chunks # output state[x][y] is in base 9 and rotated properly. It has 64 chunks, each coefficient is 0~1. def rho(state: List[List[int]]): offset_map = { 1: [], 2: [], 3: [] } new_state = [[0 for x in range(5)] for y in range(5)] for x in range(5): for y in range(5): # in base 9, coefficients 0~1. Rotation has applied new_state[x][y] = rotate_and_convert(state[x][y], ROT[x][y]) output_slices = [] output_coefs = [] # `acc` looks like the `raw` later. The difference is # acc is witnessed and checked acc = state[x][y] last_chunk_low_value = state[x][y] % 13 # raw is computed by the prover but not witnessed raw = state[x][y] // 13 last_coef = 9**ROT[x][y] # coef of base 13 slice cur_input_coef = 13 # coef of base 9 slice cur_output_coef = 1 if ROT[x][y] == 63 else last_coef * 9 # which chunk are we at cur_offset = 1 # Looping through slices while cur_offset < 64: # step could be 1, 2, 3, 4 step = get_step_size(cur_offset, ROT[x][y]) input_slice = raw % 13**step raw //= 13**step block_count, output_slice = first_to_second_base_converter_table.lookup(input_slice) acc -= cur_input_coef * input_slice output_slices.append(output_slice) cur_offset += step cur_input_coef *= 13 ** step cur_output_coef = ( 1 if cur_offset == 64 - ROT[x][y] else cur_output_coef * 9 ** step ) # keep track of non 4 step if step < 4: offset_map[step].append(block_count) last_chunk_high_value = raw % 13 assert raw // 13 == 0, "No chunk should be left in the raw" last_chunk_value = last_chunk_high_value * (13**64) + last_chunk_low_value # we are looking up table to get this evaluation in the circuit # keccak_u64_first_converter(last_chunk_low_value + last_chunk_high_value) output_slice = of_first_to_second_base_converter_table.lookup( last_chunk_value ) assert acc == last_chunk_value output_coefs.append(last_coef) output_slices.append(output_slice) # linear combination check assert new_state[x][y] == sum([s*c for s, c in zip(output_slices, output_coefs)]) # offset_transformed = [0, 0, 1, 13, 170] # for i in (1, 2, 3): # # Perform a range check here # assert sum(offset_map[i]) <= offset_transformed[i] * len(offset_map[i]) assert sum(offset_map[1]) == 0 assert sum(offset_map[2]) <= len(offset_map[2]) assert sum(offset_map[3]) <= 13 * len(offset_map[3]) return new_state

Cost analysis

  • 25 lanes each needs
    • for middle slices: approxmate 16 iterations (64 chunks /4 chunks) in the while loop, each needs
      • a lookup to first_to_second_base_converter_table
      • 1 linear combination check for current acc value
    • for special chunk
      • a lookup to of_first_to_second_base_converter_table
      • check the remaining acc == last_chunk_value
    • 1 linear combination check for new_state[x][y] and the sum product of output_slices and output_coefs

lastly we need 3 linear combination check for offset_map for each non-4 step

total:

  • 25 * 16 lookups to first_to_second_base_converter_table lookup
  • 25 lookups to of_first_to_second_base_converter_table
  • 25*(16 + 1 + 1) + 3 = 453 linear combinations

Xi + Iota step

Original

# state in base 2
def xi_and_iota(state: List[List[int]], round_constant: int):
    new_state = [[0 for x in range(5)] for y in range(5)]
    # Xi step
    for x in range(5):
        for y in range(5):
            new_state[x][y] = state[x][y] ^ ((~state[(x + 1) % 5][y]) & state[(x + 2) % 5][y])
    # Iota step
    new_state[0][0] ^= round_constant
    return new_state

Original: Merging Iota step

# state in base 2
def xi_and_iota(state: List[List[int]], round_constant: int):
    new_state = [[0 for x in range(5)] for y in range(5)]
    for x in range(5):
        for y in range(5):
            a = state[x][y]
            b = state[(x + 1) % 5][y]
            c = state[(x + 2) % 5][y]
            d = 0 if x!=0 and y!=0 else round_constant
            new_state[x][y] = a ^ (~b & c) ^ d
    return new_state

We have the keccak_u64_second_converter function to map 2*a + b + 3*c + 2*d to a ^ (~b & c) ^ d.

Depending if we are at the last round of round_b or last round of keccak_f, we'll use d to do
- iota step to add a round constant in base 9, or
- absorb next input to step, then add round constant in base 13 later.

In the circuit

# state in base 9, coefficient in 0~1
def xi(state: List[List[int]]):
    new_state = [[0 for x in range(5)] for y in range(5)]
    # Xi step
    for x in range(5):
        for y in range(5):
            # a, b, c, d are base9, coefficient in 0~1
            a = state[x][y]
            b = state[(x + 1) % 5][y]
            c = state[(x + 2) % 5][y]
            # coefficient in 0~6
            new_state[x][y] = 2*a + b + 3*c
    return new_state
# state in base 9, coefficient in 0~6
# next_input is state in base 9, coefficient in 0~1
def absorb(state: List[List[int], next_input: List[List[int]):
    for idx in range(17):
        # state[idx] has 2*a + b + 3*c already, now add 2*d to make it 2*a + b + 3*c + 2*d
        # coefficient in 0~8
        state[idx] += 2 * next_input[idx]
    return state
# state in base 9, coefficient in 0~6
def iota_b9(state: List[List[int], round_constant_base9: int):
    d = round_constant_base9
    # state[0][0] has 2*a + b + 3*c already, now add 2*d to make it 2*a + b + 3*c + 2*d
    # coefficient in 0~8
    state[0][0] += 2*d
    return state
def iota_b13(state: List[List[int], round_constant_base13: int):
    state[0][0] += round_constant_base13
    return state

Cost analysis

25 linear combination check for new_state[x][y]

Helpers

get_step_size (check_offset_helper)

# cur_offset begins with 1
def get_step_size(cur_offset: int, max_offset: int) -> int:
    """
    Is the `check_offset_helper` in rust with
    base_num_of_chunks = 4
    """
    # near the start of the lane
    if cur_offset < max_offset < cur_offset + 4:
        return max_offset - cur_offset
    # near the end of the lane
    if cur_offset < 64 < cur_offset + 4:
        return 64 - cur_offset
    return 4

get_steps

def get_steps(max_offset: int):
    cur_offset = 1
    offsets = [cur_offset]
    steps = []
    while cur_offset < 64:
        step = get_step_size(cur_offset, max_offset)
        cur_offset += step
        steps.append(step)
        offsets.append(cur_offset)
    print("=======")
    print("max_offset", max_offset)
    print("offsets",  offsets, "length", len(offsets))
    print("steps", steps)

get_steps(1)
get_steps(4)
get_steps(5)
get_steps(63)

=======
max_offset 1
offsets [1, 5, 9, 13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61, 64] length 17
steps [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3]
=======
max_offset 4
offsets [1, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64] length 17
steps [3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
=======
max_offset 5
offsets [1, 5, 9, 13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61, 64] length 17
steps [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3]
=======
max_offset 63
offsets [1, 5, 9, 13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61, 63, 64] length 18
steps [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 1]

keccak_u64_first_converter


def keccak_u64_first_converter(n: int) -> int:
    """
    n is the sum of 12 different bits.
    The theta step has 12 xor operations
    If the sum is odd, then the 12 xor operations result 1
    If the sum is even, then the 12 xor operations result 0
    """
    assert n < 13
    return n & 1

keccak_u64_second_converter

def keccak_u64_second_converter(n: int) -> int:
    """
    n is the output of 2a + b + 3c + 2d, where a, b, c, d are bits
    every possible n can be uniquely mapped to the output of a ^ (~b & c) ^ d
    bit_table is the output of a ^ (~b & c) ^ d
    """
    assert n < 9
    bit_table = [0, 0, 1, 1, 0, 0, 1, 1, 0]
    return bit_table[n]

f_logic = a ^ (~b & c) ^ d
f_algebratic = 2a + b + 3c + 2d
idx a b c d a ^ (~b & c) ^ d 2a + b + 3c + 2d
0 0 0 0 0 0 0
1 0 0 0 1 1 2
2 0 0 1 0 1 3
3 0 0 1 1 0 5
4 0 1 0 0 0 1
5 0 1 0 1 1 3
6 0 1 1 0 0 4
7 0 1 1 1 1 6
8 1 0 0 0 1 2
9 1 0 0 1 0 4
10 1 0 1 0 0 5
11 1 0 1 1 1 7
12 1 1 0 0 1 3
13 1 1 0 1 0 5
14 1 1 1 0 1 6
15 1 1 1 1 0 8

The idx is from 0~15

The a, b, c, d are binary decomposition

If we sort by the f_alg column and select f_logic column together, we get the following table. Don't worry about the duplicated rows, they are consistent since f_algebratic -> f_logic is well-defined

2a + b + 3c + 2d a ^ (~b & c) ^ d
0 0
1 0
2 1
3 1
4 0
5 0
6 1
7 1
8 0

If we implement that table it would look just like keccak_u64_second_converter

rotate_and_convert

def rotate_and_convert(base13_input: int, rot: int) -> int:
    """
    Convert the base13 input to base9 output
    base13_input is assumed having shifted left 1
    the coefficient of base9 is the result of binary operations of base13 coefficient, which is 0~1
    rot is applied on base9
    """
    base = 9**rot
    special_chunk = 0
    raw = base13_input
    acc = 0

    for i in range(65):
        remainder = raw % 13
        if i in (0, 64):
            # t0 and t64
            special_chunk += remainder
        else:
            acc += keccak_u64_first_converter(remainder) * base
        raw /= 13
        base *= 9
        if i == 64 - rot:
            base = 1
    acc += keccak_u64_first_converter(special_chunk) * 9**rot
    return acc

normalizer

def normalizer(
    _input: int,
    input_base: int,
    output_base: int,
    transform_f:Callable[[int], int]
) -> int:
    """
    Takes a _input in input_base, converts to output in output_base
    """
    output = 0
    base = 1
    while _input > 0:
        remainder = _input % input_base
        output += transform_f(remainder) * base

        _input /= input_base
        base *= output_base
    return output
    

Tables

from_binary_converter_table

Purpose: Convert 64 bits value from binary to base 13 in 16 bits chunks

>>> import itertools
>>> list(itertools.product([1, 2, 3], [4, 5, 6]))
[(1, 4), (1, 5), (1, 6), (2, 4), (2, 5), (2, 6), (3, 4), (3, 5), (3, 6)]

This table has 2**16 rows


class FromBinaryConverterTable:

    def __init__(self):
        axies = [[0, 1] for i in range(16)]

        for coefs in itertools.product(*axies):
            assert len(coefs) == 16
            # key is the binary evaluation of coefs
            key = sum([coef*(2**i) for i, coef in enumerate(coefs)])
            value0 = sum([coef*(13**i) for i, coef in enumerate(coefs)])
            value1 = sum([coef*(9**i) for i, coef in enumerate(coefs)])
            self.add_row(key, value0, value1)

first_to_second_base_converter_table

Purpose: convert the middle slices, where the chunk size is 4

This table has 13**4 rows


def block_counting_function(n: int) -> int:
    """
    aka the function g. 
    The table is reverse engineered from the trace of `of_transformed`
    See: Decyphering block counting function section
    """
    table = [0, 0, 1, 13, 170]
    return table[n]


class FirstToSecondBaseConverterTable:
    def __init__(self):
        axies = [list(range(13)) for i in range(4)]

        for coefs in itertools.product(*axies):
            assert len(coefs) == 4
            # x0, x1, x2, x3 are 0~12
            x0, x1, x2, x3 = coefs
            key = x0 + x1 * 13 + x2 *13**2 + x3*13**3
            fx0 = keccak_u64_first_converter(x0)
            fx1 = keccak_u64_first_converter(x1)
            fx2 = keccak_u64_first_converter(x2)
            fx3 = keccak_u64_first_converter(x3)
            # fx0, fx1, fx2, fx3 are 0~1
            value = fx0 + fx1 * 9 + fx2 *9**2 + fx3*9**3
            
            # could be 0, 1, 2, 3, 4
            non_zero_chunk_count = 4 - len([i for i in coefs if i==0])
            # could be 0, 0, 1, 13, 170
            block_count = block_counting_function(non_zero_chunk_count)
            
            self.add_row(key, block_count, value)

of_first_to_second_base_converter_table

Purpose: convert the t0 and the t64 chunks from the arithmetic result of the sparse form to the bitwise operation result of the binary form.

The table size should be (13 + 1) * 13 / 2 = 91 rows

class OfFirstToSecondBaseConverterTable:

    def __init__(self):
        for i in range(13):
            for j in range(13 - i):
                low = i
                high = j
                # 13 ** 64 is the base_b ^ offset
                key = low + high * (13 ** 64)
                value = keccak_u64_first_converter(low + high)
                self.add_row(key, value)

from_second_base_converter_table

Purpose: Convert from base 9 to base 13 or binary

This table has 9**5 rows

class FromSecondBaseConverterTable:

    def __init__(self):
        axies = [list(range(9)) for i in range(5)]

        for coefs in itertools.product(*axies):
            assert len(coefs) == 5
            # key is the binary evaluation of coefs
            key = sum([coef*(9**i) for i, coef in enumerate(coefs)])
            value0 = sum([coef*(13**i) for i, coef in enumerate(coefs)])
            value1 = sum([coef*(2**i) for i, coef in enumerate(coefs)])
            self.add_row(key, value0, value1)

Sponges

Keccak use simple sponge

For Keccak256, the squeeze phase is a no-op

First we absorb the input bytes. Convert them to base 13 sparse representation using from_binary_converter_table.

We absorb at a rate of 16 bits. (could enlarge the table size to optimize this)

state is initialized 1600 bits.

We absorb 1088 bits of input at a time.

We run keccak_f1600 for 1600 bits state to get the next state.

When we finish the absorb, we xored 0x01 and 0x80 at specific place. Run keccak_f1600 again.

Return first 256 bits of the state.

Cost analysis

  • keccak1600 for 25 lanes (total 500ish lcs + 845 lookups)
    • 100 lookups (=1600/16) for 16 chunks b2tob13 conversion (table is \(2^{16}\) rows)
    • 30 lcs in theta
    • Rho
      • 2800 lcs (running down (2) + running up (2) + block count acc (3)) *16 *25
      • 400 lookups in Base13To9Table(28.5k rows)
      • 25 lookups in specialChunkTable(170 rows) in rho
      • final check
        • range check 0~12
        • range check 0~169
    • 25 lcs in xi+iota
    • 320 lookups (25 lanes * 64 chunks/5 chunks) for 5 chunks b9tob2 conversion (59k rows)
  • sponge 1 absorb loop
    • takes 1088 bits input (1088/16= 68 b2tob13 conversion)
    • 1 keccak1600
  • sponge padding when absorb complete
    • 1 keccak1600
  • for n bytes input
    • x = 8*n / 1088 absorb phases
    • x + 1 keccak1600
  • for a 32 bytes word in EVM (x = 0)
    • 0 absorb loop
    • 1 padding keccak1600

(WIP) Working with multiple inputs

Han: yeah it’s determined at proving time. Naively we can handle keccak opcode by copying the memory value from bus mapping to another table, which could contains (hash, offset, u64_0, …, u64_16), then keccak circuit consume the table row by row to update the hash state, and it only resets the state (re-initialize) when hash != hash_prev hash == output

To keep things as simple as possible, I take opcode SHA3 as an example and limit the keccak circuit’s access to bus mapping.

  1. When evm circuit sees a SHA3, it takes the offset and size from stack, then it starts a multi-step process to copy memory[offset:offset+size] from BusMapping to another table, let’s call it KeccakIO. The action copy here is actually do lookup twice to both tables and ensure thet have same row.
  2. Naively, we can have KeccakIO with 19 columns which are (hash, offset, data0, …, data16), where data* are decided by the absorbing rate
  3. Then in keccak circuit, say we have a chip keccak-f given (old_state, data0, …, data16) and will ouput (new_state, hash)
  4. What we need to do next is to repeat the keccak-f chip at much as possible, lookup the KeccakIO table to get some data to absorb, and update the state and see if output hash is matching to the one in table.
  5. If it’s not matching, we keep absorb until matched.
  6. If it’s matching, we reset the state for next chip, and start to verify next hash.
  7. Somehow we need a mechanism like global_counter in BusMapping to avoid malicious insertion of invalid hash data pair.

Appendix

(WIP) Decyphering block counting function

What's known

  • g = |x| { of_transformed[x as usize] };
  • The value of counts is [12, 12, 13]. The first_base_num_of_chunks is 4, useually we lookup the 4 bit chunks. But there are exceptional cases, the counts means
    • 12 cases in the OFFSETS, we have to deal with 1 chunks
    • 12 cases in the OFFSETS, we have to deal with 2 chunks
    • 13 cases in the OFFSETS, we have to deal with 3 chunks
  • "of_transformed" is from "counts"
  • printed log shows the value of "of_transformed" is [0, 0, 1, 13, 170]
  • They have a unused constant MAX_OF_COUNT_PER_ITER= 50.

What's unknown

  • The value of "of_transformed" does not match the code comments. The value of g(3) is 13 and is not at least 51
  • what does "25 shifted rows so there are at most 50 overflows" mean?

https://github.com/matter-labs/franklin-crypto/blob/a78e17674dd1c99788259986cae11239953b50af/src/plonk/circuit/hashes_with_tables/keccak/gadgets.rs#L176

we have 25 shifted rows so there are at most 50 overflows
let g(0) = 0, g(1) = 0, g(2) = 1
g(3) should be than at least 51
g(4) than should be g(3) * 50 + 1 = 2551
and so forth: g(i+1) = 25 * g(i) + 1

# OFFSETS are the 64 - ROTs
OFFSETS = [
    [64, 28, 61, 23, 46],
    [63, 20, 54, 19, 62],
    [2, 58, 21, 49, 3],
    [36, 9, 39, 43, 8],
    [37, 44, 25, 56, 50]
]

def get_step_size(cur_offset: int, max_offset: int) -> int:
    """
    Is the `check_offset_helper` in rust with
    base_num_of_chunks = 4
    outputs either 1, 2, 3, 4
    """
    # near the start of the lane
    if cur_offset < max_offset < cur_offset + 4:
        return max_offset - cur_offset
    # near the end of the lane
    if cur_offset < 64 < cur_offset + 4:
        return 64 - cur_offset
    return 4



def get_counts():
    counts = [0, 0, 0]
    for row in OFFSETS:
        for max_offset in row:
            cur_offset = 1
            while cur_offset < 64:
                step = get_step_size(cur_offset, max_offset)
                if step != 4:
                    counts[step - 1] += 1;  
                cur_offset += step
    print("counts", counts)  # [12, 12, 13]

    of_transformed = [0, 0]
    for c in counts:
        elem = of_transformed[-1]
        of_transformed.append(c* elem + 1)
    print("of_transformed", of_transformed)  # [0, 0, 1, 13, 170]
get_counts()

Chips

expose keccak256 as a gadget

impl keccak256 for keccak-f

keccak-f a trait,

gadget

rho, theta,

gadget is a group of traits

chips implements the traits

keccak-f as the unit interface

try learn from the poseidon example https://github.com/zcash/orchard/blob/main/src/circuit/gadget/poseidon.rs

(WIP) Circuit as a Lookup Table

Because most keccak256 in evm doesn't have fixed length input. If we want to handle all cases including the worst case by defining some max input length with a complete keccak256, we might have to fine-tune the circuit to not waste too much.

Another approach is to not have a complete keccak256, we instead have repeating keccak_f with some beginning and end condition, which seems more friendly for other circuit to use.

Naively each keccak_f could look like:

  • size: 1~17, the current block size (for later padding use)
  • offset: if not is_final then we have to bump the number of offset (either 1 or 17)

where i_* is input and the region inside waves we perform keccak_f.

Other circuits call it like lookup(i_1, i_2, i_3, ..., i_17, hash, offset, size, is_final) to ensure keccak256 is correct. (would call multiple times if the input is more than 17 words)

Then keccak256 circuit will set next input to all 0 and check if hash is matching next output (o_1, o_2, o_3, o_4) when is_final is set to 1.

The reason for next input and next output is because current algorithm kind of defer the keccak_f to next round with no input. If we are changing the order, the base 13 would not work anymore.

When the is_final =1, we do the padding.

The approach to handle when is_final is set to 1 seems to have much room for optimization, but not have too much thought for now.
han

Select a repo