2021년 9월 20일 월요일

PR402 - 11: An Information-Geometric Distance on the Space of Tasks

 오늘 리뷰해볼 논문은 An Information-Geometric Distance on the Space of Tasks(Gao et al., 2021 ICML accepted)입니다. 참고로 내용이 조금 어렵습니다. 저 또한 이쪽에 익숙하지 않아서 그런지 정확한 맥락을 짚어내기가 좀 힘들었고, 그래서 정확하지 않은 부분이 있을 수도 있습니다. 양해 부탁드립니다!

추가로, 원 논문에 포함되지 않은 많은 내용들을 제 해석으로 담았습니다. 제가 이해하기 쉽도록 포인트를 잡아서 간략하게 흐름만 짚어보려고 했는데, 틀린 내용은 아닐지 걱정이 되네요...ㅠ


Intro:

최근 딥러닝에서 핫한 이슈 중 하나가 바로 meta-learning입니다. 단순히 transfer-learning을 넘어서 generalization에 대한 보다 깊은 이해와 다양한 접근 방식을 바탕으로 아키텍쳐와 데이터 자체에 대해 풀어나가고자 하는 분야인데요, 오늘의 논문도 따져보자면 meta-learning에 기여하는 내용이라고 볼 수 있겠습니다.

간단히 핵심만 요약하면 'Task A를 배운 모델이 Task B를 배우는 것은 어느 정도 힘든 일일지 계량화할 수 있을까?' 정도로 말해볼 수 있겠네요. 예를 들어 고양이와 강아지를 분류하는 것을 배운 로봇이 강아지와 호랑이를 구분하는 것은 상대적으로 쉬울 것입니다. 그러나 침엽수와 활엽수를 구분하라고 하면 어렵겠죠. 이때 해당 Task들 간의 난이도를 거리로 표현하고자 하는 노력을 담아낸 논문입니다. 

문제를 좀 더 formal하게 정의하면 Task A를 통해 학습한 지식인 p_A(y | x)에서 출발하여 Task B를 학습한 지식인 p_B(y | x)에 도달하기까지의 거리를 찾고자 하는 것입니다. 

그럼 이제 할 일은 Task A의 p(y | x)와 Task B의 p(y | x)가 얼마나 멀리 떨어져 있는지를 거리로 나타낼 차례입니다. 그러나 도달하는 과정에서 중간 지점은 직접 데이터로 관찰할 수 없는 지점이기 때문에 어느 길로 가야하는지 알 수가 없습니다. 이해를 위해 그림을 볼까요?


A에서 출발해 B에 도착한다는 야심찬 계획은 좋았지만, 해당 문제를 해결하기 위해 알아야 하는 p(x)의 형태를 파악하기가 굉장히 어렵습니다. 사진의 경우 픽셀로 나누어져 데이터로 저장되기 때문에, 같은 강아지 그림이라도 데이터로 표현하면 천차만별이기 일쑤입니다. 이는 본 논문에서는 OT(Optimal Transportation)을 이용해 해결합니다.

또한, 우리가 만들고자 하는 metric은 단지 얼마나 멀리 떨어져있는지를 기준으로만 계산해서는 안됩니다. '원래의 모델에서 학습한 weights가 target probability에 도달하기 얼마나 어려운지'를 포함해야 하기 때문에 확률 공간에서의 movement를 계산해야 합니다. 이는 Fisher-Rao metric을 통해 해결합니다.

궁극적으로 우리가 얻고자 하는 것은 p(y | x)간의 거리인 Fisher-Rao distance입니다. 그런데 A라는 Task도 우리 손에 있고 B라는 Task도 우리 손에 있지만 중간 지점의 지형은 아예 모르기 때문에 거리를 측정하기 어렵습니다. 그래서 중간 지점의 지형을 근사시켜주기 위해 OT를 사용하는 것입니다. 즉, OT를 이용해 Fisher-Rao distance을 구하여 Task간의 거리(전이학습의 난이도라고 이해해도 좋을 것 같습니다)를 구하는 것이 목표입니다.

Contents:

이제야 서론이 끝났네요. OT와 Fisher-Rao metric을 하나씩 살펴보도록 하겠습니다. 먼저 Fisher-Rao metric을 살펴보죠. 본 논문에서는 Task를 이렇게 정의합니다: 'Classification problem을 SGD를 사용해 Cross-entropy loss를 기반으로 optimization하는 과정'

