You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/11/12 17:42:52 UTC

[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

anijain2305 commented on a change in pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905#discussion_r522282005



##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -134,15 +135,18 @@ def partition_for_tensorrt(
 
     if params:
         mod["main"] = bind_params_by_name(mod["main"], params)
+
     seq = tvm.transform.Sequential(
         [
             transform.InferType(),
             RemoveDropoutPass(),
             transform.RemoveUnusedFunctions(),
             transform.ConvertLayout(
-                {"nn.conv2d": ["NCHW", "default"], "nn.conv3d": ["NCDHW", "default"]}
+                {"nn.conv2d": ["NCHW", "default"],
+                 "nn.conv3d": ["NCDHW", "default"]}
             ),
             transform.FoldConstant(),
+            transform.InferType(),

Review comment:
       Do we need this?

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -18,11 +18,12 @@
 """TensorRT supported operators."""
 import logging
 import numpy as np
+import os

Review comment:
       Do we need this?

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -152,13 +156,51 @@ def partition_for_tensorrt(
     with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
         mod = seq(mod)
         mod = prune_tensorrt_subgraphs(mod)
+
     return mod, config
 
+def check_dynamism(args, op_name):
+    """
+    This function checks for dynamism inside any of the args in the op.
+    Can be used to offload dynamic ops that are not supported by TRT to
+    be offloaded to relay VM.
+
+    Raises a NotImplementedError if the type of the arg is not of types
+    Call, Var, Constant, or TupleGetItem.
+
+    Parameters
+    ----------
+    args: a TRT array of the arguments of the op

Review comment:
       Please follow Python style for docstrings

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -715,6 +771,34 @@ def conv3d_transpose_annotate_fn(expr):  # pylint: disable=unused-variable
         return False
     return True
 
+_register_external_dynamic_check_func("add", add_annotate_fn)
+_register_external_dynamic_check_func("nn.batch_norm", batch_norm_annotate_fn)
+_register_external_dynamic_check_func("nn.softmax", softmax_annotate_fn)
+_register_external_dynamic_check_func("nn.conv2d", conv2d_annotate_fn)
+_register_external_dynamic_check_func("nn.dense", dense_annotate_fn)
+_register_external_dynamic_check_func("nn.bias_add", bias_add_annotate_fn)
+_register_external_dynamic_check_func("nn.max_pool2d", max_pool_2d_annotate_fn)
+_register_external_dynamic_check_func("nn.avg_pool2d", avg_pool_2d_annotate_fn)
+_register_external_dynamic_check_func("nn.global_max_pool2d", global_max_pool_2d_annotate_fn)
+_register_external_dynamic_check_func("nn.global_avg_pool2d", global_avg_pool_2d_annotate_fn)
+_register_external_dynamic_check_func("expand_dims", expand_dims_annotate_fn)
+_register_external_dynamic_check_func("squeeze", squeeze_annotate_fn)
+_register_external_dynamic_check_func("concatenate", concatenate_annotate_fn)
+_register_external_dynamic_check_func("nn.conv2d_transpose", conv2d_transpose_annotate_fn)
+_register_external_dynamic_check_func("transpose", transpose_annotate_fn)
+_register_external_dynamic_check_func("layout_transform", layout_transform_annotate_fn)
+_register_external_dynamic_check_func("reshape", reshape_annotate_fn)
+_register_external_dynamic_check_func("nn.pad", pad_annotate_fn)
+_register_external_dynamic_check_func("strided_slice", strided_slice_annotate_fn)
+_register_external_dynamic_check_func("nn.adaptive_max_pool2d", adaptive_max_pool2d_annotate_fn)
+_register_external_dynamic_check_func("nn.adaptive_avg_pool2d", adaptive_avg_pool2d_annotate_fn)
+_register_external_dynamic_check_func("nn.conv3d", conv3d_annotate_fn)
+_register_external_dynamic_check_func("nn.max_pool3d", max_pool_3d_annotate_fn)
+_register_external_dynamic_check_func("nn.avg_pool3d", avg_pool_3d_annotate_fn)
+_register_external_dynamic_check_func("nn.conv3d_transpose", conv3d_transpose_annotate_fn)
+
+
+

Review comment:
       remove extra spaces everywhere

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -173,6 +215,29 @@ def _register_external_op_helper(op_name, supported=True):
     )
 
 
+def _register_external_dynamic_check_func(op_name, checker):
+    """
+    Wrapper to check dynamic shapes inside any of the args in the op
+
+    Parameters
+    ----------
+    op_name: name of the op for debugging purposes only

Review comment:
       Same as above




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org