2021년 9월 2일 목요일

PR402 - 8: A Simple Baseline for Bayesian Uncertainty in Deep Learning (SWAG)

 오늘 가져온 논문은 A Simple Baseline for Bayesian Uncertainty in Deep Learning (Maddox et al., NIPS 2019)입니다. 이전 포스팅 중 https://ajlab402.blogspot.com/2021/08/pr402-6-loss-surface-simplexes-for-mode.html과 연결되는 내용입니다. 오늘의 논문을 바탕으로 loss surface를 이해하여 실전에 써먹는 논문인데 둘 다 굉장히 재미있고 매력적입니다. 두 논문 모두 읽어보시길 추천드립니다!

Link - https://arxiv.org/abs/1902.02476


Intro:

Kaggle이나 Dacon같은 대회 수상 솔루션들을 보다보면 가끔 보이는 SWA(Stochastic Weight Average)라는 알고리즘이 있습니다. neural net 내부의 weight들을 최적화하는 과정에서 weight들의 평균값을 사용하여 보다 나은 weight을 찾아가는 방법입니다. 보통 매 batch마다 고정되어 있는 learning rate와 방향을 가진 SGD와 같이 사용됩니다. 

오늘 소개드릴 논문은 위의 SWA에서 한 단계 더 나아간 SWAG (SWA-Gaussian, 이름 마음에 드네요 ㅋㅋ) 알고리즘을 제안하는 논문입니다. SWA는 weight들의 1차 모멘트인 평균값만을 사용하지만, SWAG에서는 2차 모멘트인 분산까지 활용합니다. 분산을 활용하기 때문에, 모델의 성능과 더불어 Uncertainty까지 부산물로써 얻을 수 있다는 장점이 있습니다.


Contents:

기본이 되는 SWA의 weight는 다음과 같이 주어집니다:

학습을 시키면서 도달한 weight space에서의 point들의 평균을 내어 사용하는 방식이죠. 마치 동일 아키텍쳐의 뉴럴넷 여러개의 앙상블을 할 때 weight을 평균내어 사용하는 방식과 비슷한 효과를 낼 수 있습니다. 위에서 구한 Өswa를 '평균'으로 인식한다면, 분산을 다음과 같은 방식으로 구해볼 수 있습니다.


E(X^2) - E(X)^2 = V(X)인 분산의 공식을 사용하여 weight 각각의 분산을 구한 것입니다.
이를 바탕으로 weight의 posterior를 다음과 같이 나타낼 수 있습니다.


그러나 이렇게 되면 weight 각각의 분산만을 고려하게 되고, Covariance를 전혀 고려하지 못하게 됩니다. 따라서 다음과 같은 방식으로 Covariance를 고려해주어야 합니다.


정석으로 Covariance의 근사치를 위와 같이 구해줄 수 있습니다. 그러나 Өswa의 경우 학습이 끝나고 나서야 접근 가능하므로, Өswa대신 학습과정에서 구해진 Өi들의 평균을 사용하여 근사치를 구할 수 있습니다. 


위처럼 Covariance Matrix를 구성할 수 있는데, 만약 위처럼 1~T의 전 학습과정에 걸쳐 일일이 공분산을 계산해주게 되면 연산량도 무시 못할 뿐더러 (본래의 목적 중 하나인 앙상블 없이 가려는 효율성을 해치게 됩니다.) 너무 과거의 weight 간 관계까지 matrix에 녹아들게 되어 저해할 위험이 있으므로, 가장 최근의 K개만을 대상으로 Covariance Matrix를 구해줍니다. 이를 Σlow-rank라고 부릅니다. (K 또한 Hyperparameter입니다.) 

그러나 위처럼 Covariance를 계산해주게 되면 결국 Өswa의 근사치로 Өi의 평균을 사용한 것이기 때문에 이를 보정해주고자 위에서 Өswa를 직접 사용한 분산과 합쳐서 최종적인 Multivariate Gaussian Distribution을 다음과 같이 구합니다.


(개인적인 의견 - 이렇게 되면 대각성분의 경우는 이론적으로 근사치라고 할 수 있지만, 비대각성분의 경우는 뒤의 low-rank에서만 고려되기 때문에 근사치의 절반 정도로 나옵니다. 그러나 그럼에도 이렇게 한 이유는 각각의 variance를 더 중요시하기 때문이며, covariance 또한 low-rank로 근사한 것이기 때문에 정확도가 낮아 covariance의 영향을 약화시키는 역할 또한 맡을 수 있기 때문이라고 짐작됩니다. 물론 몇 번 돌려보고 성능 때문에 이렇게 했을 확률이 제일 높지만 굳이굳이 정당화를 하자면요... 물론 이건 제 의견일 뿐입니다!)

위 분포를 바탕으로 weight를 sampling하면 다음과 같이 구해집니다.


위를 통해 weight space에서의 분포를 근사할 수 있습니다. 
최종적인 알고리즘은 아래와 같습니다.


왼쪽은 SWAG을 통해 weight의 분포를 찾아내는 과정입니다. (자세한 notation은 논문을 참조해주시면 감사하겠습니다.) 오른쪽은 찾아낸 분포를 바탕으로 Bayesian Inference를 하는 과정입니다. 


우변의 적분식은 S번의 평균을 통해 근사되는 값으로, 위의 알고리즘에서는 1/S를 곱해서 계속 더해주는데, p(y* l Data)를 Bayes룰을 통해 평균으로 근사시킨다고 이해하는 쪽이 직관적으로는 더 편합니다.

