You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/01/16 06:14:50 UTC

[incubator-mxnet] Diff for: [GitHub] KellenSunderland merged pull request #13310: [MXNET-703] Update to TensorRT 5, ONNX IR 3. Fix inference bugs.

diff --git a/3rdparty/onnx-tensorrt b/3rdparty/onnx-tensorrt
index 3d8ee049970..f1c7aa63d88 160000
--- a/3rdparty/onnx-tensorrt
+++ b/3rdparty/onnx-tensorrt
@@ -1 +1 @@
-Subproject commit 3d8ee049970e81ff4935cc7f36b653c0b27bcbbc
+Subproject commit f1c7aa63d88d8d8ef70490f2ebb6b33f7450218b
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3e3de205347..23609e5ec24 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -187,6 +187,7 @@ if(USE_TENSORRT)
   include_directories(${ONNX_PATH})
   include_directories(3rdparty/onnx-tensorrt/)
   include_directories(3rdparty/)
+  include_directories(3rdparty/onnx-tensorrt/third_party/onnx/)
   add_definitions(-DMXNET_USE_TENSORRT=1)
   add_definitions(-DONNX_NAMESPACE=onnx)
 
diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
index 255da316041..f4844115c0f 100644
--- a/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
+++ b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
@@ -18,7 +18,7 @@
 #
 # Dockerfile to run MXNet on Ubuntu 16.04 for CPU
 
-FROM nvidia/cuda:9.0-cudnn7-devel
+FROM nvidia/cuda:10.0-cudnn7-devel
 
 WORKDIR /work/deps
 
diff --git a/ci/docker/install/tensorrt.sh b/ci/docker/install/tensorrt.sh
index 61e73ef9a62..1950cad0b52 100755
--- a/ci/docker/install/tensorrt.sh
+++ b/ci/docker/install/tensorrt.sh
@@ -26,7 +26,7 @@ pip3 install gluoncv==0.2.0
 pushd .
 cd ..
 apt-get update
-apt-get install -y automake libtool
+apt-get install -y automake libtool zip
 git clone --recursive -b 3.5.1.1 https://github.com/google/protobuf.git
 cd protobuf
 ./autogen.sh
@@ -41,7 +41,7 @@ popd
 
 # Install TensorRT
 echo "TensorRT build enabled. Installing TensorRT."
-wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0_1-1_amd64.deb
+wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda10.0_1-1_amd64.deb
 dpkg -i tensorrt.deb
 apt-get update
 apt-get install -y --allow-downgrades libnvinfer-dev
diff --git a/ci/docker/install/ubuntu_core.sh b/ci/docker/install/ubuntu_core.sh
index 4382aa6aefd..fc903e5c889 100755
--- a/ci/docker/install/ubuntu_core.sh
+++ b/ci/docker/install/ubuntu_core.sh
@@ -22,6 +22,10 @@
 
 set -ex
 apt-get update || true
+
+# Avoid interactive package installers such as tzdata.
+export DEBIAN_FRONTEND=noninteractive
+
 apt-get install -y \
     apt-transport-https \
     build-essential \
@@ -41,10 +45,11 @@ apt-get install -y \
     unzip \
     wget
 
