# 7.6 循環神經網路中的梯度爆炸和梯度消失 雖然循環神經網路理論上可以捕捉長時間序列的資訊,但仍容易出現梯度爆炸和梯度消失的問題,導致訓練無法收斂。 假設有以下簡化的循環神經網路模型。忽略偏置和輸入,只考慮隱狀態向量 $h_t$ ,即 $t$ 時刻的隱狀態 $h_t$ 和 $t-1$ 時刻的隱狀態 $h_{t-1}$ 具有以下關係。 \begin{equation} h_t = \sigma(wh_{t-1}) \end{equation} \begin{equation} \frac{\partial{h_t}}{\partial{h_{t-1}}} = w\sigma^{\prime}(wh_{t-1}) \end{equation} \begin{equation} \frac{\partial{h_{t-1}}}{\partial{h_{t-2}}} = w\sigma^{\prime}(wh_{t-2}) \end{equation} 假設從 $t$ 時刻開始,經過一系列時刻 $(t+1, t+2, \cdot\cdot\cdot, t^{\prime})$ 到達 $t^{\prime}$ 時刻。在反向求導時, $t^{\prime}$ 時刻的 $h_{t^{\prime}}$ 關於 $t$ 時刻的 $h_t$ 的偏導數為 \begin{aligned} \frac{\partial{h_{t^{\prime}}}}{\partial{h_t}} &= \frac{\partial{h_{t^{\prime}}}}{\partial{h_{t^{\prime}-1}}} \frac{\partial{h_{t^{\prime}-1}}}{\partial{h_{t^{\prime}-2}}}\cdot\cdot\cdot \frac{\partial{h_{t+1}}}{\partial{h_t}} \\ &= (w\sigma^{\prime}(wh_{t^{\prime}-1}))(w\sigma^{\prime}(wh_{t^{\prime}-2}))\cdot\cdot\cdot(w\sigma^{\prime}(wh_t)) \\ &= \prod^{t^{\prime}-t}_{k = 1}w\sigma^{\prime}(wh_{t^{\prime}-k}) \\ &= \underbrace{w^{t^{\prime}-t}}_{!!!}\prod^{t^{\prime}-t}_{k = 1}\sigma^{\prime}(wh_{t^{\prime}-k}) \end{aligned} 若權值 $w$ 不等於0,那麼 : 當 $0<|w|<1$ 時,上式將以 $t^{\prime}-t$ 的速度指數衰減到 $0$ ; 當 $|w|>1$ 時,上式將增長到無限大。也就是說,梯度 $\frac{\partial{h_{t^{\prime}}}}{\partial{h_t}}$ 將衰減到$0$或爆炸到無限大。 參數的更新公式為 \begin{equation} w = w - \alpha\frac{\partial{L}}{\partial{w}} \end{equation} \begin{equation} \frac{\partial{L}}{\partial{w}} = \sum^{n}_{t = 1}\frac{\partial{L}}{\partial{h_{t^T}}}\frac{\partial{h_{t^T}}}{\partial{h_t}}h_t \end{equation} 由上可知,$\frac{\partial{L}}{\partial{w}}$將隨著$\frac{\partial{h_{t^{\prime}}}}{\partial{h_t}}$衰減為$0$或爆炸到無限大,使訓練無法收斂。 # 7.7 長短期記憶網路(LSTM) LSTM引入了和隱狀態$h_t$不同的元胞狀態(Cell State)$c_t$,$c_{t-1}$和$c_t$之間為家法關係而非乘法關係 \begin{equation} c_t = i\odot \hat{c_t}+f\odot c_{t-1} \end{equation} \begin{equation} \frac{\partial{L}}{\partial{c_{t-1}}} = \cdot\cdot\cdot+f\odot\frac{\partial{L}}{\partial{c_t}} \end{equation} $f$是一個接近$1$的值,因此,仍可以保證$\frac{\partial{L}}{\partial{c_t}}$穩定,既不至於梯度消失,也緩解了梯度爆炸問題(但仍會產生梯度爆炸)。 ## 7.7.1 LSTM的神經元 LSTM的神經元稱為元胞(Cell)。元胞在傳統循環神經網路的隱狀態$h_t$的基礎上,增加了一個專門用於記憶歷史資訊的元胞狀態$c_t$。$c_t$記錄了所有歷史資訊,可以從一個元胞流入下一個元胞。 ![S__121356371](https://hackmd.io/_uploads/rkRWR1aQp.jpg) 元胞中有一個當前記憶單元(也稱候選記憶單元),用於計算當前輸入對整體歷史資訊$c_t$的貢獻值$\hat{c_t}$(也稱啟動值)。當前記憶單元根據資料登錄$x_t$和隱狀態輸入$h_{t-1}$計算當前時刻的啟動值$\hat{c_t}$,公式如下 \begin{equation} \hat{c_t} = \tanh(x_tW_{xc}+h_{t-1}W_{hc}+b_c) \end{equation} 其中,$W_{xc}\in R^{d*h}$、$W_{hc}\in R^{h*h}$為權值參數,$b_c\in R^{1*h}$為偏置參數,$h$表示元組狀態$h_t$和隱狀態$c_t$的向量長度,$d$表示輸入樣本的特徵數目。 \begin{equation} \hat{c_t} = \tanh(x_tW_{xc} + h_{t-1}W_{hc}+b_c) \end{equation} 其中,$x_t \in R^{n*d}$為當前時刻的輸入,$h_{t-1}\in R^{n*h}$為前一個時刻的隱狀態,$n$為樣本個數<br/> 元胞中除當前記憶單元外還包含三種門(Gate): 輸入門、輸出門、遺忘門 \begin{equation} out = f*in \end{equation} ![S__121356372](https://hackmd.io/_uploads/S1LvEx6m6.jpg) 遺忘門公式如下 \begin{equation} F_t = \sigma(X_tW_{xf} + H_{t-1}W_{hf}+ b_f) \end{equation} 輸入門公式如下 \begin{equation} I_t = \sigma(X_tW_{xi} + H_{t-1}W_{hi} + b_i) \end{equation} \begin{equation} c_t = f_tc_{t-1} + i_t\hat{c_t} \end{equation} ![S__121356373](https://hackmd.io/_uploads/BJYKUea7a.jpg) 輸出門的公式如下 \begin{equation} O_t = \sigma(X_tW_{xo} + H_{t-1}W_{ho} +b_o) \end{equation} ![S__121356374](https://hackmd.io/_uploads/Hy2XDlT7T.jpg) 元胞的輸出 \begin{equation} H_t = O_t * \tanh(c_t) \end{equation} ![S__121356375](https://hackmd.io/_uploads/ByCpPlTQa.jpg) 計算當前時刻的輸出值 \begin{equation} Z_t = (H_tW_y + b_y) \end{equation} ## 7.7.2 LSTM的反向求導 \begin{equation} \frac{\partial{L_t}}{\partial{W_y}} = H^{T}_{t} \frac{\partial{L_t}}{\partial{Z_t}} \end{equation} \begin{equation} \frac{\partial{L_t}}{\partial{b_y}} = np.sum(\frac{\partial{L_t}}{\partial{Z_t}}, axis = 0, keepdims = True) \end{equation} \begin{equation} \frac{\partial{L_t}}{\partial{H_t}} = \frac{\partial{L_t}}{\partial{Z_t}}W^T_y \end{equation} \begin{equation} \frac{\partial{L}}{\partial{H_t}} = \frac{\partial{L_t}}{\partial{Z_t}}W^T_y + \frac{\partial{L^{t-}}}{\partial{H_t}} \end{equation} \begin{equation} \frac{\partial{L}}{\partial{c_t}} = O_t \odot \tanh^{\prime}(c_t) \frac{\partial{L_t}}{\partial{H_t}} + \frac{\partial{L^{t-}}}{\partial{c_t}} \end{equation} \begin{equation} \frac{\partial{L}}{\partial{O_t}} = \frac{\partial{L}}{\partial{H_t}}\odot \tanh(c_t) \end{equation} ## 7.7.3 LSTM的程式實現 ```python= import os # 設置環境變數 KMP_DUPLICATE_LIB_OK os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' import torch import torch.nn as nn import torch.optim as optim import numpy as np import matplotlib.pyplot as plt # 生成虛擬時間序列數據 def generate_data(n_points): time = np.arange(0, n_points) data = np.sin(0.2 * time) + 0.5 * np.random.randn(n_points) return torch.tensor(data, dtype=torch.float32).view(-1, 1) # 定義 LSTM 模型 class SimpleLSTM(nn.Module): def __init__(self, input_size=1, hidden_layer_size=100, output_size=1): super(SimpleLSTM, self).__init__() self.hidden_layer_size = hidden_layer_size self.lstm = nn.LSTM(input_size, hidden_layer_size) self.linear = nn.Linear(hidden_layer_size, output_size) self.hidden_cell = (torch.zeros(1, 1, self.hidden_layer_size), torch.zeros(1, 1, self.hidden_layer_size)) def forward(self, input_seq): lstm_out, self.hidden_cell = self.lstm(input_seq.view(len(input_seq), 1, -1), self.hidden_cell) predictions = self.linear(lstm_out.view(len(input_seq), -1)) return predictions[-1] # 模型、損失函數和優化器初始化 model = SimpleLSTM() criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # 生成數據 data = generate_data(100) # 訓練模型 epochs = 150 losses = [] for i in range(epochs): for seq, labels in zip(data, data): optimizer.zero_grad() model.hidden_cell = (torch.zeros(1, 1, model.hidden_layer_size), torch.zeros(1, 1, model.hidden_layer_size)) y_pred = model(seq) single_loss = criterion(y_pred, labels) single_loss.backward() optimizer.step() losses.append(single_loss.item()) # 繪製損失曲線 plt.plot(losses) plt.xlabel('Epochs') plt.ylabel('Loss') plt.title('Loss Curve') plt.show() # 測試模型 model.eval() with torch.no_grad(): test_data = generate_data(100) model.hidden = (torch.zeros(1, 1, model.hidden_layer_size), torch.zeros(1, 1, model.hidden_layer_size)) predictions = [] for seq in test_data: predictions.append(model(seq).item()) # 繪製原始數據和預測結果 plt.plot(data, label='Original Data') plt.plot(predictions, label='Predictions', linestyle='dashed') plt.xlabel('Time') plt.ylabel('Value') plt.title('Learning Curve') plt.legend() plt.show() ``` ![Figure_1](https://hackmd.io/_uploads/SJ4PtgTQa.png)