성장일기

내가 보려고 정리하는 공부기록

코딩테스트/백준 골드

[백준] 1197: 최소 스패닝 트리 (MST, 크루스칼 알고리즘) - JAVA

와나나나 2024. 10. 24. 15:24
728x90

class 5

https://www.acmicpc.net/problem/1197

 

 

# 문제

그래프가 주어졌을 때, 그 그래프의 최소 스패닝 트리를 구하는 프로그램을 작성하시오.

최소 스패닝 트리는, 주어진 그래프의 모든 정점들을 연결하는 부분 그래프 중에서 그 가중치의 합이 최소인 트리를 말한다.

 

# 예제

입력 : 첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이 가중치 C인 간선으로 연결되어 있다는 의미이다. C는 음수일 수도 있으며, 절댓값이 1,000,000을 넘지 않는다.

그래프의 정점은 1번부터 V번까지 번호가 매겨져 있고, 임의의 두 정점 사이에 경로가 있다. 최소 스패닝 트리의 가중치가 -2,147,483,648보다 크거나 같고, 2,147,483,647보다 작거나 같은 데이터만 입력으로 주어진다.

3 3
1 2 1
2 3 2
1 3 3

 

 

출력

3

 

# 필요개념

스패닝 트리라는 이름을 보자마자 익숙하다는 생각이 들어 찾아보니 예전에 공부한 적이 있는 트리였다. 다시 한 번 정리해보면 아래와 같이 설명할 수 있다.

 

스패닝 트리는 모든 노드를 거쳐가는 트리이지만 순환하지 않는 트리를 의미한다. 그렇기에 노드의 개수가 N이라면, 간선의 개수는 이보다 1개 작은 N - 1 이라고 정의할 수 있게 된다. 

 

또 스패닝 트리는 가중치를 갖는데, 이 가중치가 최소가 되도록 만들어진 트리가 바로 최소 스패닝 트리 (MST, Minimum Spanning Tree) 이다.

 

스패닝 트리를 구하기 위해서는 프림 알고리즘, 크루스칼 알고리즘 등을 사용한다.

 

✅ 프림 알고리즘

프림 알고리즘은 임의의 시작 정점에서 가장 가까운 정점을 하나씩 추가하여 최소신장트리(MST)를 만들어가는 알고리즘이다. 자세한 것은 아래 링크에 적어두었다.

https://wanna-developer02.tistory.com/51

 

[알고리즘] CHAP 4. The Greedy Approach (탐욕알고리즘) - Prim's alg (프림알고리즘)

이번 챕터에서는 탐욕알고리즘과 이를 이용하는 프림, 크루스칼, 다익스트라 알고리즘을 정리해보려고 한다. 한 게시물에 정리하기에는 내용이 너무 많아서 나눠서 정리할 예정이다. 목차는 다

wanna-developer02.tistory.com

 

✅ 크루스칼 알고리즘

이번 문제에서 사용한 알고리즘이다. MST는 결국 가중치를 최소로 만들어야 하기 때문에 우선 가중치를 오름차순 정렬하고 순환하지 않는 선에서 간선을 추가해가는 방식이다. 

 

이 알고리즘에서 가장 중요한 것은 순환 여부를 어떻게 확인하느냐 이다. 이를 위해 배열을 한 개 만들어 해당 노드의 루트를 담는다. 그 후 연결되는 노드의 루트를 통합(union)해준다.

 

위 작업을 위해 두가지 함수를 만들었다. 

 

1. 루트노드를 구하는 함수 - find()

기본적으로 parent[idx] 에는 idx가 들어있다. 하지만 값이 다르다면 무언가 작업이 이루어졌다는 뜻이고, 이를 재귀함수를 통해 구해낸다.

 

만약 두 노드의 루트가 다르다면? 순환그래프가 되지 않을 것이기 때문에 MST에 해당 노드를 추가하고, 루트를 통합해준다. 

 

2. 루트를 통일해주는 함수 - union()

임의로 부모노드의 루트로 통합해주었다. 이때 루트를 구하는 과정에서 앞서 구현한 find 함수를 사용한다.

 

3. 크루스칼 알고리즘 함수 - kruskal()

위 함수들을 이용해 최종 함수를 만들었다.

private static void kruskal() {
        int idx = 0;
        while (true) {
            if (edges == N - 1) break;
            int parent = inputs[idx][0];
            int child = inputs[idx][1];
            int weight = inputs[idx][2];

            if (find(parent) != find(child)) {
                union(parent, child);
                mst_sum += weight;
                edges++;
            }
            idx++;
        }
    }

 

MST의 특성상 간선의 개수는 노드의 개수 - 1 이므로 이를 종료조건으로 잡았다.

# Code

import java.io.*;
import java.util.*;

public class Main {
    static int[][] inputs;
    static int edges = 0;
    static int mst_sum = 0;
    static int N;
    static int[] parent;
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        N = Integer.parseInt(st.nextToken());
        int E = Integer.parseInt(st.nextToken());
        inputs = new int[E][3];
        parent = new int[N + 1];

        for (int i = 1 ; i <= N ; i++) {
            parent[i] = i;
        };

        for (int i = 0 ; i < E ; i++) {
            st = new StringTokenizer(br.readLine());
            int v1 = Integer.parseInt(st.nextToken());
            int v2 = Integer.parseInt(st.nextToken());
            int weight = Integer.parseInt(st.nextToken());
            inputs[i][0] = v1;
            inputs[i][1] = v2;
            inputs[i][2] = weight;
        }
        Arrays.sort(inputs, new Comparator<int[]>() {
            @Override
            public int compare(int[] o1, int[] o2) {
                return o1[2] - o2[2];
            }
        });

        kruskal();
        System.out.println(mst_sum);
    }

    private static void kruskal() {
        int idx = 0;
        while (true) {
            if (edges == N - 1) break;
            int parent = inputs[idx][0];
            int child = inputs[idx][1];
            int weight = inputs[idx][2];

            if (find(parent) != find(child)) {
                union(parent, child);
                mst_sum += weight;
                edges++;
            }
            idx++;
        }
    }

    private static void union(int v1, int v2) {
        int root1 = find(v1);
        int root2 = find(v2);
        parent[root2] = root1;
    }

    private static int find(int node) {
        if (node != parent[node]) {
            parent[node] = find(parent[node]);
        }
        return parent[node];
    }
}

 

 

# 결과