우리는 Cross-entropy를 minimize하는 것은 결국 KL-divergence를 minimize하는 것과 같음을 알고 있습니다. 우리가 궁금한 것은 이제 Task A를 기반으로 학습한 weights가 Task B를 학습하기 위해 얼마나 빠른 속도로 weight이 수렴하는지입니다. 위 사진에서 보면 p_ws(y | x)가 p_wt(y | x)로 향해 가는 과정을 나타내고 있음을 볼 수 있죠. 가장 빠른 루트를 선택하여 해당 루트를 얼마만큼의 속도로 따라갔을때의 시간이 결국 난이도를 나타낸다고 이해하면 되겠습니다. 

이는 Information Geometry에서의 문제 세팅이랑 굉장히 유사하고, 실제로 KL-divergence를 이용한 metric이 존재하기 때문에 활용해볼 수 있는데, 바로 Fisher-Rao metric이 여기서 등장합니다. 

꽤 복잡하기 때문에 개념만 간단하게 잡고 가자면, 다음과 같이 설명해볼 수 있겠습니다.


위는 우리가 아는 통상적인 KL-divergence입니다. Riemmanian manifold 위에서 pw가 변화하는 순간속도를 다음과 같이 나타낼 수 있습니다.
위에서 g는 FIM(Fisher Information Matrix)를 뜻하고, KL-divergence의 Hessian Matrix로도 이해할 수 있습니다. FIM은 아래의 수식으로 표현됩니다.

확률공간 위에서의 manifold theory를 기반으로 하고 있어 통상적으로 알고 있던 Hessian과는 모습이 살짝 다르게 적분이 들어가 있음을 확인할 수 있습니다. 

우리가 궁금한 것은 w'을 향해서 이동할 때의 최단 거리 및 속도였는데요, 이는 아래의 Fisher-Rao distance로 정의됩니다. 
새로 정의한 w라는 선을 따라 선적분을 한다고 생각하시면 편할 것 같습니다. 물론 w를 정의함에 있어 최단거리라는 조건을 걸기 위해 min이 들어간 것입니다. 그러나 위는 w를 parameter로 갖는 함수여야 하기 때문에(즉, 모델의 파라미터가 영향을 끼치는 함수) p_w(y | x) (간단히 p(y | x)로 생각하셔도 됩니다.)에 대해서 구할 수 있고, 이는 나중에 실제 계산 단계에서는 근사됩니다. 

자, 이제 Fisher-Rao distance를 통해 p_ws(y | x)를 p_wt(y | x)로 보내는 난이도를 수치화하였습니다. 논리를 짧게 요약하면 '모델은 Cross-entropy를 바탕으로 minimize하기 때문에 KL-divergence를 minimize하는 방향을 따라 선적분하여 걸리는 시간을 거리로 나타낼 수 있다'정도가 되겠네요.

그럼 이제 선적분을 따라 거리를 계산하기 위해 필요한 중간지점을 어떻게 구하는지 볼까요? 본 논문에서는 OT(Optimal Transportation)을 사용합니다. OT를 한눈에 쉽게 이해할 수 있는 그림이 있길래 가져와보았습니다.

(출처: wikipedia)

x라는 확률분포를 y라는 확률분포로 보내는 행렬을 통해 변환이 가능합니다. 그런데, 단순히 변환을 하는 행렬을 찾는 것이 아니라 변환 간에 비용이 존재할 때 이를 반영해 최소화하는 것을 목표로 하는 방법론이 OT입니다. 

OT는 원래 Transportation 문제를 풀기 위해 고안된 방법인데, 확률론 쪽에서도 많이 쓰인다고 합니다. 개념 자체는 n개의 금광과 m개의 공장이 있는데, 이 때 n개의 금광에서 m개의 공장으로 어떻게 금을 실어 날라야 최소한의 비용으로 금을 이쁘게 분배할 수 있을지를 고민하는 것으로부터 시작했습니다. 문제를 정의해보면 아래와 같이 정의됩니다.

감마가 행렬이고, C가 비용함수, H는 entropy를 나타냅니다. 자세한 내용은 Transportation theory를 참고해보셔도 좋을 것 같습니다.

