You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/09/04 09:16:43 UTC

[GitHub] [incubator-tvm] leandron commented on a change in pull request #6395: [BYOC][TensorRT] TensorRT BYOC integration

leandron commented on a change in pull request #6395:
URL: https://github.com/apache/incubator-tvm/pull/6395#discussion_r483477103



##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -0,0 +1,675 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""TensorRT supported operators."""
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.expr import Call, Constant, Tuple, GlobalVar
+from tvm.relay.expr_functor import ExprMutator
+
+import os
+import numpy as np
+
+# Version to use for annotation when there is no linked TRT.
+TENSORRT_VERSION = (6, 0, 1)
+USE_IMPLICIT_BATCH = True
+REMOVE_NO_MAC_SUBGRAPHS = False
+
+def is_tensorrt_runtime_enabled():
+    """Check if the TensorRT graph runtime is present.
+    Returns
+    -------
+    ret: bool
+        True if present, False if not.
+    """
+    check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
+    if check_enabled:
+        return check_enabled()
+    return False
+
+def get_tensorrt_version():
+    """Gets the version of TensorRT that TVM is built against.
+
+    Returns
+    -------
+    ret: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, the value set by set_tensorrt_version() is returned instead.
+    """
+    linked_ver = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+    if len(linked_ver) == 3:
+        return linked_ver
+    return TENSORRT_VERSION
+
+def set_tensorrt_version(version):
+    """Override TensorRT version for annotation
+
+    Returns
+    -------
+    version: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, an empty tuple is returned instead.
+    """
+    global TENSORRT_VERSION
+    TENSORRT_VERSION = version
+
+def get_tensorrt_use_implicit_batch_mode():
+    return USE_IMPLICIT_BATCH
+
+def set_tensorrt_use_implicit_batch_mode(use_implicit_batch):
+    global USE_IMPLICIT_BATCH
+    USE_IMPLICIT_BATCH = use_implicit_batch
+
+def get_tensorrt_remove_no_mac_subgraphs():
+    return REMOVE_NO_MAC_SUBGRAPHS
+
+def set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs):
+    global REMOVE_NO_MAC_SUBGRAPHS
+    REMOVE_NO_MAC_SUBGRAPHS = remove_no_mac_subgraphs
+
+def partition_for_tensorrt(mod, params=None, version=None, use_implicit_batch=True, remove_no_mac_subgraphs=False, max_workspace_size=1 << 30):
+    """Partition the graph greedily offloading supported
+    operators to TensorRT.
+    Parameters
+    ----------
+    mod : Module
+        The module to run passes on.
+    params : Optional[Dict[str, NDArray]]
+        Constant input parameters.
+    version : Optional[Tuple(int)]
+        TensorRT version to target as tuple of (major, minor, patch). Will use linked TRT version if available if version is not specified.
+    use_implicit_batch : Optional[bool]
+
+    remove_no_mac_subgraphs : Optional[bool]
+
+    Returns
+    -------
+    ret : annotated and partitioned module.

Review comment:
       Maybe `ret` can be replaced by `mod` here ?

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -0,0 +1,675 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""TensorRT supported operators."""
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.expr import Call, Constant, Tuple, GlobalVar
+from tvm.relay.expr_functor import ExprMutator
+
+import os
+import numpy as np
+
+# Version to use for annotation when there is no linked TRT.
+TENSORRT_VERSION = (6, 0, 1)
+USE_IMPLICIT_BATCH = True
+REMOVE_NO_MAC_SUBGRAPHS = False
+
+def is_tensorrt_runtime_enabled():
+    """Check if the TensorRT graph runtime is present.
+    Returns
+    -------
+    ret: bool
+        True if present, False if not.
+    """
+    check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
+    if check_enabled:
+        return check_enabled()
+    return False
+
+def get_tensorrt_version():
+    """Gets the version of TensorRT that TVM is built against.
+
+    Returns
+    -------
+    ret: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, the value set by set_tensorrt_version() is returned instead.
+    """
+    linked_ver = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+    if len(linked_ver) == 3:
+        return linked_ver
+    return TENSORRT_VERSION

