2265 Count Nodes Equal to Average of Subtree
Published:
Problem
Solution
Intuition
The problem asks to find whether each node value is equal to the average value of its inclusive substree. I initially thought this was asking for which nodes are equal to the average of the total tree so this seemed like an easy two-run DFS solution; however, it is actually the individual subtrees which adds some complexity. We know that we must retain the information that each node must pass on to its parents and figuring out this detail is the only hard part!
Approach
The approach is actually very basic when you understand what information must be passed up to each parent node. The parent will want to know the sum of all nodes on its left side and right side as well as the total number of nodes in each side in order to calculate the average. In essence we are looking to find the answer to the following conditional equation for each node:
$node.val \stackrel{?}{=} (\sum node.leftSubtree+\sum node.rightSubtree+node.val) // (node.leftNodes+node.rightNodes+1)$
The sum of left and right can easily be passed up to the parent in a tuple and they will continue to accumulate up the tree as the recursive calls finish.
Complexity
Time complexity: $O(n)$ time complexity as all nodes are visited once and a constant number of operations are done on each node.
Space complexity: $O(n)$ worst case as each node creates a values to track the sum and total number of nodes. The maximum space that will be used before the values are cleared when the stack starts to clear is $n/2$ as the last row can have $n/2$ nodes and each will need some space for its variables.
Code
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def averageOfSubtree(self, root: Optional[TreeNode]) -> int:
// Total number of nodes which match the criteria of value equalling the average of their inclusive subtree
count = 0
def dfs(root):
// accessing global variable
nonlocal count
// each node will track the sum and number of nodes of its children
leftSum, leftNodes, rightSum, rightNodes = 0, 0, 0, 0
if root.left:
leftSum, leftNodes = dfs(root.left)
if root.right:
rightSum, rightNodes = dfs(root.right)
// check if this node matches the criteria
if root.val == (leftSum + rightSum + root.val) // (leftNodes + rightNodes + 1):
count += 1
// return the total sum and node that this subtree represents for its parent
return (leftSum + rightSum + root.val, leftNodes + rightNodes + 1)
dfs(root)
return count