--- tags: idea title: Neural Tangent Kernel for Explaining NNs --- # Background on NTK The neural tangent kernel was originally developed to model the gradient-based training dynamics of infinitely wide neural networks. However, the underlying concept is fundamentally interesting. The NTK asks the question: how does the neural network's output at point $x'$ change in response to taking an infinitesimally small gradient step incorporating a training datapoint $(x,y)$. Let's denote the network by $f$, its parameters by $\theta$. Given a datapint $(x, y)$, the we incur a loss $\ell(f(x, \theta), y)$. Let's consider updating the parameter $\theta$ in a way as to reduce this loss. This we can do using a gradient step with stepsize $\eta$ $$ \theta'_\eta(x,y) = \theta - \eta \nabla_\theta \ell(f(x, \theta), y) $$ Now we'll look at how this parameter's change effects the network's output at $x'$, in the limit as $\eta\rightarrow 0$: $$ g(x\, x, y, \theta) = \lim_{\eta \rightarrow 0} \frac{f(x', \theta'_\eta(x,y)) - f(x', \theta)}{\eta} = \nabla_\theta f(x')^T \nabla_\theta f(x) \left.\frac{\partial}{\partial \hat{y}}\ell(\hat{y}, y)\right|_{\hat{y}=f(x, \theta)} $$g The above quantity $g(x\, x, y, \theta)$ expresses the sensitivity of the network's predictions at $x'$ to a particular observation $(x,y)$. Note that this sensitivity is parameter-dependent and will generally change over the course of training. Early on in training, when $\theta$ is randomly drawn, we get different sensitivities than later on in training when $\theta$ is close to its final values. The theory of NTK states that in the limit of infinite width, when certain assumptions hold about the random initialization of weigths and the network is trained using gradient flow (full batch gradient descent with infinitezimally small learning rates), the sensitivity funnction $g$ is in fact constant and idependent of $\theta$. However, this likely doesn't hold in practice for real neural networks. ### Using this to explain NNs Let's say we trained a deep neural network and it misclassifies datapoint $x'$. Can we use the NTK framework to understand what aspects of the training data resulted in that misclassification? I'll assume that we have access to a series of pararmeter values $\theta_0, \theta_1, \ldots, \theta_K$ which are snapshots of the neural network's weights at various stages of training, for example $\theta_0$ at initialization, $\theta_1$ after the first epoch, etc. For each $k$ we can calculate $g(x', x, y, \theta_k)$ for all datapoints $(x,y)$ in the training set. We can then rank each datapoint $(x,y)$ based on whether they would have had a good or bad influence on how the network classifies $x'$ at that stage. Basically, if taking a gradient step to reduce loss on $(x,y)$ increases the probability of the correct class, we say the datapoint had a good influence, if it decreases the probability of the correct class, or if it increases the probability of the specific class the network misclassifies it as, then we say that $(x,y)$ had a bad influence. We could then visualise the top images that had good or bad influence on the specific datapoints' classification at all stages of training. This visualisation might allow us to provide an explanation of why the misclassification is made, perhaps identifying similar images in the training dataset that were classified differently.