Skip to content
Related Articles

Related Articles

Extended Disjoint Set Union on Trees

View Discussion
Improve Article
Save Article
Like Article
  • Last Updated : 01 Dec, 2021

Prerequisites: DFS, Trees, DSU
Given a tree with of N nodes from value 1 to N and E edges and array arr[] which denotes number associated to each node. You are also given Q queries which contains 2 integers {V, F}. For each query, there is a subtree with vertex V, the task is to check if there exists count of numbers associated with each node in that subtree is F or not. If yes then print True else print False.
Examples: 
 

Input: N = 8, E = { {1, 2}, {1, 3}, {2, 4}, {2, 5}, {5, 8}, {5, 6}, {6, 7} }, arr[] = { 11, 2, 7, 1, -3, -1, -1, -3 }, Q = 3, queries[] = { {2, 3}, {5, 2}, {7, 1} } 
Output: 
False 
True 
True 
Explanation: 
Query 1: No number occurs three times in sub-tree 2 
Query 2: Number -1 and -3 occurs 2 times in sub-tree 5 
Query 3: Number -1 occurs once in sub-tree 7 
Input: N = 11, E = { {1, 2}, {1, 3}, {2, 4}, {2, 5}, {4, 9}, {4, 8}, {3, 6}, {3, 7}, {4, 10}, {5, 11} }, arr[] = { 2, -1, -12, -1, 6, 14, 7, -1, -2, 13, 12 }, Q = 2, queries[] = { {2, 2}, {4, 2} } 
Output: 
False 
True 
Explanation: 
Query 1: No number occurs exactly 2 times in sub-tree 2 
Query 2: Number -1 occurs twice in sub-tree 4 
 

 

Naive Approach: The idea is to traverse the tree using DFS traversal and calculate the frequency of number associated with each vertices of sub-tree, V and store the result in a hashmap. After traversal, we just need to traverse the hashmap to check if the given frequency number exists.
Below is the implementation of the above approach:
 

Java




// Java program for the above approach
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
 
@SuppressWarnings("unchecked")
public class Main {
 
    // To store edges of tree
    static ArrayList<Integer> adj[];
 
    // To store number associated
    // with vertex
    static int num[];
 
    // To store frequency of number
    static HashMap<Integer, Integer> freq;
 
    // Function to add edges in tree
    static void add(int u, int v)
    {
        adj[u].add(v);
        adj[v].add(u);
    }
 
    // Function returns boolean value
    // representing is there any number
    // present in subtree qv having
    // frequency qc
    static boolean query(int qv, int qc)
    {
 
        freq = new HashMap<>();
 
        // Start from root
        int v = 1;
 
        // DFS Call
        if (qv == v) {
            dfs(v, 0, true, qv);
        }
        else
            dfs(v, 0, false, qv);
 
        // Check for frequency
        for (int fq : freq.values()) {
            if (fq == qc)
                return true;
        }
        return false;
    }
 
    // Function to implement DFS
    static void dfs(int v, int p,
                    boolean isQV, int qv)
    {
        if (isQV) {
 
            // If we are on subtree qv,
            // then increment freq of
            // num[v] in freq
            freq.put(num[v],
                     freq.getOrDefault(num[v], 0) + 1);
        }
 
        // Recursive DFS Call for
        // adjacency list of node u
        for (int u : adj[v]) {
            if (p != u) {
                if (qv == u) {
                    dfs(u, v, true, qv);
                }
                else
                    dfs(u, v, isQV, qv);
            }
        }
    }
 
    // Driver Code
    public static void main(String[] args)
    {
 
        // Given Nodes
        int n = 8;
        adj = new ArrayList[n + 1];
        for (int i = 1; i <= n; i++)
            adj[i] = new ArrayList<>();
 
        // Given edges of tree
        // (root=1)
        add(1, 2);
        add(1, 3);
        add(2, 4);
        add(2, 5);
        add(5, 8);
        add(5, 6);
        add(6, 7);
 
        // Number assigned to each vertex
        num = new int[] { -1, 11, 2, 7, 1, -3, -1, -1, -3 };
 
        // Total number of queries
        int q = 3;
 
        // Function Call to find each query
        System.out.println(query(2, 3));
        System.out.println(query(5, 2));
        System.out.println(query(7, 1));
    }
}


Python3




# Python 3 program for the above approach
 
# To store edges of tree
adj=[]
 
# To store number associated
# with vertex
num=[]
 
# To store frequency of number
freq=dict()
 
# Function to add edges in tree
def add(u, v):
    adj[u].append(v)
    adj[v].append(u)
 
 
# Function returns boolean value
# representing is there any number
# present in subtree qv having
# frequency qc
def query(qv, qc):
    freq.clear()
     
    # Start from root
    v = 1
 
    # DFS Call
    if (qv == v) :
        dfs(v, 0, True, qv)
     
    else:
        dfs(v, 0, False, qv)
 
    # Check for frequency
    if qc in freq.values() :
        return True
     
    return False
 
 
