Disjoint union graph

PHOTO EMBED

Sat Aug 26 2023 06:17:05 GMT+0000 (Coordinated Universal Time)

Saved by @utp

class DisjointSet:
    def __init__(self, n):
        self.rank = [0] * (n+1)
        self.parent = list(range(n+1))
        self.size = [1] * (n+1)

    def find_upar(self, node):
        if node == self.parent[node]:
            return node
        self.parent[node] = self.find_upar(self.parent[node])
        return self.parent[node]

    def union_by_rank(self, u, v):
        ulp_u = self.find_upar(u)
        ulp_v = self.find_upar(v)
        if ulp_u == ulp_v:
            return
        if self.rank[ulp_u] < self.rank[ulp_v]:
            self.parent[ulp_u] = ulp_v
        elif self.rank[ulp_v] < self.rank[ulp_u]:
            self.parent[ulp_v] = ulp_u
        else:
            self.parent[ulp_v] = ulp_u
            self.rank[ulp_u] += 1

    def union_by_size(self, u, v):
        ulp_u = self.find_upar(u)
        ulp_v = self.find_upar(v)
        if ulp_u == ulp_v:
            return
        if self.size[ulp_u] < self.size[ulp_v]:
            self.parent[ulp_u] = ulp_v
            self.size[ulp_v] += self.size[ulp_u]
        else:
            self.parent[ulp_v] = ulp_u
            self.size[ulp_u] += self.size[ulp_v]


ds = DisjointSet(7)
ds.union_by_size(1, 2)
ds.union_by_size(2, 3)
ds.union_by_size(4, 5)
ds.union_by_size(6, 7)
ds.union_by_size(5, 6)

if ds.find_upar(3) == ds.find_upar(7):
    print("Same")
else:
    print("Not same")

ds.union_by_size(3, 7)

if ds.find_upar(3) == ds.find_upar(7):
    print("Same")
else:
    print("Not same")
content_copyCOPY