Skip to content
Related Articles
Open in App
Not now

Related Articles

Sum of nodes within K distance from target

Improve Article
Save Article
  • Difficulty Level : Hard
  • Last Updated : 25 Jan, 2023
Improve Article
Save Article

Given a binary tree, a target node and a positive integer K on it,  the task is to find the sum of all nodes within distance K from the target node (including the value of the target node in the sum).

Examples:

Input: target = 9, K = 1,  
Binary Tree =             1
                                /  \
                             2     9
                           /      /   \
                        4      5     7
                      /  \           /  \
                   8    19      20   11
                 /      /   \
             30     40   50
Output: 22
Explanation: Nodes within distance 1 from 9 is 9 + 5 + 7 + 1 = 22

Input: target = 40,  K = 2,  
Binary Tree =             1
                                /  \
                             2     9
                           /      /   \
                        4      5     7
                      /  \           /  \
                   8    19      20   11
                 /      /   \
             30     40   50
Output: 113
Explanation: Nodes within distance 2 from 40 is
40 + 19 + 50 + 4 = 113

 

Approach: This problem can be solved using hashing and Depth-First-Search based on the following idea:

Use a data structure to store the parent of each node. Now utilise that data structure to perform a DFS traversal from target and calculate the sum of all the nodes within K distance from that node.

Follow the steps mentioned below to implement the approach:

  • Create a hash table (say par)to store the parent of each node.
  • Perform a DFS and store the parent of each node.
  • Now find the target in the tree.
  • Create a hash table to mark the visited nodes.
  • Start a DFS from target:
    • If the distance is not K, add the value in the final sum.
    • If the node is not visited then continue the DFS traversal for its neighbours also (i.e. parent and child) with the help of par and the links of each node.
    • Return the sum of its neighbours while the recursion for the current node is complete
  • Return the sum of all the nodes within K distance from the target.

Below is the implementation of the above approach:

C++




// C++ code to implement above approach
 
#include <bits/stdc++.h>
using namespace std;
 
// Structure of a tree node
struct Node {
    int data;
    Node* left;
    Node* right;
    Node(int val)
    {
        this->data = val;
        this->left = 0;
        this->right = 0;
    }
};
 
// Function for marking the parent node
// for all the nodes using DFS
void dfs(Node* root,
         unordered_map<Node*, Node*>& par)
{
    if (root == 0)
        return;
    if (root->left != 0)
        par[root->left] = root;
    if (root->right != 0)
        par[root->right] = root;
    dfs(root->left, par);
    dfs(root->right, par);
}
 
// Function calling for finding the sum
void dfs3(Node* root, int h, int& sum, int k,
          unordered_map<Node*, int>& vis,
          unordered_map<Node*, Node*>& par)
{
    if (h == k + 1)
        return;
    if (root == 0)
        return;
    if (vis[root])
        return;
    sum += root->data;
    vis[root] = 1;
    dfs3(root->left, h + 1, sum, k, vis, par);
    dfs3(root->right, h + 1, sum, k, vis, par);
    dfs3(par[root], h + 1, sum, k, vis, par);
}
 
// Function for finding
// the target node in the tree
Node* dfs2(Node* root, int target)
{
    if (root == 0)
        return 0;
    if (root->data == target)
        return root;
    Node* node1 = dfs2(root->left, target);
    Node* node2 = dfs2(root->right, target);
    if (node1 != 0)
        return node1;
    if (node2 != 0)
        return node2;
}
 
// Function to find the sum at distance K
int sum_at_distK(Node* root, int target,
                 int k)
{
    // Hash Table to store
    // the parent of a node
    unordered_map<Node*, Node*> par;
 
    // Make the parent of root node as NULL
    // since it does not have any parent
    par[root] = 0;
 
    // Mark the parent node for all the
    // nodes using DFS
    dfs(root, par);
 
    // Find the target node in the tree
    Node* node = dfs2(root, target);
 
    // Hash Table to mark
    // the visited nodes
    unordered_map<Node*, int> vis;
 
    int sum = 0;
 
    // DFS call to find the sum
    dfs3(node, 0, sum, k, vis, par);
    return sum;
}
 
