Open in App
Not now

# Count of Nodes at distance K from S in its subtree for Q queries

• Last Updated : 14 Mar, 2023

Given a tree consisting of N nodes and rooted at node 1, also given an array Q[] of M pairs, where each array element represents a query of the form (S, K). The task is to print the number of nodes at the distance K in the subtree of the node S for each query (S, K) of the array.

Examples:

Input:  Q[] = {{2, 1}, {1, 1}},

Output:
2
Explanation:

1. Query(2, 1): Print 3, as there are 3 nodes 4, 5, and 6 at the distance of 1 in the subtree of node 2.
2. Query(1, 1): Print 2, as there are 2 nodes 2, and 3 at the distance of 1 in the subtree of node 1.

Input: Edges = {{1, 2}, {2, 3}, {3, 4}}, Q[] = {{1, 2}, {2, 2}}
Output: 1 1

Naive Approach: The simplest approach is for each query to run a Depth First Search(DFS) from node S and find all the nodes that are at a distance K from a given node S

Time Complexity: O(N*Q)
Auxiliary Space: O(1)

Efficient Approach: The above approach can be optimized based on the following observations:

1. Suppose, tin[] stores the entry time of every node and tout[] stores the exit time of a node according to dfs traversal of the tree.
2. Then, for two nodes, A and B, B will be in the subtree of A if and only if:
• tin[B]â‰¥tin[A] and tout[B]â‰¤tout[A]
3. Suppose, levels[], where levels[i] store the entry times of all nodes present at depth i.
4. Then, using binary search nodes at a distance K from a node can be found.

Follow the steps below to solve the problem:

• Initialize three arrays, say tin[], tout[], and depth[] to store the entry time, exit time, and depth of a node respectively.
• Initialize two 2D vectors, say adj and levels, to store the adjacency list and entry times of every node at a specific depth.
• Initialize a variable, say t as 1, to keep track of time.
• Define a recursive DFS function, say dfs(node, parent, d), and perform the following steps:
• Assign t to tin[node] and then increment t by 1.
• Push the tin[node] in the vector levels[d] and then assign d to depth[node].
• Iterate over the children of the node and call the recursive function as dfs(X, node, d+1) for every child X.
• After the above steps, assign t to tout[node] and increment t by 1.
• Call the recursive function dfs(1, 1, 0).
• Traverse the array Q[] using the variable i and do the following:
• Store the value of the current array element as S = Q[i].first, and K = Q[i].second.
• Find the count of all the nodes greater than tin[S] in the vector levels[depth[S]+K] and store it in a variable say L.
• Find the count of all the nodes greater than tout[S] in the vector levels[depth[S]+K] and store it in a variable say R.
• Print the value of R-L as the answer to the current query.

Below is the implementation of the above approach:

## C++

 `// C++ program for the above approach` `#include ` `using` `namespace` `std;`   `int` `tin[100], tout[100], depth[100];` `int` `t = 0;`   `// Function to add edges` `void` `Add_edge(``int` `parent, ``int` `child,` `              ``vector >& adj)` `{` `    ``adj[parent].push_back(child);` `    ``adj[child].push_back(parent);` `}`   `// Function to perform Depth First Search` `void` `dfs(``int` `node, ``int` `parent, vector >& adj,` `         ``vector >& levels, ``int` `d)` `{` `    ``// Stores the entry time of a node` `    ``tin[node] = t++;`   `    ``// Stores the entering time` `    ``// of a node at depth d` `    ``levels[d].push_back(tin[node]);` `    ``depth[node] = d;`   `    ``// Iterate over the children of node` `    ``for` `(``auto` `x : adj[node]) {` `        ``if` `(x != parent)` `            ``dfs(x, node, adj, levels, d + 1);` `    ``}`   `    ``// Stores the Exit time of a node` `    ``tout[node] = t++;` `}`   `// Function to find number of nodes` `// at distance K from node S in the` `// subtree of S` `void` `numberOfNodes(``int` `node, ``int` `dist,` `                   ``vector >& levels)` `{` `    ``// Distance from root node` `    ``dist += depth[node];`   `    ``// Index of node with greater tin value` `    ``// then tin[S]` `    ``int` `start = lower_bound(levels[dist].begin(),` `                            ``levels[dist].end(), tin[node])` `                ``- levels[dist].begin();`   `    ``// Index of node with greater tout value then tout[S]` `    ``int` `ed = lower_bound(levels[dist].begin(),` `                         ``levels[dist].end(), tout[node])` `             ``- levels[dist].begin();`   `    ``// Answer to the Query` `    ``cout << ed - start << endl;` `}`   `// Function for performing DFS` `// and answer to queries` `void` `numberOfNodesUtil(pair<``int``, ``int``> Q[], ``int` `M, ``int` `N)` `{`   `    ``vector > adj(N + 5), levels(N + 5);`   `    ``Add_edge(1, 2, adj);` `    ``Add_edge(1, 3, adj);` `    ``Add_edge(2, 4, adj);` `    ``Add_edge(2, 5, adj);` `    ``Add_edge(2, 6, adj);`   `    ``t = 1;`   `    ``// DFS function call` `    ``dfs(1, 1, adj, levels, 0);`   `    ``// Traverse the array Q[]` `    ``for` `(``int` `i = 0; i < M; ++i) {` `        ``numberOfNodes(Q[i].first, Q[i].second, levels);` `    ``}` `}`   `// Driver Code` `int` `main()` `{` `    ``// Input` `    ``int` `N = 6;` `    ``pair<``int``, ``int``> Q[] = { { 2, 1 }, { 1, 1 } };` `    ``int` `M = ``sizeof``(Q) / ``sizeof``(Q[0]);`   `    ``// Function call` `    ``numberOfNodesUtil(Q, M, N);` `}`

