Try   HackMD

算法面試套路|雙重堆積(Two Heaps)

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

本篇內容主要為 Grokking the Coding Interview: Patterns for Coding Questions 的翻譯與整理,行有餘力建議可以購買該專欄課程閱讀並在上面進行練習。

重點整理

  • 雙重堆積(Two Heaps)
  • 策略:將元素分為兩部分,使用 最小堆積(min heap) 查找最小元素,使用 最大堆積(max heap) 查找最大元素
  • 題型:
    • 給定元素集合可以被劃分成兩部分
    • 需要分別自各個部分找到其中的最大元素和最小元素

題目匯總

  • 0295 Find Median from Data StreamPermalink
  • 0480 Sliding Window Median
  • 0502 IPO

[例題] Find the Median of a Number Stream

問題

設計一個類別(class)來計算一組數字流(number stream)的中位數(median),該類別需要具備以下兩個方法:

  • insertNum(int num) 將數字儲存於類別中
  • findMedian() 返回目前插入數字流的中位數

當數字流中的個數為偶數時,其中位數為中間兩個數字的平均。

Example

insertNum(3)  // [3]
insertNum(1)  // [1, 3]
findMedian()  // 2
insertNum(5)  // [1, 3, 5]
findMedian()  // 3
insertNum(4)  // [1, 3, 4, 5]
findMedian()  // 3.5

題解

正如我們所知,所謂的中位數(median)是排序後的整數序列的中間值,因此一個直觀的暴力解法可以是維護一個有序陣列來儲存插入的數字,這樣一來便可以有效率地返回其中位數;其中,在有序陣列中插入數字需要花費

O(n) 的時間,此處的
n
為數字個數,插入的過程類似於進行 插入排序(Insertion Sort)。我們能不能有更好的處理方式?我們能不能利用此題只需要關注中間的數值的事實,而不需要維護完整的有序陣列?

假設

x 是一組數字中的中位數,這表示數字中的一半元素會小於等於他,而另外一半元素則會大於等於他。所以我們可以將一組數字分成兩半:一半用於儲存所有較小的數(記為 smallNumList),一半用於儲存所有較大的數(記為 largeNumList);此時中位數便會是 largeNumList 中的最小數字或是 smallNumList 中的最大數字,當元素數量為偶數時,則為兩數的平均值。

最適合用來查找一組數字中極值的資料結構,非 堆積(Heap) 莫屬,因此我們可以使用堆積來處理這個問題,並有以下演算法:

  1. 使用最大堆積(max heap)儲存前一半數字,即 smallNumList,因為我們要找出其中的最大數字
  2. 使用最小堆積(min heap)儲存後一半數字,即 largeNumList,因為我們要找出其中的最小數字
  3. 往堆積中插入元素需要
    O(logn)
    時間,相較於前面提及的暴力解法更有效率
  4. 再任意時間,當前的中位數都可以從兩個堆積中的頂端元素計算得到

補充:如何判斷數字屬於前半還是後半?

這個問題等同於「要怎麼確定今天的數字要往最大堆積放?還是往最小堆積放?」;判斷的依據其實是與目前最大堆積中的最大元素進行比較,而且需要維護使得兩個堆積中的元素個數儘量保持平衡:

  • 如果當前數字比 maxHeap 中的數字要小,插入 maxHeap
  • 如果當前數字比 maxHeap 中的數字要大,插入 minHeap
  • 每次插入後需要檢查兩個堆積是否保持平衡(不能讓一個堆積獨大)
    • 如果最大堆積比最小堆積多出了兩個元素,往 minHeap 搬動元素
    • 如果當前數字個數為奇數,往 maxHeap 搬動元素

上述步驟如下圖所示:

1. insertNum(3):優先往最大堆積中插入數字

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

2. insertNum(1):由於 1 比 3 小,往最大堆積中插入數字
Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

3. 檢查堆積是否平衡,此時最大堆積較多元素,將其中的最大值 3 往最小堆積放
Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