Experiment:

여러 가지 결과를 Reporting해주는데, 우선 성능 부문입니다. SWAG 모델 하나가 SGD 모델을 3개 앙상블한 것보다 성능이 뛰어났고, SGD 모델을 5개 앙상블한 것과 견줄만 했습니다. (놀랐던 점은 SWAG 모델이 이미 앙상블을 한 것과 비슷한 효과를 냈음에도 SWAG 모델끼리의 앙상블을 하니 성능이 더 많이 올랐습니다.) -> 이는 CIFAR 데이터를 기반으로 했고, Vision 데이터 말고 Language 데이터인 Penn Treebank와 WikiText-2 데이터셋에서도 기존의 알고리즘들을 뛰어넘는 성능을 보였습니다.

아래는 추가적으로 Bayesian Uncertainty를 계량화한 그래프입니다.

자세히 보면 Confidence와 Accuracy가 주요 척도로 나오는데, Confidence는 classification 문제에서 softmax로 출력했을 때 가장 높은 값을 의미합니다. 모델이 uncertainty를 정확히 캐치하고 있으면 Confidence가 높을 때 accuracy도 높고, confidence가 낮을 때 accuracy도 낮게 됩니다. 따라서 confidence와 accuracy의 차이가 모델의 uncertainty 캐치 정도를 어느 정도 계량화하고 있다고 이해할 수 있는데, 위 그래프에서 y축에 명시되어 있습니다. 즉, 위 그래프를 해석할 때는 수평점선에 가까울수록 좋습니다. SWAG을 보면 대부분의 알고리즘들보다는 가까운 것을 알 수 있습니다. (Uncertainty Estimation이 훌륭합니다.)

Conclusion:

저번 포스팅을 하고 나서 SWAG 알고리즘을 꼭 읽어봐야겠다고 생각했는데 기대보다도 더 재미있었습니다. 아쉬웠던 부분은 주석을 달았던 diag와 low-rank를 합치는 부분입니다. 그래도 해당 알고리즘이 단지 모델의 성능을 높이는 데만 사용된 것이 아니고, weight space에서의 분포를 짐작할 수 있게 해줘서 느끼는 부분이 있었습니다. 이에서 더 발전하여 저번 포스팅으로 이어져 loss surface simplexes와 연결된다는 점도 스토리적으로 정말 재밌는 것 같습니다ㅎㅎ 개인적으로 이 부분이 연구해볼 점도 많고, 또 매력적이라고 생각합니다!


Github:

SWAG 알고리즘 자체가 적용이 굉장히 쉬운 편이라 가볍게 구현을 해보았습니다. 원 저자가 만든 repo도 물론 있습니다만, 그냥 개념만 잡으려면 제 코드를 보시면 훨씬 간단할 것 같습니다. 또한 torch보다 keras에 더 편하신 분들에게도 좋을 것 같습니다. 단점이라면 GPU를 사용하지 않아서 Bayesian model averaging 단계에서 상당히 느립니다. 

Github Link : https://github.com/Eternity402/SWA-Gaussian_Keras

구현하면서 느낀 점은, 예전의 weights까지 다 모델에 포함되다보니 학습이 많이 진척이 된 후에야 기존의 SGD를 뛰어넘는 성능을 보여주는 것 같았습니다. 마지막 성능 최적화 단계까지 가서야 앙상블 모델보다 효율적일 것 같습니다(시간, 계산량 대비 성능 부분에서).  SWAG은 optimizer 자체로서의 능력보다는, Loss surface를 가늠할 수 있게 해주는 하나의 툴로서 바라보는 것이 더 좋아보입니다.


+추가)

SWAG이 성능이 생각보다 잘 나오지 않아서 알고리즘을 살짝 바꾸어 돌려보았습니다.

위에서 제가 제기했던 문제점처럼 covariance matrix가 맘에 들지 않아서 변형을 해주었는데요, 두 가지 방법을 더 시도해 보았습니다. 첫 번째는 아예 Σlow-rank만 사용해준 방식이고, 두 번째는 Σdiag만 사용한 방식입니다. 시간이 없어서 그냥 학습 도중에 바로 plot을 해보았는데 우선 중간까지는 다음과 같은 결과를 낳았습니다.


확실히 공분산을 고려해주지 않는 diag만 사용한 경우 그냥 swag보다도 현저히 성능이 떨어집니다. 놀랍게도 lowrank만 사용해준 경우 오히려 swag보다도 높은 성능을 보여줍니다. 물론 아직 sgd보다 낮은 것은 K를 10으로 설정해주어서 16 epoch 째에서도 7 epoch째의 weights가 영향을 미치고 있는 상태이기 때문인 것으로 추정됩니다. 

위에서 보면 swag이 마치 swag-lowrank와 swag-diag의 가중평균인 것처럼 느껴져서 그럼 만약 lowrank에서 covariance의 크기를 오히려 증가시키면 성능이 올라갈지 궁금해져서 그것도 plot해보았습니다. (즉, covariance의 근사치가 아니라 분산을 오히려 증가시킨 case)


당연히 swag-lowrank보다 떨어질 줄 알았는데 신기하게도 swag-lowrank랑 거의 비슷한 성능을 보여줍니다. 생각보다 covariance가 weight space에서 robust하게 분포해 있는 것 같네요. 아마 이 covariance matrix를 잘 근사하는 알고리즘이 사용된다면 훨씬 좋은 성능을 보여줄것으로 기대됩니다.

댓글 없음:

댓글 쓰기