## A note on ghost grads
This is a brief note about [ghost grads](https://transformer-circuits.pub/2024/jan-update/index.html#dict-learning-resampling). I've seen multiple implementations that I think are confused, and so I've decided to write a note clarifying things.
I use the [notation](https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder) from Anthropic's [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features/index.html) paper.
### SAEs
#### Architecture
Let $\mathbf{x} \in \mathbb{R}^n$ denote the activation to be reconstructed.
Then we can write the forward pass of an SAE as follows:
$$\bar{\mathbf{x}} = \mathbf{x} - \mathbf{b}_{\text{dec}}$$
$$\mathbf{f} = \text{ReLU}(W_{\text{enc}} \bar{\mathbf{x}} + \mathbf{b}_{\text{enc}})$$
$$\hat{\mathbf{x}} = W_{\text{dec}}\mathbf{f} + \mathbf{b}_{\text{dec}}$$
#### Training
SAEs are trained to minimize a loss function which aims for low reconstruction error and high sparsity of intermediate feature activations:
$$\mathcal{L} = \frac{1}{|X|} \sum_{\mathbf{x} \in X} \underbrace{\lvert \lvert \mathbf{x} - \hat{\mathbf{x}} \rvert \rvert_2^2}_{\text{reconstruction error}} + \underbrace{\lambda \lvert \lvert \mathbf{f} \rvert \rvert_1}_{\text{sparsity penalty}}$$
where $\lambda$ is a hyperparameter that controls the tradeoff between reconstruction error and sparsity of intermediate feature activations.
### Dead features
During SAE training, a significant number of features may become "dead" - their post-ReLU activations are 0 on all input activations.
Mathematically, we say a feature $i$ is ****dead**** if for all input activations $\mathbf{x}$, $W_{\text{enc}}[i, :] \cdot \mathbf{x} < 0$, where $W_{\text{enc}}[i, :]$ denotes the $i^{th}$ row of $W_{\text{enc}}$.
The phenomenon of dead features is driven during SAE training by the sparsity penalty term, which penalizes high L1 norm of the intermediate feature vector $\mathbf{f}$.
Once dead, features will generally not be resuscitated. In order to be resuscitated, a dead feature needs its corresponding encoder vector, i.e. $W_{\text{enc}}[i, :]$, to be altered. However, in the negative region of the ReLU, [gradients are blocked](https://datascience.stackexchange.com/a/5734), and so no gradient updates are propagated to the encoder vectors of dead features.
### Ghost grads
The main idea behind [ghost grads](https://transformer-circuits.pub/2024/jan-update/index.html#dict-learning-resampling) is to resuscitate dead features, and put them to good use in improving reconstruction error.
Intuitively, there are two parts of this idea:
1. **Resuscitate dead neurons.** This will require altering the activation function of dead features (to something more sensitive than ReLU) so that gradients can flow through and update their corresponding encoder directions.
2. **Repurpose dead neurons to improve reconstruction error.** This can be achieved by pushing the dead neurons to help close the gap between the reconstructed activation $\mathbf{\hat{x}}$ and the original activation $\mathbf{x}$.
How do ghost grads achieve this?
Let's rewrite the SAE forward pass equations, but partition the feature space into alive features and dead features:
$$\bar{\mathbf{x}} = \mathbf{x} - \mathbf{b}_{\text{dec}}$$
$$\mathbf{f}_{\text{alive}} = \text{ReLU}(W_{\text{enc}} \bar{\mathbf{x}} + \mathbf{b}_{\text{enc}}) * \text{mask}_{\text{alive}}$$
$$\mathbf{f}_{\text{dead}} = \text{ReLU}(W_{\text{enc}} \bar{\mathbf{x}} + \mathbf{b}_{\text{enc}}) * \text{mask}_{\text{dead}}$$
$$\hat{\mathbf{x}} = W_{\text{dec}}\mathbf{f}_{\text{alive}} + W_{\text{dec}}\mathbf{f}_{\text{dead}} + \mathbf{b}_{\text{dec}}$$
Note that this is a bit silly to write out: given the definition of a "dead feature", we have that $\mathbf{f}_{\text{alive}} = \mathbf{f}$, and $\mathbf{f}_{\text{dead}} = \mathbf{0}$. But I think it's helpful to partition things in this way.
We want our reconstruction to approximate the original activation $\mathbf{x}$:
$$W_{\text{dec}}\mathbf{f}_{\text{alive}} + W_{\text{dec}}\mathbf{f}_{\text{dead}} + \mathbf{b}_{\text{dec}} \underset{\text{want}}{\approx} \mathbf{x}$$
Noticing that $\hat{\mathbf{x}} = W_{\text{dec}}\mathbf{f}_{\text{alive}} + \mathbf{b}_{\text{dec}}$, we can re-write our desire as
$$
\hat{\mathbf{x}} + W_{\text{dec}}\mathbf{f}_{\text{dead}} \underset{\text{want}}{\approx} \mathbf{x}
$$
or equivalently,
$$W_{\text{dec}}\mathbf{f}_{\text{dead}} \underset{\text{want}}{\approx} \mathbf{x} - \hat{\mathbf{x}}$$
The intuition of ghost grads is to try and *alter the contribution of* $W_{\text{dec}} \mathbf{f}_{\text{dead}}$ *from* $\mathbf{0}$ *to* $\mathbf{x} - \mathbf{\hat{x}}$.
In order to do this, we'll first compute the activations of dead features using a more sensitive activation function, the $\exp$ function, so that gradients can flow:
$$\mathbf{f}_{\text{dead}}^{\text{ghost}} = \exp(W_{\text{enc}} \bar{\mathbf{x}} + \mathbf{b}_{\text{enc}}) * \text{mask}_{\text{dead}}$$
We then construct a loss term that shapes the output contributions of these dead neurons towards $\mathbf{x} - \hat{\mathbf{x}}$ (i.e. the "residual"):
$$\mathcal{L}_{\text{ghost}} = \lvert \lvert \gamma W_{\text{dec}}\mathbf{f}_{\text{dead}}^{\text{ghost}} - (\mathbf{x} - \hat{\mathbf{x}}) \rvert \rvert_2^2$$
[See the [original post](https://transformer-circuits.pub/2024/jan-update/index.html#dict-learning-resampling) for details about scaling the outputs of the ghost features (denoted as $\gamma$ in the above equation) and other details of how to scale the ghost grad loss term with respect to the other loss terms.]
**Note:** it is $W_{\text{dec}}\mathbf{f}_{\text{dead}}^{\text{ghost}}$ that ought to approximate the residual, **not** $W_{\text{dec}}\mathbf{f}_{\text{dead}}^{\text{ghost}} + \mathbf{b}_{\text{dec}}$. I have seen this subtle mistake in a couple of independent implementations so far, and hope that this note clears up this misconception.