// Driver Code
int main()
{
    // Taking Input
    Node* root = new Node(1);
    root->left = new Node(2);
    root->right = new Node(9);
    root->left->left = new Node(4);
    root->right->left = new Node(5);
    root->right->right = new Node(7);
    root->left->left->left = new Node(8);
    root->left->left->right = new Node(19);
    root->right->right->left = new Node(20);
    root->right->right->right
        = new Node(11);
    root->left->left->left->left
        = new Node(30);
    root->left->left->right->left
        = new Node(40);
    root->left->left->right->right
        = new Node(50);
 
    int target = 9, K = 1;
 
    // Function call
    cout << sum_at_distK(root, target, K);
    return 0;
}


Java




// Java code to implement above approach
import java.util.*;
 
public class Main {
    // Structure of a tree node
    static class Node {
        int data;
        Node left;
        Node right;
        Node(int val)
        {
            this.data = val;
            this.left = null;
            this.right = null;
        }
    }
 
    // Function for marking the parent node
    // for all the nodes using DFS
    static void dfs(Node root,
            HashMap <Node, Node> par)
    {
        if (root == null)
            return;
        if (root.left != null)
            par.put( root.left, root);
        if (root.right != null)
            par.put( root.right, root);
        dfs(root.left, par);
        dfs(root.right, par);
    }
    static int sum;
    // Function calling for finding the sum
    static void dfs3(Node root, int h, int k,
            HashMap <Node, Integer> vis,
            HashMap <Node, Node> par)
    {
        if (h == k + 1)
            return;
        if (root == null)
            return;
        if (vis.containsKey(root))
            return;
        sum += root.data;
        vis.put(root, 1);
        dfs3(root.left, h + 1, k, vis, par);
        dfs3(root.right, h + 1, k, vis, par);
        dfs3(par.get(root), h + 1, k, vis, par);
    }
    // Function for finding
    // the target node in the tree
    static Node dfs2(Node root, int target)
    {
        if (root == null)
            return null;
        if (root.data == target)
            return root;
        Node node1 = dfs2(root.left, target);
        Node node2 = dfs2(root.right, target);
        if (node1 != null)
            return node1;
        if (node2 != null)
            return node2;
        return null;
    }
 
    static int sum_at_distK(Node root, int target,
                 int k)
    {
        // Hash Map to store
        // the parent of a node
        HashMap <Node, Node> par =  new HashMap<>();
 
        // Make the parent of root node as NULL
        // since it does not have any parent
        par.put(root, null);
 
        // Mark the parent node for all the
        // nodes using DFS
        dfs(root, par);
 
        // Find the target node in the tree
        Node node = dfs2(root, target);
 
        // Hash Map to mark
        // the visited nodes
        HashMap <Node, Integer> vis = new HashMap<>();
 
        sum = 0;
 
        // DFS call to find the sum
        dfs3(node, 0, k, vis, par);
        return sum;
    }
 
 
 
    public static void main(String args[]) {
        // Taking Input
        Node root = new Node(1);
        root.left = new Node(2);
        root.right = new Node(9);
        root.left.left = new Node(4);
        root.right.left = new Node(5);
        root.right.right = new Node(7);
        root.left.left.left = new Node(8);
        root.left.left.right = new Node(19);
        root.right.right.left = new Node(20);
        root.right.right.right
            = new Node(11);
        root.left.left.left.left
            = new Node(30);
        root.left.left.right.left
            = new Node(40);
        root.left.left.right.right
            = new Node(50);
 
        int target = 9, K = 1;
 
        // Function call
        System.out.println( sum_at_distK(root, target, K) );
         
    }
}
 
// This code has been contributed by Sachin Sahara (sachin801)


Python3




# python program to implement above approach
# structure of tree node
class Node:
    def __init__(self, val):
        self.data = val
        self.left = None
        self.right = None
 
 
