728x90
세그먼트 트리(Segment Tree)란
세그먼트 트리(Segment Tree)는 효율적으로 배열 또는 리스트와 같은 순차적인 데이터 구조에서 구간 쿼리(구간 검색 또는 구간 연산)를 수행하기 위한 자료구조입니다. 주로 구간 합, 구간 최솟값, 구간 최댓 값 등을 빠르게 계산하는데 사용됩니다.
세그먼트 트리는 트리 구조로 표현되며, 각 노드는 배열의 일부 구간에 대한 정보를 저장합니다. 일반적으로 이진 트리 형태를 가지며, 아래와 같은 구성요소를 가집니다.
- 루트 노드 : 배열 전체 구간에 대한 정보를 저장합니다.
- 각 내부 노드 : 두 자식 노드의 구간 정보를 합친 결과를 저장합니다. 이를 통해 트리를 아래로 내려가면서 구간 정보를 계산할 수 있습니다.
- 리프 노드 : 배열의 개별 원소를 나타냅니다.
세그먼트 트리의 구성은 일반적으로 재귀적으로 정의되며, 구간 쿼리를 수행할 때 효율적으로 작동합니다. 주어진 구간의 합, 최솟값, 최댓값 등을 빠르게 계산하려면 세그먼트 트리를 미리 구축해야합니다. 그런 다음, 구간 쿼리가 필요할 때 해당 구간에 해당하는 노드를 찾아서 필요한 정보를 가져오는 방식으로 작동합니다.
구현
구간합을 구하는 세그먼트 트리 예제
public SegmentTree(int[] nums) {
this.nums = nums;
int n = nums.length;
// 세그먼트 트리의 크기는 주어진 배열 크기의 4배 이상이어야 합니다.
tree = new int[4 * n];
build(0, 0, n - 1);
}
SegmentTree(int[] nums)
- 생성자 함수로, 세그먼트 트리를 초기화합니다.
- nums 배열을 입력으로 받아 세그먼트 트리를 구축합니다.
- 세그먼트 트리를 구축하기 위해 build 함수를 호출합니다.
private void build(int node, int start, int end) {
if (start == end) {
tree[node] = nums[start];
} else {
int mid = (start + end) / 2;
int leftNode = 2 * node + 1;
int rightNode = 2 * node + 2;
build(leftNode, start, mid);
build(rightNode, mid + 1, end);
tree[node] = tree[leftNode] + tree[rightNode]; // 구간 합을 저장
}
}
build(int node, int start, int end)
- 세그먼트 트리를 구축하는 재귀함수입니다.
- node는 현재 노드의 인덱스를 나타냅니다.
- start와 end는 현재 노드가 나타내는 구간을 나타냅니다.
- 재귀적으로 왼쪽 자식 노드와 오른쪽 자식 노드를 구축하고, 해당 구간의 합을 현재 노드에 저장합니다.
public void update(int index, int val) {
update(0, 0, nums.length - 1, index, val);
}
private void update(int node, int start, int end, int index, int val) {
if (start == end) {
nums[index] = val;
tree[node] = val;
} else {
int mid = (start + end) / 2;
int leftNode = 2 * node + 1;
int rightNode = 2 * node + 2;
if (index <= mid) {
update(leftNode, start, mid, index, val);
} else {
update(rightNode, mid + 1, end, index, val);
}
tree[node] = tree[leftNode] + tree[rightNode]; // 구간 합 업데이트
}
}
update(int index, int val)
- 배열의 특정 인덱스를 업데이트하는 함수입니다.
- index는 업데이트할 인데스를 나타냅니다.
- val은 해당 인게스의 값을 업데이트할 값으로 설정합니다.
- update 함수를 호출하면 해당 인덱스의 값을 업데이트하고, 이에 업데이트 됩니다.
public int query(int left, int right) {
return query(0, 0, nums.length - 1, left, right);
}
private int query(int node, int start, int end, int left, int right) {
if (right < start || left > end) {
return 0; // 범위가 겹치지 않으면 0 반환
} else if (left <= start && right >= end) {
return tree[node]; // 현재 노드가 범위 내에 있으면 노드 값 반환
} else {
int mid = (start + end) / 2;
int leftNode = 2 * node + 1;
int rightNode = 2 * node + 2;
int leftSum = query(leftNode, start, mid, left, right);
int rightSum = query(rightNode, mid + 1, end, left, right);
return leftSum + rightSum; // 왼쪽 서브트리와 오른쪽 서브트리의 결과를 합산하여 반환
}
}
query(int left, int right)
- 구간 합을 계산하는 함수입니다.
- left와 right는 구간의 시작과 끝을 나타냅니다.
- query 함수를 호출하면 해당 구간의 합을 반환합니다.
- 재귀적으로 세그먼트 트리를 탐색하면서 구간을 나누어 합을 계산합니다.
- 효율적인 구간 합 계산을 위해 세그먼트 트리를 사용합니다.
전체 코드
class SegmentTree {
private int[] tree;
private int[] nums;
public SegmentTree(int[] nums) {
this.nums = nums;
int n = nums.length;
// 세그먼트 트리의 크기는 주어진 배열 크기의 4배 이상이어야 합니다.
tree = new int[4 * n];
build(0, 0, n - 1);
}
private void build(int node, int start, int end) {
if (start == end) {
tree[node] = nums[start];
} else {
int mid = (start + end) / 2;
int leftNode = 2 * node + 1;
int rightNode = 2 * node + 2;
build(leftNode, start, mid);
build(rightNode, mid + 1, end);
tree[node] = tree[leftNode] + tree[rightNode]; // 구간 합을 저장
}
}
public void update(int index, int val) {
update(0, 0, nums.length - 1, index, val);
}
private void update(int node, int start, int end, int index, int val) {
if (start == end) {
nums[index] = val;
tree[node] = val;
} else {
int mid = (start + end) / 2;
int leftNode = 2 * node + 1;
int rightNode = 2 * node + 2;
if (index <= mid) {
update(leftNode, start, mid, index, val);
} else {
update(rightNode, mid + 1, end, index, val);
}
tree[node] = tree[leftNode] + tree[rightNode]; // 구간 합 업데이트
}
}
public int query(int left, int right) {
return query(0, 0, nums.length - 1, left, right);
}
private int query(int node, int start, int end, int left, int right) {
if (right < start || left > end) {
return 0; // 범위가 겹치지 않으면 0 반환
} else if (left <= start && right >= end) {
return tree[node]; // 현재 노드가 범위 내에 있으면 노드 값 반환
} else {
int mid = (start + end) / 2;
int leftNode = 2 * node + 1;
int rightNode = 2 * node + 2;
int leftSum = query(leftNode, start, mid, left, right);
int rightSum = query(rightNode, mid + 1, end, left, right);
return leftSum + rightSum; // 왼쪽 서브트리와 오른쪽 서브트리의 결과를 합산하여 반환
}
}
}
참고
https://codingnojam.tistory.com/49
https://yoongrammer.tistory.com/103
728x90
'Computer Science > Algorithm' 카테고리의 다른 글
[Algorithm] 비트마스크 (BitMask) 알고리즘에 대해서 공부하자. (0) | 2023.09.30 |
---|---|
[Algorithm] 에라토스테네스의 체, 소수찾기 feat. Java (0) | 2023.08.16 |
[Algorithm] 최소공배수, 최대공약수와 유클리드 알고리즘 (GCD, LCM) (0) | 2023.08.16 |
[Algorithm] LIS, LDS (최장 증가 부분 수열, 최장 감소 부분 수열) (0) | 2023.07.31 |
[Algorithm] 동적 계획법(Dynamic Programming) (0) | 2023.07.30 |