2021년 7월 7일 수요일

Self-Coding Review(1) : keras custom loss를 편하게 짜는 법?

최근에 Actor-Critic 코드를 짜면서 생각보다 custom loss를 토대로 backpropagation을 하는게 어렵다는 느낌을 받았습니다. 저는 특히나 torch가 아니라 keras를 주로 사용하는데, keras의 경우에 custom loss를 따로 적용하는게 생각보다 복잡하더라구요. 물론 경우에 따라 다릅니다. 이번 글에서는 해당 custom loss를 쉽게 적용할 수 있는 방법?에 대해서 조금 적어보고자 합니다. 

keras에서 보통 사용하는 custom loss의 형태는 다음과 같이 구성됩니다. 

Link - https://keras.io/api/losses/ (keras 공식 홈페이지)

위처럼 보통 예측값과 실제값을 인자로 받아서 이를 바탕으로 backpropagation을 하게 됩니다.

보통의 경우에는 이것만으로도 충분한 경우가 많죠. 뉴럴넷이 뽑아낸 값과 실제값을 바탕으로 업데이트를 하는 경우가 대부분이니까요. 그렇지만 논문구현을 하는 등 필요에 따라서 그렇지 않은 경우에 기울기 업데이트를 해야하는 경우가 생기면 이러한 방법을 적용하기 어렵습니다.

예를 들어, actor-critic을 업데이트 하는 경우를 살펴볼 수 있습니다.
critic은 critic이 예측한 value값과 one-step forward value값을 줄여나가는 방식이기 때문에 평소에 사용하는 방법대로 update를 할 수 있습니다. 

critic의 gradient update

그러나 actor의 경우 조금 다른 방식으로 update가 이루어지기 때문에 loss custom이 필수적입니다. 
actor의 gradient update

따라서 대부분의 actor-critic을 구현한 github을 보면 custom loss를 나름의 방식으로 정의하고, update를 하고 있습니다. 그런데, keras는 torch에 비해서 이런 부분에서 조금 어려움이 있어서 그런지 제가 찾아보았을 때는 잘 정리된 github이 적었을 뿐더러 예전 자료여서 최근의 keras 버전과 잘 호환이 되지 않았습니다. 그래서 actor-critic을 구현하는데 어려움을 겪고 있었는데, 나름의 방법을 찾아서 까먹지 않게 적어놓으려고 합니다.


<코드 실행 기준 tensorflow 버전은 2.4.1이고, python 버전은 3.8.5입니다.>

우선, 기본적으로 tensorflow에서 제공하는 gradient tape를 사용합니다. (Link - https://www.tensorflow.org/api_docs/python/tf/GradientTape)

gradient tape를 사용하면, tensor 연산에 한해 행해진 연산들을 모두 기록하여 해당 기록을 통해 gradient update를 실행할 수 있습니다. 제가 처음에 시도해봤던 방법은 keras model을 구축하고, 해당 모델의 predict 연산을 통해 input에 model layer들을 모두 덧씌우는 연산을 적용해주는 과정을 gradient tape에 기록하여 update를 하는 방법이었습니다. 그러나 이 경우 치명적인 결함이 keras의 model.predict를 사용하면 model의 output이 numpy 형태로 도출되어 gradient tape에 model에서 적용된 연산을 추적할 수가 없습니다. 

그래서 찾아낸 대안은 model.layers를 이용해 직접 덧씌우는 방법입니다. 즉, model.layers[-1](model.layers[-2](model.layers[-3]....)을 하면 model의 feedforward를 그대로 적용하면서 output이 tensor로 나오기 때문에 gradient tape에 연산을 기록할 수 있게 됩니다. 따라서 gradient update 또한 가능합니다. 저는 다음과 같이 적용했습니다.



이를 gradient tape과 같이 활용한다면 custom loss를 마음대로 구현할 수 있습니다. 예를 들면 model 전체의 update 말고 부분적인 update를 원하는 경우 등 여러 방향으로 적용할 수 있을 것 같습니다. 적어놓고 보니 별로 어려운 건 아닌 것 같은데 저는 이거 찾느라 엄청 고생했네요...ㅠ

제 경우에는 actor의 action이 continuous한 경우를 만드려고 해서 조금 더 github에 올라온 코드들보다 복잡하게 설계했어야 하는데 discrete한 경우에는 github에 올라와있는 코드들도 충분히 도움이 될 것 같습니다.

댓글 없음:

댓글 쓰기