# 演算法課程題解 - 動態規劃\: 最佳矩陣連乘
# UVa 348
## 題目
http://domen111.github.io/UVa-Easy-Viewer/?348
給兩個矩陣 $A_{n \times m}, B_{m \times p}$ ,根據矩陣相乘的定義可以得到矩陣 $C$ ,其中:
$$
C_{ij} = \sum_{k=1}^{m} A_{ik}B_{kj}
$$
根據這個定義可以知道兩矩陣相乘所需的乘法數量有 $n \times m \times p$ 次
如果有一連串的陣列要相乘,只要維持**相鄰的矩陣才可相乘**的原則,無論乘法先後順序為何答案都相同
這導致不同的乘法順序所需的乘法次數不相同
給你一些矩陣,請問這些矩陣應該怎樣相乘,才能使總共的乘法次數最少
## 想法 By Koios
無論如何,最後一定都只會剩下兩個矩陣相乘
所以問題就是我們應該要怎麼把一連串的矩陣分成兩半,讓兩邊的矩陣相乘次數相加,加上這兩個矩陣相乘次數要是最少的
發現到切割後又會變成同類型的子問題,於是可以得到 DP 式
定義 $dp[i][j]$ 表示第 $i$ 個矩陣到第 $j$ 個矩陣相乘的最少次數
定義 $arr[i]$ 表示第 $i$ 個矩陣,其中 $arr[i][0]$ 表示該矩陣的**列數(col)**, $arr[i][1]$ 表示該矩陣的**行數(row)**
則有轉移式
$$
dp[i][j] = min(dp[i][k] + dp[k+1][j] + (arr[i][0] \times arr[k][1] \times arr[j][1])) \quad i \leq k \leq j-1
$$
如果詢問的區間 $[i, j]$ 指包含了兩個矩陣,那麼答案就是兩個矩陣相乘次數
最後還需要回朔解
每次 DP 都可以知道最後我們選擇了哪個 $k$ ,那麼就可以依據這個遞迴
**先輸出左側的部分,再輸出右側的部分**
比較需要注意的有兩個部分
- 括號應該甚麼時候放
- 乘號應該甚麼時候放
首先,括號在區間包含的矩陣有兩個(包含兩個)以上就需要,所以在遞迴左側以及右側之前要放 `(` ,遞迴完要放 `)`
至於乘號的部分只會出現在我們切割的部分
實際做法可以參考程式碼
### 程式碼
```cpp=
//By Koios1143
#include<iostream>
#include<climits>
using namespace std;
const long long MaxN = LLONG_MAX;
int res[15][15], Case=1;
long long N, arr[15][2], dp[15][15];
long long sol(int n, int m){
// 不合法的情形可以直接回傳 0,表示不存在乘法次數
if(n >= m || n < 0 || m < 0 || n >= N || m >= N) return 0;
// 只有兩個元素就直接給兩個矩陣相乘次數
if(m - n == 1){
res[n][m] = -1;
return arr[n][0] * arr[n][1] * arr[m][1];
}
if(dp[n][m] != MaxN) return dp[n][m];
else{
// 枚舉 k
for(int k=n ; k<=m-1 ; k++){
// 題目有保證答案存在,即使不判斷是否可以相乘也沒關係
if(arr[k][1] == arr[k+1][0]){
int cnt = sol(n, k) + sol(k+1, m) + arr[n][0] * arr[k][1] * arr[m][1];
if(cnt < dp[n][m]){
dp[n][m] = cnt;
res[n][m] = k;
}
}
}
return dp[n][m];
}
}
// 輸出區間 [n, m] 的答案
void print_ans(int n, int m){
// 只要還有兩個以上的矩陣就需要括號
if(n != m) cout<<"(";
if(res[n][m] == -1){
if(n == m){
cout<<"A"<<n+1;
}
else{
cout<<"A"<<n+1<<" x "<<"A"<<m+1;
cout<<")";
}
return;
}
print_ans(n, res[n][m]);
// 只要還能切割,中間就需要乘號
cout<<" x ";
print_ans(res[n][m]+1, m);
if(n != m) cout<<")";
return;
}
int main(){
while(cin>>N && N){
// init
for(int i=0 ; i<N ; i++){
for(int j=0 ; j<N ; j++){
dp[i][j] = MaxN;
res[i][j] = -1;
}
}
for(int i=0 ; i<N ; i++){
cin>>arr[i][0]>>arr[i][1];
}
sol(0, N-1);
cout<<"Case "<<Case++<<": ";
print_ans(0, N-1);
cout<<"\n";
}
return 0;
}
```
### 時間複雜度分析
預處理時間複雜度為 $O(N^2)$
輸入時間複雜度為 $O(N)$
DP 每個狀態轉移時間複雜度約為 $O(N)$
大概總共有 $N^2$ 種狀態,DP 總時間複雜度為 $O(N^3)$
每筆測資時間複雜度為 $O(N + N^2 + N^3)$
###### tags: `SCIST 演算法 題解`