You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/11/02 22:51:02 UTC

[tvm] branch main updated: [Frontend][Tensorflow2] Import graph_def to default graph before calling function_def_to_graph_def (#13260)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ff6aaeb12a [Frontend][Tensorflow2] Import graph_def to default graph before calling function_def_to_graph_def (#13260)
ff6aaeb12a is described below

commit ff6aaeb12ae71393fef37da8f9c72a0f2017e6d5
Author: Alexander Pivovarov <pi...@amazon.com>
AuthorDate: Wed Nov 2 15:50:56 2022 -0700

    [Frontend][Tensorflow2] Import graph_def to default graph before calling function_def_to_graph_def (#13260)
    
    [TF2] Import graph_def to default graph before calling function_def_to_graph_def
---
 python/tvm/relay/frontend/tensorflow2.py | 30 ++++++++++++++++++------------
 1 file changed, 18 insertions(+), 12 deletions(-)

diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py
index 465f530624..2a2a64b295 100644
--- a/python/tvm/relay/frontend/tensorflow2.py
+++ b/python/tvm/relay/frontend/tensorflow2.py
@@ -25,6 +25,7 @@ Otherwise use the tf1.x converter:
 """
 
 import numpy as np
+import tensorflow as tf
 from tensorflow.python.framework import function_def_to_graph, tensor_util, dtypes
 
 import tvm
@@ -839,16 +840,21 @@ def from_tensorflow(graph_def, layout="NHWC", shape=None, outputs=None):
 
     """
 
-    # Subgraph graph_defs are cached here to avoid a TF error when parsing after prelude init
-    graph_def_library = {}
-    for func in graph_def.library.function:
-        inshape = func.attr["_input_shapes"].list.shape
-        graph_def_library[func.signature.name], _ = function_def_to_graph.function_def_to_graph_def(
-            func, inshape
+    with tf.Graph().as_default():
+        tf.import_graph_def(graph_def, name="")
+        # Subgraph graph_defs are cached here to avoid a TF error when parsing after prelude init
+        graph_def_library = {}
+        for func in graph_def.library.function:
+            inshape = func.attr["_input_shapes"].list.shape
+            (
+                graph_def_library[func.signature.name],
+                _,
+            ) = function_def_to_graph.function_def_to_graph_def(func, inshape)
+        module = RelayModule()
+        g = GraphProto(module)
+        func, params = g.from_tensorflow(
+            graph_def, layout, shape, outputs, gdef_lib=graph_def_library
         )
-    module = RelayModule()
-    g = GraphProto(module)
-    func, params = g.from_tensorflow(graph_def, layout, shape, outputs, gdef_lib=graph_def_library)
-    module.mod["main"] = func
-    module.params.update(params)
-    return module.mod, module.params
+        module.mod["main"] = func
+        module.params.update(params)
+        return module.mod, module.params