Comparator

THIS NOTE IS DEPRECATED see here for the latest comparator https://github.com/appliedzkp/zkevm-specs/blob/master/src/encoding/comparator.py

Intro

We break down u256 words a and b into 32 u8 chunks, respectively. Now we want to perform conditional check like a>=b or a<b, which is not natively allowed in Plonk circuit.

We first build a table to do comparisons between a u8 and u8. The step yields 32 groups of 2 bits comparators.
We then regroup the 32 groups into 4 groups of 16 bits values. We lookup another table and yields 2 bits comparator each group. We now have 4 groups of 2 bits, which can be convert to a 8 bits value.
We finally lookup a 8 bits table to get the final result.

from typing import NewType
U8 = NewType("U8", int)
U256 = NewType("U256", int)
U32 = NewType("U32", int)

def u256_to_u8s(x: U256) -> Sequence[U8]:
    """
    Off circuit preprocessing
    """
    assert x < 2**256
    result = []
    for i in range(32):
        result.append( (x >> (8*i)) & 255)
    return result

# a = 1
a8s = [1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
# b = 257
b8s = [1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
# c = 258
c8s = [2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
class U8IsGreaterTable:
    """
    a: a 8 bit value
    b: a 8 bit value
    comparator:  2 bits
    (3 columns and 2**16 rows)
    """
    def __init__(self):
        self.rows = []
        for a in range(2**8):
            for b in range(2**8):
                self.rows.append({
                    "a": a,
                    "b": b,
                    "comparator": (a==b) + 2 * (a>b)
                })
                
    def lookup(self, a: U8, b: U8, comparator: int) -> bool:
        for row in self.rows:
            if (
                row["a"] == a and 
                row["b"] == b and 
                row["comparator"] == comparator
            ):
                return True
        return False

class U16ComparatorTable:
    """
    a: a 16 bit value of 8 groups of 2 bits input comparators
    output_comparator: 2 bits
    (2 columns and 2**16 rows)
    """
    def __init__(self):
        self.rows = []
        for a in range(2**16):
            for i in reversed(range(8)):
                eq_bit, gt_bit = a % 2**i, a % 2**(i+1)
                if gt_bit == 1:
                    self.rows.append({"a":a, "output_comparator": 2})
                    break
                if eq_bit == 0:
                    self.rows.append({"a":a, "output_comparator": 0})
                    break
            # if we end up 8 groups of hot eq_bit, then this means equality
            self.rows.append({"a":a, "output_comparator": 1})    
           
    def lookup(self, a: U8, comparator: U2) -> bool:
        for row in self.rows:
            if row["a"] == a and row["output_comparator"] == comparator:
                return True
        return False

class U8ComparatorTable:
    """
    a: a 8 bit value. 4 groups of 2 bits input comparators
    is_gte: bool
    (2 columns and 2**8 rows)
    """
    def __init__(self):
        self.rows = []
        for a in range(2**8):
            for i in reversed(range(4)):
                eq_bit, gt_bit = a % 2**i, a % 2**(i+1)
                if gt_bit == 1:
                    self.rows.append({"a":a, "is_gte": True})
                    break
                if eq_bit == 0:
                    self.rows.append({"a":a, "is_gte": False})
                    break
            self.rows.append({"a":a, "is_gte": True})    
           
            
    def lookup(self, a: U8, is_gte: bool) -> bool:
        for row in self.rows:
            if row["a"] == a and row["is_gte"] == is_gte:
                return True
        return False

def is_greater_than_equal(
        a: U256,
        b: U256,
        # [00] * 32
        comparators: Sequence[U2],
        # 00 00 00 00
        comparators_2: Sequence[U2],
        a_gte_b: bool
    ):
    """
    validate a >= b
    (32 + 4 + 1 = 37 lookups)
    """
    # TODO: commitment and 8 bit range check
    a8s = u256_to_u8s(a256)
    b8s = u256_to_u8s(b256)
    
    assert len(a8s) == 32
    
    for a8, b8, cbits in zip(a8s, b8s, comparators):
        require(
            U8IsGreaterTable.lookup(a8, b8, cbits),
            "should satisfy cbit = (a==b) + 2 * (a>b)"
        )

    # we end up 32 groups of 2 bits comparators here
    # We group 8 of them
    for i in range(4):
        # 00 00 00 00 00 00 00 00
        u16 = sum([ 4**j * comparators[8*i + j] for j in range(8)]
        require(U16ComparatorTable.lookup(u16, comparators_2[i]))
    
    u8 = sum([ 4**j * comparators_2[i] for i in range(4)]
    require(U8ComparatorTable.lookup(u8, a_gte_b))

Appendix


# Let x0 < x1 == y < x2 < x3
x0 = [5, 6]
x1 = [6, 6]
x2 = [7, 6]
x3 = [0, 7]

y = [6, 6]

x0 >= y # False
x1 >= y # True
x2 >= y # True
x3 >= y # True

# U8 table use a>=b
cbit_0 = [0, 1]
cbit_1 = [1, 1]
cbit_2 = [1, 1]
cbit_3 = [0, 1]

# U8 table use a>b
cbit_0 = [0, 0]
cbit_1 = [0, 0]
cbit_2 = [1, 0]
cbit_3 = [0, 1]

# Conclusion: It's not possible to use a>=b and a>b table to differentiate 4 cases

# Add extra carry bit
# U8 table use (a==b) + 2 * (a>b)
cbit_0 = [0, 1] # 3
cbit_1 = [1, 1] # 4
cbit_2 = [2, 1] # 5
cbit_3 = [0, 2] # 6


Select a repo