You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/01/30 00:55:34 UTC

[GitHub] [tvm] comaniac commented on a change in pull request #7366: [TVMC] Allow manual shape specification in tvmc

comaniac commented on a change in pull request #7366:
URL: https://github.com/apache/tvm/pull/7366#discussion_r567167046



##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -136,3 +138,46 @@ def tracker_host_port_from_cli(rpc_tracker_str):
         logger.info("RPC tracker port: %s", rpc_port)
 
     return rpc_hostname, rpc_port
+
+
+def parse_shape_string(inputs):
+    """Parse an input shape dictionary string to a usable dictionary.
+
+    Parameters
+    ----------
+    inputs: str
+        A string of the form "name:num1xnum2x...xnumN,name2:num1xnum2xnum3" that indicates
+        the desired shape for specific model inputs.
+
+    Returns
+    -------
+    shape_dict: dict
+        A dictionary mapping input names to their shape for use in relay frontend converters.
+    """
+    inputs = inputs.replace(" ", "")
+    # Check if the passed input is in the proper format.
+    valid_pattern = re.compile("(\w+:(\d+(x|X))*(\d)+)(,(\w+:(\d+(x|X))*(\d)+))*")
+    result = re.fullmatch(valid_pattern, inputs)
+    if result is None:
+        raise argparse.ArgumentTypeError(
+            "--shapes argument must be of the form 'input_name:dim1xdim2x...xdimN,input_name2:dim1xdim2"
+        )
+    d = {}
+    # Break apart each specific input string
+    inputs = inputs.split(",")

Review comment:
       Do not mess up with types. `inputs` is already a string.

##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -136,3 +138,46 @@ def tracker_host_port_from_cli(rpc_tracker_str):
         logger.info("RPC tracker port: %s", rpc_port)
 
     return rpc_hostname, rpc_port
+
+
+def parse_shape_string(inputs):
+    """Parse an input shape dictionary string to a usable dictionary.
+
+    Parameters
+    ----------
+    inputs: str
+        A string of the form "name:num1xnum2x...xnumN,name2:num1xnum2xnum3" that indicates
+        the desired shape for specific model inputs.
+
+    Returns
+    -------
+    shape_dict: dict
+        A dictionary mapping input names to their shape for use in relay frontend converters.
+    """
+    inputs = inputs.replace(" ", "")
+    # Check if the passed input is in the proper format.
+    valid_pattern = re.compile("(\w+:(\d+(x|X))*(\d)+)(,(\w+:(\d+(x|X))*(\d)+))*")
+    result = re.fullmatch(valid_pattern, inputs)
+    if result is None:
+        raise argparse.ArgumentTypeError(
+            "--shapes argument must be of the form 'input_name:dim1xdim2x...xdimN,input_name2:dim1xdim2"
+        )
+    d = {}
+    # Break apart each specific input string
+    inputs = inputs.split(",")
+    for string in inputs:

