백준 1717번 - disjoint set

각 노드를 가리키는 다른 노드들이 유니언 연산에 의해서 밑에 추가될 때마다 해당 로드의 랭크 인덱스를 계속 갱신해주고 랭크가 큰 쪽 부분집합 (disjoint set) 에다가 union 을 함으로써 트리가 너무 깊어지지 않게끔 만들었습니다 😀

문제

첫째 줄에 n(1 ≤ n ≤ 1,000,000), m(1 ≤ m ≤ 100,000)이 주어진다. m은 입력으로 주어지는 연산의 개수이다. 다음 m개의 줄에는 각각의 연산이 주어진다. 합집합은 0 a b의 형태로 입력이 주어진다. 이는 a가 포함되어 있는 집합과, b가 포함되어 있는 집합을 합친다는 의미이다. 두 원소가 같은 집합에 포함되어 있는지를 확인하는 연산은 1 a b의 형태로 입력이 주어진다. 이는 a와 b가 같은 집합에 포함되어 있는지를 확인하는 연산이다. a와 b는 n 이하의 자연수 또는 0이며 같을 수도 있다.

7 8
0 1 3
1 1 7
0 7 6
1 7 1
0 3 7
0 4 2
0 1 1
1 1 1
NO
NO
YES

코드

#python 3 
import sys
class DisjointSet:
    def read(self):
        self.item_num,self.operation_num=map(int,input().split())
        self.union_question=[]
        self.root=[i for i in range(self.item_num+1)]
        self.rank=[0]*len(self.root)
        for _ in range(self.operation_num):
        		 self.union_question.append(list(map(int,sys.stdin.readline().split())))
    
    def find(self,a):
        if self.root[a]==a:
            return a
        else:
            return self.find(self.root[a])
    
    def union_by_rank(self,a,b):
        A=self.find(a)
        B=self.find(b)
        if A==B:
            return
        elif self.rank[A]>self.rank[B]:
            self.root[B]=A
        else:
            self.root[A]=B
            if self.rank[A]==self.rank[B]:
                self.rank[B]+=1
                
    def solve(self):
        for item in self.union_question:
            if item[0]==0:
                self.union_by_rank(item[1],item[2])
            else:
                if self.find(item[1])==self.find(item[2]):
                    print("YES",end='\n')
                else:
                    print("NO",end='\n')

disjoint=DisjointSet()
disjoint.read()
disjoint.solve()

코드 과정

7 8
0 1 3
1 1 7
0 7 6
1 7 1
0 3 7
0 4 2
0 1 1
1 1 1
NO
NO
YES

KakaoTalk_20210302_234825442

0 1 3

KakaoTalk_20210302_234825752

0 7 6

KakaoTalk_20210302_234825918

0 3 7

KakaoTalk_20210302_234826060

0 4 2

KakaoTalk_20210302_234826254