본문 바로가기

백준/골드

BOJ 2613 - 숫자구슬 (G2)

알고리즘 분류

  • Dynamic Programming
  • Greedy
  • Binary Search
  • Parametric Search

풀이

풀이 1: Greedy, O(N^3)

 

N이 최대 300이므로 O(N^3)까지 고려할 수 있다. 따라서 임의의 구간 [l, r]을 골라서, [l,r]의 합이 최댓값이 되게 할 수 있는지를 빠르게 구할 수 있으면 된다.

  • [l,r]의 합을 x라 하자. 배열의 처음부터 시작해서 값을 연속해서 더해나간다. 이 때, 합이 x 초과일 경우 새로 그룹을 갱신한다. (단, 하나의 값이 x 초과일경우는 불가능하다.) [l,r]에 대해서는 예외처리를 해준다.

위 알고리즘은 x가 최댓값이 되게 하면서 그룹의 크기를 최소화하는 그리디적인 방법이다. 만약 위 알고리즘대로 만든 그룹의 개수가 m보다 크다면, [l,r]이 최댓값이 되는 경우는 없다는 뜻이다. 만든 그룹의 개수가 m보다 작다면, 그룹을 쪼개서 m개로 만들 수 있는지를 판별하면 된다. 가능하다면 정답과 x의 크기를 비교하여 답을 갱신하면 된다.

 

정당성 판별은 쉽다. 그 이유는 최적해가 당연히 어떠한 [l,r]의 합이 최대인 경우에 포함되기 때문이며, 그 [l,r]을 제외한 나머지 그룹들은 합이 x 이하라는 조건만 만족하면 되기 때문이다.

 

시간복잡도는 임의의 구간을 고르는데 O(N^2), 위 알고리즘에 O(N)이 드므로 총 O(N^3)이다.

#include <iostream>
#include <cmath>
#include <algorithm>
#include <numeric>
#include <cstring>
#include <vector>
#include <string>
#include <climits>
#include <map>
#include <set>
#include <stack>
#include <queue>
#include <bitset>
#include <cassert>
#include <list>
using namespace std;
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define all(x) x.begin(), x.end()
#define rll(x) x.rbegin(), x.rend()
#define comp(x) x.erase(unique(all(x)), x.end())
#define MOD 998244353
typedef long long ll;
ll n, m;
int main(){
    fast;
    cin>>n>>m;
    vector<ll> v(n);
    for(auto &i : v)cin>>i;
    ll ans = 1e18;
    vector<ll> res;
    for(int i = 0 ; i < n ; i++){
        for(int j = i ; j < n ; j++){
            ll l = i, r = j;
            ll x = accumulate(v.begin()+l, v.begin()+r+1, 0);
            if(x>ans)continue;
            vector<ll> group;
            ll cnt = 0;
            ll sum = 0;
            bool flag = 1;
            for(int k = 0 ; k < n ; k++){
                if(k==l){
                    if(cnt)group.push_back(cnt);
                    cnt=sum=0;
                    while(k<=r)k++, cnt++;
                    group.push_back(cnt);
                    k--;
                    cnt=0;
                }
                else{
                    if(sum+v[k] > x){
                        if(!cnt){
                            flag = 0;
                            break;
                        }
                        group.push_back(cnt);
                        cnt=sum=0;
                        k--;
                    }
                    else{
                        cnt++;
                        sum += v[k];
                    }
                }
            }
            if(!flag)continue;
            if(cnt)group.push_back(cnt);
            if(group.size()>m)continue;
            while(group.size()<m){
                ll ptr = group.size()-1;
                while(ptr>=0 and group[ptr]==1)ptr--;
                if(ptr<0)break;
                group[ptr]--, group.push_back(1);
            }
            if(group.size()<m)continue;
            ans=x;
            res=group;
        }
    }
    cout<<ans<<"\n";
    for(auto i : res)cout<<i<<" ";
}

 

풀이 2: Parametric Search, O(NlogS) (S : 배열 값의 합)

f(x) := 최댓값이 x 이하가 되게 할 때, 구간의 크기가 m 이하일 수 있는가?

위와 같이 f(x)를 정의하면, f(x)의 값은 True/False이고 그 값은 x에 따라 단조성을 띈다. 따라서 이분탐색을 적용하여 f(x)=True인 최소 x를 찾을 수 있다. f(x)의 판별은 위 풀이 1의 알고리즘에서 [l,r]만 제거한 형태이므로 생략한다.

 

시간복잡도는 이분 탐색에 O(logS) (S : 배열 값의 합), f(x)를 구할 때 O(N)이므로 O(NlogS)이다.

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#include <cstring>
#include <queue>
#include <climits>
#include <numeric>
using namespace std;
int n, m;
int arr[303];
int M;
vector<int> v, ans;
bool chk(int x){
    v.clear();
    int cur = 0, cnt = 0;
    if(M>x)return 0;
    for(int i = 0 ; i < n ; i++){
        if(cur+arr[i]>x){
            v.push_back(cnt);
            cur = cnt = 0;
        }
        if(v.size()>m)return 0;
        cur += arr[i]; cnt++;
    }
    v.push_back(cnt);
    if(v.size()>m)return 0;
    return 1;
}
int main(){
    cin>>n>>m;
    for(int i = 0 ; i < n ; i++)cin>>arr[i], M = max(M, arr[i]);
    int l = 0, r = accumulate(arr, arr+n, 0)+1;
    while(l+1 < r){
        int mid = (l+r)/2;
        if(chk(mid)){
            ans = v;
            r = mid;
        }
        else l = mid;
    }
    cout<<r<<endl;
    int idx = ans.size()-1;
    while(ans.size()<m){
        ans.push_back(1);
        while(ans[idx]<=1)idx--;
        ans[idx]--;
    }
    for(auto i : ans)cout<<i<<" "; cout<<endl;
    return 0;
}

위쪽이 풀이1, 아래가 풀이2. O(N^3)이 생각보다 빨리 돈다.