You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/04/07 12:00:08 UTC

[tvm] branch main updated: [TRT] Minor fixes on TRT python interface (#10917)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6b4d351e9b [TRT] Minor fixes on TRT python interface (#10917)
6b4d351e9b is described below

commit 6b4d351e9b3b2a4f814de35415071faf3d19715e
Author: Michalis Papadimitriou <mi...@users.noreply.github.com>
AuthorDate: Thu Apr 7 15:00:03 2022 +0300

    [TRT] Minor fixes on TRT python interface (#10917)
    
    Co-authored-by: Michalis Papapdimitriou <mp...@octoml.ai>
---
 python/tvm/relay/op/contrib/tensorrt.py | 57 +++++++++++++++++++++++++--------
 1 file changed, 43 insertions(+), 14 deletions(-)

diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py
index f24d366e59..160939369c 100644
--- a/python/tvm/relay/op/contrib/tensorrt.py
+++ b/python/tvm/relay/op/contrib/tensorrt.py
@@ -25,10 +25,9 @@ from tvm.ir import Op
 from tvm.relay import transform
 from tvm.relay.build_module import bind_params_by_name
 from tvm.relay.dataflow_pattern import is_op, wildcard
-from tvm.relay.expr import Call, Constant, GlobalVar, Tuple, TupleGetItem, Var
+from tvm.relay.expr import Call, Constant, GlobalVar, Tuple
 from tvm.relay.expr_functor import ExprMutator, ExprVisitor
 from tvm.relay.op.contrib.register import register_pattern_table
-from tvm.relay.op.transform import split
 
 logger = logging.getLogger("TensorRT")
 supported_types = ["float32", "float16"]
@@ -236,6 +235,36 @@ def get_pass_order(use_patterns):
     )
 
 
+def check_type_dynamism(type, op_name):  # pylint: disable=redefined-builtin
+    r"""
+    Check for dynamic TensorType for an input op
+
+    Parameters
+    ----------
+    type: checked_type of the op
+    op_name: str
+        Name of the op for debugging pursposes.
+    Returns
+    -------
+    ret: bool
+        True if arg dynamic type not suppot in TRT, False otherwise
+    """
+
+    if isinstance(type, tvm.ir.TensorType):
+        # assumes dim 0 is for batch and can be dynamic
+        for dim_shape in type.shape[1:]:
+            if isinstance(dim_shape, tvm.tir.expr.Any):
+                return True
+    elif isinstance(type, tvm.ir.TupleType):
+        for field_type in type.fields:
+            if check_type_dynamism(field_type, op_name):
+                return True
+    else:
+        logger.info("Arg not supported in TensorRT for %s with type %s", op_name, type)
+        return True
+    return False
+
+
 def check_dynamism(args, op_name):
     """
     Check for dynamism inside any of the args in the op.
@@ -253,14 +282,7 @@ def check_dynamism(args, op_name):
         True if dynamism is present, False otherwise
     """
     for arg in args:
-        if isinstance(arg, (Call, Var, Constant, TupleGetItem)):
-            for dim_shape in arg.checked_type.shape[1:]:
-                if isinstance(dim_shape, tvm.tir.expr.Any):
-                    return True
-        elif isinstance(arg, Tuple):
-            return check_dynamism(arg.fields, op_name)
-        else:
-            logger.info("Arg not supported in TensorRT for %s with type %s", op_name, type(arg))
+        if check_type_dynamism(arg.checked_type, op_name):
             return True
     return False
 
@@ -355,6 +377,7 @@ _register_external_op_helper_with_checker("prod", reduce_annotate_fn)
 _register_external_op_helper_with_checker("max", reduce_annotate_fn)
 _register_external_op_helper_with_checker("min", reduce_annotate_fn)
 _register_external_op_helper_with_checker("mean", reduce_annotate_fn)
+_register_external_op_helper_with_checker("variance", reduce_annotate_fn)
 
 
 def trt_version_annotate_fn(version):
@@ -464,6 +487,9 @@ def conv2d_annotate_fn(expr):  # pylint: disable=unused-variable
     attrs, args = expr.attrs, expr.args
     if not is_supported_trt_dtype(args):
         return False
+    if not isinstance(args[1], Constant):
+        logger.info("nn.conv2d: kernel argument must be constant.")
+        return False
     if attrs.data_layout != "NCHW":
         logger.info("nn.conv2d: data_layout is %s but must be NCHW.", attrs.data_layout)
         return False
@@ -483,6 +509,9 @@ def dense_annotate_fn(expr):  # pylint: disable=unused-variable
     args = expr.args
     if not is_supported_trt_dtype(args):
         return False
+    if not isinstance(args[1], Constant):
+        logger.info("nn.dense: weight must be constant")
+        return False
     input_rank = len(args[0].checked_type.shape)
     weight_rank = len(args[1].checked_type.shape)
     if input_rank not in (2, 3, 4):
@@ -790,7 +819,9 @@ def pad_annotate_fn(expr):  # pylint: disable=unused-variable
     if not is_supported_trt_dtype(args):
         return False
     pad_value = args[1]
-    assert isinstance(pad_value, relay.Constant)
+    if not isinstance(pad_value, relay.Constant):
+        logger.info("nn.pad: pad argument must be constant")
+        return False
     pad_value = pad_value.data.numpy().item()
     if attrs.pad_mode != "constant":
         logger.info("nn.pad: pad mode is %s but must be constant.", attrs.pad_mode)
@@ -999,8 +1030,6 @@ def pattern_table():
         ),
         ("tensorrt.divide", binary_op_pattern("divide")),
         ("tensorrt.multiply", binary_op_pattern("multiply")),
-        ("tensorrt.split", unary_op_pattern("split")),
-        ("tensorrt.reshape", unary_op_pattern("reshape")),
         ("tensorrt.nn.relu", unary_op_pattern("nn.relu")),
         (
             "tensorrt.nn.leaky_relu",
@@ -1038,7 +1067,7 @@ def pattern_table():
         ),
         ("tensorrt.transpose", unary_op_pattern("transpose"), transpose_annotate_fn),
         ("tensorrt.reshape", unary_op_pattern("reshape"), reshape_annotate_fn),
-        ("tensorrt.split", unary_op_pattern("split"), split),
+        ("tensorrt.split", unary_op_pattern("split"), split_annotate_fn),
         ("tensorrt.nn.pad", unary_op_pattern("nn.pad"), pad_annotate_fn),
         ("tensorrt.strided_slice", unary_op_pattern("strided_slice"), strided_slice_annotate_fn),
         (