논문읽기: FedMatch(FEDERATED SEMI-SUPERVISED LEARNING WITHINTER-CLIENT CONSISTENCY & DISJOINT LEARNING)
Federated Learning과 Semi-supervised learning을 결합한 FSSL(Federated Semi-Supervised Learning)에 대한 연구이다.
우리나라 카이스트에서 나온 논문이다.

https://arxiv.org/abs/2006.12097
Federated Semi-Supervised Learning with Inter-Client Consistency & Disjoint Learning
While existing federated learning approaches mostly require that clients have fully-labeled data to train on, in realistic settings, data obtained at the client-side often comes without any accompanying labels. Such deficiency of labels may result from eit
arxiv.org
논문의 주요 아이디어는 다음 2가지로 요약할 수 있을 것 같다: 1.Inter-Client consistency loss, 2. Parameter decomposition for disjoint learning
Inter-Client consistency loss
Inter-Client consistency loss
Semi-supervised learning에서는 Consistency loss라는 개념을 많이 사용한다. Data augmentation을 하더라도 Class분류가 바뀌면 안된다는 점에 착안하여 만들어진 개념이다. FedMatch에서는 이 개념을 Federated learning에 적용하여 아래와 같은 Loss로 정의하고, inter-client consistency loss(ICCL)로 정의하였다.
$$ \frac{1}{H} \sum_{j=1}^{H}KL[p_{\theta^{h_j}}^{*}(y|u)||p_{\theta^{l}}(y|u)] \\
where \: KL: Kullback-Leibler \: divergence, \\
h_j: Selectetd \: agent \: for \: local \: model \: l $$
(여기에서 Agent는 Client의 Local model과 가장 비슷한 Client의 Model로 선택된다.)
각 Local model과 Agent들과의 분포가 유사하게 유지되도록 학습을 진행하겠다 정도로 해석할 수 있겠다.
Pseudo-label
뿐만 아니라, 각 Unlabeld data에 대한 학습을 위해 Pseudo-Label을 생성하고 이를 이용하여 학습을 진행하였다. Pseudo-label은 아래와 같이 생성한다.
여기에서 $ \unicode{x1D7D9} $ 는 Softmax값을 One-hot vector로 만드는 함수이다. 일정 Threshold를 넘지 못하면 아예 무시한다. Max함수는 일반적인 의미로 쓰인 것이 아니고, Class별로 투표를 해서 가장 많은 득표를 한 Class로 Pseudo-label을 정하겠다는 것을 의미한다.
이렇게 만든 Pseudo-label $ \hat{y} $와 Aument된 이미지로부터 추정한 확률값을 이용하여 Cross Entropy를 계산한다.
Consistency regularization term
최종 Consistency regularization term은 Pseudo-label을 이용한 Cross entropy와 Inter-client consistency loss를 결합하여 아래와 같이 정의하였다.
Parameter Decomposition For Disjoint Learning
일반적인 Semi-Supervised learning과는 다르게 Federated라는 제약이 추가되었기 때문에, Labeled data와 Unlabeled data에 대한 학습이 같이 진행될 수 없다. 따라서 Parameter를 2개의 set으로 분리하고 따로따로 학습을 진행하였다.
$$ \theta = \sigma + \psi \\
\sigma: Parameters \: for \: supervised \: learning \\
\psi: Parameters \: for \: Unsupervised \: learning $$
Supervised learning은 아래의 Loss로 진행하고, ($ \psi $는 고정)
Unsupervised learning은 위에서 정의한 Consistency regularization을 활용하여 아래의 Loss로 진행한다. ($ \sigma $는 고정)
$ \psi, \sigma $가 너무 멀어지지 않도록 두개의 Regularization term을 추가하였다.
Scenarios
Labels-at-client와 Labels-at-server 시나리오가 있다. Label data가 클라이언트에 있는지 서버에 있는지에 따라 달라지는 것이고, Labels-at-client의 경우 SL을 클라이언트에서, Labels-at-server의 경우 SL을 서버에서 하는 것이 다르다.
Experiments
- 타 SSL 방법론을 단순히 적용한 것에 비해 좋은 성능을 보여주었음.(UDA, FixMatch)
- Forgetting issue를 해결해 주기 때문인 것으로 보임.
- Forgetting issue: Labeled data를 통해 학습된 지식이 Unlabeled 트레이닝과정에서 점차 희석되는 현상. 즉, Inter-task inference가 발생하기 때문.
- Labels-at-Server에서 타 방법론 대비 우수성이 더욱 두드러지게 관측됨.
- Non-IID task에서도 큰 성능하락이 관측되지 않음.
Ablation studies
- Inter-client consistency loss를 제거하면 성능이 소폭 하락. Chart (a)
- $ \psi \: or \: \sigma $을 제거하면 성능이 크게 하락, 특히 $ \sigma $를 제거하면 트레이닝이 거의 되지 않음 Chart (b)
- (c)에서 Forgetting issue를 관측할 수 있음. 타 SSL 방법론의 경우 Round가 올라갈수록 성능이 떨어지는 현상이 관측됨.
- Label data를 늘려가면 전반적인 성능이 향상됨. 당연한 것으로 보이나, 타 방법론의 경우 한번 감소했다가 증가하는 모양이 관측되었음. Chart (d).