kruskal's algo MST

PHOTO EMBED

Sat Aug 26 2023 06:45:06 GMT+0000 (Coordinated Universal Time)

Saved by @utp

class DisjointSet:
    def __init__(self, n):
        self.rank = [0] * (n + 1)
        self.parent = list(range(n + 1))
        self.size = [1] * (n + 1)

    def find_upar(self, node):
        if node == self.parent[node]:
            return node
        self.parent[node] = self.find_upar(self.parent[node])
        return self.parent[node]

    def union_by_rank(self, u, v):
        ulp_u = self.find_upar(u)
        ulp_v = self.find_upar(v)
        if ulp_u == ulp_v:
            return
        if self.rank[ulp_u] < self.rank[ulp_v]:
            self.parent[ulp_u] = ulp_v
        elif self.rank[ulp_v] < self.rank[ulp_u]:
            self.parent[ulp_v] = ulp_u
        else:
            self.parent[ulp_v] = ulp_u
            self.rank[ulp_u] += 1

    def union_by_size(self, u, v):
        ulp_u = self.find_upar(u)
        ulp_v = self.find_upar(v)
        if ulp_u == ulp_v:
            return
        if self.size[ulp_u] < self.size[ulp_v]:
            self.parent[ulp_u] = ulp_v
            self.size[ulp_v] += self.size[ulp_u]
        else:
            self.parent[ulp_v] = ulp_u
            self.size[ulp_u] += self.size[ulp_v]


class Solution:
    def spanningTree(self, V, adj):
        edges = []
        for i in range(V):
            for it in adj[i]:
                adjNode, wt = it
                node = i
                edges.append((wt, (node, adjNode)))

        ds = DisjointSet(V)
        edges.sort()
        mstWt = 0
        for wt, (u, v) in edges:
            if ds.find_upar(u) != ds.find_upar(v):
                mstWt += wt
                ds.union_by_size(u, v)

        return mstWt


if __name__ == "__main__":
    V = 5
    edges = [[0, 1, 2], [0, 2, 1], [1, 2, 1], [2, 3, 2], [3, 4, 1], [4, 2, 2]]
    adj = [[] for _ in range(V)]

    for it in edges:
        u, v, wt = it[0], it[1], it[2]
        adj[u].append((v, wt))
        adj[v].append((u, wt))

    obj = Solution()
    mstWt = obj.spanningTree(V, adj)
    print("The sum of all the edge weights:", mstWt)
content_copyCOPY