# Kattis A+B Problem & Fast Fourier Transform 筆記
###### tags: `tutorial` `解題筆記` `Kattis` `FFT`
:::warning
今天也要持續學習。
[原題目在這裡。](https://open.kattis.com/problems/aplusb)
:::
## Attempt #1 : Brute Force + Hash (TLE)
### 思路
暴力搜索 $i,j$ , 再用 $unordered\_map$ 查詢是否存在 $k$ 使 $a_k=a_i+a_j$ 且 $i,j,k$ 兩兩不相同。
```cpp=
int nums[n];
for(int i=0;i<n;i++)cin>>nums[i];
unordered_map<int,int> nms;
unsigned long long int ans=0;
// 暴力搜索 i
for(int i=2;i<n;i++){
//把第一個數字加入 map,從第二個開始暴力搜索 j
//把每一個搜過的 j 加入 map,避免重複
nms[nums[0]]=1;
for(int j=1;j<i;j++){
ans+=2*(nms[nums[i]+nums[j]]
+nms[nums[i]-nums[j]]
+nms[nums[j]-nums[i]]);
nms[nums[j]]++;
}
//清除 map,保證不會取到 k>j
nms.clear();
}
```
> 最差複雜度:$O(n^3)$
> 平均複雜度:$O(n^2)$
:::warning
第七個點 TLE 了。能不能再壓複雜度?
:::
---
## Attempt #2 : Brute Force + 線性搜 (RTE)
### 思路
先建好每一個 $a_i+a_j$ 出現次數,再搜尋 $a_k$
```cpp=
typedef ll long long int
//全部轉換成正數
const ll ADD=50001;
//記錄出現頻率
vector<ll>freq(200010,0),res(200010,0);
ll n,tmp,mn=INT_MAX,mx=INT_MIN,zeros=0,ans=0;
//保留原陣列
vector<ll> all;
cin>>n;
//特判 n<3 ,雖然好像沒有這組
if(n<3){cout<<0;return 0;}
//輸入
for(ll i=0;i<n;i++){
cin>>tmp;
//數零
if(!tmp)zeros++;
freq[tmp+ADD]++;
all.push_back(tmp+ADD);
//記錄值域,減少不必要計算
if(mn>tmp+ADD)mn=tmp+ADD;
if(mx<tmp+ADD)mx=tmp+ADD;
}
//計算 a[i]+a[j] 出現次數
for(ll i=mn;i<=mx;i++){
for(ll j=mn;j<=mx;j++){
res[i+j-ADD]+=freq[i]*freq[j];
}
}
//去除 a[i]+a[i] 的情形
for(auto i: all)res[i*2-ADD]--;
//把每一個 a[k] 出現次數加起來
for(auto i: all)ans+=res[i];
//去除 a[i]+a[j]=a[i] (a[j]=0)的情形
ans-=(ll)2*zeros*(n-1);
```
> 複雜度:$O(n^2)$ 。 $n$ 是值域。
:::warning
第四個點 RTE 了。作者不會 debug :sweat:
:::
沒關係我們直接繼續~~
---
## Attempt #3 : Fast Fourier Transform
> [推薦影片](https://youtu.be/spUNpyF58BY),不過跟題目沒有直接關係就是了。
### 思路
把出現次數的陣列看成多項式,計算次數就變成了多項式相乘。
例如範測二:
```
6
1 1 3 3 4 6
```
數字出現的頻率:
```
num: 1 2 3 4 5 6
freq: 2 0 2 1 0 1
```
轉換成多項式:
> $2x^1+0x^2+2x^3+1x^4+0x^5+1x^6$
> $=2x^1+2x^3+x^4+x^6$
把這個多項式平方(自乘):
> $(2x^1+2x^3+x^4+x^6)^2$
> $=4x^2+8x^4+2x^5+4x^6+8x^7+x^8+4x^9+2x^{10}+x^{12}$
接著,把 $i=j$ 的狀況去掉:
> $(4x^2+8x^4+2x^5+4x^6+8x^7+x^8+4x^9+2x^{10}+x^{12})\\-(2x^{1+1}+2x^{3+3}+1x^{4+4}+1x^{6+6})\\=8x^4+2x^5+2x^6+8x^7+4x^9+2x^{10}$
最後,找尋出現過的數字(項次):
> $8x^4+2x^6$
係數加起來,再減掉 $a_i+a_j=a_i\ (a_j=0)$ 的狀況數就是答案了。
> $8+2-0=10$
:::info
但這就跟第二種做法一樣啊。要怎麼壓複雜度?
:::
### 複雜度?
這整個做法的複雜度是 $O(n^2)$,包含 $O(n^2)$ 的多項式乘法以及 $O(n)$ 的線性搜尋。好消息是,我們有個方法可以把多項式乘法壓到$O(n\log_2n)$ ——
:::success
:tada:快速傅立葉變換 FFT!:tada:
[這裏](https://youtu.be/iTMn0Kt18tg)有很棒的概念教學呦。下面的概念教學大多是這部影片的筆記,在往下看前可以先試著去理解一下喔。:wink:
:::
> 注意:傅立葉變換和快速傅立葉變換不同,
> 前者處理的是時域---頻域變換,
> 後者做的是加速其中多項式乘法的演算法。
作者懶得講概念了(可能會補啦),請自己去看上面的教學喔。
有點長,但是看懂了絕對值得啦:satisfied:
### FFT如何加速多項式乘法?
#### 多項式的性質
我們知道多項式有三種表示方法:
* 係數表示法
> $a_0x^0 + a_1x^1 +a_2x^2 +......+ a_nx^n$
> $=\sum_{k=0}^{n} a_kx^k$
> $=a_0+x(a_1+x(a_2+x(......)))$
> $=<a_0,a_1,a_2,......,a_n>$
其中 $a_n$ 是係數。
這是最常見、也最直觀地表示方式。
一個 $n$ 次多項式,共有 $n+1$ 個係數(記得常數項喔)。
* 根(解)表示法(因式分解)
> $a(x-x_0)(x-x_1)(x-x_2)......(x-x_n)$
> $=a\prod_{k=1}^{n} (x-x_k)$
> $=<x_1,x_2,......,x_n>$
其中 $a$ 是領導係數,$x_n$ 是根。
在解題畫圖的時候常用的表示方式。
* 取樣點表示法(由外文直譯,有正式中文名稱的話請留言告訴作者)
給定平面上的 $n$ 個相異的點,只有唯一的一個 $n$ 次多項式會通過所有這些點。因此,一個 $n$ 次多項式也可以直接帶入 $x$ 取 $n$ 個相異的點,表示這個多項式。
> $(x_1,y_1),(x_2,y_2),......,(x_n,y_n)$
> 換成方程式要用拉格朗日插值法,好麻煩
> 這裡就不多寫了
現實中少用的表示法,但是有一些奇怪的性質。
:::info
那麼,他們在運算上各有什麼優點呢?
:::
多項式的運算主要分為三種:
* 求值(求一個 $x$ 對應的 $y$ )
* 相加(多項式兩兩相加)
* 相乘(多項式兩兩相乘)
而三種表示法分別對應到三種運算的理論複雜度是:
* 係數表示法
| 求值 | 加法 | 乘法 |
|:--------------------------:|:----------------------------:| -------- |
| $O(n)$ | $O(n)$ | $O(n^2)$ |
| 參照上面係數表示法第三行。 | 把同次項的係數兩兩相加即可。 | 一項一項乘啊⋯⋯慢。 |
* 根表示法
| 求值 | 加法 | 乘法 |
|:--------------------------:|:----------------------------:| :--------: |
| $O(n)$ | $\infty$ | $O(n)$ |
| 一個一個帶入乘起來。 | 算算看就知道。 | 把兩個接起來就好嚕。 |
* 取樣點表示法
| 求值 | 加法 | 乘法 |
|:--------------------------:|:----------------------------:| :--------: |
| $O(n^2)$ | $O(n)$ | $O(n)$ |
| 已知最好演算法的複雜度是 $O(n^2)$。 | 只要兩邊取的 $x$ 值一樣,就可以直接把對應的 $y$ 值兩兩相加。 | 跟加法很相似,只要兩邊取的 $x$ 值一樣,直接把對應的 $y$ 值兩兩相乘就是答案了。 |
> 註:求值的部分作者不會證明(大概就是拉格朗日插下去),影片裡也沒有說。請自行參透。
我們的目標是要 $O(n \log n)$,然而沒有一個表示方式可以完全達到。
:::info
如果在這三者之間互相轉換呢?
:::
如果我們使用係數表示法來進行加法、求值,用取樣點表示法來計算多項式乘法——那麼只要找到一種 $O(n \log n)$ 的轉換方式,就可以達成目標了。
> 註:雖然根表示法的乘法時間也是 $O(n \log n)$,但是沒辦法好好轉換(五次以上多項式沒有公式解),故不考慮。
所以,題目解法就是:
1. 把出現次數的多項式從係數表示轉換成取樣點表示
2. 各個取樣點相乘
3. 再轉換回係數表示
這樣,問題就只剩下怎麼把多項式由係數表示在 $O(n \log n)$ 時間內轉換成取樣點表示。
#### 快速傅立葉變換?
正常的「係數——取樣點」轉換需要 $O(n^2)$:有 $n$ 個點,每個點要計算 $n$ 的時間。
但是,只要巧妙的選擇要取樣的點座標,讓要計算的 $x$ 重複,就能用 Divide & Conquer 壓成 $O(n\log_2n)$。耶,問題解決。
[再貼一個教學。](http://www.csie.ntnu.edu.tw/~u91029/Wave.html)
作者真的懶得解釋,請自行參透。
### 實作
首先是位元反轉。
為了輔助稍後的 FFT,先將資料按照位元反轉後的順序排序,DP 時就可以直接疊上去。藉由 DP 實作可以達到 $O(n)$。
```cpp=
#define MAX 600000 //題目值域
ll revBit[MAX];
ll bound,logBound; //記錄陣列長度
void bitRev(){
//紀錄 High Bit 位置
ll highBitPtr=-1;
// 0 的 revBit 就是 0,不要做,會爛掉
for(ll i=1;i<bound;i++){
//(i&(i-1)==0 代表 i 是 2 的冪次
if((i&(i-1))==0)highBitPtr++;
//取比 i 小一個 High Bit 的數的 revBit
//再把自己的 High Bit 翻過去後接上
revBit[i]=revBit[i^(1<<highBitPtr)]
|(1<<(logBound- highBitPtr-1));
}
}
```
再來是製作取樣點。這裏用ω代表。
```cpp=
//為了讓程式看起來更簡潔,增加可讀性
//C++ 也可以用希臘字母喔
//但是上傳 Kattis 前請取代掉,Kattis 的編譯器不吃喔
typedef complex<double> cd;
const double π=M_PI;
cd ω[MAX];
void calc(){
ω[0]=1;
for(int i=1;i<=bound;i++){
ω[i]=cd(cos(2*π/bound*i),sin(2*π/bound*i));
}
}
```
:::warning
注意,這裏有浮點數精度的問題喔。
作者嘗試用迭代的方式去做,也就是用 ω[i-1] 計算 ω[i],但是這樣會造成浮點數問題,害作者吃了快十次 WA 才找到 bug ⋯⋯
另外,最後轉換回去要四捨五入。
以上兩步驟的迴圈可以合併,但為了可讀性先分開放著。
:::
接下來就是重頭戲了—— **FFT 本體**。
為了方便理解先放一張圖。
圖源:[演算法筆記——Wave](http://www.csie.ntnu.edu.tw/~u91029/Wave.html)

```cpp=
//用指標傳入,直接修改原陣列
void fft(cd *arr){
//按照 revBit 排序,方便直接 DP
//也有的版本是最後做排序
for(int i=0;i<bound;i++){
//避免重複交換
if(i<revBit[i])swap(arr[i],arr[revBit[i]]);
}
// DP 暫存用
cd a,b;
//上圖橫軸,每一次做一層
for(int len=1;len<bound;len*=2){
//上圖中的結,每一次做一組
for(int g=0;g<bound;g+=2*len){
//上圖中的單向,每一次做一個
for(int i=0;i<len;i++){
a=arr[g+i],
b=arr[g+i+(len>>1)]*ω[bound/len/2*i];
arr[g+i]=a+b;
// a-b = a+(-b)
arr[g+i+len]=a-b;
}
}
}
//雖然像是 O(n^3),但實際上最外圈是 O(log n),
//內兩圈合起來才是 O(n)
}
```
有了 FFT 後,也得有逆轉換才行。
好消息是,逆轉換的流程基本上相同:satisfied:
逆轉換不同的地方:
* $ω$ 要換成 $\overline{ω}$(共軛複數)
* 轉換完之後要把每一項除以 $n$ (總長)
:::warning
為什麼?請看[這裏](https://youtu.be/iTMn0Kt18tg)的教學呦(跟上面同一個。)
$1:05:00$ 左右的地方有解釋呦。
:::
在原本的函數上加上一個 flag,作為正負轉換標記:
```cpp=
//inverse: 判斷是否為逆轉換
void fft(cd *arr,bool inverse){
for(int i=0;i<bound;i++){
if(i<revBit[i])swap(arr[i],arr[revBit[i]]);
}
cd a,b;
for(int len=1;len<bound;len<<=1){
for(int g=0;g<bound;g+=2*len){
for(int i=0;i<len;i++){
a=arr[g+i];
//判斷要不要取共軛
if(inverse)b=arr[g+i+len]*conj(W[bound/len/2*i]);
else b=arr[g+i+len]*W[bound/len/2*i];
arr[g+i]=a+b;
arr[g+i+len]=a-b;
}
}
}
//把每一項除回來
if(inverse){
for(int i=0;i<bound;i++){
arr[i]/=bound;
}
}
}
```
> 註:另外有一個方法不用逆轉換,只要做正常的轉換後,把第二項到最後一項倒序即可。用 STL 的 reverse 實作。
> 為什麽第一項不用翻?難得作者知道。第一項要翻應該要跟第 $n+1$ 項換,但沒有這項,而且 $(1,0)$ 轉了一整圈之後還是在 $(1,0)$。不了解的話請先看一下[這個](https://youtu.be/spUNpyF58BY)喔。
最後,還記得一開始的步驟嗎?
1. 把出現次數的多項式從係數表示轉換成取樣點表示
2. 各個取樣點相乘
3. 再轉換回係數表示
把這個步驟做成一個函式:
```cpp=
//暫存用的陣列
cd tmpa[MAX*2],tmpb[MAX*2];
//傳入要相乘的兩個陣列 a、b,把答案存到 c
void mult(vector<ll> &a,vector<ll> &b,vector<ll> &c){
//為了判斷要做幾層(上面的圖橫軸),先判斷兩陣列大小
logBound=0;
while((1<<logBound)<a.size()||(1<<logBound)<b.size())logBound++;
bound=1<<(++logBound);
//預先計算位元反轉和取樣點的資料
//少打這兩行也害作者 debug 很久
bitRev();
calc();
//把傳入的資料換成負數格式,並補上不足長度的零
//為什麼補零?作者不會證明,只知道這樣方便直接計算
//想了解請去看教學喔
for(int i=0;i<a.size();i++)tmpa[i]=cd(a[i],0);
for(int i=a.size();i<bound;i++)tmpa[i]=cd(0,0);
for(int i=0;i<b.size();i++)tmpb[i]=cd(b[i],0);
for(int i=b.size();i<bound;i++)tmpb[i]=cd(0,0);
//轉換~~~
fft(tmpa,0);
fft(tmpb,0);
//把每一個取樣點相乘
for(int i=0;i<bound;i++)tmpa[i]*=tmpb[i];
//逆轉換~~~
fft(tmpa,1);
c.resize(bound);
for(int i=0;i<bound;i++){
//避免浮點數精度造成的問題
c[i]=(ll)(tmpa[i].real()>0?tmpa[i].real()+0.5:tmpa[i].real()-0.5);
}
//把剛剛補的零去掉
while(c.size()&&c.back()==0)c.pop_back();
}
```
到目前為止,完整的 $O(n\log_2n)$ 多項式相乘已經完成了!:tada:
接下來就是做題目要的部分囉。
步驟:
1. 輸入所有數字,紀錄每個數字出現的頻率、零的數量。
2. 轉換成多項式,把它平方。
3. 去除 $a_i+a_i=a_k$ 的情形。
4. 挑出符合的 $a_j$ 出現的次數。
5. 去除 $a_i+a_j=a_i$ 的狀況數(數零的用處)。
```cpp=
//把所有數字換成非負整數,方便計算
#define ADD 50000
#define pb push_back
int main(){
//養成輸入優化的好習慣,尤其是這種大量輸入的題
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
ll n;cin>>n;
//紀錄每個數字出現頻率
vector<ll> freq(100001,0);
//保留所有入的數字
vector<ll> nums;
//儲存相乘後的頻率(a[i]+a[j]各個數字的出現頻率)
vector<ll> ansf;
//輸入暫存和零的數量
ll tmp,zeros=0;
for(int i=0;i<n;i++){
cin>>tmp;
if(!tmp)zeros++;
nums.pb(tmp+ADD);
freq[tmp+ADD]++;
}
//多項式相乘,本題為自乘
//還可以針對本題優化啦,但是先寫萬用版本
mult(freq,freq,ansf);
//去除 a[i]+a[i]=a[k] 的情形
for(auto i:nums)ansf[i*2]--;
//加總各個a[k]出現次數
ll ans=0;
for(auto i:nums)ans+=ansf[i+ADD];
//去除 a[i]+a[j]=a[i] (a[j]=0) 的狀況數
ans-=(ll)(2*zeros*(n-1));
cout<<ans;
//養成回傳的好習慣,有時候不做會造成 RTE 喔。
return 0;
}
```
### 完成囉:tada:
但是還可以針對單一題目優化呦。
例如函數合併、常數壓準確之類的,有空再補吧。
用這份扣的跑出來的結果:

> 本機編譯花很久的時間是正常的呦,畢竟佔了很大的記憶體嘛。
:::success
有任何問題麻煩留言告訴作者喔~~
:::