https://leetcode.com/problems/min-cost-to-connect-all-points/description/
給一個陣列 points
裡面裝著點不重複的二維座標資料
定義連結兩點的成本 (cost) 為兩點之曼哈頓距離 (Manhattan distance):
回傳使所有點相連的最小成本。只要任意兩點之間恰存在一條 simple path,所有點就被視為相連。
首先我們要搞懂題目想問的是什麼
題目有兩個關鍵線索:
如果我們把圖畫出來就會發現,第二點就代表若視連線為圖上的邊,那個我們連線出來的結果不應該會有 cycle
這下應該就能想到本題其實要考的就是最小生成樹 (minimum spanning tree)
提到 MST 就會想到大名鼎鼎的兩個演算法:Kruskal's algorithm 與 Prim's algorithm。兩者皆是採取 greedy 的策略,前者每次選取 weight 最小的邊加入 MST;後者則可以想成是每次選能使連出去的 weight 最小的那個點
本題我們採用 Kruskal's algorithm 結合 disjoint-set,其 pseudocode 可參考下方:
我們初始條件為對
下面提供本題我的寫法:
class Edge
{
public:
int cost;
int x;
int y;
Edge() : cost(0), x(0), y(0) {}
Edge(int cost, int x, int y) : cost(cost), x(x), y(y) {}
bool operator<(const Edge &e) const { return cost < e.cost; }
Edge &operator=(const Edge &e)
{
if (this != &e)
{
cost = e.cost;
x = e.x;
y = e.y;
}
return *this;
}
};
class Solution
{
public:
int minCostConnectPoints(vector<vector<int>> &points)
{
size_t n = points.size();
vector<Edge> edge(n * (n - 1) / 2);
for (int i = 0, k = 0; i < n; i++)
{
for (int j = i + 1; j < n; j++)
{
int x1 = points[i][0], x2 = points[j][0];
int y1 = points[i][1], y2 = points[j][1];
int d = abs(x1 - x2) + abs(y1 - y2);
edge[k++] = Edge(d, i, j);
}
}
sort(edge.begin(), edge.end());
vector<int> p(n), rank(n, 0);
iota(p.begin(), p.end(), 0); // 填入遞增
int ans = 0, count = 0;
for (auto &e : edge)
{
int rx = find(p, e.x), ry = find(p, e.y);
if (rx == ry)
continue;
ans += e.cost;
if (rank[rx] < rank[ry])
swap(rx, ry);
p[rx] = ry; // 合併
rank[ry] += (rank[rx] == rank[rx]);
if (++count == n - 1)
break;
}
return ans;
}
private:
// Disjoint Set Union -- find()
int find(vector<int> &parent, int x)
{
while (parent[x] != x) // 當 x 並非集合之根元素
{
int temp = parent[x];
parent[x] = parent[parent[x]];
x = temp;
}
return x;
}
};
一樣的核心精神,不過這個寫法可能會再更平易近人些
class Solution
{
public:
int dist(int x1, int y1, int x2, int y2)
{
return abs(x1 - x2) + abs(y1 - y2);
}
int minCostConnectPoints(vector<vector<int>> &points)
{
size_t n = points.size();
vector<int> edge(n, 0);
int ans = 0;
edge[0] = INT_MAX;
for (auto i = 1; i < n; i++)
edge[i] = dist(points[0][0], points[0][1], points[i][0], points[i][1]);
for (int i = 1; i < n; i++)
{
auto it = min_element(edge.begin(), edge.end());
ans += *it;
int index = it - edge.begin();
*it = INT_MAX;
for (auto i = 0; i < n; i++)
{
if (edge[i] == INT_MAX)
continue;
edge[i] = min(edge[i], dist(points[i][0], points[i][1], points[index][0], points[index][1]));
}
}
return ans;
}
};
【定義】
給定