# 輕重鍊剖分HLD
* 問題:詢問樹上兩點路徑和,**支援修改**(所以不能倍增)。
* HLD: $O(\log^2 n)$詢問,$O(\log n)$修改
:::info
重子樹:點數量最多的子樹(如果都一樣則選任意個)
輕子樹:不是重子樹的其他子樹
輕邊:連接輕子樹的邊
重邊:連接重子樹的邊
重鍊:一條連續的重邊
LCA(a, b):a,b的最小共同祖先
:::
* 依照定義,把重字樹,輕子樹以及其輕重邊找出來,並對每條重鍊分別開線段樹(or其他資節)
* 性質:
* 每條重鍊不會有分支(不然就不是鍊了)
* 從葉節點往上到根節點,其路徑最多分別會遇到$\log n$個輕邊和重鍊(證:
:::spoiler
:::info
定義size(x)為點x的大小
假設b為a的輕子樹,c為a的重子樹,size(a)>=size(b)+size\(c\),且size\(c\)>=size(b),則size(a)>=2\*size(b)成立,而c一定存在因為如果沒有重子樹就沒有輕子樹
由size(a)>=2\*size(b)可知,每次往上一個輕邊,你目前的子樹大小就至少增加了兩倍(原本是size(b),往上變size(a)),而最多也只有n個點,也就是說你只能往上$\log_2 n$次
而因為多個重鍊一定是由輕邊接起來的,所以已知有log n個輕邊,那就只會有log n +1個重鍊
:::
* 詢問a, b兩點路徑,可以分別處理a到LCA(a, b)和b到LCA(a, b),所以先單看a到LCA(a, b)的部份
* 從a往上,當遇到重鍊,線段樹區間查詢在那條鍊上會走到的範圍的和,並且直接跳到鍊的頂端。
* 遇到輕邊,直接一個一個往上加
* 有log n個重鍊,每個重鍊的區間查詢log n。並且有log n的輕邊,所以總共$O(\log^2 n+\log n)=O(\log^2 n)$
* 實做上直接把每個重鍊都放在同一個線段樹上就好,只要確保同個鍊是在一個連續區間上,而輕鍊也可以直接放在線段樹,反正複雜度還是log^2 n。所以建樹的方法就是做個dfs,並且在外層存一個cnt變數紀錄目前線段樹的位置。每次優先遞迴到重子樹(cnt++),醬重鍊就在連續區間了(要記鍊的頂端,如果是輕子樹那頂端就是自己)。
https://cses.fi/problemset/task/2134
```cpp=
#include <bits/stdc++.h>
#include <queue>
#define pb push_back
#define F first
#define S second
#define rep(X, a,b) for(int X=a;X<b;++X)
#define ALL(a) (a).begin(), (a).end()
#define SZ(a) (int)(a).size()
#define NL "\n"
using namespace std;
typedef pair<long long,long long> pll;
typedef pair<int,int> pii;
typedef long long ll;
ll n, q, val[200010], tree[800010]={0};
vector<int> adj[200010];
int dep[200010], up[200010], sz[200010], rt[200010], id[200010];
//dep:the depth, up:its parent, sz:subtree size, rt:its chain's root, id:pos in segment tree
void upd(int v, int l, int r, int pos, ll x){
if(l==r){
tree[v]=x;
return;
}
int mid=(l+r)>>1;
if(pos<=mid) upd(2*v, l, mid, pos, x);
else if(pos>mid) upd(2*v+1, mid+1, r, pos, x);
tree[v]=max(tree[2*v], tree[2*v+1]);
}
ll get(int v, int l, int r, int ql, int qr){
if(l==ql && r==qr) return tree[v];
int mid=(l+r)>>1;
if(qr<=mid) return get(2*v, l, mid, ql, qr);
else if(ql>mid) return get(2*v+1, mid+1, r, ql, qr);
return max(get(2*v, l, mid, ql, mid), get(2*v+1, mid+1, r, mid+1, qr));
}
void dfs(int v, int par){
sz[v]=1;
dep[v]=dep[par]+1;
for(auto a:adj[v]){
if(a!=par){
dfs(a, v);
sz[v]+=sz[a];
up[a]=v;
}
}
}
int cnt=0;
void hld(int v, int par, int top){
id[v]=cnt++;
rt[v]=top;
upd(1, 0, n, id[v], val[v]);
int mx=-1, mxi=-1;
for(auto a:adj[v]){
if(a!=par) if(sz[a]>mx){
mx=sz[a];
mxi=a;
}
}
if(mxi==-1) return;
hld(mxi, v, top);
for(auto a:adj[v]){
if(a!=par && a!=mxi){
hld(a, v, a);
}
}
}
ll query(int a, int b){
ll res=0;
//while till the the chain of a, b is the lca
while(rt[a]!=rt[b]){
if(dep[rt[a]]<dep[rt[b]]) swap(a, b);
res=max(res, get(1, 0, n, id[rt[a]], id[a]));
a=up[rt[a]];
}
//walk through the rest
if(dep[a]>dep[b]) swap(a, b);
res=max(res, get(1, 0, n, id[a], id[b]));
return res;
}
int main(){
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin>>n>>q;
rep(i,1,n+1) cin>>val[i];
int a, b;
rep(i,0,n-1){
cin>>a>>b;
adj[a].pb(b);
adj[b].pb(a);
}
dep[0]=0;
dfs(1, 0);
hld(1, 0, 1);
int c;
while(q--){
cin>>c;
if(c==1){
ll s, x;
cin>>s>>x;
upd(1, 0, n, id[s], x);
}
else{
ll x, y;
cin>>x>>y;
cout<<query(x, y)<<" ";
}
}
cout<<NL;
}
```