ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [Python 3] BOJ - 17940 "지하철"
    BOJ 2021. 11. 1. 19:43

     # 문제 링크

    https://www.acmicpc.net/problem/17940

     

    17940번: 지하철

    대학원생인 형욱이는 연구실에 출근할 때 주로 지하철을 이용한다. 지하철은 A와 B, 두 개의 회사에서 운영하고 있다. 두 회사는 경쟁사 관계로 사람들이 상대 회사의 지하철을 이용하는 것을 매

    www.acmicpc.net

     

     # 풀이

     개인적으로 이 문제의 핵심은 0-1 BFS와 2차원 다익스트라라고 생각한다. 0-1 BFS란 그래프의 간선 가중치가 0 또는 1로만 구성되어 있을 때 최단 경로를 찾는 BFS이다. 일반 BFS와의 차이점은 간선이 0인 경로는 큐의 앞에 넣음으로서 최대한 간선의 가중치가 작도록 움직이는 것이다. 

     문제에서 회사가 다른 두 정점은 환승이 필요하고, 두 정점 사이의 환승 횟수는 같은 회사라면 0, 다른 회사라면 1이기 때문에 0-1 BFS의 구조를 만족한다. 따라서 0-1 BFS를 통해 시작점에서 끝점까지의 최소 환승 횟수를 구할 수 있다.

     그런데 문제에서는 최소 환승 횟수 경로에서도 가장 작은 비용으로 가야하기 때문에 다익스트라의 상태 공간을 현재 idx번째 정점이고 환승 횟수가 cnt일 때 걸리는 최소 비용인 d[idx][cnt]로 정의하면 정답을 구할 수 있다

     

     # 코드

    import sys, collections, heapq
    
    # bfs : 시작점 0에서 출발하여 끝점 m까지 도달하는데 필요한 최소 환승 수를 리턴하는 함수
    def bfs():
        q = collections.deque()
        vis = [-1] * n
        q.append(0)
        vis[0] = 0
        while q:
            now  = q.popleft()
            for (next, next_dist) in adj1[now]:
                if vis[next] == -1:
                    # 만약 회사가 달라 환승이 필요하면 맨 마지막에 넣는다
                    if arr[next] != arr[now]:
                        vis[next] = vis[now] + 1
                        q.append(next)
                    # 만약 회사가 같아서 환승이 필요없다면 맨 처음에 넣는다
                    else:
                        vis[next] = vis[now]
                        q.appendleft(next)
        return vis[m]
    
    # dijkstra : 특정 환승 횟수인 num이하인 경로 중 끝점 m에 대한 최단거리를 리턴하는 함수
    def dijkstra(num):
        inf = 9876543210
        
        # d : 현재 idx번째 정점이고 현재까지 환승 수가 cnt일 때 최단거리를 저장하는 2차원 상태 배열
        d = [[inf] * (num + 1) for _ in range(n)]
        d[0][0] = 0
        q = []
        # 우선순위 큐에는 (현재 거리, 현재 환승 수, 현재 정점)을 넣는다
        heapq.heappush(q, (0, 0, 0))
        while q:
            cur_dist, cur_cnt, cur = heapq.heappop(q)
            if d[cur][cur_cnt] < cur_dist:
                continue
            for (next, next_dist) in adj2[cur]:
                if arr[next] != arr[cur]:
                    if cur_cnt + 1 <= num:
                        if d[next][cur_cnt + 1] > cur_dist + next_dist:
                            d[next][cur_cnt + 1] = cur_dist + next_dist
                            heapq.heappush(q, (cur_dist + next_dist, cur_cnt + 1, next))
                else:
                    if d[next][cur_cnt] > cur_dist + next_dist:
                        d[next][cur_cnt] = cur_dist + next_dist
                        heapq.heappush(q, (cur_dist + next_dist, cur_cnt, next))
        return d[m][num]
    
    # 입력부
    n, m = map(int, sys.stdin.readline().split())
    arr = [0] * n
    for i in range(n):
        arr[i] = int(sys.stdin.readline())
    temp = [list(map(int, sys.stdin.readline().split())) for _ in range(n)]
    
    # adj1 : 0-1 BFS 인접리스트
    adj1 = [[] for _ in range(n)]
    # adj2 : 다익스트라 인접리스트
    adj2 = [[] for _ in range(n)]
    for i in range(n):
        for j in range(i + 1, n):
            if temp[i][j]:
                adj2[i].append((j, temp[i][j]))
                adj2[j].append((i, temp[i][j]))
                if arr[i] != arr[j]:
                    adj1[i].append((j, 1))
                    adj1[j].append((i, 1))
                else:
                    adj1[i].append((j, 0))
                    adj1[j].append((i, 0))
                    
    # 정답 출력
    val = bfs()
    res = dijkstra(val)
    print(val, res)

    댓글

Designed by Tistory.