tags: ML數學分部

JS Divergence

在之前 "KL Divergence & CrossEntrophy 的真面目"中提過了 KL Divergence
也有講到 KL Divergence 有非對稱的性質,
也就是

DKL(P||Q)DKL(Q||P)

這時就出現了 JS Divergence 來解決此非對稱性

1. JS Divergence 推導

JS Divergence 的想法非常簡單,如果

(P||Q) 的差距和
(Q||P)
差距不同
那我們找出一個平均點,個別計算與
(P||Q)
(Q||P)
的距離相加不就對稱了

設 M 為

(P||Q)
(Q||P)
的平均分佈

M=12(P+Q)

並分別計算 M 與 P、Q 的差距 ( KL Divergence ) 做平均
就為 JS Divergence

JSD(P||Q)=12DKL(P||M)+12DKL(Q||M)

JSD(P||Q)=JSD(Q||P)

所以利用 JS Divergence 解決了非對稱的問題

JSD(P||Q)=JSD(Q||P)=12DKL(P||M)+12DKL(Q||M)



2. JS Divergence 的問題

JS Divergence 常被用在 GAN 計算生成的 Data 與實際 Data 間的差距
但卻有一大問題,這問題也讓 GAN 很難 Train 成功

=> 就是在 P、Q 兩分佈沒有重疊時,JS Divergence 恆為 log2
導致無法測量出兩 Distribution 間的差距

推導

JSD(P||Q)=12DKL(P||M)+12DKL(Q||M)

將 M 帶入 KL Divergence

JSD(P||Q)=12P(x)log(P(x)12(P+Q))+12Q(x)log(Q(x)12(P+Q))

將 log 中分母的

12 取出 ( 先變分母,在分離 log )

JSD(P||Q)=12P(x)log(P(x)P+Q)+12Q(x)log(Q(x)P+Q)+log2

如果在 P、Q 兩分佈沒有重疊時,假如以 Q 來觀察,那 P 則會為 0
並將 0 帶入 JS Divergence

JSD(P||Q)=0+0+log2=log2

在真實 data 分佈上其實是很難有重疊的
可以把 data 看成是"高維向量中的低維向量" ( ex: 三維空間中的平面 )
所以很容易就會不重疊

更不用說上百維度的空間了 ~



Reference