kruskal's algo MST
Sat Aug 26 2023 06:45:06 GMT+0000 (Coordinated Universal Time)
Saved by @utp
class DisjointSet:
def __init__(self, n):
self.rank = [0] * (n + 1)
self.parent = list(range(n + 1))
self.size = [1] * (n + 1)
def find_upar(self, node):
if node == self.parent[node]:
return node
self.parent[node] = self.find_upar(self.parent[node])
return self.parent[node]
def union_by_rank(self, u, v):
ulp_u = self.find_upar(u)
ulp_v = self.find_upar(v)
if ulp_u == ulp_v:
return
if self.rank[ulp_u] < self.rank[ulp_v]:
self.parent[ulp_u] = ulp_v
elif self.rank[ulp_v] < self.rank[ulp_u]:
self.parent[ulp_v] = ulp_u
else:
self.parent[ulp_v] = ulp_u
self.rank[ulp_u] += 1
def union_by_size(self, u, v):
ulp_u = self.find_upar(u)
ulp_v = self.find_upar(v)
if ulp_u == ulp_v:
return
if self.size[ulp_u] < self.size[ulp_v]:
self.parent[ulp_u] = ulp_v
self.size[ulp_v] += self.size[ulp_u]
else:
self.parent[ulp_v] = ulp_u
self.size[ulp_u] += self.size[ulp_v]
class Solution:
def spanningTree(self, V, adj):
edges = []
for i in range(V):
for it in adj[i]:
adjNode, wt = it
node = i
edges.append((wt, (node, adjNode)))
ds = DisjointSet(V)
edges.sort()
mstWt = 0
for wt, (u, v) in edges:
if ds.find_upar(u) != ds.find_upar(v):
mstWt += wt
ds.union_by_size(u, v)
return mstWt
if __name__ == "__main__":
V = 5
edges = [[0, 1, 2], [0, 2, 1], [1, 2, 1], [2, 3, 2], [3, 4, 1], [4, 2, 2]]
adj = [[] for _ in range(V)]
for it in edges:
u, v, wt = it[0], it[1], it[2]
adj[u].append((v, wt))
adj[v].append((u, wt))
obj = Solution()
mstWt = obj.spanningTree(V, adj)
print("The sum of all the edge weights:", mstWt)



Comments