Disjoint Set|C++

PHOTO EMBED

Sun Jun 23 2024 18:01:25 GMT+0000 (Coordinated Universal Time)

Saved by @utp #c++

#include <iostream>
#include <vector>

class DisjointSet {
public:
    DisjointSet(int n) {
        rank.resize(n + 1, 0);
        parent.resize(n + 1);
        size.resize(n + 1, 1);
        for (int i = 0; i <= n; ++i) {
            parent[i] = i;
        }
    }

    int find_upar(int node) {
        if (node == parent[node]) {
            return node;
        }
        return parent[node] = find_upar(parent[node]);
    }

    void union_by_rank(int u, int v) {
        int ulp_u = find_upar(u);
        int ulp_v = find_upar(v);
        if (ulp_u == ulp_v) {
            return;
        }
        if (rank[ulp_u] < rank[ulp_v]) {
            parent[ulp_u] = ulp_v;
        } else if (rank[ulp_v] < rank[ulp_u]) {
            parent[ulp_v] = ulp_u;
        } else {
            parent[ulp_v] = ulp_u;
            rank[ulp_u]++;
        }
    }

    void union_by_size(int u, int v) {
        int ulp_u = find_upar(u);
        int ulp_v = find_upar(v);
        if (ulp_u == ulp_v) {
            return;
        }
        if (size[ulp_u] < size[ulp_v]) {
            parent[ulp_u] = ulp_v;
            size[ulp_v] += size[ulp_u];
        } else {
            parent[ulp_v] = ulp_u;
            size[ulp_u] += size[ulp_v];
        }
    }

private:
    std::vector<int> rank;
    std::vector<int> parent;
    std::vector<int> size;
};

class Solution {
public:
    int Solve(int n, std::vector<std::vector<int>>& edge) {
        DisjointSet ds(n);
        int cnt_extras = 0;
        for (auto& e : edge) {
            int u = e[0];
            int v = e[1];
            if (ds.find_upar(u) == ds.find_upar(v)) {
                cnt_extras++;
            } else {
                ds.union_by_size(u, v);
            }
        }
        int cnt_c = 0;
        for (int i = 0; i < n; ++i) {
            if (ds.find_upar(i) == i) {
                cnt_c++;
            }
        }
        int ans = cnt_c - 1;
        if (cnt_extras >= ans) {
            return ans;
        }
        return -1;
    }
};

int main() {
    int V = 9;
    std::vector<std::vector<int>> edge = { {0, 1}, {0, 2}, {0, 3}, {1, 2}, {2, 3}, {4, 5}, {5, 6}, {7, 8} };

    Solution obj;
    int ans = obj.Solve(V, edge);
    std::cout << "The number of operations needed: " << ans << std::endl;

    return 0;
}
content_copyCOPY