# 輕重鍊剖分HLD * 問題:詢問樹上兩點路徑和,**支援修改**(所以不能倍增)。 * HLD: $O(\log^2 n)$詢問,$O(\log n)$修改 :::info 重子樹:點數量最多的子樹(如果都一樣則選任意個) 輕子樹:不是重子樹的其他子樹 輕邊:連接輕子樹的邊 重邊:連接重子樹的邊 重鍊:一條連續的重邊 LCA(a, b):a,b的最小共同祖先 ::: * 依照定義,把重字樹,輕子樹以及其輕重邊找出來,並對每條重鍊分別開線段樹(or其他資節) * 性質: * 每條重鍊不會有分支(不然就不是鍊了) * 從葉節點往上到根節點,其路徑最多分別會遇到$\log n$個輕邊和重鍊(證: :::spoiler :::info 定義size(x)為點x的大小 假設b為a的輕子樹,c為a的重子樹,size(a)>=size(b)+size\(c\),且size\(c\)>=size(b),則size(a)>=2\*size(b)成立,而c一定存在因為如果沒有重子樹就沒有輕子樹 由size(a)>=2\*size(b)可知,每次往上一個輕邊,你目前的子樹大小就至少增加了兩倍(原本是size(b),往上變size(a)),而最多也只有n個點,也就是說你只能往上$\log_2 n$次 而因為多個重鍊一定是由輕邊接起來的,所以已知有log n個輕邊,那就只會有log n +1個重鍊 ::: * 詢問a, b兩點路徑,可以分別處理a到LCA(a, b)和b到LCA(a, b),所以先單看a到LCA(a, b)的部份 * 從a往上,當遇到重鍊,線段樹區間查詢在那條鍊上會走到的範圍的和,並且直接跳到鍊的頂端。 * 遇到輕邊,直接一個一個往上加 * 有log n個重鍊,每個重鍊的區間查詢log n。並且有log n的輕邊,所以總共$O(\log^2 n+\log n)=O(\log^2 n)$ * 實做上直接把每個重鍊都放在同一個線段樹上就好,只要確保同個鍊是在一個連續區間上,而輕鍊也可以直接放在線段樹,反正複雜度還是log^2 n。所以建樹的方法就是做個dfs,並且在外層存一個cnt變數紀錄目前線段樹的位置。每次優先遞迴到重子樹(cnt++),醬重鍊就在連續區間了(要記鍊的頂端,如果是輕子樹那頂端就是自己)。 https://cses.fi/problemset/task/2134 ```cpp= #include <bits/stdc++.h> #include <queue> #define pb push_back #define F first #define S second #define rep(X, a,b) for(int X=a;X<b;++X) #define ALL(a) (a).begin(), (a).end() #define SZ(a) (int)(a).size() #define NL "\n" using namespace std; typedef pair<long long,long long> pll; typedef pair<int,int> pii; typedef long long ll; ll n, q, val[200010], tree[800010]={0}; vector<int> adj[200010]; int dep[200010], up[200010], sz[200010], rt[200010], id[200010]; //dep:the depth, up:its parent, sz:subtree size, rt:its chain's root, id:pos in segment tree void upd(int v, int l, int r, int pos, ll x){ if(l==r){ tree[v]=x; return; } int mid=(l+r)>>1; if(pos<=mid) upd(2*v, l, mid, pos, x); else if(pos>mid) upd(2*v+1, mid+1, r, pos, x); tree[v]=max(tree[2*v], tree[2*v+1]); } ll get(int v, int l, int r, int ql, int qr){ if(l==ql && r==qr) return tree[v]; int mid=(l+r)>>1; if(qr<=mid) return get(2*v, l, mid, ql, qr); else if(ql>mid) return get(2*v+1, mid+1, r, ql, qr); return max(get(2*v, l, mid, ql, mid), get(2*v+1, mid+1, r, mid+1, qr)); } void dfs(int v, int par){ sz[v]=1; dep[v]=dep[par]+1; for(auto a:adj[v]){ if(a!=par){ dfs(a, v); sz[v]+=sz[a]; up[a]=v; } } } int cnt=0; void hld(int v, int par, int top){ id[v]=cnt++; rt[v]=top; upd(1, 0, n, id[v], val[v]); int mx=-1, mxi=-1; for(auto a:adj[v]){ if(a!=par) if(sz[a]>mx){ mx=sz[a]; mxi=a; } } if(mxi==-1) return; hld(mxi, v, top); for(auto a:adj[v]){ if(a!=par && a!=mxi){ hld(a, v, a); } } } ll query(int a, int b){ ll res=0; //while till the the chain of a, b is the lca while(rt[a]!=rt[b]){ if(dep[rt[a]]<dep[rt[b]]) swap(a, b); res=max(res, get(1, 0, n, id[rt[a]], id[a])); a=up[rt[a]]; } //walk through the rest if(dep[a]>dep[b]) swap(a, b); res=max(res, get(1, 0, n, id[a], id[b])); return res; } int main(){ ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0); cin>>n>>q; rep(i,1,n+1) cin>>val[i]; int a, b; rep(i,0,n-1){ cin>>a>>b; adj[a].pb(b); adj[b].pb(a); } dep[0]=0; dfs(1, 0); hld(1, 0, 1); int c; while(q--){ cin>>c; if(c==1){ ll s, x; cin>>s>>x; upd(1, 0, n, id[s], x); } else{ ll x, y; cin>>x>>y; cout<<query(x, y)<<" "; } } cout<<NL; } ```