# AlphaTensor
###### tags: `paper notes` `deep learning`
[Discovering novel algorithms with AlphaTensor](https://www.deepmind.com/blog/discovering-novel-algorithms-with-alphatensor)
## Introduction
- 之前做出 AlphaGo, AlphaZero 等專案的 DeepMind 之最新力作 (2022-10-05),用 RL 來找到比現有方法更有效率的矩陣相乘做法
- AlphaTensor 為 AlphaZero 的延伸,AlphaZero 是個完全沒有訓練過的神經網路,它透過與自己對戰來進行強化學習,而 AlphaTensor 一開始也沒有被灌輸任何既有的矩陣乘法演算法,而是藉由自我學習來重新發現了歷史上的各種快速矩陣演算法,包括知名的 Strassen,最後它超越了人類,找到比現今SOTA更快的演算法。
## 這有什麼厲害的?
### 矩陣相乘 (GEMM)

```cpp
// GEMM, A:MxK * B:KxN
for (int m = 0; m < M; m++) {
for (int n = 0; n < N; n++) {
C[m][n] = 0;
for (int k = 0; k < K; k++) {
C[m][n] += A[m][k] * B[k][n];
}
}
}
```
### Strassen’s algo
- 在 1969 年由德國的數學家 Volker Strassen 所發明
- 原本需要乘 8 ($2^3$) 次 → 7 ($2^{2.8}$) 次,$O(n^{\log_27}) = O(n^{2.807})$,更準確地說,原本需要乘8次加4次,現在變成了乘7次加18次
- [https://ccjou.wordpress.com/2013/06/04/分治矩陣乘法──strassen-演算法/](https://ccjou.wordpress.com/2013/06/04/%E5%88%86%E6%B2%BB%E7%9F%A9%E9%99%A3%E4%B9%98%E6%B3%95%E2%94%80%E2%94%80strassen-%E6%BC%94%E7%AE%97%E6%B3%95/)
- 在這之後雖然[進一步的研究](http://www.cs.utoronto.ca/~yuvalf/Limitations.pdf)有把複雜度降到 $2^{2.38}$但進步已經微乎其微了


## 重新定義問題
> Although tensor decomposition is NP-hard, the inverse task of constructing the tensor from its rank-one factors is elementary.
>
將”發現更有效率的矩陣相乘”這個任務變成一個單人遊戲
### 搜索空間
他們使用了以下兩個對應關係來建立搜索空間
1. 一個尺度的矩陣乘法對應一個表徵張量
2. 表徵張量的一個低秩分解(Low Rank Filters),也就是用兩個 k*1 的 conv kernel 替換掉一個 k*k 的 conv kernel 的那個 [trick](https://arxiv.org/abs/1703.09746),對應一種包含 R 次乘法的矩陣乘法流程
接下來只要通過尋找能使用最小的乘法次數來達到目的的 U, V, W,就可以讓他自己去設計矩陣算法
---
先來解釋第一個對應關係
- 圖中以 2x2 矩陣相乘為例,表徵張量 $T_2$尺寸為4x4x4共64個元素
- c1 = a1b1 + a2b3 → tensor entries located at (a1, b1, c1) and (a2, b3, c1) are set to 1
- 深色取值為1,淺色取值為0,表徵張量會直接對應為 C = A.B

第二個對應關係
- T_n = sum(UxVxW)

rank=1 的表徵向量分解 (透過三個一維向量做 outer product 得到)
- b 就是 strassen,c 則是 stassen 在 tensor 的表示方法
- 下圖展示的是表徵張量 rank-7 (乘7次) 的分解,U, V, W 都是4 row R col 的矩陣 (4維向量),

以 U, V, W 的第二列 (第二個 rank=1 之 U,V,W) 為例,他表示的是 m_2 和 c1~c4 的計算
- 綠色矩陣U的第二列是 [0 0 1 1],第3, 4項為1表示了 m2 計算過程中的 a3 和 a4
- 紫色矩陣V的第二列是 [1 0 0 0],第1項為1表示了 m2 計算過程中的 b1
- 黃色矩陣W的第二項是 [0 0 1 -1],代表 m2 在輸出的 c1 c2 c3 c4 的係數為 0 0 1 -1
以這種方式做對應之後,我們就能確保一種特徵張量 T2 的分解方法 (一組 U,V,W矩陣) 就能夠確定了唯一一種 2x2 矩陣的算法流程
### 在搜索空間中搜索

- 初始狀態是 $S_0 = T_N$
- 在第 t 步的時候,agent 會根據狀態 $S_{t-1}$來做決策,而這個決策就是要給出 $U^{(t)}, V^{(t)}, W^{(t)}$,好得到一組 rank=1 的 tensor 來做外積,再以此來改變當前的狀態(對外積後的結果做減法)
- 停止條件是當 $S_t$ = 0 或是步驟數大於預先設好的 $R_{limit}$ 的時候停止
因為這很明顯是硬湊的,所以為了讓 agent 快點湊出0, reward 的設計就變成
1. 每湊出一個步驟,無論好壞都會得到 -1 reward
2. 若$R_{limit}$ 之後沒有湊出零特徵張量,會得到額外的 $-\gamma(SR_{limit})$ 的 reward,這個數字的大小會跟最後湊出來的數值的 rank 有關,當結果的 rank 越大這個值就會越大,懲罰就會越多
### 剪枝搜索空間
為了縮小搜索空間,agent 在做 action 的時候是離散化的,因此他們限制每個決策所得到的向量中的值都必須在 {-2, -1, 0, 1, 2} 中選擇,也就是說那矩陣中前面的那每一項係數都得是這五個數字其中一個
- 這其實是很強大的剪枝,把整個搜索空間縮小了非常多
在做完離散化之後,這就變得有點像是下圍棋,每一步都會有多個離散的分支選擇,這些分支就可以對應一個搜索樹,然後就可以用他們在做 AlphaZero 所使用的 RL + MCTS (蒙地卡羅搜尋) 來處理這些樹

## 網路

- 輸入是一串包含 current state 和 previous history of actions 以及 scalars (index) 的 tensors
- 輸出有兩個,一個是 value,另一個是 action space 的分布
- c = 512
### torso
- 這個網路是用來 mapping scalars 和 tensor 到對應的矩陣相乘操作
- 這是變形的 transformers,它會將 SxSxS 的輸入投影到三個 SxS grids 上做計算,也就變成類似 Q, K, V 的東西,但他們是分開來算的
- scalars 就是位置編碼

### policy head
- Use transformer decoder to build a autoregressive policy network

### Value head
- 就4層linear

## Result
- 最後,AlphaTensor 真的學出了 Strassen 的排法並繼續演化,在他們的 blog 上展示了一個 5x5 的矩陣相乘在 SOTA 上需要乘80次,而 AlphaTensor 只需要乘76次
- 不過 paper 有提到目前仍尚未找出 3x3 矩陣的最佳解
- modular arithmetic and standard arithmetic

## 這有什麼用?
最直覺的就是拿來加速 GPU 運算
- 最右邊的是 AlphaTensor 增加了一個減少硬體計算速度的 reward 所做出的結果 (用 timeit 函數算)

想像未來任何可以暴搜的領域都能這樣探索
## Reference
[https://www.nature.com/articles/s41586-022-05172-4](https://www.nature.com/articles/s41586-022-05172-4)