Try   HackMD

這是一個具有一層隱藏層的神經網路:

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

假設

  • 輸入層有 3 個節點,輸入 X 中有 3 筆數據,其標籤為 Y:
    X=[120232113],Y=[123]

W1=[101111]

  • 隱藏層有 2 個節點,隱藏層權重矩陣為
    W1
    ,線性組合
    Z=XW1

    經過激活函數
    σ
    後的值為
    K
    ,即
    K=σ(Z)
  • 另激活函數
    σ
    Relu
    函數,
    σ(x)=Relu(x)=max(x,0)
  • 輸出層有 1 個節點,其權重矩陣為
    W2
    ,線性輸出
    O=KW2

W2=[12]

  • 將輸出值與標籤去計算損失,令損失為
    J
    ,假設使用加總型式的最小平方損失

J=(12(OY)2)

此時,已知輸出層梯度:

Gout=JO=OY

  • 隱藏層梯度:

G2=JW2=JOOW2=((Gout)TK)T=KTGout

  • Relu
    函數的微分式:

σ(x)={0x<01x0

  • 先假定一個暫存的
    Gtemp

Gtemp=(GoutW2T)σ(Z)

其中的

代表一般的矩陣乘法、
代表阿達瑪乘積,為對應位置的矩陣元素乘積

  • 輸入層梯度:

G1=JW1=JOOKKZZW1=(((GoutW2T)σ(Z))TX)T=((Gtemp)TX)T=XTGtemp

  • 可用
    G1,G2
    梯度更新權重
    W1,W2
    的值,得到新權重
    W1new,W2new

假設我們採用隨機梯度下降法來進行更新,且學習率令為

0.1,則

{W1new=W10.1×G1W2new=W20.1×G2

問題

求矩陣

Z,K,O,Gout,G2,σ(Z),Gtemp,G1,W1new,W2new

求解過程皆省略公式推導過程,將直接使用最終結果代入計算


Z

由線性組合

Z=XW1

Z=XW1=[120232113][101111]=[323554]


K

由經過激活函數

σ 後的值
K=σ(Z)
,且激活函數
σ
Relu
函數,
σ(x)=Relu(x)=max(x,0)

K=σ(Z)=σ ([323554])=[023050]


O

輸出層有 1 個節點,其權重矩陣為

W2,由線性輸出
O=KW2

O=KW2=[023050][12]=[435]


Gout

由輸出層梯度

Gout=OY

Gout=OY=[435][123]=[518]


G2

由隱藏層梯度

G2=KTGout,其中
KT
為 矩陣
K
的轉置:

K=[023050],KT=[035200]

G2=KTGout=[035200][518]=[4310]


σ(Z)

其中

σ
Relu
函數的微分式:

σ(x)={0x<01x0
σ(Z)=σ([323554])=[011010]


Gtemp

由假定暫存的

Gtemp=(GoutW2T)σ(Z)
其中的
代表一般的矩陣乘法、
代表阿達瑪乘積,為對應位置的矩陣元素乘積,且
W2T
為矩陣
W2
的轉置:

W2=[12],W2T=[12]

Gtemp=(GoutW2T)σ(Z)=([518][12])[011010]=[51012816][011010]=[0101080]


G1

由輸入層梯度

G1=XTGtemp,其中
XT
為 矩陣
X
的轉置:

X=[120232113],XT=[121231023],

G1=XTGtemp=[121231023][0101080]=[6101120260]


W1new,W2new

  • 可用
    G1,G2
    梯度更新權重
    W1,W2
    的值,得到新權重
    W1new,W2new

    已知我們採用隨機梯度下降法來進行更新,且學習率為
    0.1
    ,則

{W1new=W10.1×G1W2new=W20.1×G2

W1new=W10.1×G1=[101111]0.1×[6101120260]=[101111][0.611.122.60]=[0.410.111.61]

W2new=W20.1×G2=[12]0.1×[4310]=[101111][4.31]=[3.31]


點擊回到導覽頁面