You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/09 18:00:51 UTC

[GitHub] KellenSunderland closed pull request #12081: WIP: Tensorrt integration 12

KellenSunderland closed pull request #12081: WIP: Tensorrt integration 12
URL: https://github.com/apache/incubator-mxnet/pull/12081
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/.gitmodules b/.gitmodules
index 9aeb1c75498..836d824a6f5 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -26,3 +26,6 @@
 [submodule "3rdparty/tvm"]
 	path = 3rdparty/tvm
 	url = https://github.com/dmlc/tvm
+[submodule "3rdparty/onnx-tensorrt"]
+	path = 3rdparty/onnx-tensorrt
+	url = https://github.com/onnx/onnx-tensorrt.git
diff --git a/3rdparty/onnx-tensorrt b/3rdparty/onnx-tensorrt
new file mode 160000
index 00000000000..e7be19cff37
--- /dev/null
+++ b/3rdparty/onnx-tensorrt
@@ -0,0 +1 @@
+Subproject commit e7be19cff377a95817503e8525e20de34cdc574a
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 483108a6841..8ff337ed159 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -37,6 +37,7 @@ mxnet_option(ENABLE_CUDA_RTC      "Build with CUDA runtime compilation support"
 mxnet_option(BUILD_CPP_EXAMPLES   "Build cpp examples" ON)
 mxnet_option(INSTALL_EXAMPLES     "Install the example source files." OFF)
 mxnet_option(USE_SIGNAL_HANDLER   "Print stack traces on segfaults." OFF)
+mxnet_option(USE_TENSORRT         "Enable infeference optimization with TensorRT." OFF)
 
 message(STATUS "CMAKE_SYSTEM_NAME ${CMAKE_SYSTEM_NAME}")
 if(USE_CUDA AND NOT USE_OLDCMAKECUDA)
@@ -185,6 +186,36 @@ if(USE_VTUNE)
   list(APPEND mxnet_LINKER_LIBS dl)
 endif()
 
+if(USE_TENSORRT)
+  message(STATUS "Using TensorRT")
+  set(ONNX_PATH 3rdparty/onnx-tensorrt/third_party/onnx/build/)
+  set(ONNX_TRT_PATH 3rdparty/onnx-tensorrt/build/)
+
+  include_directories(${ONNX_PATH})
+  include_directories(3rdparty/onnx-tensorrt/)
+  include_directories(3rdparty/)
+  add_definitions(-DMXNET_USE_TENSORRT=1)
+  add_definitions(-DONNX_NAMESPACE=onnx)
+
+  find_package(Protobuf REQUIRED)
+
+  find_library(ONNX_LIBRARY NAMES libonnx.so REQUIRED
+          PATHS ${ONNX_PATH}
+          DOC "Path to onnx library.")
+  find_library(ONNX_PROTO_LIBRARY NAMES libonnx_proto.so REQUIRED
+          PATHS ${ONNX_PATH}
+          DOC "Path to onnx_proto library.")
+  find_library(ONNX_TRT_RUNTIME_LIBRARY NAMES libnvonnxparser_runtime.so REQUIRED
+          PATHS ${ONNX_TRT_PATH}
+          DOC "Path to onnx_proto library.")
+  find_library(ONNX_TRT_PARSER_LIBRARY NAMES libnvonnxparser.so REQUIRED
+          PATHS ${ONNX_TRT_PATH}
+          DOC "Path to onnx_proto library.")
+
+  list(APPEND mxnet_LINKER_LIBS libnvinfer.so ${ONNX_TRT_PARSER_LIBRARY} ${ONNX_TRT_RUNTIME_LIBRARY}
+          ${ONNX_PROTO_LIBRARY} ${ONNX_LIBRARY} ${PROTOBUF_LIBRARY})
+endif()
+
 if(USE_MKLDNN)
   include(cmake/MklDnn.cmake)
   # CPU architecture (e.g., C5) can't run on another architecture (e.g., g3).
diff --git a/Jenkinsfile b/Jenkinsfile
index 6d21f496426..758e8e870ee 100644
--- a/Jenkinsfile
+++ b/Jenkinsfile
@@ -28,6 +28,7 @@ mx_dist_lib = 'lib/libmxnet.so, lib/libmxnet.a, 3rdparty/dmlc-core/libdmlc.a, 3r
 mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so'
 mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0'
 mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
+mx_tensorrt_lib = 'lib/libmxnet.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
 // timeout in minutes
 max_time = 120
 // assign any caught errors here
@@ -372,6 +373,17 @@ try {
         }
       }
     },
