--- title: "Codeforces Round 787 (Div. 3) F(樹上DP)" tags: 題解, 樹, 動態規劃 --- https://codeforces.com/contest/1675/problem/F #### 題意 給一棵樹,其中有一些特殊的點 $A=a_1, ..., a_k$,以及起點 $X$ 和終點 $X$;求從起點 $X$ 出發,經過所有特殊的點後,抵達終點 $Y$ 所需的最小成本。 #### 思路 用這題學一下「從根節點出發,走過所有特殊的點所需的成本」。 先用 $dfs()$ 遍歷每個點,紀錄在 $u$ 的 subtree 底下特殊點的數量(或者紀錄是否存在特殊節點即可?)。 再用 $dfs2()$ 去計算所需的路徑數量:對於一個點 $u$,如果它的子節點 $v$ 的 subtree 有包含特殊點,則它對 $u$ 的貢獻是 $1+dfs2(v)+1$,即從 $u$ 走到 $v$、$v$ 的答案、從 $v$ 走回 $u$。 這題需要額外考慮終點 $Y$ 的存在,由於最後要走到的是 $Y$ 而不是 $X$,對於子節點 $v$,不見得總是走過去再走回來(有可能走過去就不回來了,直接去終點那邊。)因此對於那些處在 $X$-$Y$ 路徑 $P$ 上的點,計算時先不考慮它們,最後加上從 $X$ 走到 $Y$ 的貢獻時便相當於走過它們。 要這麼做需要額外紀錄哪些點處在 $P$ 上,做 $dfs2()$ 時只考慮不在 $P$ 上的點。 #### Code ```python=1 def solve(_tc): dbg("=== Case {} ===".format(str(_tc).rjust(2))) @bootstrap def dfs(node, prev): # check if subtree rooted at u contains special nodes # also check if node is on the path between x and y for neigh in adj[node]: if neigh == prev: continue (yield dfs(neigh, node)) special[node] |= special[neigh] inside[node] |= inside[neigh] yield @bootstrap def dfs2(node, prev): # check number of edges to collect all special nodes in the subtree # rooted at node. only check children nodes that are not on the path # between x and y. # for a child node, its contribution would be 1 + dfs2(child) + 1, # which means (go to it) + (its answer) + (come back from it). e = 0 for neigh in adj[node]: if neigh == prev: continue if not inside[neigh] and special[neigh]: e += 1 + (yield dfs2(neigh, node)) + 1 yield e input() N, K = map(int, input().split()) X, Y = map(int, input().split()) X -= 1; Y -= 1 A = list(map(int, input().split())) adj = defaultdict(list) for _ in range(N - 1): u, v = map(int, input().split()) u -= 1; v -= 1 adj[u].append(v) adj[v].append(u) special = [False for _ in range(N)] for a in A: special[a - 1] = True inside = [False for _ in range(N)] inside[Y] = True dfs(X, -1) ans = 0 for v in range(N): if inside[v]: ans += dfs2(v, -1) # finally, there is also a contribution of path from x to y # which is just the number of nodes on the path and minus one. ans += inside.count(True) - 1 print(ans) if __name__ == "__main__": for _tc in range(1, int(input()) + 1): solve(vars().get("_tc", 1)) ```