성장일기

내가 보려고 정리하는 공부기록

자료구조

[Data Structure] Segment Tree (세그먼트 트리) | 자바

와나나나 2025. 1. 20. 18:34
728x90

백준 문제를 풀다가 정리하지 않은 자료구조가 나와서 정리하고자 글을 쓴다 ! 이번 주제는 세그먼트 트리이다.

순서는 아래와 같다.

 

  1. 세그먼트 트리란?
  2. 구현하기 (구간합을 예시로)
    1. 생성
    2. 데이터를 수정하게 된다면?

 

* 세그먼트 트리의 개념은 아래를 참고하였습니다 *

https://cano721.tistory.com/38

 

[알고리즘 개념] 세그먼트 트리(Segment Tree) / Java

세그먼트 트리란 특정 구간 내 데이터에 대한 연산(쿼리)을 빠르게 구할 수 있는 트리. ex) 특정 구간 합,최소값,최대값,평균값 등등 Segment : 부분.분할.나누다.분할하다. 시간복잡도 데이터 변경:

cano721.tistory.com

https://book.acmicpc.net/ds/segment-tree

 


1. 세그먼트 트리?

일반적으로 트리는 부모노드와 자식 노드에 데이터 자체를 저장한다. 하지만 세그먼트 트리는 리프노드 여부에 따라 다른 의미를 갖는 트리이다.

  • 리프노드 - 배열의 수 그 자체
  • 리프노드가 아닌 노드 - 왼쪽 자식노드와 오른쪽 자식노드의 무언가 (누적합을 구하는 경우엔 왼쪽과 오른쪽 자식노드의 합이 저장되고, 최대값을 구하는 경우에는 왼쪽과 오른쪽 자식노드 중 더 큰 값이 저장됨)

세그먼트 트리

 

시간복잡도는 O(log N) 을 갖기 때문에 효율적이다!

 


2. 구현하기

✅ 생성하기

트리의 값은 보통 배열에 저장한다. 이진트리인 것을 이용하면 루트노드의 인덱스는 1부터 시작하여

왼쪽은 루트노드idx * 2, 오른쪽은 루트노드idx * 2 + 1이 된다.

 

배열에 저장하려면 미리 배열의 크기를 지정해야 하고, 이 과정에서 트리의 높이를 알아야 한다.

 

세그먼트 트리의 높이

만약 리프노드의 개수가 8이라면, 세그먼트 트리를 생성했을 때 아래와 같은 모양이 된다.

 

노드의 개수가 8 -> 4 -> 2 -> 1 이 되고 높이는 3이라는 것을 알 수 있다. 즉, 원소의 개수를 N이라 하면 높이는

밑이 2인 로그 N 임을 알 수 있다. N이 9이면 높이가 1 증가하는 것으로 보면, 로그 결과를 올림해주면 높이가 나온다.

 

자바에서 Math.log() 는 밑이 e인 자연로그이기 때문에, 아래와 같이 써주면 된다.

int h = Math.ceil(Math.log(N) / Math.log(2));

 

높이를 구하고 나면, 세그먼트 트리의 배열 크기는 단순하게 구해진다.

int[] segmentTree = new int[Math.pow(2, h + 1)];

 

 

세그먼트 트리 생성

이제 재귀함수를 통해 배열을 채워나간다.

현재 배열의 시작과 끝 인덱스를 파라미터로 받아서 구현한다. 누적합을 구하는 경우를 예로 들어 코드를 작성하려 한다!

// tree 배열 : 세그먼트 트리 배열
// input 배열 : 리프노드 입력받은 배열

int[] tree = new int[N];
public long make(int[] input, int now, int start,int end){
            
            // 배열의 시작과 끝이 같다면 리프 노드 -> 리프노드 값 담음
            if(start == end){
                return tree[now] = arr[start];
            }
			
            // 리프 노드가 아니면 -> 자식노드 합 담기
            return tree[now] =
            init(arr, now * 2, start, (start + end) / 2)
            + init(arr, now * 2 + 1, (start + end) / 2 + 1, end);
}

 

 

데이터 수정시 트리 수정하기

특정 데이터가 수정되었다면 기존 값과 특정 데이터 값의 차이를 파라미터로 넘겨서 관련 노드들의 값을 수정한다.

// node: 현재노드idx, start: 배열의 시작, end:배열의 끝, change: 변경된 데이터의 idx, diff: 원래 데이터 값과 변경 데이터값의 차이
public void update(int now, int start, int end, int change, long diff) {

    // 변경할 idx 범위 체크 -> 해당 범위에 없으면 수정 안 함
    if(change < start || end < change) return;

    // 차이 적용
    tree[now] += diff;

    // 리프노드가 아니면 아래 자식들도 확인
    if(start != end){
        update(now * 2, start, (start + end) / 2, change, diff);
        update(now * 2 + 1, (start + end) / 2 + 1, end, change, diff);
    }
}

 

 

누적합 구하기 (변형하면 구간의 최대 최소 구하기도 가능)

// start : 현재 배열의 시작 end : 현재 배열의 끝, left : 원하는 구간의 시작, right: 원하는 구간의 끝

public long sum(int now, int start, int end, int left, int right){
    // 범위체크
    if(left > end || right < start){
        return 0;
    }

    // 범위 내 완전히 포함되면 바로 리턴
    if(left <= start && end <= right){
        return tree[now];
    }

    // 그 외의 경우 좌우 트리 탐색
    return sum( now * 2, start, (start + end) / 2, left, right)
    + sum(node * 2 + 1, (start + end) / 2 + 1, end, left, right);
}

 

재귀를 이용하고, 범위만 잘 확인해주면 어렵지 않게 풀 수 있을 거 같다!