세그먼트 트리

세그먼트 트리는 트리형태의 자료구조로 해당 노드에는 구간합이 저장되어 있다.

일반적으로 배열에서 특정 구간의 합을 구해야 한다면 O(N)이 걸린다.

세그먼트 트리는 노드마다 구간의 합이 저장되어 있기때문에 해당 노드에 접근하여 반환하면 되므로 O(logN)이 걸리게 된다.

이처럼 세그먼트 트리는 여러개의 데이터가 연속적으로 있을 때, 특정 구간의 합을 구하는데 효율적인 자료구조이다.

 

세그먼트 트리만들어보기

배열 A = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]이 주어졌다. 편의상 인덱스는 1부터 시작한다.

1 2 3 4 5 6 7 8 9 10

이를 트리로 만들어보자.

위 그림은 배열 A를 트리로 만들어본 형태다.

왼쪽노드는 index*2, 오른쪽노드는 index*2+1이 되고, 트리의 특성상 합을 구하면 O(logN)이 걸리게 된다.

구간힙 트리로 만들어보자.

맨 위 최상단노드는 모든 원소의 합이 들어간다.

그다음 두 번째 세번째 노드를 구해주는데 두번째노드는 1번부터 5번까지의 구간합, 세 번째 노드는 6번부터 10번까지의 구간합이 들어간다.

그 다음 자식노드도 구해보자. 

두번째 노드는 1~5의 구간합이고, 세번째노드는 6~10의 구간합이다.

중간인덱스를 구해 두 번째 노드의 왼쪽자식노드는 1~3, 오른쪽자식노드는 4~5

세 번째 노드의 왼쪽자식노드는 6~8, 오른쪽자식노드는 9~10 이 된다.

이렇게 마지막 노드까지 구하면 구간합 트리는 아래와 같다.

 

이전의 배열의 크기는 10이었고 트리로 만들어도 크기는 달라지지 않았지만, 구간힙 트리로 만들면서 크기가 변경되었다.

이처럼 배열을 구간힙트리로 만들려면 배열의 크기보다 크면서 가장 인접한 제곱수 * 2만큼의 크기를 할당해야 한다.

예시의 배열크기는 10이어서 10보다 큰 인접한 제곱수 4^2 = 16이고, 16*2 = 32이다. 최소 32의 크기를 할당해야 하지만, 더 넉넉하게 주려면 배열의 크기보다 4배 크면 된다.(최소로만 크기를 할당해도 상관은 없다.)

 

※종합

1. 연속된 데이터의 특정범위를 빠르게 구하기 위해 세그먼트 트리를 작성

2. 세그먼트 트리의 크기는 적어도 주어진 데이터크기보다 크며 가장 인접한 제곱수의 2배만큼 할당(넉넉히 4배만큼 줘도 상관은 없다)

3. 세그먼트 트리의 각 노드는 특정 구간의 합이 저장되어 있어 선형탐색(O(N))보다 효율적(O(logN))

 

이를 코드로 구현해 보자.

 

1. 세그먼트 트리 만들기

// 연속된 데이터 예시
// 트리의 특성(왼쪽자식인덱스:index*2, 오른쪽자식인덱스:index*2+1)를 이용하기위해 배열 첫번째는 임의의데이터를 넣는다.
var arr = Array(0...10)
// 크기 할당
var segTree = Array(repeating: 0, count: arr.count*4)

// 세그먼트 트리 만들기 (재귀호출)
// 구간합을 구해서 더한값이 부모노드가 됨
// 재귀호출로 맨 아래의 노드(start=end)까지 도달 후 더해서 부모노드를 채워가는 형식으로 진행됨
// start: 배열arr의 첫번째인덱스, end: 배열arr의 마지막인덱스
// index: 세그먼트트리의 인덱스
func segmentTreeInit(_ start: Int, _ end: Int, _ index: Int) -> Int{
	// 구간의 start == end는 해당 위치의 원소를 의미
    // 배열 arr의 3부터3까지의 구간합은 결국 arr[3]이다.
    if start == end {
        segTree[index] = arr[start]
        return segTree[index]
    }
    // 구간을 둘로 나누어 트리 특성상 왼쪽노드의 위치는 index*2, 오른쪽노드는 index*2+1
    var mid = (start + end) / 2
    // 재귀호출
    // 해당 구간합은 그 위치의 왼쪽 자식노드와 오른쪽자식노드의 합이다.
    segTree[index] = segmentTreeInit(start, mid, index*2) + segmentTreeInit(mid+1, end, index*2+1)
    return segTree[index]
}

 

결과:

segmentTreeInit(1, arr.count-1, 1)
print(segTree)
print(segmentTreeInit(1, arr.count-1, 1))

// [0, 55, 15, 40, 6, 9, 21, 19, 3, 3, 4, 5, 13, 8, 9, 10, 1, 2, 0, 0, 0, 0, 0, 0, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
// 55

 

 