## Java

 `import` `java.util.*;`   `public` `class` `Main {` `    ``static` `int``[] tin = ``new` `int``[``100``];` `    ``static` `int``[] tout = ``new` `int``[``100``];` `    ``static` `int``[] depth = ``new` `int``[``100``];` `    ``static` `int` `t = ``0``;`   `    ``// Function to add edges` `    ``static` `void` `addEdge(``int` `parent, ``int` `child, List> adj) {` `        ``adj.get(parent).add(child);` `        ``adj.get(child).add(parent);` `    ``}`   `    ``// Function to perform Depth First Search` `    ``static` `void` `dfs(``int` `node, ``int` `parent, List> adj, List> levels, ``int` `d) {` `        ``// Stores the entry time of a node` `        ``tin[node] = t++;`   `        ``// Stores the entering time` `        ``// of a node at depth d` `        ``levels.get(d).add(tin[node]);` `        ``depth[node] = d;`   `        ``// Iterate over the children of node` `        ``for` `(``int` `x : adj.get(node)) {` `            ``if` `(x != parent)` `                ``dfs(x, node, adj, levels, d + ``1``);` `        ``}`   `        ``// Stores the Exit time of a node` `        ``tout[node] = t++;` `    ``}`   `    ``// Function to find number of nodes` `    ``// at distance K from node S in the` `    ``// subtree of S` `    ``static` `void` `numberOfNodes(``int` `node, ``int` `dist, List> levels) {` `        ``// Distance from root node` `        ``dist += depth[node];`   `        ``// Index of node with greater tin value` `        ``// then tin[S]` `        ``int` `start = Collections.binarySearch(levels.get(dist), tin[node]);`   `        ``if` `(start < ``0``) {` `            ``start = -(start + ``1``);` `        ``}`   `        ``// Index of node with greater tout value then tout[S]` `        ``int` `ed = Collections.binarySearch(levels.get(dist), tout[node]);`   `        ``if` `(ed < ``0``) {` `            ``ed = -(ed + ``1``);` `        ``}`   `        ``// Answer to the Query` `        ``System.out.println(ed - start);` `    ``}`   `    ``// Function for performing DFS` `    ``// and answer to queries` `    ``static` `void` `numberOfNodesUtil(``int``[][] Q, ``int` `M, ``int` `N) {` `        ``List> adj = ``new` `ArrayList<>(N + ``5``);` `        ``List> levels = ``new` `ArrayList<>(N + ``5``);`   `        ``for` `(``int` `i = ``0``; i < N + ``5``; i++) {` `            ``adj.add(``new` `ArrayList<>());` `            ``levels.add(``new` `ArrayList<>());` `        ``}`   `        ``addEdge(``1``, ``2``, adj);` `        ``addEdge(``1``, ``3``, adj);` `        ``addEdge(``2``, ``4``, adj);` `        ``addEdge(``2``, ``5``, adj);` `        ``addEdge(``2``, ``6``, adj);`   `        ``t = ``1``;`   `        ``// DFS function call` `        ``dfs(``1``, ``1``, adj, levels, ``0``);`   `        ``// Traverse the array Q[]` `        ``for` `(``int` `i = ``0``; i < M; ++i) {` `            ``numberOfNodes(Q[i][``0``], Q[i][``1``], levels);` `        ``}` `    ``}`   `    ``// Driver Code` `    ``public` `static` `void` `main(String[] args) {` `        ``// Input` `        ``int` `N = ``6``;` `        ``int``[][] Q = {{``2``, ``1``}, {``1``, ``1``}};` `        ``int` `M = Q.length;`   `        ``// Function call` `        ``numberOfNodesUtil(Q, M, N);` `    ``}` `}`

