본문 바로가기
Programming/Algorithm

[Algorithm][Python]유니온-파인드(Union-Find) 알고리즘 + 코드

by NoiB 2023. 8. 3.
반응형

이번에는 유니온 파인드 알고리즘에 대해서 한 번 알아봅시다. 아직 학부생이던 시절에 저보다 먼저 알고리즘 공부를 하던 친구가 유니온 파인드 한 번 써보라고, 진짜 너무 편하다고 했던 기억이 나는데요. 그 때 저는 DFS/BFS도 모르던 시절이라 그런게 있구나 생각만 하고 넘겼습니다. 사실상 그런 일이 있었다는 것도 거의 잊고 지내다가 평소처럼 알고리즘 문제를 풀었는데 별로 런타임이 빠르지 않았고, 해당 문제가 어떤 태그가 있나 살펴보다가 분리 집합이라는 알고리즘으로 분류가 되어있어서 이걸 사용하면 좀 빨라지려나 싶어서 찾아봤던게 유니온 파인드와의 첫 만남이었네요.

 

사실 이게 유니온 파인드구나를 알았을 때는 약간의 당황과 겁이 났습니다. 저한테 유니온 파인드는 상당히 먼 얘기일거라고 생각하고 있었거든요. 아직 아무것도 모르던 시절에 이미 훨씬 앞서가있는 친구가 했던 얘기라 무의식 속에서 오랜 세월이 지나야지 만날 수 있을거라고 생각했었나 봅니다. 그래서 '어라, 아직 이걸 만나면 안되는데, 어떡하지' 하는 생각이 먼저 들었습니다. 하지만 몰랐다면 유야무야 넘어갔을지라도, 이미 유니온 파인드를 써야한다는 걸 알았는데 그냥 넘어갈 순 없죠. 바로 유니온 파인드에 대해서 알아보기 시작했고 처음 눈에 들어왔던 포스팅이 너무 좋았기 때문에 읽자마자 바로 직접 코드를 짤 수 있을 정도였습니다. 이 자리를 빌어서 다시 한 번 감사드리며, 해당 포스팅의 링크를 걸어두겠습니다. https://4legs-study.tistory.com/94

 

분리 집합 (Disjoint Set) : Union-Find

서론 다음과 같은 메신저 프로그램이 있다. "A와 B가 친구 관계이고, 내가 A와 친구 관계이면 자동으로 나와 B는 친구 관계가 된다." 이 메신저 프로그램을 통해 내가 A라는 사람과 친구 관계를 맺

4legs-study.tistory.com

 

유니온-파인드(Union-Find)

유니온 파인드 알고리즘은 엄밀하게 말하면 알고리즘이라기 보다는 '분리 집합을 조작 및 관리하는 방법'이라고 할 수 있을 것 같습니다. 분리 집합(Disjoint Set) 또는 서로소 집합이라고 부르는 자료 구조는 상호 배타적인(서로 다른 집합에 공통 원소가 없는 = 교집합이 없는) 자료 구조를 말합니다. 이런 서로소 집합들에서 특정 원소가 어떤 집합에 포함되는지를 찾거나(Find), 서로 다른 집합을 합치는(Union) 연산을 사용하기 때문에 유니온 파인드 자료 구조라고 부르기도 합니다.

 

사용하는 연산만 봐도 알겠지만, 서로 다른 집합을 합치거나 특정 원소가 어떤 집합에 있는지 찾는 것에 이점이 있는 자료 구조입니다. 사실 저는 이 표현을 더 좋아하는데, 누가 어떤 그룹인지를 찾는 것에 강점이 있습니다. 서로소 집합이니까 공통원소가 없다면 다른 그룹이라는 뜻이 되고 공통원소가 있다면 같은 그룹이라는 뜻이 되니까요.

 

분리 집합의 구현

현재는 이미 효율적인 방법을 사용하고 있지만, 우리가 최초로 분리 집합이라는 자료구조를 만들었다고 치고 어떻게 구현할 것인지 한 번 생각해봅시다. 실제 최초의 분리 집합 자료구조는 어땠을지 모르겠지만 아주 나이브하게 진짜로 배열에 각 원소를 담아보죠. 예를 들어 {0, 1, 2}, {3, 4, 5}, {6, 7, 8} 이렇게 집합이 있다면,

a = [0, 1, 2]
b = [3, 4, 5]
c = [6, 7, 8]

