---
# System prepended metadata

title: Leetcode 2316. Count Unreachable Pairs of Nodes in an Undirected Graph
tags: [graph, daily, leetcode, union find]

---

# Leetcode 2316. Count Unreachable Pairs of Nodes in an Undirected Graph
###### tags: `leetcode` `daily` `graph` `union find`

[題目連結](https://leetcode.com/problems/count-unreachable-pairs-of-nodes-in-an-undirected-graph/)

# Method Union Find
:::info
:bulb: **作法講解**:
The problem ask calculate the number of pairs of different nodes that are unreachable from each other.


:::warning
**Base case 1**

If the graph consist of two connected component,
groupA `{a1, ... an}` and groupB `{b1, ... bn}`

a1 is unreach to `{b1, ... bn}`
a2 is unreach to `{b1, ... bn}`
..
an is unreach to `{b1, ... bn}`

so the number of pairs of different nodes that are unreachable from each other is an * bn

---
**Base case 2**
Following above base case,

we can expand two component to three component
groupA, groupB and groupC

the number of pairs of different nodes that are unreachable from each other is 
`an * (bn + cn) + bn * (cn)`


so we define gx is the node count of group x.
the formula is 
`g0 * (g1 + g2 ... gn) + 
g1 * (g2 + g3 ... gn) +
..
gn-1 * (gn)`
     
     
define total is the node count of graph.
we can transfer the formula to
`(g0 * (total - g0) + ..gn * (total - n2)) / 2`
:::


Step 1: using union_find to calculate the node count for each group

Step 2: using above formula to calculate the final anwer.

:::

TC: O(N) SC: O(N)
完整程式碼
```cpp=
class union_find {
public:
    vector<int> parent;
    vector<int> rank;
    vector<int> cnt;
    union_find(int n) {
        parent.resize(n, 0);
        rank.resize(n, 0);
        cnt.resize(n, 0);
        for(int i = 0 ; i < n ; i++) {
            parent[i] = i;
            cnt[i] = 1;
        }
    }

    int Find(int x) {
        if(parent[x] != x) {
            parent[x] = Find(parent[x]);
        }
        return parent[x];
    }

    void merge(int x, int y) {
        parent[x] = y;
        cnt[y] += cnt[x];
    }

    void Union(int x, int y) {
        int px = Find(x);
        int py = Find(y);
        if(px == py) {
            return;
        }
        
        if(rank[px] > rank[py]) {
            merge(py, px);
        }
        else if(rank[py] > rank[px]) {
            merge(px, py);
        }
        else {
            merge(py, px);
            rank[px]++;
        }
    }
};

class Solution {
public:
    long long countPairs(int n, vector<vector<int>>& edges) {
        union_find uf(n);
        long long output = 0;

        for(vector<int> &v : edges) {
            uf.Union(v[0], v[1]);
        } 
        for(int i = 0 ; i < n ; i++) {
            if(uf.parent[i] == i) {
                long long cnt = uf.cnt[i];
                output += cnt * (n-cnt);
            }
        }
        return output / 2;
    }
};
```



