# Learning to Stack
Define network functions: $\{(f_{\mbox{att}}^l,f_\mbox{ffn}^l,f^l_{\mbox{norm}})\}_{l=1}^L$, gates $\{g^l\}_{l=1}^L$. An input to layer $l$, $\mathbf{x}^l$ is transformed to $\mathbf{x}^{l+1}$ via
$$ \mathbf{h}_l = \mathbf{x}^l + g^lf_{\mbox{att}}^l(f^l_{\mbox{norm}}(\mathbf{x}^l))$$
$$ \mathbf{x}_{l+1} = \mathbf{h}^l + g^lf_{\mbox{ffn}}^l(f^l_{\mbox{norm}}(\mathbf{h}^l))$$
(inspired by: https://arxiv.org/pdf/2010.13369.pdf) Define the final prediction layer as $f_{\mbox{pred}}$ (this layer is not gated). Define $g^l \sim B(p^l)$, where $B(p^l)$ is the Bernoulli distribution with mean $p^l = \sigma(\theta^l)$, where $\sigma$ is the sigmoid function. Given this, define the following training procedure
> Initialize $g^1 = 1$ and $g^l = 0, \forall l \in [2, L]$, network size $S = 1$, update time $T$, counter $t=0$, total number of iterations $I$, counter $i=0$, gate parameters $\theta^l = 0, \forall l \in [2,L]$ (which makes $p^l = \sigma(0) = 0.5$), number of samples $k$
> While $i < I$
$\quad$ While $t < T:$
$\qquad$ Sample batch $\mathcal{B} \subset \mathcal{D}$
$\qquad$ Update $\{(f_{\mbox{att}}^l,f_\mbox{ffn}^l,f^l_{\mbox{norm}})\}_{l=1}^S, f_{\mbox{pred}}$ the normal way, via gradients of loss function $\mathcal{L}$ on batch $\mathcal{B}$
$\qquad$ Increment $t \leftarrow t + 1$
$\quad$ Copy $(f_{\mbox{att}}^{S+1},f_\mbox{ffn}^{S+1},f^{S+1}_{\mbox{norm}}) = (f_{\mbox{att}}^{S},f_\mbox{ffn}^{S},f^{S}_{\mbox{norm}})$
$\quad$ Sample $\{g_j^{S+1}\}_{j=1}^k \sim B(p^{S+1})$
$\quad$ Sample multiple batches $\{\mathcal{B}_j\}_{j=1}^k \subset \mathcal{D}$
$\quad$ Update $\theta^{S+1}$ via $\nabla_{\theta^{S+1}} \mathbb{E}_{g^{S+1} \sim p^{S+1}}[\mathcal{L}(\mathcal{B})] \approx \frac{1}{k-1} \sum_{j=1}^k \nabla_{\theta^{S+1}} (g_j^{S+1}\log p^{S+1} + (1-g_j^{S+1})\log(1-p^{S+1})) \Big(\mathcal{L}(\mathcal{B}_j, g_j^{S+1}) - \frac{1}{k} \sum_{h=1}^k \mathcal{L}(\mathcal{B}_h, g_h^{S+1}) \Big)$ (taken from equation (8) here: https://openreview.net/pdf?id=r1lgTGL5DE)
$\quad$ Sample $g^{S+1} \sim B(p^{S+1})$
$\quad$ If $g^{S+1} = 1$,
$\qquad$ Increment $S \leftarrow S + 1$
$\quad$ Increment $i \leftarrow i + 1$
### Notes
- The update for $\theta^{S+1}$ is super easy to compute. Note that if $g_i^{S+1} = 0$ then $\mathcal{L}(\mathcal{B}, g_i^{S+1})$ is the loss on batch $\mathcal{B}$ without the added layer $(f_{\mbox{att}}^{S+1},f_\mbox{ffn}^{S+1},f^{S+1}_{\mbox{norm}})$. If $g_i^{S+1} = 1$ then $\mathcal{L}(\mathcal{B}, g_i^{S+1})$ is the loss on batch $\mathcal{B}$ on the batch _with_ the added layer $(f_{\mbox{att}}^{S+1},f_\mbox{ffn}^{S+1},f^{S+1}_{\mbox{norm}})$.
- To reduce the number of hyperparameters I fixed the number of sampled batches $\mathcal{B}_i$ to the number of sampled $g_i$ but these don't have to be the same.
- How to pitch the paper:
- https://proceedings.neurips.cc/paper/2017/file/9ef2ed4b7fd2c810847ffa5fa85bce38-Paper.pdf
- https://arxiv.org/pdf/2210.11369.pdf