4. findMedian():兩個堆積數量相同,數字個數為偶數,取兩個頂部元素的平均即為中位數,(1 + 3) / 2 = 2.0
5. insertNum(5):由於 5 比 1 大,往最小堆積中插入數字,兩個堆積處於不平衡狀態,將多餘元素優先往最大堆積放
Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

6. findMedian():兩個堆積數量不同,取最大堆積中的頂部元素 3
7. insertNum(4):由於 4 比 3 大,往最小堆積中插入數字
Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

4. findMedian():兩個堆積數量相同,數字個數為偶數,取兩個頂部元素的平均即為中位數,(3 + 4) / 2 = 3.5

代碼

C++

class MedianOfAStream {
 public:
  priority_queue<int> maxHeap;
  priority_queue<int, vector<int>, greater<int>> minHeap;
  
  virtual void insertNum(int num) {
    if (maxHeap.empty() || maxHeap.top() >= num) maxHeap.push(num);
    else minHeap.push(num);
    
    if (maxHeap.size() > minHeap.size() + 1) {
      minHeap.push(maxHeap.top());
      maxHeap.pop();
    } else if (maxHeap.size() < minHeap.size()) {
      maxHeap.push(minHeap.top());
      minHeap.pop();
    }
  }
  
  virtual double findMedian() {
    if (maxHeap.size() == minHeap.size()) {
      return (maxHeap.top() + minHeap.top()) / 2.0;
    }
    return maxHeap.top();
  }
}

Java

class MedianOfAStream {
  PriorityQueue<Integer> maxHeap;  // contain first half of numbers
  PriorityQueue<Integer> minHeap;  // contain second half of numbers
  
  public MedianOfAStream() {
    maxHeap = new PriorityQueue<>((a, b) -> b - a);
    minHeap = new PriorityQueue<>((a, b) -> a - b);
  }
  
  public void insertNum(int num) {
    if (maxHeap.isEmpty() || maxHeap.peek() >= num) maxHeap.add(num);
    else minHeap.add(num);
    
    // either both the heaps will have equal number of elements or max-heap will have
    // one more element than the min-heap
    if (maxHeap.size() > minHeap.size() + 1) minHeap.add(maxHeap.poll());
    else if (maxHeap.size() < minHeap.size()) maxHeap.add(minHeap.poll());
  }
  
  public double findmedian() {
    if (maxHeap.size() == minHeap.size()) return (maxHeap.peek() + minHeap.peek()) / 2.0;
    else return maxHeap.peak();
  }
}

JavaScript

在 JavaScript 中並沒有原生支持 Heap 這種資料結構,必須引入其他人的實現,此處使用 collections.js

$ npm install --save collections
const Heap = require("collections/heap");

class MedianOfAStream {
  constructor() {
    this.maxHeap = new Heap([], null, ((a, b) => a - b));
    this.minHeap = new Heap([], null, ((a, b) => b - a));
  }
  
  insertNum(num) {
    if (this.maxHeap.length === 0 || this.maxHeap.peek() >= num) {
      this.maxHeap.push(num);
    } else {
      this.minHeap.push(num);
    }
    
    // 兩個堆疊必須大小相同,或是多餘的元素要往最大堆疊擺放
    if (this.maxHeap.length > this.minHeap.length + 1) {
      this.minHeap.push(this.maxHeap.pop());
    } else if (this.maxHeap.length < this.minheap.length) {
      this.maxHeap.push(this.minHeap.pop());
    }
  }
  
  findMedian() {
    if (this.maxHeap.length === this.minHeap.length) {
      return (this.maxheap.peek() + this.minheap.peek()) / 2.0;
    }
    return this.maxHeap.peek();
  }
}

Python

from heapq import *

class MedianOfAStream:
    max_heap = []
    min_heap = []
    
    def insert_num(self, num):
        if not self.max_heap or self.max_heap[0] >= num:
            heappush(self.max_heap, num)
        else:
            heappush(self.min_heap, num)
            
    if len(self.max_heap) > len(min_heap) + 1:
        heappush(self.min_heap, heappop(self.max_heap))
    else:
        heappush(self.max_heap, heappop(self.min_heap))
        
    def find_median(self)
        if len(self.max_heap) == len(self.min_heap):
            return (self.max_heap[0] + self.min_heap[0]) / 2.0
        return self.max_heap[0]

