문제 링크

풀이

구간 합 구하기 문제와 연산의 종류만 바뀌었지만 업데이트와 연산 쿼리를 생각보다 많이 수정해야 한다. 덧셈과 곱셈의 특징을 잘 생각하며 두 세그먼트 트리의 차이점을 고려해 코드를 수정해보자. 첫번째로 두 연산의 단위 차이가 있다. 덧셈은 더해서 자기 자신이 나오려면 0을 더해야하지만 곱셈은 1을 곱해야 자기 자신이 나온다. 따라서 연산 쿼리를 수행할 때 범위를 벗어난 구간은 0이 아니라 1을 반환해야한다. 두번째로 덧셈 세그먼트 트리에서는 업데이트 쿼리 시 변화량을 반환하며 해당 인덱스를 포함하는 부모 노드들에 변화량을 추가하는 방식으로 구현했다. 이는 덧셈에만 가능하고 곱셈에서는 구현할 수 없는 방식이다. 따라서 곱셈 세그먼트 트리에서는 부모 노드를 업데이트할 때 업데이트되지 않은 반대쪽 세그먼트 트리의 값도 참조해 부모 노드의 값을 새로 계산해야한다. 그림으로 나타내면 아래와 같다. ( 왜 부모 노드의 반대쪽 자식의 값도 계산해야 하는지 설명하는 그림 ) 위에서 보이는 것처럼 반대쪽 자식의 값과 업데이트된 자식의 값을 곱해 새로운 부모의 값을 구할 수 있다. 그렇기 때문에 곱셈 세그먼트 트리의 업데이트 쿼리에서는 범위를 완전히 벗어난 구간이 나왔을 때 현재 노드의 값을 반환해야한다.

코드

#include <iostream>
#include <algorithm>
#include <vector>
#include <cmath>
using namespace std;
using ll = long long;
const ll mod = 1e9 + 7;
int n, m, k;
vector<ll> arr, seg;
// n : tree 기준 노드 번호
// s : 배열 구간 시작 인덱스
// e : 배열 구간 종료 종료 인덱스
ll init(int n, int s, int e) {
    if (s == e) { // 시작 인덱스와 종료 인덱스가 같다면 리프 노드이므로 배열값 반환
        return seg[n] = arr[s];
    }
    int mid = (s + e) / 2;
    ll s1 = init(n * 2, s, mid); // 왼자식 탐색
    ll s2 = init(n * 2 + 1, mid + 1, e); // 오른자식 탐색
    return seg[n] = (s1 * s2) % mod; // 양쪽 자식 곱한 값 반환
}
// n : tree 기준 노드 번호
// t : 업데이트 대상 인덱스
// p : 업데이트 값
// s : 배열 구간 시작 인덱스
// e : 배열 구간 종료 종료 인덱스 
ll update(int n, int t, ll p, int s, int e) {
    if (t < s || t > e) return seg[n]; // 구간이 업데이트 대상 인덱스를 포함하지 않으면 현재 노드값 반환
    if (s == e) { // 업데이트 대상 리프 노드를 찾으면        
        return seg[n] = p; // 업데이트 후 값 반환
    }
    int mid = (s + e) / 2;
    ll s1 = update(n * 2, t, p, s, mid); // 왼자식 탐색
    ll s2 = update(n * 2 + 1, t, p, mid + 1, e); // 오른자식 탐색
    return seg[n] = (s1 * s2) % mod; // 양쪽 자식의 값을 곱한 값으로 부모 업데이트 및 반환    
}
// n : tree 기준 노드 번호
// l : 합을 구하는 구간 시작 인덱스
// r : 합을 구하는 구간 종료 인덱스
// s : 배열 구간 시작 인덱스
// e : 배열 구간 종료 종료 인덱스 
ll getmult(int n, int l, int r, int s, int e) {
    if (l <= s && e <= r) return seg[n]; // 구간을 완전히 포함하는 경우, 노드 값 반환
    if (r < s || e < l) return 1; // 구간이 전혀 포함되지 않는 경우, 곱셈 단위인 1 반환
    int mid = (s + e) / 2;
    ll s1 = getmult(n * 2, l, r, s, mid); // 왼자식 탐색 
    ll s2 = getmult(n * 2 + 1, l, r, mid + 1, e); // 오른자식 탐색
    return (s1 * s2) % mod; // 자식 곱 반환
}
int main()
{
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    cin >> n >> m >> k; arr = vector<ll>(n + 1); seg = vector<ll>((int)pow(2, (int)ceil(log2(n)) + 1));
    m += k;
    for (int i = 1; i <= n; i++) {
        cin >> arr[i];
    }
    init(1, 1, n);
    for (int i = 1; i <= m; i++) {
        ll a, b, c; cin >> a >> b >> c;
        if (a == 1) {
            update(1, b, c, 1, n);            
        }
        if (a == 2) {
            cout << getmult(1, b, c, 1, n) << "\n";
        }
    }
    return 0;
}