우리 문제에서는 비용은 위에서 구한 Fisher-Rao distance가 되겠네요. 이 Fisher-Rao distance를 최소화하는 방향으로 p(x)를 변환하여 중간 지점에서도 Fisher-Rao distance를 계산할 수 있도록 사용됩니다. 다음과 같은 방식으로 말이죠!

델타 기호는 크로네커 델타 함수입니다. τ는 A에서 B로 이동한 정도를 나타냅니다. 이런 방식으로 통해 중간 지점에서도 p(x)를 계산할 수 있습니다. 이제, p(x) * p(y | x) = p(x, y)를 구할 수 있고, 이를 바탕으로 중간지점에서의 데이터를 생성하여 Fisher-Rao distance를 구하는 여정을 이어갈 수 있게 됩니다. 최종적인 알고리즘을 요약하면 다음과 같습니다:

위 알고리즘은 반복을 통해 최소의 FR distance를 갖는 w를 찾아가는 과정입니다. 이해하기 쉬운 순서는 (d) - (e) - (c) - (b) - (a) 로 생각하시면 될 것 같습니다.
(d) 이전 단계에서 생성된 OT 행렬을 바탕으로 joint distribution을 구하고, (e) 구해진 joint distribution으로부터 샘플을 생성합니다. (c) 생성된 샘플을 바탕으로 최소 거리를 갖는 trajectory를 생성하고, (d) 생성된 trajectory를 따라 FR distance를 계산합니다. (a) 구해진 FR distance를 기반으로 OT 행렬을 최적화합니다.

Experiment:


Task간의 distance를 나타낸 지표입니다. Asymmetric한 것은 어느 정도 당연한데, 예를 들어 CIFAR100 데이터를 학습한 모델이 CIFAR10을 학습하는 것이 그 반대보다 훨씬 쉽기 때문입니다. 위 히트맵을 설명하자면, (a)는 원 논문을 적용한 것, (b)는 Task2Vec라는 알고리즘을 적용한 것, (c)는 transfer learning에서 fine-tuning을 통해 도달하는 방식을 통해 측정한 것입니다. (a)의 4행 1열을 보면 0.31로 검게 칠해져 있는데, 이는 vehicles를 학습한 모델이 CIFAR 데이터를 학습하는데까지의 거리가 0.31이라는 의미입니다. 각 알고리즘 간 숫자는 비교불가능하기 때문에 숫자의 크기보다는 색깔에 집중해서 보는 것이 좋습니다. (b)의 경우 애초에 symmetric할 뿐더러 다른 모델들이 CIFAR100 데이터를 학습하는게 그리 어렵지 않다고 나오는 것으로 봐서 좋은 distance라고 할 수 없습니다. (c)의 경우 (b)보다는 낫지만 우리가 얻고자 하는 직관은 전혀 제공해 주지 못합니다.


Conclusion:

설명이 길었네요.
저도 이해하는데 굉장히 오래걸렸고, 원 논문이 그렇게 친절하게 적혀있는 것도 아니어서 좀 더 어려웠던 것 같습니다. 물론 이해한 내용이 완벽하진 않고 아예 틀렸을 수도 있지만 최대한 쉽게 풀어보고자 노력했습니다.

재밌는 점은 Appendix 부분에 FAQ가 있어서 내용의 부가설명 및 본 논문을 향한 비판들을 방어하고 있습니다. 3번을 보면 '아니 이거 돌리려면 모델을 몇 번이나 돌려야 하는데, 그럼 이거 해서 쓸데가 있나? 전이학습을 위한 좋은 데이터셋 찾는 데는 도움이 되는건가? 재미는 있는데, 실전에 쓰기 너무 복잡하다.'라는 비판이 등장합니다. 저자는 이에 대해 'learning task간의 거리에 대한 이해도 및 직관을 얻을 수 있다.' 정도로 답하고, 실전성에 대해서는 부족함을 인정했습니다. 그러나 저자가 말한 'leave a lot on the table'에는 어느 정도 동감하는 편입니다. Bayesian도 처음에는 딥러닝에의 적용이 복잡했다가 앙상블이나 dropout, Batchnorm을 통한 근사법이 나온 것처럼 이 분야도 좀 더 간단한 방법론이 제시되고 많은 직관이 곁들여지게 되면 한 단계 발전된 알고리즘에 도달할 수 있을 거라고 생각합니다. 
 

댓글 없음:

댓글 쓰기