본문 바로가기

백준/다이아

BOJ 16901 - XOR MST (D4)

본인 풀이로 푼 사람이 없는 것 같아서 써 본다.

알고리즘 분류
Union Find, Trie(xor minimization/update), MST(idea)

 

풀이

일단 생각을 해보자. 일단 크루스칼 알고리즘 기준으로 생각해보면, 아직 하나도 병합되지 않은 정점이 처음 병합될 때, 당연히 가중치가 최소인 간선을 통해 병합될 것이다. 그럼 얘를 찾아주자. 가중치가 XOR이므로 자신을 제외한 정점 중 xor했을 때 최소인 정점과 연결해야 할 것이다. 이는 Trie를 이용하면 O(비트 개수)에 찾을 수 있다는 사실이 알려져 있다.

일단 이 과정을 모든 정점에 대해 한번 반복했다고 하자. 여기서 2가지 관찰을 할 수 있다.

1. 한 번 병합된 컴포넌트들을 다시 하나의 정점 그룹으로 생각하면, 각 정점 그룹에서도 동일한 과정을 시행할 수 있다.

2. 이 때 정점 그룹의 개수는 그 전 과정의 절반 이하이다.

1은 아래에서 다시 언급하고, 2번에 대해 생각해보자.

한번 과정을 수행하면 각 정점이 포함된 컴포넌트의 크기는 적어도 2가 된다. 그렇지 않으면 제대로 수행한 것이 아니기 때문이다.

즉 크기가 처음에 모두 1이었는데 이제 모두 2 이상이 되었으므로 컴포넌트 개수가 절반 이하로 줄어듦을 알 수 있다.

즉 이 과정은 MST가 만들어지기 전까지 O(logN)번만 발생한다.

결론적으로, 이 과정을 O(T)에 수행할 수 있다면 O(TlogN)에 문제를 풀 수 있다. 이제 이 과정을 빠르게 수행하는 방법을 생각하자.

이를 위해서는 1번을 빠르게 할 수 있어야 한다. 생각해보면 이건 결국 각 그룹 내 모든 정점에 대해 (다른 모든 그룹의 정점 중에서 xor했을 때 최솟값) 중 최솟값임을 알 수 있다. 이걸 위해서는 Trie에서 그룹 내의 모든 값을 지우고, Trie에서 최솟값을 찾고, 다시 그룹 내의 모든 값을 추가해야 한다. 이는 한 그룹의 크기를 K라 할 때 O(KlogX)가 걸린다. X는 2^30이다. K=N이므로 O(NlogX)가 걸린다는 것을 알 수 있다. 즉 한번의 과정이 O(NlogX)가 걸리므로 문제를 O(NlogNlogX)에 해결할 수 있다.

 

본인은 구현에서 많이 TLE가 났다. 그 이유는 O(NlogN(logX + logN))으로 구현했기 때문이다. 이를 해결하기 위해 다음과 같은 최적화를 적용했다.

1. set -> priority queue 2개

2. 쓸데없는 낭비 모두 제거

3. 구현에서 필요한 절차를 빠르게 하기 위해 std::list 사용

4. small to large

5. #pragma GCC optimize

6. 세그를 비재귀로 구현

다 쓸모 없는 짓이었고, 결국 1에서 언급한 priority queue를 아예 쓰지 않는 방식으로 바꿔서 O(NlogNlogX)에 풀었다.

더보기
#include <bits/stdc++.h>
using namespace std;
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define all(x) x.begin(), x.end()
#define rll(x) x.rbegin(), x.rend()
#define comp(x) x.erase(unique(all(x)), x.end())
#define MOD 1000000007
#define MOD2 998244353
#define debug(x) cout << #x<<" is "<<x<<"\n";
#define X first
#define Y second
#define DEB cout<<"[DEBUG]"
#define PAIR(a,b) "("<<a<<", "<<b<<")"
#define PRINT1(V) DEB<<#V<<endl; for(auto i : V)DEB<<i<<"\n"
#define PRINT2(V) DEB<<#V<<endl; for(auto [a,b] : V)DEB<<PAIR(a,b)<<"\n";
typedef long long ll;
typedef long double ld;
typedef pair<ll,ll> P;
typedef pair<ll,P> PP;
ll n;
struct Node{
    ll l, r, d;
};
struct segtree{
    vector<Node> tree;
    void ins(){ tree.push_back({-1, -1, 0}); }
    void upd(ll node, ll s, ll e, ll i, ll d){
        while(s ^ e){
            tree[node].d += d;
            ll mid = s+e>>1;
            if(i <= mid){
                if(tree[node].l<0)tree[node].l = tree.size(), ins();
                node = tree[node].l, e = mid;
            }
            else{
                if(tree[node].r<0)tree[node].r = tree.size(), ins();
                node = tree[node].r, s = mid+1;
            }
        }
        tree[node].d += d;
    }
    ll find_minXOR(ll node, ll s, ll e, ll x, ll k){
        while(k >= 0){
            ll curbit = (x&(1<<k) ? 1:0);
            ll mid = s+e>>1;
            if(curbit==0){
                if(tree[node].l>=0 and tree[tree[node].l].d > 0)node = tree[node].l, e = mid;
                else node = tree[node].r, s = mid+1;
            }
            else{
                if(tree[node].r >= 0 and tree[tree[node].r].d > 0)node = tree[node].r, s = mid+1;
                else node = tree[node].l, e = mid;
            }
            k--;
        }
        return s;
    }
} seg;
const ll SZ = (1<<30)-1;
ll p[200001];
ll num[200001];
vector<ll> element[200001];
ll ans;
list<ll> lt;
ll find(ll x){
    if(p[x]<0)return x;
    return p[x] = find(p[x]);
}
void merge(ll x, ll y, ll w){
    x = find(x), y = find(y);
    if(x==y)return;
    if(element[num[x]].size() < element[num[y]].size())swap(x,y);
    ans += w;
    p[x] += p[y];
    p[y] = x;
    for(auto i : element[num[y]])element[num[x]].push_back(i);
    element[num[y]].clear();
}
vector<ll> v,t;
map<ll,ll> mp;
ll cp[200001];
bool merged[200001];
int main(){
    fast;
    seg.ins();
    memset(p,-1,sizeof(p));
    cin>>n;
    v.resize(n);
    for(auto &i : v)cin>>i, t.push_back(i);
    sort(all(v)); comp(v);
    n = v.size();
    for(int i = 0 ; i < n ; i++){
        lt.push_back(i);
        num[i]=i;
        mp[v[i]] = i;
        seg.upd(0,0,SZ,v[i],1);
        element[num[i]].push_back(i);
    }
    while(-p[find(1)] < n){
        memset(merged,0,sizeof(merged));
        for(auto it = lt.begin() ;; it++){
            while(it != lt.end() and p[*it]>=0)it = lt.erase(it);
            if(it == lt.end())break;
            auto &V = element[num[*it]];
            if(merged[find(V[0])])continue;
            for(auto i : V)seg.upd(0,0,SZ,v[i],-1);
            ll mn = 1e18;
            ll x;
            for(auto i : V){
                ll tmp = seg.find_minXOR(0,0,SZ,v[i],29);
                if(mn > (v[i] ^ tmp)){
                    mn = (v[i] ^ tmp);
                    x = tmp;
                }
            }
            ll idx = mp[x];
            for(auto i : V)seg.upd(0,0,SZ,v[i],1);
            merge(V[0],idx,mn);
            merged[find(V[0])]=1;
        }
    }
    cout<<ans;
}

'백준 > 다이아' 카테고리의 다른 글