문제 설명
N(1 ≤ N ≤ 100)개의 수로 이루어진 1차원 배열이 있다. 이 배열에서 M(1 ≤ M ≤ ⌈(N/2)⌉)개의 구간을 선택해서, 구간에 속한 수들의 총 합이 최대가 되도록 하려 한다. 단, 다음의 조건들이 만족되어야 한다.
- 각 구간은 한 개 이상의 연속된 수들로 이루어진다.
- 서로 다른 두 구간끼리 겹쳐있거나 인접해 있어서는 안 된다.
- 정확히 M개의 구간이 있어야 한다. M개 미만이어서는 안 된다.
N개의 수들이 주어졌을 때, 답을 구하는 프로그램을 작성하시오.
https://www.acmicpc.net/problem/2228
제한 사항
풀이
문제를 요약하면, N개의 수를 M개의 구간으로 나누었을 대 구간에 포함된 수들의 합의 최댓값을 구하는 것이다.
이때, 구간은 겹치면 안되며 정확히 M개여야 한다.
처음에는 N이 충분히 작기 때문에 부분 집합을 구해 M개의 구간으로 나누어진다면 최댓값을 갱신하도록 풀었다.
하지만, 계속 틀렸고 이유를 찾지 못했다.
M개의 구간으로 나누어 지는지 확인하는 부분이 문제가 될 것이라 예측되지만 1011001이라고 했을 때 j가 1이면서 j-1이 0인 개수를 세었는데 왜 틀리는지는 모르겠다.(j=0은 따로 체크)
이 문제를 푸는 방법은 DP를 이용하는 것이다.
dp[i][j]: i번째 수까지 중 j개의 구간의 합의 최댓값
즉, 우리가 최종적으로 구할 것은 dp[N][M]이 될 것이다.
그럼 기저 조건을 설정해 보자.
j가 0이되면 dp[i][0]이 될 것이고 이는 명백히 0이다.
0개를 골라야 하기 때문이다.
M의 조건이 N/2의 상한이기 때문에 N < M / 2 + 1이 되면 M개의 구간으로 나눌 수 없게 된다.
붙어있거나 겹치는 구간이 발생하기 때문이다.
그럼 이제 재귀와 dp를 통해 답을 구해나가면 된다.
우선, (N, M)에서 시작하여 다음과 같은 논리를 따라가게 된다.
- n번째 수를 선택하지 않는 경우: n-1번째까지의 수를 m개의 구간으로 나눌 때의 최댓값
- n번째 수를 선택한 경우: n-2번째까지의 수를 m-1개의 구간으로 나눌때의 최댓값
이는 지금 선택하는 수가 마지막 구간이라고 생각한다면 이해하기 쉬울 것이다.
따라서, 마지막 구간으로 최대 1~n까지의 수를 선택할 수 있다.
이를 k라고 한다면 마지막 구간은 k~n까지의 합이 될 것이다.
그럼, 이 부분에서 계속하여 배열의 합을 구해야 하기 때문에 누적합을 이용하여 이를 효율적으로 처리할 수 있다.
for (int i = n; i >= 1; i--)
{
int temp = Solve(i - 2, m - 1);
dp[n][m] = max(dp[n][m], temp + prefixSum[n] - prefixSum[i - 1]);
}
정리하자면, 재귀 함수에서는 두가지 중 하나를 선택하면 된다.
아무것도 선택하지 않고 n-1까지 구한 최댓값을 그대로 사용하는 경우와 n-2까지 구한 최댓값에 하나의 구간을 더해 최댓값을 구하는 경우이다.
n-2까지 구한 최댓값을 이용하려면 마지막 구간을 만들어야 하는데 1~n까지의 수를 하나씩 살펴보며 더해보면 된다.
전체 코드
#include <bits/stdc++.h>
using namespace std;
int N, M;
vector<int> nums;
vector<vector<int>> dp;
vector<int> prefixSum;
const int MIN = -987654321;
int Solve(int n, int m)
{
if (n < m * 2 - 1) return MIN; // 겹침
if (m == 0) return 0; // 0개 선택
if (dp[n][m] != MIN) return dp[n][m]; // dp
//n번째 수를 선택하지 않을 경우
dp[n][m] = Solve(n - 1, m);
//n번째 수를 선택하는 경우
for (int i = n; i >= 1; i--)
{
int temp = Solve(i - 2, m - 1);
dp[n][m] = max(dp[n][m], temp + prefixSum[n] - prefixSum[i - 1]);
}
return dp[n][m];
}
int main()
{
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cout << fixed;
cin >> N >> M;
nums.resize(N+1);
prefixSum.resize(N+1);
for (int i = 1; i <= N; i++)
{
cin >> nums[i];
if (i == 0) prefixSum[i] = nums[i];
else prefixSum[i] = prefixSum[i - 1] + nums[i];
}
dp.resize(N+1, vector<int>(M+1));
for (int i = 1; i <= N; i++)
{
for (int j = 1; j <= M; j++)
{
dp[i][j] = MIN;
}
}
cout << Solve(N, M);
return 0;
}