owned this note
                
                
                     
                     owned this note
                
                
                     
                    
                
                
                     
                    
                
                
                     
                    
                        
                            
                            Published
                        
                        
                            
                                
                                Linked with GitHub
                            
                            
                                
                                
                            
                        
                     
                
            
            
                
                    
                    
                
                
                    
                
                
                
                    
                        
                    
                    
                    
                
                
                
                    
                
            
            
         
        
        # Activation function: SwiGLU
緣由: 看了[CS336: Language Modeling from Scratch
](https://github.com/stanford-cs336/spring2025-lectures/tree/e94e33f433985e57036b25215dff2a4292e67a4f)介紹的資料(Lecture 3),內文介紹LLM相關的時候(下圖)

後面的activation function大多採用SwiGLU,很久沒看新東西,就看一下這個非線性函數在做什麼。
SwiGLU(Swish-Gated Linear Unit)這個激活函數(activation function)是結合Swish和GLU的混合激活函數。我們要分別看什麼是
1. Swish
2. GLU:Gated Linear Unit
然後再來看
4. SwiGLU
## 1. Swish
Swish是參考LSTM中sigmoid的來控制輸出大小的概念,利用sigmoid來控制輸出的比重,sigmoid是0-1的輸出,Swish公式如下
$$
f(x)=x\times \sigma(\beta x)
$$
$\sigma$: sigmoid function
這方法取代ReLU,gradient>0,避免梯度消失的問題,且具有平滑特性,有利於優化和泛化性。
```
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch
def swish(x, beta=1):
    return x * F.sigmoid(beta * x)
x = np.linspace(-10,10,100)
x = torch.tensor(x)
for beta in [1,2,3]:
    y =  swish(x, beta=beta)
    plt.plot(x,y, label='beta:{}'.format(beta))
plt.grid('on')
plt.legend()
plt.show()
```

此函數德garident為,假設$\beta=1$
$$f(x)= x\times \sigma(x)$$
$$f'(x)=f(x)+\sigma(x)\times(1-f(x))$$
推導:
$$f'(x)= \frac{x}{\partial x} \times \sigma(x) + x \times \frac{\sigma(x)}{\partial x} = \sigma(x)+ x\times\sigma(x)\times(1-\sigma(x))
=x\times\sigma(x)+\sigma(x)(1-x\times\sigma(x))=f(x)+\sigma(x)\times(1-f(x))$$
正數部分: 和ReLU一樣,正數沒有上限,因為sigmoid函數輸出最大是1,所以當輸入值過大,則$f(x)$輸出就是$x$,跟ReLU一樣。
負數部分: 有下限,因為sigmoid函數輸出最小是0,所以當負數過大,則$f(x)$輸出就是$0$。(見上圖)
NOTE:當$\beta=1$則 Swich = SiLU(Sigmoid Linear Unit)。
## 2. GLU
GLU全名是Gated Linear Unit,在LSTM是用在gate control的原件,來控制輸出是否要保留,跟剛剛Swish長得很像,但不一樣。
$$
GLU(a,b)=a\otimes\sigma(b)
$$
$\otimes$: 是元素相乘(element-wise product)的符號,也就是哈達瑪積(Hadamard product)。
$\sigma$: sigmoid函數。
輸入變成$a$和$b$倆個。
## 3. SwiGLU(Swish-Gated Linear Unit)
結合Swish和GLU,
$$
SwiGLU(a,b) = GLU(Swish(a),b)=Swish(a)\otimes\sigma(b)
$$那a和b是什麼?
a和b是模型要學習兩個FC的輸出,也就是
$$a=xW_1$$$$b=xW_2$$$Swish$內的$\beta$通常取$1$,等於用SiLU。
$\otimes$: 是哈達瑪積(Hadamard product)。
$$SwiGLU(a,b) =Swish(a)\otimes b=Swish(xW_1) \otimes xW_2$$
加了兩個FC要學習,這有什麼好處,
1. 透過gate權重的的方式動態調整訊息傳遞,可以過濾不重要的訊息。
2. Swish的連續可微分,減緩了ReLU梯度消失的問題。
3. 非單調: 讓模型可以擷取到更複雜的非線性關係。
4. 相較於一般transformer內的FFN:採用ReLU和FC的運算,SwiGLU透過gate減少冗餘計算。
Note: FFN: feed forward layer,本質就是3層的MLP也就是加上兩個fully connection,$FFN(x)=f(xW_1)W_2$
```
import torch
import torch.nn as nn
import torch.nn.functional as F
class SwiGLU(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(input_dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(input_dim, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, input_dim, bias=False)
    def forward(self, x):
        a = self.w1(x)
        b = self.w2(x)
        gate = F.silu(self.w1(a))  # F.silu等价于Swish(β=1)
        filtered = gate * self.w2(b)
        return self.w3(filtered)
```
以上是網路的寫法,在原論文[GLU Variants Improve Transformer](https://arxiv.org/pdf/2002.05202)中
SwiGLU主要是為了提升改良transformer中的FFN(feed-forward network)。
FFN定義如下,(這個我沒有寫上bias項目):
$$FFN_{ReLU}(x, W_1,W_2)= ReLU(xW_1)W2=max(xW_1,0)W_2$$$$FFN_{Swish}(x, W_1,W_2)= Swish(xW_1)W2$$
我這邊為了跟前面的a和b呼應,我有改論文的公式
論文寫法:$$GLU(x,W,V,b,c)=\sigma(xW+b)\otimes(xV+c)$$
我改寫
$GLU(x,W_1,W_2,b_1,b_2)=GLU(a=xW_1+b_1,b=xW_2+b_2)=\sigma(xW_1+b_1)\otimes(xW_2+b_2)$
$SwiGLU(x,W_1,W_2,b_1,b_2,\beta)=Swish_{\beta}(xW_1+b_1)\otimes(xW_2+b_2)$
在原論文中寫,在transformer內的FFN,我們可以看一下他的變形(假設不考慮bias)
$FFN_{GLU}(x,W_1,W_2,W_3)=(\sigma(xW_1)\otimes xW_2)W_3$
$FFN_{SwiGLU}(x,W_1,W_2,W_3)=(Swish_1(xW_1)\otimes xW_2)W_3$
所以論文其實在函數是沒有再加上sigmoid在$xW_2$的上面,但網路找到的寫法都有加上去(居多)。
所以實作程式碼應該是
```
class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
        super().__init__()
        hidden_dim = multiple_of * ((2 * hidden_dim // 3 + multiple_of - 1) // multiple_of)
        self.w1 = nn.Linear(dim, hidden_dim)
        self.w2 = nn.Linear(hidden_dim, dim)
        self.w3 = nn.Linear(dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
```