## A note on ghost grads This is a brief note about [ghost grads](https://transformer-circuits.pub/2024/jan-update/index.html#dict-learning-resampling). I've seen multiple implementations that I think are confused, and so I've decided to write a note clarifying things. I use the [notation](https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder) from Anthropic's [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features/index.html) paper. ### SAEs #### Architecture Let $\mathbf{x} \in \mathbb{R}^n$ denote the activation to be reconstructed. Then we can write the forward pass of an SAE as follows: $$\bar{\mathbf{x}} = \mathbf{x} - \mathbf{b}_{\text{dec}}$$ $$\mathbf{f} = \text{ReLU}(W_{\text{enc}} \bar{\mathbf{x}} + \mathbf{b}_{\text{enc}})$$ $$\hat{\mathbf{x}} = W_{\text{dec}}\mathbf{f} + \mathbf{b}_{\text{dec}}$$ #### Training SAEs are trained to minimize a loss function which aims for low reconstruction error and high sparsity of intermediate feature activations: $$\mathcal{L} = \frac{1}{|X|} \sum_{\mathbf{x} \in X} \underbrace{\lvert \lvert \mathbf{x} - \hat{\mathbf{x}} \rvert \rvert_2^2}_{\text{reconstruction error}} + \underbrace{\lambda \lvert \lvert \mathbf{f} \rvert \rvert_1}_{\text{sparsity penalty}}$$ where $\lambda$ is a hyperparameter that controls the tradeoff between reconstruction error and sparsity of intermediate feature activations. ### Dead features During SAE training, a significant number of features may become "dead" - their post-ReLU activations are 0 on all input activations. Mathematically, we say a feature $i$ is ****dead**** if for all input activations $\mathbf{x}$, $W_{\text{enc}}[i, :] \cdot \mathbf{x} < 0$, where $W_{\text{enc}}[i, :]$ denotes the $i^{th}$ row of $W_{\text{enc}}$. The phenomenon of dead features is driven during SAE training by the sparsity penalty term, which penalizes high L1 norm of the intermediate feature vector $\mathbf{f}$. Once dead, features will generally not be resuscitated. In order to be resuscitated, a dead feature needs its corresponding encoder vector, i.e. $W_{\text{enc}}[i, :]$, to be altered. However, in the negative region of the ReLU, [gradients are blocked](https://datascience.stackexchange.com/a/5734), and so no gradient updates are propagated to the encoder vectors of dead features. ### Ghost grads The main idea behind [ghost grads](https://transformer-circuits.pub/2024/jan-update/index.html#dict-learning-resampling) is to resuscitate dead features, and put them to good use in improving reconstruction error. Intuitively, there are two parts of this idea: 1. **Resuscitate dead neurons.** This will require altering the activation function of dead features (to something more sensitive than ReLU) so that gradients can flow through and update their corresponding encoder directions. 2. **Repurpose dead neurons to improve reconstruction error.** This can be achieved by pushing the dead neurons to help close the gap between the reconstructed activation $\mathbf{\hat{x}}$ and the original activation $\mathbf{x}$. How do ghost grads achieve this? Let's rewrite the SAE forward pass equations, but partition the feature space into alive features and dead features: $$\bar{\mathbf{x}} = \mathbf{x} - \mathbf{b}_{\text{dec}}$$ $$\mathbf{f}_{\text{alive}} = \text{ReLU}(W_{\text{enc}} \bar{\mathbf{x}} + \mathbf{b}_{\text{enc}}) * \text{mask}_{\text{alive}}$$ $$\mathbf{f}_{\text{dead}} = \text{ReLU}(W_{\text{enc}} \bar{\mathbf{x}} + \mathbf{b}_{\text{enc}}) * \text{mask}_{\text{dead}}$$ $$\hat{\mathbf{x}} = W_{\text{dec}}\mathbf{f}_{\text{alive}} + W_{\text{dec}}\mathbf{f}_{\text{dead}} + \mathbf{b}_{\text{dec}}$$ Note that this is a bit silly to write out: given the definition of a "dead feature", we have that $\mathbf{f}_{\text{alive}} = \mathbf{f}$, and $\mathbf{f}_{\text{dead}} = \mathbf{0}$. But I think it's helpful to partition things in this way. We want our reconstruction to approximate the original activation $\mathbf{x}$: $$W_{\text{dec}}\mathbf{f}_{\text{alive}} + W_{\text{dec}}\mathbf{f}_{\text{dead}} + \mathbf{b}_{\text{dec}} \underset{\text{want}}{\approx} \mathbf{x}$$ Noticing that $\hat{\mathbf{x}} = W_{\text{dec}}\mathbf{f}_{\text{alive}} + \mathbf{b}_{\text{dec}}$, we can re-write our desire as $$ \hat{\mathbf{x}} + W_{\text{dec}}\mathbf{f}_{\text{dead}} \underset{\text{want}}{\approx} \mathbf{x} $$ or equivalently, $$W_{\text{dec}}\mathbf{f}_{\text{dead}} \underset{\text{want}}{\approx} \mathbf{x} - \hat{\mathbf{x}}$$ The intuition of ghost grads is to try and *alter the contribution of* $W_{\text{dec}} \mathbf{f}_{\text{dead}}$ *from* $\mathbf{0}$ *to* $\mathbf{x} - \mathbf{\hat{x}}$. In order to do this, we'll first compute the activations of dead features using a more sensitive activation function, the $\exp$ function, so that gradients can flow: $$\mathbf{f}_{\text{dead}}^{\text{ghost}} = \exp(W_{\text{enc}} \bar{\mathbf{x}} + \mathbf{b}_{\text{enc}}) * \text{mask}_{\text{dead}}$$ We then construct a loss term that shapes the output contributions of these dead neurons towards $\mathbf{x} - \hat{\mathbf{x}}$ (i.e. the "residual"): $$\mathcal{L}_{\text{ghost}} = \lvert \lvert \gamma W_{\text{dec}}\mathbf{f}_{\text{dead}}^{\text{ghost}} - (\mathbf{x} - \hat{\mathbf{x}}) \rvert \rvert_2^2$$ [See the [original post](https://transformer-circuits.pub/2024/jan-update/index.html#dict-learning-resampling) for details about scaling the outputs of the ghost features (denoted as $\gamma$ in the above equation) and other details of how to scale the ghost grad loss term with respect to the other loss terms.] **Note:** it is $W_{\text{dec}}\mathbf{f}_{\text{dead}}^{\text{ghost}}$ that ought to approximate the residual, **not** $W_{\text{dec}}\mathbf{f}_{\text{dead}}^{\text{ghost}} + \mathbf{b}_{\text{dec}}$. I have seen this subtle mistake in a couple of independent implementations so far, and hope that this note clears up this misconception.