다익스트라 알고리즘을 복습하기 위해 문제 카테고리에서 제일 첫 번째에 위치한 문제를 풀러 들어갔다.
처음에는 다익스트라만 적용하면 간단하게 풀릴 문제라고 생각했다.
하지만...... 오만했던 나를 발견했다.....
거의 하루를 쏟아 부었다.....
몇 틀을 한 건지 모르겠다..... 15트만에 드디어 풀었다.
중간에 맞은 문제는 구글에 돌아다니는 코드를 붙여 넣어 봤다. 내 코드만 안 되는 줄 알고....
점심 나가버릴 거 같았다....
문제풀이는 간단했다.
노드의 수와 간선의 수, 시작 노드를 입력받고 간선들의 시작점, 도착점, 가중치를 입력받아 시작 노드에서 다른 노드들까지의 최단거리를 구하면 되는 문제이다.
나는 다익스트라 알고리즘을 이용하여 최단 경로를 구하려 했다.
다익스트라 알고리즘은 아직 방문하지 않은 노드 중 super set에서 갈 수 있는 최단 거리의 노드로 이동한 뒤,
그 노드 또한 super set에 포함시키며 모든 간선을 처리하면 된다.
처음에는 간선을 저장하는 자료구조를 시작, 도착, 가중치로 설정하는 Triple이란 자료구조를 구현해 사용하였으나 그다지 필요하지 않았다. 그리고 시작 노드에서 갈 수 있는 모든 노드를 구하기 위해서는 N번의 탐색을 거쳐야 하기 때문에 시간 복잡도도 좋지 않았다.
그래서 도착, 가중치만 저장하는 Pair를 구현한 뒤, Pair를 저장하는 ArrayList를 저장하는 ArrayList를 만들었다.
마치 마트로시카 같다.... ㅎㅎ
그렇게 만든 뒤, super set에서 갈 수 있는 노드들 중 거리를 줄일 수 있는 노드를 모두 줄이고 우선순위 큐에 넣어 가장 짧은 거리의 노드로 이동하여 super set을 점차 늘려갔다.
내 시간을 잡아먹은 요소는 여기에 있었다. 실수로 우선순위 큐에 넣을 때, 초기화된 가중치를 넣어야 한다.
즉, 시작 노드에서 도착 노드까지의 최단거리를 큐에 넣어야 하는데 그냥 가중치만 넣어서 잘못된 결과가 나왔다.
틀린 이유를 몰라 화가 많이 났지만.....
내가 다익스트라를 완전히 이해하지 못한 결과라고 생각한다.
부단히 노력해야겠다.
다음은 Java로 작성한 코드이다.
package graph;
import java.util.ArrayList;
import java.util.PriorityQueue;
import java.util.Scanner;
public class ShortestPathDijkstra {
public static void main(String[] args) {
//도착지, 가중치 저장 자료구조
class Pair implements Comparable<Pair>{
int a;
int b;
public Pair(int a, int b) {
this.a = a;
this.b = b;
}
public int getA() {
return a;
}
public int getB() {
return b;
}
@Override
public int compareTo(Pair o) {
return this.b - o.getB();
}
}
//무한대
final int INF = 987654321;
Scanner scan = new Scanner(System.in);
int nodeNum = scan.nextInt();
int edgeNum = scan.nextInt();
int startNode = scan.nextInt();
//graph
ArrayList<ArrayList<Pair>> edge = new ArrayList<ArrayList<Pair>>();
//방문 여부
boolean[] visted = new boolean[nodeNum+1];
//거리 저장
int[] distance = new int[nodeNum+1];
//우선 순위 큐
PriorityQueue<Pair> q = new PriorityQueue<Pair>();
//초기화
for (int i = 0; i <= nodeNum; i++) {
edge.add(new ArrayList<Pair>());
}
for (int i = 0; i < edgeNum; i++) {
int a = scan.nextInt();
int b = scan.nextInt();
int w = scan.nextInt();
edge.get(a).add(new Pair(b,w));
}
for( int i = 0 ; i <= nodeNum ; i++ ) {
distance[i] = INF;
visted[i] = false;
}
//시작 노드 거리 0으로 초기화
distance[startNode] = 0;
//큐에 추가
q.add(new Pair(startNode, 0));
//다익스트라
while(!q.isEmpty()) {
int a = q.poll().a;
if(visted[a]){
continue;
}
visted[a] = true;
for (Pair p : edge.get(a)) {
int b = p.getA();
int w = p.getB();
if(distance[b] > distance[a] + w) {
distance[b] = distance[a] + w;
//이부분에서 줄어든 거리를 넣어야 한다.
q.add(new Pair(b, distance[b]));
}
}
}
//출력
for( int i = 1 ; i <= nodeNum ; i++ ) {
if(!visted[i]) {
System.out.println("INF");
continue;
}
if(distance[i]==INF) {
System.out.println("INF");
}else {
System.out.println(distance[i]);
}
}
}
}