You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/01/28 13:34:26 UTC

[GitHub] [tvm] ekalda opened a new pull request #7359: [TVMC] Fix PyTorch support

ekalda opened a new pull request #7359:
URL: https://github.com/apache/tvm/pull/7359


   A PyTorch model could not be compiled through tvmc because the shape
   of the input tensor could not be deduced from the model after it has been
   saved. We've added an --input-shape parameter to tvmc compile and
   tvmc tune that allows the inputs to be specified for PyTorch models.
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac closed pull request #7359: [TVMC] Fix PyTorch support

Posted by GitBox <gi...@apache.org>.
comaniac closed pull request #7359:
URL: https://github.com/apache/tvm/pull/7359


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on pull request #7359: [TVMC] Fix PyTorch support

Posted by GitBox <gi...@apache.org>.
comaniac commented on pull request #7359:
URL: https://github.com/apache/tvm/pull/7359#issuecomment-769954804


   Include the functionalities in #7366.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] leandron commented on a change in pull request #7359: [TVMC] Fix PyTorch support

Posted by GitBox <gi...@apache.org>.
leandron commented on a change in pull request #7359:
URL: https://github.com/apache/tvm/pull/7359#discussion_r566195003



##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -136,3 +138,40 @@ def tracker_host_port_from_cli(rpc_tracker_str):
         logger.info("RPC tracker port: %s", rpc_port)
 
     return rpc_hostname, rpc_port