뭐 이런 식으로 시작했을 것 같아요(그 때는 파이썬도 없었겠지만 그런 부분은 넘어가도록 합시다). 그런데 집합을 합치거나, 연결을 끊거나, 어떤 원소가 어디에 있는지 찾는게 상당히 비효율적일 것으로 보입니다. 지금이야 집합 3개에 원소도 3개씩이라 한 눈에 바로 알아볼 수 있지만, 원소가 각 몇 십 개, 집합도 몇 십 개만 되어도 눈으로 보고 골라낼 순 없으니 반복문을 잔뜩 써야 하는데 각 집합이 서로소일거란 보장도 없죠. 두 집합이 서로소인지 확인 하는 것에 각 원소의 갯수가 n, m개라고 하면 O(n*m)이 걸릴겁니다. 합치는 건 좀 나은가 싶지만 이게 무의식적으로 파이썬이라고 생각하니까 나아보이지, 배열 사이즈 변화가 안되는 C같은 언어라고 생각하면 합칠 때 마다 새로운 배열을 만들어줘야 하니 메모리도 신경써야할겁니다. 이래서는 안되겠죠.

 

그렇다면 진짜로 배열에 각 원소를 담지 말고, 추상적으로 담아봅시다. 각 원소에다가 얘 여기 담겼어요 하고 표시만 해주는거에요.

disjoint_set = [a, a, a, b, b, b, c, c, c]

가독성은 좀 떨어진 것 같습니다만 활용도 면에 비해서는 아까 보다 훨씬 낫죠. 임의의 원소가 어떤 집합에 속하는지 알고 싶으면 disjoint_set[원소]로 바로 찾을 수 있습니다. 집합을 합칠 때도 예를 들어 a와 b를 합치면 b라고 되어있는 친구들을 다 찾아서 a로 바꿔주기만 하면 됩니다. 빼줄 때는 조금 귀찮겠지만 현재 사용하지 않는 문자를 찾아서 빼주는 친구들의 값을 그 문자로 다시 바꿔주면 되겠죠. Find 연산은 O(1)이고 Union 연산은 O(모든 원소 갯수)의 시간복잡도를 갖겠네요. 뺄 때는 사용중인 문자열이 몇개냐에 따라 다르겠죠. 계속 새로운 문자를 부여해도 되겠지만 집합의 갯수가 몇십만개쯤 된다면 aaaaaaaa... 같은 값을 갖기도 하겠네요. 역시 아직 너무 효율적이다 라고는 하기 힘들 것 같아요.

 

그리고 이게 실제로 사용중인 각 집합을 트리 자료 구조로 만드는 방식입니다. 위의 방식과 비슷하지만 자신의 부모 노드를 값으로 저장하고 루트 노드를 집합의 고유 번호로 사용하는 방식이죠.

disjoint_set = [0, 0, 1, 3, 3, 4, 6, 6, 7]

임의로 바로 앞의 노드를 부모노드로 정했습니다(루트 노드는 부모 노드가 없으니 자기 자신). 엥, 아까보다 더 비효율적이지 않아요? Find 연산만 생각했을 때는 약간 그렇습니다. 특정 원소가 어떤 집합에 속하는지 알기 위해서 최악의 경우에는 트리의 깊이만큼(가장 깊은 곳에서 루트 노드까지) 연산을 해야하니까요. 그게 지금처럼 불균형한 트리라면 더 그렇겠죠. 하지만 합치거나 빼는 부분에서는 위의 방법보다 훨씬 좋은 성능을 보여줍니다.

 

0번 집합과 3번 집합을 한 번 합쳐볼까요. 합칠 때는 disjoint_set[3]의 값을 3에서 0으로 바꿔주면 됩니다. 빼줄 때도 다시 disjoint[3]의 값을 0에서 3으로 바꿔주면 끝입니다. 훨씬 간단해졌죠. 에이, 이건 루트 노드 끼리 붙였으니까 그런거잖아요. 맞습니다. 합치라는 명령이 루트 노드와 루트 노드가 아닌 경우를 생각해보죠.

 

위에서 말했던 대로 한다면, 간선 정보가 2번 노드와 5번 노드를 연결하라고 주어졌을 때 2번의 값을 5로 바꾸거나 5번의 값을 2로 바꿔야겠죠. 하지만 그렇게 할 경우 기존 집합의 연결성이 깨집니다. 무슨 소린가 하면 2번의 값을 5로 바꾼다면 0,1,2를 원소로 갖는 집합을 3,4,5 집합에 붙인 것이 아닌 2만 떼어 붙인 3,4,5,2 이렇게 만들게 되는 것이죠.

 

