문제 설명
개의 노드로 이루어진 트리가 주어지고 M개의 두 노드 쌍을 입력받을 때 두 노드 사이의 거리를 출력하라.
https://www.acmicpc.net/problem/1240
제한 사항


풀이
문제를 요약하면, N개의 노드와 N-1개의 간선으로 연결된 트리가 있고 두 노드가 주어졌을 때 두 노드 사이의 거리를 구하면 된다.
두 노드 사이의 거리를 구하는 건 사실 bfs를 통해 순회하면 된다.
하지만 더 효율적인 방법이 있다.
그 방법은 부모가 같을 때까지 한 칸씩 올려보며 부모가 같아질 때 그 노드에서부터 시작 노드까지의 거리를 구해서 더하면 된다.
일반적인 방법은 노드의 depth를 맞추고 한칸씩 올리는 것이다.
하지만, depth를 또 저장하고 싶지 않아 다른 방법을 썼다.
hashset을 이용하여 각 경로의 방문 노드들을 저장했다.
만약 왼쪽 노드의 이동 경로중 하나를 오른쪽 노드가 지난다면 그 노드가 두 노드의 공통 조상이 된다.
반대도 마찬가지다.
즉, 임의의 노드를 시작으로 트리 구조를 만들어 가중치를 누적하여 저장한다.
void dfs(int start, int cost)
{
for (auto [next, c] : adj[start])
{
if (tree[next].first != 0) continue;
tree[next] = { start, cost + c };
dfs(next, cost + c);
}
}
이때, 부모 노드와 누적 가중치를 저장하면 나중에 쉽게 계산할 수 있다.
이후, 주어진 두 노드를 한 칸씩 부모쪽으로 이동시킨다.
만약 이동중 두 노드가 일치한다면 해당 노드가 공통 조상 노드가 된다.
그게 아니라 어느 한쪽이 다른 한쪽의 경로에 포함된다면 그 노드에 대해 가중치를 구하면 된다.
for (int i = 0; i < M; i++)
{
int a, b;
cin >> a >> b;
int left = a;
int right = b;
unordered_set<int> leftPath;
unordered_set<int> rightPath;
leftPath.insert(left);
rightPath.insert(right);
while (left != right)
{
left = tree[left].first;
right = tree[right].first;
if (left == right)
{
cout << (tree[a].second - tree[left].second) + (tree[b].second - tree[right].second) << "\n";
break;
}
if (rightPath.count(left) > 0)
{
cout << (tree[a].second - tree[left].second) + (tree[b].second - tree[left].second) << "\n";
break;
}
if (leftPath.count(right) > 0)
{
cout << (tree[a].second - tree[right].second) + (tree[b].second - tree[right].second) << "\n";
break;
}
if(left != 1) leftPath.insert(left);
if(right != 1) rightPath.insert(right);
}
}
전체 코드
#include <bits/stdc++.h>
#include <unordered_set>
using namespace std;
#define INPUT_OPTIMIZE cin.tie(NULL); cout.tie(NULL); ios::sync_with_stdio(false);
#define INF 2e9
using namespace std;
int N, M;
vector<vector<pair<int, int>>> adj;
vector<pair<int, int>> tree;
void dfs(int start, int cost)
{
for (auto [next, c] : adj[start])
{
if (tree[next].first != 0) continue;
tree[next] = { start, cost + c };
dfs(next, cost + c);
}
}
int main()
{
INPUT_OPTIMIZE;
cin >> N >> M;
adj.resize(N + 1);
tree.resize(N + 1);
for (int i = 0; i < N - 1; i++)
{
int a, b, c;
cin >> a >> b >> c;
adj[a].push_back({ b,c });
adj[b].push_back({ a,c });
}
//루트는 1
tree[1] = { 1, 0 };
dfs(1, 0);
for (int i = 0; i < M; i++)
{
int a, b;
cin >> a >> b;
int left = a;
int right = b;
unordered_set<int> leftPath;
unordered_set<int> rightPath;
leftPath.insert(left);
rightPath.insert(right);
while (left != right)
{
left = tree[left].first;
right = tree[right].first;
if (left == right)
{
cout << (tree[a].second - tree[left].second) + (tree[b].second - tree[right].second) << "\n";
break;
}
if (rightPath.count(left) > 0)
{
cout << (tree[a].second - tree[left].second) + (tree[b].second - tree[left].second) << "\n";
break;
}
if (leftPath.count(right) > 0)
{
cout << (tree[a].second - tree[right].second) + (tree[b].second - tree[right].second) << "\n";
break;
}
if(left != 1) leftPath.insert(left);
if(right != 1) rightPath.insert(right);
}
}
return 0;
}