You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/03/18 05:10:09 UTC
[tvm] branch main updated: [TVMC] Allow optional arguments to be
passed to importers (#7674)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 38aed59 [TVMC] Allow optional arguments to be passed to importers (#7674)
38aed59 is described below
commit 38aed59f9fdddcbc9ac98afb8aa11455c81fc9de
Author: CircleSpin <2k...@gmail.com>
AuthorDate: Thu Mar 18 01:09:55 2021 -0400
[TVMC] Allow optional arguments to be passed to importers (#7674)
* add support for optional args for frontends tvmc
* remove unnecessary comments
* Add changes suggested by Matt W. via PR
Co-authored-by: Jocelyn <jo...@pop-os.localdomain>
---
python/tvm/driver/tvmc/frontends.py | 27 ++++++++++++++-------------
tests/python/driver/tvmc/test_frontends.py | 22 +++++++++++++++-------
2 files changed, 29 insertions(+), 20 deletions(-)
diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py
index 16e6c8e..0488223 100644
--- a/python/tvm/driver/tvmc/frontends.py
+++ b/python/tvm/driver/tvmc/frontends.py
@@ -54,7 +54,7 @@ class Frontend(ABC):
"""File suffixes (extensions) used by this frontend"""
@abstractmethod
- def load(self, path, shape_dict=None):
+ def load(self, path, shape_dict=None, **kwargs):
"""Load a model from a given path.
Parameters
@@ -101,7 +101,7 @@ class KerasFrontend(Frontend):
def suffixes():
return ["h5"]
- def load(self, path, shape_dict=None):
+ def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0103
tf, keras = import_keras()
@@ -130,7 +130,8 @@ class KerasFrontend(Frontend):
input_shapes = {name: x.shape for (name, x) in zip(model.input_names, inputs)}
if shape_dict is not None:
input_shapes.update(shape_dict)
- return relay.frontend.from_keras(model, input_shapes, layout="NHWC")
+ kwargs.setdefault("layout", "NHWC")
+ return relay.frontend.from_keras(model, input_shapes, **kwargs)
def is_sequential_p(self, model):
_, keras = import_keras()
@@ -158,14 +159,14 @@ class OnnxFrontend(Frontend):
def suffixes():
return ["onnx"]
- def load(self, path, shape_dict=None):
+ def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import onnx
# pylint: disable=E1101
model = onnx.load(path)
- return relay.frontend.from_onnx(model, shape=shape_dict)
+ return relay.frontend.from_onnx(model, shape=shape_dict, **kwargs)
class TensorflowFrontend(Frontend):
@@ -179,7 +180,7 @@ class TensorflowFrontend(Frontend):
def suffixes():
return ["pb"]
- def load(self, path, shape_dict=None):
+ def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import tensorflow as tf
import tvm.relay.testing.tf as tf_testing
@@ -192,7 +193,7 @@ class TensorflowFrontend(Frontend):
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
logger.debug("parse TensorFlow model and convert into Relay computation graph")
- return relay.frontend.from_tensorflow(graph_def, shape=shape_dict)
+ return relay.frontend.from_tensorflow(graph_def, shape=shape_dict, **kwargs)
class TFLiteFrontend(Frontend):
@@ -206,7 +207,7 @@ class TFLiteFrontend(Frontend):
def suffixes():
return ["tflite"]
- def load(self, path, shape_dict=None):
+ def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import tflite.Model as model
@@ -229,7 +230,7 @@ class TFLiteFrontend(Frontend):
raise TVMCException("input file not tflite version 3")
logger.debug("parse TFLite model and convert into Relay computation graph")
- mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict)
+ mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, **kwargs)
return mod, params
@@ -245,7 +246,7 @@ class PyTorchFrontend(Frontend):
# Torch Script is a zip file, but can be named pth
return ["pth", "zip"]
- def load(self, path, shape_dict=None):
+ def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import torch
@@ -259,7 +260,7 @@ class PyTorchFrontend(Frontend):
input_shapes = list(shape_dict.items())
logger.debug("parse Torch model and convert into Relay computation graph")
- return relay.frontend.from_pytorch(traced_model, input_shapes)
+ return relay.frontend.from_pytorch(traced_model, input_shapes, **kwargs)
ALL_FRONTENDS = [
@@ -339,7 +340,7 @@ def guess_frontend(path):
raise TVMCException("failed to infer the model format. Please specify --model-format")
-def load_model(path, model_format=None, shape_dict=None):
+def load_model(path, model_format=None, shape_dict=None, **kwargs):
"""Load a model from a supported framework and convert it
into an equivalent relay representation.
@@ -367,6 +368,6 @@ def load_model(path, model_format=None, shape_dict=None):
else:
frontend = guess_frontend(path)
- mod, params = frontend.load(path, shape_dict)
+ mod, params = frontend.load(path, shape_dict, **kwargs)
return mod, params
diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py
index b41f4c4..5a63c5c 100644
--- a/tests/python/driver/tvmc/test_frontends.py
+++ b/tests/python/driver/tvmc/test_frontends.py
@@ -115,26 +115,34 @@ def test_load_model__tflite(tflite_mobilenet_v1_1_quant):
assert "_param_1" in params.keys()
-def test_load_model__keras(keras_resnet50):
+@pytest.mark.parametrize("load_model_kwargs", [{}, {"layout": "NCHW"}])
+def test_load_model__keras(keras_resnet50, load_model_kwargs):
# some CI environments wont offer TensorFlow/Keras, so skip in case it is not present
pytest.importorskip("tensorflow")
- mod, params = tvmc.frontends.load_model(keras_resnet50)
+ mod, params = tvmc.frontends.load_model(keras_resnet50, **load_model_kwargs)
assert type(mod) is IRModule
assert type(params) is dict
## check whether one known value is part of the params dict
assert "_param_1" in params.keys()
+def verify_load_model__onnx(model, **kwargs):
+ mod, params = tvmc.frontends.load_model(model, **kwargs)
+ assert type(mod) is IRModule
+ assert type(params) is dict
+ return mod, params
+
+
def test_load_model__onnx(onnx_resnet50):
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")
-
- mod, params = tvmc.frontends.load_model(onnx_resnet50)
- assert type(mod) is IRModule
- assert type(params) is dict
- ## check whether one known value is part of the params dict
+ mod, params = verify_load_model__onnx(onnx_resnet50)
+ # check whether one known value is part of the params dict
assert "resnetv24_batchnorm0_gamma" in params.keys()
+ mod, params = verify_load_model__onnx(onnx_resnet50, freeze_params=True)
+ # check that the parameter dict is empty, implying that they have been folded into constants
+ assert params == {}
def test_load_model__pb(pb_mobilenet_v1_1_quant):