$$ \newcommand{\part}{\partial} \newcommand{\R}{\mathbb{R}} \newcommand{\C}{\mathbb{C}} \newcommand{\Z}{\mathbb{Z}} \newcommand{\Span}{\text{Span}} \newcommand{\Lin}{\text{Lin}} \newcommand{\Func}{\text{Func}} \newcommand{\dim}{\text{dim}} \newcommand{\HP}{\text{HP}} $$ # Masterplan: A Geometrical Theory of Learning A document tracking the work on the generalization theory of neural networks. * Part 1 is a high level desrcription of problem of generalization in ML. Contains an intuitive explanation of why symmetry might help us solve the problem. Should be readable by people who don't have a background in deep learning or mathematics. * Part 2 descrives the problem of generalization in the simpler setting where the function space is a vector space. I'll refer to this as linear learning. Connections with representation theory of Lie groups and Lie algebras will be drawn and used to explain generalization. We'll apply these techniques to study the learning with polynomaials of 2 variables. * In Part 3 will se how some of the ideas in Part 2 can be generalized to the non linear setting of deep neural networks. This is a much harder task and the current state is still very immature. ## 1.1 Learning Functions **Carles Gelada, Jacob Buckman** At its core, machine learning is about fitting functions to data. There is a function $f: X\to Y$ between two spaces that we'd like to learn. The input space $X$ might be text sequences, images encoded as arrays of pixels, sound waves, etc. Similarly, the output spaces $Y$ could be numbers, text, images, etc. For example, one might want to learn a function which takes as input an image of a person, and output an estimate of that person's height; for this application, the input space would be the space of images, and the output space would be the space of positive real numbers. In image reconstruction, the input is a corrupted/low-quality image and the output is the original, high-quality version, so image reconstruction is a mapping from the space of images to itself. In audio transcription, the input is a person speaking and the output is the text of what was said, so audio transcription maps from the space of sound waves to the space of text strings. If we are instead interested in audio synthesis, we can simply reverse this, and instead learn a mapping from the space of text strings to the space of sound waves. You get the idea. This is a very flexible approach, and it can model all sorts of tasks. The challenge of machine learning is to discover $f$ when we only given access to a dataset $D$, a collection of inputs $D = \{x_i\}_{i=1}^n\subset X$ for which we know the value $f(x_i)$. _Some might raise concerns that this definition of machine learning is too narrow. For example, it doesn't account for cases where $f$ outputs a distribution, and so the dataset is composed of examples $(x_i, y_i)$ where $y_i$ is merely sampled from $f(x_i)$. Though that setting is important, we choose to ignore it for now; the deterministic case is simpler, and the issues we study are deeper than the superficial difference between trying to recover a function or a distribution. Understanding one will be the key to understanding the other. Similar tradeoffs are made through the document, prioritizing simplicity over generality as long as the simplification does not impact the central issues._ The basic approach of machine learning is to construct a large space of functions under consideration, and then find a function in this space agrees with $f$ on $D$. For this to work well, we must be confident that the space of functions is large enough that it contains $f$ (or at least a close approximation). The most common way to construct such function spaces is via a parameterization $h: X \times \R^p \to Y$, where $\R^p$ is the space of parameters. By picking different values $w\in \R^p$ we can control the behaviour of the function $h_w: X \to Y$, defined as $h_w(x) = h(x, w)$. Two examples of such function spaces are: * **Polynomials of degree $d$.** We define $h: (\R \times \R^{d+1}) \to \R$ as $h(x, w) = w_0 + w_1 x + \cdots + x^d w_d$. * **Neural networks.** We define $h(x, w) = w_l \sigma( w_{l-1} \sigma(\cdots \sigma(w_1 x)))$ where $l$ is the number of layers, $w_i \in \R^{h_{i+1}\times h_i}$ is a linear map from $\R^{s_{i}}$ to $\R^{s_{i+1}}$ and where $\sigma$ is any standard nonlinearity like tanh or ReLU. Once we have been given our dataset $D$ and chosen our parameterized function space $h$, the job of our learning algorithm is to find a value of the parameters $w\in \R^p$ that fit the data. With that aim in mind, we pick a loss function $L(D, w)$ which measures how successful the function $h_w$ is at fitting the data in $D$. For example, if $Y$ is a scalar, a common way to define the loss function is $L(D, w) = \frac{1}{|D|} \sum_{x_i \in D} \left(h_w(x_i) - f(x_i)\right)^2$. A loss function must take its minumum possible value when $h_w(x_i) = f(x_i), \; \forall x_i \in D$, i.e., when $h_w$ fits the dataset perfectly. Learning is often described as the problem of finding the optimal parameters, defined as $$w_{\textrm{opt}} = \underset{{w\in \R^p}}{\text{argmin}} \: L(D, w)$$ But as we are going to see, this is a fundamentally flawed definition in the context of modern machine learning. _This is because for the function spaces $h$ that we use in practice, there are vast subspaces of the space of parameters $\R^p$ that all equivalently minimize the loss, but which behave completely differently outside of $D$._ The implications of this will be discussed in greater depth in Section 1.3; for now, all we need to understand is that it's not enough to talk about the argmin $w$, we need to be thinking about *which* argmin $w$ we pick. This means that to understand our machine learning algorithms, it is important to look carefully at the specific method used to choose a $w$ that fits the data. In any modern deep learning scenario, that method is **gradient descent**. It works by changing the weights through time according to the differential equation $\frac{\part w}{\part t} = -\nabla_w L(D, w)$, where $\nabla_w L(D, w)$ is the gradient of $L$ w.r.t. the weights $w$. _(Again, for simplicity we discuss only vanilla gradient descent, though in practice it is more common to use variants such as SGD, momentum, Adam, etc.)_ The gradient can be thought of as the best linear approximation to the loss function at the point $w$; it is a vector pointing in the direction where the loss function increases the fastest, and with a magnitude proportional how fast the loss function is increasing. Since the goal of our algorithm is to decrease the loss, we update the parameters in the opposite direction. A function space $h: X \times \R^p \to Y$, together with a loss function $L(D, w)$, and gradient descent on $w$. That is all we are going to be studying here. Because somehow, this extremely simple setup is the core of the recent revolution in artificial intelligence. ## 1.2 The Magic of Generalization The point of learning a function $h_w$ is that we'd like to know $f(x)$ for $x\not\in D$, i.e., points outside of the dataset. (After all, if we wanted to know $f(x)$ for $x\in D$, we could just look it up in the dataset!) So, we find $w$ for which $h(x_i, w) = f(x_i), \;\forall x_i \in D$ by minimizing $L(D,w)$ with gradient descent, and we hope that this results in an $h_w$ which is similar to $f$ elsewhere. If it is, we say the model *generalizes*. A way to measure how well a model $h_w$ generalizes is by checking if it fits the data on a test set $D'\subset X$, where a test set is a collection of points $x'_i \in D'$ for which we also know the values $f(x'_i)$, but which we never used during the training of the parameters $w$. After we've found some parameters $w$ that fit $D$, we measure whether $h_w(x'_i) = f(x'_i)$ on $x'_i \in D'$, and the more points it gets correct, the better it generalizes. _(For simplicity, in this discussion we consider fitting to be a binary yes/no; either a prediction is correct, or it is not. This ignores some important generalizations. For example, when the target lives in $\mathbb{R}$, a more relaxed notion of approximate fitting might be more desirable where, for example, we consider a point to be correctly fit whenever some error is below a threshold.)_ A useful tool for thinking about generalization is to consider a partition of $X$ into its *correct* and *error* subsets, denoted $X_c$ and $X_e$ respectively, where $X_c = \{x\in X \: |\ h_w(x) = f(x) \}$ is the subset where $h_w$ correctly fits $f$ and $X_e = X-X_c$ is the region where it makes errors. Of course, these are purely theoretical objects, since they depend on the unknown function $f$. Yet, they give us a natural way to think about generalization: a model is generalizing well when $X_c$ is much larger than $D$. They also make it evident that different weights or different function spaces might get the correct predictions on equally-large but *completely different* regions of $X$, and so generalization is not naturally a scalar quantity. To put it in the language of test sets: how well a model seems to generalize is dependent on the test set $D'\subset X$ on which you measure it. Ultimately, generalization is about giving something to the learning process and receiving more in return. We only know $f|_D$ (the restriction of $f$ on $D\subset X$) and in return we receive $f|_{X_c}$ (the target function on a larger space). The fact that $X_c$ is larger than $D$ might not seem impressive at first glance; after all, even if we pick the parameters $w$ at random, there might still be some points $x \in X$ where the predictions of $h_w$ are correct by pure coincidence. But a learning algorithm that produces correct regions $X_c$ which are consistently much larger than $D$ is something powerful indeed. However, there is something very important to note about generalization. Since our target function $f$ is in general arbitrary, when we observe the value of $f$ only on the points $D\subset X$, the only points $x$ for which we can be *absolutely sure* of the value $f(x)$ are the points in $D$ themselves. For $x\not\in D$, $f(x)$ *could* take any value in $Y$, and thus no matter how we construct our function space or what learning algorithm we use, there will always be functions $f$ for which $X_c = D$, i.e. the model gets all the predictions wrong outside of the dataset. **In general, generalization is impossible.** But the reason the field of ML has exploded is that we've found an approach which performs far better than anyone could have expected. To understand the challange that lies ahead, it's worth empasizing the astronomical generalization abilities that neural networks have. Consider a task like playing chess, where the input is a board position, and the output is a move. The number of positions one might encouter in the game vastly exeeds the number of atoms in the universe. Yet, a neural network trained to imitate the moves selected by human players on just a few billion board positions will predict good moves on almost any situation it encouters when playing. The vast majority of these will be brand-new, never-before-seen positions. Somehow, the neural network has taken a few examples of board positions and moves, and it has used them to choose good moves on a space larger than the number of atoms of the universe. Another way to apprciate the reach of the generalization abilities of neural networks is by exploring the behavior of text-to-image models like DALLE-2. ![](https://i.imgur.com/8foVvCY.jpg) | ![](https://i.imgur.com/WHkWsKM.jpg) | ![](https://i.imgur.com/hJLfOrw.jpg) :-------------------------:|:-------------------------:|:-------------------------: "Photo of hip hop cow in a denim jacket recording a hit single in the studio" | "An italian town made of pasta, tomatoes, basil and parmesan" |“Spider-Man from Ancient Rome” While the datasets used in training are as large as technologically viable, the few billion text-image pairs that are included in the dataset are negilgible compared to the sheer combinatorial complexity of the space of text inputs, which is of course much larger than even the number of positions in the game of chess. How can neural networks trained with gradient descent possibly return so much more than they are given? That is problem of generalization. And is the fundamental question we are trying to answer. ## 1.3 The Overparametrized Regime Before trying to understand the ability of neural networks to generalize well outside of the dataset $D$, it is worthwhile to first consider why gradient descent is even able to find parameters which correctly fit $f$ on the training dataset $D$. Asking this question will make the mystery of generalization seem even more daunting. An important fact to keep in mind is that modern ML tends to use parameter spaces $\R^p$ of much higher dimensionality than $|D|$, the number of datapoints in our dataset. For polynomial function spaces, we can (and will) prove that for any function $f$, if $|D| < p$, there will always exist parameters $w\in \R^p$ with $h_w(x) = f(x)$ for all $x\in D$. More over, gradient descent (with a reasonably defined loss funciton) is guaranteed to find one of these solutions. Getting similar results for general neural networks is of course more challenging, but some do exist. For networks with one hidden layer [cite] shows that there exist parameters that fit the data if $p$ is of the order $O(|D|)$. Showing that gradient descent would find these solutions is a different thing and, as far as we know, no such results exist. (Leaving aside limiting cases like NTK [].) Given our lack of theoretical understanding, it is fortunate that this is a question that can be answered empirically. In perhaps one of the most important and insightful experiments of deep learning, [cite] took a variety of neural network architectures used in computer vision, and trained them to fit a few variations of ImageNet, a popular image classification dataset: * __Standard ImageNet.__ The inputs were real-world images -- pictures of cats, dogs, etc. The outputs are the correct labels given by ImageNet. * __ImageNet with random labels.__ The inputs were once again images from ImageNet dataset. However, the output labels were sampled at random from the set of all possible labels. (Sampling happened just once at the start of training.) * __Random inputs, random outputs.__ The input images are sampled according to a uniform distribution, essentially generating white noise. The outputs are also selected at random. (Once again, sampling happened just once at the start of training, after which the dataset was held fixed.) The result of the experiment was conclusive: the neural networks were reliably capable of fitting all three datasets. This indicates that in practice, where $p$ is much larger than $|D|$, neural networks trained with gradient descent are flexible enough to fit essentially anything. To recap: for simple cases like polynomial learning it's easy to show that if $p > |D|$, gradient descent is guaranteed find a solution that pefrectly fits the dataset $D$ of any funciton $f: X\to Y$. For neural networks, mathematical results are lacking, but the same statement has been explored empirically, and systematically found to be true. Accepting this premise leads us to an inescapable realization. __In the overparametrized regime, low loss does not imply generaliation.__ Imagine a fourth variant of the last set of experiment, in which we construct a dataset $D$ which contains both (1) every image-label pair from ImageNet, and (2) second set of images, much larger than ImageNet where every image is assigned a random label (since ImageNet has 1M images, this second set could have 100M). If we train a very large neural network on the expanded dataset $D$, gradient descent will find parameters that fit both ImageNet and all the images with random labels. But of course, this model will not generalize; after all, it was primarily trained with random targets. Thus, we've found parameters that perfectly fit ImageNet and yet they don't generalize. The crucial realization is that gradient descent does much more than find low-loss solutions. When we are learning on real-world tasks there is a sea of parameters that fit $f$ on $D$ but which do completely different things elsewhere. Yet, gradient descent somehow finds the $h_w$ that actually generalizes! How can gradient descent possibly know which parameters will generalize and which don't? _(To be clear, it wouldn't be true to say that real world problems are always in the overparametrized regime, although it is certainly the most common case; it's perfectly possible to train an image model with a few hundred million parameters on a dataset with a billion images. But the fact that overparametrized neural networks are capable both of generalizing and of fitting any dataset hints at something very deep about the role that gradient descent has in the success of deep learning to generalize. It also turns out to be a very mathematically elegant case to study. That is why for the rest of the document we'll assume we are in the overparametrized regime.)_ ## 1.4 Symmetry and Generalization The complete mathematical definition of a group will be given later, but at a high level, a group $G$ is a collection of objects $g\in G$ that transform a space $X$, sending every $x\in X$ to another point $g\rhd x\in X$. One of the central uses of groups is studying functions that have symmetry, meaning that $f(x) = f(g\rhd x)$. Groups help us abstract away the redundant information in these functions. We can gain an intuition by looking at images with symmetry. (Note that an image can be thought of as a function from the pixel coordinates to RGB colors, a perspective which we shall return to in a moment to illustrate the close connection between these ideas and machine learning.) ![](https://i.imgur.com/1SAOe4q.png) | ![](https://i.imgur.com/XdzNBd0.png) | -------- | -------- | | Image 1 | Image 2 The reflection transformation $r$ sends the pixel cordinates $(x_1,x_2)$ to $r\rhd (x_1, x_2) = (-x_1, x_2)$. Since an image of a butterfly has this symmetry, we only need the data on the left half (Image 1) to reconstruct the full image (Image 2). The RGB value for any pixel coordinates $(x_1, x_2)$ on the right half will be the same as the RGB value of $r\rhd (x_1, x_2)$ which we can determine by looking at Image 1. Thus, using this symmetry, we can reconstruct the entirety of Image 2. For another example, take a look at this tiling pattern from Seville. The tile seen on Image 1 can be used to construct the entire image on the left by placing shifted copies alongside one another. (After all, that is literally how Image 2 was constructed). These shift movements also correspond to the group $G= \Z^2$ with points $(g_1, g_2) \in G$ which transforms ponits $(x_1,x_2)$ to $g\rhd (x_1, x_2) = (x_1 + g_1, x_2+ g_2)$. ![](https://i.imgur.com/7zUttn5.png) | ![](https://i.imgur.com/ouJwE64.jpg) | -------- | -------- | | Image 1 | Image 2 Furthermore, note how even the tile (Image 1) itself is full of symmetry (In this case, the group of symmetry is composed of rotations and reflections). We again only really need a small part of it, together with the group that governs how its copies should be arranged, to reconstruct the whole tile. The analogy between these images and learning should be clear. In fact, we can formalize it precisely as a learning problem. The full image, $f: \R^2 \to \R^3$, is a funciton from two coordiantes $(x_1,x_2)\in \R^2$ to an RGB value $h(x_1, x_2) \in \R^3$. We are only given a patch of the image, so $D$ consists of all the pixel coordinates within this patch. To fill in the entire pattern in the image, the learning process would need to know the underlying group of symmetry $G$. After all, there are many ways we can put a building block together to construct different patterns. One naive hypothesis might be that function spaces like neural networks only contain symmetric funcitons. That there is a group $G$ that they use to generalize from the training datapoints to the rest of the inputs. But from the ability of a neural network to fit any dataset we can deduce that isn't the case. For any symmetry $G$, we can find parameters where the neural network does not have that symmetry. All we need to do is find two points $x,x'\in X$ where $x' = g\rhd x$ and then construct a dataset where the targets on $x$ and $x'$ differ. The neural network will find parameters $w$ that fit the data, but then $h(x, w) \not= h(g\rhd x, w)$ so the learned function does not have the symmetry $G$. Neural networks are capable of learning all sorts of functions with all sorts of symmetries. But how can they possibly know what symmetries the function $f$ has? I believe the answer is that the datasets $D$ can contain information not only about the value of $f$ on the points in $D$, but also about symmetries of $f$. Understanding this idea requires delving deep into the math of learning, which we will begin to do in Section 2. But we can get a bit of an intuition of this idea visually from some of the examples we will study later, where we are learning with the function space of homogeneous polynomials of degree 7 with 2 input variables. The dimensionality of this function space is 8, so in all the visualizations below the datasets have less than $8$ points to be in the overparametrized regime. The first column of the following plots shows the target function $f: R^2 \to \R$. The color of each pixel indicates the value of $f$ at those coordiantes. The second column represents everything gradient descent "sees". Every point drawn is a member of $D$ and its color corresponds to the value of $f$ at that point. The third column visualizes the gradient. It indicates the infinitesimal change that gradient descent (using the corresponding dataset in column 2) applies to each pixel. In other words, it shows the derivative with respect to time of the value of our function $h(x, w - t\nabla_w L(D, w) )$. (The visualization is drawn for $w=0$) ![](https://i.imgur.com/e8mWV7K.png) As we can see, when the points of the dataset seem to be distributed along some pattern that interacts nicely with the symmetry of the target function $f$, the gradients seem to pick up on this symmetry. We could change the target function and redraw the exact same plot. Again, we observe the same phenomemon. ![](https://i.imgur.com/nJXo6J1.png) But when we visualize the gradients for datasets that don't have any nice structure, neither do the gradients. As you can see none of them seems to capture the symmetric structure of $f$. ![](https://i.imgur.com/PICsUfV.png) If we are on the overparametrized regime, it's not enough to measure the function at any set of random points. The structure of which points we measure, together with the value of $f$ observed in them is crucial if we are to recover the symmetries of the funciton $f$. Most machine learning applications don't deal with function spaces whose input is 2 dimensional. In computer vision the inputs to our function space might be an arrray of 100x100 RGB pixels, and so the dimensionality of the input space is 3 million. It becomes way harder to develop an intuition of what a dataset with structure would look like in that case. The fact that the datasets contains millions or even billions of datapoints doesn't help us either. Yet, we'll try to say something about it. Imagine we have a dataset for object classification. The first thing to note is that the images in our dataset will be very special within the space of all images, the vast majority of which look something like this ![](https://i.imgur.com/f7LvKq3.png) But the images in our dataset are very different. | ![](https://i.imgur.com/YCjzFNm.jpg) | ![](https://i.imgur.com/GwWYdXT.png)| ![](https://i.imgur.com/ASpbbkt.jpg) | ![](https://i.imgur.com/VfECaan.jpg) | -------- | -------- | -------- | -------- | Under the category of "woman" we might find thousands of variations of the concept. Women with hats, paintings of women, women in light or dark backgrounds, old and young women... All these variations hint at the underlying structure of the concept "woman". If you start with the picture of a woman and then change the lighting conditions you still have the image of a woman. The same is true if you begin with the picture of a building. That is because changing the ilimunation is a symmetry of the function that goes from images to classes. If you have the picture of a building in the desert and you move it to a medaw, it is still a building. Because again, changing the background is a symmetry of the function we are trying to learn. I could go on and on and fill endless pages by stating all the ways one can perform variations on images that preserve their category and I wouldn't even make a dent to the whole thing. It's only through the constellation of millions of points in the space of images that these incredibly complicated symmetries start to appear and are picked up by our machine learning models. Well, at least that is the theory.