Full disclosure; this post is not about gardening but about implementing ZK versions of machine learning algorithms with botanical nomenclature: decision trees, gradient boosted trees, and random forests. If you're a keen gardener check this out.
Lingering Github issues give me heart palpitations, particularly those that have been open for months on end. Sitting like mildew in an otherwise pristine home.
Here's one we've had open since January of this year:
EZKL
(for those not in the know), is a library for converting common computational graphs, in the (quasi)-universal .onnx
format, into zero knowledge (ZK) circuits. This allows, for example, for:
Though our library has improved in its scope of supported models, including transformer-based models (see here for a writeup), GANs, and LSTMs; implementing Kaggle crushing models like random forests and gradient boosting trees has been challenging.
Part of the issue stems from the way sklearn
, xgboost
, and lightgbm
models are exported to .onnx
. Instead of decomposing tree based models into individual operations, like matrix multiplication, addition, and comparisons, the whole model is exported as a single node (see image above) !
Given our library's focus on modularity and composability this has been a bit of an anti-pattern, a proverbial thorn in our side.
This weekend after yet another call with users asking for a timeline for the implementation of such models we decided to roll up our sleeves and get it done in 48 hours. Check out a colab example here.
Here's what it took.
As noted above, having single node onnx
graphs is an anti-pattern, something that might destroy our library's clean architecture if we try and accomodate it. A much better approach would be to instead convert the single node graph into its individual operations. Luckily we are not the only folks in history to have been keen to do this. And we landed on the beautifully coded sk2torch library which takes a graph of this form:
And turns it into something like this:
So much nicer !
For more complex models like random forests we can simply extract the individual trees / estimators, run them through sk2torch
and recreate the forest as a pytorch module.
trees = []
for tree in clr.estimators_:
trees.append(sk2torch.wrap(tree))
print(trees)
class RandomForest(nn.Module):
def __init__(self, trees):
super(RandomForest, self).__init__()
self.trees = nn.ModuleList(trees)
def forward(self, x):
out = self.trees[0](x)
for tree in self.trees[1:]:
out += tree(x)
return out / len(self.trees)
For xgboost
and lightgbm
we leveraged hummingbird, a Microsoft library for converting xgboost into torch / tensor graphs. A converted XGBoost classifier looks like this when exported to onnx:
An observant reader will note that some operations, like ArgMax
or Gather
don't have particularly obvious implementations in zero-knowledge circuits. This was the second leg of our sprint.
In python a simple and innocent indexing operation over a one-dimensional tensor \(x\), z = x[m]
is trivial. But in ZK-circuits how do we enforce this sort of indexing? especially when the indices like (m
) might be private (advice in plonk parlance) values?
The first argument we constructed was one which allows us to implement zk-circuit equivalents of the Gather
operation. Which essentially just indexes a tensor x
at a given set of indices. To allow for these indices to be advice values we need to construct a new kind of argument for indexing over vectors / tensors in a zk-circuit.
equals
argument (see appendix below for a full description of this argument) to generate the following constraint:
Note that we want \(b\) to be \(0\) at indices not equal to \(m\) and to be \(1\) at \(m\). This is a boolean operation, and should be distinguised from the typical zk-circuit operation of constraining two elements to be equal (i.e arguments of the form \(x - y =0\)).
Altogether this set of arguments and constraints allow us to constrain the claimed output \(z\) to be the \(m^{th}\) element of \(x\).
The construction of argmax and argmin is very similar to the private indexing argument (and in fact leverages it). We add one additional constaint which is that, for a claimed \(m = \text{argmax}(x)\), we should have \(x[m] = \text{max}(x)\).
Say we want to calculate \(m = \text{argmax}(x)\), where x is of length \(N\).
For argmin you only need to replace the above max
operations with min
:)
You can try out colab notebooks for the new tree based models at:
All these models (when properly calibrated using ezkl) output predictions that are less than \(0.1\%\) away from the original sklearn
, xgboost
, and lightgbm
predictions.
ReLU
element-wise operation.Consider the following plonk columns:
| a0 | a1 | m | s_dot |
|-----|-----|-------|-------|
| a_i | b_i |m_{i-1}| s_sum |
| | |m_i | |
The sum between vectors \(a\) and \(b\) is then enforced using the following constraints \(\forall i\): \(a_i + b_i + m_{i-1} = m_i\)