Solution
class Disjoint_Set(object):
def __init__(self, ini_list=[]):
self.parent_map = {}
self.group = {}
for connect in ini_list:
self.union(connect[0], connect[1])
def union(self, a, b):
while a in self.parent_map:
a = self.parent_map[a]
while b in self.parent_map:
b = self.parent_map[b]
if a != b:
self.parent_map[a] = b
if b in self.group:
if a in self.group:
self.group[b] = self.group[b].union(self.group.pop(a))
self.group[b].add(a)
else:
if a in self.group:
self.group[b] = self.group.pop(a)
else:
self.group[b] = set([a])
def is_in_same_group(self, a, b):
while a in self.parent_map:
a = self.parent_map[a]
while b in self.parent_map:
b = self.parent_map[b]
return a is b
def find_group(self, a):
while a in self.parent_map:
a = self.parent_map[a]
if not a in self.group:
return None
res = list(self.group[a])
res.append(a)
return res