Try   HackMD

LeetCode 2458. Height of Binary Tree After Subtree Removal Queries

https://leetcode.com/problems/height-of-binary-tree-after-subtree-removal-queries/description/

題目大意

題目會給好幾個值裝在 queries 裡,我們每次對原始的二元樹移除 queries 裡的某個值
把移除後的結果裝進陣列 ans

題目保證不會移除二元樹的 root

思考

如果我們要計算移除某節點後的樹高,那我們最好的方式就是事先紀錄各個子樹的樹高情況,這樣求樹高會更有效率

我們先紀錄好原始二元樹,再來紀錄移除某點後的二元樹

最後 treeQueries() 我們只需要對答案就好

值得提醒的是, dfs() 中我們要傳入的 height 是要考慮自己兄弟的子樹,也就是 node->left 要考慮的是 node->rightnode->right 要考慮的是 node->left

class Solution
{
public:
    vector<int> treeQueries(TreeNode *root, vector<int> &queries)
    {
        memset(heightAfterRemove, 0, sizeof(heightAfterRemove));
        memset(heightInit, 0, sizeof(heightInit));
        dfs(root, 0, 0);
        vector<int> ans;

        for (int query : queries)
        {
            ans.push_back(heightAfterRemove[query]);
        }

        return ans;
    }

private:
    int heightInit[100001];
    int heightAfterRemove[100001];

    void dfs(TreeNode *node, int level, int height)
    {
        if (!node)
            return;
        heightAfterRemove[node->val] = height;
        dfs(node->left, level + 1, max(height, level + getHeight(node->right)));
        dfs(node->right, level + 1, max(height, level + getHeight(node->left)));
    }

    int getHeight(TreeNode *node)
    {
        if (!node)
            return 0;
        if (heightInit[node->val])
            return heightInit[node->val];
        return heightInit[node->val] = max(getHeight(node->left), getHeight(node->right)) + 1;
    }
};

再附上 Go 的參考解答:

func treeQueries(root *TreeNode, queries []int) []int {
	const maxNodes = 100001
	var heightInit [maxNodes]int
	var heightAfterRemove [maxNodes]int

	dfs(root, 0, 0, &heightInit, &heightAfterRemove)

	ans := make([]int, len(queries))
	for i, query := range queries {
		ans[i] = heightAfterRemove[query]
	}
	return ans
}

func dfs(node *TreeNode, level, height int, heightInit, heightAfterRemove *[100001]int) {
	if node == nil {
		return
	}
	heightAfterRemove[node.Val] = height

	dfs(node.Left, level+1, int(math.Max(float64(height), float64(level+getHeight(node.Right, heightInit)))), heightInit, heightAfterRemove)
	dfs(node.Right, level+1, int(math.Max(float64(height), float64(level+getHeight(node.Left, heightInit)))), heightInit, heightAfterRemove)
}

func getHeight(node *TreeNode, heightInit *[100001]int) int {
	if node == nil {
		return 0
	}
	if heightInit[node.Val] != 0 {
		return heightInit[node.Val]
	}

	heightInit[node.Val] = int(math.Max(float64(getHeight(node.Left, heightInit)), float64(getHeight(node.Right, heightInit)))) + 1
	return heightInit[node.Val]
}