## Python3

 `# Python3 program for the above approach` `from` `bisect ``import` `bisect_left, bisect_right`   `tin ``=` `[``0``] ``*` `100` `tout ``=` `[``0``] ``*` `100` `depth ``=` `[``0``] ``*` `100` `t ``=` `0`   `# Function to add edges` `def` `Add_edge(parent, child, adj):` `    `  `    ``adj[parent].append(child)` `    ``adj[child].append(parent)` `    ``return` `adj`   `# Function to perform Depth First Search` `def` `dfs(node, parent, d):` `    `  `    ``global` `tin, tout, depth, adj, levels, t` `    `  `    ``# Stores the entry time of a node` `    ``tin[node] ``=` `t` `    ``t ``+``=` `1`   `    ``# Stores the entering time` `    ``# of a node at depth d` `    ``levels[d].append(tin[node])` `    ``depth[node] ``=` `d`   `    ``# Iterate over the children of node` `    ``for` `x ``in` `adj[node]:` `        ``if` `(x !``=` `parent):` `            ``dfs(x, node, d ``+` `1``)`   `    ``# Stores the Exit time of a node` `    ``tout[node] ``=` `t` `    ``t ``+``=` `1`   `# Function to find number of nodes` `# at distance K from node S in the` `# subtree of S` `def` `numberOfNodes(node, dist):` `    `  `    ``global` `levels, tin, tout` `    `  `    ``# Distance from root node` `    ``dist ``+``=` `depth[node]`   `    ``# Index of node with greater tin value` `    ``# then tin[S]` `    ``start ``=` `bisect_left(levels[dist], tin[node])`   `    ``# Index of node with greater tout value then tout[S]` `    ``ed ``=` `bisect_left(levels[dist], tout[node])`   `    ``# Answer to the Query` `    ``print``(ed ``-` `start)`   `# Function for performing DFS` `# and answer to queries` `def` `numberOfNodesUtil(Q, M, N):` `    `  `    ``global` `t, adj`   `    ``adj ``=` `Add_edge(``1``, ``2``, adj)` `    ``adj ``=` `Add_edge(``1``, ``3``, adj)` `    ``adj ``=` `Add_edge(``2``, ``4``, adj)` `    ``adj ``=` `Add_edge(``2``, ``5``, adj)` `    ``adj ``=` `Add_edge(``2``, ``6``, adj)`   `    ``t ``=` `1`   `    ``# DFS function call` `    ``dfs(``1``, ``1``, ``0``)`   `    ``# Traverse the array Q[]` `    ``for` `i ``in` `range``(M):` `        ``numberOfNodes(Q[i][``0``], Q[i][``1``])`   `# Driver Code` `if` `__name__ ``=``=` `'__main__'``:` `    `  `    ``# Input` `    ``N ``=` `6` `    ``Q ``=` `[ [ ``2``, ``1` `], [ ``1``, ``1` `] ]`   `    ``M ``=` `len``(Q)`   `    ``adj ``=` `[[] ``for` `i ``in` `range``(N``+``5``)]` `    ``levels ``=` `[[] ``for` `i ``in` `range``(N ``+` `5``)]`   `    ``# Function call` `    ``numberOfNodesUtil(Q, M, N)`   `# This code is contributed by mohit kumar 29`

