기존 Deep network는 학습데이터와 테스트 데이터가 similar한 distribution일때는 잘 동작해 왔습니다. 그러나, Real-World와 같이 unseen 데이터가 갑작스럽게 들어오는 경우에는 오작동하는 경우가 다반사입니다. 이러한 문제를 domain shift라고 합니다. 이 문제를 다루는 Domain Adaptation분야는 다양하게 연구되어오고 있습니다. 최근에는 TTA(Test-Time Adaptation)라는 분야가 연구가 활발히 되고 있습니다. 기본적인 problem setting은 source model에는 access가능하지만 data privacy와 같은 이유로 source data에는 no available한 상황을 가정하고 있습니다.
구체적인 설명은
2023.06.2 - TENT: FULLY TEST-TIME ADAPTATION BY ENTROPY MINIMIZATION (ICLR 2021)서론 부분을 참고해보세요!
본 논문에서 말하는 TTA의 mojor challenge
- 어떻게 target domain의 label정보 없이 target domain의 representation을 학습할 수 있을까?
- source domain의 classifier만으로 target domain classifier를 만들 수 있을까?
위의 difficulty를 해결하기위해 제안하는 method는 크게 두가지 key factor가 있습니다.
- Self-Supervised Contrastive Learning
- Pseudo Labeling
위 두가지를 online으로 학습함으로 2022년 기준으로 TTA task에서 SOTA를 찍었네요.
Realted Works
최근에 TTA를 다룬 논문인 TENT, SHOT과 간략히 설명해보겠습니다.TENT는 Target domain에 대해 Test Entropy를 minimization하는 objective통행 학습하는 전략이고, SHOT은 Pseudo Labeling를 이용해서 Entropy Minimization를 하는 전략을 취합니다. 본 논문에서 언급하는 두가지 method의 limitation은 다음과 같습니다.
The entropy minimization does not model the relation among different samples.
More importantly, distrupts the model calibration on target data due to direct entropy optimization.
The pseudo labels are updated only a per-epoch basis, which fails to reflect the most recent model improvment during an epoch.
Method
Online pseudo label refinement(위 그림 a)
Online으로 pseudo label의 정제를 하기위해 epoch단위가 아닌 batch 단위로 pseudo label을 반영하도록 합니다. 그렇다면 정제는 어떻게 할까요? source data로 학습된 모델로 taerget encoder를 initiation 합니다.
taget image가 주어지면, weak augmentation를 통해 target encoder를 통과시켜 feature vector들을 만들어 냅니다. 그리고 target feature space에서 기존 target image 와 nearest neighbor를 voting하여 선택합니다. 선택된 feature vector들과 기존 target image의 feature vecotor w(weak feature)를 classifier에 태워 probability 평균을 산출합니다. 최종 pseudo label은 argmax를 통해 구합니다.
- Memory queue
위에서 다룬 nearest neighbor search를 하기 위해서, weak augmented target sample들의 feature, probability를 memory queue 길이 M만큼 저장합니다. 처음에는 random target sample를 사용하여 memory queue에서 관리됩니다. 제안 방법은 feature space를 보다 stable하게 하기 위해, 조금씩 바뀌는 momentum model를 feature, probability를 계산하기 위해 사용됩니다. - Nearest-neighbor soft voting
위에서 Memory queue에 저장된 feature과 w(weak feature)와의 cosine distance를 계산합니다. 이 수치를 통해, nearest neighbor를 뽑아서 probability의 평균을 구해, argmax operation으로 pseudo label 만들어내게 됩니다.
Joint self-supervised contrastive learning(위 그림 b, c)
기존의 제안된 self-supervised contrastive learning의 아이디어에 영감을 받아서, target data의 pair-wise information를 추출합니다. 기존방식과 공유되는 점은 다음과 같습니다.
Positive pairs: 같은 이미지의 different view들의 feature
Negative pairs: 다른 이미지의 feature
즉, 골자는 Positive pairs는 서로 끌어당기고, Negative pairs는 서로 멀어지도록 학습하는 원리입니다.
- Encoder initialization by source
제안된 방법의 momentum encoder는 source weight로 initialize되게 됩니다. 또한, momentum encoder는 memory queue를 update하는데 사용합니다. 즉, momentum encoder는 (a)에서 target feature, probability를 위한 memory queue 업데이트 및 (b)에서 contrastive feature를 만듭니다. - Exclusion of same-class negative pairs
Strong augmentaion으로 만들어진 두개 버전의 target image는 query와 key faeture로 encoding됩니다. 그 중 key feature들은 memory queue에 update됩니다.(그림 (b)에 해당)(참고: 이 memory queue는 이전에 weak augmented target 들이 저장된 memory queue와는 별개입니다.). 제안 방법은 MoCo에서 사용된 InfoNCE Loss를 통해 positive, negative를 정의해 학습합니다. Positive는 query와 key feature 사이의 similarity, Negative는 query와 memory queue feature 사이의 similarity를 뜻합니다. 기존의 Contrastive learning의 개념에 따라, Positive pair의 similarity는 작아지도록, Negative pair의 similarity는 커지도록 학습됩니다. 주의할 점은, negative pairs가운데는 같은 class도 포함되어 있기에 이를 제외하도록 합니다. 아래의 식이 위의 설명을 대변합니다.
Additional regularization
- Weak-strong consistency
TTA는 test에 대한 ground truth가 주어지지 않기에, weakly-augmented target image로부터 얻어지는 pseudo label를 가지고 strongly-augmented target image에 대한 예측에 대해 cross entorpy로 supervise합니다. 이는 weak-aug와 strong-aug의 예측의 consitency를 학습한다고 보면 됩니다. 본 논문에서는 ground truth를 가지고 있는 조건에서 refined pseudo label를 사용하기에 confidence thresholding도 하지 않기에 이점이 있다고 합니다. - Diversity regularization
위에서 pseudo label를 만드는 과정이 noises를 줄여줄 수 있는 효과는 있지만, ideal하지않기에 제안 방법에서는 regularization term를 추가하였습니다. Class diversification를 주기위해 class diversification loss로 학습합니다.
최종적을 위의 세가지 loss를 weighted sum하여 encoder가 학습을 하게 됩니다.
더보기를 클릭하시면 코드 베이스로 구체적인 설명을 추가하였습니다! 참고해주세요😁
코드리뷰(참고)
input: twss (test, weak(test_w), strong(test_q), strong(test_k)) 4개
model: momentum model(adamoco), encoder(adamoco)
- weak augmented image가 encoder에 태워져 feature(feats_w)와 logit(logits_w) 출력합니다. (위의 그림(a))
- 해당 feats_w와 memory queue(for w)에 있는 feature들과 distance를 계산해 nearest neighbor들 찾은 후, 해당top k에 대해 메모리 뱅크에 있는 각 클래스별 prob 평균을 구합니다. pseudo label은 해당 prob에 argmax를 취해서 만들어줍니다. 이렇게 만들어진 pseudo label은 앞서 설명에서처럼, classification 및 contrastive learning에 사용됩니다.(위의 그림(a))
- strong augmentation으로 생성된 test_q, test_k를 각각 encoder, momentum model로 feature를 뽑습니다. 이 두개의 feature를 앞으로 q, k라고 명명하겠습니다. (참고: shape: (Batch, Dim(feature)) 그리고, momentum model은 encoder와 momentum model의 parameter의 EMA로 업데이트 됩니다.
- 앞선 k와 pseudo label를 각각 memory queue(for k) feature 및 memory queue(for k) label에 업데이트 합니다.
- Contrastive learning을 위해 positive pair의 cosine similarity matrix, negative pair의 cosine similarty matrix를 산출합니다.
Positive matrix: q(Batch, Dim), k(Batch, Dim)를 성분끼리 곱하고 feature방향으로 summation를 하여 (Batch, 1)형태의 consine similarity matrix
Negtive matrix: q(Batch,Dim)와 memory queue feature(Dim, M(memory length)) 행렬곱을 하여 (Batch, M)형태의 consine similarity matrix.(참고: memory queue는 이전 배치까지의 moment model의 feature(K1, K2, ..., KM)입니다. )
즉, Positive matrix는 같은 sample간의 feature similarity를 표현하고 있고, Negative matrix는 q와 다른 샘플들과의 feature similarity를 표현하고 있습니다. 두 matrix를 concat하여 하나의 matrix로 만들어줍니다. (Batch, 1+M) - 본 논문의 제안 방법은 앞선 Positive, Negative matrix를 활용해 moco에서 제안된 contrastive loss인 InfoNCE를 적용합니다. 여시거 주의할점은 negative sample에 같은 class sample도 포함될 수 있다는 건데요. 이 문제를 해결하기 위해서 pseudo label과 memory queue label과 비교하여 다른 클래스에 해당되는 Negtive matrix 성분들만 masking합니다. (Batch, 0)이 positive, (Batch, maked idx)가 negative이므로 0 idex에 대해 F.cross entropy를 적용하여 contrastive loss를 산출합니다.
- logit q와 pseudo label로부터 cross entropy loss, logit q로부터 diversification loss를 취합니다.
- 앞선 세개 loss(contrast, cross entropy, diversification)를 통해 encoder를 업데이트합니다.
- 마지막으로, momentum model에 test_w를 태워서 feats_w, logits_w를 memory queue(for w)에 업데이트합니다.
Reference
- Chen, Dian, et al. "Contrastive test-time adaptation." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.
- He, Kaiming, et al. "Momentum contrast for unsupervised visual representation learning." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2020.
'Computer Vision > Test-Time Adaptation' 카테고리의 다른 글
TENT: FULLY TEST-TIME ADAPTATION BY ENTROPY MINIMIZATION (ICLR 2021) (0) | 2023.06.21 |
---|