技術路線:KFAC+Tikhonov Petersen, Felix, et al. "ISAAC newton: Input-based approximate curvature for newton’s method." in proceedings of ICLR, 2023. Citations: 1 #### 先備知識: ##### 神經網路 - 對於一個 $L$ 層的網路 $f(x;\theta)$,第 $i$ 層可以以這個形式表示: - $z_i\leftarrow \alpha_{i-1}W^{(i)}$ (pre-activations) 與 $\alpha_{i}\leftarrow \phi (z_i)$ - $a_0=x, a_L=f(x;\theta)$ - 而標準的牛頓法會以這個形式進行更新:$$ \theta'=\theta - \eta\textbf{H}^{-1}_\theta \nabla_\theta L(f(x;\theta),y)$$ ##### Tikhonov regularization - 為了增加效率與穩定性,可以引入一個 $\lambda\textbf{I}$ 進行 regularization:$$ \theta'=\theta - \eta(\textbf{H}_\theta+\lambda\textbf{I})^{-1} \nabla_\theta L(f(x;\theta),y)$$ - 其中,$\lambda > 0$,為 Tikhonov regularization parameter ##### 問題形式: - 可以注意到,一般神經網路問題可以被歸類為以下最佳化問題:$$ argmin_\theta \mbox{E}_{x, y}[L(f(x;\theta), y)] $$ #### 近似: ##### 高斯牛頓 - 在一階泰勒展開來近似模型的假設下,Hessian 可以以廣義高斯牛頓矩陣(GGN)$\textbf{G}_\theta$ 來表示 [1]$$ \theta'=\theta - \eta(\textbf{G}_\theta+\lambda\textbf{I})^{-1}_\theta \nabla_\theta L(f(x;\theta),y)$$ ##### K-FAC - 對於 inverse 的計算仍然是問題瓶頸,通過引入 K-FAC 來對 block-diagonal (i.e., layer-wise) 來去近似 Hessian 或 GGN - block-diagonal: 將完整 Hessian 分解成多個對角矩陣,每個對角矩陣對應神經網路的一個層 - 具體來說, GGN 是這樣子的:$$ \textbf{G}_\theta = \mbox{E}[(\mbox{J}_\theta f(x;\theta))^\top \nabla^2_f L(f(x;\theta) \mbox{J}_\theta f(x;\theta) ]$$ - 基本來說 $\nabla^2_f = \textbf{H}$ 、$\mbox{J}_\theta$ 為 Jacobian matrix - 接著我們討論第 $i$ 層的 diagonal block 的表示:$$ \textbf{G}_{W^{(i)}} = \mbox{E}[(\mbox{J}_{W^{(i)}} f(x;\theta))^\top \nabla^2_f L(f(x;\theta) \mbox{J}_{W^{(i)}} f(x;\theta) ] $$感覺就是把 $\theta$ 替換成 $W_{(i)}$ - 根據 Jacobian 的 backpropagation rule $$\mbox{J}_{W^{(i)}} f(x;\theta) = \mbox{J}_{z_i}f(x;\theta)a_{i-1}, \ \ a^\top b = a \otimes b $$與 mixed-product property ( $(AB)C = A(BC)$ ),從而可獲得: - $G_{W^{(i)}}=$ $$ \mbox{E}[ ((\mbox{J}_{z_i}f(x;\theta)a_{i-1})^\top(\nabla^2_f L(f(x;\theta), y))^{1/2}) $$把前面替換成 backpropagation rule 版本,然後把後面拆成兩個開根號相乘:$$ +(\nabla^2_f L(f(x;\theta), y))^{1/2}\mbox{J}_{z_i}f(x;\theta)a_{i-1}) ] $$ - 然後再做代數替換 $\bar{g}=(\mbox{J}_{z_i}f(x;\theta))^\top (\nabla^2_f L(f(x;\theta), y))^{1/2}$ ,總之會變這樣:$$ G_{W^{(i)}}=\mbox{E}[(\bar{g}^\top a_{i-1})^\top(\bar{g}^\top a_{i-1})] =\mbox{E}[(\bar{g} \otimes a_{i-1})^\top(\bar{g} \otimes a_{i-1})] =\mbox{E}[ (\bar{g}^\top \bar{g}) \otimes (a_{i-1}^\top a_{i-1}) ] $$ - 這邊的 $\bar{g}$ 本質上可以視為在第 $i$ 層輸出 $z_i$ 上的梯度乘上 Hessian ##### Monte-Carlo Low-Rank Approximation - $\bar{g}^\top \bar{g}$ 仍然很難算,通過這個方法來去近似,以及 estimate $\nabla^2_f L(f(x;\theta), y))$ - 其中,$\bar{g}∈ \mbox{R}^{m \times d_i}$, $m$ 為 $f$ 輸出的維度 - 我們將使用到 loss 的機率模型 的分佈 $p_f(x)$ - MSE:我們可以假設輸出服從 Gaussian - Cross entropy:我們可以假設輸出服從 categorical distribution - 這時我們可以做 $\nabla^2_f L(f(x;\theta), y))$ 的 estimator:$$ \nabla^2_f L(f(x;\theta), y)) = \mbox{E}_{\hat{y} \sim p_f(x)}[\nabla_f L(f(x;\theta), \hat{y})^\top\nabla_f L(f(x;\theta), \hat{y})] $$ - 接著我們引入 rank-1 approximation "$\triangleq$":$$ \nabla^2_f L(f(x;\theta), y)) \triangleq \nabla_f L(f(x;\theta), \hat{y})^\top \nabla_f L(f(x;\theta), \hat{y}) $$其中 $\hat{y} \sim p_f(x)$ - 由於 $\bar{g}$ 的後半部份是開 1/2 次方的緣故,可以視為只取上式的一半:$$ \bar{g} \triangleq (\mbox{J}_{z_i}f(x;\theta))^\top \nabla_f L(f(x;\theta), \hat{y}) = \nabla_{z_i} L(f(x;\theta), \hat{y})$$ - 而對於訓練目標的 gradient $g_i$ 則會得到以下式子:$$ g_i = (\mbox{J}_{z_i}f(x;\theta))^\top \nabla_f L(f(x;\theta), y) = \nabla_{z_i} L(f(x;\theta), y) $$其中 $g∈\mbox{R}^{1 \times d_i}$,為給定輸出 $z_i$ 的梯度,而 $\bar{g}$ 為給定 $g$ 的倒數,相當於網路的梯度乘上 Hessian 的平方根,且後者需要多次採樣迭代來近似 - 最後在套用 K-FAC 近似,GGN 會變成這個形式$$ \mbox{G}=\mbox{E}[ (\bar{g}^\top \bar{g}) \otimes (a_{i-1}^\top a_{i-1}) ] \approx \mbox{E}(\bar{g}^\top \bar{g}) \otimes \mbox{E}(a_{i-1}^\top a_{i-1}) $$為兩個期望值的 **Kronecker product** - 套用至神經網路 mini-batch 的學習方式,這兩個期望值可以這樣表示: - $\mbox{E}(\bar{g}^\top \bar{g}) \approx \bar{g}^\top \bar{g} / b$ - $\mbox{E}(a_{i-1}^\top a_{i-1}) \approx a_{i-1}^\top a_{i-1} / b$ - $g∈\mbox{R}^{b\times d_i}, \bar{g}∈\mbox{R}^{rb\times d_i}, a∈\mbox{R}^{b\times d_{i-1}}$,$d_i$ 為第 $i$ 層輸出維度,$r$ 為近似採樣次數 - 代回至原式, GGN 與其 inverse 可以這樣表示: - $\mbox{G}\approx (\bar{g}^\top \bar{g}) \otimes (a_{i-1}^\top a_{i-1}) / b^2$ - $\mbox{G}^{-1} \approx (\bar{g}^\top \bar{g})^{-1} \otimes (a_{i-1}^\top a_{i-1})^{-1} / b^2$ #### Update step - 通過上面的推導,我們將其結合:$$ \theta^{(i)\prime} = \theta^{(i)} - \eta( \bar{g}^\top\bar{g}/b + \lambda \mbox{I})^{-1} \otimes (a^\top a/b + \lambda \mbox{I})^{-1} \nabla_{\theta^{(i)}}L(f(x;\theta), y) ) $$ - 他們通過觀察發現,把 $\lambda$ disentangle成兩個 independent regularization 參數 $\lambda_g, \lambda_a > 0$ - 我們的 Kronecker-factorized Gauss-Newton update step 就會變成:$$ \zeta = \lambda_g \lambda_a (\bar{g}^\top\bar{g}/b + \lambda_g \mbox{I})^{-1} \otimes (a^\top a/b + \lambda_a \mbox{I})^{-1} \nabla_{\theta^{(i)}}L(f(x;\theta), y) $$以及更新規則:$\theta^{(i)\prime} = \theta^{(i)} - \eta^* \zeta$,$\eta^* = \eta /(\lambda_g \lambda_a)$ - 當 $\lambda = \lambda_g = \lambda_a$ 時,下式會等於上式 #### 理論 ##### 特性 - 首先探討對於 Kronecker factors $\bar{g}^\top\bar{g}, a^\top a$ 各自的 regularization $\lambda_g, \lambda_a$ 之間的 disentangle,與存在其中一方為極大值時的狀況 - 通過 vector product identity 與 Woodbury matrix identity,可以把 $\zeta$ 推導成這個:$$ \zeta = (\ \mbox{I}_m - \frac{1}{b\lambda_g}\bar{g}^\top (\mbox{I}_b + \frac{1}{b\lambda_g} \bar{g}\bar{g}^\top )^{-1} \bar{g} (\ \mbox{I}_b - \frac{1}{b\lambda_a} aa^\top (I_b + aa^\top) ) a $$ - 這時我們可以輕鬆的討論 $\lambda_g, \lambda_a$ 之間的關係: - 當 $\lambda_g, \lambda_a \to 0$,就會是完全採用 K-FAC 的近似更新方法$$ lim_{\lambda_g, \lambda_a \to 0} \frac{1}{\lambda_g\lambda_a}\zeta \approx \mbox{G}^{-1} \nabla_{\theta^{(i)}} \mbox{L}(f(x; \theta)) $$ - $\lambda_g, \lambda_a \to \infty$ 時,就會是一般的 gradient $$ lim_{\lambda_g, \lambda_a \to \infty} \frac{1}{\lambda_g\lambda_a}\zeta \approx \nabla_{\theta^{(i)}} \mbox{L}(f(x; \theta)) $$ ##### $\lambda_g, \lambda_a$ 與loss方向獨立 - 他們想去證明,$\lambda_g, \lambda_a$ 並不與更新方向有關 - 我們重新來看 $\bar{g}^\top\bar{g}/b + \lambda_g \mbox{I}$ 與 $a^\top a/b + \lambda_a \mbox{I}$,從定義上他們都是 PSD matrices,由此他們的 inverse 也會是 PSD matrices - PSD matrices 的 Kronecker product 也會是 PSD ,因此 $\mbox{G}^{-1}$ 也會是 PSD,因此更新步伐也會正確 ##### 加速版本,也就是論文所提的 ISAAC (why work?) - 重新考慮到 $\lambda_g \to \infty$ ,我們可以獲得下式:$$ lim\lambda_g \zeta = \zeta^* = g^\top(\ \mbox{I}_b - \frac{1}{b\lambda_a} aa^\top (I_b + aa^\top) ) a $$可以發現這個式子並不依賴 $\bar{g}$,避免了算 $\bar{g}$ 所需要的數次額外的反向傳播,並且$g^\top, a ∈ R^{b \times b}$,計算上相當快 - 故此我們獲得了 input based ($a$) 而且不需要計算 $\bar{g}$ 的更新方向 - 原本公式的前項將收斂到 0,可以直接省略 - 時間複雜度為:$\mbox{O}(bn^2+b^2n+b^3)$: - $n$ 為 layer 的 neuron 數量 - $\nabla = g^\top x$ 為其中的 $\mbox{O}(bn^2)$ - $\zeta^*$ 為後面的 $\mbox{O}(b^2n + b^3)$ - 我們假設$n >> b$,若 $b>n$ 則會是$\mbox{O}(bn^2 + n^3)$ - 關於時間複雜度的證明跳過,感覺沒特別重要,就只是單純的計算 - 值得注意的是,這個可以為每層獨立計算,可以針對想要的 layer 單獨為他進行二階優化 - 有一段我看不太懂:![image](https://hackmd.io/_uploads/HkbjXfbGR.png) ##### $\zeta^*$ coincides with the Gauss-Newton update direction - 對於特定網路中($\zeta^*$ 只在最後一層使用、且為 MSE loss),$\zeta^*$ 為 Gauss-Newton update direction - MSE 的 Hessian 其實會是 identity matrix,這時的 $\bar{g}^\top\bar{g}$ 也為 $\mbox{I}$,故此$\zeta^* = \zeta$ ##### 方向 - $\zeta^*$ 的方向對應到 Gauss-Newton update direction 所使用的近似 $\mbox{G}$ 可以被這樣表示:$$ \mbox{G} \approx \mbox{E}[\mbox{I}(a^\top a)] $$ ##### 拓展至 Fisher-based natural gradient - 有時候使用 Fisher-based natural gradient 會更有效果,此時我們只需要把 GGN matrix $\mbox{G}$ 改為對應的 empirical $Fisher\ information\ matrix\ \mbox{F}$ - 他們的理論也可以映用到 $\mbox{F}$,而$\zeta^*$也可以很有效率的近似 $\mbox{F}^{-1}\nabla$,此時對應到第 $i$ 層的 diagonal block 公式為:$$ \mbox{F}_{\theta^{(i)}} = \mbox{E}(g_i^\top g_i) \otimes (a^\top_{i-1}a_{i-1})] $$與 GGN matrix $\mbox{G}$ 相同 - 故此,只需要把 $\bar{g}$ 替換為 $g$ ,即可獲得到 $\mbox{F}$ [1] [Gauss Newton Matrix - Andrew Gibiansky](https://andrew.gibiansky.com/blog/machine-learning/gauss-newton-matrix/)