Close Menu
    Trending
    • Optimizing Data Transfer in Distributed AI/ML Training Workloads
    • Achieving 5x Agentic Coding Performance with Few-Shot Prompting
    • Why the Sophistication of Your Prompt Correlates Almost Perfectly with the Sophistication of the Response, as Research by Anthropic Found
    • From Transactions to Trends: Predict When a Customer Is About to Stop Buying
    • America’s coming war over AI regulation
    • “Dr. Google” had its issues. Can ChatGPT Health do better?
    • Evaluating Multi-Step LLM-Generated Content: Why Customer Journeys Require Structural Metrics
    • Why SaaS Product Management Is the Best Domain for Data-Driven Professionals in 2026
    ProfitlyAI
    • Home
    • Latest News
    • AI Technology
    • Latest AI Innovations
    • AI Tools & Technologies
    • Artificial Intelligence
    ProfitlyAI
    Home » On the Challenge of Converting TensorFlow Models to PyTorch
    Artificial Intelligence

    On the Challenge of Converting TensorFlow Models to PyTorch

    ProfitlyAIBy ProfitlyAIDecember 5, 2025No Comments20 Mins Read
    Share Facebook Twitter Pinterest LinkedIn Tumblr Reddit Telegram Email
    Share
    Facebook Twitter LinkedIn Pinterest Email


    Within the curiosity of managing reader expectations and stopping disappointment, we want to start by stating that this submit does not present a completely passable answer to the issue described within the title. We’ll suggest and assess two doable schemes for auto-conversion of TensorFlow fashions to PyTorch — the primary based mostly on the Open Neural Network Exchange (ONNX) format and libraries and the second utilizing the Keras3 API. Nonetheless, as we are going to see, every comes with its personal set of challenges and limitations. To the very best of the authors’ information, on the time of this writing, there are not any publicly out there foolproof options to this drawback.

    Many due to Rom Maltser for his contributions to this submit.

    The Decline of TensorFlow

    Through the years, the sphere of pc science has identified its justifiable share of “non secular wars” — heated, typically hostile, debates amongst programmers and engineers over the “greatest” instruments, languages, and methodologies. Up till a number of years in the past, the non secular battle between PyTorch and TensorFlow, two distinguished open-source deep studying frameworks, loomed giant. Proponents of TensorFlow would spotlight its quick graph-execution mode, whereas these within the PyTorch camp would emphasize its “Pythonic” nature and ease of use.

    Nonetheless, as of late, the quantity of exercise in PyTorch far overshadows that of TensorFlow. That is evidenced by the variety of big-tech firms which have embraced PyTorch over TensorFlow, by the variety of fashions per framework in HuggingFace’s models repository, and by the quantity of innovation and optimization in every framework. Merely put, TensorFlow is a shell of its former self. The battle is over, with PyTorch the definitive winner. For a short historical past of the Pytorch-TensorFlow wars and the explanations for TensorFlow’s downfall, see Pan Xinghan’s submit: TensorFlow Is Dead. PyTorch Won.

    Downside: What can we do with all of our legacy TensorFlow fashions?!!

    In mild of this new actuality, many organizations that after used TensorFlow have moved all of their new AI/ML mannequin improvement to PyTorch. However they’re confronted with a tough problem on the subject of their legacy code: What ought to they do with all the fashions which have already been constructed and deployed in TensorFlow?

    Possibility 1: Do Nothing.

    You is perhaps questioning why that is even an issue — the TensorFlow fashions work — let’s not contact them. Whereas it is a legitimate strategy, there are a variety of disadvantages that ought to be considered:

    1. Lowered upkeep: As TensorFlow continues to say no so will its upkeep. Inevitably, issues will begin to break. For instance, there could also be problems with compatibility with newer Python packages or system libraries.
    2. Restricted Ecosystem: AI/ML options usually contain a number of supporting software program libraries and companies that interface with our framework of alternative, be it PyTorch or TensorFlow. Over time, we will anticipate to see many of those discontinue their help for TensorFlow. Living proof: HuggingFace just lately announced the deprecation of its support for TensorFlow.
    3. Restricted Group: The AI/ML business owes its quick tempo of improvement, largely, to its group. The variety of open supply initiatives, the variety of on-line tutorials, and the quantity of exercise in devoted help channels within the AI/ML area, is unparalleled. As TensorFlow declines, so will its group and it’s possible you’ll expertise rising issue getting the assist you to want. Evidently, the PyTorch group is flourishing.
    4. Alternative Price: The PyTorch ecosystem is prospering with fixed improvements and optimizations. Latest years have seen the event of flash-attention kernels, help for the eight-bit floating-point information sort, graph compilation, and plenty of different developments which have demonstrated vital boosts to runtime efficiency and vital reductions in AI/ML prices. Throughout the identical time interval the function providing in TensorFlow has remained principally static. Sticking with TensorFlow means forgoing many alternatives for AI/ML price optimization.

    Possibility 2: Manually Convert TensorFlow Fashions to PyTorch

    The second choice is to rewrite legacy TensorFlow fashions in PyTorch. That is most likely the best choice when it comes to its end result, however for firms which have constructed up technical debt over a few years, changing even a single mannequin may very well be a frightening activity. Given the hassle required, it’s possible you’ll select to do that just for fashions which might be nonetheless below energetic improvement (e.g., within the mannequin coaching section). Doing this for all the fashions which might be already deployed might show prohibitive.

    Possibility 3: Automate TensorFlow to PyTorch Conversion

    The third choice, and the strategy we discover on this submit, is to automate the conversion of legacy TensorFlow fashions to PyTorch. On this method, we hope to perform the good thing about mannequin execution in PyTorch, however with out the large effort of manually changing every one.

    To facilitate our dialogue we are going to outline a toy TensorFlow mannequin and assess two proposals for changing it to PyTorch. As our runtime setting, we are going to use an Amazon EC2 g6e.xlarge with an NVIDIA L40S GPU, an AWS Deep Learning Ubuntu (22.04) AMI, and a Python setting that features the TensorFlow (2.20), PyTorch (2.9), torchvision (0.24.0), and transformers (4.55.4) libraries. Please be aware that the code blocks we are going to share are meant for demonstrative functions. Please don’t interpret our use of any code, library, or platform as an endorsement of its use.

    Mannequin Conversion — Why is it Laborious?

    An AI mannequin definition is comprised of two elements: a mannequin structure and its skilled weights. A mannequin conversion answer should deal with each elements. Conversion of the mannequin weights is fairly simple; the weights are usually saved in a format that may be simply parsed into particular person tensor arrays and reapplied within the framework of alternative. In distinction, conversion of the mannequin structure presents a a lot better problem.

    One strategy may very well be to create a mapping between the constructing blocks of the mannequin in every of the frameworks. Nonetheless, there are a variety of things that make this strategy, for all intents and functions, nearly intractable:

    • API Overlap and Proliferation: While you bear in mind the sheer variety of, typically overlapping, TensorFlow APIs for constructing mannequin elements after which add the huge variety of API controls and arguments for every layer, you possibly can see how making a complete, one-to-one mapping can shortly get ugly.
    • Differing Implementation Approaches: On the implementation degree, TensorFlow and PyTorch have essentially completely different approaches. Though normally hidden behind the top-level APIs, some assumptions require particular person consideration. For instance, whereas TensorFlow defaults to the “channels-last” (NHWC) format, PyTorch prefers “channels-first” (NCHW). This distinction in how tensors are listed and saved complicates the conversion of mannequin operations, as each layer should be checked/altered for proper dimension ordering.

    Reasonably than try conversion on the API degree, an alternate strategy may very well be to seize and convert an inside TensorFlow graph illustration. Nonetheless, as anybody who has ever seemed below the hood of TensorFlow will inform you, this too might get fairly nasty in a short time. TensorFlow’s inside graph illustration is extremely advanced, typically together with a large number of low-level operations, management circulate, and auxiliary nodes that should not have a direct equal in PyTorch (particularly in the event you’re coping with older variations of TensorFlow). Simply its comprehension appears past regular human skill, not to mention its conversion to PyTorch.

    Word that the identical challenges would make it tough for a generative AI mannequin to carry out the conversion in a fashion that’s totally dependable.

    Proposed Conversion Schemes

    In mild of those difficulties, we abandon our try at implementing our personal mannequin converter and as a substitute look to see what instruments the AI/ML group has to supply. Extra particularly, we contemplate two completely different methods for overcoming the challenges we described:

    1. Conversion By way of a Unified Graph Illustration: This answer assumes a standard customary for representing an AI/ML mannequin definition and utilities for changing fashions to and from this customary. The answer we are going to discover makes use of the favored ONNX format.
    2. Conversion Primarily based on a Standardized Excessive-level API: On this answer we simplify the conversion activity by limiting our mannequin to an outlined set of excessive degree summary APIs with supported implementations in every of the AI/ML frameworks of curiosity. For this strategy, we are going to use the Keras3 library.

    Within the subsequent sections we are going to assess these methods on a toy TensorFlow mannequin.

    A Toy TensorFlow Mannequin

    Within the code block under we initialize and run a TensorFlow Imaginative and prescient Transformer (ViT) mannequin from HuggingFace’s widespread transformers library (model 4.55.4), TFViTForImageClassification. Word that consistent with HuggingFace’s choice to deprecate help for TensorFlow, this class was faraway from current releases of the library. The HuggingFace TensorFlow mannequin depends on Keras 2 which we dutifully set up through the tf-keras (2.20.1) bundle. We set the ViTConfig.hidden_act subject to “gelu_new” for ONNX compatibility:

    import tensorflow as tf
    gpu = tf.config.list_physical_devices('GPU')[0]
    tf.config.experimental.set_memory_growth(gpu, True)
    
    from transformers import ViTConfig, TFViTForImageClassification
    vit_config = ViTConfig(hidden_act="gelu_new", return_dict=False)
    tf_model = TFViTForImageClassification(vit_config)

    Mannequin Conversion Utilizing ONNX

    The primary methodology we assess depends on Open Neural Network Exchange (ONNX), a group undertaking that goals to outline an open format for constructing AI/ML fashions to extend interoperability between AI/ML frameworks and scale back the dependence on any single one. Included within the ONNX API providing are utilities for changing fashions from frequent frameworks, together with TensorFlow, to the ONNX format. There are additionally a number of public libraries for changing ONNX fashions to PyTorch. On this submit we use the onnx2torch utility. Thus, mannequin conversion from TensorFlow to PyTorch will be achieved by successively making use of TensorFlow-to-ONNX conversion adopted by ONNX-to-PyTorch conversion.

    To evaluate this answer we set up the onnx (1.19.1), tf2onnx (1.16.1), and onnx2torch (1.5.15 ) libraries. We apply the no-deps flag to stop an undesired downgrade of the protobuf library:

    pip set up --no-deps onnx tf2onnx onnx2torch

    The conversion scheme seems within the code block under:

    import tensorflow as tf
    import torch
    import tf2onnx, onnx2torch
    
    BATCH_SIZE = 32
    DEVICE = "cuda"
    
    spec = (tf.TensorSpec((BATCH_SIZE, 3, 224, 224), tf.float32, identify="enter"),)
    onnx_model, _ = tf2onnx.convert.from_keras(tf_model, input_signature=spec)
    converted_model = onnx2torch.convert(onnx_model)

    To ensure that the resultant mannequin is certainly a PyTorch module, we run the next assertion:

    assert isinstance(converted_model, torch.nn.Module)

    Allow us to now assess the standard and make-up of the resultant PyTorch mannequin.

    Numerical Precision

    To confirm the validity of the transformed mannequin, we execute each the TensorFlow mannequin and the transformed mannequin on the identical enter and evaluate the outcomes:

    import numpy as np
    
    batch_input = np.random.randn(BATCH_SIZE, 3, 224, 224).astype(np.float32)
    
    # execute tf mannequin
    tf_input = tf.convert_to_tensor(batch_input)
    tf_output = tf_model(tf_input, coaching=False)
    tf_output = tf_output[0].numpy()
    
    # execute transformed mannequin
    converted_model = converted_model.to(DEVICE)
    converted_model = converted_model.eval()
    torch_input = torch.from_numpy(batch_input).to(DEVICE)
    torch_output = converted_model(torch_input)
    torch_output = torch_output.detach().cpu().numpy()
    
    # evaluate outcomes
    print("Max diff:", np.max(np.abs(tf_output - torch_output)))
    
    # pattern output:
    # Max diff: 9.3877316e-07

    The outputs are actually shut sufficient to validate the transformed mannequin.

    Mannequin Construction

    To get a really feel for the construction of the transformed mannequin, we calculate the variety of trainable comparisons and evaluate it that of the unique mannequin:

    num_tf_params = sum([np.prod(v.shape) for v in tf_model.trainable_weights])
    num_pyt_params = sum([p.numel()
                          for p in converted_model.parameters()
                          if p.requires_grad])
    print(f"TensorFlow trainable parameters: {num_tf_params}")
    print(f"PyTorch Trainable Parameters: {num_pyt_params:,}")

    The distinction within the variety of trainable parameters is profound, simply 589,824 within the transformed mannequin in comparison with over 85 million within the unique mannequin. Traversing the layers of the transformed mannequin results in that very same conclusion: The ONNX-based conversion has fully altered the mannequin construction, rendering it primarily unrecognizable. There are a selection of ramifications to this discovering, together with:

    1. Coaching/fine-tuning the transformed mannequin: Though we’ve got proven that the transformed mannequin can be utilized for inference, the change in construction — notably the truth that a few of the mannequin parameters have been baked in, implies that we can’t use the transformed mannequin for coaching or fine-tuning.
    2. Making use of pinpoint PyTorch optimizations to the mannequin: The transformed mannequin consists of a really giant variety of layers every representing a comparatively low-level operation. This drastically limits our skill to exchange inefficient operations with optimized PyTorch equivalents, corresponding to torch.nn.functional.scaled_dot_product_attention (SPDA).

    Mannequin Optimization

    We’ve got already seen that our skill to entry and modify mannequin operations is proscribed, however there are a variety of optimizations that we will apply that don’t require such entry. Within the code block under, we apply PyTorch compilation and automatic mixed precision (AMP) and evaluate the resultant throughput to that of the TensorFlow mannequin. For additional context, we additionally take a look at the runtime of the PyTorch model of the ViTForImageClassification mannequin:

    # Set tf blended precision coverage to bfloat16
    tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
    
    # Set torch matmul precision to excessive
    torch.set_float32_matmul_precision('excessive')
    
    @tf.perform
    def tf_infer_fn(batch):
        return tf_model(batch, coaching=False)
    
    def get_torch_infer_fn(mannequin):
        def infer_fn(batch):
            with torch.inference_mode(), torch.amp.autocast(
                    DEVICE,
                    dtype=torch.bfloat16,
                    enabled=DEVICE=='cuda'
            ):
                output = mannequin(batch)
            return output
        return infer_fn
    
    def benchmark(infer_fn, batch):
        # warm-up
        for _ in vary(20):
            _ = infer_fn(batch)
        begin = torch.cuda.Occasion(enable_timing=True)
        finish = torch.cuda.Occasion(enable_timing=True)
        torch.cuda.synchronize()
        begin.document()
    
        iters = 100
    
        for _ in vary(iters):
            _ = infer_fn(batch)
        finish.document()
        torch.cuda.synchronize()
        return begin.elapsed_time(finish) / iters
    
    # assess throughput of TF mannequin
    avg_time = benchmark(tf_infer_fn, tf_input)
    print(f"nTensorFlow common step time: {(avg_time):.4f}")
    
    # assess throughput of transformed mannequin
    torch_infer_fn = get_torch_infer_fn(converted_model) 
    avg_time = benchmark(torch_infer_fn, torch_input)
    print(f"nConverted mannequin common step time: {(avg_time):.4f}")
    
    # assess throughput of compiled mannequin
    torch_infer_fn = get_torch_infer_fn(torch.compile(converted_model)) 
    avg_time = benchmark(torch_infer_fn, torch_input)
    print(f"nCompiled mannequin common step time: {(avg_time):.4f}")
    
    # assess throughput of torch ViT
    from transformers import ViTForImageClassification
    torch_model = ViTForImageClassification(vit_config).to(DEVICE)
    torch_infer_fn = get_torch_infer_fn(torch_model) 
    avg_time = benchmark(torch_infer_fn, torch_input)
    print(f"nPyTorch ViT mannequin common step time: {(avg_time):.4f}")
    
    # assess throughput of compiled torch ViT
    torch_infer_fn = get_torch_infer_fn(torch.compile(torch_model)) 
    avg_time = benchmark(torch_infer_fn, torch_input)
    print(f"nCompiled ViT mannequin common step time: {(avg_time):.4f}")

    Word that originally PyTorch compilation fails on the transformed mannequin resulting from the usage of torch.Size operator within the OnnxReshape layer. Whereas that is simply fixable (e.g., tuple([int(i) for i in shape])), it factors to a deeper impediment to optimization of the mannequin: The reshape layer, which seems dozens of instances within the mannequin, treats shapes as PyTorch tensors residing on the GPU. Because of this every name requires detaching the form tensor from the graph and copying it to the CPU. The conclusion is that though the transformed mannequin is functionally correct, its resultant definition will not be optimized for runtime efficiency. This may be seen from the step time outcomes of the completely different mannequin configurations:

    ONNX-Primarily based Conversion Step Time Outcomes (by Creator)

    The transformed mannequin is slower than the unique TensorFlow circulate and considerably slower than PyTorch model of the ViT mannequin.

    Limitations

    Though (within the case of our toy mannequin) the ONNX-based conversion scheme works, it has quite a few vital limitations:

    1. In the course of the conversion many parameters had been baked into the mannequin, limiting its use to inference workloads solely.
    2. The ONNX conversion breaks the computation graph into low degree operators in a fashion that makes it tough to use and/or reap the good thing about some PyTorch optimizations.
    3. The reliance on ONNX implies that our conversion scheme will solely work on ONNX-friendly fashions. It won’t work on fashions that can not be mapped to the usual ONNX operator set (e.g., fashions with dynamic management circulate).
    4. The conversion scheme depends on the well being and upkeep of a third-party library that’s not a part of the official ONNX providing.

    Though the scheme works — at the very least for inference workloads — it’s possible you’ll discover the constraints to be too restrictive to be used by yourself TensorFlow fashions. One doable choice is to desert the ONNX-to-PyTorch conversion and carry out inference utilizing the ONNX Runtime library.

    Mannequin Conversion By way of Keras3

    Keras3 is a high-level deep studying API targeted on maximizing the readability, maintainability, and ease of use of AI/ML purposes. In a previous post, we evaluated Keras3 and highlighted its help for a number of backends. On this submit we revisit its multi-framework help and assess whether or not this may be utilized for mannequin conversion. The scheme we suggest is to 1) migrate the existing TensorFlow model to Keras3 after which 2) run the mannequin with the Keras3 PyTorch backend.

    Upgrading TensorFlow to Keras3

    Opposite to the ONNX-based conversion scheme, our present answer might require some code modifications to the TensorFlow mannequin emigrate it to Keras3. Whereas the documentation makes it sound easy, in follow the issue of the migration will rely drastically on the main points of the mannequin implementation. Within the case of our toy mannequin, HuggingFace explicitly enforces the usage of the legacy tf-keras, stopping the usage of Keras3. To implement our scheme, we have to 1) redefine the mannequin with out this restriction, and a pair of) substitute native TensorFlow operators with Keras3 equivalents. The code block under comprises a stripped-down model of the mannequin, together with the required changes. To get a full grasp of the modifications that had been required, carry out a side-by-side code comparability with the original model definition.

    import math
    import keras
    
    HIDDEN_SIZE = 768
    IMG_SIZE = 224
    PATCH_SIZE = 16
    ATTN_HEADS = 12
    NUM_LAYERS = 12
    INTER_SZ = 4*HIDDEN_SIZE
    N_LABELS = 2
    
    
    class TFViTEmbeddings(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.patch_embeddings = TFViTPatchEmbeddings()
            num_patches = self.patch_embeddings.num_patches
            self.cls_token = self.add_weight((1, 1, HIDDEN_SIZE))
            self.position_embeddings = self.add_weight((1, num_patches+1,
                                                        HIDDEN_SIZE))
    
        def name(self, pixel_values, coaching=False):
            bs, num_channels, peak, width = pixel_values.form
            embeddings = self.patch_embeddings(pixel_values, coaching=coaching)
            cls_tokens = keras.ops.repeat(self.cls_token, repeats=bs, axis=0)
            embeddings = keras.ops.concatenate((cls_tokens, embeddings), axis=1)
            embeddings = embeddings + self.position_embeddings
            return embeddings
    
    class TFViTPatchEmbeddings(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            patch_size = (PATCH_SIZE, PATCH_SIZE)
            image_size = (IMG_SIZE, IMG_SIZE)
            num_patches = (image_size[1]//patch_size[1]) * 
                          (image_size[0]//patch_size[0])
            self.patch_size = patch_size
            self.num_patches = num_patches
            self.projection = keras.layers.Conv2D(
                filters=HIDDEN_SIZE,
                kernel_size=patch_size,
                strides=patch_size,
                padding="legitimate",
                data_format="channels_last"
            )
    
        def name(self, pixel_values, coaching=False):
            bs, num_channels, peak, width = pixel_values.form
            pixel_values = keras.ops.transpose(pixel_values, (0, 2, 3, 1))
            projection = self.projection(pixel_values)
            num_patches = (width // self.patch_size[1]) * 
                          (peak // self.patch_size[0])
            embeddings = keras.ops.reshape(projection, (bs, num_patches, -1))
            return embeddings
    
    class TFViTSelfAttention(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.num_attention_heads = ATTN_HEADS
            self.attention_head_size = int(HIDDEN_SIZE / ATTN_HEADS)
            self.all_head_size = ATTN_HEADS * self.attention_head_size
            self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
            self.question = keras.layers.Dense(self.all_head_size,  identify="question")
            self.key = keras.layers.Dense(self.all_head_size, identify="key")
            self.worth = keras.layers.Dense(self.all_head_size, identify="worth")
    
        def transpose_for_scores(self, tensor, batch_size: int):
            tensor = keras.ops.reshape(tensor, (batch_size, -1, ATTN_HEADS,
                                                self.attention_head_size))
            return keras.ops.transpose(tensor, [0, 2, 1, 3])
    
        def name(self, hidden_states, coaching=False):
            bs = hidden_states.form[0]
            mixed_query_layer = self.question(inputs=hidden_states)
            mixed_key_layer = self.key(inputs=hidden_states)
            mixed_value_layer = self.worth(inputs=hidden_states)
            query_layer = self.transpose_for_scores(mixed_query_layer, bs)
            key_layer = self.transpose_for_scores(mixed_key_layer, bs)
            value_layer = self.transpose_for_scores(mixed_value_layer, bs)
            key_layer_T = keras.ops.transpose(key_layer, [0,1,3,2])
            attention_scores = keras.ops.matmul(query_layer, key_layer_T)
            dk = keras.ops.solid(self.sqrt_att_head_size,
                                dtype=attention_scores.dtype)
            attention_scores = keras.ops.divide(attention_scores, dk)
            attention_probs = keras.ops.softmax(attention_scores+1e-9, axis=-1)
            attention_output = keras.ops.matmul(attention_probs, value_layer)
            attention_output = keras.ops.transpose(attention_output,[0,2,1,3])
            attention_output = keras.ops.reshape(attention_output,
                                                 (bs, -1, self.all_head_size))
            return (attention_output,)
    
    class TFViTSelfOutput(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.dense = keras.layers.Dense(HIDDEN_SIZE)
    
        def name(self, hidden_states, input_tensor, coaching = False):
            return self.dense(inputs=hidden_states)
    
    class TFViTAttention(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.self_attention = TFViTSelfAttention()
            self.dense_output = TFViTSelfOutput()
    
        def name(self, input_tensor, coaching = False):
            self_outputs = self.self_attention(
                hidden_states=input_tensor, coaching=coaching
            )
            attention_output = self.dense_output(
                hidden_states=self_outputs[0],
                input_tensor=input_tensor,
                coaching=coaching
            )
            return (attention_output,)
    
    class TFViTIntermediate(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.dense = keras.layers.Dense(INTER_SZ)
            self.intermediate_act_fn = keras.activations.gelu
    
        def name(self, hidden_states):
            hidden_states = self.dense(hidden_states)
            hidden_states = self.intermediate_act_fn(hidden_states)
            return hidden_states
    
    class TFViTOutput(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.dense = keras.layers.Dense(HIDDEN_SIZE)
    
        def name(self, hidden_states, input_tensor, coaching: bool = False):
            hidden_states = self.dense(inputs=hidden_states)
            hidden_states = hidden_states + input_tensor
            return hidden_states
    
    class TFViTLayer(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.consideration = TFViTAttention()
            self.intermediate = TFViTIntermediate()
            self.vit_output = TFViTOutput()
            self.layernorm_before = keras.layers.LayerNormalization(
                epsilon=1e-12
            )
            self.layernorm_after = keras.layers.LayerNormalization(
                epsilon=1e-12
            )
    
        def name(self, hidden_states, coaching=False):
            attention_outputs = self.consideration(
                input_tensor=self.layernorm_before(inputs=hidden_states),
                coaching=coaching,
            )
            attention_output = attention_outputs[0]
            hidden_states = attention_output + hidden_states
            layer_output = self.layernorm_after(hidden_states)
            intermediate_output = self.intermediate(layer_output)
            layer_output = self.vit_output(
                hidden_states=intermediate_output,
                input_tensor=hidden_states,
                coaching=coaching
            )
            outputs = (layer_output,)
            return outputs
    
    class TFViTEncoder(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.layer = [TFViTLayer(name=f"layer_{i}")
                          for i in range(NUM_LAYERS)]
    
        def name(self, hidden_states, coaching=False):
            for i, layer_module in enumerate(self.layer):
                layer_outputs = layer_module(
                    hidden_states=hidden_states,
                    coaching=coaching,
                )
                hidden_states = layer_outputs[0]
            return tuple([hidden_states])
    
    class TFViTMainLayer(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.embeddings = TFViTEmbeddings()
            self.encoder = TFViTEncoder()
            self.layernorm = keras.layers.LayerNormalization(epsilon=1e-12)
    
        def name(self, pixel_values, coaching=False):
            embedding_output = self.embeddings(
                pixel_values=pixel_values,
                coaching=coaching,
            )
            encoder_outputs = self.encoder(
                hidden_states=embedding_output,
                coaching=coaching,
            )
            sequence_output = encoder_outputs[0]
            sequence_output = self.layernorm(inputs=sequence_output)
            return (sequence_output,)
    
    class TFViTForImageClassification(keras.Mannequin):
        def __init__(self, *inputs, **kwargs):
            tremendous().__init__(*inputs, **kwargs)
            self.vit = TFViTMainLayer()
            self.classifier = keras.layers.Dense(N_LABELS)
    
        def name(self, pixel_values, coaching=False):
            outputs = self.vit(pixel_values, coaching=coaching)
            sequence_output = outputs[0]
            logits = self.classifier(inputs=sequence_output[:, 0, :])
            return (logits,)

    TensorFlow to PyTorch Conversion

    The conversion sequence seems within the code block under. As earlier than, we validate the output of the resultant mannequin in addition to the variety of trainable parameters.

    # save weights of TensorFlow mannequin
    tf_model.save_weights("model_weights.h5")
    
    import keras
    keras.config.set_backend("torch")
    
    from keras3_vit import TFViTForImageClassification as Keras3ViT
    keras3_model = Keras3ViT()
    
    # name mannequin to initializate all layers
    keras3_model(torch_input, coaching=False)
    
    # load the weights from the TensorFlow mannequin
    keras3_model.load_weights("model_weights.h5")
    
    # validate transformed mannequin
    assert isinstance(keras3_model, torch.nn.Module)
    
    keras3_model = keras3_model.to(DEVICE)
    keras3_model = keras3_model.eval()
    torch_output = keras3_model(torch_input, coaching=False)
    torch_output = torch_output[0].detach().cpu().numpy()
    print("Max diff:", np.max(np.abs(tf_output - torch_output)))
    
    num_pyt_params = sum([p.numel()
                          for p in keras3_model.parameters()
                          if p.requires_grad])
    print(f"Keras3 Trainable Parameters: {num_pyt_params:,}")

    Coaching/High-quality-tuning the Mannequin

    Opposite to the ONNX-converted mannequin, the Keras3 mannequin maintains the identical construction and trainable parameters. This permits for resuming coaching and/or finetuning on the transformed mannequin. This will both be completed throughout the Keras3 training framework or utilizing a standard PyTorch training loop.

    Optimizing Mannequin Layers

    Opposite to the ONNX-converted mannequin, the coherence of the Keras3 mannequin definition permits for simply modifying and optimizing the layer implementations. Within the code block under, we substitute the prevailing consideration mechanism with PyTorch’s extremely environment friendly SDPA operator.

    from torch.nn.purposeful import scaled_dot_product_attention as sdpa
    
    class TFViTSelfAttention(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.num_attention_heads = ATTN_HEADS
            self.attention_head_size = int(HIDDEN_SIZE / ATTN_HEADS)
            self.all_head_size = ATTN_HEADS * self.attention_head_size
            self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
            self.question = keras.layers.Dense(self.all_head_size,  identify="question")
            self.key = keras.layers.Dense(self.all_head_size, identify="key")
            self.worth = keras.layers.Dense(self.all_head_size, identify="worth")
    
        def transpose_for_scores(self, tensor, batch_size: int):
            tensor = keras.ops.reshape(tensor, (batch_size, -1, ATTN_HEADS,
                                                self.attention_head_size))
            return keras.ops.transpose(tensor, [0, 2, 1, 3])
    
        def name(self, hidden_states, coaching=False):
            bs = hidden_states.form[0]
            mixed_query_layer = self.question(inputs=hidden_states)
            mixed_key_layer = self.key(inputs=hidden_states)
            mixed_value_layer = self.worth(inputs=hidden_states)
            query_layer = self.transpose_for_scores(mixed_query_layer, bs)
            key_layer = self.transpose_for_scores(mixed_key_layer, bs)
            value_layer = self.transpose_for_scores(mixed_value_layer, bs)
            sdpa_output = sdpa(query_layer, key_layer, value_layer)
            attention_output = keras.ops.transpose(sdpa_output,[0,2,1,3])
            attention_output = keras.ops.reshape(attention_output,
                                                 (bs, -1, self.all_head_size))
            return (attention_output,)

    We utilizing the identical benchmarking perform from above to evaluate the influence of this optimization on the mannequin’s runtime efficiency:

    torch_infer_fn = get_torch_infer_fn(keras3_model)
    avg_time = benchmark(torch_infer_fn, torch_input)
    print(f"Keras3 transformed mannequin common step time: {(avg_time):.4f}")

    The outcomes are captured within the desk under:

    Keras3 Conversion Step Time Outcomes (by Creator)

    Utilizing the Keras3-based mannequin conversion scheme, and making use of the SDPA optimization, we’re in a position to speed up the mannequin inference throughput by 22% in comparison with the unique TensorFlow mannequin.

    Mannequin Compilation

    One other optimization we want to apply is PyTorch compilation. Sadly (as of the time of this writing), PyTorch compilation in Keras3 is proscribed. Within the case of our toy mannequin, each our try to use torch.compile on to the mannequin, in addition to setting the jit_compile subject of the Keras3 Model.compile perform, failed. In each circumstances, the failure resulted from a number of recompilations that had been triggered by the Keras3 inside equipment. Whereas Keras3 grants entry to the PyTorch ecosystem, its high-level abstraction may impose some limitations.

    Limitations

    As soon as once more, we’ve got a conversion scheme that works however has a number of limitations:

    1. The TensorFlow fashions should be Keras3-compatible. The quantity of labor this can require will rely upon the main points of your mannequin implementation. It could require some Keras layer customization.
    2. Whereas the resultant mannequin is a torch.nn.Module, it isn’t a “pure” PyTorch mannequin within the sense that it’s comprised of Keras3 layers and contains numerous extra Keras3 code. This will likely require some diversifications to our PyTorch tooling and will impose some restrictions, as we noticed after we tried to use PyTorch compilation.
    3. The answer depends on the well being and upkeep of Keras3 and its help for the TensorFlow and PyTorch backends.

    Abstract

    On this submit we’ve got proposed and assessed two strategies for auto-conversion of legacy TensorFlow fashions to PyTorch. We summarize our findings within the following desk.

    Comparability of Conversion Schemes (by Creator)

    Finally, the very best strategy, whether or not it’s one of many strategies mentioned right here, guide conversion, an answer based mostly on generative AI, or the choice to not carry out conversion in any respect, will drastically rely upon the main points of the mannequin and the state of affairs.



    Source link

    Share. Facebook Twitter Pinterest LinkedIn Tumblr Email
    Previous ArticleInsmind Image to Video: Features, Pricing and Alternatives
    Next Article Wondershare Filmora: Features, Benefits, Review and Alternatives
    ProfitlyAI
    • Website

    Related Posts

    Artificial Intelligence

    Optimizing Data Transfer in Distributed AI/ML Training Workloads

    January 23, 2026
    Artificial Intelligence

    Achieving 5x Agentic Coding Performance with Few-Shot Prompting

    January 23, 2026
    Artificial Intelligence

    Why the Sophistication of Your Prompt Correlates Almost Perfectly with the Sophistication of the Response, as Research by Anthropic Found

    January 23, 2026
    Add A Comment
    Leave A Reply Cancel Reply

    Top Posts

    Extracting Clinical Information from EHRs Using NLP & AI Models

    November 13, 2025

    Teaching AI models what they don’t know | MIT News

    June 3, 2025

    Hyper-Realistic AI Video Is Outpacing Our Ability to Label It

    June 3, 2025

    How To Significantly Enhance LLMs by Leveraging Context Engineering

    July 22, 2025

    Writing Is Thinking | Towards Data Science

    September 2, 2025
    Categories
    • AI Technology
    • AI Tools & Technologies
    • Artificial Intelligence
    • Latest AI Innovations
    • Latest News
    Most Popular

    Medical Datasets for Machine Learning

    April 5, 2025

    A new generative AI approach to predicting chemical reactions | MIT News

    September 3, 2025

    A Practical Starters’ Guide to Causal Structure Learning with Bayesian Methods in Python

    June 17, 2025
    Our Picks

    Optimizing Data Transfer in Distributed AI/ML Training Workloads

    January 23, 2026

    Achieving 5x Agentic Coding Performance with Few-Shot Prompting

    January 23, 2026

    Why the Sophistication of Your Prompt Correlates Almost Perfectly with the Sophistication of the Response, as Research by Anthropic Found

    January 23, 2026
    Categories
    • AI Technology
    • AI Tools & Technologies
    • Artificial Intelligence
    • Latest AI Innovations
    • Latest News
    • Privacy Policy
    • Disclaimer
    • Terms and Conditions
    • About us
    • Contact us
    Copyright © 2025 ProfitlyAI All Rights Reserved.

    Type above and press Enter to search. Press Esc to cancel.