---
title: "Codeforces Round 787 (Div. 3) F(樹上DP)"
tags: 題解, 樹, 動態規劃
---
https://codeforces.com/contest/1675/problem/F
#### 題意
給一棵樹,其中有一些特殊的點 $A=a_1, ..., a_k$,以及起點 $X$ 和終點 $X$;求從起點 $X$ 出發,經過所有特殊的點後,抵達終點 $Y$ 所需的最小成本。
#### 思路
用這題學一下「從根節點出發,走過所有特殊的點所需的成本」。
先用 $dfs()$ 遍歷每個點,紀錄在 $u$ 的 subtree 底下特殊點的數量(或者紀錄是否存在特殊節點即可?)。
再用 $dfs2()$ 去計算所需的路徑數量:對於一個點 $u$,如果它的子節點 $v$ 的 subtree 有包含特殊點,則它對 $u$ 的貢獻是 $1+dfs2(v)+1$,即從 $u$ 走到 $v$、$v$ 的答案、從 $v$ 走回 $u$。
這題需要額外考慮終點 $Y$ 的存在,由於最後要走到的是 $Y$ 而不是 $X$,對於子節點 $v$,不見得總是走過去再走回來(有可能走過去就不回來了,直接去終點那邊。)因此對於那些處在 $X$-$Y$ 路徑 $P$ 上的點,計算時先不考慮它們,最後加上從 $X$ 走到 $Y$ 的貢獻時便相當於走過它們。
要這麼做需要額外紀錄哪些點處在 $P$ 上,做 $dfs2()$ 時只考慮不在 $P$ 上的點。
#### Code
```python=1
def solve(_tc):
dbg("=== Case {} ===".format(str(_tc).rjust(2)))
@bootstrap
def dfs(node, prev):
# check if subtree rooted at u contains special nodes
# also check if node is on the path between x and y
for neigh in adj[node]:
if neigh == prev:
continue
(yield dfs(neigh, node))
special[node] |= special[neigh]
inside[node] |= inside[neigh]
yield
@bootstrap
def dfs2(node, prev):
# check number of edges to collect all special nodes in the subtree
# rooted at node. only check children nodes that are not on the path
# between x and y.
# for a child node, its contribution would be 1 + dfs2(child) + 1,
# which means (go to it) + (its answer) + (come back from it).
e = 0
for neigh in adj[node]:
if neigh == prev:
continue
if not inside[neigh] and special[neigh]:
e += 1 + (yield dfs2(neigh, node)) + 1
yield e
input()
N, K = map(int, input().split())
X, Y = map(int, input().split())
X -= 1; Y -= 1
A = list(map(int, input().split()))
adj = defaultdict(list)
for _ in range(N - 1):
u, v = map(int, input().split())
u -= 1; v -= 1
adj[u].append(v)
adj[v].append(u)
special = [False for _ in range(N)]
for a in A:
special[a - 1] = True
inside = [False for _ in range(N)]
inside[Y] = True
dfs(X, -1)
ans = 0
for v in range(N):
if inside[v]:
ans += dfs2(v, -1)
# finally, there is also a contribution of path from x to y
# which is just the number of nodes on the path and minus one.
ans += inside.count(True) - 1
print(ans)
if __name__ == "__main__":
for _tc in range(1, int(input()) + 1):
solve(vars().get("_tc", 1))
```