Union find


class UnionFindManager {

public:

    // every connected component has a single root node identified using: i == parent[i]
    // parent[i] is the parent of ith node
    vector<int> parent;

    // size[i] is meaningful only if i is a root node, 
    // size[i] = number of nodes in its connected component
    vector<int> size;

    /**
    * inputs ::
    *     N  : total number of nodes count
    */
    UnionFindManager(int N) {

        size.resize(N, 1);  // every node is in different set (each of size 1) initially
        
        parent.resize(N, -1);

        // set each node as its own parent initially
        for(int i=0 ; i<parent.size() ; i++){
            parent[i] = i;
            // cout << "set parent of " << i << " as " << i << endl;
        }
        
    }

    int findRoot(int i){
        while(parent[i] != i){
            parent[i] = parent[parent[i]];  // update parent to a node which is little closer to the actual root parent so that further queries gets optimized 
            i = parent[i];
        }
        return i;
    }

    // worst case time complexity is O(log n). 
    // Amortised per operation time complexity is O(1)
    // j gets meged in i, root of combined set = findRoot(i)
    void unionNodes(int i, int j){

        int rooti = findRoot(i);
        int rootj = findRoot(j);

        if(rooti == rootj){
            return;
        }

        int smallRoot = (size[rooti] < size[rootj]) ? rooti : rootj ;
        int bigRoot = (size[rooti] >= size[rootj]) ? rooti : rootj ;

        parent[smallRoot] = bigRoot;
        size[bigRoot] += size[smallRoot];
    }

    bool isRoot(int i){
        return (i == parent[i]);
    }

};


int main(){

    UnionFindManager ufm(10);
    
    ufm.unionNodes(0,2);        // 2 gets merged in 0, rootNnode = 0
    ufm.unionNodes(1,7);        // 7 gets merged in 1, rootNnode = 1
    ufm.unionNodes(7,8);
    ufm.unionNodes(8,9);
    ufm.unionNodes(4,5);
    
    cout << ufm.findRoot(8) << endl;            // ans = 1
    
    int totalDisjointSets = 0;
    for(int i = 0; i<ufm.parent.size(); i++){
        if(ufm.isRoot(i)){
            totalDisjointSets++;
        }
    }
    
    cout << "Total number of disjoint sets are : " << totalDisjointSets << endl;      
    
    //  totalDisjointSets = 5
    //  = {
    //      [0,2],
    //      [1,7,8,9],
    //      [3],
    //      [4,5],
    //      [6]
    //    }
    

}

Last updated