- The paper presents the Generalized Kullback-Leibler (GKL) Divergence Loss, which decouples KL divergence to address optimization asymmetry and sample bias.
- The approach incorporates class-wise global information instead of per-sample predictions, reducing bias and mitigating overfitting risks during training.
- Experimental results demonstrate the GKL loss achieves state-of-the-art adversarial robustness on RobustBench and delivers competitive knowledge distillation performance.
Overview and Motivation
The Generalized Kullback-Leibler (GKL) Divergence Loss extends the classical KL divergence loss by leveraging a decoupled formulation. This formulation breaks the loss into two main components—a weighted Mean Square Error (wMSE) term and a Cross-Entropy component with soft labels—thereby enabling more granular control during optimization, particularly in tasks like adversarial training and knowledge distillation. The approach addresses two inherent issues in the conventional KL loss: the asymmetric optimization behavior and bias induced by sample-wise predictions.
The paper rigorously proves that the classical KL divergence loss can be equivalently rewritten as a Decoupled KL (DKL) divergence loss with the following structure:
- Weighted Mean Square Error (wMSE) Component: This term encapsulates the prediction error by weighting the MSE according to the soft label distribution. The classical KL divergence in its raw form might neglect gradients contributed by this component in specific applications, e.g., when teacher outputs are detached. By explicitly modeling this component, the approach ensures that the convergence properties for high-probability classes are not compromised.
- Cross-Entropy Loss with Soft Labels: The second term mirrors the standard cross-entropy loss but operates on soft labels, thereby capturing the entropy part of the divergence. This is crucial during tasks like knowledge distillation where the teacher’s probability distribution provides a softer supervision signal.
The decoupling allows for addressing the limitations of the original KL loss by isolating the effects of each component on the gradient flow during optimization.
Addressing Asymmetric Optimization
One of the key limitations of traditional KL divergence is its asymmetric gradient propagation. Specifically, the majority of the gradient update may be dominated by one of the loss components, leading to insufficient corrections from the other part. In knowledge distillation, this asymmetry results in inadequate gradient flow from the wMSE component, particularly for classes with high predicted scores. The GKL formulation breaks this asymmetry by:
- Ensuring that both the wMSE and the Cross-Entropy components contribute meaningful gradients.
- Smoothing the weighting function so that classes with high predicted probabilities are better moderated, hence alleviating convergence difficulties during training.
The modification leads to an improved optimization dynamic, which in turn results in more stable and reliable convergence properties across varied tasks.
Alongside addressing asymmetric optimization, the approach introduces a critical enhancement—incorporating class-wise global information into the loss computation. Rather than relying solely on per-sample predictions, which can be susceptible to noise and outlier effects, the GKL loss employs a global statistic for each class. This is done by:
- Aggregating predictions at the class level to form a more stable guiding signal.
- Using a smoother class-wise weight function that reduces the influence of individual noisy samples.
This formulation not only reduces sample-wise bias but also mitigates overfitting risks that can arise from hard examples, leading to a more consistent training procedure in both adversarial scenarios and distillation settings.
Experimental Validation and Numerical Results
The effectiveness of the GKL loss is validated through experiments on multiple benchmarks, including CIFAR-10/100, ImageNet, and vision-language datasets. The results indicate:
The GKL loss achieved new state-of-the-art robustness on the RobustBench leaderboard. The decoupled gradient contributions and class-wise regularization significantly improve the model’s ability to withstand adversarial perturbations.
The improved gradient balance and bias mitigation translate to competitive distillation performance. The method demonstrated robust performance across diverse architectures (ranging from CIFAR models to large-scale ImageNet models, as well as in CLIP-based systems), effectively transferring the soft label information from teacher to student.
These numerical improvements confirm the practical merits of the approach, making it well-suited for real-world deployments where stability and robustness are paramount.
Practical Implementation Considerations
For implementation, the general structure of the GKL Divergence loss can be integrated into existing deep learning frameworks. Key considerations include:
- Loss Function Decomposition:
Implement the loss by explicitly computing the two components. Pseudocode for the combined loss might look like:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
def GKL_loss(student_logits, teacher_soft_labels, class_weight_global, alpha=0.5):
# Compute probabilities
student_probs = softmax(student_logits)
teacher_probs = teacher_soft_labels
# Weighted Mean Square Error (wMSE) component
wMSE = class_weight_global * torch.mean((student_probs - teacher_probs) ** 2)
# Cross-Entropy component with soft labels
CE = - torch.sum(teacher_probs * torch.log(student_probs + 1e-8), dim=1).mean()
# Combine with balancing factor alpha
loss = alpha * wMSE + (1 - alpha) * CE
return loss |
Ensure that the gradient flow through both terms is preserved. This might require careful tuning of any detachments that occur naturally in knowledge distillation where teacher outputs are precomputed.
- Global Information Integration:
Compute class-wise global predictions over a batch or an epoch, which can then be used to adjust the weight function dynamically. This might involve tracking running averages per class and updating the class-wise weights periodically.
The additional computations to aggregate class-level statistics introduce marginal overhead. However, such operations are highly parallelizable and can be efficiently implemented within most modern deep learning libraries.
As demonstrated by experiments in adversarial training and distillation, the GKL loss is robust across various model architectures. Integrating this loss into current pipelines is straightforward given the modular structure, and the publicly available implementation can accelerate experimentation (2503.08038).
Conclusion
The Generalized Kullback-Leibler Divergence Loss provides a refined approach to utilizing KL divergence in training, particularly for adversarial robustness and knowledge distillation. By decoupling the loss into a weighted MSE and a Cross-Entropy term, and by introducing mechanisms to address both asymmetric optimization and sample-wise bias, the method offers theoretically justified and practically validated improvements over traditional approaches. The strong numerical results and state-of-the-art scores on benchmarks like RobustBench highlight its utility in demanding applications, making it a valuable tool for practitioners aiming to enhance model robustness and performance in challenging environments.