# function for making the parent node
# for all the nodes using DFS
def dfs(root, par):
    if(root is None):
        return
    if(root.left is not None):
        par[root.left] = root
    if(root.right is not None):
        par[root.right] = root
    dfs(root.left, par)
    dfs(root.right, par)
 
 
# function calling for finding the sum
summ = 0
def dfs3(root, h, k, vis, par):
    if(h == k+1):
        return
    if(root is None):
        return
    if(vis.get(root) == 1):
        return
    global summ
    summ += root.data
    vis[root] = 1
    dfs3(root.left, h+1, k, vis, par)
    dfs3(root.right, h+1, k, vis, par)
    dfs3(par[root], h+1, k, vis, par)
 
 
# function for finding
# the target node in the tree
def dfs2(root, target):
    if(root is None):
        return None
    if(root.data == target):
        return root
    node1 = dfs2(root.left, target)
    node2 = dfs2(root.right, target)
    if(node1 is not None):
        return node1
    if(node2 is not None):
        return node2
         
 
# function tofind the sum at distance k
def sum_at_distK(root, target, k):
    # hash tagle to store
    # the parent of a node
    par = {}
     
    # make the parent of root node as None
    # snce it does not have any parent
    par[root] = 0
     
    # make the parent node for all the
    # nodes using DFS
    dfs(root, par)
     
    # find the target node in the tree
    node = dfs2(root, target)
     
    # hash table to make the visited nodes
    vis = {}
     
    # dfs call to find the sum
    dfs3(node, 0, k, vis, par)
 
 
# driver program
root = Node(1)
root.left = Node(2)
root.right = Node(9)
root.left.left = Node(4)
root.right.left = Node(5)
root.right.right = Node(7)
root.left.left.left = Node(8)
root.left.left.right = Node(19)
root.right.right.left = Node(20)
root.right.right.right = Node(11)
root.left.left.left.left = Node(30)
root.left.left.right.left = Node(40)
root.left.left.right.right = Node(50)
 
target = 9
K = 1
 
# function call
sum_at_distK(root, target, K)
print(summ)
 
# this code is contributed by Yash Agarwal(yashagarwal2852002)


C#




// C# code to implement above approach
 
using System;
using System.Collections.Generic;
 
public class GFG {
 
  // Structure of a tree node
  class Node {
    public int data;
    public Node left;
    public Node right;
    public Node(int val)
    {
      this.data = val;
      this.left = null;
      this.right = null;
    }
  }
 
  // Function for marking the parent node
  // for all the nodes using DFS
  static void dfs(Node root, Dictionary<Node, Node> par)
  {
    if (root == null)
      return;
    if (root.left != null)
      par.Add(root.left, root);
    if (root.right != null)
      par.Add(root.right, root);
    dfs(root.left, par);
    dfs(root.right, par);
  }
 
  static int sum;
 
  // Function calling for finding the sum
  static void dfs3(Node root, int h, int k,
                   Dictionary<Node, int> vis,
                   Dictionary<Node, Node> par)
  {
    if (h == k + 1)
      return;
    if (root == null)
      return;
    if (vis.ContainsKey(root))
      return;
    sum += root.data;
    vis.Add(root, 1);
    dfs3(root.left, h + 1, k, vis, par);
    dfs3(root.right, h + 1, k, vis, par);
    dfs3(par[root], h + 1, k, vis, par);
  }
 
  // Function for finding
  // the target node in the tree
  static Node dfs2(Node root, int target)
  {
    if (root == null)
      return null;
    if (root.data == target)
      return root;
    Node node1 = dfs2(root.left, target);
    Node node2 = dfs2(root.right, target);
    if (node1 != null)
      return node1;
    if (node2 != null)
      return node2;
    return null;
  }
 
