You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/04/07 12:00:08 UTC
[tvm] branch main updated: [TRT] Minor fixes on TRT python interface (#10917)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6b4d351e9b [TRT] Minor fixes on TRT python interface (#10917)
6b4d351e9b is described below
commit 6b4d351e9b3b2a4f814de35415071faf3d19715e
Author: Michalis Papadimitriou <mi...@users.noreply.github.com>
AuthorDate: Thu Apr 7 15:00:03 2022 +0300
[TRT] Minor fixes on TRT python interface (#10917)
Co-authored-by: Michalis Papapdimitriou <mp...@octoml.ai>
---
python/tvm/relay/op/contrib/tensorrt.py | 57 +++++++++++++++++++++++++--------
1 file changed, 43 insertions(+), 14 deletions(-)
diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py
index f24d366e59..160939369c 100644
--- a/python/tvm/relay/op/contrib/tensorrt.py
+++ b/python/tvm/relay/op/contrib/tensorrt.py
@@ -25,10 +25,9 @@ from tvm.ir import Op
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.dataflow_pattern import is_op, wildcard
-from tvm.relay.expr import Call, Constant, GlobalVar, Tuple, TupleGetItem, Var
+from tvm.relay.expr import Call, Constant, GlobalVar, Tuple
from tvm.relay.expr_functor import ExprMutator, ExprVisitor
from tvm.relay.op.contrib.register import register_pattern_table
-from tvm.relay.op.transform import split
logger = logging.getLogger("TensorRT")
supported_types = ["float32", "float16"]
@@ -236,6 +235,36 @@ def get_pass_order(use_patterns):
)
+def check_type_dynamism(type, op_name): # pylint: disable=redefined-builtin
+ r"""
+ Check for dynamic TensorType for an input op
+
+ Parameters
+ ----------
+ type: checked_type of the op
+ op_name: str
+ Name of the op for debugging pursposes.
+ Returns
+ -------
+ ret: bool
+ True if arg dynamic type not suppot in TRT, False otherwise
+ """
+
+ if isinstance(type, tvm.ir.TensorType):
+ # assumes dim 0 is for batch and can be dynamic
+ for dim_shape in type.shape[1:]:
+ if isinstance(dim_shape, tvm.tir.expr.Any):
+ return True
+ elif isinstance(type, tvm.ir.TupleType):
+ for field_type in type.fields:
+ if check_type_dynamism(field_type, op_name):
+ return True
+ else:
+ logger.info("Arg not supported in TensorRT for %s with type %s", op_name, type)
+ return True
+ return False
+
+
def check_dynamism(args, op_name):
"""
Check for dynamism inside any of the args in the op.
@@ -253,14 +282,7 @@ def check_dynamism(args, op_name):
True if dynamism is present, False otherwise
"""
for arg in args:
- if isinstance(arg, (Call, Var, Constant, TupleGetItem)):
- for dim_shape in arg.checked_type.shape[1:]:
- if isinstance(dim_shape, tvm.tir.expr.Any):
- return True
- elif isinstance(arg, Tuple):
- return check_dynamism(arg.fields, op_name)
- else:
- logger.info("Arg not supported in TensorRT for %s with type %s", op_name, type(arg))
+ if check_type_dynamism(arg.checked_type, op_name):
return True
return False
@@ -355,6 +377,7 @@ _register_external_op_helper_with_checker("prod", reduce_annotate_fn)
_register_external_op_helper_with_checker("max", reduce_annotate_fn)
_register_external_op_helper_with_checker("min", reduce_annotate_fn)
_register_external_op_helper_with_checker("mean", reduce_annotate_fn)
+_register_external_op_helper_with_checker("variance", reduce_annotate_fn)
def trt_version_annotate_fn(version):
@@ -464,6 +487,9 @@ def conv2d_annotate_fn(expr): # pylint: disable=unused-variable
attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
+ if not isinstance(args[1], Constant):
+ logger.info("nn.conv2d: kernel argument must be constant.")
+ return False
if attrs.data_layout != "NCHW":
logger.info("nn.conv2d: data_layout is %s but must be NCHW.", attrs.data_layout)
return False
@@ -483,6 +509,9 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
args = expr.args
if not is_supported_trt_dtype(args):
return False
+ if not isinstance(args[1], Constant):
+ logger.info("nn.dense: weight must be constant")
+ return False
input_rank = len(args[0].checked_type.shape)
weight_rank = len(args[1].checked_type.shape)
if input_rank not in (2, 3, 4):
@@ -790,7 +819,9 @@ def pad_annotate_fn(expr): # pylint: disable=unused-variable
if not is_supported_trt_dtype(args):
return False
pad_value = args[1]
- assert isinstance(pad_value, relay.Constant)
+ if not isinstance(pad_value, relay.Constant):
+ logger.info("nn.pad: pad argument must be constant")
+ return False
pad_value = pad_value.data.numpy().item()
if attrs.pad_mode != "constant":
logger.info("nn.pad: pad mode is %s but must be constant.", attrs.pad_mode)
@@ -999,8 +1030,6 @@ def pattern_table():
),
("tensorrt.divide", binary_op_pattern("divide")),
("tensorrt.multiply", binary_op_pattern("multiply")),
- ("tensorrt.split", unary_op_pattern("split")),
- ("tensorrt.reshape", unary_op_pattern("reshape")),
("tensorrt.nn.relu", unary_op_pattern("nn.relu")),
(
"tensorrt.nn.leaky_relu",
@@ -1038,7 +1067,7 @@ def pattern_table():
),
("tensorrt.transpose", unary_op_pattern("transpose"), transpose_annotate_fn),
("tensorrt.reshape", unary_op_pattern("reshape"), reshape_annotate_fn),
- ("tensorrt.split", unary_op_pattern("split"), split),
+ ("tensorrt.split", unary_op_pattern("split"), split_annotate_fn),
("tensorrt.nn.pad", unary_op_pattern("nn.pad"), pad_annotate_fn),
("tensorrt.strided_slice", unary_op_pattern("strided_slice"), strided_slice_annotate_fn),
(