문제 링크

풀이

가장 기본적인 세그먼트 트리 문제이다. 세그먼트 트리 설명을 읽고 구현해보자.

코드

#include <iostream>
#include <algorithm>
#include <vector>
#include <cmath>
using namespace std;
using ll = long long;
int n, m, k;
vector<ll> arr, seg;
// n : tree 기준 노드 번호
// s : 배열 구간 시작 인덱스
// e : 배열 구간 종료 종료 인덱스
ll init(int n, int s, int e) {
    if (s == e) { // 시작 인덱스와 종료 인덱스가 같다면 리프 노드이므로 배열값 반환
        return seg[n] = arr[s];
    }
    int mid = (s + e) / 2;
    ll s1 = init(n * 2, s, mid); // 왼자식 탐색
    ll s2 = init(n * 2 + 1, mid + 1, e); // 오른자식 탐색
    return seg[n] = s1 + s2; // 양쪽 자식 합친 값 반환
}
// n : tree 기준 노드 번호
// t : 업데이트 대상 인덱스
// p : 업데이트 값
// s : 배열 구간 시작 인덱스
// e : 배열 구간 종료 종료 인덱스 
ll update(int n, int t, ll p, int s, int e) {
    if (t < s || t > e) return 0; // 구간이 업데이트 대상 인덱스를 포함하지 않으면 변화량 0  
    if (s == e) { // 업데이트 대상 리프 노드를 찾으면
        ll tmp = seg[n];
        seg[n] = p;
        return p - tmp; // 해당 노드의 변화량 반환
    }
    int mid = (s + e) / 2;
    ll s1 = update(n * 2, t, p, s, mid); // 왼자식 탐색
    ll s2 = update(n * 2 + 1, t, p, mid + 1, e); // 오른자식 탐색
    seg[n] += s1 + s2; // 자식의 변화량 만큼 부모도 업데이트
    return s1 + s2; // 변화량 반환
}
// n : tree 기준 노드 번호
// l : 합을 구하는 구간 시작 인덱스
// r : 합을 구하는 구간 종료 인덱스
// s : 배열 구간 시작 인덱스
// e : 배열 구간 종료 종료 인덱스 
ll getsum(int n, int l, int r, int s, int e) {
    if (l <= s && e <= r) return seg[n]; // 구간을 완전히 포함하는 경우, 노드 값 반환
    if (r < s || e < l) return 0; // 구간이 전혀 포함되지 않는 경우, 탐색 중지
    int mid = (s + e) / 2;
    ll s1 = getsum(n * 2, l, r, s, mid); // 왼자식 탐색 
    ll s2 = getsum(n * 2 + 1, l, r, mid + 1, e); // 오른자식 탐색
    return s1 + s2; // 자식 합계 반환
}
int main()
{
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    cin >> n >> m >> k; arr = vector<ll>(n + 1); seg = vector<ll>((int)pow(2, (int)ceil(log2(n)) + 1));
    m += k;
    for (int i = 1; i <= n; i++) {
        cin >> arr[i];
    }
    init(1, 1, n);
    for (int i = 1; i <= m; i++) {
        ll a, b, c; cin >> a >> b >> c;
        if (a == 1) {
            update(1, b, c, 1, n);            
        }
        if (a == 2) {
            cout << getsum(1, b, c, 1, n) << "\n";
        }
    }
    return 0;
}