# Python-LeetCode 581 第19招 Branching algorithm and Backtracking
本單元中我們介紹 branching algroithm,也包含backtracking、剪枝、以及折半枚舉(meet in the middle)等優化的技巧。這裡的**branching algorithm (分支算法)** 泛指以遞迴方式展開所有可能的解並從中搜尋答案的演算法。
算法中最簡單的方法是branching algorithm,最難的也是branching algorithm,因為當我們什麼算法也不會的時候,我們只會用branching algorithm來遞迴暴搜,當我們學會很多算法技巧之後,對於某些問題卻發現所學的算法一招也用不上,只好再回到遞迴暴搜。不同的是,我們可以加上一些技巧讓branching algorithm 暴的快速、暴的優雅、暴的乾淨俐落。
Branching algorithm翻譯為分支算法與分枝算法可能都可以,因為有樹狀圖展開之意,這裡本來就是雙關語,如果不是樹枝的話,那「剪枝」這個詞也難以符合。
---
## 基本原理
組合問題從所求答案類型可以分為decision problem (回答true or false)、optimization problem (最大化或最小化某個目標函數)、以及enumeration problem (找出所有符合答案的解),而這些解都是要找出滿足某些要求的組合結構。常見的組合結構包含subset (combination)、permutation、以及partition (將集合切割成若干不相交的子集)。對於某些問題,特別是一些NP-hard的問題,我們也只能暴搜窮舉,也就是把所有可能的的結構都找出來。Branching名字的由來是:在遞迴展開時,其實是將所有可能的解以樹狀圖的方式分類展開,對各分支遞迴展開搜尋。
有一些名詞指的是類似的方法,但也彼此不盡相同。這些名詞基本上並沒有準確的定義。
* Exhaustive serach、Brute Force、Enumeration:這些名詞是著眼於「將所有可能的解都產生出來」,但不一定是遞迴的形式。
* Branching algorithm 常常採用DFS(depth first search)的方式搜尋,所以有時也直接被稱為DFS,但在算法上Branching algorithm也未必採取DFS,也有Breadth first search或best first search。
* 搜尋過程中加入剪枝的步驟以減少搜尋的範圍,這樣的算法稱為Branch and Bound,或Branch and cut、Branch and reduce。
* Backtracking指的是嘗試搜尋時,將解的內容帶入搜尋的參數。在每一步的時候是嘗試各種可能,每一種可能嘗試後必須予以回復改變的內容再嘗試下一種可能。Backtracking帶入解的內容通常是因為:1.為了判斷哪些分支滿足題目的條件;2.最後的答案必須輸出解的內容而非只是某個目標值;3.剪枝。
### 基本branching algorithm
我們用一個很簡單的題目來舉例說明。
輸入一個長度為$n$的正整數陣列$arr$以及一個整數$t$,現在要找一個子集合其總和不超過且最接近$t$。假設數值的範圍太大,不適合用背包的DP做。以下是一個簡單的暴搜。
```python=
# arr and t is given
best = 0
def branch(idx,currSum):
global t, best
if idx == n: # terminal case
if currSum <= t:
best = max(best, currSum)
return
# without arr[idx]
branch(idx+1, currSum)
# with arr[idx]
branch(idx+1, currSum+arr[idx])
return
# main program
branch(0,0)
# answer is best
```
在這個遞迴函式中,進入的參數$idx$代表目前是決定$arr[idx]$是否要納入總和,而$currSum$是目前已挑選元素的總和。常見的遞迴寫法是一開始先寫終端條件,若$idx==n$時,所有的元素皆已考慮完畢,所以我們檢查所得到的答案$currSum$是否比目前最好的答案$best$更好,是的話就更新答案。
如果不是終端條件,我們就分兩支遞迴搜尋,一個是不取$arr[idx]$;另外一個是取$arr[idx]$。
這個程式的時間複雜度是$O(2^n)$,因為遞迴呼叫的次數就是這麼多,也就是子集合的個數(2倍),每次遞迴呼叫都是花$O(1)$的時間。Space complexity是$O(n)$,因為走DFS,堆疊深度就是樹的高度。
請注意,通常遞迴會比迴圈枚舉要好一點。以下是迴圈枚舉的形式:
```python=
# arr and t is given
best = 0
for s in range(1<<n):
# bitmask, an int is for one subset
currSum = sum(arr[i] for i in range(n) if s&(1<<i))
if currSum <= t:
best = max(best, currSum)
# answer is best
```
這樣寫的時間複雜度是$O(n\times 2^n)$,因為每個子集合花了$O(n)$的時間做加總。如果你認真探究為何遞迴版的時間複雜度會比較好,你會發現它減少了重複計算,因為不是在終端條件時才計算總和,而是在往下分類時,就將選取元素的值加入了。
### 剪枝
剪枝(pruning)的名稱由來是在分支算法中,我們在分支過程中發現某些分支無須再往下搜尋,而未到終端條件就提前結束。
以上述的題目來說,因為陣列中都是正整數,我們可以加上一個簡單的剪枝:當目前總和已經超過目標,就不必再往下搜尋了。
```python=
best = 0
def branch(idx,currSum):
global t, best
if currSum > t: return # bounding
if idx == n: # terminal case
best = max(best, currSum)
return
#
branch(idx+1, currSum)
branch(idx+1, currSum+arr[idx])
return
# main program
branch(0,0)
```
這個例子的剪枝是簡單的,有時剪枝可以很複雜的。剪枝有的時候對runtime的改善是很巨大的,但對worst case的時間複雜度是否改善以及改善的程度就不一定了。如果程式比較複雜,很多branch and bound的演算法只有粗略的upper bound,但是否為tight bound往往是未知的狀態。所謂tight bound是指真的存在worst case讓程式的時間複雜度達到所稱的bound。
### Backtracking
如果題目要求找出達到目標值最佳的解,而非只是最佳的目標值,或者是要找出所有滿足條件的解,我們需要將解的內容(以本題來說是選取的元素)納入遞迴參數。這時,這個程式就成為Backtracking了。
```python=
best = 0
bestSol = []
def branch(idx,sol,currSum):
global t, best, bestSol
if currSum > t: return # bounding
if idx == n: # terminal case
if currSum > best:
best = currSum
bestSol = sol[:]
return
#
branch(idx+1, sol, currSum)
sol.append(arr[idx])
branch(idx+1, sol, currSum+arr[idx])
sol.pop() # backtracking, recover
return
# main program
branch(0,[],0)
print(best,bestSol)
```
另外,題目如果改成對選取的元素有一些限制,例如,選取的任兩個元素相差必須超過若干,或者選取的位置不可連續三個,....等等,這個時候我們必須把解的內容納入以便判斷。這狀況如同在著名的八后問題中,後面擺放的位置必須根據前面已經擺放的位置來決定。
### 折半枚舉 (Meet in the middle)
有一些題目可以運用折半枚舉的技巧來改善時間複雜度加快搜尋。將集合均分成兩半,一半有$n//2$個,另外一半有$n-n//2$個,兩半各自枚舉所有的子集合,最佳解可能全部落在其中一半,或者由兩邊各一個子集合聯合構成。我們先找出兩邊各自的最佳解,然後由其中一半的每一個子集合在另外一半中搜尋它最好的搭配。
以下是範例程式。因為要枚舉兩個集合,我們把枚舉的函數寫成$branch(p,idx,currSum,res)$,其中$p$是要枚舉的list,而結果放在$res$這個set中。
主程式先建立兩個空集合,將$arr$前半與後半各自呼叫枚舉,結果傳回來時各自排序後放在$left$與$right$。$best$先取兩邊最佳解的較大者,然後搜尋各一個子集合的和併的解,這裡可以用two pointers的方法或者用binary search。
時間複雜度由原來的$O(2^n)$降為$O(n\times 2^{n/2})$,這是很大的改善。
```python=
#from bisect import bisect_right
arr = [10,17,31,15,10,46,32]
t = 50
# sums of all subsets of p,
def branch(p,idx,currSum,res):
if currSum > t: return # bounding
if idx == len(p): # terminal case
res.add(currSum)
return
#
branch(p, idx+1, currSum, res)
branch(p, idx+1, currSum+p[idx], res)
return
# main program
s1,s2 = set(),set()
branch(arr[:len(arr)//2],0,0,s1)
branch(arr[len(arr)//2:],0,0,s2)
left,right = sorted(s1),sorted(s2)
print(left,right)
best = max(left[-1],right[-1]) # sol within either part
# find sol cross left and right
i = len(right)-1 # two pointers or binary search
for q in left:
# i = bisect_right(right,t-q)-1 # always >=0
while right[i] > t-q: i -= 1
best = max(best,q+right[i])
print(best)
print(best)
```
折半枚舉的名詞並非直接翻譯來的,英文Meet in the middle這個名詞不只用在這裡,可以廣泛的指從兩邊往中間找答案的策略。
## LeetCode範例
以下我們看一些LeetCode上的題目,難度等級幾乎都列在難題,雖然有些題目不見得很難。由於有剪枝的狀況下,worst case time complexity不一定能準確的分析,我們在範例程式尾端將列出在LeetCode上的runtime與rank,不過這個值可能隨著時間不完全相同。
[(題目連結) 1601. Maximum Number of Achievable Transfer Requests (Hard)](https://leetcode.com/problems/maximum-number-of-achievable-transfer-requests/)
有$n$個有向邊,要找出最大的子集合,滿足每個點的in-degree 等於 out-degree。點數不超過20,邊數不超過16。
我們暴搜所有邊的子集合,找出degree符合條件的最大子集。遞迴函式dfs的參數包含:
* $idx$:目前要決定$requests[idx]$是否納入;
* $num$:目前選的邊數;
* $deg$:每個點的out-degree與in-degree的差值;
* $nonzero$:$deg$中非0的個數。
策略與參數定義好了之後程式的寫法就不難了。第8 ~ 11行是終端條件,第12行是不挑選此邊的搜尋,接下來是挑選此邊,挑選時,先更改所挑邊端點的degree,搜尋後再將它回復。
時間複雜度是$O(n+2^m)$,其中$n,m$是點數與邊數。
```python=
class Solution:
def maximumRequests(self, n: int, requests: List[List[int]]) -> int:
d = [0]*n
m = len(requests)
imax = 0
def dfs(idx,num,deg,nonzero):
nonlocal imax,m,requests
if idx==m:
if nonzero==0:
imax = max(imax,num)
return
dfs(idx+1,num,deg,nonzero) # not using this edge
u,v = requests[idx]
deg[u] += 1
if deg[u]==1: nonzero += 1
elif deg[u]==0: nonzero -= 1
deg[v] -= 1
if deg[v]==-1: nonzero += 1
elif deg[v]==0: nonzero -= 1
dfs(idx+1,num+1,deg,nonzero)
deg[u] -= 1
deg[v] += 1
return
dfs(0,0,d,0)
return imax
# runtime=793ms, rank=75%
```
**Meet in the middle (折半枚舉)**
本題適合折半枚舉。以下是範例程式。我們寫一個副程式
$dfs(edge,idx,num,deg,res)$
枚舉$edge$的子集合,結果放在一個字典$res$中,其對應為
$res$[tuple of degree] = number of edges。
其餘參數與前面的程式相同。
第5 ~ 8行是終端條件,我們將$deg$這個list轉成tuple當做key,對應值為使用了幾根edge,也就是$num$。下面的遞迴呼叫與前面的程式相似,但我們不必記錄非零項有幾個。
第18行是將傳入的edge均勻切成兩半,第19行建立儲存結果的字典,這裡用的是Counter(),然後各自呼叫dfs枚舉子集合。因為本題degree均為0是我們要的答案,所以最佳解先設成兩邊各自最佳解之和,然後第24 ~ 27行搜尋跨兩邊的解,其中一半的degree如果是$p$,另外一邊就必須每一項都是它的負值,如果存在,就比較是否更大。
時間複雜度是$O(n\times 2^{m/2})$,其中$n,m$是點數與邊數。這樣比之前的$O(n+2^m)$好很多。
```python=
class Solution:
# meet in middle
def maximumRequests(self, n: int, req: List[List[int]]) -> int:
def dfs(edge,idx,num,deg,res):
if idx == len(edge):
key = tuple(deg)
res[key] = max(res[key],num)
return
dfs(edge,idx+1,num,deg,res) # not using this edge
u,v = edge[idx]
deg[u] += 1
deg[v] -= 1
dfs(edge,idx+1,num+1,deg,res)
deg[u] -= 1
deg[v] += 1
return
# main
edge1,edge2 = req[:len(req)//2],req[len(req)//2:]
d1,d2 = Counter(),Counter()
dfs(edge1,0,0,[0]*n,d1) # enumerate first part
dfs(edge2,0,0,[0]*n,d2) # enumerate second part
imax = d1[(0,)*n]+d2[(0,)*n]
#print(d1,d2,imax)
for p in d1:
q = tuple(-x for x in p)
if q in d2:
imax = max(imax, d1[p]+d2[q])
return imax
# runtime=49ms, rank=99%
```
[(題目連結) 1723. Find Minimum Time to Finish All Jobs (Hard)](https://leetcode.com/problems/find-minimum-time-to-finish-all-jobs/)
要把$n$個工作分給$k$個人,第$i$件工作需要$jobs[i]$的時間,每件工作只能分配給一個人,也就是必須由同一個人完成,不能多人分工。請找出最好的工作指派方式,使得這$k$個人中最大的總工時最小化。換句話說,最少需要經過多少時間,$k$個人可以完成全部的工作。參數範圍為$k\le n\le 12$。
這問題稱為minimum makespan scheduling,屬於NP-complete。
每個工作有$k$種選擇,所以總共有$k^n$個指派的方式,但這未免也太大了。事實上,最佳解中每個人一定會被指派到工作,不必考慮有空閒的人,此外,人的工作能力是沒有區別的,我們要減少對稱組合的重複計算。我們採用的Branching rule並非單純對每一個工作嘗試$k$種指派。
我們紀錄目前已經被指派工作的人的工作量,對於一個工作,考慮指派給已經有工作的人其中之一,或者指派給一個還沒有工作的人,指派給沒有工作的人的時候,並不區分給誰。這樣,就減少了很多重複的情形。
此外,這題我們採取branch and bound,給予一個簡單的剪枝:如果現在工作量的最大值已經不小於目前的最佳解,就不必再繼續指派其他的工作,因為不可能找到更好的解。
以下是範例程式。
一開始我們把工作由大到小排序,先跑一個heuristic algorithm找出一個解來當做目前的最佳解。這個heuristic方法是每次將工作指派給目前工作量最小的人,我們用一個heap來完成(第9 ~ 13行)。第14行是目前工作量的初始,第15行開始是branching algorithm,進入的參數$curr$是目前工作量的list,$idx$是目前要指定的工作,$cost$是目前工作量的最大值,也是本題的目標函數。第17行是剪枝,第18 ~ 20行是終端條件,已經指派完畢或者剩下的每個工作都可以指派給一個目前沒有工作的人。第21 ~ 24行是指派給沒有工作的人,第25 ~ 29行則是指派給一個目前已經有工作的人,留意兩種情形都需要backtracking。
```python=
class Solution:
# min makespan scheduling, NPC partition problem
# dp or branch and cut? branch
def minimumTimeRequired(self, jobs: List[int], k: int) -> int:
if k==1: return sum(jobs)
n = len(jobs)
if k==n: return max(jobs)
# find a greedy initial cost
jobs.sort(reverse=True)
pq = [0]*k
for v in jobs:
heappush(pq,heappop(pq)+v)
best = max(pq) # best so far
curr = [jobs[0]]
def part(curr,idx,cost): # insert jobs[idx]
nonlocal best
if cost>=best: return
if idx==n or n+len(curr)-idx<=k:
best = min(best,cost)
return
if len(curr)<k: # new bin
curr.append(jobs[idx])
part(curr,idx+1,cost)
curr.pop()
for i in range(len(curr)):
if curr[i]+jobs[idx]<best:
curr[i] += jobs[idx]
part(curr,idx+1,max(cost,curr[i]))
curr[i] -= jobs[idx]
return
#
part(curr,1,jobs[0])
return best
# runtime=41ms, rank=90%
```
[(題目連結) 473. Matchsticks to Square (Hard)](https://leetcode.com/problems/matchsticks-to-square/)
給$n$根火柴棒,每根的長度可能不同,請問是否可以用這些火柴棒圍成一個正方形,火柴棒不可折,每根或柴棒都要恰好使用一次。火柴棒數量不超過15,長度不超過$10^8$。
這問題也就是問一群正整數能否分成四個總和相同的子集合,屬於NP-complete。比較常見的是分成兩個子集合,稱為Partition problem,本題要分成4塊,且數值範圍大不適合用背包的DP解法。
既然和前一題一樣是Partition的問題,我們可以採取類似的Branching方式,不同的是本題固定要切成4份,此外本題是decision problem,答案只有True/False而非求最佳解。
以下是範例程式。一開始先判斷一下總長度是否是4的倍數,並且把每一邊的邊長$leng$算出來。我們一樣採取Backtracking的方式,dfs 傳入的參數中,$idx$是目前考慮到第幾根火柴棒,$curr$是目前每一個塊的總和,程式過程中,我們嘗試將$stick[idx]$放入4個可能的位置,但必須總和不超過$leng$才能放入。此外,如果$curr$有兩個以上的0,我們只需嘗試第一個0,這裡採取的寫法與前一題稍有不同,但意義相同。
如果dfs可以將所有的棍子放完,那就找到了一個答案。
本題的worst case粗估是$O(4^n)$,但事實上應該沒有那麼壞,因為很多case不可能有那麼多選擇。實際LeetCode的測資跑起來並不壞。
```python=
class Solution:
# partition into 4 equal sum
def makesquare(self, stick: List[int]) -> bool:
leng = sum(stick)
if leng%4 !=0: return False
leng //= 4
n = len(stick)
stick.sort(reverse=True)
curr = [0]*4
def dfs(idx,curr): # can be equal 4-partitioned
if idx == n: return True
for i in range(4):
if stick[idx] + curr[i] <= leng:
curr[i] += stick[idx]
if dfs(idx+1,curr): return True
curr[i] -= stick[idx]
if curr[i] == 0: # only try 1 empty
break
return False
#
return dfs(0,curr)
# runtime=304ms, rank=83%
```
**使用折半枚舉**
這題可以用折半枚舉。我們將火柴棒均分兩半,每一半都枚舉它分成四個子集合的各種可能,我們還是保持每一個子集合的和不得超過$leng$,枚舉的結果轉成tuple放入一個set中,同時,放進去的四個值保持由小排到大,這是便於後面的搜尋。
枚舉完畢後,我們由其中一半的每一個結果去另外一半的結果中搜它需要的對象,如果這邊的四個值是 p,那它的對象就是第27行寫的那樣,每一項都與它互補(各項總和為$leng$),因為我們各項之間是排序的,所以它的互補項應該是反序的。
這個時間複雜度粗估是$O(2^n \times n\log n)$,但實際可能更好。
```python=
class Solution:
# partition into 4 equal sum
# meet in middle
def makesquare(self, stick: List[int]) -> bool:
leng = sum(stick)
if leng%4 !=0: return False
leng //= 4
def dfs(s,idx,curr,res): # generate 4-partition
if idx == len(s):
res.add(tuple(sorted(curr)))
return
for i in range(4):
if s[idx] + curr[i] <= leng:
curr[i] += s[idx]
dfs(s,idx+1,curr,res)
curr[i] -= s[idx]
if curr[i] == 0: # only try 1 empty
break
return
#
n = len(stick)
s1,s2 = stick[:n//2], stick[n//2:]
part1,part2 = set(), set()
dfs(s1,0,[0]*4,part1)
dfs(s2,0,[0]*4,part2)
for p in part1:
q = tuple(leng-x for x in p[::-1])
if q in part2: return True
return False
# runtime=122ms, rank=94%
```
[(題目連結) 140. Word Break II (Hard)](https://leetcode.com/problems/word-break-ii/)
給一個字典以及一個字串,要找出所有將該字串切成字典中的字的方式。字串長度不超過20,字典的字數不超過$1000$,每個字長度不超過10。輸出時要將切開的地方放入空白。請看以下範例。
Input: s = "catsanddog", wordDict = ["cat", "cats", "and", "sand", "dog"]
Output: ["cats and dog", "cat sand dog"]
令$s$的長度為$n$,我們可以建構出以下的DAG (directed acyclic graph):
點集合為 $[0,n]$。對於$0\le i<j\le n$,若$s[i:j]$在字典中,我們說有一條$(i,j)$有向邊,也就是$i$可以走到$j$,那麼,一個合法的切割就是一條從$0$走到$n$的路徑,本題就是要找出所有從$0$走到$n$的路徑。
以下是範例程式。第3行將所給字典的字放入一個set()以便查詢,第6 ~ 9行是建構DAG的部分,我們枚舉所有可能起點$i$以及終點$j$,在字典中檢查$s[i:j]$是否是一個邊。
接著是dfs走訪,其中$curr$是目前已經走的路徑,$curr[-1]$就是目前的點,$adj$是邊集,$path$是用來放答案的list。第12行是找到一條到達終點的情形,我們要根據路徑建立一個題目要求的句子,zip($curr$,$curr[1:]$)是枚舉$curr$相鄰點對,例如$curr=[0,1,2,3]$,就會產生$[(0,1), (1,2), (2, 3)]$,從相鄰點對可以抓出所切割的子字串序列,然後用join在之間加入空白並串接。
如果不是終點情形,我們就用backtracking的方式,嘗試從目前的點$curr[-1]$往下一步走。
主程式從$curr = [0]$開始dfs就可以了。
```python=
class Solution:
def wordBreak(self, s: str, wordDict: List[str]) -> List[str]:
iword = set(wordDict)
# build dag
n = len(s)
adj = [[] for i in range(n)]
for i in range(n):
for j in range(i+1,min(i+11,n+1)):
if s[i:j] in iword: adj[i].append(j)
path = []
def dfs(curr, adj, path):
if curr[-1] == n:
t = ' '.join([s[i:j] for i,j in zip(curr,curr[1:])])
path.append(t)
return
for v in adj[curr[-1]]: # backtracking
curr.append(v)
dfs(curr, adj, path)
curr.pop() # recover
return
#end dfs
# enumerate all path from 0 to n
dfs([0], adj, path)
return path
# 33ms, rank=86%
```
[(題目連結) 301. Remove Invalid Parentheses (Hard)](https://leetcode.com/problems/remove-invalid-parentheses/)
給一個由左右小括號與字母組成的字串,要移除最少的括號,使得括號成為合法。找出所有移除數量最少的結果。字串長度不超過25,括號最多20個。
若$dif(t)$表示一個字串$t$中左括號的數量減去右括號的數量,那麼字串$s$是合法的充分且必要條件為:
* $s$的任意前綴皆滿足$dif(s[:i])\ge 0$;且
* $dif(s) == 0$。
因此,我們由前往後枚舉所有可能的前綴,對於字母,一定將它留住,對於括弧,我們分成選它或不選它分別進行搜尋,但要保持$dif\ge 0$。由於要輸出所有最長的合法字串,我們以backtracking的方式將字串代入參數,但因為字串是不可修改的,所以用list的方式做。
以下是範例程式。首先,第6 ~ 8行算一個後綴的右括弧個數,這是用來做簡單剪枝的。第9行的ans用來放答案,best則是最長合法字串的長度。第10行是遞迴搜尋的dfs函數,參數$idx$是目前做到第幾個字元,$curr$是目前的前綴,以list表示,$dif$則是前述的差值。
第12行是剪枝,有兩種狀況可以不必繼續搜,一是目前的差值已經大於後面的右括號數量,則最後一定不可能回到0,這是前面先算出$right$的目的。第二種狀況是剩下的長度加上現在的長度已經小於目前最佳解的長度,所以他也不可能成為最長的合法字串。
第14 ~ 20行是終端情形,因為有前面的剪枝判斷,這時如果已經走完整個字串,它必是目前的最佳解,但可能是更常的字串或者與最長字串等長,所以分兩種來更新答案。
最後是分支的處理,這裡依照目前的字元的種類來分別撰寫,留意都是以backtracking的方式。
```python=
class Solution:
# dif = num of '(' - num of ')'
# valid: dif>=0 at any position and dif==0 at end
def removeInvalidParentheses(self, s: str) -> List[str]:
n = len(s)
right = [0]*(n+1) # num of ) of suffix
for i in range(n-1,-1,-1):
right[i] = right[i+1] + (s[i]==')')
ans = set(); best = 0
def dfs(idx, curr, dif): # dif >= 0
nonlocal s,best,ans,n
if dif > right[idx] or len(curr)+n-idx<best:
return # pruning
if idx == n:
t = ''.join(curr)
if len(t) > best:
ans = {t}; best = len(t)
else: # len(t) == best
ans.add(t) # dif==0
return
if s[idx]=='(':
curr.append(s[idx])
dfs(idx+1,curr, dif+1)
curr.pop()
dfs(idx+1,curr,dif)
elif s[idx]==')':
if dif > 0:
curr.append(')')
dfs(idx+1,curr,dif-1)
curr.pop()
dfs(idx+1,curr,dif)
else: # letter
curr.append(s[idx])
dfs(idx+1,curr,dif)
curr.pop()
return
# end dfs
dfs(0,[],0)
return ans
# 54ms, rank=90%
```
[(題目連結) 1681. Minimum Incompatibility (Hard)](https://leetcode.com/problems/minimum-incompatibility/description/)
要將$n$個整數分到$k$個子集合,每個子集合分到的數字一樣多,且同一個子集合不可以有相同的數字。找出分配後各子集合的最大值減最小值的總和的最小值。如不可能分配則回傳-1。$n\le 16$,$k$保證可以整除$n$,數字大小不超過$n$。
這是個partition的問題,如果每個元素有$k$種選擇,總共會有$k^n$種可能,這太大了。但是本題有限制每個子集合的大小必須相等,而且子集合看成是相同的,再加上我們可以用簡易的剪枝,所以dfs不至於那麼可怕。以下是範例程式,分支方式是將每一個元素指定給每一個可能的子集合。
一開始先把數字排序一下,這將有助於我們判斷同一個子集合不可以有相同的元素。第8行的dfs()其參數為:$idx$目前處理哪一個元素,$curr$是目前每一個子集合的內容,我們用list of list來做,$cost$則是目前每個子集合的最大與最小的差值的總和,也就是題目要最小化的目標函數值。
第10行是剪枝,如果$cost$不小於目前最佳解,就不必做了。第11行是終端條件,由於前面有剪枝,所以這裡的$cost$一定是更好的解。第14行考慮將$nums[idx]$放入個子集合,不能放的情形包含子集合已經滿了(只能放$size$個),或者改子集合中的最後一個是相同的元素,這裡因為排序過,所以如果有相同必是最後一個。第18行算一下放入此子集合時$cost$的增加量,然後就做backtracking的動作:放入、遞迴搜尋、然後回復。第22行我們做一個判斷,如果$curr[i]$是空的,那後面的子集合就不必再考慮,因為我們的做法空集合是會在後面連續出現,由於對稱的關係,我們只要考慮其中一個就可以。
在LeetCode跑的結果並不差。
```python=
class Solution:
def minimumIncompatibility(self, nums: List[int], k: int) -> int:
n = len(nums)
if n==k: return 0
size = n//k
best = 256 #oo
nums.sort()
def dfs(idx, curr, cost):
nonlocal best,nums,n
if cost >= best: return #pruning
if idx == n: # terminal case
best = cost
return
for i in range(k):
if len(curr[i])==size or \
curr[i] and curr[i][-1]==nums[idx]:
continue
inc = nums[idx]-curr[i][-1] if curr[i] else 0
curr[i].append(nums[idx])
dfs(idx+1,curr,cost+inc)
curr[i].pop() # recover
if not curr[i]: break # try only one empty
return
#
dfs(0,[[] for i in range(k)],0)
if best == 256: return -1
return best
# 326ms, rank=83%; partition DP is better
```
[(題目連結) 2305. Fair Distribution of Cookies (Hard)](https://leetcode.com/problems/fair-distribution-of-cookies/)
要將$n$包餅乾分給$k$個小朋友,每包餅乾內的餅乾數量未必相同,希望分的越平均越好,請問拿到最多餅乾的最小值。$2\le k \le n\le 8$,每包餅乾內的餅乾數不超過$10^5$。
這也是個$n$個整數的集合分成$k$個子集的問題,目標是最小化最大子集總和。我們還是採用DFS分支的方式來搜尋,過程中也是採取簡單剪枝以及減少對稱組合的搜尋。
以下是範例程式,一開始先將$cookies$由大到小排序,然後第8 ~ 12行跑一個簡單的heuristic algorithm,計算一個最佳解的初值。這個方法是由大到小考慮集合內的數字,每次將數字放到目前總和最小的子集合中,我們用一個heap來做。
第13行開始的dfs,其參數$curr$是目前每個子集合的總和,這題我們換一個方法寫,只記錄非空的子集,此外$idx$則是目前處理到第幾個元素。第15行是終端情形,因為近來的結果一定比目前最佳解好,所以可以走到終端情形就更新目前最佳解。
第18行的迴圈考慮將$cookies[idx]$放入$curr[i]$的可能但只有總和小於$best$才需要考慮,加入時必須以backtracking的方式在搜尋後予以回復(第22行)。最後,我們第24行考慮放到一個新的空子集合中的情形。
```python=
class Solution:
# backtracking with bounding
def distributeCookies(self, cookies: List[int], k: int) -> int:
n = len(cookies)
cookies.sort(reverse=True)
if n==k: return cookies[0]
# initial heuristic improved from 45ms to 36 ms
g = cookies[:k][::-1]
for p in cookies[k:]:
x=heappop(g)
heappush(g,x+p)
best = max(g)
def dfs(curr,idx):
nonlocal best
if idx==n: # ony better is generated
best = max(curr)
return
for i in range(len(curr)): # insert current bin
if curr[i]+cookies[idx] < best:
curr[i] += cookies[idx]
dfs(curr,idx+1)
curr[i] -= cookies[idx]
# end for
if len(curr) < k: # open new bin
curr.append(cookies[idx])
dfs(curr,idx+1)
curr.pop()
return
#end dfs
dfs([],0)
return best
# 36ms, rank=97%
```
[(題目連結) 1655. Distribute Repeating Integers (Hard)](https://leetcode.com/problems/distribute-repeating-integers/)
題目說有一群數字$nums$,數字也許很多,但最多不超過50種不同的數字。此外,有一些需求,第$i$個需求是一個正整數$q[i]$,我們要從$nums$中挑出$q[i]$個相同的整數來滿足第$i$個需求,挑的是哪一個數字都可以,但必須是相同的。請問是否可以滿足所有的需求,當然,每一個$nums[j]$不能分給兩個人。
本題有DP的解法,請參考 [Python-LeetCode 581 第16招 Dynamic Programming III: Tree and Subset](/9Fwnt2MoSDGZ-VgCzQVzTQ)。這裡我們介紹用Branching algorithm的解法。
數字的內容不重要,重要的是相同的數字有幾個,我們把相同的數字看成箱子,如果$nums$中有7個3,我們看成有一個容量是7的箱子。現在就是要將每一個需求$q[i]$放到一個箱子裡面,而一個箱子裡面裝需求總量不可以超過它的容量。這是一個decision problem,要回答是否可以滿足所有需求,也就是把所有需求裝入箱子裡。
在分支的時候,相同剩餘容量的箱子只需要考慮其中一個,因此我們可以把箱子的剩餘容量相同的一起納入考量。以下是範例程式。一開始先用Counter統計箱子的大小,再把箱子容量用Counter歸類,第4行$bin$的內容是{箱子容量:個數}。
第7行的branch其參數$i$是目前處理第幾個需求,$currBin$是目前箱子的剩餘情形,它是個Counter,初值就是$bin$。第8行是終端情形,全部裝箱完畢就回傳True。接下來考慮各種剩餘容量的箱子,因為dict在歷遍過程中不可更改,第10行我們將它轉成list後歷遍。第11行排除無法裝的情形,其他的用backtracking的方式搜尋,將原容量數量減一,剩餘容量的數量加一。第15 ~ 16在搜尋後再予以回復。
```python=
class Solution:
# backtracking
def canDistribute(self, nums: List[int], quantity: List[int]) -> bool:
bins = Counter(Counter(nums).values()) # size of bin
quantity.sort(reverse=True) # larger q has less choices
m = len(quantity)
def branch(i,currBin): # i-th quantity, current bin size
if i==m: return True
# dict cannot be changed while iterating
for s,q in list(currBin.items()): # use a bin for q[i]
if s<quantity[i] or q<=0: continue
currBin[s] -= 1
currBin[s-quantity[i]] += 1 # remaining capacity
if branch(i+1,currBin): return True
currBin[s-quantity[i]] -= 1 #backtrack
currBin[s] += 1
return False
#
return branch(0,bins)
# 642ms, rank = 94%
```
## 結語
對於一些找不到有效率算法的問題來說,Branching algorithm是最後的算法手段之一,但Branching algorithm並非只是單純的暴力解,有很多技巧可以減少搜尋空間來加速搜尋。純暴搜與折半枚舉的worst case time complexity是比較可以估計的,其他牽涉剪枝的方法的worst case往往只能估算一個upper bound而難以估計tight bound,這可能也是程式競賽中較少出現branching algorithm題目的原因,或者出現時只當做中等題而非決勝的難題。另外,不少問題同時存在Branching與DP的兩種解法,DP的worst case是比較確定的,但兩種方法哪個比較好往往是難以判定的。
在研究領域的剪枝手法可能非常複雜,但並不適合在程式比賽與考試中出現,在考試與比賽中出現的通常僅限於簡單的剪枝。
就筆者的經驗來說,不少學生對於Branching algorithm非常陌生,有不少「看似應該會的人」(學校好成績好的學生),即使是單純的遞迴暴搜子集合,也不太會寫,這是有點納悶令人不解的事,或許是學程式時遞迴沒有充分的了解吧。
演算法始於Branching algorithm,終於Branching algorithm,可謂有始有終。在無法可施的時候,Branching還是有其需要的,近一二十年,演算法學術界對Exponential time algorithm也有比過去多的研究與進展,但那些新的技術尚未廣泛的進入程式競賽中。