Papers
Topics
Authors
Recent
Search
2000 character limit reached

Effective Gradient Sample Size via Variation Estimation for Accelerating Sharpness aware Minimization

Published 24 Feb 2024 in cs.CV and cs.LG | (2403.08821v1)

Abstract: Sharpness-aware Minimization (SAM) has been proposed recently to improve model generalization ability. However, SAM calculates the gradient twice in each optimization step, thereby doubling the computation costs compared to stochastic gradient descent (SGD). In this paper, we propose a simple yet efficient sampling method to significantly accelerate SAM. Concretely, we discover that the gradient of SAM is a combination of the gradient of SGD and the Projection of the Second-order gradient matrix onto the First-order gradient (PSF). PSF exhibits a gradually increasing frequency of change during the training process. To leverage this observation, we propose an adaptive sampling method based on the variation of PSF, and we reuse the sampled PSF for non-sampling iterations. Extensive empirical results illustrate that the proposed method achieved state-of-the-art accuracies comparable to SAM on diverse network architectures.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (34)
  1. Towards understanding sharpness-aware minimization. In Proceedings of the International Conference on Machine Learning (ICML), pages 639–668. PMLR, 2022.
  2. Entropy-sgd: Biasing gradient descent into wide valleys. Journal of Statistical Mechanics: Theory and Experiment, 2019(12):124018, 2019.
  3. When vision transformers outperform resnets without pre-training or strong data augmentations. 2022.
  4. Improved regularization of convolutional neural networks with cutout. arXiv preprint arXiv:1708.04552, 2017.
  5. Sharp minima can generalize for deep nets. In Proceedings of the International Conference on Machine Learning (ICML), pages 1019–1028. PMLR, 2017.
  6. Efficient sharpness-aware minimization for improved training of neural networks. 2022.
  7. Sharpness-aware training for free. Advances in Neural Information Processing Systems (NeurIPS), 35:23439–23451, 2022.
  8. Learned step size quantization. In Proceedings of the International Conference on Learning Representations (ICLR), 2020.
  9. Sharpness-aware minimization for efficiently improving generalization. 2021.
  10. Deep pyramidal residual networks. In Proceedings of the Conference on Computer Vision and Pattern Recognition (CVPR), pages 5927–5935, 2017.
  11. Deep residual learning for image recognition. In Proceedings of the Conference on Computer Vision and Pattern Recognition (CVPR), pages 770–778, 2016.
  12. Simplifying neural nets by discovering flat minima. Advances in Neural Information Processing Systems (NeurIPS), 7, 1994.
  13. Averaging weights leads to wider optima and better generalization. pages 876–885.
  14. Fantastic generalization measures and where to find them. 2020.
  15. An adaptive policy to employ sharpness-aware minimization. In Proceedings of the International Conference on Learning Representations (ICLR), 2023.
  16. On large-batch training for deep learning: Generalization gap and sharp minima. Proceedings of the International Conference on Learning Representations (ICLR), 2017.
  17. Learning multiple layers of features from tiny images. 2009.
  18. Asam: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In Proceedings of the International Conference on Machine Learning (ICML), pages 5905–5914. PMLR, 2021.
  19. Visualizing the loss landscape of neural nets. Advances in Neural Information Processing Systems (NeurIPS), pages 6391–6401, 2018.
  20. On the loss landscape of adversarial training: Identifying challenges and how to overcome them. Advances in Neural Information Processing Systems (NeurIPS), 33:21476–21487, 2020.
  21. Towards efficient and scalable sharpness-aware minimization. In Proceedings of the Conference on Computer Vision and Pattern Recognition (CVPR), pages 12360–12370, 2022.
  22. Sgdr: Stochastic gradient descent with warm restarts. arXiv preprint arXiv:1608.03983, 2016.
  23. Unique properties of flat minima in deep networks. In Proceedings of the International Conference on Machine Learning (ICML), pages 7108–7118. PMLR, 2020.
  24. Overcoming oscillations in quantization-aware training. In Proceedings of the International Conference on Machine Learning (ICML), pages 16318–16330. PMLR, 2022.
  25. Exploring the vulnerability of deep neural networks: A study of parameter corruption. In Proceedings of the AAAI Conference on Artificial Intelligence (AAAI), volume 35, pages 11648–11656, 2021.
  26. Improved sample complexities for deep neural networks and robust classification via an all-layer margin. In Proceedings of the International Conference on Learning Representations (ICLR), 2019.
  27. Qdrop: Randomly dropping quantization for extremely low-bit post-training quantization. In Proceedings of the International Conference on Learning Representations (ICLR), 2021.
  28. Wide residual networks. arXiv preprint arXiv:1605.07146, 2016.
  29. Understanding deep learning (still) requires rethinking generalization. Communications of the ACM, 64(3):107–115, 2021.
  30. Gradient norm aware minimization seeks first-order flatness and improves generalization. In Proceedings of the Conference on Computer Vision and Pattern Recognition (CVPR), pages 20247–20257, 2023.
  31. Regularizing neural networks via adversarial model perturbation. In Proceedings of the Conference on Computer Vision and Pattern Recognition (CVPR), pages 8156–8165, 2021.
  32. Towards understanding why lookahead generalizes better than sgd and beyond. Advances in Neural Information Processing Systems (NeurIPS), 34:27290–27304, 2021.
  33. Understanding the robustness in vision transformers. In Proceedings of the International Conference on Machine Learning (ICML), pages 27378–27394. PMLR, 2022.
  34. Surrogate gap minimization improves sharpness-aware training. 2022.