+    'TensorRT': {
+      node(NODE_LINUX_CPU) {
+        ws('workspace/build-tensorrt') {
+          timeout(time: max_time, unit: 'MINUTES') {
+            utils.init_git()
+            utils.docker_run('ubuntu_gpu_tensorrt', 'build_ubuntu_gpu_tensorrt', false)
+            utils.pack_lib('tensorrt', mx_tensorrt_lib)
+          }
+        }
+      }
+    },
     'Build CPU windows':{
       node('mxnetwindows-cpu') {
         timeout(time: max_time, unit: 'MINUTES') {
@@ -740,6 +752,22 @@ try {
         }
       }
     },
+    'Python3: TensorRT GPU': {
+      node(NODE_LINUX_GPU_P3) {
+        ws('workspace/build-tensorrt') {
+          timeout(time: max_time, unit: 'MINUTES') {
+            try {
+              utils.init_git()
+              utils.unpack_lib('tensorrt', mx_tensorrt_lib)
+              utils.docker_run('ubuntu_gpu_tensorrt', 'unittest_ubuntu_tensorrt_gpu', true)
+              utils.publish_test_coverage()
+            } finally {
+              utils.collect_test_results_unix('nosetests_tensorrt.xml', 'nosetests_python3_tensorrt_gpu.xml')
+            }
+          }
+        }
+      }
+    },
     'Scala: CPU': {
       node('mxnetlinux-cpu') {
         ws('workspace/ut-scala-cpu') {
diff --git a/Makefile b/Makefile
index 88f7dd9278c..b794e00f00a 100644
--- a/Makefile
+++ b/Makefile
@@ -91,6 +91,14 @@ else
 endif
 CFLAGS += -I$(TPARTYDIR)/mshadow/ -I$(TPARTYDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(TPARTYDIR)/tvm/include -Iinclude $(MSHADOW_CFLAGS)
 LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS)
+
+
+ifeq ($(USE_TENSORRT), 1)
+	CFLAGS +=  -I$(ROOTDIR) -I$(TPARTYDIR) -DONNX_NAMESPACE=$(ONNX_NAMESPACE) -DMXNET_USE_TENSORRT=1
+	LDFLAGS += -lprotobuf -pthread -lonnx -lonnx_proto -lnvonnxparser -lnvonnxparser_runtime -lnvinfer -lnvinfer_plugin
+endif
+# -L/usr/local/lib
+
 ifeq ($(DEBUG), 1)
 	NVCCFLAGS += -std=c++11 -Xcompiler -D_FORCE_INLINES -g -G -O0 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS)
 else
diff --git a/amalgamation/amalgamation.py b/amalgamation/amalgamation.py
index 52d775b7692..a3c28f7118e 100644
--- a/amalgamation/amalgamation.py
+++ b/amalgamation/amalgamation.py
@@ -23,13 +23,12 @@
 import platform
 
 blacklist = [
-    'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh',
-    'cuda_runtime.h', 'cudnn.h', 'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h',
-    'glog/logging.h', 'io/azure_filesys.h', 'io/hdfs_filesys.h', 'io/s3_filesys.h',
-    'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h',
-    'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h',
-    'nvml.h', 'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h',
-    'omp.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h',
+    'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh', 'cuda_runtime.h', 'cudnn.h',
+    'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h', 'glog/logging.h', 'io/azure_filesys.h',
+    'io/hdfs_filesys.h', 'io/s3_filesys.h', 'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h',
+    'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h', 'NvInfer.h', 'nvml.h',
+    'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h', 'omp.h',
+    'onnx/onnx.pb.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h',
     'cusolverDn.h', 'internal/concurrentqueue_internal_debug.h', 'relacy/relacy_std.hpp',
     'relacy_shims.h', 'ittnotify.h', 'shared_mutex'
     ]
@@ -150,6 +149,7 @@ def expand(x, pending, stage):
                     h not in sysheaders and
                     'mkl' not in h and
                     'nnpack' not in h and
+                    'tensorrt' not in h and
                     not h.endswith('.cuh')): sysheaders.append(h)
             else:
                 expand.treeDepth += 1
diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
new file mode 100755
index 00000000000..255da316041
--- /dev/null
+++ b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
@@ -0,0 +1,41 @@
+# -*- mode: dockerfile -*-
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Dockerfile to run MXNet on Ubuntu 16.04 for CPU
+
+FROM nvidia/cuda:9.0-cudnn7-devel
+
+WORKDIR /work/deps
+
+COPY install/ubuntu_core.sh /work/
+RUN /work/ubuntu_core.sh
+COPY install/deb_ubuntu_ccache.sh /work/
+RUN /work/deb_ubuntu_ccache.sh
+COPY install/ubuntu_python.sh /work/
+RUN /work/ubuntu_python.sh
+COPY install/tensorrt.sh /work
+RUN /work/tensorrt.sh
+
+ARG USER_ID=0
+COPY install/ubuntu_adduser.sh /work/
+RUN /work/ubuntu_adduser.sh
+
+COPY runtime_functions.sh /work/
+
+WORKDIR /work/mxnet
+ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib
diff --git a/ci/docker/install/tensorrt.sh b/ci/docker/install/tensorrt.sh
new file mode 100755
index 00000000000..a6258d94f62
--- /dev/null
+++ b/ci/docker/install/tensorrt.sh
@@ -0,0 +1,45 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Install gluoncv since we're testing Gluon models as well
+pip2 install gluoncv==0.2.0
+pip3 install gluoncv==0.2.0
+
+# Install Protobuf
+# Install protoc 3.5 and build protobuf here (for onnx and onnx-tensorrt)
+pushd .
+cd ..
+apt-get update
+apt-get install -y automake libtool
+git clone --recursive -b 3.5.1.1 https://github.com/google/protobuf.git
+cd protobuf
+./autogen.sh
+./configure
+make -j$(nproc)
+make install
+ldconfig
+popd
+
+# Install TensorRT
+echo "TensorRT build enabled. Installing TensorRT."
+wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0_1-1_amd64.deb
+dpkg -i tensorrt.deb
+apt-get update
+apt-get install -y --allow-downgrades libnvinfer-dev
+rm tensorrt.deb
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index a0795eb58a5..3e19eaf7004 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -436,6 +436,60 @@ build_ubuntu_gpu() {
     build_ubuntu_gpu_cuda91_cudnn7
 }
 
+build_ubuntu_gpu_tensorrt() {
+
+    set -ex
+
+    build_ccache_wrappers
+
+    # Build ONNX
+    pushd .
+    echo "Installing ONNX."
+    cd 3rdparty/onnx-tensorrt/third_party/onnx
+    rm -rf build
+    mkdir -p build
+    cd build
+    cmake \
+        -DCMAKE_CXX_FLAGS=-I/usr/include/python${PYVER}\
+        -DBUILD_SHARED_LIBS=ON ..\
+        -G Ninja
+    ninja -v
+    export LIBRARY_PATH=`pwd`:`pwd`/onnx/:$LIBRARY_PATH
+    export CPLUS_INCLUDE_PATH=`pwd`:$CPLUS_INCLUDE_PATH
+    popd
+
+    # Build ONNX-TensorRT
+    pushd .
+    cd 3rdparty/onnx-tensorrt/
+    mkdir -p build
+    cd build
+    cmake ..
+    make -j$(nproc)
+    export LIBRARY_PATH=`pwd`:$LIBRARY_PATH
+    popd
+
+    mkdir -p /work/mxnet/lib/
+    cp 3rdparty/onnx-tensorrt/third_party/onnx/build/*.so /work/mxnet/lib/
+    cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser_runtime.so.0 /work/mxnet/lib/
+    cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser.so.0 /work/mxnet/lib/
+
+    rm -rf build
+    make \
+        DEV=1                                               \
+        USE_BLAS=openblas                                   \
+        USE_CUDA=1                                          \
+        USE_CUDA_PATH=/usr/local/cuda                       \
+        USE_CUDNN=1                                         \
+        USE_OPENCV=0                                        \
+        USE_DIST_KVSTORE=0                                  \
+        USE_TENSORRT=1                                      \
+        USE_JEMALLOC=0                                      \
+        USE_GPERFTOOLS=0                                    \
+        ONNX_NAMESPACE=onnx                                 \
+        CUDA_ARCH="-gencode arch=compute_70,code=compute_70"\
+        -j$(nproc)
+}
+
 build_ubuntu_gpu_mkldnn() {
     set -ex
 
@@ -638,6 +692,15 @@ unittest_ubuntu_python3_gpu_nocudnn() {
     nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_gpu.xml --verbose tests/python/gpu
 }
 
+unittest_ubuntu_tensorrt_gpu() {
+    set -ex
+    export PYTHONPATH=./python/
+    export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
+    export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH
+    python tests/python/tensorrt/lenet5_train.py
+    nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_trt_gpu.xml --verbose tests/python/tensorrt/
+}
+
 # quantization gpu currently only runs on P3 instances
 # need to separte it from unittest_ubuntu_python2_gpu()
 unittest_ubuntu_python2_quantization_gpu() {
@@ -970,3 +1033,5 @@ EOF
     declare -F | cut -d' ' -f3
     echo
 fi
+
+
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 75147cfd706..58b1b1b4daf 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1714,6 +1714,13 @@ MXNET_DLL int MXExecutorReshape(int partial_shaping,
                                 NDArrayHandle** aux_states,
                                 ExecutorHandle shared_exec,
                                 ExecutorHandle *out);
+
+/*!
+ * \brief get optimized graph from graph executor
+ */
+MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle,
+                                           SymbolHandle *out);
+
 /*!
  * \brief set a call back to notify the completion of operation
  */
diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h
index 842653f8653..0ab04b86a0a 100644
--- a/include/mxnet/executor.h
+++ b/include/mxnet/executor.h
@@ -166,6 +166,7 @@ class Executor {
                               std::unordered_map<std::string, NDArray>*
                                 shared_data_arrays = nullptr,
                               Executor* shared_exec = nullptr);
+
   /*!
    * \brief the prototype of user-defined monitor callback
    */
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 4df794bdfe3..5f4ae8bd0ac 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -709,3 +709,19 @@ def write_all_str(module_file, module_all_list):
     module_op_file.close()
     write_all_str(module_internal_file, module_internal_all)
     module_internal_file.close()
+
+def cint(init_val=0):
+    """create a C int with an optional initial value"""
+    return C.c_int(init_val)
+
+def int_addr(x):
+    """given a c_int, return it's address as an int ptr"""
+    x_addr = C.addressof(x)
+    int_p = C.POINTER(C.c_int)
+    x_int_addr = C.cast(x_addr, int_p)
+    return x_int_addr
+
+def checked_call(f, *args):
+    """call a cuda function and check for success"""
+    error_t = f(*args)
+    assert error_t == 0, "Failing cuda call %s returns %s." % (f.__name__, error_t)
diff --git a/python/mxnet/contrib/__init__.py b/python/mxnet/contrib/__init__.py
index fbfd3469678..606bb0ada54 100644
--- a/python/mxnet/contrib/__init__.py
+++ b/python/mxnet/contrib/__init__.py
@@ -32,3 +32,4 @@
 from . import io
 from . import quantization
 from . import quantization as quant
+from . import tensorrt
diff --git a/python/mxnet/contrib/tensorrt.py b/python/mxnet/contrib/tensorrt.py
new file mode 100644
index 00000000000..11bdecc4486
--- /dev/null
+++ b/python/mxnet/contrib/tensorrt.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Module to enable the use of TensorRT optimized graphs."""
+
+import ctypes
+import logging
+import os
+
+from mxnet.symbol import Symbol
+
+from ..base import _LIB, SymbolHandle, MXNetError
+from ..base import check_call
+
+
+def set_use_tensorrt(status):
+    """
+    Set an environment variable which will enable or disable the use of TensorRT in the backend.
+    Note: this is useful for A/B testing purposes.
+    :param status: Boolean, true if TensorRT optimization should be applied, False for legacy
+    behaviour.
+    """
+    os.environ["MXNET_USE_TENSORRT"] = str(int(status))
+
+
+def get_use_tensorrt():
+    """
+    Get an environment variable which describes if TensorRT is currently enabled in the backend.
+    Note: this is useful for A/B testing purposes.
+    :return: Boolean, true if TensorRT optimization should be applied, False for legacy
+    behaviour.
+    """
+    return bool(int(os.environ.get("MXNET_USE_TENSORRT", 0)) == 1)
+
+
+def get_optimized_symbol(executor):
+    """
+    Take an executor's underlying symbol graph and return its generated optimized version.
+
+    Parameters
+    ----------
+    executor :
+        An executor for which you want to see an optimized symbol. Getting an optimized symbol
+        is useful to compare and verify the work TensorRT has done against a legacy behaviour.
+
+    Returns
+    -------
+    symbol : nnvm::Symbol
+        The nnvm symbol optimized.
+    """
+    handle = SymbolHandle()
+    try:
+        check_call(_LIB.MXExecutorGetOptimizedSymbol(executor.handle, ctypes.byref(handle)))
+        result = Symbol(handle=handle)
+        return result
+    except MXNetError:
+        logging.error('Error while trying to fetch TRT optimized symbol for graph. Please ensure '
+                      'build was compiled with MXNET_USE_TENSORRT enabled.')
+        raise
diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py
index c0272c5bb43..fcd5406236e 100644
--- a/python/mxnet/executor.py
+++ b/python/mxnet/executor.py
@@ -73,6 +73,7 @@ def __init__(self, handle, symbol, ctx, grad_req, group2ctx):
         self.aux_arrays = []
         self.outputs = self._get_outputs()
         self._symbol = copy.deepcopy(symbol)
+        self._optimized_symbol = None
         self._arg_dict = None
         self._grad_dict = None
         self._aux_dict = None
diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py
index 5d8e95077c4..c4050699bd5 100755
--- a/python/mxnet/module/executor_group.py
+++ b/python/mxnet/module/executor_group.py
@@ -592,8 +592,8 @@ def backward(self, out_grads=None):
                     # pylint: disable=no-member
                     og_my_slice = nd.slice_axis(grad, axis=axis, begin=islice.start,
                                                 end=islice.stop)
-                    # pylint: enable=no-member
                     out_grads_slice.append(og_my_slice.as_in_context(self.contexts[i]))
+                    # pylint: enable=no-member
                 else:
                     out_grads_slice.append(grad.copyto(self.contexts[i]))
             exec_.backward(out_grads=out_grads_slice)
diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc
index 09bc23934e5..b99350525bf 100644
--- a/src/c_api/c_api_executor.cc
+++ b/src/c_api/c_api_executor.cc
@@ -26,6 +26,10 @@
 #include <mxnet/c_api.h>
 #include <mxnet/executor.h>
 #include "./c_api_common.h"
+#include "../executor/graph_executor.h"
+#if MXNET_USE_TENSORRT
+#include "../executor/trt_graph_executor.h"
+#endif  // MXNET_USE_TENSORRT
 
 int MXExecutorPrint(ExecutorHandle handle, const char **out_str) {
   Executor *exec = static_cast<Executor*>(handle);
@@ -439,13 +443,38 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
   std::vector<NDArray> in_arg_vec;
   std::vector<NDArray> arg_grad_vec;
   std::vector<NDArray> aux_state_vec;
-
-  *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec,
-                              aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map,
-                              grad_req_type_vec, shared_arg_name_set, &in_arg_vec,
-                              &arg_grad_vec, &aux_state_vec,
-                              use_shared_buffer ? &shared_buffer_map : nullptr,
-                              reinterpret_cast<Executor*>(shared_exec_handle));
+#if MXNET_USE_TENSORRT
+  // If we've built with TensorRT support we by default return an TRTExecutor.
+  // Users can override this behaviour via env var, which is useful for example for A/B
+  // performance testing.
+  if (dmlc::GetEnv("MXNET_USE_TENSORRT", false)) {
+    *out = exec::TrtGraphExecutor::TensorRTBind(*sym, ctx, ctx_map, &in_arg_ctx_vec,
+                                                &arg_grad_ctx_vec, &aux_state_ctx_vec,
+                                                &arg_shape_map, &arg_dtype_map, &arg_stype_map,
+                                                &grad_req_type_vec, shared_arg_name_set,
+                                                &in_arg_vec, &arg_grad_vec, &aux_state_vec,
+                                                use_shared_buffer ? &shared_buffer_map : nullptr,
+                                                reinterpret_cast<Executor*>(shared_exec_handle));
+  } else {
+    // Checks to see if this env var has been set to true or false by the user.
+    // If the user is using a TensorRT build, but has not enabled TRT at inference time, warn
+    // them and describe further steps.
+    const int unset_indicator =  std::numeric_limits<int>::quiet_NaN();
+    if (dmlc::GetEnv("MXNET_USE_TENSORRT", unset_indicator) == unset_indicator) {
+      LOG(INFO) << "TensorRT not enabled by default.  Please set the MXNET_USE_TENSORRT "
+                   "environment variable to 1 or call mx.contrib.tensorrt.set_use_tensorrt(True) "
+                   "to enable.";
+    }
+#endif  // MXNET_USE_TENSORRT
+    *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec,
+                                aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map,
+                                grad_req_type_vec, shared_arg_name_set, &in_arg_vec,
+                                &arg_grad_vec, &aux_state_vec,
+                                use_shared_buffer ? &shared_buffer_map : nullptr,
+                                reinterpret_cast<Executor*>(shared_exec_handle));
+#if MXNET_USE_TENSORRT
+  }
+#endif  // MXNET_USE_TENSORRT
 
   // copy ndarray ptrs to ret->handles so that front end
   // can access them
@@ -597,6 +626,25 @@ int MXExecutorReshape(int partial_shaping,
   API_END_HANDLE_ERROR(delete out);
 }
 
+int MXExecutorGetOptimizedSymbol(ExecutorHandle handle,
+                                 SymbolHandle *out) {
+  auto s = new nnvm::Symbol();
+  API_BEGIN();
+
+#if MXNET_USE_TENSORRT
+  auto exec = static_cast<exec::TrtGraphExecutor*>(handle);
+  *s = exec->GetOptimizedSymbol();
+  *out = s;
+#else
+  LOG(FATAL) << "GetOptimizedSymbol may only be used when MXNet is compiled with "
+                "MXNET_USE_TENSORRT enabled.  Please re-compile MXNet with TensorRT support.";
+#endif  // MXNET_USE_TENSORRT
+
+  API_END_HANDLE_ERROR(delete s);
+}
+
+
+
 int MXExecutorSetMonitorCallback(ExecutorHandle handle,
                                  ExecutorMonitorCallback callback,
                                  void* callback_handle) {
diff --git a/src/common/serialization.h b/src/common/serialization.h
new file mode 100644
index 00000000000..56b6069304d
--- /dev/null
+++ b/src/common/serialization.h
@@ -0,0 +1,319 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015 by Contributors
+ * \file serialization.h
+ * \brief Serialization of some STL and nnvm data-structures
+ * \author Clement Fuji Tsang
+ */
+
+#ifndef MXNET_COMMON_SERIALIZATION_H_
+#define MXNET_COMMON_SERIALIZATION_H_
+
+#include <dmlc/logging.h>
+#include <mxnet/graph_attr_types.h>
+#include <nnvm/graph_attr_types.h>
+#include <nnvm/tuple.h>
+
+#include <cstring>
+#include <map>
+#include <set>
+#include <string>
+#include <tuple>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+
+namespace mxnet {
+namespace common {
+
+template<typename T>
+inline size_t SerializedSize(const T &obj);
+
+template<typename T>
+inline size_t SerializedSize(const nnvm::Tuple <T> &obj);
+
+template<typename K, typename V>
+inline size_t SerializedSize(const std::map <K, V> &obj);
+
+template<>
+inline size_t SerializedSize(const std::string &obj);
+
+template<typename... Args>
+inline size_t SerializedSize(const std::tuple<Args...> &obj);
+
+template<typename T>
+inline void Serialize(const T &obj, char **buffer);
+
+template<typename T>
+inline void Serialize(const nnvm::Tuple <T> &obj, char **buffer);
+
+template<typename K, typename V>
+inline void Serialize(const std::map <K, V> &obj, char **buffer);
+
+template<>
+inline void Serialize(const std::string &obj, char **buffer);
+
+template<typename... Args>
+inline void Serialize(const std::tuple<Args...> &obj, char **buffer);
+
+template<typename T>
+inline void Deserialize(T *obj, const std::string &buffer, size_t *curr_pos);
+
+template<typename T>
+inline void Deserialize(nnvm::Tuple <T> *obj, const std::string &buffer, size_t *curr_pos);
+
+template<typename K, typename V>
+inline void Deserialize(std::map <K, V> *obj, const std::string &buffer, size_t *curr_pos);
+
+template<>
+inline void Deserialize(std::string *obj, const std::string &buffer, size_t *curr_pos);
+
+template<typename... Args>
+inline void Deserialize(std::tuple<Args...> *obj, const std::string &buffer, size_t *curr_pos);
+
+
+template<typename T>
+struct is_container {
+  static const bool value = !std::is_pod<T>::value;
+};
+
+template<typename T>
+inline size_t SerializedSize(const T &obj) {
+  return sizeof(T);
+}
+
+template<typename T>
+inline size_t SerializedSize(const nnvm::Tuple <T> &obj) {
+  if (is_container<T>::value) {
+    size_t sum_val = 4;
+    for (auto& el : obj) {
+      sum_val += SerializedSize(el);
+    }
+    return sum_val;
+  } else {
+    return 4 + (obj.ndim() * sizeof(T));
+  }
+}
+
+template<typename K, typename V>
+inline size_t SerializedSize(const std::map <K, V> &obj) {
+  size_t sum_val = 4;
+  if (is_container<K>::value && is_container<V>::value) {
+    for (const auto& p : obj) {
+      sum_val += SerializedSize(p.first) + SerializedSize(p.second);
+    }
+  } else if (is_container<K>::value) {
+    for (const auto& p : obj) {
+      sum_val += SerializedSize(p.first);
+    }
+    sum_val += sizeof(V) * obj.size();
+  } else if (is_container<V>::value) {
+    for (const auto& p : obj) {
+      sum_val += SerializedSize(p.second);
+    }
+    sum_val += sizeof(K) * obj.size();
+  } else {
+    sum_val += (sizeof(K) + sizeof(V)) * obj.size();
+  }
+  return sum_val;
+}
+
+template<>
+inline size_t SerializedSize(const std::string &obj) {
+  return obj.size() + 4;
+}
+
+template<int I>
+struct serialized_size_tuple {
+  template<typename... Args>
+  static inline size_t Compute(const std::tuple<Args...> &obj) {
+    return SerializedSize(std::get<I>(obj)) + serialized_size_tuple<I-1>::Compute(obj);
+  }
+};
+
+template<>
+struct serialized_size_tuple<0> {
+  template<typename... Args>
+  static inline size_t Compute(const std::tuple<Args...> &obj) {
+    return SerializedSize(std::get<0>(obj));
+  }
+};
+
+template<typename... Args>
+inline size_t SerializedSize(const std::tuple<Args...> &obj) {
+  return serialized_size_tuple<sizeof... (Args)-1>::Compute(obj);
+}
+
+//  Serializer
+
+template<typename T>
+inline size_t SerializedContainerSize(const T &obj, char **buffer) {
+  uint32_t size = obj.size();
+  std::memcpy(*buffer, &size, 4);
+  *buffer += 4;
+  return (size_t) size;
+}
+
+template<typename T>
+inline void Serialize(const T &obj, char **buffer) {
+  std::memcpy(*buffer, &obj, sizeof(T));
+  *buffer += sizeof(T);
+}
+
+template<typename T>
+inline void Serialize(const nnvm::Tuple <T> &obj, char **buffer) {
+  uint32_t size = obj.ndim();
+  std::memcpy(*buffer, &size, 4);
+  *buffer += 4;
+  for (auto& el : obj) {
+    Serialize(el, buffer);
+  }
+}
+
+template<typename K, typename V>
+inline void Serialize(const std::map <K, V> &obj, char **buffer) {
+  SerializedContainerSize(obj, buffer);
+  for (auto& p : obj) {
+    Serialize(p.first, buffer);
+    Serialize(p.second, buffer);
+  }
+}
+
+template<>
+inline void Serialize(const std::string &obj, char **buffer) {
+  auto size = SerializedContainerSize(obj, buffer);
+  std::memcpy(*buffer, &obj[0], size);
+  *buffer += size;
+}
+
+template<int I>
+struct serialize_tuple {
+  template<typename... Args>
+  static inline void Compute(const std::tuple<Args...> &obj, char **buffer) {
+    serialize_tuple<I-1>::Compute(obj, buffer);
+    Serialize(std::get<I>(obj), buffer);
+  }
+};
+
+template<>
+struct serialize_tuple<0> {
+  template<typename... Args>
+  static inline void Compute(const std::tuple<Args...> &obj, char **buffer) {
+    Serialize(std::get<0>(obj), buffer);
+  }
+};
+
+template<typename... Args>
+inline void Serialize(const std::tuple<Args...> &obj, char **buffer) {
+  serialize_tuple<sizeof... (Args)-1>::Compute(obj, buffer);
+}
+
+// Deserializer
+
+template<typename T>
+inline size_t DeserializedContainerSize(T *obj, const std::string &buffer, size_t *curr_pos) {
+  uint32_t size = obj->size();
+  std::memcpy(&size, &buffer[*curr_pos], 4);
+  *curr_pos += 4;
+  return (size_t) size;
+}
+
+template<typename T>
+inline void Deserialize(T *obj, const std::string &buffer, size_t *curr_pos) {
+  std::memcpy(obj, &buffer[*curr_pos], sizeof(T));
+  *curr_pos += sizeof(T);
+}
+
+template<typename T>
+inline void Deserialize(nnvm::Tuple <T> *obj, const std::string &buffer, size_t *curr_pos) {
+  uint32_t size = obj->ndim();
+  std::memcpy(&size, &buffer[*curr_pos], 4);
+  *curr_pos += 4;
+  obj->SetDim(size);
+  for (size_t i = 0; i < size; ++i) {
+    Deserialize((*obj)[i], buffer, curr_pos);
+  }
+}
+
+template<typename K, typename V>
+inline void Deserialize(std::map <K, V> *obj, const std::string &buffer, size_t *curr_pos) {
+  auto size = DeserializedContainerSize(obj, buffer, curr_pos);
+  K first;
+  for (size_t i = 0; i < size; ++i) {
+    Deserialize(&first, buffer, curr_pos);
+    Deserialize(&(*obj)[first], buffer, curr_pos);
+  }
+}
+
+template<>
+inline void Deserialize(std::string *obj, const std::string &buffer, size_t *curr_pos) {
+  auto size = DeserializedContainerSize(obj, buffer, curr_pos);
+  obj->resize(size);
+  std::memcpy(&(obj->front()), &buffer[*curr_pos], size);
+  *curr_pos += size;
+}
+
+template<int I>
+struct deserialize_tuple {
+  template<typename... Args>
+  static inline void Compute(std::tuple<Args...> *obj,
+                             const std::string &buffer, size_t *curr_pos) {
+    deserialize_tuple<I-1>::Compute(obj, buffer, curr_pos);
+    Deserialize(&std::get<I>(*obj), buffer, curr_pos);
+  }
+};
+
+template<>
+struct deserialize_tuple<0> {
+  template<typename... Args>
+  static inline void Compute(std::tuple<Args...> *obj,
+                             const std::string &buffer, size_t *curr_pos) {
+    Deserialize(&std::get<0>(*obj), buffer, curr_pos);
+  }
+};
+
+template<typename... Args>
+inline void Deserialize(std::tuple<Args...> *obj, const std::string &buffer, size_t *curr_pos) {
+  deserialize_tuple<sizeof... (Args)-1>::Compute(obj, buffer, curr_pos);
+}
+
+
+template<typename T>
+inline void Serialize(const T& obj, std::string* serialized_data) {
+  serialized_data->resize(SerializedSize(obj));
+  char* curr_pos = &(serialized_data->front());
+  Serialize(obj, &curr_pos);
+  CHECK_EQ((int64_t)curr_pos - (int64_t)&(serialized_data->front()),
+           serialized_data->size());
+}
+
+template<typename T>
+inline void Deserialize(T* obj, const std::string& serialized_data) {
+  size_t curr_pos = 0;
+  Deserialize(obj, serialized_data, &curr_pos);
+  CHECK_EQ(curr_pos, serialized_data.size());
+}
+
+}  // namespace common
+}  // namespace mxnet
+#endif  // MXNET_COMMON_SERIALIZATION_H_
diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index 26a24911894..8c483e9b2b8 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -198,6 +198,18 @@ Graph InferStorageType(Graph&& graph,
                        StorageTypeVector&& storage_type_inputs = StorageTypeVector(),
                        const std::string& storage_type_attr_key = "");
 
+#if MXNET_USE_TENSORRT
+/*!
+ * \brief Replace subgraphs by TRT (forward only)
+ */
+Graph ReplaceSubgraph(Graph&& g,
+                      const std::unordered_set<nnvm::Node*>& set_subgraph,
+                      std::unordered_map<std::string, NDArray>* const params_map);
+
+std::vector<std::unordered_set<nnvm::Node*>> GetTrtCompatibleSubsets(const Graph& g,
+    std::unordered_map<std::string, NDArray>* const params_map);
+#endif
+
 }  // namespace exec
 }  // namespace mxnet
 
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 7386de4d12e..f9c286b596b 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -56,8 +56,8 @@ GraphExecutor::~GraphExecutor() {
   }
 }
 
-inline NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape,
-                                const Context &ctx, const int dtype) {
+inline NDArray GraphExecutor::InitZeros(const NDArrayStorageType stype, const TShape &shape,
+                                        const Context &ctx, const int dtype) {
   // NDArray with default storage
   if (stype == kDefaultStorage) {
     NDArray ret(shape, ctx, false, dtype);
@@ -68,9 +68,9 @@ inline NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape,
   return NDArray(stype, shape, ctx, true, dtype);
 }
 
-inline void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape,
-                             const Context &ctx, const int dtype,
-                             std::vector<NDArray> *vec) {
+inline void GraphExecutor::EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape,
+                                            const Context &ctx, const int dtype,
+                                            std::vector<NDArray> *vec) {
   // NDArray with default storage
   if (stype == kDefaultStorage) {
     vec->emplace_back(shape, ctx, false, dtype);
@@ -312,15 +312,15 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
  * \brief Assign context to the graph.
  * This is triggered by both simple_bind and bind flows.
  */
-static Graph AssignContext(Graph g,
-                    const Context& default_ctx,
-                    const std::map<std::string, Context>& ctx_map,
-                    const std::vector<Context>& in_arg_ctxes,
-                    const std::vector<Context>& arg_grad_ctxes,
-                    const std::vector<Context>& aux_state_ctxes,
-                    const std::vector<OpReqType>& grad_req_types,
-                    size_t num_forward_inputs,
-                    size_t num_forward_outputs) {
+Graph GraphExecutor::AssignContext(Graph g,
+                                   const Context& default_ctx,
+                                   const std::map<std::string, Context>& ctx_map,
+                                   const std::vector<Context>& in_arg_ctxes,
+                                   const std::vector<Context>& arg_grad_ctxes,
+                                   const std::vector<Context>& aux_state_ctxes,
+                                   const std::vector<OpReqType>& grad_req_types,
+                                   size_t num_forward_inputs,
+                                   size_t num_forward_outputs) {
   const auto& idx = g.indexed_graph();
   const auto& mutable_nodes = idx.mutable_input_nodes();
   // default use default context.
@@ -437,9 +437,9 @@ static Graph AssignContext(Graph g,
   return g;
 }
 
-static void HandleInferShapeError(const size_t num_forward_inputs,
-                           const nnvm::IndexedGraph& idx,
-                           const nnvm::ShapeVector& inferred_shapes) {
+void GraphExecutor::HandleInferShapeError(const size_t num_forward_inputs,
+                                          const nnvm::IndexedGraph& idx,
+                                          const nnvm::ShapeVector& inferred_shapes) {
   int cnt = 10;
   std::ostringstream oss;
   for (size_t i = 0; i < num_forward_inputs; ++i) {
@@ -460,9 +460,9 @@ static void HandleInferShapeError(const size_t num_forward_inputs,
              << oss.str();
 }
 
-static void HandleInferTypeError(const size_t num_forward_inputs,
-                          const nnvm::IndexedGraph& idx,
-                          const nnvm::DTypeVector& inferred_dtypes) {
+void GraphExecutor::HandleInferTypeError(const size_t num_forward_inputs,
+                                         const nnvm::IndexedGraph& idx,
+                                         const nnvm::DTypeVector& inferred_dtypes) {
   int cnt = 10;
   std::ostringstream oss;
   for (size_t i = 0; i < num_forward_inputs; ++i) {
@@ -483,9 +483,9 @@ static void HandleInferTypeError(const size_t num_forward_inputs,
              << oss.str();
 }
 
-static void HandleInferStorageTypeError(const size_t num_forward_inputs,
-                                 const nnvm::IndexedGraph& idx,
-                                 const StorageTypeVector& inferred_stypes) {
+void GraphExecutor::HandleInferStorageTypeError(const size_t num_forward_inputs,
+                                                const nnvm::IndexedGraph& idx,
+                                                const StorageTypeVector& inferred_stypes) {
   int cnt = 10;
   std::ostringstream oss;
   for (size_t i = 0; i < num_forward_inputs; ++i) {
@@ -688,13 +688,13 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
  * Shareable storages include both default storage and row_sparse storage
  * if enable_row_sparse_sharing is `True`, otherwise default storage only.
  */
-static NDArray ReshapeOrCreate(const std::string& name,
-                        const TShape& dest_arg_shape,
-                        const int dest_arg_dtype,
-                        const NDArrayStorageType dest_arg_stype,
-                        const Context& ctx,
-                        std::unordered_map<std::string, NDArray>* shared_buffer,
-                        bool enable_row_sparse_sharing) {
+NDArray GraphExecutor::ReshapeOrCreate(const std::string& name,
+                                       const TShape& dest_arg_shape,
+                                       const int dest_arg_dtype,
+                                       const NDArrayStorageType dest_arg_stype,
+                                       const Context& ctx,
+                                       std::unordered_map<std::string, NDArray>* shared_buffer,
+                                       bool enable_row_sparse_sharing) {
   bool stype_shareable = dest_arg_stype == kDefaultStorage;
   if (enable_row_sparse_sharing) {
     stype_shareable = stype_shareable || dest_arg_stype == kRowSparseStorage;
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index bfc415b4526..05429b2508a 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -163,20 +163,21 @@ class GraphExecutor : public Executor {
                      std::vector<NDArray>* aux_state_vec);
   // Initialize in_args, arg_grads and aux_states with
   // shared_buffer and shared_exec
-  void InitArguments(const nnvm::IndexedGraph& idx,
-                     const nnvm::ShapeVector& inferred_shapes,
-                     const nnvm::DTypeVector& inferred_dtypes,
-                     const StorageTypeVector& inferred_stypes,
-                     const std::vector<Context>& in_arg_ctxes,
-                     const std::vector<Context>& arg_grad_ctxes,
-                     const std::vector<Context>& aux_state_ctxes,
-                     const std::vector<OpReqType>& grad_req_types,
-                     const std::unordered_set<std::string>& shared_arg_names,
-                     const Executor* shared_exec,
-                     std::unordered_map<std::string, NDArray>* shared_buffer,
-                     std::vector<NDArray>* in_arg_vec,
-                     std::vector<NDArray>* arg_grad_vec,
-                     std::vector<NDArray>* aux_state_vec);
+  virtual void InitArguments(const nnvm::IndexedGraph& idx,
+                             const nnvm::ShapeVector& inferred_shapes,
+                             const nnvm::DTypeVector& inferred_dtypes,
+                             const StorageTypeVector& inferred_stypes,
+                             const std::vector<Context>& in_arg_ctxes,
+                             const std::vector<Context>& arg_grad_ctxes,
+                             const std::vector<Context>& aux_state_ctxes,
+                             const std::vector<OpReqType>& grad_req_types,
+                             const std::unordered_set<std::string>& shared_arg_names,
+                             const Executor* shared_exec,
+                             std::unordered_map<std::string, NDArray>* shared_buffer,
+                             std::vector<NDArray>* in_arg_vec,
+                             std::vector<NDArray>* arg_grad_vec,
+                             std::vector<NDArray>* aux_state_vec);
+
   // internal initialization of the graph for simple bind
   Graph InitGraph(nnvm::Symbol symbol,
                   const Context& default_ctx,
@@ -212,9 +213,46 @@ class GraphExecutor : public Executor {
   void BulkInferenceOpSegs();
   // perform bulking and segmentation on a training graph
   void BulkTrainingOpSegs(size_t total_num_nodes);
-
+  // prints a helpful message after shape inference errors in executor.
+  static void HandleInferShapeError(const size_t num_forward_inputs,
+                                    const nnvm::IndexedGraph& idx,
+                                    const nnvm::ShapeVector& inferred_shapes);
+  // prints a helpful message after type inference errors in executor.
+  static void HandleInferTypeError(const size_t num_forward_inputs,
+                                   const nnvm::IndexedGraph& idx,
+                                   const nnvm::DTypeVector& inferred_dtypes);
+  // prints a helpful message after storage type checking errors in executor.
+  static void HandleInferStorageTypeError(const size_t num_forward_inputs,
+                                          const nnvm::IndexedGraph& idx,
+                                          const StorageTypeVector& inferred_stypes);
+  // helper to initialize an NDArray to all zeros.
+  static NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape,
+                           const Context &ctx, const int dtype);
+  // helper to add a NDArray of zeros to a std::vector.
+  static void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape,
+                               const Context &ctx, const int dtype,
+                               std::vector<NDArray> *vec);
+  // helper to reshape an NDArray of certain shape if it doesn't already exist.
+  static NDArray ReshapeOrCreate(const std::string& name,
+                                 const TShape& dest_arg_shape,
+                                 const int dest_arg_dtype,
+                                 const NDArrayStorageType dest_arg_stype,
+                                 const Context& ctx,
+                                 std::unordered_map<std::string, NDArray>* shared_buffer,
+                                 bool enable_row_sparse_sharing);
+  // Assign context to the graph.
+  static Graph AssignContext(Graph g,
+                             const Context& default_ctx,
+                             const std::map<std::string, Context>& ctx_map,
+                             const std::vector<Context>& in_arg_ctxes,
+                             const std::vector<Context>& arg_grad_ctxes,
+                             const std::vector<Context>& aux_state_ctxes,
+                             const std::vector<OpReqType>& grad_req_types,
+                             size_t num_forward_inputs,
+                             size_t num_forward_outputs);
   // indicate whether there is a backward graph for gradients.
   bool need_grad_;
+
   // internal graph
   nnvm::Graph graph_;
   // operator node
diff --git a/src/executor/onnx_to_tensorrt.cc b/src/executor/onnx_to_tensorrt.cc
new file mode 100644
index 00000000000..0b4d91be700
--- /dev/null
+++ b/src/executor/onnx_to_tensorrt.cc
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file onnx_to_tensorrt.cc
+ * \brief TensorRT integration with the MXNet executor
+ * \author Marek Kolodziej, Clement Fuji Tsang
+ */
+
+#if MXNET_USE_TENSORRT
+
+#include "./onnx_to_tensorrt.h"
+
+#include <onnx/onnx.pb.h>
+
+#include <NvInfer.h>
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/text_format.h>
+#include <onnx-tensorrt/NvOnnxParser.h>
+#include <onnx-tensorrt/NvOnnxParserRuntime.h>
+#include <onnx-tensorrt/PluginFactory.hpp>
+#include <onnx-tensorrt/plugin_common.hpp>
+
+using std::cout;
+using std::cerr;
+using std::endl;
+
+namespace onnx_to_tensorrt {
+
+struct InferDeleter {
+  template<typename T>
+    void operator()(T* obj) const {
+      if ( obj ) {
+        obj->destroy();
+      }
+    }
+};
+
+template<typename T>
+inline std::shared_ptr<T> InferObject(T* obj) {
+  if ( !obj ) {
+    throw std::runtime_error("Failed to create object");
+  }
+  return std::shared_ptr<T>(obj, InferDeleter());
+}
+
+std::string onnx_ir_version_string(int64_t ir_version = onnx::IR_VERSION) {
+  int onnx_ir_major = ir_version / 1000000;
+  int onnx_ir_minor = ir_version % 1000000 / 10000;
+  int onnx_ir_patch = ir_version % 10000;
+  return (std::to_string(onnx_ir_major) + "." +
+    std::to_string(onnx_ir_minor) + "." +
+    std::to_string(onnx_ir_patch));
+}
+
+void PrintVersion() {
+  cout << "Parser built against:" << endl;
+  cout << "  ONNX IR version:  " << onnx_ir_version_string(onnx::IR_VERSION) << endl;
+  cout << "  TensorRT version: "
+    << NV_TENSORRT_MAJOR << "."
+    << NV_TENSORRT_MINOR << "."
+    << NV_TENSORRT_PATCH << endl;
+}
+
+nvinfer1::ICudaEngine* onnxToTrtCtx(
+        const std::string& onnx_model,
+        int32_t max_batch_size,
+        size_t max_workspace_size,
+        nvinfer1::ILogger::Severity verbosity,
+        bool debug_builder) {
+  GOOGLE_PROTOBUF_VERIFY_VERSION;
+
+  TRT_Logger trt_logger(verbosity);
+  auto trt_builder = InferObject(nvinfer1::createInferBuilder(trt_logger));
+  auto trt_network = InferObject(trt_builder->createNetwork());
+  auto trt_parser  = InferObject(nvonnxparser::createParser(
+      *trt_network, trt_logger));
+  ::ONNX_NAMESPACE::ModelProto parsed_model;
+  // We check for a valid parse, but the main effect is the side effect
+  // of populating parsed_model
+  if (!parsed_model.ParseFromString(onnx_model)) {
+    throw dmlc::Error("Could not parse ONNX from string");
+  }
+
+  if ( !trt_parser->parse(onnx_model.c_str(), onnx_model.size()) ) {
+      int nerror = trt_parser->getNbErrors();
+      for ( int i=0; i < nerror; ++i ) {
+        nvonnxparser::IParserError const* error = trt_parser->getError(i);
+        if ( error->node() != -1 ) {
+          ::ONNX_NAMESPACE::NodeProto const& node =
+            parsed_model.graph().node(error->node());
+          cerr << "While parsing node number " << error->node()
+               << " [" << node.op_type();
+          if ( !node.output().empty() ) {
+            cerr << " -> \"" << node.output(0) << "\"";
+          }
+          cerr << "]:" << endl;
+          if ( static_cast<int>(verbosity) >= \
+            static_cast<int>(nvinfer1::ILogger::Severity::kINFO) ) {
+            cerr << "--- Begin node ---" << endl;
+            cerr << node.DebugString() << endl;
+            cerr << "--- End node ---" << endl;
+          }
+        }
+        cerr << "ERROR: "
+             << error->file() << ":" << error->line()
+             << " In function " << error->func() << ":\n"
+             << "[" << static_cast<int>(error->code()) << "] " << error->desc()
+             << endl;
+      }
+      throw dmlc::Error("Cannot parse ONNX into TensorRT Engine");
+  }
+
+  bool fp16 = trt_builder->platformHasFastFp16();
+
+  trt_builder->setMaxBatchSize(max_batch_size);
+  trt_builder->setMaxWorkspaceSize(max_workspace_size);
+  if ( fp16 && dmlc::GetEnv("MXNET_TENSORRT_USE_FP16_FOR_FP32", false) ) {
+    LOG(INFO) << "WARNING: TensorRT using fp16 given original MXNet graph in fp32 !!!";
+    trt_builder->setHalf2Mode(true);
+  }
+
+  trt_builder->setDebugSync(debug_builder);
+  nvinfer1::ICudaEngine* trt_engine = trt_builder->buildCudaEngine(*trt_network.get());
+  return trt_engine;
+}
+
+}  // namespace onnx_to_tensorrt
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/executor/onnx_to_tensorrt.h b/src/executor/onnx_to_tensorrt.h
new file mode 100644
index 00000000000..259cfce7c33
--- /dev/null
+++ b/src/executor/onnx_to_tensorrt.h
@@ -0,0 +1,77 @@
+#ifndef MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_
+#define MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file onnx_to_tensorrt.h
+ * \brief TensorRT integration with the MXNet executor
+ * \author Marek Kolodziej, Clement Fuji Tsang
+ */
+
+#if MXNET_USE_TENSORRT
+
+#include <fstream>
+#include <iostream>
+#include <NvInfer.h>
+#include <sstream>
+#include <string>
+
+#include "../operator/contrib/tensorrt-inl.h"
+
+namespace onnx_to_tensorrt {
+
+class TRT_Logger : public nvinfer1::ILogger {
+        nvinfer1::ILogger::Severity _verbosity;
+        std::ostream* _ostream;
+ public:
+        TRT_Logger(Severity verbosity = Severity::kWARNING,
+                   std::ostream& ostream = std::cout)
+                : _verbosity(verbosity), _ostream(&ostream) {}
+        void log(Severity severity, const char* msg) override {
+                if ( severity <= _verbosity ) {
+                        time_t rawtime = std::time(0);
+                        char buf[256];
+                        strftime(&buf[0], 256,
+                                 "%Y-%m-%d %H:%M:%S",
+                                 std::gmtime(&rawtime));
+                        const char* sevstr = (severity == Severity::kINTERNAL_ERROR ? "    BUG" :
+                                              severity == Severity::kERROR          ? "  ERROR" :
+                                              severity == Severity::kWARNING        ? "WARNING" :
+                                              severity == Severity::kINFO           ? "   INFO" :
+                                              "UNKNOWN");
+                        (*_ostream) << "[" << buf << " " << sevstr << "] "
+                                    << msg
+                                    << std::endl;
+                }
+        }
+};
+
+nvinfer1::ICudaEngine* onnxToTrtCtx(
+        const std::string& onnx_model,
+        int32_t max_batch_size = 32,
+        size_t max_workspace_size = 1L << 30,
+        nvinfer1::ILogger::Severity verbosity = nvinfer1::ILogger::Severity::kWARNING,
+        bool debug_builder = false);
+}  // namespace onnx_to_tensorrt
+
+#endif  // MXNET_USE_TENSORRT
+
+#endif  // MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_
diff --git a/src/executor/tensorrt_pass.cc b/src/executor/tensorrt_pass.cc
new file mode 100644
index 00000000000..b5fc8d15f7a
--- /dev/null
+++ b/src/executor/tensorrt_pass.cc
@@ -0,0 +1,596 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file tensorrt_pass.cc
+ * \brief Replace TRT compatible subgraphs by TRT engines
+ * \author Clement Fuji Tsang
+ */
+
+#if MXNET_USE_TENSORRT
+
+#include <NvInfer.h>
+#include <mxnet/base.h>
+#include <mxnet/op_attr_types.h>
+#include <mxnet/operator.h>
+#include <nnvm/graph_attr_types.h>
+#include <onnx/onnx.pb.h>
+
+#include "../operator/contrib/nnvm_to_onnx-inl.h"
+#include "./exec_pass.h"
+#include "./onnx_to_tensorrt.h"
+
+namespace mxnet {
+namespace exec {
+
+using NodePtr = nnvm::NodePtr;
+
+/*!
+ * \brief Custom graph class, which will contain bi-directional nodes
+ * we need to compute DFS and reverse DFS for graph partitioning
+ */
+class BidirectionalGraph {
+ public:
+  struct Node {
+    nnvm::Node* nnvmptr;
+    std::vector<Node*> inputs;
+    std::vector<Node*> outputs;
+  };
+  std::vector<Node> nodes;
+  std::unordered_map<nnvm::Node*, uint32_t> nnvm2nid;
+  std::vector<Node*> outputs;
+  static const std::unordered_set<std::string> unconditionalTRTop;
+
+  explicit BidirectionalGraph(const Graph &g) {
+    auto& idx = g.indexed_graph();
+    auto num_nodes = idx.num_nodes();
+    nodes.reserve(num_nodes);
+    nnvm2nid.reserve(num_nodes);
+    outputs.reserve(idx.outputs().size());
+    DFSVisit(g.outputs, [this](const nnvm::NodePtr& n) {
+      BidirectionalGraph::Node new_node;
+      new_node.nnvmptr = n.get();
+      nnvm2nid[n.get()] = static_cast<uint32_t>(nodes.size());
+      nodes.emplace_back(std::move(new_node));
+    });
+    for (const auto& it : nnvm2nid) {
+      nnvm::Node* nnvmnode = it.first;
+      uint32_t nid = it.second;
+      for (auto& n : nnvmnode->inputs) {
+        uint32_t input_nid = nnvm2nid[n.node.get()];
+        nodes[input_nid].outputs.emplace_back(&nodes[nid]);
+        nodes[nid].inputs.emplace_back(&nodes[input_nid]);
+      }
+    }
+    for (auto& e : g.outputs) {
+      uint32_t nid = nnvm2nid[e.node.get()];
+      outputs.emplace_back(&nodes[nid]);
+    }
+  }
+
+  template <typename FVisit>
+  void DFS(const std::vector<Node*>& heads, bool reverse, FVisit fvisit) {
+    std::unordered_set<Node*> visited;
+    std::vector<Node*> vec(heads.begin(), heads.end());
+    visited.reserve(heads.size());
+    while (!vec.empty()) {
+      Node* vertex = vec.back();
+      vec.pop_back();
+      if (visited.count(vertex) == 0) {
+        visited.insert(vertex);
+        fvisit(vertex);
+        std::vector<Node*> nexts = reverse ? vertex->inputs : vertex->outputs;
+        for (Node* node : nexts) {
+          if (visited.count(node) == 0) {
+            vec.emplace_back(node);
+          }
+        }
+      }
+    }
+  }
+
+  using t_pairset = std::pair<std::unordered_set<Node*>, std::unordered_set<Node*>>;
+  using t_pairvec = std::pair<std::vector<Node*>, std::vector<Node*>>;
+  using t_uncomp_map = std::unordered_map<Node*, std::unordered_set<Node*>>;
+
+  std::unordered_set<Node*> naive_grow_subgraph(Node* head,
+                                                std::unordered_set<Node*>* set_unused,
+                                                t_uncomp_map* uncomp_map) {
+    std::unordered_set<Node*> subgraph;
+    std::unordered_set<Node*> uncomp_set;
+    std::deque<Node*> stack;
+    stack.emplace_back(head);
+    while (!stack.empty()) {
+      Node* vertex = stack.back();
+      stack.pop_back();
+      if (set_unused->count(vertex) && !uncomp_set.count(vertex)) {
+        set_unused->erase(vertex);
+        subgraph.insert(vertex);
+        uncomp_set.insert((*uncomp_map)[vertex].begin(), (*uncomp_map)[vertex].end());
+        for (Node* input : vertex->inputs) {
+          if (set_unused->count(input) && !uncomp_set.count(input)) {
+            stack.emplace_back(input);
+          }
+        }
+        for (Node* output : vertex->outputs) {
+          if (set_unused->count(output) && !uncomp_set.count(output)) {
+            stack.emplace_back(output);
+          }
+        }
+      }
+    }
+    return subgraph;
+  }
+
+  std::vector<std::unordered_set<Node*>> get_subsets(
+    std::unordered_map<std::string, NDArray>* const params_map) {
+    std::vector<std::unordered_set<Node*>> subgraphs;
+    std::unordered_set<Node*> set_nonTRTnodes;
+    std::unordered_set<Node*> set_allnodes(nodes.size());
+    std::vector<t_pairset> separation_sets;
+    for (Node& node : nodes) {
+      if (!IsTRTCompatible(node.nnvmptr)) {
+        set_nonTRTnodes.insert(&node);
+        std::unordered_set<Node*> in_graph;
+        std::unordered_set<Node*> out_graph;
+        std::vector<Node*> dummy_head;
+        dummy_head.emplace_back(&node);
+        DFS(dummy_head, false, [&out_graph](Node* node) {
+          out_graph.insert(node);
+        });
+        DFS(dummy_head, true, [&in_graph](Node* node) {
+          in_graph.insert(node);
+        });
+        separation_sets.emplace_back(std::make_pair(in_graph, out_graph));
+      }
+      set_allnodes.emplace(&node);
+    }
+    t_uncomp_map uncomp_map;
+    std::unordered_set<Node*> set_TRTnodes;
+    set_TRTnodes.insert(set_allnodes.begin(), set_allnodes.end());
+    for (Node* n : set_nonTRTnodes) {
+      set_TRTnodes.erase(n);
+    }
+    for (Node* n : set_TRTnodes) {
+      for (t_pairset p : separation_sets) {
+        if (p.first.count(n)) {
+          uncomp_map[n].insert(p.second.begin(), p.second.end());
+        } else if (p.second.count(n)) {
+          uncomp_map[n].insert(p.first.begin(), p.first.end());
+        }
+      }
+      for (Node* nonTRTn : set_nonTRTnodes) {
+        uncomp_map[n].erase(nonTRTn);
+      }
+    }
+    std::unordered_set<Node*> set_unused;
+    set_unused.reserve(set_TRTnodes.size());
+
+    for (auto& n : set_TRTnodes) {
+      if (n->nnvmptr->attrs.op != nullptr || params_map->count(n->nnvmptr->attrs.name)) {
+        set_unused.insert(n);
+      }
+    }
+    std::unordered_set<Node*> visited;
+    std::deque<Node*> stack(outputs.begin(), outputs.end());
+    while (!stack.empty()) {
+      Node* vertex = stack.front();
+      stack.pop_front();
+      if (!visited.count(vertex)) {
+        visited.insert(vertex);
+        if (set_unused.count(vertex)) {
+          subgraphs.emplace_back(naive_grow_subgraph(vertex, &set_unused, &uncomp_map));
+        }
+        for (Node* input : vertex->inputs) {
+          stack.emplace_back(input);
+        }
+      }
+    }
+
+    return subgraphs;
+  }
+
+
+ private:
+  friend class Graph;
+
+  bool IsTRTCompatible(nnvm::Node* nodeptr) {
+    if (nodeptr->op() == nullptr) {
+      return true;
+    }
+
+    const std::string op_name = nodeptr->op()->name;
+    if (op_name == "Pooling") {
+      return (nodeptr->attrs.dict.at("pool_type") == "avg" ||
+          nodeptr->attrs.dict.at("pool_type") == "max");
+    }
+
+    if (unconditionalTRTop.count(op_name)) {
+      return true;
+    }
+
+    if (op_name == "Activation") {
+      return nodeptr->attrs.dict.at("act_type") == "relu" ||
+        nodeptr->attrs.dict.at("act_type") == "tanh" ||
+        nodeptr->attrs.dict.at("act_type") == "sigmoid";
+    }
+
+    return false;
+  }
+};  // class BidirectionalGraph
+
+/*!
+ * \brief function which transform std::vector<dmlc::any> back to Attrs (dmlc::any)
+ */
+const std::unordered_set<std::string> BidirectionalGraph::unconditionalTRTop = {
+  "Convolution",
+  "BatchNorm",
+  "elemwise_add",
+  "elemwise_sub",
+  "elemwise_mul",
+  "rsqrt",
+  "pad",
+  "Pad",
+  "mean",
+  "FullyConnected",
+  "Flatten",
+  "SoftmaxOutput",
+};
+
+
+using NodeEntrySet = std::unordered_set<nnvm::NodeEntry, nnvm::NodeEntryHash,
+                                        nnvm::NodeEntryEqual>;
+
+/*!
+ * \brief get the output nodes of the subgraph in the main graph
+ * \return a vector of the output nodes
+*/
+std::vector<nnvm::NodeEntry> GetSubgraphNodeEntries(Graph g,
+    std::unordered_set<nnvm::Node*> set_subgraph) {
+  std::vector<nnvm::NodeEntry> outputs;
+  NodeEntrySet _outputs;
+  for (auto& e : g.outputs) {
+    if (set_subgraph.count(e.node.get())) {
+      _outputs.insert(e);
+    }
+  }
+  DFSVisit(g.outputs, [&set_subgraph, &_outputs](const nnvm::NodePtr &node){
+    if (!set_subgraph.count(node.get())) {
+      for (auto& e : node->inputs) {
+        if (set_subgraph.count(e.node.get())) {
+          _outputs.insert(e);
+        }
+      }
+    }
+  });
+  outputs.insert(outputs.begin(), _outputs.begin(), _outputs.end());
+  return outputs;
+}
+
+
+/*!
+ * \brief get the nodes outside of the subgraph for which outputs are used in the subgraph
+ * \return a vector the nodes
+*/
+std::vector<nnvm::NodeEntry> GetSubgraphInterfaceNodes(Graph g,
+    std::unordered_set<nnvm::Node*> set_subgraph) {
+  std::vector<nnvm::NodeEntry> inputs;
+  NodeEntrySet _inputs;
+  DFSVisit(g.outputs, [&set_subgraph, &_inputs](const nnvm::NodePtr &node){
+    if (set_subgraph.count(node.get())) {
+      for (auto& e : node->inputs) {
+        if (!set_subgraph.count(e.node.get())) {
+          _inputs.insert(e);
+        }
+      }
+    }
+  });
+  inputs.insert(inputs.begin(), _inputs.begin(), _inputs.end());
+  return inputs;
+}
+
+std::unordered_map<uint32_t, uint32_t> GetGraphInputsMap(const Graph& g) {
+  std::unordered_map<uint32_t, uint32_t> outputs;
+  auto& idx = g.indexed_graph();
+  outputs.reserve(idx.num_nodes());
+  std::vector<uint32_t> input_nodes = idx.input_nodes();
+  for (size_t i = 0; i < input_nodes.size(); ++i) {
+    outputs[input_nodes[i]] = static_cast<uint32_t>(i);
+  }
+  return outputs;
+}
+
+/*!
+ * \brief Dummy function which creates a fake TensorRT Node
+ */
+nnvm::NodePtr ConvertNnvmGraphToOnnx(const nnvm::Graph &g,
+                                     std::unordered_map<std::string, NDArray>* const params_map) {
+  auto p = nnvm::Node::Create();
+  p->attrs.op = nnvm::Op::Get("_trt_op");
+  op::TRTParam trt_param = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(g, params_map);
+  p->attrs.dict["serialized_output_map"] = trt_param.serialized_output_map;
+  p->attrs.dict["serialized_input_map"]  = trt_param.serialized_input_map;
+  p->attrs.dict["serialized_onnx_graph"] = trt_param.serialized_onnx_graph;
+  if (p->op()->attr_parser != nullptr) {
+    p->op()->attr_parser(&(p->attrs));
+  }
+  return p;
+}
+
+/*!
+ * \brief Update attributes of the graph (such as some inputs properties)
+ */
+Graph UpdateSubgraphAttrs(Graph&& subgraph, const Graph& g,
+                          const std::unordered_map<nnvm::Node*, nnvm::NodePtr>& old2new,
+                          const nnvm::NodeEntryMap<nnvm::NodeEntry>& main_input_entry_to_sub) {
+  const auto& idx     = g.indexed_graph();
+  const auto& sub_idx = subgraph.indexed_graph();
+
+  const auto& shape               = g.GetAttr<nnvm::ShapeVector>("shape");
+  const auto& dtype               = g.GetAttr<nnvm::DTypeVector>("dtype");
+  const auto& storage_type        = g.GetAttr<StorageTypeVector>("storage_type");
+  const auto& shape_inputs        = g.GetAttr<nnvm::ShapeVector>("shape_inputs");
+  const auto& dtype_inputs        = g.GetAttr<nnvm::DTypeVector>("dtype_inputs");
+  const auto& storage_type_inputs = g.GetAttr<StorageTypeVector>("storage_type_inputs");
+
+  nnvm::ShapeVector sub_shape(sub_idx.num_node_entries());
+  nnvm::DTypeVector sub_dtype(sub_idx.num_node_entries());
+  StorageTypeVector sub_storage_type(sub_idx.num_node_entries());
+  nnvm::ShapeVector sub_shape_inputs(sub_idx.input_nodes().size());
+  nnvm::DTypeVector sub_dtype_inputs(sub_idx.input_nodes().size());
+  StorageTypeVector sub_storage_type_inputs(sub_idx.input_nodes().size());
+
+  const std::unordered_map<uint32_t, uint32_t> inputsindex2pos     = GetGraphInputsMap(g);
+  const std::unordered_map<uint32_t, uint32_t> sub_inputsindex2pos = GetGraphInputsMap(subgraph);
+  // map attributes from graph to subgraph
+  for (auto& p : old2new) {
+    const uint32_t nid     = idx.node_id(p.first);
+    const uint32_t sub_nid = sub_idx.node_id(p.second.get());
+    const nnvm::Op* op = sub_idx[sub_nid].source->op();
+    if (op == nullptr) {  // if it's an input node, there is only one output node entry
+      const uint32_t sub_i       = sub_idx.entry_id(sub_nid, 0);
+      const uint32_t sub_input_i = sub_inputsindex2pos.at(sub_nid);
+      const uint32_t i           = idx.entry_id(nid, 0);
+
+      sub_shape[sub_i] = shape[i];
+      sub_dtype[sub_i] = dtype[i];
+      sub_storage_type[sub_i]       = storage_type[i];
+      sub_shape_inputs[sub_input_i] = shape_inputs[inputsindex2pos.at(nid)];
+      sub_dtype_inputs[sub_input_i] = dtype_inputs[inputsindex2pos.at(nid)];
+      sub_storage_type_inputs[sub_input_i] = storage_type_inputs[inputsindex2pos.at(nid)];
+
+    } else {
+      for (size_t oi = 0; oi < op->num_outputs; ++oi) {
+        const uint32_t sub_i = sub_idx.entry_id(sub_nid, oi);
+        const uint32_t i = idx.entry_id(nid, oi);
+          sub_shape[sub_i] = shape[i];
+          sub_dtype[sub_i] = dtype[i];
+          sub_storage_type[sub_i] = storage_type[i];
+      }
+    }
+  }
+  // old2new doesn't contain placeholder / interfaces
+  for (auto& p : main_input_entry_to_sub) {
+    nnvm::NodeEntry main_entry = p.first;
+    nnvm::NodeEntry sub_entry = p.second;
+    const uint32_t sub_nid = sub_idx.node_id(sub_entry.node.get());
+    const uint32_t sub_i = sub_idx.entry_id(sub_entry);
+    const uint32_t i = idx.entry_id(main_entry);
+    const uint32_t sub_input_i = sub_inputsindex2pos.at(sub_nid);
+    sub_shape[sub_i] = shape[i];
+    sub_dtype[sub_i] = dtype[i];
+    sub_storage_type[sub_i] = storage_type[i];
+    sub_shape_inputs[sub_input_i] = sub_shape[sub_i];
+    sub_dtype_inputs[sub_input_i] = sub_dtype[sub_i];
+    sub_storage_type_inputs[sub_input_i] = sub_storage_type[sub_i];
+  }
+  subgraph.attrs["shape"] =
+      std::make_shared<dmlc::any>(std::move(sub_shape));
+  subgraph.attrs["dtype"] =
+      std::make_shared<dmlc::any>(std::move(sub_dtype));
+  subgraph.attrs["storage_type"] =
+      std::make_shared<dmlc::any>(std::move(sub_storage_type));
+  subgraph.attrs["shape_inputs"] =
+      std::make_shared<dmlc::any>(std::move(sub_shape_inputs));
+  subgraph.attrs["dtype_inputs"] =
+      std::make_shared<dmlc::any>(std::move(sub_dtype_inputs));
+  subgraph.attrs["storage_type_inputs"] =
+      std::make_shared<dmlc::any>(std::move(sub_storage_type_inputs));
+
+  return subgraph;
+}
+
+/*!
+ * \brief Generate a name for a new TRT node, avoid collision if some TRT_nodes are already defined
+ */
+const std::string GetNewTrtName(const Graph& g, const Graph& subgraph) {
+  const std::string name_prefix("TRT_node");
+  std::unordered_set<std::string> name_set;
+  DFSVisit(g.outputs, [&name_set, &name_prefix](const nnvm::NodePtr& node) {
+    if (node->attrs.name.compare(0, name_prefix.size(), name_prefix) == 0) {
+      name_set.insert(node->attrs.name);
+    }
+  });
+  // name inside the subgraph will be avaible as they will be removed
+  DFSVisit(subgraph.outputs, [&name_set, &name_prefix](const nnvm::NodePtr& node) {
+    if (node->attrs.name.compare(0, name_prefix.size(), name_prefix) == 0) {
+      name_set.erase(node->attrs.name);
+    }
+  });
+  uint32_t name_suffix = 0;
+  std::string full_name = name_prefix + std::to_string(name_suffix);
+  while (name_set.count(full_name)) {
+    full_name = name_prefix + std::to_string(++name_suffix);
+  }
+  return full_name;
+}
+
+/*!
+ * \brief helper function to display what nodes are in a specific subset
+ */
+void dispNodesSet(Graph g, std::unordered_set<nnvm::Node*> s) {
+  DFSVisit(g.outputs, [&s](const nnvm::NodePtr n){
+    if (s.count(n.get())) {
+      std::cout << "  Y " << n->attrs.name << std::endl;
+    } else {
+      std::cout << "  N " << n->attrs.name << std::endl;
+    }
+  });
+}
+
+/*!
+ * \brief Replace a set of nodes by a TensorRT node
+ */
+Graph ReplaceSubgraph(Graph&& g,
+                      const std::unordered_set<nnvm::Node*>& set_subgraph,
+                      std::unordered_map<std::string, NDArray>* const params_map) {
+  // Create MXNet subgraph
+  Graph subgraph;
+
+  const auto sub_outputs_in_main = GetSubgraphNodeEntries(g, set_subgraph);
+  subgraph.outputs = sub_outputs_in_main;
+  // old2new will link raw pointer of the nodes in the graph to
+  // the corresponding shared_ptr of the nodes in the generated subgraph
+  std::unordered_map<nnvm::Node*, nnvm::NodePtr> old2new;
+  std::deque<nnvm::Node*> stack;
+  std::unordered_set<nnvm::Node*> visited;
+  int32_t reservation = set_subgraph.size();
+  old2new.reserve(reservation);
+  visited.reserve(reservation);
+
+  // Create the shared_ptr using the same raw pointer don't really matter
+  for (auto& n : set_subgraph) {
+    old2new[n] = std::make_shared<nnvm::Node>(*n);
+  }
+
+  // To generate a subgraph an input have to be replace by data node (no op)
+  // and it have to be agnostic to the node from which it's an output
+  // (For exemple even if two inputs are two different outputs from the same node)
+  nnvm::NodeEntryMap<nnvm::NodeEntry> main_input_entry_to_sub;
+  for (auto& e : GetSubgraphInterfaceNodes(g, set_subgraph)) {
+    auto node = nnvm::Node::Create();
+    node->attrs.name = e.node->attrs.name + "_" + std::to_string(e.index);
+    auto new_e = nnvm::NodeEntry{node, 0, 0};
+    main_input_entry_to_sub[e] = new_e;
+  }
+
+  for (nnvm::NodeEntry& e : subgraph.outputs) {
+    e.node = old2new[e.node.get()];
+    stack.emplace_back(e.node.get());
+  }
+  // link all nodes in the subgraph to nodes in the subgraph instead of main graph
+  while (!stack.empty()) {
+    auto vertex = stack.front();
+    stack.pop_front();
+    if (!visited.count(vertex)) {
+      visited.insert(vertex);
+      for (auto& e : vertex->inputs) {
+        auto it = main_input_entry_to_sub.find(e);
+        if (it != main_input_entry_to_sub.end()) {
+          e = it->second;
+        } else {
+          e.node = old2new[e.node.get()];
+        }
+      stack.emplace_back(e.node.get());
+      }
+    }
+  }
+  // Remove the control dependencies of the subgraph to nodes that are not in the subgraph
+  DFSVisit(subgraph.outputs, [&set_subgraph, &old2new](const nnvm::NodePtr& node) {
+    std::remove_if(node->control_deps.begin(),
+                   node->control_deps.end(),
+                   [&set_subgraph](nnvm::NodePtr n_ptr) {
+                    return !set_subgraph.count(n_ptr.get());
+                   });
+    for (nnvm::NodePtr& n_ptr : node->control_deps) {
+      n_ptr = old2new[n_ptr.get()];
+    }
+  });
+
+  subgraph = UpdateSubgraphAttrs(std::move(subgraph), g, old2new, main_input_entry_to_sub);
+  auto& sub_idx = subgraph.indexed_graph();
+
+  auto trtnodeptr = ConvertNnvmGraphToOnnx(subgraph, params_map);
+  trtnodeptr->attrs.name = GetNewTrtName(g, subgraph);
+
+  // Insert new trt node and unplug replaced nodes
+  std::unordered_map<uint32_t, nnvm::NodeEntry> sub_input_entryid_to_main;
+  for (auto& p : main_input_entry_to_sub) {
+    sub_input_entryid_to_main[sub_idx.entry_id(p.second)] = p.first;
+  }
+
+  // Plug the nodes from the main graph as inputs of the trt node
+  trtnodeptr->inputs.resize(main_input_entry_to_sub.size());
+  {
+    uint32_t counter = 0;
+    for (uint32_t i : sub_idx.input_nodes()) {
+      auto it = sub_input_entryid_to_main.find(sub_idx.entry_id(i, 0));
+      if (it != sub_input_entryid_to_main.end()) {
+        trtnodeptr->inputs[counter++] = it->second;
+      }
+    }
+  }
+  nnvm::NodeEntryMap<uint32_t> sub_outputs_in_main_to_pos;
+  for (uint32_t i = 0; i < sub_outputs_in_main.size(); ++i) {
+    sub_outputs_in_main_to_pos[sub_outputs_in_main[i]] = i;
+  }
+  // Plug the trt node as inputs to the main graph nodes
+  DFSVisit(g.outputs, [&sub_outputs_in_main_to_pos, &trtnodeptr](const nnvm::NodePtr& n) {
+    for (auto& e : n->inputs) {
+      auto it = sub_outputs_in_main_to_pos.find(e);
+      if (it != sub_outputs_in_main_to_pos.end()) {
+        e.index = it->second;
+        e.node = trtnodeptr;
+      }
+    }
+  });
+
+  for (auto& output : g.outputs) {
+    auto it = sub_outputs_in_main_to_pos.find(output);
+    if (it != sub_outputs_in_main_to_pos.end()) {
+      output.index = it->second;
+      output.node = trtnodeptr;
+    }
+  }
+
+  Graph new_graph;
+  new_graph.outputs = g.outputs;
+  return new_graph;
+}
+
+std::vector<std::unordered_set<nnvm::Node*>> GetTrtCompatibleSubsets(const Graph& g,
+    std::unordered_map<std::string, NDArray>* const params_map) {
+  BidirectionalGraph biG = BidirectionalGraph(g);
+  std::vector<std::unordered_set<BidirectionalGraph::Node*>> subsets = biG.get_subsets(params_map);
+  std::vector<std::unordered_set<nnvm::Node*>> nnvm_subsets(subsets.size(),
+                                                            std::unordered_set<nnvm::Node*>());
+  for (size_t i = 0; i < subsets.size(); ++i) {
+    nnvm_subsets[i].reserve(subsets[i].size());
+    for (auto& n : subsets[i]) {
+      nnvm_subsets[i].insert(n->nnvmptr);
+    }
+  }
+  return nnvm_subsets;
+}
+
+}  // namespace exec
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/executor/trt_graph_executor.cc b/src/executor/trt_graph_executor.cc
new file mode 100644
index 00000000000..a73013b49fc
--- /dev/null
+++ b/src/executor/trt_graph_executor.cc
@@ -0,0 +1,445 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#if MXNET_USE_TENSORRT
+
+#include "trt_graph_executor.h"
+
+#include <onnx/onnx.pb.h>
+#include <NvInfer.h>
+#include "./onnx_to_tensorrt.h"
+#include "../operator/contrib/tensorrt-inl.h"
+
+namespace mxnet {
+namespace exec {
+
+  /*!
+ * \brief TrtGraphExecutor initializer for simple bind flow in
+ * which only certain input shapes and dtypes are provided by users.
+ * The initializer uses these shapes and dtypes to perform
+ * shape and dtype inferences, and then create NDArrays
+ * to populate data entries of the graph. The created NDArrays
+ * for in_args, arg_grads and aux_states are passed to the
+ * front end to attach the created executor.
+ * In front end, if the simple_bind flow is trigger by
+ * _bind_ith_exec, the shared data arrays of DataParallelExecutorGroup
+ * and shared executor will be taken into account in creating
+ * NDArrays for in_args, arg_grads, and aux_states for reusing
+ * already allocated memory.
+ *
+ * This version of an executor exports the computation graph to TensorRT make use of fused
+ * kernels and other runtime enhancements.  TRT will compile the sub-graphs to executable fused
+ * operators without intervention from the user.  Operators in the original graph that are not
+ * supported by TRT will continue to be executed normally by MXNet.
+ *
+ */
+void TrtGraphExecutor::Init(nnvm::Symbol symbol,
+                            const Context& default_ctx,
+                            const std::map<std::string, Context>& ctx_map,
+                            std::vector<Context> *in_arg_ctxes,
+                            std::vector<Context> *arg_grad_ctxes,
+                            std::vector<Context> *aux_state_ctxes,
+                            std::unordered_map<std::string, TShape> *arg_shape_map,
+                            std::unordered_map<std::string, int> *arg_dtype_map,
+                            std::unordered_map<std::string, int> *arg_stype_map,
+                            std::vector<OpReqType> *grad_req_types,
+                            const std::unordered_set<std::string>& shared_arg_names,
+                            std::vector<NDArray>* in_arg_vec,
+                            std::vector<NDArray>* arg_grad_vec,
+                            std::vector<NDArray>* aux_state_vec,
+                            std::unordered_map<std::string, NDArray>* shared_buffer,
+                            Executor* shared_exec,
+                            const nnvm::NodeEntryMap<NDArray>& feed_dict) {
+  symbol = symbol.Copy();
+  nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, *in_arg_ctxes, *arg_grad_ctxes,
+                            *aux_state_ctxes, *grad_req_types);
+
+  if (need_grad_) {
+    LOG(FATAL) << "You may be attempting to use TensorRT for training.  TensorRT is an inference "
+                  "only library.  To re-enable legacy MXNet graph execution, which will support "
+                  "training, set the MXNET_USE_TENSORRT environment variable to 0, or call "
+                  "mx.contrib.tensorrt.set_use_tensorrt(False)";
+  }
+
+  if (shared_buffer == nullptr || shared_buffer->empty()) {
+    LOG(FATAL) << "MXNET_USE_TENSORRT = 1 but shared_buffer is empty. "
+               << "Please provide weights and other parameters, such as "
+               << "BatchNorm moments, via the shared_buffer, during simple bind call.";
+  }
+
+  // The following code of shape and dtype inferences and argument
+  // initialization is for simple_bind only. Regular bind operation
+  // should do this differently.
+
+  // Initialize arg_shapes and arg_dtypes for shape and type inferences.
+  // It contains all in_args and aux_states' shapes and types in a certain order.
+  const nnvm::IndexedGraph& idx = g.indexed_graph();
+  nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape());
+  nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1);
+  StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage);
+  for (size_t i = 0; i < num_forward_inputs_; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const std::string& name = idx[nid].source->attrs.name;
+    auto it1 = arg_shape_map->find(name);
+    if (arg_shape_map->end() != it1) {
+      arg_shapes[i] = it1->second;
+    }
+    auto it2 = arg_dtype_map->find(name);
+    if (arg_dtype_map->end() != it2) {
+      arg_dtypes[i] = it2->second;
+    }
+    auto it3 = arg_stype_map->find(name);
+    if (arg_stype_map->end() != it3) {
+      arg_stypes[i] = it3->second;
+    }
+  }
+  g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+  if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+    HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
+                          g.GetAttr<nnvm::ShapeVector>("shape"));
+  }
+
+  g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+  if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
+    HandleInferTypeError(num_forward_inputs_, g.indexed_graph(),
+                         g.GetAttr<nnvm::DTypeVector>("dtype"));
+  }
+
+  g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
+  if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
+    HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(),
+                                g.GetAttr<StorageTypeVector>("storage_type"));
+  }
+
+  auto trt_groups = GetTrtCompatibleSubsets(g, shared_buffer);
+  for (auto trt_group : trt_groups) {
+    if (trt_group.size() > 1) {
+      g = ReplaceSubgraph(std::move(g), trt_group, shared_buffer);
+      g = ReinitGraph(std::move(g), default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes,
+                      aux_state_ctxes, grad_req_types, arg_shape_map, arg_dtype_map,
+                      arg_stype_map, shared_buffer);
+    }
+  }
+
+
+  InitArguments(g.indexed_graph(), g.GetAttr<nnvm::ShapeVector>("shape"),
+                g.GetAttr<nnvm::DTypeVector>("dtype"),
+                g.GetAttr<StorageTypeVector>("storage_type"),
+                *in_arg_ctxes, *arg_grad_ctxes, *aux_state_ctxes,
+                *grad_req_types, shared_arg_names, shared_exec,
+                shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec);
+
+  // The above code of shape and dtype inferences and argument
+  // initialization is for simple_bind only. Regular bind operation
+  // should do this differently.
+
+  // Initialize the rest attributes of the graph.
+  // This function can be called by regular bind
+  // operation flow as well.
+  FinishInitGraph(symbol, g, shared_exec, feed_dict);
+}
+/*!
+ * \brief Initialize in_args, arg_grads, and aux_states
+ * and their data_entry_ of the executor using
+ * shared_buffer from DataParallelExecutorGroup
+ * and shared_exec if available.
+ */
+void TrtGraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
+                                  const nnvm::ShapeVector& inferred_shapes,
+                                  const nnvm::DTypeVector& inferred_dtypes,
+                                  const StorageTypeVector& inferred_stypes,
+                                  const std::vector<Context>& in_arg_ctxes,
+                                  const std::vector<Context>& arg_grad_ctxes,
+                                  const std::vector<Context>& aux_state_ctxes,
+                                  const std::vector<OpReqType>& grad_req_types,
+                                  const std::unordered_set<std::string>& shared_arg_names,
+                                  const Executor* shared_exec,
+                                  std::unordered_map<std::string, NDArray>* shared_buffer,
+                                  std::vector<NDArray>* in_arg_vec,
+                                  std::vector<NDArray>* arg_grad_vec,
+                                  std::vector<NDArray>* aux_state_vec) {
+  // initialize in_args, arg_grads, and aux_states and populate grad_store_
+  data_entry_.resize(idx.num_node_entries());
+  size_t arg_top = 0, aux_top = 0;
+  const auto& mutable_nodes = idx.mutable_input_nodes();
+  for (size_t i = 0; i < num_forward_inputs_; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const uint32_t eid = idx.entry_id(nid, 0);
+    const TShape& inferred_shape = inferred_shapes[eid];
+    const int inferred_dtype = inferred_dtypes[eid];
+    const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid];
+    const std::string& arg_name = idx[nid].source->attrs.name;
+    // aux_states
+    if (mutable_nodes.count(nid)) {
+      if (nullptr != shared_exec) {
+        const NDArray& aux_nd = shared_exec->aux_state_map().at(arg_name);
+        CHECK(inferred_stype == kDefaultStorage && aux_nd.storage_type() == kDefaultStorage)
+          << "Non-default storage type detected when creating auxilliary NDArray. The allocated "
+          << "memory of shared_exec.aux_array cannot be resued for argument: "
+          << arg_name << " for the current executor";
+        CHECK_EQ(inferred_shape, aux_nd.shape())
+          << "Inferred shape does not match shared_exec.aux_array's shape."
+             " Therefore, the allocated memory for shared_exec.aux_array cannot"
+             " be resued for creating auxilliary NDArray of the argument: "
+          << arg_name << " for the current executor";
+        CHECK_EQ(inferred_dtype, aux_nd.dtype())
+          << "Inferred dtype does not match shared_exec.aux_array's dtype."
+             " Therefore, the allocated memory for shared_exec.aux_array cannot"
+             " be resued for creating auxilliary NDArray of the argument: "
+          << arg_name << " for the current executor";
+        aux_state_vec->emplace_back(aux_nd);
+      } else {
+        auto it = shared_buffer->find(arg_name);
+        if (it != shared_buffer->end()) {
+          aux_state_vec->push_back(std::move(it->second.Copy(aux_state_ctxes[aux_top])));
+        } else {
+          aux_state_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape,
+                                                       aux_state_ctxes[aux_top], inferred_dtype)));
+        }
+      }  // if (has_shared_exec)
+      data_entry_[eid] = aux_state_vec->back();
+      aux_state_map_.emplace(arg_name, aux_state_vec->back());
+      ++aux_top;
+    } else {  // in_args and grad for in_args
+      if (shared_arg_names.count(arg_name)) {  // model parameter
+        // model parameter
+        if (nullptr != shared_exec) {
+          const NDArray& in_arg_nd = shared_exec->in_arg_map().at(arg_name);
+          auto arg_nd_stype = in_arg_nd.storage_type();
+          // for model parameter, both default storage and row_sparse storage can be shared
+          bool shareable_arg_stype = inferred_stype == kDefaultStorage ||
+                                     inferred_stype == kRowSparseStorage;
+          // try to reuse memory from shared_exec
+          CHECK(shareable_arg_stype) << "Inferred storage type "
+            << common::stype_string(inferred_stype)
+            << " does not support memory sharing with shared_exec.arg_array";
+          CHECK_EQ(inferred_stype, arg_nd_stype)
+            << "Inferred stype does not match shared_exec.arg_array's stype"
+               " Therefore, the allocated memory for shared_exec.arg_array cannot"
+               " be resued for creating NDArray of the argument "
+            << arg_name << " for the current executor";
+          CHECK_EQ(inferred_shape, in_arg_nd.shape())
+            << "Inferred shape does not match shared_exec.arg_array's shape"
+               " Therefore, the allocated memory for shared_exec.arg_array cannot"
+               " be resued for creating NDArray of the argument "
+            << arg_name << " for the current executor";
+          CHECK_EQ(inferred_dtype, in_arg_nd.dtype())
+            << "Inferred dtype does not match shared_exec.arg_array's dtype"
+               " Therefore, the allocated memory for shared_exec.arg_array cannot"
+               " be resued for creating NDArray of the argument "
+            << arg_name << " for the current executor";
+          in_arg_vec->emplace_back(in_arg_nd);
+        } else {
+          // doesn't have shared_exec, or non-default storage
+          EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top],
+                           inferred_dtype, in_arg_vec);
+        }
+        // gradient for model parameter
+        if (kNullOp == grad_req_types[arg_top]) {
+          arg_grad_vec->emplace_back();
+        } else {
+          auto grad_oid = grad_store_.size() + num_forward_outputs_;
+          auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]);
+          auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid];
+          if (nullptr != shared_exec && grad_stype == kDefaultStorage &&
+              shared_exec->arg_grad_map().at(arg_name).storage_type() == kDefaultStorage) {
+            // try to reuse memory from shared_exec
+            arg_grad_vec->emplace_back(shared_exec->arg_grad_map().at(arg_name));
+          } else {
+            // no need to reuse memory from shared_exec for gradient of non-default storage
+            EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top],
+                             inferred_dtype, arg_grad_vec);
+          }
+          grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back());
+        }
+      } else {  // !shared_arg_names.count(arg_name)
+        // model parameter, row_sparse ndarray sharing enabled
+        auto it = shared_buffer->find(arg_name);
+        if (it != shared_buffer->end()) {
+          in_arg_vec->push_back(std::move(it->second.Copy(in_arg_ctxes[arg_top])));
+        } else {
+          in_arg_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape,
+                                                    in_arg_ctxes[arg_top], inferred_dtype)));
+        }
+        // gradient for model parameter, row_sparse ndarray sharing disabled
+        if (kNullOp == grad_req_types[arg_top]) {
+          arg_grad_vec->emplace_back();
+        } else {
+          auto grad_oid = grad_store_.size() + num_forward_outputs_;
+          auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]);
+          auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid];
+          bool enable_row_sparse_sharing = false;
+          arg_grad_vec->emplace_back(ReshapeOrCreate("grad of " + arg_name, inferred_shape,
+                                                     inferred_dtype, grad_stype,
+                                                     arg_grad_ctxes[arg_top], shared_buffer,
+                                                     enable_row_sparse_sharing));
+          grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back());
+        }  // if (kNullOp == grad_req_types[arg_top])
+      }  // if (shared_arg_names.count(arg_name))
+      in_arg_map_.emplace(arg_name, in_arg_vec->back());
+      if (!arg_grad_vec->back().is_none()) {
+        arg_grad_map_.emplace(arg_name, arg_grad_vec->back());
+      }
+      data_entry_[eid] = in_arg_vec->back();
+      ++arg_top;
+    }
+  }
+}
+
+
+  /*!
+ * \brief This function is triggered after each tensorrt subgraph replacement pass.
+ * Reset arguments of GraphExecutor::Init(...) as some variables (weights and biases)
+ * are absorbed into the TRT engine it also it reruns attributes inferences accordingly
+ * to the new topology.
+ */
+Graph TrtGraphExecutor::ReinitGraph(Graph&& g, const Context &default_ctx,
+                                 const std::map<std::string, Context> &ctx_map,
+                                 std::vector<Context> *in_arg_ctxes,
+                                 std::vector<Context> *arg_grad_ctxes,
+                                 std::vector<Context> *aux_state_ctxes,
+                                 std::vector<OpReqType> *grad_req_types,
+                                 std::unordered_map<std::string, TShape> *arg_shape_map,
+                                 std::unordered_map<std::string, int> *arg_dtype_map,
+                                 std::unordered_map<std::string, int> *arg_stype_map,
+                                 std::unordered_map<std::string, NDArray> *params_map) {
+  std::unordered_set<std::string> to_remove_params;
+  for (auto& el : *params_map) {
+    to_remove_params.insert(el.first);
+  }
+
+  DFSVisit(g.outputs, [&to_remove_params](const nnvm::NodePtr n) {
+    to_remove_params.erase(n->attrs.name);
+  });
+
+  for (auto& el : to_remove_params) {
+    params_map->erase(el);
+    arg_shape_map->erase(el);
+    arg_dtype_map->erase(el);
+    arg_stype_map->erase(el);
+  }
+  const auto &idx = g.indexed_graph();
+  num_forward_inputs_ = idx.input_nodes().size();
+  in_arg_ctxes->resize(num_forward_inputs_ - idx.mutable_input_nodes().size());
+  arg_grad_ctxes->resize(num_forward_inputs_ - idx.mutable_input_nodes().size());
+  grad_req_types->resize(num_forward_inputs_ - idx.mutable_input_nodes().size());
+  aux_state_ctxes->resize(idx.mutable_input_nodes().size());
+
+  // create "device" and "context" attrs for the graph
+  g = AssignContext(g, default_ctx, ctx_map, *in_arg_ctxes, *arg_grad_ctxes,
+                    *aux_state_ctxes, *grad_req_types, num_forward_inputs_,
+                    num_forward_outputs_);
+
+  // get number of nodes used in forward pass
+  num_forward_nodes_ = 0;
+  for (size_t i = 0; i < num_forward_outputs_; ++i) {
+    num_forward_nodes_ = std::max(
+        num_forward_nodes_, static_cast<size_t>(idx.outputs()[i].node_id + 1));
+  }
+  nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape());
+  nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1);
+  StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage);
+  for (size_t i = 0; i < num_forward_inputs_; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const std::string &name = idx[nid].source->attrs.name;
+    auto it1 = arg_shape_map->find(name);
+    if (arg_shape_map->end() != it1) {
+      arg_shapes[i] = it1->second;
+    }
+    auto it2 = arg_dtype_map->find(name);
+    if (arg_dtype_map->end() != it2) {
+      arg_dtypes[i] = it2->second;
+    }
+    auto it3 = arg_stype_map->find(name);
+    if (arg_stype_map->end() != it3) {
+      arg_stypes[i] = it3->second;
+    }
+  }
+  g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+  if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+    HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
+                          g.GetAttr<nnvm::ShapeVector>("shape"));
+  }
+
+  g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+  if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
+    HandleInferTypeError(num_forward_inputs_, g.indexed_graph(),
+                         g.GetAttr<nnvm::DTypeVector>("dtype"));
+  }
+
+  g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
+
+  if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
+    HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(),
+                                g.GetAttr<StorageTypeVector>("storage_type"));
+  }
+
+  return g;
+}
+
+
+/*!
+ * \brief Return the "optimized" symbol contained in the graph.
+ * For optimization pass such as TensorRT pass
+ */
+nnvm::Symbol TrtGraphExecutor::GetOptimizedSymbol() {
+  Symbol ret;
+  ret.outputs = std::vector<nnvm::NodeEntry>(graph_.outputs.begin(),
+                                             graph_.outputs.begin() + num_forward_outputs_);
+  ret = ret.Copy();
+  static const Op* trt_op = Op::Get("_trt_op");
+  DFSVisit(ret.outputs, [](const nnvm::NodePtr n) {
+    if (n->op() == trt_op) {
+      n->attrs.dict.clear();
+    }
+  });
+  return ret;
+}
+
+Executor *TrtGraphExecutor::TensorRTBind(nnvm::Symbol symbol,
+                                         const Context &default_ctx,
+                                         const std::map<std::string, Context> &group2ctx,
+                                         std::vector<Context> *in_arg_ctxes,
+                                         std::vector<Context> *arg_grad_ctxes,
+                                         std::vector<Context> *aux_state_ctxes,
+                                         std::unordered_map<std::string, TShape> *arg_shape_map,
+                                         std::unordered_map<std::string, int> *arg_dtype_map,
+                                         std::unordered_map<std::string, int> *arg_stype_map,
+                                         std::vector<OpReqType> *grad_req_types,
+                                         const std::unordered_set<std::string> &param_names,
+                                         std::vector<NDArray> *in_args,
+                                         std::vector<NDArray> *arg_grads,
+                                         std::vector<NDArray> *aux_states,
+                                         std::unordered_map<std::string, NDArray> *shared_buffer,
+                                         Executor *shared_exec) {
+  auto exec = new exec::TrtGraphExecutor();
+  exec->Init(symbol, default_ctx, group2ctx,
+             in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
+             arg_shape_map, arg_dtype_map, arg_stype_map,
+             grad_req_types, param_names,
+             in_args, arg_grads, aux_states,
+             shared_buffer, shared_exec);
+  return exec;
+}
+
+}  // namespace exec
+
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/executor/trt_graph_executor.h b/src/executor/trt_graph_executor.h
new file mode 100644
index 00000000000..96ac4426270
--- /dev/null
+++ b/src/executor/trt_graph_executor.h
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_
+#define MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_
+
+#if MXNET_USE_TENSORRT
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "./graph_executor.h"
+
+namespace mxnet {
+
+namespace exec {
+
+class TrtGraphExecutor : public GraphExecutor {
+ public:
+  static Executor* TensorRTBind(nnvm::Symbol symbol,
+                                const Context& default_ctx,
+                                const std::map<std::string, Context>& group2ctx,
+                                std::vector<Context> *in_arg_ctxes,
+                                std::vector<Context>* arg_grad_ctxes,
+                                std::vector<Context>* aux_state_ctxes,
+                                std::unordered_map<std::string, TShape>* arg_shape_map,
+                                std::unordered_map<std::string, int>* arg_dtype_map,
+                                std::unordered_map<std::string, int>* arg_stype_map,
+                                std::vector<OpReqType>* grad_req_types,
+                                const std::unordered_set<std::string>& param_names,
+                                std::vector<NDArray>* in_args,
+                                std::vector<NDArray>* arg_grads,
+                                std::vector<NDArray>* aux_states,
+                                std::unordered_map<std::string, NDArray>*
+                                shared_data_arrays = nullptr,
+                                Executor* shared_exec = nullptr);
+
+  virtual void Init(nnvm::Symbol symbol,
+                    const Context& default_ctx,
+                    const std::map<std::string, Context>& ctx_map,
+                    std::vector<Context> *in_arg_ctxes,
+                    std::vector<Context> *arg_grad_ctxes,
+                    std::vector<Context> *aux_state_ctxes,
+                    std::unordered_map<std::string, TShape> *arg_shape_map,
+                    std::unordered_map<std::string, int> *arg_dtype_map,
+                    std::unordered_map<std::string, int> *arg_stype_map,
+                    std::vector<OpReqType> *grad_req_types,
+                    const std::unordered_set<std::string>& shared_arg_names,
+                    std::vector<NDArray>* in_arg_vec,
+                    std::vector<NDArray>* arg_grad_vec,
+                    std::vector<NDArray>* aux_state_vec,
+                    std::unordered_map<std::string, NDArray>* shared_buffer = nullptr,
+                    Executor* shared_exec = nullptr,
+                    const nnvm::NodeEntryMap<NDArray>& feed_dict
+                      = nnvm::NodeEntryMap<NDArray>());
+
+  // Returns symbol representing the TRT optimized graph for comparison purposes.
+  nnvm::Symbol GetOptimizedSymbol();
+
+ protected:
+  Graph ReinitGraph(Graph&& g, const Context &default_ctx,
+        const std::map<std::string, Context> &ctx_map,
+        std::vector<Context> *in_arg_ctxes,
+        std::vector<Context> *arg_grad_ctxes,
+        std::vector<Context> *aux_state_ctxes,
+        std::vector<OpReqType> *grad_req_types,
+        std::unordered_map<std::string, TShape> *arg_shape_map,
+        std::unordered_map<std::string, int> *arg_dtype_map,
+        std::unordered_map<std::string, int> *arg_stype_map,
+        std::unordered_map<std::string, NDArray> *params_map);
+
+  void InitArguments(const nnvm::IndexedGraph& idx,
+                     const nnvm::ShapeVector& inferred_shapes,
+                     const nnvm::DTypeVector& inferred_dtypes,
+                     const StorageTypeVector& inferred_stypes,
+                     const std::vector<Context>& in_arg_ctxes,
+                     const std::vector<Context>& arg_grad_ctxes,
+                     const std::vector<Context>& aux_state_ctxes,
+                     const std::vector<OpReqType>& grad_req_types,
+                     const std::unordered_set<std::string>& shared_arg_names,
+                     const Executor* shared_exec,
+                     std::unordered_map<std::string, NDArray>* shared_buffer,
+                     std::vector<NDArray>* in_arg_vec,
+                     std::vector<NDArray>* arg_grad_vec,
+                     std::vector<NDArray>* aux_state_vec) override;
+};
+
+}  // namespace exec
+
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
+
+#endif  // MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_
diff --git a/src/operator/contrib/nnvm_to_onnx-inl.h b/src/operator/contrib/nnvm_to_onnx-inl.h
new file mode 100644
index 00000000000..58f88b05143
--- /dev/null
+++ b/src/operator/contrib/nnvm_to_onnx-inl.h
@@ -0,0 +1,156 @@
+#ifndef MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_
+#define MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file tensorrt-inl.h
+ * \brief TensorRT Operator
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include <dmlc/logging.h>
+#include <dmlc/memory_io.h>
+#include <dmlc/serializer.h>
+#include <dmlc/parameter.h>
+#include <mxnet/base.h>
+#include <mxnet/operator.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+
+#include <NvInfer.h>
+#include <onnx/onnx.pb.h>
+
+#include <algorithm>
+#include <iostream>
+#include <map>
+#include <vector>
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+#include <string>
+
+#include "./tensorrt-inl.h"
+#include "../operator_common.h"
+#include "../../common/utils.h"
+#include "../../common/serialization.h"
+
+namespace mxnet {
+namespace op {
+namespace nnvm_to_onnx {
+
+using namespace nnvm;
+using namespace ::onnx;
+using int64 = ::google::protobuf::int64;
+
+std::unordered_map<std::string, TShape> GetPlaceholderShapes(const ShapeVector& shape_inputs,
+    const nnvm::IndexedGraph& ig);
+
+std::unordered_map<std::string, uint32_t> GetOutputLookup(const nnvm::IndexedGraph& ig);
+
+void ConvertPlaceholder(
+  const std::string& node_name,
+  const std::unordered_map<std::string, TShape>& placeholder_shapes,
+  GraphProto* const graph_proto);
+
+void ConvertConstant(GraphProto* const graph_proto,
+  const std::string& node_name,
+  std::unordered_map<std::string, NDArray>* const shared_buffer);
+
+void ConvertOutput(op::tensorrt::InferenceMap_t* const trt_output_map,
+                   GraphProto* const graph_proto,
+                   const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
+                   const std::string& node_name,
+                   const nnvm::Graph& g,
+                   const StorageTypeVector& storage_types,
+                   const DTypeVector& dtypes);
+
+typedef void (*ConverterFunction)(NodeProto *node_proto,
+                                  const NodeAttrs &attrs,
+                                  const nnvm::IndexedGraph &ig,
+                                  const array_view<IndexedGraph::NodeEntry> &inputs);
+
+
+// Forward declarations
+void ConvertConvolution(
+                        NodeProto *node_proto,
+                        const NodeAttrs &attrs,
+                        const nnvm::IndexedGraph &ig,
+                        const array_view<IndexedGraph::NodeEntry> &inputs);
+
+
+void ConvertPooling(NodeProto *node_proto,
+                    const NodeAttrs &attrs,
+                    const nnvm::IndexedGraph &ig,
+                    const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertActivation(NodeProto *node_proto,
+                       const NodeAttrs &attrs,
+                       const nnvm::IndexedGraph &ig,
+                       const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertFullyConnected(NodeProto *node_proto,
+                           const NodeAttrs &attrs,
+                           const nnvm::IndexedGraph &ig,
+                           const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertSoftmaxOutput(NodeProto *node_proto,
+                          const NodeAttrs &attrs,
+                          const nnvm::IndexedGraph &ig,
+                          const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertFlatten(NodeProto *node_proto,
+                    const NodeAttrs &attrs,
+                    const nnvm::IndexedGraph &ig,
+                    const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertBatchNorm(NodeProto *node_proto,
+                    const NodeAttrs &attrs,
+                    const nnvm::IndexedGraph &ig,
+                    const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertElementwiseAdd(NodeProto *node_proto,
+                    const NodeAttrs &attrs,
+                    const nnvm::IndexedGraph &ig,
+                    const array_view<IndexedGraph::NodeEntry> &inputs);
+
+TRTParam ConvertNnvmGraphToOnnx(
+    const nnvm::Graph &g,
+    std::unordered_map<std::string, NDArray> *const shared_buffer);
+
+static const std::unordered_map<std::string, ConverterFunction> converter_map = {
+  {"Convolution", ConvertConvolution},
+  {"Pooling", ConvertPooling},
+  {"Activation", ConvertActivation},
+  {"FullyConnected", ConvertFullyConnected},
+  {"SoftmaxOutput", ConvertSoftmaxOutput},
+  {"Flatten", ConvertFlatten},
+  {"BatchNorm", ConvertBatchNorm},
+  {"elemwise_add", ConvertElementwiseAdd}};
+
+}  // namespace nnvm_to_onnx
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
+
+#endif  // MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_
diff --git a/src/operator/contrib/nnvm_to_onnx.cc b/src/operator/contrib/nnvm_to_onnx.cc
new file mode 100644
index 00000000000..902466614c7
--- /dev/null
+++ b/src/operator/contrib/nnvm_to_onnx.cc
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file trt.cc
+ * \brief TensorRT operation registration
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include "./nnvm_to_onnx-inl.h"
+
+#include <mxnet/base.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+
+#include <algorithm>
+#include <fstream>
+#include <iostream>
+#include <unordered_map>
+#include <vector>
+
+#include "../../common/serialization.h"
+#include "../../common/utils.h"
+#include "../../ndarray/ndarray_function.h"
+#include "../../operator/nn/activation-inl.h"
+#include "../../operator/nn/batch_norm-inl.h"
+#include "../../operator/nn/convolution-inl.h"
+#include "../../operator/nn/fully_connected-inl.h"
+#include "../../operator/nn/pooling-inl.h"
+#include "../../operator/softmax_output-inl.h"
+#include "./tensorrt-inl.h"
+
+#if MXNET_USE_TENSORRT_ONNX_CHECKER
+#include <onnx/checker.h>
+#endif  // MXNET_USE_TENSORRT_ONNX_CHECKER
+
+namespace mxnet {
+namespace op {
+namespace nnvm_to_onnx {
+
+op::TRTParam ConvertNnvmGraphToOnnx(
+    const nnvm::Graph& g,
+    std::unordered_map<std::string, NDArray>* const shared_buffer) {
+    op::TRTParam trt_param;
+    op::tensorrt::NameToIdx_t trt_input_map;
+    op::tensorrt::InferenceMap_t trt_output_map;
+
+  const nnvm::IndexedGraph& ig = g.indexed_graph();
+  const auto& storage_types = g.GetAttr<StorageTypeVector>("storage_type");
+  const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
+  const auto& shape_inputs = g.GetAttr<ShapeVector>("shape_inputs");
+
+  for (auto& e : storage_types) {
+    if (e != mshadow::kFloat32) {
+      LOG(FATAL) << "ONNX converter does not support types other than float32 "
+                    "right now.";
+    }
+  }
+
+  ModelProto model_proto;
+  // Need to determine IR versions and features to support
+  model_proto.set_ir_version(static_cast<int64>(2));
+  GraphProto* graph_proto = model_proto.mutable_graph();
+
+  std::unordered_map<std::string, TShape> placeholder_shapes =
+      GetPlaceholderShapes(shape_inputs, ig);
+  std::unordered_map<std::string, uint32_t> output_lookup = GetOutputLookup(ig);
+  uint32_t current_input = 0;
+
+  // Can't do a foreach over IndexedGraph since it doesn't implement begin(), etc.
+  for (uint32_t node_idx = 0; node_idx < ig.num_nodes(); ++node_idx) {
+    const IndexedGraph::Node& node = ig[node_idx];
+    const nnvm::Node* source = node.source;
+    const NodeAttrs& attrs = source->attrs;
+    const Op* op = source->op();
+
+    std::string node_name = attrs.name;
+    // Here, "variable" actually means anything that's not an op i.e. a constant (weights) or a
+    // placeholder
+    if (source->is_variable()) {
+      // Is this a placeholder?
+      if (shared_buffer->count(node_name) == 0) {
+        // This fixes the problem with a SoftmaxOutput node during inference, but it's hacky.
+        // Need to figure out how to properly fix it.
+        if (node_name.find("label") != std::string::npos) {
+          current_input++;
+          continue;
+        }
+        trt_input_map.emplace(node_name, current_input++);
+        ConvertPlaceholder(node_name, placeholder_shapes, graph_proto);
+      } else {
+        // If it's not a placeholder, then by exclusion it's a constant.
+        ConvertConstant(graph_proto, node_name, shared_buffer);
+      }  // is_placeholder
+    } else {
+      // It's an op, rather than a "variable" (constant or placeholder)
+      NodeProto* node_proto = graph_proto->add_node();
+      node_proto->set_name(node_name);
+      if (converter_map.count(op->name) == 0) {
+        LOG(FATAL) << "Conversion for node of type " << op->name << " (node "
+                   << node_name << ") "
+                   << " is not supported yet.";
+      }
+      // Find function ptr to a converter based on the op name, and invoke the converter. This
+      // looks unsafe because find may not succeed, but it does because we're in the operator
+      // logic after testing that this node name does not represent a variable.
+      converter_map.find(op->name)->second(node_proto, attrs, ig, node.inputs);
+      // Add all inputs to the current node (i.e. add graph edges)
+      for (const nnvm::IndexedGraph::NodeEntry& entry : node.inputs) {
+        std::string in_node_name = ig[entry.node_id].source->attrs.name;
+        // As before, we're not adding labels e.g. for SoftmaxOutput, but I wish there was a less
+        // hacky way to do it than name matching.
+        if (in_node_name.find("label") != std::string::npos) {
+          continue;
+        }
+        node_proto->add_input(in_node_name);
+      }
+      // The node's output will have the same name as the node name.
+      node_proto->add_output(node_name);
+      // See if the current node is an output node
+      auto out_iter = output_lookup.find(node_name);
+      // We found an output
+      if (out_iter != output_lookup.end()) {
+        ConvertOutput(&trt_output_map, graph_proto, out_iter, node_name, g,
+                      storage_types, dtypes);
+      }  // output found
+    }    // conversion function exists
+  }      // loop over i from 0 to num_nodes
+
+  model_proto.SerializeToString(&trt_param.serialized_onnx_graph);
+  common::Serialize<op::tensorrt::NameToIdx_t>(trt_input_map,
+                                          &trt_param.serialized_input_map);
+  common::Serialize<op::tensorrt::InferenceMap_t>(trt_output_map,
+                                             &trt_param.serialized_output_map);
+
+#if MXNET_USE_TENSORRT_ONNX_CHECKER
+  onnx::checker::check_model(model_proto);
+#endif  // MXNET_USE_TENSORRT_ONNX_CHECKER
+
+  return trt_param;
+}
+
+void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
+                        const nnvm::IndexedGraph& /*ig*/,
+                        const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  const auto& conv_param = nnvm::get<op::ConvolutionParam>(attrs.parsed);
+
+  node_proto->set_op_type("Conv");
+
+  const TShape kernel = conv_param.kernel;
+  const TShape stride = conv_param.stride;
+  const TShape dilate = conv_param.dilate;
+  const TShape pad = conv_param.pad;
+  const uint32_t num_group = conv_param.num_group;
+  // const bool no_bias = conv_param.no_bias;
+  const dmlc::optional<int> layout = conv_param.layout;
+
+  // kernel shape
+  AttributeProto* const kernel_shape = node_proto->add_attribute();
+  kernel_shape->set_name("kernel_shape");
+  kernel_shape->set_type(AttributeProto::INTS);
+
+  for (const dim_t kval : kernel) {
+    kernel_shape->add_ints(static_cast<int64>(kval));
+  }
+
+  // pads
+  AttributeProto* const pads = node_proto->add_attribute();
+  pads->set_name("pads");
+  pads->set_type(AttributeProto::INTS);
+
+  for (const dim_t kval : pad) {
+    pads->add_ints(static_cast<int64>(kval));
+    pads->add_ints(static_cast<int64>(kval));
+  }
+
+  // dilations
+  AttributeProto* const dilations = node_proto->add_attribute();
+  dilations->set_name("dilations");
+  dilations->set_type(AttributeProto::INTS);
+  for (const dim_t kval : dilate) {
+    dilations->add_ints(static_cast<int64>(kval));
+  }
+
+  // strides
+  AttributeProto* const strides = node_proto->add_attribute();
+  strides->set_name("strides");
+  strides->set_type(AttributeProto::INTS);
+  for (const dim_t kval : stride) {
+    strides->add_ints(static_cast<int64>(kval));
+  }
+
+  // group
+  AttributeProto* const group = node_proto->add_attribute();
+  group->set_name("group");
+  group->set_type(AttributeProto::INT);
+  group->set_i(static_cast<int64>(num_group));
+}  // end ConvertConvolution
+
+void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
+                    const nnvm::IndexedGraph& /*ig*/,
+                    const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  const auto& pooling_param = nnvm::get<op::PoolingParam>(attrs.parsed);
+
+  const TShape kernel = pooling_param.kernel;
+  const TShape stride = pooling_param.stride;
+  const TShape pad = pooling_param.pad;
+  const int pool_type = pooling_param.pool_type;
+  const bool global_pool = pooling_param.global_pool;
+
+  if (global_pool) {
+    if (pool_type == 0) {
+      node_proto->set_op_type("GlobalMaxPool");
+    } else {
+      node_proto->set_op_type("GlobalAveragePool");
+    }
+    return;
+  }
+
+  // kernel_shape
+  AttributeProto* const kernel_shape = node_proto->add_attribute();
+  kernel_shape->set_name("kernel_shape");
+  kernel_shape->set_type(AttributeProto::INTS);
+  for (int kval : kernel) {
+    kernel_shape->add_ints(static_cast<int64>(kval));
+  }
+
+  // pads
+  AttributeProto* const pads = node_proto->add_attribute();
+  pads->set_name("pads");
+  pads->set_type(AttributeProto::INTS);
+  for (int kval : pad) {
+    pads->add_ints(static_cast<int64>(kval));
+  }
+
+  // strides
+  AttributeProto* const strides = node_proto->add_attribute();
+  strides->set_name("strides");
+  strides->set_type(AttributeProto::INTS);
+  for (int kval : stride) {
+    strides->add_ints(static_cast<int64>(kval));
+  }
+
+  if (pool_type == 0) {
+    node_proto->set_op_type("MaxPool");
+  } else {
+    node_proto->set_op_type("AveragePool");
+  }  // average pooling
+  // not global pooling
+}  // end ConvertPooling
+
+void ConvertActivation(NodeProto* node_proto, const NodeAttrs& attrs,
+                       const nnvm::IndexedGraph& /*ig*/,
+                       const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  const auto& act_param = nnvm::get<op::ActivationParam>(attrs.parsed);
+  std::string act_type;
+  switch (act_param.act_type) {
+    case op::activation::kReLU:
+      act_type = "Relu";
+      break;
+    case op::activation::kSigmoid:
+      act_type = "Sigmoid";
+      break;
+    case op::activation::kTanh:
+      act_type = "Tanh";
+      break;
+    case op::activation::kSoftReLU:
+      // act_type = "SoftReLU";
+      throw dmlc::Error("SoftReLU is not supported in ONNX");
+      break;
+    default:
+      throw dmlc::Error("Activation of such type doesn't exist");
+  }
+
+  node_proto->set_op_type(act_type);
+}
+
+void ConvertFullyConnected(NodeProto* node_proto, const NodeAttrs& attrs,
+                           const nnvm::IndexedGraph& /*ig*/,
+                           const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  const auto& act_param = nnvm::get<op::FullyConnectedParam>(attrs.parsed);
+  if (act_param.no_bias) {
+      node_proto->set_op_type("MatMul");
+  } else {
+      node_proto->set_op_type("Gemm");
+
+      AttributeProto* const alpha = node_proto->add_attribute();
+      alpha->set_name("alpha");
+      alpha->set_type(AttributeProto::FLOAT);
+      alpha->set_f(1.0f);
+
+      AttributeProto* const beta = node_proto->add_attribute();
+      beta->set_name("beta");
+      beta->set_type(AttributeProto::FLOAT);
+      beta->set_f(1.0f);
+
+      AttributeProto* const broadcast = node_proto->add_attribute();
+      broadcast->set_name("broadcast");
+      broadcast->set_type(AttributeProto::INT);
+      broadcast->set_i(1);
+
+      AttributeProto* const transA = node_proto->add_attribute();
+      transA->set_name("transA");
+      transA->set_type(AttributeProto::INT);
+      transA->set_i(0);
+
+      AttributeProto* const transB = node_proto->add_attribute();
+      transB->set_name("transB");
+      transB->set_type(AttributeProto::INT);
+      transB->set_i(1);
+  }
+}
+
+void ConvertSoftmaxOutput(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
+                          const nnvm::IndexedGraph& /*ig*/,
+                          const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  node_proto->set_op_type("Softmax");
+
+  // Setting by default to 1 since MXNet doesn't provide such an attribute for softmax in its
+  // node params. This attribute is only relevant when the input is coerced to 2D, and in that
+  // case dimension 0 is assumed to be the batch dimension.
+  AttributeProto* const axis = node_proto->add_attribute();
+  axis->set_name("axis");
+  axis->set_type(AttributeProto::INT);
+  axis->set_i(1);
+}
+
+void ConvertFlatten(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
+                    const nnvm::IndexedGraph& /*ig*/,
+                    const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  node_proto->set_op_type("Flatten");
+
+  // Setting by default to 1 since MXNet doesn't provide such an attribute for Flatten in its
+  // node params. This attribute is only relevant when the input is coerced to 2D, and in that
+  // case dimension 0 is assumed to be the batch dimension.
+  AttributeProto* const axis = node_proto->add_attribute();
+  axis->set_name("axis");
+  axis->set_type(AttributeProto::INT);
+  axis->set_i(1);
+}
+
+void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs,
+                      const nnvm::IndexedGraph& /*ig*/,
+                      const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  node_proto->set_op_type("BatchNormalization");
+  const auto& param = nnvm::get<op::BatchNormParam>(attrs.parsed);
+
+  AttributeProto* const epsilon = node_proto->add_attribute();
+  epsilon->set_name("epsilon");
+  epsilon->set_type(AttributeProto::FLOAT);
+  epsilon->set_f(static_cast<float>(param.eps));
+
+  AttributeProto* const is_test = node_proto->add_attribute();
+  is_test->set_name("is_test");
+  is_test->set_type(AttributeProto::INT);
+  is_test->set_i(1);
+
+  AttributeProto* const momentum = node_proto->add_attribute();
+  momentum->set_name("momentum");
+  momentum->set_type(AttributeProto::FLOAT);
+  momentum->set_f(param.momentum);
+
+  AttributeProto* const spatial = node_proto->add_attribute();
+  spatial->set_name("spatial");
+  spatial->set_type(AttributeProto::INT);
+  spatial->set_i(1);
+
+  AttributeProto* const consumed = node_proto->add_attribute();
+  consumed->set_name("consumed_inputs");
+  consumed->set_type(AttributeProto::INTS);
+
+  for (int i = 0; i < 5; i++) {
+    int val = (i < 3) ? 0 : 1;
+    consumed->add_ints(static_cast<int64>(val));
+  }
+}
+
+void ConvertElementwiseAdd(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
+                           const nnvm::IndexedGraph& /*ig*/,
+                           const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  node_proto->set_op_type("Add");
+  AttributeProto* const axis = node_proto->add_attribute();
+  axis->set_name("axis");
+  axis->set_type(AttributeProto::INT);
+  axis->set_i(1);
+
+  AttributeProto* const broadcast = node_proto->add_attribute();
+  broadcast->set_name("broadcast");
+  broadcast->set_type(AttributeProto::INT);
+  broadcast->set_i(0);  // 1
+}
+
+std::unordered_map<std::string, TShape> GetPlaceholderShapes(
+    const ShapeVector& shape_inputs, const nnvm::IndexedGraph& ig) {
+  std::unordered_map<std::string, TShape> placeholder_shapes;
+  for (uint32_t i = 0; i < shape_inputs.size(); ++i) {
+    std::string name = ig[ig.input_nodes()[i]].source->attrs.name;
+    TShape shp = shape_inputs[i];
+    if (shp.ndim() > 0) {
+      placeholder_shapes.emplace(name, shp);
+    }
+  }
+  return placeholder_shapes;
+}
+
+std::unordered_map<std::string, uint32_t> GetOutputLookup(
+    const nnvm::IndexedGraph& ig) {
+  std::unordered_map<std::string, uint32_t> output_lookup;
+  const std::vector<nnvm::IndexedGraph::NodeEntry>& graph_outputs =
+      ig.outputs();
+  for (uint32_t i = 0; i < graph_outputs.size(); ++i) {
+    const uint32_t id = graph_outputs[i].node_id;
+    const IndexedGraph::Node ig_node = ig[id];
+    const nnvm::Node* const source = ig_node.source;
+    const std::string name = source->attrs.name;
+    output_lookup.emplace(name, i);
+  }
+  return output_lookup;
+}
+
+void ConvertPlaceholder(
+    const std::string& node_name,
+    const std::unordered_map<std::string, TShape>& placeholder_shapes,
+    GraphProto* const graph_proto) {
+  auto val_info_proto = graph_proto->add_input();
+  auto type_proto = val_info_proto->mutable_type()->mutable_tensor_type();
+  auto shape_proto = type_proto->mutable_shape();
+
+  val_info_proto->set_name(node_name);
+  // Will support fp16, etc. in the near future
+  type_proto->set_elem_type(TensorProto_DataType_FLOAT);
+  auto entry_shape = placeholder_shapes.find(node_name)->second;
+
+  for (const auto& elem : entry_shape) {
+    TensorShapeProto_Dimension* const tsp_dim = shape_proto->add_dim();
+    tsp_dim->set_dim_value(static_cast<int64>(elem));
+  }
+}
+
+void ConvertConstant(
+    GraphProto* const graph_proto, const std::string& node_name,
+    std::unordered_map<std::string, NDArray>* const shared_buffer) {
+  NodeProto* const node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
+  node_proto->add_output(node_name);
+  node_proto->set_op_type("Constant");
+
+  const NDArray nd = shared_buffer->find(node_name)->second;
+  const TBlob& blob = nd.data();
+  const TShape shape = blob.shape_;
+  const int32_t size = shape.Size();
+
+  std::shared_ptr<float> shared_data_ptr(new float[size]);
+  float* const data_ptr = shared_data_ptr.get();
+  nd.SyncCopyToCPU(static_cast<void*>(data_ptr), size);
+
+  AttributeProto* const tensor_attr = node_proto->add_attribute();
+  tensor_attr->set_name("value");
+  tensor_attr->set_type(AttributeProto::TENSOR);
+
+  TensorProto* const tensor_proto = tensor_attr->mutable_t();
+  tensor_proto->set_data_type(TensorProto_DataType_FLOAT);
+  for (auto& dim : shape) {
+    tensor_proto->add_dims(static_cast<int64>(dim));
+  }
+
+  for (int blob_idx = 0; blob_idx < size; ++blob_idx) {
+    tensor_proto->add_float_data(data_ptr[blob_idx]);
+  }
+}
+
+void ConvertOutput(
+    op::tensorrt::InferenceMap_t* const trt_output_map,
+    GraphProto* const graph_proto,
+    const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
+    const std::string& node_name, const nnvm::Graph& g,
+    const StorageTypeVector& storage_types, const DTypeVector& dtypes) {
+  const nnvm::IndexedGraph& ig = g.indexed_graph();
+  uint32_t out_idx = ig.entry_id(ig.outputs()[out_iter->second]);
+  TShape out_shape = g.GetAttr<nnvm::ShapeVector>("shape")[out_idx];
+  int storage_type = storage_types[out_idx];
+  int dtype = dtypes[out_idx];
+
+  // This should work with fp16 as well
+  op::tensorrt::InferenceTuple_t out_tuple{out_iter->second, out_shape, storage_type,
+                                      dtype};
+
+  trt_output_map->emplace(node_name, out_tuple);
+
+  auto graph_out = graph_proto->add_output();
+  auto tensor_type = graph_out->mutable_type()->mutable_tensor_type();
+  auto tensor_shape_proto = tensor_type->mutable_shape();
+  graph_out->set_name(node_name);
+
+  // Also support fp16.
+  tensor_type->set_elem_type(TensorProto_DataType_FLOAT);
+
+  for (int64_t dim_shp : out_shape) {
+    TensorShapeProto_Dimension* const tsp_dim = tensor_shape_proto->add_dim();
+    tsp_dim->set_dim_value(static_cast<int64>(dim_shp));
+  }
+}
+
+}  // namespace nnvm_to_onnx
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/operator/contrib/tensorrt-inl.h b/src/operator/contrib/tensorrt-inl.h
new file mode 100644
index 00000000000..be335ab1208
--- /dev/null
+++ b/src/operator/contrib/tensorrt-inl.h
@@ -0,0 +1,113 @@
+#ifndef MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_
+#define MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file tensorrt-inl.h
+ * \brief TensorRT Operator
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include <dmlc/logging.h>
+#include <dmlc/memory_io.h>
+#include <dmlc/serializer.h>
+#include <dmlc/parameter.h>
+#include <mxnet/base.h>
+#include <mxnet/operator.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+
+#include <NvInfer.h>
+#include <onnx/onnx.pb.h>
+
+#include <algorithm>
+#include <iostream>
+#include <map>
+#include <vector>
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+#include <string>
+
+#include "../operator_common.h"
+#include "../../common/utils.h"
+#include "../../common/serialization.h"
+#include "../../executor/exec_pass.h"
+#include "../../executor/graph_executor.h"
+#include "../../executor/onnx_to_tensorrt.h"
+
+namespace mxnet {
+namespace op {
+
+using namespace nnvm;
+using namespace ::onnx;
+using int64 = ::google::protobuf::int64;
+
+namespace tensorrt {
+  enum class TypeIO { Inputs = 0, Outputs = 1 };
+  using NameToIdx_t = std::map<std::string, int32_t>;
+  using InferenceTuple_t = std::tuple<uint32_t, TShape, int, int>;
+  using InferenceMap_t = std::map<std::string, InferenceTuple_t>;
+}  // namespace tensorrt
+
+using trt_name_to_idx = std::map<std::string, uint32_t>;
+
+struct TRTParam : public dmlc::Parameter<TRTParam> {
+  std::string serialized_onnx_graph;
+  std::string serialized_input_map;
+  std::string serialized_output_map;
+  tensorrt::NameToIdx_t input_map;
+  tensorrt::InferenceMap_t output_map;
+  ::onnx::ModelProto onnx_pb_graph;
+
+  TRTParam() {}
+
+  TRTParam(const ::onnx::ModelProto& onnx_graph,
+           const tensorrt::InferenceMap_t& input_map,
+           const tensorrt::NameToIdx_t& output_map) {
+    common::Serialize(input_map, &serialized_input_map);
+    common::Serialize(output_map, &serialized_output_map);
+    onnx_graph.SerializeToString(&serialized_onnx_graph);
+  }
+
+DMLC_DECLARE_PARAMETER(TRTParam) {
+    DMLC_DECLARE_FIELD(serialized_onnx_graph)
+    .describe("Serialized ONNX graph");
+    DMLC_DECLARE_FIELD(serialized_input_map)
+    .describe("Map from inputs to topological order as input.");
+    DMLC_DECLARE_FIELD(serialized_output_map)
+    .describe("Map from outputs to order in g.outputs.");
+  }
+};
+
+struct TRTEngineParam {
+  nvinfer1::IExecutionContext* trt_executor;
+  std::vector<std::pair<uint32_t, tensorrt::TypeIO> > binding_map;
+};
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
+
+#endif  // MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_
diff --git a/src/operator/contrib/tensorrt.cc b/src/operator/contrib/tensorrt.cc
new file mode 100644
index 00000000000..619fe1e2b8f
--- /dev/null
+++ b/src/operator/contrib/tensorrt.cc
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file trt.cc
+ * \brief TensorRT operation registration
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include "./tensorrt-inl.h"
+
+#include <mxnet/base.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+
+#include <algorithm>
+#include <fstream>
+#include <iostream>
+#include <unordered_map>
+#include <vector>
+
+#include "../../common/serialization.h"
+#include "../../common/utils.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(TRTParam);
+
+OpStatePtr GetPtrMapping(nvinfer1::ICudaEngine* trt_engine,
+                         tensorrt::NameToIdx_t input_map,
+                         tensorrt::NameToIdx_t output_map) {
+  TRTEngineParam param;
+  for (int b = 0; b < trt_engine->getNbBindings(); ++b) {
+    const std::string& binding_name = trt_engine->getBindingName(b);
+    if (trt_engine->bindingIsInput(b)) {
+      param.binding_map.emplace_back(input_map[binding_name],
+                                     tensorrt::TypeIO::Inputs);
+    } else {
+      param.binding_map.emplace_back(output_map[binding_name],
+                                     tensorrt::TypeIO::Outputs);
+    }
+  }
+  param.trt_executor = trt_engine->createExecutionContext();
+  return OpStatePtr::Create<TRTEngineParam>(param);
+}
+
+OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/,
+                          const std::vector<TShape>& /*ishape*/,
+                          const std::vector<int>& /*itype*/) {
+  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+
+  ::onnx::ModelProto model_proto;
+  bool success = model_proto.ParseFromString(node_param.serialized_onnx_graph);
+  if (!success) {
+    LOG(FATAL) << "Problems parsing serialized ONNX model.";
+  }
+  auto graph = model_proto.graph();
+  auto first_input_type = graph.input(0).type().tensor_type();
+  auto dim_value = first_input_type.shape().dim(0).dim_value();
+  auto batch_size = static_cast<int32_t >(dim_value);
+  // Need to set up max workspace size based on device properties
+  nvinfer1::ICudaEngine* const trt_engine = ::onnx_to_tensorrt::onnxToTrtCtx(
+      node_param.serialized_onnx_graph, batch_size, 1 << 30);
+
+  tensorrt::NameToIdx_t output_map;
+  for (auto& el : node_param.output_map) {
+    output_map[el.first] = std::get<0>(el.second);
+  }
+  return GetPtrMapping(trt_engine, node_param.input_map, output_map);
+}
+
+void TRTParamParser(nnvm::NodeAttrs* attrs) {
+  TRTParam param_;
+
+  try {
+    param_.Init(attrs->dict);
+    common::Deserialize(&param_.input_map, param_.serialized_input_map);
+    common::Deserialize(&param_.output_map, param_.serialized_output_map);
+    param_.onnx_pb_graph.ParseFromString(param_.serialized_onnx_graph);
+  } catch (const dmlc::ParamError& e) {
+    std::ostringstream os;
+    os << e.what();
+    os << ", in operator " << attrs->op->name << "("
+       << "name=\"" << attrs->name << "\"";
+    for (const auto& k : attrs->dict) {
+      os << ", " << k.first << "=\"" << k.second << "\"";
+    }
+    os << ")";
+    throw dmlc::ParamError(os.str());
+  }
+
+  attrs->parsed = std::move(param_);
+}
+
+inline bool TRTInferShape(const NodeAttrs& attrs, std::vector<TShape>* /*in_shape*/,
+                          std::vector<TShape>* out_shape) {
+  const auto &node_param = nnvm::get<TRTParam>(attrs.parsed);
+  for (auto& el : node_param.output_map) {
+    (*out_shape)[std::get<0>(el.second)] = std::get<1>(el.second);
+  }
+  return true;
+}
+
+inline bool TRTInferStorageType(const NodeAttrs& /*attrs*/, const int /*dev_mask*/,
+                                DispatchMode* dispatch_mode,
+                                std::vector<int>* /*in_storage_type*/,
+                                std::vector<int>* out_storage_type) {
+  return storage_type_assign(out_storage_type, mxnet::kDefaultStorage,
+                             dispatch_mode, DispatchMode::kFCompute);
+}
+
+inline bool TRTInferType(const NodeAttrs& attrs, std::vector<int>* /*in_dtype*/,
+                         std::vector<int>* out_dtype) {
+  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+  for (auto& el : node_param.output_map) {
+    (*out_dtype)[std::get<0>(el.second)] = std::get<3>(el.second);
+  }
+  return true;
+}
+
+inline std::vector<std::string> TRTListInputNames(const NodeAttrs& attrs) {
+  std::vector<std::string> output;
+  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+  output.resize(node_param.input_map.size());
+  for (auto& el : node_param.input_map) {
+    output[el.second] = el.first;
+  }
+  return output;
+}
+
+inline std::vector<std::string> TRTListOutputNames(const NodeAttrs& attrs) {
+  std::vector<std::string> output;
+  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+  output.resize(node_param.output_map.size());
+  for (auto& el : node_param.output_map) {
+    output[std::get<0>(el.second)] = el.first;
+  }
+  return output;
+}
+
+NNVM_REGISTER_OP(_trt_op)
+    .describe(R"code(TRT operation (one engine)
+)code" ADD_FILELINE)
+    .set_num_inputs([](const NodeAttrs& attrs) {
+      const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+      return node_param.input_map.size();
+    })
+    .set_num_outputs([](const NodeAttrs& attrs) {
+      const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+      return node_param.output_map.size();
+    })
+    .set_attr_parser(TRTParamParser)
+    .set_attr<nnvm::FInferShape>("FInferShape", TRTInferShape)
+    .set_attr<nnvm::FInferType>("FInferType", TRTInferType)
+    .set_attr<nnvm::FListInputNames>("FListInputNames", TRTListInputNames)
+    .set_attr<nnvm::FListOutputNames>("FListOutputNames", TRTListOutputNames)
+    .set_attr<FCreateOpState>("FCreateOpState", TRTCreateState)
+    .set_attr<FInferStorageType>("FInferStorageType", TRTInferStorageType);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/operator/contrib/tensorrt.cu b/src/operator/contrib/tensorrt.cu
new file mode 100644
index 00000000000..2fe8727b73e
--- /dev/null
+++ b/src/operator/contrib/tensorrt.cu
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file trt.cu
+ * \brief TensorRT GPU operation
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include "./tensorrt-inl.h"
+
+namespace mxnet {
+namespace op {
+
+#define CHECK_CUDART(x) do { \
+  cudaError_t res = (x); \
+  if (res != cudaSuccess) { \
+    fprintf(stderr, "CUDART: %s = %d (%s) at (%s:%d)\n", \
+      #x, res, cudaGetErrorString(res), __FILE__, __LINE__); \
+    exit(1); \
+  } \
+} while (0)
+
+void TRTCompute(const OpStatePtr& state, const OpContext& ctx,
+                     const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
+                     const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+  cudaStream_t cuda_s = Stream<gpu>::GetStream(s);
+  const auto& param = state.get_state<TRTEngineParam>();
+  std::vector<void*> bindings;
+  bindings.reserve(param.binding_map.size());
+  for (auto& p : param.binding_map) {
+    if (p.second == tensorrt::TypeIO::Inputs) {
+      bindings.emplace_back(inputs[p.first].dptr_);
+    } else {
+      bindings.emplace_back(outputs[p.first].dptr_);
+    }
+  }
+
+  const int batch_size = static_cast<int>(inputs[0].shape_[0]);
+  param.trt_executor->enqueue(batch_size, bindings.data(), cuda_s, nullptr);
+  CHECK_CUDART(cudaStreamSynchronize(cuda_s));
+}
+
+NNVM_REGISTER_OP(_trt_op)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", TRTCompute);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/tests/.gitignore b/tests/.gitignore
index d6459089c24..3e5eed695f0 100644
--- a/tests/.gitignore
+++ b/tests/.gitignore
@@ -1 +1,2 @@
 *_unittest
+*.gz
diff --git a/tests/cpp/misc/serialization.cc b/tests/cpp/misc/serialization.cc
new file mode 100644
index 00000000000..96f8b6c3a3a
--- /dev/null
+++ b/tests/cpp/misc/serialization.cc
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <../../../src/common/serialization.h>
+
+using namespace mxnet;
+using namespace std;
+
+/*
+ * Test that used datastruct are properly serialized and deserialized
+ */
+
+TEST(SerializerTest, InputMapCorrect) {
+    std::map<std::string, int32_t> input_map;
+    input_map.emplace("input_0", 2);
+    input_map.emplace("another_input", 0);
+    input_map.emplace("last_input", 1);
+    std::string serialized_data;
+    common::Serialize(input_map, &serialized_data);
+    std::map<std::string, int32_t> deserialized_input_map;
+    common::Deserialize(&deserialized_input_map, serialized_data);
+    ASSERT_EQ(input_map.size(), deserialized_input_map.size());
+    for (auto& p : input_map) {
+        auto it = deserialized_input_map.find(p.first);
+        ASSERT_NE(it, deserialized_input_map.end());
+        ASSERT_EQ(it->second, p.second);
+    }
+}
+
+TEST(SerializerTest, OutputMapCorrect) {
+    std::map<std::string, std::tuple<uint32_t, TShape, int, int> > output_map;
+    output_map.emplace("output_0", std::make_tuple(1, TShape({23, 12, 63, 432}), 0, 1));
+    output_map.emplace("another_output", std::make_tuple(2, TShape({23, 123}), 14, -23));
+    output_map.emplace("last_output", std::make_tuple(0, TShape({0}), -1, 0));
+    std::string serialized_data;
+    common::Serialize(output_map, &serialized_data);
+    std::map<std::string, std::tuple<uint32_t, TShape, int, int> > deserialized_output_map;
+    common::Deserialize(&deserialized_output_map, serialized_data);
+    ASSERT_EQ(output_map.size(), deserialized_output_map.size());
+    for (auto& p : output_map) {
+        auto it = deserialized_output_map.find(p.first);
+        ASSERT_NE(it, deserialized_output_map.end());
+        auto lhs = it->second;
+        auto rhs = p.second;
+        ASSERT_EQ(std::get<0>(lhs), std::get<0>(rhs));
+        ASSERT_EQ(std::get<1>(lhs), std::get<1>(rhs));
+        ASSERT_EQ(std::get<2>(lhs), std::get<2>(rhs));
+        ASSERT_EQ(std::get<3>(lhs), std::get<3>(rhs));
+    }
+}
+
diff --git a/tests/python/tensorrt/common.py b/tests/python/tensorrt/common.py
new file mode 100644
index 00000000000..eb599f69973
--- /dev/null
+++ b/tests/python/tensorrt/common.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+from ctypes.util import find_library
+
+
+def check_tensorrt_installation():
+    assert find_library('nvinfer') is not None, "Can't find the TensorRT shared library"
+
+
+def merge_dicts(*dict_args):
+    """Merge arg_params and aux_params to populate shared_buffer"""
+    result = {}
+    for dictionary in dict_args:
+        result.update(dictionary)
+    return result
+
+
+def get_fp16_infer_for_fp16_graph():
+    return int(os.environ.get("MXNET_TENSORRT_USE_FP16_FOR_FP32", 0))
+
+
+def set_fp16_infer_for_fp16_graph(status=False):
+    os.environ["MXNET_TENSORRT_USE_FP16_FOR_FP32"] = str(int(status))
diff --git a/tests/python/tensorrt/lenet5_common.py b/tests/python/tensorrt/lenet5_common.py
new file mode 100644
index 00000000000..347d6f3c11b
--- /dev/null
+++ b/tests/python/tensorrt/lenet5_common.py
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import mxnet as mx
+from common import *
+
+def get_iters(mnist, batch_size):
+    """Get MNIST iterators."""
+    train_iter = mx.io.NDArrayIter(mnist['train_data'],
+                                   mnist['train_label'],
+                                   batch_size,
+                                   shuffle=True)
+    val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
+    test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
+    all_test_labels = np.array(mnist['test_label'])
+    return train_iter, val_iter, test_iter, all_test_labels
diff --git a/tests/python/tensorrt/lenet5_train.py b/tests/python/tensorrt/lenet5_train.py
new file mode 100644
index 00000000000..74de66620e8
--- /dev/null
+++ b/tests/python/tensorrt/lenet5_train.py
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import numpy as np
+import mxnet as mx
+from lenet5_common import get_iters
+
+def lenet5():
+    """LeNet-5 Symbol"""
+    #pylint: disable=no-member
+    data = mx.sym.Variable('data')
+    conv1 = mx.sym.Convolution(data=data, kernel=(5, 5), num_filter=20)
+    tanh1 = mx.sym.Activation(data=conv1, act_type="tanh")
+    pool1 = mx.sym.Pooling(data=tanh1, pool_type="max",
+                           kernel=(2, 2), stride=(2, 2))
+    # second conv
+    conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=50)
+    tanh2 = mx.sym.Activation(data=conv2, act_type="tanh")
+    pool2 = mx.sym.Pooling(data=tanh2, pool_type="max",
+                           kernel=(2, 2), stride=(2, 2))
+    # first fullc
+    flatten = mx.sym.Flatten(data=pool2)
+    fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=500)
+    tanh3 = mx.sym.Activation(data=fc1, act_type="tanh")
+    # second fullc
+    fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10)
+    # loss
+    lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax')
+    #pylint: enable=no-member
+    return lenet
+
+def train_lenet5(num_epochs, batch_size, train_iter, val_iter, test_iter):
+    """train LeNet-5 model on MNIST data"""
+    ctx = mx.gpu(0)
+    lenet_model = mx.mod.Module(lenet5(), context=ctx)
+
+    lenet_model.fit(train_iter,
+                    eval_data=val_iter,
+                    optimizer='sgd',
+                    optimizer_params={'learning_rate': 0.1, 'momentum': 0.9},
+                    eval_metric='acc',
+                    batch_end_callback=mx.callback.Speedometer(batch_size, 1),
+                    num_epoch=num_epochs)
+
+    # predict accuracy for lenet
+    acc = mx.metric.Accuracy()
+    lenet_model.score(test_iter, acc)
+    accuracy = acc.get()[1]
+    assert accuracy > 0.95, "LeNet-5 training accuracy on MNIST was too low"
+    return lenet_model
+
+
+def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size):
+    """Run inference with either MXNet or TensorRT"""
+
+    shared_buffer = merge_dicts(arg_params, aux_params)
+    if not get_use_tensorrt():
+        shared_buffer = dict([(k, v.as_in_context(mx.gpu(0))) for k, v in shared_buffer.items()])
+    executor = sym.simple_bind(ctx=mx.gpu(0),
+                               data=(batch_size,) +  mnist['test_data'].shape[1:],
+                               softmax_label=(batch_size,),
+                               shared_buffer=shared_buffer,
+                               grad_req='null',
+                               force_rebind=True)
+
+    # Get this value from all_test_labels
+    # Also get classes from the dataset
+    num_ex = 10000
+    all_preds = np.zeros([num_ex, 10])
+    test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
+
+    example_ct = 0
+
+    for idx, dbatch in enumerate(test_iter):
+        executor.arg_dict["data"][:] = dbatch.data[0]
+        executor.forward(is_train=False)
+        offset = idx*batch_size
+        extent = batch_size if num_ex - offset > batch_size else num_ex - offset
+        all_preds[offset:offset+extent, :] = executor.outputs[0].asnumpy()[:extent]
+        example_ct += extent
+
+    all_preds = np.argmax(all_preds, axis=1)
+    matches = (all_preds[:example_ct] == all_test_labels[:example_ct]).sum()
+
+    percentage = 100.0 * matches / example_ct
+
+    return percentage
+
+if __name__ == '__main__':
+
+    num_epochs = 10
+    batch_size = 128
+    model_name = 'lenet5'
+    model_dir = os.getenv("LENET_MODEL_DIR", "/tmp")
+    model_file = '%s/%s-symbol.json' % (model_dir, model_name)
+    params_file = '%s/%s-%04d.params' % (model_dir, model_name, num_epochs)
+
+    if not (os.path.exists(model_file) and os.path.exists(params_file)):
+        mnist = mx.test_utils.get_mnist()
+
+        _, _, _, all_test_labels = get_iters(mnist, batch_size)
+
+        trained_lenet = train_lenet5(num_epochs, batch_size,
+                                    *get_iters(mnist, batch_size)[:-1])
+        trained_lenet.save_checkpoint(model_name, num_epochs)
diff --git a/tests/python/tensorrt/test_cvnets.py b/tests/python/tensorrt/test_cvnets.py
new file mode 100644
index 00000000000..2c64b96fafc
--- /dev/null
+++ b/tests/python/tensorrt/test_cvnets.py
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import gc
+import gluoncv
+import mxnet as mx
+import numpy as np
+
+from mxnet import gluon
+from time import time
+
+from mxnet.gluon.data.vision import transforms
+
+
+def get_classif_model(model_name, use_tensorrt, ctx=mx.gpu(0), batch_size=128):
+    mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt)
+    h, w = 32, 32
+    net = gluoncv.model_zoo.get_model(model_name, pretrained=True)
+    data = mx.sym.var('data')
+
+    if use_tensorrt:
+        out = net(data)
+        softmax = mx.sym.SoftmaxOutput(out, name='softmax')
+        all_params = dict([(k, v.data()) for k, v in net.collect_params().items()])
+        executor = softmax.simple_bind(ctx=ctx, data=(batch_size, 3, h, w),
+                                       softmax_label=( batch_size,), grad_req='null',
+                                       shared_buffer=all_params, force_rebind=True)
+    else:
+        # Convert gluon model to Symbolic
+        net.hybridize()
+        net.forward(mx.ndarray.zeros((batch_size, 3, h, w)))
+        net.export(model_name)
+        symbol, arg_params, aux_params = mx.model.load_checkpoint(model_name, 0)
+        executor = symbol.simple_bind(ctx=ctx, data=(batch_size, 3, h, w),
+                                      softmax_label=(batch_size,))
+        executor.copy_params_from(arg_params, aux_params)
+    return executor
+
+
+def cifar10_infer(model_name, use_tensorrt, num_workers, ctx=mx.gpu(0), batch_size=128):
+    executor = get_classif_model(model_name, use_tensorrt, ctx, batch_size)
+
+    num_ex = 10000
+    all_preds = np.zeros([num_ex, 10])
+
+    all_label_test = np.zeros(num_ex)
+
+    transform_test = transforms.Compose([
+        transforms.ToTensor(),
+        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
+    ])
+
+    data_loader = lambda: gluon.data.DataLoader(
+        gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
+        batch_size=batch_size, shuffle=False, num_workers=num_workers)
+
+    val_data = data_loader()
+
+    for idx, (data, label) in enumerate(val_data):
+        # Skip last batch if it's undersized.
+        if data.shape[0] < batch_size:
+            continue
+        offset = idx * batch_size
+        all_label_test[offset:offset + batch_size] = label.asnumpy()
+
+        # warm-up, but don't use result
+        executor.forward(is_train=False, data=data)
+        executor.outputs[0].wait_to_read()
+
+    gc.collect()
+    val_data = data_loader()
+    example_ct = 0
+    start = time()
+
+    # if use_tensorrt:
+    for idx, (data, label) in enumerate(val_data):
+        # Skip last batch if it's undersized.
+        if data.shape[0] < batch_size:
+            continue
+        executor.forward(is_train=False, data=data)
+        preds = executor.outputs[0].asnumpy()
+        offset = idx * batch_size
+        all_preds[offset:offset + batch_size, :] = preds[:batch_size]
+        example_ct += batch_size
+
+    all_preds = np.argmax(all_preds, axis=1)
+    matches = (all_preds[:example_ct] == all_label_test[:example_ct]).sum()
+    duration = time() - start
+
+    return duration, 100.0 * matches / example_ct
+
+
+def run_experiment_for(model_name, batch_size, num_workers):
+    print("\n===========================================")
+    print("Model: %s" % model_name)
+    print("===========================================")
+    print("*** Running inference using pure MXNet ***\n")
+    mx_duration, mx_pct = cifar10_infer(model_name=model_name, batch_size=batch_size,
+                                        num_workers=num_workers, use_tensorrt=False)
+    print("\nMXNet: time elapsed: %.3fs, accuracy: %.2f%%" % (mx_duration, mx_pct))
+    print("\n*** Running inference using MXNet + TensorRT ***\n")
+    trt_duration, trt_pct = cifar10_infer(model_name=model_name, batch_size=batch_size,
+                                          num_workers=num_workers, use_tensorrt=True)
+    print("TensorRT: time elapsed: %.3fs, accuracy: %.2f%%" % (trt_duration, trt_pct))
+    speedup = mx_duration / trt_duration
+    print("TensorRT speed-up (not counting compilation): %.2fx" % speedup)
+
+    acc_diff = abs(mx_pct - trt_pct)
+    print("Absolute accuracy difference: %f" % acc_diff)
+    return speedup, acc_diff
+
+
+def test_tensorrt_on_cifar_resnets(batch_size=32, tolerance=0.1, num_workers=1):
+    original_try_value = mx.contrib.tensorrt.get_use_tensorrt()
+    try:
+        models = [
+            'cifar_resnet20_v1',
+            'cifar_resnet56_v1',
+            'cifar_resnet110_v1',
+            'cifar_resnet20_v2',
+            'cifar_resnet56_v2',
+            'cifar_resnet110_v2',
+            'cifar_wideresnet16_10',
+            'cifar_wideresnet28_10',
+            'cifar_wideresnet40_8',
+            'cifar_resnext29_16x64d'
+        ]
+
+        num_models = len(models)
+
+        speedups = np.zeros(num_models, dtype=np.float32)
+        acc_diffs = np.zeros(num_models, dtype=np.float32)
+
+        test_start = time()
+
+        for idx, model in enumerate(models):
+            speedup, acc_diff = run_experiment_for(model, batch_size, num_workers)
+            speedups[idx] = speedup
+            acc_diffs[idx] = acc_diff
+            assert acc_diff < tolerance, "Accuracy difference between MXNet and TensorRT > %.2f%% for model %s" % (
+                tolerance, model)
+
+        print("Perf and correctness checks run on the following models:")
+        print(models)
+        mean_speedup = np.mean(speedups)
+        std_speedup = np.std(speedups)
+        print("\nSpeedups:")
+        print(speedups)
+        print("Speedup range: [%.2f, %.2f]" % (np.min(speedups), np.max(speedups)))
+        print("Mean speedup: %.2f" % mean_speedup)
+        print("St. dev. of speedups: %.2f" % std_speedup)
+        print("\nAcc. differences: %s" % str(acc_diffs))
+
+        test_duration = time() - test_start
+
+        print("Test duration: %.2f seconds" % test_duration)
+    finally:
+        mx.contrib.tensorrt.set_use_tensorrt(original_try_value)
+
+
+if __name__ == '__main__':
+    import nose
+
+    nose.runmodule()
diff --git a/tests/python/tensorrt/test_cycle.py b/tests/python/tensorrt/test_cycle.py
new file mode 100644
index 00000000000..25f515a106a
--- /dev/null
+++ b/tests/python/tensorrt/test_cycle.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from common import *
+
+
+def detect_cycle_from(sym, visited, stack):
+    visited.add(sym.handle.value)
+    stack.add(sym.handle.value)
+    for s in sym.get_children():
+        if s.handle.value not in visited:
+            if detect_cycle_from(sym, visited, stack):
+                return True
+        elif s.handle.value in stack:
+            return True
+        stack.remove(sym.handle.value)
+    return False
+
+
+def has_no_cycle(sym):
+    visited = set()
+    stack = set()
+    all_nodes = sym.get_internals()
+    for s in all_nodes:
+        if s.handle.value in visited:
+            if detect_cycle_from(s, visited, stack):
+                return False
+    return True
+
+
+def test_simple_cycle():
+    inp = mx.sym.Variable('input', shape=[1,10])
+    A = mx.sym.FullyConnected(data=inp, num_hidden=10, no_bias=False, name='A')
+    B = mx.sym.FullyConnected(data=A, num_hidden=10, no_bias=False, name='B')
+    D = mx.sym.sin(data=A, name='D')
+    C = mx.sym.elemwise_add(lhs=B, rhs=D, name='C')
+    arg_params = {
+                'I_weight': mx.nd.zeros([10,10]),
+                'I_bias': mx.nd.zeros([10]),
+                'A_weight': mx.nd.zeros([10,10]),
+                'A_bias': mx.nd.zeros([10]),
+                'B_weight': mx.nd.zeros([10,10]),
+                'B_bias': mx.nd.zeros([10]),
+               }
+
+    executor = C.simple_bind(ctx=mx.gpu(0), data=(1,10), softmax_label=(1,),
+                           shared_buffer=arg_params, grad_req='null', force_rebind=True)
+    optimized_graph = mx.contrib.tensorrt.get_optimized_symbol(executor)
+    assert has_no_cycle(optimized_graph), "The graph optimized by TRT contains a cycle"
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()
diff --git a/tests/python/tensorrt/test_tensorrt_lenet5.py b/tests/python/tensorrt/test_tensorrt_lenet5.py
new file mode 100644
index 00000000000..8e2730bcc7d
--- /dev/null
+++ b/tests/python/tensorrt/test_tensorrt_lenet5.py
@@ -0,0 +1,103 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import numpy as np
+import mxnet as mx
+from common import *
+from lenet5_common import get_iters
+
+
+def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size, use_tensorrt):
+    """Run inference with either MXNet or TensorRT"""
+    mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt)
+
+    shared_buffer = merge_dicts(arg_params, aux_params)
+    if not use_tensorrt:
+        shared_buffer = dict([(k, v.as_in_context(mx.gpu(0))) for k, v in shared_buffer.items()])
+
+    executor = sym.simple_bind(ctx=mx.gpu(0),
+                               data=(batch_size,) +  mnist['test_data'].shape[1:],
+                               softmax_label=(batch_size,),
+                               shared_buffer=shared_buffer,
+                               grad_req='null',
+                               force_rebind=True)
+
+    # Get this value from all_test_labels
+    # Also get classes from the dataset
+    num_ex = 10000
+    all_preds = np.zeros([num_ex, 10])
+    test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
+
+    example_ct = 0
+
+    for idx, dbatch in enumerate(test_iter):
+        executor.arg_dict["data"][:] = dbatch.data[0]
+        executor.forward(is_train=False)
+        offset = idx*batch_size
+        extent = batch_size if num_ex - offset > batch_size else num_ex - offset
+        all_preds[offset:offset+extent, :] = executor.outputs[0].asnumpy()[:extent]
+        example_ct += extent
+
+    all_preds = np.argmax(all_preds, axis=1)
+    matches = (all_preds[:example_ct] == all_test_labels[:example_ct]).sum()
+
+    percentage = 100.0 * matches / example_ct
+
+    return percentage
+
+
+def test_tensorrt_inference():
+    """Run LeNet-5 inference comparison between MXNet and TensorRT."""
+    original_try_value = mx.contrib.tensorrt.get_use_tensorrt()
+    try:
+        check_tensorrt_installation()
+        mnist = mx.test_utils.get_mnist()
+        num_epochs = 10
+        batch_size = 128
+        model_name = 'lenet5'
+        model_dir = os.getenv("LENET_MODEL_DIR", "/tmp")
+        model_file = '%s/%s-symbol.json' % (model_dir, model_name)
+        params_file = '%s/%s-%04d.params' % (model_dir, model_name, num_epochs)
+
+        _, _, _, all_test_labels = get_iters(mnist, batch_size)
+
+        # Load serialized MXNet model (model-symbol.json + model-epoch.params)
+        sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, num_epochs)
+
+        print("LeNet-5 test")
+        print("Running inference in MXNet")
+        mx_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels,
+                               batch_size=batch_size, use_tensorrt=False)
+
+        print("Running inference in MXNet-TensorRT")
+        trt_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels,
+                                batch_size=batch_size, use_tensorrt=True)
+
+        print("MXNet accuracy: %f" % mx_pct)
+        print("MXNet-TensorRT accuracy: %f" % trt_pct)
+
+        assert abs(mx_pct - trt_pct) < 1e-2, \
+            """Diff. between MXNet & TensorRT accuracy too high:
+               MXNet = %f, TensorRT = %f""" % (mx_pct, trt_pct)
+    finally:
+        mx.contrib.tensorrt.set_use_tensorrt(original_try_value)
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()
diff --git a/tests/python/tensorrt/test_training_warning.py b/tests/python/tensorrt/test_training_warning.py
new file mode 100644
index 00000000000..3008a4234b5
--- /dev/null
+++ b/tests/python/tensorrt/test_training_warning.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import gluoncv
+import mxnet as mx
+
+from tests.python.unittest.common import assertRaises
+
+
+def test_training_without_trt():
+    run_resnet(is_train=True, use_tensorrt=False)
+
+
+def test_inference_without_trt():
+    run_resnet(is_train=False, use_tensorrt=False)
+
+
+def test_training_with_trt():
+    assertRaises(RuntimeError, run_resnet, is_train=True, use_tensorrt=True)
+
+
+def test_inference_with_trt():
+    run_resnet(is_train=False, use_tensorrt=True)
+
+
+def run_resnet(is_train, use_tensorrt):
+    original_trt_value = mx.contrib.tensorrt.get_use_tensorrt()
+    try:
+        mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt)
+        ctx = mx.gpu(0)
+        batch_size = 1
+        h = 32
+        w = 32
+        model_name = 'cifar_resnet20_v1'
+        resnet = gluoncv.model_zoo.get_model(model_name, pretrained=True)
+        data = mx.sym.var('data')
+        out = resnet(data)
+        softmax = mx.sym.SoftmaxOutput(out, name='softmax')
+        all_params = dict([(k, v.data()) for k, v in resnet.collect_params().items()])
+        if is_train:
+            grad_req = 'write'
+        else:
+            grad_req = 'null'
+        softmax.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), softmax_label=(batch_size,),
+                            shared_buffer=all_params,  force_rebind=True, grad_req=grad_req)
+    finally:
+        mx.contrib.tensorrt.set_use_tensorrt(original_trt_value)
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services