You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/11/02 22:51:02 UTC
[tvm] branch main updated: [Frontend][Tensorflow2] Import graph_def to default graph before calling function_def_to_graph_def (#13260)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ff6aaeb12a [Frontend][Tensorflow2] Import graph_def to default graph before calling function_def_to_graph_def (#13260)
ff6aaeb12a is described below
commit ff6aaeb12ae71393fef37da8f9c72a0f2017e6d5
Author: Alexander Pivovarov <pi...@amazon.com>
AuthorDate: Wed Nov 2 15:50:56 2022 -0700
[Frontend][Tensorflow2] Import graph_def to default graph before calling function_def_to_graph_def (#13260)
[TF2] Import graph_def to default graph before calling function_def_to_graph_def
---
python/tvm/relay/frontend/tensorflow2.py | 30 ++++++++++++++++++------------
1 file changed, 18 insertions(+), 12 deletions(-)
diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py
index 465f530624..2a2a64b295 100644
--- a/python/tvm/relay/frontend/tensorflow2.py
+++ b/python/tvm/relay/frontend/tensorflow2.py
@@ -25,6 +25,7 @@ Otherwise use the tf1.x converter:
"""
import numpy as np
+import tensorflow as tf
from tensorflow.python.framework import function_def_to_graph, tensor_util, dtypes
import tvm
@@ -839,16 +840,21 @@ def from_tensorflow(graph_def, layout="NHWC", shape=None, outputs=None):
"""
- # Subgraph graph_defs are cached here to avoid a TF error when parsing after prelude init
- graph_def_library = {}
- for func in graph_def.library.function:
- inshape = func.attr["_input_shapes"].list.shape
- graph_def_library[func.signature.name], _ = function_def_to_graph.function_def_to_graph_def(
- func, inshape
+ with tf.Graph().as_default():
+ tf.import_graph_def(graph_def, name="")
+ # Subgraph graph_defs are cached here to avoid a TF error when parsing after prelude init
+ graph_def_library = {}
+ for func in graph_def.library.function:
+ inshape = func.attr["_input_shapes"].list.shape
+ (
+ graph_def_library[func.signature.name],
+ _,
+ ) = function_def_to_graph.function_def_to_graph_def(func, inshape)
+ module = RelayModule()
+ g = GraphProto(module)
+ func, params = g.from_tensorflow(
+ graph_def, layout, shape, outputs, gdef_lib=graph_def_library
)
- module = RelayModule()
- g = GraphProto(module)
- func, params = g.from_tensorflow(graph_def, layout, shape, outputs, gdef_lib=graph_def_library)
- module.mod["main"] = func
- module.params.update(params)
- return module.mod, module.params
+ module.mod["main"] = func
+ module.params.update(params)
+ return module.mod, module.params