分析

  • 時間複雜度
    • insertNum() 需要
      O(logn)
      往堆積中插入數字
    • findMedian() 需要
      O(1)
      從堆積中獲取頂部元素
  • 空間複雜度:
    O(n)
    用於儲存所有數字

[例題] Sliding Window Median

問題

給定一組數字組成的陣列與數字

k,找出每
k
個元素為一組的子陣列之中位數。

Example 01

Input   : [1, 2, -1, 3, 5], k = 2
Output  : [1.5, 0.5, 1.0, 4.0]

// [1, 2]            -> 1.5
//    [2, -1]        -> 0.5
//       [-1, 3]     -> 1.0
//           [3, 5]  -> 4.0

Example 02

Input   : [1, 2, -1, 3, 5], k = 3
Output  : [1.0, 2.0, 3.0]

// [1, 2, -1]        -> 1.0
//    [2, -1, 3]     -> 2.0
//       [-1, 3, 5]  -> 3.0

題解

這一題遵循 雙重堆積(Two Heaps) 模式,且與 Find the Median of a Number Stream 存在相似之處。我們一樣維護一個最大堆積與最小堆積來查找陣列的中位數。

唯一差別在於我們必須 追蹤大小為

k 的滑動窗口中之數字,也就是在每一次迭代的過程,在往堆積中插入新數字的同時,必須從堆積中移去一個數字(脫離窗口的數字),並且在每次操作後都需要對堆積進行平衡(rebalance)

代碼

C++

class SlidingWindowMedian {
 public:
  priority_queue_with_remove<int> maxHeap;
  priority_queue_with_remove<int, vector<int>, greater<int>> minHeap;
  
  virtual vector<double> findSlidingWindowMedian(const vector<int> &nums, int k) {
    vector<double> result(nums.size() - k + 1);
    for (int i = 0; i < nums.size(); i++) {
      if (maxHeap.size() == 0 || maxHeap.top() >= nums[i]) {
        maxHeap.push(nums[i]);
      } else {
        minHeap.push(nums[i]);
      }
      rebalanceHeaps();
      
      if (i - k + 1 >= 0) {
        if (maxHeap.size() == minHeap.size()) {
          result[i - k + 1] = (maxHeap.top() + minHeap.top()) / 2.0;
        } else {
          result[i - k + 1] = meaHeap.top();
        }
        
        int elementToBeRemoved = nums[i - k + 1];
        if (elementToBeRemoved <= maxHeap.top()) maxHeap.remove(elementToBeRemoved);
        else minHeap.remove(elementToBeRemoved);
        
        rebalanceHeaps();
      }
    }
    return result;
  }
  
 private:
  void rebalanceHeaps() {
    if (maxHeap.size() > minHeap.size() + 1) {
      minHeap.push(maxHeap.top());
      maxHeap.pop();
    } else if (maxHeap.size() < minHeap.size()) {
      maxHeap.push(minHeap.top());
      minHeap.pop();
    }
  }
};

Java

class SlidingWindowMedian {
  PriorityQueue<Integer> maxHeap = new PriorityQueue<>(Collections.reverseOrder());
  PriorityQueue<Integer> minHeap = new PriorityQueue<>();
   
  public double[] findSlidingWindowMedian(int[] nums, int k) {
    double[] result = new double[nums.length - k + 1];
    for (int i = 0; i < nums.length; i++) {
      if (maxHeap.size() == 0 || maxHeap.peek() >= nums[i]) maxHeap.add(nums[i]);
      else minHeap.add(nums[i]);
    } 
    
    rebalanceHeaps();
    
    if (i - k + 1 >= 0) {
      if (maxHeap.size() == minHeap.size()) {
        result[i - k + 1] == (maxHeap.peek() + minHeap.peek()) / 2.0;
      } else {
        result[i - k + 1] == maxHeap.peek();
      }
      
      int elementToBeRemoved = nums[i - k + 1];
      if (elementToBeremoved <= maxHeap.peek()) maxHeap.remove(elementToBeRemoved);
      else minHeap.remove(elementToBeRemoved);
      rebalanceHeaps();
    }
  }
  
