empirical-sam

The gradient magnitude is a crucial factor in determining the ability of SAM to find flat minima.

https://github.com/lhchau/empirical-sam

Science Score: 44.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (8.2%) to scientific vocabulary
Last synced: 6 months ago · JSON representation ·

Repository

The gradient magnitude is a crucial factor in determining the ability of SAM to find flat minima.

Basic Info
  • Host: GitHub
  • Owner: lhchau
  • Language: Python
  • Default Branch: master
  • Homepage:
  • Size: 44.9 KB
Statistics
  • Stars: 2
  • Watchers: 1
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created over 1 year ago · Last pushed over 1 year ago
Metadata Files
Readme Citation

README.md

Empirical Study on Sharpness-Aware Minimization (SAM)

Table of Contents

  1. Research Question
  2. Optimization Background
  3. Experiments
  4. Observations and Reproducibility
  5. Intuition from Update Rule
  6. Conclusion
  7. Cite this repository

Research Question:

Between the magnitude and direction of the gradient, which is more important to SAM?

We decompose a gradient vector into two components: - Direction: $\frac{\nabla L(wt)}{|\nabla L(wt)|}$ - Magnitude: $|\nabla L(w_t)|$

Our investigation focuses on which of these components has a more significant impact on the success of SAM.

Optimization Background

Optimization plays a critical role in training deep learning models. Two techniques we discuss here are Stochastic Gradient Descent (SGD) and Sharpness-Aware Minimization (SAM).

Basic Concepts

Optimization in the context of machine learning refers to the process of adjusting the parameters of a model to minimize (or maximize) an objective function, often referred to as the loss function. The goal is to find the parameter values that result in the best performance of the model on a given task.

Stochastic Gradient Descent (SGD)

SGD is a fundamental optimization algorithm used to minimize the loss function. The key idea is to iteratively adjust the model parameters in the direction opposite to the gradient of the loss function with respect to the parameters. This process is repeated until convergence.

The update rule for stochastic gradient descent is given by: $w{t+1} = wt - \eta \nabla L(wt)$ where: - $wt$ are the model parameters at iteration $t$, - $\eta$ is the learning rate, - $\nabla L(wt)$ is the gradient of the loss function with respect to $wt$.

Challenges in Optimization

  • Local Minima: Optimization algorithms can get stuck in local minima, especially in non-convex loss landscapes.
  • Saddle Points: Points where the gradient is zero but are not minima can slow down optimization.
  • Sharp Minima: Solutions where the loss function has steep curves can lead to poor generalization on new data.

Sharpness-Aware Minimization (SAM)

Sharpness-Aware Minimization (SAM) aims to address the issue of sharp minima. Sharp minima are regions where the loss function has steep slopes, which often correspond to poor generalization. SAM seeks to find flatter solutions that are robust to perturbations in the model parameters, leading to better generalization.

SAM modifies the loss function to penalize sharp minima: $L^{SAM}(wt) = \max{ || \epsilon || \leq \rho } L(w_t + \epsilon)$

To solve this objective function, SAM proposed the update rule as:

$$ w{t+1} = wt - \eta \nabla L(wt + \rho \frac{ \nabla L(wt) }{ || \nabla L(w_t) || }) $$

Experiments

Experiment Setting

  • Batch size: 256
  • Initial learning rate (lr): 0.1
  • Momentum: 0.9
  • Weight decay (wd): 0.001
  • $\rho \in {0.1, 0.2, 0.4}$

Accuracy Results (ResNet-18)

| Accuracy (ResNet-18) | $\rho=0.1$ | $\rho=0.2$ | $\rho=0.4$ | |----------------------|--------------|--------------|--------------| | SAM | 79.24 | 79.54 | 79.44 | | SAMDIRECTION | 77.89 | 78.52 | 78.22 | | SAMMAGNITUDE | 78.73 | 79.23 | 77.94 |

| Accuracy (ResNet-34) | $\rho=0.1$ | $\rho=0.2$ | $\rho=0.4$ | |----------------------|--------------|--------------|--------------| | SAM | ?? | 80.95 | 80.89 | | SAMDIRECTION | 79.35 | 80.02 | 76.38 | | SAMMAGNITUDE | 80.08 | 79.71 | 72.74 |

| Accuracy (ResNet-50) | $\rho=0.1$ | $\rho=0.2$ | $\rho=0.4$ | |----------------------|--------------|--------------|--------------| | SAM | ?? | 81.24 | ?? | | SAMDIRECTION | ?? | ?? | 80.34 | | SAMMAGNITUDE | ?? | ?? | ?? |

| Accuracy (WideResNet-28-10) | $\rho=0.1$ | $\rho=0.2$ | $\rho=0.4$ | |----------------------|--------------|--------------|--------------| | SAM | 83.50 | 83.91 | 83.44 | | SAMDIRECTION | 82.44 | 82.54 | 82.11 | | SAMMAGNITUDE | 82.47 | 80.63 | 79.36 |

  • Experiment 1: SAMMAGNITUDE

    • Maintains the SAM magnitude, replacing the direction with SGD direction.
    • ResNet18:
    • The results approximate SAM performance.
    • ResNet34, WideResNet28-10:
    • The results are below SAM performance and very sensitive to perturbation radius $\rho$.
  • Experiment 2: SAMDIRECTION

    • Maintains the SAM direction, replacing the magnitude with SGD magnitude.
    • ResNet18:
    • The results differ significantly from SAM, with extremely sharp minima.
    • ResNet34, WideResNet28-10:
    • The results are below SAM, but less sensitive to perturbation radius $\rho$.

