owned this note
owned this note
Published
Linked with GitHub
# Where do features come from? A story of sinusoids and inductive biases
By [Ben Edelman](https://www.benjaminedelman.com/), [Depen Morwani](https://depenm.github.io), Costin Oncescu, and [Rosie Zhao](https://rosieyzh.github.io/)
This blog post is based on [Feature emergence via margin maximization: case studies in algebraic tasks](linktoarxiv) by [Depen Morwani](https://depenm.github.io), [Ben Edelman](https://www.benjaminedelman.com/), Costin Oncescu, [Rosie Zhao](https://rosieyzh.github.io/), and [Sham Kakade](https://sham.seas.harvard.edu/).
## The Mystery
Suppose it is 9am. What will the time be 5 hours from now?
There are many valid ways to solve this problem. For instance:
1. **Counting up by one hour five times:** 10am, 11am, 12pm, 1pm, 2pm. (At the fourth step, you needed to use the memorized fact that 1, not 13, follows 12 on a clock.)
2. **Addition followed by subtraction:** $9 + 5 = 14$. $14 - 12 = 2$. So: 2pm.
3. **Memorization:** You conveniently remember off the top of your head that 5 hours after 9am is 2pm.
4. **Clock visualization:** You envision an analog clock set to 9 o'clock in your mind's eye, and mentally rotate the hour hand five hours forward. At its new angle, you observe the hand is pointing at 2 o'clock.
It would be straightforward to write code implementing any of the above algorithms for [modular addition](https://en.wikipedia.org/wiki/Modular_arithmetic) (in this case, adding two integers modulo 12). But this is going to be a story about deep learning, and the fundamental principle of deep learning is laziness: why intelligently design an algorithm and hand-code it yourself, when you could just feed a bunch of data into an off-the-shelf neural network, train it with an all-purpose optimizer, and watch it *learn* its *own* computational strategy?
Admittedly, the machinery of deep learning isn't particularly practically useful for clean, synthetic tasks like this; it's meant for messy tasks like predicting the next word in Internet text or classifying cats and dogs. But our goal will be *understanding* deep learning, and scientific understanding is sometimes easiest to arrive at by studying toy cases.
In any case, training a neural network to perform modular addition turns out to be an interesting exercise. In 2022, [Power et al.](https://arxiv.org/abs/2201.02177) trained transformers on modular arithmetic tasks and observed that, surprisingly, "long after severely overfitting, validation accuracy sometimes suddenly begins to increase from chance level toward perfect generalization." [Various](https://arxiv.org/abs/2205.10343) works [since then](https://arxiv.org/abs/2206.04817) have tried to [understand](https://arxiv.org/abs/2210.01117) this [befuddling](https://arxiv.org/abs/2310.06110) generalization behavior, dubbed "grokking." But instead of focusing on the question of why grokking occurs, we will focus on a different question: Which modular addition algorithm does the trained network implement, *and why*?
Earlier this year, [Nanda et al.](https://arxiv.org/abs/2301.05217) empirically investigated the first half of this question. They found that, remarkably, small transformers consistently learn to implement a version of the **clock visualization algorithm**---converting the inputs into cosines and sines of the corresponding angles, and then using trigonometry to add the angles!
{{Expanding box: Details of clock visualization algorithm}}
The training dataset consists of inputs of the form $(a,b)$, paired with the corresponding target outputs $c = a+b \bmod p$, where $a,b,c \in \mathbb{Z}_p$ with $\mathbb{Z}_p = \{ 0,1,...,p-1 \}$. The algorithm identified by [Nanda et al.](https://arxiv.org/abs/2301.05217) can be seen as a real-valued implementation of the following procedure:
1. Choose a fixed $k$. Embed $a \mapsto e^{2\pi i k a}$, $b \mapsto e^{2 \pi i k b}$, representing rotations by $ka$ and $kb$.
2. Multiply these (i.e. compose the rotations) to obtain $e^{2 \pi i k(a+b)}$.
3. Then, for each $c$ in the output, multiply by $e^{-2\pi i k c}$ and take the real part to obtain the logit for $c$.

Caption: Based on Figure 1 from Nanda et al.
The algorithm fundamentally relies on the following identity: for any $a, b \in \mathbb{Z}_p$ and $k \in \mathbb{Z}_p \setminus \{0\}$,
$$(a+b) \textrm{ mod } p = \text{argmax}_{c\in \mathbb{Z}_p} \left\{\cos\left(\frac{2\pi k(a+b-c)}{p}\right)\right\}$$
\begin{eqnarray}
\{ \alpha \implies \beta, \alpha \} & \infers & \beta\\
\{ \alpha \implies \beta, \neg \beta \} & \infers & \neg \alpha\\
\{ \alpha \land \beta \} & \infers & \alpha\\
\{ \alpha , \beta \} & \infers & \alpha \land \beta\\
\{ \alpha \} & \infers & \alpha \lor \beta\\
\{ \alpha \lor \beta, \neg \alpha \} & \infers & \beta\\
\alpha \Leftrightarrow \beta & \equiv & \beta \Leftrightarrow \alpha\\
\alpha \Leftrightarrow \beta & \equiv & (\alpha \implies \beta) \land \beta \implies \alpha\\
\alpha \implies \beta & \equiv & \neg \alpha \lor \beta\\
\alpha \land \beta & \equiv & \beta \land \alpha\\
\alpha \lor \beta & \equiv & \beta \lor \alpha\\
\neg (\alpha \land \beta) & \equiv & \neg \alpha \lor \neg \beta\\
\neg (\alpha \lor \beta) & \equiv & \neg \alpha \land \neg \beta\\
\neg \neg \alpha & \equiv & \alpha
\end{eqnarray}
Moreover, averaging the result over neurons with different frequencies $k$ results in destructive interference when $c \neq a + b$, accentuating the correct answer.
<!--- The training dataset consists of inputs of the form $(a,b)$, paired with the corresponding target outputs $c = a+b \bmod p$. (Note that $a$, $b$, and $c$ are encoded as [one-hot](https://en.wikipedia.org/wiki/One-hot) vectors.) Nanda et al. discovered that the network computes the output in three steps, from bottom to top in the following diagram. For each neuron, the steps are roughly
1. Map one-hot inputs $a, b$ to $\cos(wa), \cos(wb), \sin(wa), \sin(wb)$ for some Fourier frequency $w = 2\pi k/p$.
2. Compute $\cos(w(a+b)) = \cos(wa)\cos(wb) - \sin(wa)\sin(wb)$ and $\sin(w(a+b)) = \cos(wa)\sin(wb) + \sin(wa)\cos(wb)$.
3. For each $c$ output logits proportional to $\cos(w(a+b-c)) = \cos(w(a+b))\cos(wc) - \sin(w(a+b))\sin(wc)$, which is maximized precisely when $c = a + b \mod p$.

Caption: from [Nanda et al.](https://arxiv.org/abs/2301.05217) --->
{{End expanding box}}
In a twist on the story, [Zhong et al.](https://arxiv.org/abs/2306.17844) found that not all neural network architectures use the same procedure---some modified networks learn to implement a related but distinct 'pizza algorithm'. But, notably, the pizza algorithm also starts out by calculating sinusoidal functions of the input, which we will refer to as *Fourier features*. This, then, is our primary mystery:
***Why do neural networks have a bias towards using Fourier features?***
To be more mathematical about it, modular addition is a finite group operation, with the group in question being the [cyclic](https://en.wikipedia.org/wiki/Cyclic_group) group. And Fourier analysis on the cyclic group is a special case of [representation theory](https://en.wikipedia.org/wiki/Representation_theory) for general groups. [Chughtai et al.](https://arxiv.org/abs/2302.03025) presented suggestive evidence that neural networks trained on another group, the [symmetric group](https://en.wikipedia.org/wiki/Symmetric_group), learn to convert the inputs into features corresponding to the [irreducible representations](https://en.wikipedia.org/wiki/Irreducible_representation) of the group, which are analogous to Fourier features!
{{Expanding box: Chughtai et al.'s construction}}
To be precise, [Chughtai et al.](https://arxiv.org/abs/2302.03025) show that one layer ReLU MLPs and transformers learn the task by taking representation matrices $R(a)$, $R(b)$ of group elements $a, b$ and performing matrix multiplication with $R(c^{-1})$ such that the logit at output $c$ is proportional to the *character* $\chi_R(abc^{-1}) = \mathrm{tr}(R(a)R(b)R(c^{-1}))$ which is just the trace of the resulting matrix product. The output logits happen to be maximized precisely when $c = ab$ for all irreducible representations $R$.
{{End expandable box}}
So the mystery has deepened.
*Why do neural networks have a bias towards solving finite group tasks using irreducible representations?*
Before we try to solve the mysteries, let's take a step back. What's the point?
Trained deep neural networks are famously black boxes. Or are they? Over the past several years, various researchers have peered into real big AI models and tried---sometimes with a modicum of success---to explain *how* they compute the things they compute, at a comprehensible level of abstraction. What features and circuits do neural networks learn to employ when they are trained to solve a given task?
This pursuit has gone under various names---[BERTology](https://arxiv.org/abs/2002.12327) in the NLP community; [mechanistic interpretability](https://transformer-circuits.pub/2022/mech-interp-essay/index.html) in the machine learning community. Whenever such an investigation is successful, it raises a further question: *why*? What was it about the architecture and training process that biased the network towards a particular computational strategy?
If we can answer the "why" question, we can gain more leverage on various other questions: Can we predict which mechanisms a network will learn? Can we understand why [different](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html) mechanisms [are favored](https://arxiv.org/abs/2309.07311) at different stages of training? Can we intervene on the learning process to modify the mechanisms, to make them more robust, safe, fair, etc.?
The modular addition task has served as a relatively tractable case study for mechanistic interpretability. It is thus a natural choice of case study for that *why* question. So let's begin our investigation.
## Boiling the problem down to the essence
First, we will simplify the setting down to the essentials. Just an MLP, no biases---an embedding layer, activation function, and unembedding layer.

For simplicity, we train these networks using population gradient descent on the full distribution. The phenomenon seen in previous works still happens!
Below, we visualize how the embedding weights and their Fourier power spectrum evolve throughout training on the mod-71 task (with $L_2$ regularization), when ReLU activations are used:
FIRST VIDEO GOES HERE
**Caption** - In the above video, the red dots in each row on the left depicts the vector $\textbf{u}$ of weights feeding into a single hidden layer neuron (for 10 arbitrary neurons in the network). The light blue interpolation is obtained by finding the function over the reals with the same [Fourier spectrum](https://en.wikipedia.org/wiki/Discrete_Fourier_transform) as the weight vector. On the right, each row shows the corresponding [Fourier power spectrum](https://en.wikipedia.org/wiki/Spectral_density): the strength of each frequency in the Fourier spectrum.
We can see that the embedding weights for each neuron become periodic, with almost all of the Fourier spectrum concentrated on a single frequency! But the vectors aren't quite pure sinusoids: they aren't smooth (because, we suspect, of the ReLU activations).
Let's replace the ReLU activations with quadratic activations $x^2$. This makes the phenomenon much cleaner (and easier to analyze!)[^1]
[^1]: Note that the Clock visualization algorithm still can be expressed using quadratic activations. One could further ask if changing to this architecture makes studying the solutions uninteresting, in the sense that this is the *only* solution it can express. In the Appendix of our paper, we show that the network can still express a memorizing solution even with quadratic activations.
SECOND VIDEO GOES HERE
#### Caption: Evolution of trained embeddings and their Fourier power spectrum for a 1-hidden layer quadratic network trained on a mod-71 addition dataset with $L_{2,3}$ regularization [^2]. Each row corresponds to an arbitrary neuron from the trained network. The red dots represent the actual value of the weights, while the light blue interpolation is obtained by finding the function over the reals with the same Fourier spectrum as the weight vector.
[^2]: L~2,3~ norm : Consider a neural network of width $m$, and let the parameters associated with the $i^{th}$ neuron be represented by $u_i, v_i$ and $w_i$. Then $L_{2,3}$ norm of the network is defined as $\sum_{i=1}^m (||u_i||^2 + ||v_i||^2 + ||w_i||^2)^{3/2}$.
In a technical sense, this norm is the "natural" norm for quadratic activations, in the same sense that $L_2$ norm is natural for ReLUs.
What seems to be happening is that as training progresses, the network approaches a limit, and in this limit each neuron's embedding vector is a pure sinusoid. This is also true for the unembedding vectors. In fact, for each neuron, the frequency of its embedding and unembedding vectors is the same.
## Solving the mystery
Given that the phenomenon is exhibited in such a pure form in MLPs with quadratic activations, one may hope for an elegant mathematical explanation.
This is where we're going to bring in a important insight from deep learning theory: the inductive bias of neural networks toward _maximum margin_ solutions. A maximum margin solution is a setting of the network weights that minimizes the network's total weight norm, subject to classifying every data point correctly with a given confidence (or "margin").
{{Expanding box: Formal definition of maximum (normalized) margin}}
Consider a neural network $f(\theta; x)$, where $\theta$ and $x$ represent its parameters and input respectively. For a given norm $|| \cdot ||$, let $\Theta = \{ \theta: || \theta || \leq 1 \}$. The maximum normalized margin of the network with respect to the given norm, when trained on a multi-class classification task with dataset $D$ is defined as
$$\max_{\theta \in \Theta} \min_{(x,y) \sim D} f(\theta; x)[y] - \max_{y' \in \mathcal{Y}\backslash y} f(\theta; x)[y']$$
where $\mathcal{Y}$ represents the set of classes.
{{End expanding box}}
In particular, [a result of Wei et al.](https://arxiv.org/abs/1810.05369) implies that standard training with sufficiently small regularization tends towards the maximum margin solution.[^3]
[^3]: Let $\gamma^*$ represent the maximum normalized margin of the network with respect to a $|| \cdot ||$. Under mild assumptions on $f$, the normalized margin of the global minimizer of the loss given by $\mathbb{E}_{(x,y) \sim D} \ell(f(\theta; x), y) + \lambda \|\theta\|^r$ approaches $\gamma^*$ as $\lambda \to 0$. Here $\ell$ represents the standard cross-entropy loss and $r > 0$.
In our paper, we present a suite of theoretical techniques for deriving the *precise value* of the maximum margin. In the below plots, we show that, empirically, the margin of the network indeed approaches the derived value over the course of training!

**Caption** - Evolution of the normalized margin of the quadratic network with training steps for the task of addition mod 23. It asymptotically reaches the theoretically predicted maximum margin.

**Caption** - Evolution of the normalized margin of the quadratic network with training steps for task of addition mod 71. It asymptotically reaches the theoretically predicted maximum margin.
So we can predict the value of the margin... but does that actually imply anything about the learned circuit?
Yes. We are able to prove that for the task of addition mod $p$, if the network has width at least $4(p-1)$ and achieves the maximum margin, then all of the weight vectors are sinusoids, precisely of the following form:
$$u(a) = \lambda \cos(\theta_u^* + 2 \pi ka/p), \quad v(b) = \lambda \cos(\theta_v^* + 2 \pi kb/p), \quad w(c) = \lambda \cos(\theta_w^* + 2 \pi kc/p),$$
where $\lambda \in \mathbb{R}$ is some constant, $k \in \left\{1, \dots, \frac{p-1}{2}\right\}$ is the frequency of the neuron, and $\theta_u^*,\theta_v^*,\theta_w^*$ are phase offsets satisfying $\theta_u^* + \theta_v^* = \theta_w^*$
Moreover, we prove that _every_ frequency is used by _some_ neuron.

**Caption** - Final distribution of the neurons corresponding to each frequency, in a quadratic-activation network trained on the task of addition mod 23.

**Caption** - Final distribution of the neurons corresponding to each frequency, in a quadratic-activation network trained on the task of addition mod 71.
{{Expanding box: Brief discussion of mathematical techniques}}
How did we calculate the max-margin value and characterize the maximum margin solutions? The central tool is the [max-min inequality](https://en.wikipedia.org/wiki/Max–min_inequality). Consider the definition of normalized maximum margin:
$$\max_{\theta \in \Theta} \min_{(x,y) \sim D} f(\theta; x)[y] - \max_{y' \in \mathcal{Y}\backslash y} f(\theta; x)[y']$$
Letting $Q$ be the set of distributions defined over $(x,y) \in D$, we can rewrite the definition above as
$$\max_{\theta \in \Theta} \min_{q \in Q} \mathbb{E}_{(x,y) \sim q} \left[f(\theta; x)[y] - \max_{y' \in \mathcal{Y}\backslash y} f(\theta; x)[y']\right]$$
Let $\gamma_\theta^q = \mathbb{E}_{(x,y) \sim q} \left[f(\theta; x)[y] - \max_{y' \in \mathcal{Y}\backslash y} f(\theta; x)[y']\right]$. Then, the max-min inequality implies that
$$ \max_{\theta \in \Theta} \min_{q \in Q} \gamma_\theta^q \leq \min_{q \in Q} \max_{\theta \in \Theta} \gamma_\theta^q.$$
Our technique aims at finding a certificate pair $(\theta^*, q^*)$ such that
$$ q^* \in \text{argmin}_{q \in Q} \gamma_{\theta^*}^q \text{ and } \theta^* \in \text{argmax}_{\theta \in \Theta} \gamma_\theta^{q^*}.$$
If such a pair exists, the "max-min property" holds: the above inequality becomes an equality, with the optimal value given by $\gamma_{\theta^*}^{q^*}$.
In order to find a certificate pair, we reduce the problem from an optimization of the *full network* to an optimization over a single neuron considered in isolation. For details, refer to Section 3 of our paper.
{{End expandable box}}
Thus, we have a resolution to the mystery in our stylized setting:
**Neural networks have a tendency to approach maximum margin solutions, and every maximum margin solution uses Fourier features.**
## Other algebraic tasks
Futhermore, we were able to extend our max margin analysis from modular addition (the cyclic group) to general finite groups, explaining the empirical results of Chughthai et al.! What is the 'analogous' result here? Basically, instead of all _frequencies_ being used, all group _representations_ are used. Furthermore, all neurons only use a _single_ representation. For more details, see Section 6 of our paper.
We also derived results of a similar flavour for the **sparse parity setting** studied in works such as [Daniely et al.](https://arxiv.org/abs/2002.07400), [Barak et al.](https://arxiv.org/abs/2207.08799), and [Edelman et al.](https://arxiv.org/abs/2309.03800)
## Beyond algebraic tasks?
We have shown that at least for simple algebraic tasks and simple neural networks, we can actually explain *where features come from* as a consequence of a known inductive bias of deep learning.
What are the prospects for understanding where features come from in general? *If* we can explain why neural networks prefer certain circuits over others, this can have significant implications:
- To what extent are learned circuits *universal*, and to what extent are they sensitive to the architecture and learning algorithm?
- Can we modify any aspects of the learning process to favor circuits that are more interpretable, robust, fair, or have other desired properties?
- In some cases, like training transformers on standard arithmetic, the "right" algorithm (i.e., one that generalizes off the training distribution) isn't learned by default. If we can explain success stories of algorithm learning, can we also explain the failure cases?
We are hopeful that better understanding the inductive biases of neural networks will lead to a better understanding of feature learning.