Review comment:
       Please avoid bad naming. `string` is too general and looks like a preserved word (although it's not).

##########
File path: tests/python/driver/tvmc/test_common.py
##########
@@ -149,3 +149,27 @@ def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090():
 
     assert expected_host == actual_host
     assert expected_port == actual_port
+
+
+def test_shape_parser():

Review comment:
       Cover negative shapes.

##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -389,6 +396,8 @@ def load_model(path, model_format=None):
     model_format : str, optional
         The underlying framework used to create the model.
         If not specified, this will be inferred from the file type.
+    shape_dict : dict, optional
+        A mapping between input names and their desired shape.

Review comment:
       ```suggestion
           Mapping from input names to their shapes.
   ```

##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -136,3 +138,46 @@ def tracker_host_port_from_cli(rpc_tracker_str):
         logger.info("RPC tracker port: %s", rpc_port)
 
     return rpc_hostname, rpc_port
+
+
+def parse_shape_string(inputs):
+    """Parse an input shape dictionary string to a usable dictionary.
+
+    Parameters
+    ----------
+    inputs: str
+        A string of the form "name:num1xnum2x...xnumN,name2:num1xnum2xnum3" that indicates
+        the desired shape for specific model inputs.
+
+    Returns
+    -------
+    shape_dict: dict
+        A dictionary mapping input names to their shape for use in relay frontend converters.
+    """
+    inputs = inputs.replace(" ", "")
+    # Check if the passed input is in the proper format.
+    valid_pattern = re.compile("(\w+:(\d+(x|X))*(\d)+)(,(\w+:(\d+(x|X))*(\d)+))*")
+    result = re.fullmatch(valid_pattern, inputs)
+    if result is None:
+        raise argparse.ArgumentTypeError(
+            "--shapes argument must be of the form 'input_name:dim1xdim2x...xdimN,input_name2:dim1xdim2"
+        )
+    d = {}
+    # Break apart each specific input string
+    inputs = inputs.split(",")
+    for string in inputs:
+        # Split name from shape string.
+        string = string.split(":")
+        shapelist = []
+        # Separate each dimension in the shape.
+        string[1] = string[1].lower().split("x")
+        # Parse each dimension into an integer.
+        for x in string[1]:
+            x = int(x)
+            # Negative numbers are converted to dynamic axes.
+            if x < 0:
+                x = relay.Any()
+            shapelist.append(x)

Review comment:
       ```suggestion
               shapelist.append(x if x >= 0 else relay.Any())
   ```

##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -136,3 +138,46 @@ def tracker_host_port_from_cli(rpc_tracker_str):
         logger.info("RPC tracker port: %s", rpc_port)
 
     return rpc_hostname, rpc_port
+
+
+def parse_shape_string(inputs):
+    """Parse an input shape dictionary string to a usable dictionary.
+
+    Parameters
+    ----------
+    inputs: str
+        A string of the form "name:num1xnum2x...xnumN,name2:num1xnum2xnum3" that indicates
+        the desired shape for specific model inputs.
+
+    Returns
+    -------
+    shape_dict: dict
+        A dictionary mapping input names to their shape for use in relay frontend converters.
+    """
+    inputs = inputs.replace(" ", "")
+    # Check if the passed input is in the proper format.
+    valid_pattern = re.compile("(\w+:(\d+(x|X))*(\d)+)(,(\w+:(\d+(x|X))*(\d)+))*")
+    result = re.fullmatch(valid_pattern, inputs)

Review comment:
       - Seems doesn't match negative numbers such as `data:-1x3x224x224`.
   - Same comment as the one I left in another PR: This syntax is not straightforward to many users. I would suggest using either JSON or Python syntax.

##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -136,3 +138,46 @@ def tracker_host_port_from_cli(rpc_tracker_str):
         logger.info("RPC tracker port: %s", rpc_port)
 
     return rpc_hostname, rpc_port
+
+
+def parse_shape_string(inputs):
+    """Parse an input shape dictionary string to a usable dictionary.
+
+    Parameters
+    ----------
+    inputs: str
+        A string of the form "name:num1xnum2x...xnumN,name2:num1xnum2xnum3" that indicates
+        the desired shape for specific model inputs.
+
+    Returns
+    -------
+    shape_dict: dict
+        A dictionary mapping input names to their shape for use in relay frontend converters.
+    """
+    inputs = inputs.replace(" ", "")
+    # Check if the passed input is in the proper format.
+    valid_pattern = re.compile("(\w+:(\d+(x|X))*(\d)+)(,(\w+:(\d+(x|X))*(\d)+))*")
+    result = re.fullmatch(valid_pattern, inputs)
+    if result is None:
+        raise argparse.ArgumentTypeError(
+            "--shapes argument must be of the form 'input_name:dim1xdim2x...xdimN,input_name2:dim1xdim2"
+        )
+    d = {}
+    # Break apart each specific input string
+    inputs = inputs.split(",")
+    for string in inputs:
+        # Split name from shape string.
+        string = string.split(":")
+        shapelist = []
+        # Separate each dimension in the shape.
+        string[1] = string[1].lower().split("x")

Review comment:
       Ditto. Do not mess up with types.

##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -54,13 +54,15 @@ def suffixes():
         """File suffixes (extensions) used by this frontend"""
 
     @abstractmethod
-    def load(self, path):
+    def load(self, path, shape_dict=None):
         """Load a model from a given path.
 
         Parameters
         ----------
         path: str
             Path to a file
+        shape_dict: dict, optional
+            A dictionary mapping input names to shapes.

Review comment:
       ```suggestion
               Mapping from input names to their shapes.
   ```

##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -136,3 +138,46 @@ def tracker_host_port_from_cli(rpc_tracker_str):
         logger.info("RPC tracker port: %s", rpc_port)
 
     return rpc_hostname, rpc_port
+
+
+def parse_shape_string(inputs):
+    """Parse an input shape dictionary string to a usable dictionary.
+
+    Parameters
+    ----------
+    inputs: str
+        A string of the form "name:num1xnum2x...xnumN,name2:num1xnum2xnum3" that indicates
+        the desired shape for specific model inputs.
+
+    Returns
+    -------
+    shape_dict: dict
+        A dictionary mapping input names to their shape for use in relay frontend converters.
+    """
+    inputs = inputs.replace(" ", "")
+    # Check if the passed input is in the proper format.
+    valid_pattern = re.compile("(\w+:(\d+(x|X))*(\d)+)(,(\w+:(\d+(x|X))*(\d)+))*")
+    result = re.fullmatch(valid_pattern, inputs)
+    if result is None:
+        raise argparse.ArgumentTypeError(
+            "--shapes argument must be of the form 'input_name:dim1xdim2x...xdimN,input_name2:dim1xdim2"
+        )
+    d = {}

Review comment:
       Please avoid bad naming.

##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -285,17 +291,18 @@ def suffixes():
         # Torch Script is a zip file, but can be named pth
         return ["pth", "zip"]
 
-    def load(self, path):
+    def load(self, path, shape_dict=None):
         # pylint: disable=C0415
         import torch
 
         traced_model = torch.jit.load(path)
+        traced_model.eval()  # Switch to inference mode
 
-        inputs = list(traced_model.graph.inputs())[1:]
-        input_shapes = [inp.type().sizes() for inp in inputs]
+        if shape_dict is None:
+            raise TVMCException("--shapes must be specified for {}".format(self.name()))

Review comment:
       ```suggestion
               raise TVMCException("--shapes must be specified for %s" % self.name())
   ```

##########
File path: python/tvm/driver/tvmc/compiler.py
##########
@@ -158,6 +167,9 @@ def compile_model(
         The layout to convert the graph to. Note, the convert layout
         pass doesn't currently guarantee the whole of the graph will
         be converted to the chosen layout.
+    shape_dict: dict, optional
+        A mapping between input names and their shape. This is useful
+        to override the default values in a model if needed.

Review comment:
       ```suggestion
           A mapping from input names to their shapes. When present,
           the default shapes in the model will be overwritten.
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org