Review comment:
       minor: can be simplified to
   
   ```
   return linked_var len(linked_ver) == 3 else TENSORRT_VERSION
   ```

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -0,0 +1,675 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""TensorRT supported operators."""
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.expr import Call, Constant, Tuple, GlobalVar
+from tvm.relay.expr_functor import ExprMutator
+
+import os
+import numpy as np
+
+# Version to use for annotation when there is no linked TRT.
+TENSORRT_VERSION = (6, 0, 1)
+USE_IMPLICIT_BATCH = True
+REMOVE_NO_MAC_SUBGRAPHS = False
+
+def is_tensorrt_runtime_enabled():
+    """Check if the TensorRT graph runtime is present.
+    Returns
+    -------
+    ret: bool
+        True if present, False if not.
+    """
+    check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
+    if check_enabled:
+        return check_enabled()
+    return False
+
+def get_tensorrt_version():
+    """Gets the version of TensorRT that TVM is built against.
+
+    Returns
+    -------
+    ret: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, the value set by set_tensorrt_version() is returned instead.
+    """
+    linked_ver = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+    if len(linked_ver) == 3:
+        return linked_ver
+    return TENSORRT_VERSION
+
+def set_tensorrt_version(version):
+    """Override TensorRT version for annotation
+
+    Returns
+    -------
+    version: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, an empty tuple is returned instead.
+    """
+    global TENSORRT_VERSION
+    TENSORRT_VERSION = version
+
+def get_tensorrt_use_implicit_batch_mode():
+    return USE_IMPLICIT_BATCH
+
+def set_tensorrt_use_implicit_batch_mode(use_implicit_batch):
+    global USE_IMPLICIT_BATCH
+    USE_IMPLICIT_BATCH = use_implicit_batch
+
+def get_tensorrt_remove_no_mac_subgraphs():
+    return REMOVE_NO_MAC_SUBGRAPHS
+
+def set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs):
+    global REMOVE_NO_MAC_SUBGRAPHS
+    REMOVE_NO_MAC_SUBGRAPHS = remove_no_mac_subgraphs
+
+def partition_for_tensorrt(mod, params=None, version=None, use_implicit_batch=True, remove_no_mac_subgraphs=False, max_workspace_size=1 << 30):
+    """Partition the graph greedily offloading supported
+    operators to TensorRT.
+    Parameters
+    ----------
+    mod : Module
+        The module to run passes on.
+    params : Optional[Dict[str, NDArray]]
+        Constant input parameters.
+    version : Optional[Tuple(int)]
+        TensorRT version to target as tuple of (major, minor, patch). Will use linked TRT version if available if version is not specified.
+    use_implicit_batch : Optional[bool]
+
+    remove_no_mac_subgraphs : Optional[bool]
+
+    Returns
+    -------
+    ret : annotated and partitioned module.
+    """
+    if version:
+        assert isinstance(version, tuple) and len(version) == 3
+        set_tensorrt_version(version)
+    set_tensorrt_use_implicit_batch_mode(use_implicit_batch)
+    set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs)
+    if params:
+        mod['main'] = bind_params_by_name(mod['main'], params)
+
+    seq = tvm.transform.Sequential([transform.InferType(),
+                                    RemoveDropoutPass(),
+                                    transform.RemoveUnusedFunctions(),
+                                    transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default'],
+                                                             'nn.conv3d': ['NCDHW', 'default']}),
+                                    transform.FoldConstant(),
+                                    transform.AnnotateTarget('tensorrt'),
+                                    transform.MergeCompilerRegions(),
+                                    transform.PartitionGraph(),
+                                    transform.InferType()])
+    with tvm.transform.PassContext(opt_level=3):
+        mod = seq(mod)
+    mod = prune_tensorrt_subgraphs(mod)
+    # Pass parameters to codegen
+    os.environ["TVM_TENSORRT_USE_IMPLICIT_BATCH"] = str(int(use_implicit_batch))
+    os.environ["TVM_TENSORRT_MAX_WORKSPACE_SIZE"] = str(int(max_workspace_size))
+    return mod
+
+
+def _register_external_op_helper(op_name, supported=True):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return supported
+    return _func_wrapper
+
+
+def _register_external_op_helper_func(op_name, func):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return func(attrs, args, op_name)
+    return _func_wrapper
+
+
+# Ops which are always supported
+_register_external_op_helper("nn.relu")
+_register_external_op_helper("sigmoid")
+_register_external_op_helper("tanh")
+_register_external_op_helper("subtract")
+_register_external_op_helper("multiply")
+_register_external_op_helper("divide")
+_register_external_op_helper("power")
+_register_external_op_helper("maximum")
+_register_external_op_helper("minimum")
+_register_external_op_helper("exp")
+_register_external_op_helper("log")
+_register_external_op_helper("sqrt")
+_register_external_op_helper("abs")
+_register_external_op_helper("negative")
+_register_external_op_helper("nn.batch_flatten")
+_register_external_op_helper("clip")
+
+@tvm.ir.register_op_attr("add", "target.tensorrt")
+def add_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if (isinstance(args[0], Constant) or isinstance(args[1], Constant)) and \
+            args[0].checked_type.shape[0] == args[0].checked_type.shape[0] and \

Review comment:
       This comparison `args[0].checked_type.shape[0] == args[0].checked_type.shape[0]` needs some review as it is comparing something with itself. Maybe the right side is expected to be `args[1]` ?

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -0,0 +1,675 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""TensorRT supported operators."""
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.expr import Call, Constant, Tuple, GlobalVar
+from tvm.relay.expr_functor import ExprMutator
+
+import os
+import numpy as np
+
+# Version to use for annotation when there is no linked TRT.
+TENSORRT_VERSION = (6, 0, 1)
+USE_IMPLICIT_BATCH = True
+REMOVE_NO_MAC_SUBGRAPHS = False
+
+def is_tensorrt_runtime_enabled():
+    """Check if the TensorRT graph runtime is present.
+    Returns
+    -------
+    ret: bool
+        True if present, False if not.
+    """
+    check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
+    if check_enabled:
+        return check_enabled()
+    return False
+
+def get_tensorrt_version():
+    """Gets the version of TensorRT that TVM is built against.
+
+    Returns
+    -------
+    ret: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, the value set by set_tensorrt_version() is returned instead.
+    """
+    linked_ver = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+    if len(linked_ver) == 3:
+        return linked_ver
+    return TENSORRT_VERSION
+
+def set_tensorrt_version(version):
+    """Override TensorRT version for annotation
+
+    Returns
+    -------
+    version: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, an empty tuple is returned instead.
+    """
+    global TENSORRT_VERSION
+    TENSORRT_VERSION = version
+
+def get_tensorrt_use_implicit_batch_mode():
+    return USE_IMPLICIT_BATCH
+
+def set_tensorrt_use_implicit_batch_mode(use_implicit_batch):
+    global USE_IMPLICIT_BATCH
+    USE_IMPLICIT_BATCH = use_implicit_batch
+
+def get_tensorrt_remove_no_mac_subgraphs():
+    return REMOVE_NO_MAC_SUBGRAPHS
+
+def set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs):
+    global REMOVE_NO_MAC_SUBGRAPHS
+    REMOVE_NO_MAC_SUBGRAPHS = remove_no_mac_subgraphs
+
+def partition_for_tensorrt(mod, params=None, version=None, use_implicit_batch=True, remove_no_mac_subgraphs=False, max_workspace_size=1 << 30):
+    """Partition the graph greedily offloading supported
+    operators to TensorRT.
+    Parameters
+    ----------
+    mod : Module
+        The module to run passes on.
+    params : Optional[Dict[str, NDArray]]
+        Constant input parameters.
+    version : Optional[Tuple(int)]
+        TensorRT version to target as tuple of (major, minor, patch). Will use linked TRT version if available if version is not specified.
+    use_implicit_batch : Optional[bool]
+
+    remove_no_mac_subgraphs : Optional[bool]
+
+    Returns
+    -------
+    ret : annotated and partitioned module.
+    """
+    if version:
+        assert isinstance(version, tuple) and len(version) == 3
+        set_tensorrt_version(version)
+    set_tensorrt_use_implicit_batch_mode(use_implicit_batch)
+    set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs)
+    if params:
+        mod['main'] = bind_params_by_name(mod['main'], params)
+
+    seq = tvm.transform.Sequential([transform.InferType(),
+                                    RemoveDropoutPass(),
+                                    transform.RemoveUnusedFunctions(),
+                                    transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default'],
+                                                             'nn.conv3d': ['NCDHW', 'default']}),
+                                    transform.FoldConstant(),
+                                    transform.AnnotateTarget('tensorrt'),
+                                    transform.MergeCompilerRegions(),
+                                    transform.PartitionGraph(),
+                                    transform.InferType()])
+    with tvm.transform.PassContext(opt_level=3):
+        mod = seq(mod)
+    mod = prune_tensorrt_subgraphs(mod)
+    # Pass parameters to codegen
+    os.environ["TVM_TENSORRT_USE_IMPLICIT_BATCH"] = str(int(use_implicit_batch))
+    os.environ["TVM_TENSORRT_MAX_WORKSPACE_SIZE"] = str(int(max_workspace_size))
+    return mod
+
+
+def _register_external_op_helper(op_name, supported=True):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return supported
+    return _func_wrapper
+
+
+def _register_external_op_helper_func(op_name, func):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return func(attrs, args, op_name)
+    return _func_wrapper
+
+
+# Ops which are always supported
+_register_external_op_helper("nn.relu")
+_register_external_op_helper("sigmoid")
+_register_external_op_helper("tanh")
+_register_external_op_helper("subtract")
+_register_external_op_helper("multiply")
+_register_external_op_helper("divide")
+_register_external_op_helper("power")
+_register_external_op_helper("maximum")
+_register_external_op_helper("minimum")
+_register_external_op_helper("exp")
+_register_external_op_helper("log")
+_register_external_op_helper("sqrt")
+_register_external_op_helper("abs")
+_register_external_op_helper("negative")
+_register_external_op_helper("nn.batch_flatten")
+_register_external_op_helper("clip")
+
+@tvm.ir.register_op_attr("add", "target.tensorrt")
+def add_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if (isinstance(args[0], Constant) or isinstance(args[1], Constant)) and \
+            args[0].checked_type.shape[0] == args[0].checked_type.shape[0] and \
+            args[0].checked_type.shape[0] != 1 and \
+            (len(args[0].checked_type.shape) > 3 or len(args[1].checked_type.shape) > 3):
+        print("add: bug in TRT with adding batched constants.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.batch_norm", "target.tensorrt")
+def batch_norm_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if int(attrs.axis) != 1 and int(attrs.axis) != 3:
+        print("nn.batch_norm: axis is {} but must be 1 or 3.".format(int(attrs.axis)))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.softmax", "target.tensorrt")
+def softmax_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
+        print("nn.softmax: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.conv2d", "target.tensorrt")
+def conv2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.data_layout != "NCHW":
+        print("nn.conv2d: data_layout is {} but must be NCHW.".format(attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIHW":
+        print("nn.conv2d: kernel_layout is {} but must be OIHW.".format(attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCHW":
+        print("nn.conv2d: out_layout is {} but must be NCHW.".format(attrs.out_layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.dense", "target.tensorrt")
+def dense_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    input_rank = len(args[0].checked_type.shape)
+    weight_rank = len(args[1].checked_type.shape)
+    if input_rank < 2 or input_rank > 4:
+        print("nn.dense: input has rank {} but must be 2, 3 or 4.".format(input_rank))
+        return False
+    if weight_rank != 2:
+        print("nn.dense: weight has rank {} but must be 2.".format(weight_rank))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.bias_add", "target.tensorrt")
+def bias_add_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    input_rank = len(args[0].checked_type.shape)
+    if input_rank < 2 or input_rank > 4:
+        print("nn.bias_add: input rank is {} but must be 2, 3 or 4.".format(input_rank))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.max_pool2d", "target.tensorrt")
+def max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.max_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    if attrs.ceil_mode and get_tensorrt_version() < (5, 1, 5):
+        print("nn.avg_pool2d: ceil_mode=True requires TensorRT 5.1.5 or greater.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.avg_pool2d", "target.tensorrt")
+def avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.avg_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    if attrs.count_include_pad and len(attrs.padding) == 4:
+        print("nn.avg_pool2d: inclusive-counted blended or average "
+                "pooling is not supported in combination with asymmetric padding")
+        return False
+    if attrs.ceil_mode and get_tensorrt_version() < (5, 1, 5):
+        print("nn.avg_pool2d: ceil_mode=True requires TensorRT 5.1.5 or greater.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.tensorrt")
+def global_max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.global_max_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.tensorrt")
+def global_avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.global_avg_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("expand_dims", "target.tensorrt")
+def expand_dims_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
+        print("expand_dims: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("squeeze", "target.tensorrt")
+def squeeze_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if not attrs.axis:
+        print("squeeze: must explicitly set axis.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and any([axis == 0 for axis in map(int, attrs.axis)]):
+        print("squeeze: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("concatenate", "target.tensorrt")
+def concatenate_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.dtype != "float32" for x in args[0].checked_type.fields]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if not get_tensorrt_use_implicit_batch_mode():
+        return True
+    if int(attrs.axis) == 0:
+        print("concatenate: can't modify batch dimension.")
+        return False
+    if isinstance(args[0], Tuple):
+        for tuple_input in args[0].fields:
+            if isinstance(tuple_input, Constant):
+                print("concatenate: can't concatenate tensors with constants.")
+                return False
+    return True
+
+@tvm.ir.register_op_attr("nn.conv2d_transpose", "target.tensorrt")
+def conv2d_transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.data_layout != "NCHW":
+        print("nn.conv2d_transpose: data_layout is {} but must be NCHW.".format(
+            attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIHW":
+        print("nn.conv2d_transpose: kernel_layout is {} but must be OIHW.".format(
+            attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCHW":
+        print("nn.conv2d_transpose: out_layout is {} but must be NCHW.".format(
+            attrs.out_layout))
+        return False
+    if attrs.dilation and any([rate != 1 for rate in map(int, attrs.dilation)]):
+        print("nn.conv2d_transpose: dilation rate must be 1.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("transpose", "target.tensorrt")
+def transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axes[0]) != 0:
+        print("transpose: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("layout_transform", "target.tensorrt")
+def resize_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if (attrs.src_layout, attrs.dst_layout) not in [("NCHW", "NHWC"), ("NHWC", "NCHW"), ("NDHWC", "NCDHW"), ("NCDHW", "NDHWC")]:
+        print("layout_transform: {} to {} is not supported.".format(attrs.src_layout, attrs.dst_layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("reshape", "target.tensorrt")
+def reshape_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if args[0].checked_type.dtype != "float32":
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if any([x < -1 for x in map(int, attrs.newshape)]):
+        print("reshape: new shape dims must be explicit.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode():
+        shape = list(map(int, args[0].checked_type.shape))
+        new_shape = list(map(int, attrs.newshape))
+        if len(new_shape) == 0 or len(shape) == 0:
+            print("reshape: Can't reshape to or from scalar.")
+            return False
+        # TRT cannot modify batch dimension.
+        original_volume = np.prod(shape)
+        # First, resolve 0.
+        for i, value in enumerate(new_shape):
+            if value == 0:
+                new_shape[i] = shape[i]
+        # Resolve -1.
+        for i, value in enumerate(new_shape):
+            if value == -1:
+                new_shape[i] = original_volume // np.prod([x for x in new_shape if x != -1])
+        if shape[0] != new_shape[0]:
+            print("reshape: can't modify batch dimension.")
+            return False
+    return True
+
+@tvm.ir.register_op_attr("nn.pad", "target.tensorrt")
+def pad_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.pad_mode != "constant":
+        print("nn.pad: pad mode is {} but must be constant.".format(attrs.pad_mode))
+        return False
+    if float(attrs.pad_value) != 0.0:
+        print("nn.pad: pad value is {} but must be 0.0.".format(float(attrs.pad_value)))
+        return False
+    if any([x != 0 for x in attrs.pad_width[0]]) or any([x != 0 for x in attrs.pad_width[1]]):
+        print("nn.pad: can't pad batch or channel dimensions.")
+        return False
+    if len(attrs.pad_width) == 5 and any([x != 0 for x in attrs.pad_width[2]]):
+        print("nn.pad: can only pad last two dimensions for 5D inputs.")
+    return True
+
+def reduce_annotate_fn(attrs, args, op_name):
+    if not attrs.axis or len(attrs.axis) == 0:
+        print("{}: cannot reduce to scalar.".format(op_name))
+        return False
+    if attrs.exclude:
+        print("{}: exclude not supported.".format(op_name))
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and any([x == 0 for x in map(int, attrs.axis)]):
+        print("{}: can't modify batch dimension.".format(op_name))
+        return False
+    return True
+
+_register_external_op_helper_func("sum", reduce_annotate_fn)
+_register_external_op_helper_func("prod", reduce_annotate_fn)
+_register_external_op_helper_func("max", reduce_annotate_fn)
+_register_external_op_helper_func("min", reduce_annotate_fn)
+_register_external_op_helper_func("mean", reduce_annotate_fn)
+
+def trt_5_1_5_annotate_fn(attrs, args, op_name):
+    if get_tensorrt_version() < (5, 1, 5):
+        print("{}: requires TensorRT version 5.1.5 or higher.".format(op_name))
+        return False
+    return True
+
+_register_external_op_helper_func("nn.leaky_relu", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("sin", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("cos", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("atan", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("ceil", trt_5_1_5_annotate_fn)
+
+@tvm.ir.register_op_attr("strided_slice", "target.tensorrt")
+def strided_slice_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if args[0].checked_type.dtype != "float32":
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_version() < (5, 1, 5):
+        print("strided_slice: requires TensorRT version 5.1.5 or higher.")
+        return False
+    if args[0].checked_type.dtype != "float32":
+        print("strided_slice: only fp32 inputs are supported.")
+        return False

Review comment:
       This clause is repeated here, and can be removed maybe?

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -0,0 +1,675 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""TensorRT supported operators."""
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.expr import Call, Constant, Tuple, GlobalVar
+from tvm.relay.expr_functor import ExprMutator
+
+import os
+import numpy as np
+
+# Version to use for annotation when there is no linked TRT.
+TENSORRT_VERSION = (6, 0, 1)
+USE_IMPLICIT_BATCH = True
+REMOVE_NO_MAC_SUBGRAPHS = False
+
+def is_tensorrt_runtime_enabled():
+    """Check if the TensorRT graph runtime is present.
+    Returns
+    -------
+    ret: bool
+        True if present, False if not.
+    """
+    check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
+    if check_enabled:
+        return check_enabled()
+    return False
+
+def get_tensorrt_version():
+    """Gets the version of TensorRT that TVM is built against.
+
+    Returns
+    -------
+    ret: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, the value set by set_tensorrt_version() is returned instead.
+    """
+    linked_ver = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+    if len(linked_ver) == 3:
+        return linked_ver
+    return TENSORRT_VERSION
+
+def set_tensorrt_version(version):
+    """Override TensorRT version for annotation
+
+    Returns
+    -------
+    version: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, an empty tuple is returned instead.
+    """
+    global TENSORRT_VERSION
+    TENSORRT_VERSION = version
+
+def get_tensorrt_use_implicit_batch_mode():
+    return USE_IMPLICIT_BATCH
+
+def set_tensorrt_use_implicit_batch_mode(use_implicit_batch):
+    global USE_IMPLICIT_BATCH
+    USE_IMPLICIT_BATCH = use_implicit_batch
+
+def get_tensorrt_remove_no_mac_subgraphs():
+    return REMOVE_NO_MAC_SUBGRAPHS
+
+def set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs):
+    global REMOVE_NO_MAC_SUBGRAPHS
+    REMOVE_NO_MAC_SUBGRAPHS = remove_no_mac_subgraphs
+
+def partition_for_tensorrt(mod, params=None, version=None, use_implicit_batch=True, remove_no_mac_subgraphs=False, max_workspace_size=1 << 30):
+    """Partition the graph greedily offloading supported
+    operators to TensorRT.
+    Parameters
+    ----------
+    mod : Module
+        The module to run passes on.
+    params : Optional[Dict[str, NDArray]]
+        Constant input parameters.
+    version : Optional[Tuple(int)]
+        TensorRT version to target as tuple of (major, minor, patch). Will use linked TRT version if available if version is not specified.
+    use_implicit_batch : Optional[bool]
+
+    remove_no_mac_subgraphs : Optional[bool]
+
+    Returns
+    -------
+    ret : annotated and partitioned module.
+    """
+    if version:
+        assert isinstance(version, tuple) and len(version) == 3
+        set_tensorrt_version(version)
+    set_tensorrt_use_implicit_batch_mode(use_implicit_batch)
+    set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs)
+    if params:
+        mod['main'] = bind_params_by_name(mod['main'], params)
+
+    seq = tvm.transform.Sequential([transform.InferType(),
+                                    RemoveDropoutPass(),
+                                    transform.RemoveUnusedFunctions(),
+                                    transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default'],
+                                                             'nn.conv3d': ['NCDHW', 'default']}),
+                                    transform.FoldConstant(),
+                                    transform.AnnotateTarget('tensorrt'),
+                                    transform.MergeCompilerRegions(),
+                                    transform.PartitionGraph(),
+                                    transform.InferType()])
+    with tvm.transform.PassContext(opt_level=3):
+        mod = seq(mod)
+    mod = prune_tensorrt_subgraphs(mod)
+    # Pass parameters to codegen
+    os.environ["TVM_TENSORRT_USE_IMPLICIT_BATCH"] = str(int(use_implicit_batch))
+    os.environ["TVM_TENSORRT_MAX_WORKSPACE_SIZE"] = str(int(max_workspace_size))
+    return mod
+
+
+def _register_external_op_helper(op_name, supported=True):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return supported
+    return _func_wrapper
+
+
+def _register_external_op_helper_func(op_name, func):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return func(attrs, args, op_name)
+    return _func_wrapper
+
+
+# Ops which are always supported
+_register_external_op_helper("nn.relu")
+_register_external_op_helper("sigmoid")
+_register_external_op_helper("tanh")
+_register_external_op_helper("subtract")
+_register_external_op_helper("multiply")
+_register_external_op_helper("divide")
+_register_external_op_helper("power")
+_register_external_op_helper("maximum")
+_register_external_op_helper("minimum")
+_register_external_op_helper("exp")
+_register_external_op_helper("log")
+_register_external_op_helper("sqrt")
+_register_external_op_helper("abs")
+_register_external_op_helper("negative")
+_register_external_op_helper("nn.batch_flatten")
+_register_external_op_helper("clip")
+
+@tvm.ir.register_op_attr("add", "target.tensorrt")
+def add_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if (isinstance(args[0], Constant) or isinstance(args[1], Constant)) and \
+            args[0].checked_type.shape[0] == args[0].checked_type.shape[0] and \
+            args[0].checked_type.shape[0] != 1 and \
+            (len(args[0].checked_type.shape) > 3 or len(args[1].checked_type.shape) > 3):
+        print("add: bug in TRT with adding batched constants.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.batch_norm", "target.tensorrt")
+def batch_norm_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if int(attrs.axis) != 1 and int(attrs.axis) != 3:
+        print("nn.batch_norm: axis is {} but must be 1 or 3.".format(int(attrs.axis)))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.softmax", "target.tensorrt")
+def softmax_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
+        print("nn.softmax: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.conv2d", "target.tensorrt")
+def conv2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.data_layout != "NCHW":
+        print("nn.conv2d: data_layout is {} but must be NCHW.".format(attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIHW":
+        print("nn.conv2d: kernel_layout is {} but must be OIHW.".format(attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCHW":
+        print("nn.conv2d: out_layout is {} but must be NCHW.".format(attrs.out_layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.dense", "target.tensorrt")
+def dense_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    input_rank = len(args[0].checked_type.shape)
+    weight_rank = len(args[1].checked_type.shape)
+    if input_rank < 2 or input_rank > 4:

Review comment:
       Suggestion: This could be simplified to `if input_rank not in (2, 3, 4):`.

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -0,0 +1,675 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""TensorRT supported operators."""
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.expr import Call, Constant, Tuple, GlobalVar
+from tvm.relay.expr_functor import ExprMutator
+
+import os
+import numpy as np
+
+# Version to use for annotation when there is no linked TRT.
+TENSORRT_VERSION = (6, 0, 1)
+USE_IMPLICIT_BATCH = True
+REMOVE_NO_MAC_SUBGRAPHS = False
+
+def is_tensorrt_runtime_enabled():
+    """Check if the TensorRT graph runtime is present.
+    Returns
+    -------
+    ret: bool
+        True if present, False if not.
+    """
+    check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
+    if check_enabled:
+        return check_enabled()
+    return False
+
+def get_tensorrt_version():
+    """Gets the version of TensorRT that TVM is built against.
+
+    Returns
+    -------
+    ret: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, the value set by set_tensorrt_version() is returned instead.
+    """
+    linked_ver = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+    if len(linked_ver) == 3:
+        return linked_ver
+    return TENSORRT_VERSION
+
+def set_tensorrt_version(version):
+    """Override TensorRT version for annotation
+
+    Returns
+    -------
+    version: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, an empty tuple is returned instead.
+    """
+    global TENSORRT_VERSION
+    TENSORRT_VERSION = version
+
+def get_tensorrt_use_implicit_batch_mode():
+    return USE_IMPLICIT_BATCH
+
+def set_tensorrt_use_implicit_batch_mode(use_implicit_batch):
+    global USE_IMPLICIT_BATCH
+    USE_IMPLICIT_BATCH = use_implicit_batch
+
+def get_tensorrt_remove_no_mac_subgraphs():
+    return REMOVE_NO_MAC_SUBGRAPHS
+
+def set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs):
+    global REMOVE_NO_MAC_SUBGRAPHS
+    REMOVE_NO_MAC_SUBGRAPHS = remove_no_mac_subgraphs
+
+def partition_for_tensorrt(mod, params=None, version=None, use_implicit_batch=True, remove_no_mac_subgraphs=False, max_workspace_size=1 << 30):
+    """Partition the graph greedily offloading supported
+    operators to TensorRT.
+    Parameters
+    ----------
+    mod : Module
+        The module to run passes on.
+    params : Optional[Dict[str, NDArray]]
+        Constant input parameters.
+    version : Optional[Tuple(int)]
+        TensorRT version to target as tuple of (major, minor, patch). Will use linked TRT version if available if version is not specified.
+    use_implicit_batch : Optional[bool]
+
+    remove_no_mac_subgraphs : Optional[bool]
+
+    Returns
+    -------
+    ret : annotated and partitioned module.
+    """
+    if version:
+        assert isinstance(version, tuple) and len(version) == 3
+        set_tensorrt_version(version)
+    set_tensorrt_use_implicit_batch_mode(use_implicit_batch)
+    set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs)
+    if params:
+        mod['main'] = bind_params_by_name(mod['main'], params)
+
+    seq = tvm.transform.Sequential([transform.InferType(),
+                                    RemoveDropoutPass(),
+                                    transform.RemoveUnusedFunctions(),
+                                    transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default'],
+                                                             'nn.conv3d': ['NCDHW', 'default']}),
+                                    transform.FoldConstant(),
+                                    transform.AnnotateTarget('tensorrt'),
+                                    transform.MergeCompilerRegions(),
+                                    transform.PartitionGraph(),
+                                    transform.InferType()])
+    with tvm.transform.PassContext(opt_level=3):
+        mod = seq(mod)
+    mod = prune_tensorrt_subgraphs(mod)
+    # Pass parameters to codegen
+    os.environ["TVM_TENSORRT_USE_IMPLICIT_BATCH"] = str(int(use_implicit_batch))
+    os.environ["TVM_TENSORRT_MAX_WORKSPACE_SIZE"] = str(int(max_workspace_size))
+    return mod
+
+
+def _register_external_op_helper(op_name, supported=True):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return supported
+    return _func_wrapper
+
+
+def _register_external_op_helper_func(op_name, func):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return func(attrs, args, op_name)
+    return _func_wrapper
+
+
+# Ops which are always supported
+_register_external_op_helper("nn.relu")
+_register_external_op_helper("sigmoid")
+_register_external_op_helper("tanh")
+_register_external_op_helper("subtract")
+_register_external_op_helper("multiply")
+_register_external_op_helper("divide")
+_register_external_op_helper("power")
+_register_external_op_helper("maximum")
+_register_external_op_helper("minimum")
+_register_external_op_helper("exp")
+_register_external_op_helper("log")
+_register_external_op_helper("sqrt")
+_register_external_op_helper("abs")
+_register_external_op_helper("negative")
+_register_external_op_helper("nn.batch_flatten")
+_register_external_op_helper("clip")
+
+@tvm.ir.register_op_attr("add", "target.tensorrt")
+def add_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if (isinstance(args[0], Constant) or isinstance(args[1], Constant)) and \
+            args[0].checked_type.shape[0] == args[0].checked_type.shape[0] and \
+            args[0].checked_type.shape[0] != 1 and \
+            (len(args[0].checked_type.shape) > 3 or len(args[1].checked_type.shape) > 3):
+        print("add: bug in TRT with adding batched constants.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.batch_norm", "target.tensorrt")
+def batch_norm_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if int(attrs.axis) != 1 and int(attrs.axis) != 3:
+        print("nn.batch_norm: axis is {} but must be 1 or 3.".format(int(attrs.axis)))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.softmax", "target.tensorrt")
+def softmax_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
+        print("nn.softmax: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.conv2d", "target.tensorrt")
+def conv2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.data_layout != "NCHW":
+        print("nn.conv2d: data_layout is {} but must be NCHW.".format(attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIHW":
+        print("nn.conv2d: kernel_layout is {} but must be OIHW.".format(attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCHW":
+        print("nn.conv2d: out_layout is {} but must be NCHW.".format(attrs.out_layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.dense", "target.tensorrt")
+def dense_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    input_rank = len(args[0].checked_type.shape)
+    weight_rank = len(args[1].checked_type.shape)
+    if input_rank < 2 or input_rank > 4:
+        print("nn.dense: input has rank {} but must be 2, 3 or 4.".format(input_rank))
+        return False
+    if weight_rank != 2:
+        print("nn.dense: weight has rank {} but must be 2.".format(weight_rank))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.bias_add", "target.tensorrt")
+def bias_add_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    input_rank = len(args[0].checked_type.shape)
+    if input_rank < 2 or input_rank > 4:

Review comment:
       Suggestion: This could be simplified to `if input_rank not in (2, 3, 4):`.

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -0,0 +1,675 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""TensorRT supported operators."""
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.expr import Call, Constant, Tuple, GlobalVar
+from tvm.relay.expr_functor import ExprMutator
+
+import os
+import numpy as np
+
+# Version to use for annotation when there is no linked TRT.
+TENSORRT_VERSION = (6, 0, 1)
+USE_IMPLICIT_BATCH = True
+REMOVE_NO_MAC_SUBGRAPHS = False
+
+def is_tensorrt_runtime_enabled():
+    """Check if the TensorRT graph runtime is present.
+    Returns
+    -------
+    ret: bool
+        True if present, False if not.
+    """
+    check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
+    if check_enabled:
+        return check_enabled()
+    return False
+
+def get_tensorrt_version():
+    """Gets the version of TensorRT that TVM is built against.
+
+    Returns
+    -------
+    ret: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, the value set by set_tensorrt_version() is returned instead.
+    """
+    linked_ver = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+    if len(linked_ver) == 3:
+        return linked_ver
+    return TENSORRT_VERSION
+
+def set_tensorrt_version(version):
+    """Override TensorRT version for annotation
+
+    Returns
+    -------
+    version: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, an empty tuple is returned instead.
+    """
+    global TENSORRT_VERSION
+    TENSORRT_VERSION = version
+
+def get_tensorrt_use_implicit_batch_mode():
+    return USE_IMPLICIT_BATCH
+
+def set_tensorrt_use_implicit_batch_mode(use_implicit_batch):
+    global USE_IMPLICIT_BATCH
+    USE_IMPLICIT_BATCH = use_implicit_batch
+
+def get_tensorrt_remove_no_mac_subgraphs():
+    return REMOVE_NO_MAC_SUBGRAPHS
+
+def set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs):
+    global REMOVE_NO_MAC_SUBGRAPHS
+    REMOVE_NO_MAC_SUBGRAPHS = remove_no_mac_subgraphs
+
+def partition_for_tensorrt(mod, params=None, version=None, use_implicit_batch=True, remove_no_mac_subgraphs=False, max_workspace_size=1 << 30):
+    """Partition the graph greedily offloading supported
+    operators to TensorRT.
+    Parameters
+    ----------
+    mod : Module
+        The module to run passes on.
+    params : Optional[Dict[str, NDArray]]
+        Constant input parameters.
+    version : Optional[Tuple(int)]
+        TensorRT version to target as tuple of (major, minor, patch). Will use linked TRT version if available if version is not specified.
+    use_implicit_batch : Optional[bool]
+
+    remove_no_mac_subgraphs : Optional[bool]
+
+    Returns
+    -------
+    ret : annotated and partitioned module.
+    """
+    if version:
+        assert isinstance(version, tuple) and len(version) == 3
+        set_tensorrt_version(version)
+    set_tensorrt_use_implicit_batch_mode(use_implicit_batch)
+    set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs)
+    if params:
+        mod['main'] = bind_params_by_name(mod['main'], params)
+
+    seq = tvm.transform.Sequential([transform.InferType(),
+                                    RemoveDropoutPass(),
+                                    transform.RemoveUnusedFunctions(),
+                                    transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default'],
+                                                             'nn.conv3d': ['NCDHW', 'default']}),
+                                    transform.FoldConstant(),
+                                    transform.AnnotateTarget('tensorrt'),
+                                    transform.MergeCompilerRegions(),
+                                    transform.PartitionGraph(),
+                                    transform.InferType()])
+    with tvm.transform.PassContext(opt_level=3):
+        mod = seq(mod)
+    mod = prune_tensorrt_subgraphs(mod)
+    # Pass parameters to codegen
+    os.environ["TVM_TENSORRT_USE_IMPLICIT_BATCH"] = str(int(use_implicit_batch))
+    os.environ["TVM_TENSORRT_MAX_WORKSPACE_SIZE"] = str(int(max_workspace_size))
+    return mod
+
+
+def _register_external_op_helper(op_name, supported=True):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return supported
+    return _func_wrapper
+
+
+def _register_external_op_helper_func(op_name, func):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return func(attrs, args, op_name)
+    return _func_wrapper
+
+
+# Ops which are always supported
+_register_external_op_helper("nn.relu")
+_register_external_op_helper("sigmoid")
+_register_external_op_helper("tanh")
+_register_external_op_helper("subtract")
+_register_external_op_helper("multiply")
+_register_external_op_helper("divide")
+_register_external_op_helper("power")
+_register_external_op_helper("maximum")
+_register_external_op_helper("minimum")
+_register_external_op_helper("exp")
+_register_external_op_helper("log")
+_register_external_op_helper("sqrt")
+_register_external_op_helper("abs")
+_register_external_op_helper("negative")
+_register_external_op_helper("nn.batch_flatten")
+_register_external_op_helper("clip")
+
+@tvm.ir.register_op_attr("add", "target.tensorrt")
+def add_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if (isinstance(args[0], Constant) or isinstance(args[1], Constant)) and \
+            args[0].checked_type.shape[0] == args[0].checked_type.shape[0] and \
+            args[0].checked_type.shape[0] != 1 and \
+            (len(args[0].checked_type.shape) > 3 or len(args[1].checked_type.shape) > 3):
+        print("add: bug in TRT with adding batched constants.")

Review comment:
       I think it would be better to replace use of `print`, with the Python `logging` infrastructure.
   
   This applies for all uses of print in this class and other sources.

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -0,0 +1,675 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""TensorRT supported operators."""
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.expr import Call, Constant, Tuple, GlobalVar
+from tvm.relay.expr_functor import ExprMutator
+
+import os
+import numpy as np
+
+# Version to use for annotation when there is no linked TRT.
+TENSORRT_VERSION = (6, 0, 1)
+USE_IMPLICIT_BATCH = True
+REMOVE_NO_MAC_SUBGRAPHS = False
+
+def is_tensorrt_runtime_enabled():
+    """Check if the TensorRT graph runtime is present.
+    Returns
+    -------
+    ret: bool
+        True if present, False if not.
+    """
+    check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
+    if check_enabled:
+        return check_enabled()
+    return False
+
+def get_tensorrt_version():
+    """Gets the version of TensorRT that TVM is built against.
+
+    Returns
+    -------
+    ret: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, the value set by set_tensorrt_version() is returned instead.
+    """
+    linked_ver = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+    if len(linked_ver) == 3:
+        return linked_ver
+    return TENSORRT_VERSION
+
+def set_tensorrt_version(version):
+    """Override TensorRT version for annotation
+
+    Returns
+    -------
+    version: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, an empty tuple is returned instead.
+    """
+    global TENSORRT_VERSION
+    TENSORRT_VERSION = version
+
+def get_tensorrt_use_implicit_batch_mode():
+    return USE_IMPLICIT_BATCH
+
+def set_tensorrt_use_implicit_batch_mode(use_implicit_batch):
+    global USE_IMPLICIT_BATCH
+    USE_IMPLICIT_BATCH = use_implicit_batch
+
+def get_tensorrt_remove_no_mac_subgraphs():
+    return REMOVE_NO_MAC_SUBGRAPHS
+
+def set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs):
+    global REMOVE_NO_MAC_SUBGRAPHS
+    REMOVE_NO_MAC_SUBGRAPHS = remove_no_mac_subgraphs
+
+def partition_for_tensorrt(mod, params=None, version=None, use_implicit_batch=True, remove_no_mac_subgraphs=False, max_workspace_size=1 << 30):

Review comment:
       Missing `max_worspace_size` on docstring

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -0,0 +1,675 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""TensorRT supported operators."""
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.expr import Call, Constant, Tuple, GlobalVar
+from tvm.relay.expr_functor import ExprMutator
+
+import os
+import numpy as np
+
+# Version to use for annotation when there is no linked TRT.
+TENSORRT_VERSION = (6, 0, 1)
+USE_IMPLICIT_BATCH = True
+REMOVE_NO_MAC_SUBGRAPHS = False
+
+def is_tensorrt_runtime_enabled():
+    """Check if the TensorRT graph runtime is present.
+    Returns
+    -------
+    ret: bool
+        True if present, False if not.
+    """
+    check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
+    if check_enabled:
+        return check_enabled()
+    return False
+
+def get_tensorrt_version():
+    """Gets the version of TensorRT that TVM is built against.
+
+    Returns
+    -------
+    ret: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, the value set by set_tensorrt_version() is returned instead.
+    """
+    linked_ver = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+    if len(linked_ver) == 3:
+        return linked_ver
+    return TENSORRT_VERSION
+
+def set_tensorrt_version(version):
+    """Override TensorRT version for annotation
+
+    Returns
+    -------
+    version: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, an empty tuple is returned instead.
+    """
+    global TENSORRT_VERSION
+    TENSORRT_VERSION = version
+
+def get_tensorrt_use_implicit_batch_mode():
+    return USE_IMPLICIT_BATCH
+
+def set_tensorrt_use_implicit_batch_mode(use_implicit_batch):
+    global USE_IMPLICIT_BATCH
+    USE_IMPLICIT_BATCH = use_implicit_batch
+
+def get_tensorrt_remove_no_mac_subgraphs():
+    return REMOVE_NO_MAC_SUBGRAPHS
+
+def set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs):
+    global REMOVE_NO_MAC_SUBGRAPHS
+    REMOVE_NO_MAC_SUBGRAPHS = remove_no_mac_subgraphs
+
+def partition_for_tensorrt(mod, params=None, version=None, use_implicit_batch=True, remove_no_mac_subgraphs=False, max_workspace_size=1 << 30):
+    """Partition the graph greedily offloading supported
+    operators to TensorRT.
+    Parameters
+    ----------
+    mod : Module
+        The module to run passes on.
+    params : Optional[Dict[str, NDArray]]
+        Constant input parameters.
+    version : Optional[Tuple(int)]
+        TensorRT version to target as tuple of (major, minor, patch). Will use linked TRT version if available if version is not specified.
+    use_implicit_batch : Optional[bool]
+
+    remove_no_mac_subgraphs : Optional[bool]
+
+    Returns
+    -------
+    ret : annotated and partitioned module.
+    """
+    if version:
+        assert isinstance(version, tuple) and len(version) == 3
+        set_tensorrt_version(version)
+    set_tensorrt_use_implicit_batch_mode(use_implicit_batch)
+    set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs)
+    if params:
+        mod['main'] = bind_params_by_name(mod['main'], params)
+
+    seq = tvm.transform.Sequential([transform.InferType(),
+                                    RemoveDropoutPass(),
+                                    transform.RemoveUnusedFunctions(),
+                                    transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default'],
+                                                             'nn.conv3d': ['NCDHW', 'default']}),
+                                    transform.FoldConstant(),
+                                    transform.AnnotateTarget('tensorrt'),
+                                    transform.MergeCompilerRegions(),
+                                    transform.PartitionGraph(),
+                                    transform.InferType()])
+    with tvm.transform.PassContext(opt_level=3):
+        mod = seq(mod)
+    mod = prune_tensorrt_subgraphs(mod)
+    # Pass parameters to codegen
+    os.environ["TVM_TENSORRT_USE_IMPLICIT_BATCH"] = str(int(use_implicit_batch))
+    os.environ["TVM_TENSORRT_MAX_WORKSPACE_SIZE"] = str(int(max_workspace_size))

Review comment:
       Is there any alternative to the use of environment variables o communicate these arguments from the partitioning to the codegen? If there is no alternative I think at least these environment variables should be documented somewhere.
   
   I don't know the right answer, so I'll just cc @mbaret and @comaniac 
   

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -0,0 +1,675 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""TensorRT supported operators."""
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.expr import Call, Constant, Tuple, GlobalVar
+from tvm.relay.expr_functor import ExprMutator
+
+import os
+import numpy as np
+
+# Version to use for annotation when there is no linked TRT.
+TENSORRT_VERSION = (6, 0, 1)
+USE_IMPLICIT_BATCH = True
+REMOVE_NO_MAC_SUBGRAPHS = False
+
+def is_tensorrt_runtime_enabled():
+    """Check if the TensorRT graph runtime is present.
+    Returns
+    -------
+    ret: bool
+        True if present, False if not.
+    """
+    check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
+    if check_enabled:
+        return check_enabled()
+    return False
+
+def get_tensorrt_version():
+    """Gets the version of TensorRT that TVM is built against.
+
+    Returns
+    -------
+    ret: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, the value set by set_tensorrt_version() is returned instead.
+    """
+    linked_ver = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+    if len(linked_ver) == 3:
+        return linked_ver
+    return TENSORRT_VERSION
+
+def set_tensorrt_version(version):
+    """Override TensorRT version for annotation
+
+    Returns
+    -------
+    version: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, an empty tuple is returned instead.
+    """
+    global TENSORRT_VERSION
+    TENSORRT_VERSION = version
+
+def get_tensorrt_use_implicit_batch_mode():
+    return USE_IMPLICIT_BATCH
+
+def set_tensorrt_use_implicit_batch_mode(use_implicit_batch):
+    global USE_IMPLICIT_BATCH
+    USE_IMPLICIT_BATCH = use_implicit_batch
+
+def get_tensorrt_remove_no_mac_subgraphs():
+    return REMOVE_NO_MAC_SUBGRAPHS
+
+def set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs):
+    global REMOVE_NO_MAC_SUBGRAPHS
+    REMOVE_NO_MAC_SUBGRAPHS = remove_no_mac_subgraphs
+
+def partition_for_tensorrt(mod, params=None, version=None, use_implicit_batch=True, remove_no_mac_subgraphs=False, max_workspace_size=1 << 30):
+    """Partition the graph greedily offloading supported
+    operators to TensorRT.
+    Parameters
+    ----------
+    mod : Module
+        The module to run passes on.
+    params : Optional[Dict[str, NDArray]]
+        Constant input parameters.
+    version : Optional[Tuple(int)]
+        TensorRT version to target as tuple of (major, minor, patch). Will use linked TRT version if available if version is not specified.
+    use_implicit_batch : Optional[bool]
+
+    remove_no_mac_subgraphs : Optional[bool]
+
+    Returns
+    -------
+    ret : annotated and partitioned module.
+    """
+    if version:
+        assert isinstance(version, tuple) and len(version) == 3
+        set_tensorrt_version(version)
+    set_tensorrt_use_implicit_batch_mode(use_implicit_batch)
+    set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs)
+    if params:
+        mod['main'] = bind_params_by_name(mod['main'], params)
+
+    seq = tvm.transform.Sequential([transform.InferType(),
+                                    RemoveDropoutPass(),
+                                    transform.RemoveUnusedFunctions(),
+                                    transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default'],
+                                                             'nn.conv3d': ['NCDHW', 'default']}),
+                                    transform.FoldConstant(),
+                                    transform.AnnotateTarget('tensorrt'),
+                                    transform.MergeCompilerRegions(),
+                                    transform.PartitionGraph(),
+                                    transform.InferType()])
+    with tvm.transform.PassContext(opt_level=3):
+        mod = seq(mod)
+    mod = prune_tensorrt_subgraphs(mod)
+    # Pass parameters to codegen
+    os.environ["TVM_TENSORRT_USE_IMPLICIT_BATCH"] = str(int(use_implicit_batch))
+    os.environ["TVM_TENSORRT_MAX_WORKSPACE_SIZE"] = str(int(max_workspace_size))
+    return mod
+
+
+def _register_external_op_helper(op_name, supported=True):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return supported
+    return _func_wrapper
+
+
+def _register_external_op_helper_func(op_name, func):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return func(attrs, args, op_name)
+    return _func_wrapper
+
+
+# Ops which are always supported
+_register_external_op_helper("nn.relu")
+_register_external_op_helper("sigmoid")
+_register_external_op_helper("tanh")
+_register_external_op_helper("subtract")
+_register_external_op_helper("multiply")
+_register_external_op_helper("divide")
+_register_external_op_helper("power")
+_register_external_op_helper("maximum")
+_register_external_op_helper("minimum")
+_register_external_op_helper("exp")
+_register_external_op_helper("log")
+_register_external_op_helper("sqrt")
+_register_external_op_helper("abs")
+_register_external_op_helper("negative")
+_register_external_op_helper("nn.batch_flatten")
+_register_external_op_helper("clip")
+
+@tvm.ir.register_op_attr("add", "target.tensorrt")
+def add_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if (isinstance(args[0], Constant) or isinstance(args[1], Constant)) and \
+            args[0].checked_type.shape[0] == args[0].checked_type.shape[0] and \
+            args[0].checked_type.shape[0] != 1 and \
+            (len(args[0].checked_type.shape) > 3 or len(args[1].checked_type.shape) > 3):
+        print("add: bug in TRT with adding batched constants.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.batch_norm", "target.tensorrt")
+def batch_norm_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if int(attrs.axis) != 1 and int(attrs.axis) != 3:

Review comment:
       Suggestion: This could be simplified to `if int(attrs.axis) not in (1, 3):`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org