-
-# Ubuntu 14.04
-if [[ $(lsb_release -r | grep 14.04) ]]; then
-    apt-get install -y cmake3
-else
-    apt-get install -y cmake
-fi
+# Note: we specify an exact cmake version to work around a cmake 3.10 CUDA 10 issue.
+# Reference: https://github.com/clab/dynet/issues/1457
+mkdir /opt/cmake && cd /opt/cmake
+wget -nv https://cmake.org/files/v3.12/cmake-3.12.4-Linux-x86_64.sh
+sh cmake-3.12.4-Linux-x86_64.sh --prefix=/opt/cmake --skip-license
+ln -s /opt/cmake/bin/cmake /usr/local/bin/cmake
+rm cmake-3.12.4-Linux-x86_64.sh
+cmake --version
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index a6bb1064a58..fcad7ffd97c 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -602,23 +602,23 @@ build_ubuntu_gpu_tensorrt() {
     cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser_runtime.so.0 /work/mxnet/lib/
     cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser.so.0 /work/mxnet/lib/
 
-    rm -rf build
-    make \
-        DEV=1                                                \
-        ENABLE_TESTCOVERAGE=1                                \
-        USE_BLAS=openblas                                    \
-        USE_CUDA=1                                           \
-        USE_CUDA_PATH=/usr/local/cuda                        \
-        USE_CUDNN=1                                          \
-        USE_OPENCV=0                                         \
-        USE_MKLDNN=0                                         \
-        USE_DIST_KVSTORE=0                                   \
-        USE_TENSORRT=1                                       \
-        USE_JEMALLOC=0                                       \
-        USE_GPERFTOOLS=0                                     \
-        ONNX_NAMESPACE=onnx                                  \
-        CUDA_ARCH="-gencode arch=compute_70,code=compute_70" \
-        -j$(nproc)
+    cd /work/build
+    cmake -DUSE_CUDA=1                            \
+          -DCMAKE_CXX_COMPILER_LAUNCHER=ccache    \
+          -DCMAKE_C_COMPILER_LAUNCHER=ccache      \
+          -DUSE_CUDNN=1                           \
+          -DUSE_OPENCV=1                          \
+          -DUSE_TENSORRT=1                        \
+          -DUSE_OPENMP=0                          \
+          -DUSE_MKLDNN=0                          \
+          -DUSE_MKL_IF_AVAILABLE=OFF              \
+          -DENABLE_TESTCOVERAGE=ON                \
+          -DCUDA_ARCH_NAME=Manual                 \
+          -DCUDA_ARCH_BIN=$CI_CMAKE_CUDA_ARCH_BIN \
+          -G Ninja                                \
+          /work/mxnet
+
+    ninja -v
 }
 
 build_ubuntu_gpu_mkldnn() {
diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy
index 33d76aa1668..b0bd4873695 100644
--- a/ci/jenkins/Jenkins_steps.groovy
+++ b/ci/jenkins/Jenkins_steps.groovy
@@ -34,7 +34,7 @@ mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/li
 mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests'
 mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0'
 mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
-mx_tensorrt_lib = 'lib/libmxnet.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
+mx_tensorrt_lib = 'build/libmxnet.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
 mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*'
 mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/cpp-package/example/*'
 
diff --git a/src/executor/onnx_to_tensorrt.cc b/src/executor/onnx_to_tensorrt.cc
index e3a4ae868ce..c37b856f9d6 100644
--- a/src/executor/onnx_to_tensorrt.cc
+++ b/src/executor/onnx_to_tensorrt.cc
@@ -28,7 +28,7 @@
 
 #include "./onnx_to_tensorrt.h"
 
-#include <onnx/onnx.pb.h>
+#include <onnx/onnx_pb.h>
 
 #include <NvInfer.h>
 #include <google/protobuf/io/coded_stream.h>
diff --git a/src/executor/tensorrt_pass.cc b/src/executor/tensorrt_pass.cc
index d26704c35cf..762dc0de9db 100644
--- a/src/executor/tensorrt_pass.cc
+++ b/src/executor/tensorrt_pass.cc
@@ -31,7 +31,7 @@
 #include <mxnet/op_attr_types.h>
 #include <mxnet/operator.h>
 #include <nnvm/graph_attr_types.h>
-#include <onnx/onnx.pb.h>
+#include <onnx/onnx_pb.h>
 
 #include "../operator/contrib/nnvm_to_onnx-inl.h"
 #include "./exec_pass.h"
diff --git a/src/executor/trt_graph_executor.cc b/src/executor/trt_graph_executor.cc
index ec35fee98a9..92bdcab9039 100644
--- a/src/executor/trt_graph_executor.cc
+++ b/src/executor/trt_graph_executor.cc
@@ -21,7 +21,7 @@
 
 #include "trt_graph_executor.h"
 
-#include <onnx/onnx.pb.h>
+#include <onnx/onnx_pb.h>
 #include <NvInfer.h>
 #include "./onnx_to_tensorrt.h"
 #include "../operator/contrib/tensorrt-inl.h"
diff --git a/src/operator/contrib/nnvm_to_onnx-inl.h b/src/operator/contrib/nnvm_to_onnx-inl.h
index 011ffe6b7dd..e0c4d9369e6 100644
--- a/src/operator/contrib/nnvm_to_onnx-inl.h
+++ b/src/operator/contrib/nnvm_to_onnx-inl.h
@@ -37,7 +37,7 @@
 #include <nnvm/graph.h>
 #include <nnvm/pass_functions.h>
 
-#include <onnx/onnx.pb.h>
+#include <onnx/onnx_pb.h>
 
 #include <algorithm>
 #include <iostream>
diff --git a/src/operator/contrib/nnvm_to_onnx.cc b/src/operator/contrib/nnvm_to_onnx.cc
index 784384e94e1..ccb6e04b0a4 100644
--- a/src/operator/contrib/nnvm_to_onnx.cc
+++ b/src/operator/contrib/nnvm_to_onnx.cc
@@ -62,15 +62,22 @@ namespace nnvm_to_onnx {
 op::ONNXParam ConvertNnvmGraphToOnnx(
     const nnvm::Graph& g,
     std::unordered_map<std::string, NDArray>* const shared_buffer) {
-    op::ONNXParam onnx_param;
-    op::nnvm_to_onnx::NameToIdx_t onnx_input_map;
-    op::nnvm_to_onnx::InferenceMap_t onnx_output_map;
+
+  static std::atomic_ulong subgraph_count = { 0 };
+
+  op::ONNXParam onnx_param;
+  op::nnvm_to_onnx::NameToIdx_t onnx_input_map;
+  op::nnvm_to_onnx::InferenceMap_t onnx_output_map;
 
   const nnvm::IndexedGraph& ig = g.indexed_graph();
   const auto& storage_types = g.GetAttr<StorageTypeVector>("storage_type");
   const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
   const auto& shape_inputs = g.GetAttr<ShapeVector>("shape_inputs");
 
+  // TODO(kellens): At the moment this check always passes no matter the weight dtypes used in your
+  // graph.  We should first iterate over datatypes by name and ensure  they're valid types
+  // (fp16 or fp32) and that they're uniform.  Then ensure later conversions set tensor types
+  // correctly in ONNX.
   for (auto& e : storage_types) {
     if (e != mshadow::kFloat32) {
       LOG(FATAL) << "ONNX converter does not support types other than float32 "
@@ -79,9 +86,23 @@ op::ONNXParam ConvertNnvmGraphToOnnx(
   }
 
   ModelProto model_proto;
-  // Need to determine IR versions and features to support
-  model_proto.set_ir_version(static_cast<int64>(2));
+
+  // We're currently serializing our models in ONNX 3, opset 8 as it is best supported by the
+  // currently linked version of the onnx-tensorrt library.
+  // More information on ONNX versions and opsets can be found at:
+  // https://github.com/onnx/onnx/blob/master/docs/IR.md
+
+  auto opset_proto = model_proto.add_opset_import();
+  const int64 onnx_opset = 8;
+  const int64 onnx_major_version = 3;
+
+  // Declare our ONNX versions in our protobuf model.
+  opset_proto->set_version(onnx_opset);
+  model_proto.set_ir_version(onnx_major_version);
+
   GraphProto* graph_proto = model_proto.mutable_graph();
+  auto subgraph_name_id = subgraph_count.fetch_add(1);
+  graph_proto->set_name("MXNetTRTSubgraph" + std::to_string(subgraph_name_id));
 
   std::unordered_map<std::string, TShape> placeholder_shapes =
       GetPlaceholderShapes(shape_inputs, ig);
@@ -176,6 +197,20 @@ void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
   // const bool no_bias = conv_param.no_bias;
   const dmlc::optional<int> layout = conv_param.layout;
 
+  // dilations
+  AttributeProto* const dilations = node_proto->add_attribute();
+  dilations->set_name("dilations");
+  dilations->set_type(AttributeProto::INTS);
+  for (const dim_t kval : dilate) {
+    dilations->add_ints(static_cast<int64>(kval));
+  }
+
+  // group
+  AttributeProto* const group = node_proto->add_attribute();
+  group->set_name("group");
+  group->set_type(AttributeProto::INT);
+  group->set_i(static_cast<int64>(num_group));
+
   // kernel shape
   AttributeProto* const kernel_shape = node_proto->add_attribute();
   kernel_shape->set_name("kernel_shape");
@@ -195,14 +230,6 @@ void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
     pads->add_ints(static_cast<int64>(kval));
   }
 
-  // dilations
-  AttributeProto* const dilations = node_proto->add_attribute();
-  dilations->set_name("dilations");
-  dilations->set_type(AttributeProto::INTS);
-  for (const dim_t kval : dilate) {
-    dilations->add_ints(static_cast<int64>(kval));
-  }
-
   // strides
   AttributeProto* const strides = node_proto->add_attribute();
   strides->set_name("strides");
@@ -210,12 +237,6 @@ void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
   for (const dim_t kval : stride) {
     strides->add_ints(static_cast<int64>(kval));
   }
-
-  // group
-  AttributeProto* const group = node_proto->add_attribute();
-  group->set_name("group");
-  group->set_type(AttributeProto::INT);
-  group->set_i(static_cast<int64>(num_group));
 }  // end ConvertConvolution
 
 void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
@@ -250,8 +271,12 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
   AttributeProto* const pads = node_proto->add_attribute();
   pads->set_name("pads");
   pads->set_type(AttributeProto::INTS);
-  for (int kval : pad) {
-    pads->add_ints(static_cast<int64>(kval));
+
+  // Convert from MXNet symetric pads to ONNX non-symetric by running through padding twice.
+  for (int i =0; i < 2; i++) {
+    for (dim_t kval : pad) {
+      pads->add_ints(static_cast<int64>(kval));
+    }
   }
 
   // strides
@@ -315,11 +340,6 @@ void ConvertFullyConnected(NodeProto* node_proto, const NodeAttrs& attrs,
       beta->set_type(AttributeProto::FLOAT);
       beta->set_f(1.0f);
 
-      AttributeProto* const broadcast = node_proto->add_attribute();
-      broadcast->set_name("broadcast");
-      broadcast->set_type(AttributeProto::INT);
-      broadcast->set_i(1);
-
       AttributeProto* const transA = node_proto->add_attribute();
       transA->set_name("transA");
       transA->set_type(AttributeProto::INT);
@@ -371,11 +391,6 @@ void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs,
   epsilon->set_type(AttributeProto::FLOAT);
   epsilon->set_f(static_cast<float>(param.eps));
 
-  AttributeProto* const is_test = node_proto->add_attribute();
-  is_test->set_name("is_test");
-  is_test->set_type(AttributeProto::INT);
-  is_test->set_i(1);
-
   AttributeProto* const momentum = node_proto->add_attribute();
   momentum->set_name("momentum");
   momentum->set_type(AttributeProto::FLOAT);
@@ -384,31 +399,16 @@ void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs,
   AttributeProto* const spatial = node_proto->add_attribute();
   spatial->set_name("spatial");
   spatial->set_type(AttributeProto::INT);
-  spatial->set_i(1);
-
-  AttributeProto* const consumed = node_proto->add_attribute();
-  consumed->set_name("consumed_inputs");
-  consumed->set_type(AttributeProto::INTS);
-
-  for (int i = 0; i < 5; i++) {
-    int val = (i < 3) ? 0 : 1;
-    consumed->add_ints(static_cast<int64>(val));
-  }
+  // MXNet computes mean and variance per feature for batchnorm.  Enabling spatial mode
+  // (default in ONNX3) implies running batchnorm on all spatial features so we need to explicitly
+  // disable this for MXNet's BatchNorm.
+  spatial->set_i(0);
 }
 
 void ConvertElementwiseAdd(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
                            const nnvm::IndexedGraph& /*ig*/,
                            const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
   node_proto->set_op_type("Add");
-  AttributeProto* const axis = node_proto->add_attribute();
-  axis->set_name("axis");
-  axis->set_type(AttributeProto::INT);
-  axis->set_i(1);
-
-  AttributeProto* const broadcast = node_proto->add_attribute();
-  broadcast->set_name("broadcast");
-  broadcast->set_type(AttributeProto::INT);
-  broadcast->set_i(0);  // 1
 }
 
 std::unordered_map<std::string, TShape> GetPlaceholderShapes(
@@ -461,32 +461,40 @@ void ConvertPlaceholder(
 void ConvertConstant(
     GraphProto* const graph_proto, const std::string& node_name,
     std::unordered_map<std::string, NDArray>* const shared_buffer) {
-  NodeProto* const node_proto = graph_proto->add_node();
-  node_proto->set_name(node_name);
-  node_proto->add_output(node_name);
-  node_proto->set_op_type("Constant");
+  TensorProto* const initializer_proto = graph_proto->add_initializer();
+
+  // Create initializer for constants
+  initializer_proto->set_name(node_name);
+  // TODO(kellens): convert to fp16 if needed.
+  initializer_proto->set_data_type(TensorProto_DataType_FLOAT);
 
   const NDArray nd = shared_buffer->find(node_name)->second;
   const TBlob& blob = nd.data();
   const TShape shape = blob.shape_;
-  const int32_t size = shape.Size();
 
+  for (auto& dim : shape) {
+    initializer_proto->add_dims(static_cast<int64>(dim));
+  }
+
+  auto size = shape.Size();
+  // TODO(kellens): Note hard coded float32 size assumed.
   std::shared_ptr<float> shared_data_ptr(new float[size]);
   float* const data_ptr = shared_data_ptr.get();
   nd.SyncCopyToCPU(static_cast<void*>(data_ptr), size);
 
-  AttributeProto* const tensor_attr = node_proto->add_attribute();
-  tensor_attr->set_name("value");
-  tensor_attr->set_type(AttributeProto::TENSOR);
-
-  TensorProto* const tensor_proto = tensor_attr->mutable_t();
-  tensor_proto->set_data_type(TensorProto_DataType_FLOAT);
-  for (auto& dim : shape) {
-    tensor_proto->add_dims(static_cast<int64>(dim));
+  for (size_t blob_idx = 0; blob_idx < size; ++blob_idx) {
+    initializer_proto->add_float_data(data_ptr[blob_idx]);
   }
 
-  for (int blob_idx = 0; blob_idx < size; ++blob_idx) {
-    tensor_proto->add_float_data(data_ptr[blob_idx]);
+  // Create inputs for constants.
+  ValueInfoProto* const input_proto = graph_proto->add_input();
+  input_proto->set_name(node_name);
+
+  // TODO(kellens): (fp16 support)
+  input_proto->mutable_type()->mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
+  for (auto& dim : shape) {
+    auto new_dim = input_proto->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim();
+    new_dim->set_dim_value(static_cast<int64>(dim));
   }
 }
 
diff --git a/tests/python/tensorrt/test_resnet18.py b/tests/python/tensorrt/test_resnet18.py
new file mode 100644
index 00000000000..fff3ac5dd76
--- /dev/null
+++ b/tests/python/tensorrt/test_resnet18.py
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from mxnet.gluon.model_zoo import vision
+from mxnet.test_utils import assert_almost_equal
+import mxnet as mx
+import numpy as np
+import os
+
+batch_shape = (1, 3, 224, 224)
+url = 'https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/cat.jpg?raw=true'
+model_file_name = 'resnet18_v2_trt_test'
+
+
+def get_image(image_url):
+    fname = mx.test_utils.download(image_url, fname=image_url.split('/')[-1].split('?')[0])
+    img = mx.image.imread(fname)
+    img = mx.image.imresize(img, 224, 224)  # Resize
+    img = img.transpose((2, 0, 1))  # Channel first
+    img = img.expand_dims(axis=0)  # Batchify
+    img = mx.nd.cast(img, dtype=np.float32)
+    return img/255.0
+
+
+def test_tensorrt_resnet18_feature_vect():
+    print("downloading sample input")
+    input_data = get_image(url)
+    gluon_resnet18 = vision.resnet18_v2(pretrained=True)
+    gluon_resnet18.hybridize()
+    gluon_resnet18.forward(input_data)
+    gluon_resnet18.export(model_file_name)
+    sym, arg_params, aux_params = mx.model.load_checkpoint(model_file_name, 0)
+
+    os.environ['MXNET_USE_TENSORRT'] = '0'
+    executor = sym.simple_bind(ctx=mx.gpu(), data=batch_shape, grad_req='null', force_rebind=True)
+    executor.copy_params_from(arg_params, aux_params)
+    y = executor.forward(is_train=False, data=input_data)
+
+    os.environ['MXNET_USE_TENSORRT'] = '1'
+    all_params = arg_params
+    all_params.update(aux_params)
+    executor = mx.contrib.tensorrt.tensorrt_bind(sym, ctx=mx.gpu(), all_params=all_params, data=batch_shape,
+                                                 grad_req='null', force_rebind=True)
+    y_trt = executor.forward(is_train=False, data=input_data)
+
+    no_trt_output = y[0].asnumpy()[0]
+    trt_output = y_trt[0].asnumpy()[0]
+    assert_almost_equal(no_trt_output, trt_output, 1e-4, 1e-4)
+
+
+if __name__ == '__main__':
+    import nose
+
+    nose.runmodule()


With regards,
Apache Git Services