Citations (1)

Summary

  • The paper introduces vSAM, which adaptively samples the projected second-order gradient (PSF) to reduce redundant gradient computations in Sharpness-Aware Minimization.
  • The paper demonstrates that vSAM attains comparable or superior accuracy on models like ResNet-18 and WideResNet by leveraging a dynamic sampling strategy.
  • The paper’s method yields significant speed-ups, reducing training time by approximately 40% while maintaining robust model generalization.

Effective Gradient Sample Size via Variation Estimation for Accelerating Sharpness Aware Minimization

Introduction

The paper introduces an innovative approach to accelerate Sharpness-Aware Minimization (SAM), which balances the training loss and loss sharpness to enhance model generalization. SAM traditionally suffers from computational inefficiency due to its requirement to compute gradients twice per optimization step. This paper proposes an efficient sampling method that adapts based on the variation of the second-order gradient projections, known as the Projected Second-order Gradient (PSF), during the optimization process. This results in a new variant called variation-based SAM (vSAM) that achieves comparable accuracies with a significant reduction in training time.

SAM and Gradient Decomposition

SAM aims to find parameter values corresponding to flat minima to achieve better generalization performance. However, the standard SAM algorithm's double computation of gradients to determine perturbations increases its time complexity compared to traditional Stochastic Gradient Descent (SGD). The authors identify that the SAM gradient can be decomposed into the SGD gradient and the PSF. The primary innovation is recognizing that the PSF's rate of change can be exploited to skip unnecessary computations without sacrificing accuracy. Figure 1

Figure 1: Accuracy vs training speed of SGD, SAM, LookSAM, ESAM, SAF and vSAM (Our). Every connected line represents a method that trains WideResNet-28-10 and PyramidNet-110 models on CIFAR-100. vSAM substantially accelerates training with almost no reduction in accuracy.

Adaptive Sampling Strategy

A pivotal aspect of vSAM is the introduction of a dynamic sampling strategy based on PSF's variance. As the PSF changes at varying rates throughout training, the authors propose adaptively sampling the PSF less frequently when it changes slowly and more frequently when it changes rapidly. This adaptive strategy allows for significant computational savings, as shown by the 40% acceleration in training speed without compromising model generalization performance.

Implementation Details

The vSAM algorithm involves:

  1. Calculating the SGD gradient.
  2. Evaluating whether a PSF computation is needed based on its historical variance and magnitude relative to the SGD gradient.
  3. If a PSF calculation is deemed necessary, compute it and update the model parameters using the full SAM gradient.
  4. Reuse the last computed PSF in iterations where re-calculation is unnecessary.

This implementation leverages the stochastic nature of the variance in the L2-norm of the PSF, effectively anticipating when full recalculations can be avoided. Figure 2

Figure 2

Figure 2

Figure 2: Resnet-18

Experimental Results

The experiments conducted on CIFAR-10 and CIFAR-100 datasets using different architectures (e.g., ResNet-18, WideResNet-28-10, PyramidNet-110) demonstrate that vSAM achieves comparable, or even better accuracies compared to SAM, while significantly reducing training time. Specifically, models trained with vSAM reached state-of-the-art accuracy levels with reduced computational costs, thus validating the effectiveness of adaptive PSF sampling.

Implications and Future Directions

The proposed method achieves a critical balance between optimization efficiency and generalization performance, making vSAM a compelling choice for various applications in AI and machine learning, especially those requiring large-scale training under resource constraints. The adaptive sampling approach introduces a versatile framework to be potentially applied to other optimization problems where redundant calculations can be bypassed without detriment.

Conclusions

The adaptive sampling of gradients as proposed in vSAM offers a promising enhancement over traditional SAM by strategically reducing computational overhead while maintaining model performance. Future work could explore extending this methodology to other types of gradients or regularization terms, thereby broadening the applicability of this approach in more generalized contexts across various deep learning architectures and tasks.

Paper to Video (Beta)

No one has generated a video about this paper yet.

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 1 tweet with 0 likes about this paper.