2. 특정 구간의 합 구하기

세그먼트 트리를 만드는 이유가 되기도 하는 특정구간의 합을 구해보자.

특정 구간의 합을 구하는 과정은 세그먼트트리의 최상단 노드부터 시작하며,

해당노드가 구간 안에 존재한다면, 모두 더해준다.

해당 구간을 찾아가는 방법은 재귀호출로 찾아갈 것이다.

예를 들어 4~6구간의 합을 구한다고 하자.

파랑노드를 찾아가는 조건은 재귀호출마다 해당 노드의 구간이 구하려는 범위 내에 존재하면 그 노드로 탐색하는 것이다.

- [4~6] 구간을 구할 때, 왼쪽은 [1~5] 오른쪽은 [6~10]이므로 범위 내에 존재한다. 진입

- [1~5]에서 왼쪽 [1~3]은 범위내에 존재하지 않는다 0

- [1~5]에서 오른쪽 [4~5]는 범위내에 존재한다. 해당노드를 더해줌 9

- [6~10]에서 오른쪽노드는 범위 내에 존재하지않고, 왼쪽노드는 존재한다. 왼쪽노드 진입

- [6~8]에서 왼쪽노드가 범위내에 존재 왼쪽노드진입

- [6~7]에서 왼쪽노드가 범위내에 존재 왼쪽노드 진입

- [6]은 범위내에 존재. 더해준다.

-나머지 범위들은 모두 범위밖에 구간합이므로 0을 더해준다.

 

코드로 구현하면 다음과 같다.

// 특정 구간합 구하기
// start: 배열arr의 시작인덱스, end: 배열arr의 마지막인덱스
// index: 세그먼트 트리의 인덱스(재귀호출을위해 1로 입력)
// left: 구하려는 범위의 시작인덱스, right: 구하려는범위의 마지막인덱스
func segmentTreeSum(_ start: Int,_ end: Int,_ index: Int,_ left: Int,_ right: Int) -> Int{
	// left...right 범위를 벗어났다면 0리턴
    if left > end || right < start { return 0 }
    // left...right 범위내에 있는 구간합이라면 출력
    if left <= start && right >= end { return segTree[index] }
    // 자식노드로 재귀호출
    // 범위내의 구간합을 찾는다.
    let mid = (start + end) / 2
    return segmentTreeSum(start, mid, index*2, left, right) + segmentTreeSum(mid+1, end, index*2+1, left, right)
}

 

결과:

print(segmentTreeSum(1, arr.count-1, 1, 4, 6))
// 15

 

3. 특정 원소의 값 수정하기

세그먼트 트리는 구간합이 저장되어 있으므로, 특정원소의 값을 수정한다면 그 원소의 구간합 노드들을 전부 수정해 주면 된다.

세그먼트 트리의 특정원소를 수정할 때에는 새로운 값으로 초기화하는 방법이 아닌 증가, 감소하는 방향으로 구한다.

만약 아예 새로운 값으로 초기화한다면 새로 변경된 배열을 토대로 세그먼트 트리를 다시 만들어주면 된다.

 

예를 들어 배열 세 번째의 값을 6으로 만든다면

위 그림과 같이 구간합노드 중 세번째 값이 해당범위에 존재하면 모두 바꿔준다.

재귀호출로 특정원소까지 갈 때마다 증가 혹은 감소되는 값을 더해주며 내려간다.

 

코드로 구현하면 다음과 같다.

// 특정원소의 값 수정(증가 혹은 감소)
// start: 배열arr의 시작인덱스, end: 배열arr의 마지막인덱스
// index: 세그먼트트리의 인덱스
// node: 변경하려는 원소의 위치
// newValue: 변경값(증가 혹은 감소)
func segmentUpdate(_ start: Int,_ end: Int,_ index: Int,_ node: Int,_ newValue: Int) {
    // 수정하려는 값의 위치가 범위를 벗어나면 종료
    if node > end || node < start { return }
    // 수정하려는 값의 위치가 범위내에 존재하면 해당 구간합도 수정해준다.
    // 이때 수정하려는값을 더해준다.
    segTree[index] += newValue
    // 해당위치에 도달했다면 종료
    // 예: 배열arr의 3번째 원소를 수정한다면 3이 포함된 구간합을 모두 수정 후, 구간합 3~3의 노드까지 도달하면 종료
    if start == end { return }
    let mid = (start + end) / 2
    // 세그먼트트리의 왼쪽자식노드와 오른쪽자식노드 모두 접근
    segmentUpdate(start, mid, index*2, node, newValue)
    segmentUpdate(mid+1, end, index*2+1, node, newValue)
}

결과:

// 3번째 위치한 원소의 값을 3 증가시킨다. (arr[3]을 6으로 변경)
segmentUpdate(1, arr.count-1, 1, 3, 3)
print(segTree)
//     1   2       4                9 -> 변경된 값
// [0, 58, 18, 40, 9, 9, 21, 19, 3, 6, 4, 5, 13, 8, 9, 10, 1, 2, 0, 0, 0, 0, 0, 0, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

 

