# Kattis A+B Problem & Fast Fourier Transform 筆記 ###### tags: `tutorial` `解題筆記` `Kattis` `FFT` :::warning 今天也要持續學習。 [原題目在這裡。](https://open.kattis.com/problems/aplusb) ::: ## Attempt #1 : Brute Force + Hash (TLE) ### 思路 暴力搜索 $i,j$ , 再用 $unordered\_map$ 查詢是否存在 $k$ 使 $a_k=a_i+a_j$ 且 $i,j,k$ 兩兩不相同。 ```cpp= int nums[n]; for(int i=0;i<n;i++)cin>>nums[i]; unordered_map<int,int> nms; unsigned long long int ans=0; // 暴力搜索 i for(int i=2;i<n;i++){ //把第一個數字加入 map,從第二個開始暴力搜索 j //把每一個搜過的 j 加入 map,避免重複 nms[nums[0]]=1; for(int j=1;j<i;j++){ ans+=2*(nms[nums[i]+nums[j]] +nms[nums[i]-nums[j]] +nms[nums[j]-nums[i]]); nms[nums[j]]++; } //清除 map,保證不會取到 k>j nms.clear(); } ``` > 最差複雜度:$O(n^3)$ > 平均複雜度:$O(n^2)$ :::warning 第七個點 TLE 了。能不能再壓複雜度? ::: --- ## Attempt #2 : Brute Force + 線性搜 (RTE) ### 思路 先建好每一個 $a_i+a_j$ 出現次數,再搜尋 $a_k$ ```cpp= typedef ll long long int //全部轉換成正數 const ll ADD=50001; //記錄出現頻率 vector<ll>freq(200010,0),res(200010,0); ll n,tmp,mn=INT_MAX,mx=INT_MIN,zeros=0,ans=0; //保留原陣列 vector<ll> all; cin>>n; //特判 n<3 ,雖然好像沒有這組 if(n<3){cout<<0;return 0;} //輸入 for(ll i=0;i<n;i++){ cin>>tmp; //數零 if(!tmp)zeros++; freq[tmp+ADD]++; all.push_back(tmp+ADD); //記錄值域,減少不必要計算 if(mn>tmp+ADD)mn=tmp+ADD; if(mx<tmp+ADD)mx=tmp+ADD; } //計算 a[i]+a[j] 出現次數 for(ll i=mn;i<=mx;i++){ for(ll j=mn;j<=mx;j++){ res[i+j-ADD]+=freq[i]*freq[j]; } } //去除 a[i]+a[i] 的情形 for(auto i: all)res[i*2-ADD]--; //把每一個 a[k] 出現次數加起來 for(auto i: all)ans+=res[i]; //去除 a[i]+a[j]=a[i] (a[j]=0)的情形 ans-=(ll)2*zeros*(n-1); ``` > 複雜度:$O(n^2)$ 。 $n$ 是值域。 :::warning 第四個點 RTE 了。作者不會 debug :sweat: ::: 沒關係我們直接繼續~~ --- ## Attempt #3 : Fast Fourier Transform > [推薦影片](https://youtu.be/spUNpyF58BY),不過跟題目沒有直接關係就是了。 ### 思路 把出現次數的陣列看成多項式,計算次數就變成了多項式相乘。 例如範測二: ``` 6 1 1 3 3 4 6 ``` 數字出現的頻率: ``` num: 1 2 3 4 5 6 freq: 2 0 2 1 0 1 ``` 轉換成多項式: > $2x^1+0x^2+2x^3+1x^4+0x^5+1x^6$ > $=2x^1+2x^3+x^4+x^6$ 把這個多項式平方(自乘): > $(2x^1+2x^3+x^4+x^6)^2$ > $=4x^2+8x^4+2x^5+4x^6+8x^7+x^8+4x^9+2x^{10}+x^{12}$ 接著,把 $i=j$ 的狀況去掉: > $(4x^2+8x^4+2x^5+4x^6+8x^7+x^8+4x^9+2x^{10}+x^{12})\\-(2x^{1+1}+2x^{3+3}+1x^{4+4}+1x^{6+6})\\=8x^4+2x^5+2x^6+8x^7+4x^9+2x^{10}$ 最後,找尋出現過的數字(項次): > $8x^4+2x^6$ 係數加起來,再減掉 $a_i+a_j=a_i\ (a_j=0)$ 的狀況數就是答案了。 > $8+2-0=10$ :::info 但這就跟第二種做法一樣啊。要怎麼壓複雜度? ::: ### 複雜度? 這整個做法的複雜度是 $O(n^2)$,包含 $O(n^2)$ 的多項式乘法以及 $O(n)$ 的線性搜尋。好消息是,我們有個方法可以把多項式乘法壓到$O(n\log_2n)$ —— :::success :tada:快速傅立葉變換 FFT!:tada: [這裏](https://youtu.be/iTMn0Kt18tg)有很棒的概念教學呦。下面的概念教學大多是這部影片的筆記,在往下看前可以先試著去理解一下喔。:wink: ::: > 注意:傅立葉變換和快速傅立葉變換不同, > 前者處理的是時域---頻域變換, > 後者做的是加速其中多項式乘法的演算法。 作者懶得講概念了(可能會補啦),請自己去看上面的教學喔。 有點長,但是看懂了絕對值得啦:satisfied: ### FFT如何加速多項式乘法? #### 多項式的性質 我們知道多項式有三種表示方法: * 係數表示法 > $a_0x^0 + a_1x^1 +a_2x^2 +......+ a_nx^n$ > $=\sum_{k=0}^{n} a_kx^k$ > $=a_0+x(a_1+x(a_2+x(......)))$ > $=<a_0,a_1,a_2,......,a_n>$ 其中 $a_n$ 是係數。 這是最常見、也最直觀地表示方式。 一個 $n$ 次多項式,共有 $n+1$ 個係數(記得常數項喔)。 * 根(解)表示法(因式分解) > $a(x-x_0)(x-x_1)(x-x_2)......(x-x_n)$ > $=a\prod_{k=1}^{n} (x-x_k)$ > $=<x_1,x_2,......,x_n>$ 其中 $a$ 是領導係數,$x_n$ 是根。 在解題畫圖的時候常用的表示方式。 * 取樣點表示法(由外文直譯,有正式中文名稱的話請留言告訴作者) 給定平面上的 $n$ 個相異的點,只有唯一的一個 $n$ 次多項式會通過所有這些點。因此,一個 $n$ 次多項式也可以直接帶入 $x$ 取 $n$ 個相異的點,表示這個多項式。 > $(x_1,y_1),(x_2,y_2),......,(x_n,y_n)$ > 換成方程式要用拉格朗日插值法,好麻煩 > 這裡就不多寫了 現實中少用的表示法,但是有一些奇怪的性質。 :::info 那麼,他們在運算上各有什麼優點呢? ::: 多項式的運算主要分為三種: * 求值(求一個 $x$ 對應的 $y$ ) * 相加(多項式兩兩相加) * 相乘(多項式兩兩相乘) 而三種表示法分別對應到三種運算的理論複雜度是: * 係數表示法 | 求值 | 加法 | 乘法 | |:--------------------------:|:----------------------------:| -------- | | $O(n)$ | $O(n)$ | $O(n^2)$ | | 參照上面係數表示法第三行。 | 把同次項的係數兩兩相加即可。 | 一項一項乘啊⋯⋯慢。 | * 根表示法 | 求值 | 加法 | 乘法 | |:--------------------------:|:----------------------------:| :--------: | | $O(n)$ | $\infty$ | $O(n)$ | | 一個一個帶入乘起來。 | 算算看就知道。 | 把兩個接起來就好嚕。 | * 取樣點表示法 | 求值 | 加法 | 乘法 | |:--------------------------:|:----------------------------:| :--------: | | $O(n^2)$ | $O(n)$ | $O(n)$ | | 已知最好演算法的複雜度是 $O(n^2)$。 | 只要兩邊取的 $x$ 值一樣,就可以直接把對應的 $y$ 值兩兩相加。 | 跟加法很相似,只要兩邊取的 $x$ 值一樣,直接把對應的 $y$ 值兩兩相乘就是答案了。 | > 註:求值的部分作者不會證明(大概就是拉格朗日插下去),影片裡也沒有說。請自行參透。 我們的目標是要 $O(n \log n)$,然而沒有一個表示方式可以完全達到。 :::info 如果在這三者之間互相轉換呢? ::: 如果我們使用係數表示法來進行加法、求值,用取樣點表示法來計算多項式乘法——那麼只要找到一種 $O(n \log n)$ 的轉換方式,就可以達成目標了。 > 註:雖然根表示法的乘法時間也是 $O(n \log n)$,但是沒辦法好好轉換(五次以上多項式沒有公式解),故不考慮。 所以,題目解法就是: 1. 把出現次數的多項式從係數表示轉換成取樣點表示 2. 各個取樣點相乘 3. 再轉換回係數表示 這樣,問題就只剩下怎麼把多項式由係數表示在 $O(n \log n)$ 時間內轉換成取樣點表示。 #### 快速傅立葉變換? 正常的「係數——取樣點」轉換需要 $O(n^2)$:有 $n$ 個點,每個點要計算 $n$ 的時間。 但是,只要巧妙的選擇要取樣的點座標,讓要計算的 $x$ 重複,就能用 Divide & Conquer 壓成 $O(n\log_2n)$。耶,問題解決。 [再貼一個教學。](http://www.csie.ntnu.edu.tw/~u91029/Wave.html) 作者真的懶得解釋,請自行參透。 ### 實作 首先是位元反轉。 為了輔助稍後的 FFT,先將資料按照位元反轉後的順序排序,DP 時就可以直接疊上去。藉由 DP 實作可以達到 $O(n)$。 ```cpp= #define MAX 600000 //題目值域 ll revBit[MAX]; ll bound,logBound; //記錄陣列長度 void bitRev(){ //紀錄 High Bit 位置 ll highBitPtr=-1; // 0 的 revBit 就是 0,不要做,會爛掉 for(ll i=1;i<bound;i++){ //(i&(i-1)==0 代表 i 是 2 的冪次 if((i&(i-1))==0)highBitPtr++; //取比 i 小一個 High Bit 的數的 revBit //再把自己的 High Bit 翻過去後接上 revBit[i]=revBit[i^(1<<highBitPtr)] |(1<<(logBound- highBitPtr-1)); } } ``` 再來是製作取樣點。這裏用ω代表。 ```cpp= //為了讓程式看起來更簡潔,增加可讀性 //C++ 也可以用希臘字母喔 //但是上傳 Kattis 前請取代掉,Kattis 的編譯器不吃喔 typedef complex<double> cd; const double π=M_PI; cd ω[MAX]; void calc(){ ω[0]=1; for(int i=1;i<=bound;i++){ ω[i]=cd(cos(2*π/bound*i),sin(2*π/bound*i)); } } ``` :::warning 注意,這裏有浮點數精度的問題喔。 作者嘗試用迭代的方式去做,也就是用 ω[i-1] 計算 ω[i],但是這樣會造成浮點數問題,害作者吃了快十次 WA 才找到 bug ⋯⋯ 另外,最後轉換回去要四捨五入。 以上兩步驟的迴圈可以合併,但為了可讀性先分開放著。 ::: 接下來就是重頭戲了—— **FFT 本體**。 為了方便理解先放一張圖。 圖源:[演算法筆記——Wave](http://www.csie.ntnu.edu.tw/~u91029/Wave.html) ![](https://i.imgur.com/aWStmOQ.png) ```cpp= //用指標傳入,直接修改原陣列 void fft(cd *arr){ //按照 revBit 排序,方便直接 DP //也有的版本是最後做排序 for(int i=0;i<bound;i++){ //避免重複交換 if(i<revBit[i])swap(arr[i],arr[revBit[i]]); } // DP 暫存用 cd a,b; //上圖橫軸,每一次做一層 for(int len=1;len<bound;len*=2){ //上圖中的結,每一次做一組 for(int g=0;g<bound;g+=2*len){ //上圖中的單向,每一次做一個 for(int i=0;i<len;i++){ a=arr[g+i], b=arr[g+i+(len>>1)]*ω[bound/len/2*i]; arr[g+i]=a+b; // a-b = a+(-b) arr[g+i+len]=a-b; } } } //雖然像是 O(n^3),但實際上最外圈是 O(log n), //內兩圈合起來才是 O(n) } ``` 有了 FFT 後,也得有逆轉換才行。 好消息是,逆轉換的流程基本上相同:satisfied: 逆轉換不同的地方: * $ω$ 要換成 $\overline{ω}$(共軛複數) * 轉換完之後要把每一項除以 $n$ (總長) :::warning 為什麼?請看[這裏](https://youtu.be/iTMn0Kt18tg)的教學呦(跟上面同一個。) $1:05:00$ 左右的地方有解釋呦。 ::: 在原本的函數上加上一個 flag,作為正負轉換標記: ```cpp= //inverse: 判斷是否為逆轉換 void fft(cd *arr,bool inverse){ for(int i=0;i<bound;i++){ if(i<revBit[i])swap(arr[i],arr[revBit[i]]); } cd a,b; for(int len=1;len<bound;len<<=1){ for(int g=0;g<bound;g+=2*len){ for(int i=0;i<len;i++){ a=arr[g+i]; //判斷要不要取共軛 if(inverse)b=arr[g+i+len]*conj(W[bound/len/2*i]); else b=arr[g+i+len]*W[bound/len/2*i]; arr[g+i]=a+b; arr[g+i+len]=a-b; } } } //把每一項除回來 if(inverse){ for(int i=0;i<bound;i++){ arr[i]/=bound; } } } ``` > 註:另外有一個方法不用逆轉換,只要做正常的轉換後,把第二項到最後一項倒序即可。用 STL 的 reverse 實作。 > 為什麽第一項不用翻?難得作者知道。第一項要翻應該要跟第 $n+1$ 項換,但沒有這項,而且 $(1,0)$ 轉了一整圈之後還是在 $(1,0)$。不了解的話請先看一下[這個](https://youtu.be/spUNpyF58BY)喔。 最後,還記得一開始的步驟嗎? 1. 把出現次數的多項式從係數表示轉換成取樣點表示 2. 各個取樣點相乘 3. 再轉換回係數表示 把這個步驟做成一個函式: ```cpp= //暫存用的陣列 cd tmpa[MAX*2],tmpb[MAX*2]; //傳入要相乘的兩個陣列 a、b,把答案存到 c void mult(vector<ll> &a,vector<ll> &b,vector<ll> &c){ //為了判斷要做幾層(上面的圖橫軸),先判斷兩陣列大小 logBound=0; while((1<<logBound)<a.size()||(1<<logBound)<b.size())logBound++; bound=1<<(++logBound); //預先計算位元反轉和取樣點的資料 //少打這兩行也害作者 debug 很久 bitRev(); calc(); //把傳入的資料換成負數格式,並補上不足長度的零 //為什麼補零?作者不會證明,只知道這樣方便直接計算 //想了解請去看教學喔 for(int i=0;i<a.size();i++)tmpa[i]=cd(a[i],0); for(int i=a.size();i<bound;i++)tmpa[i]=cd(0,0); for(int i=0;i<b.size();i++)tmpb[i]=cd(b[i],0); for(int i=b.size();i<bound;i++)tmpb[i]=cd(0,0); //轉換~~~ fft(tmpa,0); fft(tmpb,0); //把每一個取樣點相乘 for(int i=0;i<bound;i++)tmpa[i]*=tmpb[i]; //逆轉換~~~ fft(tmpa,1); c.resize(bound); for(int i=0;i<bound;i++){ //避免浮點數精度造成的問題 c[i]=(ll)(tmpa[i].real()>0?tmpa[i].real()+0.5:tmpa[i].real()-0.5); } //把剛剛補的零去掉 while(c.size()&&c.back()==0)c.pop_back(); } ``` 到目前為止,完整的 $O(n\log_2n)$ 多項式相乘已經完成了!:tada: 接下來就是做題目要的部分囉。 步驟: 1. 輸入所有數字,紀錄每個數字出現的頻率、零的數量。 2. 轉換成多項式,把它平方。 3. 去除 $a_i+a_i=a_k$ 的情形。 4. 挑出符合的 $a_j$ 出現的次數。 5. 去除 $a_i+a_j=a_i$ 的狀況數(數零的用處)。 ```cpp= //把所有數字換成非負整數,方便計算 #define ADD 50000 #define pb push_back int main(){ //養成輸入優化的好習慣,尤其是這種大量輸入的題 ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0); ll n;cin>>n; //紀錄每個數字出現頻率 vector<ll> freq(100001,0); //保留所有入的數字 vector<ll> nums; //儲存相乘後的頻率(a[i]+a[j]各個數字的出現頻率) vector<ll> ansf; //輸入暫存和零的數量 ll tmp,zeros=0; for(int i=0;i<n;i++){ cin>>tmp; if(!tmp)zeros++; nums.pb(tmp+ADD); freq[tmp+ADD]++; } //多項式相乘,本題為自乘 //還可以針對本題優化啦,但是先寫萬用版本 mult(freq,freq,ansf); //去除 a[i]+a[i]=a[k] 的情形 for(auto i:nums)ansf[i*2]--; //加總各個a[k]出現次數 ll ans=0; for(auto i:nums)ans+=ansf[i+ADD]; //去除 a[i]+a[j]=a[i] (a[j]=0) 的狀況數 ans-=(ll)(2*zeros*(n-1)); cout<<ans; //養成回傳的好習慣,有時候不做會造成 RTE 喔。 return 0; } ``` ### 完成囉:tada: 但是還可以針對單一題目優化呦。 例如函數合併、常數壓準確之類的,有空再補吧。 用這份扣的跑出來的結果: ![](https://i.imgur.com/j2IZq1W.png) > 本機編譯花很久的時間是正常的呦,畢竟佔了很大的記憶體嘛。 :::success 有任何問題麻煩留言告訴作者喔~~ :::