You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/09/16 18:07:05 UTC

[GitHub] [incubator-tvm] zhiics commented on a change in pull request #6395: [BYOC][TensorRT] TensorRT BYOC integration

zhiics commented on a change in pull request #6395:
URL: https://github.com/apache/incubator-tvm/pull/6395#discussion_r489579130



##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -0,0 +1,671 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""TensorRT supported operators."""
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.expr import Call, Constant, Tuple, GlobalVar
+from tvm.relay.expr_functor import ExprMutator
+
+import os
+import logging
+import numpy as np
+
+# Version to use for annotation when there is no linked TRT.
+#TENSORRT_VERSION = (6, 0, 1)
+#USE_IMPLICIT_BATCH = True
+#REMOVE_NO_MAC_SUBGRAPHS = False
+
+def is_tensorrt_runtime_enabled():
+    """Check if the TensorRT graph runtime is present.
+    Returns
+    -------
+    ret: bool
+        True if present, False if not.
+    """
+    check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
+    if check_enabled:
+        return check_enabled()
+    return False
+
+def get_tensorrt_version():
+    """Gets the version of TensorRT that TVM is built against or is targeting.
+
+    Returns
+    -------
+    ret: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, the value set by set_tensorrt_version() is returned instead.
+    """
+    pass_ctx = tvm.transform.PassContext.current()
+    if "relay.ext.tensorrt.options" in pass_ctx.config:
+        return tuple(pass_ctx.config["relay.ext.tensorrt.options"].tensorrt_version)
+    return tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+
+def get_tensorrt_use_implicit_batch_mode():
+    pass_ctx = tvm.transform.PassContext.current()
+    return pass_ctx.config["relay.ext.tensorrt.options"].use_implicit_batch
+
+def get_tensorrt_remove_no_mac_subgraphs():
+    pass_ctx = tvm.transform.PassContext.current()
+    return pass_ctx.config["relay.ext.tensorrt.options"].remove_no_mac_subgraphs
+
+def partition_for_tensorrt(mod, params=None, version=None, use_implicit_batch=True,
+                           remove_no_mac_subgraphs=False, max_workspace_size=1 << 30):
+    """Partition the graph greedily offloading supported
+    operators to TensorRT.
+    Parameters
+    ----------
+    mod : Module
+        The module to run passes on.
+    params : Optional[Dict[str, NDArray]]
+        Constant input parameters.
+    version : Optional[Tuple(int)]
+        TensorRT version to target as tuple of (major, minor, patch). If TVM is compiled with
+        USE_TENSORRT_GRAPH_RUNTIME=ON, the linked TensorRT version will be used instead.
+    use_implicit_batch : Optional[bool]
+        Use TensorRT implicit batch mode (default true). Setting to false will enable explicit batch
+        mode which will widen supported operators to include those which modify the batch dimension,
+        but may reduce performance for some models.
+    remove_no_mac_subgraphs : Optional[bool]
+        Removes subgraphs which have been partitioned for TensorRT if they do not have any
+        multiply-accumulate operations. The removed subgraphs will go through TVM's standard
+        compilation instead. Can improve performance.
+    max_workspace_size : Optional[int]
+        How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation.
+        See TensorRT documentation for more info.
+    Returns
+    -------
+    mod : annotated and partitioned module.
+    config : "relay.ext.tensorrt.options" configuration which should be given to PassContext when building.
+    """
+    config = {
+        "use_implicit_batch": use_implicit_batch,
+        "max_workspace_size": max_workspace_size,
+        "remove_no_mac_subgraphs": remove_no_mac_subgraphs
+    }
+    if version:
+        assert isinstance(version, tuple) and len(version) == 3
+        config["tensorrt_version"] = version
+    else:
+        linked_version = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+        if not linked_version:
+            logging.warn("TVM was not built against TensorRT and no version was provided to partition_for_tensorrt. Defaulting to 6.0.1")
+            linked_version = (6, 0, 1)
+        config["tensorrt_version"] = linked_version
+
+    if params:
+        mod['main'] = bind_params_by_name(mod['main'], params)
+    seq = tvm.transform.Sequential([transform.InferType(),
+                                    RemoveDropoutPass(),
+                                    transform.RemoveUnusedFunctions(),
+                                    transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default'],
+                                                             'nn.conv3d': ['NCDHW', 'default']}),
+                                    transform.FoldConstant(),
+                                    transform.AnnotateTarget('tensorrt'),
+                                    transform.MergeCompilerRegions(),
+                                    transform.PartitionGraph(),
+                                    transform.InferType()])
+    with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
+        mod = seq(mod)
+        mod = prune_tensorrt_subgraphs(mod)
+    return mod, config
+
+
+def _register_external_op_helper(op_name, supported=True):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return supported
+    return _func_wrapper
+
+
+def _register_external_op_helper_func(op_name, func):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return func(attrs, args, op_name)
+    return _func_wrapper
+
+
+# Ops which are always supported
+_register_external_op_helper("nn.relu")
+_register_external_op_helper("sigmoid")
+_register_external_op_helper("tanh")
+_register_external_op_helper("subtract")
+_register_external_op_helper("multiply")
+_register_external_op_helper("divide")
+_register_external_op_helper("power")
+_register_external_op_helper("maximum")
+_register_external_op_helper("minimum")
+_register_external_op_helper("exp")
+_register_external_op_helper("log")
+_register_external_op_helper("sqrt")
+_register_external_op_helper("abs")
+_register_external_op_helper("negative")
+_register_external_op_helper("nn.batch_flatten")
+_register_external_op_helper("clip")
+
+@tvm.ir.register_op_attr("add", "target.tensorrt")
+def add_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if not get_tensorrt_use_implicit_batch_mode() and \
+            (isinstance(args[0], Constant) or isinstance(args[1], Constant)) and \
+            args[0].checked_type.shape[0] == args[1].checked_type.shape[0] and \
+            args[0].checked_type.shape[0] != 1 and \
+            (len(args[0].checked_type.shape) > 3 or len(args[1].checked_type.shape) > 3):
+        print("add: bug in TRT with adding batched constants.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.batch_norm", "target.tensorrt")
+def batch_norm_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if int(attrs.axis) not in (1, 3):
+        print("nn.batch_norm: axis is {} but must be 1 or 3.".format(int(attrs.axis)))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.softmax", "target.tensorrt")
+def softmax_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
+        print("nn.softmax: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.conv2d", "target.tensorrt")
+def conv2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.data_layout != "NCHW":
+        print("nn.conv2d: data_layout is {} but must be NCHW.".format(attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIHW":
+        print("nn.conv2d: kernel_layout is {} but must be OIHW.".format(attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCHW":
+        print("nn.conv2d: out_layout is {} but must be NCHW.".format(attrs.out_layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.dense", "target.tensorrt")
+def dense_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    input_rank = len(args[0].checked_type.shape)
+    weight_rank = len(args[1].checked_type.shape)
+    if input_rank not in (2, 3, 4):
+        print("nn.dense: input has rank {} but must be 2, 3 or 4.".format(input_rank))
+        return False
+    if weight_rank != 2:
+        print("nn.dense: weight has rank {} but must be 2.".format(weight_rank))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.bias_add", "target.tensorrt")
+def bias_add_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    input_rank = len(args[0].checked_type.shape)
+    if input_rank not in (2, 3, 4):
+        print("nn.bias_add: input rank is {} but must be 2, 3 or 4.".format(input_rank))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.max_pool2d", "target.tensorrt")
+def max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.max_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    if attrs.ceil_mode and get_tensorrt_version() < (5, 1, 5):
+        print("nn.avg_pool2d: ceil_mode=True requires TensorRT 5.1.5 or greater.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.avg_pool2d", "target.tensorrt")
+def avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.avg_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    if attrs.count_include_pad and len(attrs.padding) == 4:
+        print("nn.avg_pool2d: inclusive-counted blended or average "
+                "pooling is not supported in combination with asymmetric padding")
+        return False
+    if attrs.ceil_mode and get_tensorrt_version() < (5, 1, 5):
+        print("nn.avg_pool2d: ceil_mode=True requires TensorRT 5.1.5 or greater.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.tensorrt")
+def global_max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.global_max_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.tensorrt")
+def global_avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.global_avg_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("expand_dims", "target.tensorrt")
+def expand_dims_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
+        print("expand_dims: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("squeeze", "target.tensorrt")
+def squeeze_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if not attrs.axis:
+        print("squeeze: must explicitly set axis.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and any([axis == 0 for axis in map(int, attrs.axis)]):
+        print("squeeze: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("concatenate", "target.tensorrt")
+def concatenate_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.dtype != "float32" for x in args[0].checked_type.fields]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if not get_tensorrt_use_implicit_batch_mode():
+        return True
+    if int(attrs.axis) == 0:
+        print("concatenate: can't modify batch dimension.")
+        return False
+    if isinstance(args[0], Tuple):
+        for tuple_input in args[0].fields:
+            if isinstance(tuple_input, Constant):
+                print("concatenate: can't concatenate tensors with constants.")
+                return False
+    return True
+
+@tvm.ir.register_op_attr("nn.conv2d_transpose", "target.tensorrt")
+def conv2d_transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.data_layout != "NCHW":
+        print("nn.conv2d_transpose: data_layout is {} but must be NCHW.".format(
+            attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIHW":
+        print("nn.conv2d_transpose: kernel_layout is {} but must be OIHW.".format(
+            attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCHW":
+        print("nn.conv2d_transpose: out_layout is {} but must be NCHW.".format(
+            attrs.out_layout))
+        return False
+    if attrs.dilation and any([rate != 1 for rate in map(int, attrs.dilation)]):
+        print("nn.conv2d_transpose: dilation rate must be 1.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("transpose", "target.tensorrt")
+def transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axes[0]) != 0:
+        print("transpose: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("layout_transform", "target.tensorrt")
+def resize_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if (attrs.src_layout, attrs.dst_layout) not in [("NCHW", "NHWC"), ("NHWC", "NCHW"), ("NDHWC", "NCDHW"), ("NCDHW", "NDHWC")]:
+        print("layout_transform: {} to {} is not supported.".format(attrs.src_layout, attrs.dst_layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("reshape", "target.tensorrt")
+def reshape_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if args[0].checked_type.dtype != "float32":
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if any([x < -1 for x in map(int, attrs.newshape)]):
+        print("reshape: new shape dims must be explicit.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode():
+        shape = list(map(int, args[0].checked_type.shape))
+        new_shape = list(map(int, attrs.newshape))
+        if len(new_shape) == 0 or len(shape) == 0:
+            print("reshape: Can't reshape to or from scalar.")
+            return False
+        # TRT cannot modify batch dimension.
+        original_volume = np.prod(shape)
+        # First, resolve 0.
+        for i, value in enumerate(new_shape):
+            if value == 0:
+                new_shape[i] = shape[i]
+        # Resolve -1.
+        for i, value in enumerate(new_shape):
+            if value == -1:
+                new_shape[i] = original_volume // np.prod([x for x in new_shape if x != -1])
+        if shape[0] != new_shape[0]:
+            print("reshape: can't modify batch dimension.")
+            return False
+    return True
+
+@tvm.ir.register_op_attr("nn.pad", "target.tensorrt")
+def pad_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.pad_mode != "constant":
+        print("nn.pad: pad mode is {} but must be constant.".format(attrs.pad_mode))
+        return False
+    if float(attrs.pad_value) != 0.0:
+        print("nn.pad: pad value is {} but must be 0.0.".format(float(attrs.pad_value)))
+        return False
+    if any([x != 0 for x in attrs.pad_width[0]]) or any([x != 0 for x in attrs.pad_width[1]]):
+        print("nn.pad: can't pad batch or channel dimensions.")
+        return False
+    if len(attrs.pad_width) == 5 and any([x != 0 for x in attrs.pad_width[2]]):
+        print("nn.pad: can only pad last two dimensions for 5D inputs.")
+    return True
+
+def reduce_annotate_fn(attrs, args, op_name):
+    if not attrs.axis or len(attrs.axis) == 0:
+        print("{}: cannot reduce to scalar.".format(op_name))
+        return False
+    if attrs.exclude:
+        print("{}: exclude not supported.".format(op_name))
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and any([x == 0 for x in map(int, attrs.axis)]):
+        print("{}: can't modify batch dimension.".format(op_name))
+        return False
+    return True
+
+_register_external_op_helper_func("sum", reduce_annotate_fn)
+_register_external_op_helper_func("prod", reduce_annotate_fn)
+_register_external_op_helper_func("max", reduce_annotate_fn)
+_register_external_op_helper_func("min", reduce_annotate_fn)
+_register_external_op_helper_func("mean", reduce_annotate_fn)
+
+def trt_5_1_5_annotate_fn(attrs, args, op_name):
+    if get_tensorrt_version() < (5, 1, 5):
+        print("{}: requires TensorRT version 5.1.5 or higher.".format(op_name))
+        return False
+    return True
+
+_register_external_op_helper_func("nn.leaky_relu", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("sin", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("cos", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("atan", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("ceil", trt_5_1_5_annotate_fn)
+
+@tvm.ir.register_op_attr("strided_slice", "target.tensorrt")
+def strided_slice_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if args[0].checked_type.dtype != "float32":
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if not trt_5_1_5_annotate_fn(attrs, args, "strided_slice"):
+        return False
+    if get_tensorrt_use_implicit_batch_mode():
+        batch_dim_begin_modified = attrs.begin[0] is not None and int(attrs.begin[0]) != 0
+        batch_dim_end_modified = attrs.end[0] is not None and int(attrs.end[0]) != -1 and \
+                                    int(attrs.end[0]) != int(args[0].checked_type.shape[0])
+        if batch_dim_begin_modified or batch_dim_end_modified:
+            print("strided_slice: can't modify batch dimension.")
+            return False
+    if any([x is not None and x <= 0 for x in attrs.strides]):
+        print("strided_slice: stride must be positive")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.adaptive_max_pool2d", "target.tensorrt")
+def adapative_max_pool2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
+        print("nn.adaptive_max_pool2d: output size must be (1, 1).")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.adaptive_avg_pool2d", "target.tensorrt")
+def adapative_avg_pool2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
+        print("nn.adaptive_avg_pool2d: output size must be (1, 1).")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.upsampling", "target.tensorrt")
+def upsampling_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    # TODO(trevmorr): Output does not match TVM. Disable.
+    return False
+
+@tvm.ir.register_op_attr("nn.conv3d", "target.tensorrt")
+def conv3d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_version() < (6, 0, 1):
+        print("nn.conv3d: requires TensorRT version 6.0.1 or higher.")
+        return False
+    if attrs.data_layout != "NCDHW":
+        print("nn.conv3d: data_layout is {} but must be NCDHW.".format(attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIDHW":
+        print("nn.conv3d: kernel_layout is {} but must be OIDHW.".format(attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCDHW":
+        print("nn.conv3d: out_layout is {} but must be NCDHW.".format(attrs.out_layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.max_pool3d", "target.tensorrt")
+def max_pool_3d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_version() < (6, 0, 1):
+        print("nn.max_pool3d: requires TensorRT version 6.0.1 or higher.")
+        return False
+    if attrs.layout != "NCDHW":
+        print("nn.max_pool3d: layout is {} but must be NCDHW.".format(attrs.layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.avg_pool3d", "target.tensorrt")
+def avg_pool_3d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_version() < (6, 0, 1):
+        print("nn.avg_pool3d: requires TensorRT version 6.0.1 or higher.")
+        return False
+    if attrs.layout != "NCDHW":
+        print("nn.avg_pool3d: layout is {} but must be NCDHW.".format(attrs.layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.conv3d_transpose", "target.tensorrt")
+def conv3d_transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_version() < (6, 0, 1):
+        print("nn.conv3d_transpose: requires TensorRT version 6.0.1 or higher.")
+        return False
+    if attrs.data_layout != "NCDHW":
+        print("nn.conv3d_transpose: data_layout is {} but must be NCDHW.".format(
+            attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIDHW":
+        print("nn.conv3d_transpose: kernel_layout is {} but must be OIDHW.".format(
+            attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCDHW":
+        print("nn.conv3d_transpose: out_layout is {} but must be NCDHW.".format(
+            attrs.out_layout))
+        return False
+    if attrs.dilation and any([rate != 1 for rate in map(int, attrs.dilation)]):
+        print("nn.conv3d_transpose: dilation rate must be 1.")
+        return False
+    if attrs.output_padding and any([x != 0 for x in map(int, attrs.output_padding)]):
+        print("nn.conv3d_transpose: output padding is not supported.")
+        return False
+    return True
+
+def is_invalid_subgraph(params, body):
+    # Remove invalid subgraphs for implicit batch mode.
+    if get_tensorrt_use_implicit_batch_mode():
+        input_batch_sizes = []
+        for var in params:
+            # In implicit batch mode, all inputs must have same batch size
+            if isinstance(var.checked_type, relay.TupleType):
+                for tupe_type in var.checked_type.fields:
+                    # Scalar inputs not allowed
+                    if len(tupe_type.shape) == 0:
+                        print('tensorrt: scalar inputs not supported')
+                        return True
+                    input_batch_sizes.append(int(tupe_type.shape[0]))
+            else:
+                # Scalar inputs not allowed
+                if len(var.checked_type.shape) == 0:
+                    print('tensorrt: scalar inputs not supported')
+                    return True
+                input_batch_sizes.append(int(var.checked_type.shape[0]))
+        if len(input_batch_sizes) > 1 and \
+           any([x != input_batch_sizes[0] for x in input_batch_sizes[1:]]):
+            print('tensorrt: inputs have different batch sizes')
+            return True
+    # Remove subgraphs with no multiply-accumulates
+    if get_tensorrt_remove_no_mac_subgraphs() and relay.analysis.get_total_mac_number(body) == 0:
+        return True
+    return False
+
+def prune_tensorrt_subgraphs(mod, target="tensorrt"):
+    class VarReplacer(ExprMutator):

Review comment:
       This looks the same to `relay.bind`

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -0,0 +1,671 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""TensorRT supported operators."""
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.expr import Call, Constant, Tuple, GlobalVar
+from tvm.relay.expr_functor import ExprMutator
+
+import os
+import logging
+import numpy as np
+
+# Version to use for annotation when there is no linked TRT.
+#TENSORRT_VERSION = (6, 0, 1)
+#USE_IMPLICIT_BATCH = True
+#REMOVE_NO_MAC_SUBGRAPHS = False
+
+def is_tensorrt_runtime_enabled():
+    """Check if the TensorRT graph runtime is present.
+    Returns
+    -------
+    ret: bool
+        True if present, False if not.
+    """
+    check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
+    if check_enabled:
+        return check_enabled()
+    return False
+
+def get_tensorrt_version():
+    """Gets the version of TensorRT that TVM is built against or is targeting.
+
+    Returns
+    -------
+    ret: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, the value set by set_tensorrt_version() is returned instead.
+    """
+    pass_ctx = tvm.transform.PassContext.current()
+    if "relay.ext.tensorrt.options" in pass_ctx.config:
+        return tuple(pass_ctx.config["relay.ext.tensorrt.options"].tensorrt_version)
+    return tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+
+def get_tensorrt_use_implicit_batch_mode():
+    pass_ctx = tvm.transform.PassContext.current()
+    return pass_ctx.config["relay.ext.tensorrt.options"].use_implicit_batch
+
+def get_tensorrt_remove_no_mac_subgraphs():
+    pass_ctx = tvm.transform.PassContext.current()
+    return pass_ctx.config["relay.ext.tensorrt.options"].remove_no_mac_subgraphs
+
+def partition_for_tensorrt(mod, params=None, version=None, use_implicit_batch=True,
+                           remove_no_mac_subgraphs=False, max_workspace_size=1 << 30):
+    """Partition the graph greedily offloading supported
+    operators to TensorRT.
+    Parameters
+    ----------
+    mod : Module
+        The module to run passes on.
+    params : Optional[Dict[str, NDArray]]
+        Constant input parameters.
+    version : Optional[Tuple(int)]
+        TensorRT version to target as tuple of (major, minor, patch). If TVM is compiled with
+        USE_TENSORRT_GRAPH_RUNTIME=ON, the linked TensorRT version will be used instead.
+    use_implicit_batch : Optional[bool]
+        Use TensorRT implicit batch mode (default true). Setting to false will enable explicit batch
+        mode which will widen supported operators to include those which modify the batch dimension,
+        but may reduce performance for some models.
+    remove_no_mac_subgraphs : Optional[bool]
+        Removes subgraphs which have been partitioned for TensorRT if they do not have any
+        multiply-accumulate operations. The removed subgraphs will go through TVM's standard
+        compilation instead. Can improve performance.
+    max_workspace_size : Optional[int]
+        How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation.
+        See TensorRT documentation for more info.
+    Returns
+    -------
+    mod : annotated and partitioned module.
+    config : "relay.ext.tensorrt.options" configuration which should be given to PassContext when building.
+    """
+    config = {
+        "use_implicit_batch": use_implicit_batch,
+        "max_workspace_size": max_workspace_size,
+        "remove_no_mac_subgraphs": remove_no_mac_subgraphs
+    }
+    if version:
+        assert isinstance(version, tuple) and len(version) == 3
+        config["tensorrt_version"] = version
+    else:
+        linked_version = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+        if not linked_version:
+            logging.warn("TVM was not built against TensorRT and no version was provided to partition_for_tensorrt. Defaulting to 6.0.1")
+            linked_version = (6, 0, 1)
+        config["tensorrt_version"] = linked_version
+
+    if params:
+        mod['main'] = bind_params_by_name(mod['main'], params)
+    seq = tvm.transform.Sequential([transform.InferType(),
+                                    RemoveDropoutPass(),
+                                    transform.RemoveUnusedFunctions(),
+                                    transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default'],
+                                                             'nn.conv3d': ['NCDHW', 'default']}),
+                                    transform.FoldConstant(),
+                                    transform.AnnotateTarget('tensorrt'),
+                                    transform.MergeCompilerRegions(),
+                                    transform.PartitionGraph(),
+                                    transform.InferType()])
+    with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
+        mod = seq(mod)
+        mod = prune_tensorrt_subgraphs(mod)
+    return mod, config
+
+
+def _register_external_op_helper(op_name, supported=True):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return supported
+    return _func_wrapper
+
+
+def _register_external_op_helper_func(op_name, func):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return func(attrs, args, op_name)
+    return _func_wrapper
+
+
+# Ops which are always supported
+_register_external_op_helper("nn.relu")
+_register_external_op_helper("sigmoid")
+_register_external_op_helper("tanh")
+_register_external_op_helper("subtract")
+_register_external_op_helper("multiply")
+_register_external_op_helper("divide")
+_register_external_op_helper("power")
+_register_external_op_helper("maximum")
+_register_external_op_helper("minimum")
+_register_external_op_helper("exp")
+_register_external_op_helper("log")
+_register_external_op_helper("sqrt")
+_register_external_op_helper("abs")
+_register_external_op_helper("negative")
+_register_external_op_helper("nn.batch_flatten")
+_register_external_op_helper("clip")
+
+@tvm.ir.register_op_attr("add", "target.tensorrt")
+def add_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if not get_tensorrt_use_implicit_batch_mode() and \
+            (isinstance(args[0], Constant) or isinstance(args[1], Constant)) and \
+            args[0].checked_type.shape[0] == args[1].checked_type.shape[0] and \
+            args[0].checked_type.shape[0] != 1 and \
+            (len(args[0].checked_type.shape) > 3 or len(args[1].checked_type.shape) > 3):
+        print("add: bug in TRT with adding batched constants.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.batch_norm", "target.tensorrt")
+def batch_norm_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if int(attrs.axis) not in (1, 3):
+        print("nn.batch_norm: axis is {} but must be 1 or 3.".format(int(attrs.axis)))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.softmax", "target.tensorrt")
+def softmax_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
+        print("nn.softmax: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.conv2d", "target.tensorrt")
+def conv2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.data_layout != "NCHW":
+        print("nn.conv2d: data_layout is {} but must be NCHW.".format(attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIHW":
+        print("nn.conv2d: kernel_layout is {} but must be OIHW.".format(attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCHW":
+        print("nn.conv2d: out_layout is {} but must be NCHW.".format(attrs.out_layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.dense", "target.tensorrt")
+def dense_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    input_rank = len(args[0].checked_type.shape)
+    weight_rank = len(args[1].checked_type.shape)
+    if input_rank not in (2, 3, 4):
+        print("nn.dense: input has rank {} but must be 2, 3 or 4.".format(input_rank))
+        return False
+    if weight_rank != 2:
+        print("nn.dense: weight has rank {} but must be 2.".format(weight_rank))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.bias_add", "target.tensorrt")
+def bias_add_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    input_rank = len(args[0].checked_type.shape)
+    if input_rank not in (2, 3, 4):
+        print("nn.bias_add: input rank is {} but must be 2, 3 or 4.".format(input_rank))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.max_pool2d", "target.tensorrt")
+def max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.max_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    if attrs.ceil_mode and get_tensorrt_version() < (5, 1, 5):
+        print("nn.avg_pool2d: ceil_mode=True requires TensorRT 5.1.5 or greater.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.avg_pool2d", "target.tensorrt")
+def avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.avg_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    if attrs.count_include_pad and len(attrs.padding) == 4:
+        print("nn.avg_pool2d: inclusive-counted blended or average "
+                "pooling is not supported in combination with asymmetric padding")
+        return False
+    if attrs.ceil_mode and get_tensorrt_version() < (5, 1, 5):
+        print("nn.avg_pool2d: ceil_mode=True requires TensorRT 5.1.5 or greater.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.tensorrt")
+def global_max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.global_max_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.tensorrt")
+def global_avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.global_avg_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("expand_dims", "target.tensorrt")
+def expand_dims_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
+        print("expand_dims: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("squeeze", "target.tensorrt")
+def squeeze_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if not attrs.axis:
+        print("squeeze: must explicitly set axis.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and any([axis == 0 for axis in map(int, attrs.axis)]):
+        print("squeeze: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("concatenate", "target.tensorrt")
+def concatenate_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.dtype != "float32" for x in args[0].checked_type.fields]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if not get_tensorrt_use_implicit_batch_mode():
+        return True
+    if int(attrs.axis) == 0:
+        print("concatenate: can't modify batch dimension.")
+        return False
+    if isinstance(args[0], Tuple):
+        for tuple_input in args[0].fields:
+            if isinstance(tuple_input, Constant):
+                print("concatenate: can't concatenate tensors with constants.")
+                return False
+    return True
+
+@tvm.ir.register_op_attr("nn.conv2d_transpose", "target.tensorrt")
+def conv2d_transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.data_layout != "NCHW":
+        print("nn.conv2d_transpose: data_layout is {} but must be NCHW.".format(
+            attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIHW":
+        print("nn.conv2d_transpose: kernel_layout is {} but must be OIHW.".format(
+            attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCHW":
+        print("nn.conv2d_transpose: out_layout is {} but must be NCHW.".format(
+            attrs.out_layout))
+        return False
+    if attrs.dilation and any([rate != 1 for rate in map(int, attrs.dilation)]):
+        print("nn.conv2d_transpose: dilation rate must be 1.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("transpose", "target.tensorrt")
+def transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axes[0]) != 0:
+        print("transpose: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("layout_transform", "target.tensorrt")
+def resize_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if (attrs.src_layout, attrs.dst_layout) not in [("NCHW", "NHWC"), ("NHWC", "NCHW"), ("NDHWC", "NCDHW"), ("NCDHW", "NDHWC")]:
+        print("layout_transform: {} to {} is not supported.".format(attrs.src_layout, attrs.dst_layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("reshape", "target.tensorrt")
+def reshape_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if args[0].checked_type.dtype != "float32":
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if any([x < -1 for x in map(int, attrs.newshape)]):
+        print("reshape: new shape dims must be explicit.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode():
+        shape = list(map(int, args[0].checked_type.shape))
+        new_shape = list(map(int, attrs.newshape))
+        if len(new_shape) == 0 or len(shape) == 0:
+            print("reshape: Can't reshape to or from scalar.")
+            return False
+        # TRT cannot modify batch dimension.
+        original_volume = np.prod(shape)
+        # First, resolve 0.
+        for i, value in enumerate(new_shape):
+            if value == 0:
+                new_shape[i] = shape[i]
+        # Resolve -1.
+        for i, value in enumerate(new_shape):
+            if value == -1:
+                new_shape[i] = original_volume // np.prod([x for x in new_shape if x != -1])
+        if shape[0] != new_shape[0]:
+            print("reshape: can't modify batch dimension.")
+            return False
+    return True
+
+@tvm.ir.register_op_attr("nn.pad", "target.tensorrt")
+def pad_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.pad_mode != "constant":
+        print("nn.pad: pad mode is {} but must be constant.".format(attrs.pad_mode))
+        return False
+    if float(attrs.pad_value) != 0.0:
+        print("nn.pad: pad value is {} but must be 0.0.".format(float(attrs.pad_value)))
+        return False
+    if any([x != 0 for x in attrs.pad_width[0]]) or any([x != 0 for x in attrs.pad_width[1]]):
+        print("nn.pad: can't pad batch or channel dimensions.")
+        return False
+    if len(attrs.pad_width) == 5 and any([x != 0 for x in attrs.pad_width[2]]):
+        print("nn.pad: can only pad last two dimensions for 5D inputs.")
+    return True
+
+def reduce_annotate_fn(attrs, args, op_name):
+    if not attrs.axis or len(attrs.axis) == 0:
+        print("{}: cannot reduce to scalar.".format(op_name))
+        return False
+    if attrs.exclude:
+        print("{}: exclude not supported.".format(op_name))
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and any([x == 0 for x in map(int, attrs.axis)]):
+        print("{}: can't modify batch dimension.".format(op_name))
+        return False
+    return True
+
+_register_external_op_helper_func("sum", reduce_annotate_fn)
+_register_external_op_helper_func("prod", reduce_annotate_fn)
+_register_external_op_helper_func("max", reduce_annotate_fn)
+_register_external_op_helper_func("min", reduce_annotate_fn)
+_register_external_op_helper_func("mean", reduce_annotate_fn)
+
+def trt_5_1_5_annotate_fn(attrs, args, op_name):
+    if get_tensorrt_version() < (5, 1, 5):
+        print("{}: requires TensorRT version 5.1.5 or higher.".format(op_name))
+        return False
+    return True
+
+_register_external_op_helper_func("nn.leaky_relu", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("sin", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("cos", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("atan", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("ceil", trt_5_1_5_annotate_fn)
+
+@tvm.ir.register_op_attr("strided_slice", "target.tensorrt")
+def strided_slice_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if args[0].checked_type.dtype != "float32":
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if not trt_5_1_5_annotate_fn(attrs, args, "strided_slice"):
+        return False
+    if get_tensorrt_use_implicit_batch_mode():
+        batch_dim_begin_modified = attrs.begin[0] is not None and int(attrs.begin[0]) != 0
+        batch_dim_end_modified = attrs.end[0] is not None and int(attrs.end[0]) != -1 and \
+                                    int(attrs.end[0]) != int(args[0].checked_type.shape[0])
+        if batch_dim_begin_modified or batch_dim_end_modified:
+            print("strided_slice: can't modify batch dimension.")
+            return False
+    if any([x is not None and x <= 0 for x in attrs.strides]):
+        print("strided_slice: stride must be positive")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.adaptive_max_pool2d", "target.tensorrt")
+def adapative_max_pool2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
+        print("nn.adaptive_max_pool2d: output size must be (1, 1).")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.adaptive_avg_pool2d", "target.tensorrt")
+def adapative_avg_pool2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
+        print("nn.adaptive_avg_pool2d: output size must be (1, 1).")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.upsampling", "target.tensorrt")
+def upsampling_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    # TODO(trevmorr): Output does not match TVM. Disable.
+    return False
+
+@tvm.ir.register_op_attr("nn.conv3d", "target.tensorrt")
+def conv3d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_version() < (6, 0, 1):
+        print("nn.conv3d: requires TensorRT version 6.0.1 or higher.")
+        return False
+    if attrs.data_layout != "NCDHW":
+        print("nn.conv3d: data_layout is {} but must be NCDHW.".format(attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIDHW":
+        print("nn.conv3d: kernel_layout is {} but must be OIDHW.".format(attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCDHW":
+        print("nn.conv3d: out_layout is {} but must be NCDHW.".format(attrs.out_layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.max_pool3d", "target.tensorrt")
+def max_pool_3d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_version() < (6, 0, 1):
+        print("nn.max_pool3d: requires TensorRT version 6.0.1 or higher.")
+        return False
+    if attrs.layout != "NCDHW":
+        print("nn.max_pool3d: layout is {} but must be NCDHW.".format(attrs.layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.avg_pool3d", "target.tensorrt")
+def avg_pool_3d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_version() < (6, 0, 1):
+        print("nn.avg_pool3d: requires TensorRT version 6.0.1 or higher.")
+        return False
+    if attrs.layout != "NCDHW":
+        print("nn.avg_pool3d: layout is {} but must be NCDHW.".format(attrs.layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.conv3d_transpose", "target.tensorrt")
+def conv3d_transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_version() < (6, 0, 1):
+        print("nn.conv3d_transpose: requires TensorRT version 6.0.1 or higher.")
+        return False
+    if attrs.data_layout != "NCDHW":
+        print("nn.conv3d_transpose: data_layout is {} but must be NCDHW.".format(
+            attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIDHW":
+        print("nn.conv3d_transpose: kernel_layout is {} but must be OIDHW.".format(
+            attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCDHW":
+        print("nn.conv3d_transpose: out_layout is {} but must be NCDHW.".format(
+            attrs.out_layout))
+        return False
+    if attrs.dilation and any([rate != 1 for rate in map(int, attrs.dilation)]):
+        print("nn.conv3d_transpose: dilation rate must be 1.")
+        return False
+    if attrs.output_padding and any([x != 0 for x in map(int, attrs.output_padding)]):
+        print("nn.conv3d_transpose: output padding is not supported.")
+        return False
+    return True
+
+def is_invalid_subgraph(params, body):
+    # Remove invalid subgraphs for implicit batch mode.
+    if get_tensorrt_use_implicit_batch_mode():
+        input_batch_sizes = []
+        for var in params:
+            # In implicit batch mode, all inputs must have same batch size
+            if isinstance(var.checked_type, relay.TupleType):
+                for tupe_type in var.checked_type.fields:
+                    # Scalar inputs not allowed
+                    if len(tupe_type.shape) == 0:
+                        print('tensorrt: scalar inputs not supported')
+                        return True
+                    input_batch_sizes.append(int(tupe_type.shape[0]))
+            else:
+                # Scalar inputs not allowed
+                if len(var.checked_type.shape) == 0:
+                    print('tensorrt: scalar inputs not supported')
+                    return True
+                input_batch_sizes.append(int(var.checked_type.shape[0]))
+        if len(input_batch_sizes) > 1 and \
+           any([x != input_batch_sizes[0] for x in input_batch_sizes[1:]]):
+            print('tensorrt: inputs have different batch sizes')
+            return True
+    # Remove subgraphs with no multiply-accumulates
+    if get_tensorrt_remove_no_mac_subgraphs() and relay.analysis.get_total_mac_number(body) == 0:
+        return True
+    return False
+
+def prune_tensorrt_subgraphs(mod, target="tensorrt"):
+    class VarReplacer(ExprMutator):
+        """
+        Visit an expression while replacing vars according to var_map. Used by
+        SubgraphRemover/PruneSubgraphs to return a subgraph originally partitioned to TRT back to TVM.
+        """
+        def __init__(self, var_map):
+            ExprMutator.__init__(self)
+            self.var_map = var_map
+
+        def visit_var(self, var):
+            if var in self.var_map:
+                return self.var_map[var]
+            return super().visit_var(var)
+
+    class SubgraphRemover(ExprMutator):

Review comment:
       Can we just remove the attributes of the function, e.g. inline, and then run the inline pass? 

##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/tensorrt/tensorrt_runtime.cc
+ * \brief JSON runtime implementation for TensorRT.
+ */
+
+#include <dmlc/parameter.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+#include <fstream>
+
+#include "../../file_util.h"
+#include "../json/json_node.h"
+#include "../json/json_runtime.h"
+
+#ifdef TVM_GRAPH_RUNTIME_TENSORRT
+#include "NvInfer.h"
+#include "tensorrt_builder.h"
+#endif
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+using namespace tvm::runtime::json;
+
+class TensorRTRuntime : public JSONRuntimeBase {
+ public:
+  /*!
+   * \brief The TensorRT runtime module. Deserialize the provided functions
+   * on creation and store in the layer cache.
+   *
+   * \param symbol_name The name of the function.
+   * \param graph_json serialized JSON representation of a sub-graph.
+   * \param const_names The names of each constant in the sub-graph.
+   */
+  explicit TensorRTRuntime(const std::string& symbol_name, const std::string& graph_json,
+                           const Array<String>& const_names)
+      : JSONRuntimeBase(symbol_name, graph_json, const_names), use_implicit_batch_(true),
+        max_workspace_size_(size_t(1) << 30) {}
+
+  /*!
+   * \brief The type key of the module.
+   *
+   * \return module type key.
+   */
+  const char* type_key() const override { return "tensorrt"; }
+
+  /*!
+   * \brief Initialize runtime. Create TensorRT layer from JSON
+   * representation.
+   *
+   * \param consts The constant params from compiled model.
+   */
+  void Init(const Array<NDArray>& consts) override {
+    CHECK_EQ(consts.size(), const_idx_.size())
+        << "The number of input constants must match the number of required.";
+    LoadGlobalAttributes();
+    if (GetCachedEnginesFromDisk()) return;
+    SetupConstants(consts);
+    BuildEngine();
+    CacheEngineToDisk();
+  }
+
+  void LoadGlobalAttributes() {
+    // These settings are global to the entire subgraph. Codegen will add them as attributes to all
+    // op nodes. Read from first one.
+    for (size_t i = 0; i < nodes_.size(); ++i) {
+      if (nodes_[i].HasAttr("use_implicit_batch") && nodes_[i].HasAttr("max_workspace_size")) {
+        use_implicit_batch_ =
+            std::stoi(nodes_[i].GetAttr<std::vector<std::string>>("use_implicit_batch")[0]);
+        // Allow max_workspace_size to be overridden at runtime.
+        size_t runtime_max_workspace_size =
+            dmlc::GetEnv("TVM_TENSORRT_MAX_WORKSPACE_SIZE", size_t(0));
+        if (runtime_max_workspace_size != 0) {
+          max_workspace_size_ = runtime_max_workspace_size;
+        } else {
+          max_workspace_size_ =
+              std::stoul(nodes_[i].GetAttr<std::vector<std::string>>("max_workspace_size")[0]);
+        }
+        return;
+      }
+    }
+  }
+
+#ifdef TVM_GRAPH_RUNTIME_TENSORRT
+  /*! \brief Run inference using built engine. */
+  void Run() override {
+    auto& engine_and_context = trt_engine_cache_.at(symbol_name_);
+    auto engine = engine_and_context.engine;
+    auto context = engine_and_context.context;
+    std::vector<void*> bindings(engine->getNbBindings(), nullptr);
+
+    for (size_t i = 0; i < input_nodes_.size(); ++i) {
+      auto nid = input_nodes_[i];
+      if (nodes_[nid].GetOpType() == "input") {
+        for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) {
+          uint32_t eid = EntryID(nid, j);
+          const std::string name = nodes_[nid].GetOpName() + "_" + std::to_string(j);
+          int binding_index = engine->getBindingIndex(name.c_str());
+          CHECK_NE(binding_index, -1);
+          bindings[binding_index] = data_entry_[eid]->data;
+        }
+      }
+    }
+
+    for (size_t i = 0; i < outputs_.size(); ++i) {
+      uint32_t eid = EntryID(outputs_[i]);
+      const std::string& name = engine_and_context.outputs[i];
+      int binding_index = engine->getBindingIndex(name.c_str());
+      CHECK_NE(binding_index, -1);
+      bindings[binding_index] = data_entry_[eid]->data;
+    }
+
+#if TRT_VERSION_GE(6, 0, 1)
+    if (use_implicit_batch_) {
+      CHECK(context->execute(batch_size_, bindings.data())) << "Running TensorRT failed.";
+    } else {
+      CHECK(context->executeV2(bindings.data())) << "Running TensorRT failed.";
+    }
+#else
+    CHECK(context->execute(batch_size_, bindings.data())) << "Running TensorRT failed.";
+#endif
+  }
+
+ private:
+  /*!
+   * \brief Build TensorRT engine from JSON representation.
+   */
+  void BuildEngine() {
+    LOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_;
+    const bool use_fp16 = dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false);
+    batch_size_ = GetBatchSize();
+    TensorRTBuilder builder(&logger_, max_workspace_size_, use_implicit_batch_, use_fp16,
+                            batch_size_);
+
+    // Add inputs and constants.
+    for (size_t i = 0; i < input_nodes_.size(); ++i) {
+      auto nid = input_nodes_[i];
+      const auto& node = nodes_[nid];
+      std::string name = node.GetOpName();
+      if (node.GetOpType() == "input") {
+        builder.AddInput(nid, node);
+      } else {
+        CHECK_EQ(node.GetOpType(), "const");
+        uint32_t eid = EntryID(nid, 0);
+        builder.AddConstant(nid, data_entry_[eid]);
+      }
+    }
+
+    // Add layers.
+    for (size_t nid = 0; nid < nodes_.size(); ++nid) {
+      const auto& node = nodes_[nid];
+      if (node.GetOpType() != "kernel") continue;
+      builder.AddLayer(nid, node);
+    }
+
+    // Add outputs.
+    for (size_t i = 0; i < outputs_.size(); ++i) {
+      builder.AddOutput(outputs_[i]);
+    }
+
+    // Build engine.
+    trt_engine_cache_[symbol_name_] = builder.BuildEngine();
+    LOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_;
+  }
+
+  /*! \brief If TVM_TENSORRT_CACHE_DIR is set, will check that directory for
+   * already built TRT engines and load into trt_engine_cache_ so they don't
+   * have to be built at first inference.
+   */
+  bool GetCachedEnginesFromDisk() {
+    std::string cache_dir = dmlc::GetEnv("TVM_TENSORRT_CACHE_DIR", std::string(""));
+    if (cache_dir.empty()) return false;
+    std::string key = GetSubgraphKey();
+    std::string path = cache_dir + "/" + key + ".plan";
+    // Check if engine is in the cache.
+    std::ifstream infile(path, std::ios::binary);
+    if (!infile.good()) return false;
+    LOG(INFO) << "Loading cached TensorRT engine from " << path;

Review comment:
       I suggest we use DLOG for all

##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/tensorrt/tensorrt_runtime.cc
+ * \brief JSON runtime implementation for TensorRT.
+ */
+
+#include <dmlc/parameter.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+#include <fstream>
+
+#include "../../file_util.h"
+#include "../json/json_node.h"
+#include "../json/json_runtime.h"
+
+#ifdef TVM_GRAPH_RUNTIME_TENSORRT
+#include "NvInfer.h"
+#include "tensorrt_builder.h"
+#endif
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+using namespace tvm::runtime::json;
+
+class TensorRTRuntime : public JSONRuntimeBase {
+ public:
+  /*!
+   * \brief The TensorRT runtime module. Deserialize the provided functions
+   * on creation and store in the layer cache.
+   *
+   * \param symbol_name The name of the function.
+   * \param graph_json serialized JSON representation of a sub-graph.
+   * \param const_names The names of each constant in the sub-graph.
+   */
+  explicit TensorRTRuntime(const std::string& symbol_name, const std::string& graph_json,
+                           const Array<String>& const_names)
+      : JSONRuntimeBase(symbol_name, graph_json, const_names), use_implicit_batch_(true),
+        max_workspace_size_(size_t(1) << 30) {}
+
+  /*!
+   * \brief The type key of the module.
+   *
+   * \return module type key.
+   */
+  const char* type_key() const override { return "tensorrt"; }
+
+  /*!
+   * \brief Initialize runtime. Create TensorRT layer from JSON
+   * representation.
+   *
+   * \param consts The constant params from compiled model.
+   */
+  void Init(const Array<NDArray>& consts) override {
+    CHECK_EQ(consts.size(), const_idx_.size())
+        << "The number of input constants must match the number of required.";
+    LoadGlobalAttributes();
+    if (GetCachedEnginesFromDisk()) return;
+    SetupConstants(consts);
+    BuildEngine();
+    CacheEngineToDisk();
+  }
+
+  void LoadGlobalAttributes() {
+    // These settings are global to the entire subgraph. Codegen will add them as attributes to all
+    // op nodes. Read from first one.
+    for (size_t i = 0; i < nodes_.size(); ++i) {
+      if (nodes_[i].HasAttr("use_implicit_batch") && nodes_[i].HasAttr("max_workspace_size")) {
+        use_implicit_batch_ =
+            std::stoi(nodes_[i].GetAttr<std::vector<std::string>>("use_implicit_batch")[0]);
+        // Allow max_workspace_size to be overridden at runtime.
+        size_t runtime_max_workspace_size =
+            dmlc::GetEnv("TVM_TENSORRT_MAX_WORKSPACE_SIZE", size_t(0));
+        if (runtime_max_workspace_size != 0) {
+          max_workspace_size_ = runtime_max_workspace_size;
+        } else {
+          max_workspace_size_ =
+              std::stoul(nodes_[i].GetAttr<std::vector<std::string>>("max_workspace_size")[0]);
+        }
+        return;
+      }
+    }
+  }
+
+#ifdef TVM_GRAPH_RUNTIME_TENSORRT
+  /*! \brief Run inference using built engine. */
+  void Run() override {
+    auto& engine_and_context = trt_engine_cache_.at(symbol_name_);
+    auto engine = engine_and_context.engine;
+    auto context = engine_and_context.context;
+    std::vector<void*> bindings(engine->getNbBindings(), nullptr);
+
+    for (size_t i = 0; i < input_nodes_.size(); ++i) {
+      auto nid = input_nodes_[i];
+      if (nodes_[nid].GetOpType() == "input") {
+        for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) {
+          uint32_t eid = EntryID(nid, j);
+          const std::string name = nodes_[nid].GetOpName() + "_" + std::to_string(j);
+          int binding_index = engine->getBindingIndex(name.c_str());
+          CHECK_NE(binding_index, -1);
+          bindings[binding_index] = data_entry_[eid]->data;
+        }
+      }
+    }
+
+    for (size_t i = 0; i < outputs_.size(); ++i) {
+      uint32_t eid = EntryID(outputs_[i]);
+      const std::string& name = engine_and_context.outputs[i];
+      int binding_index = engine->getBindingIndex(name.c_str());
+      CHECK_NE(binding_index, -1);
+      bindings[binding_index] = data_entry_[eid]->data;
+    }
+
+#if TRT_VERSION_GE(6, 0, 1)
+    if (use_implicit_batch_) {
+      CHECK(context->execute(batch_size_, bindings.data())) << "Running TensorRT failed.";
+    } else {
+      CHECK(context->executeV2(bindings.data())) << "Running TensorRT failed.";
+    }
+#else
+    CHECK(context->execute(batch_size_, bindings.data())) << "Running TensorRT failed.";
+#endif
+  }
+
+ private:
+  /*!
+   * \brief Build TensorRT engine from JSON representation.
+   */
+  void BuildEngine() {
+    LOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_;
+    const bool use_fp16 = dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false);
+    batch_size_ = GetBatchSize();
+    TensorRTBuilder builder(&logger_, max_workspace_size_, use_implicit_batch_, use_fp16,
+                            batch_size_);
+
+    // Add inputs and constants.
+    for (size_t i = 0; i < input_nodes_.size(); ++i) {
+      auto nid = input_nodes_[i];
+      const auto& node = nodes_[nid];
+      std::string name = node.GetOpName();
+      if (node.GetOpType() == "input") {
+        builder.AddInput(nid, node);
+      } else {
+        CHECK_EQ(node.GetOpType(), "const");
+        uint32_t eid = EntryID(nid, 0);
+        builder.AddConstant(nid, data_entry_[eid]);
+      }
+    }
+
+    // Add layers.
+    for (size_t nid = 0; nid < nodes_.size(); ++nid) {
+      const auto& node = nodes_[nid];
+      if (node.GetOpType() != "kernel") continue;
+      builder.AddLayer(nid, node);
+    }
+
+    // Add outputs.
+    for (size_t i = 0; i < outputs_.size(); ++i) {
+      builder.AddOutput(outputs_[i]);
+    }
+
+    // Build engine.
+    trt_engine_cache_[symbol_name_] = builder.BuildEngine();
+    LOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_;
+  }
+
+  /*! \brief If TVM_TENSORRT_CACHE_DIR is set, will check that directory for
+   * already built TRT engines and load into trt_engine_cache_ so they don't
+   * have to be built at first inference.
+   */
+  bool GetCachedEnginesFromDisk() {

Review comment:
       I am not sure if we need these two serialization methods. Can we just rely on LoadFrom/SaveToBinary?

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -0,0 +1,671 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""TensorRT supported operators."""
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.expr import Call, Constant, Tuple, GlobalVar
+from tvm.relay.expr_functor import ExprMutator
+
+import os
+import logging
+import numpy as np
+
+# Version to use for annotation when there is no linked TRT.
+#TENSORRT_VERSION = (6, 0, 1)
+#USE_IMPLICIT_BATCH = True
+#REMOVE_NO_MAC_SUBGRAPHS = False
+
+def is_tensorrt_runtime_enabled():
+    """Check if the TensorRT graph runtime is present.
+    Returns
+    -------
+    ret: bool
+        True if present, False if not.
+    """
+    check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
+    if check_enabled:
+        return check_enabled()
+    return False
+
+def get_tensorrt_version():
+    """Gets the version of TensorRT that TVM is built against or is targeting.
+
+    Returns
+    -------
+    ret: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, the value set by set_tensorrt_version() is returned instead.
+    """
+    pass_ctx = tvm.transform.PassContext.current()
+    if "relay.ext.tensorrt.options" in pass_ctx.config:
+        return tuple(pass_ctx.config["relay.ext.tensorrt.options"].tensorrt_version)
+    return tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+
+def get_tensorrt_use_implicit_batch_mode():
+    pass_ctx = tvm.transform.PassContext.current()
+    return pass_ctx.config["relay.ext.tensorrt.options"].use_implicit_batch
+
+def get_tensorrt_remove_no_mac_subgraphs():
+    pass_ctx = tvm.transform.PassContext.current()
+    return pass_ctx.config["relay.ext.tensorrt.options"].remove_no_mac_subgraphs
+
+def partition_for_tensorrt(mod, params=None, version=None, use_implicit_batch=True,
+                           remove_no_mac_subgraphs=False, max_workspace_size=1 << 30):
+    """Partition the graph greedily offloading supported
+    operators to TensorRT.
+    Parameters
+    ----------
+    mod : Module
+        The module to run passes on.
+    params : Optional[Dict[str, NDArray]]
+        Constant input parameters.
+    version : Optional[Tuple(int)]
+        TensorRT version to target as tuple of (major, minor, patch). If TVM is compiled with
+        USE_TENSORRT_GRAPH_RUNTIME=ON, the linked TensorRT version will be used instead.
+    use_implicit_batch : Optional[bool]
+        Use TensorRT implicit batch mode (default true). Setting to false will enable explicit batch
+        mode which will widen supported operators to include those which modify the batch dimension,
+        but may reduce performance for some models.
+    remove_no_mac_subgraphs : Optional[bool]
+        Removes subgraphs which have been partitioned for TensorRT if they do not have any
+        multiply-accumulate operations. The removed subgraphs will go through TVM's standard
+        compilation instead. Can improve performance.
+    max_workspace_size : Optional[int]
+        How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation.
+        See TensorRT documentation for more info.
+    Returns
+    -------
+    mod : annotated and partitioned module.
+    config : "relay.ext.tensorrt.options" configuration which should be given to PassContext when building.
+    """
+    config = {
+        "use_implicit_batch": use_implicit_batch,
+        "max_workspace_size": max_workspace_size,
+        "remove_no_mac_subgraphs": remove_no_mac_subgraphs
+    }
+    if version:
+        assert isinstance(version, tuple) and len(version) == 3
+        config["tensorrt_version"] = version
+    else:
+        linked_version = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+        if not linked_version:
+            logging.warn("TVM was not built against TensorRT and no version was provided to partition_for_tensorrt. Defaulting to 6.0.1")
+            linked_version = (6, 0, 1)
+        config["tensorrt_version"] = linked_version
+
+    if params:
+        mod['main'] = bind_params_by_name(mod['main'], params)
+    seq = tvm.transform.Sequential([transform.InferType(),
+                                    RemoveDropoutPass(),
+                                    transform.RemoveUnusedFunctions(),
+                                    transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default'],
+                                                             'nn.conv3d': ['NCDHW', 'default']}),
+                                    transform.FoldConstant(),
+                                    transform.AnnotateTarget('tensorrt'),
+                                    transform.MergeCompilerRegions(),
+                                    transform.PartitionGraph(),
+                                    transform.InferType()])
+    with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
+        mod = seq(mod)
+        mod = prune_tensorrt_subgraphs(mod)
+    return mod, config
+
+
+def _register_external_op_helper(op_name, supported=True):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return supported
+    return _func_wrapper
+
+
+def _register_external_op_helper_func(op_name, func):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return func(attrs, args, op_name)
+    return _func_wrapper
+
+
+# Ops which are always supported
+_register_external_op_helper("nn.relu")
+_register_external_op_helper("sigmoid")
+_register_external_op_helper("tanh")
+_register_external_op_helper("subtract")
+_register_external_op_helper("multiply")
+_register_external_op_helper("divide")
+_register_external_op_helper("power")
+_register_external_op_helper("maximum")
+_register_external_op_helper("minimum")
+_register_external_op_helper("exp")
+_register_external_op_helper("log")
+_register_external_op_helper("sqrt")
+_register_external_op_helper("abs")
+_register_external_op_helper("negative")
+_register_external_op_helper("nn.batch_flatten")
+_register_external_op_helper("clip")
+
+@tvm.ir.register_op_attr("add", "target.tensorrt")
+def add_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if not get_tensorrt_use_implicit_batch_mode() and \
+            (isinstance(args[0], Constant) or isinstance(args[1], Constant)) and \
+            args[0].checked_type.shape[0] == args[1].checked_type.shape[0] and \
+            args[0].checked_type.shape[0] != 1 and \
+            (len(args[0].checked_type.shape) > 3 or len(args[1].checked_type.shape) > 3):
+        print("add: bug in TRT with adding batched constants.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.batch_norm", "target.tensorrt")
+def batch_norm_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if int(attrs.axis) not in (1, 3):
+        print("nn.batch_norm: axis is {} but must be 1 or 3.".format(int(attrs.axis)))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.softmax", "target.tensorrt")
+def softmax_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
+        print("nn.softmax: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.conv2d", "target.tensorrt")
+def conv2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.data_layout != "NCHW":
+        print("nn.conv2d: data_layout is {} but must be NCHW.".format(attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIHW":
+        print("nn.conv2d: kernel_layout is {} but must be OIHW.".format(attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCHW":
+        print("nn.conv2d: out_layout is {} but must be NCHW.".format(attrs.out_layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.dense", "target.tensorrt")
+def dense_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    input_rank = len(args[0].checked_type.shape)
+    weight_rank = len(args[1].checked_type.shape)
+    if input_rank not in (2, 3, 4):
+        print("nn.dense: input has rank {} but must be 2, 3 or 4.".format(input_rank))
+        return False
+    if weight_rank != 2:
+        print("nn.dense: weight has rank {} but must be 2.".format(weight_rank))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.bias_add", "target.tensorrt")
+def bias_add_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    input_rank = len(args[0].checked_type.shape)
+    if input_rank not in (2, 3, 4):
+        print("nn.bias_add: input rank is {} but must be 2, 3 or 4.".format(input_rank))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.max_pool2d", "target.tensorrt")
+def max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.max_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    if attrs.ceil_mode and get_tensorrt_version() < (5, 1, 5):
+        print("nn.avg_pool2d: ceil_mode=True requires TensorRT 5.1.5 or greater.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.avg_pool2d", "target.tensorrt")
+def avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.avg_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    if attrs.count_include_pad and len(attrs.padding) == 4:
+        print("nn.avg_pool2d: inclusive-counted blended or average "
+                "pooling is not supported in combination with asymmetric padding")
+        return False
+    if attrs.ceil_mode and get_tensorrt_version() < (5, 1, 5):
+        print("nn.avg_pool2d: ceil_mode=True requires TensorRT 5.1.5 or greater.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.tensorrt")
+def global_max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.global_max_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.tensorrt")
+def global_avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.global_avg_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("expand_dims", "target.tensorrt")
+def expand_dims_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
+        print("expand_dims: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("squeeze", "target.tensorrt")
+def squeeze_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if not attrs.axis:
+        print("squeeze: must explicitly set axis.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and any([axis == 0 for axis in map(int, attrs.axis)]):
+        print("squeeze: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("concatenate", "target.tensorrt")
+def concatenate_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.dtype != "float32" for x in args[0].checked_type.fields]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if not get_tensorrt_use_implicit_batch_mode():
+        return True
+    if int(attrs.axis) == 0:
+        print("concatenate: can't modify batch dimension.")
+        return False
+    if isinstance(args[0], Tuple):
+        for tuple_input in args[0].fields:
+            if isinstance(tuple_input, Constant):
+                print("concatenate: can't concatenate tensors with constants.")
+                return False
+    return True
+
+@tvm.ir.register_op_attr("nn.conv2d_transpose", "target.tensorrt")
+def conv2d_transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.data_layout != "NCHW":
+        print("nn.conv2d_transpose: data_layout is {} but must be NCHW.".format(
+            attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIHW":
+        print("nn.conv2d_transpose: kernel_layout is {} but must be OIHW.".format(
+            attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCHW":
+        print("nn.conv2d_transpose: out_layout is {} but must be NCHW.".format(
+            attrs.out_layout))
+        return False
+    if attrs.dilation and any([rate != 1 for rate in map(int, attrs.dilation)]):
+        print("nn.conv2d_transpose: dilation rate must be 1.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("transpose", "target.tensorrt")
+def transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axes[0]) != 0:
+        print("transpose: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("layout_transform", "target.tensorrt")
+def resize_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if (attrs.src_layout, attrs.dst_layout) not in [("NCHW", "NHWC"), ("NHWC", "NCHW"), ("NDHWC", "NCDHW"), ("NCDHW", "NDHWC")]:
+        print("layout_transform: {} to {} is not supported.".format(attrs.src_layout, attrs.dst_layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("reshape", "target.tensorrt")
+def reshape_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if args[0].checked_type.dtype != "float32":
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if any([x < -1 for x in map(int, attrs.newshape)]):
+        print("reshape: new shape dims must be explicit.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode():
+        shape = list(map(int, args[0].checked_type.shape))
+        new_shape = list(map(int, attrs.newshape))
+        if len(new_shape) == 0 or len(shape) == 0:
+            print("reshape: Can't reshape to or from scalar.")
+            return False
+        # TRT cannot modify batch dimension.
+        original_volume = np.prod(shape)
+        # First, resolve 0.
+        for i, value in enumerate(new_shape):
+            if value == 0:
+                new_shape[i] = shape[i]
+        # Resolve -1.
+        for i, value in enumerate(new_shape):
+            if value == -1:
+                new_shape[i] = original_volume // np.prod([x for x in new_shape if x != -1])
+        if shape[0] != new_shape[0]:
+            print("reshape: can't modify batch dimension.")
+            return False
+    return True
+
+@tvm.ir.register_op_attr("nn.pad", "target.tensorrt")
+def pad_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.pad_mode != "constant":
+        print("nn.pad: pad mode is {} but must be constant.".format(attrs.pad_mode))
+        return False
+    if float(attrs.pad_value) != 0.0:
+        print("nn.pad: pad value is {} but must be 0.0.".format(float(attrs.pad_value)))
+        return False
+    if any([x != 0 for x in attrs.pad_width[0]]) or any([x != 0 for x in attrs.pad_width[1]]):
+        print("nn.pad: can't pad batch or channel dimensions.")
+        return False
+    if len(attrs.pad_width) == 5 and any([x != 0 for x in attrs.pad_width[2]]):
+        print("nn.pad: can only pad last two dimensions for 5D inputs.")
+    return True
+
+def reduce_annotate_fn(attrs, args, op_name):
+    if not attrs.axis or len(attrs.axis) == 0:
+        print("{}: cannot reduce to scalar.".format(op_name))
+        return False
+    if attrs.exclude:
+        print("{}: exclude not supported.".format(op_name))
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and any([x == 0 for x in map(int, attrs.axis)]):
+        print("{}: can't modify batch dimension.".format(op_name))
+        return False
+    return True
+
+_register_external_op_helper_func("sum", reduce_annotate_fn)
+_register_external_op_helper_func("prod", reduce_annotate_fn)
+_register_external_op_helper_func("max", reduce_annotate_fn)
+_register_external_op_helper_func("min", reduce_annotate_fn)
+_register_external_op_helper_func("mean", reduce_annotate_fn)
+
+def trt_5_1_5_annotate_fn(attrs, args, op_name):
+    if get_tensorrt_version() < (5, 1, 5):
+        print("{}: requires TensorRT version 5.1.5 or higher.".format(op_name))
+        return False
+    return True
+
+_register_external_op_helper_func("nn.leaky_relu", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("sin", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("cos", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("atan", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("ceil", trt_5_1_5_annotate_fn)
+
+@tvm.ir.register_op_attr("strided_slice", "target.tensorrt")
+def strided_slice_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if args[0].checked_type.dtype != "float32":
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if not trt_5_1_5_annotate_fn(attrs, args, "strided_slice"):
+        return False
+    if get_tensorrt_use_implicit_batch_mode():
+        batch_dim_begin_modified = attrs.begin[0] is not None and int(attrs.begin[0]) != 0
+        batch_dim_end_modified = attrs.end[0] is not None and int(attrs.end[0]) != -1 and \
+                                    int(attrs.end[0]) != int(args[0].checked_type.shape[0])
+        if batch_dim_begin_modified or batch_dim_end_modified:
+            print("strided_slice: can't modify batch dimension.")
+            return False
+    if any([x is not None and x <= 0 for x in attrs.strides]):
+        print("strided_slice: stride must be positive")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.adaptive_max_pool2d", "target.tensorrt")
+def adapative_max_pool2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
+        print("nn.adaptive_max_pool2d: output size must be (1, 1).")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.adaptive_avg_pool2d", "target.tensorrt")
+def adapative_avg_pool2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
+        print("nn.adaptive_avg_pool2d: output size must be (1, 1).")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.upsampling", "target.tensorrt")
+def upsampling_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    # TODO(trevmorr): Output does not match TVM. Disable.
+    return False
+
+@tvm.ir.register_op_attr("nn.conv3d", "target.tensorrt")
+def conv3d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_version() < (6, 0, 1):
+        print("nn.conv3d: requires TensorRT version 6.0.1 or higher.")
+        return False
+    if attrs.data_layout != "NCDHW":
+        print("nn.conv3d: data_layout is {} but must be NCDHW.".format(attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIDHW":
+        print("nn.conv3d: kernel_layout is {} but must be OIDHW.".format(attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCDHW":
+        print("nn.conv3d: out_layout is {} but must be NCDHW.".format(attrs.out_layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.max_pool3d", "target.tensorrt")
+def max_pool_3d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_version() < (6, 0, 1):
+        print("nn.max_pool3d: requires TensorRT version 6.0.1 or higher.")
+        return False
+    if attrs.layout != "NCDHW":
+        print("nn.max_pool3d: layout is {} but must be NCDHW.".format(attrs.layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.avg_pool3d", "target.tensorrt")
+def avg_pool_3d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_version() < (6, 0, 1):
+        print("nn.avg_pool3d: requires TensorRT version 6.0.1 or higher.")
+        return False
+    if attrs.layout != "NCDHW":
+        print("nn.avg_pool3d: layout is {} but must be NCDHW.".format(attrs.layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.conv3d_transpose", "target.tensorrt")
+def conv3d_transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_version() < (6, 0, 1):
+        print("nn.conv3d_transpose: requires TensorRT version 6.0.1 or higher.")
+        return False
+    if attrs.data_layout != "NCDHW":
+        print("nn.conv3d_transpose: data_layout is {} but must be NCDHW.".format(
+            attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIDHW":
+        print("nn.conv3d_transpose: kernel_layout is {} but must be OIDHW.".format(
+            attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCDHW":
+        print("nn.conv3d_transpose: out_layout is {} but must be NCDHW.".format(
+            attrs.out_layout))
+        return False
+    if attrs.dilation and any([rate != 1 for rate in map(int, attrs.dilation)]):
+        print("nn.conv3d_transpose: dilation rate must be 1.")
+        return False
+    if attrs.output_padding and any([x != 0 for x in map(int, attrs.output_padding)]):
+        print("nn.conv3d_transpose: output padding is not supported.")
+        return False
+    return True
+
+def is_invalid_subgraph(params, body):
+    # Remove invalid subgraphs for implicit batch mode.
+    if get_tensorrt_use_implicit_batch_mode():
+        input_batch_sizes = []
+        for var in params:
+            # In implicit batch mode, all inputs must have same batch size
+            if isinstance(var.checked_type, relay.TupleType):
+                for tupe_type in var.checked_type.fields:
+                    # Scalar inputs not allowed
+                    if len(tupe_type.shape) == 0:
+                        print('tensorrt: scalar inputs not supported')
+                        return True
+                    input_batch_sizes.append(int(tupe_type.shape[0]))
+            else:
+                # Scalar inputs not allowed
+                if len(var.checked_type.shape) == 0:
+                    print('tensorrt: scalar inputs not supported')
+                    return True
+                input_batch_sizes.append(int(var.checked_type.shape[0]))
+        if len(input_batch_sizes) > 1 and \

Review comment:
       `if len(input_batch_sizes) > 1 and len(set(input_batch_sizes)) == 1`?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org