문제 링크

병합정렬로 푸는 방법도 있다. 사실 병합정렬로 푸는 방법이 더 쉽고 세그먼트 트리 쓰는 방법이 더 어렵다. 병합 정렬은 나중에 씀 ( 2024.08.14 )

풀이

버블 정렬에서 Swap 이 일어나는 횟수는 나보다 인덱스가 큰데 나보다 작은 수의 개수의 합이다. 예제 3 2 1에서 2보다 큰 수가 3 하나 있고, 1보다 큰 수 2 3 이 2개 있으므로 합 3이 버블 정렬 시 발생하는 Swap 횟수다. 문제는 주어지는 수열의 개수가 상당히 커서 내 앞의 모든 수를 다 검사하는 알고리즘을 작성하면 으로 시간 초과가 발생한다. 이 검사를 세그먼트 트리를 이용하면 으로 최적화 할 수 있다. 먼저 수열의 원래 인덱스를 미리 저장한 뒤, 정렬한다. 그 후, 정렬된 수열을 순회하며 현재 수의 인덱스를 1로 업데이트 한다. 이렇게 하면 현재 수보다 먼저 나온 수들, 즉 현재 수보다 작은 수들 중 현재 수보다 원래 인덱스가 큰 수들만 세그먼트 트리에 업데이트 돼있는 상태가 된다. 즉, 현재 수의 원래 인덱스를 라고 하면 구간의 합이 현재 수보다 뒤에 있는 현재 수보다 작은 수의 개수가 된다. 세그먼트 트리를 사용하면 업데이트와 구간합을 각각 에 완료할 수 있다.

주어진 예제에 위 알고리즘을 적용하면 다음과 같이 진행된다. 3 2 1을 정렬하면 1 2 3이고, 각 수의 원래 인덱스는 3 2 1이다. 정렬한 수열 1 2 3을 순회해보자.

현재 수 1의 원래 인덱스는 3이다. 따라서 인덱스 3을 1로 업데이트 한다. 현재 세그트리 배열은 0 0 1 이므로 의 합은 이다. sum = 0

그 다음 수 2의 원래 인덱스는 2이다. 따라서 인덱스 2를 1로 업데이트 한다. 현재 세그트리 배열은 0 1 1 의 합은 이다. sum = 1

그 다음 수 3의 원래 인덱스는 1이다. 따라서 인덱스 3을 1로 업데이트 한다. 현재 세그트리 배열은 1 1 1 의 합은 2이다. sum = 3

따라서 최종 답은 3이 된다.

단순히 업데이트만하는게 아니라 시간축까지 고려해야 해서 어려운 문제였다.

코드

#include <iostream>
#include <algorithm>
#include <vector>
#include <cmath>
using namespace std;
using ll = long long;
// v : 값
// i : 인덱스
struct node {
    int v, i;
};
int n;
vector<int> seg;
vector<node> arr;
bool cmp(node a, node b) {
    if (a.v == b.v) { // 교환의 최소 횟수를 구하는 것이므로 수가 같은 경우는 입력 인덱스 순으로 배열
        return a.i < b.i;
    }
    return a.v < b.v;
}
// n : tree 기준 노드 번호
// l : 합을 구하는 구간 시작 인덱스
// r : 합을 구하는 구간 종료 인덱스
// s : 배열 구간 시작 인덱스
// e : 배열 구간 종료 종료 인덱스 
int getsum(int n, int l, int r, int s, int e) {
    if (l <= s && e <= r) return seg[n]; // 구간을 완전히 포함하는 경우, 노드 값 반환
    if (r < s || e < l) return 0; // 구간이 전혀 포함되지 않는 경우, 0 반환
    int mid = (s + e) / 2;
    int s1 = getsum(n * 2, l, r, s, mid); // 왼자식 탐색 
    int s2 = getsum(n * 2 + 1, l, r, mid + 1, e); // 오른자식 탐색
    return s1 + s2; // 양쪽 자식 최대 / 최소값 비교 후 더 나은 쪽 리턴
}
// n : tree 기준 노드 번호
// t : 업데이트 대상 배열 인덱스
// s : 배열 구간 시작 인덱스
// e : 배열 구간 종료 종료 인덱스 
void update(int n, int t, int s, int e) {
    if (t < s || t > e) return; // 구간이 대상 인덱스를 전혀 포함하지 않는 경우 업데이트 하지 않음
    seg[n]++; // 구간이 대상 인덱스를 포함하면 전부 + 1. 문제 특성상 변량이 1로 고정임.
    if (s == e) return; // 리프 노드인 경우 탐색 종료
    int mid = (s + e) / 2;
    update(n * 2, t, s, mid); // 왼자식 탐색 
    update(n * 2 + 1, t, mid + 1, e); // 오른 자식 탐색
}
int main()
{
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    cin >> n; arr = vector<node>(n + 1); seg = vector<int>((int)pow(2, (int)ceil(log2(n)) + 1));
    for (int i = 1; i <= n; i++) {
        int tmp; cin >> tmp;
        arr[i] = { tmp, i };
    }    
    sort(arr.begin() + 1, arr.end(), cmp); // 원래 인덱스는 유지하면서 입력 배열 정렬
    ll sum = 0;
    for (int i = 1; i <= n; i++) {
        update(1, arr[i].i, 1, n); // 현재 수의 원래 인덱스를 업데이트하고
        sum += (ll)getsum(1, arr[i].i + 1, n, 1, n); // 현재 수보다 원래 인덱스가 큰 수 중에 현재 수보다 먼저 나온 수( 현재 수보다 작은 수 ) 의 개수를 구한다.
    }
    cout << sum;
    return 0;
}