  private void rebalanceHeaps() {
    if (maxHeap.size() > minHeap.size() + 1) minHeap.add(maxHeap.poll());
    else if (maxHeap.size() < minHeap.size()) maxHeap.add(minHeap.poll());
  }
}

JavaScript

const Heap = require('./collections/heap');

class SlidingWindowMedian {
  constructor() {
    this.maxHeap = new Heap([], null, ((a, b) => a - b));
    this.minHeap = new Heap([], null, ((a, b) => b - a));
  }
  
  findSlidingWindowMedian(nums, k) {
    const result = Array(nums.length - k + 1).fill(0.0);
    for (let i = 0; i < nums.length; i++) {
      if (this.maxHeap.length === 0 || nums[i] <= this.maxHeap.peek()) {
        this.maxHeap.push(nums[i]);
      } else {
        this.minHeap.push(nums[i]);
      }
      
      this.rebalanceHeaps();
      
      // if we have at least k elements in the sliding window
      if (i - k + 1 >= 0) {
        // add the median to the result array
        if (this.maxHeap.length === this.minHeap.length) {
          result[i - k + 1] = (this.maxHeap.peek() + this.minHeap.peek()) / 2.0;
        } else {
          result[i - k + 1] = this.maxHeap.peek();
        }
        
        // remove the element going out of the sliding window
        const elementToBeRemove = nums[i - k + 1];
        if (elementToBeRemoved <= this.maxHeap.peek()) {
          this.maxHeap.delete(elementToBeRemoved);
        } else {
          this.minHeap.delete(elementToBeRemoved);
        }
        
        this.rebalanceHeaps();
      }
    }
    
    return result;
  }
  
  rebalnceHeaps() {
    if (this.maxHeap.length > this.minHeap.length + 1) {
      this.minHeap.push(this.maxHeap.pop());
    } else if (this.maxHeap.length < this.minHeap.length) {
      this.maxHeap.push(this.minHeap.pop());
    }
  }
}

Python

from heapq import *
import heapq


class SlidingWindowMedian:
    def __init__(self):
        self.max_heap = []
        self.min_heap = []
        
    def find_sliding_window_median(self, nums, k):
        result = [0.0 for x in range(len(nums) - k + 1)]
        for i in range(0, len(nums)):
            if not self.max_heap or nums[i] <= self.max_heap[0]:
                heappush(self.max_heap, nums[i])
            else:
                heappush(self.min_heap, nums[i])
            
            self.rebalnce_heaps()
            
            if i - k + 1 >= 0:
                if len(self.max_heap) == len(self.min_heap):
                    result[i - k + 1] = (self.max_heap[0] + self.min_heap[0]) / 2.0
                else:
                    result[i - k + 1] = self.max_heap[0]
                    
                remove_element = nums[i - k + 1]
                if remove_element < self.max_heap[0]:
                    self.remove(self.max_heap, remove_element)
                else:
                    self.remove(self.min_heap, remove_element)
                    
                self.rebalance_heaps()
    
    def remove(self, heap, element):
        idx = heap.index(element)
        heap[idx] = hea[-1]
        del heap[-1]
        
        if idx < len(heap):
            heapq._siftup(heap, idx)
            heapq._siftdown(heap, 0, idx)
            
    def rebalance_heaps(self):
        if len(self.max_heap) > len(self.min_heap) + 1:
            heappush(self.min_heap, heappop(self.max_heap))
        elif len(self.max_heap) < len(self.min_heap):
            heappush(self.max_heap, heappop(self.min_heap))

