Close Menu
    Trending
    • Implementing DRIFT Search with Neo4j and LlamaIndex
    • Agentic AI in Finance: Opportunities and Challenges for Indonesia
    • Dispatch: Partying at one of Africa’s largest AI gatherings
    • Topp 10 AI-filmer genom tiderna
    • OpenAIs nya webbläsare ChatGPT Atlas
    • Creating AI that matters | MIT News
    • Scaling Recommender Transformers to a Billion Parameters
    • Hidden Gems in NumPy: 7 Functions Every Data Scientist Should Know
    ProfitlyAI
    • Home
    • Latest News
    • AI Technology
    • Latest AI Innovations
    • AI Tools & Technologies
    • Artificial Intelligence
    ProfitlyAI
    Home » Capturing and Deploying PyTorch Models with torch.export
    Artificial Intelligence

    Capturing and Deploying PyTorch Models with torch.export

    ProfitlyAIBy ProfitlyAIAugust 20, 2025No Comments18 Mins Read
    Share Facebook Twitter Pinterest LinkedIn Tumblr Reddit Telegram Email
    Share
    Facebook Twitter LinkedIn Pinterest Email


     we consider mannequin challenge, most of our consideration tends to deal with the massive issues, akin to: creating and curating datasets, designing the very best ML structure, buying appropriately massive GPU clusters for coaching, and constructing an inference answer that meets goal quality-of-service (QOS) necessities. Nonetheless, it’s typically the small particulars that turn out to be our Achilles’ heel, resulting in unanticipated bugs and vital manufacturing delays.

    One element that’s typically missed is the handoff of a educated mannequin to the inference setting. Whereas this handoff could seem trivial, it will probably simply turn out to be the supply of an excessive amount of frustration. The coaching and inference environments are hardly ever an identical, with variations various from runtime libraries to {hardware} targets. To navigate these variations, the AI/ML mannequin developer should guarantee:

    1. that the mannequin definition, together with its educated weights, is loaded correctly into the inference setting, and
    2. that the mannequin’s conduct doesn’t change.

    This submit will deal with the primary problem — dependable restoration of the mannequin definition and state in an inference setting. We are going to survey a few of the legacy choices and their shortcomings. We are going to then introduce the brand new torch.export API and display it in motion on a toy mannequin constructed with HuggingFace’s transformers library (model 4.54.0). For our experiments, we are going to use an Amazon EC2 g5.xlarge occasion (containing an NVIDIA A10G GPU and 4 vCPUs) working a PyTorch (2.7) Deep Learning AMI (DLAMI).

    One of many applied sciences underlying torch.export is Torch Dynamo, a key element of PyTorch’s graph compilation answer, torch.compile. In a recent post we demonstrated the ability of torch.compile and its significance in optimizing the runtime efficiency of AI/ML fashions. In some methods this submit will be considered as a sequel: we are going to revisit a few of the key ideas — notably graph-breaks, we are going to reuse the identical toy mannequin, and we are going to display the usage of torch.compile for just-in-time (JIT) compilation in our inference setting. Whereas we advocate studying the previous post, it isn’t a prerequisite for this one.

    Disclaimers:

    As of the time of this writing (PyTorch 2.8.0), torch.export is a “prototype” characteristic. Whereas the behaviors mentioned on this submit are prone to principally stay, the API definitions may change. Remember to align your use of torch.export to the most recent API variations.

    This submit will cowl solely a subset of torch.export’s options and behaviors. It shouldn’t be considered as an alternative choice to the official PyTorch documentation. There are a selection of pages overlaying the export characteristic, together with an introductory tutorial, an overview of the programming model, and a guide to solving export challenges.

    Whereas we are going to display its performance on a toy instance, torch.export relies upon closely on the mannequin’s particulars and will exhibit very completely different conduct in your individual challenge.

    The code we are going to share is for demonstration functions solely and shouldn’t be relied on for correctness or optimality. Please don’t interpret our selection of platform, framework, or another instrument or library as an endorsement for its use.

    A Toy HuggingFace Mannequin

    To facilitate our dialogue, we outline a easy image-to-text generative mannequin utilizing the HuggingFace transformers library. For simplicity, we skip the coaching section and assume that the default constructor returns a pretrained mannequin.

    import torch
    
    NUM_TOKENS = 1024
    MAX_SEQ_LEN = 256
    PAD_ID = 0
    START_ID = 1
    END_ID = 2
    
    # Arrange an image-to-text mannequin.
    def get_model():
    
        # import transformers utilities
        from transformers import (
            VisionEncoderDecoderModel,
            VisionEncoderDecoderConfig,
            AutoConfig
        )
    
        config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(
            encoder_config=AutoConfig.for_model("vit"),  # vit encoder
            decoder_config=AutoConfig.for_model("gpt2")  # gpt2 decoder
        )
        config.decoder.vocab_size = NUM_TOKENS
        config.decoder.use_cache = False
        config.decoder_start_token_id = START_ID
        config.pad_token_id = PAD_ID
        config.eos_token_id = END_ID
        config.max_length = MAX_SEQ_LEN
    
        mannequin = VisionEncoderDecoderModel(config=config)
        mannequin.encoder.pooler = None  # take away unused pooler
        mannequin.eval() # put together the mannequin for analysis
        return mannequin

    We outline an auto-regressive image-to-text generator that makes use of the encoder and decoder parts of the mannequin to provide a caption for the enter picture. For simplicity, we use a primary implementation, leaving out frequent optimization methods, akin to KV caching.

    # generate the following token
    def generate_token(decoder, encoder_hidden_states, sequence):
        outputs = decoder(
            sequence,
            encoder_hidden_states
        )
        logits = outputs[0][:, -1, :]
        return torch.argmax(logits, dim=-1, keepdim=True)
    
    # easy auto-regressive sequence generator
    def image_to_text_generator(encoder, decoder, picture):
        # run encoder
        encoder_hidden_states = encoder(picture)[0]
    
        # initialize sequence
        generated_ids = torch.ones(
            (picture.form[0], 1),
            dtype=torch.lengthy,
            machine=picture.machine
        ) * START_ID
    
        for _ in vary(MAX_SEQ_LEN):
            # generate subsequent token
            next_token = generate_token(
                decoder,
                encoder_hidden_states,
                generated_ids
            )
            generated_ids = torch.cat([generated_ids, next_token], dim=-1)
            if (next_token == END_ID).all():
                break
    
        return generated_ids

    The next code block demonstrates the usage of our generator on a batch of random enter pictures.

    import os, time, random, torch
    
    torch.manual_seed(42)
    random.seed(42)
    
    BATCH_SIZE = 64
    EXPORT_PATH = '/tmp/export/'
    
    def test_inference(model_path=EXPORT_PATH, mode=None, compile=False):
        machine = 'cuda' if torch.cuda.is_available() else 'cpu'
        rnd_image = torch.randn(BATCH_SIZE, 3, 224, 224).to(machine)
        encoder, decoder = load_model(model_path, mode)
        encoder = encoder.to(machine)
        decoder = decoder.to(machine)
    
        if compile:
            encoder = torch.compile(encoder, mode="reduce-overhead")
            decoder = torch.compile(decoder, dynamic=True)
            # run a couple of warmup rounds
            for i in vary(10):
                image_to_text_generator(encoder, decoder, random_image)
    
        t0 = time.perf_counter()
        
        # optionally allow combined precision
        with torch.amp.autocast(machine, dtype=torch.bfloat16, enabled=True):
            with torch.no_grad():
                caption = image_to_text_generator(encoder, decoder, rnd_image)
    
        total_time = time.perf_counter() - t0
        print(f'batched inference whole time: {total_time}')

    The encoder and decoder fashions are loaded through a load_model utility. We outline an preliminary implementation and amend it later based mostly on our selection of mannequin capturing technique.

    To organize our mannequin for the export course of, we outline a pass-through wrapper class for the decoder. This wrapper ensures the mannequin will be traced utilizing positional (somewhat than key phrase) arguments, which is a present requirement of torch.export.

    class DecoderWrapper(torch.nn.Module):
        def __init__(self, decoder_model):
            tremendous().__init__()
            self.decoder = decoder_model
    
        def ahead(self, input_ids, encoder_hidden_states):
            return self.decoder(
                input_ids=input_ids,
                encoder_hidden_states=encoder_hidden_states,
                use_cache=False,
                output_attentions=False,
                output_hidden_states=False,
                return_dict=False
            )
    
    def load_model(path=EXPORT_PATH, mode=None):
        mannequin = get_model()
        encoder = mannequin.encoder
        decoder = mannequin.decoder
        return encoder, DecoderWrapper(decoder)

    Now that we have now outlined our toy mannequin, let’s discover completely different methods for capturing and deploying it to an inference setting.

    Mannequin Capturing and Deployment Methods

    On this part, we’ll overview two frequent strategies for capturing and deploying an AI/ML mannequin’s state: weights-only seize and conversion to a serializable intermediate graph-representation utilizing TorchScript.

    Weights Solely Seize

    The primary choice is to make use of torch.save to seize solely the mannequin weights, not the mannequin definition. This requires you to explicitly redefine the PyTorch mannequin within the inference setting. There are a number of methods to carryg over the code definition together with: copy-pasting (which could be very error-prone), pulling from a shared code repository, or utilizing Python packages. If the mannequin definition depends on particular Python bundle dependencies, you have to to make it possible for these packages exist within the inference setting and that their variations match the coaching setting.

    The next code block demonstrates capturing and loading the mannequin weights for our toy mannequin:

    def capture_model(mannequin, path=EXPORT_PATH):
        # weights solely
        weights_path = os.path.be part of(EXPORT_PATH, "weights.pth")
        torch.save(mannequin.state_dict(), weights_path)
    
    def load_model(path=EXPORT_PATH, mode=None):
        if mode == 'weights':
            mannequin = get_model()
            weights_path = os.path.be part of(path,"weights.pth")
            state_dict = torch.load(weights_path, map_location="cpu")
            mannequin.load_state_dict(state_dict)
            return mannequin.encoder, DecoderWrapper(mannequin.decoder)
        else:
            mannequin = get_model()
            return mannequin.encoder, DecoderWrapper(mannequin.decoder)

    One benefit of this methodology is that it offers most flexibility for tuning the mannequin’s configuration to the inference setting. For instance, you possibly can apply machine-specific optimizations that enhance the throughput of the inference workload. This freedom is particularly necessary for very massive fashions that require superior sharding methods for mannequin execution.

    Nonetheless, this methodology assumes you possibly can simply design and configure the inference setting as desired — which can’t be taken with no consideration. Some inference environments are extraordinarily constrained — with restricted management over the runtime libraries and bundle installations.

    The separation between the mannequin definition and the mannequin weights additionally will be the supply of all kinds of ugly bugs. Making certain applicable alignment between the supply code and the mannequin weights requires disciplined model administration. A most well-liked method can be to bundle the mannequin definition and weights right into a single archive.

    TorchScript Variants

    For a few years, the first methodology for capturing a PyTorch mannequin together with its weights was TorchScript. And regardless of its current deprecation, it stays broadly widespread. TorchScript encapsulates two completely different capturing options, torch.jit.script and torch.jit.trace, each of which convert PyTorch fashions into serializable graph representations. This graph will be loaded into one other setting and run as a standalone PyTorch program, with minimal runtime dependencies.

    The scripting and tracing functionalities are complementary. Scripting performs ahead-of-time (AOT) static evaluation of the supply code, whereas tracing traces the precise execution of the mannequin on a pattern enter. Scripting is ready to seize extra complexity within the graph, akin to conditional control-flow and dynamic shapes. That is opposite to tracing that captures simply the execution path and tensor shapes dictated by the enter pattern it runs on. Alternatively, torch.jit.hint helps extra operations than torch.jit.script (e.g., see here). Typically, scripting will succeed the place tracing will fail and vice versa. Typically, some type of combination of both strategies is required.

    Let’s now try to convert our toy mannequin to TorchScript. Because the sequence generator makes separate calls to the encoder (as soon as) and the decoder (iteratively) individually, we’ll seize them as separate graphs. The encoder takes fixed-shape inputs and has no input-dependent conditional logic, so we will apply the extra versatile tracing choice. Nonetheless, the sequence that we enter to the decoder will increase in dimension every time a token is generated, so we have now no selection however to make use of the scripting choice.

    HuggingFace helps TorchScript through a devoted configuration. See the HuggingFace documentation for extra particulars.

    config.decoder.torchscript = True
    config.encoder.torchscript = True

    Within the code block under, we lengthen our seize and loading utilities with TorchScript assist. Be aware the inclusion of the torch.jit.freeze optimization throughout seize and the use the torch.jit.optimize_for_inference optimization on the goal machine.

    def capture_model(mannequin, path=EXPORT_PATH):
        # weights solely
        weights_path = os.path.be part of(EXPORT_PATH, "weights.pth")
        torch.save(mannequin.state_dict(), weights_path)
    
        encoder = mannequin.encoder
        decoder = DecoderWrapper(mannequin.decoder)
    
        # torchscript encoder utilizing hint
        instance = torch.randn(1, 3, 224, 224)
        encoder_jit = torch.jit.hint(encoder, instance)
        # optionally apply jit.freeze optimization
        encoder_jit = torch.jit.freeze(encoder_jit)
        encoder_path = os.path.be part of(path, "encoder.pt")
        torch.jit.save(encoder_jit, encoder_path)
    
        attempt:
            # torchscript decoder utilizing scripting
            decoder_jit = torch.jit.script(decoder)
            # optionally apply jit.freeze optimization
            decoder_jit = torch.jit.freeze(decoder_jit)
            decoder_path = os.path.be part of(path, "decoder.pt")
            torch.jit.save(decoder_jit, decoder_path)
        besides Exception as e:
            print(f'torch.jit.script(mannequin.decoder) failedn{e}')
    
    def load_model(path=EXPORT_PATH, mode=None):
        if mode == 'weights':
            mannequin = get_model()
            weights_path = os.path.be part of(path,"weights.pth")
            state_dict = torch.load(weights_path, map_location="cpu")
            mannequin.load_state_dict(state_dict)
            return mannequin.encoder, DecoderWrapper(mannequin.decoder)
        elif mode == 'torchscript':
            encoder_path = os.path.be part of(path, "encoder.pt")
            decoder_path = os.path.be part of(path, "decoder.pt")
            encoder = torch.jit.load(encoder_path)
            decoder = torch.jit.load(decoder_path)
            # optionally apply target-device optimization 
            encoder = torch.jit.optimize_for_inference(encoder)
            decoder = torch.jit.optimize_for_inference(decoder)
            return encoder, decoder
        else:
            mannequin = get_model()
            return mannequin.encoder, DecoderWrapper(mannequin.decoder)

    Sadly, our seize utility fails when attempting to script the decoder mannequin. A typical downside with TorchScript is that it typically fails on complicated fashions. Whereas many points will be bypassed through the use of torch.jit.trace, it can’t be utilized to fashions with dynamic shapes or conditional logic, akin to our decoder. It’s typically attainable to rewrite the mannequin implementation to be script-compliant, however it will probably require a variety of painstaking (and ugly) work. Within the case of our decoder mannequin, it might require a variety of intrusive patchwork to the transformers library supply code.

    For extra on the subject of TorchScript, please see the official documentation.

    Mannequin Capturing with torch.export

    The brand new method of capturing a mannequin for deployment is torch.export. Much like torch.jit.trace, export works by tracing the mannequin’s execution on enter samples. Nonetheless, not like torch.jit.trace, export contains assist for dynamism and conditional control flow. The output of the export utility is an intermediate graph illustration known as Export IR, which will be loaded and executed in a clear inference setting. One of many benefits of our exported mannequin is that, opposite to TorchScript fashions, they are often optimized for the inference setting utilizing torch.compile. Alternatively, optimizations that require supply code adjustments (e.g., configuring the attn_implementation) can’t be utilized.

    Overcoming Graph Breaks

    A graph break happens when the export perform encounters an “untraceable” portion of Python code (e.g., see here for unsupported operations). We encountered the idea of graph breaks in our previous post on graph compilation. Nonetheless, opposite to mannequin compilation the place PyTorch will merely fall again to keen mode, torch.export forbids the presence of graph breaks. If export fails in your mannequin as a consequence of a graph break, you’ll have to rewrite your code to bypass it.

    There are a number of assets at your disposal for overcoming graph breaks, together with the Draft Export utility, which generates an in depth report of export points, ExportDB, which maintains a listing of supported and unsupported export instances, and a tutorial on overcoming frequent export points.

    Debugging an Exported Graph

    On some events, chances are you’ll reach exporting a graph solely to seek out that working it on the goal machine both fails with an error or returns incorrect output. A typical reason for such errors is that a few of the variables from the export setting are handled as constants and baked into the exported graph.

    Sadly, (as of the time of this writing) the instruments for debugging Export IR graphs are considerably restricted. Though the exported mannequin is a torch.nn.Module with a forward perform, you can’t use a debugger to step into it to seek out the supply of the errors.

    We will, nevertheless, examine the contents of the generated graph utilizing GraphModule.print_readable. This may print out all the graph operations together with feedback that time to the supply code from which they had been generated. Typically, this, mixed with the output error info, is sufficient to discover the supply of the errors and tweak the supply code accordingly. See under for an instance.

    Exporting a HuggingFace Mannequin

    To use torch.export to our toy mannequin, we first be certain to replace the PyTorch library to the most recent model (2.8.0 on the time of this writing). The export utility is underneath fast improvement and we wish to be certain to get essentially the most up-do-date characteristic assist.

    Within the code block under we revise our seize and cargo utility capabilities to assist torch.export. Be aware our use of torch.export.Dim to specify dynamic dimensions:

    def capture_model(mannequin, path=EXPORT_PATH):
        encoder = mannequin.encoder
        decoder = DecoderWrapper(mannequin.decoder)
    
        # outline dynamic dimensions
        batch = torch.export.Dim("batch")
        seq_len = torch.export.Dim("seq_len", min=2, max=MAX_SEQ_LEN)
    
        # export encoder
        # pattern enter
        instance = torch.randn(4, 3, 224, 224)
        encoder_export = torch.export.export(
            encoder,
            (instance,),
            dynamic_shapes=((batch,
                             torch.export.Dim.STATIC,
                             torch.export.Dim.STATIC,
                             torch.export.Dim.STATIC),
                            )
        )
        torch.export.save(encoder_export, os.path.be part of(path, "encoder.pt2"))
    
    
        # export decoder
        # get pattern enter for decoder
        encoder_hidden_states = encoder_export.module()(instance)[0]
        decoder_input_ids = torch.ones((4, MAX_SEQ_LEN),
                                       dtype=torch.lengthy)*START_ID
    
    
        decoder_export = torch.export.export(
            decoder,
            (decoder_input_ids, encoder_hidden_states),
            dynamic_shapes={
                      'input_ids': (batch,seq_len),
                      'encoder_hidden_states': (batch,
                                                torch.export.Dim.STATIC,
                                                torch.export.Dim.STATIC)
                           }
        )
        torch.export.save(decoder_export, os.path.be part of(path, "decoder.pt2"))
    
    
    def load_model(path=EXPORT_PATH, mode=None):
        if mode == 'weights':
            mannequin = get_model()
            weights_path = os.path.be part of(path,"weights.pth")
            state_dict = torch.load(weights_path, map_location="cpu")
            mannequin.load_state_dict(state_dict)
            return mannequin.encoder, DecoderWrapper(mannequin.decoder)
       elif mode == 'export':
            encoder_path = os.path.be part of(path, "encoder.pt2")
            decoder_path = os.path.be part of(path, "decoder.pt2")
            encoder = load(encoder_path).module()
            decoder = load(decoder_path).module()
            return encoder, decoder
        else:
            mannequin = get_model()
            return mannequin.encoder, DecoderWrapper(mannequin.decoder)

    Opposite to TorchScript, torch.export had no points capturing our decoder and encoder fashions. Specifically, no graph breaks had been encountered throughout tracing.

    Deploying an Exported Mannequin

    To finish our demonstration, we’ll check our exported mannequin in a clear inference setting. For this experiment, we’ll use an Amazon EC2 g5.xlarge occasion (containing an NVIDIA A10G GPU and 4 vCPUs) working a PyTorch (2.7) Deep Learning AMI (DLAMI). We’ll replace to the most recent PyTorch model however deliberately gained’t set up the transformers bundle

    Sadly, our pleasure on the success of exporting our mannequin was untimely, as working the exported decoder on the GPU ends in a runtime error, a portion of which we’ve pasted under:

    File "<eval_with_key>.24", line 306, in ahead
    
    File "/choose/pytorch/lib/python3.12/site-packages/torch/_ops.py", line 829, in __call__
    
      return self._op(*args, **kwargs)
    
             ^^^^^^^^^^^^^^^^^^^^^^^^^
    
    RuntimeError: Anticipated all tensors to be on the identical machine, however obtained index is on cpu, completely different from different tensors on cuda:0 (when checking argument in methodology wrapper_CUDA__index_select)

    The error signifies through the execution the exported graph, a few of the tensors reside on the CPU when they’re anticipated to be on the GPU. Since we explicitly copy the enter tensors onto the GPU, we will conclude that these seek advice from tensors that the graph is creating internally. Because the graph was exported on a CPU machine, the usage of machine=”cpu” was baked into the graph creation, leading to a runtime error when working on a GPU.

    Though the error message factors to the defective line of code (File “<eval_with_key>.24”, line 306, in ahead), there isn’t a precise file that we will add a breakpoint to and debug. We will, nevertheless, examine the contents of the decoder graph and seek for locations the place the usage of the CPU machine has been inadvertently baked in:

    decoder_export.module().print_readable()

    Analyzing the output (trying to find “cpu” references) and cross-referencing the supply code utilizing the embedded feedback, we uncover 4 places the place the transformers library (modeling_gpt2.py file) creates tensors on the CPU:

    1. If not specified, the GPT2 model creates of a cache_position tensor on the baked-in CPU machine utilizing torch.arange. This may be fastened by passing in a user-defined worth for cache_position.
    2. On line 861, a torch.Tensor.to operation is carried out to make sure correct placement of the position_embeds tensor. Whereas this can be required in a case of model parallelism, we don’t require it.
    3. If not specified, the mannequin creates a causal mask utilizing tensors that it creates on the baked-in CPU machine. We may bypass this by passing in a user-defined causal masks, however we’re completely comfortable holding this None and counting on the usage of the is_causal flag of the sdpa attention function.
    4. If not specified, the mannequin creates an ecnoder_attention_mask tensor on the baked-in CPU. As soon as once more, we may specify a worth for this tensor, however because the masks can be all True, setting it None achieves the identical function.

    The next patch summarizes the adjustments we carried out on the modeling_gpt2.py file. Importantly, these adjustments had been particular to our toy mannequin and won’t generalize to all use instances. This type of monkey-patching is ill-advised and requires intensive testing earlier than being utilized in a manufacturing setting.

    @@ -861 +861 @@
    -        hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.machine)
    +        hidden_states = inputs_embeds + position_embeds
    @@ -867,3 +867 @@
    -        causal_mask = self._update_causal_mask(
    -            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
    -        )
    +        causal_mask = None
    @@ -877 +875 @@
    -            if encoder_attention_mask is None:
    +            if not _use_sdpa and encoder_attention_mask is None:
    @@ -880,3 +878 @@
    -                encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
    -                    masks=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
    -                )
    +                go

    Within the following code block, we lengthen the decoder definitions and export implementation with an specific worth for cache_position.

    def generate_token(decoder, encoder_hidden_states, sequence):
        outputs = decoder(
            sequence,
            encoder_hidden_states,
            torch.arange(sequence.form[1], machine=sequence.machine)
        )
        logits = outputs[0][:, -1, :]
        return torch.argmax(logits, dim=-1, keepdim=True)
    
    class DecoderWrapper(torch.nn.Module):
        def __init__(self, decoder_model):
            tremendous().__init__()
            self.decoder = decoder_model
    
        def ahead(self, input_ids, encoder_hidden_states, cache_position):
            return self.decoder(
                input_ids=input_ids,
                cache_position=cache_position,
                encoder_hidden_states=encoder_hidden_states,
                use_cache=False,
                output_attentions=False,
                output_hidden_states=False,
                return_dict=False
            )
    
    def capture_model(mannequin, path=EXPORT_PATH):
        encoder = mannequin.encoder
        decoder = DecoderWrapper(mannequin.decoder)
    
        # outline dynamic dimensions
        batch = torch.export.Dim("batch")
        seq_len = torch.export.Dim("seq_len", min=2, max=MAX_SEQ_LEN)
    
        # export encoder
        # pattern tensor
        instance = torch.randn(4, 3, 224, 224)
        encoder_export = torch.export.export(
            encoder,
            (instance,),
            dynamic_shapes=((batch,
                             torch.export.Dim.STATIC,
                             torch.export.Dim.STATIC,
                             torch.export.Dim.STATIC),
                            )
        )
        torch.export.save(encoder_export, os.path.be part of(path, "encoder.pt2"))
    
        # export decoder
        # get pattern enter for decoder
        encoder_hidden_states = encoder_export.module()(instance)[0]
        decoder_input_ids = torch.ones((4, MAX_SEQ_LEN), 
                                        dtype=torch.lengthy)*START_ID
    
        decoder_args = (
            decoder_input_ids,
            encoder_hidden_states,
            torch.arange(MAX_SEQ_LEN)
        )
    
        dynamic_shapes = {
            'input_ids': (batch,seq_len),
            'encoder_hidden_states': (batch,
                                      torch.export.Dim.STATIC,
                                      torch.export.Dim.STATIC),
            'cache_position': (seq_len,),
        }
    
        decoder_export = torch.export.export(
            decoder,
            decoder_args,
            dynamic_shapes=dynamic_shapes
        )
        torch.export.save(decoder_export, os.path.be part of(path, "decoder.pt2"))

    Following these change, the exported decoder succeeds in producing sequences on the GPU machine. We hope (and count on) that because the torch.export characteristic evolves, these sorts of points will probably be dealt with routinely by the inner tracing mechanism.

    We additional check the potential for machine-specific optimizations by making use of graph compilation. For particulars on our selection of compilation parameters, see our previous post.

    encoder = torch.compile(encoder, mode="reduce-overhead")
    decoder = torch.compile(decoder, dynamic=True)

    The desk under captures the execution time of our mannequin on a batch of random pictures, with and with out torch.compile. (Be aware, that working the unique mannequin requires set up of the transformers library, model 4.54.0).

    Runtime Outcomes (by Creator)

    We will see that exporting the mannequin to graph illustration ends in a speed-up of 10.7%. The mannequin compilation, nevertheless, had the other impact, considerably growing the execution time. It’s seemingly that this may very well be fastened by applicable tuning.

    Abstract

    On this submit we explored the brand new torch.export utility and demonstrated its use in capturing and deploying a toy HuggingFace mannequin. We discovered that’s has various highly effective and compelling properties, together with:

    • Assist for complicated fashions: torch.export succeeded in capturing fashions that failed with TorchScript.
    • Portability: Exported fashions will be loaded and executed as standalone packages with out particular bundle dependencies.
    • Machine-specific optimizations: exported fashions are suitable with graph compilation, enabling the applying of machine-specific optimizations.

    We additionally encountered a few of torch.export’s limitations:

    • Unintended penalties of graph creation: If we’re not cautious about how we design our mannequin, values from the export setting will be inadvertently baked into the ensuing graph, breaking its compatibility with the inference setting.
    • Restricted debugging instruments: As of this writing, the instruments for debugging the execution of exported graphs are restricted.

    Though it nonetheless requires some refinement, torch.export is already an enormous enchancment over earlier capturing options like TorchScript. We sit up for seeing it proceed to evolve and enhance.

    For extra particulars on capturing AI/ML fashions utilizing PyTorch Export and to maintain monitor of API adjustments, please see the torch.export documentation. In the event you’re working inference on an edge machine, additionally see the ExecuTorch answer for mannequin deployment, which relies on torch.export.



    Source link

    Share. Facebook Twitter Pinterest LinkedIn Tumblr Email
    Previous ArticleAdvanced Prompt Engineering for Data Science Projects
    Next Article Building a Modern Dashboard with Python and Tkinter
    ProfitlyAI
    • Website

    Related Posts

    Artificial Intelligence

    Implementing DRIFT Search with Neo4j and LlamaIndex

    October 22, 2025
    Artificial Intelligence

    Agentic AI in Finance: Opportunities and Challenges for Indonesia

    October 22, 2025
    Artificial Intelligence

    Creating AI that matters | MIT News

    October 21, 2025
    Add A Comment
    Leave A Reply Cancel Reply

    Top Posts

    Unlock Global AI: Why Multilingual AI Text Data is Crucial

    April 3, 2025

    Singapore Airlines Is Using ChatGPT to Make Flying Way Smarter

    April 30, 2025

    AI model deciphers the code in proteins that tells them where to go | MIT News

    April 5, 2025

    Are your AI agents still stuck in POC? Let’s fix that.

    August 8, 2025

    Mastering SQL Window Functions | Towards Data Science

    June 10, 2025
    Categories
    • AI Technology
    • AI Tools & Technologies
    • Artificial Intelligence
    • Latest AI Innovations
    • Latest News
    Most Popular

    The Total Derivative: Correcting the Misconception of Backpropagation’s Chain Rule

    May 6, 2025

    Fine-Tune Your Topic Modeling Workflow with BERTopic

    August 12, 2025

    Torchvista: Building an Interactive Pytorch Visualization Package for Notebooks

    July 23, 2025
    Our Picks

    Implementing DRIFT Search with Neo4j and LlamaIndex

    October 22, 2025

    Agentic AI in Finance: Opportunities and Challenges for Indonesia

    October 22, 2025

    Dispatch: Partying at one of Africa’s largest AI gatherings

    October 22, 2025
    Categories
    • AI Technology
    • AI Tools & Technologies
    • Artificial Intelligence
    • Latest AI Innovations
    • Latest News
    • Privacy Policy
    • Disclaimer
    • Terms and Conditions
    • About us
    • Contact us
    Copyright © 2025 ProfitlyAI All Rights Reserved.

    Type above and press Enter to search. Press Esc to cancel.