# Matrix Chain 本文主要介紹演算法 DP 問題中其中一個常出現的題目 Matrix Chain 實際上如何計算,以及稍微解釋建 table 時背後的意義。 如果對於文中的截圖有興趣,可參考最下方的 Reference。 --- 另外,這篇文章是為了我考研究所的妹妹寫的 :D 這次寫完這篇筆記,回顧一年前考試時的內容已經比當時更熟悉。發現這件事讓我很開心,也很希望能像以前一樣不斷更新筆記,只不過研究所實在是太忙了QQ 或許要等到去工作以後才有辦法吧 XD 總之,現在有辦法更新的東西大概只有 github 了,我的頁面主要放一些學校的作業和筆記,有興趣的朋友可以來[我的 github](https://github.com/pipi-bear) 串串門子~ --- # Pseudocode ![code](https://hackmd.io/_uploads/H1zffwDqp.png) 在下方的 `直觀說明` 中並不會太著墨於 pseudocode 本身,用意是希望在理解整個意義以後,在最後一節 `pseudocode 說明` 再詳細的做 pseudocode 的解釋。 # 直觀說明 ## input 首先我們先看 `MATRIX-CHAIN-ORDER` 的 input $p$ 的定義: ![input](https://hackmd.io/_uploads/HJF3eTh8xx.png) 底線第二行中間告訴我們,我們的 input $p$ 是一個長度為 $n+1$ 的 sequence($n$ 是要乘的矩陣個數): \begin{equation} p = \langle p_0, \ p_1, \cdots ,\ p_n \rangle \end{equation} 那這裡的 $p_0$ 到 $p_n$ 的值又是什麼呢? 前半段畫底線的地方他告訴我們,對每個矩陣 $A_i \ (i:1 \sim n)$,$A_i$ 的 dimension 定義了 $p_{i-1}, \ p_i$,舉例來說,假如我們第一個要乘的矩陣 $A_1$ 是一個 $3 \times 4$ 的矩陣,那前面的 $3$ 就是 $p_{i-1} = p_{1-1} = p_0$,同理 $p_i = p_1 = 4$,寫清楚一點: \begin{equation} A_1: 3 \times 4 \quad \Rightarrow \quad \begin{cases} p_0 = 3 \\ p_1 = 4 \end{cases} \end{equation} 知道 $p_i$ 怎麼求以後,我們就可以構建出我們的 input sequence $p$ 啦! ## 每一格的意義 回過頭來看我們的 pseudocode,因為 `MATRIX-CHAIN-ORDER` 是一種採用 ==bottom-up== 方式的 DP,所以我們解題的方式就是要先建立 table,然後由下往上填。在後面我們會看一個例子,但如果我們現在先不管值,先大概看看表格長什麼樣子,我們可以看到下圖中,假如我們有 $n$ 個矩陣要乘(在這個例子裡是 $6$ 個),它會畫成這樣的一個形狀: ![table_with_line](https://hackmd.io/_uploads/SJfV1anUxg.png) 在 pseudocode 的 `3-4` 行我們做的是先把最下面那層全部填零,在 pseudocode 裡它說: \begin{equation} m[i, i] = 0 \qquad i = 1, \cdots, n \end{equation} 它的意思是什麼呢?為什麼要這麼做?每一格 $m[i, j]$ 又是什麼? 如果我們先看這裡 $i = j$ 的情形,每格 <font color = "blue">$m[i, i]$</font> 其實在說的是,<ins>第 $i$ 個矩陣乘到第 $i$ 個矩陣所需要的最少矩陣乘法數</ins>,可是現在「第 $i$ 個矩陣乘到第 $i$ 個矩陣」不就代表不用乘嗎?所以這也就是為什麼我們把這層全部填零的理由。 同樣地: :::info 拓展到 $i$ 不必然等於 $j$ 的情形,每格 <font color = "blue">$m[i, j]$</font> 代表的是:<ins>第 $i$ 個矩陣乘到第 $j$ 個矩陣所需要的最少矩陣乘法數</ins> ::: 因為在做矩陣乘法時,我們不能隨意調換相乘的順序,我們能決定的只有好比說四個矩陣 $A_1 \sim A_4$,我們要先乘 $A_1, \ A_2$ 和 $A_3, \ A_4$ 然後再把兩個結果相乘,還是先把 $A_2$ 乘到 $A_4$ ($A_2 \times A_3 \times A_4$),再和 $A_1$ 相乘,這兩種方式(以及其他種方式)所需要的矩陣乘法數,其實並不相同。 而放進我們的符號 $m[i, j]$ 裡,對這個例子來說,當我們在說 <font color = "green">$m[1, 4]$</font> 時,代表的意義其實是: :::success 如果我要乘 $A_1 \sim A_4$,在所有可能的相乘順序下,最少需要多少個矩陣乘法。 ::: ### 計算每一格的方式 知道每一格是什麼意思以後,我們的就可以再回過頭來看 pseudocode 的 `5-13` 行的意義,也就是去填除了最下層以外的格子。 如果清楚 $m[i, j]$ 是什麼意思,第二層就也很好填了,為什麼呢? 我們用剛才的例子說明,例子的完整定義如下圖: # 例子 ![ex](https://hackmd.io/_uploads/BkWuNPvqT.png) 在這個 `Figure 15.5` 裡它說我們有 $6$ 個矩陣要相乘,每個的矩陣是幾乘幾寫在中間的表格,根據前面 `說明/input` 小節的定義,我們可以標上 $p_0$ 到 $p_6$: ![image](https://hackmd.io/_uploads/By4x9an8ll.png) 現在我們可以填第二層了,我們需要填的格子從左到右分別為: \begin{equation} m[1, 2], \ m[2, 3], \ m[3, 4], \ m[4, 5], \ m[5, 6] \end{equation} 也就是下圖中橘色框起來的地方: ![image](https://hackmd.io/_uploads/BJeOrY6nLll.png) 其中我們以最左邊那格 <font color = "purple">$m[1, 2]$</font> 為例,它代表什麼意思呢? 回想前面我們說 $m[i, j]$ 代表的是「第 $i$ 個矩陣乘到第 $j$ 個矩陣所需要的最少矩陣乘法數」,那 $m[1, 2]$ 就是「從第一個矩陣乘到第二個矩陣所需要的最少乘法數」了,可是 $A_1$ 乘到 $A_2$ 沒有什麼誰先乘誰後乘的問題,總共也就兩個,我們只能直接乘,所以所需要的乘法數就是: \begin{equation} \begin{cases} A_1: 30 \times 35 \\ A_2: 35 \times 15 \end{cases} \quad \Rightarrow \quad 30 \times 35 \times 15 = 15750 \end{equation} 其他格同理,這樣第二層也完成了。 接下來更上面幾層的做法都相同,如果直接看 pseudocode 可能會覺得變數太多很複雜,但我們用最直觀的方法去想就好了,我們看其中一格 <font color = "purple">$m[2, 5]$</font> 的例子: ![image](https://hackmd.io/_uploads/rkTThp28lg.png) > 劃線的箭頭可以先不用管。 首先我們再度先思考 $m[2, 5]$ 的意義,再回想前面我們說 $m[i, j]$ 代表的是「第 $i$ 個矩陣乘到第 $j$ 個矩陣所需要的最少矩陣乘法數」,那 $m[2, 5]$ 就是「從第二個矩陣乘到第五個矩陣所需要的最少乘法數」了: \begin{equation} A_2 \times A_3 \times A_4 \times A_5 \end{equation} 四個矩陣相乘有幾種可能性呢? 首先我們先認知到,不管怎麼乘,最後一步一定都是兩個矩陣相乘,對吧! 例如 $A_2$ 先跟 $A_3$ 乘、$A_4$ 先跟 $A_5$ 乘,最後我們才將 $A_2 \times A_3$ 和 $A_4 \times A_5$ 得到的兩個矩陣乘起來,也就是: \begin{equation} [(A_2 \times A_3) \cdot (A_4 \times A_5)] \end{equation} 或是另一種方式: \begin{equation} [(A_2 \times A_3 \times A_4) \cdot A_5] \end{equation} 在第二種方式裡,假如我們可以知道 $A_2 \times A_3 \times A_4$ 所需要的最少乘法數,那我們把它看成一個 $A$,這樣變成: \begin{equation} A \cdot A_5 \end{equation} 不就是兩個矩陣相乘,沒有誰先乘誰後乘的問題,所以就是前面講過的兩個矩陣的情形: \begin{equation} \begin{cases} A: 35 \times 10 \\ A_5: 10 \times 20 \end{cases} \quad \Rightarrow \quad 35 \times 10 \times 15 \end{equation} 可是這並不是 $[(A_2 \times A_3 \times A_4) \cdot A_5]$ 所需要的乘法數喔!仔細想想,我們只考慮了一整個 $A$ 和 $A_5$ 直接乘時最後他們兩個乘所需要的乘法數,可是 $A$ 裡面呢?$A_2, \ A_3, \ A_4$ 是誰先乘誰後乘?裡面又多花了多少個矩陣乘法?這個問題,其實就回到我們的 table 來看就知道了: ![image](https://hackmd.io/_uploads/rkTThp28lg.png) 在我們算紫色框框的 $m[2, 5]$ 時,因為我們是 bottom-up 做上來的,下面幾層早就填好了,$A_2, \ A_3, \ A_4$ 以哪種方式相乘會得到最少的乘法數,最少所需乘法數的值是多少,透過定義,我們轉換成符號就是 $m[2, 4]$ 對吧! 在上面的表裡,也就是紫色框框左下方綠色底線的 $4375$。 另一邊呢,雖然 $A_5$ 自己一個不用乘,但還記得我們定義了 $m[i, i]$ 對吧,因此我們可以用 $m[5, 5] = 0$ 作為 $A_5$ 乘到 $A_5$ 所需的最少乘法數,於是我們就會得到如果最後是 $[(A_2 \times A_3 \times A_4) \cdot A_5]$ 這樣相乘的情況下,我們需要最少多少個矩陣乘法: \begin{equation} 4375 + 0 + 30 \times 10 \times 15 \end{equation} 但是這樣還沒完,畢竟像前面我們說過,最後相乘的兩個矩陣,也有可能是 $[(A_2 \times A_3) \cdot (A_4 \times A_5)]$,或是其他可能,所以同樣的過程我們算一個 $m[2, 5]$ 需要對每種可能都做一遍,然後取最小值。總共有幾種可能呢?可以用離散排列組合中學到的方式去想,$A_2$ 乘到 $A_5$ 總共有幾種可能就是在 $A_2$ 和 $A_5$ 間可以切幾刀,如下圖: <img src=https://hackmd.io/_uploads/r1Xw7C3Lge.png width="500" height="auto"> > 不同顏色的底線對應到前一張圖中的底線。 下圖就有 $m[2, 5]$ 這格怎麼算的詳細過程,可以放大看: ![detailed_ex](https://hackmd.io/_uploads/Hyvcch3Ixl.png) 當我們把每種可能都算出來以後就會發現最少的乘法數即是 $7125$。 其實,回歸到 DP 的意義,就只是我們把一個大問題拆分成更小的子問題(subproblems),像是這個例子裡,我們直接看不會知道 $A_2$ 乘到 $A_5$ 會需要幾個矩陣乘法,可是當我們先分成最後是哪些矩陣乘上哪些矩陣,就能把問題化成比原本小的問題,本質上還是暴力算出所有可能。 # pseudocode 說明 如果對 pseudocode 還有興趣,以下是說明: ![code](https://hackmd.io/_uploads/H1zffwDqp.png) 前四行前面講過了,我們可以接著看第五行,它說 $l$ 從 $2$ 到 $n$,而 $l$ 是 chain length,什麼意思? 其實對照到 table 就很直覺,當 $l=2$ 時指的就是第二層(最下面那排 $0$ 視為第一層的話),為什麼? 還記得第二層的內容是: \begin{equation} m[1, 2], \ m[2, 3], \ m[3, 4], \ m[4, 5], \ m[5, 6] \end{equation} 也就是: \begin{equation} \underbrace{A_1 \times A_2}_{m[1, 2]}, \ \underbrace{A_2 \times A_3}_{m[2, 3]}, \ \underbrace{A_3 \times A_4}_{m[3, 4]}, \ \underbrace{A_4 \times A_5}_{m[4, 5]}, \ \underbrace{A_5 \times A_6}_{m[5, 6]} \end{equation} 第二層的每一格都是兩個矩陣相乘,其實這就是 $l=2$,同理,$l=3$ 其實就是在處理第三層,其餘也類似。 所以,在第五行這個 for loop 裡,每個 iteration 都是做完一層。 接下來看第六行,從前面第五行我們知道 $l$ 會從 $2$ loop 到 $n$,代表說在這個 for loop 裡,當 $l=2$ 時 $i$ 經過 $n-1$ 次 iteration,一直到 $l=n$ 時 $i=1$ 一次 iteration: <img src="https://hackmd.io/_uploads/r1uo_JaLxl.png" width="250" height="auto"> 按照前面 $l$ 的意義去想就更清楚了,例如 $l=n$ 代表的是我們在處理第 $n$ 層,也就是 table 的最上面那格,還記得我們 $m[i, j]$ 的意思嗎? ——「第 $i$ 個矩陣乘到第 $j$ 個矩陣所需要的最少矩陣乘法數」 $m[i, j]$ 的 $i$ 就是這行在 loop 的 $i$,所以這也就次為什麼 $l=n$ 時我們的 $i$ 只能為 $1$,因為在 $l=n$ 時我們要處理的是 $n$ 個矩陣相乘,這裡的 $n$ 是我們全部的矩陣總數,所以當然唯有在從第一個乘到最後一個矩陣時,我們才會乘到 $n$ 個。 再好比說,前面我們在考慮 $m[2, 5]$ 時,我們算那一層總共有: \begin{equation} \underbrace{A_1 \times A_2}_{m[1, 2]}, \ \underbrace{A_2 \times A_3}_{m[2, 3]}, \ \underbrace{A_3 \times A_4}_{m[3, 4]}, \ \underbrace{A_4 \times A_5}_{m[4, 5]}, \ \underbrace{A_5 \times A_6}_{m[5, 6]} \end{equation} 這麼多格,這一層是 $l=2$,因為每一格都是長度為 $2$ 的矩陣數相乘,在上面的圖裡,$l=2$ 我們實際算出來會從 $i=1$ loop 到 $i = n-1$,也就對應了 $m[1, 2]$ 到 $m[5, 6]$ 這些 $m[i, j]$ 的 $i = 1 \sim 5$。 :::info 結合起來講,第五行我們 loop 的是每一層,第六行我們 loop 的是一層中的每一格。 ::: 這樣的話,在每一格,也就是第六行的 for loop 裡,我們又做什麼事呢? 第 $7$ 到 $13$ 行被包在這個 for loop 裡面,也就是說我們在算每一格時都會經過這幾行,一開始我們在第 $7$ 行計算每個 $i$ 對應的 $j$,這也不難理解,因為前面決定好幾個矩陣相乘($l$)以後,我們自然就知道當第一個矩陣是第幾個($i$)時,我們要往後乘幾個,一直乘到第 $j$ 個,才會有 $l$ 個。 接下來第 $8$ 行就是將這一格 initialize 成無限大,也就是說在開始計算前,我們把除了第一層的 $0$ 以外的每一格全部都設成 $\infty$,然後每次我們如果有算到更小的值再去更新(也就是第 $11$ 行)。 而在第 $9$ 行的 for loop 中,<font color = "blue">$k$</font> 即代表<ins>在第 $k$ 個矩陣後切一刀</ins>,以前面的例子來說就是下圖中的藍色分隔線,當 $k=2$ 時,代表在 $A_2$ 後切一刀,也就是最後我們將矩陣 $A_2$ 以及作為一個矩陣的 $A_3 \times A_4 \times A_5$ 相乘: <img src=https://hackmd.io/_uploads/r1Xw7C3Lge.png width="500" height="auto"> 除此之外,我們也能發現所有的可能性,就是在 $A_2, \ A_3, \ A_4$ 後的三個位置做出分隔,而 $A_4$ 後 $k=4$ 的位置,也就是這一格最後一個矩陣 $A_j$ 前的那個位置 $j-1$。 最後,計算一格中切在不同的位置(不同的 $k$ 值)會需要多少矩陣乘法數的公式為第 $10$ 行的: :::success \begin{equation} q = \underbrace{m[i, k]}_{第 k 個矩陣以前相乘的最少乘法數} + \underbrace{m[k+1, j]}_{第 k 個矩陣之後相乘的最少乘法數} + \underbrace{p_{i-1}p_kp_j}_{最後兩個矩陣直接相乘所需的乘法數} \end{equation} ::: 每次我們會根據不同的切的位置($k$)算出不同的矩陣乘法數($q$),於是我們拿它去跟現在 table 中這格的值($m[i, j]$)相比,如果新的這種切法只需要更少的矩陣乘法,那麼我們就把這個值($q$)改寫入 table 中,並且,如果題目有要求寫出相乘的順序,就再多畫一個一樣大的 table $s$,把 $s$ 同個位置的那格記錄為 $k$。 # Reference - Cormen, T. H., Leiserson, C. E., & Rivest, R. L. (1990). *Introduction to algorithms*. MIT Press, pp.375-377.