分析

  • 時間複雜度:
    O(n×k)
    • 在大小為
      k
      的堆積中刪除或插入元素需要
      O(logk)
    • 在大小為
      k
      的滑動窗口中移除元素,需要
      O(k)
      先在堆積中找到元素
  • 空間複雜度:
    O(k)
    用以儲存滑動窗口中的元素

[例題] Maximize Capital

問題

給定一組投資方案與其對應的收益,以及初始的資本額和允許投資的案件數量,我們需要從中找出最佳收益的投資案;當有足夠資本額就可以進行投資,當選定投資方案後,可以假設其收益已經成為我們的資本額。

Example 01

Input   : Project Capitals = [0, 1, 2]
          Project Profits = [1, 2, 3]
          Initial Capital = 1
          Number of Projects = 2
Output  : 6
  1. 初始資本額為
    1
    ,可以從第二個方案開始投資,且該方案會獲得收益
    2
    。一但我們選擇該方案時,我們的資本額變為
    3
    (Profit + Initial Capital)
  2. 當前資本額為
    3
    ,因此可以選擇第三個方案,且該方案會得到收益
    3
  3. 綜合上述,最後總資本額為
    1+2+3=6

Example 02

Input   : Project Capitals = [0, 1, 2, 3]
          Project Profits = [1, 2, 3, 5]
          Initial Capital = 0
          Number of Projects = 3
Output  : 3
  1. 初始資本額為
    0
    ,可以從第一個方案開始投資,且該方案會獲得收益
    1
    。一但我們選擇該方案時,我們的資本額變為
    1
  2. 當前資本額為
    1
    ,因此可以選擇第二個方案,且該方案會得到收益
    2
    。一但我們選擇該方案時,我們的資本額變為
    3
  3. 當前資本額為
    3
    ,因此可以選擇第三個方案,且該方案會得到收益
    5
  4. 綜合上述,最後總資本額為
    1+2+5=8

題解

選擇投資方案時,我們有以下兩個限制:

  • 我們只能在有足夠資本的前提下,才能選擇某一方案
  • 所能選擇的投資方案個數,有最大限制

採用貪心策略可以獲得最佳解。在選擇方案時,我們需要進行以下操作:

  1. 找出當前資本額下,可以選擇的投資方案
  2. 從可投資的方案中,選擇最佳收益的方案進行投資

因此我們可以遵循 雙重堆疊(Two Heaps) 模式,使用與 Find the Median of a Number Stream 相同的解題策略。以下是我們演算法的解題步驟:

  1. 將所有投資方案放置到最小堆積 minHeap 中,用來從中選出滿足最小資本需求的方案
  2. 自上而下過濾出當前可用資本所能選擇的投資方案,將這些方案的收益放入最大堆積 maxHeap 中,用來從中選出最大收益的方案
  3. 從最大堆積中選出位於頂部的投資方案
  4. 反覆上述第二和第三步的操作

代碼

C++

class MaximizeCapital {
 public:
  struct capitalCompare {
    bool operator()(const pair<int, int> &x, const pair<int, int> &y) {
      return x.first > y.first;
    }
  };
  
  struct profitCompare {
    bool operator()(const pair<int, int> &x, const pair<int, int> &y) {
      return y.first > z.first;
    }
  };
  
  static int findMaximumCapital(const vector<int> &capital, const vector<int> &profits, int numberOfProjects, int initialCapital) {
    int n = profits.size();
    priority_queue<pair<int, int>, vector<pair<int, int>>, capitalCompare> minCapitalHeap;
    priority_queue<pair<int, int>, vector<pair<int, int>>, profitCompare> maxProfitHeap;
    
    // insert all project capitals to a min-heap
    for (int i = 0; i < n; i++) {
      minCapitalHeap.push(make_pair(capital[i], i));
    }
    
    // let's try to find a total of 'numberOfProjects' best projects
    int availableCapital = initialCapital;
    for (int i = 0; i < numberOfProjects; i++) {
      // find all projects that can be selected within the available capital and insert them in a max-heap
      while (!minCapitalHeap.empty() && minCapitalHeap.top().first <= availableCapital) {
        auto capitalIndex = minCapitalHeap.top().second;
        minCapitalHeap.pop();
        maxProfitHeap.push(make_pair(profits[capitalIndex], capitalIndex));
      }
      
      // terminate if we are not able to find any project that can be completed within the available capital
      if (maxProfitHeap.empty()) break;
      
      // select the project with the maximum profit
      availableCapital += maxProfitHeap.top().first;
      maxProfitHeap.pop();
    }
    
    return availableCapital;
  }
};

