You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/01/28 18:41:32 UTC

[GitHub] [tvm] comaniac commented on a change in pull request #7359: [TVMC] Fix PyTorch support

comaniac commented on a change in pull request #7359:
URL: https://github.com/apache/tvm/pull/7359#discussion_r566306351



##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -136,3 +138,40 @@ def tracker_host_port_from_cli(rpc_tracker_str):
         logger.info("RPC tracker port: %s", rpc_port)
 
     return rpc_hostname, rpc_port
+
+
+def parse_input_shapes(xs):
+    """Turn the string from --input-shape into a list.
+
+    Parameters
+    ----------
+    xs : str
+        The input shapes, in a form "(1,2,3),(1,4),..."
+
+    Returns
+    -------
+    shapes : list
+        Input shapes as a list of lists
+    """
+
+    shapes = []
+    # Split up string into comma seperated sections ignoring commas in ()s
+    match = re.findall(r"(\(.*?\)|.+?),?", xs)
+    if match:
+        for inp in match:
+            # Test for and remove brackets
+            shape = re.match(r"\((.*)\)", inp)
+            if shape and shape.lastindex == 1:
+                # Remove white space and extract numbers
+                strshape = shape[1].replace(" ", "").split(",")

Review comment:
       It would be safer and easier to remove all spaces in `xs` in the beginning of this function.

##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -99,10 +101,13 @@ def name():
     def suffixes():
         return ["h5"]
 
-    def load(self, path):
+    def load(self, path, input_shape):
         # pylint: disable=C0103
         tf, keras = import_keras()
 
+        if input_shape:
+            raise TVMCException("--input-shape is not supported for {}".format(self.name()))
+

Review comment:
       This is definitely too ad hoc

##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -136,3 +138,40 @@ def tracker_host_port_from_cli(rpc_tracker_str):
         logger.info("RPC tracker port: %s", rpc_port)
 
     return rpc_hostname, rpc_port
+
+
+def parse_input_shapes(xs):
+    """Turn the string from --input-shape into a list.
+
+    Parameters
+    ----------
+    xs : str
+        The input shapes, in a form "(1,2,3),(1,4),..."
+
+    Returns
+    -------
+    shapes : list
+        Input shapes as a list of lists
+    """
+
+    shapes = []
+    # Split up string into comma seperated sections ignoring commas in ()s
+    match = re.findall(r"(\(.*?\)|.+?),?", xs)
+    if match:
+        for inp in match:
+            # Test for and remove brackets
+            shape = re.match(r"\((.*)\)", inp)
+            if shape and shape.lastindex == 1:
+                # Remove white space and extract numbers
+                strshape = shape[1].replace(" ", "").split(",")
+                try:
+                    shapes.append([int(i) for i in strshape])
+                except ValueError:
+                    raise argparse.ArgumentTypeError(f"expected numbers in shape '{shape[1]}'")

Review comment:
       Consider the following two input shapes:
   - `(8)`: `shapes=[8]`
   - `(8,)`: Value error because `strshape` would be `[8, ""]`.
   
   Accordingly, I guess your intention is `(8)` instead of `(8,)`. However, this is inconsistent with the Python syntax so it might confuse people. I have two proposals to deal with this:
   1. Use list syntax instead of tuple, so that the semantic is clear, and we can simply use JSON loader to deal with all variants (e.g., spaces):
       ```python
       xs = "[1,3,224,224], [32]"
       shapes = json.loads(xs) # [[1,3,224,224],[32]]
       ```
   2. Follow Python syntax to only accept `(8,)` and throw an error for `(8)`, which is treated as an integer instead of a tuple because buckets will be simplified in Python. In this case, I would suggest using `eval` to deal with all variants.
       ```python
       xs = "(1,3,224,224), (32,)"
       shapes = eval(xs, {}, {}) # Remember to disable all local and global symbols to isolate this expression.
       # shapes=[(1,3,224,224),(32,)]
       ```
   
   Either way is fine for me, and please update the help message and make sure you have a unit test to cover corner cases.

##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -389,6 +403,8 @@ def load_model(path, model_format=None):
     model_format : str, optional
         The underlying framework used to create the model.
         If not specified, this will be inferred from the file type.
+    input shape : list, optional
+        The shape of input tensor for PyTorch models

Review comment:
       ditto. make it general instead of only for PyTorch.

##########
File path: python/tvm/driver/tvmc/compiler.py
##########
@@ -59,6 +59,12 @@ def add_compile_parser(subparsers):
         default="",
         help="comma separarated list of formats to export, e.g. 'asm,ll,relay' ",
     )
+    parser.add_argument(
+        "--input-shape",
+        type=common.parse_input_shapes,
+        metavar="INPUT_SHAPE,[INPUT_SHAPE]...",
+        help="for PyTorch, e.g. '(1,3,224,224)'",

Review comment:
       Agree. It's confusing to see such a general option only for PyTorch. I would suggest the following changes:
   1. Make `--input-shape` as a general option for all frontends. If present, we skip the input shape inference.
   2. `--input-shape` is optional by default. However, if users want to process a PyTorch model but don't specify `--input-shape`, we throw out an error in the PyTorch frontend.

##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -285,17 +299,17 @@ def suffixes():
         # Torch Script is a zip file, but can be named pth
         return ["pth", "zip"]
 
-    def load(self, path):
+    def load(self, path, input_shape):
         # pylint: disable=C0415
         import torch
 
-        traced_model = torch.jit.load(path)
-
-        inputs = list(traced_model.graph.inputs())[1:]

Review comment:
       Is this approach not working at all? If it works for some cases, we should still use it first when `--input-shape` is missing.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org