Flatness Results (ResNet-18)

| Sharpness (ResNet-18) | $\rho=0.1$ | $\rho=0.2$ | $\rho=0.4$ | |----------------------|--------------|--------------|--------------| | SAM | 71.04 | 56.53 | 52.23 | | SAMDIRECTION | 259.27 | 3737.54 | 1254.17 | | SAMMAGNITUDE | 93.54 | 114.01 | 373.49 |

  • Experiment 3: Magnitude Comparison
    • The magnitude of SAM updates is larger than that of SGD updates.
    • We count the number of instances where the ratio of SAM update over SGD update is greater than one.
    • Results indicate that this ratio is over 50% during initial training and increases to 85% at later stages.

Reproducibility

To reproduce our experiments, follow these steps: - Install the wandb package: pip install wandb pip install -e . - Run the scripts: ```

SAM

python samhessian/train.py --experiment=defaultsam --rho=0.2 --wd=0.001 --projectname=CIFAR100-SAM --frameworkname=wandb

SAMDIRECTION

python samhessian/train.py --experiment=defaultsam --optname=samdirection --rho=0.2 --wd=0.001 --projectname=CIFAR100-SAM --framework_name=wandb

SAMMAGNITUDE

python samhessian/train.py --experiment=defaultsam --optname=sammagnitude --rho=0.2 --wd=0.001 --projectname=CIFAR100-SAM --framework_name=wandb ```

Intuition from Update Rule

Inspired by the insight from SAMMAGNITUDE on ResNet18 that the magnitude correlates with flatness, we run experiments to confirm this behavior on the simpler optimizer.

Considering the SAM update rule, we denote ( D ) as the gradient computed on the full batch, and ( B ) as the gradient computed on a mini-batch:

$$ w{t+1} = wt - \eta \nabla{B} L(wt + \rho \frac{\nabla{B} L(wt)}{||\nabla{B} L(wt)||}) $$

We focus on the gradient and use a first-order Taylor approximation. For convenience in analysis, and without loss of generality, we eliminate the gradient normalization in the denominator:

$$ \begin{align} \eta \nabla{B} L(wt + \rho \nabla{B} L(wt)) &= \eta [\nabla{B} L(wt) + \rho \nabla{B}^2 L(wt) \nabla{B} L(wt)] \ &= \eta (I + \rho \nabla{B}^2 L(wt)) \nabla{B} L(wt) \ \end{align} $$

The gradient magnitude is rescaled with weight $I + \rho \nabla{B}^2 L(wt)$.

Quick Check for Intuition

We introduce an SGD variant named SGDHESS, derived from the following intuition:

$$ w{t+1} = wt - \eta \nabla L(wt) (I + \rho \nabla{B}^2 L(w_t)) $$

For efficient computation, we use the Gauss-Newton approximation $G$ for $\nabla{B}^2 L(wt)$. This implementation is based on AdaHessian.

The results demonstrate that modifying the magnitude of the gradient $\nabla L(wt)$ by considering the Hessian $\nabla{B}^2 L(w_t)$ can lead to finding flatter minima.

| Flatness (ResNet-18) | Accuracy | Sharpness | |----------------------|--------------|--------------| | SGD | 77.36% | 207.06 | | SGDHESS ($\rho = 0.05$) | 77.27% | 192.45 | | SGDHESS ($\rho = 1$) | 76.75% | 151.62 | | SGDHESS ($\rho = 2$) | 75.96% | 138.62 |

However, the test accuracy of SGDHESS is lower than that of SGD. This phenomenon suggests that the actual SAM update does not only reduce sharpness.

Conclusion

  • The gradient magnitude is a crucial factor in determining the ability of SAM to find flat minima in particular architecture such as ResNet18, whereas worse in other architectures such as ResNet34, WideResNet28-10.
  • Both magnitude and direction are important to SAM's ability.

Further Questions

  • How magnitude affect SAM's ability?
    • SAMmagnitude/SGDmagnitude > 1?
  • How direction affect SAM's ability
    • SAMmagnitude/SGDmagnitude < 0?

Cite this repository

If you use this insight in your research, please cite our work: bibtex @techreport{Luong_Empirical_Study_on_2024, author = {Luong, Hoang-Chau}, month = jun, title = {{Empirical Study on Sharpness-Aware Minimization}}, url = {https://github.com/lhchau/empirical-sam}, year = {2024}}

Owner

  • Name: Hoang-Chau Luong
  • Login: lhchau
  • Kind: user

Interested in Bayesian Deep Learning

Citation (citation.cff)

abstract: The gradient magnitude is a crucial factor in determining the ability of SAM to find flat minima.
authors:
  - family-names: Luong
    given-names: Hoang-Chau
cff-version: 1.2.0
preferred-citation: 
  type: report
  authors:
  - family-names: Luong
    given-names: Hoang-Chau
  date-released: "2024-06-05"
  keywords:
    - research
    - "Sharpness-Aware Minimization"
    - "training dynamics"
  license: Apache-2.0
  message: If you use this insight in your research, please cite our work.
  repository-code: "https://github.com/lhchau/empirical-sam"
  title: "Empirical Study on Sharpness-Aware Minimization"

GitHub Events

Total
Last Year

Dependencies

setup.py pypi