You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/09/15 12:55:29 UTC

[GitHub] [incubator-tvm] lhutton1 commented on a change in pull request #6395: [BYOC][TensorRT] TensorRT BYOC integration

lhutton1 commented on a change in pull request #6395:
URL: https://github.com/apache/incubator-tvm/pull/6395#discussion_r488545828



##########
File path: cmake/modules/contrib/TensorRT.cmake
##########
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# TensorRT Codegen only. This can be enabled independently of USE_TENSORRT to

Review comment:
       ```suggestion
   # TensorRT Codegen only. This can be enabled independently of USE_TENSORRT_GRAPH_RUNTIME to
   ```

##########
File path: src/runtime/contrib/tensorrt/tensorrt_builder.cc
##########
@@ -0,0 +1,224 @@
+/* * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file runtime/contrib/tensorrt/tensorrt_builder.cc
+ * \brief The TensorRTBuilder class can be used to convert a JSONRuntime graph into a TRT engine
+ * which can be used for inference.
+ */
+
+#include "tensorrt_builder.h"
+
+#include <tvm/runtime/ndarray.h>
+
+#include <memory>
+#include <string>
+
+#include "tensorrt_logger.h"
+#include "tensorrt_ops.h"
+#include "tensorrt_utils.h"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, size_t max_workspace_size,
+                                 bool use_implicit_batch, bool use_fp16, int batch_size)
+    : max_workspace_size_(max_workspace_size),
+      use_implicit_batch_(use_implicit_batch),
+      use_fp16_(use_fp16),
+      batch_size_(batch_size) {
+  // Create TRT builder and network.
+  builder_ = nvinfer1::createInferBuilder(*logger);
+#if TRT_VERSION_GE(6, 0, 1)
+  // Use INetworkV2.
+  auto flags =
+      1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
+  if (use_implicit_batch_) {
+    flags = 0U;
+    builder_->setMaxBatchSize(batch_size_);
+  }
+  network_ = builder_->createNetworkV2(flags);
+#else
+  // Use INetwork with implicit batch.
+  builder_->setMaxBatchSize(batch_size_);
+  builder_->setMaxWorkspaceSize(max_workspace_size_);
+  builder_->setFp16Mode(use_fp16_);
+  network_ = builder_->createNetwork();
+#endif
+}
+
+void TensorRTBuilder::AddInput(int nid, const JSONGraphNode& node) {
+  auto node_name = node.GetOpName();
+  auto shapes = node.GetOpShape();
+  auto dtypes = node.GetOpDataType();
+  CHECK_EQ(shapes.size(), dtypes.size());
+  node_output_map_[nid] = {};
+  for (size_t i = 0; i < shapes.size(); ++i) {
+    const std::string name = node_name + "_" + std::to_string(i);
+    auto shape = shapes[i];
+    // Remove batch dim when not in explicit batch mode.
+    if (use_implicit_batch_ && shape.size() > 1) {
+      shape.erase(shape.begin());
+    }
+    DLOG(INFO) << "TRT input: " << name << " " << DebugString(shape);

Review comment:
       Is it useful to log every input and output? I feel like it might be a bit unnecessary 

##########
File path: src/runtime/contrib/tensorrt/tensorrt_builder.h
##########
@@ -0,0 +1,156 @@
+/* * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file runtime/contrib/tensorrt/tensorrt_builder.h
+ * \brief The TensorRTBuilder class can be used to convert a JSONRuntime graph into a TRT engine
+ * which can be used for inference.
+ */
+
+#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_BUILDER_H_
+#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_BUILDER_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "../json/json_node.h"
+#include "NvInfer.h"
+#include "tensorrt_logger.h"
+#include "tensorrt_ops.h"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
+using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
+
+/*!
+ * \brief The product of TensorRTBuilder which provides everything needed to
+ * perform inference.
+ */
+struct TrtEngineAndContext {
+  nvinfer1::ICudaEngine* engine;
+  nvinfer1::IExecutionContext* context;
+  std::vector<std::string> inputs;
+  std::vector<std::string> outputs;
+};
+
+/*!
+ * \brief Converts a JSONRuntime graph into a TensorRT engine and execution context. Inputs,
+ * constants, layers, and outputs can be added to construct the TensorRT network definition. BuildEngine will then 

Review comment:
       Missing last part of sentence?

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -0,0 +1,675 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""TensorRT supported operators."""
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.expr import Call, Constant, Tuple, GlobalVar
+from tvm.relay.expr_functor import ExprMutator
+
+import os
+import numpy as np
+
+# Version to use for annotation when there is no linked TRT.
+TENSORRT_VERSION = (6, 0, 1)
+USE_IMPLICIT_BATCH = True
+REMOVE_NO_MAC_SUBGRAPHS = False
+
+def is_tensorrt_runtime_enabled():
+    """Check if the TensorRT graph runtime is present.
+    Returns
+    -------
+    ret: bool
+        True if present, False if not.
+    """
+    check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
+    if check_enabled:
+        return check_enabled()
+    return False
+
+def get_tensorrt_version():
+    """Gets the version of TensorRT that TVM is built against.
+
+    Returns
+    -------
+    ret: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, the value set by set_tensorrt_version() is returned instead.
+    """
+    linked_ver = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
+    if len(linked_ver) == 3:
+        return linked_ver
+    return TENSORRT_VERSION
+
+def set_tensorrt_version(version):
+    """Override TensorRT version for annotation
+
+    Returns
+    -------
+    version: Tuple[int]
+        TensorRT version as a tuple of major, minor, and patch number. If TVM
+        is not built with TensorRT, an empty tuple is returned instead.
+    """
+    global TENSORRT_VERSION
+    TENSORRT_VERSION = version
+
+def get_tensorrt_use_implicit_batch_mode():
+    return USE_IMPLICIT_BATCH
+
+def set_tensorrt_use_implicit_batch_mode(use_implicit_batch):
+    global USE_IMPLICIT_BATCH
+    USE_IMPLICIT_BATCH = use_implicit_batch
+
+def get_tensorrt_remove_no_mac_subgraphs():
+    return REMOVE_NO_MAC_SUBGRAPHS
+
+def set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs):
+    global REMOVE_NO_MAC_SUBGRAPHS
+    REMOVE_NO_MAC_SUBGRAPHS = remove_no_mac_subgraphs
+
+def partition_for_tensorrt(mod, params=None, version=None, use_implicit_batch=True, remove_no_mac_subgraphs=False, max_workspace_size=1 << 30):
+    """Partition the graph greedily offloading supported
+    operators to TensorRT.
+    Parameters
+    ----------
+    mod : Module
+        The module to run passes on.
+    params : Optional[Dict[str, NDArray]]
+        Constant input parameters.
+    version : Optional[Tuple(int)]
+        TensorRT version to target as tuple of (major, minor, patch). Will use linked TRT version if available if version is not specified.
+    use_implicit_batch : Optional[bool]
+
+    remove_no_mac_subgraphs : Optional[bool]
+
+    Returns
+    -------
+    ret : annotated and partitioned module.
+    """
+    if version:
+        assert isinstance(version, tuple) and len(version) == 3
+        set_tensorrt_version(version)
+    set_tensorrt_use_implicit_batch_mode(use_implicit_batch)
+    set_tensorrt_remove_no_mac_subgraphs(remove_no_mac_subgraphs)
+    if params:
+        mod['main'] = bind_params_by_name(mod['main'], params)
+
+    seq = tvm.transform.Sequential([transform.InferType(),
+                                    RemoveDropoutPass(),
+                                    transform.RemoveUnusedFunctions(),
+                                    transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default'],
+                                                             'nn.conv3d': ['NCDHW', 'default']}),
+                                    transform.FoldConstant(),
+                                    transform.AnnotateTarget('tensorrt'),
+                                    transform.MergeCompilerRegions(),
+                                    transform.PartitionGraph(),
+                                    transform.InferType()])
+    with tvm.transform.PassContext(opt_level=3):
+        mod = seq(mod)
+    mod = prune_tensorrt_subgraphs(mod)
+    # Pass parameters to codegen
+    os.environ["TVM_TENSORRT_USE_IMPLICIT_BATCH"] = str(int(use_implicit_batch))
+    os.environ["TVM_TENSORRT_MAX_WORKSPACE_SIZE"] = str(int(max_workspace_size))
+    return mod
+
+
+def _register_external_op_helper(op_name, supported=True):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return supported
+    return _func_wrapper
+
+
+def _register_external_op_helper_func(op_name, func):
+    @tvm.ir.register_op_attr(op_name, "target.tensorrt")
+    def _func_wrapper(attrs, args):
+        if any([x.checked_type.dtype != "float32" for x in args]):
+            print("Only float32 inputs are supported for TensorRT.")
+            return False
+        return func(attrs, args, op_name)
+    return _func_wrapper
+
+
+# Ops which are always supported
+_register_external_op_helper("nn.relu")
+_register_external_op_helper("sigmoid")
+_register_external_op_helper("tanh")
+_register_external_op_helper("subtract")
+_register_external_op_helper("multiply")
+_register_external_op_helper("divide")
+_register_external_op_helper("power")
+_register_external_op_helper("maximum")
+_register_external_op_helper("minimum")
+_register_external_op_helper("exp")
+_register_external_op_helper("log")
+_register_external_op_helper("sqrt")
+_register_external_op_helper("abs")
+_register_external_op_helper("negative")
+_register_external_op_helper("nn.batch_flatten")
+_register_external_op_helper("clip")
+
+@tvm.ir.register_op_attr("add", "target.tensorrt")
+def add_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if (isinstance(args[0], Constant) or isinstance(args[1], Constant)) and \
+            args[0].checked_type.shape[0] == args[0].checked_type.shape[0] and \
+            args[0].checked_type.shape[0] != 1 and \
+            (len(args[0].checked_type.shape) > 3 or len(args[1].checked_type.shape) > 3):
+        print("add: bug in TRT with adding batched constants.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.batch_norm", "target.tensorrt")
+def batch_norm_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if int(attrs.axis) != 1 and int(attrs.axis) != 3:
+        print("nn.batch_norm: axis is {} but must be 1 or 3.".format(int(attrs.axis)))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.softmax", "target.tensorrt")
+def softmax_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
+        print("nn.softmax: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.conv2d", "target.tensorrt")
+def conv2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.data_layout != "NCHW":
+        print("nn.conv2d: data_layout is {} but must be NCHW.".format(attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIHW":
+        print("nn.conv2d: kernel_layout is {} but must be OIHW.".format(attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCHW":
+        print("nn.conv2d: out_layout is {} but must be NCHW.".format(attrs.out_layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.dense", "target.tensorrt")
+def dense_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    input_rank = len(args[0].checked_type.shape)
+    weight_rank = len(args[1].checked_type.shape)
+    if input_rank < 2 or input_rank > 4:
+        print("nn.dense: input has rank {} but must be 2, 3 or 4.".format(input_rank))
+        return False
+    if weight_rank != 2:
+        print("nn.dense: weight has rank {} but must be 2.".format(weight_rank))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.bias_add", "target.tensorrt")
+def bias_add_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    input_rank = len(args[0].checked_type.shape)
+    if input_rank < 2 or input_rank > 4:
+        print("nn.bias_add: input rank is {} but must be 2, 3 or 4.".format(input_rank))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.max_pool2d", "target.tensorrt")
+def max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.max_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    if attrs.ceil_mode and get_tensorrt_version() < (5, 1, 5):
+        print("nn.avg_pool2d: ceil_mode=True requires TensorRT 5.1.5 or greater.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.avg_pool2d", "target.tensorrt")
+def avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.avg_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    if attrs.count_include_pad and len(attrs.padding) == 4:
+        print("nn.avg_pool2d: inclusive-counted blended or average "
+                "pooling is not supported in combination with asymmetric padding")
+        return False
+    if attrs.ceil_mode and get_tensorrt_version() < (5, 1, 5):
+        print("nn.avg_pool2d: ceil_mode=True requires TensorRT 5.1.5 or greater.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.tensorrt")
+def global_max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.global_max_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.tensorrt")
+def global_avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.layout != "NCHW":
+        print("nn.global_avg_pool2d: layout is {} but must be NCHW.".format(attrs.layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("expand_dims", "target.tensorrt")
+def expand_dims_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
+        print("expand_dims: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("squeeze", "target.tensorrt")
+def squeeze_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if not attrs.axis:
+        print("squeeze: must explicitly set axis.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and any([axis == 0 for axis in map(int, attrs.axis)]):
+        print("squeeze: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("concatenate", "target.tensorrt")
+def concatenate_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.dtype != "float32" for x in args[0].checked_type.fields]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if not get_tensorrt_use_implicit_batch_mode():
+        return True
+    if int(attrs.axis) == 0:
+        print("concatenate: can't modify batch dimension.")
+        return False
+    if isinstance(args[0], Tuple):
+        for tuple_input in args[0].fields:
+            if isinstance(tuple_input, Constant):
+                print("concatenate: can't concatenate tensors with constants.")
+                return False
+    return True
+
+@tvm.ir.register_op_attr("nn.conv2d_transpose", "target.tensorrt")
+def conv2d_transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.data_layout != "NCHW":
+        print("nn.conv2d_transpose: data_layout is {} but must be NCHW.".format(
+            attrs.data_layout))
+        return False
+    if attrs.kernel_layout != "OIHW":
+        print("nn.conv2d_transpose: kernel_layout is {} but must be OIHW.".format(
+            attrs.kernel_layout))
+        return False
+    if attrs.out_layout and attrs.out_layout != "NCHW":
+        print("nn.conv2d_transpose: out_layout is {} but must be NCHW.".format(
+            attrs.out_layout))
+        return False
+    if attrs.dilation and any([rate != 1 for rate in map(int, attrs.dilation)]):
+        print("nn.conv2d_transpose: dilation rate must be 1.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("transpose", "target.tensorrt")
+def transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(attrs.axes[0]) != 0:
+        print("transpose: can't modify batch dimension.")
+        return False
+    return True
+
+@tvm.ir.register_op_attr("layout_transform", "target.tensorrt")
+def resize_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if (attrs.src_layout, attrs.dst_layout) not in [("NCHW", "NHWC"), ("NHWC", "NCHW"), ("NDHWC", "NCDHW"), ("NCDHW", "NDHWC")]:
+        print("layout_transform: {} to {} is not supported.".format(attrs.src_layout, attrs.dst_layout))
+        return False
+    return True
+
+@tvm.ir.register_op_attr("reshape", "target.tensorrt")
+def reshape_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if args[0].checked_type.dtype != "float32":
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if any([x < -1 for x in map(int, attrs.newshape)]):
+        print("reshape: new shape dims must be explicit.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode():
+        shape = list(map(int, args[0].checked_type.shape))
+        new_shape = list(map(int, attrs.newshape))
+        if len(new_shape) == 0 or len(shape) == 0:
+            print("reshape: Can't reshape to or from scalar.")
+            return False
+        # TRT cannot modify batch dimension.
+        original_volume = np.prod(shape)
+        # First, resolve 0.
+        for i, value in enumerate(new_shape):
+            if value == 0:
+                new_shape[i] = shape[i]
+        # Resolve -1.
+        for i, value in enumerate(new_shape):
+            if value == -1:
+                new_shape[i] = original_volume // np.prod([x for x in new_shape if x != -1])
+        if shape[0] != new_shape[0]:
+            print("reshape: can't modify batch dimension.")
+            return False
+    return True
+
+@tvm.ir.register_op_attr("nn.pad", "target.tensorrt")
+def pad_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if any([x.checked_type.dtype != "float32" for x in args]):
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if attrs.pad_mode != "constant":
+        print("nn.pad: pad mode is {} but must be constant.".format(attrs.pad_mode))
+        return False
+    if float(attrs.pad_value) != 0.0:
+        print("nn.pad: pad value is {} but must be 0.0.".format(float(attrs.pad_value)))
+        return False
+    if any([x != 0 for x in attrs.pad_width[0]]) or any([x != 0 for x in attrs.pad_width[1]]):
+        print("nn.pad: can't pad batch or channel dimensions.")
+        return False
+    if len(attrs.pad_width) == 5 and any([x != 0 for x in attrs.pad_width[2]]):
+        print("nn.pad: can only pad last two dimensions for 5D inputs.")
+    return True
+
+def reduce_annotate_fn(attrs, args, op_name):
+    if not attrs.axis or len(attrs.axis) == 0:
+        print("{}: cannot reduce to scalar.".format(op_name))
+        return False
+    if attrs.exclude:
+        print("{}: exclude not supported.".format(op_name))
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and any([x == 0 for x in map(int, attrs.axis)]):
+        print("{}: can't modify batch dimension.".format(op_name))
+        return False
+    return True
+
+_register_external_op_helper_func("sum", reduce_annotate_fn)
+_register_external_op_helper_func("prod", reduce_annotate_fn)
+_register_external_op_helper_func("max", reduce_annotate_fn)
+_register_external_op_helper_func("min", reduce_annotate_fn)
+_register_external_op_helper_func("mean", reduce_annotate_fn)
+
+def trt_5_1_5_annotate_fn(attrs, args, op_name):
+    if get_tensorrt_version() < (5, 1, 5):
+        print("{}: requires TensorRT version 5.1.5 or higher.".format(op_name))
+        return False
+    return True
+
+_register_external_op_helper_func("nn.leaky_relu", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("sin", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("cos", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("atan", trt_5_1_5_annotate_fn)
+_register_external_op_helper_func("ceil", trt_5_1_5_annotate_fn)
+
+@tvm.ir.register_op_attr("strided_slice", "target.tensorrt")
+def strided_slice_annotate_fn(attrs, args): # pylint: disable=unused-variable
+    if args[0].checked_type.dtype != "float32":
+        print("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_version() < (5, 1, 5):
+        print("strided_slice: requires TensorRT version 5.1.5 or higher.")
+        return False

Review comment:
       ```suggestion
       if not trt_5_1_5_annotate_fn(attrs, args, "strided_slice"):
           return False
   ```

##########
File path: tests/python/contrib/test_tensorrt.py
##########
@@ -0,0 +1,573 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import time
+import pytest
+
+import tvm
+import tvm.relay.testing
+from tvm import relay
+from tvm.relay.op.contrib import tensorrt
+from tvm.contrib import graph_runtime
+
+def should_skip():
+    if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist:
+        print("skip because cuda is not enabled.")
+        return True
+    if not tensorrt.is_tensorrt_runtime_enabled():
+        print("skip because tensorrt runtime is not available")
+        return True
+    return False
+
+def test_tensorrt_simple():
+    if should_skip():
+        return
+    dtype = 'float32'
+    xshape = (1, 3, 2, 2)
+    yshape = (1, 3,  1,  1)
+    zshape = (1,  1,  1,  1)
+    x = relay.var('x', shape=(xshape), dtype=dtype)
+    y = relay.var('y', shape=(yshape), dtype=dtype)
+    z = relay.var('z', shape=(zshape), dtype=dtype)
+    w = z * (x + y)
+    out = relay.nn.relu(w)
+    f = relay.Function([x, y, z], out)
+
+    mod = tvm.IRModule()
+    mod['main'] = f
+    mod = tensorrt.partition_for_tensorrt(mod)
+    with relay.build_config(opt_level=3):
+        graph, lib, params = relay.build(mod, "cuda")
+    mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
+    x_data = np.random.uniform(-1, 1, xshape).astype(dtype)
+    y_data = np.random.uniform(-1, 1, yshape).astype(dtype)
+    z_data = np.random.uniform(-1, 1, zshape).astype(dtype)
+    mod.run(x=x_data, y=y_data, z=z_data)
+    results = [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())]
+
+def test_tensorrt_not_compatible():
+    if should_skip():
+        return
+    dtype = 'float32'
+    xshape = (1, 32, 14, 14)
+    x = relay.var('x', shape=(xshape), dtype=dtype)
+    y = relay.add(x, x)
+    z = relay.erf(y)
+    out = relay.nn.relu(z)
+    f = relay.Function([x], out)
+    mod = tvm.IRModule()
+    mod['main'] = f
+    mod = tensorrt.partition_for_tensorrt(mod)
+    with relay.build_config(opt_level=3):
+        graph, lib, params = relay.build(mod, "cuda")
+    mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
+    x_data = np.random.uniform(-1, 1, xshape).astype(dtype)
+    mod.run(x=x_data)
+    results = [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())]
+
+def test_tensorrt_ops():
+    if should_skip():
+        return
+    def run_and_verify(config):
+        f, input_shapes, is_param = config
+        params = {x: np.random.uniform(-1, 1, input_shapes[x]).astype(np.float32) for x in is_param}
+        input_dict = {k: np.random.uniform(-1, 1, v).astype(np.float32) for k, v in input_shapes.items() if k not in is_param}
+
+        # Run TRT 
+        mod = tvm.IRModule()
+        mod['main'] = f
+        mod = tensorrt.partition_for_tensorrt(mod, params)
+        with relay.build_config(opt_level=3):
+            graph, lib, graph_params = relay.build(mod, "cuda", params=params)
+        mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
+        mod.set_input(**graph_params)
+        mod.run(**input_dict)
+        results = [mod.get_output(i) for i in range(mod.get_num_outputs())]
+
+        # Run reference
+        mod = tvm.IRModule()
+        mod['main'] = f
+        with relay.build_config(opt_level=3):
+            graph, lib, graph_params = relay.build(mod, "cuda", params=params)
+        mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
+        mod.set_input(**graph_params)
+        mod.run(**input_dict)
+        ref_results = [mod.get_output(i) for i in range(mod.get_num_outputs())]
+        
+        assert len(results) == len(ref_results)
+        for i in range(len(results)):
+            res = results[i].asnumpy()
+            ref_res = ref_results[i].asnumpy()
+            assert res.shape == ref_res.shape
+            tvm.testing.assert_allclose(res, ref_res, rtol=1e-3, atol=1e-3)

Review comment:
       Same as above

##########
File path: tests/python/contrib/test_tensorrt.py
##########
@@ -0,0 +1,573 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import time
+import pytest
+
+import tvm
+import tvm.relay.testing
+from tvm import relay
+from tvm.relay.op.contrib import tensorrt
+from tvm.contrib import graph_runtime
+
+def should_skip():
+    if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist:
+        print("skip because cuda is not enabled.")
+        return True
+    if not tensorrt.is_tensorrt_runtime_enabled():
+        print("skip because tensorrt runtime is not available")
+        return True
+    return False
+
+def test_tensorrt_simple():
+    if should_skip():
+        return
+    dtype = 'float32'
+    xshape = (1, 3, 2, 2)
+    yshape = (1, 3,  1,  1)
+    zshape = (1,  1,  1,  1)
+    x = relay.var('x', shape=(xshape), dtype=dtype)
+    y = relay.var('y', shape=(yshape), dtype=dtype)
+    z = relay.var('z', shape=(zshape), dtype=dtype)
+    w = z * (x + y)
+    out = relay.nn.relu(w)
+    f = relay.Function([x, y, z], out)
+
+    mod = tvm.IRModule()
+    mod['main'] = f
+    mod = tensorrt.partition_for_tensorrt(mod)
+    with relay.build_config(opt_level=3):
+        graph, lib, params = relay.build(mod, "cuda")
+    mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
+    x_data = np.random.uniform(-1, 1, xshape).astype(dtype)
+    y_data = np.random.uniform(-1, 1, yshape).astype(dtype)
+    z_data = np.random.uniform(-1, 1, zshape).astype(dtype)
+    mod.run(x=x_data, y=y_data, z=z_data)
+    results = [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())]
+
+def test_tensorrt_not_compatible():
+    if should_skip():
+        return
+    dtype = 'float32'
+    xshape = (1, 32, 14, 14)
+    x = relay.var('x', shape=(xshape), dtype=dtype)
+    y = relay.add(x, x)
+    z = relay.erf(y)
+    out = relay.nn.relu(z)
+    f = relay.Function([x], out)
+    mod = tvm.IRModule()
+    mod['main'] = f
+    mod = tensorrt.partition_for_tensorrt(mod)
+    with relay.build_config(opt_level=3):
+        graph, lib, params = relay.build(mod, "cuda")
+    mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
+    x_data = np.random.uniform(-1, 1, xshape).astype(dtype)
+    mod.run(x=x_data)
+    results = [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())]
+
+def test_tensorrt_ops():
+    if should_skip():
+        return
+    def run_and_verify(config):
+        f, input_shapes, is_param = config
+        params = {x: np.random.uniform(-1, 1, input_shapes[x]).astype(np.float32) for x in is_param}
+        input_dict = {k: np.random.uniform(-1, 1, v).astype(np.float32) for k, v in input_shapes.items() if k not in is_param}
+
+        # Run TRT 
+        mod = tvm.IRModule()
+        mod['main'] = f
+        mod = tensorrt.partition_for_tensorrt(mod, params)
+        with relay.build_config(opt_level=3):
+            graph, lib, graph_params = relay.build(mod, "cuda", params=params)
+        mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
+        mod.set_input(**graph_params)
+        mod.run(**input_dict)
+        results = [mod.get_output(i) for i in range(mod.get_num_outputs())]
+
+        # Run reference
+        mod = tvm.IRModule()
+        mod['main'] = f
+        with relay.build_config(opt_level=3):
+            graph, lib, graph_params = relay.build(mod, "cuda", params=params)
+        mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
+        mod.set_input(**graph_params)
+        mod.run(**input_dict)
+        ref_results = [mod.get_output(i) for i in range(mod.get_num_outputs())]
+        
+        assert len(results) == len(ref_results)
+        for i in range(len(results)):
+            res = results[i].asnumpy()
+            ref_res = ref_results[i].asnumpy()
+            assert res.shape == ref_res.shape
+            tvm.testing.assert_allclose(res, ref_res, rtol=1e-3, atol=1e-3)
+
+    def test_conv2d(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), groups=1, padding=(0, 0), strides=(1, 1), dilation=(1, 1)):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        kernel = relay.var('kernel', shape=(k_shape), dtype='float32')
+        out = relay.nn.conv2d(x, kernel, channels=k_shape[0], kernel_size=k_shape[2:4], groups=groups, padding=padding, strides=strides, dilation=dilation)
+        f = relay.Function([x, kernel], out)
+        return f, {'x': x_shape, 'kernel': k_shape}, ['kernel']
+    
+    def test_conv2d_nhwc(x_shape=(1, 8, 8, 32), k_shape=(3, 3, 32, 16)):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        kernel = relay.var('kernel', shape=(k_shape), dtype='float32')
+        out = relay.nn.conv2d(x, kernel, channels=16, kernel_size=(3, 3), data_layout="NHWC", kernel_layout="HWIO")
+        f = relay.Function([x, kernel], out)
+        return f, {'x': x_shape, 'kernel': k_shape}, ['kernel']
+
+    def test_conv2d_const_weights(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), groups=1, padding=(0, 0), strides=(1, 1), dilation=(1, 1)):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        kernel = relay.const(np.ones(k_shape).astype("float32"))
+        out = relay.nn.conv2d(x, kernel, channels=k_shape[0], kernel_size=k_shape[2:4], groups=groups, padding=padding, strides=strides, dilation=dilation)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+    
+    def test_dense(x_shape=(1, 16), k_shape=(32, 16)):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        kernel = relay.var('kernel', shape=(k_shape), dtype='float32')
+        # Dense requires constant weights in TensorRT, so the weights are transposed by us.
+        out = relay.nn.dense(x, kernel, units=k_shape[0])
+        f = relay.Function([x, kernel], out)
+        return f, {'x': x_shape, 'kernel': k_shape}, ['kernel']
+    
+    def test_bias_add(x_shape=(1, 16), channels=16):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        bias = relay.var('bias', shape=(channels,), dtype='float32')
+        out = relay.nn.bias_add(x, bias)
+        f = relay.Function([x, bias], out)
+        return f, {'x': x_shape, 'bias': (channels,)}, ['bias']
+    
+    def test_pool2d(op, x_shape=(1, 3, 32, 32), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False, count_include_pad=None):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        if count_include_pad is not None:
+            out = op(x, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode, count_include_pad=count_include_pad)
+        else:
+            out = op(x, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+
+    def test_global_pool2d(op, x_shape=(1, 3, 32, 32)):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        out = op(x)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+
+    def test_batch_flatten(x_shape=(1, 3, 4, 6)):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        out = relay.nn.batch_flatten(x)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+    
+    def test_expand_dims(x_shape=(1, 3), axis=1, num_newaxis=1):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        out = relay.expand_dims(x, axis, num_newaxis)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+
+    def test_squeeze(x_shape, axis):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        out = relay.squeeze(x, axis=axis)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+    
+    def test_concatenate(input_shapes, axis):
+        concat_inputs = []
+        shapes_dict = {}
+        for i in range(len(input_shapes)):
+            name = 'input_{}'.format(i)
+            concat_inputs.append(relay.var(name, shape=(input_shapes[i]), dtype='float32'))
+            shapes_dict[name] = input_shapes[i]
+        out = relay.concatenate(concat_inputs, axis)
+        f = relay.Function(concat_inputs, out)
+        return f, shapes_dict, []
+    
+    def test_conv2d_transpose(x_shape=(1, 32, 8, 8), k_shape=(32, 16, 3, 3), groups=1, padding=(0, 0), strides=(1, 1)):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        kernel = relay.var('kernel', shape=(k_shape), dtype='float32')
+        out = relay.nn.conv2d_transpose(x, kernel, channels=k_shape[1], kernel_size=k_shape[2:4], groups=groups, padding=padding, strides=strides)
+        f = relay.Function([x, kernel], out)
+        return f, {'x': x_shape, 'kernel': k_shape}, ['kernel']
+    
+    def test_reshape(x_shape, new_shape):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        out = relay.reshape(x, new_shape)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+    
+    def test_transpose(x_shape, order):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        out = relay.transpose(x, order)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+
+    def test_transpose_weights_conv2d(x_shape=(1, 32, 9, 9), k_shape=(3, 3, 32, 16), order=(3, 2, 0, 1)):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        kernel = relay.var('kernel', shape=(k_shape), dtype='float32')
+        kernel_t = relay.transpose(kernel, order)
+        # Conv2d requires constant weights in TensorRT, so the weights are transposed by us.
+        out = relay.nn.conv2d(x, kernel_t, channels=k_shape[order[0]], kernel_size=(3, 3))
+        f = relay.Function([x, kernel], out)
+        return f, {'x': x_shape, 'kernel': k_shape}, ['kernel']
+
+    def test_transpose_weights_dense(x_shape=(1, 16), k_shape=(16, 32)):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        kernel = relay.var('kernel', shape=(k_shape), dtype='float32')
+        kernel_t = relay.transpose(kernel, (1, 0))
+        # Dense requires constant weights in TensorRT, so the weights are transposed by us.
+        out = relay.nn.dense(x, kernel_t)
+        f = relay.Function([x, kernel], out)
+        return f, {'x': x_shape, 'kernel': k_shape}, ['kernel']
+
+    def test_dense_from_pytorch(x_shape=(1, 16), k_shape=(32, 16)):
+        # FoldConstant will fold away the tranpose -> mult -> transpose.
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        kernel = relay.var('kernel', shape=(k_shape), dtype='float32')
+        kernel_t = relay.transpose(kernel, (1, 0))
+        beta = relay.const(1, dtype='float32')
+        kernel_t = relay.multiply(kernel_t, beta)
+        kernel_t = relay.transpose(kernel_t, (1, 0))
+        out = relay.nn.dense(x, kernel_t)
+        f = relay.Function([x, kernel], out)
+        return f, {'x': x_shape, 'kernel': k_shape}, ['kernel']
+
+    def test_float_const(x_shape=(1, 16)):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        beta = relay.const(1, dtype='float32')
+        out = relay.multiply(x, beta)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+
+    def test_pad(x_shape, pad_width):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        out = relay.nn.pad(x, pad_width=pad_width)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+
+    def test_softmax(x_shape, axis):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        out = relay.nn.softmax(x, axis=axis)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+
+    def test_batch_norm(x_shape, param_shape, axis=1, epsilon=1e-5):
+        x = relay.var("x", shape=(x_shape), dtype='float32')
+        beta = relay.var("beta", shape=(param_shape), dtype='float32')
+        gamma = relay.var("gamma",  shape=(param_shape), dtype='float32')
+        moving_mean = relay.var("moving_mean", shape=(param_shape), dtype='float32')
+        moving_var = relay.var("moving_var", shape=(param_shape), dtype='float32')
+        out, _, _ = relay.nn.batch_norm(x, gamma=gamma, beta=beta, moving_mean=moving_mean, moving_var=moving_var,
+                                        axis=axis, center=True, scale=True, epsilon=epsilon)
+        f = relay.Function([x, gamma, beta, moving_mean, moving_var], out)
+        return f, {'x': x_shape, 'beta': param_shape, 'gamma': param_shape,
+                   'moving_mean': param_shape, 'moving_var': param_shape}, ['beta', 'gamma', 'moving_mean', 'moving_var']
+
+    def test_unary(op, x_shape=(1, 8, 3, 3)):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        out = op(x)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+
+    def test_clip(x_shape=(1, 8, 3, 3)):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        out = relay.clip(x, a_min=-0.2, a_max=0.4)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+
+    def test_leaky_relu(x_shape=(1, 8, 3, 3)):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        out = relay.nn.leaky_relu(x, alpha=0.1)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+    
+    def test_binary(op, x_shape, y_shape, y_is_const=False):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        if y_is_const:
+            y = relay.const(np.ones(y_shape).astype('float32'))
+            out = op(x, y)
+            f = relay.Function([x], out)
+            return f, {'x': x_shape}, []
+        y = relay.var('y', shape=(y_shape), dtype='float32')
+        out = op(x, y)
+        f = relay.Function([x, y], out)
+        return f, {'x': x_shape, 'y': y_shape}, []
+
+    def test_reduce(op, x_shape=(1, 2, 3, 4), axis=(2, 3), keepdims=False):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        out = op(x, axis=axis, keepdims=keepdims)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+
+    def test_strided_slice(x_shape, begin, end, strides=None):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        if strides:
+            out = relay.strided_slice(x, relay.expr.const(begin, dtype="int32"), relay.expr.const(end, dtype="int32"), relay.expr.const(strides, dtype="int32"))
+        else:
+            out = relay.strided_slice(x, relay.expr.const(begin, dtype="int32"), relay.expr.const(end, dtype="int32"))
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+
+    def test_adaptive_pool2d(op, x_shape=(1, 3, 32, 32), out_size=(1, 1)):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        out = op(x, out_size)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+
+    def test_resize(x_shape=(1, 3, 16, 16), out_size=(32, 32), layout='NCHW', method='nearest_neighbor', coordinate_transformation_mode='align_corners'):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        out = relay.image.resize(x, out_size, layout=layout, method=method, coordinate_transformation_mode=coordinate_transformation_mode)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+
+    def test_multiple_outputs():
+        x = relay.var('x', shape=(1, 3), dtype='float32')
+        y = relay.var('y', shape=(1, 3), dtype='float32')
+        z = relay.add(x, y)
+        w = relay.add(z, y)
+        out = relay.Tuple((z, w))
+        f = relay.Function([x, y], out)
+        return f, {'x': (1, 3), 'y': (1, 3)}, []
+
+    def test_conv3d(x_shape=(1, 32, 8, 8, 8), k_shape=(16, 32, 3, 3, 3), groups=1, padding=(0, 0, 0), strides=(1, 1, 1), dilation=(1, 1, 1)):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        kernel = relay.var('kernel', shape=(k_shape), dtype='float32')
+        out = relay.nn.conv3d(x, kernel, channels=k_shape[0], kernel_size=k_shape[2:], groups=groups, padding=padding, strides=strides, dilation=dilation)
+        f = relay.Function([x, kernel], out)
+        return f, {'x': x_shape, 'kernel': k_shape}, ['kernel']
+
+    def test_pool3d(op, x_shape=(1, 3, 8, 32, 32), pool_size=(2, 2, 2), strides=(2, 2, 2), padding=(0, 0, 0), ceil_mode=False, count_include_pad=None):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        if count_include_pad is not None:
+            out = op(x, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode, count_include_pad=count_include_pad)
+        else:
+            out = op(x, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode)
+        f = relay.Function([x], out)
+        return f, {'x': x_shape}, []
+
+    def test_conv3d_transpose(x_shape=(1, 32, 8, 8, 8), k_shape=(32, 16, 3, 3, 3), groups=1, padding=(0, 0, 0), strides=(1, 1, 1), output_padding=(0, 0, 0)):
+        x = relay.var('x', shape=(x_shape), dtype='float32')
+        kernel = relay.var('kernel', shape=(k_shape), dtype='float32')
+        out = relay.nn.conv3d_transpose(x, kernel, channels=k_shape[1], kernel_size=k_shape[2:5], groups=groups, padding=padding, strides=strides, output_padding=output_padding)
+        f = relay.Function([x, kernel], out)
+        return f, {'x': x_shape, 'kernel': k_shape}, ['kernel']
+
+    run_and_verify(test_float_const())
+    run_and_verify(test_multiple_outputs())
+    run_and_verify(test_clip())
+    run_and_verify(test_leaky_relu())
+    run_and_verify(test_batch_norm((1, 64, 56, 56), (64,)))
+    run_and_verify(test_batch_norm((1, 56, 56, 64), (64,), axis=3, epsilon=1.001e-05))
+    run_and_verify(test_softmax((1, 1000), axis=1))
+    run_and_verify(test_softmax((1, 1000), axis=-1))
+    run_and_verify(test_softmax((1, 3, 4), axis=-2))
+    run_and_verify(test_softmax((1, 3, 4), axis=1))
+    for k_shape, groups in [((16, 32, 3, 3), 1), ((32, 1, 3, 3), 32)]:
+        for padding in [(0, 0), (1, 1)]:
+            for strides in [(1, 1), (2, 2)]:
+                for dilation in [(1, 1), (2, 2)]:
+                    run_and_verify(test_conv2d(k_shape=k_shape, groups=groups, padding=padding,
+                                               strides=strides, dilation=dilation))
+    run_and_verify(test_conv2d_nhwc())
+    run_and_verify(test_conv2d_const_weights())
+    run_and_verify(test_dense())
+    run_and_verify(test_dense_from_pytorch())
+    run_and_verify(test_bias_add())
+    run_and_verify(test_bias_add((1, 6, 3, 4), 6))
+    for op in [relay.add, relay.subtract, relay.multiply, relay.divide, relay.power]:
+        for y_is_const in [True, False]:
+            run_and_verify(test_binary(op, (1, 8, 3, 3), (1, 8, 3, 3), y_is_const))
+            run_and_verify(test_binary(op, (1, 8, 1, 3), (1, 8, 3, 1), y_is_const))
+            run_and_verify(test_binary(op, (1, 10), (10,), y_is_const))
+            run_and_verify(test_binary(op, (1, 1, 1, 10), (10,), y_is_const))
+            run_and_verify(test_binary(op, (1, 1, 1), (3,), y_is_const))
+    for pool_size in [(2, 2), (3, 3)]:
+        for strides in [(1, 1), (2, 2)]:
+            for padding in [(0, 0), (1, 1), (0, 0, 1, 1)]:
+                for ceil_mode in [False, True]:
+                    # Skip "the padding size is larger than or equal to the filter size for exclusive-counting pooling"
+                    if pool_size == (2, 2) and padding == (0, 0, 1, 1):
+                        continue
+                    for count_include_pad in [False, True]:
+                        # Skip "inclusive-counted blended or average pooling is not supported in combination with asymmetric padding"
+                        if count_include_pad and (padding == (0, 0, 1, 1) or strides == (2, 2)):
+                            continue
+                        run_and_verify(test_pool2d(relay.nn.avg_pool2d, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode, count_include_pad=count_include_pad))
+                    run_and_verify(test_pool2d(relay.nn.max_pool2d, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode))
+    for op in [relay.nn.global_max_pool2d, relay.nn.global_max_pool2d]:
+        run_and_verify(test_global_pool2d(op))
+    for op in [relay.nn.relu, relay.sigmoid, relay.tanh, relay.exp, relay.log, relay.sqrt,
+               relay.abs, relay.negative, relay.sin, relay.cos, relay.atan, relay.ceil, relay.floor]:
+        run_and_verify(test_unary(op))
+    run_and_verify(test_batch_flatten())
+    run_and_verify(test_expand_dims())
+    run_and_verify(test_squeeze((1, 5, 1, 1), (2, 3)))
+    run_and_verify(test_squeeze((1, 3, 1), (-1,)))
+    run_and_verify(test_concatenate([(1, 2, 6, 6), (1, 3, 6, 6)], axis=1))
+    for padding in [(0, 0), (1, 1)]:
+        for strides in [(1, 1), (2, 2)]:           
+                run_and_verify(test_conv2d_transpose(padding=padding, strides=strides))
+    run_and_verify(test_transpose((1, 16, 7, 7), [0, 2, 3, 1]))
+    run_and_verify(test_transpose((1, 7, 7, 16), [0, 3, 1, 2]))
+    run_and_verify(test_transpose_weights_conv2d())
+    run_and_verify(test_transpose_weights_conv2d((1, 32, 9, 9), (3, 3, 16, 32), (2, 3, 0, 1)))
+    run_and_verify(test_transpose_weights_dense())
+    run_and_verify(test_reshape((1, 1, 1, 10), (-1, 10)))
+    run_and_verify(test_reshape((1, 10, 2, 3), (1, -1)))
+    run_and_verify(test_reshape((1, 1, 2, 3), (1, 6)))
+    run_and_verify(test_pad((1, 8, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0]]))
+    run_and_verify(test_pad((1, 8, 16, 16), [[0, 0], [0, 0], [1, 1], [1, 1]]))
+    run_and_verify(test_pad((1, 8, 16, 16), [[0, 0], [0, 0], [0, 1], [2, 0]]))
+    run_and_verify(test_pad((1, 8, 3, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]))
+    for op in [relay.sum, relay.prod, relay.max, relay.min, relay.mean]:
+        for keepdims in [True, False]:
+            run_and_verify(test_reduce(op, axis=(1), keepdims=keepdims))
+            run_and_verify(test_reduce(op, axis=(2, 3), keepdims=keepdims))
+            run_and_verify(test_reduce(op, axis=(1, 2), keepdims=keepdims))
+            run_and_verify(test_reduce(op, axis=(1, 2, 3), keepdims=keepdims))
+    run_and_verify(test_strided_slice((1, 3, 6, 7), [0, 0, 0, 0], [1, 1, 6, 7]))
+    run_and_verify(test_strided_slice((1, 3, 6, 7), [0, 1, 0, 0], [1, 2, 6, 6]))
+    run_and_verify(test_strided_slice((1, 10), [0, 0], [1, 10], [1, 2]))
+    for op in [relay.nn.adaptive_max_pool2d, relay.nn.adaptive_avg_pool2d]:
+        run_and_verify(test_adaptive_pool2d(op))
+    run_and_verify(test_conv3d())
+    run_and_verify(test_conv3d(padding=(0, 0, 0, 1, 1, 1)))
+    run_and_verify(test_pool3d(relay.nn.avg_pool3d))
+    run_and_verify(test_pool3d(relay.nn.max_pool3d))
+    run_and_verify(test_pool3d(relay.nn.max_pool3d, padding=(0, 0, 0, 1, 1, 1)))
+    run_and_verify(test_pool3d(relay.nn.max_pool3d, strides=(1, 1, 1)))
+    run_and_verify(test_conv3d_transpose())
+    run_and_verify(test_conv3d_transpose(strides=(2, 2, 2)))
+    run_and_verify(test_conv3d_transpose(strides=(2, 2, 2), output_padding=(1, 1, 1)))
+
+def test_tensorrt_integration(test_all_models=False):
+    if should_skip():
+        return
+    
+    def test_model(model, i_data, input_shape, dtype, use_trt=True, num_iteration=1):
+        import mxnet as mx
+        from mxnet.gluon.model_zoo.vision import get_model
+        def check_trt_used(graph):
+            import json
+            graph = json.loads(graph)
+            num_trt_subgraphs = sum([1 for n in graph['nodes'] if n.get('attrs', {}).get('func_name', '').startswith('tensorrt_')])
+            assert num_trt_subgraphs >= 1
+
+        block = get_model(model, pretrained=True)
+        mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+
+        if use_trt:
+            mod = tensorrt.partition_for_tensorrt(mod, params)
+        with relay.build_config(opt_level=3):
+            graph, lib, params = relay.build(mod, "cuda", params=params)
+        if use_trt:
+            check_trt_used(graph)
+
+        mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
+        mod.set_input(**params)
+        # Warmup
+        for i in range(10):
+            mod.run(data=i_data)
+
+        # Time
+        times = []
+        for i in range(num_iteration):
+            start_time = time.time()
+            mod.run(data=i_data)
+            res = mod.get_output(0)
+            times.append(time.time() - start_time)
+        latency = 1000.0 * np.mean(times)
+        print(model, latency)
+        return latency, res
+
+    latency = {}
+    models = [
+        'alexnet',
+        'resnet18_v1',
+        'resnet18_v2',
+        'squeezenet1.0',
+        'mobilenet0.25',
+        'mobilenetv2_0.25',
+        'vgg11',
+        'densenet121',
+    ]
+    additional_models = [
+        'resnet34_v1',
+        'resnet50_v1',
+        'resnet101_v1',
+        'resnet152_v1',
+        'resnet34_v2',
+        'resnet50_v2',
+        'resnet101_v2',
+        'resnet152_v2',
+        'mobilenet0.5',
+        'mobilenet0.75',
+        'mobilenet1.0',
+        'mobilenetv2_0.5',
+        'mobilenetv2_0.75',
+        'mobilenetv2_1.0',
+        'vgg16',
+        'densenet169',
+        'densenet201']
+    if test_all_models:
+        models.extend(additional_models)
+    
+    dtype = 'float32'
+    input_shape = (1, 3, 224, 224)
+    i_data = np.random.uniform(-1, 1, input_shape).astype(dtype)
+    for model in models:
+        latency[model], res = test_model(model, i_data, input_shape, dtype, use_trt=True)
+        _, ref_res = test_model(model, i_data, input_shape, dtype, use_trt=False, num_iteration=1)
+        tvm.testing.assert_allclose(res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-3)

Review comment:
       Would be useful to know what model the test failed on

##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/arm_compute_lib/acl_runtime.cc
+ * \brief A simple JSON runtime for Arm Compute Library.

Review comment:
       Should be TensorRT :)

##########
File path: tests/python/contrib/test_tensorrt.py
##########
@@ -0,0 +1,573 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import time
+import pytest
+
+import tvm
+import tvm.relay.testing
+from tvm import relay
+from tvm.relay.op.contrib import tensorrt
+from tvm.contrib import graph_runtime
+
+def should_skip():
+    if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist:
+        print("skip because cuda is not enabled.")
+        return True
+    if not tensorrt.is_tensorrt_runtime_enabled():
+        print("skip because tensorrt runtime is not available")
+        return True
+    return False
+
+def test_tensorrt_simple():
+    if should_skip():
+        return
+    dtype = 'float32'
+    xshape = (1, 3, 2, 2)
+    yshape = (1, 3,  1,  1)
+    zshape = (1,  1,  1,  1)
+    x = relay.var('x', shape=(xshape), dtype=dtype)
+    y = relay.var('y', shape=(yshape), dtype=dtype)
+    z = relay.var('z', shape=(zshape), dtype=dtype)
+    w = z * (x + y)
+    out = relay.nn.relu(w)
+    f = relay.Function([x, y, z], out)
+
+    mod = tvm.IRModule()
+    mod['main'] = f
+    mod = tensorrt.partition_for_tensorrt(mod)
+    with relay.build_config(opt_level=3):
+        graph, lib, params = relay.build(mod, "cuda")
+    mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
+    x_data = np.random.uniform(-1, 1, xshape).astype(dtype)
+    y_data = np.random.uniform(-1, 1, yshape).astype(dtype)
+    z_data = np.random.uniform(-1, 1, zshape).astype(dtype)
+    mod.run(x=x_data, y=y_data, z=z_data)
+    results = [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())]
+
+def test_tensorrt_not_compatible():
+    if should_skip():
+        return
+    dtype = 'float32'
+    xshape = (1, 32, 14, 14)
+    x = relay.var('x', shape=(xshape), dtype=dtype)
+    y = relay.add(x, x)
+    z = relay.erf(y)
+    out = relay.nn.relu(z)
+    f = relay.Function([x], out)
+    mod = tvm.IRModule()
+    mod['main'] = f
+    mod = tensorrt.partition_for_tensorrt(mod)
+    with relay.build_config(opt_level=3):
+        graph, lib, params = relay.build(mod, "cuda")
+    mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
+    x_data = np.random.uniform(-1, 1, xshape).astype(dtype)
+    mod.run(x=x_data)
+    results = [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())]
+
+def test_tensorrt_ops():
+    if should_skip():
+        return
+    def run_and_verify(config):
+        f, input_shapes, is_param = config
+        params = {x: np.random.uniform(-1, 1, input_shapes[x]).astype(np.float32) for x in is_param}
+        input_dict = {k: np.random.uniform(-1, 1, v).astype(np.float32) for k, v in input_shapes.items() if k not in is_param}
+
+        # Run TRT 
+        mod = tvm.IRModule()
+        mod['main'] = f
+        mod = tensorrt.partition_for_tensorrt(mod, params)
+        with relay.build_config(opt_level=3):
+            graph, lib, graph_params = relay.build(mod, "cuda", params=params)
+        mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
+        mod.set_input(**graph_params)
+        mod.run(**input_dict)
+        results = [mod.get_output(i) for i in range(mod.get_num_outputs())]
+
+        # Run reference
+        mod = tvm.IRModule()
+        mod['main'] = f
+        with relay.build_config(opt_level=3):
+            graph, lib, graph_params = relay.build(mod, "cuda", params=params)

Review comment:
       If a module cannot be built it would be useful to know which config it failed on

##########
File path: src/runtime/contrib/tensorrt/tensorrt_builder.cc
##########
@@ -0,0 +1,224 @@
+/* * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file runtime/contrib/tensorrt/tensorrt_builder.cc
+ * \brief The TensorRTBuilder class can be used to convert a JSONRuntime graph into a TRT engine
+ * which can be used for inference.
+ */
+
+#include "tensorrt_builder.h"
+
+#include <tvm/runtime/ndarray.h>
+
+#include <memory>
+#include <string>
+
+#include "tensorrt_logger.h"
+#include "tensorrt_ops.h"
+#include "tensorrt_utils.h"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, size_t max_workspace_size,
+                                 bool use_implicit_batch, bool use_fp16, int batch_size)
+    : max_workspace_size_(max_workspace_size),
+      use_implicit_batch_(use_implicit_batch),
+      use_fp16_(use_fp16),
+      batch_size_(batch_size) {
+  // Create TRT builder and network.
+  builder_ = nvinfer1::createInferBuilder(*logger);
+#if TRT_VERSION_GE(6, 0, 1)
+  // Use INetworkV2.
+  auto flags =
+      1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
+  if (use_implicit_batch_) {
+    flags = 0U;
+    builder_->setMaxBatchSize(batch_size_);
+  }
+  network_ = builder_->createNetworkV2(flags);
+#else
+  // Use INetwork with implicit batch.
+  builder_->setMaxBatchSize(batch_size_);
+  builder_->setMaxWorkspaceSize(max_workspace_size_);
+  builder_->setFp16Mode(use_fp16_);
+  network_ = builder_->createNetwork();
+#endif
+}
+
+void TensorRTBuilder::AddInput(int nid, const JSONGraphNode& node) {
+  auto node_name = node.GetOpName();
+  auto shapes = node.GetOpShape();
+  auto dtypes = node.GetOpDataType();
+  CHECK_EQ(shapes.size(), dtypes.size());
+  node_output_map_[nid] = {};
+  for (size_t i = 0; i < shapes.size(); ++i) {
+    const std::string name = node_name + "_" + std::to_string(i);
+    auto shape = shapes[i];
+    // Remove batch dim when not in explicit batch mode.
+    if (use_implicit_batch_ && shape.size() > 1) {
+      shape.erase(shape.begin());
+    }
+    DLOG(INFO) << "TRT input: " << name << " " << DebugString(shape);
+    nvinfer1::Dims dims = VectorToTrtDims(shape);
+    CHECK(TypeMatch(dtypes[i], kDLFloat, 32)) << "Only FP32 inputs are supported.";
+    auto input_tensor = network_->addInput(name.c_str(), nvinfer1::DataType::kFLOAT, dims);
+    node_output_map_[nid].push_back(TrtOpInput(input_tensor));
+    network_input_names_.push_back(input_tensor->getName());
+  }
+}
+
+void TensorRTBuilder::AddConstant(int nid, const DLTensor* data) {
+  nvinfer1::Weights weight = GetDLTensorAsWeights(data, kDLCPU);
+  std::vector<int> shape(data->shape, data->shape + data->ndim);
+  // Remove batch dim when not in explicit batch mode.
+  if (use_implicit_batch_ && shape.size() > 1 && shape[0] == 1) {
+    shape.erase(shape.begin());
+  }
+  node_output_map_[nid] = {TrtOpInput(weight, shape)};
+}
+
+void TensorRTBuilder::AddOutput(const JSONGraphNodeEntry& node) {
+  auto it = node_output_map_.find(node.id_);
+  CHECK(it != node_output_map_.end()) << "Output was not found.";
+  auto out_tensor = it->second[node.index_].tensor;
+  std::string name = "tensorrt_output_" + std::to_string(network_output_names_.size());
+  out_tensor->setName(name.c_str());
+  network_->markOutput(*out_tensor);
+  network_output_names_.push_back(out_tensor->getName());
+  DLOG(INFO) << "TRT output: " << name << DebugString(TrtDimsToVector(out_tensor->getDimensions()));

Review comment:
       Same as above




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org