알고리즘/그래프 알고리즘

백준 5052 - 전화번호 목록

hvv_an 2025. 5. 14. 21:58

문제 설명

개의 노드로 이루어진 트리가 주어지고 M개의 두 노드 쌍을 입력받을 때 두 노드 사이의 거리를 출력하라.

https://www.acmicpc.net/problem/1240


 

 

 

 

 

 

제한 사항


 

 

 

 

 

 

풀이

문제를 요약하면, N개의 노드와 N-1개의 간선으로 연결된 트리가 있고 두 노드가 주어졌을 때 두 노드 사이의 거리를 구하면 된다.

 

두 노드 사이의 거리를 구하는 건 사실 bfs를 통해 순회하면 된다.

하지만 더 효율적인 방법이 있다.

그 방법은 부모가 같을 때까지 한 칸씩 올려보며 부모가 같아질 때 그 노드에서부터 시작 노드까지의 거리를 구해서 더하면 된다.

일반적인 방법은 노드의 depth를 맞추고 한칸씩 올리는 것이다.

하지만, depth를 또 저장하고 싶지 않아 다른 방법을 썼다.

hashset을 이용하여 각 경로의 방문 노드들을 저장했다.

만약 왼쪽 노드의 이동 경로중 하나를 오른쪽 노드가 지난다면 그 노드가 두 노드의 공통 조상이 된다.

반대도 마찬가지다.

 

즉, 임의의 노드를 시작으로 트리 구조를 만들어 가중치를 누적하여 저장한다.

void dfs(int start, int cost)
{
    for (auto [next, c] : adj[start])
    {
        if (tree[next].first != 0) continue;

        tree[next] = { start, cost + c };
        dfs(next, cost + c);
    }
}

이때, 부모 노드와 누적 가중치를 저장하면 나중에 쉽게 계산할 수 있다.

 

이후, 주어진 두 노드를 한 칸씩 부모쪽으로 이동시킨다.

만약 이동중 두 노드가 일치한다면 해당 노드가 공통 조상 노드가 된다.

그게 아니라 어느 한쪽이 다른 한쪽의 경로에 포함된다면 그 노드에 대해 가중치를 구하면 된다.

for (int i = 0; i < M; i++)
{
    int a, b;
    cin >> a >> b;

    int left = a;
    int right = b;
    unordered_set<int> leftPath;
    unordered_set<int> rightPath;

    leftPath.insert(left);
    rightPath.insert(right);

    while (left != right)
    {
        left = tree[left].first;
        right = tree[right].first;

        if (left == right)
        {
            cout << (tree[a].second - tree[left].second) + (tree[b].second - tree[right].second) << "\n";
            break;
        }

        if (rightPath.count(left) > 0)
        {
            cout << (tree[a].second - tree[left].second) + (tree[b].second - tree[left].second) << "\n";
            break;
        }

        if (leftPath.count(right) > 0)
        {
            cout << (tree[a].second - tree[right].second) + (tree[b].second - tree[right].second) << "\n";
            break;
        }

        if(left != 1) leftPath.insert(left);
        if(right != 1) rightPath.insert(right);
    }
}

 

 

 

 

 

전체 코드

#include <bits/stdc++.h>
#include <unordered_set>
using namespace std;
#define INPUT_OPTIMIZE cin.tie(NULL); cout.tie(NULL); ios::sync_with_stdio(false);
#define INF 2e9

using namespace std;
int N, M;
vector<vector<pair<int, int>>> adj;
vector<pair<int, int>> tree;

void dfs(int start, int cost)
{
    for (auto [next, c] : adj[start])
    {
        if (tree[next].first != 0) continue;

        tree[next] = { start, cost + c };
        dfs(next, cost + c);
    }
}

int main() 
{
    INPUT_OPTIMIZE;

    cin >> N >> M;

    adj.resize(N + 1);
    tree.resize(N + 1);

    for (int i = 0; i < N - 1; i++)
    {
        int a, b, c;
        cin >> a >> b >> c;

        adj[a].push_back({ b,c });
        adj[b].push_back({ a,c });
    }

    //루트는 1
    tree[1] = { 1, 0 };
    dfs(1, 0);

    for (int i = 0; i < M; i++)
    {
        int a, b;
        cin >> a >> b;

        int left = a;
        int right = b;
        unordered_set<int> leftPath;
        unordered_set<int> rightPath;

        leftPath.insert(left);
        rightPath.insert(right);

        while (left != right)
        {
            left = tree[left].first;
            right = tree[right].first;

            if (left == right)
            {
                cout << (tree[a].second - tree[left].second) + (tree[b].second - tree[right].second) << "\n";
                break;
            }

            if (rightPath.count(left) > 0)
            {
                cout << (tree[a].second - tree[left].second) + (tree[b].second - tree[left].second) << "\n";
                break;
            }

            if (leftPath.count(right) > 0)
            {
                cout << (tree[a].second - tree[right].second) + (tree[b].second - tree[right].second) << "\n";
                break;
            }

            if(left != 1) leftPath.insert(left);
            if(right != 1) rightPath.insert(right);
        }
    }

    return 0;
}