본문 바로가기

백준/다이아

BOJ 18798 - OR과 쿼리 (D5)

사용 알고리즘

풀이 1 : Fenwick Tree + Segment Tree (binary search) + Offline Query + fastIO (O(30NlogN))

풀이 2 : Fenwick Tree + Segment Tree (bitwise AND) (O(30N + NlogN))

 

풀이 1

이 풀이의 기본 아이디어는 KOI'23 고기 파티 문제 풀이를 기반으로 한다.

모든 수들은 1번 쿼리에 의해서 어떤 시점에서 K가 되었다가 다시 K가 아니게 된다. OR 연산이 단조성을 띄므로 이 K가 되는 구간은 연속하게 된다. 그래서 모든 수에 대해서 이 시점을 기록한 후, 이를 나중에 쭉 처리하는 아이디어를 이용한다.

1번 쿼리에 순서대로 번호를 붙이자. 이제 i=1~n에 대하여 Ai가 언제 K가 되는지 구할 것이다. 이를 위해서는 i를 포함하는 모든 1번 쿼리를 알아야 하는데, 이는 어떤 쿼리를 이벤트로 생각해서 시작/끝점에서 추가/삭제해주는 방식으로 처리 가능하다.

2번 예제의 1번 쿼리를 나타낸 것이다. 1번이 2~4를 포함하므로 2,3,4번을 처리할 때 1번을 추가해 준다. 2번은 6~7이므로 6,7번을 처리할 때 2번을 추가해 준다.

이제 Ai가 K가 되는 최소 쿼리 번호를 구할 것이다. 이는 K에는 있지만 Ai에는 없는 비트들에 대해서 각 비트가 채워지는 쿼리 중 최대 번호와 같다. 그럼 각 비트를 포함하는 최소 쿼리 번호를 알아야 한다. 이는 각 쿼리의 비트를 담당하는 합 세그먼트 트리와 이분 탐색으로 해결 가능하다. 이렇게 저장한 값들을 잘 이용해주면 K가 되는 쿼리 번호(L)를 알 수 있고, 비슷한 방법으로 K를 벗어나는 쿼리 번호(R)도 알 수 있다.

이렇게 L, R을 구해놨으면 L,R도 이벤트로 관리하면서 또 다른 세그먼트 트리를 이용해 전체 쿼리를 다시 한번 훑어주면 문제를 해결할 수 있다.

 

이제 중요한 것은 TLE이다. 이론상 위 풀이의 시간복잡도는 O(XNlogN) (X는 비트 개수) = O(30NlogN)이다. 그런데 이는 약 1.3억으로 1.5초인 시간 제한에 비해 빡빡하다. 그런데 실제로 구현을 해보면 30NlogN * 상수 꼴로 약 6억개의 연산을 필요로 한다. 이를 빠르게 처리하기 위해서 최대한 같이 처리할 수 있는 연산을 같이 처리해주고, 이분 탐색을 세그먼트 트리 내부에서 수행하게 하며 비재귀로 구현하고, 그냥 시간이 낭비될 것 같은 모든 부분을 처리해주고 난 후 마지막으로 fastio를 이용하면 약 1400ms로 간신히 AC를 받는다.

아래는 1432ms로 AC를 받은 코드이다.

더보기
#include <bits/stdc++.h>
#include <sys/stat.h>
#include <sys/mman.h>
#include <unistd.h>
using namespace std;
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma optimize("unroll-loops")
#define all(x) x.begin(), x.end()
#define rll(x) x.rbegin(), x.rend()
#define comp(x) x.erase(unique(all(x)), x.end())
#define MOD 998244353
#define debug(x) cout << #x<<" is "<<x<<"\n";
#define X first
#define Y second
typedef int ll;
typedef long double ld;
typedef pair<ll,ll> P;
typedef pair<ll,P> PP;
/////////////////////////////////////////////////////////////////////////////////////////////
/*
 * Author : jinhan814
 * Date : 2021-05-06
 * Source : https://blog.naver.com/jinhan814/222266396476
 * Description : FastIO implementation for cin, cout. (mmap ver.)
 */
constexpr int SZ = 1 << 20;

class INPUT {
private:
    char* p;
    bool __END_FLAG__, __GETLINE_FLAG__;
public:
    explicit operator bool() { return !__END_FLAG__; }
    INPUT() {
        struct stat st; fstat(0, &st);
        p = (char*)mmap(0, st.st_size, PROT_READ, MAP_SHARED, 0, 0);
    }
    bool IsBlank(char c) { return c == ' ' || c == '\n'; }
    bool IsEnd(char c) { return c == '\0'; }
    char _ReadChar() { return *p++; }
    char ReadChar() {
        char ret = _ReadChar();
        for (; IsBlank(ret); ret = _ReadChar());
        return ret;
    }
    template<typename T> T ReadInt() {
        T ret = 0; char cur = _ReadChar(); bool flag = 0;
        for (; IsBlank(cur); cur = _ReadChar());
        if (cur == '-') flag = 1, cur = _ReadChar();
        for (; !IsBlank(cur) && !IsEnd(cur); cur = _ReadChar()) ret = 10 * ret + (cur & 15);
        if (IsEnd(cur)) __END_FLAG__ = 1;
        return flag ? -ret : ret;
    }
    string ReadString() {
        string ret; char cur = _ReadChar();
        for (; IsBlank(cur); cur = _ReadChar());
        for (; !IsBlank(cur) && !IsEnd(cur); cur = _ReadChar()) ret.push_back(cur);
        if (IsEnd(cur)) __END_FLAG__ = 1;
        return ret;
    }
    double ReadDouble() {
        string ret = ReadString();
        return stod(ret);
    }
    string getline() {
        string ret; char cur = _ReadChar();
        for (; cur != '\n' && !IsEnd(cur); cur = _ReadChar()) ret.push_back(cur);
        if (__GETLINE_FLAG__) __END_FLAG__ = 1;
        if (IsEnd(cur)) __GETLINE_FLAG__ = 1;
        return ret;
    }
    friend INPUT& getline(INPUT& in, string& s) { s = in.getline(); return in; }
} _in;