  static int sum_at_distK(Node root, int target, int k)
  {
 
    // Hash Map to store
    // the parent of a node
    Dictionary<Node, Node> par
      = new Dictionary<Node, Node>();
 
    // Make the parent of root node as NULL
    // since it does not have any parent
    par.Add(root, null);
 
    // Mark the parent node for all the
    // nodes using DFS
    dfs(root, par);
 
    // Find the target node in the tree
    Node node = dfs2(root, target);
 
    // Hash Map to mark
    // the visited nodes
    Dictionary<Node, int> vis
      = new Dictionary<Node, int>();
 
    sum = 0;
 
    // DFS call to find the sum
    dfs3(node, 0, k, vis, par);
    return sum;
  }
 
  static public void Main()
  {
 
    // Code
    Node root = new Node(1);
    root.left = new Node(2);
    root.right = new Node(9);
    root.left.left = new Node(4);
    root.right.left = new Node(5);
    root.right.right = new Node(7);
    root.left.left.left = new Node(8);
    root.left.left.right = new Node(19);
    root.right.right.left = new Node(20);
    root.right.right.right = new Node(11);
    root.left.left.left.left = new Node(30);
    root.left.left.right.left = new Node(40);
    root.left.left.right.right = new Node(50);
 
    int target = 9, K = 1;
 
    // Function call
    Console.Write(sum_at_distK(root, target, K));
  }
}
 
// This code is contributed by lokesh(lokeshmvs21).


Javascript




       // JavaScript code for the above approach
       // Structure of a tree node
       class Node {
           constructor(val) {
               this.data = val;
               this.left = null;
               this.right = null;
           }
       }
 
       // Function for marking the parent node
       // for all the nodes using DFS
       function dfs(root, par) {
           if (root === null) return;
           if (root.left !== null) par.set(root.left, root);
           if (root.right !== null) par.set(root.right, root);
           dfs(root.left, par);
           dfs(root.right, par);
       }
 
       let sum = 0;
 
       // Function calling for finding the sum
       function dfs3(root, h, k, vis, par) {
           if (h === k + 1) return;
           if (root === null) return;
           if (vis.has(root)) return;
           sum += root.data;
           vis.set(root, 1);
           dfs3(root.left, h + 1, k, vis, par);
           dfs3(root.right, h + 1, k, vis, par);
           if (par.get(root) !== null && vis.has(par.get(root))) {
               dfs3(par.get(root), h + 1, k, vis, par);
           }
       }
 
       // Function for finding
       // the target node in the tree
       function dfs2(root, target) {
           if (root === null) return null;
           if (root.data === target) return root;
           let node1 = dfs2(root.left, target);
           let node2 = dfs2(root.right, target);
           if (node1 !== null) return node1;
           if (node2 !== null) return node2;
           return null;
       }
 
       function sumAtDistK(root, target, k)
       {
        
           // Map to store the parent of a node
           let par = new Map();
 
           // Make the parent of root node as NULL
           // since it does not have any parent
           par.set(root, null);
 
           // Mark the parent node for all the
           // nodes using DFS
           dfs(root, par);
 
           // Find the target node in the tree
           let node = dfs2(root, target);
 
           // Map to mark the visited nodes
           let vis = new Map();
           sum = 1;
 
           // DFS call to find the sum
           dfs3(node, 0, k, vis, par);
           return sum;
       }
 
       // Taking Input
       let root = new Node(1);
       root.left = new Node(2);
       root.right = new Node(9);
       root.left.left = new Node(4);
       root.right.left = new Node(5);
       root.right.right = new Node(7);
       root.left.left.left = new Node(8);
       root.left.left.right = new Node(19);
       root.right.right.left = new Node(20);
       root.right.right.right = new Node(11);
       root.left.left.left.left = new Node(30);
       root.left.left.right.left = new Node(40);
       root.left.left.right.right = new Node(50);
 
       let target = 9;
       let K = 1;
 
       console.log(sumAtDistK(root, target, K));
 
// This code is contributed by Potta Lokesh


Output

22

Time Complexity: O(N) where N is the number of nodes in the tree
Auxiliary Space: O(N)


My Personal Notes arrow_drop_up
Related Articles

Start Your Coding Journey Now!