# Function to implement DFS
def dfs(v, p, isQV, qv):
    if (isQV) :
 
        # If we are on subtree qv,
        # then increment freq of
        # num[v] in freq
        freq[num[v]]=freq.get(num[v], 0) + 1
     
 
    # Recursive DFS Call for
    # adjacency list of node u
    for u in adj[v]:
        if (p != u) :
            if (qv == u) :
                dfs(u, v, True, qv)
             
            else:
                dfs(u, v, isQV, qv)
         
     
 
 
# Driver Code
if __name__ == '__main__':
 
    # Given Nodes
    n = 8
    for _ in range(n+1):
        adj.append([])
 
    # Given edges of tree
    # (root=1)
    add(1, 2)
    add(1, 3)
    add(2, 4)
    add(2, 5)
    add(5, 8)
    add(5, 6)
    add(6, 7)
 
    # Number assigned to each vertex
    num = [-1, 11, 2, 7, 1, -3, -1, -1, -3]
 
    # Total number of queries
    q = 3
 
    # Function Call to find each query
    print(query(2, 3))
    print(query(5, 2))
    print(query(7, 1))


Output: 

false
true
true

 

Time Complexity: O(N * Q) Since in each query, tree needs to be traversed. 
Auxiliary Space: O(N + E + Q)
Efficient Approach: The idea is to use Extended Disjoint Set Union to the above approach:
 

  1. Create an array, size[] to store size of sub-trees.
  2. Create an array of hashmaps, map[] i.e, map[V][X] = total vertices of number X in sub-tree, V.
  3. Calculate the size of each subtree using DFS traversal by calling dfsSize().
  4. Using DFS traversal by calling dfs(V, p), calculate the value of map[V].
  5. In the traversal, to calculate map[V], choose the adjacent vertex of V having maximum size ( bigU ) except parent vertex, p
     
  6. For join operation pass the reference of map[bigU] to map[V] i.e, map[V] = map[bigU].
  7. And atlast merge maps of all adjacent vertices, u to map[V], except parent vertex, p and bigU vertex.
  8. Now, check if map[V] contains frequency F or not. If yes then print True else print False.

Below is the implementation of the efficient approach: 
 

Java




// Java program for the above approach
import java.awt.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
 
@SuppressWarnings("unchecked")
public class Main {
 
    // To store edges
    static ArrayList<Integer> adj[];
 
    // num[v] = number assigned
    // to vertex v
    static int num[];
 
    // map[v].get(c) = total vertices
    // in subtree v having number c
    static Map<Integer, Integer> map[];
 
    // size[v]=size of subtree v
    static int size[];
 
    static HashMap<Point, Boolean> ans;
    static ArrayList<Integer> qv[];
 
    // Function to add edges
    static void add(int u, int v)
    {
        adj[u].add(v);
        adj[v].add(u);
    }
 
    // Function to find subtree size
    // of every vertex using dfsSize()
    static void dfsSize(int v, int p)
    {
        size[v]++;
 
        // Traverse dfsSize recursively
        for (int u : adj[v]) {
            if (p != u) {
                dfsSize(u, v);
                size[v] += size[u];
            }
        }
    }
 
    // Function to implement DFS Traversal
    static void dfs(int v, int p)
    {
        int mx = -1, bigU = -1;
 
        // Find adjacent vertex with
        // maximum size
        for (int u : adj[v]) {
            if (u != p) {
                dfs(u, v);
                if (size[u] > mx) {
                    mx = size[u];
                    bigU = u;
                }
            }
        }
 
        if (bigU != -1) {
 
            // Passing referencing
            map[v] = map[bigU];
        }
        else {
 
            // If no adjacent vertex
            // present initialize map[v]
            map[v] = new HashMap<Integer, Integer>();
        }
 
        // Update frequency of current number
        map[v].put(num[v],
                   map[v].getOrDefault(num[v], 0) + 1);
 
        // Add all adjacent vertices
        // maps to map[v]
        for (int u : adj[v]) {
 
            if (u != bigU && u != p) {
 
                for (Entry<Integer, Integer>
                         pair : map[u].entrySet()) {
                    map[v].put(
                        pair.getKey(),
                        pair.getValue()
                            + map[v]
                                  .getOrDefault(
                                      pair.getKey(), 0));
                }
            }
        }
 
        // Store all queries related
        // to vertex v
        for (int freq : qv[v]) {
            ans.put(new Point(v, freq),
                    map[v].containsValue(freq));
        }
    }
 
