changed 4 years ago
Published Linked with GitHub

matrix chain

連鎖舉陣列相乘
定義:給定n個矩陣,求最小合併成本,對於兩個相鄰矩陣,都是可以乘的合法矩陣

存法

對於每個矩陣的行跟列,因為前一矩陣的行等於下一個矩陣的列,所以只要存一次就好了

//test 
1 2
2 5
5 7
7 3
3 9

//存法
1 2 5 7 3 9

重要:除了頭尾的元素,每個都可以代表前一個矩陣的列和當前矩陣的行

演算法

DP & Recursion

對於每個區段(l,r)都去找分割點k,分成左右,左右再各自遞迴,最後再把兩邊遞迴的結果合併,對於每次合併,都只會選擇(l+1~r-1)的元素來消除,所以l和r是會一直留著的,最後的合併就會是a[l]*a[k]矩陣和a[k]*a[r]矩陣 (a代表矩陣的行或列)

假設:1 2 5 7 3 9選擇5來分割

左邊:1 2 5
右邊:5 7 3 9

分別代表

左邊:a[1][2] a[2][5]
右邊:a[5][7] a[7][3] a[3][9]
其中a[i][j]代表矩陣的行(i)跟列(j)

兩邊都各自遞迴直到只剩一個矩陣

最後的狀態:

左邊:a[1][5]
//a[1][2]*a[2][5]=a[1][5]
右邊:a[5][9]  
//假設第一個矩陣先跟第二個乘(a[5][7]*a[7][3]=a[5][3])
//再跟第三個乘(a[5][3]*a[3][9]=a[5][9]),或者調換順序結果都是一樣的

合併:

 a[1][5]*a[5][9]
=a[l][k]*a[k][r]  l<k<r,要找到如何分割才會有最小成本

狀態轉移為:

 dp[l][r]=min(cur(l,k)+cur(k,r)+a[l]*a[k]*a[r];l<k<r)
 //左邊的合併成本 + 右邊的合併成本 + 最後這兩個矩陣的合併成本
int cur(int l,int r){ if(r-l+1<=2) return 0; //小於等於2個編號(a[i][k])代表小於兩個矩陣(a[i][k],a[k][j]),就不用合併直接回傳 if(dp[l][r]) return dp[l][r]; //有算過直接回傳 dp[l][r]=1e9; for(int k=l+1;k<r;k++){ dp[l][r]=min(dp[l][r],cur(l,k)+cur(k,r)+a[l]*a[k]*a[r]); //遞迴兩邊的成本加上最後兩個矩陣合併的成本,取min } return dp[l][r]; }

印出優先順序

要印出優先順序,首先要知道切哪邊成本最小,所以對於每次求出區間(l,r)的分割點k,存在一個陣列(cut[l][r])裡,該陣列存放(l,r)區間最小成本的分割點k,所以對於每個區間最後要輸出的都是

if(l==r) return  r;  //只剩一個矩陣,不用合併
else return      "(" + 左邊(l,k)的結果 + "*" + 右邊(k,r)的結果 + ")" 

因為根據狀態轉移,我們"分割"然後"合併",所以"分割"左右就讓遞迴去跑,最後會傳合併後的結果,找完分割點的"合併"就是' * ',終止條件是l==r,只剩一個矩陣就不用合併,大於兩個才要

要注意的是:

  1. index=每個矩陣的編號=每個矩陣的行

  2. 遞迴求優先順序時,是以index來遞迴,因為要印出編號,所以以編號(每個矩陣的行)來遞迴會比較方便,要注意印出結果的遞迴是呼叫print(0,n-1),而求最小合併成本的遞迴是rec(0,n)

  3. 為了方便print的呼叫,cut要以index為參數,也就是該矩陣的行。因為(l,r)區間的r是右界元素的列,所以在紀錄cut的時候要記錄在cut[l][r-1]才會是右界元素的index(右界元素的行)

  4. 分割時,左邊的k代表左邊右界矩陣的列,右邊的k代表右邊左界矩陣的行,所以呼叫print左邊時,要變成cut[l][r]-1才會是左邊右界的index

string print(int l,int r){ if(l==r) return "A" + to_string(l+1); else return "(" + print(l,cut[l][r]-1) + " x " + print(cut[l][r],r) + ")"; //遞迴左右兩邊的分割 } int cur(int l,int r){ if(r-l+1<3) return 0; if(dp[l][r]) return dp[l][r]; dp[l][r]=1e9; for(int k=l+1;k<r;k++){ int result=cur(l,k)+cur(k,r)+a[l]*a[k]*a[r]; if(result<=dp[l][r]){ dp[l][r]=result; cut[l][r-1]=k; //紀錄分割點:分割成編號l~k-1和k~r-1 } } return dp[l][r]; }

c112: 00348 - Optimal Array Multiplication Sequence

#include <bits/stdc++.h> using namespace std; #define MAXN 505 int dp[MAXN][MAXN]; int cut[MAXN][MAXN]; int a[MAXN]; string print(int l,int r){ if(l==r) return "A" + to_string(l+1); else return "(" + print(l,cut[l][r]-1) + " x " + print(cut[l][r],r) + ")"; } int cur(int l,int r){ if(r-l+1<3) return 0; if(dp[l][r]) return dp[l][r]; dp[l][r]=1e9; for(int i=l+1;i<r;i++){ int result=cur(l,i)+cur(i,r)+a[l]*a[i]*a[r]; if(result<=dp[l][r]){ dp[l][r]=result; cut[l][r-1]=i; } } return dp[l][r]; } int main(){ cin.sync_with_stdio(0),cin.tie(0); int n,Case=1; while(cin >> n && n){ cout << "Case " << Case++ << ": "; for(int i=0;i<n;i++) cin >> a[i] >> a[i+1]; memset(dp,0,sizeof(dp)); memset(cut,0,sizeof(cut)); cur(0,n); cout << print(0,n-1) << '\n'; } }
Select a repo