THIS NOTE IS DEPRECATED see here for the latest comparator https://github.com/appliedzkp/zkevm-specs/blob/master/src/encoding/comparator.py
We break down u256 words a and b into 32 u8 chunks, respectively. Now we want to perform conditional check like a>=b
or a<b
, which is not natively allowed in Plonk circuit.
We first build a table to do comparisons between a u8 and u8. The step yields 32 groups of 2 bits comparators.
We then regroup the 32 groups into 4 groups of 16 bits values. We lookup another table and yields 2 bits comparator each group. We now have 4 groups of 2 bits, which can be convert to a 8 bits value.
We finally lookup a 8 bits table to get the final result.
from typing import NewType
U8 = NewType("U8", int)
U256 = NewType("U256", int)
U32 = NewType("U32", int)
def u256_to_u8s(x: U256) -> Sequence[U8]:
"""
Off circuit preprocessing
"""
assert x < 2**256
result = []
for i in range(32):
result.append( (x >> (8*i)) & 255)
return result
# a = 1
a8s = [1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
# b = 257
b8s = [1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
# c = 258
c8s = [2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
class U8IsGreaterTable:
"""
a: a 8 bit value
b: a 8 bit value
comparator: 2 bits
(3 columns and 2**16 rows)
"""
def __init__(self):
self.rows = []
for a in range(2**8):
for b in range(2**8):
self.rows.append({
"a": a,
"b": b,
"comparator": (a==b) + 2 * (a>b)
})
def lookup(self, a: U8, b: U8, comparator: int) -> bool:
for row in self.rows:
if (
row["a"] == a and
row["b"] == b and
row["comparator"] == comparator
):
return True
return False
class U16ComparatorTable:
"""
a: a 16 bit value of 8 groups of 2 bits input comparators
output_comparator: 2 bits
(2 columns and 2**16 rows)
"""
def __init__(self):
self.rows = []
for a in range(2**16):
for i in reversed(range(8)):
eq_bit, gt_bit = a % 2**i, a % 2**(i+1)
if gt_bit == 1:
self.rows.append({"a":a, "output_comparator": 2})
break
if eq_bit == 0:
self.rows.append({"a":a, "output_comparator": 0})
break
# if we end up 8 groups of hot eq_bit, then this means equality
self.rows.append({"a":a, "output_comparator": 1})
def lookup(self, a: U8, comparator: U2) -> bool:
for row in self.rows:
if row["a"] == a and row["output_comparator"] == comparator:
return True
return False
class U8ComparatorTable:
"""
a: a 8 bit value. 4 groups of 2 bits input comparators
is_gte: bool
(2 columns and 2**8 rows)
"""
def __init__(self):
self.rows = []
for a in range(2**8):
for i in reversed(range(4)):
eq_bit, gt_bit = a % 2**i, a % 2**(i+1)
if gt_bit == 1:
self.rows.append({"a":a, "is_gte": True})
break
if eq_bit == 0:
self.rows.append({"a":a, "is_gte": False})
break
self.rows.append({"a":a, "is_gte": True})
def lookup(self, a: U8, is_gte: bool) -> bool:
for row in self.rows:
if row["a"] == a and row["is_gte"] == is_gte:
return True
return False
def is_greater_than_equal(
a: U256,
b: U256,
# [00] * 32
comparators: Sequence[U2],
# 00 00 00 00
comparators_2: Sequence[U2],
a_gte_b: bool
):
"""
validate a >= b
(32 + 4 + 1 = 37 lookups)
"""
# TODO: commitment and 8 bit range check
a8s = u256_to_u8s(a256)
b8s = u256_to_u8s(b256)
assert len(a8s) == 32
for a8, b8, cbits in zip(a8s, b8s, comparators):
require(
U8IsGreaterTable.lookup(a8, b8, cbits),
"should satisfy cbit = (a==b) + 2 * (a>b)"
)
# we end up 32 groups of 2 bits comparators here
# We group 8 of them
for i in range(4):
# 00 00 00 00 00 00 00 00
u16 = sum([ 4**j * comparators[8*i + j] for j in range(8)]
require(U16ComparatorTable.lookup(u16, comparators_2[i]))
u8 = sum([ 4**j * comparators_2[i] for i in range(4)]
require(U8ComparatorTable.lookup(u8, a_gte_b))
# Let x0 < x1 == y < x2 < x3
x0 = [5, 6]
x1 = [6, 6]
x2 = [7, 6]
x3 = [0, 7]
y = [6, 6]
x0 >= y # False
x1 >= y # True
x2 >= y # True
x3 >= y # True
# U8 table use a>=b
cbit_0 = [0, 1]
cbit_1 = [1, 1]
cbit_2 = [1, 1]
cbit_3 = [0, 1]
# U8 table use a>b
cbit_0 = [0, 0]
cbit_1 = [0, 0]
cbit_2 = [1, 0]
cbit_3 = [0, 1]
# Conclusion: It's not possible to use a>=b and a>b table to differentiate 4 cases
# Add extra carry bit
# U8 table use (a==b) + 2 * (a>b)
cbit_0 = [0, 1] # 3
cbit_1 = [1, 1] # 4
cbit_2 = [2, 1] # 5
cbit_3 = [0, 2] # 6