Close Menu
    Trending
    • Optimizing Token Generation in PyTorch Decoder Models
    • Decisioning at the Edge: Policy Matching at Scale
    • Optimizing Deep Learning Models with SAM
    • AI Bots Formed a Cartel. No One Told Them To.
    • Is the AI and Data Job Market Dead?
    • PySpark for Pandas Users | Towards Data Science
    • AI in Multiple GPUs: Gradient Accumulation & Data Parallelism
    • Build Effective Internal Tooling with Claude Code
    ProfitlyAI
    • Home
    • Latest News
    • AI Technology
    • Latest AI Innovations
    • AI Tools & Technologies
    • Artificial Intelligence
    ProfitlyAI
    Home » Optimizing Deep Learning Models with SAM
    Artificial Intelligence

    Optimizing Deep Learning Models with SAM

    ProfitlyAIBy ProfitlyAIFebruary 24, 2026No Comments16 Mins Read
    Share Facebook Twitter Pinterest LinkedIn Tumblr Reddit Telegram Email
    Share
    Facebook Twitter LinkedIn Pinterest Email


    : Overparameterization, Generalizability, and SAM

    The dramatic success of contemporary deep studying — particularly within the domains of Pc Imaginative and prescient and Pure Language Processing — is constructed on “overparameterized” fashions: fashions with greater than sufficient parameters to memorize the coaching knowledge completely. Functionally, a mannequin might be identified as overparameterized when it will probably simply obtain a near-perfect coaching accuracy (near 100%) with near-zero coaching loss for a given job.

    Nonetheless, the usefulness of such a mannequin relies on whether or not it performs effectively on the held-out take a look at knowledge drawn from the identical distribution because the coaching set, however unseen throughout coaching. This property is named “generalizability” — the power of a mannequin to keep up efficiency on new examples — and it’s important for any deep studying mannequin to be virtually helpful.

    Classical Machine Studying principle tells us that overparameterized fashions ought to catastrophically overfit and due to this fact generalize poorly. Nonetheless, one of the vital stunning discoveries of the previous decade is that fashions on this class usually generalize remarkably effectively.

    This extremely counterintuitive phenomenon has been investigated in a collection of papers, beginning with the seminal works of Belkin et al. (2018) and Nakkiran et al. (2019), which demonstrated that there exists a “double descent” curve for generalizability: as mannequin measurement will increase, generalization first worsens (as classical principle predicts), then improves once more past a crucial threshold — offered the mannequin is educated with the suitable optimization strategies.

    Determine 1: A schematic illustration for the double descent habits. Picture generated by the writer with Gemini.

    Determine 1 exhibits a cartoon of a double descent curve. The y-axis plots take a look at error — a measure of generalizability, the place decrease error signifies higher generalization — whereas the x-axis exhibits the variety of mannequin parameters. As mannequin measurement will increase, coaching error (dashed blue line) quickly approaches zero, as anticipated.

    The take a look at error (strong blue line) reveals a extra fascinating habits: it initially decreases with mannequin measurement — the primary descent, highlighted by the left crimson circle — after which rises to a peak on the interpolation threshold marked by the vertical dashed line, the place the mannequin has the worst generalization. Past this threshold, nonetheless, within the overparameterized regime, the take a look at error decreases once more — the second descent, highlighted by the correct crimson circle — and continues to say no as extra parameters are added. That is the regime of curiosity for contemporary deep studying fashions. 

    In Machine Studying, one finds the parameters of a mannequin by minimizing a loss operate on the coaching dataset. However does merely minimizing our favourite loss operate — like cross-entropy — on the coaching dataset assure passable generalization properties for the category of overparametrized fashions? The reply is — typically talking — no! Whether or not one is eager about fine-tuning a pre-trained mannequin or coaching a mannequin from scratch, you will need to optimize your coaching algorithm to make sure that you’ve gotten a sufficiently generalizable mannequin. That is what makes the selection of the optimizer an important design alternative.

    Sharpness-Conscious-Minimization (SAM) — launched in a paper by Foret et al. (2019) — is an optimizer designed to enhance generalizability of an overparameterized mannequin. On this article, I current a pedagogical evaluate of SAM that features:

    1. An intuitive understanding of how SAM works and why it improves generalization.
    2. A deep dive into the algorithm, explaining the important thing mathematical steps concerned. 
    3. A PyTorch implementation of the optimizer class in a coaching loop, together with an essential caveat for fashions with BatchNorm layers.
    4. A fast demonstration of the effectiveness of the optimizer in enhancing generalization on a picture classification job with a ResNet-18 mannequin.

    The entire code used on this article might be present in this Github repo — be at liberty to mess around with it!

    The Notion of Sharpness

    To start with, allow us to attempt to get an intuitive sense of why merely minimizing the loss operate might not be sufficient for optimum generalization. 

    A helpful image to bear in mind is that of the loss panorama. For a big overparametrized mannequin, the loss panorama has a number of native and international minima. The native geometries round such minima can range considerably alongside the panorama. For instance, two minima might have practically equivalent loss values, but differ dramatically of their native geometry: one could also be sharp (slim valley) whereas the opposite is flat (broad valley).

    One formal measure for evaluating these native geometries is “sharpness”. At any given level w within the loss panorama with loss operate L(w), sharpness S(w) is outlined as:

    Let me unpack the definition. Think about you’re at some extent w within the loss panorama and also you perturb the parameters such that the brand new parameter all the time lies inside a ball of radius ρ with middle w. Sharpness is then outlined because the maximal change within the loss operate inside this household of perturbations. Within the literature, it is usually known as the worst-direction sharpness for apparent causes.

    One can readily see that for a pointy minimal — a steep, slim valley — the worth of the loss operate will change dramatically with small perturbations in sure instructions and result in a excessive worth for sharpness. For a flat minimal however — a large valley — the worth of loss operate will change comparatively slowly with small perturbations and result in a decrease worth for sharpness. Due to this fact, sharpness provides a measure of flatness for a given minimal within the loss panorama. 

    There exists a deep connection between the native geometry of a minimal — particularly the sharpness measure— and the generalization property of the resultant mannequin. During the last decade, a big quantity of theoretical and empirical analysis has gone into clarifying this connection. As an example — because the paper by Keskar et al. (2016) factors out — international minima with comparable values of the loss operate can have considerably totally different generalization properties relying on their sharpness measures.

    The essential lesson that appears to be emerge from these research is: flatter (much less sharp) minima are positively correlated with higher generalization of fashions. Particularly, the mannequin ought to keep away from getting caught in a pointy minima throughout coaching if it has to generalize effectively. Due to this fact, for coaching a mannequin with good generalization, one must make sure that the optimization process not solely minimizes the loss operate but additionally seeks to maximise the flatness (or equivalently reduce the sharpness) of the minima. 

    That is exactly the issue that the SAM optimizer is designed to unravel, and that is what we flip to within the subsequent part. 

    A fast apart: notice that the above image provides a conceptual rationalization of why an overparameterized mannequin can doubtlessly keep away from the issue of overfitting. It’s as a result of a big mannequin has a wealthy loss panorama which offers a multiplicity of flat international minima with glorious generalization properties.

    The Sharpness-Conscious Minimization (SAM) Algorithm

    Allow us to recall the usual optimization of a mannequin. It entails discovering mannequin parameters that reduce a given loss operate computed over a mini-batch B. At each time-step, one computes the gradient of the loss with respect to the parameters, and updates the parameters in accordance with the rule: 

    In contrast to SGD or Adam, SAM doesn’t reduce L immediately. As a substitute, at a given level within the loss panorama, it first scans its neighborhood of a given measurement ρ and finds the perturbation that maximizes the loss operate. Within the second step, it minimizes this most loss operate. This permits the optimizer to search out parameters that lie in neighborhoods with uniformly low loss worth, which leads to smaller sharpness values and flatter minima.

    Let’s focus on the process in somewhat extra element. The loss operate for the SAM optimizer is:

    the place ρ denotes the higher certain on the scale of the perturbations. The perturbation that maximizes the operate L (usually referred to as adversarial perturbation because it maximizes the traditional loss) might be discovered by noting that:

    the place the second equality is an approximation obtained by Taylor-expanding the perturbed operate in step one, and the final equality follows from the ϵ-independence of the primary time period in sq. brackets within the earlier step. This final equality might be solved for the adversarial perturbation as follows:

    Plugging this again within the equation for the SAM loss, one can compute the gradients of the SAM loss to the main order in derivatives of ϵ:

    That is probably the most essential equation for the optimization process. To the main order in derivatives of ϵ, the gradients of the SAM loss operate might be approximated by the gradients of the traditional loss operate evaluated on the adversarially perturbed level. Utilizing the above method for the gradients, one can now execute the usual optimizer step:

    This completes one full SAM iteration. Subsequent, allow us to translate the algorithm from English to PyTorch.

    PyTorch Implementation in a Coaching Loop

    An illustrative instance of a coaching loop with a SAM optimizer is given within the code block sam_training_loop.py. For concreteness, we have now chosen a generic picture classification downside, however the identical construction broadly holds for a variety of Pc Imaginative and prescient and NLP duties. The SAM optimizer class is proven within the code block sam_optimizer_class.py.

    Be aware that defining a SAM optimizer requires specifying two items of knowledge:

    1. A base optimizer (like SGD or Adam), since SAM entails a regular optimizer step in the long run.
    2. A hyperparameter ρ, which places an higher certain on the scale of the admissible perturbations.

    A single iteration of the optimizer entails two ahead passes and two backward passes. Let’s hint out the important thing steps of the code in sam_training_loop.py:

    1. Line 5 computes the loss operate L(w, B) for the present mini-batch B — the primary ahead cross.
    2. Line 6 computes the gradients of the loss operate L(w, B) — the primary backward cross. 
    3. Line 7 calls the operate sam_optimizer.first_step from the SAM optimizer class (see under) that computes the adversarial perturbation utilizing the method mentioned above, and perturbs the weights of the mannequin as mentioned earlier than.
    4. Line 10 computes the loss operate for the perturbed mannequin — the second ahead cross. 
    5. Line 11 computes the gradients of the loss operate for the perturbed mannequin— the second backward cross. 
    6. Line 12 calls the operate sam_optimizer.second_step from the optimizer class (see under) that restores the weights to w_t after which makes use of the bottom optimizer to replace the weights w_t utilizing the gradients computed on the perturbed level. 

    A Caveat: SAM with BatchNorm

    There is a vital level that one wants to bear in mind whereas deploying SAM in a coaching loop if the mannequin has any module that features batch-normalization layers. Throughout coaching, BatchNorm implements the normalization utilizing the present batch statistics and updates the operating statistics at each ahead cross. For analysis, it makes use of the operating statistics. 

    Now, as we noticed above, SAM entails two ahead passes per iteration. For the primary cross, BatchNorm works in the usual trend. Through the second cross, nonetheless, we’re utilizing perturbed weights to compute loss, and the naive coaching operate within the code block sam_training_loop.py will permit the BatchNorm layers to replace the operating statistics through the second cross as effectively. That is undesirable as a result of the operating statistics ought to solely mirror the habits of the authentic mannequin, not the perturbed mannequin which is just an intermediate step for computing gradients. Due to this fact, one has to explicitly disable the operating statistics replace through the second cross and allow it earlier than the subsequent iteration.

    For this goal, we’ll use two specific features disable_bn_stats and enable_bn_stats within the coaching loop — easy examples of such features are proven in code block running_stat.py — they toggle the track_running_stats parameter (line 4 and line 9) of BatchNorm operate in PyTorch. The modified coaching loop is given within the code block mod_train.py. 

    Demo: Picture classification with ResNet-18

    Lastly, let’s display how the SAM optimization improves the generalization of a mannequin in a concrete instance. We are going to contemplate a picture classification downside utilizing the Fashion-MNIST dataset (MIT License): it consists of 60,000 coaching pictures and 10,000 testing pictures throughout 10 distinct, mutually unique courses, the place every picture is grayscale with 28*28 pixels.

    Because the classifier mannequin, we’ll select a PreAct ResNet-18 with none pre-training. Whereas a dialogue on the exact ResNet-18 structure is just not very related for our goal, allow us to recall that the mannequin consists of a sequence of constructing blocks, every of which is made up of convolutional layers, BatchNorm layers, ReLU activation with skipped connections. The PreAct (pre-activation) signifies that the activation operate (ReLU) comes earlier than the convolutional layer in every block. For the standard ResNet-18, it’s the different means spherical. I might refer the reader to the paper — He et al. (2015) — for extra particulars on the structure.

    What’s essential to notice, nonetheless, is that this mannequin has about 11.2 million parameters, and due to this fact from the angle of classical Machine Studying, it’s an overparameterized mannequin with the parameter-to-sample ratio being about 186:1. Additionally, for the reason that mannequin contains BatchNorm layers, we have now to watch out about disabling the operating statistics for the second cross, whereas utilizing SAM.

    We are actually prepared to hold out the next experiment. We practice the mannequin on the Vogue-MNIST dataset with the usual SGD optimizer first after which with the SAM optimizer utilizing the identical SGD as the bottom optimizer. We are going to contemplate a easy setup with a hard and fast studying fee lr=0.05 and with the momentum and the weight-decay each set to zero. The hyperparameter ρ in SAM is ready to 0.05. All runs are carried out on a single A100 GPU. 

    Since every SAM weight replace requires two backpropagation steps — one to compute the perturbations and one other to compute the ultimate gradients — for a good comparability every non-SAM coaching run should execute twice as many epochs as every SAM coaching run. We are going to due to this fact have to check a metric from one epoch of SAM coaching run to a metric from two epochs of non-SAM coaching run. We are going to name this a “standardized epoch” and a metric recorded at standardized epochs will probably be labelled as metric_st. We are going to prohibit the experiment to 150 standardized epochs, which suggests the SAM coaching runs for 150 epochs and the non-SAM coaching runs for 300 epochs. We are going to practice the SAM-optimized mannequin for a further 50 epochs to get an thought of how the mannequin behaves on longer coaching. 

    In attempting to verify which optimizer provides higher generalization, we’ll examine the next two metrics after every standardized epoch of coaching:

    1. Check accuracy: Efficiency of the mannequin on the take a look at dataset.
    2. Generalizability hole: Distinction between the coaching accuracy and take a look at accuracy.

    The take a look at accuracy is an absolute measure of how effectively the mannequin generalizes after a sure variety of coaching epochs. The generalizability hole, however, is a diagnostic that tells you ways a lot a mannequin is overfitting at a given stage of coaching. 

    Allow us to start by evaluating the training_loss_st and training_accuracy_st graphs, as proven in Determine 3. The mannequin with SGD reaches near-zero loss and near 99% coaching accuracy inside 150 epochs, as anticipated of an overparametrized mannequin. It’s evident that SAM trains slowly in comparison with SGD and takes extra standardized epochs to achieve a near-perfect coaching accuracy. That is evident from the truth that the coaching loss in addition to the coaching accuracy continues to enhance as one trains the SAM-optimized mannequin for extra epochs past the stipulated 150.

    Determine 3. Comparability of the standardized coaching losses and coaching accuracies.

    Check accuracy. The graphs in Determine 4 compares the take a look at accuracies for the 2 circumstances after every standardized epoch.

    Determine 4. Comparability of the standardized take a look at accuracies.

    The SGD-optimized mannequin reaches 92% take a look at accuracy round epoch 50 and plateaus round that worth for the subsequent 100 epochs. The SAM-optimized mannequin generalizes poorly within the preliminary part of the coaching — till round 80 epochs — as evident from the decrease take a look at accuracies on this part in comparison with the SGD graph. Nonetheless, round epoch 80, it catches up with the SGD graph and finally surpasses it by a skinny margin.

    For this particular run, on the finish of 150 epochs, the take a look at accuracy for SAM stands at test_SAM = 92.5%, whereas that for SGD is test_SGD = 92.0%. Be aware that that is even though the SAM-trained mannequin has a a lot decrease coaching accuracy and coaching loss at this stage. If one trains the SAM-model for an additional 50 epochs, the take a look at accuracy improves barely to 92.7%.

    Generalization Hole. The evolution of the generalization hole after every standardized epoch in course of the coaching course of is proven in Determine 5.

    Determine 5: Comparability of the generalization hole.

    The hole for the SGD mannequin grows steadily with coaching and after 150 epochs reaches gap_SGD=6.8%, whereas for SAM it grows rather more slowly and reaches gap_SAM= 2.3%. On additional coaching for an additional 50 epochs, the hole for SAM climbs to round 3%, however it’s nonetheless a lot decrease in comparison with the SGD worth.

    Whereas the distinction in take a look at accuracies is small between the 2 optimizers for the Vogue-MNIST dataset, there’s a non-trivial distinction within the generalization gaps, which demonstrates that optimizing with SAM results in higher generalization.

    Concluding Remarks

    On this article, I offered a pedagogical evaluate of SAM as an optimizer that considerably improves the generalization of overparameterized deep studying fashions. We mentioned the motivation and instinct behind SAM, walked by way of a step-by-step breakdown of the algorithm, and studied a easy instance demonstrating its effectiveness in comparison with a regular SGD optimizer.

    There are a number of fascinating facets of SAM that I didn’t have an opportunity to cowl right here. Let me briefly point out two of them. First, as a sensible software, SAM is especially helpful for fine-tuning pre-trained fashions on small datasets — one thing explored intimately by Foret et al.(2019) for CNN-type architectures and in lots of subsequent works for extra basic architectures. Second, since we opened our dialogue with the connection between flat minima within the loss panorama and generalization, it’s pure to ask whether or not a SAM-trained mannequin — which demonstrably improves generalizability — does certainly converge to a flatter minimal. This can be a non-trivial query, requiring a cautious evaluation of the Hessian spectrum of the educated mannequin and a comparability with its SGD-trained counterpart. However that’s a narrative for an additional day!


    Thanks for studying! If in case you have loved the article, and would have an interest to learn extra pedagogical articles on deep studying, do observe me on Medium and LinkedIn. Until in any other case said, all pictures and graphs used on this article had been generated by the writer.



    Source link

    Share. Facebook Twitter Pinterest LinkedIn Tumblr Email
    Previous ArticleAI Bots Formed a Cartel. No One Told Them To.
    Next Article Decisioning at the Edge: Policy Matching at Scale
    ProfitlyAI
    • Website

    Related Posts

    Artificial Intelligence

    Optimizing Token Generation in PyTorch Decoder Models

    February 24, 2026
    Artificial Intelligence

    Decisioning at the Edge: Policy Matching at Scale

    February 24, 2026
    Artificial Intelligence

    AI Bots Formed a Cartel. No One Told Them To.

    February 24, 2026
    Add A Comment
    Leave A Reply Cancel Reply

    Top Posts

    I Evaluated Half a Million Credit Records with Federated Learning. Here’s What I Found

    January 7, 2026

    Grok 4 Is Making Waves as World’s “Most Intelligent Model”

    July 22, 2025

    Drawing Shapes with the Python Turtle Module

    December 11, 2025

    Predicting the NBA Champion with Machine Learning

    April 24, 2025

    Make Python Up to 150× Faster with C

    November 10, 2025
    Categories
    • AI Technology
    • AI Tools & Technologies
    • Artificial Intelligence
    • Latest AI Innovations
    • Latest News
    Most Popular

    Azure ML vs. AWS SageMaker: A Deep Dive into Model Training — Part 1

    January 25, 2026

    A Lawsuit Over AI Agents that Shop

    November 13, 2025

    How to Crack Machine Learning System-Design Interviews

    November 14, 2025
    Our Picks

    Optimizing Token Generation in PyTorch Decoder Models

    February 24, 2026

    Decisioning at the Edge: Policy Matching at Scale

    February 24, 2026

    Optimizing Deep Learning Models with SAM

    February 24, 2026
    Categories
    • AI Technology
    • AI Tools & Technologies
    • Artificial Intelligence
    • Latest AI Innovations
    • Latest News
    • Privacy Policy
    • Disclaimer
    • Terms and Conditions
    • About us
    • Contact us
    Copyright © 2025 ProfitlyAI All Rights Reserved.

    Type above and press Enter to search. Press Esc to cancel.