그래서 이렇게 루트노드가 아닌 노드끼리 합치라는 경우가 발생하면 한쪽의 루트노드를 붙이도록 합니다. 예를 들어 2번과 5번을 붙여야 할 경우 2번의 루트 노드인 0번을, 그러니까 집합 전체를 다 가져다 5번에 붙이는 것이고, disjoint_set[0]의 값을 0에서 5로 바꾸는 것입니다. 그렇게 하면 더 이상 0번 노드가 루트 노드가 아니게 되고 2번 노드의 루트 노드를 구하면 3이 나오면서 아 두 집합이 서로 합쳐졌구나 하고 생각할 수 있는 것이죠. 합치는게 리프 노드라고 해도 루트까지 트리의 깊이 만큼만 타고 올라가면 되니까 이전에 했던 합치는 연산이 O(N)이 걸리는 것 보다 훨씬 좋죠.

def find(node):
    if disjoint_set[node] == node:
        return node
    else:
        return find(node)
    
def union(node1, node2):
    root1, root2 = find(node1), find(node2)
    if root1 != root2:
        disjoint_set[root1] = node2

코드로 짜면 이렇게 될겁니다. find연산으로는 해당 노드가 속하는 그룹의 루트 노드를 찾고, union 연산은 각 노드의 루트 노드가 다를 경우, 즉 다른 그룹일 경우 한 쪽의 루트 노드를 다른 한 쪽의 노드 아래에 붙이는 거죠. 코드로 보니까 훨씬 간단하죠. 다만 이렇게 할 경우 합치면 합칠수록 트리의 깊이가 점점 깊어짐 = find연산의 시간이 오래걸린다는 단점이 있습니다. 

 

유니온 파인드의 최적화

그래서 위와 같은 단점을 해결하기 위한 최적화 방법이 존재합니다. 유니온 파인드의 최적화 방법에는 경로 압축(Path Compression)과 랭크 기반 결합(Union by Rank)이 있습니다.

 

경로 압축(Path Compression)

좀 극단적인 예시를 생각해보겠습니다. 0번부터 n번까지 노드를 union연산을 통해서 각각의 바로 앞의 번호 아래에 붙인다고 해볼게요. 그러면 disjoint_set = [0, 0, 1, 2, 3, ..., n-1]이 되겠죠. 그림으로 그려보면,

이런 모양이 되겠네요. 만약 이 때 저희가 각 노드가 앞 번호 노드에 붙었다는 사실을 모른다고 한 번 해볼게요. 그럼 어떤 그룹인지, 같은 그룹인지 알기 위해서는 find 연산을 써봐야겠죠. 비교하는 노드가 n과 n-2라고 해볼게요. n의 루트 노드를 찾으려면 0 노드까지 n번을 타고 올라가야겠죠. n-2노드도 0까지 n-2번을 타고 올라가야 합니다. 이건 너무 비효율적이죠.

 

그럼 어떻게하면 좋을까요? 어차피 find 연산은 루트 노드를 반환하는게 주목적인데 아예 모든 노드를 루트 노드 바로 아래에 붙여버리는 건 어떨까 라는 아이디어를 구현한게 바로 경로압축입니다. find 연산을 할 때 root 노드까지 쭉 타고 올라간 다음에 재귀를 탈출하면서 각 노드의 disjoint_set의 값을 루트 노드로 변경하는 겁니다. 코드를 작성해보면,

def find(node):
    if disjoint_set[node] == node:
        return node
    else:
        disjoint_set[node] = find(disjoint_set[node])
        return disjoint_set[node]

말로 설명을 하면 find연산의 파라미터로 넣어준 노드가 루트 노드가 아니라면,(루트 노드는 자기 자신을 값으로 저장하기로 했었죠) 루트 노드를 찾을 때까지 find 연산에 현재 노드의 부모 노드를 넣어서 재귀 호출합니다. 찾았다면 루트 노드를 반환하면서 재귀를 탈출합니다. 탈출하면서 반환값을 현재 노드의 부모노드로 저장하고 현재 노드의 부모 노드를 다시 반환하면서 재귀를 탈출합니다. 실행 과정도 간단하게 써보겠습니다. node는 3이라고 해보죠.

find(3)

>> find(disjoint_set[3]) # disjoint_set[3] = 2

>> find(disjoint_set[2])

>> find(disjoint_set[1])

>> find(disjoint_set[0]) # 루트노드와 입력노드가 일치

>> uf[0] = 0, return 0

>> uf[1] = 0, return disjoint_set[1]

>> uf[2] = 0, return disjoint_set[2]

>> uf[3] = 0

