--- tags: Programming Contest --- # Maximum Size Square ## [題目](https://leetcode.com/problems/maximal-square/) 給定 NxM 的 0/1 矩陣 A,求矩陣中最大的、且全部都是 1 的正方形的面積。 ## 題解 設 dp[r, c] = 正方形右下角在 (r, c) 時,最大的正方形的邊長。顯而易見,若 A[r, c] 是 0,那 dp[r, c] 一定是 0;若 A[r, c] 是 1,那 dp[r, c] 最小也是 1。以下只探討 A[r, c] 是 1 的情況。 dp[r, c] 可能從哪些 state 轉移過來呢?假設 dp[r, c] = L,那位於 (r, c) 的最大正方形邊長是 L,這個正方形可能是從哪些位置的,且邊長是 L - 1 的正方形轉移過來?是 (r - 1, c), (r, c - 1), (r - 1, c - 1)。 如示意圖的最上方 row 所示,注意圖中藍色部份都是 1。 ![](https://i.imgur.com/MMlHZn0.png) 要能夠轉移過來,需要那些擴張的格子必需都是 1,即示意圖中間 row 的 **實線綠色**、**實線紫色** 的矩形。由此我們得到我們轉移方程: ```python dp[r, c] = 1 L1 = dp[r - 1, c] if np.all(A[r - L1 : r, c - L1] == 1) and np.all(A[r, c - L1 : c] == 1): dp[r, c] = max(dp[r, c], L1 + 1) L2 = dp[r, c - 1] if np.all(A[r - L2 : r, c] == 1) and np.all(A[r - L2, c - L2 : c] == 1): dp[r, c] = max(dp[r, c], L2 + 1) L3 = dp[r - 1, c - 1] if np.all(A[r, c - L3 : c] == 1) and np.all(A[r - L3 : r, c] == 1): dp[r, c] = max(dp[r, c], L3 + 1) ``` 實作上,需要小心邊界的情況,不能讓 index 為負的。此解法需要 O(N * M * max(N, M)) 的時間,讓我們想辦法優化。O(max(N, M)) 來自於檢查實線紫色與實線綠色的格子是不是都是 1,這個部份可以使用 prefix sum 預建表,將時間壓成 O(1)。 但存在一個更巧妙的方法:利用 dp 表格。以示意圖中間 row,左數第 2 張圖為例,**檢查紫色的部份是不是都是 1,相當於檢查 dp[r, c - 1] 的值是不是大於等於 L1**。 因為若是這個值大於等於 L1,即說明 (r, c - 1) 存在一個邊長大於等於 L1 的正方形,即紫色的部份都是 1。若是小於 L1,即說明紫色部份中必有 0。 同理,綠色的部份也可以變成檢查 dp[r - 1, c - 1]。示意圖中最下方 row 的每張圖,虛線綠色格子與虛線紫色格子,分別對應他上方圖中實線矩形的部份。要檢查實線矩形是不是全為 1,只需檢查虛線對應顏色格子的 dp 值是不是大於等於 L1 或 L2 或 L3。 由此我們得到 O(N * M) 的解法: ```python dp[r, c] = 1 L1 = dp[r - 1, c] L2 = dp[r, c - 1] L3 = dp[r - 1, c - 1] if L2 >= L1 and L3 >= L1: dp[r, c] = max(dp[r, c], L1 + 1) if L1 >= L2 and L3 >= L2: dp[r, c] = max(dp[r, c], L2 + 1) if L1 >= L3 and L2 >= L3: dp[r, c] = max(dp[r, c], L3 + 1) ``` 讓我們仔細看一下這段 code,會發現他是在說 L1 比 L2, L3 都小的時候,我們用 L1 + 1 更新看看 dp[r, c];L2 比 L1, L3 都小的時候,用 L2 + 1 更新看看 dp[r, c];L3 比 L1, L2 都小的時候,用 L3 + 1 更新看看 dp[r, c]。這可以寫得更簡潔: ```python dp[r, c] = 1 dp[r, c] = max(dp[r, c], min(dp[r - 1, c], dp[r, c - 1], dp[r - 1, c - 1]) + 1) ``` 這就是最終的轉移方程。 P.S. 網上一堆人直接跳到這個結論,真不知他們怎麼想到的,也許是我太笨了,別人的題解都看不懂。只好自己慢慢想,最後得到這個題解。可能存在更優雅的方式去理解這個 dp 吧。 ## AC Code ```python import numpy as np class Solution: def maximalSquare(self, matrix: List[List[str]]) -> int: if len(matrix) == 0 or len(matrix[0]) == 0: return 0 N, M = len(matrix), len(matrix[0]) A = np.int64(matrix) dp = np.zeros((N, M), dtype=int) dp[:, 0] = A[:, 0] == 1 dp[0, :] = A[0, :] == 1 for r in range(1, N): for c in range(1, M): if A[r, c] == 0: continue dp[r, c] = 1 dp[r, c] = max(dp[r, c], min(dp[r - 1, c], dp[r, c - 1], dp[r - 1, c - 1]) + 1) ans = dp.max().item() return ans * ans ``` ```python import numpy as np class Solution: def maximalSquare(self, matrix: List[List[str]]) -> int: if len(matrix) == 0 or len(matrix[0]) == 0: return 0 N, M = len(matrix), len(matrix[0]) A = np.int64(matrix) dp = np.zeros((N, M), dtype=int) dp[:, 0] = A[:, 0] == 1 dp[0, :] = A[0, :] == 1 for r in range(1, N): for c in range(1, M): if A[r, c] == 0: continue dp[r, c] = 1 L1 = dp[r - 1, c] L2 = dp[r, c - 1] L3 = dp[r - 1, c - 1] if L2 >= L1 and L3 >= L1: dp[r, c] = max(dp[r, c], L1 + 1) if L1 >= L2 and L3 >= L2: dp[r, c] = max(dp[r, c], L2 + 1) if L1 >= L3 and L2 >= L3: dp[r, c] = max(dp[r, c], L3 + 1) ans = dp.max().item() return ans * ans ``` ```python import numpy as np class Solution: def maximalSquare(self, matrix: List[List[str]]) -> int: if len(matrix) == 0 or len(matrix[0]) == 0: return 0 N, M = len(matrix), len(matrix[0]) A = np.int64(matrix) dp = np.zeros((N, M), dtype=int) dp[:, 0] = A[:, 0] == 1 dp[0, :] = A[0, :] == 1 for r in range(1, N): for c in range(1, M): if A[r, c] == 0: continue dp[r, c] = 1 L1 = dp[r - 1, c] if ( r >= L1 and c >= L1 and np.all(A[r - L1 : r, c - L1] == 1) and np.all(A[r, c - L1 : c] == 1) ): dp[r, c] = max(dp[r, c], L1 + 1) L2 = dp[r, c - 1] if ( r >= L2 and c >= L2 and np.all(A[r - L2 : r, c] == 1) and np.all(A[r - L2, c - L2 : c] == 1) ): dp[r, c] = max(dp[r, c], L2 + 1) L3 = dp[r - 1, c - 1] if ( r >= L3 and c >= L3 and np.all(A[r, c - L3 : c] == 1) and np.all(A[r - L3 : r, c] == 1) ): dp[r, c] = max(dp[r, c], L3 + 1) ans = dp.max().item() return ans * ans ``` :::success [< 回到所有題解](https://hackmd.io/@amoshyc/HkWWepaPv/) :::