코딩테스트

[백준][세그먼트트리] 최솟값과 최댓값

pythaac 2022. 3. 23. 03:20
BAEKJOON Online Judge(BOJ) 문제입니다.

https://www.acmicpc.net/

 

Baekjoon Online Judge

Baekjoon Online Judge 프로그래밍 문제를 풀고 온라인으로 채점받을 수 있는 곳입니다.

www.acmicpc.net

 

문제

https://www.acmicpc.net/problem/2357

 

2357번: 최솟값과 최댓값

N(1 ≤ N ≤ 100,000)개의 정수들이 있을 때, a번째 정수부터 b번째 정수까지 중에서 제일 작은 정수, 또는 제일 큰 정수를 찾는 것은 어려운 일이 아니다. 하지만 이와 같은 a, b의 쌍이 M(1 ≤ M ≤ 100

www.acmicpc.net

 

내가 작성한 코드

import sys
sys.setrecursionlimit(10 ** 6)

read = sys.stdin.readline

def read_data():
    N, M = map(int, read().rstrip().split())
    arr = []
    for _ in range(N):
        arr.append(int(read().rstrip()))
    return N, M, arr

def update_mx_segtree(N, num, target, mx_tree, l, r, idx):
    if target < l or r < target:
        return
    mx_tree[idx] = max(mx_tree[idx], num)
    if l == r == target:
        return

    mid = l + ((r - l) // 2)
    update_mx_segtree(N, num, target, mx_tree, l, mid, (idx*2)+1)
    update_mx_segtree(N, num, target, mx_tree, mid+1, r, (idx*2)+2)

def get_mx_segtree(N, arr, mx_tree):
    for i, num in enumerate(arr):
        update_mx_segtree(N, num, i, mx_tree, 0, N-1, 0)


def update_mn_segtree(N, num, target, mn_tree, l, r, idx):
    if target < l or r < target:
        return
    mn_tree[idx] = min(mn_tree[idx], num)
    if l == r == target:
        return

    mid = l + ((r - l) // 2)
    update_mn_segtree(N, num, target, mn_tree, l, mid, (idx * 2) + 1)
    update_mn_segtree(N, num, target, mn_tree, mid + 1, r, (idx * 2) + 2)


def get_mn_segtree(N, arr, mn_tree):
    for i, num in enumerate(arr):
        update_mn_segtree(N, num, i, mn_tree, 0, N-1, 0)


def find_mx_segtree(tree, target_l, target_r, l, r, idx):
    if target_l == l and target_r == r:
        return tree[idx]

    mid = l + ((r - l) // 2)
    if target_r <= mid:
        return find_mx_segtree(tree, target_l, target_r, l, mid, (idx * 2) + 1)
    elif mid < target_l:
        return find_mx_segtree(tree, target_l, target_r, mid + 1, r, (idx * 2) + 2)
    else:
        return max(
            find_mx_segtree(tree, target_l, mid, l, mid, (idx * 2) + 1),
            find_mx_segtree(tree, mid + 1, target_r, mid + 1, r, (idx * 2) + 2)
        )

def find_mn_segtree(tree, target_l, target_r, l, r, idx):
    if target_l == l and target_r == r:
        return tree[idx]

    mid = l + ((r-l) // 2)
    if target_r <= mid:
        return find_mn_segtree(tree, target_l, target_r, l, mid, (idx*2)+1)
    elif mid < target_l:
        return find_mn_segtree(tree, target_l, target_r, mid+1, r, (idx * 2) + 2)
    else:
        return min(
            find_mn_segtree(tree, target_l, mid, l, mid, (idx * 2) + 1),
            find_mn_segtree(tree, mid+1, target_r, mid + 1, r, (idx * 2) + 2)
        )



def print_answer(N, M, mx_tree, mn_tree):
    for _ in range(M):
        a, b = map(int, read().rstrip().split())
        print(find_mn_segtree(mn_tree, a-1, b-1, 0, N-1, 0),
              find_mx_segtree(mx_tree, a-1, b-1, 0, N-1, 0))

N, M, arr = read_data()
mx_tree, mn_tree = [0 for _ in range(2 ** 20)], [sys.maxsize for _ in range(2 ** 20)]
get_mx_segtree(N, arr, mx_tree)
get_mn_segtree(N, arr, mn_tree)
print_answer(N, M, mx_tree, mn_tree)
  • 세그먼트 트리
    • 구간합 대신 max/min으로 비교하여 저장

 

다른 사람이 작성한 코드

import math
import sys

sys.setrecursionlimit(10 ** 8)  # pypy 제출시 삭제!
input = lambda: sys.stdin.readline().rstrip()
# in_range = lambda y,x: 0<=y<n and 0<=x<m


def make_seg(idx, s, e):

    if s == e:
        seg[idx] = (arr[s], arr[s])  # min, max
        return seg[idx]

    mid = (s + e) // 2

    l = make_seg(idx * 2, s, mid)
    r = make_seg(idx * 2 + 1, mid + 1, e)

    seg[idx] = (min(l[0], r[0]), max(l[1], r[1]))
    return seg[idx]


def f(s, e, idx):

    # 탐색범위 s~e
    if e < a or b < s:  # 범위 밖
        return (1000000000, 0)

    mid = (s + e) // 2

    if a <= s and e <= b:  # 탐색 범위가 작아서 다 리턴
        return seg[idx]

    else:
        l = f(s, mid, idx * 2)
        r = f(mid + 1, e, idx * 2 + 1)
        return (min(l[0], r[0]), max(l[1], r[1]))


n, m = map(int, input().split())
arr = [int(input()) for _ in range(n)]

b = math.ceil(math.log2(n)) + 1
node_n = 1 << b
seg = [0 for _ in range(node_n)]
make_seg(1, 0, len(arr) - 1)

for _ in range(m):
    a, b = map(int, input().split())
    a, b = a - 1, b - 1  # idx
    ans = f(0, len(arr) - 1, 1)
    print(ans[0], ans[1])
  • min/max를 하나로 작성
    • 훨씬 간결하고 깔끔하다
  • f()
    • 탐색 범위보다 작으면 모두 return

https://velog.io/@sunkyuj/python-%EB%B0%B1%EC%A4%80-2357-%EC%B5%9C%EC%86%9F%EA%B0%92%EA%B3%BC-%EC%B5%9C%EB%8C%93%EA%B0%92

 

[python] 백준 2357 : 최솟값과 최댓값

세그먼트 트리를 이용한 풀이

velog.io

기억해야할 것

  • 세그먼트 트리의 재귀
    • mid+1 같은 인덱스를 자꾸 헷갈려함