class OUTPUT {
private:
    char write_buf[SZ];
    int write_idx;
public:
    ~OUTPUT() { Flush(); }
    explicit operator bool() { return 1; }
    void Flush() {
        write(1, write_buf, write_idx);
        write_idx = 0;
    }
    void WriteChar(char c) {
        if (write_idx == SZ) Flush();
        write_buf[write_idx++] = c;
    }
    template<typename T> int GetSize(T n) {
        int ret = 1;
        for (n = n >= 0 ? n : -n; n >= 10; n /= 10) ret++;
        return ret;
    }
    template<typename T> void WriteInt(T n) {
        int sz = GetSize(n);
        if (write_idx + sz >= SZ) Flush();
        if (n < 0) write_buf[write_idx++] = '-', n = -n;
        for (int i = sz; i --> 0; n /= 10) write_buf[write_idx + i] = n % 10 | 48;
        write_idx += sz;
    }
    void WriteString(string s) { for (auto& c : s) WriteChar(c); }
    void WriteDouble(double d) { WriteString(to_string(d)); }
} _out;

/* operators */
INPUT& operator>> (INPUT& in, char& i) { i = in.ReadChar(); return in; }
INPUT& operator>> (INPUT& in, string& i) { i = in.ReadString(); return in; }
template<typename T, typename std::enable_if_t<is_arithmetic_v<T>>* = nullptr>
INPUT& operator>> (INPUT& in, T& i) {
    if constexpr (is_floating_point_v<T>) i = in.ReadDouble();
    else if constexpr (is_integral_v<T>) i = in.ReadInt<T>(); return in; }

OUTPUT& operator<< (OUTPUT& out, char i) { out.WriteChar(i); return out; }
OUTPUT& operator<< (OUTPUT& out, string i) { out.WriteString(i); return out; }
template<typename T, typename std::enable_if_t<is_arithmetic_v<T>>* = nullptr>
OUTPUT& operator<< (OUTPUT& out, T i) {
    if constexpr (is_floating_point_v<T>) out.WriteDouble(i);
    else if constexpr (is_integral_v<T>) out.WriteInt<T>(i); return out; }

/* macros */
#define fastio 1
#define cin _in
#define cout _out
#define istream INPUT
#define ostream OUTPUT
ll n,k,q;
struct segtree{
    vector<int> tree;
    segtree(): tree(1<<19){}
    ll siz = 1<<18;
    void upd(ll node, ll x){
        node += siz-1;
        node <<= 1;
        while(node>>=1)tree[node] += x;
    }
    ll find_kth(ll node, ll p){
        if(node>=siz)return node;
        if(p<tree[node<<1])return find_kth(node<<1,p);
        return find_kth(node<<1|1, p-tree[node<<1]);
    }
} seg[30];
struct fenwick{
    vector<int> tree;
    fenwick(): tree(252525){}
    void upd(ll i, ll x){
        while(i<=n){
            tree[i] += x;
            i += (i&-i);
        }
    }
    ll query(ll i, ll ret=0){
        while(i){
            ret += tree[i];
            i -= (i&-i);
        }
        return ret;
    }
} SEG;
ll a[252525];
vector<P> v1[252525], v2[252525];   //v1 : [쿼리 추가&삭제,값] / v2 : [인덱스, 추가&삭제]
ll isbit[33];
vector<pair<P,P>> query;    //1번이면 {{1,값},{l,r}}, 2번이면 {{2,*},{l,r}}
ll C[33];
int main(){
    cin>>n>>k;
    for(int i = 1 ; i <= n ; i++)cin>>a[i];
    cin>>q;
    query.resize(q);
    ll cnt=0;
    for(auto &[a,b]: query){
        cnt++;
        cin>>a.X>>b.X>>b.Y;
        if(a.X==1){
            cin>>a.Y;
            v1[b.X].push_back({cnt,a.Y});
            v1[b.Y+1].push_back({-cnt,a.Y});
        }
    }
    for(int i = 0 ; i < 30 ; i++)isbit[i] = q+1;
    for(int i = 1 ; i <= n ; i++){
        for(auto [isadd,diff] : v1[i]){
            for(int j = 0 ; j < 30 ; j++){
                if(diff&(1<<j)){
                    if(isadd>0)seg[j].upd(isadd,1), C[j]++;
                    else seg[j].upd(-isadd,-1), C[j]--;
                }
            }
        }
        bool flag = 1;
        ll r = 1;
        ll x = 0;
        for(int j = 0 ; j < 30 ; j++){
            if(a[i]&(1<<j) and !(k&(1<<j))){
                flag = 0;
                break;
            }
            isbit[j] = (C[j]>0 ? seg[j].find_kth(1,0)-seg[j].siz+1 : q+1);
            if(k&(1<<j) and !(a[i]&(1<<j))){
                r = max(r,isbit[j]);
            }
        }
        if(!flag)continue;
        if(r>=q+1)continue;
        ll L = r;
        r = q+1;
        for(int j = 0 ; j < 30 ; j++){
            if((a[i]&(1<<j)) or isbit[j]<=L)x += (1<<j);
            if(!(k&(1<<j)))r = min(r, isbit[j]);
        }
        if(x!=k)continue;
        if(L<r)v2[L].push_back({1,i}), v2[r].push_back({-1,i});
    }
    for(int i = 0 ; i < q ; i++){
        for(auto [isadd,idx] : v2[i+1])SEG.upd(idx,isadd);
        if(query[i].X.X == 2){
            cout<<SEG.query(query[i].Y.Y)-SEG.query(query[i].Y.X-1)<<"\n";
        }
    }
}

