Try   HackMD

【CSES】1137. Subtree Queries

題目連結

時間複雜度

  • O(NlogN)

解題想法

這題想考的點有兩個,一個是 樹壓平,另一個是 支持單點修改和查詢前綴和的資料結構

以下會一一做講解:

  1. 樹壓平

樹壓平是利用 DFS 順序的特性把樹壓成一個序列,在這個序列中每個節點和其子樹節點都會相連,因此我們就可以利用這個性質對這個序列使用資料結構或其他演算法做到子樹修改。

那要怎麼求樹壓平的這個序列呢?只要在DFS時,每碰到一個新的點就加進序列。

int timer; void dfs(int now, int pa) { pos[now] = ++timer; sz[now] = 1; for (int v : g[now]) { if (v == pa) continue; dfs(v, now); sz[now] += sz[v]; } }

現在我們只需要再知道每個子樹大小,就可以知道每個子樹區間了。

因為任意一個節點

x 一定比他的子樹都還要早被放進序列。 所以
x
的子樹區間就是
[pos[x],pos[x]+sz[x]1]

但是除了這個做法之外,還有另一種方式是紀錄每個點進入和離開的時間,其實和上面的差不多,所以就不多解釋了,直接看程式。

int timer; void dfs(int now, int pa) { st[now] = ++timer; for (int v : g[now]) { if (v == pa) continue; dfs(v, now); } ed[now] = timer; }

這樣

x 的子樹區間就是
[st[x],ed[x]]

  1. 支持單點修改和查詢前綴和的資料結構 - 資料結構 Binary Indexed Tree【BIT】

其實支持單點修改和查詢前綴和這個特性的話其實也可以使用線段樹,但是這題會用 BIT 的原因很簡單,就是因為他比線段樹好刻太多了,而且空間複雜度也不像線段樹需要的

4×N 那麼高,只需要開到
N
就可以了

接著簡短介紹一下 BIT 的操作

  1. 查詢前綴和

要查詢以

x 為結尾的前綴,就應該找到幾段不重疊,且聯集恰好是所求前綴的區間。

​​ int query( int x ){ ​​ int res = 0; ​​ while( x ){ ​​ res += BIT[x]; ​​ x -= x & (-x); ​​ } ​​ return res; ​​ }
  1. 單點修改

只要一直把

x 加上
lowbit(pos)
,就可以得到它和覆蓋它的所有節點,接著修改這些節點的答案就可以了。

​​ void update( int x, int delta ){ ​​ while ( x <= n ){ ​​ BIT[x] += delta; ​​ x += x & (-x); ​​ } ​​ }

完整程式

/* Question : CSES 1137. Subtree Queries */ #include<bits/stdc++.h> using namespace std; #define opt ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0); #define pirq(type) priority_queue<type, vector<type>, greater<type>> #define mem(x, value) memset(x, value, sizeof(x)); #define pii pair<int, int> #define pdd pair<double, double> #define pb push_back #define f first #define s second #define int long long const auto dir = vector< pair<int, int> > { {1, 0}, {0, 1}, {-1, 0}, {0, -1} }; const int MAXN = 2e5 + 50; const int Mod = 1e9 + 7; int n, q, a, b, st, timer, arr[MAXN], sz[MAXN], pos[MAXN], BIT[MAXN]; vector<vector<int>> graph; int query( int x ){ int res = 0; while( x ){ res += BIT[x]; x -= x & (-x); } return res; } void update( int x, int delta ){ while ( x <= n ){ BIT[x] += delta; x += x & (-x); } } void treeflat( int cnt, int fa ){ pos[cnt] = ++timer; sz[cnt] = 1; for( auto i : graph[cnt] ){ if( i == fa ) continue; treeflat(i, cnt); sz[cnt] += sz[i]; } } signed main(){ opt; cin >> n >> q; graph.resize(n+5); for( int i = 1 ; i <= n ; i++ ) cin >> arr[i]; for( int i = 0 ; i < n-1 ; i++ ){ cin >> a >> b; graph[a].pb(b); graph[b].pb(a); } treeflat(1, 0); // Init for( int i = 1 ; i <= n ; i++ ) update(pos[i], arr[i]); while( q-- ){ cin >> st; if( st == 1 ){ cin >> a >> b; update(pos[a], b - arr[a]); arr[a] = b; }else{ cin >> a; cout << query(pos[a] + sz[a] - 1) - query(pos[a] - 1) << "\n"; } } }