    // Function to find answer for
    // each queries
    static void solveQuery(Point queries[],
                           int N, int q)
    {
        // Add all queries to qv
        // where i<qv[v].size()
        qv = new ArrayList[N + 1];
        for (int i = 1; i <= N; i++)
            qv[i] = new ArrayList<>();
 
        for (Point p : queries) {
            qv[p.x].add(p.y);
        }
 
        // Get sizes of all subtrees
        size = new int[N + 1];
 
        // calculate size[]
        dfsSize(1, 0);
 
        // Map will be used to store
        // answers for current vertex
        // on dfs
        map = new HashMap[N + 1];
 
        // To store answer of queries
        ans = new HashMap<>();
 
        // DFS Call
        dfs(1, 0);
 
        for (Point p : queries) {
 
            // Print answer for each query
            System.out.println(ans.get(p));
        }
    }
 
    // Driver Code
    public static void main(String[] args)
    {
        int N = 8;
        adj = new ArrayList[N + 1];
        for (int i = 1; i <= N; i++)
            adj[i] = new ArrayList<>();
 
        // Given edges (root=1)
        add(1, 2);
        add(1, 3);
        add(2, 4);
        add(2, 5);
        add(5, 8);
        add(5, 6);
        add(6, 7);
 
        // Store number given to vertices
        // set num[0]=-1 because
        // there is no vertex 0
        num
            = new int[] { -1, 11, 2, 7, 1,
                          -3, -1, -1, -3 };
 
        // Queries
        int q = 3;
 
        // To store queries
        Point queries[] = new Point[q];
 
        // Given Queries
        queries[0] = new Point(2, 3);
        queries[1] = new Point(5, 2);
        queries[2] = new Point(7, 1);
 
        // Function Call
        solveQuery(queries, N, q);
    }
}


Python3




# Python 3 program for the above approach
 
# To store edges of tree
adj=[]
 
# To store number associated
# with vertex
num=[]
 
# mp[v] = total vertices
# in subtree v having number c
mp=dict()
 
# size[v]=size of subtree v
size=[]
 
ans=dict()
 
qv=[]
 
# Function to add edges in tree
def add(u, v):
    adj[u].append(v)
    adj[v].append(u)
 
 
# Function to find subtree size
# of every vertex using dfsSize()
def dfsSize(v, p):
    size[v]+=1
 
    # Traverse dfsSize recursively
    for u in adj[v] :
        if (p != u) :
            dfsSize(u, v)
            size[v] += size[u]
         
     
 
 
# Function to implement DFS Traversal
def dfs(v, p):
    global ans
    mx = -1; bigU = -1
 
    # Find adjacent vertex with
    # maximum size
    for u in adj[v]:
        if (u != p) :
            dfs(u, v)
            if (size[u] > mx) :
                mx = size[u]
                bigU = u       
 
    if (bigU != -1) :
 
        # Passing referencing
        mp[v] = mp[bigU]
     
    else :
 
        # If no adjacent vertex
        # present initialize mp[v]
        mp[v] = dict()
     
 
    # Update frequency of current number
    mp[v][num[v]]=mp[v].get(num[v], 0) + 1
 
    # Add all adjacent vertices
    # maps to mp[v]
    for u in adj[v] :
        if u not in (bigU,p) :
            for pair in mp[u].items():
                mp[v][pair[0]]=pair[1]+mp[v].get(pair[0], 0)
         
     
 
    # Store all queries related
    # to vertex v
    for freq in qv[v] :
        ans[(v, freq)]=freq in mp[v]
 
# Function to find answer for
# each queries
def solveQuery(queries, N, q):
    global size,mp,qv,ans
    # Add all queries to qv
    # where i<qv[v].size()
    qv = []
    for i in range(N+1):
        qv.append([])
 
    for p in queries:
        qv[p[0]].append(p[1])
     
 
    # Get sizes of all subtrees   
    size = [0]*(N + 1)
 
    # calculate size[]
    dfsSize(1, 0)
 
    # mp will be used to store
    # answers for current vertex
    # on dfs
    mp = dict()
 
    # To store answer of queries
    ans = dict()
 
    # DFS Call
    dfs(1, 0)
 
    for p in queries:
 
        # Print answer for each query
        print(ans[p])
     
 
 
# Driver Code
if __name__ == '__main__':
    N = 8
    adj = []
    for i in range(N+1):
        adj.append([])
 
    # Given edges (root=1)
    add(1, 2)
    add(1, 3)
    add(2, 4)
    add(2, 5)
    add(5, 8)
    add(5, 6)
    add(6, 7)
 
    # Store number given to vertices
    # set num[0]=-1 because
    # there is no vertex 0
    num = [-1, 11, 2, 7, 1,-3, -1, -1, -3]
 
    # Queries
    q = 3
 
    # To store queries
    queries=[]
 
    # Given Queries
    queries.append((2, 3))
    queries.append((5, 2))
    queries.append((7, 1))
 
    # Function Call
    solveQuery(queries, N, q)


Output: 

false
true
true

 

Time Complexity: O(N*logN2
Auxiliary Space: O(N + E + Q) 

 


My Personal Notes arrow_drop_up
Recommended Articles
Page :

Start Your Coding Journey Now!