# AP325 精熟到登峰 上篇 > 作者:沈宗叡(暴力又被TLE) ## 一、教材說明 [AP325-Python](https://hackmd.io/@bangyewu/Hy2kbYLI6/%2Fg2kqHh5_Q4eQnz-mfNu3Kw) 是由吳邦一教授所撰寫的演算法教材,原為 C++ 版本(2020 年完成),後於 2024 年 1 月完成 Python 版本,主要針對已具備基礎程式語法知識的學習者,透過解題方式來介紹資料結構的使用與演算法思維,程度約對應 APCS 實作題從三級分到五級分的進階需求。教材包含基礎知識說明、題目講解,並附有測資供讀者練習。 而此份教材是筆者在寫完 AP325 並真的實作五級分後,又二刷並將所有題目都盡量推到最佳解法的題解,部分題目有邏輯上比 AP325 本篇更快的解法,而大部分則是著重在實作上的優化,建議在完整寫完 AP325 後再閱讀。適合對程式實作、演算法問題或者是 Python 底層機制有強烈興趣的讀者。 本文採用 [TCIRC judge](https://judge.tcirc.tw/problems/?category=3) 的題目順序撰寫,這個 Judge 對 Python 比較不友好,沒有多開時限,因此常數偏大的解可能過不了,也有些題目就算用 PyPy 仍完全不可能過得了;相較之下 [HWSH Judge](https://judge.hwsh.tc.edu.tw/Problems?tabid=apcs325#tab03) [只有一題](#021) Python 過不了,而 PyPy 可過。但後者的題目順序很亂,而且跳號問題嚴重;前者則比較輕微(但還是沒有跟講義的順序一模一樣,因此建議搭配上面的題單連結閱讀),且 Python 版本比較新,有比較多新東西可以講。總之如果在 TCIRC judge 上過不了,可以到 HWSH Judge 上試試看。 ## 二、優化方法綜整 ### 1. 比較 General 的優化 筆者寫的 [上一篇學習歷程](https://hackmd.io/@ericshen19555/APCS_optimized#4-優化策略概述) 裡有提到 Python 寫演算法題目的一些基本優化方法,內容大差不差。 - IO 優化 預設的 `input` 會用 `sys.stdin.write` 先印出使用者輸入提示(就算是空字串),再使用 `sys.stdout.readline` 讀取使用者輸入,但只能一次讀一整行,並且會自動將最後的換行字元去掉,開銷頗大。 以純 Python 表示就是: ```python= from sys import stdin, stdout def input(prompt: str = "") -> str: # 印出使用者輸入提示 stdout.write(prompt) stdout.flush() # 讀取輸入 並去掉結尾換行\n return stdin.readline().rstrip("\n") ``` 為了更靈活、更快速的輸入,可以直接對 `sys.stdin` 執行讀檔動作,例如以下範例: ```python= from sys import stdin # 一次讀入一整行 保留結尾換行字元 line: str = stdin.readline() # 正常來講這個就夠用了 # 讀入全部 分行存入一個 list 保留結尾換行字元 lines: list[str] = stdin.readlines() # 最快 但耗記憶體 # 把 stdin 當作 iterator # 讀入一行 line: str = next(stdin) # 等價於 stdin.readline() # 一次讀入一行 (當遇到題目以 EOF 結尾時 這樣寫最漂亮) (但 APCS 不會有 EOF) for line in stdin: # 結果上等價於 for line in stdin.readlines() ... # 一次讀入整個檔案純文字 不分行 text: str = stdin.read() # 基本上用不到 # 讀入指定字數 (當輸入檔只有一行 而且大到直接輸入會 MLE) (APCS 遇不到) text: str = stdin.read(100) ``` 預設的 `print` 一樣偏慢,他接受的傳入參數很自由,並且會自動加上換行。 以純 Python 表示就是: ```python= from sys import stdout def print(*objects, sep=" ", end="\n", file=None, flush=False) -> None: # 預設輸出到 sys.stdout if file is None: file = stdout # 將objects轉換成str 並使用sep連接 最後加上end結尾 text: str = sep.join(map(str, objects)) + end # 輸出 file.write(text) # 清除緩存 if flush: file.flush() ``` 只需要保留其中真正用來輸出的部分就好了: ```python= from sys import stdout stdout.write(text) ``` 但還需要自己加上換行字元,太麻煩了。就算是多行輸出,也可以先將每行存入一個 `list[str]` 內再使用 ```"\n".join(lines)``` 組合成一個字串,然後輸出就好了。 總之,一般我輸出還是都用 `print` 就好了,畢竟函數調用次數不多,開銷差距也就不大。 - 提升變數調用速度 Python 中,在程式碼被編譯成字節碼後,變數調用最主要有兩種:`LOAD_NAME` 和 `LOAD_FAST`(還有其他涉及閉包或 comprehension 的,且不同版本可能不同),`LOAD_NAME` 需要在局部、全局、內建命名空間三個地方尋找這個變數,因此非常慢;而 `LOAD_FAST` 只會在局部尋找變數,因此十分快。 ```python= a = 10 print(a) # LOAD_NAME 慢 def f(): a = 10 print(a) # LOAD_FAST 快 ``` 為了讓變數調用的程式碼都被編譯成 `LOAD_FAST`,我們可以將所有程式碼都塞進一個函數 `main` 裡面,編譯器就知道這個變數只需要在這個函數裡面尋找就好了。 - 減少自訂函數調用 Python 中的函數調用會創建新的 stack frame,開銷超大,尤其在大部分 Judge 的舊版本(Python 3.6)這個部分都還沒被優化過,因此非常慢。 再加上 Python 有 $1000$ 的遞迴深度限制,超過就會直接被 `RecursionError`,雖然可以用 `sys.setrecursionlimit` 設定,最高可以到 $2^{31} - 1$,也就是 C int 的範圍,但龐大的記憶體開銷還是會讓你吃 `MLE`。 因此基本上我都會用 Bottom-up 的邏輯來 DP,不然就用 stack 來替代遞迴,雖然實作難度直線上升,但這就是 Python 仔的宿命。只有在 backtracking 或是 merge sort 之類,遞迴深度在 20 層左右的狀況下我會使用自訂函數(就連 $1e5$ 量級的 DFS 我都一律用 stack 手刻)。 - 不要用 `try ... except` 和 `class` 競程的實作上幾乎完全無這兩者的用武之地,他們的開銷之巨難以想像,前者應該好好用 `if` 判斷;後者用 `list` 實作即可,他不只常數較小,`__repr__` 也實作好了,debug 比較方便,完全可以取代 class 的 attribute。 如果覺得用 index 當 attribute name 不好讀也難修改,可以用以下小技巧: ```python= class Node: def __init__(self, val: int, nxt: "Node"): self.val = val self.nxt = nxt node = Node(0, None) print(node.val) # 可以換成 VAL, NXT = range(2) # 以大寫命名比較不容易搞混 node = [0, None] print(node[VAL]) ``` - 快取 object 的 attributes 如果要連續 access 一個 object 的 attribute,像是以下例子: ```python= def main(): dic = {"a": 0, "b": 1, "c": 2} for i in input(): print(dic.get(i)) main() ``` `dic.get` 地方就有優化的空間。分析他的字節碼,要先用 `LOAD_FAST` 找到 `dic`,再使用 `LOAD_ATTR` 在 `dic.__dict__` 裡面尋找 `get`,而 `dic.__dict__` 通常是一個哈希表,這個開銷就又開始大了。 因此我們可以利用 Python 美妙的引用系統,直接拿一個變數指向 `dic.get`,這樣每次調用這個函數的時候都只需要一次 `LOAD_FAST` 就好了,快得很。 ```python= def main(): dic = {"a": 0, "b": 1, "c": 2} get = dic.get for i in input(): print(get(i)) main() ``` - 模版 綜合以上的基本優化方法,我們可以得到一個程式模板: ```python= def main(): # 加速變數調用 from sys import stdin e = stdin.readline # IO加速 並快取attribute(順帶增加打字速度) ... main() ``` ### 2. Stack 優化 - 快取最後一項 對於一個 Stack,有用的就只有最後一項,因此我們可以將 Stack 的最後一項存入一個變數 `top`,這樣就可以避免掉 Python 超慢的 list access。 如果在初始化時將 `top` 初始化成適當的值,像是 $\infty$ 之類的,還可以避免掉一些 edge case 的額外判斷。 - 快取函式 Stack 常用的操作就只有 `pop`, `append`,因此我們可以用變數把他們快取起來。 - 以下是一個單調棧的簡單例子: ```python= def main(): from sys import stdin e = stdin.readline stk = [] # 遞減單調棧 top = float("INF") # sentinel pop, append = stk.pop, stk.append # 快取函式 for val in map(int, e().split()): while top < val: # 需要注意去除尾端元素的邏輯 top = pop() # 需要注意插入至尾端時的邏輯 append(top) top = val ... main() ``` ### 3. [分治優化](https://www.facebook.com/share/p/15eUYWW9gb/) 只要區間切得好,複雜度在 worst case 沒有爛掉,分治的理論複雜度通常不錯,但切到後面區間越來越短的時候,因為每層遞迴有較大的固定消耗,效率相對而言會越來越低。 所以區間足夠短的時候,可以用理論複雜度差一些但常數較小的演算法直接暴力解決。 除此之外,分治常常會搭配 merge sort 使用,而在大部分情況下,第一層遞迴(也就是完整的區間 $[0, n)$)是沒有必要將左右兩個區間 merge sort 的,可以省不少時間。 ## 三、全題解 ### [P-1-1. 合成函數(1)](https://judge.tcirc.tw/problem/d001) 基礎遞迴題,但練習使用 Stack 實作。 重點在 else case 裡面的 while 迴圈,尤其是將變數 `i` 代入公式後直接 `pop` 掉一層遞迴 Stack 的這個操作,類似於一次 `return i`。 $Time: O(n), Space: O(n)$ ```python= def main(): from sys import stdin e = stdin.readline # f(x) = 2x - 1 ; g(x, y) = x + 2y - 3 l = e().split() # 輸入每個字元 stk = [] # Stack top = None # Stack 的頂端 初始化為 None 方便邊界檢查 append, pop = stk.append, stk.pop # 快取 Stack 常用函式 for i in l: if i == "f": # f(x) append(top) # 只有一個 parameter 不需要插空格 top = 0 # 表示目前的函數類型為 f elif i == "g": # g(x, y) append(top) append(None) # 有兩個 parameters 需要插一個空格做暫存 top = 1 # 表示目前的函數類型為 g else: i = int(i) # 轉成數字 代入函數 while top is not None: if top == 0: # 前一個函數是 f i = 2 * i - 1 # 直接代入公式 不需要暫存 top = pop() # 將代表 f 的 1 pop 掉 else: # top == 1 x = pop() # 檢查當前的數字是第一或二個 argument if x is None: # 現在填入的數字是第一個 argument append(i) break # 沒有更多的 return 跳出迴圈 else: # 現在填入的數字是第二個 argument i = x + 2 * i - 3 # 將兩個數字代入公式 top = pop() # 將代表 g 的 2 pop 掉 else: # top is None 代表整個運算式結束 break # 其實不用 break 程式也會自動結束 因為接下來沒有更多輸入了 print(i) main() ``` 先補充一下 `while ... else` 這個語法,大家或許有聽過 Python 有 `for ... else` 這個語法,其實這兩者的邏輯是完全一樣的,也就是「`else` 區塊會在迴圈內沒有 `break` 的時候執行」。 舉個例子: ```python= # for ... else for i in range(3): print(i) if i == 2: print("break!") break else: print("else!") # 輸出: 0 1 2 break! print(f"{i = }") # i = 2 for i in range(3): print(i) if i == 1000: # Unreachable code print("break!") break else: print("else!") # 輸出: 0 1 2 else! print(f"{i = }") # i = 2 # while 同理 i = 0 while i < 3: print(i) if i == 2: print("break!") break i += 1 else: print("else!") # 輸出: 0 1 2 break! print(f"{i = }") # i = 2 i = 0 while i < 3: print(i) if i == 1000: print("break!") break i += 1 else: print("else!") # 輸出: 0 1 2 else! print(f"{i = }") # i = 3 <- 注意 while 最後一輪有多做一次 i += 1 ``` ### [Q-1-2. 合成函數(2)](https://judge.tcirc.tw/problem/d002) 跟上面一模一樣,主要是多了一個有三個 parameters 的 h 函數。 最大的不同在於 while 迴圈內 `top == 2` 的 case 中,判斷要填入第幾個 argument 的部分。 $Time: O(n), Space: O(n)$ ```python= def main(): from sys import stdin e = stdin.readline # f(x) = 2x - 3 ; g(x, y) = 2x + y - 7 ; h(x, y, z) = 3x - 2y + z l = e().split() stk = [] top = None append, pop = stk.append, stk.pop for i in l: if i == "f": append(top) top = 0 elif i == "g": append(top) append(None) top = 1 elif i == "h": append(top) # 三個 parameter 需要開兩格空間暫存 append(None) append(None) top = 2 else: i = int(i) while top is not None: if top == 0: i = 2 * i - 3 top = pop() elif top == 1: x = pop() if x is None: append(i) break else: i = 2 * x + i - 7 top = pop() else: # top == 2 # 先看第一個 argmunet 填了沒 x = stk[-2] # (注意取值 index) (這邊沒有 pop) if x is None: stk[-2] = i # 填入第一格 break else: y = pop() # 再看第二個 argument 填了沒 if y is None: append(i) break else: i = 3 * x - 2 * y + i pop() # pop 掉第一個 argument (即 x) top = pop() else: break print(i) main() ``` ### [P-1-3. 棍子中點切割](https://judge.tcirc.tw/problem/d003) 這題的 edge case 算是挺麻煩的,比較難在簡潔度和效率間權衡。 實作想法是手上一直拿著一根棍子,可以減少一些對 Stack 的 append 和 pop 操作,提升效率。 二分搜可以直接使用 `bisect.bisect_left`,[官方文件](https://docs.python.org/zh-tw/3/library/bisect.html#searching-sorted-lists) 的範例寫得超貼心,~~我忘記怎麼寫的時候都會偷翻~~。 有趣的是 APCS 檢測時是可以翻得到官方文件的,只是介面比較陽春,在 IDLE Shell 視窗打上 `help()` 或是在程式碼裡加上 `help()` 再執行都可以查詢。 示意圖: ![image](https://hackmd.io/_uploads/Hy5XWH5DJe.png) $Time: O(n\log_2 n), Space: O(n)$ ```python= def main(): from sys import stdin from bisect import bisect_left e = stdin.readline n, r = map(int, e().split()) l = (0, *map(int, e().split()), r) # 插入左右端點 方便取中點 ans = 0 stk = [] i, j = 0, n + 1 # 正在處理的閉區間 [i, j] while True: ans += l[j] - l[i] # 加上成本 if i + 2 == j: # 只有一個切點 不用繼續切 if stk: # 找下一根待處理的棍子 i, j = stk.pop() continue break # 沒棍子了 結束遞迴 # 要找切點 s = l[i] + l[j] idx = bisect_left(l, s >> 1, lo = i + 1, hi = j - 1) # 由中點二分搜靠右切點 # 如果 可以往左 and 左切點與中點的距離 <= 右切點與中點的距離 if idx > i + 1 and l[idx] + l[idx - 1] >= s: idx -= 1 # 左側優先 # 更新遞迴資訊 if i + 2 <= idx: # 如果左側棍子還能切 if idx + 2 <= j: # 如果右邊棍子也還能切 stk.append((idx, j)) # 右邊棍子等等遞迴 i, j = i, idx # 拿左邊棍子 else: # 不然就拿右棍子 i, j = idx, j print(ans) main() ``` <span id="004"></span> ### [Q-1-4. 支點切割](https://judge.tcirc.tw/problem/d004) 我的 [前一篇學習歷程](https://hackmd.io/UtqevkNyQDeXzkCzuLemdw?view#55-%E9%82%8F%E8%BC%AF%E5%84%AA%E5%8C%96:~:text=1\)%0A%E2%80%8Bmain()-,f638%20%E6%94%AF%E9%BB%9E%E5%88%87%E5%89%B2,-%E5%8F%83%E8%80%83%E8%A7%A3%E9%A1%8C%E5%A0%B1%E5%91%8A) 也有收錄這題 APCS 考古題。 參考解題報告: [找支點O(1)的方法](https://zerojudge.tw/ShowThread?postid=30025&reply=0) 此題限制切割層級 $K<30$,因此使用遞迴可以 AC,但還是練習用 Stack。 這題一般來講就是用 sliding window 的邏輯,線性複雜度找切點;但也可以藉由億些前綴和寫出區間力矩和的 $O(1)$ 查值(常數稍大),再對區間內二分搜就能 $O(\log_2 n)$ 找切點;然而這遠沒有直接套質心公式來得好寫又快速,而套公式需要注意的就是各種 edge case,選的切點不能在邊緣、小數如何四捨五入...等等。 [`itertools.accumulate`](https://docs.python.org/zh-tw/3/library/itertools.html#itertools.accumulate) 跟 C++ 的 `std::accumulate` 作用不一樣,C++ 的更接近 Python 的 [`functools.reduce`](https://docs.python.org/zh-tw/3.13/library/functools.html#functools.reduce)。 簡單來講,預設的情況下就是對傳入的 `iterable` 做出前綴和的表(第 15 行),回傳的是一個 `iterator`,當然也可以指定特殊的運算函數,像是 `accumulate(it, mul)` 就會做出階乘表。 而 `initial` 則是 Python 3.8 才加入的新功能,就等價於把這一項加到傳入的 `iterable` 的第一項,但 APCS 內用不了。 [`itertools.starmap`](https://docs.python.org/zh-tw/3/library/itertools.html#itertools.starmap) 其實就是一般的 `map`,只是加上了一個 `*` 用於引數傳入函數時的 unpacking,也就是說在對 `iterable` 的每一項 parameter 傳入時,`map` 是 `function(a, b)`;而 `starmap` 則是 `function(*c)`,結果論而言是 `map(func, iterable1, itertable2)` 等價於 `starmap(func, zip(iterable1, itertable2))`。當然,`zip` 起來再丟給 `starmap` 做 unpacking 是脫褲子放屁,我從來只會將 `starmap` 跟 `enumerate` 搭配使用。 `int.__mul__` 是取整數用於相乘的函式,`int.__mul__(a, b)` 等價於 `a.__mul__(b)` 也等價於 `a * b`(要把前兩個的差別解釋清楚就得提到 class 和 OOD,好個大坑!還是請讀者自己跟 ChatGPT 求教吧!),更 pythonic 的寫法是用 `operation.mul`,但 import 會有性能損耗,所以我通常直接拿 magic method。 $Time: O(n+2^k), Space: O(n+2^k)$ ```python= def main(): from sys import stdin from itertools import accumulate, starmap e = stdin.readline """ 質心公式: (x 表示位置, m表示質量) sum(xi * mi for xi, mi in zip(x, m)) / sum(m) """ n, k = map(int, e().split()) l = tuple(map(int, e().split())) # 求質心需要用到區間和 把前綴和準備好 (1-based) # >= Python 3.8 accumulate(initial=) p = tuple(accumulate(l, initial=0)) xp = tuple(accumulate(starmap(int.__mul__, enumerate(l)), initial=0)) # < Python 3.8 # p = (0, *accumulate(l)) # xp = (0, *accumulate(starmap(int.__mul__, enumerate(l)))) ans = 0 # 初始化 區間為整個棍子(左閉右開) 切割層級1 stk = [(0, n, 1)] append, pop = stk.append, stk.pop while stk: # 取出一個區間 i, j, d = pop() # 計算切點 sxp, sp = xp[j] - xp[i], p[j] - p[i] # 前綴和取區間 (1-based) q, r = divmod(sxp, sp) if r > sp >> 1: # 四捨五入選靠近實際切點的位置 (>而不是>= 因為左側優先) q += 1 # 不能切邊邊 校正回歸 if q <= i: q = i + 1 elif q >= j - 1: q = j - 2 # 更新答案 ans += l[q] # 分割新的區間 if d < k: # 確保切割層級 # 左半邊 t = q - i # 該區間長度 if t == 3: # 長度剛好為3 可以直接更新答案 ans += l[i + 1] elif t > 3: # >3 就存起來等等切 append((i, q, d + 1)) # <3 不能切 直接丟掉 # 右半邊 同理 t = j - q - 1 if t == 3: ans += l[q + 2] elif t > 3: append((q + 1, j, d + 1)) print(ans) main() ``` ### [Q-1-5. 二維黑白影像編碼](https://judge.tcirc.tw/problem/d005) 我的 [前一篇學習歷程](https://hackmd.io/UtqevkNyQDeXzkCzuLemdw?view#54-%E5%84%AA%E5%8C%96stack-%E9%81%BF%E5%85%8D%E9%81%9E%E8%BF%B4:~:text=n%22\)%0A%E2%80%8Bmain()-,f637%20DF%2Dexpression,-%E9%A1%8C%E7%9B%AE%E7%B5%A6%E7%9A%84) 也有收錄這題 APCS 考古題。 題目給的 $n \le 1024$,前面就有提到使用函式遞迴的話很有可能會吃 `RecursionError`,而且函式遞迴的常數很大,而此題各層遞迴之間的變數傳遞關係很單純,因此繼續練習 Stack 的寫法。 一般寫法可能是紀錄邊長,紀錄答案時再平方,而我的寫法是直接紀錄面積,層級變化時再 `*4` 或 `/4` (位元運算左右移兩位);Stack 中紀錄的是當前區塊還剩多少子區塊尚未處理,並使用 `cur` 變數快取 Stack 的最後一項。 `stk` 初始化時放的 `None` 可以是任何值,只是為了在全部遞迴結束後,最後一次 `pop()` 時不會 `IndexError: pop from empty list`,當然也可以將 `cur` 初始化為 `2`,但這樣誤導性有點強。 $Time: O(n), Space: O(n)$ ```python= def main(): from sys import stdin e = stdin.readline s = e().strip() area = int(e()) ** 2 ans = 0 stk = [None] # None 是為了在全部遞迴結束後 最後一次 pop() 不會 IndexError cur = 1 # 初始化為一大塊 append, pop = stk.append, stk.pop # 快取函數 for i in s: # 處理目前這塊 cur -= 1 if i != "2": # i == "0" or "1" # 加上黑色面積 if i == "1": ans += area # 將已經處理完的色塊pop掉 把面積*4 while cur == 0: cur = pop() area <<= 2 else: # i == "2" # 面積/4 分成四小塊 area >>= 2 append(cur) cur = 4 print(ans) main() ``` <span id="006"></span> ### [P-1-7. 子集合乘積](https://judge.tcirc.tw/problem/d006) 這題會用到費馬小定理,需要算 $a$ 的模逆元 $a^{p - 2} \mod p$,因此會需要用到快速冪,Python 內建的 [`pow(base, exp, mod)`](https://docs.python.org/zh-tw/3.13/library/functions.html#pow) 本身就支援快速冪,不需要手刻,直接用就好了,但如果要比賽還是要會,矩陣快速冪是很基礎的考點。 有趣的是在 Python 3.8 以後,只要 `exp` 填負數就會變成計算模逆元,可以少做一些大數減法(對 Python 來說很有差),但 APCS 還是不能用(雖然 APCS 也不會考快速冪和費馬小定理)。 接下來講講 [`collections.defaultdict`](https://docs.python.org/zh-tw/3/library/collections.html#defaultdict-objects),`defaultdict` 是 `dict` 的子類,也就是說他繼承了 `dict` 的所有功能,唯一有區別的就是 `__missing__` 這個函式(雖然他在 C 裡沒有實際被實作成一個獨立的 magic method),最主要影響到的就是 `__getitem__` 這個 method,也就是 `val = dic[key]`,當 `key` 不存在的時候一般的 `dict` 會報 `KeyError`,而 `defaultdict` 則是會在 `key` 不存在時創建一個預設值,這個預設值從哪來?這就是 `defaultdict(default_factory)` 在初始化(準確來說是實例化)時傳入的 `default_factory` 的用處,這個預設值就是 `default_factory()`,為何要傳入函數而不是直接傳入值?因為如果直接傳入值,之後每個 `__missing__` 創建的 `key` 都會指向這個 object,就亂套了,類似 `[[]] * n` 這個初學者容易犯的小錯誤。 注意到,`defaultdict` 只有 override `__getitem__` 這個 method,也就是說 `get` 不受影響,`get` 在遇到不存在的值的時候還是預設會回傳 `None` 且不建立任何新鍵值對,下面的例子就是用這個特性節省時間。 幾個 `defaultdict` 常常搭配的 `default_factory`: ```python= from collections import defaultdict # 常用 defaultdict(int) # 0 用來替代 Counter defaultdict(list) # [] 鄰接串列存圖 # 偶爾 defaultdict(set) # set() defaultdict(dict) # {} # 冷門 defaultdict(bool) # False # 字典樹 nested_dict = lambda: defaultdict(nested_dict) trie = nested_dict() ``` 其實還有另一個專門用來計數的 [`collections.Counter`](https://docs.python.org/zh-tw/3/library/collections.html#collections.Counter),功能超級多超好用,但缺點是超級肥超級慢,一般來講用不到這麼多功能,因此我都用 `defaultdict(int)` 更快一些。 Python 3.8 還加入了 [`:=` 海象運算子](https://docs.python.org/zh-tw/3.13/reference/expressions.html#assignment-expressions),就是更自由的變數賦值,在以下例子裡也成就了一定的性能提升,既避免了 `val` 不存在時,比較需要注意的是他的運算優先級非常低,要記得括號,還有記得 APCS 不能用。 最後查值的部分還有一個小巧思:遍歷 `sa` 對 `sb` 查值,因為在 `n` 是奇數的時候,`sa` 會比 `sb` 小一半,而 `dict` 的查值又是 $O(1)$,因此遍歷小的效率更高。 $Time: O(2^{n / 2}\log_2p) \space Space: O(2^{n/2})$ ```python= def main(): from sys import stdin from itertools import islice from collections import defaultdict e = stdin.readline p = 10009 n = int(e()) l = map(int, e().split()) # 折半枚舉 sa, sb = defaultdict(int), defaultdict(int) sa[1] = sb[1] = 1 # 方便建表 後面會扣掉 for i in islice(l, n >> 1): # 取前半部 for k, v in tuple(sa.items()): sa[k * i % p] += v for i in l: # 取剩下的 for k, v in tuple(sb.items()): sb[k * i % p] += v # 配對 sa sb 並扣掉一個1*1的對 # >= Python 3.8 pow(a, -1, p), := print(sum(v * u for k, v in sa.items() if (u := sb.get(pow(k, -1, p))) is not None) - 1) # < Python 3.8 # print(sum(v * sb[pow(k, p-2, p)] for k, v in sa.items()) - 1) main() ``` <span id="007"></span> ### [Q-1-8. 子集合的和](https://judge.tcirc.tw/problem/d007) 我就直接放最佳解了。 這種題型有兩種可能,第一種:$a_i$ 很小,就是經典的背包,如果不帶權可以直接用一個 `int` 用來表示一個 `list[bool]`,Python 的大數位元運算雖然超級慢,但也比列表操作快上一個量級,而最後要求的答案通常跟 lowbit `x & -x` 有關,lowbit 就是指一個二進位數字裡面最右側的一個 $1$,舉個例子,`0b10010000` 的 lowbit 就是 `0b10000`,lowbit 的性質很不錯,是後面會提到的 Binary Indexed Tree 用到的重要性值;第二種:$a_i$ 超大,$n$ 則在 $30$ 上下,那就是折半枚舉,也就是這題的解法。 先看看第一種解法會怎麼做: $Time: O(n \times p), Space: O(p)$ ```python= def main(): from sys import stdin e = stdin.readline n, p = map(int, e().split()) a = map(int, e().split()) bit = 1 mask = (1 << p + 1) - 1 for i in a: bit |= bit << i bit &= mask print(bit.bit_length() - 1) main() ``` 在 TLE 前就 MLE 了。 第二種解法: 為了節省建表時間和記憶體,把其中一個數字先拿出來,在後面查表的時候再順便算,`sb` 就可以少複製一遍。 枚舉的部分,為了不要重複加到同一輪枚舉的數字,使用 [`itertools.islice`](https://docs.python.org/zh-tw/3.13/library/itertools.html#itertools.islice) 限制枚舉的範圍,`islice(it)` 大概像是 `list[slice()]` 的概念,最主要的區別是 `islice` 輸入是 `iterable`,不一定需要能 random access,而回傳的是 iterator;`list[slice()]` 則是回傳一個全新的 `list`,在 `list` 很大的時候就會燒雞,也是初學者常常犯的錯誤之一,但很小的時候卻又會有亮眼的表現,所以 Python 枚舉題的壓常是玄學,我也不會。 這一次因為第一項被拿走,所以變成 $n$ 是偶數的時候 `sb` 會是 `sa` 的一半,因此排序和二分搜都是對 `sb` 做,可以讓 $\log$ 項比較小(但也就小了 $1$?)。 $Time: O(n \times 2^{n/2}), Space: O(2^{n/2})$ ```python= def main(): from sys import stdin, stdout from bisect import bisect_right from itertools import islice e = stdin.readline n, p = map(int, e().split()) l = map(int, e().split()) x = next(l) # 第一項最後計算的時候再加入 ans = 0 sa, sb = [0], [0] for u in islice(l, n >> 1): for v in islice(sa, len(sa)): v += u if v <= p: sa.append(v) for u in l: for v in islice(sb, len(sb)): v += u if v <= p: sb.append(v) sb.sort() mx = sb[-1] for v in sa: if ans - mx < v: # 簡易減枝 減少二分搜次數 # 嘗試未加第一項 k = bisect_right(sb, p - v) - 1 ans = max(ans, v + sb[k]) v += x if ans - mx < v <= p: # 簡易減枝 減少二分搜次數 # 加上第一項 k = bisect_right(sb, p - v) - 1 ans = max(ans, v + sb[k]) stdout.write(str(ans)) # 上 IO 加速才好過 d019 main() ``` 另一種用 [`itertools.combinations`](https://docs.python.org/zh-tw/3.13/library/itertools.html#itertools.combinations) 枚舉的方法,有 C 的黑魔法加持再加上惰性的生成,空間複雜度只有 $O(n)$,十分省記憶體(上面的解的一半),效率意外地也不錯。 [`itertools.chain.from_iterable`](https://docs.python.org/zh-tw/3.13/library/itertools.html#itertools.chain.from_iterable) 則是用來將一個 `iterable[iterable]` 裡面的一個個 `iterable` 展開,連接在一起變成一長串。 ```python= def main(): from sys import stdin, stdout from bisect import bisect_right from itertools import islice, chain, combinations e = stdin.readline n, p = map(int, e().split()) l = map(int, e().split()) x = next(l) # 第一項最後計算的時候再加入 # 枚舉前半部 it = tuple(islice(l, n >> 1)) r = len(it) + 1 # 對每種挑選數量生成子集合 就是 C(n, 1) + C(n, 2) + ... + C(n, n) 的概念 sa = sorted(v for v in map(sum, chain.from_iterable(combinations(it, i) for i in range(1, r))) if v <= p) ub, mx = p - sa[0], sa[-1] ans = 0 it = tuple(l) r = len(it) + 1 # 惰性枚舉 省記憶體 for v in map(sum, chain.from_iterable(combinations(it, i) for i in range(1, r))): if ans - mx < v <= ub: # 嘗試未加第一項 k = bisect_right(sa, p - v) - 1 ans = max(ans, v + sa[k]) v += x if ans - mx < v <= ub: # 加上第一項 k = bisect_right(sa, p - v) - 1 ans = max(ans, v + sa[k]) stdout.write(str(ans)) # 得上 IO 加速才過得了 d019 main() ``` ### [Q-1-10. 最多得分的皇后](https://judge.tcirc.tw/problem/d008) 這題原本在舊的中一中 Judge 上面我不管怎麼寫都 TLE(或許是我那時候還太菜 QwQ),後來在[社團上發問](https://www.facebook.com/share/p/19HUUuxZjs/)後,可以說是我第一次寫了一個「真正的剪枝」,也就是 `ub` 的部分,我是將每一行各取最大值總加,每一列同理,然後兩個方向取值比較小的當上界,如果加上這個理想值還是無法更新答案,就不需要繼續 DFS 下去了。 很多人可能都不知道 `enumerate` 還有第二個 `parameter`,預設值是 `0`,簡單來說就是從多少開始數。下面的例子也可以寫成 `enumerate(tmp, 1)` 就好了,但這個東西有點冷門,我還是會加上 `start=` 增加可讀性。 `nonlocal` 和 `global` 的差別也是個很重要的坑,[這個影片](https://youtu.be/tLcDQhy9Ew0)講得非常好。 ```python= def main(): from sys import stdin e = stdin.readline def backtracking(l, i=0, cur=0): nonlocal best # 要對外層變數賦值需要用 nonlocal 宣告 best = max(best, cur) # 更新答案 if i == n: return # 到底了 backtracking # 計算 upper_bound 用於剪枝 ub = min(sum(max(i) for i in l), sum(max(i) for i in zip(*l))) if cur + ub <= best: return # 不可能更新答案 剪枝 # 先嘗試跳過這行 backtracking([i.copy() for i in l[1:]], i + 1, cur) for j in range(n): # 窮舉每個位置 v = l[0][j] if v == 0: continue # 放不了 # 放下去 把攻擊範圍內的格子刪掉 tmp = [i.copy() for i in l[1:]] # 複製一份 for idx, row in enumerate(tmp, start=1): row[j] = 0 if j + idx < n: row[j + idx] = 0 if j - idx >= 0: row[j - idx] = 0 backtracking(tmp, i + 1, cur + v) n = int(e()) l = [list(map(int, e().split())) for i in range(n)] best = 0 backtracking(l) print(best) main() ``` WIP:DLX解 ### [Q-1-11. 刪除矩形邊界](https://judge.tcirc.tw/problem/d009) 我就直接放最佳解了。 Bottom-up DP,四維稍微麻煩一點但也還好,index 填對了就沒什麼問題。加上前綴和,快速算成本。 $Time: O(m^2n^2), Space: O(m^2n^2)$ ```python= def main(): from sys import stdin from itertools import accumulate e = stdin.readline m, n = map(int, e().split()) area = m * n l = tuple(map(tuple, (map(int, e().split()) for _ in range(m)))) # >= Python 3.8 accumulate(initial) pr = tuple(tuple(accumulate(row, initial=0)) for row in l) pc = tuple(tuple(accumulate(col, initial=0)) for col in zip(*l)) # < Python 3.8 # pr = tuple((0, *accumulate(row)) for row in l) # pc = tuple((0, *accumulate(col)) for col in zip(*l)) # dp[t][s][j][i] = 將 [s, t) [i, j) 這個矩陣刪除的最小花費 dp = [[[[0] * j for j in range(n + 1)] for _ in range(i)] for i in range(m + 1)] # bottom-up # 所有大小為1的矩陣成本為0 初始化後就不用dp 所以矩陣大小 r, c 都從2開始 for r in range(2, m + 1): for s in range(m - r + 1): t = s + r chunk = dp[t][s] for c in range(2, n + 1): for i in range(n - c + 1): j = i + c # 欲求最小花費 初始化為整個圖的面積(無限大的意思) res = area # 以四個方向收縮為前置狀態 加上刪除的花費 x = pc[i][t] - pc[i][s] x = min(x, r - x) # 和翻轉後的取小 res = min(res, chunk[j][i + 1] + x) x = pc[j - 1][t] - pc[j - 1][s] x = min(x, r - x) res = min(res, chunk[j - 1][i] + x) x = pr[s][j] - pr[s][i] x = min(x, c - x) res = min(res, dp[t][s + 1][j][i] + x) x = pr[t - 1][j] - pr[t - 1][i] x = min(x, c - x) res = min(res, dp[t - 1][s][j][i] + x) chunk[j][i] = res print(dp[m][0][n][0]) main() ``` ### [P-2-1. 不同的數—排序](https://judge.tcirc.tw/problem/d010) Python 的內建排序演算法是所有語言裡面最好的,沒有之一。 大部分人可能都還認為 Python 的排序演算法是 Tim Sort,但在 Python 3.11 之後 Python 的內建排序法就被 Power Sort 取代了,因為 Tim Sort 在極端狀況下可能會在分治時將區間切成一個很大一個很小,導致合併時效率很差,而 Power Sort 就是改善了這個問題的 Tim Sort,名字中的 "Power" 就是取自他切區間的方式,好像是用 $2$ 的次方去算出甚麼 magic number,再用他來切區間,但我實在找不到談論詳細細節的介紹。 $Time: O(n\log_2 n), Space: O(n)$ ```python= def main(): from sys import stdin e = stdin.readline e() l = sorted(set(map(int, e().split()))) # 去重並排序 print(len(l)) print(*l) main() ``` ### [P-2-2. 離散化 – sort](https://judge.tcirc.tw/problem/d011) 沒什麼好聊了... $Time: O(n\log_2 n), Space: O(n)$ ```python= def main(): from sys import stdin e = stdin.readline e() l = tuple(map(int, e().split())) k = {v: i for i, v in enumerate(sorted(set(l)))}.__getitem__ print(*map(k, l)) main() ``` ### [P-2-3. 快速冪](https://judge.tcirc.tw/problem/d012) 別急,等等就會手刻快速冪了。 $Time: O(\log_2 n), Space: O(1)$ ```python= print(pow(*map(int, input().split()))) ``` ### [Q-2-4. 快速冪--200 位整數](https://judge.tcirc.tw/problem/d013) Python 的大數雖然慢到爆,但用起來還是挺爽的。 沒錯,程式碼跟上一題一模一樣。 $Time: O(\log_2 n), Space: O(1)$ ```python= print(pow(*map(int, input().split()))) ``` ### [Q-2-5. 快速計算費式數列第 n 項](https://judge.tcirc.tw/problem/d014) 因為是以 $-1$ 結尾而不處理,不想寫 if 判斷這個 case,所以就 iterate 下一項,處理這一項,用這個邏輯就不會計算到最後一項了,只會把他讀進來。 矩陣乘法的部分因為對稱,所以只有三項,直接寫死。 真要寫我會寫這個: ```python= def matmul(a, b): return tuple( tuple(sum(map(mul, row, col)) % mod for col in zip(*b)) for row in a ) ``` $Time: O(t\log_2 n), Space: O(1)$ ```python= def main(): from sys import stdin e = stdin.readline def fab(n): if n < 3: return int(n > 0) # 前三項直接特判 n -= 2 # 轉移次數 d = p = (1, 1, 0) while n: if n & 1: p = matmul(p, d) d = matmul(d, d) n >>= 1 return p[0] def matmul(a, b): x = (a[0] * b[0] + a[1] * b[1]) % p y = (a[0] * b[1] + a[1] * b[2]) % p z = (a[1] * b[1] + a[2] * b[2]) % p return (x, y, z) p = 1000000007 it = map(int, stdin) n = next(it) for nxt in it: print(fab(n)) n = nxt main() ``` ### [P-2-6. Two-Number problem](https://judge.tcirc.tw/problem/d015) 挑小的做成 `set` 以增加效率。 $Time: O(m + n), Space: O(\min(m, n))$ ```python= def main(): from sys import stdin e = stdin.readline m, n, k = map(int, e().split()) a = map(int, e().split()) b = map(int, e().split()) if m > n: a, b = b, a a = set(a) print(sum(1 for i in b if k - i in a)) main() ``` ### [Q-2-7. 互補團隊](https://judge.tcirc.tw/problem/d016) 考點是如何有效率地表示集合並快速找出補集,二進位就是一個很好的集合表示方式,常數小效率高。 比較怪異的是將字母轉換成 bit 的方法,用 `1 << ord(c) - 65` 反而會比用一個 `dict` 建表再查表慢,主要是因為前者用了比較多 Python 層面的操作(而且大數的位元操作很慢),而後者則全部是 C 層面的操作,乾淨俐落。 $Time: O(n), Space: O(n)$ ```python= def main(): from sys import stdin # 將大寫字母對應到每個的 bit get = {'A': 1, 'B': 2, 'C': 4, 'D': 8, 'E': 16, 'F': 32, 'G': 64, 'H': 128, 'I': 256, 'J': 512, 'K': 1024, 'L': 2048, 'M': 4096, 'N': 8192, 'O': 16384, 'P': 32768, 'Q': 65536, 'R': 131072, 'S': 262144, 'T': 524288, 'U': 1048576, 'V': 2097152, 'W': 4194304, 'X': 8388608, 'Y': 16777216, 'Z': 33554432}.__getitem__ m, n = map(int, stdin.readline().split()) s = set() # 哪些子集合有出現 add = s.add # 快取函數 for it in map(str.rstrip, stdin): # 每行輸入並去掉結尾換行 # 將字母集合轉換成一個二進位數字表示的集合 bit = 0 for i in map(get, it): # 將大寫字母 轉換成二進位 bit bit |= i add(bit) every = (1 << m) - 1 # 總共有哪些人 # 對每個子集合查詢其補集是否存在 答案為 互補組數/2 -1 print(sum(1 for i in s if (every & ~i) in s) >> 1) main() ``` ### [Q-2-8. 模逆元](https://judge.tcirc.tw/problem/d017) 裸題,直接費馬小定理。 $Time: O(n\log_2p), Space: O(1)$ ```python= p = int(input().split()[1]) # n 用不到 # >= Python 3.8 pow(exp < 0) print(*(pow(i, -1, p) for i in map(int, input().split()))) # < Python 3.8 # print(*(pow(i, p-2, p) for i in map(int, input().split()))) ``` ### [P-2-9. 子集合乘積](https://judge.tcirc.tw/problem/d018) 約等於 [P-1-7. 子集合乘積](#006),但這題 $n$ 比較大,可以少枚舉一個數字,最後查表時再考慮。 $n$ 為偶數時,`sb` 會是 `sa` 的一半,因此遍歷 `sb` 對 `sa` 查表。 $Time: O(2^{n / 2}\log_2p) \space Space: O(2^{n/2})$ 模逆元直接用 Python 3.8 以後的寫法。 ```python= def main(): from sys import stdin from collections import defaultdict from itertools import islice e = stdin.readline n, p = map(int, e().split()) l = map(int, e().split()) x = next(l) # 第一項最後計算的時候再加入 sa, sb = defaultdict(int), defaultdict(int) sa[1] = sb[1] = 1 # 便於建枚舉表 for u in islice(l, n >> 1): # 使用 tuple() 複製一份字典 才不會重複乘到同一個數字 for v, c in tuple(sa.items()): sa[u * v % p] += c for u in l: # 使用 tuple() 複製一份字典 才不會重複乘到同一個數字 for v, c in tuple(sb.items()): sb[u * v % p] += c ans = 0 get = sa.get # 快取函式 for v, c in sb.items(): ans += c * ( get(pow(v, -1, p), 0) + get(pow(v * x, -1, p), 0) ) print((ans - 1) % p) # 扣掉 1 * 1 的 case main() ``` ### [Q-2-10. 子集合的和](https://judge.tcirc.tw/problem/d019) 題解完全同 [Q-1-8. 子集合的和](#007)。 ### [P-2-11. 最接近的區間和](https://judge.tcirc.tw/problem/d020) Python 沒有內建 Red Black Tree,僅有的 `sortedcontainers.SortedList` 也僅僅是第三方套件,我所知的範圍內只有 LeetCode 裡面有裝,還是少數沒有事先 import 的套件。沒有內建的有序容器是 Python 除了效能以外的另一大劣勢。 第一種方法是用 `bisect.insort` 硬插,時間複雜度 $O(n)$,但有 C 的黑魔法加持,常數很小,如果測資沒有特別針對的話過得了(甚至很快)。 $Time: O(n^2), Space: O(n)$ ```python= def main(): from sys import stdin from bisect import bisect_left, insort from itertools import accumulate e = stdin.readline n, k = map(int, e().split()) # 儲存所有前綴和 s = [0] # 初始化為 [0] 讓第一輪的 s[-1] 不報錯 ans = 0 for v in accumulate(map(int, e().split())): # 做前綴和 # 對所有前綴和二分搜 找出配對起來最接近而不超過 k 的 if v - s[-1] <= k: # 篩掉必定超過 k 的 case (s[idx] 會 Index Error) idx = bisect_left(s, v - k) ans = max(ans, v - s[idx]) insort(s, v) # 插入前綴和 維護升冪 O(n) print(ans) main() ``` 第二種方法是手刻一個陽春的 `sortedcontainers.SortedList`,~~因為直接把原始碼整個貼上去會被 Judge 嗆~~,Python 的 `SortedList` 不是使用 RBT,而是用二維分塊的邏輯,將 $1000$ 量級的列表操作視為常數,當一個塊的長度超過 $2000$ 的時候,就將其均分成兩塊,實作上這個操作由 `_expand` 函數負責;當一個塊的長度小於 $500$ 的時候,就與前或後的塊合併,再對合併的塊執行 `_expand`。 這題因為只有插入操作,所以只需要刻 `_expand` 就好了。 `SortedList` 原始碼: ```python= def _expand(self, pos): """Split sublists with length greater than double the load-factor. Updates the index when the sublist length is less than double the load level. This requires incrementing the nodes in a traversal from the leaf node to the root. For an example traversal see ``SortedList._loc``. """ _load = self._load # 理想區間長度 預設1000 _lists = self._lists # 儲存所有塊 _index = self._index # 用以支援 index 取值的線段樹 if len(_lists[pos]) > (_load << 1): # 超過限制 分成兩塊 _maxes = self._maxes # 用以快取每塊最大值(即每塊最後一項) _lists_pos = _lists[pos] # 要切的塊 half = _lists_pos[_load:] # 複製右半塊 del _lists_pos[_load:] # 刪除右半塊 _maxes[pos] = _lists_pos[-1] # 更新左半塊的最大值 _lists.insert(pos + 1, half) # 將右半塊插入 _maxes.insert(pos + 1, half[-1]) # 插入右半塊最大值 del _index[:] # index 線段樹需要砍掉重種 else: if _index: # 如果線段樹在可維護狀態的話就更新他 child = self._offset + pos while child: _index[child] += 1 child = (child - 1) >> 1 _index[0] += 1 ``` 以 $\sqrt n$ 作為理想分塊長度。 $Time: O(n \sqrt n), Space: O(n)$ ```python= def main(): from sys import stdin from bisect import bisect_left, insort from itertools import accumulate e = stdin.readline n, k = map(int, e().split()) load = int(n ** 0.5) # 取 isqrt(n) 為門檻值 lists = [[0]] # 儲存所有塊 maxes = [0] # 快取每塊最大值 用於第一層二分搜 ans = 0 for v in accumulate(map(int, e().split())): # 做前綴和 # 對所有前綴和二分搜 找出配對起來最接近而不超過 k 的 if v - maxes[-1] <= k: # 篩掉必定超過 k 的 case t = v - k # 二分搜目標值 i = bisect_left(maxes, t) # 第一層二分搜 鎖定塊 chunk = lists[i] j = bisect_left(chunk, t) # 第二層二分搜 鎖定目標 ans = max(ans, v - chunk[j]) # 更新答案 # 插入前綴和 維護升冪 if v >= maxes[-1]: if v == maxes[-1]: continue # 剛好跟目前最大值一樣 跳過 i = len(maxes) - 1 # 加到最後一個塊尾端 block = lists[-1] block.append(v) maxes[-1] = v # 更新最後一個塊的最大值 else: i = bisect_left(maxes, v) # 第一層二分搜 鎖定塊 block = lists[i] insort(block, v) # 第二層二分搜 插入值 # _expand if len(block) >= (load << 1): # 太長 要切塊 half = block[load:] # 取右半塊 del block[load:] # 刪掉右半塊 maxes[i] = block[-1] # 更新左半塊最大值 lists.insert(i + 1, half) # 插入右半塊 maxes.insert(i + 1, half[-1]) # 插入右半塊最大值 print(ans) main() ``` 還有另一種方法不會用到資結,而且複雜度是漂亮的 $O(n\log_2 n)$,那就是分治,最一開始是在寫 [下一題](#021) 的時候不管怎樣都被 TLE,[求助](https://www.facebook.com/groups/359446638362710/posts/1274119463562085/?comment_id=1274166743557357) 於吳邦一教授後他提點的解法: > 這題當初是為了 C++ 設計的,原測資大小的話,Python 即使用了 SortedList 都過不了。我在 Python 版有把測資放小一點。如果 SortedList 不能用的話,可以用 divide and conquer 代替,**跨過中線的解把兩邊的 prefix sum 與 suffix sum 分別排序,用二分搜或者雙指標爬行**,時間 $O(M^2N\log^2N)$。 我最後想出的做法則是先對整個數列做前綴和,跨過中線的解用雙指標配對左右端點,再右減左得出區間。 實作上可以善用 Python 的 iterator,熟練後可以更輕鬆地撰寫這種雙指標的麻煩邏輯(好不好寫比較算個人偏好),並使用 `for` 和 `next(iterator, None)` 漂亮地應對 `StopIteration`,跑起來比 `list[idx]` 還要快。 需要注意的是 merge 的時候,因為都要先從 iterator 內預取一個數字出來比較,最後在左右串列其中一個結束、另一個還剩的時候,可別忘了把這個預取的數字 merge 進去。 那個 `for ... while ... else ... else` 的迴圈可能比較難理解,我標上了各個流程控制指令會跳到哪一行,可以多多琢磨一下。 ```python= def main(): from sys import stdin from itertools import accumulate e = stdin.readline # 分治 [s, t) 時間 O(nlogn) def merge(s, t): # -> list[int] nonlocal ans # 宣告取得上一層的變數 # 結束條件: 區間長度 == 1 if s + 1 == t: return [l[s]] # 切一半 分治 mid = s + t >> 1 le = merge(s, mid) ri = merge(mid, t) # 雙指標處理跨中線的區間 it = iter(le) u = next(it) for v in ri: # 對於每個右端點 # 找到相減區間 <= k 的左端點 while (val := v - u) > k: u = next(it, None) if u is None: break # 左端點沒了 else: # 更新答案 if val > ans: ans = val continue break # 左端點沒了就結束配對 # 將區間和 merge sort res = [] le, ri = iter(le), iter(ri) u = next(le) for v in ri: while u <= v: res.append(u) u = next(le, None) if u is None: # left end res.append(v) # 記得加這項 res.extend(ri) break # -> 51 else: # left not end res.append(v) continue # -> 41 ... 53 break # left end -> 56 else: # left not end res.append(u) # 記得加這項 res.extend(le) return res n, k = map(int, e().split()) # 做前綴和 # >= Python 3.8 accumulate(initial=) l = tuple(accumulate(map(int, e().split()), initial=0)) # < Python 3.8 # l = (0, *accumulate(map(int, e().split()))) ans = 0 merge(0, n + 1) # 分治 [0, n + 1) print(ans) main() ``` <span id="021"></span> ### [Q-2-12. 最接近的子矩陣和](https://judge.tcirc.tw/problem/d021) 為了過這題我可是煞費苦心。 標準解法是對 $m$ 方向做前綴和,再窮舉任兩行相減,變成一個個直向區間和,再對這一列由區間和形成的數列求「最接近的子區間和」,時間複雜度 $O(m^2n\log_2 n)$。 但這一算起來,大概有 $5.6e7$ 的計算量,Python 得跑幾十秒(可不是十幾秒)。 在「觀察」了測資(把每個測資的 $m, n$ 偷出來)以後,發現主要只有兩種極端測資:$(1, 3e6)$ 和 $(50, 6000)$,因此我們就針對這兩種大小的測資做優化,當然如果有甚麼 $(10, 3e5)$ 之類的測資的話就是融合這兩種極端的解法再去細調。 對於 $(1, 3e6)$,就是「最接近的子區間和」,只是將結束條件改成用常數小的 $O(n)$ 解法暴力求長度 $5000$ 以內的區間。 對於 $(50, 6000)$,因為總共有 $\binom{50}{2} = 1225$ 個數列要做「最接近的子區間和」,實在太多了!因此我們可以設法篩選掉一些:如果當前數列的「最大子區間和」比目前得到的「最接近的子區間和」還要小,那就代表這個數列不可能更新答案,可以跳過。 而分治——尤其是 merge sort 的部分——受限於 Python 的列表操作速度而超級無敵慢,為了不 merge sort,我們乾脆只分治一層,也就是只切成左右區間,個別 $O(n^2)$ 暴力算,再雙指標處理跨中線的區間,如果左右的前綴和形成的最大區間比目前的「最接近的子區間和」還要小,那就代表任一組跨中線的區間不可能更新答案,或是如果左右的前綴和形成的最小區間比 $k$ 還要大,就代表任一組跨中線的區間都不合,可以跳過。 如果好奇如果不切兩塊,而是切成三塊、四塊,再枚舉所有左右配對的情況做雙指標,一樣不用 merge sort,會不會比較快?我測過,不會,主要是因為雙指標計算的部分有太多純 Python 操作,很慢。 總之,結果 Python 還是過不了,PyPy 則能打敗部分 C++ 的正常解。 $Time: \begin{cases} O(n \log_2 n), & \text{if } m = 1, \\ O(m^2 n^2), & \text{otherwise.} \end{cases}, Space: O(n)$ ```python= def main(): from sys import stdin from bisect import bisect_left, insort from itertools import accumulate, combinations, islice e = stdin.readline add, sub = int.__add__, int.__sub__ # 快取加減法函數 k = int(e()) m, n = map(int, e().split()) if m == 1: # 矩陣退化為一行 n可以很大 用分治比較穩定 # 將區間分治處理 def merge(s, t): # -> list[int] nonlocal ans # 區間長度 <= 5000 時直接暴力求 O(n²) if t - s <= 5000: it = islice(l, t - s) # 限制取的區間 # 第一項特殊處理 mx = next(it) res = [mx] if ans < mx <= k: ans = mx # 繼續處理後面的 for p in it: if p - mx <= k: idx = bisect_left(res, p - k) ans = max(ans, p - res[idx]) insort(res, p) if p > mx: mx = p return res # 切一半 分治 mid = s + t >> 1 le = merge(s, mid) ri = merge(mid, t) # 雙指標處理跨中線的區間 it = iter(le) u = next(it) for v in ri: # 對於每個右端點 # 找到相減區間 <= k 的左端點 while (val := v - u) > k: u = next(it, None) if u is None: break # 左端點沒了 else: # 更新答案 if val > ans: ans = val continue break # 左端點沒了就結束配對 # 第一層遞迴不用merge sort if s == 0 and t > n: return # 將區間和 merge sort res = [] le, ri = iter(le), iter(ri) u = next(le) for v in ri: while u <= v: res.append(u) u = next(le, None) if u is None: # left end res.append(v) # 記得加這項 res.extend(ri) break else: # left not end res.append(v) continue break # left end else: # left not end res.append(u) # 記得加這項 res.extend(le) return res ans = 0 # >= Python 3.8 accumulate(initial) l = accumulate(map(int, e().split()), initial=0) # < Python 3.8 用 chain 代替 # l = chain((0, ), accumulate(map(int, e().split()))) merge(0, n + 1) else: p = (0,) * n grid = [p] for i in range(m): # 將每一直列做前綴和 p = tuple(map(add, p, map(int, e().split()))) grid.append(p) ans = 0 for a, b in combinations(grid, 2): # 選取前綴和兩行 # map(sub, b, a): 相減變成區間 # 先做一次最大連續區間和 如果 <= ans 就剪枝 dp = res = 0 for v in map(sub, b, a): # 卡丹算法 dp += v if dp > res: res = dp elif dp < 0: dp = 0 if res > ans: break # 可能更新答案 else: continue # 不可能更新答案 剪枝 it = accumulate(map(sub, b, a)) # 左右各自 O(n²) 暴力求 mx = 0 le = [0] for v in islice(it, n >> 1): if v - mx <= k: idx = bisect_left(le, v - k) ans = max(ans, v - le[idx]) insort(le, v) if v > mx: mx = v mx = 0 ri = [0] for v in it: if v - mx <= k: idx = bisect_left(ri, v - k) ans = max(ans, v - ri[idx]) insort(ri, v) if v > mx: mx = v # 最大區間和比當前答案還小 或 最小區間和超過 k 就剪枝 if ri[-1] - le[0] <= ans or ri[0] - le[-1] > k: continue # 雙指標處理跨中線的區間 le = iter(le) u = next(le) for v in ri: while v - u > k: u = next(le, None) if u is None: break # 左端點沒了 else: v -= u # 相減成區間 if v > ans: ans = v # 更新答案 continue # 繼續雙指標枚舉 break # 左端點沒了就跳出 # 只有分治一層 不用 merge sort print(ans) main() ``` ### [Q-2-13. 無理數的快速冪](https://judge.tcirc.tw/problem/d022) 上一題好恐怖,來個簡單的快速冪壓壓驚。 這題的特色就是無理數乘法而已。 $Time: O(\log_2 n), Space: O(1)$ ```python= def main(): from sys import stdin e = stdin.readline p = 1000000009 i, j, n = map(int, e().split()) n -= 1 # 次方: 1 -> n, 轉移 n - 1 次 x, y = i, j while n > 0: if n & 1 == 1: i, j = (i * x + 2 * j * y) % p, (i * y + j * x) % p x, y = (x * x + 2 * y * y) % p, (x * y << 1) % p n >>= 1 print(i, j) main() ``` ### [Q-2-14. 水槽](https://judge.tcirc.tw/problem/d023) 這題把雙指標耍得十分精妙。 這題實作上的小細節主要在找隔板的部分,一樣是用 iterator 以減少 `list` 取值操作,也減少實作難度。 [`operator.itemgetter(idx)`](https://docs.python.org/zh-tw/3.13/library/operator.html#operator.itemgetter) 就是創建一個專門取 `idx` 這個索引取值函式,等價於 `lambda obj: obj[idx]`,最常與 `sort`, `bisect` 和 `groupby` 搭配使用。 差分序列 `ans` 只維護 $n - 1$ 格的資料,而多開一格可以在區間修改時少寫 edge case 判斷,只是最後還原時要記得把最後一格刪掉。 $Time: O(\text{sort} + n), Space: O(n)$ ```python= def main(): from sys import stdin from itertools import accumulate from operator import itemgetter e = stdin.readline n, x, w = map(int, e().split()) # 按照高度降冪排列 保留index # 做成iterator後取其next函式 nxt = iter(sorted(enumerate(map(int, e().split())), key=itemgetter(1), reverse=True)).__next__ ans = [0] * n # 差分序列 方便區間修改 s, t = 0, n - 1 # 維護一個還沒填水的區間 (為左右擋板的index) while True: if s == t - 1: # 單格水直接放 ans[s] += w ans[t] -= w break # 找到此區間內最高的隔板 i, h = nxt() # index, height while not s < i < t: i, h = nxt() # 剩餘水量 > 整個區間加到最大隔板高度的體積 # 整個區間平均分配剩餘水量 l = t - s if w > h * l: h = w // l ans[s] += h ans[t] -= h break # 看水量是否越過擋板 # 沒有越過就縮小區間 # 有越過就把當前區間填滿 並移到空區間繼續填 if x < i: # 注水於左側 v = h * (i - s) if w <= v: # 水不會越過 t = i # 將區間收斂於左側 else: # 水會越過(到右邊) 區間修改左側 ans[s] += h ans[i] -= h w -= v # 水變少 s = i # 接著看右側區間 else: # x >= i 注水於右側 v = h * (t - i) if w <= v: # 水不會越過 s = i # 將區間收斂於右側 else: # 水會越過(到左邊) 區間修改右側 ans[i] += h ans[t] -= h w -= v # 水變少 t = i # 接著看左側區間 # 將差分還原 ans.pop() # 去掉結尾的 sentinel print(*accumulate(ans)) main() ``` <span id="024"></span> ### [P-2-15. 圓環出口](https://judge.tcirc.tw/problem/d024) 環狀數列題最簡單的應對方式就是把數列複製,接到後面,這樣就能無痛解決頭尾的問題。 而以下是不用此技巧,踏踏實實判斷 edge case 的解法。 這個解漂亮的部分在於建立前綴和的時候,最前面多加了一個 $0$,而使數列變成 1-based,讓「完成任務後停在下一個房間」這件事情變得十分自然,但最後別忘了要 `% n`。 $Time: O(n \log_2 n), Space: O(n)$ ```python= def main(): from sys import stdin from bisect import bisect_left from itertools import accumulate e = stdin.readline n, m = map(int, e().split()) # 做成前綴和 # >= Python 3.8 accumulate(initial) p = tuple(accumulate(map(int, e().split()), initial=0)) # < Python 3.8 # p = (0, *accumulate(map(int, e().split()))) last = p[-1] # 快取前綴和最後一項 idx = 0 for i in map(int, e().split()): # 試試看走到底夠不夠 rest = last - p[idx] if rest == i: # 剛剛好 idx = 0 else: if rest < i: # 走到底還不夠 要繞回頭 # 那就先走到底再二分搜 i -= rest idx = 0 idx = bisect_left(p, p[idx] + i) print(0 if idx == n else idx) # idx % n 的意思 main() ``` ### [P-3-1. 樹的高度與根](https://judge.tcirc.tw/problem/d025) 基本圖論題。 直覺想法:首先找出根節點,只要看誰不是任何人的子節點就好了。接下來從根 DFS 下去,當前節點的高度就是最高的子節點 +1。 ```python= # 找根節點 isroot = [True] * n for children in tree: # 假設 tree 紀錄每個節點的所有子節點 for node in children: isroot[node] = False root = isroot.index(True) # DFS def dfs(node: int) -> int: # 回傳當前節點高度 return max((dfs(child) + 1 for child in tree[node]), default=0) dfs(root) # 就可以找出根節點的高 ``` 這樣的複雜度無疑是好的,時間空間都是 $O(n)$,但別忘了我們在寫 Python,`dfs()` 的遞迴深度會是樹的深度,最差的情況就是樹是一條直鏈,深度可以達到 $1e5$ 量級,遠遠超過 $1000$ 的深度限制,當然也可以用 Stack 手刻遞迴,但不好寫,那何不換個方向?從上往下麻煩,就由下往上唄! 大方向是「每一輪處理所有高度相同的節點」,子節點數量是 $0$ 的就是葉子,也是遍歷的第一層,而遍歷一個節點前必須先遍歷其所有的子節點,下面這棵樹的遍歷順序就是 $[4, 5, 6], [2, 3], [1], [0]$。可以發現這樣的遍歷方式有很好的性質:第 $i$ 輪遍歷到的節點之高度就是 $i - 1$,而且最後一層只會有根節點,這樣問題就整個解決了! ![image](https://hackmd.io/_uploads/BJfuXu0P1e.png) 如何實作?將輸入改成紀錄每個節點的父節點,方便由下往上溯源,並紀錄每個節點的度數 `indeg`,每次遍歷到一個節點就將其父節點的 `indeg` 減少,當父節點的 `indeg` 為 $0$ 就代表他所有的子節點都遍歷過了,需要加入下一輪遍歷。 一般的 BFS 實作為了避免 `list.pop(0)` 這種燒雞寫法主要有兩種方式:使用 `collections.deque` 雙向佇列的 `deque.popleft`,他的複雜度是 $O(1)$;或者直接不從左邊移除元素,單純用一個指標指向當前遍歷到的地方,雖然感覺空間會燒雞,但仔細一想時間空間都是 $O(n)$,因此不會有問題,但總歸是浪費了空間。 以下這種 BFS 則是我很喜歡的一種實作方式(而且是原創),在兩種解法中間取了個折衷,避免了連續進行 `pop` 這種常數大的列表操作,又不會保留所有已經遍歷完的節點,最讚的是他可以同時維護當前的遍歷層數。 $Time: O(n), Space: O(n)$ ```python= def main(): from sys import stdin e = stdin.readline n = int(e()) pa = [None] * n indeg = [0] * n cur = [] # bfs由葉子開始 for i in range(n): child = map(int, e().split()) indeg[i] = cn = next(child) if cn == 0: # 子節點數為零就是葉子 cur.append(i) else: for j in child: # 讓每個子節點都找得到爸爸 pa[j - 1] = i ans = height = 0 while cur: nxt = [] ans += height * len(cur) # 計算當層高度總和 for i in cur: p = pa[i] if p is None: # 到根就停 break # bfs擴散 indeg[p] -= 1 if indeg[p] == 0: nxt.append(p) cur = nxt height += 1 print(i + 1) # 最後一個掃到的就是根 轉成1-based print(ans) # 節點高度總和 main() ``` ### [P-3-2. 括弧配對](https://judge.tcirc.tw/problem/d026) 括弧九成以上都是 Stack 題,這題如果只有一種括弧的話可以將 Stack 簡化為計數,達到常數空間,但有三種括弧的話就沒辦法了。 實作小技巧是將六種字元分別對應到一個數字,同時使左括號的編號 +3 後就是對應的右括號編號,這樣一來判斷就好做了!其中將 `\n` 也編號,並在 Stack 初始化時放個對應的數字,這樣就不用將輸入字串 `str.rstrip`,省下一些時間。 $Time: O(n), Space: O(n)$ ```python= def main(): from sys import stdin, stdout ans = [] trans = "([{)]}\n".index # 將括號轉換成對應編號的函式 for s in stdin: # 每行輸入 保留結尾換行 stk = [6] # sentinel 與\n配對 for c in map(trans, s): # 逐字轉換成編號 if c < 3: # 左括號 stk.append(c + 3) elif stk.pop() != c: # 右括號 # 配對不上 ans.append("no") break else: # 全部配對完才是合法 ans.append("yes" if len(stk) <= 1 else "no") stdout.write("\n".join(ans)) main() ``` ### [Q-3-3. 加減乘除](https://judge.tcirc.tw/problem/d027) 再次體會到 Python 的美妙之處。 需要注意的是 Python 的預設除法是真除法,依題意需要替換成整數除法。 $Time: O(n), Space: O(n)$ ```python= print(eval(input().replace("/", "//"))) ``` 好,認真來寫寫。 題目限制所有數字都是一位正整數,只會有加減乘除,不會有括號,那根本不用用到 Stack。 將 `+`, `-` 分隔的一個個只含乘除法的子字串稱為區塊,易知每個區塊之間不會互相影響,因此可以直接累加,而將每個區塊的結果初始化為 $1$ 的話,就只有遇到 `/` 的時候需要做除法,因為遇到 `+`, `-` 的時候必定剛初始化完,會是 $1 \times ...$ 的形式。 實作上有一點需要特別注意:為何不要根據當前是 `+`, `-` 來初始化 `cur` 為 $1$, $-1$ 就好了,還能少用一個變數 `sub`?確實可以,但 `//` 除法是向下取整,就要改成向 $0$ 取整才行,最簡單的改法是將 `cur //= c` 改成 `cur = int(cur / c)`,就是有點醜。 實作勇士可以來挑戰我出的 [暴力又被TLE 想讓人滅台](https://hackmd.io/dUgOKVDmR_aSp68oXUQDxw),可以到 [APCSS Judge](https://apcs-simulation.com/problem/apcs1703_py) 上測試。 $Time: O(n), Space: O(1)$ ```python= def main(): from sys import stdin ans, cur = 0, 1 # 加總後的答案, 目前在處理的小區塊 sub = div = False # 前一個+-是不是-, 前一個+-*/是不是/ for c in stdin.readline(): # 保留換行字元 if c.isdigit(): # 是數字 c = int(c) if div: # 只有/需要做除法 cur //= c else: # 因為cur初始化為1 不管*+-都直接*就好 cur *= c else: # 是運算子 if c in "+-\n": # 進到一個新的區塊 結算前一區塊 if sub: # 前一個區塊做的是減法 ans -= cur else: # 前一個區塊做的是加法 ans += cur sub = (c == "-") cur = 1 # 初始化為 1 方便運算 div = (c == "/") # 接下來是否做除法 print(ans) main() ``` ### [P-3-4. 最接近的高人](https://judge.tcirc.tw/problem/d028) 設好 sentinel,套上一些 Stack 優化。 $Time: O(n), Space: O(n)$ ```python= def main(): from sys import stdin e = stdin.readline n = int(e()) l = map(int, e().split()) ans = 0 stk = [] append, pop = stk.append, stk.pop ph, pi = float("INF"), -1 # 不可能超越的sentinel 位置-1 for i, v in enumerate(l): # 維護身高的單調對列 while ph <= v: ph, pi = pop() ans += i - pi # 更新答案 append((ph, pi)) ph, pi = v, i print(ans) main() ``` ### [Q-3-5. 帶著板凳排雞排的高人](https://judge.tcirc.tw/problem/d029) 就是上一題加上二分搜。 為了要能使用 `bisect`,我們要讓原本遞減的單調棧變成遞增,加上負號就好了,可以利用 `tuple` 的比較特性,把高度放在前面一欄。 當然另一種方式是分成兩個 `list`,一個存高度、一個存位置,因為 `tuple` 的比較速度慢,所以這是效率最高的方式;或是用 Python 3.10 後 `bisect` 加入的 `key`,但自訂函數比 tuple 比較還慢。 為了二分搜,不能將頂端元素另外快取。 $Time: O(n \log_2 n), Space: O(n)$ ```python= def main(): from sys import stdin from bisect import bisect_right e = stdin.readline n = int(e()) l = zip(map(int, e().split()), map(int, e().split())) ans = 0 # 身高取負號 遞增才能用bisect stk = [(float("-INF"), -1)] # 不可能超越的sentinel 位置-1 for i, (v, x) in enumerate(l): # 板凳加上身高 拿去二分搜 idx = bisect_right(stk, (-x-v, )) - 1 ans += i - stk[idx][1] - 1 # 維護身高的單調對列 while stk[-1][0] >= -v: stk.pop() stk.append((-v, i)) print(ans) main() ``` <span id="030"></span> ### [P-3-6. 砍樹](https://judge.tcirc.tw/problem/d030) 我的 [前一篇學習歷程](https://hackmd.io/UtqevkNyQDeXzkCzuLemdw?view#54-%E5%84%AA%E5%8C%96stack-%E9%81%BF%E5%85%8D%E9%81%9E%E8%BF%B4:~:text=%E7%9A%84%E5%84%AA%E5%8C%96%E3%80%82-,h028%20%E7%A0%8D%E6%A8%B9,-%E9%80%99%E9%A1%8C%E6%AF%94%E8%BC%83) 也有收錄這題 APCS 考古題。 這題比較明顯是純 Stack的題目。 比較有意思的是我 one-pass 的方式,輸入的資訊都維持 iterator 的狀態,遍歷時遍歷要砍的樹 `cur` 的右邊那棵樹 `nxt`,以 `nxt` 和前面未砍的樹 `last` (即為 `stack[-1]`)的位置判斷 `cur` 是否能往左右砍倒。 左側邊界條件,也就是 Stack 內的樹全都被砍空的情況,我在 `stack[0]` (也就是初始化時的last)放了一顆無限高的樹,因此他跟吳剛的桂樹一樣不可能被砍倒;右側邊界條件,則是用 `itertools.chain` 在輸入的最後安插了一顆位置 `R` 高度 `None` 的樹(使用None是因為其進行大部分運算時皆會報錯,便於 debug),因為只會討論 `nxt` 的前一棵樹是否能砍,右邊界 `(R, None)` 時沒有更右邊的樹了,因此程式不會考慮這顆「樹」。 $Time: O(n), Space: O(n)$ ```python= def main(): from sys import stdin from itertools import chain e = stdin.readline n, R = map(int, e().split()) p = map(int, e().split()) h = map(int, e().split()) c = m = 0 # 可砍數量 最大高度 # 初始化stack stk = [] # 儲存暫時沒辦法砍的樹的stack extend, pop = stk.extend, stk.pop # 快取函數 lp, lh = 0, float("INF") # 快取stack的最後一項 初始化為 sentinel # 製作所有樹的iterator it = chain(zip(p, h), ((R, None), )) # 加入右邊界 cp, ch = next(it) # 目前這棵樹 初始化先取第一項 for np, nh in it: # 取下一顆樹 討論目前這棵樹(cp, ch)能否被砍 # 往右砍 或 往左砍 if ch <= np - cp or ch <= cp - lp: # 可以直接砍掉 c += 1 m = max(ch, m) # 看前面的樹能不能往右砍 while lh <= np - lp: c += 1 m = max(lh, m) lh, lp = pop(), pop() else: # 往左往右都壓到 沒辦法砍 先加到stack extend((lp, lh)) lp, lh = cp, ch cp, ch = np, nh print(c, m, sep="\n") main() ``` ### [P-3-7. 正整數序列之最接近的區間和](https://judge.tcirc.tw/problem/d031) [`itertools.tee`](https://docs.python.org/zh-tw/3/library/itertools.html#itertools.tee) 是個特別有意思的函數,`iterator` 的特色就是遍歷過的東西直接丟掉,而 `tee` 則是提供了複製 `iterator` 的功能,複製出來的 $n$ 個 `iterator` 會共用傳入的那個 `iterator`,利用 linked-list 的概念(在 Python 3.11 以前似乎是用一個 queue 來維護?),暫存跑得最快和最慢的 `iterator` 之間的值,如果所有 `iterator` 都已經將某個值讀取完,這個值就會失去引用並被自然地回收掉。 這題因為維護的區間長度不固定,最長可能到 $n$,所以空間複雜度是 $O(n)$。 $Time: O(n), Space: O(n)$ ```python= def main(): from sys import stdin from itertools import tee e = stdin.readline n, k = map(int, e().split()) # 將iterator複製 l, p = tee(map(int, e().split()), 2) nxt = p.__next__ # 快取 左區間的取值函式 u = nxt() # 取左邊第一項 ans = cnt = cur = 0 for i, v in enumerate(l): cur += v # 向右擴展區間 while cur > k: # 太大 縮減左側區間 cur -= u u = nxt() # 更新最大值 if cur > ans: ans = cur cnt = 1 elif cur == ans: cnt += 1 print(ans, cnt, sep="\n") main() ``` ### [P-3-8. 固定長度區間的最大區段差](https://judge.tcirc.tw/problem/d032) 固定區間長度 $k$,`itertools.tee` 能使空間複雜度降到 $O(k)$。 [`collections.deque`](https://docs.python.org/zh-tw/3.13/library/collections.html#collections.deque) 是用 doubly-linked-list 實作的,因此兩端的操作非常快。一般對 `list` 右端的操作加個 `left` 尾綴就會是 `deque` 的左端操作,像是 `popleft`, `appendleft`, `extendleft` 等等,比較特殊的是 `rotate` 函數,會將整個 `deque` 進行輪轉,`deque.rotate(1)` 等價於 `deque.appendleft(deque.pop())`,負數則反之。 $Time: O(n), Space: O(k)$ ```python= def main(): from sys import stdin from collections import deque # 使用deque雙向隊列 方便把第一項移除 from itertools import islice, tee e = stdin.readline n, k = map(int, e().split()) l, p = tee(map(int, e().split()), 2) # 先取第一個區間 M, m = deque(maxlen=k), deque(maxlen=k) for i in islice(l, k): # 維護嚴格遞減/遞增 while M and M[-1] < i: M.pop() while m and m[-1] > i: m.pop() M.append(i) m.append(i) M0, m0 = M[0], m[0] ans = M0 - m0 for i, j in zip(l, p): # 先處理被丟掉的首項 if j == M0: M.popleft() # 首項是最大值 if j == m0: m.popleft() # 首項是最小值 # 處理最後加入的一項 跟取第一個區間時一樣 while M and M[-1] < i: M.pop() while m and m[-1] > i: m.pop() M.append(i) m.append(i) # 更新答案 M0, m0 = M[0], m[0] ans = max(ans, M0 - m0) print(ans) main() ``` ### [P-3-9. 最多色彩帶](https://judge.tcirc.tw/problem/d033) 使用一個變數 `cur` 維護當前色彩種類數,看當前增加、減少的顏色其數量是否在 $0, 1$ 之間切換。 $Time: O(n), Space: O(k)$ ```python= def main(): from sys import stdin from itertools import islice, tee e = stdin.readline n, k = map(int, e().split()) l, p = tee(map(int, e().split()), 2) c = {} cur = 0 # 先取第一個區間 for i in islice(l, k): v = c.get(i, 0) c[i] = v + 1 if v == 0: cur += 1 ans = cur for i, j in zip(l, p): # 先處理被丟掉的首項 v = c.get(j, 0) c[j] = v - 1 if v == 1: cur -= 1 # 更新此項 v = c.get(i, 0) c[i] = v + 1 if v == 0: cur += 1 # 更新答案 if cur > ans: ans = cur print(ans) main() ``` ### [P-3-10. 全彩彩帶](https://judge.tcirc.tw/problem/d034) 這題要求全彩,因此可以先將包含所有顏色的計數器字典建立好,也節省後續存取的時間。 為了減少記憶體使用也同時增加效率,將區間長度維護在「當前的答案」以下且不全彩的狀態,如果當前區間是全彩就嘗試縮減左側。 $Time: O(n), Space: O(n)$ ```python= def main(): from sys import stdin e = stdin.readline ans = int(e()) l = tuple(map(int, e().split())) c = dict.fromkeys(l, 0) # 建立所有顏色的計數器 k = len(c) # 全彩顏色數 l, p = iter(l), iter(l) nxt = p.__next__ # 目前區間長度, 區間內顏色種類數 s = t = 0 for i in l: # 維護區間長度 <= ans if s == ans: # 移除左端點 j = nxt() v = c[j] c[j] -= 1 if v == 1: t -= 1 else: s += 1 # 新增右端點 v = c[i] c[i] += 1 if v == 0: t += 1 # 達到全彩 if t == k: # 嘗試收斂左側區間 直到不再全彩 while True: # 移除左端點 s -= 1 j = nxt() v = c[j] c[j] -= 1 if v == 1: t -= 1 break # 不再全彩 # 目前的區間(長度s)不全彩 # 但加上剛移除的最後一個左端點就能全彩 # 所以 s + 1 全彩 if s < ans: ans = s + 1 print(ans) main() ``` ### [Q-3-11. 最長的相異色彩帶](https://judge.tcirc.tw/problem/d035) 維護一個異色區間,遇到區間內已經有的顏色就縮減區間直到左端點遇到該顏色。 $Time: O(n), Space: O(n)$ ```python= def main(): from sys import stdin from itertools import tee e = stdin.readline n = int(e()) l, p = tee(map(int, e().split()), 2) nxt = p.__next__ s = set() ans = cur = 0 for i in l: if i in s: # 同色 while True: # 移除左端點 j = nxt() if j == i: # 跟即將加入的右端點的同色 break # 不用移除 直接跳出 s.remove(j) cur -= 1 else: # 異色 需要加入 s.add(i) cur += 1 # 更新答案 if cur > ans: ans = cur print(ans) main() ``` ### [Q-3-12. 完美彩帶](https://judge.tcirc.tw/problem/d036) 有兩種想法:維護長度為 $k$ 的區間,檢查顏色種類是否為 $k$;或者反過來維護一個盡可能長的異色區間,看顏色種類(即區間長度)是否為 $k$。 第一種方式: $Time: O(n), Space: O(k)$ ```python= def main(): from sys import stdin from itertools import islice, tee e = stdin.readline k, n = map(int, e().split()) l, p = tee(map(int, e().split()), 2) c = {} cur = 0 # 先取第一個區間 for i in islice(l, k): v = c.get(i, 0) c[i] = v + 1 if v == 0: cur += 1 ans = 1 if cur == k else 0 for i, j in zip(l, p): # 先處理被丟掉的首項 v = c.get(j, 0) c[j] = v - 1 if v == 1: cur -= 1 # 更新此項 v = c.get(i, 0) c[i] = v + 1 if v == 0: cur += 1 # 更新答案 if cur == k: ans += 1 print(ans) main() ``` 第二種方式: $Time: O(n), Space: O(k)$ ```python= def main(): from sys import stdin from itertools import tee e = stdin.readline k, n = map(int, e().split()) l, p = tee(map(int, e().split()), 2) nxt = p.__next__ s = set() ans = cur = 0 for i in l: if i in s: # 同色 while True: # 移除左端點 j = nxt() if j == i: # 跟即將加入的右端點的同色 break # 不用移除 直接跳出 s.remove(j) cur -= 1 else: # 異色 需要加入 s.add(i) cur += 1 # 更新答案 if cur == k: ans += 1 print(ans) main() ``` ### [Q-3-13. X差值範圍內的最大Y差值](https://judge.tcirc.tw/problem/d037) 先照 $x$ 排序,維護 $\Delta x \le k$ 的區間,用兩個單調雙向對列紀錄當前 $\Delta y$。 $Time: O(\text{sort} + n), Space: O(n)$ ```python= def main(): from sys import stdin from collections import deque e = stdin.readline n, k = map(int, e().split()) # 對x座標做排序 l = sorted(zip(map(int, e().split()), map(int, e().split()))) # 左端點iterator p = iter(l) nxt = p.__next__ px, py = nxt() ans = 0 M = deque() # y最大值單調隊列 m = deque() # y最小值單調隊列 for x, y in l: # 維護兩個單調隊列 while M and M[-1] < y: M.pop() while m and m[-1] > y: m.pop() M.append(y) m.append(y) # 移除左邊的點 保持x區間在k以內 while M and x - px > k: if M[0] == py: M.popleft() if m[0] == py: m.popleft() px, py = nxt() # 更新答案 ans = max(ans, M[0] - m[0]) print(ans) main() ``` ### [Q-3-14. 線性函數](https://judge.tcirc.tw/problem/d038) 使用 Stack 維護斜率的單調性,Stack 最後狀態即為 `F(x)`。 最後讀取 Stack 的邏輯很漂亮,預取下一條線的交點,並用無限遠的交點避免觸發邊界條件。與 [P-3-6. 砍樹](#030) 的那顆吳剛的無限高桂樹有著異曲同工之妙。 $Time: O(\text{sort} + n + m), Space: O(n + m)$ ```python= def main(): from sys import stdin e = stdin.readline lm = float("-INF") # left most n, m = map(int, e().split()) # 排列: 斜率升冪 y截距降冪 l = iter(sorted((tuple(map(int, e().split())) for i in range(n)), key=lambda x: (x[0], -x[1]))) # 取第一項 px = lm pa, pb = next(l) stk = [] for a, b in l: # 斜率與前一相同 y截距較小 用不到 if a == pa: continue # 找出被新線覆蓋掉的舊線 丟掉 while True: # 求新線與最後一條舊線交點之x (沒被卡精度) x = (b - pb) / (pa - a) # 檢查覆蓋情形 if x <= px: # 覆蓋掉 px, pa, pb = stk.pop() else: # 沒覆蓋掉 跳出 break # 新增新線 stk.append((px, pa, pb)) px, pa, pb = x, a, b # 新增最後一條新線 & 右界sentinel stk.extend(((px, pa, pb), (float("INF"), None, None))) ans = 0 it = iter(stk) _, a, b = next(it) nx, na, nb = next(it) for c in sorted(map(int, e().split())): while c >= nx: # 找到此x座標所屬的最大值之線 a, b = na, nb nx, na, nb = next(it) ans += c * a + b # x代入函數求解 print(ans) main() ``` ### [P-4-1. 少林寺的代幣](https://judge.tcirc.tw/problem/d042) 基本的貪心。 來講講 `divmod` 函式(在 [Q-1-4. 支點切割](#004) 的程式碼中就有使用到),`divmod(a, b)` 就等價於 `(a // b, a % b)`,但少做一次除法,就醬。 $Time: O(t), Space: O(1)$ ```python= def main(): from sys import stdin stdin.readline() # 測資數量 用不到 l = (50, 10, 5) # 兌換代幣 由大到小 for x in map(int, stdin): # 每行輸入轉成 int ans = 0 for i in l: # 貪心地兌換 q, x = divmod(x, i) ans += q print(ans + x) # 加上兌換為 1 元的 main() ``` ### [P-4-2. 笑傲江湖之三戰](https://judge.tcirc.tw/problem/d043) 基本的貪心。 $Time: O(\text{sort} + n), Space: O(n)$ ```python= def main(): from sys import stdin e = stdin.readline e() # 人數 用不到 # 貪心地升冪排序 a = iter(sorted(map(int, e().split()))) b = sorted(map(int, e().split())) j = next(a) # 預取一項 用於比較 ans = 0 for i in b: if i > j: # 如果每場都贏 最後會 StopIteration j = next(a, None) # None 不會被比較到 ans += 1 print(ans) main() ``` ### [P-4-3. 十年磨一劍](https://judge.tcirc.tw/problem/d044) $Time: O(\text{sort} + n), Space: O(n)$ ```python= def main(): from sys import stdin from itertools import accumulate e = stdin.readline e() # 劍數 用不到 # 升冪排序後 前綴和總和 print(sum(accumulate(sorted(map(int, e().split()))))) main() ``` ### [P-4-4. 幾場華山論劍](https://judge.tcirc.tw/problem/d045) `itemgetter(1)` 等價於 `lambda x: x[1]` 但更快。 $Time: O(\text{sort} + n), Space: O(n)$ ```python= def main(): from sys import stdin from operator import itemgetter e = stdin.readline n = int(e()) # 按照結束時間排序 l = sorted((tuple(map(int, e().split())) for i in range(n)), key=itemgetter(1)) ans = 0 end = -1 # 結束時間初始化為 -1 第一場活動必取 for s, t in l: if s > end: # 貪心地參加 end = t # 更新結束時間 ans += 1 # 場次+1 print(ans) main() ``` ### [P-4-5. 嵩山磨劍坊的問題](https://judge.tcirc.tw/problem/d046) 可以先用 [`filter`](https://docs.python.org/zh-tw/3.13/library/functions.html#filter) 去掉權重為 $0$ 的訂單,`filter(function, iterable)` 等價於 `(i for i in iterable if function(i))`,但因為很少有適合的 `function` 而需要使用 `lambda`,因此我通常會寫後者,這題剛好只需要把 $0$ 去掉,而 `bool(int)` 等價於 `int != 0`,因此只需要用 `itemgetter` 取 `tuple` 的第二個即可。順帶一提 `filter(None, iterable)` 等價於 `(i for i in iterable if i)`。 而後面 `key=lambda x: x[0] / x[1]` 的部分我真的想不到甚麼好的函式可以達成一樣的功能,`operator.truediv` 等價於 `lambda a, b: a / b`,跟所需差了一次 unpacking,而 Python 也沒有內建將 function 加上一次 unpacking 的 decorator,只有 `starmap` 這樣特定的例子。 $Time: O(\text{sort} + n), Space: O(n)$ ```python= def main(): from sys import stdin from operator import itemgetter e = stdin.readline e() # 依照 時間/權重 由小排到大 (權重為0直接刪除) l = sorted(filter(itemgetter(1), zip(map(int, e().split()), map(int, e().split()))), key=lambda x: x[0] / x[1]) ans = now = 0 for t, w in l: now += t # 累加時間 ans += now * w # 乘上權重 print(ans) main() ``` ### [Q-4-6. 少林寺的自動寄物櫃](https://judge.tcirc.tw/problem/d047) 和上一題只差在累加時間和計算權重的順序。 $Time: O(\text{sort} + n), Space: O(n)$ ```python= def main(): from sys import stdin from operator import itemgetter e = stdin.readline e() # 依照時間/權重由小排到大 (權重為0直接刪除) l = sorted(filter(itemgetter(1), zip(map(int, e().split()), map(int, e().split()))), key=lambda x: x[0] / x[1]) ans = now = 0 for t, w in l: # 與上一題的差別是下面兩行調換 ans += now * w now += t print(ans) main() ``` ### [P-4-7. 岳不群的併派問題](https://judge.tcirc.tw/problem/d048) [`heapq`](https://docs.python.org/zh-tw/3.13/library/heapq.html) 是少數有純 Python 程式碼的內建函式庫(`bisect` 也有,只是他的實作很簡單),如果用 `dir(heapq)` 看一下會發現除了 `heapq.__all__` 裡面列出來的 `heappush`, `heappop`, `heapify`, `heapreplace`, `merge`, `nlargest`, `nsmallest`, `heappushpop` 以外,還有 `_heapify_max`, `_heappop_max`, `_heapreplace_max`, `_siftdown`, `_siftdown_max`, `_siftup`, `_siftup_max` 這些有趣的東西,不只有各種操作的 max-heap 版本,還有用於維護 `heap` 關鍵的 `_siftdown(heap, startpos, pos)` 和 `_siftup(heap, pos)`,了解這些函式的使用方法就可以做出一些更自由的 `heap` 操作。 話雖這麼說,但我目前也只寫過 [一題](https://colab.research.google.com/drive/1oXmT17Q90hgXNsgoFSGUpwU4nMThOemW?hl=en#scrollTo=nOrrzvdLJ4XU) 有用到這些操作的,主要是為了在不改變值的情況下由 max-heap 切換到 min-heap。 如何找到 `heapq` 純 Python 程式碼?以 Python IDLE 為例,找到 `File -> Open Module`,然後輸入套件名稱就好了! ![image](https://hackmd.io/_uploads/SkVfF0kOJx.png) 注意到 [`heapq.heapify`](https://docs.python.org/zh-tw/3.13/library/heapq.html#heapq.heapify) 的複雜度是 $O(n)$ 且 in-place ,而不是 $O(n \log_2 n)$,因此單純輸入大量數據時可以先全部加入再 `heapify`。 $Time: O(n \log_2 n), Space: O(n)$ ```python= def main(): from sys import stdin from heapq import heapify, heappop, heappush e = stdin.readline n = int(e()) q = list(map(int, e().split())) heapify(q) # O(n) ans = 0 for _ in range(n - 1): # 總共合併 n - 1 次 x = heappop(q) + heappop(q) # 貪心地取最小兩個合併 heappush(q, x) # 合併完放回去 ans += x # 增加成本 print(q[0], ans, sep="\n") main() ``` ### [P-4-9. 基地台](https://judge.tcirc.tw/problem/d049) 常見二分搜考法:對答案二分搜。 [`range(start, stop, step)`](https://docs.python.org/zh-tw/3.13/library/stdtypes.html#ranges) 是一個很有意思的 object,他除了最一般的當作 `iterable` 餵給 `for` 以外,還支援隨機存取,就像一個 `list` 或 `tuple` 一樣,應該說,他的本質就是一個 `sequence`,只是其中所有項都可以從給定的三個 arguments 用數學計算出來,因此記憶體效率十分高。 因為他是一個 `sequence`,所以理所當然地支援 `__iter__`, `__getitem__`, `__len__`, `__reversed__`(當然不能 `__setitem__`),甚至可以用 `slice` 取值、用 `index` 求某個數字在數列裡的位置,超級好用!既然能夠隨機存取,那就當然能夠用 `bisect` 下去二分搜,然而 [`bisect`](https://docs.python.org/zh-tw/3/library/bisect.html#bisect.bisect_left) 直到 Python 3.10 才加入 `key` 引數,因此「對答案二分搜」這個技巧在 APCS 裡還是只能手刻。 反正也不用怕忘記怎麼寫,可以直接把模組原始碼打開,`bisect` 跟 `heapq` 一樣也有寫純 Python 版本。 假設要搜 `[le, ri)` 這個區間,建議用 `bisect(range(ri), val, lo=le, key=func)` 這種寫法,雖然比較不直覺,但這樣可以保證 `range(ri)[x] = x`,不然就得寫 `le + bisect(range(le, ri), val, key=func)`,前面那個 `le +` 豈不是更不直覺? $Time: O(\text{sort} + n \log_2 R), Space: O(n)$ ```python= def main(): from sys import stdin from bisect import bisect_left e = stdin.readline def cover(r: int) -> bool: cnt, cov = 0, -1 for i in p: if i > cov: cnt += 1 # 基地台用完 不合 if cnt > k: return False cov = i + r return True # 完整覆蓋 n, k = map(int, e().split()) p = sorted(map(int, e().split())) # 排序各點 才好用掃描線 # >= Python 3.10 bisect(key) # r = 區間長度 // k + 1 必定可完整覆蓋 print(bisect_left(range((p[-1] - p[0]) // k), True, lo=1, key=cover)) return # 下面是舊版本的實作 # < Python 3.10 # 維護 [le, ri) 區間 le, ri = 1, (p[-1] - p[0]) // k while le < ri: mid = (le + ri) >> 1 if cover(mid) is True: ri = mid else: le = mid + 1 print(le) main() ``` ### [P-4-11. 線段聯集](https://judge.tcirc.tw/problem/d050) 基本掃描線。 別忘了結算最後一個線段。 $Time: O(\text{sort} + n), Space: O(n)$ ```python= def main(): from sys import stdin e = stdin.readline n = int(e()) # 將 (s, t) 線段聯集排列 l = sorted(tuple(map(int, e().split())) for i in range(n)) ans = 0 s = t = -1 # 最後一個線段的起終點 for i, j in l: if i < t: # 重疊 更新終點 if j > t: t = j else: # 未重疊 結算距離 更新新線段 ans += t - s s, t = i, j # 結算最後一個線段 ans += t - s print(ans) main() ``` ### [P-4-12. 一次買賣](https://judge.tcirc.tw/problem/d051) 就是當前值減去歷史最小值取大。 $Time: O(n), Space: O(1)$ ```python= def main(): from sys import stdin e = stdin.readline e() # n 用不到 l = map(int, e().split()) ans = 0 mn = next(l) # 最小值 初始化取第一項 for v in l: # 更新最小值 if v < mn: mn = v # 更新所求 (最大盈餘) v -= mn # 盈餘 if v > ans: ans = v print(ans) main() ``` 也可以寫成這樣: ```python= def main(): from sys import stdin from itertools import accumulate, tee e = stdin.readline e() # n 用不到 l, p = tee(map(int, e().split())) # 複製兩份 比較好寫 print(max(map(int.__sub__, l, accumulate(p, min)))) main() ``` ### [P-4-13. 最大連續子陣列](https://judge.tcirc.tw/problem/d052) 經典卡丹算法。 以下這種寫法可以用到「更新答案」和「負數不轉移」兩個 case 互斥的特性,但得要題目有說空陣列視為 $0$ 才能這樣寫。不敢說快了多少,心情總歸是比較舒暢。 $Time: O(n), Space: O(1)$ ```python= def main(): from sys import stdin e = stdin.readline e() # n 用不到 dp = ans = 0 for i in map(int, e().split()): dp += i if dp > ans: # 更新答案 ans = dp elif dp < 0: # 負數不轉移 dp = 0 print(ans) main() ``` <span id="053"></span> ### [Q-4-8. 先到先服務](https://judge.tcirc.tw/problem/d053) [`heapq.heapreplace`](https://docs.python.org/zh-tw/3.13/library/heapq.html#heapq.heapreplace) 等於先 `heappop` 再 `heappush`,需要和 [`heapq.heappushpop`](https://docs.python.org/zh-tw/3.13/library/heapq.html#heapq.heappushpop) 做出分別,`heappushpop` 就像字面上的一樣是先 `heappush` 再 `heappop`。 $Time: O(n \log_2 m), Space: O(m)$ ```python= def main(): from sys import stdin from heapq import heapreplace e = stdin.readline n, m = map(int, e().split()) l = map(int, e().split()) q = [0] * m for i in l: # 將新的客人交給最早結束的櫃台處理 heapreplace(q, i + q[0]) # 最晚結束的即為所求 print(max(q)) main() ``` ### [Q-4-10. 恢復能量的白雲熊膽丸](https://judge.tcirc.tw/problem/d054) 確認答案的部分有兩種寫法,一種是線性搜,一種是二分搜,就像是 [P-2-15. 圓環出口](#024) 的概念。兩者的複雜度不一樣,前者是 $O(n)$,worst case 就是闖關成功;後者是 $O(k \log_2 n)$,worst case 是 $k = n$ 且一次只過一關,這樣二分搜的範圍縮減很慢,而且搜滿 $k$ 次。 線性搜: $Time: O(n \times \log_2 \sum l), Space: O(n)$ ```python= def main(): from sys import stdin from bisect import bisect_left e = stdin.readline def check(x) -> bool: # O(n) r = cur = 0 for v in l: cur += v if cur > x: r += 1 # 藥沒了 到不了 if r == k: return False cur = v # 到終點了 return True n, k = map(int, e().split()) l = map(int, e().split()) if k == 0: # 特判: 不能吃藥只能一次幹到底 print(sum(l)) elif k == n - 1: # 特判: 每一回合都吃藥滿能量 print(max(l)) else: k += 1 # 算上一開始也是滿能量 l = tuple(l) # 對答案二分搜 # >= Python 3.10 bisect(key) print(bisect_left(range(sum(l)), True, lo=max(l), key=check)) return # 下面是舊版本的實作 # < Python 3.10 # 維護 [le, ri) 區間 while le < ri: mid = (le + ri) >> 1 if check(mid) is True: ri = mid else: le = mid + 1 print(le) main() ``` 二分搜: $Time: O(n + k \log_2 n \times \log_2 \sum l), Space: O(n)$ ```python= def main(): from sys import stdin from bisect import bisect_right from itertools import accumulate e = stdin.readline def check(x) -> bool: # O(klogn) idx = 0 for _ in range(k): # 搜下一個能量消耗 > x 的點 idx = bisect_right(l, l[idx] + x, lo=idx) - 1 # 成功到底 if idx >= n: return True # 在k+1次以內到不了 return False n, k = map(int, e().split()) l = map(int, e().split()) if k == 0: # 特判: 不能吃藥只能一次幹到底 print(sum(l)) elif k == n - 1: # 特判: 每一回合都吃藥滿能量 print(max(l)) else: k += 1 # 算上一開始也是滿能量 l = tuple(l) le = max(l) # 做前綴和 方便二分搜 # >= Python 3.8 accumulate(initial) l = tuple(accumulate(l, initial=0)) # < Python 3.8 # l = (0, *accumulate(l)) ri = l[-1] # sum(l) # 對答案二分搜 # >= Python 3.10 bisect(key) print(bisect_right(range(ri), False, lo=le, key=check)) return # 下面是舊版本的實作 # < Python 3.10 # 維護 [le, ri) 區間 while le < ri: mid = (le + ri) >> 1 if check(mid) is True: ri = mid else: le = mid + 1 print(le) main() ``` ### [P-4-14. 控制點](https://judge.tcirc.tw/problem/d055) 這一題來比較一下我和吳邦一教授寫法上的小差異。 兩段程式的邏輯完全相同,但有一些小小的差異: 首先輸入的部分,「以空白分隔並轉成數字」這個操作在 Python 內有兩種主流寫法,一種是 `[int(x) for x in input().split()]`,一種是 `map(int, input().split())`,這兩者其實不等價,前者是一個 `list`,後者是一個 `iterator`,前者是等價於 `list(map(int, input().split()))`,又後者等價於 `(int(x) for x in input().split())`。以這題來講,傳入 `zip` 中的只需要是 `iterable` 就好了,而 `zip` 回傳的也是 `iterable`,可以直接傳入接收 `iterable` 的 `sorted`,當然轉成 `list` 再 `list.sort()` 也完全沒問題。 但在 `for` 迴圈寫了 `point[::-1]`,這是許多人會忽視的一點:`list[slice]` 必定會複製並建立一個新的 list object,並造成可觀的效能損耗,如果是單純要 iterate 過去的話,可以使用 `reversed(list)`,`reversed` 回傳的是一個 `iterator`,支援 `list`, `tuple`, `str`, `range` 等常見的物件。 在 [AP325 Python 版](https://hackmd.io/@bangyewu/Hy2kbYLI6/%2FmEi9wqzjRCuj2aixy3Ea2Q#P-4-14-%E6%8E%A7%E5%88%B6%E9%BB%9E2D-max) 中,這題的範例解法: ```python= # P-4-14 2d max subarray, backward, O(nlogn) n = int(input()) px = [int(x) for x in input().split()] py = [int(x) for x in input().split()] point = list(zip(px,py)) point.sort() # smaller x, if x1=x2 then smaller y max_y = -1 # currently max y total = 0 # number of maximal points for x,y in point[::-1]: # backward if y > max_y: total += 1 # a new maximal point max_y = y # update # print(total) ``` $Time: O(\text{sort} + n), Space: O(n)$ ```python= def main(): from sys import stdin from operator import itemgetter e = stdin.readline e() # n 用不到 # x降冪 再y降冪 p = sorted(zip(map(int, e().split()), map(int, e().split())), reverse=True) ans = 0 my = -1 # 目前最大 y 初始化為 -1 確保無法控制任何點 for y in map(itemgetter(1), p): # 因為此點 x <= 右邊所有掃過的 # 所以只有 y > 右邊最大y(前一個控制點之y)時 # 這個點才不會被控制 而必須加進來 # 至於x相同時 y在降冪 所以不會重複取到 if y > my: ans += 1 my = y print(ans) main() ``` ### [P-4-15. 最靠近的一對](https://judge.tcirc.tw/problem/d056) 掃描線的寫法怎麼優化呢?既然 `sortedcontainers.SortedList` 沒有內建,就自己手刻吧! 比較麻煩的是,這題不只要實作 `add`,還有 `remove`、`irange` 要實作。`expand` 函數在很多地方都有用到,就把他獨立出來了;`remove` 的邏輯則由 `_delete` 實際執行,刪除後要確保塊夠長,否則要跟其他塊合併;`irange` 則二分搜後手動於塊與塊之間遍歷即可。 `SortedList` 原始碼: ```python= def _delete(self, pos, idx): _lists = self._lists _maxes = self._maxes _index = self._index _lists_pos = _lists[pos] # 快取目標塊 del _lists_pos[idx] # 直接硬刪 self._len -= 1 # 更新總長度 len_lists_pos = len(_lists_pos) # 目標塊新長度 if len_lists_pos > (self._load >> 1): # 夠長 _maxes[pos] = _lists_pos[-1] # 更新最大值即可 # !seg.update(pos,-1) if _index: # 維護用於 index 取值的線段樹 child = self._offset + pos while child > 0: _index[child] -= 1 child = (child - 1) >> 1 _index[0] -= 1 elif len(_lists) > 1: # 不夠長 有其他塊可以合併 if not pos: # 如果左邊沒塊就將索引往右一格 pos += 1 prev = pos - 1 # 向左合併 _lists[prev].extend(_lists[pos]) _maxes[prev] = _lists[prev][-1] # 刪除右塊 del _lists[pos] del _maxes[pos] # 需要重種 index 線段樹 del _index[:] self._expand(prev) # 維護合併後新塊的長度 elif len_lists_pos: # 不夠長 但只有自己這塊 代表總長度尚短 _maxes[pos] = _lists_pos[-1] # 更新最大值即可 else: # SortedList 全空 del _lists[pos] del _maxes[pos] del _index[:] ``` 以下的程式碼中比較難以理解的可能是 `expand_i` 這個變數以及其相關邏輯,他的功能是「處理上一輪 iterate 的 `expand(i)` 操作」,並負責把 `i` 暫存起來。為何要讓 `expand` 操作隔一輪延後執行?在 `remove` 後,如果當前塊太短,就會跟左邊的塊合併,合併後需要呼叫一次 `expand` 維護新塊長,但因為左邊的塊較右塊長,下一輪要刪除的值可能在左塊尾端,`expand` 後會跑去右塊,這樣索引就爛掉了。因此為了避免這種情況發生,間隔一個 iterate,也就是保證間隔一塊以上即可。 $Time: O(n \sqrt n), Space: O(n)$ ```python= def main(): from sys import stdin from bisect import bisect_right, insort_right e = stdin.readline inf = float("INF") def expand(i): # SortedList._expand chunk = sl[i] # 取要 expand 的塊 if len(chunk) > (LOAD << 1): # 如果塊過長 # 切半 half = chunk[LOAD:] del chunk[LOAD:] mx[i] = chunk[-1] # 更新左半最大值 # 插入右半塊 sl.insert(i + 1, half) mx.insert(i + 1, half[-1]) n = int(e()) LOAD = int((n >> 1) ** 0.5) # SortedList 理想塊長 # x升冪 再y升冪 l = sorted(tuple(map(int, e().split())) for i in range(n)) ans = inf # 先前的 (y, x) 照 y 升冪排序 sl = [] # 模擬 SortedList mx = [] # SortedList 每一區塊的最後一項 方便二分搜 rmv = [] # 要刪除的點 按照塊index分類 expand_i = None # 延後一個 iterate 再做 expand chunk for x, y in l: if not mx: # 如果 SortedList 為空 # 直接插入 val = (y, x) sl.append([val]) mx.append(val) continue # 二分搜y值下界 val = (y - ans, inf) if val < mx[-1]: # 如果下界不存在就跳過 i = bisect_right(mx, val) chunk = sl[i] nc = len(chunk) j = bisect_right(chunk, val) py, px = chunk[j] rmv.clear() ci = js = None while py < y + ans: if px <= x - ans: # 出界 要刪除 if i != ci: # 要刪除的點在新的塊內 ci, js = i, [] rmv.append((ci, js)) js.append(j) else: # 更新最短距離 ans = min(ans, abs(x - px) + abs(y - py)) # 遍歷下一個點 j += 1 if j == nc: # 此塊結束 j = 0 i += 1 # 找下一塊 if i == len(sl): break # 到底 chunk = sl[i] nc = len(chunk) py, px = chunk[j] # 取點 # 由右側往左側刪除 index 才不會改變 for i, js in reversed(rmv): # 單塊內刪除 chunk = sl[i] for j in reversed(js): del chunk[j] # expand 上一個該 expand 的塊 if expand_i is not None: expand(expand_i) expand_i = None # 維護塊長 nc = len(chunk) if nc > (LOAD << 1): # 長度夠 mx[-1] = chunk[-1] # 維護最大值即可 elif len(mx) > 1: # 長度不夠 可以合併 if i: i -= 1 # 預設向左合併 sl[i].extend(sl[i + 1]) mx[i] = sl[i][-1] del sl[i + 1] del mx[i + 1] expand_i = i # 合併後須要維護塊長 elif nc: # 長度不夠 但只有此塊 即SortedList本身很短 mx[-1] = chunk[-1] # 維護最大值即可 else: # 清空 SortedList del sl[i] del mx[i] # expand 最後一組 if expand_i is not None: expand(expand_i) expand_i = None # 插入新點 val = (y, x) if not mx: # 如果 SortedList 為空 # 直接插入 val = (y, x) sl.append([val]) mx.append(val) elif val >= mx[-1]: # 特判插入尾端 sl[-1].append(val) mx[-1] = val expand(len(mx) - 1) # 維護塊長 else: # 插入中間 i = bisect_right(mx, val) insort_right(sl[i], val) expand(i) # 維護塊長 print(ans) main() ``` 還有 $O(n \log_2 n)$ 的分治解、$O(n)$ 的神奇隨機解、一次 `sort` 後線性解的「人類智慧」之極致展現,就賣個關子留到 [第五章節](https://hackmd.io/@ericshen19555/AP325_Mastering_to_Peak_Part2#P-5-5-%E6%9C%80%E9%9D%A0%E8%BF%91%E7%9A%84%E4%B8%80%E5%B0%8D) 吧,不然那章好空虛 qwq。 ### [Q-4-16. 賺錢與罰款](https://judge.tcirc.tw/problem/d057) 經典貪心。 $Time: O(\text{sort} + n), Space: O(n)$ ```python= def main(): from sys import stdin e = stdin.readline n = int(e()) # 貪心策略: 製作時間短的先 l = sorted(zip(map(int, e().split()), map(int, e().split()))) now = ans = 0 for t, d in l: now += t # 累加時間 ans += d - now # 賺賠錢 print(ans) main() ``` 稍微觀察一下可以發現 `d` 直接求和就好了,跟 `t` 根本沒有關係。 ```python= def main(): from sys import stdin from itertools import accumulate e = stdin.readline e() # n 用不到 # 時間升冪排序 對前綴和累加 是扣款 # 工資直接加總就好了 print(-sum(accumulate(sorted(map(int, e().split())))) + sum(map(int, e().split()))) main() ``` ### [Q-4-17. 死線高手](https://judge.tcirc.tw/problem/d058) 中一中的裁判機不快,新的 Judge 又沒有對 Python 開比較大的時限,這題就算用上了 PyPy 再加上輸出優化還是過不了。 可以到 [惠文高中的 Judge](https://judge.hwsh.tc.edu.tw/ShowProblem?problemid=a205) 測試。 $Time: O(t \times (\text{sort} + n)), Space: O(n)$ ```python= def main(): from sys import stdin e = stdin.readline e() # 測資數量 用不到 for _ in stdin: # 每筆測資輸入的第一行是n 用不到 t = map(int, e().split()) # 時間 d = map(int, e().split()) # 死線 now = 0 # 死線早的先處理 for d, t in sorted(zip(d, t)): now += t # 累加時間 if now > d: # 超過死線 print("no") break else: # 全部完成 print("yes") main() ``` ### [Q-4-18. 少林寺的櫃姐](https://judge.tcirc.tw/problem/d059) 就是 [Q-4-8. 先到先服務](#053) 再加上「對答案二分搜」。 看到這題的時間複雜度,我實在得澄清一點:$\log_2^2n = \log_2 (\log_2 n) \ne (\log_2 n)^2$。這兩者不知道甚麼時候開始就在競程領域被瘋狂地誤用,但 $\log$ 可不是三角函數啊! (雖然但是 $\sin^2 x = (\sin x)^2 \ne \sin (\sin x)$ 這種不一致的規定也很解就是了) $Time: O(n \times (\log_2 n)^2), Space: O(n)$ ```python= def main(): from sys import stdin from bisect import bisect_left from heapq import heapreplace e = stdin.readline def check(x) -> bool: # O(nlogx) q = [0] * x for i in l: # 將新的客人交給最早結束的櫃台處理 heapreplace(q, i + q[0]) # 最晚結束的即為所求 return max(q) <= k n, k = map(int, e().split()) l = tuple(map(int, e().split())) # >= Python 3.10 bisect(key) # 下界要設定成 1 不然 check(0) 會讓 q[0] IndexError print(bisect_left(range(n), True, lo=1, key=check)) return # 下面是舊版本的實作 # < Python 3.10 le, ri = 1, n while le < ri: mid = (le + ri) >> 1 if check(mid) is True: ri = mid else: le = mid + 1 print(le) main() ``` ### [Q-4-19. 五嶽盟主的會議場所](https://judge.tcirc.tw/problem/d060) 「按照開始時間排序」可以確保更早到場的人一定已經掃過了,並保證已經下山的人之後不會再用到,使用 heap 可以讓他們依序下山。 $Time: O(n \log_2 n), Space: O(n)$ ```python= def main(): from sys import stdin from heapq import heappush, heappop from operator import itemgetter e = stdin.readline e() # n 用不到 ans = cur = 0 q = [(float("INF"), 0)] # 結束時間 sentinel 為無限 # 按照開始時間排序 for m, s, t in sorted((tuple(map(int, e().split())) for i in range(n)), key=itemgetter(1)): while q[0][0] < s: # 讓結束的門派下山 _, v = heappop(q) cur -= v # 減去下山人數 cur += m # 加上上山人數 if cur > ans: ans = cur # 更新最多人數 heappush(q, (t, m)) # 紀錄下山時間 print(ans) main() ``` ### [Q-4-20. 監看華山練功場](https://judge.tcirc.tw/problem/d061) 根據題意應該可以猜出要貪心,先想想看甚麼狀況下可以貪: - case 1: ```c++= ##### ###### ###### <- new ``` 這個貪不了,區間 1, 2 都要留著,還得加入 3。 - case 2: ```c++= ####### ######### ##### <- new ``` 可以看出區間 3 沒有存在的必要,不用加入。 - case 3: ```c++= ####### ###### ######## <- new ``` 這次變成區間 2 沒有存在的必要,可以刪掉 2 並加入 3。 因此,我們可以先將各區間按照起始點排列,再依據上面的三個 case 判斷就好了。 需要注意的是 case 3 判斷的部分,可能需要考慮要不要加上 `while`,當然如果不確定的話加上絕對不會錯,仔細思考一下:三個區間都留下的情況只有 case 1,條件是區間 1, 3 沒有重疊,假設再加上一個區間 4 和區間 2 有重疊,這樣 2, 3, 4 會形成 case 3,刪除 3 並加入 4,接下來如果再檢查一次 case 3 會發現 4, 1 沒有重疊,因為區間 4 的起點必定比 3 還要後面,不可能與 1 重疊。因此 case 3 最多觸發一次,用 `if` 就好了。 ```c++= ##### ####### ###### <- case 1 ###### <- new - case 3 - delete 3 ``` 題目要求覆蓋 $[x, y]$,我們可以假裝已經加入了一個 $[-1, x]$ 區間,左端點 $-1$ 確保他,$y$ 就只要看右端點覆蓋到的時候就結束就好了。 因為一次需要判斷三個區間之間的關係,因此快取 Stack 的最後兩項。 $Time: O(\text{sort} + n), Space: O(n)$ ```python= def main(): from sys import stdin from heapq import heappush, heappop e = stdin.readline n = int(e()) x, y = map(int, e().split()) l = sorted((tuple(map(int, e().split())) for i in range(n))) # 儲存結束點的stack 單調遞增 stk = [] pop, append = stk.pop, stk.append # 快取函數 # stack的最後兩項 初始化為 sentinel, 第一區間 pp, p = -1, x for s, t in l: if s > p: # 區間不連續 print(-1) return if t <= p: # 此區間沒有存在必要 - case 2 continue if s <= pp: # 如果前一個區間不必要就pop掉 - case 3 p, pp = pp, pop() append(pp) # 加入此區間 pp, p = p, t if t >= y: # 達成條件 輸出所求 break else: # 沒有break出來 代表區間右端沒有到y print(-1) return # 即為所求 (最後兩項在p,pp中 但扣掉初始化時預填的兩項剛好消掉) print(len(stk)) main() ``` ## 附件 - [AP325 初見](https://colab.research.google.com/drive/1LyBHbqB0n2lFb1Lt_7EM0gbBOfDTC2H1?hl=en) - 這篇是第一次刷 AP325 時寫的,時間跨距稍微有點大,可以看到很噁心的寫法和最佳解交雜,比較亂,屬於單純的歷程記錄。 - [AP325 二見](https://colab.research.google.com/drive/1lcYD5MQF2cVF3rQ725gSnsKtFc-NcRS-?hl=en) - 這篇即是本文所有程式碼的來源。 - [AP325 精熟到登峰 上篇](https://hackmd.io/@ericshen19555/AP325_Mastering_to_Peak_Part1) - 本文連結 - [AP325 精熟到登峰 下篇](https://hackmd.io/@ericshen19555/AP325_Mastering_to_Peak_Part2) - 下篇連結 ## 參考資料 - [AP325](https://drive.google.com/drive/u/0/folders/10hZCMHH0YgsfguVZCHU7EYiG8qJE5f-m) by 吳邦一教授 - [AP325-Python](https://hackmd.io/@bangyewu/Hy2kbYLI6/%2Fg2kqHh5_Q4eQnz-mfNu3Kw) by 吳邦一教授