# **Building a Digit Recognition Classifier and Verifier with PyTorch, Lilith and Next.js** ![image](https://hackmd.io/_uploads/BJd0f2a96.png) This is part 1 of our tutorial on building the [MNIST-Clansifier](https://e2e-mnist.vercel.app/) demo app.The entire app is open source which you can view on Github; check out the model [here](https://github.com/zkonduit/e2e-mnist/blob/main/mnist_classifier.ipynb), the frontend [here](https://github.com/zkonduit/e2e-mnist) and the smart contract [here](https://goerli-optimism.etherscan.io/address/0xf5cDCD333E3Fd09929BAcEa32c2c1E3A5A746d45#code) ## **Introduction** In this tutorial, we will demonstrate an end-to-end tutorial that showcases how to train a model for handrawn digit recognition, use Lilith to deploy the model, build a smart contract that interfaces with the an EVM verifier and a front end that calls on Lilith for delegated proving. The proofs get submitted to a contract that assigns the users account to a digit clan (0-9). The contract keeps track of the member count of each clan. The clan with the most members is the winner! This application is a good faith fork of the first ZKML application built by [horacepan](https://horacepan.github.io/), [sunfishstanford](https://github.com/sunfishstanford) and [henripal](https://github.com/henripal) as part of 0xPARC's winter 2021 applied zk learning group. You can find the original app [here](https://zkmnist.netlify.app/) and source code [here](https://github.com/0xZKML/zk-mnist). This is the project that inspired the creaton of EZKL and we are excited to replicate and gamify it :). ## **Data Preparation and Training** - Dataset Loading: PyTorch's torchvision.datasets.MNIST to load the MNIST dataset of handwritten digits. - Normalize Data: normalize_img to set all pixel values to binary (0 or 1), mimicking input data from the drawing interface we will build. - Data Pipeline: Employs DataLoader for batching (256 images per batch) for both training and testing datasets. ## **Training Process** - Model Building: Uses the LeNet model defined in PyTorch with convolutional and fully connected layers. - Model Configuration: The model is adapted to the appropriate device (GPU/CPU), using Adam optimizer and CrossEntropyLoss function. - Training Loop: Runs for 25 epochs, including forward pass, loss calculation, and backpropagation, with model accuracy evaluated after each epoch. - Exporting: Exports the PyTorch model to ONNX format and exports a sample dataset in JSON format ## **Deploying to Lilith** - Proving Service: Lilith, a backend proving service, simplifies the process of generating and managing zk-specific artifacts. - Deployment Steps: Includes initializing EZKL run args, gathering calibration data and deploy the model using lilith cli. and setting up a next.js serverless function for making proof requests to the lilith API. ## **EVM Verifier Deployment and Verification** - Verifier Contract: The verifier contract is deployed on the Goerli testnet using the Optimism L2 network. - Proof Verification: We setup a next.js serverless function that uses the Lilith API to fetch proofs on the client side and then submit it on chain for verification. ## **Frontend and JS Binding** - Tensor Preparation: Converts the drawn digit into a tensor for processing. - Digit Classification: Analyzes the drawn digit and displays the prediction result from the `proof.pretty_public_inputs.outputs[0]` output. - Proof Generation: Utilizes Lilith's backend [API](https://archon.ezkl.xyz/swagger-ui/) for remotely proving that the model classified a given digit. - Smart Contract Integration: Uses wagmi's public provider and specifies the contract's ABI and address for on-chain verification. # Model Tutorial This is part 2 of our tutorial on building the [e2e-mnist](https://e2e-mnist.vercel.app) demo app where we go over the model training and exporting process. To follow along with this portion of the tutorial, you can run the associated [notebook](https://github.com/zkonduit/e2e-mnist/blob/main/mnist_classifier.ipynb) in Google Colab. [!embed aspect="1:1" height="340"](https://www.researchgate.net/publication/318972455/figure/fig2/AS:525282893615105@1502248609221/The-overall-LeNet-architecture-The-numbers-at-the-convolution-and-pooling-layers.png) > Diagram of the LeNet architecture. Image source: [ResearchGate](https://www.researchgate.net/profile/Gerard-Pons-3/publication/318972455/figure/fig2/AS:525282893615105@1502248609221/The-overall-LeNet-architecture-The-numbers-at-the-convolution-and-pooling-layers.png) ### Data Preparation and Training 1. Lenet Model: - The LeNet model is defined using PyTorch's neural network module (torch.nn). It includes a series of convolutional layers (nn.Conv2d) and fully connected layers (nn.Linear), with activations functions (F.sigmoid) applied appropriately. The architecture is a classic choice for image classification tasks. ```python mnist_classifier.ipynb import torch import torch.nn as nn import torch.nn.functional as F class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() # Convolutional encoder self.conv1= nn.Conv2d(1, 6, 5) # 1 input channel, 6 output cha nnels, 5x5 kernel self.conv2 = nn.Conv2d(6, 16, 5) # 6 input channels, 16 output channels, 5x5 kernel # Fully connected layers / Dense block self.fc1 = nn.Linear(16 * 4 * 4, 120) self.fc2 = nn.Linear(120, 84) # 120 inputs, 84 outputs self.fc3 = nn.Linear(84, 10) # 84 inputs, 10 outputs (number of classes) def forward(self, x): # Convolutional block x = F.avg_pool2d(F.sigmoid(self.conv1(x)), (2, 2)) # Convolution -> Sigmoid -> Avg Pool x = F.avg_pool2d(F.sigmoid(self.conv2(x)), (2, 2)) # Convolution -> Sigmoid -> Avg Pool # Flattening x = x.view(x.size(0), -1) # Fully connected layers x = F.sigmoid(self.fc1(x)) x = F.sigmoid(self.fc2(x)) x = self.fc3(x) # No activation function here, will use CrossEntropyLoss later return x ``` 1. Dataset Loading: - The MNIST dataset, a collection of handwritten digits, is loaded using PyTorch's torchvision.datasets.MNIST class. This process involves specifying parameters such as root for the storage directory, train to toggle between training and testing sets, transform to apply transformations like converting images to tensor format, and download to download the data if it's not already present locally. ```python mnist_classifier.ipynb import numpy as np import os import torch from torchvision.datasets import mnist from torch.utils.data import DataLoader from torchvision.transforms import ToTensor device = 'cuda' if torch.cuda.is_available() else 'cpu' batch_size = 256 train_dataset = mnist.MNIST(root='./train', train=True, transform=ToTensor(), download=True) test_dataset = mnist.MNIST(root='./test', train=False, transform=ToTensor(), download=True) train_loader = DataLoader(train_dataset, batch_size=batch_size) test_loader = DataLoader(test_dataset, batch_size=batch_size) ``` 2. Normalization Function: - The normalize_img function is a crucial part of the data preparation. - It rounds the pixel values of the images to 0 or 1. This step mimics th binary nature of the input data expected from a drawing interface, where the pixels are either filled or not. We will build this out in part 4 of the tutorial where we focus on the frontend. ```python mnist_classifier.ipynb def normalize_img(image, label): return torch.round(image), label ``` 3. Data Pipeline for Training Set: - The training dataset undergoes normalization using the DataLoader class, which combines the dataset and a sampler, providing an iterable over the dataset. - Data is batched with a size of 256, meaning 256 images are processed in each iteration of training. - No explicit shuffling or prefetching is used here, but these could be incorporated for enhanced performance. 4. Data Pipeline for Testing Set: - The testing dataset is similarly normalized and batched using the DataLoader. - The batching process for the test dataset mirrors the training set, facilitating a consistent evaluation process. ### Training 1. Model Configuration: - The model is moved to the appropriate device (GPU if available, otherwise CPU) using PyTorch’s `.to(device)` method. - An instance of the Adam optimizer is created with the model's parameters. Adam is chosen for its efficiency in handling sparse gradients and adaptive learning rate capabilities. - The loss function is defined as CrossEntropyLoss, a standard choice for classification tasks with multiple classes. 2. Model Training: - The training loop runs for a predefined number of epochs (25 in this case). Each epoch consists of a forward pass where predictions are generated, a loss calculation, and a backward pass for gradients computation and parameters update. - The accuracy of the model is evaluated at the end of each epoch on the test dataset. This provides insight into the model's generalization capability on unseen data. ```python mnist_classifier.ipynb model = LeNet().to(device) adam = Adam(model.parameters()) # Using Adam optimizer loss_fn = CrossEntropyLoss() all_epoch = 25 prev_acc = 0 for current_epoch in range(all_epoch): model.train() for idx, (train_x, train_label) in enumerate(train_loader): train_x = train_x.to(device) # normalize the image to 0 or 1 to reflect the inputs from the drawing board train_x = train_x.round() train_label = train_label.to(device) adam.zero_grad() # Use adam optimizer predict_y = model(train_x.float()) loss = loss_fn(predict_y, train_label.long()) loss.backward() adam.step() # Use adam optimizer all_correct_num = 0 all_sample_num = 0 model.eval() for idx, (test_x, test_label) in enumerate(test_loader): test_x = test_x.to(device) # normalize the image to 0 or 1 to reflect the inputs from the drawing board test_x = test_x.round() test_label = test_label.to(device) predict_y = model(test_x.float()).detach() predict_y = torch.argmax(predict_y, dim=-1) current_correct_num = predict_y == test_label all_correct_num += np.sum(current_correct_num.to('cpu').numpy(), axis=-1) all_sample_num += current_correct_num.shape[0] acc = all_correct_num / all_sample_num print('test accuracy: {:.3f}'.format(acc), flush=True) if not os.path.isdir("models"): os.mkdir("models") torch.save(model, 'models/mnist_{:.3f}.pkl'.format(acc)) prev_acc = acc ``` ### Exporting to ONNX The trained PyTorch model is then converted to the ONNX (Open Neural Network Exchange) format using torch.onnx.export, which facilitates model interoperability and makes it compatible with various platforms and tools. The conversion process involves specifying the input shape and the ONNX opset version (always set to 12). Alongside, an input data file (input.json) is created by extracting a sample dataset from the training dataset and serializing it into JSON format. This data is used later for Lilith deployment. ```python mnist_classifier.ipynb import torch import json import os model_path = os.path.join('network_lenet.onnx') model.eval() # Set the model to evaluation mode # # Fetch a single data point from the train_dataset # # Ensure train_dataset is already loaded and accessible train_data_point, _ = next(iter(train_dataset)) train_data_point = train_data_point.unsqueeze(0) # Add a batch dimension # Verify the device (CPU or CUDA) and transfer the data point to the same device as the model device = 'cuda' if torch.cuda.is_available() else 'cpu' train_data_point = train_data_point.to(device) # # Export the model to ONNX format torch.onnx.export(model, train_data_point, model_path, export_params=True, opset_version=12, do_constant_folding=True, input_names=['input_0'], output_names=['output']) # Convert the tensor to numpy array and reshape it for JSON serialization x = train_data_point.cpu().detach().numpy().reshape([-1]).tolist() data = {'input_data': [x]} with open('input.json', 'w') as f: json.dump(data, f) print(f"Model exported to {model_path} and input data saved to input.json") ``` ### Deploying to Lilith If you have been following along with the `mnist_classifier` notebook, you should be on cell number [7](https://colab.research.google.com/github/zkonduit/e2e-mnist/blob/main/mnist_classifier.ipynb#scrollTo=dS-yXte30rZ3). There is another version of this notebook on the main EZKL repo where we perform the setup and proving all locally [here](https://colab.research.google.com/github/zkonduit/ezkl/blob/main/examples/notebooks/mnist_classifier.ipynb#scrollTo=dS-yXte30rZ3). Comparing this version of the notebook to the one on the e2e-mnist repo, you may notice the cells are identical up until cell 7. After this point we generate a lot of zk specific artifacts (namely a settings file, a compiled circuit, a proving key, a verification key, a witness, a proof and solidity verifier) whereas with the other notebook doesn't even import the ezkl library. This is a lot of artifacts to manage and takes a lot of compute resources to generate. Moreover, once you have the artifacts necessary to generate proofs, you then have to set up a proving server if you want to generate proofs on behalf of your users. To make this process easier, we have created a backend proving service called [Lilith](https://app.ezkl.xyz/) that generates and manages these artifacts and makes serving proofs as easy as using the EZKL cli, except all the commands are run a remote compute cluster instead of locally. To deploy your model to Lilith, you will need to log into Lilith using your Github. Then you will need to generate an API key [here](https://app.ezkl.xyz/api-key) and add it to your shell's path (the instructions for doing so are on the site). Now we can deploy our model to lilith using lilith's cli, which we refer to as `archon`. Here are the set of commmands we will run to deploy our model to lilith: 1. Create an artifact and name it whatever you want (we will call it mnist for this example) Make sure to replace `network_lenet.onnx` with the path to your model and likewise for `input.json` and `cal_data.json`. The `cal_data.json` is a file that contains a set of data points from the training set to use for calibration. Calibrating the circuit with sample inputs like this ensures that the lookup tables within our circuit are robust to outliers and therefore not likely to fall out of range of the [lookups](https://zcash.github.io/halo2/design/proving-system/lookup.html) (we represent non-linearities as lookup tables). ```bash archon create-artifact -a mnist -m network_lenet.onnx -i input.json -c cal_data.json ``` 2. Once the artifact is created, we can start running a series of ezkl specific commands to setup a circuit we need to generate proofs. We will start by generating a settings file. We set the param visibility to fixed so that the parameters are hardcoded into the circuit and can't be updated later. We use the default values for the input and output visibility (private and public respectively). ```bash archon job gen-settings -a mnist --param-visibility fixed ``` >Note: After each job is run, you will be given a job id. You can use this job id to check the status of the job using the `archon get -i <your job id>` command. You can add the `-p` flag to poll the job until it is complete. 3. Next we calibrate the settings file. We set the input scales to 2 and 7. These are the scales you want to try out during calibration... Under the hood calibration iterates over the scales array to see how precise you can go before failure. So in the code snippet above if 7 fails it falls back to 2. Right now the default scales values for the resources target is 8-10 and 10-13 for accuracy. We found if you set it any lower the accuracy of the circuit will be too low (fails to recognize digits more often than the vanilla model) and if you set it any higher the proving time will be too high (greater than 20 seconds), so 2 is a good middle ground. ```bash archon job -a mnist calibrate-settings --scales 2,7 ``` 4. Now we can compile the circuit. ```bash archon job -a mnist compile-circuit ``` 5. Next we call the setup commands to generate a proving and verifying key pair. ```bash archon job -a mnist setup ``` 6. Now we can generate an witness file and use that to generate a test proof to ensure that our model was deployed properly. :) ```bash archon job -a mnist gen-witness ``` 7. Finally we can generate a proof. ```bash archon job -a mnist gen-proof ``` # Contracts Tutorial This is part 3 of our tutorial on building the [e2e-mnist](https://e2e-mnist.vercel.app) demo app where we over the development of the mnist clan contract. Check out the contract code [here](https://goerli-optimism.etherscan.io/address/0xf5cDCD333E3Fd09929BAcEa32c2c1E3A5A746d45#code) The minst clan contract stores the hand drawn digits of user accounts and makes calls to the on-chain evm verifier of the corresponding digit recognition model to validate the submitted digit. ## Developing MnistClan.sol **1. Verifier Contract Interface:** The contract leverages an external Verifier contract to verify the proof submitted by users. The Verifier contract interface has one function - **`verifyProof`** - which takes public inputs and a proof as parameters and returns a boolean indicating whether the proof is valid. ```solidity // SPDX-License-Identifier: MIT pragma solidity ^0.8.13; interface Verifier { function verifyProof( bytes calldata proof, uint256[] calldata instances ) external returns (bool); } ``` **2. Constants and State Variables:** - Verifier Reference: The contract holds a reference to the Verifier contract, which is set at the time of contract deployment. - Constants and Mappings: - ORDER: The order of the [prime field](https://0xparc.org/blog/zk-pairing-2#:~:text=Although%20our%20circuits%20target%20the%20BLS12%2D381%20curve%2C%20they%20can%20be%20easily%20adapted%20to%20other%20curves.%20We%20have%20recently%20modified%20our%20circuits%20to%20work%20for%20the%20BN254%20curve%2C%20which%20Ethereum%20supports%20with%20precompiles%20for%20elliptic%20curve%20arithmetic%20and%20pairings.%20This%20used%20the%20following%20modifications%3A) EZKL uses. - THRESHOLD: Used to in the feltToInt function for determing the cut off when a field element goes from being represented as a positive integer to a negative integer. - Mappings: - entered: Tracks whether an account has already submitted a digit. - clan: Maps an account to its chosen digit. - counts: Keeps track of submissions for each digit. ```solidity contract MnistClan { // The admin address in charge of updating the to new verifier each new cycle. Verifier public verifier; /** * @notice EZKL P value * @dev In order to prevent the verifier from accepting two version of the same pubInput, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a * @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P. */ uint256 constant ORDER = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001; uint256 private constant THRESHOLD = uint128(type(int128).max); // Keeps track of whether an account has already entered a digit. mapping(address => bool) public entered; // Accounts clan mapping(address => uint8) public clan; // Keeps track of the number of submissions for each digit. mapping(uint8 => uint256) public counts; constructor(Verifier _verifier) { verifier = _verifier; } ``` **3. Core Functions** - submitDigit - Purpose: Allows users to submit a digit classification along with a proof. - Process: - Ensures one submission per account. - Validates the proof using the Verifier. - Determines the digit with the max value from the instances using feltToInt. - We need to do this because some of the outputs of the model can be negative. - Updates the `entered` and `clan` mappings. - Increments the count of submissions for the submitted digit. ```solidity function submitDigit( bytes calldata proof, uint256[] calldata instances ) public { // One submission per account require(!entered[msg.sender]); // Verify EZKL proof. require(verifier.verifyProof(proof, instances)); // retrieve the index with the max value uint256 maxIndex = 0; int256 maxValue = feltToInt(instances[0]); for (uint256 i = 1; i < instances.length; i++) { int256 adjustedValue = feltToInt(instances[i]); if (adjustedValue > maxValue) { maxValue = adjustedValue; maxIndex = i; } } // update the entered mapping entered[msg.sender] = true; // update clan mapping clan[msg.sender] = uint8(maxIndex); // Update the counts mapping (this should be safe since instances length will always be == 10) uint256 count = ++counts[uint8(maxIndex)]; counts[uint8(maxIndex)] = count; } ``` - feltToInt - Purpose: Converts a field element to an integer. - Process: - If the field element is less than the threshold, it is returned as is. - Otherwise, the field element is subtracted from the prime field order and the result is returned. ```solidity function feltToInt(uint256 value) private pure returns (int256) { if (value > THRESHOLD) { return int256(value) - int256(ORDER); } else { return int256(value); } } ``` - getCounts - Purpose: Getter function for returning the contents of the `counts` mapping. - Use Case: Makes getting the clan counts easier on the frontend as we don't have to call on the mapping iteratively. ```solidity function getCounts() public view returns (uint256[10] memory) { uint256[10] memory countsArray; for (uint8 i = 0; i < 10; i++) { countsArray[i] = counts[i]; } return countsArray; } } ``` ## Deploying the contracts **1. Download Solidity Verifier from Lilith** - First you will need to run this Lilith job from the CLI to generate the verifier contract. ```bash archon job create-evm-verifier ``` - The verifier contract is available on [Lilith](https://app.ezkl.xyz/artifacts/mnist) and can be downloaded by clicking the 'Download' button on the `evm_deploy.sol` artifact name. Copy and paste this code into a new file in the Remix IDE. **2. Adjust compiler settings.** - Setting EVM Version: As you prepare to deploy the verifier and crypto idol contracts, it is critical to set the Ethereum Virtual Machine (EVM) version to a configuration that's compatible with layer 2 blockchains. Our testing has shown the 'London' version to be most compatible. For this tutorial we'll use the [Remix](https://remix.ethereum.org/) IDE to deploy our contracts, but your could also use Foundry, Hardhat or you prefered EVM compatible smart contract development framework. To modify the EVM version to `London`, navigate to the `Advanced Configurations` tab and select `London` from the `EVM Version` dropdown list. Neglecting to make this adjustment might result in unsuccessful deployment of your verifier Solidity code, often manifesting as a 'push0 not a valid opcode' error. - Enabling Optimizations: In the `Compiler` tab, make sure `Enable optimization` is checked and the dropdown is set to `1`, else you will run into a `stack too deep` error. ![](../../assets/mnistcontracts0.png) **3. Deployment** - You need to deploy the verifier contract first as you will need to pass the address of the deployed verifier to the `Mnist Clan` contract’s constructor. Click the page icon next to the ‘x’ on the deployed verifier instance to copy its address, then paste it into the `_verifier` deploy param of `MnistClassifier.sol`. # Frontend Tutorial This is part 4 of our tutorial on building the [e2e-minst](https://e2e-mnist.vercel.app) demo app; check out the [frontend code](https://github.com/zkonduit/e2e-mnist). # Overview Armed with the artifacts we need to prove and verify on-chain, we need to build a frontend that can: 1. [Collect input data from the user (drawn digits)](#step-1-collecting-input-data) 2. [Retrieve proofs from Lilith](#step-2-generating-proofs-using-lilith) 3. [Submit proof for on-chain verification](#step-3-verifying-on-chain) ## Step 1. Collecting Input Data The first step is to collect the input data from the user. In our case, we want to collect a 28x28 pixel image of a hand drawn digit. Here is the Next.js code that renders a drawing board that we can use to faciltate such collection: ```tsx MNISTDraw.tsx 'use client' import { useState, useCallback } from 'react' import './MNIST.css' import './App.css' const GRID_SIZE = 28 as const interface IMNISTBoardProps { grid: number[][] onChange: (row: number, col: number) => void } interface IGridSquareProps { isActive: boolean onMouseDown: () => void onMouseEnter: () => void onMouseUp: () => void } function GridSquare({ isActive, onMouseDown, onMouseEnter, onMouseUp, }: IGridSquareProps) { return ( <div className={`square ${isActive ? 'on' : 'off'}`} onMouseEnter={onMouseEnter} onMouseDown={onMouseDown} onMouseUp={onMouseUp} /> ) } function MNISTBoard({ grid, onChange }: IMNISTBoardProps) { const [mouseDown, setMouseDown] = useState(false) const handleMouseDown = useCallback( (row: number, col: number) => { setMouseDown(true) onChange(row, col) }, [onChange] ) const handleMouseUp = useCallback(() => { setMouseDown(false) }, []) const handleMouseEnter = useCallback( (row: number, col: number) => { if (mouseDown) { onChange(row, col) } }, [mouseDown, onChange] ) const size = GRID_SIZE return ( <div className='MNISTBoard'> <div className='centerObject'> <div className='grid'> {Array.from({ length: size }, (_, col) => ( <div key={`col-${col}`}> {Array.from({ length: size }, (_, row) => ( <GridSquare key={`row-${row}-col-${col}`} isActive={grid[row][col] === 1} onMouseDown={() => handleMouseDown(row, col)} onMouseEnter={() => handleMouseEnter(row, col)} onMouseUp={handleMouseUp} /> ))} </div> ))} </div> </div> </div> ) } export function MNISTDraw() { const size = GRID_SIZE const [grid, setGrid] = useState( Array(size) .fill(null) .map(() => Array(size).fill(0)) ) function handleSetSquare(myrow: number, mycol: number) { const newArray = grid.map((row, i) => (i === myrow ? [...row] : row)) newArray[myrow][mycol] = 1 setGrid(newArray) } return ( <> <div className='MNISTPage'> <h1 className='text-2xl'>Draw and classify a digit</h1> <MNISTBoard grid={grid} onChange={handleSetSquare} /> </div> </> ) } ``` ## Step 2. Generating Proofs Using Lilith Next, we need to generate a proof using Lilith. We need to ensure that the API key we use to authenticate with Lilith is kept secret. We can do this by storing it in an environment variable and then accessing it in our code. We also need to setup a Next.js serverless function to make proof requests to the Lilith API to ensure the API key is not exposed to the client. Here are the steps to do this: 1. Create a `.env.local` file in the root of the project and add the following lines to it: ``` ARCHON_API_KEY="insert-your-archon-api-key-here" // You api key that you generated from the app.ezkl.xyz dashboard NEXT_PUBLIC_ARCHON_URL="https://archon.ezkl.xyz" // The url we will make lilith API requests to ``` 2. Create a `Next.js` serverless function that will make proof requests to the Lilith API. Create a folder called `pages/api` in the root of the project and then create a file called `generateProof.tsx` in the `api` folder.Add the following code to the `generateProofs.tsx` file: >NOTE: For more details on the Lilith API schema, checkout the API docs [here](https://archon.ezkl.xyz/swagger-ui/) ```tsx generateProof.tsx // pages/api/generateProof.ts // Route that calls into Lilith API to generate a proof by first: // 1. Uploading the input data to Lilith // 2. Calling GenWitness and Prove commands on the uploaded data // 3. Polling for the status of the proof generation // 4. Returning the proof once it is generated import type { NextApiRequest, NextApiResponse } from 'next'; import axios from 'axios'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { if (req.method === 'POST') { try { let formData = new FormData(); formData.append("data", new Blob([JSON.stringify(req.body)], { type: "application/json" })); const axiosResponse = await axios.put(`${process.env.NEXT_PUBLIC_ARCHON_URL}/artifact/mnist`, formData, { headers: { 'X-API-KEY': process.env.ARCHON_API_KEY, 'Content-Type': 'multipart/form-data'// Ensures the correct Content-Type header for multipart/form-data }, }); // get the latest uuid const uuid = axiosResponse.data.latest_uuid; // Prepare data for gen-witness and prove requests const requestBody = [ { "ezkl_command": { "GenWitness": { "data": `input_${uuid}.json`, "compiled_circuit": "model.compiled", "output": `witness_${uuid}.json`, }, }, "working_dir": "mnist", }, { "ezkl_command": { "Prove": { "witness": `witness_${uuid}.json`, "compiled_circuit": "model.compiled", "pk_path": "pk.key", "proof_path": `proof_${uuid}.json`, "proof_type": "Single", "check_mode": "UNSAFE", }, }, "working_dir": "mnist", }, ]; // Prove request using axios const proveRes = await axios.post(`${process.env.NEXT_PUBLIC_ARCHON_URL}/recipe`, requestBody, { headers: { 'X-API-KEY': process.env.ARCHON_API_KEY, } }); console.log("proveRes.data.id: ", proveRes.data.id) let getProofResp let status = null while (status !== 'Complete') { getProofResp = await axios.get(`${process.env.NEXT_PUBLIC_ARCHON_URL}/recipe/${proveRes.data.id}`, { headers: { 'X-API-KEY': process.env.ARCHON_API_KEY } }); status = getProofResp.data[1].status if (status === 'Complete') { break } await new Promise((resolve) => setTimeout(resolve, 2_000)) } let parsedData = JSON.parse(getProofResp?.data[1].output) res.status(200).json({ message: 'Proof generation successful', data: parsedData }); } catch (error) { console.error(error); res.status(500).json({ error }); } } else { res.setHeader('Allow', ['POST']); res.status(405).end(`Method ${req.method} Not Allowed`); } } ``` <!-- Once we get the proof, we need to parse the `instances` field of the proof to get the prediction. We do this by converting the `instances` field to a BigInt array and then finding the index of the max value of the array. To do this we convert each field element to an integer using the same logic in the intToFieldElement function in the solidity verifier contract. We then find the index of the max value of the array. This index is the predicted digit. Here is the code that does this: --> 3. Now we need to make a `POST` request to the `generateProof` route from the frontend. Once we get the proof from the `POST` request, we need to parse the rescaled outputs of the model from the proof and then find the index of the max value of the outputs array to get the predicted digit. Here is the code that does this: > NOTE: The `rescaled_outputs` field of the proof contains the field element outputs of the model converted back into a floating point values based on the scaling factor of the circuits outputs. ```tsx ezkl.ts const [prediction, setPrediction] = useState<number>(-1); const [grid, setGrid] = useState(Array(size).fill(null).map(() => Array(size).fill(0))); // initialize to a 28x28 array of 0's const [generatingProof, setGeneratingProof] = useState(false); const [proofDone, setProofDone] = useState(false); const [proof, setProof] = useState<any>(null); async function doProof() { // get image from grid let imgTensor: number[] = Array(MNISTSIZE).fill(0); for (let i = 0; i < size; i++) { for (let j = 0; j < size; j++) { imgTensor[i * size + j] = grid[i][j]; } } const inputFile = JSON.stringify({ input_data: [imgTensor] }); console.log("inputFile", inputFile); setGeneratingProof(true); try { const responseProof = await fetch('/api/generateProof', { method: 'POST', headers: { 'Content-Type': 'application/json', // Set Content-Type to application/json }, body: inputFile, // Send the inputFile string as the request body }); let result = await responseProof.json(); console.log(result); setProof(result?.data) const results = result?.data?.pretty_public_inputs?.rescaled_outputs[0] console.log('results', results) if (!results || results.length === 0) { throw new Error('Array is empty') } let maxIndex = 0 let maxValue = results[0] for (let i = 1; i < results.length; i++) { if (results[i] > maxValue) { maxValue = results[i] maxIndex = i } } setPrediction(maxIndex) setProofDone(true) } catch (error) { alert(error); console.log('error', error) } setGeneratingProof(false) } ``` ## Step 3. Verifying on Chain Last but not least, we need to be able to verify the proof on chain. For this we will need to call the `submitDigit` method on the `MnistClan.sol` contract, passing in the `proof` and `instances` as arguments. - First we will need to create a wagmi provider and a rainbow wallet that will provide us read access to the contract deployed on the Optimism Goerli network and a wallet connection respectively. Here is how that would look like: ```tsx wagmi.tsx "use client"; import '@rainbow-me/rainbowkit/styles.css'; import { getDefaultWallets, RainbowKitProvider } from '@rainbow-me/rainbowkit'; import { configureChains, createConfig, WagmiConfig } from 'wagmi'; import { alchemyProvider } from 'wagmi/providers/alchemy'; import { publicProvider } from 'wagmi/providers/public'; import { optimismGoerli } from "wagmi/chains"; const { chains, publicClient } = configureChains( [optimismGoerli], [ alchemyProvider({ apiKey: process.env.NEXT_PUBLIC_ALCHEMY_KEY!, }), publicProvider() ] ); const { connectors } = getDefaultWallets({ appName: 'My RainbowKit App', projectId: 'YOUR_PROJECT_ID', chains }); const wagmiConfig = createConfig({ autoConnect: true, connectors, publicClient }) export default function WagmiProvider({ children, }: { children: React.ReactNode; }) { return ( <WagmiConfig config={wagmiConfig}> <RainbowKitProvider chains={chains}> {children} </RainbowKitProvider> </WagmiConfig> ) } ``` - Next we will wrap this provider around the app so that we can use the `useWagmi` hook to access the contract. Here is how that would look like: ```tsx page.tsx import type { Metadata } from "next"; import "./globals.css"; import WagmiProvider from "@/providers/wagmi"; export const metadata: Metadata = { title: "MNIST Clan", description: "Submit a ZKML proof of your classified handrawn digit to join the clan.", }; export default function RootLayout({ children, }: { children: React.ReactNode; }) { return ( <html lang="en"> <body> <WagmiProvider>{children}</WagmiProvider> </body> </html> ); } ``` - Next we will need to create a couple of [json files](https://github.com/zkonduit/e2e-mnist/tree/main/contract_data) that store the address and ABI of each contract that we deployed and then import that info into our app. - You can get the ABI from remix by clicking on the solidity compiler icon and then the ABI button at the bottom under "Compilation Details". - Finally we will instantiate both the mnist clan and verifier contracts. If the user hasn't submited a digit already, the proofs of digit recognition will get sent to the `MnistClan.sol` contract when the `SubmitMnistDigitButton` is clicked. If the user has already submited a digit, the `VerifyOnChainButton` will be appear instead and will merely statically call the `verifyProof` method on the `Halo2Verifier.sol` contract. - We will also display the counts of each digit submitted by all users in the form of a bar graph, as well as the user's clan and rank within the clan. Here is the code that does this: ```tsx MNISTDraw.tsx 'use client' import { Modal } from 'flowbite-react' import { useState, useEffect, FC } from 'react' import './MNIST.css' import './App.css' import { Button } from '@/components/button/Button' import styles from '../../app/styles.module.scss' import { stringify } from 'json-bigint' import { getContract } from 'wagmi/actions' import { publicProvider } from 'wagmi/providers/public' import { useAccount, usePrepareContractWrite, useContractWrite, useWaitForTransaction } from 'wagmi' import { ConnectButton } from '@rainbow-me/rainbowkit'; import BarGraph from '../bargraph/BarGraph'; // Adjust the path as necessary import MNIST from '../../contract_data/MnistClan.json' import Verifier from '../../contract_data/Halo2Verifier.json' const size = 28 const MNISTSIZE = 784 export function MNISTDraw() { const [openModal, setOpenModal] = useState<string | undefined>() const props = { openModal, setOpenModal } const [prediction, setPrediction] = useState<number>(-1) const [proof, setProof] = useState<any | null>(null) const [generatingProof, setGeneratingProof] = useState(false) const [counts, setCounts] = useState<number[] | null>(null) const [clan, setClan] = useState<number | null>(null) const [clanRank, setClanRank] = useState<number | null>(null) const [verifyResult, setVerifyResult] = useState<boolean | null>(null) const [proofDone, setProofDone] = useState(false) const [grid, setGrid] = useState<number[][]>( Array(size) .fill(null) .map(() => Array(size).fill(0)) ) // initialize to a 28x28 array of 0's const { address, isConnected } = useAccount() const { config } = usePrepareContractWrite({ address: MNIST.address as `0x${string}`, abi: MNIST.abi, functionName: 'submitDigit', args: [ proof?.hex_proof, proof?.pretty_public_inputs?.outputs[0] ], enabled: true, }) const { data, error, isError, write } = useContractWrite(config) const { isLoading, isSuccess } = useWaitForTransaction({ hash: data?.hash, }) const provider = publicProvider() // Instantiate the contract using wagmi's getContract hook const contract = getContract({ address: MNIST.address as `0x${string}`, abi: MNIST.abi, walletClient: publicProvider(), chainId: 420, }) async function getAccountClanInfo() { let entry = await contract.read.entered([address]) as boolean let clan = await contract.read.clan([address]) as number setClan(entry ? clan : null) console.log('entry', entry) console.log('clan', clan) let counts = await contract.read.getCounts() as number[] // convert BigInt to number counts = counts.map((count) => Number(count)) // determine clan rank setCounts(counts) if (!entry) { return } let rank = 1 for (let i = 0; i < counts.length; i++) { if (counts[i] > counts[clan]) { rank++ } } setClanRank(rank) console.log('counts', counts) } useEffect(() => { (async () => { if (isConnected && (!clan || isSuccess)) { getAccountClanInfo() } if (!isConnected && clan) { setClan(null) setCounts(null) } })() }, [isConnected, isSuccess, address]) // Reload clan info when account changes useEffect(() => { if (isConnected) { setClan(null) setCounts(null) } }, [address, isConnected]); function ShowClanResultsBlock() { if (!counts) { return } return ( <div> <div className="MNISTClanChart"> <div className="chart-container"> <BarGraph data={counts} /> </div> </div> </div> ) } async function doProof() { // get image from grid let imgTensor: number[] = Array(MNISTSIZE).fill(0); for (let i = 0; i < size; i++) { for (let j = 0; j < size; j++) { imgTensor[i * size + j] = grid[i][j]; } } const inputFile = JSON.stringify({ input_data: [imgTensor] }); console.log("inputFile", inputFile); setGeneratingProof(true); try { const responseProof = await fetch('/api/generateProof', { method: 'POST', headers: { 'Content-Type': 'application/json', // Set Content-Type to application/json }, body: inputFile, // Send the inputFile string as the request body }); let result = await responseProof.json(); console.log(result); setProof(result?.data) const results = result?.data?.pretty_public_inputs?.rescaled_outputs[0] console.log('results', results) if (!results || results.length === 0) { throw new Error('Array is empty') } let maxIndex = 0 let maxValue = results[0] for (let i = 1; i < results.length; i++) { if (results[i] > maxValue) { maxValue = results[i] maxIndex = i } } setPrediction(maxIndex) setProofDone(true) } catch (error) { alert(error); console.log('error', error) } setGeneratingProof(false) } async function doOnChainVerify() { let verifierContract = getContract({ address: Verifier.address as `0x${string}`, abi: Verifier.abi, walletClient: provider, chainId: 420, }) let result = await verifierContract.read.verifyProof([proof?.hex_proof, proof?.pretty_public_inputs?.outputs[0]]) as boolean setVerifyResult(result); } async function doSubmitMnistDigit() { if (!write) { return } write() } function resetImage() { var newArray = Array(size) .fill(null) .map((_) => Array(size).fill(0)) setGrid(newArray) setProofDone(false) setVerifyResult(null) } function handleSetSquare(myrow: number, mycol: number) { var newArray = [] for (var i = 0; i < grid.length; i++) newArray[i] = grid[i].slice() newArray[myrow][mycol] = 1 setGrid(newArray) } function ProofButton() { return ( <Button className={styles.button} text='Classify & Prove' loading={generatingProof} loadingText='Proving...' onClick={doProof} /> ) } function VerifyOnChainButton() { return ( <Button className={styles.button} text='Verify On Chain' disabled={!proofDone} loading={isLoading} loadingText='Verifying...' onClick={doOnChainVerify} /> ) } function SubmitMnistDigitButton() { return ( <Button className={styles.button} text='Submit Mnist Digit' disabled={!proofDone || !write || isLoading} loading={isLoading} loadingText='Verifying...' onClick={doSubmitMnistDigit} /> ) } function ResetButton() { return ( <Button className={styles.button} text='Reset' onClick={resetImage} /> ) } function ProofBlock() { return ( <div className='proof'> <Button className='w-auto' onClick={() => props.setOpenModal('default')} data-modal-target='witness-modal' data-modal-toggle='witness-modal' text='Show Proof' /> <Modal show={props.openModal === 'default'} onClose={() => props.setOpenModal(undefined)} > <Modal.Header>Proof: </Modal.Header> <Modal.Body className='bg-black'> <div className='mt-4 p-4 bg-black-100 rounded'> <pre className='blackspace-pre-wrap'> {stringify(proof, null, 6)} </pre> </div> </Modal.Body> </Modal> </div> ) } function PredictionBlock() { return ( <div className='predction color-white'> <h1>Prediction</h1> {prediction} </div> ) } function VerifyOnChainBlock() { return ( <div className='verify'> <h1 className='text-2xl'> Verified on chain: { } <a href={`https://goerli-optimism.etherscan.io/address/${Verifier.address}#code`} target='_blank' rel='noopener noreferrer' style={{ textDecoration: 'underline' }} > {Verifier.address} </a> </h1> </div> ) } if (proofDone && isError) { window.alert(`Transaction failed on MnistClan contract:${error?.message}`) } return ( <div className='MNISTPage'> <h1 className='text-2xl'>Draw and classify a digit</h1> <MNISTBoard grid={grid} onChange={(r, c) => handleSetSquare(r, c)} /> <div className='flex justify-center pt-7'> <ConnectButton /> </div> {clan && <h1 className='text-2xl pt-7'>Your MNIST Clan: {clan} </h1>} {clan && <h1 className='text-2xl'>Your Clan Rank: {clanRank} </h1>} <div className='buttonPanel'> <ProofButton /> {clan ? <VerifyOnChainButton /> : <SubmitMnistDigitButton />} <ResetButton /> </div> {proofDone && PredictionBlock()} {proofDone && ProofBlock()} {(isSuccess || !(verifyResult == null)) && VerifyOnChainBlock()} {clan && ShowClanResultsBlock()} </div> ) } ```