# Extended Disjoint Set Union on Trees

• Last Updated : 01 Dec, 2021

Prerequisites: DFS, Trees, DSU
Given a tree with of N nodes from value 1 to N and E edges and array arr[] which denotes number associated to each node. You are also given Q queries which contains 2 integers {V, F}. For each query, there is a subtree with vertex V, the task is to check if there exists count of numbers associated with each node in that subtree is F or not. If yes then print True else print False.
Examples:

Input: N = 8, E = { {1, 2}, {1, 3}, {2, 4}, {2, 5}, {5, 8}, {5, 6}, {6, 7} }, arr[] = { 11, 2, 7, 1, -3, -1, -1, -3 }, Q = 3, queries[] = { {2, 3}, {5, 2}, {7, 1} }
Output:
False
True
True
Explanation:
Query 1: No number occurs three times in sub-tree 2
Query 2: Number -1 and -3 occurs 2 times in sub-tree 5
Query 3: Number -1 occurs once in sub-tree 7
Input: N = 11, E = { {1, 2}, {1, 3}, {2, 4}, {2, 5}, {4, 9}, {4, 8}, {3, 6}, {3, 7}, {4, 10}, {5, 11} }, arr[] = { 2, -1, -12, -1, 6, 14, 7, -1, -2, 13, 12 }, Q = 2, queries[] = { {2, 2}, {4, 2} }
Output:
False
True
Explanation:
Query 1: No number occurs exactly 2 times in sub-tree 2
Query 2: Number -1 occurs twice in sub-tree 4

Naive Approach: The idea is to traverse the tree using DFS traversal and calculate the frequency of number associated with each vertices of sub-tree, V and store the result in a hashmap. After traversal, we just need to traverse the hashmap to check if the given frequency number exists.
Below is the implementation of the above approach:

## Java

 `// Java program for the above approach` `import` `java.util.ArrayList;` `import` `java.util.HashMap;` `import` `java.util.Map;`   `@SuppressWarnings``(``"unchecked"``)` `public` `class` `Main {`   `    ``// To store edges of tree` `    ``static` `ArrayList adj[];`   `    ``// To store number associated` `    ``// with vertex` `    ``static` `int` `num[];`   `    ``// To store frequency of number` `    ``static` `HashMap freq;`   `    ``// Function to add edges in tree` `    ``static` `void` `add(``int` `u, ``int` `v)` `    ``{` `        ``adj[u].add(v);` `        ``adj[v].add(u);` `    ``}`   `    ``// Function returns boolean value` `    ``// representing is there any number` `    ``// present in subtree qv having` `    ``// frequency qc` `    ``static` `boolean` `query(``int` `qv, ``int` `qc)` `    ``{`   `        ``freq = ``new` `HashMap<>();`   `        ``// Start from root` `        ``int` `v = ``1``;`   `        ``// DFS Call` `        ``if` `(qv == v) {` `            ``dfs(v, ``0``, ``true``, qv);` `        ``}` `        ``else` `            ``dfs(v, ``0``, ``false``, qv);`   `        ``// Check for frequency` `        ``for` `(``int` `fq : freq.values()) {` `            ``if` `(fq == qc)` `                ``return` `true``;` `        ``}` `        ``return` `false``;` `    ``}`   `    ``// Function to implement DFS` `    ``static` `void` `dfs(``int` `v, ``int` `p,` `                    ``boolean` `isQV, ``int` `qv)` `    ``{` `        ``if` `(isQV) {`   `            ``// If we are on subtree qv,` `            ``// then increment freq of` `            ``// num[v] in freq` `            ``freq.put(num[v],` `                     ``freq.getOrDefault(num[v], ``0``) + ``1``);` `        ``}`   `        ``// Recursive DFS Call for` `        ``// adjacency list of node u` `        ``for` `(``int` `u : adj[v]) {` `            ``if` `(p != u) {` `                ``if` `(qv == u) {` `                    ``dfs(u, v, ``true``, qv);` `                ``}` `                ``else` `                    ``dfs(u, v, isQV, qv);` `            ``}` `        ``}` `    ``}`   `    ``// Driver Code` `    ``public` `static` `void` `main(String[] args)` `    ``{`   `        ``// Given Nodes` `        ``int` `n = ``8``;` `        ``adj = ``new` `ArrayList[n + ``1``];` `        ``for` `(``int` `i = ``1``; i <= n; i++)` `            ``adj[i] = ``new` `ArrayList<>();`   `        ``// Given edges of tree` `        ``// (root=1)` `        ``add(``1``, ``2``);` `        ``add(``1``, ``3``);` `        ``add(``2``, ``4``);` `        ``add(``2``, ``5``);` `        ``add(``5``, ``8``);` `        ``add(``5``, ``6``);` `        ``add(``6``, ``7``);`   `        ``// Number assigned to each vertex` `        ``num = ``new` `int``[] { -``1``, ``11``, ``2``, ``7``, ``1``, -``3``, -``1``, -``1``, -``3` `};`   `        ``// Total number of queries` `        ``int` `q = ``3``;`   `        ``// Function Call to find each query` `        ``System.out.println(query(``2``, ``3``));` `        ``System.out.println(query(``5``, ``2``));` `        ``System.out.println(query(``7``, ``1``));` `    ``}` `}`

