###### tags: `Leetcode` `medium` `binary search` `heap` `sort` `python` `c++` # 378. Kth Smallest Element in a Sorted Matrix ## [題目連結:] https://leetcode.com/problems/kth-smallest-element-in-a-sorted-matrix/description/ ## 題目: Given an ```n x n``` ```matrix``` where each of the rows and columns is sorted in ascending order, return the ```kth``` smallest element in the matrix. Note that it is the ```kth``` smallest element **in the sorted order**, not the ```kth``` **distinct** element. You must find a solution with a memory complexity better than ```O(n^2)```. **Follow up:** * Could you solve the problem with a constant memory (i.e., ```O(1)``` memory complexity)? * Could you solve the problem in ```O(n)``` time complexity? The solution may be too advanced for an interview but you may find reading this paper fun. **Example 1:** ``` Input: matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8 Output: 13 Explanation: The elements in the matrix are [1,5,9,10,11,12,13,13,15], and the 8th smallest number is 13 ``` **Example 2:** ``` Input: matrix = [[-5]], k = 1 Output: -5 ``` * All the rows and columns of matrix are guaranteed to be sorted in non-decreasing order. ## 解題想法: * 此題為給一矩陣, 其中每一行每一列都按照遞增排列 * 尋找第k個元素 * **Sol1**: Follow up要求 * **使用Binary Search左下到右上階梯式找** * time: O(NlogK) * k為最大值與最小值的差值 * space: O(1) * Step1: init * head=matrix[0][0] 最小值在左上 * tail=matrix[-1][-1] 最大值在右下 * Step2: * while head<=tail: * **mid=(head+tail)//2** * mid表示數字值,而非位置 * Step3: 額外需要函式 * def **CountLower**(matrix,mid) * 用以紀錄整個matrix有多少個數字小於or等於當前mid * cur_count=self.CountLower(matrix,mid) * **if cur_count>=k**: * tail=mid-1 * 目前選到的數mid太大了 大過k個以上的數量 * else: * head=mid+1 * 選到的數字mid對於matrix之中太小了 * final return head * Step4: **CountLower(matrix, num):** * 從matrix中左下角開始找 * init: * i=len(matrix)-1 列 * j=0 行 * count=0 * while i>=0 and j<len(matrix[0]) :不要出界 * if matrix[i][j]<=num: * 表示當前該直行**皆比num小** * **count+= i+1** :加上整行數量 * **j+=1** : 同時j需右移找更大的 * else: * i-=1 * return count ## Python_Sol1: ``` python= class Solution(object): def kthSmallest(self, matrix, k): """ :type matrix: List[List[int]] :type k: int :rtype: int """ #binary search左下到右上階梯式找 head=matrix[0][0] #最小值在左上 tail=matrix[-1][-1] #最大值在右下 while head<=tail: mid=(head+tail)//2 cur_count=self.CountLower(matrix,mid) if cur_count>=k: #目前選到的數mid太大了 大過k個以上的數量 tail=mid-1 else: head=mid+1 #選到的數mid太小了 return head def CountLower(self,matrix,num): #從左下開始找 i=len(matrix)-1 j=0 count=0 while i>=0 and j<len(matrix[0]): if matrix[i][j]<=num: #表示目前該直行皆比num小 count+=i+1 #因此加上該行數量 j+=1 #同時j右移找更大的 else: i-=1 return count matrix = [[1,5,9],[10,11,13],[12,13,15]] k = 8 result = Solution() ans=result.kthSmallest(matrix,k) print(ans) ``` ## C++_Sol1: ``` cpp= class Solution { public: int kthSmallest(vector<vector<int>>& matrix, int k) { int head=matrix[0][0]; int tail=matrix.back().back(); //python: matrix[-1][-1] while (head<=tail){ int mid=(head+tail)/2; int curCount=CountLower(matrix,mid); if (curCount>=k) tail=mid-1; else head=mid+1; } return head; } int CountLower(vector<vector<int>>& matrix, int num){ int i=matrix.size()-1; int j=0; int curCount=0; while (i>=0 && j<matrix[0].size()){ if (matrix[i][j]<=num){ curCount+=i+1; j+=1; } else i-=1; } return curCount; } }; ``` * Sol2: * 使用heap queue進行判斷 * que存(該數字, row, col) * for _ in range(k): 進行k次pop * res=heappop() * 將res該位置下面、右邊的鄰居加入que ## Python_Sol2: ``` python= from heapq import heappush, heappop class Solution(object): def kthSmallest(self, matrix, k): """ :type matrix: List[List[int]] :type k: int :rtype: int """ #heap: O(NlogN) m=len(matrix) n=len(matrix[0]) #heapq que=[(matrix[0][0], 0, 0)] #value, row, col for _ in range(k): res, row, col=heappop(que) #將下面、右邊的鄰居加入que if col==0 and row+1<m: #若在左邊界,且還能往下一列,才能往下 heappush(que,(matrix[row+1][col], row+1, col)) if col+1<n: #向右 heappush(que,(matrix[row][col+1], row, col+1)) return res ``` # C++_Sol2: ``` cpp= class Solution { public: int kthSmallest(vector<vector<int>>& matrix, int k) { int m=matrix.size(), n=matrix[0].size(); priority_queue<vector<int>> que; //max_heap que.push({-matrix[0][0], 0, 0}); int res=matrix[0][0], row=0, col=0; for (int i=0; i<k; i++){ vector<int> curRes=que.top(); que.pop(); //res need to plus'-' to get the positive value res=-curRes[0], row=curRes[1], col=curRes[2]; if (col==0 && row+1<m) que.push({-matrix[row+1][col], row+1, col}); if (col+1<n) que.push({-matrix[row][col+1], row, col+1}); } return res; } }; ```