Try   HackMD

222. Count Complete Tree Nodes

Question

Given a complete binary tree, count the number of nodes.

Note:

Definition of a complete binary tree from Wikipedia:
In a complete binary tree every level, except possibly the last, is completely filled, and all nodes in the last level are as far left as possible. It can have between 1 and 2h nodes inclusive at the last level h.

Example:

Input: 
    1
   / \
  2   3
 / \  /
4  5 6

Output: 6

Solution: Python

compare the depth between left sub tree and right sub tree.
A, If it is equal, it means the left sub tree is a full binary tree
B, It it is not , it means the right sub tree is a full binary tree

# O(logn * logn) class Solution: # @param {TreeNode} root # @return {integer} def countNodes(self, root): if not root: return 0 leftDepth = self.getDepth(root.left) rightDepth = self.getDepth(root.right) if leftDepth == rightDepth: return pow(2, leftDepth) + self.countNodes(root.right) else: return pow(2, rightDepth) + self.countNodes(root.left) def getDepth(self, root): if not root: return 0 return 1 + self.getDepth(root.left)

Solution: Official

Approach 1: Naive solution

Intuition

Approach 1 doesn't profit from the fact that the tree is a complete one.

In a complete binary tree every level, except possibly the last, is completely filled, and all nodes in the last level are as far left as possible.

That means that complete tree has 2^k2k nodes in the kth level if the kth level is not the last one. The last level may be not filled completely, and hence in the last level the number of nodes could vary from 1 to 2^d2d, where d is a tree depth.

fig

Now one could compute the number of nodes in all levels but the last one: \sum_{k = 0}^{k = d - 1}{2^k} = 2^d - 1∑k=0k=d−12k=2d−1. That reduces the problem to the simple check of how many nodes the tree has in the last level.

fic

Now there are two questions:

  1. How many nodes in the last level have to be checked?
  2. What is the best time performance for such a check?

Let's start from the first question. It's a complete tree, and hence all nodes in the last level are as far left as possible. That means that instead of checking the existence of all 2^d2d possible leafs, one could use binary search and check \log(2^d) = dlog(2d)=d leafs only.

pic

Let's move to the second question, and enumerate potential nodes in the last level from 0 to 2^d - 12d−1. How to check if the node number idx exists? Let's use binary search again to reconstruct the sequence of moves from root to idx node. For example, idx = 4. idx is in the second half of nodes 0,1,2,3,4,5,6,7 and hence the first move is to the right. Then idx is in the first half of nodes 4,5,6,7 and hence the second move is to the left. The idx is in the first half of nodes 4,5 and hence the next move is to the left. The time complexity for one check is \mathcal{O}(d)O(d).

pif

1 and 2 together result in \mathcal{O}(d)O(d) checks, each check at a price of \mathcal{O}(d)O(d). That means that the overall time complexity would be \mathcal{O}(d^2)O(d2).

Algorithm

  • Return 0 if the tree is empty.
  • Compute the tree depth d.
  • Return 1 if d == 0.
  • The number of nodes in all levels but the last one is 2^d - 12d−1. The number of nodes in the last level could vary from 1 to 2^d2d. Enumerate potential nodes from 0 to 2^d - 12d−1 and perform the binary search by the node index to check how many nodes are in the last level. Use the function exists(idx, d, root) to check if the node with index idx exists.
  • Use binary search to implement exists(idx, d, root) as well.
  • Return 2^d - 12d−1 + the number of nodes in the last level.

Implementation

Complexity Analysis

  • Time complexity : \mathcal{O}(d^2) = \mathcal{O}(\log^2 N)O(d2)=O(log2N), where dd is a tree depth.
  • Space complexity : \mathcal{O}(1)O(1).

Solution: Python O(n)

# Definition for a binary tree node. # class TreeNode(object): # def __init__(self, x): # self.val = x # self.left = None # self.right = None # # h last_level c # 1 1 1 0 # 2 2 1 -1 # 3 2 2 1 # 4 3 1 -2 # 5 3 2 0 # 6 3 3 0 # 7 3 4 2 # 8 4 1 -3 # 9 4 2 -1 # 10 4 3 -1 # 11 4 4 1 # 12 4 5 -1 # 13 4 6 1 # 14 4 7 1 # 15 4 8 3 # height = log_2(num_of_nodes) # number_of_nodes = 2 ** (height) - 1 # number_of_nodes_last_level = 2 ** (height - 1) # last_level = getRightmostNode() # 1 # / \ # 1 2 # / \ / \ # 1 2 3 4 # / \ / \ / \ / \ # 1 2 3 4 5 6 7 8 class Solution(object): def countNodes(self, root): """ :type root: TreeNode :rtype: int """ self.height = self.getHeight(root) # O(logn) rightmost_index = self.getRightmostIndex(root) # O(n) print(rightmost_index) return self.getMaximumCount(self.height - 1) def countNodesLinear(self, root): """ :type root: TreeNode :rtype: int """ self.height = self.getHeight(root) # O(logn) rightmost_index = self.getRightmostIndexLinear(root, 1, 1) or 0 # O(n) return self.getMaximumCount(self.height - 1) + rightmost_index def getMaximumCount(self, height): if height < 0: return 0 return int(math.pow(2, height)) - 1 def getHeight(self, node): if node is None: return 0 height = self.getHeight(node.left) return height + 1 def getRightmostIndexLinear(self, node, depth, index): if node is None: return None if depth == self.height: return index return ( self.getRightmostIndex(node.right, depth + 1, index * 2) or self.getRightmostIndex(node.left, depth + 1, index * 2 - 1) ) def getRightmostIndex(self, node): if node is None: return None l_height = self.getHeight(node.left) r_height = self.getHeight(node.right) return ( self.getRightmostIndex(node.right) if l_height == r_height else self.getRightmostIndex(node.left) )