본문 바로가기
Computer Vision/Transformer

DeiT(Training Data-efficient Image Transformers & Distillation Trough Attention)

by 흐긴 2022. 8. 21.

Introduction

이번에는 페이스북에서 발표된 "Training Data-efficient Image Transformers&Distillation Through Attention" 줄여서 "DeiT"논문을 리뷰 해볼려고 합니다. 최근들어서 트랜스포머 관련된 논문이 쏟아져 나오고 있습니다. 그 중에서 가장 대표적인 모델이 ViT입니다. ViT에 대해서 모르시는 분들은 여기를 참고해주시고 들으시면 도움이 될 거 같습니다. ViT의 장점으로는 Imagenet 데이터에 대해서 SOTA에 가까운 성능을 보여주었지만, 단점으로는 높은 자유도로 훈련에 많은 양의 데이터가 필요로 하는 문제점이 있었습니다.

이 논문에서는 Konwledge Distillation(지식 증류)방법을 이용한 Transfomer모델을 제안합니다. 지식 증류란 훈련 중에 Student 모델이 Teacher 모델로부터오는 소프트 라벨을 이용하는 훈련 방법입니다. 일반적으로 모델의 경량화를 위해 자주 사용되어집니다. 이 논문에서 지식 증류 방법을 갖고 오게 된 이유는 크게 두가지입니다.

  1. Crop같은 Data augmentation을 이용하여 학습시 이미지의 물체가 없는 영역에 대해서도 Crop될 경우, 잘못된 라벨로 훈련될 가능성이 있다.
  2. Teacher 모델의 inductive biases들을 소프트한 방법으로 Student 모델로 전이시킬 수 있다.

Method

Attention&Inductive bias 설명

Self Attention은 다음 그림과 같이 피쳐 맵이 인풋으로 들어왔을 때, 1차원의 벡터로 만들어서 선형 함수를 곱하여 쿼리, 키, 밸류를 생성합니다. 여기서 선형함수는 weight parameter로 학습값 입니다. 그리고 쿼리와 키를 행렬곱하여 소프트 맥스를 곱하여 attention map을 생성합니다. 마지막으로 밸류와 곱하여서 Self attention map을 생성합니다.

CNN의 Inductive biases는 하나의 값을 도출하기 위해서 일정한 지역의 값을 모두 함께 이용하는 것입니다. 그리고 모든 입력에 대해서 같은 weight을 사용한다는 것입니다. 하지만 Attention의 경우에는 어떤 인풋이 들어오느냐에 따라 다른 weight을 활용한다는 점입니다.

Knowledge Distillation

Soft label distillation는 teacher 모델의 소프트맥스 값과 자신의(student) 모델의 소프트맥스 값의 KL-divergence를 계산하고, student 모델의 예측치와 원래 라벨의 크로스 엔트로피를 적절한 비율 람다로 합산하여 학습을 진행합니다. Hard label distillation은 student 모델의 예측치와 진짜 라벨과의 크로스 엔트로피, 그리고 student 모델의 예측치와 teacher 모델의 예측치(소프트맥스 통과 후 원핫 인코딩한 값)의 크로스 엔트로피로 로스를 계산합니다. 본 논문에서는 두가지 모두 실험한 결과를 제공해주고 있습니다.

Distillation with knowledge distillation

이 논문에서 새롭게 제안한 방법은 Distillation token이라는 개념입니다. 원래 클래스 토큰이 있는 것처럼 끝에 distillation token을 추가하여 모델을 설계했습니다. 흥미로운 점은 이 토큰들이 학습되는 대상이므로 학습이 끝난 후 두개의 토큰이 얼마나 다른지 코사인 similarity을 측정해 보았더니 0.06정도로 두 가지가 굉장히 달랐다는 것입니다. 하지만, 레이어를 통과한 두개의 토큰들의 임베딩의 코사인 simillarity는 0.93정도로 매우 비슷한 값을 가졌다고 합니다. 그리고 추가 실험으로서 distillation token대신에 클래스 토큰을 추가해서 학습했을 때는 토큰간의 similarity가 0.99로 매우 비슷했다고 합니다.

Experiment & Results

Distillation

먼저, 어떤 Teacher 모델을 쓰면 좋을지를 실험한 결과입니다. 결론적으로는 Convolutional Neural Net을 쓰는 경우가 더 좋습니다. 논문에서는 CNN이 inductive bias를 갖고 있기때문에 그게 잘 전달이 되어서 성능이 더 좋다고 풀이를 하고 있습니다.

다음, soft label & hard label 가운데 무슨 distillation 방법이 더 좋은가를 실험한 결과도 보여줍니다. hard distillation 방법이 더욱 좋은 성능을 보여주었습니다. 그리고 class & distillation token을 함께 썻을 때 더욱 좋은 성능을 가진다는 것을 알 수 있습니다.

Efficiency vs accuracy

이 결과는 imagenet 데이터만으로 학습을 했을 때 성능을 비교한 결과입니다. 증류 기호가 있는 모델이 distillation token을 사용한 결과입니다. 토큰을 사용하지 않아도 EfficientNet보다 조금 낮은 성능을 가지고 있고 기존의 ViT보다는 훨씬 높은 성능을 보여주었습니다. 또한 비슷한 성능 대비 1초당 처리하는 이미지 숫자가 많다는것을 보여줍니다. 이는 cnn커널보다 큰 행렬 곱 계산이 transformer에서 이루어지기 때문이라고 생각됩니다.

'Computer Vision > Transformer' 카테고리의 다른 글

ViT(Vision Transformer)  (0) 2022.08.21
Transformer("Attention is all you need")  (0) 2021.05.06