이렇게 세그먼트리를 이용한다면 특정 구간의 합을 구하거나 수정할 때, 선형탐색 O(N) 보다 더 효율적이다 O(logN)

 

전체코드:

import Foundation

// 연속된 데이터 예시
// 트리의 특성(왼쪽자식인덱스:index*2, 오른쪽자식인덱스:index*2+1)를 이용하기위해 배열 첫번째는 임의의데이터를 넣는다.
var arr = Array(0...10)
// 크기 할당
var segTree = Array(repeating: 0, count: arr.count*4)

// 세그먼트 트리 만들기 (재귀호출)
// 구간합을 구해서 더한값이 부모노드가 됨
// 재귀호출로 맨 아래의 노드(start=end)까지 도달 후 더해서 부모노드를 채워가는 형식으로 진행됨
// start: 배열arr의 첫번째인덱스, end: 배열arr의 마지막인덱스
// index: 세그먼트트리의 인덱스
func segmentTreeInit(_ start: Int, _ end: Int, _ index: Int) -> Int{
    // 구간의 start == end는 해당 위치의 원소를 의미
    // 배열 arr의 3부터3까지의 구간합은 결국 arr[3]이다.
    if start == end {
        segTree[index] = arr[start]
        return segTree[index]
    }
    // 구간을 둘로 나누어 트리 특성상 왼쪽노드의 위치는 index*2, 오른쪽노드는 index*2+1
    var mid = (start + end) / 2
    // 재귀호출
    // 해당 구간합은 그 위치의 왼쪽 자식노드와 오른쪽자식노드의 합이다.
    segTree[index] = segmentTreeInit(start, mid, index*2) + segmentTreeInit(mid+1, end, index*2+1)
    return segTree[index]
}

// 특정 구간합 구하기
// start: 배열arr의 시작인덱스, end: 배열arr의 마지막인덱스
// index: 세그먼트 트리의 인덱스(재귀호출을위해 1로 입력)
// left: 구하려는 범위의 시작인덱스, right: 구하려는범위의 마지막인덱스
func segmentTreeSum(_ start: Int,_ end: Int,_ index: Int,_ left: Int,_ right: Int) -> Int{
    // left...right 범위를 벗어났다면 0리턴
    if left > end || right < start { return 0 }
    // left...right 범위내에 있는 구간합이라면 출력
    if left <= start && right >= end { return segTree[index] }
    // 자식노드로 재귀호출
    // 범위내의 구간합을 찾는다.
    let mid = (start + end) / 2
    return segmentTreeSum(start, mid, index*2, left, right) + segmentTreeSum(mid+1, end, index*2+1, left, right)
}

// 특정원소의 값 수정(증가 혹은 감소)
// start: 배열arr의 시작인덱스, end: 배열arr의 마지막인덱스
// index: 세그먼트트리의 인덱스
// node: 변경하려는 원소의 위치
// newValue: 변경값(증가 혹은 감소)
func segmentUpdate(_ start: Int,_ end: Int,_ index: Int,_ node: Int,_ newValue: Int) {
    // 수정하려는 값의 위치가 범위를 벗어나면 종료
    if node > end || node < start { return }
    // 수정하려는 값의 위치가 범위내에 존재하면 해당 구간합도 수정해준다.
    // 이때 수정하려는값을 더해준다.
    segTree[index] += newValue
    // 해당위치에 도달했다면 종료
    // 예: 배열arr의 3번째 원소를 수정한다면 3이 포함된 구간합을 모두 수정 후, 구간합 3~3의 노드까지 도달하면 종료
    if start == end { return }
    let mid = (start + end) / 2
    // 세그먼트트리의 왼쪽자식노드와 오른쪽자식노드 모두 접근
    segmentUpdate(start, mid, index*2, node, newValue)
    segmentUpdate(mid+1, end, index*2+1, node, newValue)
}

    
segmentTreeInit(1, arr.count-1, 1)
print(segmentTreeSum(1, arr.count-1, 1, 4, 6))
segmentUpdate(1, arr.count-1, 1, 3, 3)
print(segTree)

 

 

Reference:

https://m.blog.naver.com/ndb796/221282210534

 

41. 세그먼트 트리(Segment Tree)

이번 시간에 다룰 내용은 여러 개의 데이터가 연속적으로 존재할 때 특정한 범위의 데이터의 합을 구하는 ...

blog.naver.com

 

'Swift > 자료구조' 카테고리의 다른 글

트라이(Trie) Swift  (0) 2024.04.30
분리집합과 Union-Find Swift  (1) 2024.04.28
Swift 덱(Deque) 구현  (0) 2024.04.03
Swift 큐(Queue) 구현  (0) 2024.04.03
Swift 우선순위 큐(Priority Queue) 구현  (0) 2023.05.17

+ Recent posts