You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/03/18 05:10:09 UTC

[tvm] branch main updated: [TVMC] Allow optional arguments to be passed to importers (#7674)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 38aed59  [TVMC] Allow optional arguments to be passed to importers (#7674)
38aed59 is described below

commit 38aed59f9fdddcbc9ac98afb8aa11455c81fc9de
Author: CircleSpin <2k...@gmail.com>
AuthorDate: Thu Mar 18 01:09:55 2021 -0400

    [TVMC] Allow optional arguments to be passed to importers (#7674)
    
    * add support for optional args for frontends tvmc
    
    * remove unnecessary comments
    
    * Add changes suggested by Matt W. via PR
    
    Co-authored-by: Jocelyn <jo...@pop-os.localdomain>
---
 python/tvm/driver/tvmc/frontends.py        | 27 ++++++++++++++-------------
 tests/python/driver/tvmc/test_frontends.py | 22 +++++++++++++++-------
 2 files changed, 29 insertions(+), 20 deletions(-)

diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py
index 16e6c8e..0488223 100644
--- a/python/tvm/driver/tvmc/frontends.py
+++ b/python/tvm/driver/tvmc/frontends.py
@@ -54,7 +54,7 @@ class Frontend(ABC):
         """File suffixes (extensions) used by this frontend"""
 
     @abstractmethod
-    def load(self, path, shape_dict=None):
+    def load(self, path, shape_dict=None, **kwargs):
         """Load a model from a given path.
 
         Parameters
@@ -101,7 +101,7 @@ class KerasFrontend(Frontend):
     def suffixes():
         return ["h5"]
 
-    def load(self, path, shape_dict=None):
+    def load(self, path, shape_dict=None, **kwargs):
         # pylint: disable=C0103
         tf, keras = import_keras()
 
@@ -130,7 +130,8 @@ class KerasFrontend(Frontend):
         input_shapes = {name: x.shape for (name, x) in zip(model.input_names, inputs)}
         if shape_dict is not None:
             input_shapes.update(shape_dict)
-        return relay.frontend.from_keras(model, input_shapes, layout="NHWC")
+        kwargs.setdefault("layout", "NHWC")
+        return relay.frontend.from_keras(model, input_shapes, **kwargs)
 
     def is_sequential_p(self, model):
         _, keras = import_keras()
@@ -158,14 +159,14 @@ class OnnxFrontend(Frontend):
     def suffixes():
         return ["onnx"]
 
-    def load(self, path, shape_dict=None):
+    def load(self, path, shape_dict=None, **kwargs):
         # pylint: disable=C0415
         import onnx
 
         # pylint: disable=E1101
         model = onnx.load(path)
 
-        return relay.frontend.from_onnx(model, shape=shape_dict)
+        return relay.frontend.from_onnx(model, shape=shape_dict, **kwargs)
 
 
 class TensorflowFrontend(Frontend):
@@ -179,7 +180,7 @@ class TensorflowFrontend(Frontend):
     def suffixes():
         return ["pb"]
 
-    def load(self, path, shape_dict=None):
+    def load(self, path, shape_dict=None, **kwargs):
         # pylint: disable=C0415
         import tensorflow as tf
         import tvm.relay.testing.tf as tf_testing
@@ -192,7 +193,7 @@ class TensorflowFrontend(Frontend):
         graph_def = tf_testing.ProcessGraphDefParam(graph_def)
 
         logger.debug("parse TensorFlow model and convert into Relay computation graph")
-        return relay.frontend.from_tensorflow(graph_def, shape=shape_dict)
+        return relay.frontend.from_tensorflow(graph_def, shape=shape_dict, **kwargs)
 
 
 class TFLiteFrontend(Frontend):
@@ -206,7 +207,7 @@ class TFLiteFrontend(Frontend):
     def suffixes():
         return ["tflite"]
 
-    def load(self, path, shape_dict=None):
+    def load(self, path, shape_dict=None, **kwargs):
         # pylint: disable=C0415
         import tflite.Model as model
 
@@ -229,7 +230,7 @@ class TFLiteFrontend(Frontend):
             raise TVMCException("input file not tflite version 3")
 
         logger.debug("parse TFLite model and convert into Relay computation graph")
-        mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict)
+        mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, **kwargs)
         return mod, params
 
 
@@ -245,7 +246,7 @@ class PyTorchFrontend(Frontend):
         # Torch Script is a zip file, but can be named pth
         return ["pth", "zip"]
 
-    def load(self, path, shape_dict=None):
+    def load(self, path, shape_dict=None, **kwargs):
         # pylint: disable=C0415
         import torch
 
@@ -259,7 +260,7 @@ class PyTorchFrontend(Frontend):
         input_shapes = list(shape_dict.items())
 
         logger.debug("parse Torch model and convert into Relay computation graph")
-        return relay.frontend.from_pytorch(traced_model, input_shapes)
+        return relay.frontend.from_pytorch(traced_model, input_shapes, **kwargs)
 
 
 ALL_FRONTENDS = [
@@ -339,7 +340,7 @@ def guess_frontend(path):
     raise TVMCException("failed to infer the model format. Please specify --model-format")
 
 
-def load_model(path, model_format=None, shape_dict=None):
+def load_model(path, model_format=None, shape_dict=None, **kwargs):
     """Load a model from a supported framework and convert it
     into an equivalent relay representation.
 
@@ -367,6 +368,6 @@ def load_model(path, model_format=None, shape_dict=None):
     else:
         frontend = guess_frontend(path)
 
-    mod, params = frontend.load(path, shape_dict)
+    mod, params = frontend.load(path, shape_dict, **kwargs)
 
     return mod, params
diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py
index b41f4c4..5a63c5c 100644
--- a/tests/python/driver/tvmc/test_frontends.py
+++ b/tests/python/driver/tvmc/test_frontends.py
@@ -115,26 +115,34 @@ def test_load_model__tflite(tflite_mobilenet_v1_1_quant):
     assert "_param_1" in params.keys()
 
 
-def test_load_model__keras(keras_resnet50):
+@pytest.mark.parametrize("load_model_kwargs", [{}, {"layout": "NCHW"}])
+def test_load_model__keras(keras_resnet50, load_model_kwargs):
     # some CI environments wont offer TensorFlow/Keras, so skip in case it is not present
     pytest.importorskip("tensorflow")
 
-    mod, params = tvmc.frontends.load_model(keras_resnet50)
+    mod, params = tvmc.frontends.load_model(keras_resnet50, **load_model_kwargs)
     assert type(mod) is IRModule
     assert type(params) is dict
     ## check whether one known value is part of the params dict
     assert "_param_1" in params.keys()
 
 
+def verify_load_model__onnx(model, **kwargs):
+    mod, params = tvmc.frontends.load_model(model, **kwargs)
+    assert type(mod) is IRModule
+    assert type(params) is dict
+    return mod, params
+
+
 def test_load_model__onnx(onnx_resnet50):
     # some CI environments wont offer onnx, so skip in case it is not present
     pytest.importorskip("onnx")
-
-    mod, params = tvmc.frontends.load_model(onnx_resnet50)
-    assert type(mod) is IRModule
-    assert type(params) is dict
-    ## check whether one known value is part of the params dict
+    mod, params = verify_load_model__onnx(onnx_resnet50)
+    # check whether one known value is part of the params dict
     assert "resnetv24_batchnorm0_gamma" in params.keys()
+    mod, params = verify_load_model__onnx(onnx_resnet50, freeze_params=True)
+    # check that the parameter dict is empty, implying that they have been folded into constants
+    assert params == {}
 
 
 def test_load_model__pb(pb_mobilenet_v1_1_quant):