## Python3

 `# Python 3 program for the above approach`   `# To store edges of tree` `adj``=``[]`   `# To store number associated` `# with vertex` `num``=``[]`   `# To store frequency of number` `freq``=``dict``()`   `# Function to add edges in tree` `def` `add(u, v):` `    ``adj[u].append(v)` `    ``adj[v].append(u)`     `# Function returns boolean value` `# representing is there any number` `# present in subtree qv having` `# frequency qc` `def` `query(qv, qc):` `    ``freq.clear()` `    `  `    ``# Start from root` `    ``v ``=` `1`   `    ``# DFS Call` `    ``if` `(qv ``=``=` `v) :` `        ``dfs(v, ``0``, ``True``, qv)` `    `  `    ``else``:` `        ``dfs(v, ``0``, ``False``, qv)`   `    ``# Check for frequency` `    ``if` `qc ``in` `freq.values() :` `        ``return` `True` `    `  `    ``return` `False`     `# Function to implement DFS` `def` `dfs(v, p, isQV, qv):` `    ``if` `(isQV) :`   `        ``# If we are on subtree qv,` `        ``# then increment freq of` `        ``# num[v] in freq` `        ``freq[num[v]]``=``freq.get(num[v], ``0``) ``+` `1` `    `    `    ``# Recursive DFS Call for` `    ``# adjacency list of node u` `    ``for` `u ``in` `adj[v]:` `        ``if` `(p !``=` `u) :` `            ``if` `(qv ``=``=` `u) :` `                ``dfs(u, v, ``True``, qv)` `            `  `            ``else``:` `                ``dfs(u, v, isQV, qv)` `        `  `    `      `# Driver Code` `if` `__name__ ``=``=` `'__main__'``:`   `    ``# Given Nodes` `    ``n ``=` `8` `    ``for` `_ ``in` `range``(n``+``1``):` `        ``adj.append([])`   `    ``# Given edges of tree` `    ``# (root=1)` `    ``add(``1``, ``2``)` `    ``add(``1``, ``3``)` `    ``add(``2``, ``4``)` `    ``add(``2``, ``5``)` `    ``add(``5``, ``8``)` `    ``add(``5``, ``6``)` `    ``add(``6``, ``7``)`   `    ``# Number assigned to each vertex` `    ``num ``=` `[``-``1``, ``11``, ``2``, ``7``, ``1``, ``-``3``, ``-``1``, ``-``1``, ``-``3``] `   `    ``# Total number of queries` `    ``q ``=` `3`   `    ``# Function Call to find each query` `    ``print``(query(``2``, ``3``))` `    ``print``(query(``5``, ``2``))` `    ``print``(query(``7``, ``1``))`

Output:

```false
true
true```

Time Complexity: O(N * Q) Since in each query, tree needs to be traversed.
Auxiliary Space: O(N + E + Q)
Efficient Approach: The idea is to use Extended Disjoint Set Union to the above approach:

1. Create an array, size[] to store size of sub-trees.
2. Create an array of hashmaps, map[] i.e, map[V][X] = total vertices of number X in sub-tree, V.
3. Calculate the size of each subtree using DFS traversal by calling dfsSize().
4. Using DFS traversal by calling dfs(V, p), calculate the value of map[V].
5. In the traversal, to calculate map[V], choose the adjacent vertex of V having maximum size ( bigU ) except parent vertex, p

