###### tags: `Leetcode` `medium` `binary search` `heap` `sort` `python` `c++`
# 378. Kth Smallest Element in a Sorted Matrix
## [題目連結:] https://leetcode.com/problems/kth-smallest-element-in-a-sorted-matrix/description/
## 題目:
Given an ```n x n``` ```matrix``` where each of the rows and columns is sorted in ascending order, return the ```kth``` smallest element in the matrix.
Note that it is the ```kth``` smallest element **in the sorted order**, not the ```kth``` **distinct** element.
You must find a solution with a memory complexity better than ```O(n^2)```.
**Follow up:**
* Could you solve the problem with a constant memory (i.e., ```O(1)``` memory complexity)?
* Could you solve the problem in ```O(n)``` time complexity? The solution may be too advanced for an interview but you may find reading this paper fun.
**Example 1:**
```
Input: matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8
Output: 13
Explanation: The elements in the matrix are [1,5,9,10,11,12,13,13,15], and the 8th smallest number is 13
```
**Example 2:**
```
Input: matrix = [[-5]], k = 1
Output: -5
```
* All the rows and columns of matrix are guaranteed to be sorted in non-decreasing order.
## 解題想法:
* 此題為給一矩陣, 其中每一行每一列都按照遞增排列
* 尋找第k個元素
* **Sol1**: Follow up要求
* **使用Binary Search左下到右上階梯式找**
* time: O(NlogK)
* k為最大值與最小值的差值
* space: O(1)
* Step1: init
* head=matrix[0][0] 最小值在左上
* tail=matrix[-1][-1] 最大值在右下
* Step2:
* while head<=tail:
* **mid=(head+tail)//2**
* mid表示數字值,而非位置
* Step3: 額外需要函式
* def **CountLower**(matrix,mid)
* 用以紀錄整個matrix有多少個數字小於or等於當前mid
* cur_count=self.CountLower(matrix,mid)
* **if cur_count>=k**:
* tail=mid-1
* 目前選到的數mid太大了 大過k個以上的數量
* else:
* head=mid+1
* 選到的數字mid對於matrix之中太小了
* final return head
* Step4: **CountLower(matrix, num):**
* 從matrix中左下角開始找
* init:
* i=len(matrix)-1 列
* j=0 行
* count=0
* while i>=0 and j<len(matrix[0]) :不要出界
* if matrix[i][j]<=num:
* 表示當前該直行**皆比num小**
* **count+= i+1** :加上整行數量
* **j+=1** : 同時j需右移找更大的
* else:
* i-=1
* return count
## Python_Sol1:
``` python=
class Solution(object):
def kthSmallest(self, matrix, k):
"""
:type matrix: List[List[int]]
:type k: int
:rtype: int
"""
#binary search左下到右上階梯式找
head=matrix[0][0] #最小值在左上
tail=matrix[-1][-1] #最大值在右下
while head<=tail:
mid=(head+tail)//2
cur_count=self.CountLower(matrix,mid)
if cur_count>=k: #目前選到的數mid太大了 大過k個以上的數量
tail=mid-1
else:
head=mid+1 #選到的數mid太小了
return head
def CountLower(self,matrix,num):
#從左下開始找
i=len(matrix)-1
j=0
count=0
while i>=0 and j<len(matrix[0]):
if matrix[i][j]<=num:
#表示目前該直行皆比num小
count+=i+1 #因此加上該行數量
j+=1 #同時j右移找更大的
else:
i-=1
return count
matrix = [[1,5,9],[10,11,13],[12,13,15]]
k = 8
result = Solution()
ans=result.kthSmallest(matrix,k)
print(ans)
```
## C++_Sol1:
``` cpp=
class Solution {
public:
int kthSmallest(vector<vector<int>>& matrix, int k) {
int head=matrix[0][0];
int tail=matrix.back().back(); //python: matrix[-1][-1]
while (head<=tail){
int mid=(head+tail)/2;
int curCount=CountLower(matrix,mid);
if (curCount>=k)
tail=mid-1;
else
head=mid+1;
}
return head;
}
int CountLower(vector<vector<int>>& matrix, int num){
int i=matrix.size()-1;
int j=0;
int curCount=0;
while (i>=0 && j<matrix[0].size()){
if (matrix[i][j]<=num){
curCount+=i+1;
j+=1;
}
else
i-=1;
}
return curCount;
}
};
```
* Sol2:
* 使用heap queue進行判斷
* que存(該數字, row, col)
* for _ in range(k): 進行k次pop
* res=heappop()
* 將res該位置下面、右邊的鄰居加入que
## Python_Sol2:
``` python=
from heapq import heappush, heappop
class Solution(object):
def kthSmallest(self, matrix, k):
"""
:type matrix: List[List[int]]
:type k: int
:rtype: int
"""
#heap: O(NlogN)
m=len(matrix)
n=len(matrix[0])
#heapq
que=[(matrix[0][0], 0, 0)] #value, row, col
for _ in range(k):
res, row, col=heappop(que)
#將下面、右邊的鄰居加入que
if col==0 and row+1<m: #若在左邊界,且還能往下一列,才能往下
heappush(que,(matrix[row+1][col], row+1, col))
if col+1<n: #向右
heappush(que,(matrix[row][col+1], row, col+1))
return res
```
# C++_Sol2:
``` cpp=
class Solution {
public:
int kthSmallest(vector<vector<int>>& matrix, int k) {
int m=matrix.size(), n=matrix[0].size();
priority_queue<vector<int>> que; //max_heap
que.push({-matrix[0][0], 0, 0});
int res=matrix[0][0], row=0, col=0;
for (int i=0; i<k; i++){
vector<int> curRes=que.top(); que.pop();
//res need to plus'-' to get the positive value
res=-curRes[0], row=curRes[1], col=curRes[2];
if (col==0 && row+1<m)
que.push({-matrix[row+1][col], row+1, col});
if (col+1<n)
que.push({-matrix[row][col+1], row, col+1});
}
return res;
}
};
```