TL;DR
- Problem: Federated Learning에서 data heterogeneity로 인한 global model의 convergence 어려움
- forgetting이 FL의 bottleneck임을 가설로 세움.
- global model이 이전 round의 knowledge를 잊어버리는 것을 발견함.
- local training은 local distribution 밖의 knowledge를 잊어버리도록 유도함.
- FedNTD(Federated Not-True Distillation) : not-true class에 대해서만 global 유지 (?)
- global model의 knowledge를 local distribution 밖에 유지.
→ FedNTD가 cost/privacy 손실 없이 다양한 도메인에서 SOTA 달성!
Motivation
Federated Learning과 Continual Learning이 유사.
- Federated Learning : client간 data heterogeneity(non-IID) 때문에 global performance 달성이 어려움.
- 새로운 local 학습시 global parameter drift.
- Continual Learning: sequence of task가 들어올때 모든 task에 대해 well-perform 해야하는 objective, but 이전 정보를 잊어버리는 문제(catastrophic forgetting).
- 새 task에 대해 학습시 previous task에 대한 parameter drift.
→ 이 가설의 검증을 위해 1) global model consistency와 2) round별 class-wise acc 측정.
Method
- Global model prediction consistency
- Non-IID에서 round 별로 class-wise acc가 변하고, consistency 유지가 안됨.
→ heterogeneity가 커질수록
- Knowledge outside of local distribution
- 매 communication round 마다 local / global model의 accuracy 측정.
- local training 이후 local model이 가지고 있는 data에 대해서는 잘 perform. (어떻게 보면 당연)
- 하지만 local model이 가지고 있는 data 외의 것들에 대해서는 degrade.
→ out-local distribution에 대한 forgetting이 일어나고 있음에 대한 근거.
Experiments
- Algorithms: FedAvg, FedProx, SCAFFOLD, MOON, FedCurv, FedNova, FedNTD
- Datasets: MNIST, CIFAR-10, CIFAR-100, CINIC-10
- 다양한 환경에서 다는 algorithm들 보다 좋은 성능을 내는 것을 확인.
Questions
- 왜 not-true에서만 distillation 하는 것인지?
- true label을 제외하는 이유가 있는것인지?
- Classification에서는 다른 algorithm들 보다 성능이 좋은데, personal re-id에서는 왜 MOON이 더 좋은 것인지?