6. For join operation pass the reference of map[bigU] to map[V] i.e, map[V] = map[bigU].
7. And atlast merge maps of all adjacent vertices, u to map[V], except parent vertex, p and bigU vertex.
8. Now, check if map[V] contains frequency F or not. If yes then print True else print False.

Below is the implementation of the efficient approach:

## Java

 `// Java program for the above approach` `import` `java.awt.*;` `import` `java.util.ArrayList;` `import` `java.util.HashMap;` `import` `java.util.Map;` `import` `java.util.Map.Entry;`   `@SuppressWarnings``(``"unchecked"``)` `public` `class` `Main {`   `    ``// To store edges` `    ``static` `ArrayList adj[];`   `    ``// num[v] = number assigned` `    ``// to vertex v` `    ``static` `int` `num[];`   `    ``// map[v].get(c) = total vertices` `    ``// in subtree v having number c` `    ``static` `Map map[];`   `    ``// size[v]=size of subtree v` `    ``static` `int` `size[];`   `    ``static` `HashMap ans;` `    ``static` `ArrayList qv[];`   `    ``// Function to add edges` `    ``static` `void` `add(``int` `u, ``int` `v)` `    ``{` `        ``adj[u].add(v);` `        ``adj[v].add(u);` `    ``}`   `    ``// Function to find subtree size` `    ``// of every vertex using dfsSize()` `    ``static` `void` `dfsSize(``int` `v, ``int` `p)` `    ``{` `        ``size[v]++;`   `        ``// Traverse dfsSize recursively` `        ``for` `(``int` `u : adj[v]) {` `            ``if` `(p != u) {` `                ``dfsSize(u, v);` `                ``size[v] += size[u];` `            ``}` `        ``}` `    ``}`   `    ``// Function to implement DFS Traversal` `    ``static` `void` `dfs(``int` `v, ``int` `p)` `    ``{` `        ``int` `mx = -``1``, bigU = -``1``;`   `        ``// Find adjacent vertex with` `        ``// maximum size` `        ``for` `(``int` `u : adj[v]) {` `            ``if` `(u != p) {` `                ``dfs(u, v);` `                ``if` `(size[u] > mx) {` `                    ``mx = size[u];` `                    ``bigU = u;` `                ``}` `            ``}` `        ``}`   `        ``if` `(bigU != -``1``) {`   `            ``// Passing referencing` `            ``map[v] = map[bigU];` `        ``}` `        ``else` `{`   `            ``// If no adjacent vertex` `            ``// present initialize map[v]` `            ``map[v] = ``new` `HashMap();` `        ``}`   `        ``// Update frequency of current number` `        ``map[v].put(num[v],` `                   ``map[v].getOrDefault(num[v], ``0``) + ``1``);`   `        ``// Add all adjacent vertices` `        ``// maps to map[v]` `        ``for` `(``int` `u : adj[v]) {`   `            ``if` `(u != bigU && u != p) {`   `                ``for` `(Entry` `                         ``pair : map[u].entrySet()) {` `                    ``map[v].put(` `                        ``pair.getKey(),` `                        ``pair.getValue()` `                            ``+ map[v]` `                                  ``.getOrDefault(` `                                      ``pair.getKey(), ``0``));` `                ``}` `            ``}` `        ``}`   `        ``// Store all queries related` `        ``// to vertex v` `        ``for` `(``int` `freq : qv[v]) {` `            ``ans.put(``new` `Point(v, freq),` `                    ``map[v].containsValue(freq));` `        ``}` `    ``}`   `    ``// Function to find answer for` `    ``// each queries` `    ``static` `void` `solveQuery(Point queries[],` `                           ``int` `N, ``int` `q)` `    ``{` `        ``// Add all queries to qv` `        ``// where i();`   `        ``for` `(Point p : queries) {` `            ``qv[p.x].add(p.y);` `        ``}`   `        ``// Get sizes of all subtrees` `        ``size = ``new` `int``[N + ``1``];`   `        ``// calculate size[]` `        ``dfsSize(``1``, ``0``);`   `        ``// Map will be used to store` `        ``// answers for current vertex` `        ``// on dfs` `        ``map = ``new` `HashMap[N + ``1``];`   `        ``// To store answer of queries` `        ``ans = ``new` `HashMap<>();`   `        ``// DFS Call` `        ``dfs(``1``, ``0``);`   `        ``for` `(Point p : queries) {`   `            ``// Print answer for each query` `            ``System.out.println(ans.get(p));` `        ``}` `    ``}`   `    ``// Driver Code` `    ``public` `static` `void` `main(String[] args)` `    ``{` `        ``int` `N = ``8``;` `        ``adj = ``new` `ArrayList[N + ``1``];` `        ``for` `(``int` `i = ``1``; i <= N; i++)` `            ``adj[i] = ``new` `ArrayList<>();`   `        ``// Given edges (root=1)` `        ``add(``1``, ``2``);` `        ``add(``1``, ``3``);` `        ``add(``2``, ``4``);` `        ``add(``2``, ``5``);` `        ``add(``5``, ``8``);` `        ``add(``5``, ``6``);` `        ``add(``6``, ``7``);`   `        ``// Store number given to vertices` `        ``// set num[0]=-1 because` `        ``// there is no vertex 0` `        ``num` `            ``= ``new` `int``[] { -``1``, ``11``, ``2``, ``7``, ``1``,` `                          ``-``3``, -``1``, -``1``, -``3` `};`   `        ``// Queries` `        ``int` `q = ``3``;`   `        ``// To store queries` `        ``Point queries[] = ``new` `Point[q];`   `        ``// Given Queries` `        ``queries[``0``] = ``new` `Point(``2``, ``3``);` `        ``queries[``1``] = ``new` `Point(``5``, ``2``);` `        ``queries[``2``] = ``new` `Point(``7``, ``1``);`   `        ``// Function Call` `        ``solveQuery(queries, N, q);` `    ``}` `}`