+
+
+def parse_input_shapes(xs):
+    """Turn the string from --input-shape into a list.
+

Review comment:
       It would be good to have an example here, that describes the input format and expected output format, similar to what you have on `test_parse_input_shapes__turn_into_list`.

##########
File path: python/tvm/driver/tvmc/compiler.py
##########
@@ -59,6 +59,12 @@ def add_compile_parser(subparsers):
         default="",
         help="comma separarated list of formats to export, e.g. 'asm,ll,relay' ",
     )
+    parser.add_argument(
+        "--input-shape",
+        type=common.parse_input_shapes,
+        metavar="INPUT_SHAPE,[INPUT_SHAPE]...",
+        help="for PyTorch, e.g. '(1,3,224,224)'",

Review comment:
       Maybe clarify that it is in fact mandatory for PyTorch.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ekalda commented on a change in pull request #7359: [TVMC] Fix PyTorch support

Posted by GitBox <gi...@apache.org>.
ekalda commented on a change in pull request #7359:
URL: https://github.com/apache/tvm/pull/7359#discussion_r566826298



##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -285,17 +299,17 @@ def suffixes():
         # Torch Script is a zip file, but can be named pth
         return ["pth", "zip"]
 
-    def load(self, path):
+    def load(self, path, input_shape):
         # pylint: disable=C0415
         import torch
 
-        traced_model = torch.jit.load(path)
-
-        inputs = list(traced_model.graph.inputs())[1:]

Review comment:
       I looked into this and I didn't find a way to extract inputs from the model after it has been saved and loaded. I asked on the PyTorch forum as well (https://discuss.pytorch.org/t/input-size-disappears-between-torch-jit-save-and-torch-jit-load/108955) and since I received a grand total of zero responses, I suspect it is a deliberate design decision. If there was a way, it would be good to keep it, of course, but in that form it doesn't work any more.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ekalda commented on pull request #7359: [TVMC] Fix PyTorch support

Posted by GitBox <gi...@apache.org>.
ekalda commented on pull request #7359:
URL: https://github.com/apache/tvm/pull/7359#issuecomment-769161279


   cc @leandron @u99127 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #7359: [TVMC] Fix PyTorch support

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #7359:
URL: https://github.com/apache/tvm/pull/7359#discussion_r566306351



##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -136,3 +138,40 @@ def tracker_host_port_from_cli(rpc_tracker_str):
         logger.info("RPC tracker port: %s", rpc_port)
 
     return rpc_hostname, rpc_port
+
+
+def parse_input_shapes(xs):
+    """Turn the string from --input-shape into a list.
+
+    Parameters
+    ----------
+    xs : str
+        The input shapes, in a form "(1,2,3),(1,4),..."
+
+    Returns
+    -------
+    shapes : list
+        Input shapes as a list of lists
+    """
+
+    shapes = []
+    # Split up string into comma seperated sections ignoring commas in ()s
+    match = re.findall(r"(\(.*?\)|.+?),?", xs)
+    if match:
+        for inp in match:
+            # Test for and remove brackets
+            shape = re.match(r"\((.*)\)", inp)
+            if shape and shape.lastindex == 1:
+                # Remove white space and extract numbers
+                strshape = shape[1].replace(" ", "").split(",")

Review comment:
       It would be safer and easier to remove all spaces in `xs` in the beginning of this function.

##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -99,10 +101,13 @@ def name():
     def suffixes():
         return ["h5"]
 
-    def load(self, path):
+    def load(self, path, input_shape):
         # pylint: disable=C0103
         tf, keras = import_keras()
 
+        if input_shape:
+            raise TVMCException("--input-shape is not supported for {}".format(self.name()))
+

Review comment:
       This is definitely too ad hoc

##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -136,3 +138,40 @@ def tracker_host_port_from_cli(rpc_tracker_str):
         logger.info("RPC tracker port: %s", rpc_port)
 
     return rpc_hostname, rpc_port
+
+
+def parse_input_shapes(xs):
+    """Turn the string from --input-shape into a list.
+
+    Parameters
+    ----------
+    xs : str
+        The input shapes, in a form "(1,2,3),(1,4),..."
+
+    Returns
+    -------
+    shapes : list
+        Input shapes as a list of lists
+    """
+
+    shapes = []
+    # Split up string into comma seperated sections ignoring commas in ()s
+    match = re.findall(r"(\(.*?\)|.+?),?", xs)
+    if match:
+        for inp in match:
+            # Test for and remove brackets
+            shape = re.match(r"\((.*)\)", inp)
+            if shape and shape.lastindex == 1:
+                # Remove white space and extract numbers
+                strshape = shape[1].replace(" ", "").split(",")
+                try:
+                    shapes.append([int(i) for i in strshape])
+                except ValueError:
+                    raise argparse.ArgumentTypeError(f"expected numbers in shape '{shape[1]}'")

Review comment:
       Consider the following two input shapes:
   - `(8)`: `shapes=[8]`
   - `(8,)`: Value error because `strshape` would be `[8, ""]`.
   
   Accordingly, I guess your intention is `(8)` instead of `(8,)`. However, this is inconsistent with the Python syntax so it might confuse people. I have two proposals to deal with this:
   1. Use list syntax instead of tuple, so that the semantic is clear, and we can simply use JSON loader to deal with all variants (e.g., spaces):
       ```python
       xs = "[1,3,224,224], [32]"
       shapes = json.loads(xs) # [[1,3,224,224],[32]]
       ```
   2. Follow Python syntax to only accept `(8,)` and throw an error for `(8)`, which is treated as an integer instead of a tuple because buckets will be simplified in Python. In this case, I would suggest using `eval` to deal with all variants.
       ```python
       xs = "(1,3,224,224), (32,)"
       shapes = eval(xs, {}, {}) # Remember to disable all local and global symbols to isolate this expression.
       # shapes=[(1,3,224,224),(32,)]
       ```
   
   Either way is fine for me, and please update the help message and make sure you have a unit test to cover corner cases.

##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -389,6 +403,8 @@ def load_model(path, model_format=None):
     model_format : str, optional
         The underlying framework used to create the model.
         If not specified, this will be inferred from the file type.
+    input shape : list, optional
+        The shape of input tensor for PyTorch models

Review comment:
       ditto. make it general instead of only for PyTorch.

##########
File path: python/tvm/driver/tvmc/compiler.py
##########
@@ -59,6 +59,12 @@ def add_compile_parser(subparsers):
         default="",
         help="comma separarated list of formats to export, e.g. 'asm,ll,relay' ",
     )
+    parser.add_argument(
+        "--input-shape",
+        type=common.parse_input_shapes,
+        metavar="INPUT_SHAPE,[INPUT_SHAPE]...",
+        help="for PyTorch, e.g. '(1,3,224,224)'",

Review comment:
       Agree. It's confusing to see such a general option only for PyTorch. I would suggest the following changes:
   1. Make `--input-shape` as a general option for all frontends. If present, we skip the input shape inference.
   2. `--input-shape` is optional by default. However, if users want to process a PyTorch model but don't specify `--input-shape`, we throw out an error in the PyTorch frontend.

##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -285,17 +299,17 @@ def suffixes():
         # Torch Script is a zip file, but can be named pth
         return ["pth", "zip"]
 
-    def load(self, path):
+    def load(self, path, input_shape):
         # pylint: disable=C0415
         import torch
 
-        traced_model = torch.jit.load(path)
-
-        inputs = list(traced_model.graph.inputs())[1:]

Review comment:
       Is this approach not working at all? If it works for some cases, we should still use it first when `--input-shape` is missing.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org