이렇게 진행되겠네요. 그렇다면 find(n)을 진행하면 모든 노드의 부모 노드가 0이 되겠죠. 그림으로 그려보면,

연결 상태는 이렇게 됩니다. 이제 find(n)을 해도 n번이 아니라 한 번만 타고 올라가면 바로 루트 노드에 도달할 수 있습니다.

 

랭크 기반 결합(Union by Rank)

이번 최적화 방법은 Union by Rank 입니다. 정식 명칭이 뭔지 찾아보려고 했는데 다들 Union by Rank라고만 적어놔서 제 나름대로 이해하기 편하도록 번역을 해봤습니다. 트리의 랭크(Rank)는 해당 트리의 자식 노드 중 최대 레벨을 나타낸 것입니다. 노드의 레벨(level)은 특정 깊이에 있는 노드의 집합을 나타냅니다. 노드의 깊이는 루트 노드에서 해당 노드에 도착할 때 까지 거쳐야 하는 간선의 수를 말합니다. 말이 많죠. 사실 다 개념이 조금씩 달라서 이게 이거다라고 말할 수는 없지만 말만 조금 다를 뿐 읽어보면 사실상 같은 값이 나올 수 밖에 없는 설명이라 그냥 그렇구나만 생각하시고 랭크는 트리의 최대 깊이와 같은 값이라고 생각하셔도 무방합니다.

 

그러면 랭크 기반 결합은 뭐냐. 최대한 트리의 랭크가 커지지 않도록 결합을 하는 테크닉이라고 생각하시면 됩니다. 위에서 트리의 깊이가 깊어지면 연산이 비효율적이게 된다고 했죠. 그래서 연산의 효율성을 위해 합쳤을 때 트리의 깊이가 깊어지지 않도록 하는 방법입니다. 그럼 어떻게 합쳐야 깊이가 깊어지지 않느냐, 그건 그림을 먼저 한 번 보는게 좋을 것 같아요.

어떤 말을 하려고 하는지 감이 벌써 오신 분도 있을 것 같은데요. 랭크가 큰 트리에 작은 트리를 붙이는게 결과적으로 랭크가 커지지 않도록 결합하는 방법입니다. 큰 트리에 붙이면 랭크가 바뀌지 않기 때문에 따로 랭크를 키워주는 조작도 필요없어요. 어, 그러면 랭크가 같을 때는 어떻게 하죠? 그것도 그림으로 볼까요.

이렇게 랭크가 같은 트리를 합칠 경우에는 그냥 아무 한쪽 트리에 붙인 다음 랭크를 1 상승시켜주기만 해도 됩니다. 이렇게 결합을 하면 트리가 너무 불균형해지는 것도 어느정도 방지할 수 있겠죠. 이번엔 코드로 확인해봅시다.

def union(node1, node2):
    root1, root2 = find(node1), find(node2)
    if root1 != root2:
        if rank[root1] < rank[root2]:
            disjoint_set[root1] = root2
        elif rank[root1] == rank[root2]:
            disjoint_set[root1] = root2
            rank[root2] += 1
        else:
            disjoint_set[root2] = root1

각 루트 노드를 찾아서 서로 다를 때, 랭크가 작은 쪽을 큰 쪽에 붙이고, 랭크가 같다면 아무거나 붙인 다음 붙인 쪽의 랭크를 1 키워줍니다.

 

개념도 별로 어렵지 않고, 코드도 간단하죠? 왜 그렇게 친구가 편하다고 써보라고 추천을 했는지 알 것 같았습니다. 사실 유니온 파인드가 별로 어렵지 않다는 걸 알았을 땐 데이크스트라를 이해했을 때 보다 훨씬 기뻤던 것 같아요. 물론 그 기쁨이 오롯이 유니온 파인드를 이해했다는 점 때문은 아니었고 도전 중인 문제에서 연결 요소를 구하는 과정에 유니온 파인드를 접목시켜볼 수 있을 것 같다는 생각이 들어서 그랬던 게 절반 정도 였지만요(아쉽게도 유니온 파인드를 써서 예제는 풀었지만 시간초과로 문제 통과는 못했습니다).

 

다만 확실히 데이크스트라에 비해 고생을 덜해서 그런가 포스팅이 훨씬 짧습니다. 아마 이 모든 건 처음에 봤던 자료가 너무 훌륭했기 때문이라고 생각이 드네요. 꼭 상단에 제가 첨부한 링크를 들어가보셨으면 좋겠습니다. 아마 제 글보다 훨씬 이해가 잘 되실거에요. 

반응형