# Tree rerooting DP Tree rerooting DP can basically be thought of as a way to speed up problems where the naive solution is to do a DFS from every node in a tree, which runs in $O(N^2)$ time naively. The example problem we will be using is the following: Given a tree with $N$ nodes, find the node that maximizes the sum of distances to all other nodes. The distance between two nodes is the number of edges you must cross to get from one node to the other one. Let us call this node the anti-centroid. The motivation for this problem is that the solution to [BOI 2023 minequake](https://open.kattis.com/problems/boi23.minequake) is to find the anti-centroid, and then simulate the walk from it. The naive solution looks something like the following: ```python adj = [...] def dfs(u, p): # Returns total distance to subtree and number of nodes ret = [0, 1] for e in adj[u]: if e==p: continue r = dfs(e, u) ret[0] += r[0] + r[1] # The r[1] term is because the distance to # all nodes in the subtree should be increased by 1 ret[1] += r[1] return ret best_node = 0 best_dist = 0 for i in range(n): if dfs(i,i)[0]>best_dist: best_dist = dfs(i,i)[0] best_node = i print(best_node) # best_node will now be the anti-centroid ``` This code is "obviously correct", but takes $O(N^2)$ time in total. Surely, we can do better. ## Attempt 1: DP What happens in if we do a DP on dfs? That is to say, we cache the result of (u, p). It may seem like this is a lot of state, but in fact, there are two pairs (u, p) for every edge and one pair (u,u) for every node. Since we have a linear amount of state, it may seem like this should be fast enough. Sadly, it isn't. Consider the following graph: ![Screenshot 2024-10-23 232033](https://hackmd.io/_uploads/Hyf3u1vgyx.png) For this graph, the DP takes $O(N^2)$ time in total. For every node $u \neq 1$, we will call $dfs(1, u)$. There are $N-1$ possible pairs $(1, u)$, and each of them take $O(N)$ time to calculate. However, if we can handle high-degree nodes quickly, then the whole DP should be linear. ## Attempt 2: deleting edges The problem seems to that some nodes have too many edges. What if, every time we use an edge, we delete it? Then, our DP will actually be fast enough. However, how do we actually calculate the DP if we delete edges? We will cache values of the resulting DFS call for every edge: $$\text{ev[u][e]=dfs(e, u)}$$ Then, to calculate the value of some dfs(u,p), we combine the values of $\text{ev[u][e]}$ for every $e \neq p$. We say combine, because to aggregate the dfs calls, we don't simply add them; we aggregate by ```python ret[0] += r[0] + r[1] ret[1] += r[1] ``` (see the $O(N^2)$ code for why we do it this way). However, we can slightly rewrite the code to make it simpler: ```python r = dfs(e, u) r[0] += r[1] ret[0] += r[0] ret[1] += r[1] ``` After adding the $\text{r[0]+=r[1]}$, ret really is the sum of all $\text{dfs(e,u)}$. However, this is still slow; combining all the $\text{ev}$ takes as much time as before. The key insight is that we can also calculate $$\text{accum[u]=the combination of all edges calculated so far}$$ Then, our return value is simply $$\text{accum[u]-ev[u][p]}$$ However, the implementation is still not complete from this description; how do we iterate over all edges to compute $\text{ev}$ and $\text{accum}$, while deleting elements from the list at the same time? We can realize that we will delete every node except our parent, and deleting the last element of an array is very fine. So we will change the spot of $p$ to the beginning of our neighbours, and then repeatedly delete and dfs the last element of the adjacency list. Note that this implementation will destroy the adjacency list, so make a copy of it if you will need it again (needed for minequake). ```python import sys sys.setrecursionlimit(int(2e5)) accum = [[0,1] for i in range(n)] ev = [{} for i in range(n)] def dfs(u, p): # Returns total distance to subtree and number of nodes # Move parent edge to the front if p in adj[u]: adj[u].remove(p) adj[u].insert(0, p) while len(adj[u]): # Repeatedly remove edges from the end e = adj[u][-1] if e==p: break r = dfs(e, u) r[0] += r[1] # Transform r so it's easier to work with # Cache dfs(e, u) accum[u][0] += r[0] accum[u][1] += r[1] ev[u][e] = r del adj[u][-1] # Calculate the value of this node ret = [accum[u][0], accum[u][1]] # If we don't create a new list explicitly, we operate on # a reference to accum[u] if p in ev[u]: # Exclude our parent ret[0] -= ev[u][p][0] ret[1] -= ev[u][p][1] return ret best_node = 0 best_dist = 0 for i in range(n): if dfs(i,i)[0]>best_dist: best_dist = dfs(i,i)[0] best_node = i ``` This will in fact work fast enough. It may seem like moving the parent to the beginning is expensive. While it is an $O(N)$ operation, after it has been moved to the beginning, we remove every edge it was moved past. So we only move each parent past each other edge at most once, which amortizes to $O(N)$. If we want to make it faster, the first thing to improve would probably be to not use lists of size 2 all over the place. After that, the slowest part is very likely to be the fact that $\text{ev}$ uses dictionaries, which are slow. # Part 3: generalizations In our last problem, we had the combine function be ```python def combine(a, b): return [a[0]+b[0], a[1]+b[1]] ``` And we excluded our parent by subtracting it away. However, we can't subtract away from every combine function. More formally, we used the fact that combine has an inverse. In fact, most functions don't have an inverse. We still require that $\text{combine}$ be associative- that is to say, the order we apply it doesn't matter. More formally, we require $$\text{combine(combine(x,y),z)=combine(x, combine(y,z))}$$ Most functions we come accross are at least associative. Associativity is for example required in segment trees. In order to solve it for general associative combine functions, let us think deeper about the behaviour of our DP in part 2. It turns out that the while-loop will only run for 2 dfs calls for every node. The first time $\text{dfs(u,p)}$ is called, it removes every edge except for $p$. Then, when it is called for some other parent, it will finally calculate the last edge. This will be useful later. Now, how can we query everything except for p? How about we create an array of all $\text{ev}$ values, and then compute a prefix and suffix sum (not really sum, combination) for it? So for example, $\text{pref[2]=combine(combine(pref[0], pref[1]), pref[2])}$. Then, to get everything except for $p$, we compute $\text{combine(pref[}ind_p-1\text{], suf[}ind_p\text{+1])}$, where $ind_p$ is the index of $p$. Of course, we have to special-handle if $p$ has not been computed yet. To make this fast, we compute pref and suf after the while-loop every time something changes. As we concluded before, this will happen at most 2 times. So in total, this all runs in linear time. The slowest part will be the usage of maps. # Part 4: problems Blessed with this knowledge, go ahead and solve some problems: - [https://open.kattis.com/problems/boi23.minequake](https://open.kattis.com/problems/boi23.minequake) - [https://open.kattis.com/problems/gasqutva](https://open.kattis.com/problems/gasqutva) - [https://po2punkt0.kattis.com/problems/pokval22.ornattack](https://po2punkt0.kattis.com/problems/pokval22.ornattack) - [https://po2punkt0.kattis.com/problems/bonsai](https://po2punkt0.kattis.com/problems/bonsai) - [https://open.kattis.com/problems/perfectdate](https://open.kattis.com/problems/perfectdate) You can also read more about it [here](https://codeforces.com/blog/entry/124286).