# 演算法課程題解 - 動態規劃\: 最佳矩陣連乘 # UVa 348 ## 題目 http://domen111.github.io/UVa-Easy-Viewer/?348 給兩個矩陣 $A_{n \times m}, B_{m \times p}$ ,根據矩陣相乘的定義可以得到矩陣 $C$ ,其中: $$ C_{ij} = \sum_{k=1}^{m} A_{ik}B_{kj} $$ 根據這個定義可以知道兩矩陣相乘所需的乘法數量有 $n \times m \times p$ 次 如果有一連串的陣列要相乘,只要維持**相鄰的矩陣才可相乘**的原則,無論乘法先後順序為何答案都相同 這導致不同的乘法順序所需的乘法次數不相同 給你一些矩陣,請問這些矩陣應該怎樣相乘,才能使總共的乘法次數最少 ## 想法 By Koios 無論如何,最後一定都只會剩下兩個矩陣相乘 所以問題就是我們應該要怎麼把一連串的矩陣分成兩半,讓兩邊的矩陣相乘次數相加,加上這兩個矩陣相乘次數要是最少的 發現到切割後又會變成同類型的子問題,於是可以得到 DP 式 定義 $dp[i][j]$ 表示第 $i$ 個矩陣到第 $j$ 個矩陣相乘的最少次數 定義 $arr[i]$ 表示第 $i$ 個矩陣,其中 $arr[i][0]$ 表示該矩陣的**列數(col)**, $arr[i][1]$ 表示該矩陣的**行數(row)** 則有轉移式 $$ dp[i][j] = min(dp[i][k] + dp[k+1][j] + (arr[i][0] \times arr[k][1] \times arr[j][1])) \quad i \leq k \leq j-1 $$ 如果詢問的區間 $[i, j]$ 指包含了兩個矩陣,那麼答案就是兩個矩陣相乘次數 最後還需要回朔解 每次 DP 都可以知道最後我們選擇了哪個 $k$ ,那麼就可以依據這個遞迴 **先輸出左側的部分,再輸出右側的部分** 比較需要注意的有兩個部分 - 括號應該甚麼時候放 - 乘號應該甚麼時候放 首先,括號在區間包含的矩陣有兩個(包含兩個)以上就需要,所以在遞迴左側以及右側之前要放 `(` ,遞迴完要放 `)` 至於乘號的部分只會出現在我們切割的部分 實際做法可以參考程式碼 ### 程式碼 ```cpp= //By Koios1143 #include<iostream> #include<climits> using namespace std; const long long MaxN = LLONG_MAX; int res[15][15], Case=1; long long N, arr[15][2], dp[15][15]; long long sol(int n, int m){ // 不合法的情形可以直接回傳 0,表示不存在乘法次數 if(n >= m || n < 0 || m < 0 || n >= N || m >= N) return 0; // 只有兩個元素就直接給兩個矩陣相乘次數 if(m - n == 1){ res[n][m] = -1; return arr[n][0] * arr[n][1] * arr[m][1]; } if(dp[n][m] != MaxN) return dp[n][m]; else{ // 枚舉 k for(int k=n ; k<=m-1 ; k++){ // 題目有保證答案存在,即使不判斷是否可以相乘也沒關係 if(arr[k][1] == arr[k+1][0]){ int cnt = sol(n, k) + sol(k+1, m) + arr[n][0] * arr[k][1] * arr[m][1]; if(cnt < dp[n][m]){ dp[n][m] = cnt; res[n][m] = k; } } } return dp[n][m]; } } // 輸出區間 [n, m] 的答案 void print_ans(int n, int m){ // 只要還有兩個以上的矩陣就需要括號 if(n != m) cout<<"("; if(res[n][m] == -1){ if(n == m){ cout<<"A"<<n+1; } else{ cout<<"A"<<n+1<<" x "<<"A"<<m+1; cout<<")"; } return; } print_ans(n, res[n][m]); // 只要還能切割,中間就需要乘號 cout<<" x "; print_ans(res[n][m]+1, m); if(n != m) cout<<")"; return; } int main(){ while(cin>>N && N){ // init for(int i=0 ; i<N ; i++){ for(int j=0 ; j<N ; j++){ dp[i][j] = MaxN; res[i][j] = -1; } } for(int i=0 ; i<N ; i++){ cin>>arr[i][0]>>arr[i][1]; } sol(0, N-1); cout<<"Case "<<Case++<<": "; print_ans(0, N-1); cout<<"\n"; } return 0; } ``` ### 時間複雜度分析 預處理時間複雜度為 $O(N^2)$ 輸入時間複雜度為 $O(N)$ DP 每個狀態轉移時間複雜度約為 $O(N)$ 大概總共有 $N^2$ 種狀態,DP 總時間複雜度為 $O(N^3)$ 每筆測資時間複雜度為 $O(N + N^2 + N^3)$ ###### tags: `SCIST 演算法 題解`