풀이 1시간 + 구현&최적화 4시간이 걸렸다. 처음 이 풀이를 찾고 나서는 이 문제를 P1로 기여하려고 했는데 그런 마음이 싹 사라졌다. 다이아는 다이아다.

 

풀이 2

나만 몰랐던 정해이다. 잘 알려져 있으니 자세한 설명은 생략하고, 간단하게만 말하면 OR이 항상 비트를 채우기 때문에 각 수는 30번 정도만 변경이 일어나게 된다. 그래서 수가 바뀌는 위치만 빠르게 찾아서 업데이트해주면 되는데, 이를 bitwise AND를 이용한 세그먼트 트리로 처리할 수 있다.

더보기
using namespace std;
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma optimize("unroll-loops")
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define all(x) x.begin(), x.end()
#define rll(x) x.rbegin(), x.rend()
#define comp(x) x.erase(unique(all(x)), x.end())
#define MOD 998244353
#define debug(x) cout << #x<<" is "<<x<<"\n";
#define X first
#define Y second
typedef long long ll;
typedef long double ld;
typedef pair<ll,ll> P;
typedef pair<ll,P> PP;
ll n, k, q;
struct fenwick{
    vector<int> tree;
    fenwick(): tree(252525){}
    void upd(ll i, ll x){
        while(i<=n){
            tree[i] += x;
            i += (i&-i);
        }
    }
    ll query(ll i, ll ret=0){
        while(i){
            ret += tree[i];
            i -= (i&-i);
        }
        return ret;
    }
} SEG;
ll a[252525];
struct segtree{
    vector<ll> tree;
    segtree(): tree(1010101) {}
    void init(ll node, ll s, ll e){
        if(s==e)tree[node] = a[s];
        else{
            ll mid = s+e>>1;
            init(node<<1,s,mid); init(node<<1|1,mid+1,e);
            tree[node] = tree[node<<1] & tree[node<<1|1];
        }
    }
    void upd(ll node, ll s, ll e, ll l, ll r, ll diff){
        if(e<l or r<s)return;
        if(l<=s and e<=r){
            if((tree[node]&diff) == diff)return;
        }
        if(s==e){
            if(tree[node]==k)SEG.upd(s,-1);
            else if((tree[node]|diff) == k)SEG.upd(s,1);
            tree[node] |= diff;
            return;
        }
        diff &= (~tree[node]);
        ll mid = s+e>>1;
        upd(node<<1,s,mid,l,r,diff); upd(node<<1|1,mid+1,e,l,r,diff);
        tree[node] = tree[node<<1]&tree[node<<1|1];
    }
} seg;
int main(){
    fast;
    cin>>n>>k;
    for(int i = 1; i <= n ; i++){
        cin>>a[i];
        if(a[i]==k)SEG.upd(i,1);
    }
    seg.init(1,1,n);
    cin>>q;
    while(q--){
        ll a,b,c,d; cin>>a>>b>>c;
        if(a==1){
            cin>>d; seg.upd(1,1,n,b,c,d);
        }
        else cout<<SEG.query(c) - SEG.query(b-1)<<"\n";
    }
}

풀이 1 : 1432ms / 풀이 2 : 240ms

'백준 > 다이아' 카테고리의 다른 글

BOJ 16901 - XOR MST (D4)  (1) 2024.02.06
BOJ 3654 - L퍼즐 (D4)  (0) 2024.01.25
BOJ 18830 - 하이퍼 수열과 하이퍼 쿼리 (D5)  (0) 2023.11.17
BOJ 25392 - 사건의 지평선 (D3)  (0) 2023.09.21
BOJ 29202 - 책가방 (D5)  (0) 2023.09.03