# Python-LeetCode 581 第16招 Dynamic Programming III: Tree and Subset
## DP on tree
因為Tree沒有環路,所以很多Tree的題目天生就是適合分治(divide and conquer)與DP,子問題通常就是subtree的解,基本上就是由下而上解決subtree子問題後合併成更大subtree的解,直到root為止。
實作上可以走bottom-up的方式,也可以top-down遞迴,最常見的方式就是走DFS的遞迴方式。
LeetCode上Tree的問題輸入分成兩種:給定的tree node structure,都是binary rooted tree;或者如同graph給edge list。
[(題目連結) 124. Binary Tree Maximum Path Sum (Hard)](https://leetcode.com/problems/binary-tree-maximum-path-sum/)
題目是給一個rooted binary tree,點上有權值(可能為負),要找樹上最大一條非空路徑,其點的權值總和最大。路徑必須為simple path,也就是點不可以重複。
當邊上有權值(長度)時,圖(樹)上一條最長的路徑稱為直徑(diameter)。在樹上找直徑為經典教科書問題,邊可能有不同的權值,或著unweighted,也就是所有邊都一樣長。本題的權值在點上,但解法跟找直徑是差不多的。
一個更基本的問題是對rooted tree的每一個點,找出往下走到leaf的最長路徑。這是一個非常簡單的DP,因為一個點往下的路徑必定經過他的其中一個孩子,所以每個孩子往下的最長路徑都求出來之後,從中求最大就好了。以本題的權值在點上的狀況來說,若令$d(v)$是$v$點往下路徑的最大權值,則
$d(v) = \max\{d(v.left),d(v.right)\}+v.val$
如果權值都是非負的或者允許空路徑,那麼我們可以將邊界條件定為
$d(v)=0$ if $v$=None
但本題權值可能為負又不允許空路徑,這情形就如同在一個陣列中找最大的總和的非空子陣列一樣(max subarray problem),我們可以將邊界條件設為
$d(v)=-\infty$ if $v$=None
然後把遞迴式改成
$d(v) = \max\{d(v.left),d(v.right),0\}+v.val$
事實上,一個陣列就是一個退化為每代都只有一個孩子的樹(只有一條路徑)。所以本題可以看成max subarray的變化版。
找出往下的最大路徑,但還沒解決本題,因為本題的路徑不一定是往下。在rooted tree上的路徑,一定最高點掛在某點,也就是每點往下的兩條路徑,當然可能退化成其中一條為空。所以,若$p(v)$是最高點在$v$的最長路徑,則
$p(v) = v.val+\max(d(v.left),0)+\max(d(v.right),0)$
以下是以DFS走訪的範例程式。$dfs(r)$會回傳兩個值:$r$點以下子樹的最佳解,以及$r$點往下的最大路徑權值。第10行是邊界條件,我們回傳兩個無窮大。
第11 ~ 12行分別遞迴呼叫兩個孩子,取得子樹的解,然後第13 ~ 15行合併解並回傳。
時間複雜度$O(n)$,$n$是節點數量。
```python=
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def maxPathSum(self, root: Optional[TreeNode]) -> int:
def dfs(r): # return max_solution, max_path from r
if not r: return -1001, -1001
lmax,lpath = dfs(r.left)
rmax,rpath = dfs(r.right)
imax = max(lmax, rmax, r.val+max(lpath,0)+max(rpath,0))
ipath = r.val + max(0, lpath, rpath)
return imax, ipath
#main
ans,_ = dfs(root)
return ans
```
[(題目連結) 310. Minimum Height Trees (Medium)](https://leetcode.com/problems/minimum-height-trees/description/)
給一棵樹,請問從哪個點當root吊起來樹的高度最小。本題輸入的樹是以graph edge的形式,且不一定是binary。本題每一個邊的長度都以1計算(unweighted)。
因為diameter是樹上最長的path,所以從diameter的中點吊起來,樹的高度最小,這一點稱為樹的center。當diameter的長度(edge個數)是偶數時,是center恰有一個,若diameter長度是奇數時,center恰有兩個。此外,樹的diameter不唯一,但所有diameter都會交集在center(s),以上這些性質都不難證明。
根據以上的性質,我們只要找出diameter,然後找中點就是答案。找diameter可以像前一題的DP的方式,也可以運用以下圖論性質:**距離樹上任一點最遠的一點,必是某直徑的一端點**。因此,我們可以用下列步驟求直徑以及center。
1. 任選一點,例如$0$,找出距離$0$最遠的一點$v$;
2. 找距離$v$最遠的一點$u$,$u$與$v$的距離即為直徑長度;
3. 找$u,v$路徑的中點。
我們寫一個DFS計算某$v$點出發到所有點的距離,並回傳最大的距離,以及最遠的點。如果最遠的點不唯一,回傳其中任何一點。為了要能是之後找到路徑中點,我們在DFS過程設定每個點的parent,如此只要往上一路沿著parent就可以回溯與出發點之間的路徑。
以下是範例程式。第3 ~ 6行先轉換輸入為每個點的adjacency list。
第7 ~ 13行是DFS,會回傳最大距離與最遠的點,並設定每個點的parent。Tree的DFS可以不必使用visited來避免重複,因為tree沒有環路,$v$點的鄰居中只要不是它的parent的都是要往下走的點而且不會重複。本題的寫法因為希望記錄每個點的parent,所以呼叫前初設一個parent的list,並將每個點的parent都設為-1。回傳值先初設為$(-1,v)$,然後呼叫他的孩子進行DFS的結果比大小,距離放在第一個參數,所以是先比距離。最後把距離加一後回傳。
主程式先從$0$做第一次DFS,我們要的是端點$v$,然後從$v$點做第二次DFS,我們要的是距離$diameter$與端點$u$,以及設定的parent,然後從$u$點回溯parent一半diameter的長度就是中點,如果是奇數,他的parent也是中點。
時間複雜度$O(n)$。
```python=
class Solution:
def findMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]:
adj = [[] for i in range(n)]
for u,v in edges:
adj[u].append(v)
adj[v].append(u)
def dfs(v,p): #return (max distancs,node)
imax = (-1,v)
for u in adj[v]:
if u != p[v]:
p[u] = v
imax = max(imax, dfs(u,p))
return (imax[0]+1,imax[1])
#
p1 = [-1]*n
maxd,v = dfs(0,p1)
p2 = [-1]*n
diameter,u = dfs(v,p2)
for i in range(diameter//2):
u = p2[u]
if diameter&1: return [u,p2[u]]
return [u]
```
**Another solution:剝洋蔥**
這題tree center有另外一個解法,筆者戲稱為剝洋蔥的方法。洋蔥是一層一層的,想像一棵樹,如果我們將目前的所有leaf(degree=1的節點)全部刪除,那麼直徑的兩端會各被刪除一個節點。重複這個動作,如果剩下兩個點一條邊,這兩點就是center。如果直徑長度是偶數的狀況,最後會刪到剩下一個點。
以下是範例程式。一開始先把edge轉換為每個點的鄰居,$deg$是每個點鄰居數(第9行)。先抓出一開始的葉節點(第10行)。跑一個while迴圈直到點數小於等於2為止。
每次進入迴圈後,我們打算把leaf刪掉,並且抓出剝掉葉節點之後新的葉節點,我們檢查葉節點的鄰居,將每個鄰居$deg$減去1,若減去後恰好為$1$就是新的葉節點,我們不必檢查已經刪除的鄰居,因為他們的$deg$會變成$0$而不是$1$,不會誤判。在做完這動作之後,我們修改$n$的值,並且讓leaf指向新的葉節點。
這樣在最後如果剩下兩點時,他們的$deg$都是$1$,所以當然沒問題會找到正確的答案。但是如果是剩下一點的情況呢?他的$deg$會是$0$,這樣還有正確找到答案嗎?事實上是的,因為它的$deg$在被刪為$0$之前,會先被刪為$1$,而在那時被納入$tem$。
時間複雜度一樣是$O(n)$。
```python=
class Solution:
# iteratively remove all leaves, until 1 or 2 nodes
def findMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]:
if n==1: return [0]
adj = [[] for i in range(n)]
for u,v in edges:
adj[u].append(v)
adj[v].append(u)
deg = [len(adj[v]) for v in range(n)]
leaf = [v for v in range(n) if deg[v]==1]
while n > 2:
tem = [] # new leaf
for v in leaf:
for u in adj[v]:
deg[u] -= 1
if deg[u]==1: tem.append(u)
# even when unique center, it will be added to leaf
n -= len(leaf)
leaf = tem
return leaf
```
[(題目連結) 834. Sum of Distances in Tree (Hard)](https://leetcode.com/problems/sum-of-distances-in-tree/)
給一棵樹,點的編號是$0$ ~ $n-1$,對每一個點$i$,要算出$i$到其他點的距離總和。
在tree上求一點到其他點的距離只需要$O(n)$,但是如果每個點都做一次起點,那就會導致$O(n^2)$的時間複雜度。
關鍵點在每次不需要全部重算。令$d(v)$是點$v$到所有點的距離和。若$(u,v)$是樹的一個邊,也就是兩點互為鄰居,我們可以發現$d(u)$與$d(v)$是有關係的。最重要的性質是:**樹上任兩點之間只有一條路徑**。

以這個示意圖來說明,上方所有的點要走到$v$都會經過$(u,v)$這個邊,而且上方任一點走到$v$距離都是走到$u$距離再加$1$。相同的,下方每一個點到$u$距離都是走到$v$距離再加$1$。因此,
$d(v) = d(u)+(n-n_v)-n_v$
其中$n$是總點數而$n_v$是$v$點以下的總點數,$n-n_v$也就是$u$點以上的點數。
我們可以從任一點當root,先算出每個點以下的點數,以及root到所有點的距離總和,然後根據上面的遞迴式由上而下計算出每個點到其他點的距離總和值。
以下是範例程式。我們任選$root=0$,第11 ~ 16行是dfs是計算每個點以下的節點數放在$n\_node[]$中,也計算一個點到它下面所有點的距離總和,這部分我們只要$root$的結果。
接著第17 ~ 21行是另外一個遞迴函數,由上而下,根據parent的$d$值計算孩子的$d$值。
時間複雜度$O(n)$。
```python=
class Solution:
# p is parent of r. total(r) = total(p) + n -2node(r)
def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
adj = [[] for i in range(n)]
for u,v in edges:
adj[u].append(v)
adj[v].append(u)
root = 0 # random root
n_node = [1]*n # num of node in subtree
d = [0]*n # total distance
def dfs(r,p): # root r with parent p. find dist to descendant
for c in adj[r]:
if c==p: continue
dfs(c,r)
n_node[r] += n_node[c]
d[r] += d[c]+n_node[c]
def topdown(r,p): # find total distance
for c in adj[r]:
if c==p: continue
d[c] = d[r]+n-n_node[c]-n_node[c]
topdown(c,r)
#main
dfs(root,-1)
topdown(root,-1)
return d
```
[(題目連結) 968. Binary Tree Cameras (Hard)](https://leetcode.com/problems/binary-tree-cameras/description/)
題目說有一個rooted binary tree,如果在一個點上裝監視器,可以監看到自己與鄰居。請問最少要裝幾個監視器才可以使得所有的點都被監看到。
這個問題是圖論上的dominating set problem。在一般圖上是NP-hard,在tree上則可以用DP在linear time解。這個DP遞迴關係稍微有點難。我們可秉持DP的基本原則:若列不出遞迴式時,可以進一步分類。
若令$d(v)$是$v$點以下的最佳解,我們很難列出遞迴式,各位讀者可以自行試看看。究其原因是,我們不知道$v$點有沒有放監視器以及$v$點有沒有被下方的監視器監看到。因此我們將$v$點以下的最佳解區分成三類:
* $d1(v)$: $v$點有監視器的最佳解,當然$v$點也被監視到了;
* $d01(v)$: $v$點沒有監視器但被下方監視到的最佳解;
* $d00(v)$: $v$點沒有監視器而且也沒有被下方監視到的最佳解;
分類以後遞迴式就很簡單可以照定義列出了,請看以下範例程式。我們就不另外列出與解釋了。
時間複雜度$O(n)$。
```python=
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
# tree dominating set
def minCameraCover(self, root: Optional[TreeNode]) -> int:
# dxy: x=1 if root with camera;
# 01: root without camera but dominated
# 00: root without camera and not dominated
def dfs(v): # return camera in subtree (1,01,00)
if not v: return 1000,0,0
d1,d01,d00 = dfs(v.left)
p1,p01,p00 = dfs(v.right)
dp1 = 1+min(p1,p01,p00)+min(d1,d01,d00)
dp01 = min(d1+min(p1,p01), p1+min(d1,d01))
dp00 = d01+p01
return dp1,dp01,dp00
#main
k1,k01,k00 = dfs(root)
return min(k1,k01)
```
## Subset
DP的子問題(狀態)除了可能由一維、二維與subtree表示外,也可能有更複雜的結構,例如狀態可能是某個集合的所有或一部分子集合來表示。這一類的問題通常都較困難。子集合通常編碼為一個非負整數:第$i$個bit為$1$表示子集合包含第$i$個元素,因此這一類的問題有時也被稱為**Bit DP**或**Bitmask DP**。
讀者要做這一類的問題或看教學之前,需要先了解bitwise運算與bitmask表達集合的一些基本運算。
Python的位元運算有6個:AND, OR, NOT, XOR, 左移位與右移位。
以下我們列出一些在表示集合時常用的位元運算,其中$n$是成員個數,$0\leq s < 2^n$是一個子集合的代表整數,$0\leq i <n$是一個成員。
* $i$的代表數字就是$1<<i$,也就是1左移$i$位,也就是$2^i$。某些成員所成的集合就是成員的代表數字相加或OR起來的結果。
* 空集合就是0,宇集合(全部成員)就是$(1<<n)-1$,也就是$n$個bit都是1,也就是$2^n -1$。
* 將$i$加入$s$,$s|(1<<i)$。這個用法比用$s+(1<<i)$好,因為若$i$已經在$s$中,用加法會是錯的。
* 檢查$i$是否在$s$中,$s\&(1<<i)\ !=\ 0$。
* 從$s$中刪除$i$,$s\&( \text{~} (1<<i))$。
* 兩個集合聯集直接$s\ |\ t$,交集為$s\ \&\ t$。
---
以下我們來看一些例題。
[(題目連結) 1349. Maximum Students Taking Exam (Hard)](https://leetcode.com/problems/maximum-students-taking-exam/)
教室的座位看成$m\times n$的矩陣,有些座位壞掉不能坐人,除此之外,若學生坐在某個位置,它可以看到左邊、右邊、左前、以及右前四個座椅上人的答案,請問該教室最多可以有幾個學生入座考試,滿足每一個考試學生不會看到其他學生的答案。$1\leq m,n\leq 8$。

顯然每一列的座位是否可以坐只跟前一列與自己這一列有關,與更前面的座位無關。因此,我們若依序計算前1列、前2列,...,直到最後一列。計算前$i$列時,對於前$i-1$列,我們記錄下各種可能的最後一列坐人的狀態,每個狀態是一個$n$-column的子集合,對於每一個子集合我們紀錄在該狀態下前面可以坐得下的最大人數。
對第$i-1$的每一種狀態,我們可以根據第$i$列的坐椅狀況產生所有可能的第$i$列狀態。
以下是範例程式。第6行我們先寫一個函數$one\_row(row,pre)$,它根據目前這一列的坐椅狀況$row$以及前一列的狀態$prev$,產生所有可能的這一列狀態以及該狀態有幾個$1$-bit,其中狀態都是子集合的代表整數。做法是從$0$(空集合)開始,逐步檢查每一個column,把現有狀態中可以在該column擺$1$的,全部放上$1$再附加回原來的狀態列表(第10行),至於可以放$1$的規則就是依照題意:首先坐椅必須是好的(第9行),此外,本列前一個(左方)不是$1$、前一列的前一個(左前方)不是1,以及前一列的下一個(右前方)不是1(第10 ~ 12行)。這裡我們採取1-index,也就是從bit-1開始,原因是省略左邊邊界的檢查。
在主程式方面,我們用字典Counter()來放置每個狀態它所坐的人數,$prev$是前一列,每次考慮一列,把新產生的狀態暫放在$tem$,在迴圈末再把$tem$丟回$prev$做為下一列的前一列。初始的前一列給$(0,0)$是空集合與0個人。第17行歷遍前一列的每一個狀態$s1$,對每一個$s1$都產生所有可能的本列狀態$s2$(第19行),然後更新$s2$的最大人數(第20行)。最後答案在最後一列各種狀態的最大人數。
一列的狀態數有$O(2^n)$種,每一個狀態最多產生$O(2^n)$個下一列狀態,再加上有$m$列,所以時間複雜度總計$O(m\times 2^{2n})$。本題$m,n\leq 8$。
```python=
class Solution:
# dp on the last row, 2^8 states
def maxStudents(self, seats: List[List[str]]) -> int:
m,n = len(seats),len(seats[0])
# column is 1-indexed
def one_row(row,prev): # return all possible susbet and size
w = [(0,0)] # column set and its size
for i,c in enumerate(row):
if c=='#': continue
w += [(s|(1<<(i+1)), k+1) for s,k in w \
if s&(1<<i) == 0 and prev&(1<<i)==0 and
prev&(1<<(i+2))==0]
return w
prev = Counter({0:0}) # (set,size)
for row in seats:
tem = Counter()
for s1 in prev:
k1 = prev[s1]
for s2,k2 in one_row(row,s1): # (set,size)
tem[s2] = max(tem[s2],k1+k2)
prev = tem
#main
ans = max(prev.values())
return ans
```
[(題目連結) 1125. Smallest Sufficient Team (Medium)](https://leetcode.com/problems/smallest-sufficient-team/description/)
有$m$種需要的技術,每種以一個字串表示;有$n$個人,每個人擁有某些需要的技術。要選出最少的一組人,使得每種需要的技術都至少有一個被選出的人具備該技術。$m\leq 16$,$n\leq 60$。保證答案存在。
既然要找出一組人,直覺的想法是枚舉所有可能的人的子集合,無奈$O(2^n)$太大,因為本題$n$可以到$60$。我們換一個角度想,雖然人數$n$大,但是技術種類$m$小,如果是相同的技術組合,不同組合的人並不重要,只要留一組最小的就好了。所以我們應該對技術的各種組合進行DP。
對於一個技術組合$s$,令$dp(i,s)$為前$i$個人擁有$s$的最少人數,而則
$$
dp(i,s) = \min\Bigg\{
\begin{array}{ll}
dp(i-1,s) \\
\min\{dp(i-1,t)+1: t\cup skill(i)=s\}
\end{array}\Bigg\}
$$
實作時,直接從前一列$dp(i-1,*)$計算出下一列$dp(i,*)$。此外,本題不只要算最少人數,而是要輸出一組最少人數的組合(index),因此,DP過程必須以list儲存組合。
以下是範例程式。
首先,技術是以字串表示,我們先建立一個將技術字串轉換為整數的字典,如果把第$i$項技術轉換為$i$,將來使用時還是要換成$1<<i$,不如直接直接轉換為$1<<i$。第5 ~ 7行是建立轉換的字典。
第8行$hit$是儲存每一種技術組合的最小團隊的字典,初始是空集合不需要任何人。
第9行歷遍每一個人的技術,一開始先將他的技術轉換為一個代表整數(第10 ~ 11行),然後將目前的$hit$複製一份到$tem$,我們將根據舊的($hit$)更新成新的結果放在$tem$,在迴圈結尾再把它丟回$hit$。第13開始歷遍每一個$hit$中的技術組合,檢查搭配$myskill$是否變成更好的解,如果是就放入$tem$中(第14 ~ 16行)。
最後的結果在$hit[(1<<m)-1]$,因為$(1<<m)-1$是$m$個bit都是1的整數,代表全部技術的集合。
時間複雜度$O(n\times 2^m)$。
在歷遍技術組合的時候,我們也可以不使用字典,對於每一個人,將$[0,2^m-1]$所有的技術都狂掃一遍,就worst case來說是一樣的,但LeetCode上跑出來的時間會比較差,因為是全測資的時間。在下面範例程式的後半段註解中,就是這樣的一個寫法,提供讀者參考。
```python=
class Solution:
def smallestSufficientTeam(self, req_skills: List[str], people: List[List[str]]) -> List[int]:
m,n = len(req_skills),len(people)
# encoding each skill string to int
idx = {}
for i,s in enumerate(req_skills):
idx[s] = 1<<i # i-th bit is 1
hit = {0:[]} # min hitting set of skill subset
for i,skill in enumerate(people):
myskill = 0 # encoding my skills to int
for ss in skill: myskill |= idx[ss]
tem = hit.copy() # updated result
for subset in hit:
if (subset|myskill) not in tem or \
len(hit[subset])+1 < len(tem[subset|myskill]):
tem[subset|myskill] = hit[subset]+[i]
hit,tem = tem,hit # swap new result
#
return hit[(1<<m)-1] # all skills
'''
Another method, without dict hitSet
hitSet = [list(range(n)) for i in range(1<<m)]
hitSet[0] = []
for i,skill in enumerate(people):
myskill = 0
for ss in skill: myskill |= idx[ss]
for j in range((1<<m)-2,-1,-1):
if len(hitSet[j])+1 < len(hitSet[j|myskill]):
hitSet[j|myskill] = hitSet[j]+[i]
#
return hitSet[-1]
'''
```
[(題目連結) 847. Shortest Path Visiting All Nodes (Hard)](https://leetcode.com/problems/shortest-path-visiting-all-nodes/)
題目說有一個無向圖,點的編號是$0$ ~ $n-1$,要找一條路徑走過所有的點,路徑的起點與終點可以任選,邊也可以重複走。邊上無權重,點數不超過12。
圖上找一條最短路徑歷遍所有節點是著名的旅行推銷員問題,**Travelling Salesperson Problem (TSP)**。這個問題屬於NP-hard,也就是尚未找到多項式時間的算法,有很多的形式。
本題如果先求出所有點到點的最短路徑,就會得到一個完全圖,任兩點有個距離,在此完全圖上找一條包含所有點且點不重複的路徑,就是答案,這就是標準的TSP。
解TSP最直接的方法是枚舉所有的可能,也就是要嘗試所有$n!$種可能。但$n!$是一個上升非常快速的函數,時間複雜度$O(n!)$的程式通常只能勉強做到$n=10$或$11$,因為$11!$大約是4千萬,$12!$則已經到4億多了。
雖然同樣是exponential,$2^n$要比$n!$成長慢得多。TSP有一個著名的DP可以將時間複雜度降到$O(n^2\cdot 2^n)$。讓我們來考慮一下子問題:哪些路徑可以歸成一類,只留下最好的一條?或者說當我們想將一條路徑再往後增加一點時,我們需要考慮哪些因素?
一條路徑的狀態可以歸納為兩點:(1).目前已經走過的點;(2).目前所停的最後一點。也就是任何一條路徑$P$可以用$(S(P),last(P))$表示,$S(P)$是$P$上點的集合,$last(P)$是$P$的最後一點。兩個路徑$P$與$Q$,如果$S(P)=S(Q)$且$last(P)=last(Q)$,我們只需要留其中比較好的一條,因為,很顯然的,如果$P+R$是一個解,若且為若$Q+R$是一個解('+'是指路徑串接),而且路徑串接的成本(長度)是兩者獨立計算。
根據以上分析,我們可以把經過點的集合逐步由小到大推來產生各個狀態的最短路徑,因為點集合有$2^n$個,最後的點有$n$種可能,所以狀態(經過點的集,最後一點)總共有$O(n\cdot 2^n)$,每一個狀態需要考慮$O(n)$種可能(下一步),所以時間複雜度是$O(n^2\cdot 2^n)$,這要比枚舉的$O(n!)$要好很多。
這一題因為是TSP的簡化型,所以有更簡單一點的方法,以下是用TSP的解法來做的範例程式。稍後再提供一個比較簡單快速的方法。
首先先將輸入轉換成距離矩陣$d$,不存在的邊設成$n(\infty)$,然後用Floyd-Warshall三迴圈演算法算出all-to-all distance(第5 ~ 14行)。
接著第16行是從各點出發的初始狀態,點集合用整數表示,第$i$個點所成的集合就是$(1<<i)$,也就是$2^i$。第17行開始的迴圈做$n-1$次,每次把每一個狀態(第19行)都嘗試多增加一點(第20行),將
每種可能狀態比較好的都放入$tem$中,這回合結束時,再將$tem$放回$dp$以便進行下一回合。
最後的答案在最後狀態中的最小值(各種可能終點)。
```python=
class Solution:
# find all to all shortest path and then TSP using DP
def shortestPathLength(self, graph: List[List[int]]) -> int:
n = len(graph)
d = [[n]*n for i in range(n)]
for i in range(n): d[i][i]=0
for i,adj in enumerate(graph):
for j in adj:
d[i][j]=d[j][i]=1
#Floyd-Warshall
for i in range(n):
for u in range(n):
for v in range(n):
d[u][v] = min(d[u][v],d[u][i]+d[i][v])
# DP state = (visited, end_vertex)
dp = {(1<<i,i): 0 for i in range(n)} #starting state
for step in range(n-1):
tem = {}
for s,v in dp: # visited set, end vertex
for i in range(n): # next vertex
si = s|(1<<i) #
if s == si: continue
if (si,i) not in tem or dp[s,v]+d[v][i]<tem[si,i]:
tem[si, i] = dp[s,v]+d[v][i]
dp = tem
return min(dp.values())
```
這一題有比較簡單的寫法,因為邊上無權重,每一步都是$1$,我們可以看成在所有狀態$(S,v)$上的一個最短路徑問題,所以用BFS就可以做,以下是範例程式。
第7行的$visit$是紀錄已經走過的狀態,第8行的$que$是BFS待走訪的點,用一個deque來做,然後就是標準BFS的做法。這裡$que$中放的是$(s,v,d)$,其中$s$是經過的點集合(編碼為一整數)、$v$是最後的點、而$d$是此路徑目前的長度。
對在LeetCode上面的測資,這個方法要比前一個速度快。時間複雜度方面,因為每個點往外走的邊會被看$2^n$次,所以複雜度是degree總和的$2^n$倍,也就是$O(m\cdot 2^n)$,$m$是邊數。Wotst case其實與前者相同,但sparse graph這個就跑得快一些,另外前面那個多跑了一個$O(n^3)$的最短路徑,在$n$小的時候,這個是影響時間不可忽略的因素。
```python=
class Solution:
# TSP using BFS on (visited, last)
def shortestPathLength(self, graph: List[List[int]]) -> int:
n = len(graph)
allnode = (1<<n)-1
# DP state = (visited, last_vertex)
visit = {(1<<i,i) for i in range(n)} #starting state
que = deque([(1<<i,i,0) for i in range(n)])
while que:
s,v,d = que.popleft()
if s==allnode: return d
for u in graph[v]: # extend one step
if (s|(1<<u), u) in visit: continue
visit.add((s|(1<<u), u))
que.append((s|(1<<u), u, d+1))
#end while
return -1
```
[(題目連結) 1434. Number of Ways to Wear Different Hats to Each Other (Hard)](https://leetcode.com/problems/number-of-ways-to-wear-different-hats-to-each-other/)
有$n$個人與40種帽子,每個人喜歡某些帽子,請問有多少種可能的組合,滿足每個人都戴喜歡的帽子而且每個人戴的帽子都不同。$n\le 10$
這題跟前面的技術團隊有些類似,如果枚舉每個人可能戴的帽子,那種類就太多了。因為$n\le 10$,我們在人的子集合上DP。把帽子依照編號一一納入考量,對於一個人的子集合$S$,以$d(S,i-1)$表示目前(前$i-1$種帽子)$S$可能的組合數,帽子$i$可以戴在$p$這個人且$p\notin S$,則$S\cup \{p\}$在前$i$頂帽子會對應產生$d(S,i-1)$種組合,也就是
$$
dp(S,i)=\sum_{p\in S\cap H(i)} dp(S-\{p\},i-1)
$$
其中$H(i)$表示第$i$頂帽子可以戴在哪些人頭上。
以下是範例程式。第6 ~ 9行是將輸入轉換為$H[i]$,它是喜歡戴第$i$頂帽子的人,第$p$個人用整數$1<<p$表示。上面的遞迴式是為了方便了解而寫成二維的樣子,因為$dp(i,*)$只跟$dp(i-1,*)$有關,顯然可以用滾動陣列的方式,本題甚至可以只用一個list就處理好,重點在於因為不能踩到這一回合新設的值,所以我們掃描子集合時,由大到小來進行,這樣本回合所更新的值就不會被引用到。
第13行開始一一考慮帽子,第14行由大到小歷遍所有子集合,第19行則檢視喜歡戴第$i$頂帽子的人。如果條件滿足就依照遞迴式更新DP值。
時間複雜度方面,若$n$是人數,$m$是帽子數($=40$),而$K\le mn$是每個人喜歡的帽子數量總和,第13 ~ 17行是主要花時間的部分,
$$
\sum_{i=1}^{41}\sum_{s=0}^{2^n}|H[i]| =2^n\times \sum_{i=1}^{41}|H[i]| =2^n\times K \le mn2^n
$$
所以時間複雜度是$O(mn\cdot 2^n)$,
```python=
class Solution:
# dp on set of people
# O(2^n *mn) time
def numberWays(self, hats: List[List[int]]) -> int:
n = len(hats)
H = [[] for i in range(41)] # hat can assign to whem
for p,hlist in enumerate(hats):
for i in hlist:
H[i].append(1<<p) # hat i can assign to p
nn = 1<<n
d = [0]*nn
d[0] = 1 # empty set
for i in range(1,41): # for each hat
for subset in range(nn-2,-1,-1): # subset of people, backward
for person in H[i]:
if (person & subset) == 0: # not in subset
d[subset|person] += d[subset]
return d[nn-1]%1000000007
```
[(題目連結) 691. Stickers to Spell Word (Hard)](https://leetcode.com/problems/stickers-to-spell-word/description/)
有$n\le 50$種貼紙,每張貼紙上有一個字串,現在想要挑選最少的貼紙,以上面的字母剪下後拼出一個字串$target$,請問最少需要幾張貼紙。每種貼紙無限供應,也就是使用數量不限,字串中可能有重複出現的字母,貼紙上的字串長度不超過10,$target$的長度不超過$15$。
因為字母是剪下來拚字,所以字母的順序無關,字母可能重複,所以貼紙與$target$都可以看成multi-set(元素可以重複的集合),要找最少的multi-set相加的結果涵蓋$target$,相加是指重複元素的數量相加。
以$target$子集合的角度來做DP,最多只有$2^{15}$種子集合,若$S$是一個$target$的子集合,
$dp(S) = \min_{t\in sticker}\{dp(S-t)+1\}$
邊界條件為$dp(\emptyset)=0$。
實作時有一些需要的技巧。我們需要multi-set的減法,此外要使用DP,我們用字典,字典需要可以hash的東西當作key。作法不只一種,在下面的範例程式裡,我們用字串來當做字典的key,為了方便,我們一律將字串中的字母排序,此外,我們自己寫一個字串的"減法"。
第4 ~ 13行是自製的字串減法,將字串$t$中的字母從$s$中刪除(如果該字母存在),因為字母排序過,所以掃一遍就可以做到。
第15 ~ 17行是一開始將字串中的字母都加以排序,第18行是top-down DP遞迴要用的字典。第19 ~ 27行是遞迴函數,目的是求出字串$req$需要的貼紙數量,一開始第20 ~ 21行是終止條件,接著,跑一個迴圈,嘗試使用每一張貼紙,計算使用該貼紙後的需求數量,並求出最小值。第24行加一個條件,避免該貼紙毫無幫助。
時間複雜度$O(mn\cdot 2^m)$,其中$m$是字串長度,而$n$是貼紙數量。
```python=
class Solution:
# dp, subsequence of target, no counter but sorted str
def minStickers(self, stickers: List[str], target: str) -> int:
def minus(s,t): # two sorted string
res =""
i=0
for c in s:
while i<len(t) and t[i]<c:
i+=1
if i>=len(t) or t[i]!=c:
res += c
else: i += 1
return res
#end minus
for i in range(len(stickers)):
stickers[i]="".join(sorted(stickers[i]))
target="".join(sorted(target))
dp = {}
def dfs(req):
if not req: return 0
if req in dp: return dp[req]
best = 1000
for w in stickers:
if req[0] not in w: continue
tem = minus(req,w)
best = min(best,dfs(tem)+1)
#end for
dp[req]=best; #print('str',reqstr,best)
return best
#end dfs
ans = dfs(target)
#print(ans); print(dp)
if ans==1000: return -1
return ans
# 158ms, 85%
```
Python中的Counter()是用來處理multi-set的字典,這題使用Counter可以簡化一點,不必自製減法,不過因為要用一個可以hash的資料型態做字典,所以並不能省太多事,以下是用這樣的範例程式,提供讀者參考。第4 ~ 9行是字串$s$減去第$k$張貼紙的函數,此處貼紙都已經轉成Counter()了(第12 ~ 13行)。執行速度毀比前面一支慢一點。
```python=
class Solution:
# dp, subsequence of target, using counter for set subtract
def minStickers(self, stickers: List[str], target: str) -> int:
def minus(s,k): # string s - stickers[i]
cs = Counter(s) - stickers[k]
res = ""
for c,i in cs.items():
res += c*i
return res
#end minus
n = len(stickers)
for i in range(n):
stickers[i] = Counter(stickers[i])
dp = {}
def dfs(req):
if not req: return 0
if req in dp: return dp[req]
best = 1000
for i in range(n):
if req[0] not in stickers[i]: continue
tem = minus(req,i)
best = min(best,dfs(tem)+1)
#end for
dp[req] = best
return best
#end dfs
ans = dfs(target)
if ans==1000: return -1
return ans
# 324ms 58%
```
[(題目連結) 1655. Distribute Repeating Integers (Hard)](https://leetcode.com/problems/distribute-repeating-integers/)
題目說有一群數字$nums$,數字也許很多,但最多不超過50種不同的數字。此外,有一些需求,第$i$個需求是一個正整數$q[i]$,我們要從$nums$中挑出$q[i]$個相同的整數來滿足第$i$個需求,挑的是哪一個數字都可以,但必須是相同的。請問是否可以滿足所有的需求,當然,每一個$nums[j]$不能分給兩個人。
首先,我們可以把$nums$中相同的數字統計一下,同一種數字,我們可以看成一個箱子(bin),容量就是這種數字的數量。目標是把這些需求放入箱子,每個箱子裝的需求總和不可超過它的容量。
DP主要的步驟要記錄哪些需求的組合可以被滿足,對於$m$個需求,有$2^m$個需求的組合,初始時只有空集合可以被滿足,我們一一考慮每一個箱子,計算出前$i$個箱子可以滿足哪些組合。這個過程其實很像背包問題的DP,只是背包問題時紀錄的是哪些重量,而此處紀錄的是哪些組合(子集合)。
考慮第$i$個箱子時,我們檢查單獨第$i$個箱子可以滿足些組合,搭配前$i-1$個箱子能滿足的組合就可以了。
以下是範例程式。一開始第4行是把原輸入的整數數列利用Counter統計出哪些數字出現多少次,然後把出現次數(values())放到一個List中並且由大到小排序。接著,我們計算出對於$2^m$個需求子集合,每一個子集和的需求總和放在$subsum$,這裡我們寫了一個遞迴生成所有子集合的函數$gen_subsum$,時間複雜度是$O(2^m)$。
第16行呼叫$gen\_subsum$,第17行將所有子集合的代表整數依照總和由小到大排序,這個是為了方便決定每一個箱子可以滿足哪些組合。第18行的字典$d$是紀錄已經可以滿足的子集合。
第19 ~ 26行是DP的過程,歷遍每一個箱子。每次由$d$計算出納入此箱子後可以被滿足的組合放在$tem$,迴圈末尾再放回$d$。因為$subset$已經依照總和排序,所以只要是比不超過箱子容量的組合都是可以加入的,所以用雙迴圈枚舉就可以了。
時間複雜度部分,產生子集合的總和是$O(2^m)$,$m\le 10$是需求數。DP部分,箱子數$n\le 50$,每個箱子最多可以滿足$O(2^m)$個組合,那麼複雜度將會是$O(n\times 2^m\times 2^m)=O(n\times 2^{2m})$。這看起來不是很好,不過如果每個箱子能夠滿足的組合數都很大的話,可能很快就找到滿足的方法而提前結束;如果箱子能滿足的組合數不多,複雜度就不會那麼壞。前述的複雜度或許有高估,實際跑起來並沒有想像的壞。
這一類演算法的複雜度是否高估或有更精準的界線並不一定是容易證明的。本題實際的測資用branch+backtracking跑得更快,我們將在其他單元再說明。
```python=
class Solution:
# DP
def canDistribute(self, nums: List[int], quantity: List[int]) -> bool:
bins = sorted(Counter(nums).values(),reverse=True) # size of bin
n,m= len(bins),len(quantity)
subsum = [0]*(1<<m)
def gen_subsum(i,isum,iset): #sum of all subsets of q
nonlocal quantity,subsum
if i==m:
subsum[iset] = isum
return
gen_subsum(i+1,isum,iset)
gen_subsum(i+1,isum+quantity[i],iset|(1<<i))
return
#
gen_subsum(0,0,0)
subset = sorted(range(1<<m),key=lambda x:subsum[x])
d = {0} # subset can be satisfied
for b in bins: #
tem = set()
for s in subset: # sum<=b, can be satisfied by b
if subsum[s] > b: break
for t in d:
tem.add(t|s)
d = tem
if (1<<m)-1 in d: return True
#
return False
```
**Set-partition**
集合的partition就是將一個集合分成若干個不相交的子集,而聯集必須等於原集合。集合的分割通常比爪子集合的題目更難一點,因為要找的是若干子集合而不只是一個子集合。這一類的問題通常也都是NP-hard。直覺的搜尋方式類似如下:
```
sol(s):
if terminal condition:
do something
for each valid subset of s:
sol(s-subset)
```
這裡有個需要處理的問題:通常我們會將集合encoded成為一個整數,對於整個集合找子集合,我們可以用迴圈枚舉的方式或者遞迴枚舉,這樣我們就搞了個遞迴中再有另外一個遞迴的架構,姑且不論好不好,我們還有另外一個問題,就是每次進來的集合可能是個子集合而非整個集合。一個直覺的處理方法是將子集合內的元素重新編號為$\{0,1,2...\}$,但是這通常比較麻煩。以下介紹一個利用位元特性枚舉一個部份集合的所有子集合的方式:
```python=
s = 0b0110010 # an example {1,4,5}
subset = s
while subset:
print(bin(subset)[2:].zfill(7)) # handle this subset
subset = (subset-1)&s # next subset
#end while
print(bin(subset)[2:].zfill(7)) # handle empty set if necessary
```
我們可以看跑出來的結果
```
0110010
0110000
0100010
0100000
0010010
0010000
0000010
0000000
```
這個方法的原理是每次減一把最低一個位元1變成0,而他的後面都是1,再與原集合取bitwise-and,就是原集合中更小編號的元素都取出來。所以每次的效果就是**將最小編號的元素刪除,加入編號比她小的所有元素**。所以它可以很方便而有效率的歷遍子集合的所有子集合。
那麼,用前面的方式搜尋所有的set-partition時間複雜度會是如何呢?如果我們什麼都不做,那搜尋數量將是非常大的,但我們可以用memoization的方式改善,這樣會進入$sol(s)$的次數將是子集合的個數,也就是$O(2^n)$,每一個子集合要搜尋他的子集合,所以直覺來看總時間是$O(2^n\times 2^n)=O(4^n)$。但事實上這樣是高估了,因為元素個數為$i$的子集合有$C(n,i)$個,他的子集合有$2^i$個,所以總數是
$$
\sum_{i=0}^{n}C(n,i)\times 2^i = (1+2)^n=3^n
$$
後面是根據二項式定理。
每個題目狀況不同,有些題目也許用別的方法會更快。
先來看一題很直接的題目。
[(題目連結) 698. Partition to K Equal Sum Subsets (Hard)](https://leetcode.com/problems/partition-to-k-equal-sum-subsets/)
給一個正整數的陣列以及一個整數$K$,請問陣列可否恰好分割為$K$個總和相同的非空子集合。
以下是照著set-partition的方法做的範例程式。第4 ~ 7行是一些特例的判斷,第9 ~ 14行是計算一個index的集合$s$(以整數bit表示)所對應的數的總和。第16 ~ 26行則是set-partition的DP搜尋,這裡我們不需要紀錄成功分割的子集合,因為這題只要判斷yes/no,如果搜到一個yes,就已經結束返回了。第18 ~ 19行是終端條件,第20 ~ 24則是前面介紹的枚舉子集合的方法。
時間複雜度如前面說明的是$O(3^n)$,$isum$的執行最多只花$O(n2^n)$。
這支程式可以通過,但執行時間的rank非常低。本題$n\le 16$,$O(3^n)$很邊緣,而且$O(n2^n)$並不小。
```python=
class Solution:
def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
n = len(nums)
if k==1: return True
isum = sum(nums)
psum = isum//k
if k*psum != isum or max(nums)>psum: return False
@lru_cache(maxsize=None)
def isum(s): # sum of subset s
total = 0
for i in range(n):
if s&(1<<i): total += nums[i]
return total
#
fail = set() # set of no-case
def dp(s,k): # subset s can be partitioned into k*psum
if k==1: return True # ensure total = psum*k
if s in fail: return False #memoization
subset = s
while subset: # check all subset of s
if isum(subset)==psum:
if dp(s-subset,k-1): return True
subset = (subset-1)&s # next subset
fail.add(s)
return False
#end dp
return dp((1<<n)-1,k) # all
```
前一支程式第一個可以改善的是點是:計算每一個子集合的和,我們可以先把所有符合要求的子集合(也就是總和等於$psum$)全部找出來,以下程式第9 ~ 16行是一個遞迴暴搜子集合的函數,但是,我們加了一個提早結束的cut:第13行如果發現和已經超過目標,則不再搜下去,這樣做$find\_good$的wort case時間複雜度由$O(n\times 2^n)$降至$O(2^n)$,而且在很多情形會更好。
這個程式對LeetCode的測資比前面一支快了不少,752ms,rank=20%。
```python=
class Solution:
def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
n = len(nums)
if k==1: return True
isum = sum(nums)
psum = isum//k
if k*psum != isum or max(nums)>psum: return False
good = set() # all subset with sum=psum
def find_good(idx,subset,curr):
if idx==n:
if curr==psum: good.add(subset)
return
if curr>psum: return
find_good(idx+1,subset,curr)
find_good(idx+1,subset|(1<<idx),curr+nums[idx])
return
#
fail = set() # set of no-case
find_good(0,0,0)
def dp(s,k): # subset s can be partitioned into k*psum
if k==1: return True # ensure total = psum*k
if s in fail: return False #memoization
subset = s
while subset: # check all subset of s
if subset in good:
if dp(s-subset,k-1): return True
subset = (subset-1)&s # next subset
fail.add(s)
return False
#end dp
return dp((1<<n)-1,k) # all
#752ms 20%
```
我們可以想一下就知道前面程式rank不好的原因,子集合中其和恰好是我們要的$psum$的應該只佔少數,但我們的程式對每一個子集合都會搜尋他的所有子集,因而浪費了許多時間。所以我們可以檢查那些總和為$psum$的子集($good$)就好。
以下是修改後的程式。第23行從檢查所有子集改成檢查所有$good$中的子集,第24行的if是要確定該子集在$s$中($subset\ \&\ s == subset$)。
改善後的程式跑169ms,rank = 65%。
```python=
class Solution:
def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
n = len(nums)
if k==1: return True
isum = sum(nums)
psum = isum//k
if k*psum != isum or max(nums)>psum: return False
good = set() # all subset with sum=psum
def find_good(idx,subset,curr):
if idx==n:
if curr==psum: good.add(subset)
return
if curr>psum: return
find_good(idx+1,subset,curr)
find_good(idx+1,subset|(1<<idx),curr+nums[idx])
return
#
fail = set() # set of no-case
find_good(0,0,0)
def dp(s,k): # subset s can be partitioned into k*psum
if k==1: return True # ensure total = psum*k
if s in fail: return False #memoization
for subset in good:
if subset&s == subset:
if dp(s-subset,k-1): return True
fail.add(s)
return False
#end dp
return dp((1<<n)-1,k) # all
#169ms 65%
```
我們還可以再改善前一支程式,我們比對所有$good$中的子集有點太多,我們可以將$good$中的子集分類,例如根據他的最低1-bit分成$n$類,最低1-bit在第$i$個bit就放入$good[1<<i]$。以下是修改後的程式,我們用一個方法很快可以決定一個整數$s$的最低1-bit:$s\ \&\ (-s)$,這個方法也是用在**binary indexed tree**中的方法,不知道的讀者可以先了解一下2's-complement。
修改的地方是
1. good變成一個default為set的字典
2. 第13行找到子集合時,根據該集合的最低1-bit放入不同的地方。
3. 第25行在搜尋時,我們只要搜每一類good子集就好了。
修改後跑119ms,rank=68%。
```python=
class Solution:
def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
n = len(nums)
if k==1: return True
isum = sum(nums)
psum = isum//k
if k*psum != isum or max(nums)>psum: return False
good = defaultdict(set) # all subset with sum=psum
# good[i]: lowest 1-bit at 1<<i
def find_good(idx,subset,curr):
if idx==n:
if curr==psum:
good[subset&(-subset)].add(subset)
return
if curr>psum: return
find_good(idx+1,subset,curr)
find_good(idx+1,subset|(1<<idx),curr+nums[idx])
return
#
fail = set() # set of no-case
find_good(0,0,0)
def dp(s,k): # subset s can be partitioned into k*psum
if k==1: return True # ensure total = psum*k
if s in fail: return False #memoization
for subset in good[s&(-s)]: # lowest bit of s
if subset&s == subset:
if dp(s-subset,k-1): return True
fail.add(s)
return False
#end dp
return dp((1<<n)-1,k) # all
#119ms, 68%
```
其實LeetCode的running time並不是特別重要,因為全測資的時間與worst case的考慮還是有差別的。筆者故意寫這麼多種逐步改善的方法,是讓讀者了解如何運用自己的知識與算法的原則去改善優化程式。
這一題對於LeetCode測資跑更快的方法,其實反而是對於一個子集合,直接暴搜他的子集合,但加上簡單的cut,因為總和只要超過目標就不必再搜下去,所以實際跑起來並沒有那麼壞。寫法有很多種,以下是用tuple當做hash key的一種寫法。提供有興趣的讀者參考。注意到程式中始終保持進入的tuple是由小到大排好序的。
```python=
class Solution:
def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
if k==1: return True
isum = sum(nums)
psum = isum//k
if k*psum != isum or max(nums)>psum: return False
fail = set() # set of no-case
nums.sort()
def dp(s): # if tuple s can be partitioned =psum
if not s: return True
if s in fail: return False
ls = list(s)
e = ls.pop()
if e==psum: return dp(tuple(ls))
for i in range(len(ls)):
if ls[i]+e<=psum and dp(tuple(ls[:i]+ls[i+1:]+[ls[i]+e])):
return True
fail.add(s)
return False
#end dp
return dp(tuple(nums))
# 84ms 76%
```
## 結語
DP的題型豐富多變,無論是在LeetCode或是各種程式比賽中都時常見到。他的道理雖然很簡單,但運用與變化非常多。對於1D與2D的DP,LeetCode中也只涉及比較簡單的優化技巧,事實上在競程領域還存在著許多奇奇怪怪的精妙優化技術,精妙的意思是一般人很難靠自己想能想得出來的。
Tree DP的問題也很多,但相對來說,樣子比較固定,這些題目往往不只有DP解,也有靠著圖論知識所導致的解,例如樹的直徑與中心。
集合的DP問題大多是NP-hard的難題,目標是要找一個比直接暴力好一些的複雜度的解。這一類的問題往往也存在另外一個解法,就是branch(分支)算法搭配著一些優化的技術,有時稱為**branch and bound**或**branch and cut**,有的則歸類在**Backtracking**。這些算法往往我們只知道他在worst case時的一個粗略的時間複雜度上界,但是這些上界是否是**tight**並不一定都知道。因為這個原因,競程界出現這種題目的機會比較少,更何況有些exponential-time algorithm的技術並不在競程界流傳,例如fixed-parameter。
以LeetCode的測資來說,很多時候這些方法跑得比DP要來得快。