You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/01/29 14:16:47 UTC

[GitHub] [tvm] ekalda commented on a change in pull request #7366: [TVMC] Allow manual shape specification in tvmc

ekalda commented on a change in pull request #7366:
URL: https://github.com/apache/tvm/pull/7366#discussion_r566843771



##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -297,6 +303,15 @@ def load(self, path):
         traced_model.eval()  # Switch to inference mode
         input_shapes = [("input{}".format(idx), shape) for idx, shape in enumerate(shapes)]

Review comment:
       (1) That "shapes" in "enumerate(shapes)" is not defined (and old bug, something I discovered when I tried to run a PyTorch model with tvmc).
   (2) This approach of extracting input shapes form PyTorch model is not functional any more (see the discussion in #7359 ), so this parameter needs to be set for PyTorch and the PyTorch frontend should throw an error when it is not set. (Unless someone maybe knows a way to extract the shapes from the model?)

##########
File path: python/tvm/driver/tvmc/compiler.py
##########
@@ -36,6 +36,41 @@
 logger = logging.getLogger("TVMC")
 
 
+def parse_shape(inputs):
+    """Parse an input shape dictionary string to a usable dictionary.
+
+    Parameters
+    ----------
+    inputs: str
+        A string of the form "name:num1xnum2x...xnumN,name2:num1xnum2xnum3" that indicates
+        the desired shape for specific model inputs.
+
+    Returns
+    -------
+    shape_dict: dict
+        A dictionary mapping input names to their shape for use in relay frontend converters.
+    """
+    d = {}
+    # Break apart each specific input string
+    inputs = inputs.split(",")
+    for string in inputs:
+        # Split name from shape string.
+        string = string.split(":")
+        shapelist = []
+        # Separate each dimension in the shape.
+        string[1] = string[1].split("x")
+        # Parse each dimension into an integer.
+        for x in string[1]:
+            x = int(x)
+            # Negative numbers are converted to dynamic axes.
+            if x < 0:
+                x = relay.Any()
+            shapelist.append(x)
+        # Assign dictionary key value pair.
+        d[string[0]] = shapelist
+    return d
+

Review comment:
       Maybe it would be good to add some error handling when the user input is not in the desired format? Also, should we maybe allow inputs that are "close", e.g. "name:num1Xnum2" or "name:num1xnum2, name2:num1xnum2" ? Some unit tests to handle the corner cases would be nice :)

##########
File path: tests/python/driver/tvmc/test_compiler.py
##########
@@ -56,6 +53,15 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant):
     assert type(dumps) is dict
 
 
+def test_compile_tflite_module(tflite_mobilenet_v1_1_quant):
+    # Check default compilation.
+    verify_compile_tflite_module(tflite_mobilenet_v1_1_quant)
+    # Check with manual shape override
+    shape_string = "input:1x224x224x3"
+    shape_dict = tvmc.compiler.parse_shape(shape_string)
+    verify_compile_onnx_module(tflite_mobilenet_v1_1_quant, shape_dict)

Review comment:
       Should this call verify_compile_tflite_module?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org