# CSES Tree Algorithms # Disclaimer 這篇題解假設你已經會寫樹上的DFS,並且$dep_u$為$u$的深度,$sz_u$為$u$的子樹大小,$fa(u)$為$u$的父節點,$child(u)$為$u$的子節點們。 # [Subordinates](https://cses.fi/problemset/task/1674) 求每個點的子樹大小$-1$,可以用樹 DP 維護子樹大小,轉移式如下 $$sz_u=1+\sum_{v:child(u)}sz_v$$ 複雜度:$O(n)$ :::spoiler code ```c++= #include <bits/stdc++.h> using namespace std; #define ll long long #define pb push_back const int MAXN=200005; vector <int> G[MAXN]; int sz[MAXN]; void dfs(int cur, int lst){ sz[cur]=1; for(int nxt: G[cur]){ if(nxt==lst) continue; dfs(nxt, cur); sz[cur]+=sz[nxt]; } } int main(){ int n; cin>>n; int fa; for(int i=2; i<=n; i++){ cin>>fa; G[fa].pb(i); G[i].pb(fa); } dfs(1, 0); for(int i=1; i<=n; i++){ cout<<sz[i]-1<<" "; } ``` ::: # [Tree Matching](https://cses.fi/problemset/task/1130) 問題可以被改成找最小點覆蓋(最少可以覆蓋每一條邊的點集)。由 Konig's Theorem 這會對。 樹 DP。對每個節點維護選擇它與不選擇它的時候它的子樹中的最小點覆蓋大小。設$dp_{0, u}$是不選擇$u$時$u$的子樹中的最小點覆蓋,而$dp_{1,u}$是選擇$u$時$u$的子樹中的最小點覆蓋,我們有 $$dp_{0,u}=\sum_{v:child(u)}dp_{1,u}$$以及 $$dp_{1,u}=\sum_{v:child(u)}\min(dp_{0,u}, dp_{1,u})$$ 複雜度:$O(n)$ :::spoiler code ```c++= #include <bits/stdc++.h> using namespace std; #define ll long long #define pb push_back const int MAXN=200005; int dp[2][MAXN]; vector <int> G[MAXN]; int vis[MAXN]; void dfs(int cur){ vis[cur]=1; dp[0][cur]=0; dp[1][cur]=1; for(int nxt: G[cur]){ if(vis[nxt]) continue; dfs(nxt); dp[1][cur]+=min(dp[0][nxt],dp[1][nxt]); dp[0][cur]+=dp[1][nxt]; } } void addedge(int u, int v){ G[u].pb(v); G[v].pb(u); } int main(){ int n, u, v; cin>>n; for(int i=0; i<n-1; i++){ cin>>u>>v; addedge(u,v); } dfs(1); cout<<min(dp[0][1], dp[1][1]); ``` ::: # [Tree Diameter](https://cses.fi/problemset/task/1131) 找樹直徑。先隨便找一個人,DFS找離它最遠的點,再對這個人 DFS找離它最遠的點。這兩點的距離就是直徑長(離根的距離就是深度,可以樹 DP)。 複雜度:$O(n)$ :::spoiler code ```c++= #include <bits/stdc++.h> using namespace std; #define pb push_back #define ll long long const int MAXN=200005; vector <int> G[MAXN]; int dep[MAXN], sink, mxdep=-1; void dfs(int cur, int lst){ dep[cur]=dep[lst]+1; if(dep[cur]>mxdep){ sink=cur; mxdep=dep[cur]; } for(int nxt: G[cur]){ if(nxt==lst) continue; dfs(nxt, cur); } } int main(){ int n; cin>>n; int u,v; for(int i=0; i<n-1; i++){ cin>>u>>v; G[u].pb(v); G[v].pb(u); } dfs(1,0); dep[0]=-1; mxdep=0; dfs(sink,0); cout<<dep[sink]<<endl; ``` ::: # [Tree Distances I](https://cses.fi/problemset/task/1132) 對每個點找離它最遠的點。這一定是直徑上的一個端點。先把直徑兩端點找出來再求每個點的距離。 複雜度:$O(n)$ :::spoiler code ```c++= #include <bits/stdc++.h> using namespace std; #define pb push_back #define ll long long const int MAXN=200005; vector <int> G[MAXN]; int dep1[MAXN], dep2[MAXN], s1, s2, sink, mxdep=-1; void dfs(int cur, int lst){ dep1[cur]=dep1[lst]+1; if(dep1[cur]>mxdep){ mxdep=dep1[cur]; sink=cur; } for(int nxt: G[cur]){ if(nxt==lst) continue; dfs(nxt, cur); } } void dfs2(int cur, int lst){ dep2[cur]=dep2[lst]+1; for(int nxt: G[cur]){ if(nxt==lst) continue; dfs2(nxt, cur); } } int main(){ int n; cin>>n; int u,v; for(int i=0; i<n-1; i++){ cin>>u>>v; G[u].pb(v); G[v].pb(u); } dfs(1,0); s1=sink; dep1[0]=-1; mxdep=0; dfs(s1,0); s2=sink; dep2[0]=-1; mxdep=0; dfs2(s2,0); for(int i=1; i<=n; i++){ cout<<max(dep1[i], dep2[i])<<" "; } } ``` ::: # [Tree Distances II](https://cses.fi/problemset/task/1133) 換根 DP。假設 $u$的答案為$dp_u$,若$v$是$u$的父節點,則轉移式如下 $$dp_u=dp_v+n-2sz_{u}$$ 其中$sz_u$為$u$的子樹大小 複雜度:$O(n)$ :::spoiler code ```c++= #include <bits/stdc++.h> using namespace std; #define ll long long #define pb push_back const int MAXN=200005; vector <int> G[MAXN]; ll sz[MAXN], dp[MAXN], dep[MAXN]; ll totalsize; void dfs1(int cur, int lst){ sz[cur]=1; dep[cur]=dep[lst]+1; for(int nxt: G[cur]){ if(nxt==lst) continue; dfs1(nxt, cur); sz[cur]+=sz[nxt]; } } void dfs2(int cur, int lst){ dp[cur]=dp[lst]+totalsize-(2*sz[cur]); for(int nxt: G[cur]){ if(nxt==lst) continue; dfs2(nxt, cur); } } void addedge(int u, int v){ G[u].pb(v); G[v].pb(u); } int main(){ int n; cin>>n; totalsize=n; int u,v; for(int i=0; i<n-1; i++){ cin>>u>>v; addedge(u,v); } dep[0]=-1; dfs1(1,0); dp[0]=n; for(int i=1; i<=n; i++){ dp[0]+=dep[i]; } dfs2(1,0); for(int i=1; i<=n; i++){ cout<<dp[i]<<" "; } ``` ::: # [Company Queries I](https://cses.fi/problemset/task/1687) 問一個節點高$k$層的祖先是誰。考慮倍增,存下每個節點第$1,2,4,...,2^{18}$層的祖先再組合成第$k$層($u$第$2^{i+1}$層的祖先是它第$2^i$層的祖先的第$2^i$層的祖先) 複雜度:$O(n\log{n}+q\log{n})$ :::spoiler code ```c++= #include <bits/stdc++.h> using namespace std; #define ll long long #define pb push_back const int MAXN=200005; int bz[20][MAXN]; int main(){ int n,q; cin>>n>>q; for(int i=2; i<=n; i++){ cin>>bz[0][i]; } for(int i=1; i<=19; i++){ for(int j=1; j<=n; j++){ if(bz[i-1][j]) bz[i][j]=bz[i-1][bz[i-1][j]]; } } int x,y; for(int aa=0; aa<q; aa++){ cin>>x>>y; for(int i=0; i<20; i++){ if(y&(1<<i)){ if (x) x=bz[i][x]; } } if(x) cout<<x<<endl; else cout<<-1<<endl; } ``` ::: # [Company Queries II](https://cses.fi/problemset/task/1688) 找 LCA。考慮倍增,如果兩個人深度不一樣先讓低的往上跳他們的深度差,接著兩個人一起往上跳,直到跳到一樣的位置(如果$2^i$層的祖先一樣,表示 LCA 比它們高$2^i$層以下) 複雜度:$O((n+q)\log{n})$ :::spoiler code ```c++= #include <bits/stdc++.h> using namespace std; #define ll long long #define pb push_back const int MAXN=200005; vector <int> G[MAXN]; int bz[20][MAXN]; int dep[MAXN]; void dfs(int cur, int lst){ dep[cur]=dep[lst]+1; for(int nxt: G[cur]){ if(nxt==lst) continue; dfs(nxt, cur); } } void build(int n){ for(int i=2; i<=n; i++){ cin>>bz[0][i]; G[i].pb(bz[0][i]); G[bz[0][i]].pb(i); } for(int i=1; i<20; i++){ for(int j=1; j<=n; j++){ if(bz[i-1][j]) bz[i][j]=bz[i-1][bz[i-1][j]]; } } } int main(){ int n,q; cin>>n>>q; build(n); dfs(1,0); int a,b; int jump; for(int aa=0; aa<q; aa++){ cin>>a>>b; if(dep[a]>dep[b]) swap(a,b); jump=dep[b]-dep[a]; for(int i=0; i<20; i++){ if(jump&(1<<i)) b=bz[i][b]; } if(a==b) cout<<a<<endl; else{ for(int i=19; i>-1; i--){ if(bz[i][a]!=bz[i][b]){ a=bz[i][a]; b=bz[i][b]; } } cout<<bz[0][a]<<endl; } } } ``` ::: # [Distance Queries](https://cses.fi/problemset/task/1135) 問兩個點$u,v$的最短路徑距離,那就會是$dep_u+dep_v-2dep_{lca(u,v)}$。樹DP求深度,倍增求LCA。 **有點卡常,記得開 IO 優化。我開完之後 0.94 秒** 複雜度:$O((n+q)\log{n})$ :::spoiler code ```c++= #include <bits/stdc++.h> using namespace std; #define ll long long #define pb push_back const int MAXN=200005; vector <int> G[MAXN]; int bz[20][MAXN]; int dep[MAXN]; void dfs(int cur, int lst){ dep[cur]=dep[lst]+1; for(int nxt: G[cur]){ if(nxt==lst) continue; bz[0][nxt]=cur; dfs(nxt, cur); } } void build(int n){ int u,v; for(int i=0; i<n-1; i++){ cin>>u>>v; G[u].pb(v); G[v].pb(u); } dfs(1,0); for(int i=1; i<20; i++){ for(int j=1; j<=n; j++){ if(bz[i-1][j]) bz[i][j]=bz[i-1][bz[i-1][j]]; } } } int jump; int LCA(int a, int b){ if(dep[a]>dep[b]) swap(a,b); jump=dep[b]-dep[a]; for(int i=0; i<20; i++){ if(jump&(1<<i)) b=bz[i][b]; } if(a==b) return a; else{ for(int i=19; i>-1; i--){ if(bz[i][a]!=bz[i][b]){ a=bz[i][a]; b=bz[i][b]; } } return bz[0][a]; } } int main(){ int n,q; cin>>n>>q; build(n); int x,y,c; for(int aa=0; aa<q; aa++){ cin>>x>>y; c=LCA(x,y); cout<<dep[x]+dep[y]-dep[c]-dep[c]<<endl; } } ``` ::: # [Counting Paths](https://cses.fi/problemset/task/1136/) 考慮前綴和DP,每次我們加入一條$u$到$v$的路徑,我們就把$dp_u, dp_v$都加$1$,然後把$dp_{lca(u,v)}, dp_{fa(lca(u,v))}$都減$1$ 最後我們對整棵樹 DFS,把每個節點加上它所有子節點的$dp$值。 複雜度:$O((n+m)\log{n})$ :::spoiler code ```c++= #include <bits/stdc++.h> using namespace std; #define ll long long #define pb push_back const int MAXN=200014; vector <int> G[MAXN]; int dep[MAXN]; int bz[20][MAXN]; void getdep(int cur, int lst){ dep[cur]=dep[lst]+1; for(int nxt: G[cur]){ if(nxt==lst) continue; bz[0][nxt]=cur; getdep(nxt, cur); } } void build(int n){ int u,v; for(int i=0; i<n-1; i++){ cin>>u>>v; G[u].pb(v); G[v].pb(u); } getdep(1,0); for(int i=1; i<20; i++){ for(int j=1; j<=n; j++){ if(bz[i-1][j]) bz[i][j]=bz[i-1][bz[i-1][j]]; } } } int jump; int lca(int a, int b){ if(dep[a]>dep[b]) swap(a,b); jump=dep[b]-dep[a]; for(int i=0; i<20; i++){ if(jump&(1<<i)) b=bz[i][b]; } if(a==b) return a; for(int i=19; i>=0; i--){ if(bz[i][b]!=bz[i][a]){ a=bz[i][a]; b=bz[i][b]; } } return bz[0][a]; } int dp[MAXN]; void dfs(int cur, int lst){ for(int nxt: G[cur]){ if(nxt==lst) continue; dfs(nxt, cur); dp[cur]+=dp[nxt]; } } int main(){ int n, k; cin>>n>>k; build(n); int u,v; int l; for(int i=0; i<k; i++){ cin>>u>>v; dp[u]++; dp[v]++; l=lca(u,v); dp[l]--; dp[bz[0][l]]--; } dfs(1,0); for(int i=1; i<=n; i++){ cout<<dp[i]<<" "; } } ``` ::: # [Subtree Queries](https://cses.fi/problemset/task/1137/) 求子樹點權和,帶修改。先做樹壓平,之後變成求進入時間和離開時間的區間和,用BIT算區間和。 複雜度:$O(n+q\log{n})$ :::spoiler code ```c++= #include <bits/stdc++.h> using namespace std; #define ll long long #define pb push_back #define LO(x) ((x)&(-(x))) const int MAXN=200005; vector <int> G[MAXN]; ll syp[2*MAXN], tf[MAXN], tl[MAXN], val[MAXN]; bool vis[MAXN]; ll t=0; void dfs(int cur){ vis[cur]=1; t++; tf[cur]=t; syp[t]=val[cur]; for(int nxt: G[cur]){ if(vis[nxt]) continue; dfs(nxt); } t++; tl[cur]=t; syp[t]=val[cur]; } struct BIT{ int n; ll summation; ll s[MAXN*2]; void build(int _n){ n=_n; memset(s,0,sizeof(s)); } void sgadd(int i, int x){ while(i<=2*n){ s[i]+=x; i+=LO(i); } } ll pfsum(int i){ summation=0; while(i>0){ summation+=s[i]; i-=LO(i); } return summation; } }bit; void addedge(int u, int v){ G[u].pb(v); G[v].pb(u); } int main(){ int n,q; cin>>n>>q; for(int i=1; i<=n; i++){ cin>>val[i]; } int u,v; for(int i=0; i<n-1; i++){ cin>>u>>v; addedge(u,v); } dfs(1); bit.build(n*2); for(int i=1; i<=n*2; i++){ bit.sgadd(i, syp[i]); } cout<<endl; ll type, cg; ll ans, minus; for(int aa=0; aa<q; aa++){ cin>>type>>v; if(type==1){ cin>>cg; bit.sgadd(tf[v], cg-syp[tf[v]]); syp[tf[v]]=cg; bit.sgadd(tl[v], cg-syp[tl[v]]); syp[tl[v]]=cg; } else{ ans=bit.pfsum(tl[v]); if(tf[v]){ minus=bit.pfsum(tf[v]-1); } else minus=0; ans-=minus; ans/=2; cout<<ans<<'\n'; } } } ``` ::: # [Path Queries](https://cses.fi/problemset/task/1138) 求根到節點路徑上點權和,帶修改。考慮輕重鏈剖分。 理論上對每一條鏈開一棵線段樹,存這條鏈的點權和,每次詢問查$\log n$棵線段樹。但實際上只需要一棵線段樹,因為我每次先走重邊,鏈就會在壓平的樹上是一個連續段。我們只要把點跳到鏈頂就可以了 (樹壓平聽說也可以) 複雜度:$O(n\log^2n)$ 或 $O(n\log n)$ :::spoiler code(兩個 log) ```c++= #include <bits/stdc++.h> using namespace std; #define ll long long #define pb push_back #define L(id) id*2+1 #define R(id) id*2+2 const int MAXN=200014; vector <int> G[MAXN], ord; ll val[MAXN], seg[MAXN*4]; int sz[MAXN], dep[MAXN], fa[MAXN], hson[MAXN], in[MAXN], out[MAXN], deg[MAXN]; int t=0; void addedge(int u, int v){ G[u].pb(v); G[v].pb(u); deg[u]++; deg[v]++; } void dfs(int cur, int lst){ dep[cur]=dep[lst]+1; fa[cur]=lst; sz[cur]=1; int mx=0; for(int i=0; i<deg[cur]; i++){ if(G[cur][i]==lst) continue; dfs(G[cur][i], cur); sz[cur]+=sz[G[cur][i]]; if(sz[G[cur][i]]>mx){ mx=sz[G[cur][i]]; swap(G[cur][i], G[cur][0]); } } } void hld(int cur, int lst){ in[cur]=++t; ord.pb(cur); for(int nxt: G[cur]){ if(nxt==lst) continue; if(nxt==G[cur][0]) hson[nxt]=hson[cur]; else hson[nxt]=nxt; hld(nxt, cur); } out[cur]=t; } void pull(int id){ seg[id]=seg[L(id)]+seg[R(id)]; } void build(int id, int l, int r){ if(l==r){ seg[id]=val[ord[l]]; return; } int m=(l+r)/2; build(L(id), l, m); build(R(id), m+1, r); pull(id); } void change(int id, int l, int r, ll x, int pos){ if(l==r){ seg[id]=x; return; } int m=(l+r)/2; if(pos<=m) change(L(id), l, m, x, pos); else change(R(id), m+1, r, x, pos); pull(id); } ll query(int id, int l, int r, int L, int R){ if(L<=l&&R>=r) return seg[id]; int m=(l+r)/2; ll tot=0; if(L<=m) tot+=query(L(id), l, m, L, R); if(R>m) tot+=query(R(id), m+1, r, L, R); return tot; } int main(){ ios::sync_with_stdio(false); cin.tie(0); int n,q; cin>>n>>q; ord.pb(0); dep[0]=-1; hson[1]=1; int u, v, type; for(int i=1; i<=n; i++){ cin>>val[i]; } for(int i=0; i<n-1; i++){ cin>>u>>v; addedge(u,v); } dfs(1,0); hld(1,0); build(1,1,n); int a; ll b; int k; for(int i=0; i<q; i++){ cin>>type; if(type==1){ cin>>a>>b; change(1,1,n,b,in[a]); val[a]=b; } else{ cin>>a; ll ans=0; while(hson[a]!=1){ ans+=query(1,1,n,in[hson[a]],in[a]); a=fa[hson[a]]; } ans+=query(1,1,n,in[1],in[a]); cout<<ans<<'\n'; } } } ``` ::: # [Path Queries II](https://cses.fi/problemset/task/2134) 作法跟上一題差不多,只是線段樹改存點權最大值,一樣用HLD做。 **<font color="#f00">這題非常的卡常,只加IO優化通常不會過(除非你code寫很好),建議加Pragma O3</font>** 複雜度:$O(n\log^2n)$ :::spoiler code ```c++ #include <iostream> #include <vector> #include <algorithm> #pragma GCC optimize("O3,unroll-loops") #pragma GCC target("avx,avx2") using namespace std; #define ll long long #define pb push_back #define L(id) 2*id+1 #define R(id) 2*id+2 const int MAXN=200014; vector <int> G[MAXN]; vector <int> ord; int val[MAXN], sz[MAXN], fa[MAXN], deg[MAXN], seg[MAXN*4], in[MAXN], out[MAXN], hson[MAXN], dep[MAXN]; int t=0; void addedge(int u, int v){ G[u].pb(v); G[v].pb(u); deg[u]++; deg[v]++; } void dfs(int cur, int lst){ fa[cur]=lst; sz[cur]=1; dep[cur]=dep[lst]+1; int mx=0; for(int i=0; i<deg[cur]; i++){ if(G[cur][i]==lst) continue; dfs(G[cur][i], cur); sz[cur]+=sz[G[cur][i]]; if(sz[G[cur][i]]>mx){ mx=sz[G[cur][i]]; swap(G[cur][i], G[cur][0]); } } } void hld(int cur, int lst){ in[cur]=++t; ord.pb(cur); for(int nxt: G[cur]){ if(nxt==lst) continue; if(nxt==G[cur][0]) hson[nxt]=hson[cur]; else hson[nxt]=nxt; hld(nxt, cur); } out[cur]=t; } void pull(int id){ seg[id]=max(seg[L(id)], seg[R(id)]); } void build(int id, int l, int r){ if(l==r){ seg[id]=val[ord[l]]; return; } int m=(l+r)/2; build(L(id), l, m); build(R(id), m+1, r); pull(id); } void change(int id, int l, int r, int x, int pos){ if(l==r){ seg[id]=x; return; } int m=(l+r)/2; if(pos<=m) change(L(id), l, m, x, pos); else change(R(id), m+1, r, x, pos); pull(id); } int query(int id, int l, int r, int L, int R){ if(L<=l&&R>=r) return seg[id]; int m=(l+r)/2; int res=-2e9; if(L<=m) res=max(res, query(L(id), l, m, L, R)); if(R>m) res=max(res, query(R(id), m+1, r, L, R)); return res; } int main(){ ios::sync_with_stdio(false); cin.tie(0); int n,q; cin>>n>>q; int a,b; hson[1]=1; dep[0]=-1; ord.pb(0); for(int i=1; i<=n; i++) cin>>val[i]; for(int i=0; i<n-1; i++){ cin>>a>>b; addedge(a,b); } dfs(1,0); hld(1,0); build(1,1,n); int type; for(int i=0; i<q; i++){ cin>>type>>a>>b; if(type==1){ change(1, 1, n, b, in[a]); val[a]=b; } else{ int ans=val[a]; while(hson[a]!=hson[b]){ if(dep[hson[a]]<dep[hson[b]]) swap(a,b); ans=max(ans, query(1, 1, n, in[hson[a]], in[a])); a=fa[hson[a]]; } if(dep[a]<dep[b]) swap(a,b); ans=max(ans, query(1, 1, n, in[b], in[a])); cout<<ans<<" "; } } } ``` ::: # [Distinct Colors](https://cses.fi/problemset/task/1139/) 問一個節點的子樹有多少種相異的值。考慮啟發式合併,對每個節點開一個set,DFS下去把子節點的顏色加進去,但如果子節點大小更大就先交換。 複雜度:$O(n\log n)$ :::spoiler code ```c++= #include <bits/stdc++.h> using namespace std; #define ll long long #define pb push_back const int MAXN=200005; vector <int> G[MAXN]; set <int> s[MAXN]; int col[MAXN]; void dfs (int cur, int lst){ for(int nxt: G[cur]){ if(nxt==lst) continue; dfs(nxt, cur); if(s[cur].size()<s[nxt].size()){ swap(s[cur], s[nxt]); } for(auto c: s[nxt]) s[cur].insert(c); } col[cur]=s[cur].size(); } void addedge(int u, int v){ G[u].pb(v); G[v].pb(u); } int main(){ int n; cin>>n; int temp; int u,v; for(int i=1; i<=n; i++){ cin>>temp; s[i].insert(temp); } for(int i=0; i<n-1; i++){ cin>>u>>v; addedge(u,v); } dfs(1,0); for(int i=1; i<=n; i++){ cout<<col[i]<<" "; } } ``` ::: # [Finding a Centroid](https://cses.fi/problemset/task/2079) 找樹重心。先把每個節點的子樹大小求出來。接著從根DFS下去,如果一個子節點子樹大小超過$n/2$就往下DFS,然後如果找不到這樣的子節點就return。 複雜度:$O(n)$ :::spoiler code ```c++= #include <bits/stdc++.h> using namespace std; #define ll long long #define pb push_back const int MAXN=200005; vector <int> G[MAXN]; int sz[MAXN]; void dfs(int cur, int last){ for(int nxt: G[cur]){ if(nxt==last) continue; dfs(nxt, cur); sz[cur]+=sz[nxt]; } sz[cur]++; } int findcentroid(int cur, int last, int size){ for(int nxt: G[cur]){ if(nxt==last) continue; if(sz[nxt]>size) return findcentroid(nxt, cur, size); } return cur; } int main(){ int n; cin>>n; int u,v; for(int i=1; i<n; i++){ cin>>u>>v; G[u].pb(v); G[v].pb(u); } dfs(1,0); int ans=findcentroid(1,0,n/2); cout<<ans; } ``` ::: # [Fixed-Length Paths I](https://cses.fi/problemset/list/) 重心剖分。建重心樹然後每個點對他的子樹 DFS,用深度檢查答案。 複雜度:常數大的 $O(n\log n)$ **開個 IO 優化吧,我開完 0.99s** :::spoiler ```c++= #include <bits/stdc++.h> using namespace std; #define ll long long #define pb push_back const int MAXN=200005; vector <int> G[MAXN]; int sz[MAXN], dead[MAXN], cnt[MAXN]={1}, low; int n,k; ll ans=0; int getsize(int cur, int lst){ sz[cur]=1; for(int nxt: G[cur]){ if(nxt==lst || dead[nxt]) continue; sz[cur]+=getsize(nxt, cur); } return sz[cur]; } int findcentroid(int cur, int lst, int cursz){ for(int nxt: G[cur]){ if(nxt==lst || dead[nxt]) continue; if(sz[nxt]>=cursz) return findcentroid(nxt, cur, cursz); } return cur; } void solve(int cur, int lst, bool keep, int dep){ if(dep>k) return; low=max(low, dep); if(keep) cnt[dep]++; else ans+=cnt[k-dep]; for(int nxt: G[cur]){ if(nxt==lst || dead[nxt]) continue; solve(nxt, cur, keep, dep+1); } } void decompose(int cur){ int cent=findcentroid(cur, 0, getsize(cur, 0)>>1); low=0; dead[cent]=1; for(int nxt: G[cent]){ if(dead[nxt]) continue; solve(nxt, cent, false, 1); solve(nxt, cent, true, 1); } fill(cnt+1, cnt+low+1, 0); for(int nxt: G[cent]){ if(dead[nxt]) continue; decompose(nxt); } } int main(){ ios::sync_with_stdio(false); cin.tie(0); cin>>n>>k; int u,v; for(int i=0; i<n-1; i++){ cin>>u>>v; G[u].pb(v); G[v].pb(u); } decompose(1); cout<<ans; } ``` ::: # [Fixed-Length Paths II](https://cses.fi/problemset/task/2081) 一樣重心剖分。最直接的方法是前一題的作法配 BIT,但因為常數太大可以用區間和加速。 複雜度:<font color="#f00">$O(n\log^2n)$</font> 或 $O(n\log n)$ **<font color="#f00"> 這題很卡常,兩個$\log$大概不會過</font>** :::spoiler code ```c++= #include <iostream> #include <array> #include <vector> using namespace std; #define ll long long #define pb push_back #define LO(x) x&-x const int MAXN=200005; int n, k1, k2; vector <int> G[MAXN], rem; int sz[MAXN], dead[MAXN], low, cnt[MAXN], sc_nt[MAXN]{1}, vlow; int p1[MAXN], p2[MAXN]; ll ans=0; int findsubtree(int cur, int lst){ sz[cur]=1; for(int nxt: G[cur]){ if((nxt!=lst) && (!dead[nxt])){ findsubtree(nxt, cur); sz[cur]+=sz[nxt]; } } return sz[cur]; } int findcentroid(int cur, int lst, int cursz){ for(int nxt: G[cur]){ if((nxt!=lst) && (!dead[nxt]) && sz[nxt]>=cursz) return findcentroid(nxt, cur, cursz); } return cur; } void solve(int cur, int lst, int dep){ if(dep>k2) return; low=max(low, dep); cnt[dep]++; for(int nxt: G[cur]){ if((nxt!=lst) && (!dead[nxt])) solve(nxt, cur, dep+1); } } void decompose(int cur){ int cent=findcentroid(cur, 0, findsubtree(cur, 0)>>1); vlow=0; dead[cent]=1; ll sum_ori; if(k1==1) sum_ori=1ll; else sum_ori=0ll; for(int nxt: G[cent]){ if(dead[nxt]) continue; low=0; solve(nxt, cent, 1); ll sum=sum_ori; for(int i=1; i<=low; i++){ ans+=(sum*cnt[i]); if(k2-i>=0) sum-=sc_nt[k2-i]; if(k1-i-1>=0) sum+=sc_nt[k1-i-1]; } for(int i=k1-1; i<=min(k2-1, low); i++){ sum_ori+=cnt[i]; } for(int i=1; i<=low; i++){ sc_nt[i]+=cnt[i]; } vlow=max(vlow, low); fill(cnt, cnt+low+1, 0); } fill(sc_nt+1, sc_nt+vlow+1, 0); for(int nxt: G[cent]){ if(!dead[nxt]) decompose(nxt); } } int main(){ ios::sync_with_stdio(false); cin.tie(0); cin>>n>>k1>>k2; int u, v; for(int i=0; i<n-1; i++){ cin>>u>>v; G[u].pb(v); G[v].pb(u); } decompose(1); cout<<ans; } ``` :::