Java

class MaximizeCapital {
  public static int findMaximumCapital(int[] capital, int[] profits, int numberOfProjects, int initialCaptial) {
    int n = profits.length;
    PriorityQueue<Integer> minCapitalHeap = new PriorityQueue<>(n, (i1, i2) -> capital[i1] - capital[i2]);
    PriorityQueue<Integer> maxProfitHeap = new PriorityQueue<>(n, (i1, i2) -> capital[i2] - capital[i1]);
    
    // insert all project capitals to a min-heap
    for (int i = 0; i < n; i++) minCapitalHeap.offer(i);
    
    // let's try to find a total of numberOfProjects best projects
    int availableCapital = initialCapital;
    for (int i = 0; i < numberOfProjects; i++) {
      while (!minCapitalHeap.isEmpty() && capital[minCapitalHeap.peek()] <= availableCapital) {
        maxProfitHeap.add(minCapitalHeap.poll());
      }
      
      // terminate if we are not able to find any project that can be completed with the available capital
      if (maxProfitHeap.isEmpty()) break;
      
      // select the project with the maximum profit
      availableCapital += profits[maxProfitHeap.poll()];
    }
    
    return availableCapital;
  }
}

JavaScript

const Heap = require('./collections/heap');

const findMaximumCapital = (capital, profits, numberOfProjects, initialCapital) => {
  const minCapitalHeap = new Heap([], null, ((a, b) => b[0] - a[0]));
  const maxProfitHeap = new Heap([], null, ((a, b) => a[0] - b[0]));
  
  // insert all project capitals to a min-heap
  for (let i = 0; i < profits.length; i++) {
    minCapitalHeap.push([capital[i], i]);
  }
  
  // let's try to find a total of 'numberOfProjects' best projects
  let availableCapital = initialCapital;
  for (let i = 0; i < numberOfProjects; i++) {
    // find all projects that can be selected within the available capital and insert them in a max-heap
    while (minCapitalHeap.length > 0 && minCapitalHeap.peek()[0] <= availableCapital) {
      const [capital, index] = minCapitalHeap.pop();
      maxProfitHeap.push([profits[index], index]);
    }
    
    // terminate if we are not able to find any project that can be complete within the available capital
    if (maxProfitHeap.length === 0) break;
    
    // select the project with the maximum profit
    availableCapital += maxProfitHeap.pop()[0];
  }
  
  return availableCapital;
}

Python

from heapq import *


def find_maximum_capital(capital, profits, number_of_projects, initial_capital):
    min_capital_heap = []
    max_profit_heap = []
    
    # insert all project capitals to a min-heap
    for i in range(0, len(profits)):
        heappush(min_capital_heap, (capital[i], i))
        
    # let's try to find a total of number_of_projects best projects
    available_capital = initial_capital
    for _ in range(number_of_projects):
        # find all projects that can be selected within the available capital and insert them in a max-heap
        while min_capital_heap and min_capital_heap[0][0] <= available_capital:
            capital, i = heappop(min_capital_heap)
            heappush(max_profit_heap, (profits[i], i))
            
        # terminate if we are not able to find any project that can be comleted within the available capital
        if not max_profit_heap:
            break
            
        # select the project with the maximum profit
        available_capital += heappop(max_profit_heap)[0]
        
    return available_capital

分析

  • 時間複雜度:
    O(nlogn+klogn)
    ,其中
    n
    為投資方案總數,而
    k
    為選擇方案數目
  • 空間複雜度:
    O(n)
    用以往堆積中存放投資方案

參考資料