You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2022/05/10 19:39:31 UTC

[tvm] branch main updated: [TENSORRT] Improvements and fixes for TensorRT (#11203)

This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new be2ae9433e [TENSORRT] Improvements and fixes for TensorRT (#11203)
be2ae9433e is described below

commit be2ae9433ebe19738529bb9008251b883a3f5e89
Author: Matthew Barrett <55...@users.noreply.github.com>
AuthorDate: Tue May 10 21:39:25 2022 +0200

    [TENSORRT] Improvements and fixes for TensorRT (#11203)
    
    A number of small fixes and refactors to improve the robustness of
    the TensorRT integration.
    
    Co-authored-by: Mark Shields <mb...@octoml.ai>
    
    Co-authored-by: Mark Shields <mb...@octoml.ai>
---
 python/tvm/relay/op/contrib/tensorrt.py            | 1123 +++++++++-----------
 src/relay/backend/contrib/tensorrt/codegen.cc      |  178 +++-
 src/relay/transforms/inline_composites.cc          |   94 --
 src/runtime/contrib/json/json_node.h               |    6 +
 src/runtime/contrib/tensorrt/tensorrt_builder.cc   |   53 +-
 src/runtime/contrib/tensorrt/tensorrt_calibrator.h |    2 +-
 src/runtime/contrib/tensorrt/tensorrt_ops.cc       |  304 ++++--
 src/runtime/contrib/tensorrt/tensorrt_ops.h        |   35 +-
 src/runtime/contrib/tensorrt/tensorrt_runtime.cc   |    8 +-
 tests/python/contrib/test_tensorrt.py              |   37 +-
 tests/python/relay/test_pass_inline_composites.py  |  165 ---
 tests/scripts/task_mypy.sh                         |    1 +
 12 files changed, 901 insertions(+), 1105 deletions(-)

diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py
index 3f867d97b7..58dac06382 100644
--- a/python/tvm/relay/op/contrib/tensorrt.py
+++ b/python/tvm/relay/op/contrib/tensorrt.py
@@ -14,39 +14,26 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name, unused-argument
+# pylint: disable=invalid-name, unused-argument, logging-format-interpolation
 """TensorRT supported operators."""
 import logging
+from typing import Tuple, List, Dict, Union, Optional, Any, Callable
 
-import numpy as np
+import numpy as np  # type: ignore
 import tvm
 from tvm import relay
 from tvm.ir import Op
 from tvm.relay import transform
 from tvm.relay.build_module import bind_params_by_name
-from tvm.relay.dataflow_pattern import is_op, wildcard, is_constant, is_tuple
-from tvm.relay.expr import Call, Constant, GlobalVar, Tuple
+from tvm.relay.dataflow_pattern import is_op, wildcard, is_constant, is_tuple, is_tuple_get_item
+from tvm.relay.expr import Call, Constant, GlobalVar, TupleGetItem
 from tvm.relay.expr_functor import ExprMutator, ExprVisitor
 from tvm.relay.op.contrib.register import register_pattern_table
 
 logger = logging.getLogger("TensorRT")
-supported_types = ["float32", "float16"]
 
 
-def is_supported_trt_dtype(args):
-    """Check if the TensorRT BYOC support input tensor dtype.
-    Returns
-    -------
-    ret: bool
-        True if supported, False if not.
-    """
-    if not all([x.checked_type.dtype in supported_types for x in args]):
-        logger.info("Only float32 and float16 inputs are supported for TensorRT BYOC.")
-        return False
-    return True
-
-
-def is_tensorrt_runtime_enabled():
+def is_tensorrt_runtime_enabled() -> bool:
     """Check if the TensorRT graph executor is present.
     Returns
     -------
@@ -59,7 +46,7 @@ def is_tensorrt_runtime_enabled():
     return False
 
 
-def get_tensorrt_version():
+def get_tensorrt_version() -> Tuple[int, int, int]:
     """Gets the version of TensorRT that TVM is built against or is targeting.
 
     Returns
@@ -70,11 +57,11 @@ def get_tensorrt_version():
     """
     pass_ctx = tvm.transform.PassContext.current()
     if "relay.ext.tensorrt.options" in pass_ctx.config:
-        return tuple(pass_ctx.config["relay.ext.tensorrt.options"].tensorrt_version)
-    return tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+        return tuple(pass_ctx.config["relay.ext.tensorrt.options"].tensorrt_version)  # type: ignore
+    return tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())  # type: ignore
 
 
-def get_tensorrt_use_implicit_batch_mode():
+def get_tensorrt_use_implicit_batch_mode() -> bool:
     pass_ctx = tvm.transform.PassContext.current()
     if "relay.ext.tensorrt.options" in pass_ctx.config:
         return pass_ctx.config["relay.ext.tensorrt.options"].use_implicit_batch
@@ -85,7 +72,7 @@ def get_tensorrt_use_implicit_batch_mode():
     return True
 
 
-def get_tensorrt_remove_no_mac_subgraphs():
+def get_tensorrt_remove_no_mac_subgraphs() -> bool:
     pass_ctx = tvm.transform.PassContext.current()
     if "relay.ext.tensorrt.options" in pass_ctx.config:
         return pass_ctx.config["relay.ext.tensorrt.options"].remove_no_mac_subgraphs
@@ -97,55 +84,53 @@ def get_tensorrt_remove_no_mac_subgraphs():
 
 
 def partition_for_tensorrt(
-    mod,
-    params=None,
-    version=None,
-    use_implicit_batch=True,
-    remove_no_mac_subgraphs=False,
-    max_workspace_size=1 << 30,
-    use_fp16=False,
-    use_uint8=False,
-    use_patterns=False,
-):
+    mod: tvm.IRModule,
+    params: Optional[Dict[str, tvm.nd.NDArray]] = None,
+    version: Optional[Tuple[int, int, int]] = None,
+    use_implicit_batch: bool = True,
+    remove_no_mac_subgraphs: bool = False,
+    max_workspace_size: int = 1 << 30,
+    use_fp16: bool = False,
+    use_uint8: bool = False,
+) -> Tuple[tvm.IRModule, Dict[str, Any]]:
     """Partition the graph greedily offloading supported operators to TensorRT.
 
     Parameters
     ----------
-    mod : Module
+    mod : tvm.IRModule
         The module to run passes on.
-    params : Optional[Dict[str, NDArray]]
+    params : Optional[Dict[str, tvm.nd.NDArray]]
         Constant input parameters.
     version : Optional[Tuple[int, int, int]]
         TensorRT version to target as tuple of (major, minor, patch). If TVM is compiled with
         USE_TENSORRT_RUNTIME=ON, the linked TensorRT version will be used instead.
-    use_implicit_batch : Optional[bool]
+    use_implicit_batch : bool
         Use TensorRT implicit batch mode (default true). Setting to false will enable explicit batch
         mode which will widen supported operators to include those which modify the batch dimension,
         but may reduce performance for some models.
-    remove_no_mac_subgraphs : Optional[bool]
+    remove_no_mac_subgraphs : bool
         Removes subgraphs which have been partitioned for TensorRT if they do not have any
         multiply-accumulate operations. The removed subgraphs will go through TVM's standard
         compilation instead. Can improve performance.
-    max_workspace_size : Optional[int]
+    max_workspace_size : int
         How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation.
         See TensorRT documentation for more info.
-    use_fp16: Optional[bool]
+    use_fp16: bool
         Allows, TRT to automatically convert FP32 inputs to FP16. Also, it is required to be enabled
         if FP16 inputs tensors and weights are used.
         Note that TensorRT will still choose a higher-precision kernel if it results in overall
         lower runtime, or if no low-precision implementation exists.
-    use_uint8: Optional[bool]
+    use_uint8: bool
         Allows, TRT to automatically convert FP32 inputs to UINT8.
-    use_patterns: Optional[bool]
-        Switches to use pattern-based op suppot by applying MergeCompsite and InlineComposites
-        passes.
+
     Returns
     -------
-    mod_and_config : Tuple[Module, Dict[str, Any]]
+    mod_and_config : Tuple[tvm.IRModule, Dict[str, Any]]
         A tuple of 1) annotated and partitioned module and 2) "relay.ext.tensorrt.options"
         configuration which should be given to PassContext when building.
+
     """
-    config = {
+    config: Dict[str, Any] = {
         "use_implicit_batch": use_implicit_batch,
         "max_workspace_size": max_workspace_size,
         "remove_no_mac_subgraphs": remove_no_mac_subgraphs,
@@ -168,247 +153,163 @@ def partition_for_tensorrt(
     if params:
         mod["main"] = bind_params_by_name(mod["main"], params)
 
-    seq = get_pass_order(use_patterns)
+    seq = tvm.transform.Sequential(
+        [
+            transform.InferType(),
+            RemoveDropoutPass(),
+            transform.RemoveUnusedFunctions(),
+            transform.ConvertLayout(
+                {
+                    "nn.conv1d": ["NCW", "default"],
+                    "nn.conv2d": ["NCHW", "default"],
+                    "nn.conv3d": ["NCDHW", "default"],
+                    "nn.conv2d_transpose": ["NCHW", "default"],
+                }
+            ),
+            transform.FoldConstant(),
+            transform.MergeComposite(pattern_table()),
+            transform.AnnotateTarget("tensorrt"),
+            transform.MergeCompilerRegions(),
+            transform.PartitionGraph(),
+            transform.InferType(),
+        ]
+    )
     with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
         mod = seq(mod)
-        mod = prune_tensorrt_subgraphs(mod)
+        # TODO(mbs): Revisit
+        # mod = prune_tensorrt_subgraphs(mod)
     return mod, config
 
 
-def get_pass_order(use_patterns):
-    """
-    Get the pass ordering based on using predicates or patterns.
-
-    Parameters
-    ----------
-    use_patterns: Bool
-        True if pass needs to work with op patterns
-    Returns
-    ----------
-    ret : Sequential
-        Pass object
-    """
-    return (
-        tvm.transform.Sequential(
-            [
-                transform.InferType(),
-                RemoveDropoutPass(),
-                transform.RemoveUnusedFunctions(),
-                transform.ConvertLayout(
-                    {
-                        "nn.conv1d": ["NCW", "default"],
-                        "nn.conv2d": ["NCHW", "default"],
-                        "nn.conv3d": ["NCDHW", "default"],
-                        "nn.conv2d_transpose": ["NCHW", "default"],
-                    }
-                ),
-                transform.FoldConstant(),
-                transform.MergeComposite(pattern_table()),
-                transform.AnnotateTarget("tensorrt"),
-                transform.MergeCompilerRegions(),
-                transform.PartitionGraph(),
-                transform.InlineComposites("tensorrt"),
-                transform.InferType(),
-            ]
-        )
-        if use_patterns
-        else tvm.transform.Sequential(
-            [
-                transform.InferType(),
-                RemoveDropoutPass(),
-                transform.RemoveUnusedFunctions(),
-                transform.ConvertLayout(
-                    {
-                        "nn.conv1d": ["NCW", "default"],
-                        "nn.conv2d": ["NCHW", "default"],
-                        "nn.conv3d": ["NCDHW", "default"],
-                        "nn.conv2d_transpose": ["NCHW", "default"],
-                    }
-                ),
-                transform.FoldConstant(),
-                transform.AnnotateTarget("tensorrt"),
-                transform.MergeCompilerRegions(),
-                transform.PartitionGraph(),
-                transform.InferType(),
-            ]
-        )
-    )
-
-
-def check_type_dynamism(type, op_name):  # pylint: disable=redefined-builtin
-    r"""
-    Check for dynamic TensorType for an input op
-
-    Parameters
-    ----------
-    type: checked_type of the op
-    op_name: str
-        Name of the op for debugging pursposes.
-    Returns
-    -------
-    ret: bool
-        True if arg dynamic type not suppot in TRT, False otherwise
-    """
-
-    if isinstance(type, tvm.ir.TensorType):
+def is_supported_trt_type(typ: Union[tvm.ir.TensorType, tvm.ir.TupleType], op_name: str) -> bool:
+    """Check whether a type is supported by TensorRT."""
+    supported_dtypes = ["float32", "float16"]
+    if isinstance(typ, tvm.ir.TensorType):
+        if typ.dtype not in supported_dtypes:
+            logger.info(f"{op_name}: Only float32 and float16 tensor dtypes are supported.")
+            return False
         # assumes dim 0 is for batch and can be dynamic
-        for dim_shape in type.shape[1:]:
+        # TODO(mbs): But does this depend use_implicit_batch flag?
+        for dim_shape in typ.shape[1:]:
             if isinstance(dim_shape, tvm.tir.expr.Any):
-                return True
-    elif isinstance(type, tvm.ir.TupleType):
-        for field_type in type.fields:
-            if check_type_dynamism(field_type, op_name):
-                return True
+                logger.info(f"{op_name}: Only statically known tensor shapes are supported.")
+                return False
+    elif isinstance(typ, tvm.ir.TupleType):
+        for field_type in typ.fields:
+            if not is_supported_trt_type(field_type, op_name):
+                return False
     else:
-        logger.info("Arg not supported in TensorRT for %s with type %s", op_name, type)
-        return True
-    return False
-
-
-def check_dynamism(args, op_name):
-    """
-    Check for dynamism inside any of the args in the op.
-
-    Parameters
-    ----------
-    args : tvm.ir.container.Array
-        Arguments of the op. Each of the argument shape is checked for presence of dynamic
-        components.
-    op_name: str
-        Name of the op for debugging purposes only.
-    Returns
-    ----------
-    ret : bool
-        True if dynamism is present, False otherwise
-    """
-    for arg in args:
-        if check_type_dynamism(arg.checked_type, op_name):
-            return True
-    return False
+        logger.info(f"{op_name}: Type {typ} is not supported.")
+        return False
+    return True
 
 
-def _register_external_op_helper_with_checker(op_name, checker):
-    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
-    def _func_wrapper(expr):
-        attrs, args = expr.attrs, expr.args
-        # ops with dynamic shapes are offloaded to VM
-        if not is_supported_trt_dtype(args):
-            return False
-        if check_dynamism(args, op_name):
+def get_op_name(expr: relay.expr.Expr) -> str:
+    """Get the operator name from an expression."""
+    if isinstance(expr, Op):
+        return expr.name
+    if isinstance(expr, Call):
+        return get_op_name(expr.op)
+    if isinstance(expr, TupleGetItem):
+        return get_op_name(expr.tuple_value)
+    if isinstance(expr, relay.Tuple):
+        return get_op_name(expr.fields[0])
+    return ""
+
+
+def get_args(expr: relay.expr.Expr) -> List[relay.expr.Expr]:
+    """Get the arguments from an expression."""
+    if isinstance(expr, Call):
+        return expr.args
+    if isinstance(expr, TupleGetItem):
+        return get_args(expr.tuple_value)
+    if isinstance(expr, relay.Tuple):
+        return [arg for args in map(get_args, expr.fields) for arg in args]
+    return []
+
+
+def get_attrs(expr: relay.expr.Expr) -> Any:
+    """Get the attributes from an expression."""
+    if isinstance(expr, Call):
+        return expr.attrs
+    if isinstance(expr, TupleGetItem):
+        return get_attrs(expr.tuple_value)
+    return {}
+
+
+CheckFunc = Callable[[Any, List[relay.expr.Expr], str], bool]
+
+
+def make_predicate(checker: CheckFunc) -> Callable[[relay.expr.Expr], bool]:
+    def predicate(expr: relay.expr.Expr) -> bool:
+        op_name = get_op_name(expr)
+        attrs = get_attrs(expr)
+        args = get_args(expr)
+        if not all([is_supported_trt_type(arg.checked_type, op_name) for arg in args]):
             return False
-        if op_name == "multiply":
-            shapes = [
-                [
-                    int(x) if not isinstance(x, tvm.tir.expr.Any) else -1
-                    for x in arg.checked_type.shape
-                ]
-                for arg in args
-            ]
-            # Batched multiply operations don't work in implicit batch mode. The following shapes
-            # have been excluded because they occur in PT MaskRCNN model. The long term solution is
-            # to switch to explicit batch mode after performance regressions are solved.
-            if all(
-                [list(map(int, shape)) in [[300, 64, 7, 7], [300, 1, 1, 1]] for shape in shapes]
-            ):
-                return False
         return checker(attrs, args, op_name)
 
-    return _func_wrapper
+    return predicate
 
 
-def _register_external_op_helper(op_name, supported=True):
-    return _register_external_op_helper_with_checker(
-        op_name, lambda attrs, args, op_name: supported
-    )
+standard_predicate = make_predicate(lambda attrs, args, op_name: True)
 
 
-def _register_external_dynamic_check_func(op_name):
-    """Wrapper to check dynamic shapes inside any of the args in the op."""
+def make_trt_version_checker(version: Tuple[int, int, int]) -> CheckFunc:
+    """Helper for ops which require a minimum TRT version"""
 
-    def _decorator_helper(checker):
-        @tvm.ir.register_op_attr(op_name, "target.tensorrt")
-        def _func_wrapper(expr):
-            args = expr.args
-            # ops with dynamic shapes are offloaded to VM
-            if check_dynamism(args, op_name):
-                return False
-            return checker(expr)
+    def checker(attrs: Any, args: List[relay.expr.Expr], op_name: str) -> bool:
+        if get_tensorrt_version() < version:
+            logger.info(
+                f"{op_name}: requires TensorRT version {'.'.join(map(str, version))} or higher."
+            )
+            return False
+        return True
 
-        return _func_wrapper
+    return checker
 
-    return _decorator_helper
 
+def make_and_checker(*checkers: CheckFunc) -> CheckFunc:
+    def checker(attrs: Any, args: List[relay.expr.Expr], op_name: str) -> bool:
+        return all([c(attrs, args, op_name) for c in checkers])
 
-# Ops which are always supported
-_register_external_op_helper("nn.relu")
-_register_external_op_helper("sigmoid")
-_register_external_op_helper("tanh")
-_register_external_op_helper("subtract")
-_register_external_op_helper("multiply")
-_register_external_op_helper("divide")
-_register_external_op_helper("power")
-_register_external_op_helper("maximum")
-_register_external_op_helper("minimum")
-_register_external_op_helper("exp")
-_register_external_op_helper("log")
-_register_external_op_helper("sqrt")
-_register_external_op_helper("abs")
-_register_external_op_helper("negative")
-_register_external_op_helper("nn.batch_flatten")
-_register_external_op_helper("clip")
+    return checker
 
 
-def reduce_annotate_fn(attrs, args, op_name):
+def multiply_checker(attrs: Any, args: List[relay.expr.Expr], op_name: str) -> bool:
+    """Helper for multiply operations."""
+    shapes = [
+        [int(x) if not isinstance(x, tvm.tir.expr.Any) else -1 for x in arg.checked_type.shape]
+        for arg in args
+    ]
+    # TODO(mbs): Follow up
+    # Batched multiply operations don't work in implicit batch mode. The following shapes
+    # have been excluded because they occur in PT MaskRCNN model. The long term solution is
+    # to switch to explicit batch mode after performance regressions are solved.
+    if all([list(map(int, shape)) in [[300, 64, 7, 7], [300, 1, 1, 1]] for shape in shapes]):
+        logger.info(f"{op_name}: Excluding since problematic in implicit batch mode")
+        return False
+    return True
+
+
+def reduce_checker(attrs: Any, args: List[relay.expr.Expr], op_name: str) -> bool:
     """Helper for reduce operations."""
     if get_tensorrt_use_implicit_batch_mode() and (not attrs.axis or len(attrs.axis) == 0):
-        logger.info("%s: cannot reduce to scalar.", op_name)
+        logger.info(f"{op_name}: cannot reduce to scalar.")
         return False
     if attrs.exclude:
-        logger.info("%s: exclude not supported.", op_name)
+        logger.info(f"{op_name}: exclude not supported.")
         return False
     if get_tensorrt_use_implicit_batch_mode() and any([x == 0 for x in map(int, attrs.axis)]):
-        logger.info("%s: can't modify batch dimension.", op_name)
+        logger.info(f"{op_name}: can't modify batch dimension.")
         return False
     return True
 
 
-_register_external_op_helper_with_checker("sum", reduce_annotate_fn)
-_register_external_op_helper_with_checker("prod", reduce_annotate_fn)
-_register_external_op_helper_with_checker("max", reduce_annotate_fn)
-_register_external_op_helper_with_checker("min", reduce_annotate_fn)
-_register_external_op_helper_with_checker("mean", reduce_annotate_fn)
-_register_external_op_helper_with_checker("variance", reduce_annotate_fn)
-
-
-def trt_version_annotate_fn(version):
-    """Helper for ops which require a minimum TRT version"""
-
-    def _func_wrapper(attrs, args, op_name):
-        if get_tensorrt_version() < version:
-            logger.info(
-                "%s: requires TensorRT version %s or higher.", op_name, ".".join(map(str, version))
-            )
-            return False
-        return True
-
-    return _func_wrapper
-
-
-_register_external_op_helper_with_checker("nn.leaky_relu", trt_version_annotate_fn((5, 1, 5)))
-_register_external_op_helper_with_checker("sin", trt_version_annotate_fn((5, 1, 5)))
-_register_external_op_helper_with_checker("cos", trt_version_annotate_fn((5, 1, 5)))
-_register_external_op_helper_with_checker("atan", trt_version_annotate_fn((5, 1, 5)))
-_register_external_op_helper_with_checker("ceil", trt_version_annotate_fn((5, 1, 5)))
-_register_external_op_helper_with_checker("erf", trt_version_annotate_fn((7, 0, 0)))
-
-
-@_register_external_dynamic_check_func("add")
-def add_annotate_fn(expr):  # pylint: disable=unused-variable
+def add_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if add is supported by TensorRT."""
-
-    args = expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     shapes = [
         [int(x) if not isinstance(x, tvm.tir.expr.Any) else -1 for x in arg.checked_type.shape]
         for arg in args
@@ -416,6 +317,7 @@ def add_annotate_fn(expr):  # pylint: disable=unused-variable
 
     # Scalars require explicit batch mode.
     if get_tensorrt_use_implicit_batch_mode() and any([len(shape) < 1 for shape in shapes]):
+        logger.info(f"{op_name}: Scalars not supported in implicit batch mode")
         return False
 
     if (
@@ -427,172 +329,141 @@ def add_annotate_fn(expr):  # pylint: disable=unused-variable
         and shapes[0][0] != 1
         and (len(shapes[0]) > 3 or len(shapes[1]) > 3)
     ):
-        logger.info("add: bug in TRT with adding batched constants.")
+        logger.info(f"{op_name}: bug in TRT with adding batched constants.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.batch_norm")
-def batch_norm_annotate_fn(expr):  # pylint: disable=unused-variable
+def batch_norm_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if nn.batch_norm is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     if len(args[0].checked_type.shape) == 5 and get_tensorrt_version() < (6, 0, 1):
-        logger.info("nn.batch_norm: TensorRT 6.0.1 or higher is required for rank 5 inputs.")
+        logger.info(f"{op_name}: TensorRT 6.0.1 or higher is required for rank 5 inputs.")
         return False
     if len(args[0].checked_type.shape) > 5:
-        logger.info("nn.batch_norm: Input rank must be 5 or less.")
+        logger.info(f"{op_name}: Input rank must be 5 or less.")
         return False
     if int(attrs.axis) not in (1, 3):
-        logger.info("nn.batch_norm: axis is %d but must be 1 or 3.", int(attrs.axis))
+        logger.info(f"{op_name}: axis is {int(attrs.axis)} but must be 1 or 3.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.softmax")
-def softmax_annotate_fn(expr):  # pylint: disable=unused-variable
+def softmax_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if nn.softmax is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
-        logger.info("nn.softmax: can't modify batch dimension.")
+        logger.info(f"{op_name}: can't modify batch dimension.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.conv1d")
-def conv1d_annotate_fn(expr):  # pylint: disable=unused-variable
+def conv1d_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if nn.conv1d is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     if not isinstance(args[1], Constant):
-        logger.info("nn.conv1d: kernel argument must be constant.")
+        logger.info(f"{op_name}: kernel argument must be constant.")
         return False
     if attrs.data_layout != "NCW":
-        logger.info("nn.conv1d: data_layout is %s but must be NCW.", attrs.data_layout)
+        logger.info(f"{op_name}: data_layout is {attrs.data_layout} but must be NCW.")
         return False
     if attrs.kernel_layout != "OIW":
-        logger.info("nn.conv1d: kernel_layout is %s but must be OIW.", attrs.kernel_layout)
+        logger.info(f"{op_name}: kernel_layout is {attrs.kernel_layout} but must be OIW.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.conv2d")
-def conv2d_annotate_fn(expr):  # pylint: disable=unused-variable
+def conv2d_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if nn.conv2d is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
+    assert len(args) == 2
     if not isinstance(args[1], Constant):
-        logger.info("nn.conv2d: kernel argument must be constant.")
+        logger.info(f"{op_name}: kernel argument must be constant.")
         return False
     if attrs.data_layout != "NCHW":
-        logger.info("nn.conv2d: data_layout is %s but must be NCHW.", attrs.data_layout)
+        logger.info(f"{op_name}: data_layout is {attrs.data_layout} but must be NCHW.")
         return False
     if attrs.kernel_layout != "OIHW":
-        logger.info("nn.conv2d: kernel_layout is %s but must be OIHW.", attrs.kernel_layout)
+        logger.info(f"{op_name}: kernel_layout is {attrs.kernel_layout} but must be OIHW.")
         return False
     if attrs.out_layout and attrs.out_layout != "NCHW":
-        logger.info("nn.conv2d: out_layout is %s but must be NCHW.", attrs.out_layout)
+        logger.info(f"{op_name}: out_layout is {attrs.out_layout} but must be NCHW.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.dense")
-def dense_annotate_fn(expr):  # pylint: disable=unused-variable
+def dense_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if dense is supported by TensorRT."""
-
-    args = expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     if not isinstance(args[1], Constant):
-        logger.info("nn.dense: weight must be constant")
+        logger.info(f"{op_name}: weight must be constant")
         return False
     input_rank = len(args[0].checked_type.shape)
     weight_rank = len(args[1].checked_type.shape)
     if input_rank not in (2, 3, 4):
-        logger.info("nn.dense: input has rank %d but must be 2, 3 or 4.", input_rank)
+        logger.info(f"{op_name}: input has rank {input_rank} but must be 2, 3 or 4.")
         return False
     if weight_rank != 2:
-        logger.info("nn.dense: weight has rank %d but must be 2.", weight_rank)
+        logger.info(f"{op_name}: weight has rank {weight_rank} but must be 2.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.batch_matmul")
-def batch_matmul_annotate_fn(expr):
+def batch_matmul_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if dense is supported by TensorRT."""
-
-    args = expr.args
-    if not is_supported_trt_dtype(args):
-        return False
-    if get_tensorrt_use_implicit_batch_mode() and len(expr.args[0].checked_type.shape) != len(
-        expr.args[1].checked_type.shape
+    if get_tensorrt_use_implicit_batch_mode() and len(args[0].checked_type.shape) != len(
+        args[1].checked_type.shape
     ):
-        logger.info("nn.batch_matmul: requires use_implict_batch=False.")
+        logger.info(f"{op_name}: requires use_implict_batch=False.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.layer_norm")
-def layer_norm_annotate_fn(expr):
+def layer_norm_checker(attrs: Any, args: List[relay.expr.Expr], op_name: str) -> bool:
     """Check if dense is supported by TensorRT."""
-
-    args = expr.args
-    if not is_supported_trt_dtype(args):
-        return False
-    if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0:
-        logger.info("nn.layer_norm: requires use_implict_batch=False.")
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
+        logger.info(f"{op_name}: requires use_implict_batch=False.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.bias_add")
-def bias_add_annotate_fn(expr):  # pylint: disable=unused-variable
+def bias_add_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if nn.bias_add is supported by TensorRT."""
-
-    args = expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     input_rank = len(args[0].checked_type.shape)
     if input_rank not in (2, 3, 4):
-        logger.info("nn.bias_add: input rank is %d but must be 2, 3 or 4.", input_rank)
+        logger.info(f"{op_name}: input rank is {input_rank} but must be 2, 3 or 4.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.max_pool2d")
-def max_pool_2d_annotate_fn(expr):  # pylint: disable=unused-variable
+def max_pool_2d_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if nn.max_pool2d is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     if attrs.layout != "NCHW":
-        logger.info("nn.max_pool2d: layout is %s but must be NCHW.", attrs.layout)
+        logger.info(f"{op_name}: layout is {attrs.layout} but must be NCHW.")
         return False
     if attrs.ceil_mode and get_tensorrt_version() < (5, 1, 5):
-        logger.info("nn.avg_pool2d: ceil_mode=True requires TensorRT 5.1.5 or greater.")
+        logger.info(f"{op_name}: ceil_mode=True requires TensorRT 5.1.5 or greater.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.avg_pool2d")
-def avg_pool_2d_annotate_fn(expr):  # pylint: disable=unused-variable
+def avg_pool_2d_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if nn.avg_pool2d is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     if attrs.layout != "NCHW":
-        logger.info("nn.avg_pool2d: layout is %d but must be NCHW.", attrs.layout)
+        logger.info(f"{op_name}: layout is {attrs.layout} but must be NCHW.")
         return False
     if (
         attrs.count_include_pad
@@ -603,175 +474,141 @@ def avg_pool_2d_annotate_fn(expr):  # pylint: disable=unused-variable
         )
     ):
         logger.info(
-            "nn.avg_pool2d: inclusive-counted blended or average "
+            f"{op_name}: inclusive-counted blended or average "
             "pooling is not supported in combination with asymmetric padding"
         )
         return False
     if attrs.ceil_mode and get_tensorrt_version() < (5, 1, 5):
-        logger.info("nn.avg_pool2d: ceil_mode=True requires TensorRT 5.1.5 or greater.")
+        logger.info(f"{op_name}: ceil_mode=True requires TensorRT 5.1.5 or greater.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.global_max_pool2d")
-def global_max_pool_2d_annotate_fn(expr):  # pylint: disable=unused-variable
+def global_max_pool_2d_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if nn.global_max_pool2d is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     if attrs.layout != "NCHW":
-        logger.info("nn.global_max_pool2d: layout is %s but must be NCHW.", attrs.layout)
+        logger.info(f"{op_name}: layout is {attrs.layout} but must be NCHW.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.global_avg_pool2d")
-def global_avg_pool_2d_annotate_fn(expr):  # pylint: disable=unused-variable
+def global_avg_pool_2d_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if nn.global_avg_pool2d is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     if attrs.layout != "NCHW":
-        logger.info("nn.global_avg_pool2d: layout is %s but must be NCHW.", attrs.layout)
+        logger.info(f"{op_name}: layout is {attrs.layout} but must be NCHW.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("expand_dims")
-def expand_dims_annotate_fn(expr):  # pylint: disable=unused-variable
+def expand_dims_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if expand_dims is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
-        logger.info("expand_dims: can't modify batch dimension.")
+        logger.info(f"{op_name}: can't modify batch dimension.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("squeeze")
-def squeeze_annotate_fn(expr):  # pylint: disable=unused-variable
+def squeeze_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if squeeze is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     if not attrs.axis:
-        logger.info("squeeze: must explicitly set axis.")
+        logger.info(f"{op_name}: must explicitly set axis.")
         return False
     if get_tensorrt_use_implicit_batch_mode() and any([axis == 0 for axis in map(int, attrs.axis)]):
-        logger.info("squeeze: can't modify batch dimension.")
+        logger.info(f"{op_name}: can't modify batch dimension.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("concatenate")
-def concatenate_annotate_fn(expr):  # pylint: disable=unused-variable
+def concatenate_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if concatenate is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if any([x.dtype not in supported_types for x in args[0].checked_type.fields]):
-        logger.info("Only float16 and float32 inputs are supported for TensorRT.")
-    if not get_tensorrt_use_implicit_batch_mode():
-        return True
-    if int(attrs.axis) == 0:
-        logger.info("concatenate: can't modify batch dimension.")
-        return False
-    if isinstance(args[0], Tuple):
-        for tuple_input in args[0].fields:
-            if isinstance(tuple_input, Constant):
-                logger.info("concatenate: can't concatenate tensors with constants.")
-                return False
+    if get_tensorrt_use_implicit_batch_mode():
+        if int(attrs.axis) == 0:
+            logger.info(f"{op_name}: can't modify batch dimension.")
+            return False
+        if isinstance(args[0], relay.Tuple):
+            for tuple_input in args[0].fields:
+                if isinstance(tuple_input, Constant):
+                    logger.info(f"{op_name}: can't concatenate tensors with constants.")
+                    return False
     return True
 
 
-@_register_external_dynamic_check_func("split")
-def split_annotate_fn(expr):
+def split_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if split is supported by TensorRT."""
-
-    args = expr.args
-    if not is_supported_trt_dtype(args):
-        return False
-    if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0:
-        logger.info("split: can't modify batch dimension.")
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
+        logger.info(f"{op_name}: can't modify batch dimension.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.conv2d_transpose")
-def conv2d_transpose_annotate_fn(expr):  # pylint: disable=unused-variable
+def conv2d_transpose_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if nn.conv2d_transpose is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     if attrs.data_layout != "NCHW":
-        logger.info("nn.conv2d_transpose: data_layout is %s but must be NCHW.", attrs.data_layout)
+        logger.info(f"{op_name}: data_layout is {attrs.data_layout} but must be NCHW.")
         return False
     if attrs.kernel_layout != "OIHW":
-        logger.info(
-            "nn.conv2d_transpose: kernel_layout is %s but must be OIHW.", attrs.kernel_layout
-        )
+        logger.info(f"{op_name}: kernel_layout is {attrs.kernel_layout} but must be OIHW.")
         return False
     if attrs.out_layout and attrs.out_layout != "NCHW":
-        logger.info("nn.conv2d_transpose: out_layout is %s but must be NCHW.", attrs.out_layout)
+        logger.info(f"{op_name}: out_layout is {attrs.out_layout} but must be NCHW.")
         return False
     if attrs.dilation and any([rate != 1 for rate in map(int, attrs.dilation)]):
-        logger.info("nn.conv2d_transpose: dilation rate must be 1.")
+        logger.info(f"{op_name}: dilation rate must be 1.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("transpose")
-def transpose_annotate_fn(expr):  # pylint: disable=unused-variable
+def transpose_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if transpose is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     if get_tensorrt_use_implicit_batch_mode() and int(attrs.axes[0]) != 0:
-        logger.info("transpose: can't modify batch dimension.")
+        logger.info(f"{op_name}: can't modify batch dimension.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("layout_transform")
-def layout_transform_annotate_fn(expr):  # pylint: disable=unused-variable
+def layout_transform_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if layout_transform is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     if (attrs.src_layout, attrs.dst_layout) not in [
         ("NCHW", "NHWC"),
         ("NHWC", "NCHW"),
         ("NDHWC", "NCDHW"),
         ("NCDHW", "NDHWC"),
     ]:
-        logger.info(
-            "layout_transform: %s to %s is not supported.", attrs.src_layout, attrs.dst_layout
-        )
+        logger.info(f"{op_name}: {attrs.src_layout} to {attrs.dst_layout} is not supported.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("reshape")
-def reshape_annotate_fn(expr):  # pylint: disable=unused-variable
+def reshape_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if reshape is supported by TensorRT."""
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     if any([x < -1 for x in map(int, attrs.newshape)]):
-        logger.info("reshape: new shape dims must be explicit.")
+        logger.info(f"{op_name}: new shape dims must be explicit.")
         return False
     if get_tensorrt_use_implicit_batch_mode():
         shape = args[0].checked_type.shape
         new_shape = attrs.newshape
         if len(new_shape) == 0 or len(shape) == 0:
-            logger.info("reshape: Can't reshape to or from scalar.")
+            logger.info(f"{op_name}: Can't reshape to or from scalar.")
             return False
         dynamic_reshape = any([isinstance(x, tvm.tir.expr.Any) for x in shape])
 
@@ -784,6 +621,7 @@ def reshape_annotate_fn(expr):  # pylint: disable=unused-variable
                         and isinstance(new_shape_val, (int, tvm.tir.expr.IntImm))
                         and int(shape_val) == int(new_shape_val)
                     ):
+                        logger.info(f"{op_name}: can't modify batch dimension")
                         return False
             elif int(new_shape[0]) > 0:
                 # Currently we only allow dim[0] to be Any, so this branch will always be False
@@ -792,67 +630,60 @@ def reshape_annotate_fn(expr):  # pylint: disable=unused-variable
                     and isinstance(new_shape[0], (int, tvm.tir.expr.IntImm))
                     and int(shape[0]) == int(new_shape[0])
                 ):
+                    logger.info(f"{op_name}: can't modify batch dimension")
                     return False
-            return True
-        shape = list(map(int, shape))
-        new_shape = list(map(int, new_shape))
-
-        # TRT cannot modify batch dimension.
-        original_volume = np.prod(shape)
-        # First, resolve 0.
-        for i, value in enumerate(new_shape):
-            if value == 0:
-                new_shape[i] = shape[i]
-        # Resolve -1.
-        for i, value in enumerate(new_shape):
-            if value == -1:
-                new_shape[i] = original_volume // np.prod([x for x in new_shape if x != -1])
-        # Remove batch dimension and see if volumes match
-        if shape[0] != new_shape[0]:
-            logger.info("reshape: can't modify batch dimension.")
-            return False
+        else:
+            shape = list(map(int, shape))
+            new_shape = list(map(int, new_shape))
+
+            # TRT cannot modify batch dimension.
+            original_volume = np.prod(shape)
+            # First, resolve 0.
+            for i, value in enumerate(new_shape):
+                if value == 0:
+                    new_shape[i] = shape[i]
+            # Resolve -1.
+            for i, value in enumerate(new_shape):
+                if value == -1:
+                    new_shape[i] = original_volume // np.prod([x for x in new_shape if x != -1])
+            # Remove batch dimension and see if volumes match
+            if shape[0] != new_shape[0]:
+                logger.info(f"{op_name}: can't modify batch dimension.")
+                return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.pad")
-def pad_annotate_fn(expr):  # pylint: disable=unused-variable
+def pad_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if nn.pad is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     pad_value = args[1]
     if not isinstance(pad_value, relay.Constant):
-        logger.info("nn.pad: pad argument must be constant")
+        logger.info(f"{op_name}: pad argument must be constant")
         return False
     pad_value = pad_value.data.numpy().item()
     if attrs.pad_mode != "constant":
-        logger.info("nn.pad: pad mode is %s but must be constant.", attrs.pad_mode)
+        logger.info(f"{op_name}: pad mode is {attrs.pad_mode} but must be constant.")
         return False
     if pad_value > 0.0:
-        logger.info("nn.pad: pad value is %f but must be 0.0.", pad_value)
+        logger.info(f"{op_name}: pad value is {pad_value} but must be 0.0.")
         return False
     if len(attrs.pad_width) not in [4, 5]:
-        logger.info("nn.pad: can only pad 4D or 5D inputs")
+        logger.info(f"{op_name}: can only pad 4D or 5D inputs")
         return False
     if any([x != 0 for x in attrs.pad_width[0]]) or any([x != 0 for x in attrs.pad_width[1]]):
-        logger.info("nn.pad: can't pad batch or channel dimensions.")
+        logger.info(f"{op_name}: can't pad batch or channel dimensions.")
         return False
     if len(attrs.pad_width) == 5 and any([x != 0 for x in attrs.pad_width[2]]):
-        logger.info("nn.pad: can only pad last two dimensions for 5D inputs.")
+        logger.info(f"{op_name}: can only pad last two dimensions for 5D inputs.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("strided_slice")
-def strided_slice_annotate_fn(expr):  # pylint: disable=unused-variable
+def strided_slice_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if strided_slice is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
-    if not trt_version_annotate_fn((5, 1, 5))(attrs, args, "strided_slice"):
-        return False
     if get_tensorrt_use_implicit_batch_mode():
         batch_dim_begin_modified = attrs.begin[0] is not None and int(attrs.begin[0]) != 0
         batch_dim_end_modified = (
@@ -861,10 +692,10 @@ def strided_slice_annotate_fn(expr):  # pylint: disable=unused-variable
             and int(attrs.end[0]) != int(args[0].checked_type.shape[0])
         )
         if batch_dim_begin_modified or batch_dim_end_modified:
-            logger.info("strided_slice: can't modify batch dimension.")
+            logger.info(f"{op_name}: can't modify batch dimension.")
             return False
     if any([x is not None and x <= 0 for x in attrs.strides]):
-        logger.info("strided_slice: stride must be positive")
+        logger.info(f"{op_name}: stride must be positive")
         return False
     for i in range(0, len(args[0].checked_type.shape)):
         begin = int(attrs.begin[i])
@@ -882,238 +713,304 @@ def strided_slice_annotate_fn(expr):  # pylint: disable=unused-variable
                 else args[0].checked_type.shape[i] - begin
             )
         else:
-            logger.warning("strided_slice: unknown slice mode encountered")
+            logger.warning(f"{op_name}: unknown slice mode encountered")
+            size = 1
 
         if int(size) < 1:
-            logger.info("strided_slice: size of slice must be at least 1")
+            logger.info(f"{op_name}: size of slice must be at least 1")
             return False
 
     return True
 
 
-@_register_external_dynamic_check_func("nn.adaptive_max_pool2d")
-def adaptive_max_pool2d_annotate_fn(expr):  # pylint: disable=unused-variable
+def adaptive_max_pool2d_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if nn.adaptive_max_pool2d is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
-        logger.info("nn.adaptive_max_pool2d: output size must be (1, 1).")
+        logger.info(f"{op_name}: output size must be (1, 1).")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.adaptive_avg_pool2d")
-def adaptive_avg_pool2d_annotate_fn(expr):  # pylint: disable=unused-variable
+def adaptive_avg_pool2d_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if nn.adaptive_avg_pool2d is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
-        logger.info("nn.adaptive_avg_pool2d: output size must be (1, 1).")
+        logger.info(f"{op_name}: output size must be (1, 1).")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.conv3d")
-def conv3d_annotate_fn(expr):  # pylint: disable=unused-variable
+def conv3d_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if nn.conv3d is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
     if not isinstance(args[1], Constant):
-        logger.info("nn.conv3d: kernel argument must be constant.")
-        return False
-    if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d"):
+        logger.info(f"{op_name}: kernel argument must be constant.")
         return False
     if attrs.data_layout != "NCDHW":
-        logger.info("nn.conv3d: data_layout is %s but must be NCDHW.", attrs.data_layout)
+        logger.info(f"{op_name}: data_layout is {attrs.data_layout} but must be NCDHW.")
         return False
     if attrs.kernel_layout != "OIDHW":
-        logger.info("nn.conv3d: kernel_layout is %s but must be OIDHW.", attrs.kernel_layout)
+        logger.info(f"{op_name}: kernel_layout is {attrs.kernel_layout} but must be OIDHW.")
         return False
     if attrs.out_layout and attrs.out_layout != "NCDHW":
-        logger.info("nn.conv3d: out_layout is %s but must be NCDHW.", attrs.out_layout)
+        logger.info(f"{op_name}: out_layout is {attrs.out_layout} but must be NCDHW.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.max_pool3d")
-def max_pool_3d_annotate_fn(expr):  # pylint: disable=unused-variable
+def max_pool_3d_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if nn.max_pool3d is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
-    if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.max_pool3d"):
-        return False
     if attrs.layout != "NCDHW":
-        logger.info("nn.max_pool3d: layout is %s but must be NCDHW.", attrs.layout)
+        logger.info(f"{op_name}: layout is {attrs.layout} but must be NCDHW.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.avg_pool3d")
-def avg_pool_3d_annotate_fn(expr):  # pylint: disable=unused-variable
+def avg_pool_3d_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if nn.avg_pool3d is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
-    if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.avg_pool3d"):
-        return False
     if attrs.layout != "NCDHW":
-        logger.info("nn.avg_pool3d: layout is %s but must be NCDHW.", attrs.layout)
+        logger.info(f"{op_name}: layout is {attrs.layout} but must be NCDHW.")
         return False
     return True
 
 
-@_register_external_dynamic_check_func("nn.conv3d_transpose")
-def conv3d_transpose_annotate_fn(expr):  # pylint: disable=unused-variable
+def conv3d_transpose_checker(
+    attrs: Any, args: List[relay.expr.Expr], op_name: str
+) -> bool:  # pylint: disable=unused-variable
     """Check if nn.conv3d_transpose is supported by TensorRT."""
-
-    attrs, args = expr.attrs, expr.args
-    if not is_supported_trt_dtype(args):
-        return False
-    if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d_transpose"):
-        return False
     if attrs.data_layout != "NCDHW":
-        logger.info("nn.conv3d_transpose: data_layout is %s but must be NCDHW.", attrs.data_layout)
+        logger.info(f"{op_name}: data_layout is {attrs.data_layout} but must be NCDHW.")
         return False
     if attrs.kernel_layout != "OIDHW":
-        logger.info(
-            "nn.conv3d_transpose: kernel_layout is %s but must be OIDHW.", attrs.kernel_layout
-        )
+        logger.info(f"{op_name}: kernel_layout is {attrs.kernel_layout} but must be OIDHW.")
         return False
     if attrs.out_layout and attrs.out_layout != "NCDHW":
-        logger.info("nn.conv3d_transpose: out_layout is %s but must be NCDHW.", attrs.out_layout)
+        logger.info(f"{op_name}: out_layout is {attrs.out_layout} but must be NCDHW.")
         return False
     if attrs.dilation and any([rate != 1 for rate in map(int, attrs.dilation)]):
-        logger.info("nn.conv3d_transpose: dilation rate must be 1.")
+        logger.info(f"{op_name}: dilation rate must be 1.")
         return False
     if attrs.output_padding and any([x != 0 for x in map(int, attrs.output_padding)]):
-        logger.info("nn.conv3d_transpose: output padding is not supported.")
+        logger.info(f"{op_name}: output padding is not supported.")
         return False
     return True
 
 
-def unary_op_pattern(op):
+def unary_op_pattern(op: relay.expr.Expr) -> relay.dataflow_pattern.DFPattern:
     """Matches unary operation"""
     return is_op(op)(wildcard())
 
 
-def unary_op_pattern_with_any_tuple(op):
+def unary_op_pattern_with_any_tuple(op: relay.expr.Expr) -> relay.dataflow_pattern.DFPattern:
     """Matches unary operation with literal tuple argument"""
     return is_op(op)(is_tuple(None))
 
 
-def binary_op_pattern(op):
+def binary_op_pattern(op: relay.expr.Expr) -> relay.dataflow_pattern.DFPattern:
     """Matches binary operation"""
     return is_op(op)(wildcard(), wildcard())
 
 
-def binary_op_pattern_with_const(op):
+def binary_op_pattern_with_const(op: relay.expr.Expr) -> relay.dataflow_pattern.DFPattern:
     """Matches binary operation with rhs arg a constant"""
     return is_op(op)(wildcard(), is_constant())
 
 
+def proj_five_op_pattern_with_const(op: relay.expr.Expr) -> relay.dataflow_pattern.DFPattern:
+    return is_tuple_get_item(
+        is_op(op)(wildcard(), is_constant(), is_constant(), is_constant(), is_constant()), 0
+    )
+
+
 @register_pattern_table("tensorrt")
-def pattern_table():
+def pattern_table() -> List[
+    Tuple[str, relay.dataflow_pattern.DFPattern, Callable[[relay.expr.Call], bool]]
+]:
     """Get the Tensorrt compiler pattern table for supported ops."""
 
     return [
-        ("tensorrt.nn.conv3d", binary_op_pattern_with_const("nn.conv3d"), conv3d_annotate_fn),
-        ("tensorrt.nn.conv2d", binary_op_pattern_with_const("nn.conv2d"), conv2d_annotate_fn),
-        ("tensorrt.nn.conv1d", binary_op_pattern_with_const("nn.conv1d"), conv1d_annotate_fn),
+        (
+            "tensorrt.nn.conv3d",
+            binary_op_pattern_with_const("nn.conv3d"),
+            make_predicate(make_and_checker(make_trt_version_checker((6, 0, 1)), conv3d_checker)),
+        ),
+        (
+            "tensorrt.nn.conv2d",
+            binary_op_pattern_with_const("nn.conv2d"),
+            make_predicate(conv2d_checker),
+        ),
+        (
+            "tensorrt.nn.conv1d",
+            binary_op_pattern_with_const("nn.conv1d"),
+            make_predicate(conv1d_checker),
+        ),
         (
             "tensorrt.nn.conv2d_transpose",
             binary_op_pattern("nn.conv2d_transpose"),
-            conv2d_transpose_annotate_fn,
+            make_predicate(conv2d_transpose_checker),
+        ),
+        ("tensorrt.squeeze", binary_op_pattern("squeeze"), make_predicate(squeeze_checker)),
+        ("tensorrt.add", binary_op_pattern("add"), make_predicate(add_checker)),
+        (
+            "tensorrt.nn.dense",
+            binary_op_pattern_with_const("nn.dense"),
+            make_predicate(dense_checker),
         ),
-        ("tensorrt.squeeze", binary_op_pattern("squeeze"), squeeze_annotate_fn),
-        ("tensorrt.add", binary_op_pattern("add"), add_annotate_fn),
-        ("tensorrt.nn.dense", binary_op_pattern_with_const("nn.dense"), dense_annotate_fn),
-        ("tensorrt.bias_add", binary_op_pattern("nn.bias_add"), bias_add_annotate_fn),
+        ("tensorrt.bias_add", binary_op_pattern("nn.bias_add"), make_predicate(bias_add_checker)),
         (
             "tensorrt.nn.batch_matmul",
             binary_op_pattern("nn.batch_matmul"),
-            batch_matmul_annotate_fn,
+            make_predicate(batch_matmul_checker),
         ),
-        ("tensorrt.divide", binary_op_pattern("divide")),
-        ("tensorrt.multiply", binary_op_pattern("multiply")),
-        ("tensorrt.nn.relu", unary_op_pattern("nn.relu")),
+        ("tensorrt.divide", binary_op_pattern("divide"), standard_predicate),
+        ("tensorrt.multiply", binary_op_pattern("multiply"), make_predicate(multiply_checker)),
+        ("tensorrt.subtract", binary_op_pattern("subtract"), standard_predicate),
+        ("tensorrt.power", binary_op_pattern("power"), standard_predicate),
+        ("tensorrt.maximum", binary_op_pattern("maximum"), standard_predicate),
+        ("tensorrt.minimum", binary_op_pattern("minimum"), standard_predicate),
+        ("tensorrt.nn.relu", unary_op_pattern("nn.relu"), standard_predicate),
         (
             "tensorrt.nn.leaky_relu",
             unary_op_pattern("nn.leaky_relu"),
-            trt_version_annotate_fn((5, 1, 5)),
+            make_predicate(make_trt_version_checker((5, 1, 5))),
+        ),
+        ("tensorrt.nn.pad", unary_op_pattern("nn.pad"), standard_predicate),
+        ("tensorrt.sigmoid", unary_op_pattern("sigmoid"), standard_predicate),
+        ("tensorrt.tanh", unary_op_pattern("tanh"), standard_predicate),
+        ("tensorrt.exp", unary_op_pattern("exp"), standard_predicate),
+        ("tensorrt.log", unary_op_pattern("log"), standard_predicate),
+        ("tensorrt.sqrt", unary_op_pattern("sqrt"), standard_predicate),
+        ("tensorrt.abs", unary_op_pattern("abs"), standard_predicate),
+        ("tensorrt.negative", unary_op_pattern("negative"), standard_predicate),
+        ("tensorrt.nn.batch_flatten", unary_op_pattern("nn.batch_flatten"), standard_predicate),
+        ("tensorrt.clip", unary_op_pattern("clip"), standard_predicate),
+        (
+            "tensorrt.sin",
+            unary_op_pattern("sin"),
+            make_predicate(make_trt_version_checker((5, 1, 5))),
         ),
-        ("tensorrt.nn.pad", unary_op_pattern("nn.pad")),
-        ("tensorrt.sigmoid", unary_op_pattern("sigmoid")),
-        ("tensorrt.tanh", unary_op_pattern("tanh")),
-        ("tensorrt.exp", unary_op_pattern("exp")),
-        ("tensorrt.log", unary_op_pattern("log")),
-        ("tensorrt.sqrt", unary_op_pattern("sqrt")),
-        ("tensorrt.abs", unary_op_pattern("abs")),
-        ("tensorrt.power", unary_op_pattern("power")),
-        ("tensorrt.negative", unary_op_pattern("negative")),
-        ("tensorrt.nn.batch_flatten", unary_op_pattern("nn.batch_flatten")),
-        ("tensorrt.sin", unary_op_pattern("sin"), trt_version_annotate_fn((5, 1, 5))),
-        ("tensorrt.clip", unary_op_pattern("clip")),
-        ("tensorrt.cos", unary_op_pattern("cos"), trt_version_annotate_fn((5, 1, 5))),
-        ("tensorrt.atan", unary_op_pattern("atan"), trt_version_annotate_fn((5, 1, 5))),
-        ("tensorrt.ceil", unary_op_pattern("ceil"), trt_version_annotate_fn((5, 1, 5))),
-        ("tensorrt.floor", unary_op_pattern("floor")),
-        ("tensorrt.erf", unary_op_pattern("erf"), trt_version_annotate_fn((7, 0, 0))),
-        ("tensorrt.sum", unary_op_pattern("sum"), reduce_annotate_fn),
-        ("tensorrt.prod", unary_op_pattern("prod"), reduce_annotate_fn),
-        ("tensorrt.max", unary_op_pattern("max"), reduce_annotate_fn),
-        ("tensorrt.min", unary_op_pattern("min"), reduce_annotate_fn),
-        ("tensorrt.max", unary_op_pattern("max"), reduce_annotate_fn),
+        (
+            "tensorrt.cos",
+            unary_op_pattern("cos"),
+            make_predicate(make_trt_version_checker((5, 1, 5))),
+        ),
+        (
+            "tensorrt.atan",
+            unary_op_pattern("atan"),
+            make_predicate(make_trt_version_checker((5, 1, 5))),
+        ),
+        (
+            "tensorrt.ceil",
+            unary_op_pattern("ceil"),
+            make_predicate(make_trt_version_checker((5, 1, 5))),
+        ),
+        ("tensorrt.floor", unary_op_pattern("floor"), standard_predicate),
+        (
+            "tensorrt.erf",
+            unary_op_pattern("erf"),
+            make_predicate(make_trt_version_checker((7, 0, 0))),
+        ),
+        ("tensorrt.sum", unary_op_pattern("sum"), make_predicate(reduce_checker)),
+        ("tensorrt.prod", unary_op_pattern("prod"), make_predicate(reduce_checker)),
+        ("tensorrt.max", unary_op_pattern("max"), make_predicate(reduce_checker)),
+        ("tensorrt.min", unary_op_pattern("min"), make_predicate(reduce_checker)),
+        ("tensorrt.max", unary_op_pattern("max"), make_predicate(reduce_checker)),
+        ("tensorrt.mean", unary_op_pattern("mean"), make_predicate(reduce_checker)),
         (
             "tensorrt.concatenate",
             unary_op_pattern_with_any_tuple("concatenate"),
-            concatenate_annotate_fn,
+            make_predicate(concatenate_checker),
+        ),
+        (
+            "tensorrt.expand_dims",
+            unary_op_pattern("expand_dims"),
+            make_predicate(expand_dims_checker),
         ),
-        ("tensorrt.expand_dims", unary_op_pattern("expand_dims"), expand_dims_annotate_fn),
         (
             "tensorrt.layout_transform",
             unary_op_pattern("layout_transform"),
-            layout_transform_annotate_fn,
+            make_predicate(layout_transform_checker),
+        ),
+        ("tensorrt.transpose", unary_op_pattern("transpose"), make_predicate(transpose_checker)),
+        ("tensorrt.reshape", unary_op_pattern("reshape"), make_predicate(reshape_checker)),
+        ("tensorrt.split", unary_op_pattern("split"), make_predicate(split_checker)),
+        ("tensorrt.nn.pad", unary_op_pattern("nn.pad"), make_predicate(pad_checker)),
+        (
+            "tensorrt.strided_slice",
+            unary_op_pattern("strided_slice"),
+            make_predicate(
+                make_and_checker(make_trt_version_checker((5, 1, 5)), strided_slice_checker)
+            ),
         ),
-        ("tensorrt.transpose", unary_op_pattern("transpose"), transpose_annotate_fn),
-        ("tensorrt.reshape", unary_op_pattern("reshape"), reshape_annotate_fn),
-        ("tensorrt.split", unary_op_pattern("split"), split_annotate_fn),
-        ("tensorrt.nn.pad", unary_op_pattern("nn.pad"), pad_annotate_fn),
-        ("tensorrt.strided_slice", unary_op_pattern("strided_slice"), strided_slice_annotate_fn),
         (
             "tensorrt.nn.adaptive_avg_pool2d",
             unary_op_pattern("nn.adaptive_avg_pool2d"),
-            adaptive_avg_pool2d_annotate_fn,
+            make_predicate(adaptive_avg_pool2d_checker),
+        ),
+        (
+            "tensorrt.nn.adaptive_max_pool2d",
+            unary_op_pattern("nn.adaptive_max_pool2d"),
+            make_predicate(adaptive_max_pool2d_checker),
+        ),
+        (
+            "tensorrt.nn.max_pool3d",
+            unary_op_pattern("nn.max_pool3d"),
+            make_predicate(
+                make_and_checker(make_trt_version_checker((6, 0, 1)), max_pool_3d_checker)
+            ),
+        ),
+        (
+            "tensorrt.nn.avg_pool3d",
+            unary_op_pattern("nn.avg_pool3d"),
+            make_predicate(
+                make_and_checker(make_trt_version_checker((6, 0, 1)), avg_pool_3d_checker)
+            ),
         ),
-        ("tensorrt.nn.max_pool3d", unary_op_pattern("nn.max_pool3d"), max_pool_3d_annotate_fn),
-        ("tensorrt.nn.avg_pool3d", unary_op_pattern("nn.avg_pool3d"), avg_pool_3d_annotate_fn),
         (
             "tensorrt.nn.conv3d_transpose",
             unary_op_pattern("nn.conv3d_transpose"),
-            conv3d_transpose_annotate_fn,
+            make_predicate(
+                make_and_checker(make_trt_version_checker((6, 0, 1)), conv3d_transpose_checker)
+            ),
+        ),
+        ("tensorrt.nn.softmax", unary_op_pattern("nn.softmax"), make_predicate(softmax_checker)),
+        (
+            "tensorrt.nn.layer_norm",
+            unary_op_pattern("nn.layer_norm"),
+            make_predicate(layer_norm_checker),
+        ),
+        (
+            "tensorrt.nn.max_pool2d",
+            unary_op_pattern("nn.max_pool2d"),
+            make_predicate(max_pool_2d_checker),
+        ),
+        (
+            "tensorrt.nn.avg_pool2d",
+            unary_op_pattern("nn.avg_pool2d"),
+            make_predicate(avg_pool_2d_checker),
         ),
-        ("tensorrt.nn.softmax", unary_op_pattern("nn.softmax"), softmax_annotate_fn),
-        ("tensorrt.nn.layer_norm", unary_op_pattern("nn.layer_norm"), layer_norm_annotate_fn),
-        ("tensorrt.nn.max_pool2d", unary_op_pattern("nn.max_pool2d"), max_pool_2d_annotate_fn),
-        ("tensorrt.nn.avg_pool2d", unary_op_pattern("nn.avg_pool2d"), avg_pool_2d_annotate_fn),
-        ("tensorrt.nn.max_pool3d", unary_op_pattern("nn.max_pool3d"), max_pool_3d_annotate_fn),
         (
             "tensorrt.nn.global_max_pool2d",
             unary_op_pattern("nn.global_max_pool2d"),
-            global_max_pool_2d_annotate_fn,
+            make_predicate(global_max_pool_2d_checker),
         ),
         (
             "tensorrt.nn.global_avg_pool2d",
             unary_op_pattern("nn.global_avg_pool2d"),
-            global_avg_pool_2d_annotate_fn,
+            make_predicate(global_avg_pool_2d_checker),
+        ),
+        (
+            "tensorrt.nn.batch_norm",
+            proj_five_op_pattern_with_const("nn.batch_norm"),
+            make_predicate(batch_norm_checker),
         ),
     ]
 
@@ -1124,34 +1021,32 @@ class IsComputeIntensiveGraph(ExprVisitor):
     its transpose, dense and batch mat-mul.
     """
 
-    def __init__(self):
+    def __init__(self) -> None:
         ExprVisitor.__init__(self)
         self.is_compute_intensive = False
 
-    def visit_call(self, call):
-        compute_intensive_ops = set(
-            [
-                "nn.conv1d",
-                "nn.conv2d",
-                "nn.conv2d_transpose",
-                "nn.conv3d",
-                "nn.conv3d_transpose",
-                "nn.dense",
-                "nn.batch_matmul",
-                "sum",
-                "prod",
-                "max",
-                "min",
-                "mean",
-            ]
-        )
+    def visit_call(self, call: relay.expr.Call) -> None:
+        compute_intensive_ops = {
+            "nn.conv1d",
+            "nn.conv2d",
+            "nn.conv2d_transpose",
+            "nn.conv3d",
+            "nn.conv3d_transpose",
+            "nn.dense",
+            "nn.batch_matmul",
+            "sum",
+            "prod",
+            "max",
+            "min",
+            "mean",
+        }
         if isinstance(call.op, tvm.tir.op.Op):
             if str(call.op) in compute_intensive_ops:
                 self.is_compute_intensive = True
 
         return super().visit_call(call)
 
-    def is_graph_compute_intensive(self, subgraph) -> bool:
+    def is_graph_compute_intensive(self, subgraph: relay.expr.Expr) -> bool:
         """
         This function recursively visits the graph and checks if it's compute intensive"
         """
@@ -1159,7 +1054,7 @@ class IsComputeIntensiveGraph(ExprVisitor):
         return self.is_compute_intensive
 
 
-def is_valid_subgraph(params, body):
+def is_valid_subgraph(params: List[relay.expr.Var], body: relay.expr.Expr) -> bool:
     """Final check on whether the subgraph is valid and should be offloaded to TensorRT."""
     # Remove invalid subgraphs for implicit batch mode.
     if get_tensorrt_use_implicit_batch_mode():
@@ -1192,7 +1087,7 @@ def is_valid_subgraph(params, body):
     return True
 
 
-def prune_tensorrt_subgraphs(mod):
+def prune_tensorrt_subgraphs(mod: tvm.IRModule) -> tvm.IRModule:
     """
     Removes invalid subgraphs and those with no multiply-accumulates (if remove_no_max_subgraphs
     is set).
@@ -1203,13 +1098,15 @@ def prune_tensorrt_subgraphs(mod):
         Reverts subgraphs in subgraphs_to_remove back to TVM instead of using an external codegen.
         """
 
-        def __init__(self, subgraphs_to_remove, mod, new_mod):
+        def __init__(
+            self, subgraphs_to_remove: List[str], mod: tvm.IRModule, new_mod: tvm.IRModule
+        ) -> None:
             ExprMutator.__init__(self)
             self.subgraphs_to_remove = subgraphs_to_remove
             self.mod = mod
             self.new_mod = new_mod
 
-        def visit_call(self, call):
+        def visit_call(self, call: relay.expr.Call) -> relay.expr.Expr:
             if isinstance(call.op, GlobalVar):
                 name = call.op.name_hint
                 if name in self.subgraphs_to_remove:
@@ -1227,7 +1124,7 @@ def prune_tensorrt_subgraphs(mod):
                     return call.op(*args)
             return super().visit_call(call)
 
-    subgraphs_to_remove = []
+    subgraphs_to_remove: List[str] = []
     # Remove invalid subgraphs
     for subgraph in mod.get_global_vars():
         name = subgraph.name_hint
@@ -1247,7 +1144,7 @@ class RemoveDropout(ExprMutator):
     Removes all nn.dropout from an expr.
     """
 
-    def visit_tuple_getitem(self, op):
+    def visit_tuple_getitem(self, op: TupleGetItem) -> relay.expr.Expr:
         visit = super().visit_tuple_getitem(op)
         if visit.index != 0:
             return visit
@@ -1263,5 +1160,7 @@ class RemoveDropout(ExprMutator):
 
 @transform.function_pass(opt_level=0)
 class RemoveDropoutPass:
-    def transform_function(self, func, mod, _):
+    def transform_function(
+        self, func: relay.function.Function, mod: tvm.IRModule, _: tvm.transform.PassContext
+    ) -> relay.function.Function:
         return RemoveDropout().visit(func)
diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc
index 431be8ed3d..149cc485c7 100644
--- a/src/relay/backend/contrib/tensorrt/codegen.cc
+++ b/src/relay/backend/contrib/tensorrt/codegen.cc
@@ -70,51 +70,28 @@ class TensorRTCompilerConfig : public Attrs {
 TVM_REGISTER_NODE_TYPE(TensorRTCompilerConfigNode);
 TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.tensorrt.options", TensorRTCompilerConfig);
 
+using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
+using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
+using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr;
+using OpAttrExtractor = backend::contrib::OpAttrExtractor;
+using JSONSerializer = backend::contrib::JSONSerializer;
+
+class TensorRTJSONSerializer;
+
 /*!
- * \brief Generates an TensorRTModule from a relay expression by serializing the expression to a
- * json representation. TensorRT is not required here because use of TensorRT APIs is deferred until
- * runtime.
+ * \brief Collect the constants and attributes from all operator calls in the body
+ * of a "Composite" function.
  */
-class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
-  using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
-  using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
-
+class CollectFromCompositeFunctionBody : public ExprVisitor {
  public:
-  TensorRTJSONSerializer(const std::string& symbol, const Expr& expr)
-      : JSONSerializer(symbol, expr) {}
-
-  std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* cn) {
-    std::string name;
-    if (const auto* op_node = cn->op.as<OpNode>()) {
-      name = op_node->name;
-    } else {
-      return JSONSerializer::VisitExpr_(cn);
-    }
+  explicit CollectFromCompositeFunctionBody(TensorRTJSONSerializer* serializer)
+      : serializer_(serializer), node_(std::make_shared<JSONGraphNode>()) {}
 
-    std::vector<JSONGraphNodeEntry> inputs;
-    for (const auto& arg : cn->args) {
-      auto res = VisitExpr(arg);
-      inputs.insert(inputs.end(), res.begin(), res.end());
-    }
-    auto node = std::make_shared<JSONGraphNode>(name,     /* name_ */
-                                                "kernel", /* op_type_ */
-                                                inputs, 1 /* num_outputs_ */);
-    if (name == "nn.pad") {
-      SetPadNodeAttribute(node, cn);
-    } else if (name == "strided_slice") {
-      SetStridedSliceNodeAttribute(node, cn);
-    } else if (name == "split") {
-      SetSplitNodeAttribute(node, cn);
-    } else {
-      SetCallNodeAttribute(node, cn);
-    }
-    // These attributes are global to the whole module.
-    SaveGlobalAttributes(node);
-    return AddNode(node, GetRef<Expr>(cn));
-  }
+  void VisitExpr_(const ConstantNode* constant_node) final;
+  void VisitExpr_(const CallNode* call_node) final;
 
-  void SetPadNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* cn) {
-    const auto* pad_attr = cn->attrs.as<PadAttrs>();
+  void SetPadNodeAttribute(const CallNode* call_node) {
+    const auto* pad_attr = call_node->attrs.as<PadAttrs>();
     ICHECK(pad_attr);
     auto p = pad_attr->pad_width;
     const int dim_h = (p.size() == 5) ? 3 : 2;
@@ -125,16 +102,16 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
                                         std::to_string(p[dim_w][1].as<IntImmNode>()->value)};
     std::vector<dmlc::any> padding_attr;
     padding_attr.emplace_back(padding);
-    node->SetAttr("padding", padding_attr);
+    node_->SetAttr("padding", padding_attr);
   }
 
-  void SetStridedSliceNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* cn) {
-    const auto* attrs = cn->attrs.as<StridedSliceAttrs>();
+  void SetStridedSliceNodeAttribute(const CallNode* call_node) {
+    const auto* attrs = call_node->attrs.as<StridedSliceAttrs>();
     ICHECK(attrs && attrs->begin && attrs->end && attrs->strides)
         << "StridedSlice must have static begin, end, and strides.";
     const bool default_strides =
         !attrs->strides.value().defined() || attrs->strides.value().size() == 0;
-    auto ishape = backend::GetShape(cn->args[0]->checked_type());
+    auto ishape = backend::GetShape(call_node->args[0]->checked_type());
 
     auto process_slice_index = [](Integer x, int default_value, int dim_value) {
       if (!x.defined()) return default_value;
@@ -173,19 +150,19 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
     start_attr.emplace_back(start);
     size_attr.emplace_back(size);
     strides_attr.emplace_back(strides);
-    node->SetAttr("start", start_attr);
-    node->SetAttr("size", size_attr);
-    node->SetAttr("strides", strides_attr);
+    node_->SetAttr("start", start_attr);
+    node_->SetAttr("size", size_attr);
+    node_->SetAttr("strides", strides_attr);
   }
 
-  void SetSplitNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* cn) {
-    const auto* split_attr = cn->attrs.as<SplitAttrs>();
+  void SetSplitNodeAttribute(const CallNode* call_node) {
+    const auto* split_attr = call_node->attrs.as<SplitAttrs>();
     ICHECK(split_attr);
 
     std::vector<std::string> indices_or_sections;
     std::vector<std::string> mode;
     std::vector<std::string> axis = {std::to_string(split_attr->axis)};
-    if (const IntImmNode* sections = split_attr->indices_or_sections.as<IntImmNode>()) {
+    if (const auto* sections = split_attr->indices_or_sections.as<IntImmNode>()) {
       mode.emplace_back("sections");
       indices_or_sections.emplace_back(std::to_string(sections->value));
     } else {
@@ -202,12 +179,80 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
     indices_or_sections_attr.emplace_back(indices_or_sections);
     mode_attr.emplace_back(mode);
     axis_attr.emplace_back(axis);
-    node->SetAttr("indices_or_sections", indices_or_sections_attr);
-    node->SetAttr("mode", mode_attr);
-    node->SetAttr("axis", axis_attr);
+    node_->SetAttr("indices_or_sections", indices_or_sections_attr);
+    node_->SetAttr("mode", mode_attr);
+    node_->SetAttr("axis", axis_attr);
+  }
+
+  void SetGenericAttributes(const CallNode* call_node) {
+    OpAttrExtractor extractor(node_);
+    const Object* attr_obj = call_node->attrs.get();
+    extractor.Extract(const_cast<Object*>(attr_obj));
   }
 
-  void SaveGlobalAttributes(std::shared_ptr<JSONGraphNode> node) {
+  TensorRTJSONSerializer* serializer_;
+  /*! \brief Accumulated translated arguments. */
+  std::vector<JSONGraphNodeEntry> args_;
+  /*!
+   * \brief Temporary node into which we'll accumulate attributes. Ideally this would be the
+   * final JSONGraphNode however we don't yet know how many inputs that will have.
+   */
+  JSONGraphObjectPtr node_;
+};
+
+/*!
+ * \brief Generates an TensorRTModule from a relay expression by serializing the expression to a
+ * json representation. TensorRT is not required here because use of TensorRT APIs is deferred until
+ * runtime.
+ */
+class TensorRTJSONSerializer : public JSONSerializer {
+ public:
+  TensorRTJSONSerializer(const std::string& symbol, const Expr& expr)
+      : JSONSerializer(symbol, expr) {}
+
+  using JSONSerializer::VisitExpr_;
+
+  std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* call_node) final {
+    // The call must be to an inline "Composite" function
+    const auto* function_node = call_node->op.as<FunctionNode>();
+    ICHECK(function_node != nullptr);
+    auto opt_composite = function_node->GetAttr<String>(attr::kComposite);
+    ICHECK(opt_composite.defined());
+    std::string name = opt_composite.value();
+
+    // Collect the constants and attributes of all operator calls inside the composite body.
+    CollectFromCompositeFunctionBody collector(this);
+    collector.VisitExpr(function_node->body);
+
+    // Capture the args to the "Composite" function as inputs for this node.
+    std::vector<JSONGraphNodeEntry> inputs;
+    for (const auto& arg : call_node->args) {
+      auto res = VisitExpr(arg);
+      inputs.insert(inputs.end(), res.begin(), res.end());
+    }
+
+    // Capture constants from the composite function body as additional inputs for this node.
+    for (const auto& node : collector.args_) {
+      inputs.emplace_back(node);
+    }
+
+    // Create the final node.
+    auto node = std::make_shared<JSONGraphNode>(name,
+                                                /*op_type=*/"kernel", inputs,
+                                                /*num_output=*/1);
+
+    // Transfer attributes from the collector's node to the final node.
+    node->CaptureAttrs(*collector.node_);
+
+    // Capture global settings on the JSON node.
+    SaveGlobalAttributes(node);
+
+    VLOG(1) << name << " has " << node->GetInputs().size() << " inputs";
+
+    return AddNode(node, GetRef<Expr>(call_node));
+  }
+
+  static void SaveGlobalAttributes(std::shared_ptr<JSONGraphNode> node) {
     auto ctx = transform::PassContext::Current();
     auto cfg = ctx->GetConfig<TensorRTCompilerConfig>("relay.ext.tensorrt.options");
     if (!cfg.defined()) {
@@ -236,6 +281,28 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
   }
 };
 
+void CollectFromCompositeFunctionBody::VisitExpr_(const ConstantNode* constant_node) {
+  for (const auto& entry : serializer_->VisitExpr(GetRef<Constant>(constant_node))) {
+    args_.emplace_back(entry);
+  }
+}
+
+void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) {
+  const auto* op_node = call_node->op.as<OpNode>();
+  ICHECK(op_node != nullptr);
+  std::string name = op_node->name;
+  if (name == "nn.pad") {
+    SetPadNodeAttribute(call_node);
+  } else if (name == "strided_slice") {
+    SetStridedSliceNodeAttribute(call_node);
+  } else if (name == "split") {
+    SetSplitNodeAttribute(call_node);
+  } else {
+    SetGenericAttributes(call_node);
+  }
+  ExprVisitor::VisitExpr_(call_node);
+}
+
 /*!
  * \brief Create a runtime module for TensorRT.
  * \param ref The ext_func Relay expression/module to be executed using extern ops.
@@ -246,12 +313,15 @@ runtime::Module TensorRTCompiler(const ObjectRef& ref) {
   Function func = Downcast<Function>(ref);
   std::string func_name = backend::GetExtSymbol(func);
 
+  VLOG(1) << "TensorRT partition:" << std::endl << PrettyPrint(func);
   TensorRTJSONSerializer serializer(func_name, func);
   serializer.serialize();
   std::string graph_json = serializer.GetJSON();
+  VLOG(1) << "TensorRT JSON:" << std::endl << graph_json;
   auto param_names = serializer.GetParams();
   const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create");
   ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create function.";
+  VLOG(1) << "Creating tensorrt runtime::Module for '" << func_name << "'";
   runtime::Module lib = (*pf)(func_name, graph_json, param_names);
   return lib;
 }
diff --git a/src/relay/transforms/inline_composites.cc b/src/relay/transforms/inline_composites.cc
deleted file mode 100644
index daa82816dd..0000000000
--- a/src/relay/transforms/inline_composites.cc
+++ /dev/null
@@ -1,94 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file src/relay/transforms/inline_composites.cc
- * \brief Undo the partioned graphs originate from merge composite.
- */
-#include <tvm/relay/expr.h>
-#include <tvm/relay/expr_functor.h>
-#include <tvm/relay/transform.h>
-
-#include "../analysis/call_graph.h"
-#include "../op/call/call.h"
-
-using namespace tvm::runtime;
-
-namespace tvm {
-
-namespace relay {
-
-class CompositeInliner : public MixedModeMutator {
- public:
-  CompositeInliner() = default;
-
-  using MixedModeMutator::Rewrite_;
-
-  Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
-    const auto* post_call_node = post.as<CallNode>();
-    Call vanilla_post_call = GetAnyCall(post_call_node);
-    if (const auto* function_node = vanilla_post_call->op.as<FunctionNode>()) {
-      if (function_node->GetAttr(attr::kComposite, Optional<String>()).defined()) {
-        // Is a call to a literal function with the "Composite" attribute.
-        // Inline the function body.
-        Map<Var, Expr> bind_map;
-        for (size_t i = 0; i < vanilla_post_call->args.size(); i++) {
-          bind_map.Set(function_node->params[i], vanilla_post_call->args[i]);
-        }
-        return Bind(function_node->body, bind_map);
-      }
-    }
-    return post;
-  }
-
-  Function Inline(const Function& func) {
-    return WithFields(func, /*opt_params=*/{}, VisitExpr(func->body));
-  }
-};
-
-IRModule InlineComposites(const IRModule& module, runtime::String target) {
-  IRModule out_mod = module->ShallowCopy();
-  for (const auto& kv : module->functions) {
-    Optional<String> opt_compiler = kv.second->GetAttr(attr::kCompiler, Optional<String>());
-    if (const auto* function_node = kv.second.as<FunctionNode>()) {
-      if (opt_compiler.defined() && opt_compiler.value() == target) {
-        // Is a global function with the "Compiler" attribute matching the desired target.
-        // Inline all "Composite" function calls in the body.
-        out_mod->Add(kv.first, CompositeInliner().Inline(GetRef<Function>(function_node)));
-      }
-    }
-  }
-  return out_mod;
-}
-
-namespace transform {
-
-Pass InlineComposites(runtime::String target) {
-  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
-      [=](IRModule m, PassContext pc) { return relay::InlineComposites(m, target); };
-  return CreateModulePass(pass_func, 0, "InlineComposites", {});
-}
-
-TVM_REGISTER_GLOBAL("relay._transform.InlineComposites").set_body_typed(InlineComposites);
-
-}  // namespace transform
-
-}  // namespace relay
-
-}  // namespace tvm
diff --git a/src/runtime/contrib/json/json_node.h b/src/runtime/contrib/json/json_node.h
index 77c289b04c..1a8d09cbba 100644
--- a/src/runtime/contrib/json/json_node.h
+++ b/src/runtime/contrib/json/json_node.h
@@ -281,6 +281,12 @@ class JSONGraphNode {
    */
   bool HasAttr(const std::string& key) const { return attrs_.find(key) != attrs_.end(); }
 
+  void CaptureAttrs(const JSONGraphNode& that) {
+    for (const auto& kv : that.attrs_) {
+      attrs_[kv.first] = kv.second;
+    }
+  }
+
   virtual ~JSONGraphNode() {}
 
  private:
diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc
index 4f196265b5..5f923667d0 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc
@@ -71,6 +71,12 @@ TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger,
 #endif
 }
 
+nvinfer1::DataType DLDataType2NVDataType(DLDataType data_type) {
+  ICHECK(data_type.code == kDLFloat && (data_type.bits == 16 || data_type.bits == 32))
+      << "Invalid input Tensor type. Only float16 and float32 are supported";
+  return (data_type.bits == 16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT;
+}
+
 void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode& node) {
   auto node_name = node.GetOpName();
   auto shapes = node.GetOpShape();
@@ -85,13 +91,7 @@ void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode&
       shape.erase(shape.begin());
     }
     nvinfer1::Dims dims = VectorToTrtDims(shape);
-    ICHECK((dtypes[i].bits != 16 || dtypes[i].bits != 32))
-        << "Invalid input Tensor type. Float16 and Float32 are supported";
-
-    auto tensor_dtype =
-        (dtypes[i].bits == 16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT;
-
-    auto input_tensor = network_->addInput(name.c_str(), tensor_dtype, dims);
+    auto input_tensor = network_->addInput(name.c_str(), DLDataType2NVDataType(dtypes[i]), dims);
     node_output_map_[nid].push_back(TensorRTOpInput(input_tensor));
     network_input_names_.push_back(name);
     entry_id_map_[name] = entry_id + i;
@@ -124,40 +124,43 @@ void TensorRTBuilder::AddOutput(const JSONGraphNodeEntry& node, uint32_t entry_i
 }
 
 void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) {
-  TensorRTOpConverterParams params(network_, node, &trt_weights_);
+  TensorRTOpConverterParams params(network_, nid, node, &trt_weights_);
   // Look up converter.
-  auto it = GetOpConverters()->find(params.op_name);
-  ICHECK(it != GetOpConverters()->end())
-      << "Unsupported operator conversion to TRT, op name: " << params.op_name;
-  const auto converter = it->second;
+  const std::unordered_map<std::string, std::unique_ptr<TensorRTOpConverter>>& map =
+      GetOpConverters();
+  auto it = map.find(params.op_name);
+  ICHECK(it != map.end()) << params.op_name << ": Unsupported operator";
+  const TensorRTOpConverter& converter = *it->second;
+  if (!converter.variable_input_count) {
+    ICHECK_EQ(node.GetInputs().size(), converter.input_types.size())
+        << params.op_name << ": Mismatched input sizes";
+  }
   // Get inputs.
   for (size_t i = 0; i < node.GetInputs().size(); ++i) {
     auto in_node = node.GetInputs()[i];
     auto it = node_output_map_.find(in_node.id_);
-    ICHECK(it != node_output_map_.end()) << "Input was not found.";
+    ICHECK(it != node_output_map_.end()) << params.op_name << ": Input was not found";
     auto input = it->second[in_node.index_];
-    if (!converter->variable_input_count) {
-      if (converter->input_types[i] == kTensor && input.type == kWeight) {
+    if (!converter.variable_input_count) {
+      if (converter.input_types[i] == kTensor && input.type == kWeight) {
         input = TensorRTOpInput(GetInputAsTensor(input));
-      } else if (converter->input_types[i] == kWeight && input.type == kTensor) {
-        LOG(FATAL) << "Input " << i << " for " << params.op_name
-                   << " requires weights but got a tensor.";
+      } else if (converter.input_types[i] == kWeight && input.type == kTensor) {
+        LOG(FATAL) << params.op_name << ": Input " << i << " must be a constant.";
       }
     }
     params.inputs.push_back(input);
   }
 
   // Convert op to TRT.
-  converter->Convert(&params);
+  converter.Convert(&params);
 
   // Get outputs.
   node_output_map_[nid] = {};
-  for (auto out : params.outputs) {
-    auto out_type = params.inputs.at(1).weight.type == params.inputs.at(0).tensor->getType()
-                        ? params.inputs.at(0).tensor->getType()
-                        : params.inputs.at(1).weight.type;
-    out->setType(out_type);
-
+  std::vector<DLDataType> dtype = node.GetOpDataType();
+  ICHECK_EQ(params.outputs.size(), dtype.size()) << params.op_name << ": Mismatched output sizes";
+  for (size_t i = 0; i < params.outputs.size(); ++i) {
+    auto out = params.outputs[i];
+    out->setType(DLDataType2NVDataType(dtype[i]));
     node_output_map_[nid].push_back(TensorRTOpInput(out));
   }
 }
diff --git a/src/runtime/contrib/tensorrt/tensorrt_calibrator.h b/src/runtime/contrib/tensorrt/tensorrt_calibrator.h
index 58bfcc248f..523676b947 100755
--- a/src/runtime/contrib/tensorrt/tensorrt_calibrator.h
+++ b/src/runtime/contrib/tensorrt/tensorrt_calibrator.h
@@ -80,7 +80,7 @@ class TensorRTCalibrator : public nvinfer1::IInt8EntropyCalibrator2 {
     }
     num_batches_calibrated_++;
     // TODO(trevmorr): Free data from previous batch?
-    return (num_batches_calibrated_ < data_.size());
+    return (num_batches_calibrated_ < static_cast<int>(data_.size()));
   }
 
   const void* readCalibrationCache(size_t& length) noexcept override {
diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc
index e7e83bf984..3971081bf8 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc
@@ -29,6 +29,7 @@
 #include <memory>
 #include <string>
 #include <unordered_map>
+#include <utility>
 #include <vector>
 
 #include "../json/json_node.h"
@@ -39,9 +40,12 @@ namespace tvm {
 namespace runtime {
 namespace contrib {
 
-TensorRTOpConverter::TensorRTOpConverter(const std::vector<TensorRTInputType>& input_types,
+TensorRTOpConverter::TensorRTOpConverter(std::string op_name,
+                                         const std::vector<TensorRTInputType>& input_types,
                                          bool variable_input_count)
-    : input_types(input_types), variable_input_count(variable_input_count) {}
+    : op_name(std::move(op_name)),
+      input_types(input_types),
+      variable_input_count(variable_input_count) {}
 
 nvinfer1::ITensor* TensorRTOpConverter::Reshape(TensorRTOpConverterParams* params,
                                                 nvinfer1::ITensor* input,
@@ -156,7 +160,9 @@ void TensorRTOpConverter::GetPadding3D(const std::vector<std::string>& padding,
 
 class ActivationOpConverter : public TensorRTOpConverter {
  public:
-  ActivationOpConverter() : TensorRTOpConverter({kTensor}) {}
+  explicit ActivationOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor}) {}
+  ~ActivationOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     static const std::unordered_map<std::string, nvinfer1::ActivationType> op_map = {
@@ -168,17 +174,17 @@ class ActivationOpConverter : public TensorRTOpConverter {
       {"nn.leaky_relu", nvinfer1::ActivationType::kLEAKY_RELU},
 #endif
     };
-    auto it = op_map.find(params->op_name);
-    ICHECK(it != op_map.end()) << "Unsupported activation type " << params->op_name;
+    auto it = op_map.find(op_name);
+    ICHECK(it != op_map.end()) << "Unsupported activation type " << op_name;
     nvinfer1::IActivationLayer* act_layer =
         params->network->addActivation(*params->inputs.at(0).tensor, it->second);
 #if TRT_VERSION_GE(5, 1, 5)
-    if (params->op_name == "clip") {
+    if (op_name == "clip") {
       float a_min = std::stof(params->node.GetAttr<std::vector<std::string>>("a_min")[0]);
       float a_max = std::stof(params->node.GetAttr<std::vector<std::string>>("a_max")[0]);
       act_layer->setAlpha(a_min);
       act_layer->setBeta(a_max);
-    } else if (params->op_name == "nn.leaky_relu") {
+    } else if (op_name == "nn.leaky_relu") {
       float alpha = std::stof(params->node.GetAttr<std::vector<std::string>>("alpha")[0]);
       act_layer->setAlpha(alpha);
     }
@@ -190,7 +196,9 @@ class ActivationOpConverter : public TensorRTOpConverter {
 
 class ElementWiseBinaryOpConverter : public TensorRTOpConverter {
  public:
-  ElementWiseBinaryOpConverter() : TensorRTOpConverter({kTensor, kTensor}) {}
+  explicit ElementWiseBinaryOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor, kTensor}) {}
+  ~ElementWiseBinaryOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     static const std::unordered_map<std::string, nvinfer1::ElementWiseOperation> op_map = {
@@ -201,8 +209,8 @@ class ElementWiseBinaryOpConverter : public TensorRTOpConverter {
         {"power", nvinfer1::ElementWiseOperation::kPOW},
         {"maximum", nvinfer1::ElementWiseOperation::kMAX},
         {"minimum", nvinfer1::ElementWiseOperation::kMIN}};
-    auto it = op_map.find(params->op_name);
-    ICHECK(it != op_map.end()) << "Unsupported elementwise type " << params->op_name;
+    auto it = op_map.find(op_name);
+    ICHECK(it != op_map.end()) << "Unsupported elementwise type " << op_name;
     // Broadcast
     auto input0 = params->inputs.at(0).tensor;
     auto input0_dims = TrtDimsToVector(input0->getDimensions());
@@ -230,7 +238,9 @@ class ElementWiseBinaryOpConverter : public TensorRTOpConverter {
 
 class Conv1DOpConverter : public TensorRTOpConverter {
  public:
-  Conv1DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {}
+  explicit Conv1DOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor, kWeight}) {}
+  ~Conv1DOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input_tensor = params->inputs.at(0).tensor;
@@ -281,7 +291,9 @@ class Conv1DOpConverter : public TensorRTOpConverter {
 
 class Conv2DOpConverter : public TensorRTOpConverter {
  public:
-  Conv2DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {}
+  explicit Conv2DOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor, kWeight}) {}
+  ~Conv2DOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input_tensor = params->inputs.at(0).tensor;
@@ -322,6 +334,7 @@ class Conv2DOpConverter : public TensorRTOpConverter {
     auto conv_layer = params->network->addConvolution(*input_tensor, channels, kernel_size,
                                                       params->inputs.at(1).weight, bias);
     ICHECK(conv_layer != nullptr);
+    conv_layer->setName(params->LayerName().c_str());
     if (use_asymmetric_padding) {
 #if TRT_VERSION_GE(5, 1, 5)
       conv_layer->setPrePadding(prepadding);
@@ -344,7 +357,9 @@ class Conv2DOpConverter : public TensorRTOpConverter {
 #if TRT_VERSION_GE(6, 0, 1)
 class Conv3DOpConverter : public TensorRTOpConverter {
  public:
-  Conv3DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {}
+  explicit Conv3DOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor, kWeight}) {}
+  ~Conv3DOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input_tensor = params->inputs.at(0).tensor;
@@ -393,7 +408,9 @@ class Conv3DOpConverter : public TensorRTOpConverter {
 
 class DenseOpConverter : public TensorRTOpConverter {
  public:
-  DenseOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {}
+  explicit DenseOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor, kWeight}) {}
+  ~DenseOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input_tensor = params->inputs.at(0).tensor;
@@ -427,7 +444,9 @@ class DenseOpConverter : public TensorRTOpConverter {
 
 class BatchNormOpConverter : public TensorRTOpConverter {
  public:
-  BatchNormOpConverter() : TensorRTOpConverter({kTensor, kWeight, kWeight, kWeight, kWeight}) {}
+  explicit BatchNormOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor, kWeight, kWeight, kWeight, kWeight}) {}
+  ~BatchNormOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input = params->inputs.at(0).tensor;
@@ -524,7 +543,9 @@ class BatchNormOpConverter : public TensorRTOpConverter {
 
 class LayerNormOpConverter : public TensorRTOpConverter {
  public:
-  LayerNormOpConverter() : TensorRTOpConverter({kTensor, kWeight, kWeight}) {}
+  explicit LayerNormOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor, kWeight, kWeight}) {}
+  ~LayerNormOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input = params->inputs.at(0).tensor;
@@ -596,7 +617,9 @@ class LayerNormOpConverter : public TensorRTOpConverter {
 
 class BatchFlattenOpConverter : public TensorRTOpConverter {
  public:
-  BatchFlattenOpConverter() : TensorRTOpConverter({kTensor}) {}
+  explicit BatchFlattenOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor}) {}
+  ~BatchFlattenOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     std::vector<int> new_shape{-1};
@@ -609,7 +632,9 @@ class BatchFlattenOpConverter : public TensorRTOpConverter {
 
 class SoftmaxOpConverter : public TensorRTOpConverter {
  public:
-  SoftmaxOpConverter() : TensorRTOpConverter({kTensor}) {}
+  explicit SoftmaxOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor}) {}
+  ~SoftmaxOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input = params->inputs.at(0).tensor;
@@ -625,15 +650,17 @@ class SoftmaxOpConverter : public TensorRTOpConverter {
 
 class PoolingOpConverter : public TensorRTOpConverter {
  public:
-  PoolingOpConverter() : TensorRTOpConverter({kTensor}) {}
+  explicit PoolingOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor}) {}
+  ~PoolingOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input = params->inputs.at(0).tensor;
     static const std::unordered_map<std::string, nvinfer1::PoolingType> op_map = {
         {"nn.max_pool2d", nvinfer1::PoolingType::kMAX},
         {"nn.avg_pool2d", nvinfer1::PoolingType::kAVERAGE}};
-    auto it = op_map.find(params->op_name);
-    ICHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name << " in TensorRT";
+    auto it = op_map.find(op_name);
+    ICHECK(it != op_map.end()) << "Unsupported pooling type " << op_name << " in TensorRT";
     ICHECK_EQ(params->node.GetAttr<std::vector<std::string>>("layout")[0], "NCHW");
     auto str_pool_size = params->node.GetAttr<std::vector<std::string>>("pool_size");
     auto str_padding = params->node.GetAttr<std::vector<std::string>>("padding");
@@ -671,7 +698,7 @@ class PoolingOpConverter : public TensorRTOpConverter {
     } else {
       pool_layer->setPadding(prepadding);
     }
-    if (params->op_name == "nn.avg_pool2d") {
+    if (op_name == "nn.avg_pool2d") {
       bool count_include_pad =
           std::stoi(params->node.GetAttr<std::vector<std::string>>("count_include_pad")[0]);
       // count_include_pad=True is useless if there is no padding. TRT doesn't
@@ -698,15 +725,17 @@ class PoolingOpConverter : public TensorRTOpConverter {
 #if TRT_VERSION_GE(6, 0, 1)
 class Pooling3DOpConverter : public TensorRTOpConverter {
  public:
-  Pooling3DOpConverter() : TensorRTOpConverter({kTensor}) {}
+  explicit Pooling3DOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor}) {}
+  ~Pooling3DOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input = params->inputs.at(0).tensor;
     static const std::unordered_map<std::string, nvinfer1::PoolingType> op_map = {
         {"nn.max_pool3d", nvinfer1::PoolingType::kMAX},
         {"nn.avg_pool3d", nvinfer1::PoolingType::kAVERAGE}};
-    auto it = op_map.find(params->op_name);
-    ICHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name << " in TensorRT";
+    auto it = op_map.find(op_name);
+    ICHECK(it != op_map.end()) << "Unsupported pooling type " << op_name << " in TensorRT";
     ICHECK_EQ(params->node.GetAttr<std::vector<std::string>>("layout")[0], "NCDHW");
     auto str_pool_size = params->node.GetAttr<std::vector<std::string>>("pool_size");
     auto str_padding = params->node.GetAttr<std::vector<std::string>>("padding");
@@ -728,7 +757,7 @@ class Pooling3DOpConverter : public TensorRTOpConverter {
     } else {
       pool_layer->setPaddingNd(prepadding);
     }
-    if (params->op_name == "nn.avg_pool3d") {
+    if (op_name == "nn.avg_pool3d") {
       bool count_include_pad =
           std::stoi(params->node.GetAttr<std::vector<std::string>>("count_include_pad")[0]);
       pool_layer->setAverageCountExcludesPadding(!count_include_pad);
@@ -743,7 +772,9 @@ class Pooling3DOpConverter : public TensorRTOpConverter {
 
 class GlobalPoolingOpConverter : public TensorRTOpConverter {
  public:
-  GlobalPoolingOpConverter() : TensorRTOpConverter({kTensor}) {}
+  explicit GlobalPoolingOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor}) {}
+  ~GlobalPoolingOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input_tensor = params->inputs.at(0).tensor;
@@ -751,8 +782,8 @@ class GlobalPoolingOpConverter : public TensorRTOpConverter {
     static const std::unordered_map<std::string, nvinfer1::PoolingType> op_map = {
         {"nn.global_max_pool2d", nvinfer1::PoolingType::kMAX},
         {"nn.global_avg_pool2d", nvinfer1::PoolingType::kAVERAGE}};
-    auto it = op_map.find(params->op_name);
-    ICHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name << " in TensorRT";
+    auto it = op_map.find(op_name);
+    ICHECK(it != op_map.end()) << "Unsupported pooling type " << op_name << " in TensorRT";
     ICHECK_EQ(params->node.GetAttr<std::vector<std::string>>("layout")[0], "NCHW");
     const int h = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[1] : input_dims[2];
     const int w = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[2] : input_dims[3];
@@ -765,7 +796,9 @@ class GlobalPoolingOpConverter : public TensorRTOpConverter {
 
 class ExpandDimsOpConverter : public TensorRTOpConverter {
  public:
-  ExpandDimsOpConverter() : TensorRTOpConverter({kTensor}) {}
+  explicit ExpandDimsOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor}) {}
+  ~ExpandDimsOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input_tensor = params->inputs.at(0).tensor;
@@ -783,7 +816,9 @@ class ExpandDimsOpConverter : public TensorRTOpConverter {
 
 class SqueezeOpConverter : public TensorRTOpConverter {
  public:
-  SqueezeOpConverter() : TensorRTOpConverter({kTensor}) {}
+  explicit SqueezeOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor}) {}
+  ~SqueezeOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input_tensor = params->inputs.at(0).tensor;
@@ -800,7 +835,9 @@ class SqueezeOpConverter : public TensorRTOpConverter {
 
 class UnaryOpConverter : public TensorRTOpConverter {
  public:
-  UnaryOpConverter() : TensorRTOpConverter({kTensor}) {}
+  explicit UnaryOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor}) {}
+  ~UnaryOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     // The following ops are supported by TRT but don't exist in relay yet:
@@ -822,8 +859,8 @@ class UnaryOpConverter : public TensorRTOpConverter {
       {"erf", nvinfer1::UnaryOperation::kERF},
 #endif
     };
-    auto it = op_map.find(params->op_name);
-    ICHECK(it != op_map.end()) << "Unsupported unary type " << params->op_name;
+    auto it = op_map.find(op_name);
+    ICHECK(it != op_map.end()) << "Unsupported unary type " << op_name;
     nvinfer1::IUnaryLayer* unary_layer =
         params->network->addUnary(*params->inputs.at(0).tensor, it->second);
     ICHECK(unary_layer != nullptr);
@@ -833,7 +870,9 @@ class UnaryOpConverter : public TensorRTOpConverter {
 
 class ConcatOpConverter : public TensorRTOpConverter {
  public:
-  ConcatOpConverter() : TensorRTOpConverter({}, /*variable_input_count=*/true) {}
+  explicit ConcatOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {}, /*variable_input_count=*/true) {}
+  ~ConcatOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     const int num_inputs = params->inputs.size();
@@ -860,7 +899,9 @@ class ConcatOpConverter : public TensorRTOpConverter {
 #if TRT_VERSION_GE(5, 1, 5)
 class SplitOpConverter : public TensorRTOpConverter {
  public:
-  SplitOpConverter() : TensorRTOpConverter({kTensor}) {}
+  explicit SplitOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor}) {}
+  ~SplitOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input = params->inputs.at(0).tensor;
@@ -908,7 +949,9 @@ class SplitOpConverter : public TensorRTOpConverter {
 
 class BiasAddOpConverter : public TensorRTOpConverter {
  public:
-  BiasAddOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {}
+  explicit BiasAddOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor, kWeight}) {}
+  ~BiasAddOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input_tensor = params->inputs.at(0).tensor;
@@ -941,7 +984,9 @@ class BiasAddOpConverter : public TensorRTOpConverter {
 
 class Conv2DTransposeOpConverter : public TensorRTOpConverter {
  public:
-  Conv2DTransposeOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {}
+  explicit Conv2DTransposeOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor, kWeight}) {}
+  ~Conv2DTransposeOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input_tensor = params->inputs.at(0).tensor;
@@ -1011,7 +1056,9 @@ class Conv2DTransposeOpConverter : public TensorRTOpConverter {
 #if TRT_VERSION_GE(6, 0, 1)
 class Conv3DTransposeOpConverter : public TensorRTOpConverter {
  public:
-  Conv3DTransposeOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {}
+  explicit Conv3DTransposeOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor, kWeight}) {}
+  ~Conv3DTransposeOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input_tensor = params->inputs.at(0).tensor;
@@ -1067,7 +1114,9 @@ class Conv3DTransposeOpConverter : public TensorRTOpConverter {
 
 class TransposeOpConverter : public TensorRTOpConverter {
  public:
-  TransposeOpConverter() : TensorRTOpConverter({kTensor}) {}
+  explicit TransposeOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor}) {}
+  ~TransposeOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input = params->inputs.at(0).tensor;
@@ -1082,7 +1131,9 @@ class TransposeOpConverter : public TensorRTOpConverter {
 
 class LayoutTransformOpConverter : public TensorRTOpConverter {
  public:
-  LayoutTransformOpConverter() : TensorRTOpConverter({kTensor}) {}
+  explicit LayoutTransformOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor}) {}
+  ~LayoutTransformOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input = params->inputs.at(0).tensor;
@@ -1104,13 +1155,17 @@ class LayoutTransformOpConverter : public TensorRTOpConverter {
 
 class ReshapeOpConverter : public TensorRTOpConverter {
  public:
-  ReshapeOpConverter() : TensorRTOpConverter({kTensor}) {}
+  explicit ReshapeOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor}) {}
+  ~ReshapeOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input = params->inputs.at(0).tensor;
+    auto input_dims = TrtDimsToVector(input->getDimensions());
     auto str_newshape = params->node.GetAttr<std::vector<std::string>>("newshape");
     std::vector<int> new_shape;
-    const int start_index = TRT_HAS_IMPLICIT_BATCH(params) ? 1 : 0;
+    int start_index = TRT_HAS_IMPLICIT_BATCH(params) ? 1 : 0;
+    if (std::stoi(str_newshape[0]) == -1) start_index = 0;
     for (size_t i = start_index; i < str_newshape.size(); ++i) {
       const int value = std::stoi(str_newshape[i]);
       ICHECK_GE(value, -1);
@@ -1122,7 +1177,9 @@ class ReshapeOpConverter : public TensorRTOpConverter {
 
 class PadOpConverter : public TensorRTOpConverter {
  public:
-  PadOpConverter() : TensorRTOpConverter({kTensor}) {}
+  explicit PadOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor, kIgnored}) {}
+  ~PadOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input = params->inputs.at(0).tensor;
@@ -1138,7 +1195,9 @@ class PadOpConverter : public TensorRTOpConverter {
 
 class ReduceOpConverter : public TensorRTOpConverter {
  public:
-  ReduceOpConverter() : TensorRTOpConverter({kTensor}) {}
+  explicit ReduceOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor}) {}
+  ~ReduceOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     static const std::unordered_map<std::string, nvinfer1::ReduceOperation> op_map = {
@@ -1147,8 +1206,8 @@ class ReduceOpConverter : public TensorRTOpConverter {
         {"max", nvinfer1::ReduceOperation::kMAX},
         {"min", nvinfer1::ReduceOperation::kMIN},
         {"mean", nvinfer1::ReduceOperation::kAVG}};
-    auto it = op_map.find(params->op_name);
-    ICHECK(it != op_map.end()) << "Unsupported reduce type " << params->op_name;
+    auto it = op_map.find(op_name);
+    ICHECK(it != op_map.end()) << "Unsupported reduce type " << op_name;
 
     auto input = params->inputs.at(0).tensor;
     ICHECK_EQ(std::stoi(params->node.GetAttr<std::vector<std::string>>("exclude")[0]), false);
@@ -1177,7 +1236,9 @@ class ReduceOpConverter : public TensorRTOpConverter {
 #if TRT_VERSION_GE(5, 1, 5)
 class StridedSliceOpConverter : public TensorRTOpConverter {
  public:
-  StridedSliceOpConverter() : TensorRTOpConverter({kTensor}) {}
+  explicit StridedSliceOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor}) {}
+  ~StridedSliceOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input = params->inputs.at(0).tensor;
@@ -1206,7 +1267,9 @@ class StridedSliceOpConverter : public TensorRTOpConverter {
 
 class AdaptivePoolingOpConverter : public TensorRTOpConverter {
  public:
-  AdaptivePoolingOpConverter() : TensorRTOpConverter({kTensor}) {}
+  explicit AdaptivePoolingOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor}) {}
+  ~AdaptivePoolingOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input_tensor = params->inputs.at(0).tensor;
@@ -1214,8 +1277,8 @@ class AdaptivePoolingOpConverter : public TensorRTOpConverter {
     static const std::unordered_map<std::string, nvinfer1::PoolingType> op_map = {
         {"nn.adaptive_max_pool2d", nvinfer1::PoolingType::kMAX},
         {"nn.adaptive_avg_pool2d", nvinfer1::PoolingType::kAVERAGE}};
-    auto it = op_map.find(params->op_name);
-    ICHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name << " in TensorRT";
+    auto it = op_map.find(op_name);
+    ICHECK(it != op_map.end()) << "Unsupported pooling type " << op_name << " in TensorRT";
     ICHECK_EQ(params->node.GetAttr<std::vector<std::string>>("layout")[0], "NCHW");
 
     // This is an approximation of adaptive pooling. Results will not be
@@ -1236,7 +1299,9 @@ class AdaptivePoolingOpConverter : public TensorRTOpConverter {
 
 class BatchMatmulOpConverter : public TensorRTOpConverter {
  public:
-  BatchMatmulOpConverter() : TensorRTOpConverter({kTensor, kTensor}) {}
+  explicit BatchMatmulOpConverter(std::string op_name)
+      : TensorRTOpConverter(std::move(op_name), {kTensor, kTensor}) {}
+  ~BatchMatmulOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto transa = std::stoi(params->node.GetAttr<std::vector<std::string>>("transpose_a")[0]);
@@ -1252,75 +1317,84 @@ class BatchMatmulOpConverter : public TensorRTOpConverter {
   }
 };
 
-const std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<TensorRTOpConverter>>>
-GetOpConverters() {
-  static auto map =
-      std::make_shared<std::unordered_map<std::string, std::shared_ptr<TensorRTOpConverter>>>();
-  if (!map->empty()) return map;
-  map->emplace("nn.relu", std::make_shared<ActivationOpConverter>());
-  map->emplace("sigmoid", std::make_shared<ActivationOpConverter>());
-  map->emplace("tanh", std::make_shared<ActivationOpConverter>());
-  map->emplace("nn.batch_norm", std::make_shared<BatchNormOpConverter>());
-  map->emplace("nn.layer_norm", std::make_shared<LayerNormOpConverter>());
-  map->emplace("nn.softmax", std::make_shared<SoftmaxOpConverter>());
-  map->emplace("nn.conv1d", std::make_shared<Conv1DOpConverter>());
-  map->emplace("nn.conv2d", std::make_shared<Conv2DOpConverter>());
-  map->emplace("nn.dense", std::make_shared<DenseOpConverter>());
-  map->emplace("nn.bias_add", std::make_shared<BiasAddOpConverter>());
-  map->emplace("add", std::make_shared<ElementWiseBinaryOpConverter>());
-  map->emplace("subtract", std::make_shared<ElementWiseBinaryOpConverter>());
-  map->emplace("multiply", std::make_shared<ElementWiseBinaryOpConverter>());
-  map->emplace("divide", std::make_shared<ElementWiseBinaryOpConverter>());
-  map->emplace("power", std::make_shared<ElementWiseBinaryOpConverter>());
-  map->emplace("maximum", std::make_shared<ElementWiseBinaryOpConverter>());
-  map->emplace("minimum", std::make_shared<ElementWiseBinaryOpConverter>());
-  map->emplace("nn.max_pool2d", std::make_shared<PoolingOpConverter>());
-  map->emplace("nn.avg_pool2d", std::make_shared<PoolingOpConverter>());
-  map->emplace("nn.global_max_pool2d", std::make_shared<GlobalPoolingOpConverter>());
-  map->emplace("nn.global_avg_pool2d", std::make_shared<GlobalPoolingOpConverter>());
-  map->emplace("exp", std::make_shared<UnaryOpConverter>());
-  map->emplace("log", std::make_shared<UnaryOpConverter>());
-  map->emplace("sqrt", std::make_shared<UnaryOpConverter>());
-  map->emplace("abs", std::make_shared<UnaryOpConverter>());
-  map->emplace("negative", std::make_shared<UnaryOpConverter>());
-  map->emplace("nn.batch_flatten", std::make_shared<BatchFlattenOpConverter>());
-  map->emplace("expand_dims", std::make_shared<ExpandDimsOpConverter>());
-  map->emplace("squeeze", std::make_shared<SqueezeOpConverter>());
-  map->emplace("concatenate", std::make_shared<ConcatOpConverter>());
-  map->emplace("nn.conv2d_transpose", std::make_shared<Conv2DTransposeOpConverter>());
-  map->emplace("transpose", std::make_shared<TransposeOpConverter>());
-  map->emplace("layout_transform", std::make_shared<LayoutTransformOpConverter>());
-  map->emplace("reshape", std::make_shared<ReshapeOpConverter>());
-  map->emplace("nn.pad", std::make_shared<PadOpConverter>());
-  map->emplace("sum", std::make_shared<ReduceOpConverter>());
-  map->emplace("prod", std::make_shared<ReduceOpConverter>());
-  map->emplace("max", std::make_shared<ReduceOpConverter>());
-  map->emplace("min", std::make_shared<ReduceOpConverter>());
-  map->emplace("mean", std::make_shared<ReduceOpConverter>());
-  map->emplace("nn.adaptive_max_pool2d", std::make_shared<AdaptivePoolingOpConverter>());
-  map->emplace("nn.adaptive_avg_pool2d", std::make_shared<AdaptivePoolingOpConverter>());
-  map->emplace("nn.batch_matmul", std::make_shared<BatchMatmulOpConverter>());
+const std::unordered_map<std::string, std::unique_ptr<TensorRTOpConverter>>& GetOpConverters() {
+  static const std::unordered_map<std::string, std::unique_ptr<TensorRTOpConverter>>* map = []() {
+    std::vector<std::unique_ptr<TensorRTOpConverter>> all_converters;
+    all_converters.emplace_back(std::make_unique<ActivationOpConverter>("nn.relu"));
+    all_converters.emplace_back(std::make_unique<ActivationOpConverter>("sigmoid"));
+    all_converters.emplace_back(std::make_unique<ActivationOpConverter>("tanh"));
+    all_converters.emplace_back(std::make_unique<BatchNormOpConverter>("nn.batch_norm"));
+    all_converters.emplace_back(std::make_unique<LayerNormOpConverter>("nn.layer_norm"));
+    all_converters.emplace_back(std::make_unique<SoftmaxOpConverter>("nn.softmax"));
+    all_converters.emplace_back(std::make_unique<Conv1DOpConverter>("nn.conv1d"));
+    all_converters.emplace_back(std::make_unique<Conv2DOpConverter>("nn.conv2d"));
+    all_converters.emplace_back(std::make_unique<DenseOpConverter>("nn.dense"));
+    all_converters.emplace_back(std::make_unique<BatchMatmulOpConverter>("nn.batch_matmul"));
+    all_converters.emplace_back(std::make_unique<BiasAddOpConverter>("nn.bias_add"));
+    all_converters.emplace_back(std::make_unique<ElementWiseBinaryOpConverter>("add"));
+    all_converters.emplace_back(std::make_unique<ElementWiseBinaryOpConverter>("subtract"));
+    all_converters.emplace_back(std::make_unique<ElementWiseBinaryOpConverter>("multiply"));
+    all_converters.emplace_back(std::make_unique<ElementWiseBinaryOpConverter>("divide"));
+    all_converters.emplace_back(std::make_unique<ElementWiseBinaryOpConverter>("power"));
+    all_converters.emplace_back(std::make_unique<ElementWiseBinaryOpConverter>("maximum"));
+    all_converters.emplace_back(std::make_unique<ElementWiseBinaryOpConverter>("minimum"));
+    all_converters.emplace_back(std::make_unique<PoolingOpConverter>("nn.max_pool2d"));
+    all_converters.emplace_back(std::make_unique<PoolingOpConverter>("nn.avg_pool2d"));
+    all_converters.emplace_back(std::make_unique<GlobalPoolingOpConverter>("nn.global_max_pool2d"));
+    all_converters.emplace_back(std::make_unique<GlobalPoolingOpConverter>("nn.global_avg_pool2d"));
+    all_converters.emplace_back(std::make_unique<UnaryOpConverter>("exp"));
+    all_converters.emplace_back(std::make_unique<UnaryOpConverter>("log"));
+    all_converters.emplace_back(std::make_unique<UnaryOpConverter>("sqrt"));
+    all_converters.emplace_back(std::make_unique<UnaryOpConverter>("abs"));
+    all_converters.emplace_back(std::make_unique<UnaryOpConverter>("negative"));
+    all_converters.emplace_back(std::make_unique<BatchFlattenOpConverter>("nn.batch_flatten"));
+    all_converters.emplace_back(std::make_unique<ExpandDimsOpConverter>("expand_dims"));
+    all_converters.emplace_back(std::make_unique<SqueezeOpConverter>("squeeze"));
+    all_converters.emplace_back(std::make_unique<ConcatOpConverter>("concatenate"));
+    all_converters.emplace_back(
+        std::make_unique<Conv2DTransposeOpConverter>("nn.conv2d_transpose"));
+    all_converters.emplace_back(std::make_unique<TransposeOpConverter>("transpose"));
+    all_converters.emplace_back(std::make_unique<LayoutTransformOpConverter>("layout_transform"));
+    all_converters.emplace_back(std::make_unique<ReshapeOpConverter>("reshape"));
+    all_converters.emplace_back(std::make_unique<PadOpConverter>("nn.pad"));
+    all_converters.emplace_back(std::make_unique<ReduceOpConverter>("sum"));
+    all_converters.emplace_back(std::make_unique<ReduceOpConverter>("prod"));
+    all_converters.emplace_back(std::make_unique<ReduceOpConverter>("max"));
+    all_converters.emplace_back(std::make_unique<ReduceOpConverter>("min"));
+    all_converters.emplace_back(std::make_unique<ReduceOpConverter>("mean"));
+    all_converters.emplace_back(
+        std::make_unique<AdaptivePoolingOpConverter>("nn.adaptive_max_pool2d"));
+    all_converters.emplace_back(
+        std::make_unique<AdaptivePoolingOpConverter>("nn.adaptive_avg_pool2d"));
+    all_converters.emplace_back(std::make_unique<BatchMatmulOpConverter>("nn.batch_matmul"));
 #if TRT_VERSION_GE(5, 1, 5)
-  map->emplace("clip", std::make_shared<ActivationOpConverter>());
-  map->emplace("nn.leaky_relu", std::make_shared<ActivationOpConverter>());
-  map->emplace("sin", std::make_shared<UnaryOpConverter>());
-  map->emplace("cos", std::make_shared<UnaryOpConverter>());
-  map->emplace("atan", std::make_shared<UnaryOpConverter>());
-  map->emplace("ceil", std::make_shared<UnaryOpConverter>());
-  map->emplace("floor", std::make_shared<UnaryOpConverter>());
-  map->emplace("split", std::make_shared<SplitOpConverter>());
-  map->emplace("strided_slice", std::make_shared<StridedSliceOpConverter>());
+    all_converters.emplace_back(std::make_unique<ActivationOpConverter>("clip"));
+    all_converters.emplace_back(std::make_unique<ActivationOpConverter>("nn.leaky_relu"));
+    all_converters.emplace_back(std::make_unique<UnaryOpConverter>("sin"));
+    all_converters.emplace_back(std::make_unique<UnaryOpConverter>("cos"));
+    all_converters.emplace_back(std::make_unique<UnaryOpConverter>("atan"));
+    all_converters.emplace_back(std::make_unique<UnaryOpConverter>("ceil"));
+    all_converters.emplace_back(std::make_unique<UnaryOpConverter>("floor"));
+    all_converters.emplace_back(std::make_unique<SplitOpConverter>("split"));
+    all_converters.emplace_back(std::make_unique<StridedSliceOpConverter>("strided_slice"));
 #endif  // TRT_VERSION_GE(5, 1, 5)
 #if TRT_VERSION_GE(6, 0, 1)
-  map->emplace("nn.conv3d", std::make_shared<Conv3DOpConverter>());
-  map->emplace("nn.max_pool3d", std::make_shared<Pooling3DOpConverter>());
-  map->emplace("nn.avg_pool3d", std::make_shared<Pooling3DOpConverter>());
-  map->emplace("nn.conv3d_transpose", std::make_shared<Conv3DTransposeOpConverter>());
+    all_converters.emplace_back(std::make_unique<Conv3DOpConverter>("nn.conv3d"));
+    all_converters.emplace_back(std::make_unique<Pooling3DOpConverter>("nn.max_pool3d"));
+    all_converters.emplace_back(std::make_unique<Pooling3DOpConverter>("nn.avg_pool3d"));
+    all_converters.emplace_back(
+        std::make_unique<Conv3DTransposeOpConverter>("nn.conv3d_transpose"));
 #endif  // TRT_VERSION_GE(6, 0, 1)
 #if TRT_VERSION_GE(7, 0, 0)
-  map->emplace("erf", std::make_shared<UnaryOpConverter>());
+    all_converters.emplace_back(std::make_unique<UnaryOpConverter>("erf"));
 #endif  // TRT_VERSION_GE(7, 0, 0)
-  return map;
+    auto* map = new std::unordered_map<std::string, std::unique_ptr<TensorRTOpConverter>>();
+    for (auto& converter : all_converters) {
+      map->emplace("tensorrt." + converter->op_name, std::move(converter));
+    }
+    return map;
+  }();
+  return *map;
 }
 
 }  // namespace contrib
diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.h b/src/runtime/contrib/tensorrt/tensorrt_ops.h
index b71dec00c9..e2ef341b4a 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_ops.h
+++ b/src/runtime/contrib/tensorrt/tensorrt_ops.h
@@ -49,13 +49,10 @@ namespace contrib {
 using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
 
 /*!
- * \brief An input to a op may be either kTensor in the case of nvinfer::ITensor*
- * or kWeight for nvinfer1::Weights.
+ * \brief An input to a op may be either kTensor in the case of nvinfer::ITensor*,
+ * a kWeight for nvinfer1::Weights, or ignored (eg for the nn.pad value).
  */
-enum TensorRTInputType {
-  kTensor,
-  kWeight,
-};
+enum TensorRTInputType { kTensor, kWeight, kIgnored };
 
 /*!
  * \brief An input to a TensorRTOpConverter. The type of the input is either kTensor
@@ -85,7 +82,9 @@ struct TensorRTOpInput {
 struct TensorRTOpConverterParams {
   /*! \brief The TRT network that the new layer should be added to. */
   nvinfer1::INetworkDefinition* network;
-  /*! \brief The corresponding serialized node. */
+  /*! \brief Index of JSON node. */
+  int nid;
+  /*! \brief The corresponding JSON node. */
   const JSONGraphNode& node;
   /*! \brief The type of op. */
   std::string op_name;
@@ -96,20 +95,25 @@ struct TensorRTOpConverterParams {
   /*! \brief Any newly allocated weights should be stored here also. */
   std::vector<nvinfer1::Weights>* trt_weights;
 
-  TensorRTOpConverterParams(nvinfer1::INetworkDefinition* network, const JSONGraphNode& node,
-                            std::vector<nvinfer1::Weights>* trt_weights)
-      : network(network), node(node), trt_weights(trt_weights) {
+  TensorRTOpConverterParams(nvinfer1::INetworkDefinition* network, int nid,
+                            const JSONGraphNode& node, std::vector<nvinfer1::Weights>* trt_weights)
+      : network(network), nid(nid), node(node), trt_weights(trt_weights) {
     op_name = node.GetOpName();
   }
+
+  std::string LayerName() const { return op_name + "(" + std::to_string(nid) + ")"; }
 };
 
 /*! \brief Base class for an op converter from Relay to TRT. */
 class TensorRTOpConverter {
  public:
+  virtual ~TensorRTOpConverter() = default;
+
+  /*! \brief Operator name. */
+  std::string op_name;
   /*! \brief Used to specify whether each input is tensor or weight. */
   const std::vector<TensorRTInputType> input_types;
-  /*! \brief If set to true, any number of tensor inputs can be used for the op.
-   */
+  /*! \brief If set to true, any number of tensor inputs can be used for the op. */
   const bool variable_input_count;
 
   /*!
@@ -123,8 +127,8 @@ class TensorRTOpConverter {
    * true. input_types vector will be ignored and any number of input tensors
    * can be used for this op. All inputs will be tensors and not weights.
    */
-  explicit TensorRTOpConverter(const std::vector<TensorRTInputType>& input_types,
-                               bool variable_input_count = false);
+  TensorRTOpConverter(std::string op_name, const std::vector<TensorRTInputType>& input_types,
+                      bool variable_input_count = false);
 
   /*!
    * \brief Convert to TRT. Implementation should use inputs and attributes
@@ -197,8 +201,7 @@ class TensorRTOpConverter {
  * \brief Get the map of available TensorRTOpConverters, where the key is the name of the relay op.
  * \return Map of TensorRTOpConverters.
  */
-const std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<TensorRTOpConverter>>>
-GetOpConverters();
+const std::unordered_map<std::string, std::unique_ptr<TensorRTOpConverter>>& GetOpConverters();
 
 }  // namespace contrib
 }  // namespace runtime
diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
index 814d96863b..b60074e66d 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
@@ -127,7 +127,9 @@ class TensorRTRuntime : public JSONRuntimeBase {
           max_workspace_size_ =
               std::stoul(nodes_[i].GetAttr<std::vector<std::string>>("max_workspace_size")[0]);
         }
-        return;
+      }
+      if (nodes_[i].HasAttr("use_fp16")) {
+        use_fp16_ = std::stoi(nodes_[i].GetAttr<std::vector<std::string>>("use_fp16")[0]);
       }
     }
   }
@@ -300,8 +302,8 @@ class TensorRTRuntime : public JSONRuntimeBase {
       }
     }
 
-    LOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_
-              << " with batch size " << batch_size;
+    VLOG(1) << "Finished building TensorRT engine for subgraph " << symbol_name_
+            << " with batch size " << batch_size;
     CacheEngineToDisk();
     return trt_engine_cache_.at(std::make_pair(symbol_name_, batch_size));
   }
diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py
index 4e6f2421b5..4e6aab14c0 100644
--- a/tests/python/contrib/test_tensorrt.py
+++ b/tests/python/contrib/test_tensorrt.py
@@ -15,30 +15,21 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm.testing
-from curses import tparm
-from unittest import result
 import numpy as np
-import time
 import pytest
 import itertools
-import pdb
 
 
 import tvm
-from tvm.relay.op.contrib.bnns import dtype_is_supported
 import tvm.relay.testing
 
-from tvm import relay, runtime
+from tvm import relay
 from tvm.relay.op.contrib import tensorrt
-from tvm.contrib import graph_executor, utils
-from tvm.runtime.vm import VirtualMachine
 
 from tvm.relay import Any, GlobalVar
-from tvm.relay.transform import FirstOrderGradient, InferType
-from tvm.relay.transform.transform import ToMixedPrecision
 
 from tvm.relay.expr_functor import ExprVisitor
-from typing import Dict, Tuple, Union
+from typing import Tuple
 from tvm.contrib.download import download
 from tvm.relay.op.contrib import tensorrt
 
@@ -78,7 +69,7 @@ def assert_result_dict_holds(result_dict, dtype="float16"):
             if dtype == "float16":
                 tvm.testing.assert_allclose(r1, r2, rtol=1e-1, atol=1e-1)
             else:
-                tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3)
+                tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=5e-3)
 
 
 def set_func_attr(func, compile_name, symbol_name):
@@ -105,6 +96,7 @@ def run_and_verify_func(config, target="cuda", run_module=True, data_type="float
     data_type: str
         Check between single and double floating precision
     """
+    np.random.seed(42)
     f, input_shapes, is_param = config
     params = {
         x: np.random.uniform(-1, 1, input_shapes[x]).astype(dtype=data_type) for x in is_param
@@ -125,7 +117,9 @@ def run_and_verify_func(config, target="cuda", run_module=True, data_type="float
                 result_key = mode + ("_trt" if use_trt else "")
                 if use_trt:
                     mod = relay.transform.InferType()(mod)
-                    mod, config = tensorrt.partition_for_tensorrt(mod, params)
+                    mod, config = tensorrt.partition_for_tensorrt(
+                        mod, params, use_fp16=data_type == "float16"
+                    )
                     with tvm.transform.PassContext(
                         opt_level=3, config={"relay.ext.tensorrt.options": config}
                     ):
@@ -185,7 +179,6 @@ def test_tensorrt_simple(run_module):
                 if run_module:
                     result_dict[result_key] = func(x_data, y_data, z_data)
 
-        print(result_dict)
         if run_module:
             assert_result_dict_holds(result_dict)
 
@@ -594,9 +587,13 @@ def test_reshape(run_module):
         f = relay.Function([x], out)
         return f, {"x": x_shape}, []
 
-    run_and_verify_func(get_graph((1, 1, 1, 10), (-1, 10)), run_module=run_module)
-    run_and_verify_func(get_graph((1, 10, 2, 3), (1, -1)), run_module=run_module)
-    run_and_verify_func(get_graph((1, 1, 2, 3), (1, 6)), run_module=run_module)
+    run_and_verify_func(
+        get_graph((1, 1, 1, 10), (-1, 10)), run_module=run_module, data_type="float16"
+    )
+    run_and_verify_func(
+        get_graph((1, 10, 2, 3), (1, -1)), run_module=run_module, data_type="float16"
+    )
+    run_and_verify_func(get_graph((1, 1, 2, 3), (1, 6)), run_module=run_module, data_type="float16")
 
 
 class AreOpsOnGraph(ExprVisitor):
@@ -731,7 +728,7 @@ def test_float_const16(run_module):
         f = relay.Function([x], out)
         return f, {"x": x_shape}, []
 
-    run_and_verify_func(get_graph(), run_module=run_module)
+    run_and_verify_func(get_graph(), run_module=run_module, data_type="float16")
 
 
 def test_pad(run_module):
@@ -1056,8 +1053,8 @@ def test_multiple_outputs(run_module):
 
 def test_conv3d(run_module):
     def get_graph(
-        x_shape=(1, 32, 8, 8, 8),
-        k_shape=(16, 32, 3, 3, 3),
+        x_shape=(1, 24, 8, 8, 8),
+        k_shape=(16, 24, 3, 3, 3),
         groups=1,
         padding=(0, 0, 0),
         strides=(1, 1, 1),
diff --git a/tests/python/relay/test_pass_inline_composites.py b/tests/python/relay/test_pass_inline_composites.py
deleted file mode 100644
index 54fc08c879..0000000000
--- a/tests/python/relay/test_pass_inline_composites.py
+++ /dev/null
@@ -1,165 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=invalid-name, missing-docstring, too-many-statements
-"""Unit tests for inline composites."""
-import pytest
-import tvm
-from tvm import relay, tir
-from tvm.relay.dataflow_pattern import TupleGetItemPattern, is_op, wildcard
-from tvm.relay.testing import run_opt_pass
-
-"""
-The inline composite pass is designed to inline multiple kernel generated through 
-the merge composite composite pass. The underlying idea is to inline N kernels 
-produced from merge composite based on a given set of pattern into a single IR module.
-Also, clears Composite and PartionedFromPatterns that infer with certain BYOC implementations
-
-For example suppose we have the graph:
-
-        a  b                   
-        \ /              
-        add     
-         |            
-       relu                            
-
-Merge composite will wrap each standalone op to it's own function, while setting Composite and
-PartitionedFromPattern attrs. 
-       
-Relay IR after merge composite pass when registering each op as a standalone pattern: 
-fn (%a: Tensor[(10, 10), float32], %b: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
-  %0 = fn (%FunctionVar_0_01: Tensor[(10, 10), float32], %FunctionVar_0_1: Tensor[(10, 10), float32], PartitionedFromPattern="add_", Composite="add") -> Tensor[(10, 10), float32] {
-    add(%FunctionVar_0_01, %FunctionVar_0_1) /* ty=Tensor[(10, 10), float32] */
-  };
-  %1 = %0(%a, %b) /* ty=Tensor[(10, 10), float32] */;
-  %2 = fn (%FunctionVar_0_0: Tensor[(10, 10), float32], PartitionedFromPattern="nn.relu_", Composite="nn.relu") -> Tensor[(10, 10), float32] {
-    nn.relu(%FunctionVar_0_0) /* ty=Tensor[(10, 10), float32] */
-  };
-  %2(%1) /* ty=Tensor[(10, 10), float32] */
-}
-
-Relay IR after inline composites pass:
-fn (%a: Tensor[(10, 10), float32], %b: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
-  %0 = add(%a, %b) /* ty=Tensor[(10, 10), float32] */;
-  nn.relu(%0) /* ty=Tensor[(10, 10), float32] */
-}
-
-One convenient use of this pass is to use Pattern-based operator support to move away
-from the original operator predicates, and inline them into a single primitive function to offload it 
-to an external BYOC backend, such as TensorRT.
-"""
-
-
-def make_add_relu_pattern():
-    r"""Create a pattern to match the following graph.
-
-     add
-      |
-    relu
-    """
-    add_node = wildcard() + wildcard()
-    r = is_op("nn.relu")(add_node)
-    return r
-
-
-def make_relu_pattern():
-    r"""Create a pattern to match the following graph
-     a
-     |
-    relu
-     |
-    """
-    pattern = is_op("nn.relu")(wildcard())
-    return pattern
-
-
-def make_add_pattern():
-    r"""Create a pattern to match the following graph
-    a  b
-    \  /
-    add
-     |
-    """
-    pattern = is_op("add")(wildcard(), wildcard())
-    return pattern
-
-
-def check_success_composite_pass(func):
-    return func.body.op.attrs["Composite"] is not None
-
-
-def check_result(pattern_table, expected_graph, import_prelude=False):
-    """Utility function to check inline composites results."""
-    result = run_opt_pass(
-        expected_graph, relay.transform.MergeComposite(pattern_table), import_prelude=import_prelude
-    )
-    assert check_success_composite_pass(
-        result
-    ), "Merge Composite pass didn't produced partioned from Pattern"
-    result = run_opt_pass(
-        expected_graph, relay.transform.InlineComposites(target=""), import_prelude=import_prelude
-    )
-    assert not relay.analysis.free_vars(result), "Found free vars in the result graph: {0}".format(
-        str(result)
-    )
-    expected = run_opt_pass(expected_graph, relay.transform.InferType())
-    assert tvm.ir.structural_equal(
-        result, expected, map_free_vars=True
-    ), "Graph mismatch: output vs. expected\n{0}\n=====\n{1}".format(str(result), str(expected))
-
-
-def test_single_op_registry():
-    r"""Test inline composite pass is correctly inline the post-merge composite graph.
-
-    We could expect the patterns `make_add_pattern` and `make_relu_pattern` to be inlined
-    into a single func instead of an single func per registered pattern.
-
-    """
-    pattern_table = [("add", make_add_pattern()), ("nn.relu", make_relu_pattern())]
-
-    def expected():
-        in_1 = relay.var("in_1", shape=(10, 10))
-        in_2 = relay.var("in_2", shape=(10, 10))
-        add_node = relay.add(in_1, in_2)
-        relu_node = relay.nn.relu(add_node)
-        add_relu = relay.Function([in_1, in_2], relu_node)
-        return add_relu
-
-    check_result(pattern_table, expected())
-
-
-def test_mix_fused_and_single_op():
-    r"""Test inline composite pass is correctly inline the merge composite result"""
-    pattern_table = [("add_relu", make_add_relu_pattern()), ("nn.relu", make_relu_pattern())]
-
-    def expected():
-        a = relay.var("a", shape=(10, 10))
-        b = relay.var("b", shape=(10, 10))
-
-        # add_relu function
-        in_1 = relay.var("in_1", shape=(10, 10))
-        in_2 = relay.var("in_2", shape=(10, 10))
-        add_node = relay.add(in_1, in_2)
-        relu_node = relay.nn.relu(add_node)
-        relu_nd = relay.nn.relu(relu_node)
-        add_relu = relay.Function([in_1, in_2], relu_nd)
-        return add_relu
-
-    check_result(pattern_table, expected())
-
-
-if __name__ == "__main__":
-    pytest.main()
diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh
index aaba996dbe..c26ab73846 100755
--- a/tests/scripts/task_mypy.sh
+++ b/tests/scripts/task_mypy.sh
@@ -40,6 +40,7 @@ echo "Checking MyPy Type defs in tvm.relay.op.contrib"
 mypy --disallow-untyped-defs python/tvm/relay/op/contrib/cublas.py
 mypy --disallow-untyped-defs python/tvm/relay/op/contrib/cudnn.py
 mypy --disallow-untyped-defs python/tvm/relay/op/contrib/te_target.py
+mypy --disallow-untyped-defs python/tvm/relay/op/contrib/tensorrt.py
 
 #TODO(@mikepapadim): This is failing atm
 # echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package."