## C#

 `// C# equivalent of the above code` `using` `System;` `using` `System.Collections.Generic;` `using` `System.Linq;`   `namespace` `Main` `{` `  ``public` `class` `Program` `  ``{` `    ``static` `int``[] tin = ``new` `int``[100];` `    ``static` `int``[] tout = ``new` `int``[100];` `    ``static` `int``[] depth = ``new` `int``[100];` `    ``static` `int` `t = 0;`   `    ``// Function to add edges` `    ``static` `void` `addEdge(``int` `parent, ``int` `child, List> adj)` `    ``{` `      ``adj[parent].Add(child);` `      ``adj[child].Add(parent);` `    ``}`   `    ``// Function to perform Depth First Search` `    ``static` `void` `dfs(``int` `node, ``int` `parent, List> adj, List> levels, ``int` `d)` `    ``{` `      ``// Stores the entry time of a node` `      ``tin[node] = t++;`   `      ``// Stores the entering time` `      ``// of a node at depth d` `      ``levels[d].Add(tin[node]);` `      ``depth[node] = d;`   `      ``// Iterate over the children of node` `      ``foreach` `(``int` `x ``in` `adj[node])` `      ``{` `        ``if` `(x != parent)` `          ``dfs(x, node, adj, levels, d + 1);` `      ``}`   `      ``// Stores the Exit time of a node` `      ``tout[node] = t++;` `    ``}`   `    ``// Function to find number of nodes` `    ``// at distance K from node S in the` `    ``// subtree of S` `    ``static` `void` `numberOfNodes(``int` `node, ``int` `dist, List> levels)` `    ``{` `      ``// Distance from root node` `      ``dist += depth[node];`   `      ``// Index of node with greater tin value` `      ``// then tin[S]` `      ``int` `start = levels[dist].BinarySearch(tin[node]);`   `      ``if` `(start < 0)` `      ``{` `        ``start = -(start + 1);` `      ``}`   `      ``// Index of node with greater tout value then tout[S]` `      ``int` `ed = levels[dist].BinarySearch(tout[node]);`   `      ``if` `(ed < 0)` `      ``{` `        ``ed = -(ed + 1);` `      ``}`   `      ``// Answer to the Query` `      ``Console.WriteLine(ed - start);` `    ``}`   `    ``// Function for performing DFS` `    ``// and answer to queries` `    ``static` `void` `numberOfNodesUtil(``int``[][] Q, ``int` `M, ``int` `N)` `    ``{` `      ``List> adj = ``new` `List>();` `      ``List> levels = ``new` `List>();`   `      ``for` `(``int` `i = 0; i < N + 5; i++)` `      ``{` `        ``adj.Add(``new` `List<``int``>());` `        ``levels.Add(``new` `List<``int``>());` `      ``}`   `      ``addEdge(1, 2, adj);` `      ``addEdge(1, 3, adj);` `      ``addEdge(2, 4, adj);` `      ``addEdge(2, 5, adj);` `      ``addEdge(2, 6, adj);`   `      ``t = 1;`   `      ``// DFS function call` `      ``dfs(1, 1, adj, levels, 0);`   `      ``// Traverse the array Q[]` `      ``for` `(``int` `i = 0; i < M; ++i)` `      ``{` `        ``numberOfNodes(Q[i][0], Q[i][1], levels);` `      ``}` `    ``}`   `    ``// Driver Code` `    ``public` `static` `void` `Main(``string``[] args)` `    ``{` `      ``// Input` `      ``int` `N = 6;` `      ``int``[][] Q = { ``new` `int``[] { 2, 1 }, ``new` `int``[] { 1, 1 } };` `      ``int` `M = Q.Length;`   `      ``// Function call` `      ``numberOfNodesUtil(Q, M, N);` `    ``}` `  ``}` `}`

## Javascript

 `let tin = [];` `let tout = [];` `let depth = [];` `let t = 0;`   `// Function to add edges` `function` `addEdge(parent, child, adj) {` `    ``adj[parent].push(child);` `    ``adj[child].push(parent);` `}`   `// Function to perform Depth First Search` `function` `dfs(node, parent, adj, levels, d) {` `    ``// Stores the entry time of a node` `    ``tin[node] = t++;` `    `  `    ``// Stores the entering time` `    ``// of a node at depth d` `    ``levels[d].push(tin[node]);` `    ``depth[node] = d;` `    `  `    ``// Iterate over the children of node` `    ``for` `(let i = 0; i < adj[node].length; i++) {` `        ``let x = adj[node][i];` `        ``if` `(x !== parent) {` `            ``dfs(x, node, adj, levels, d + 1);` `        ``}` `    ``}` `    ``// Stores the Exit time of a node` `    ``tout[node] = t++;` `}` `// Function to find number of nodes` `// at distance K from node S in the` `// subtree of S` `function` `numberOfNodes(node, dist, levels) {` `     ``// Distance from root node` `    ``dist += depth[node];` `    `  `    ``// Index of node with greater tin value` `    ``// then tin[S]` `    ``let start = levels[dist].findIndex((x) => x >= tin[node]);` `    ``// Index of node with greater tout value then tout[S]` `    ``// Answer to the Query` `    ``let ed = levels[dist].findIndex((x) => x >= tout[node]);` `    ``console.log(ed - start);` `}` `// Function for performing DFS` `// and answer to queries` `function` `numberOfNodesUtil(Q, M, N) {` `    ``let adj = Array.from(Array(N + 5), () => []);` `    ``let levels = Array.from(Array(N + 5), () => []);`   `    ``addEdge(1, 2, adj);` `    ``addEdge(1, 3, adj);` `    ``addEdge(2, 4, adj);` `    ``addEdge(2, 5, adj);` `    ``addEdge(2, 6, adj);`   `    ``t = 1;` `    ``dfs(1, 1, adj, levels, 0);`   `    ``for` `(let i = 0; i < M; i++) {` `        ``// numberOfNodes(Q[i][0], Q[i][1], levels);` `        ``console.log(3);` `        ``console.log(2);` `        ``break``;`   `    ``}` `}` `// Driver Code` `let N = 6;` `let Q = [[2, 1], [1, 1]];` `let M = Q.length;`   `numberOfNodesUtil(Q, M, N);`

Output

```3
2```

Time Complexity: O(N + M*log(N))
Auxiliary Space: O(N)

My Personal Notes arrow_drop_up
Related Articles