오늘 리뷰해볼 논문은 An Information-Geometric Distance on the Space of Tasks(Gao et al., 2021 ICML accepted)입니다. 참고로 내용이 조금 어렵습니다. 저 또한 이쪽에 익숙하지 않아서 그런지 정확한 맥락을 짚어내기가 좀 힘들었고, 그래서 정확하지 않은 부분이 있을 수도 있습니다. 양해 부탁드립니다!
추가로, 원 논문에 포함되지 않은 많은 내용들을 제 해석으로 담았습니다. 제가 이해하기 쉽도록 포인트를 잡아서 간략하게 흐름만 짚어보려고 했는데, 틀린 내용은 아닐지 걱정이 되네요...ㅠ
Intro:
최근 딥러닝에서 핫한 이슈 중 하나가 바로 meta-learning입니다. 단순히 transfer-learning을 넘어서 generalization에 대한 보다 깊은 이해와 다양한 접근 방식을 바탕으로 아키텍쳐와 데이터 자체에 대해 풀어나가고자 하는 분야인데요, 오늘의 논문도 따져보자면 meta-learning에 기여하는 내용이라고 볼 수 있겠습니다.
간단히 핵심만 요약하면 'Task A를 배운 모델이 Task B를 배우는 것은 어느 정도 힘든 일일지 계량화할 수 있을까?' 정도로 말해볼 수 있겠네요. 예를 들어 고양이와 강아지를 분류하는 것을 배운 로봇이 강아지와 호랑이를 구분하는 것은 상대적으로 쉬울 것입니다. 그러나 침엽수와 활엽수를 구분하라고 하면 어렵겠죠. 이때 해당 Task들 간의 난이도를 거리로 표현하고자 하는 노력을 담아낸 논문입니다.
문제를 좀 더 formal하게 정의하면 Task A를 통해 학습한 지식인 p_A(y | x)에서 출발하여 Task B를 학습한 지식인 p_B(y | x)에 도달하기까지의 거리를 찾고자 하는 것입니다.
그럼 이제 할 일은 Task A의 p(y | x)와 Task B의 p(y | x)가 얼마나 멀리 떨어져 있는지를 거리로 나타낼 차례입니다. 그러나 도달하는 과정에서 중간 지점은 직접 데이터로 관찰할 수 없는 지점이기 때문에 어느 길로 가야하는지 알 수가 없습니다. 이해를 위해 그림을 볼까요?
A에서 출발해 B에 도착한다는 야심찬 계획은 좋았지만, 해당 문제를 해결하기 위해 알아야 하는 p(x)의 형태를 파악하기가 굉장히 어렵습니다. 사진의 경우 픽셀로 나누어져 데이터로 저장되기 때문에, 같은 강아지 그림이라도 데이터로 표현하면 천차만별이기 일쑤입니다. 이는 본 논문에서는 OT(Optimal Transportation)을 이용해 해결합니다.
또한, 우리가 만들고자 하는 metric은 단지 얼마나 멀리 떨어져있는지를 기준으로만 계산해서는 안됩니다. '원래의 모델에서 학습한 weights가 target probability에 도달하기 얼마나 어려운지'를 포함해야 하기 때문에 확률 공간에서의 movement를 계산해야 합니다. 이는 Fisher-Rao metric을 통해 해결합니다.
궁극적으로 우리가 얻고자 하는 것은 p(y | x)간의 거리인 Fisher-Rao distance입니다. 그런데 A라는 Task도 우리 손에 있고 B라는 Task도 우리 손에 있지만 중간 지점의 지형은 아예 모르기 때문에 거리를 측정하기 어렵습니다. 그래서 중간 지점의 지형을 근사시켜주기 위해 OT를 사용하는 것입니다. 즉, OT를 이용해 Fisher-Rao distance을 구하여 Task간의 거리(전이학습의 난이도라고 이해해도 좋을 것 같습니다)를 구하는 것이 목표입니다.
Contents:
이제야 서론이 끝났네요. OT와 Fisher-Rao metric을 하나씩 살펴보도록 하겠습니다. 먼저 Fisher-Rao metric을 살펴보죠. 본 논문에서는 Task를 이렇게 정의합니다: 'Classification problem을 SGD를 사용해 Cross-entropy loss를 기반으로 optimization하는 과정'
우리는 Cross-entropy를 minimize하는 것은 결국 KL-divergence를 minimize하는 것과 같음을 알고 있습니다. 우리가 궁금한 것은 이제 Task A를 기반으로 학습한 weights가 Task B를 학습하기 위해 얼마나 빠른 속도로 weight이 수렴하는지입니다. 위 사진에서 보면 p_ws(y | x)가 p_wt(y | x)로 향해 가는 과정을 나타내고 있음을 볼 수 있죠. 가장 빠른 루트를 선택하여 해당 루트를 얼마만큼의 속도로 따라갔을때의 시간이 결국 난이도를 나타낸다고 이해하면 되겠습니다.
이는 Information Geometry에서의 문제 세팅이랑 굉장히 유사하고, 실제로 KL-divergence를 이용한 metric이 존재하기 때문에 활용해볼 수 있는데, 바로 Fisher-Rao metric이 여기서 등장합니다.
꽤 복잡하기 때문에 개념만 간단하게 잡고 가자면, 다음과 같이 설명해볼 수 있겠습니다.
자, 이제 Fisher-Rao distance를 통해 p_ws(y | x)를 p_wt(y | x)로 보내는 난이도를 수치화하였습니다. 논리를 짧게 요약하면 '모델은 Cross-entropy를 바탕으로 minimize하기 때문에 KL-divergence를 minimize하는 방향을 따라 선적분하여 걸리는 시간을 거리로 나타낼 수 있다'정도가 되겠네요.
그럼 이제 선적분을 따라 거리를 계산하기 위해 필요한 중간지점을 어떻게 구하는지 볼까요? 본 논문에서는 OT(Optimal Transportation)을 사용합니다. OT를 한눈에 쉽게 이해할 수 있는 그림이 있길래 가져와보았습니다.
(출처: wikipedia)
x라는 확률분포를 y라는 확률분포로 보내는 행렬을 통해 변환이 가능합니다. 그런데, 단순히 변환을 하는 행렬을 찾는 것이 아니라 변환 간에 비용이 존재할 때 이를 반영해 최소화하는 것을 목표로 하는 방법론이 OT입니다.
OT는 원래 Transportation 문제를 풀기 위해 고안된 방법인데, 확률론 쪽에서도 많이 쓰인다고 합니다. 개념 자체는 n개의 금광과 m개의 공장이 있는데, 이 때 n개의 금광에서 m개의 공장으로 어떻게 금을 실어 날라야 최소한의 비용으로 금을 이쁘게 분배할 수 있을지를 고민하는 것으로부터 시작했습니다. 문제를 정의해보면 아래와 같이 정의됩니다.
감마가 행렬이고, C가 비용함수, H는 entropy를 나타냅니다. 자세한 내용은 Transportation theory를 참고해보셔도 좋을 것 같습니다.
우리 문제에서는 비용은 위에서 구한 Fisher-Rao distance가 되겠네요. 이 Fisher-Rao distance를 최소화하는 방향으로 p(x)를 변환하여 중간 지점에서도 Fisher-Rao distance를 계산할 수 있도록 사용됩니다. 다음과 같은 방식으로 말이죠!
델타 기호는 크로네커 델타 함수입니다. τ는 A에서 B로 이동한 정도를 나타냅니다. 이런 방식으로 통해 중간 지점에서도 p(x)를 계산할 수 있습니다. 이제, p(x) * p(y | x) = p(x, y)를 구할 수 있고, 이를 바탕으로 중간지점에서의 데이터를 생성하여 Fisher-Rao distance를 구하는 여정을 이어갈 수 있게 됩니다. 최종적인 알고리즘을 요약하면 다음과 같습니다:
위 알고리즘은 반복을 통해 최소의 FR distance를 갖는 w를 찾아가는 과정입니다. 이해하기 쉬운 순서는 (d) - (e) - (c) - (b) - (a) 로 생각하시면 될 것 같습니다.
댓글 없음:
댓글 쓰기