## Python3

 `# Python 3 program for the above approach`   `# To store edges of tree` `adj``=``[]`   `# To store number associated` `# with vertex` `num``=``[]`   `# mp[v] = total vertices` `# in subtree v having number c` `mp``=``dict``()`   `# size[v]=size of subtree v` `size``=``[]`   `ans``=``dict``()`   `qv``=``[]`   `# Function to add edges in tree` `def` `add(u, v):` `    ``adj[u].append(v)` `    ``adj[v].append(u)`     `# Function to find subtree size` `# of every vertex using dfsSize()` `def` `dfsSize(v, p):` `    ``size[v]``+``=``1`   `    ``# Traverse dfsSize recursively` `    ``for` `u ``in` `adj[v] :` `        ``if` `(p !``=` `u) :` `            ``dfsSize(u, v)` `            ``size[v] ``+``=` `size[u]` `        `  `    `      `# Function to implement DFS Traversal` `def` `dfs(v, p):` `    ``global` `ans` `    ``mx ``=` `-``1``; bigU ``=` `-``1`   `    ``# Find adjacent vertex with` `    ``# maximum size` `    ``for` `u ``in` `adj[v]:` `        ``if` `(u !``=` `p) :` `            ``dfs(u, v)` `            ``if` `(size[u] > mx) :` `                ``mx ``=` `size[u]` `                ``bigU ``=` `u        `   `    ``if` `(bigU !``=` `-``1``) :`   `        ``# Passing referencing` `        ``mp[v] ``=` `mp[bigU]` `    `  `    ``else` `:`   `        ``# If no adjacent vertex` `        ``# present initialize mp[v]` `        ``mp[v] ``=` `dict``()` `    `    `    ``# Update frequency of current number` `    ``mp[v][num[v]]``=``mp[v].get(num[v], ``0``) ``+` `1`   `    ``# Add all adjacent vertices` `    ``# maps to mp[v]` `    ``for` `u ``in` `adj[v] :` `        ``if` `u ``not` `in` `(bigU,p) :` `            ``for` `pair ``in` `mp[u].items():` `                ``mp[v][pair[``0``]]``=``pair[``1``]``+``mp[v].get(pair[``0``], ``0``)` `        `  `    `    `    ``# Store all queries related` `    ``# to vertex v` `    ``for` `freq ``in` `qv[v] :` `        ``ans[(v, freq)]``=``freq ``in` `mp[v]`   `# Function to find answer for` `# each queries` `def` `solveQuery(queries, N, q):` `    ``global` `size,mp,qv,ans` `    ``# Add all queries to qv` `    ``# where i

Output:

```false
true
true```

Time Complexity: O(N*logN2
Auxiliary Space: O(N + E + Q)

My Personal Notes arrow_drop_up
Recommended Articles
Page :