# Sparse NanoGPT **Goal** Train a "decent" GPT2-small sized LLM from scratch while being "as interpretable as possible" and not taking "too long" compared to vanilla baselines. Quantify alignment tax by the time it takes to reach common validation CE loss thresholds used by the nanogpt-speedrunning literature. The goal is alignment tax $<15\times$. The most common threshold is 3.28 CE loss. * This took Andrej Karpathy the equivalent of 45min to train on an 8xH100 machine in 2024. * Currently the record is ~2min. The most important components needed to reach this record are mainly: * The Muon optimizer (orthogonalizes gradient updates for all linear layers). This is the most qualitatively important optimization. * More optimal ways of batching the data. * CUDA / low-level optimizations (lumping in stuff like compilation, using FP8 where appropriate, changing activation functions to be slightly faster to compute). * We'll aim for some fair comparison with a "vanilla" baseline with same hyper-parameters. ## Broad approach: induce sparsity As a proxy for "as interpretable as possible", I mean "as sparse as possible". By this I mean both: * *Weight sparsity* * In some sense this is the easier of the two since we can amortize the cost of whatever sparsification steps we take across a batch. * But the geometry of sparsification may not play well with Muon. * *Activation sparsity* * This faces more of a "bottleneck": * For "best interpretability", we probably want $\gg768$ "features" in GPT2-small residual stream (otherwise we'll have polysemanticity). For now, targeting something like $25,000$ features. ### Signals of plausibility from prior literature I used to think training a sparse transformer from scratch might have too nasty a loss landscape to perform well at all. I've since updated based off the following two papers. There doesn't seem to be an obvious performance ceiling for sparse; the main blocker seems to be compute costs. #### Weight-Sparse Transformers https://cdn.openai.com/pdf/41df8f28-d4ef-43e9-aed2-823f9393e470/circuit-sparsity-paper.pdf * Induces sparsity in activations and weights via TopK. * The model performs well on their corpus of coding data, despite only having $2048$ features in the residual stream. * Perhaps their coding corpus has less need for larger number of features / concepts than natural language. * Perhaps there's still some implicit computation in superposition. Imagine an over-complete basis where vectors are $p$ sparse, and the $\ell_0$ of activations in this basis is $q$. If $pq\leq k$ then this can lead to an $\ell_0$ of $k$ in the standard basis even though this is not the most "disentangled" representation. They use $k=512$, so this seems plausible. * Authors claim that weight-sparse models are $100-1000\times$ more expensive than equivalent dense models for training/inference. #### Baby Dragon Hatchling https://arxiv.org/pdf/2509.26507 * Encourages sparsity via a very wide ReLU MLP. They find the hidden layer of this MLP is ~5% sparse without any other modifications. * The architecture doesn't appear to have obvious performance ceilings; performs similarly to GPT2 architectures trained on translation tasks. * They don't talk about computational costs but their approach is likely also $100-1000\times$ as expensive as dense. ## Approach: sparsity in over-complete basis The main differentiator of my approach vs WST / BDH is to encourage sparsity in an over-complete basis rather than standard basis. ## Decomposing alignment tax: throughput and data efficiency To fix ideas, there are two main components of the alignment tax: 1. Token throughput (how many tok/s can you achieve during training on same hardware) 2. Data efficiency (how many tokens does it take to reach some target CE loss) sparsity may mess with either of these. ## This week: activation sparsity via ParityDeepTopK (edited 2/12) From a computational perspective, this is the harder part. So this serves as useful estimate of how good / bad we can get on token throughput. From a computational perspective, my current approach (+ tireless kernel engineering from Opus 4.6) is going well. To summarize, you can view my approach as splicing in a "deep topK" SAE at a given layer of a transformer. Essentially this is a multi-level mixture-of-experts style TopK SAE where the experts' directions are strategically chosen to: 1. maintain incoherence (approximate orthogonality of feature directions) 2. be efficiently computable via an H200/B200 tensor-core friendly "parity hash" function. this side-steps the need to load experts from HBM ("slow, but large" memory on GPU), allowing us to scale to a larger number of experts than for standard mixture-of-experts (as standard MoE requires a relatively small number of frequent experts in order to amortize the cost of memory loads). * Essentially, by *designing* the over-complete basis rather than learning it post-hoc, we can ensure that it has computationally convenient properties. * The theory is that scaling to a larger number of (potentially rare experts/features) will be better from an interpretability perspective. In this framework, features are arranged in a tree. "Experts" in the MoE analogy correspond to level-1 features in the tree. "Child" (level-2) features correspond to "expert weights". Unlike traditional MoE, because the features are designed, we can efficiently scale to multiple levels, forming a potentially deep hierarchy. However, I've currently only implemented a two-level tree. See this doc for details: https://hackmd.io/@amack/Sk7UXFDPbx. ### Load-balancing (added 2/12) As alluded to above, from an inductive bias perspective we don't really want any "load-balancing" of experts / level-1 features, as a priori we expect feature firing frequencies to be highly heterogeneous (as opposed to uniform). However, from a training dynamics perspective, some amount of load-balancing is necessary, or else SGD quickly converges to a regime where 90% of level-1 features are dead (fire with frequency less than $1/m_1$ where $m_1$ is number of level-1 features). My current approach for load balancing adaptively keeps track of running estimates of quantiles of expert score distributions and strategically adds biases to infrequent experts' scores to guarantee they fire at a minimum firing frequency. The idea is to ensure no experts "die" while allowing for heterogeneity in firing frequencies. My current approach appears to eliminate dead experts. However, I'm still tinkering with it (looking at things like whether this induces a plausible distribution over firing frequencies) and so the method is currently in flux. I'll write up a more detailed description once I've finalized the load-balancing method. ## Results (added 2/12) I apply DeepTopK to the attention-outputs of a GPT2-small-sized transformer at all layers. In particular: 1. There are $m_1=768$ level-1 features (differently from the doc, level-1 features are basis-aligned, and so are essentially free to compute). Each level-1 feature has $m_2=32$ children, for a total of $25,344$ features in the dictionary. * Currently, the main bottleneck for scaling to more features seems to be more about training dynamics as opposed to compute efficiency. In particular, I noticed in small pilot runs that increasing the number of features to 1e5-1e6 led to worse data efficiency, perhaps because this introduces more noise in the DeepTopK reconstruction which is harmful early in training when the network hasn't learned to use the features. * It's plausible that to scale to a larger number of features, you may need to progressively add levels to the tree, focusing on learning the most coarse-grained features early in training, then finer-grained later on. 3. I select $k_1=16$ features from the first level. At the second level, we score $k_1*m_2=512$ level-2 candidates, and keep $k_2=32$ of these. This yields a total $\ell_0$ of $k_1+k_2=48$. The only other difference from a vanilla GPT2 architecture is we replace all instances of LayerNorm with RMSNorm in order to priveledge the origin (similarly to weight-sparse transformers). I haven't yet run a fully apples-to-apples comparison with exact same hyper-parameters (learning rate schedule etc) except for the DeepTopK operation. However a ballpark summary is as follows: 1. **Token Throughput** - 1.5x slower than dense (on 8xH200) * dense: 3.8m tok/s, sparse: 2.6m tok/s 2. **Data efficiency** - 1-3.33x slower than dense. This depends on whether we allow use of Adam or Muon. rough numbers for tokens needed to reach 3.28 CE loss are: * dense (Adam): 10B, dense (Muon): 3B, sparse (Adam/Muon): 10B * In particular, the 10B number for sparse is technically for a full sparse+Muon run (leading to a >=200% "data efficiency" tax). However, in smaller pilot runs, I've noticed that sparse+Adam appears to perform similarly as sparse+Muon. * Essentially, the DeepTopK operation appears to negate any performance benefits of Muon vs Adam. * Muon is a fairly recent optimization following years of research into dense architectures. At least one major lab has used Muon at scale (see kimi-k2) but other organizations report issues with stability / distributed training efficiency; making Muon more robust at scale seems to be an active research area (see e.g. [this paper from microsoft AI](https://arxiv.org/pdf/2504.05295)). * Given that "weight-sparse transformers" as a concept are only a few months old, it's arguably unfair to compare against an optimizer which has benefited from years of research into optimizing dense transformers. So going forward, I'll report a range of "alignment taxes", comparing against both dense+Muon and dense+Adam. 3. **Total** - between 1.5-5x slower than dense (well within our 15x target) * Next steps are to quantify whether this is actually "as interpretable as possible" and adjust accordingly. It's possible the alignment tax will go up (or down) as we adjust to reach our interpretability goals. * **Training/Architectural Modifications to Consider** Most of these will increase the alignment tax, but may help interpretability: * sparsify weights * apply ParityDeepTopK at more locations in the model (e.g. residual stream) * **Measures of "interpretability"** These measures are imperfect but will hopefully give some signal * **SAE-Bench** train a dense transformer with same CE loss as sparse, then train SAE on the dense transformer. Run SAE-bench on the post-hoc SAE * **Circuit Isolation** use the circuit-pruning method from weight-sparse transformers paper on several tasks, see if we can explain performance with sparser circuits. This metric [has issues](https://www.lesswrong.com/posts/sHpZZnRDLg7ccX9aF/weight-sparse-circuits-may-be-interpretable-yet-unfaithful) but plausibly seems loosely correlated with interpretability.