문제 설명
n개의 점으로 이루어진 트리가 있습니다. 이때, 트리 상에서 다음과 같은 것들을 정의합니다.
어떤 두 점 사이의 거리는, 두 점을 잇는 경로 상 간선의 개수로 정의합니다.
임의의 3개의 점 a, b, c에 대한 함수 f(a, b, c)의 값을 a와 b 사이의 거리, b와 c 사이의 거리, c와 a 사이의 거리, 3개 값의 중간값으로 정의합니다.
트리의 정점의 개수 n과 트리의 간선을 나타내는 2차원 정수 배열 edges가 매개변수로 주어집니다. 주어진 트리에서 임의의 3개의 점을 뽑아 만들 수 있는 모든 f값 중에서, 제일 큰 값을 구해 return 하도록 solution 함수를 완성해주세요.
https://school.programmers.co.kr/learn/courses/30/lessons/68937#
제한 사항
- n은 3 이상 250,000 이하입니다.
- edges의 행의 개수는 n-1 입니다.
- edges의 각 행은 [v1, v2] 2개의 정수로 이루어져 있으며, 이는 v1번 정점과 v2번 정점 사이에 간선이 있음을 의미합니다.
- v1, v2는 각각 1 이상 n 이하입니다.
- v1, v2는 다른 수입니다.
- 입력으로 주어지는 그래프는 항상 트리입니다.
풀이
문제를 요약하면, 트리에서 임의의 3개의 노드를 선택하여 선택한 노드사이의 거리의 중앙값을 최대화하는 것이다.
우선, 시간 복잡도를 고려해 보자.
임의의 노드 3개를 선택하는 상황만 놓고 봐도 ${}_n \mathrm{ C }_3$으로 꽤 큰 시간이 필요하다.
만약, 이 과정이 시간 안에 이루어졌다고 하더라도 a와 b, b와 c, a와 c사이의 거리를 구하는 과정이 추가로 필요하기 때문에 시간안에 해결할 수 없을 것이다.
즉, 다른 접근법이 필요하다.
우선, 중간값이라는 조건을 생각해 보자.
중간값은 3개의 값 중 둘째로 큰 수가 된다.
만약, 가장 큰 값이 동일하다면 가장 큰 수가 중간값이 될 수 있다.
예를 들어, {1, 10, 10}의 경우 중간값은 10이다.
그렇다면 a, b, c 세 개의 노드사이의 거리의 중간값이 최대가 되려면 트리에서 가장 먼 노드들을 선택하면 유리하다고 생각할 수 있다.
트리에서 가장 먼 두 노드의 거리는 트리의 지름이라고 한다.
만약, 트리의 지름을 구한다면 다음 두 가지 상황을 생각할 수 있다.
- 트리의 지름을 이루는 노드의 쌍이 여러 개일 경우
- 트리의 지름이 이루는 노드의 쌍이 유일할 경우
우선, 노드의 쌍이 여러 개일 경우를 생각해 보자.
해당 트리의 지름은 (1, 2), (1, 3), (1, 4)... 등으로 구할 수 있다.
이 중, (1, 3)으로 지름을 이루는 두 노드 a, b를 선택했다고 가정해 보자.
그렇다면 {2, 4, 5} 중 남은 노드 c로 선택해야 한다.
이때 5를 선택한다고 하면 각 노드 사이의 거리가 {1, 1, 2}로 중간값은 1이 된다.
하지만, 3과 마찬가지로 1과 트리의 지름을 이루는 2를 선택한다고 하면 (2, 2, 2)로 중간값은 2가 된다.
즉, 트리의 지름을 이루는 노드의 쌍이 여러 개인 경우에는 정답이 트리의 지름이 된다.
그렇다면 트리의 지름을 이루는 노드가 유일한 경우를 생각해 보자.
트리의 지름을 이루는 노드쌍은 (1, 4)로 유일하다.
이때, 중간값을 크게 하기 위해서는 1과 4에서 1만큼 떨어진 노드를 c로 선택하는 것이 유리하다.
예를 들어, 2번 노드를 c로 선택한 경우 거리를 (1, 2, 3)으로 중간값은 2가 된다.
3번 노드를 선택해도 결과는 마찬가지이다.
즉, 트리의 지름을 이루는 노드쌍이 유일한 경우 트리의 지름에서 1을 뺀 값이 중간값의 최댓값이 된다.
정리하면, 트리의 지름과 트리의 지름을 이루는 노드 쌍의 개수를 구한다면 정답을 구할 수 있다.
트리의 지름을 구하는 방법은 dfs, bfs를 통해 쉽게 구할 수 있다.
임의의 노드를 선택한 뒤, 해당 노드에서 가장 먼 노드를 구하고 그 노드에서 가장 먼 노드를 구하면 두 노드는 트리의 지름을 이룬다.
예를 들어, (4, 7)이 트리의 지름을 이루는 노드의 쌍이다.
임의의 노드 1번에서 가장 먼 노드를 구하면 {4, 6}이다.
둘 중 아무거나 선택해도 상관없다.
만약, 4번을 선택했다고 하면 4에서 가장 먼 노드는 {1, 7}이 된다.
즉, 위와 같은 경우는 트리의 지름을 이루는 노드의 쌍이 여러 개인 경우이며 {(1,4), (1,6), (4, 1), (4, 7)}이 될 수 있다.
이를 문제에 적용한다면 트리의 지름을 이루는 노드의 쌍이 여러 개이기 때문에 트리의 지름인 4가 정답이 된다.
이를 구현하면 다음과 같다.
struct BFSResult
{
BFSResult(int d, int i, int c) : dist(d), idx(i), cnt(c) {};
int dist;
int idx;
int cnt;
};
BFSResult BFS(int start)
{
queue<int> q;
vector<bool> visited(N+1, false);
q.push(start);
visited[start] = true;
int dist = 0;
int farNode = 0;
int cnt = 0;
while(!q.empty())
{
cnt = q.size();
bool bEnd = true;
for(int i = 0 ; i < cnt; i++)
{
int current = q.front();
q.pop();
for(auto next : adj[current])
{
if(visited[next]) continue;
visited[next] = true;
farNode = next;
bEnd = false;
q.push(next);
}
}
if(!bEnd) dist++;
}
return BFSResult(dist, farNode, cnt);
}
BFS를 통해 거리, 가장 먼 노드, 같은 거리의 노드의 개수를 구하고 반환한다.
이제 BFS를 호출하는 방법만 구현하면 된다.
결론부터 말하자면 BFS를 총 세 번 호출한다.
- 임의의 시작점(1번 노드)을 기준으로 가장 멀리 있는 트리의 지름을 이루는 노드(a)를 찾는다.
- 1번에서 찾은 노드를 기준으로 가장 멀리 있는 또 다른 트리의 지름을 이루는 노드(b)를 찾는다.
- 2번에서 찾은 노드(b)를 기준으로 또 다른 지름을 이루는 노드가 있는지 탐색한다.
- 만약, 그러한 노드가 있다면 해당 노드가 c가 된다.
- 그렇지 않은 경우, a와 b가 유일한 트리의 지름이므로 답은 트리의 지름 - 1이다.
auto start = BFS(1); //1
auto leaf = BFS(start.idx); //2
if(leaf.cnt > 1) answer = leaf.dist;
else
{
auto last = BFS(leaf.idx); //3
if(last.cnt > 1) answer = last.dist;
else answer = last.dist - 1;
}
만약, 2번 과정에서 이미 지름을 이루는 노드가 여러 개인 경우에는 더 이상 진행할 필요가 없다.
전체 코드
#include <string>
#include <vector>
#include <bits/stdc++.h>
using namespace std;
vector<vector<int>> adj;
int N;
struct BFSResult
{
BFSResult(int d, int i, int c) : dist(d), idx(i), cnt(c) {};
int dist;
int idx;
int cnt;
};
BFSResult BFS(int start)
{
queue<int> q;
vector<bool> visited(N+1, false);
q.push(start);
visited[start] = true;
int dist = 0;
int farNode = 0;
int cnt = 0;
while(!q.empty())
{
cnt = q.size();
bool bEnd = true;
for(int i = 0 ; i < cnt; i++)
{
int current = q.front();
q.pop();
for(auto next : adj[current])
{
if(visited[next]) continue;
visited[next] = true;
farNode = next;
bEnd = false;
q.push(next);
}
}
if(!bEnd) dist++;
}
return BFSResult(dist, farNode, cnt);
}
int solution(int n, vector<vector<int>> edges) {
int answer = 0;
N = n;
adj.resize(n+1);
for(auto edge : edges)
{
adj[edge[0]].push_back(edge[1]);
adj[edge[1]].push_back(edge[0]);
}
auto start = BFS(1);
int leaf = start.idx;
auto startLeaf = BFS(leaf);
if(startLeaf.cnt > 1) answer = startLeaf.dist;
else
{
auto last = BFS(startLeaf.idx);
if(last.cnt > 1) answer = last.dist;
else answer = last.dist - 1;
}
return answer;
}