You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2018/08/10 09:38:14 UTC

[incubator-mxnet] branch master updated: [MXNET-703] TensorRT runtime integration (#11325)

This is an automated email from the ASF dual-hosted git repository.

marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new c053262  [MXNET-703] TensorRT runtime integration (#11325)
c053262 is described below

commit c053262613670f51e7b20f39895c74a8166dcb19
Author: Marek Kolodziej <mk...@gmail.com>
AuthorDate: Fri Aug 10 02:38:04 2018 -0700

    [MXNET-703] TensorRT runtime integration (#11325)
    
    * [MXNET-703] TensorRT runtime integration
    
    Co-authored-by: Clement Fuji-Tsang <ca...@hotmail.com>
    Co-authored-by: Kellen Sunderland <ke...@gmail.com>
    
    * correctly assign self._optimized_symbol in executor
    
    * declare GetTrtCompatibleSubsets and ReplaceSubgraph only if MXNET_USE_TENSORRT
    
    * add comments in ReplaceSubgraph
    
    * Addressing Haibin's code review points
    
    * Check that shared_buffer is not empty when USE_TENSORRT is set
    
    * Added check that TensorRT binding is for inference only
    
    * Removed redundant decl.
    
    * WIP Refactored TRT integration and tests
    
    * Add more build guards, remove unused code
    
    * Remove ccache report
    
    * Remove redundant const in declaration
    
    * Clean Cmake TRT files
    
    * Remove TensorRT env var usage
    
    We don't want to use environment variables with TensorRT yet, the
    logic being that we want to try and have as much fwd compatiblity as
    possible when working on an experimental feature.  Were we to add
    env vars they would have to be gaurenteed to work in the future until
    a major version change.  Moving the functionality to a contrib call
    reduces this risk.
    
    * Use contrib optimize_graph instaed of bind
    
    * Clean up cycle detector
    
    * Convert lenet test to contrib optimize
    
    * Protect interface with trt build flag
    
    * Fix whitespace issues
    
    * Add another build guard to c_api
    
    * Move get_optimized_symbol to contrib area
    
    * Ignore gz files in test folder
    
    * Make trt optimization implicit
    
    * Remove unused declaration
    
    * Replace build guards with runtime errors
    
    * Change default value of TensorRT to off
    
    This is change applies to both TensorRT and non-TensorRT builds.
    
    * Warn user when TRT not active at runtime
    
    * Move TensorRTBind declaration, add descriptive errors
    
    * Test TensorRT graph execution, fix bugs
    
    * Fix lint and whitespace issues
    
    * Fix typo
    
    * Removed default value for set_use_tensorrt
    
    * Improved documentation and fixed spacing issues
    
    * Move static exec funcs to util files
    
    * Update comments to match util style
    
    * Apply const to loop element
    
    * Fix a few namespace issues
    
    * Make static funcs inline to avoid compiler warning
    
    * Remove unused inference code from lenet5_train
    
    * Add explicit trt contrib bind, update tests to use it
    
    * Rename trt bind call
    
    * Remove documentation that is not needed for trt
    
    * Reorder arguments, allow position calling
---
 .gitmodules                                        |   3 +
 3rdparty/onnx-tensorrt                             |   1 +
 CMakeLists.txt                                     |  31 ++
 Jenkinsfile                                        |  28 +
 Makefile                                           |   8 +
 amalgamation/amalgamation.py                       |  14 +-
 .../docker/Dockerfile.build.ubuntu_gpu_tensorrt    |  33 +-
 .../__init__.py => ci/docker/install/tensorrt.sh   |  41 +-
 ci/docker/runtime_functions.sh                     |  65 +++
 include/mxnet/c_api.h                              |   7 +
 include/mxnet/executor.h                           |   1 +
 python/mxnet/base.py                               |  16 +
 python/mxnet/contrib/__init__.py                   |   1 +
 python/mxnet/contrib/tensorrt.py                   | 110 ++++
 python/mxnet/executor.py                           |   1 +
 python/mxnet/module/executor_group.py              |   2 +-
 src/c_api/c_api_executor.cc                        |  62 ++-
 src/common/exec_utils.h                            | 255 +++++++++
 src/common/serialization.h                         | 319 +++++++++++
 src/common/utils.h                                 |  31 ++
 src/executor/exec_pass.h                           |  12 +
 src/executor/graph_executor.cc                     | 275 +---------
 src/executor/graph_executor.h                      |  30 +-
 src/executor/onnx_to_tensorrt.cc                   | 148 +++++
 src/executor/onnx_to_tensorrt.h                    |  77 +++
 src/executor/tensorrt_pass.cc                      | 596 +++++++++++++++++++++
 src/executor/trt_graph_executor.cc                 | 450 ++++++++++++++++
 src/executor/trt_graph_executor.h                  | 111 ++++
 src/operator/contrib/nnvm_to_onnx-inl.h            | 156 ++++++
 src/operator/contrib/nnvm_to_onnx.cc               | 527 ++++++++++++++++++
 src/operator/contrib/tensorrt-inl.h                | 113 ++++
 src/operator/contrib/tensorrt.cc                   | 183 +++++++
 src/operator/contrib/tensorrt.cu                   |  73 +++
 tests/.gitignore                                   |   1 +
 tests/cpp/misc/serialization.cc                    |  68 +++
 .../__init__.py => tests/python/tensorrt/common.py |  31 +-
 .../python/tensorrt/lenet5_common.py               |  29 +-
 tests/python/tensorrt/lenet5_train.py              |  84 +++
 tests/python/tensorrt/test_cvnets.py               | 179 +++++++
 tests/python/tensorrt/test_cycle.py                |  69 +++
 tests/python/tensorrt/test_tensorrt_lenet5.py      | 108 ++++
 tests/python/tensorrt/test_training_warning.py     |  70 +++
 42 files changed, 4059 insertions(+), 360 deletions(-)

diff --git a/.gitmodules b/.gitmodules
index 9aeb1c7..836d824 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -26,3 +26,6 @@
 [submodule "3rdparty/tvm"]
 	path = 3rdparty/tvm
 	url = https://github.com/dmlc/tvm
+[submodule "3rdparty/onnx-tensorrt"]
+	path = 3rdparty/onnx-tensorrt
+	url = https://github.com/onnx/onnx-tensorrt.git
diff --git a/3rdparty/onnx-tensorrt b/3rdparty/onnx-tensorrt
new file mode 160000
index 0000000..e7be19c
--- /dev/null
+++ b/3rdparty/onnx-tensorrt
@@ -0,0 +1 @@
+Subproject commit e7be19cff377a95817503e8525e20de34cdc574a
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 000bbbf..8c3e635 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -37,6 +37,7 @@ mxnet_option(ENABLE_CUDA_RTC      "Build with CUDA runtime compilation support"
 mxnet_option(BUILD_CPP_EXAMPLES   "Build cpp examples" ON)
 mxnet_option(INSTALL_EXAMPLES     "Install the example source files." OFF)
 mxnet_option(USE_SIGNAL_HANDLER   "Print stack traces on segfaults." OFF)
+mxnet_option(USE_TENSORRT         "Enable infeference optimization with TensorRT." OFF)
 
 message(STATUS "CMAKE_SYSTEM_NAME ${CMAKE_SYSTEM_NAME}")
 if(USE_CUDA AND NOT USE_OLDCMAKECUDA)
@@ -185,6 +186,36 @@ if(USE_VTUNE)
   list(APPEND mxnet_LINKER_LIBS dl)
 endif()
 
+if(USE_TENSORRT)
+  message(STATUS "Using TensorRT")
+  set(ONNX_PATH 3rdparty/onnx-tensorrt/third_party/onnx/build/)
+  set(ONNX_TRT_PATH 3rdparty/onnx-tensorrt/build/)
+
+  include_directories(${ONNX_PATH})
+  include_directories(3rdparty/onnx-tensorrt/)
+  include_directories(3rdparty/)
+  add_definitions(-DMXNET_USE_TENSORRT=1)
+  add_definitions(-DONNX_NAMESPACE=onnx)
+
+  find_package(Protobuf REQUIRED)
+
+  find_library(ONNX_LIBRARY NAMES libonnx.so REQUIRED
+          PATHS ${ONNX_PATH}
+          DOC "Path to onnx library.")
+  find_library(ONNX_PROTO_LIBRARY NAMES libonnx_proto.so REQUIRED
+          PATHS ${ONNX_PATH}
+          DOC "Path to onnx_proto library.")
+  find_library(ONNX_TRT_RUNTIME_LIBRARY NAMES libnvonnxparser_runtime.so REQUIRED
+          PATHS ${ONNX_TRT_PATH}
+          DOC "Path to onnx_proto library.")
+  find_library(ONNX_TRT_PARSER_LIBRARY NAMES libnvonnxparser.so REQUIRED
+          PATHS ${ONNX_TRT_PATH}
+          DOC "Path to onnx_proto library.")
+
+  list(APPEND mxnet_LINKER_LIBS libnvinfer.so ${ONNX_TRT_PARSER_LIBRARY} ${ONNX_TRT_RUNTIME_LIBRARY}
+          ${ONNX_PROTO_LIBRARY} ${ONNX_LIBRARY} ${PROTOBUF_LIBRARY})
+endif()
+
 if(USE_MKLDNN)
   include(cmake/MklDnn.cmake)
   # CPU architecture (e.g., C5) can't run on another architecture (e.g., g3).
diff --git a/Jenkinsfile b/Jenkinsfile
index d74f0b4..003e79c 100644
--- a/Jenkinsfile
+++ b/Jenkinsfile
@@ -30,6 +30,7 @@ mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/li
 mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests'
 mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0'
 mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
+mx_tensorrt_lib = 'lib/libmxnet.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
 // timeout in minutes
 max_time = 120
 
@@ -301,6 +302,17 @@ core_logic: {
         }
       }
     },
+    'TensorRT': {
+      node(NODE_LINUX_CPU) {
+        ws('workspace/build-tensorrt') {
+          timeout(time: max_time, unit: 'MINUTES') {
+            utils.init_git()
+            utils.docker_run('ubuntu_gpu_tensorrt', 'build_ubuntu_gpu_tensorrt', false)
+            utils.pack_lib('tensorrt', mx_tensorrt_lib)
+          }
+        }
+      }
+    },
     'Build CPU windows':{
       node(NODE_WINDOWS_CPU) {
         timeout(time: max_time, unit: 'MINUTES') {
@@ -616,6 +628,22 @@ core_logic: {
         }
       }
     },
+    'Python3: TensorRT GPU': {
+      node(NODE_LINUX_GPU_P3) {
+        ws('workspace/build-tensorrt') {
+          timeout(time: max_time, unit: 'MINUTES') {
+            try {
+              utils.init_git()
+              utils.unpack_lib('tensorrt', mx_tensorrt_lib)
+              utils.docker_run('ubuntu_gpu_tensorrt', 'unittest_ubuntu_tensorrt_gpu', true)
+              utils.publish_test_coverage()
+            } finally {
+              utils.collect_test_results_unix('nosetests_tensorrt.xml', 'nosetests_python3_tensorrt_gpu.xml')
+            }
+          }
+        }
+      }
+    },
     'Scala: CPU': {
       node(NODE_LINUX_CPU) {
         ws('workspace/ut-scala-cpu') {
diff --git a/Makefile b/Makefile
index 18661aa..6e93567 100644
--- a/Makefile
+++ b/Makefile
@@ -91,6 +91,14 @@ else
 endif
 CFLAGS += -I$(TPARTYDIR)/mshadow/ -I$(TPARTYDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(TPARTYDIR)/tvm/include -Iinclude $(MSHADOW_CFLAGS)
 LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS)
+
+
+ifeq ($(USE_TENSORRT), 1)
+	CFLAGS +=  -I$(ROOTDIR) -I$(TPARTYDIR) -DONNX_NAMESPACE=$(ONNX_NAMESPACE) -DMXNET_USE_TENSORRT=1
+	LDFLAGS += -lprotobuf -pthread -lonnx -lonnx_proto -lnvonnxparser -lnvonnxparser_runtime -lnvinfer -lnvinfer_plugin
+endif
+# -L/usr/local/lib
+
 ifeq ($(DEBUG), 1)
 	NVCCFLAGS += -std=c++11 -Xcompiler -D_FORCE_INLINES -g -G -O0 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS)
 else
diff --git a/amalgamation/amalgamation.py b/amalgamation/amalgamation.py
index 52d775b..a3c28f7 100644
--- a/amalgamation/amalgamation.py
+++ b/amalgamation/amalgamation.py
@@ -23,13 +23,12 @@ from io import BytesIO, StringIO
 import platform
 
 blacklist = [
-    'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh',
-    'cuda_runtime.h', 'cudnn.h', 'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h',
-    'glog/logging.h', 'io/azure_filesys.h', 'io/hdfs_filesys.h', 'io/s3_filesys.h',
-    'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h',
-    'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h',
-    'nvml.h', 'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h',
-    'omp.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h',
+    'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh', 'cuda_runtime.h', 'cudnn.h',
+    'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h', 'glog/logging.h', 'io/azure_filesys.h',
+    'io/hdfs_filesys.h', 'io/s3_filesys.h', 'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h',
+    'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h', 'NvInfer.h', 'nvml.h',
+    'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h', 'omp.h',
+    'onnx/onnx.pb.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h',
     'cusolverDn.h', 'internal/concurrentqueue_internal_debug.h', 'relacy/relacy_std.hpp',
     'relacy_shims.h', 'ittnotify.h', 'shared_mutex'
     ]
@@ -150,6 +149,7 @@ def expand(x, pending, stage):
                     h not in sysheaders and
                     'mkl' not in h and
                     'nnpack' not in h and
+                    'tensorrt' not in h and
                     not h.endswith('.cuh')): sysheaders.append(h)
             else:
                 expand.treeDepth += 1
diff --git a/python/mxnet/contrib/__init__.py b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
old mode 100644
new mode 100755
similarity index 58%
copy from python/mxnet/contrib/__init__.py
copy to ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
index fbfd346..255da31
--- a/python/mxnet/contrib/__init__.py
+++ b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
@@ -1,3 +1,4 @@
+# -*- mode: dockerfile -*-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -14,21 +15,27 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+#
+# Dockerfile to run MXNet on Ubuntu 16.04 for CPU
+
+FROM nvidia/cuda:9.0-cudnn7-devel
 
-# coding: utf-8
-"""Experimental contributions"""
+WORKDIR /work/deps
 
-from . import symbol
-from . import ndarray
+COPY install/ubuntu_core.sh /work/
+RUN /work/ubuntu_core.sh
+COPY install/deb_ubuntu_ccache.sh /work/
+RUN /work/deb_ubuntu_ccache.sh
+COPY install/ubuntu_python.sh /work/
+RUN /work/ubuntu_python.sh
+COPY install/tensorrt.sh /work
+RUN /work/tensorrt.sh
 
-from . import symbol as sym
-from . import ndarray as nd
+ARG USER_ID=0
+COPY install/ubuntu_adduser.sh /work/
+RUN /work/ubuntu_adduser.sh
 
-from . import autograd
-from . import tensorboard
+COPY runtime_functions.sh /work/
 
-from . import text
-from . import onnx
-from . import io
-from . import quantization
-from . import quantization as quant
+WORKDIR /work/mxnet
+ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib
diff --git a/python/mxnet/contrib/__init__.py b/ci/docker/install/tensorrt.sh
old mode 100644
new mode 100755
similarity index 50%
copy from python/mxnet/contrib/__init__.py
copy to ci/docker/install/tensorrt.sh
index fbfd346..a6258d9
--- a/python/mxnet/contrib/__init__.py
+++ b/ci/docker/install/tensorrt.sh
@@ -1,3 +1,5 @@
+#!/bin/bash
+
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -15,20 +17,29 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# coding: utf-8
-"""Experimental contributions"""
-
-from . import symbol
-from . import ndarray
-
-from . import symbol as sym
-from . import ndarray as nd
+# Install gluoncv since we're testing Gluon models as well
+pip2 install gluoncv==0.2.0
+pip3 install gluoncv==0.2.0
 
-from . import autograd
-from . import tensorboard
+# Install Protobuf
+# Install protoc 3.5 and build protobuf here (for onnx and onnx-tensorrt)
+pushd .
+cd ..
+apt-get update
+apt-get install -y automake libtool
+git clone --recursive -b 3.5.1.1 https://github.com/google/protobuf.git
+cd protobuf
+./autogen.sh
+./configure
+make -j$(nproc)
+make install
+ldconfig
+popd
 
-from . import text
-from . import onnx
-from . import io
-from . import quantization
-from . import quantization as quant
+# Install TensorRT
+echo "TensorRT build enabled. Installing TensorRT."
+wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0_1-1_amd64.deb
+dpkg -i tensorrt.deb
+apt-get update
+apt-get install -y --allow-downgrades libnvinfer-dev
+rm tensorrt.deb
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 1c861be..815eae9 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -414,6 +414,60 @@ build_ubuntu_gpu() {
     build_ubuntu_gpu_cuda91_cudnn7
 }
 
+build_ubuntu_gpu_tensorrt() {
+
+    set -ex
+
+    build_ccache_wrappers
+
+    # Build ONNX
+    pushd .
+    echo "Installing ONNX."
+    cd 3rdparty/onnx-tensorrt/third_party/onnx
+    rm -rf build
+    mkdir -p build
+    cd build
+    cmake \
+        -DCMAKE_CXX_FLAGS=-I/usr/include/python${PYVER}\
+        -DBUILD_SHARED_LIBS=ON ..\
+        -G Ninja
+    ninja -v
+    export LIBRARY_PATH=`pwd`:`pwd`/onnx/:$LIBRARY_PATH
+    export CPLUS_INCLUDE_PATH=`pwd`:$CPLUS_INCLUDE_PATH
+    popd
+
+    # Build ONNX-TensorRT
+    pushd .
+    cd 3rdparty/onnx-tensorrt/
+    mkdir -p build
+    cd build
+    cmake ..
+    make -j$(nproc)
+    export LIBRARY_PATH=`pwd`:$LIBRARY_PATH
+    popd
+
+    mkdir -p /work/mxnet/lib/
+    cp 3rdparty/onnx-tensorrt/third_party/onnx/build/*.so /work/mxnet/lib/
+    cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser_runtime.so.0 /work/mxnet/lib/
+    cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser.so.0 /work/mxnet/lib/
+
+    rm -rf build
+    make \
+        DEV=1                                               \
+        USE_BLAS=openblas                                   \
+        USE_CUDA=1                                          \
+        USE_CUDA_PATH=/usr/local/cuda                       \
+        USE_CUDNN=1                                         \
+        USE_OPENCV=0                                        \
+        USE_DIST_KVSTORE=0                                  \
+        USE_TENSORRT=1                                      \
+        USE_JEMALLOC=0                                      \
+        USE_GPERFTOOLS=0                                    \
+        ONNX_NAMESPACE=onnx                                 \
+        CUDA_ARCH="-gencode arch=compute_70,code=compute_70"\
+        -j$(nproc)
+}
+
 build_ubuntu_gpu_mkldnn() {
     set -ex
 
@@ -610,6 +664,15 @@ unittest_ubuntu_python3_gpu_nocudnn() {
     nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_gpu.xml --verbose tests/python/gpu
 }
 
+unittest_ubuntu_tensorrt_gpu() {
+    set -ex
+    export PYTHONPATH=./python/
+    export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
+    export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH
+    python tests/python/tensorrt/lenet5_train.py
+    nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_trt_gpu.xml --verbose tests/python/tensorrt/
+}
+
 # quantization gpu currently only runs on P3 instances
 # need to separte it from unittest_ubuntu_python2_gpu()
 unittest_ubuntu_python2_quantization_gpu() {
@@ -961,3 +1024,5 @@ EOF
     declare -F | cut -d' ' -f3
     echo
 fi
+
+
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 6bbe9df..43f8227 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1761,6 +1761,13 @@ MXNET_DLL int MXExecutorReshape(int partial_shaping,
                                 NDArrayHandle** aux_states,
                                 ExecutorHandle shared_exec,
                                 ExecutorHandle *out);
+
+/*!
+ * \brief get optimized graph from graph executor
+ */
+MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle,
+                                           SymbolHandle *out);
+
 /*!
  * \brief set a call back to notify the completion of operation
  */
diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h
index 842653f..0ab04b8 100644
--- a/include/mxnet/executor.h
+++ b/include/mxnet/executor.h
@@ -166,6 +166,7 @@ class Executor {
                               std::unordered_map<std::string, NDArray>*
                                 shared_data_arrays = nullptr,
                               Executor* shared_exec = nullptr);
+
   /*!
    * \brief the prototype of user-defined monitor callback
    */
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 3d8ee01..1bbc121 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -729,3 +729,19 @@ def _generate_op_module_signature(root_namespace, module_name, op_code_gen_func)
     module_op_file.close()
     write_all_str(module_internal_file, module_internal_all)
     module_internal_file.close()
+
+def cint(init_val=0):
+    """create a C int with an optional initial value"""
+    return C.c_int(init_val)
+
+def int_addr(x):
+    """given a c_int, return it's address as an int ptr"""
+    x_addr = C.addressof(x)
+    int_p = C.POINTER(C.c_int)
+    x_int_addr = C.cast(x_addr, int_p)
+    return x_int_addr
+
+def checked_call(f, *args):
+    """call a cuda function and check for success"""
+    error_t = f(*args)
+    assert error_t == 0, "Failing cuda call %s returns %s." % (f.__name__, error_t)
diff --git a/python/mxnet/contrib/__init__.py b/python/mxnet/contrib/__init__.py
index fbfd346..606bb0a 100644
--- a/python/mxnet/contrib/__init__.py
+++ b/python/mxnet/contrib/__init__.py
@@ -32,3 +32,4 @@ from . import onnx
 from . import io
 from . import quantization
 from . import quantization as quant
+from . import tensorrt
diff --git a/python/mxnet/contrib/tensorrt.py b/python/mxnet/contrib/tensorrt.py
new file mode 100644
index 0000000..bb20767
--- /dev/null
+++ b/python/mxnet/contrib/tensorrt.py
@@ -0,0 +1,110 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Module to enable the use of TensorRT optimized graphs."""
+
+import ctypes
+import logging
+import os
+
+from mxnet.symbol import Symbol
+
+from ..base import _LIB, SymbolHandle, MXNetError
+from ..base import check_call
+
+
+def set_use_tensorrt(status):
+    """
+    Set an environment variable which will enable or disable the use of TensorRT in the backend.
+    Note: this is useful for A/B testing purposes.
+    :param status: Boolean, true if TensorRT optimization should be applied, False for legacy
+    behaviour.
+    """
+    os.environ["MXNET_USE_TENSORRT"] = str(int(status))
+
+
+def get_use_tensorrt():
+    """
+    Get an environment variable which describes if TensorRT is currently enabled in the backend.
+    Note: this is useful for A/B testing purposes.
+    :return: Boolean, true if TensorRT optimization should be applied, False for legacy
+    behaviour.
+    """
+    return bool(int(os.environ.get("MXNET_USE_TENSORRT", 0)) == 1)
+
+
+def get_optimized_symbol(executor):
+    """
+    Take an executor's underlying symbol graph and return its generated optimized version.
+
+    Parameters
+    ----------
+    executor :
+        An executor for which you want to see an optimized symbol. Getting an optimized symbol
+        is useful to compare and verify the work TensorRT has done against a legacy behaviour.
+
+    Returns
+    -------
+    symbol : nnvm::Symbol
+        The nnvm symbol optimized.
+    """
+    handle = SymbolHandle()
+    try:
+        check_call(_LIB.MXExecutorGetOptimizedSymbol(executor.handle, ctypes.byref(handle)))
+        result = Symbol(handle=handle)
+        return result
+    except MXNetError:
+        logging.error('Error while trying to fetch TRT optimized symbol for graph. Please ensure '
+                      'build was compiled with MXNET_USE_TENSORRT enabled.')
+        raise
+
+
+def tensorrt_bind(symbol, ctx, all_params, type_dict=None, stype_dict=None, group2ctx=None,
+                  **kwargs):
+    """Bind current symbol to get an optimized trt executor.
+
+    Parameters
+    ----------
+    symbol : Symbol
+        The symbol you wish to bind, and optimize with TensorRT.
+
+    ctx : Context
+        The device context the generated executor to run on.
+
+    all_params : Dict of str->ndarray
+        A dictionary of mappings from parameter names to parameter NDArrays.
+
+    type_dict  : Dict of str->numpy.dtype
+        Input type dictionary, name->dtype
+
+    stype_dict  : Dict of str->str
+        Input storage type dictionary, name->storage_type
+
+    group2ctx : Dict of string to mx.Context
+        The dict mapping the `ctx_group` attribute to the context assignment.
+
+    kwargs : Dict of str->shape
+        Input shape dictionary, name->shape
+
+    Returns
+    -------
+    executor : mxnet.Executor
+        An optimized TensorRT executor.
+    """
+    kwargs['shared_buffer'] = all_params
+    return symbol.simple_bind(ctx, type_dict=type_dict, stype_dict=stype_dict,
+                              group2ctx=group2ctx, **kwargs)
diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py
index c0272c5..fcd5406 100644
--- a/python/mxnet/executor.py
+++ b/python/mxnet/executor.py
@@ -73,6 +73,7 @@ class Executor(object):
         self.aux_arrays = []
         self.outputs = self._get_outputs()
         self._symbol = copy.deepcopy(symbol)
+        self._optimized_symbol = None
         self._arg_dict = None
         self._grad_dict = None
         self._aux_dict = None
diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py
index 5d8e950..c405069 100755
--- a/python/mxnet/module/executor_group.py
+++ b/python/mxnet/module/executor_group.py
@@ -592,8 +592,8 @@ class DataParallelExecutorGroup(object):
                     # pylint: disable=no-member
                     og_my_slice = nd.slice_axis(grad, axis=axis, begin=islice.start,
                                                 end=islice.stop)
-                    # pylint: enable=no-member
                     out_grads_slice.append(og_my_slice.as_in_context(self.contexts[i]))
+                    # pylint: enable=no-member
                 else:
                     out_grads_slice.append(grad.copyto(self.contexts[i]))
             exec_.backward(out_grads=out_grads_slice)
diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc
index 09bc239..b993505 100644
--- a/src/c_api/c_api_executor.cc
+++ b/src/c_api/c_api_executor.cc
@@ -26,6 +26,10 @@
 #include <mxnet/c_api.h>
 #include <mxnet/executor.h>
 #include "./c_api_common.h"
+#include "../executor/graph_executor.h"
+#if MXNET_USE_TENSORRT
+#include "../executor/trt_graph_executor.h"
+#endif  // MXNET_USE_TENSORRT
 
 int MXExecutorPrint(ExecutorHandle handle, const char **out_str) {
   Executor *exec = static_cast<Executor*>(handle);
@@ -439,13 +443,38 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
   std::vector<NDArray> in_arg_vec;
   std::vector<NDArray> arg_grad_vec;
   std::vector<NDArray> aux_state_vec;
-
-  *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec,
-                              aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map,
-                              grad_req_type_vec, shared_arg_name_set, &in_arg_vec,
-                              &arg_grad_vec, &aux_state_vec,
-                              use_shared_buffer ? &shared_buffer_map : nullptr,
-                              reinterpret_cast<Executor*>(shared_exec_handle));
+#if MXNET_USE_TENSORRT
+  // If we've built with TensorRT support we by default return an TRTExecutor.
+  // Users can override this behaviour via env var, which is useful for example for A/B
+  // performance testing.
+  if (dmlc::GetEnv("MXNET_USE_TENSORRT", false)) {
+    *out = exec::TrtGraphExecutor::TensorRTBind(*sym, ctx, ctx_map, &in_arg_ctx_vec,
+                                                &arg_grad_ctx_vec, &aux_state_ctx_vec,
+                                                &arg_shape_map, &arg_dtype_map, &arg_stype_map,
+                                                &grad_req_type_vec, shared_arg_name_set,
+                                                &in_arg_vec, &arg_grad_vec, &aux_state_vec,
+                                                use_shared_buffer ? &shared_buffer_map : nullptr,
+                                                reinterpret_cast<Executor*>(shared_exec_handle));
+  } else {
+    // Checks to see if this env var has been set to true or false by the user.
+    // If the user is using a TensorRT build, but has not enabled TRT at inference time, warn
+    // them and describe further steps.
+    const int unset_indicator =  std::numeric_limits<int>::quiet_NaN();
+    if (dmlc::GetEnv("MXNET_USE_TENSORRT", unset_indicator) == unset_indicator) {
+      LOG(INFO) << "TensorRT not enabled by default.  Please set the MXNET_USE_TENSORRT "
+                   "environment variable to 1 or call mx.contrib.tensorrt.set_use_tensorrt(True) "
+                   "to enable.";
+    }
+#endif  // MXNET_USE_TENSORRT
+    *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec,
+                                aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map,
+                                grad_req_type_vec, shared_arg_name_set, &in_arg_vec,
+                                &arg_grad_vec, &aux_state_vec,
+                                use_shared_buffer ? &shared_buffer_map : nullptr,
+                                reinterpret_cast<Executor*>(shared_exec_handle));
+#if MXNET_USE_TENSORRT
+  }
+#endif  // MXNET_USE_TENSORRT
 
   // copy ndarray ptrs to ret->handles so that front end
   // can access them
@@ -597,6 +626,25 @@ int MXExecutorReshape(int partial_shaping,
   API_END_HANDLE_ERROR(delete out);
 }
 
+int MXExecutorGetOptimizedSymbol(ExecutorHandle handle,
+                                 SymbolHandle *out) {
+  auto s = new nnvm::Symbol();
+  API_BEGIN();
+
+#if MXNET_USE_TENSORRT
+  auto exec = static_cast<exec::TrtGraphExecutor*>(handle);
+  *s = exec->GetOptimizedSymbol();
+  *out = s;
+#else
+  LOG(FATAL) << "GetOptimizedSymbol may only be used when MXNet is compiled with "
+                "MXNET_USE_TENSORRT enabled.  Please re-compile MXNet with TensorRT support.";
+#endif  // MXNET_USE_TENSORRT
+
+  API_END_HANDLE_ERROR(delete s);
+}
+
+
+
 int MXExecutorSetMonitorCallback(ExecutorHandle handle,
                                  ExecutorMonitorCallback callback,
                                  void* callback_handle) {
diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h
index 816599b..fbe5442 100644
--- a/src/common/exec_utils.h
+++ b/src/common/exec_utils.h
@@ -24,10 +24,14 @@
 #ifndef MXNET_COMMON_EXEC_UTILS_H_
 #define MXNET_COMMON_EXEC_UTILS_H_
 
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+#include <map>
 #include <vector>
 #include <string>
 #include <utility>
 #include "../common/utils.h"
+#include "../executor/exec_pass.h"
 
 namespace mxnet {
 namespace common {
@@ -366,6 +370,257 @@ inline void LogInferStorage(const nnvm::Graph& g) {
   }
 }
 
+// prints a helpful message after shape inference errors in executor.
+inline void HandleInferShapeError(const size_t num_forward_inputs,
+                                  const nnvm::IndexedGraph& idx,
+                                  const nnvm::ShapeVector& inferred_shapes) {
+  int cnt = 10;
+  std::ostringstream oss;
+  for (size_t i = 0; i < num_forward_inputs; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const uint32_t eid = idx.entry_id(nid, 0);
+    const TShape& inferred_shape = inferred_shapes[eid];
+    if (inferred_shape.ndim() == 0 || inferred_shape.Size() == 0U) {
+      const std::string& arg_name = idx[nid].source->attrs.name;
+      oss << arg_name << ": " << inferred_shape << ", ";
+      if (--cnt == 0) {
+        oss << "...";
+        break;
+      }
+    }
+  }
+  LOG(FATAL) << "InferShape pass cannot decide shapes for the following arguments "
+                "(0s means unknown dimensions). Please consider providing them as inputs:\n"
+             << oss.str();
+}
+
+// prints a helpful message after type inference errors in executor.
+inline void HandleInferTypeError(const size_t num_forward_inputs,
+                                 const nnvm::IndexedGraph& idx,
+                                 const nnvm::DTypeVector& inferred_dtypes) {
+  int cnt = 10;
+  std::ostringstream oss;
+  for (size_t i = 0; i < num_forward_inputs; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const uint32_t eid = idx.entry_id(nid, 0);
+    const int inferred_dtype = inferred_dtypes[eid];
+    if (inferred_dtype == -1) {
+      const std::string& arg_name = idx[nid].source->attrs.name;
+      oss << arg_name << ": " << inferred_dtype << ", ";
+      if (--cnt == 0) {
+        oss << "...";
+        break;
+      }
+    }
+  }
+  LOG(FATAL) << "InferType pass cannot decide dtypes for the following arguments "
+                "(-1 means unknown dtype). Please consider providing them as inputs:\n"
+             << oss.str();
+}
+
+// prints a helpful message after storage type checking errors in executor.
+inline void HandleInferStorageTypeError(const size_t num_forward_inputs,
+                                        const nnvm::IndexedGraph& idx,
+                                        const StorageTypeVector& inferred_stypes) {
+  int cnt = 10;
+  std::ostringstream oss;
+  for (size_t i = 0; i < num_forward_inputs; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const uint32_t eid = idx.entry_id(nid, 0);
+    const int inferred_stype = inferred_stypes[eid];
+    if (inferred_stype == -1) {
+      const std::string& arg_name = idx[nid].source->attrs.name;
+      oss << arg_name << ": " << common::stype_string(inferred_stype) << ", ";
+      if (--cnt == 0) {
+        oss << "...";
+        break;
+      }
+    }
+  }
+  LOG(FATAL) << "InferStorageType pass cannot decide storage type for the following arguments "
+                "(-1 means unknown stype). Please consider providing them as inputs:\n"
+             << oss.str();
+}
+
+/*!
+ * \brief If the requested ndarray's shape size is less than
+ * the corresponding shared_data_array's shape size and the
+ * storage type is shareable, reuse the memory allocation
+ * in shared_buffer; otherwise, create a zero ndarray.
+ * Shareable storages include both default storage and row_sparse storage
+ * if enable_row_sparse_sharing is `True`, otherwise default storage only.
+ */
+inline NDArray ReshapeOrCreate(const std::string& name,
+                               const TShape& dest_arg_shape,
+                               const int dest_arg_dtype,
+                               const NDArrayStorageType dest_arg_stype,
+                               const Context& ctx,
+                               std::unordered_map<std::string, NDArray>* shared_buffer,
+                               bool enable_row_sparse_sharing) {
+  bool stype_shareable = dest_arg_stype == kDefaultStorage;
+  if (enable_row_sparse_sharing) {
+    stype_shareable = stype_shareable || dest_arg_stype == kRowSparseStorage;
+  }
+  auto it = shared_buffer->find(name);
+  if (it != shared_buffer->end()) {
+    // check if size is large enough for sharing
+    bool size_shareable = it->second.shape().Size() >= dest_arg_shape.Size();
+    if (size_shareable && stype_shareable) {  // memory can be reused
+      CHECK_EQ(it->second.dtype(), dest_arg_dtype)
+          << "Requested arg array's dtype does not match that of the reusable ndarray";
+      CHECK_EQ(it->second.storage_type(), dest_arg_stype)
+          << "Requested arg array's stype does not match that of the reusable ndarray";
+      return it->second.Reshape(dest_arg_shape);
+    } else if (stype_shareable) {
+      LOG(WARNING) << "Bucketing: data " << name << " has a shape " << dest_arg_shape
+                   << ", which is larger than already allocated shape " << it->second.shape()
+                   << ". Need to re-allocate. Consider putting default bucket key to be "
+                   << "the bucket taking the largest input for better memory sharing.";
+      // size is not large enough, creating a larger one for sharing
+      // the NDArrays in shared_buffer are guaranteed to be of shareable storages
+      it->second = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
+      return it->second;
+    } else {
+      // not shareable storage
+      return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
+    }
+  } else {
+    auto ret = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
+    if (stype_shareable) {
+      shared_buffer->emplace(name, ret);
+    }
+    return ret;
+  }  // if (it != shared_buffer->end())
+}
+
+/*!
+ * \brief Assign context to the graph.
+ * This is triggered by both simple_bind and bind flows.
+ */
+inline nnvm::Graph AssignContext(nnvm::Graph g,
+                                 const Context& default_ctx,
+                                 const std::map<std::string, Context>& ctx_map,
+                                 const std::vector<Context>& in_arg_ctxes,
+                                 const std::vector<Context>& arg_grad_ctxes,
+                                 const std::vector<Context>& aux_state_ctxes,
+                                 const std::vector<OpReqType>& grad_req_types,
+                                 size_t num_forward_inputs,
+                                 size_t num_forward_outputs) {
+  const auto& idx = g.indexed_graph();
+  const auto& mutable_nodes = idx.mutable_input_nodes();
+  // default use default context.
+  if (ctx_map.size() == 0) {
+    g.attrs["context"] = std::make_shared<nnvm::any>(
+        exec::ContextVector(idx.num_nodes(), default_ctx));
+    for (const auto& x : in_arg_ctxes) {
+      CHECK(x == default_ctx)
+          << "Input array is in " << x << " while binding with ctx=" << default_ctx
+          << ". All arguments must be in global context (" << default_ctx
+          << ") unless group2ctx is specified for cross-device graph.";
+    }
+    for (const auto& x : arg_grad_ctxes) {
+      CHECK(x == default_ctx)
+          << "Gradient array is in " << x << " while binding with ctx="
+          << default_ctx << ". All gradients must be in global context (" << default_ctx
+          << ") unless group2ctx is specified for cross-device graph.";
+    }
+    return g;
+  }
+
+  // otherwise, use context assignment.
+  std::map<Context, int> ctx2id;  // map ctx to device id
+  std::vector<Context> ctx_list;  // index is device id
+  nnvm::DeviceVector device(idx.num_nodes(), -1);  // index is node id
+  nnvm::DeviceAssignMap device_map;  // map arg name to device id
+
+  // loop through the user input ctx_map and
+  // populate maps and lists
+  for (auto &kv : ctx_map) {
+    if (ctx2id.count(kv.second) == 0) {  // if context has no device id, create one
+      ctx2id[kv.second] = static_cast<int>(ctx_list.size());  // assign device id to ctx
+      ctx_list.push_back(kv.second);  // save ctx to the list
+    }
+    // assign device id to to the arg name with the corresponding ctx
+    device_map[kv.first] = ctx2id.at(kv.second);
+  }
+
+  // loop through all the rest of input nodes not specified
+  // in the ctx_map and populate maps and lists
+  size_t arg_top = 0, aux_top = 0;
+  for (size_t i = 0; i < num_forward_inputs; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    Context ctx;
+    if (mutable_nodes.count(nid)) {  // aux node is mutable
+      CHECK_LT(aux_top, aux_state_ctxes.size());
+      ctx = aux_state_ctxes[aux_top];
+      ++aux_top;
+    } else {  // regular input node is immutable
+      CHECK_LT(arg_top, in_arg_ctxes.size());
+      ctx = in_arg_ctxes[arg_top];
+      ++arg_top;
+    }
+    if (ctx2id.count(ctx) == 0) {  // if the current ctx is not in the map of ctx and device id
+      ctx2id[ctx] = static_cast<int>(ctx_list.size());  // assign the current ctx with device id
+      ctx_list.push_back(ctx);  // save the current ctx in the list
+    }
+    device[nid] = ctx2id.at(ctx);  // assign device id to the current node
+  }
+
+  // loop through backward input nodes and populate maps and lists
+  // the backward input nodes is the gradient of the loss wrt the output
+  size_t arg_grad_offset = 0;
+  // keep an offset into the arg_grad_ctxes vector,
+  // since g.outputs exclude arg_grad whose req == null
+  CHECK_GE(grad_req_types.size(), g.outputs.size() - num_forward_outputs)
+      << "insufficient number of grad_reqs";
+  for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) {
+    while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset;
+    const uint32_t nid = idx.outputs()[i].node_id;
+    Context ctx = arg_grad_ctxes[arg_grad_offset];
+    if (ctx2id.count(ctx) == 0) {
+      ctx2id[ctx] = static_cast<int>(ctx_list.size());
+      ctx_list.push_back(ctx);
+    }
+    int devid = ctx2id.at(ctx);
+    if (device[nid] != -1) {
+      CHECK_EQ(device[nid], devid) << "device of same output not equal to each other";
+    } else {
+      device[nid] = devid;
+    }
+  }
+
+  g.attrs["device"] = std::make_shared<dmlc::any>(std::move(device));
+  g = nnvm::pass::PlaceDevice(g, "__ctx_group__", device_map, "_CrossDeviceCopy");
+  const auto& assigned_device = g.GetAttr<nnvm::DeviceVector>("device");
+
+  exec::ContextVector vcontext;
+  for (size_t i = 0; i < assigned_device.size(); ++i) {
+    if (assigned_device[i] == -1) {
+      vcontext.push_back(default_ctx);
+    } else {
+      vcontext.push_back(ctx_list[assigned_device[i]]);
+    }
+  }
+
+  // after device planning, we should check again
+  // if the assigned device of gradient node
+  // corresponds to storage of grads
+  auto &new_idx = g.indexed_graph();
+  arg_grad_offset = 0;
+  for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) {
+    while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset;
+    const uint32_t nid = new_idx.outputs()[i].node_id;
+    Context ctx = arg_grad_ctxes[arg_grad_offset];
+    CHECK(ctx == vcontext[nid])
+        << "Trying to save gradient to " << ctx
+        << " while its source node \"" << new_idx[nid].source->attrs.name
+        << "\" computes it on " << vcontext[nid]
+        << ". Check your ctx in NDArray allocation.";
+  }
+
+  g.attrs["context"] = std::make_shared<nnvm::any>(std::move(vcontext));
+  return g;
+}
 
 }  // namespace common
 }  // namespace mxnet
diff --git a/src/common/serialization.h b/src/common/serialization.h
new file mode 100644
index 0000000..8a1bcc6
--- /dev/null
+++ b/src/common/serialization.h
@@ -0,0 +1,319 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015 by Contributors
+ * \file serialization.h
+ * \brief Serialization of some STL and nnvm data-structures
+ * \author Clement Fuji Tsang
+ */
+
+#ifndef MXNET_COMMON_SERIALIZATION_H_
+#define MXNET_COMMON_SERIALIZATION_H_
+
+#include <dmlc/logging.h>
+#include <mxnet/graph_attr_types.h>
+#include <nnvm/graph_attr_types.h>
+#include <nnvm/tuple.h>
+
+#include <cstring>
+#include <map>
+#include <set>
+#include <string>
+#include <tuple>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+
+namespace mxnet {
+namespace common {
+
+template<typename T>
+inline size_t SerializedSize(const T &obj);
+
+template<typename T>
+inline size_t SerializedSize(const nnvm::Tuple <T> &obj);
+
+template<typename K, typename V>
+inline size_t SerializedSize(const std::map <K, V> &obj);
+
+template<>
+inline size_t SerializedSize(const std::string &obj);
+
+template<typename... Args>
+inline size_t SerializedSize(const std::tuple<Args...> &obj);
+
+template<typename T>
+inline void Serialize(const T &obj, char **buffer);
+
+template<typename T>
+inline void Serialize(const nnvm::Tuple <T> &obj, char **buffer);
+
+template<typename K, typename V>
+inline void Serialize(const std::map <K, V> &obj, char **buffer);
+
+template<>
+inline void Serialize(const std::string &obj, char **buffer);
+
+template<typename... Args>
+inline void Serialize(const std::tuple<Args...> &obj, char **buffer);
+
+template<typename T>
+inline void Deserialize(T *obj, const std::string &buffer, size_t *curr_pos);
+
+template<typename T>
+inline void Deserialize(nnvm::Tuple <T> *obj, const std::string &buffer, size_t *curr_pos);
+
+template<typename K, typename V>
+inline void Deserialize(std::map <K, V> *obj, const std::string &buffer, size_t *curr_pos);
+
+template<>
+inline void Deserialize(std::string *obj, const std::string &buffer, size_t *curr_pos);
+
+template<typename... Args>
+inline void Deserialize(std::tuple<Args...> *obj, const std::string &buffer, size_t *curr_pos);
+
+
+template<typename T>
+struct is_container {
+  static const bool value = !std::is_pod<T>::value;
+};
+
+template<typename T>
+inline size_t SerializedSize(const T &obj) {
+  return sizeof(T);
+}
+
+template<typename T>
+inline size_t SerializedSize(const nnvm::Tuple <T> &obj) {
+  if (is_container<T>::value) {
+    size_t sum_val = 4;
+    for (const auto& el : obj) {
+      sum_val += SerializedSize(el);
+    }
+    return sum_val;
+  } else {
+    return 4 + (obj.ndim() * sizeof(T));
+  }
+}
+
+template<typename K, typename V>
+inline size_t SerializedSize(const std::map <K, V> &obj) {
+  size_t sum_val = 4;
+  if (is_container<K>::value && is_container<V>::value) {
+    for (const auto& p : obj) {
+      sum_val += SerializedSize(p.first) + SerializedSize(p.second);
+    }
+  } else if (is_container<K>::value) {
+    for (const auto& p : obj) {
+      sum_val += SerializedSize(p.first);
+    }
+    sum_val += sizeof(V) * obj.size();
+  } else if (is_container<V>::value) {
+    for (const auto& p : obj) {
+      sum_val += SerializedSize(p.second);
+    }
+    sum_val += sizeof(K) * obj.size();
+  } else {
+    sum_val += (sizeof(K) + sizeof(V)) * obj.size();
+  }
+  return sum_val;
+}
+
+template<>
+inline size_t SerializedSize(const std::string &obj) {
+  return obj.size() + 4;
+}
+
+template<int I>
+struct serialized_size_tuple {
+  template<typename... Args>
+  static inline size_t Compute(const std::tuple<Args...> &obj) {
+    return SerializedSize(std::get<I>(obj)) + serialized_size_tuple<I-1>::Compute(obj);
+  }
+};
+
+template<>
+struct serialized_size_tuple<0> {
+  template<typename... Args>
+  static inline size_t Compute(const std::tuple<Args...> &obj) {
+    return SerializedSize(std::get<0>(obj));
+  }
+};
+
+template<typename... Args>
+inline size_t SerializedSize(const std::tuple<Args...> &obj) {
+  return serialized_size_tuple<sizeof... (Args)-1>::Compute(obj);
+}
+
+//  Serializer
+
+template<typename T>
+inline size_t SerializedContainerSize(const T &obj, char **buffer) {
+  uint32_t size = obj.size();
+  std::memcpy(*buffer, &size, 4);
+  *buffer += 4;
+  return (size_t) size;
+}
+
+template<typename T>
+inline void Serialize(const T &obj, char **buffer) {
+  std::memcpy(*buffer, &obj, sizeof(T));
+  *buffer += sizeof(T);
+}
+
+template<typename T>
+inline void Serialize(const nnvm::Tuple <T> &obj, char **buffer) {
+  uint32_t size = obj.ndim();
+  std::memcpy(*buffer, &size, 4);
+  *buffer += 4;
+  for (auto& el : obj) {
+    Serialize(el, buffer);
+  }
+}
+
+template<typename K, typename V>
+inline void Serialize(const std::map <K, V> &obj, char **buffer) {
+  SerializedContainerSize(obj, buffer);
+  for (auto& p : obj) {
+    Serialize(p.first, buffer);
+    Serialize(p.second, buffer);
+  }
+}
+
+template<>
+inline void Serialize(const std::string &obj, char **buffer) {
+  auto size = SerializedContainerSize(obj, buffer);
+  std::memcpy(*buffer, &obj[0], size);
+  *buffer += size;
+}
+
+template<int I>
+struct serialize_tuple {
+  template<typename... Args>
+  static inline void Compute(const std::tuple<Args...> &obj, char **buffer) {
+    serialize_tuple<I-1>::Compute(obj, buffer);
+    Serialize(std::get<I>(obj), buffer);
+  }
+};
+
+template<>
+struct serialize_tuple<0> {
+  template<typename... Args>
+  static inline void Compute(const std::tuple<Args...> &obj, char **buffer) {
+    Serialize(std::get<0>(obj), buffer);
+  }
+};
+
+template<typename... Args>
+inline void Serialize(const std::tuple<Args...> &obj, char **buffer) {
+  serialize_tuple<sizeof... (Args)-1>::Compute(obj, buffer);
+}
+
+// Deserializer
+
+template<typename T>
+inline size_t DeserializedContainerSize(T *obj, const std::string &buffer, size_t *curr_pos) {
+  uint32_t size = obj->size();
+  std::memcpy(&size, &buffer[*curr_pos], 4);
+  *curr_pos += 4;
+  return (size_t) size;
+}
+
+template<typename T>
+inline void Deserialize(T *obj, const std::string &buffer, size_t *curr_pos) {
+  std::memcpy(obj, &buffer[*curr_pos], sizeof(T));
+  *curr_pos += sizeof(T);
+}
+
+template<typename T>
+inline void Deserialize(nnvm::Tuple <T> *obj, const std::string &buffer, size_t *curr_pos) {
+  uint32_t size = obj->ndim();
+  std::memcpy(&size, &buffer[*curr_pos], 4);
+  *curr_pos += 4;
+  obj->SetDim(size);
+  for (size_t i = 0; i < size; ++i) {
+    Deserialize((*obj)[i], buffer, curr_pos);
+  }
+}
+
+template<typename K, typename V>
+inline void Deserialize(std::map <K, V> *obj, const std::string &buffer, size_t *curr_pos) {
+  auto size = DeserializedContainerSize(obj, buffer, curr_pos);
+  K first;
+  for (size_t i = 0; i < size; ++i) {
+    Deserialize(&first, buffer, curr_pos);
+    Deserialize(&(*obj)[first], buffer, curr_pos);
+  }
+}
+
+template<>
+inline void Deserialize(std::string *obj, const std::string &buffer, size_t *curr_pos) {
+  auto size = DeserializedContainerSize(obj, buffer, curr_pos);
+  obj->resize(size);
+  std::memcpy(&(obj->front()), &buffer[*curr_pos], size);
+  *curr_pos += size;
+}
+
+template<int I>
+struct deserialize_tuple {
+  template<typename... Args>
+  static inline void Compute(std::tuple<Args...> *obj,
+                             const std::string &buffer, size_t *curr_pos) {
+    deserialize_tuple<I-1>::Compute(obj, buffer, curr_pos);
+    Deserialize(&std::get<I>(*obj), buffer, curr_pos);
+  }
+};
+
+template<>
+struct deserialize_tuple<0> {
+  template<typename... Args>
+  static inline void Compute(std::tuple<Args...> *obj,
+                             const std::string &buffer, size_t *curr_pos) {
+    Deserialize(&std::get<0>(*obj), buffer, curr_pos);
+  }
+};
+
+template<typename... Args>
+inline void Deserialize(std::tuple<Args...> *obj, const std::string &buffer, size_t *curr_pos) {
+  deserialize_tuple<sizeof... (Args)-1>::Compute(obj, buffer, curr_pos);
+}
+
+
+template<typename T>
+inline void Serialize(const T& obj, std::string* serialized_data) {
+  serialized_data->resize(SerializedSize(obj));
+  char* curr_pos = &(serialized_data->front());
+  Serialize(obj, &curr_pos);
+  CHECK_EQ((int64_t)curr_pos - (int64_t)&(serialized_data->front()),
+           serialized_data->size());
+}
+
+template<typename T>
+inline void Deserialize(T* obj, const std::string& serialized_data) {
+  size_t curr_pos = 0;
+  Deserialize(obj, serialized_data, &curr_pos);
+  CHECK_EQ(curr_pos, serialized_data.size());
+}
+
+}  // namespace common
+}  // namespace mxnet
+#endif  // MXNET_COMMON_SERIALIZATION_H_
diff --git a/src/common/utils.h b/src/common/utils.h
index 96949a0..fcc3da8 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -675,6 +675,37 @@ MSHADOW_XINLINE int ilog2ui(unsigned int a) {
   return k;
 }
 
+/*!
+ * \brief Return an NDArray of all zeros.
+ */
+inline NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape,
+                         const Context &ctx, const int dtype) {
+  // NDArray with default storage
+  if (stype == kDefaultStorage) {
+    NDArray ret(shape, ctx, false, dtype);
+    ret = 0;
+    return ret;
+  }
+  // NDArray with non-default storage. Storage allocation is always delayed.
+  return NDArray(stype, shape, ctx, true, dtype);
+}
+
+/*!
+ * \brief Helper to add a NDArray of zeros to a std::vector.
+ */
+inline void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape,
+                             const Context &ctx, const int dtype,
+                             std::vector<NDArray> *vec) {
+  // NDArray with default storage
+  if (stype == kDefaultStorage) {
+    vec->emplace_back(shape, ctx, false, dtype);
+    vec->back() = 0;
+  } else {
+    // NDArray with non-default storage. Storage allocation is always delayed.
+    vec->emplace_back(stype, shape, ctx, true, dtype);
+  }
+}
+
 }  // namespace common
 }  // namespace mxnet
 #endif  // MXNET_COMMON_UTILS_H_
diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index 26a2491..8c483e9 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -198,6 +198,18 @@ Graph InferStorageType(Graph&& graph,
                        StorageTypeVector&& storage_type_inputs = StorageTypeVector(),
                        const std::string& storage_type_attr_key = "");
 
+#if MXNET_USE_TENSORRT
+/*!
+ * \brief Replace subgraphs by TRT (forward only)
+ */
+Graph ReplaceSubgraph(Graph&& g,
+                      const std::unordered_set<nnvm::Node*>& set_subgraph,
+                      std::unordered_map<std::string, NDArray>* const params_map);
+
+std::vector<std::unordered_set<nnvm::Node*>> GetTrtCompatibleSubsets(const Graph& g,
+    std::unordered_map<std::string, NDArray>* const params_map);
+#endif
+
 }  // namespace exec
 }  // namespace mxnet
 
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 33c6f57..0e80706 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -37,6 +37,8 @@
 namespace mxnet {
 namespace exec {
 
+using namespace mxnet::common;
+
 GraphExecutor::GraphExecutor() {
   log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false);
   need_grad_ = false;
@@ -56,30 +58,6 @@ GraphExecutor::~GraphExecutor() {
   }
 }
 
-inline NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape,
-                                const Context &ctx, const int dtype) {
-  // NDArray with default storage
-  if (stype == kDefaultStorage) {
-    NDArray ret(shape, ctx, false, dtype);
-    ret = 0;
-    return ret;
-  }
-  // NDArray with non-default storage. Storage allocation is always delayed.
-  return NDArray(stype, shape, ctx, true, dtype);
-}
-
-inline void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape,
-                             const Context &ctx, const int dtype,
-                             std::vector<NDArray> *vec) {
-  // NDArray with default storage
-  if (stype == kDefaultStorage) {
-    vec->emplace_back(shape, ctx, false, dtype);
-    vec->back() = 0;
-  } else {
-    // NDArray with non-default storage. Storage allocation is always delayed.
-    vec->emplace_back(stype, shape, ctx, true, dtype);
-  }
-}
 void GraphExecutor::Forward(bool is_train) {
   RunOps(is_train, 0, num_forward_nodes_);
 }
@@ -309,204 +287,6 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
 }
 
 /*!
- * \brief Assign context to the graph.
- * This is triggered by both simple_bind and bind flows.
- */
-static Graph AssignContext(Graph g,
-                    const Context& default_ctx,
-                    const std::map<std::string, Context>& ctx_map,
-                    const std::vector<Context>& in_arg_ctxes,
-                    const std::vector<Context>& arg_grad_ctxes,
-                    const std::vector<Context>& aux_state_ctxes,
-                    const std::vector<OpReqType>& grad_req_types,
-                    size_t num_forward_inputs,
-                    size_t num_forward_outputs) {
-  const auto& idx = g.indexed_graph();
-  const auto& mutable_nodes = idx.mutable_input_nodes();
-  // default use default context.
-  if (ctx_map.size() == 0) {
-    g.attrs["context"] = std::make_shared<nnvm::any>(
-        ContextVector(idx.num_nodes(), default_ctx));
-    for (const auto& x : in_arg_ctxes) {
-      CHECK(x == default_ctx)
-        << "Input array is in " << x << " while binding with ctx=" << default_ctx
-        << ". All arguments must be in global context (" << default_ctx
-        << ") unless group2ctx is specified for cross-device graph.";
-    }
-    for (const auto& x : arg_grad_ctxes) {
-      CHECK(x == default_ctx)
-        << "Gradient array is in " << x << " while binding with ctx="
-        << default_ctx << ". All gradients must be in global context (" << default_ctx
-        << ") unless group2ctx is specified for cross-device graph.";
-    }
-    return g;
-  }
-
-  // otherwise, use context assignment.
-  std::map<Context, int> ctx2id;  // map ctx to device id
-  std::vector<Context> ctx_list;  // index is device id
-  nnvm::DeviceVector device(idx.num_nodes(), -1);  // index is node id
-  nnvm::DeviceAssignMap device_map;  // map arg name to device id
-
-  // loop through the user input ctx_map and
-  // populate maps and lists
-  for (auto &kv : ctx_map) {
-    if (ctx2id.count(kv.second) == 0) {  // if context has no device id, create one
-      ctx2id[kv.second] = static_cast<int>(ctx_list.size());  // assign device id to ctx
-      ctx_list.push_back(kv.second);  // save ctx to the list
-    }
-    // assign device id to to the arg name with the corresponding ctx
-    device_map[kv.first] = ctx2id.at(kv.second);
-  }
-
-  // loop through all the rest of input nodes not specified
-  // in the ctx_map and populate maps and lists
-  size_t arg_top = 0, aux_top = 0;
-  for (size_t i = 0; i < num_forward_inputs; ++i) {
-    const uint32_t nid = idx.input_nodes().at(i);
-    Context ctx;
-    if (mutable_nodes.count(nid)) {  // aux node is mutable
-      CHECK_LT(aux_top, aux_state_ctxes.size());
-      ctx = aux_state_ctxes[aux_top];
-      ++aux_top;
-    } else {  // regular input node is immutable
-      CHECK_LT(arg_top, in_arg_ctxes.size());
-      ctx = in_arg_ctxes[arg_top];
-      ++arg_top;
-    }
-    if (ctx2id.count(ctx) == 0) {  // if the current ctx is not in the map of ctx and device id
-      ctx2id[ctx] = static_cast<int>(ctx_list.size());  // assign the current ctx with device id
-      ctx_list.push_back(ctx);  // save the current ctx in the list
-    }
-    device[nid] = ctx2id.at(ctx);  // assign device id to the current node
-  }
-
-  // loop through backward input nodes and populate maps and lists
-  // the backward input nodes is the gradient of the loss wrt the output
-  size_t arg_grad_offset = 0;
-  // keep an offset into the arg_grad_ctxes vector,
-  // since g.outputs exclude arg_grad whose req == null
-  CHECK_GE(grad_req_types.size(), g.outputs.size() - num_forward_outputs)
-           << "insufficient number of grad_reqs";
-  for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) {
-    while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset;
-    const uint32_t nid = idx.outputs()[i].node_id;
-    Context ctx = arg_grad_ctxes[arg_grad_offset];
-    if (ctx2id.count(ctx) == 0) {
-      ctx2id[ctx] = static_cast<int>(ctx_list.size());
-      ctx_list.push_back(ctx);
-    }
-    int devid = ctx2id.at(ctx);
-    if (device[nid] != -1) {
-      CHECK_EQ(device[nid], devid) << "device of same output not equal to each other";
-    } else {
-      device[nid] = devid;
-    }
-  }
-
-  g.attrs["device"] = std::make_shared<dmlc::any>(std::move(device));
-  g = nnvm::pass::PlaceDevice(g, "__ctx_group__", device_map, "_CrossDeviceCopy");
-  const auto& assigned_device = g.GetAttr<nnvm::DeviceVector>("device");
-
-  ContextVector vcontext;
-  for (size_t i = 0; i < assigned_device.size(); ++i) {
-    if (assigned_device[i] == -1) {
-      vcontext.push_back(default_ctx);
-    } else {
-      vcontext.push_back(ctx_list[assigned_device[i]]);
-    }
-  }
-
-  // after device planning, we should check again
-  // if the assigned device of gradient node
-  // corresponds to storage of grads
-  auto &new_idx = g.indexed_graph();
-  arg_grad_offset = 0;
-  for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) {
-    while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset;
-    const uint32_t nid = new_idx.outputs()[i].node_id;
-    Context ctx = arg_grad_ctxes[arg_grad_offset];
-    CHECK(ctx == vcontext[nid])
-      << "Trying to save gradient to " << ctx
-      << " while its source node \"" << new_idx[nid].source->attrs.name
-      << "\" computes it on " << vcontext[nid]
-      << ". Check your ctx in NDArray allocation.";
-  }
-
-  g.attrs["context"] = std::make_shared<nnvm::any>(std::move(vcontext));
-  return g;
-}
-
-static void HandleInferShapeError(const size_t num_forward_inputs,
-                           const nnvm::IndexedGraph& idx,
-                           const nnvm::ShapeVector& inferred_shapes) {
-  int cnt = 10;
-  std::ostringstream oss;
-  for (size_t i = 0; i < num_forward_inputs; ++i) {
-    const uint32_t nid = idx.input_nodes().at(i);
-    const uint32_t eid = idx.entry_id(nid, 0);
-    const TShape& inferred_shape = inferred_shapes[eid];
-    if (inferred_shape.ndim() == 0 || inferred_shape.Size() == 0U) {
-      const std::string& arg_name = idx[nid].source->attrs.name;
-      oss << arg_name << ": " << inferred_shape << ", ";
-      if (--cnt == 0) {
-        oss << "...";
-        break;
-      }
-    }
-  }
-  LOG(FATAL) << "InferShape pass cannot decide shapes for the following arguments "
-                "(0s means unknown dimensions). Please consider providing them as inputs:\n"
-             << oss.str();
-}
-
-static void HandleInferTypeError(const size_t num_forward_inputs,
-                          const nnvm::IndexedGraph& idx,
-                          const nnvm::DTypeVector& inferred_dtypes) {
-  int cnt = 10;
-  std::ostringstream oss;
-  for (size_t i = 0; i < num_forward_inputs; ++i) {
-    const uint32_t nid = idx.input_nodes().at(i);
-    const uint32_t eid = idx.entry_id(nid, 0);
-    const int inferred_dtype = inferred_dtypes[eid];
-    if (inferred_dtype == -1) {
-      const std::string& arg_name = idx[nid].source->attrs.name;
-      oss << arg_name << ": " << inferred_dtype << ", ";
-      if (--cnt == 0) {
-        oss << "...";
-        break;
-      }
-    }
-  }
-  LOG(FATAL) << "InferType pass cannot decide dtypes for the following arguments "
-                "(-1 means unknown dtype). Please consider providing them as inputs:\n"
-             << oss.str();
-}
-
-static void HandleInferStorageTypeError(const size_t num_forward_inputs,
-                                 const nnvm::IndexedGraph& idx,
-                                 const StorageTypeVector& inferred_stypes) {
-  int cnt = 10;
-  std::ostringstream oss;
-  for (size_t i = 0; i < num_forward_inputs; ++i) {
-    const uint32_t nid = idx.input_nodes().at(i);
-    const uint32_t eid = idx.entry_id(nid, 0);
-    const int inferred_stype = inferred_stypes[eid];
-    if (inferred_stype == -1) {
-      const std::string& arg_name = idx[nid].source->attrs.name;
-      oss << arg_name << ": " << common::stype_string(inferred_stype) << ", ";
-      if (--cnt == 0) {
-        oss << "...";
-        break;
-      }
-    }
-  }
-  LOG(FATAL) << "InferStorageType pass cannot decide storage type for the following arguments "
-                "(-1 means unknown stype). Please consider providing them as inputs:\n"
-             << oss.str();
-}
-
-/*!
  * \brief GraphExecutor initializer for regular bind flow in which
  * input arguments and gradients are provided by users. This initializer
  * uses the user provided NDArrays to populate data entries of the graph.
@@ -681,57 +461,6 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
 }
 
 /*!
- * \brief If the requested ndarray's shape size is less than
- * the corresponding shared_data_array's shape size and the
- * storage type is shareable, reuse the memory allocation
- * in shared_buffer; otherwise, create a zero ndarray.
- * Shareable storages include both default storage and row_sparse storage
- * if enable_row_sparse_sharing is `True`, otherwise default storage only.
- */
-static NDArray ReshapeOrCreate(const std::string& name,
-                        const TShape& dest_arg_shape,
-                        const int dest_arg_dtype,
-                        const NDArrayStorageType dest_arg_stype,
-                        const Context& ctx,
-                        std::unordered_map<std::string, NDArray>* shared_buffer,
-                        bool enable_row_sparse_sharing) {
-  bool stype_shareable = dest_arg_stype == kDefaultStorage;
-  if (enable_row_sparse_sharing) {
-    stype_shareable = stype_shareable || dest_arg_stype == kRowSparseStorage;
-  }
-  auto it = shared_buffer->find(name);
-  if (it != shared_buffer->end()) {
-    // check if size is large enough for sharing
-    bool size_shareable = it->second.shape().Size() >= dest_arg_shape.Size();
-    if (size_shareable && stype_shareable) {  // memory can be reused
-      CHECK_EQ(it->second.dtype(), dest_arg_dtype)
-        << "Requested arg array's dtype does not match that of the reusable ndarray";
-      CHECK_EQ(it->second.storage_type(), dest_arg_stype)
-        << "Requested arg array's stype does not match that of the reusable ndarray";
-      return it->second.Reshape(dest_arg_shape);
-    } else if (stype_shareable) {
-      LOG(WARNING) << "Bucketing: data " << name << " has a shape " << dest_arg_shape
-                   << ", which is larger than already allocated shape " << it->second.shape()
-                   << ". Need to re-allocate. Consider putting default bucket key to be "
-                   << "the bucket taking the largest input for better memory sharing.";
-      // size is not large enough, creating a larger one for sharing
-      // the NDArrays in shared_buffer are guaranteed to be of shareable storages
-      it->second = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
-      return it->second;
-    } else {
-      // not shareable storage
-      return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
-    }
-  } else {
-    auto ret = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
-    if (stype_shareable) {
-      shared_buffer->emplace(name, ret);
-    }
-    return ret;
-  }  // if (it != shared_buffer->end())
-}
-
-/*!
  * \brief Initialize in_args, arg_grads, and aux_states
  * and their data_entry_ of the executor using
  * shared_buffer from DataParallelExecutorGroup
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index bfc415b..7b936c3 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -163,20 +163,21 @@ class GraphExecutor : public Executor {
                      std::vector<NDArray>* aux_state_vec);
   // Initialize in_args, arg_grads and aux_states with
   // shared_buffer and shared_exec
-  void InitArguments(const nnvm::IndexedGraph& idx,
-                     const nnvm::ShapeVector& inferred_shapes,
-                     const nnvm::DTypeVector& inferred_dtypes,
-                     const StorageTypeVector& inferred_stypes,
-                     const std::vector<Context>& in_arg_ctxes,
-                     const std::vector<Context>& arg_grad_ctxes,
-                     const std::vector<Context>& aux_state_ctxes,
-                     const std::vector<OpReqType>& grad_req_types,
-                     const std::unordered_set<std::string>& shared_arg_names,
-                     const Executor* shared_exec,
-                     std::unordered_map<std::string, NDArray>* shared_buffer,
-                     std::vector<NDArray>* in_arg_vec,
-                     std::vector<NDArray>* arg_grad_vec,
-                     std::vector<NDArray>* aux_state_vec);
+  virtual void InitArguments(const nnvm::IndexedGraph& idx,
+                             const nnvm::ShapeVector& inferred_shapes,
+                             const nnvm::DTypeVector& inferred_dtypes,
+                             const StorageTypeVector& inferred_stypes,
+                             const std::vector<Context>& in_arg_ctxes,
+                             const std::vector<Context>& arg_grad_ctxes,
+                             const std::vector<Context>& aux_state_ctxes,
+                             const std::vector<OpReqType>& grad_req_types,
+                             const std::unordered_set<std::string>& shared_arg_names,
+                             const Executor* shared_exec,
+                             std::unordered_map<std::string, NDArray>* shared_buffer,
+                             std::vector<NDArray>* in_arg_vec,
+                             std::vector<NDArray>* arg_grad_vec,
+                             std::vector<NDArray>* aux_state_vec);
+
   // internal initialization of the graph for simple bind
   Graph InitGraph(nnvm::Symbol symbol,
                   const Context& default_ctx,
@@ -212,7 +213,6 @@ class GraphExecutor : public Executor {
   void BulkInferenceOpSegs();
   // perform bulking and segmentation on a training graph
   void BulkTrainingOpSegs(size_t total_num_nodes);
-
   // indicate whether there is a backward graph for gradients.
   bool need_grad_;
   // internal graph
diff --git a/src/executor/onnx_to_tensorrt.cc b/src/executor/onnx_to_tensorrt.cc
new file mode 100644
index 0000000..0b4d91b
--- /dev/null
+++ b/src/executor/onnx_to_tensorrt.cc
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file onnx_to_tensorrt.cc
+ * \brief TensorRT integration with the MXNet executor
+ * \author Marek Kolodziej, Clement Fuji Tsang
+ */
+
+#if MXNET_USE_TENSORRT
+
+#include "./onnx_to_tensorrt.h"
+
+#include <onnx/onnx.pb.h>
+
+#include <NvInfer.h>
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/text_format.h>
+#include <onnx-tensorrt/NvOnnxParser.h>
+#include <onnx-tensorrt/NvOnnxParserRuntime.h>
+#include <onnx-tensorrt/PluginFactory.hpp>
+#include <onnx-tensorrt/plugin_common.hpp>
+
+using std::cout;
+using std::cerr;
+using std::endl;
+
+namespace onnx_to_tensorrt {
+
+struct InferDeleter {
+  template<typename T>
+    void operator()(T* obj) const {
+      if ( obj ) {
+        obj->destroy();
+      }
+    }
+};
+
+template<typename T>
+inline std::shared_ptr<T> InferObject(T* obj) {
+  if ( !obj ) {
+    throw std::runtime_error("Failed to create object");
+  }
+  return std::shared_ptr<T>(obj, InferDeleter());
+}
+
+std::string onnx_ir_version_string(int64_t ir_version = onnx::IR_VERSION) {
+  int onnx_ir_major = ir_version / 1000000;
+  int onnx_ir_minor = ir_version % 1000000 / 10000;
+  int onnx_ir_patch = ir_version % 10000;
+  return (std::to_string(onnx_ir_major) + "." +
+    std::to_string(onnx_ir_minor) + "." +
+    std::to_string(onnx_ir_patch));
+}
+
+void PrintVersion() {
+  cout << "Parser built against:" << endl;
+  cout << "  ONNX IR version:  " << onnx_ir_version_string(onnx::IR_VERSION) << endl;
+  cout << "  TensorRT version: "
+    << NV_TENSORRT_MAJOR << "."
+    << NV_TENSORRT_MINOR << "."
+    << NV_TENSORRT_PATCH << endl;
+}
+
+nvinfer1::ICudaEngine* onnxToTrtCtx(
+        const std::string& onnx_model,
+        int32_t max_batch_size,
+        size_t max_workspace_size,
+        nvinfer1::ILogger::Severity verbosity,
+        bool debug_builder) {
+  GOOGLE_PROTOBUF_VERIFY_VERSION;
+
+  TRT_Logger trt_logger(verbosity);
+  auto trt_builder = InferObject(nvinfer1::createInferBuilder(trt_logger));
+  auto trt_network = InferObject(trt_builder->createNetwork());
+  auto trt_parser  = InferObject(nvonnxparser::createParser(
+      *trt_network, trt_logger));
+  ::ONNX_NAMESPACE::ModelProto parsed_model;
+  // We check for a valid parse, but the main effect is the side effect
+  // of populating parsed_model
+  if (!parsed_model.ParseFromString(onnx_model)) {
+    throw dmlc::Error("Could not parse ONNX from string");
+  }
+
+  if ( !trt_parser->parse(onnx_model.c_str(), onnx_model.size()) ) {
+      int nerror = trt_parser->getNbErrors();
+      for ( int i=0; i < nerror; ++i ) {
+        nvonnxparser::IParserError const* error = trt_parser->getError(i);
+        if ( error->node() != -1 ) {
+          ::ONNX_NAMESPACE::NodeProto const& node =
+            parsed_model.graph().node(error->node());
+          cerr << "While parsing node number " << error->node()
+               << " [" << node.op_type();
+          if ( !node.output().empty() ) {
+            cerr << " -> \"" << node.output(0) << "\"";
+          }
+          cerr << "]:" << endl;
+          if ( static_cast<int>(verbosity) >= \
+            static_cast<int>(nvinfer1::ILogger::Severity::kINFO) ) {
+            cerr << "--- Begin node ---" << endl;
+            cerr << node.DebugString() << endl;
+            cerr << "--- End node ---" << endl;
+          }
+        }
+        cerr << "ERROR: "
+             << error->file() << ":" << error->line()
+             << " In function " << error->func() << ":\n"
+             << "[" << static_cast<int>(error->code()) << "] " << error->desc()
+             << endl;
+      }
+      throw dmlc::Error("Cannot parse ONNX into TensorRT Engine");
+  }
+
+  bool fp16 = trt_builder->platformHasFastFp16();
+
+  trt_builder->setMaxBatchSize(max_batch_size);
+  trt_builder->setMaxWorkspaceSize(max_workspace_size);
+  if ( fp16 && dmlc::GetEnv("MXNET_TENSORRT_USE_FP16_FOR_FP32", false) ) {
+    LOG(INFO) << "WARNING: TensorRT using fp16 given original MXNet graph in fp32 !!!";
+    trt_builder->setHalf2Mode(true);
+  }
+
+  trt_builder->setDebugSync(debug_builder);
+  nvinfer1::ICudaEngine* trt_engine = trt_builder->buildCudaEngine(*trt_network.get());
+  return trt_engine;
+}
+
+}  // namespace onnx_to_tensorrt
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/executor/onnx_to_tensorrt.h b/src/executor/onnx_to_tensorrt.h
new file mode 100644
index 0000000..259cfce
--- /dev/null
+++ b/src/executor/onnx_to_tensorrt.h
@@ -0,0 +1,77 @@
+#ifndef MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_
+#define MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file onnx_to_tensorrt.h
+ * \brief TensorRT integration with the MXNet executor
+ * \author Marek Kolodziej, Clement Fuji Tsang
+ */
+
+#if MXNET_USE_TENSORRT
+
+#include <fstream>
+#include <iostream>
+#include <NvInfer.h>
+#include <sstream>
+#include <string>
+
+#include "../operator/contrib/tensorrt-inl.h"
+
+namespace onnx_to_tensorrt {
+
+class TRT_Logger : public nvinfer1::ILogger {
+        nvinfer1::ILogger::Severity _verbosity;
+        std::ostream* _ostream;
+ public:
+        TRT_Logger(Severity verbosity = Severity::kWARNING,
+                   std::ostream& ostream = std::cout)
+                : _verbosity(verbosity), _ostream(&ostream) {}
+        void log(Severity severity, const char* msg) override {
+                if ( severity <= _verbosity ) {
+                        time_t rawtime = std::time(0);
+                        char buf[256];
+                        strftime(&buf[0], 256,
+                                 "%Y-%m-%d %H:%M:%S",
+                                 std::gmtime(&rawtime));
+                        const char* sevstr = (severity == Severity::kINTERNAL_ERROR ? "    BUG" :
+                                              severity == Severity::kERROR          ? "  ERROR" :
+                                              severity == Severity::kWARNING        ? "WARNING" :
+                                              severity == Severity::kINFO           ? "   INFO" :
+                                              "UNKNOWN");
+                        (*_ostream) << "[" << buf << " " << sevstr << "] "
+                                    << msg
+                                    << std::endl;
+                }
+        }
+};
+
+nvinfer1::ICudaEngine* onnxToTrtCtx(
+        const std::string& onnx_model,
+        int32_t max_batch_size = 32,
+        size_t max_workspace_size = 1L << 30,
+        nvinfer1::ILogger::Severity verbosity = nvinfer1::ILogger::Severity::kWARNING,
+        bool debug_builder = false);
+}  // namespace onnx_to_tensorrt
+
+#endif  // MXNET_USE_TENSORRT
+
+#endif  // MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_
diff --git a/src/executor/tensorrt_pass.cc b/src/executor/tensorrt_pass.cc
new file mode 100644
index 0000000..b5fc8d1
--- /dev/null
+++ b/src/executor/tensorrt_pass.cc
@@ -0,0 +1,596 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file tensorrt_pass.cc
+ * \brief Replace TRT compatible subgraphs by TRT engines
+ * \author Clement Fuji Tsang
+ */
+
+#if MXNET_USE_TENSORRT
+
+#include <NvInfer.h>
+#include <mxnet/base.h>
+#include <mxnet/op_attr_types.h>
+#include <mxnet/operator.h>
+#include <nnvm/graph_attr_types.h>
+#include <onnx/onnx.pb.h>
+
+#include "../operator/contrib/nnvm_to_onnx-inl.h"
+#include "./exec_pass.h"
+#include "./onnx_to_tensorrt.h"
+
+namespace mxnet {
+namespace exec {
+
+using NodePtr = nnvm::NodePtr;
+
+/*!
+ * \brief Custom graph class, which will contain bi-directional nodes
+ * we need to compute DFS and reverse DFS for graph partitioning
+ */
+class BidirectionalGraph {
+ public:
+  struct Node {
+    nnvm::Node* nnvmptr;
+    std::vector<Node*> inputs;
+    std::vector<Node*> outputs;
+  };
+  std::vector<Node> nodes;
+  std::unordered_map<nnvm::Node*, uint32_t> nnvm2nid;
+  std::vector<Node*> outputs;
+  static const std::unordered_set<std::string> unconditionalTRTop;
+
+  explicit BidirectionalGraph(const Graph &g) {
+    auto& idx = g.indexed_graph();
+    auto num_nodes = idx.num_nodes();
+    nodes.reserve(num_nodes);
+    nnvm2nid.reserve(num_nodes);
+    outputs.reserve(idx.outputs().size());
+    DFSVisit(g.outputs, [this](const nnvm::NodePtr& n) {
+      BidirectionalGraph::Node new_node;
+      new_node.nnvmptr = n.get();
+      nnvm2nid[n.get()] = static_cast<uint32_t>(nodes.size());
+      nodes.emplace_back(std::move(new_node));
+    });
+    for (const auto& it : nnvm2nid) {
+      nnvm::Node* nnvmnode = it.first;
+      uint32_t nid = it.second;
+      for (auto& n : nnvmnode->inputs) {
+        uint32_t input_nid = nnvm2nid[n.node.get()];
+        nodes[input_nid].outputs.emplace_back(&nodes[nid]);
+        nodes[nid].inputs.emplace_back(&nodes[input_nid]);
+      }
+    }
+    for (auto& e : g.outputs) {
+      uint32_t nid = nnvm2nid[e.node.get()];
+      outputs.emplace_back(&nodes[nid]);
+    }
+  }
+
+  template <typename FVisit>
+  void DFS(const std::vector<Node*>& heads, bool reverse, FVisit fvisit) {
+    std::unordered_set<Node*> visited;
+    std::vector<Node*> vec(heads.begin(), heads.end());
+    visited.reserve(heads.size());
+    while (!vec.empty()) {
+      Node* vertex = vec.back();
+      vec.pop_back();
+      if (visited.count(vertex) == 0) {
+        visited.insert(vertex);
+        fvisit(vertex);
+        std::vector<Node*> nexts = reverse ? vertex->inputs : vertex->outputs;
+        for (Node* node : nexts) {
+          if (visited.count(node) == 0) {
+            vec.emplace_back(node);
+          }
+        }
+      }
+    }
+  }
+
+  using t_pairset = std::pair<std::unordered_set<Node*>, std::unordered_set<Node*>>;
+  using t_pairvec = std::pair<std::vector<Node*>, std::vector<Node*>>;
+  using t_uncomp_map = std::unordered_map<Node*, std::unordered_set<Node*>>;
+
+  std::unordered_set<Node*> naive_grow_subgraph(Node* head,
+                                                std::unordered_set<Node*>* set_unused,
+                                                t_uncomp_map* uncomp_map) {
+    std::unordered_set<Node*> subgraph;
+    std::unordered_set<Node*> uncomp_set;
+    std::deque<Node*> stack;
+    stack.emplace_back(head);
+    while (!stack.empty()) {
+      Node* vertex = stack.back();
+      stack.pop_back();
+      if (set_unused->count(vertex) && !uncomp_set.count(vertex)) {
+        set_unused->erase(vertex);
+        subgraph.insert(vertex);
+        uncomp_set.insert((*uncomp_map)[vertex].begin(), (*uncomp_map)[vertex].end());
+        for (Node* input : vertex->inputs) {
+          if (set_unused->count(input) && !uncomp_set.count(input)) {
+            stack.emplace_back(input);
+          }
+        }
+        for (Node* output : vertex->outputs) {
+          if (set_unused->count(output) && !uncomp_set.count(output)) {
+            stack.emplace_back(output);
+          }
+        }
+      }
+    }
+    return subgraph;
+  }
+
+  std::vector<std::unordered_set<Node*>> get_subsets(
+    std::unordered_map<std::string, NDArray>* const params_map) {
+    std::vector<std::unordered_set<Node*>> subgraphs;
+    std::unordered_set<Node*> set_nonTRTnodes;
+    std::unordered_set<Node*> set_allnodes(nodes.size());
+    std::vector<t_pairset> separation_sets;
+    for (Node& node : nodes) {
+      if (!IsTRTCompatible(node.nnvmptr)) {
+        set_nonTRTnodes.insert(&node);
+        std::unordered_set<Node*> in_graph;
+        std::unordered_set<Node*> out_graph;
+        std::vector<Node*> dummy_head;
+        dummy_head.emplace_back(&node);
+        DFS(dummy_head, false, [&out_graph](Node* node) {
+          out_graph.insert(node);
+        });
+        DFS(dummy_head, true, [&in_graph](Node* node) {
+          in_graph.insert(node);
+        });
+        separation_sets.emplace_back(std::make_pair(in_graph, out_graph));
+      }
+      set_allnodes.emplace(&node);
+    }
+    t_uncomp_map uncomp_map;
+    std::unordered_set<Node*> set_TRTnodes;
+    set_TRTnodes.insert(set_allnodes.begin(), set_allnodes.end());
+    for (Node* n : set_nonTRTnodes) {
+      set_TRTnodes.erase(n);
+    }
+    for (Node* n : set_TRTnodes) {
+      for (t_pairset p : separation_sets) {
+        if (p.first.count(n)) {
+          uncomp_map[n].insert(p.second.begin(), p.second.end());
+        } else if (p.second.count(n)) {
+          uncomp_map[n].insert(p.first.begin(), p.first.end());
+        }
+      }
+      for (Node* nonTRTn : set_nonTRTnodes) {
+        uncomp_map[n].erase(nonTRTn);
+      }
+    }
+    std::unordered_set<Node*> set_unused;
+    set_unused.reserve(set_TRTnodes.size());
+
+    for (auto& n : set_TRTnodes) {
+      if (n->nnvmptr->attrs.op != nullptr || params_map->count(n->nnvmptr->attrs.name)) {
+        set_unused.insert(n);
+      }
+    }
+    std::unordered_set<Node*> visited;
+    std::deque<Node*> stack(outputs.begin(), outputs.end());
+    while (!stack.empty()) {
+      Node* vertex = stack.front();
+      stack.pop_front();
+      if (!visited.count(vertex)) {
+        visited.insert(vertex);
+        if (set_unused.count(vertex)) {
+          subgraphs.emplace_back(naive_grow_subgraph(vertex, &set_unused, &uncomp_map));
+        }
+        for (Node* input : vertex->inputs) {
+          stack.emplace_back(input);
+        }
+      }
+    }
+
+    return subgraphs;
+  }
+
+
+ private:
+  friend class Graph;
+
+  bool IsTRTCompatible(nnvm::Node* nodeptr) {
+    if (nodeptr->op() == nullptr) {
+      return true;
+    }
+
+    const std::string op_name = nodeptr->op()->name;
+    if (op_name == "Pooling") {
+      return (nodeptr->attrs.dict.at("pool_type") == "avg" ||
+          nodeptr->attrs.dict.at("pool_type") == "max");
+    }
+
+    if (unconditionalTRTop.count(op_name)) {
+      return true;
+    }
+
+    if (op_name == "Activation") {
+      return nodeptr->attrs.dict.at("act_type") == "relu" ||
+        nodeptr->attrs.dict.at("act_type") == "tanh" ||
+        nodeptr->attrs.dict.at("act_type") == "sigmoid";
+    }
+
+    return false;
+  }
+};  // class BidirectionalGraph
+
+/*!
+ * \brief function which transform std::vector<dmlc::any> back to Attrs (dmlc::any)
+ */
+const std::unordered_set<std::string> BidirectionalGraph::unconditionalTRTop = {
+  "Convolution",
+  "BatchNorm",
+  "elemwise_add",
+  "elemwise_sub",
+  "elemwise_mul",
+  "rsqrt",
+  "pad",
+  "Pad",
+  "mean",
+  "FullyConnected",
+  "Flatten",
+  "SoftmaxOutput",
+};
+
+
+using NodeEntrySet = std::unordered_set<nnvm::NodeEntry, nnvm::NodeEntryHash,
+                                        nnvm::NodeEntryEqual>;
+
+/*!
+ * \brief get the output nodes of the subgraph in the main graph
+ * \return a vector of the output nodes
+*/
+std::vector<nnvm::NodeEntry> GetSubgraphNodeEntries(Graph g,
+    std::unordered_set<nnvm::Node*> set_subgraph) {
+  std::vector<nnvm::NodeEntry> outputs;
+  NodeEntrySet _outputs;
+  for (auto& e : g.outputs) {
+    if (set_subgraph.count(e.node.get())) {
+      _outputs.insert(e);
+    }
+  }
+  DFSVisit(g.outputs, [&set_subgraph, &_outputs](const nnvm::NodePtr &node){
+    if (!set_subgraph.count(node.get())) {
+      for (auto& e : node->inputs) {
+        if (set_subgraph.count(e.node.get())) {
+          _outputs.insert(e);
+        }
+      }
+    }
+  });
+  outputs.insert(outputs.begin(), _outputs.begin(), _outputs.end());
+  return outputs;
+}
+
+
+/*!
+ * \brief get the nodes outside of the subgraph for which outputs are used in the subgraph
+ * \return a vector the nodes
+*/
+std::vector<nnvm::NodeEntry> GetSubgraphInterfaceNodes(Graph g,
+    std::unordered_set<nnvm::Node*> set_subgraph) {
+  std::vector<nnvm::NodeEntry> inputs;
+  NodeEntrySet _inputs;
+  DFSVisit(g.outputs, [&set_subgraph, &_inputs](const nnvm::NodePtr &node){
+    if (set_subgraph.count(node.get())) {
+      for (auto& e : node->inputs) {
+        if (!set_subgraph.count(e.node.get())) {
+          _inputs.insert(e);
+        }
+      }
+    }
+  });
+  inputs.insert(inputs.begin(), _inputs.begin(), _inputs.end());
+  return inputs;
+}
+
+std::unordered_map<uint32_t, uint32_t> GetGraphInputsMap(const Graph& g) {
+  std::unordered_map<uint32_t, uint32_t> outputs;
+  auto& idx = g.indexed_graph();
+  outputs.reserve(idx.num_nodes());
+  std::vector<uint32_t> input_nodes = idx.input_nodes();
+  for (size_t i = 0; i < input_nodes.size(); ++i) {
+    outputs[input_nodes[i]] = static_cast<uint32_t>(i);
+  }
+  return outputs;
+}
+
+/*!
+ * \brief Dummy function which creates a fake TensorRT Node
+ */
+nnvm::NodePtr ConvertNnvmGraphToOnnx(const nnvm::Graph &g,
+                                     std::unordered_map<std::string, NDArray>* const params_map) {
+  auto p = nnvm::Node::Create();
+  p->attrs.op = nnvm::Op::Get("_trt_op");
+  op::TRTParam trt_param = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(g, params_map);
+  p->attrs.dict["serialized_output_map"] = trt_param.serialized_output_map;
+  p->attrs.dict["serialized_input_map"]  = trt_param.serialized_input_map;
+  p->attrs.dict["serialized_onnx_graph"] = trt_param.serialized_onnx_graph;
+  if (p->op()->attr_parser != nullptr) {
+    p->op()->attr_parser(&(p->attrs));
+  }
+  return p;
+}
+
+/*!
+ * \brief Update attributes of the graph (such as some inputs properties)
+ */
+Graph UpdateSubgraphAttrs(Graph&& subgraph, const Graph& g,
+                          const std::unordered_map<nnvm::Node*, nnvm::NodePtr>& old2new,
+                          const nnvm::NodeEntryMap<nnvm::NodeEntry>& main_input_entry_to_sub) {
+  const auto& idx     = g.indexed_graph();
+  const auto& sub_idx = subgraph.indexed_graph();
+
+  const auto& shape               = g.GetAttr<nnvm::ShapeVector>("shape");
+  const auto& dtype               = g.GetAttr<nnvm::DTypeVector>("dtype");
+  const auto& storage_type        = g.GetAttr<StorageTypeVector>("storage_type");
+  const auto& shape_inputs        = g.GetAttr<nnvm::ShapeVector>("shape_inputs");
+  const auto& dtype_inputs        = g.GetAttr<nnvm::DTypeVector>("dtype_inputs");
+  const auto& storage_type_inputs = g.GetAttr<StorageTypeVector>("storage_type_inputs");
+
+  nnvm::ShapeVector sub_shape(sub_idx.num_node_entries());
+  nnvm::DTypeVector sub_dtype(sub_idx.num_node_entries());
+  StorageTypeVector sub_storage_type(sub_idx.num_node_entries());
+  nnvm::ShapeVector sub_shape_inputs(sub_idx.input_nodes().size());
+  nnvm::DTypeVector sub_dtype_inputs(sub_idx.input_nodes().size());
+  StorageTypeVector sub_storage_type_inputs(sub_idx.input_nodes().size());
+
+  const std::unordered_map<uint32_t, uint32_t> inputsindex2pos     = GetGraphInputsMap(g);
+  const std::unordered_map<uint32_t, uint32_t> sub_inputsindex2pos = GetGraphInputsMap(subgraph);
+  // map attributes from graph to subgraph
+  for (auto& p : old2new) {
+    const uint32_t nid     = idx.node_id(p.first);
+    const uint32_t sub_nid = sub_idx.node_id(p.second.get());
+    const nnvm::Op* op = sub_idx[sub_nid].source->op();
+    if (op == nullptr) {  // if it's an input node, there is only one output node entry
+      const uint32_t sub_i       = sub_idx.entry_id(sub_nid, 0);
+      const uint32_t sub_input_i = sub_inputsindex2pos.at(sub_nid);
+      const uint32_t i           = idx.entry_id(nid, 0);
+
+      sub_shape[sub_i] = shape[i];
+      sub_dtype[sub_i] = dtype[i];
+      sub_storage_type[sub_i]       = storage_type[i];
+      sub_shape_inputs[sub_input_i] = shape_inputs[inputsindex2pos.at(nid)];
+      sub_dtype_inputs[sub_input_i] = dtype_inputs[inputsindex2pos.at(nid)];
+      sub_storage_type_inputs[sub_input_i] = storage_type_inputs[inputsindex2pos.at(nid)];
+
+    } else {
+      for (size_t oi = 0; oi < op->num_outputs; ++oi) {
+        const uint32_t sub_i = sub_idx.entry_id(sub_nid, oi);
+        const uint32_t i = idx.entry_id(nid, oi);
+          sub_shape[sub_i] = shape[i];
+          sub_dtype[sub_i] = dtype[i];
+          sub_storage_type[sub_i] = storage_type[i];
+      }
+    }
+  }
+  // old2new doesn't contain placeholder / interfaces
+  for (auto& p : main_input_entry_to_sub) {
+    nnvm::NodeEntry main_entry = p.first;
+    nnvm::NodeEntry sub_entry = p.second;
+    const uint32_t sub_nid = sub_idx.node_id(sub_entry.node.get());
+    const uint32_t sub_i = sub_idx.entry_id(sub_entry);
+    const uint32_t i = idx.entry_id(main_entry);
+    const uint32_t sub_input_i = sub_inputsindex2pos.at(sub_nid);
+    sub_shape[sub_i] = shape[i];
+    sub_dtype[sub_i] = dtype[i];
+    sub_storage_type[sub_i] = storage_type[i];
+    sub_shape_inputs[sub_input_i] = sub_shape[sub_i];
+    sub_dtype_inputs[sub_input_i] = sub_dtype[sub_i];
+    sub_storage_type_inputs[sub_input_i] = sub_storage_type[sub_i];
+  }
+  subgraph.attrs["shape"] =
+      std::make_shared<dmlc::any>(std::move(sub_shape));
+  subgraph.attrs["dtype"] =
+      std::make_shared<dmlc::any>(std::move(sub_dtype));
+  subgraph.attrs["storage_type"] =
+      std::make_shared<dmlc::any>(std::move(sub_storage_type));
+  subgraph.attrs["shape_inputs"] =
+      std::make_shared<dmlc::any>(std::move(sub_shape_inputs));
+  subgraph.attrs["dtype_inputs"] =
+      std::make_shared<dmlc::any>(std::move(sub_dtype_inputs));
+  subgraph.attrs["storage_type_inputs"] =
+      std::make_shared<dmlc::any>(std::move(sub_storage_type_inputs));
+
+  return subgraph;
+}
+
+/*!
+ * \brief Generate a name for a new TRT node, avoid collision if some TRT_nodes are already defined
+ */
+const std::string GetNewTrtName(const Graph& g, const Graph& subgraph) {
+  const std::string name_prefix("TRT_node");
+  std::unordered_set<std::string> name_set;
+  DFSVisit(g.outputs, [&name_set, &name_prefix](const nnvm::NodePtr& node) {
+    if (node->attrs.name.compare(0, name_prefix.size(), name_prefix) == 0) {
+      name_set.insert(node->attrs.name);
+    }
+  });
+  // name inside the subgraph will be avaible as they will be removed
+  DFSVisit(subgraph.outputs, [&name_set, &name_prefix](const nnvm::NodePtr& node) {
+    if (node->attrs.name.compare(0, name_prefix.size(), name_prefix) == 0) {
+      name_set.erase(node->attrs.name);
+    }
+  });
+  uint32_t name_suffix = 0;
+  std::string full_name = name_prefix + std::to_string(name_suffix);
+  while (name_set.count(full_name)) {
+    full_name = name_prefix + std::to_string(++name_suffix);
+  }
+  return full_name;
+}
+
+/*!
+ * \brief helper function to display what nodes are in a specific subset
+ */
+void dispNodesSet(Graph g, std::unordered_set<nnvm::Node*> s) {
+  DFSVisit(g.outputs, [&s](const nnvm::NodePtr n){
+    if (s.count(n.get())) {
+      std::cout << "  Y " << n->attrs.name << std::endl;
+    } else {
+      std::cout << "  N " << n->attrs.name << std::endl;
+    }
+  });
+}
+
+/*!
+ * \brief Replace a set of nodes by a TensorRT node
+ */
+Graph ReplaceSubgraph(Graph&& g,
+                      const std::unordered_set<nnvm::Node*>& set_subgraph,
+                      std::unordered_map<std::string, NDArray>* const params_map) {
+  // Create MXNet subgraph
+  Graph subgraph;
+
+  const auto sub_outputs_in_main = GetSubgraphNodeEntries(g, set_subgraph);
+  subgraph.outputs = sub_outputs_in_main;
+  // old2new will link raw pointer of the nodes in the graph to
+  // the corresponding shared_ptr of the nodes in the generated subgraph
+  std::unordered_map<nnvm::Node*, nnvm::NodePtr> old2new;
+  std::deque<nnvm::Node*> stack;
+  std::unordered_set<nnvm::Node*> visited;
+  int32_t reservation = set_subgraph.size();
+  old2new.reserve(reservation);
+  visited.reserve(reservation);
+
+  // Create the shared_ptr using the same raw pointer don't really matter
+  for (auto& n : set_subgraph) {
+    old2new[n] = std::make_shared<nnvm::Node>(*n);
+  }
+
+  // To generate a subgraph an input have to be replace by data node (no op)
+  // and it have to be agnostic to the node from which it's an output
+  // (For exemple even if two inputs are two different outputs from the same node)
+  nnvm::NodeEntryMap<nnvm::NodeEntry> main_input_entry_to_sub;
+  for (auto& e : GetSubgraphInterfaceNodes(g, set_subgraph)) {
+    auto node = nnvm::Node::Create();
+    node->attrs.name = e.node->attrs.name + "_" + std::to_string(e.index);
+    auto new_e = nnvm::NodeEntry{node, 0, 0};
+    main_input_entry_to_sub[e] = new_e;
+  }
+
+  for (nnvm::NodeEntry& e : subgraph.outputs) {
+    e.node = old2new[e.node.get()];
+    stack.emplace_back(e.node.get());
+  }
+  // link all nodes in the subgraph to nodes in the subgraph instead of main graph
+  while (!stack.empty()) {
+    auto vertex = stack.front();
+    stack.pop_front();
+    if (!visited.count(vertex)) {
+      visited.insert(vertex);
+      for (auto& e : vertex->inputs) {
+        auto it = main_input_entry_to_sub.find(e);
+        if (it != main_input_entry_to_sub.end()) {
+          e = it->second;
+        } else {
+          e.node = old2new[e.node.get()];
+        }
+      stack.emplace_back(e.node.get());
+      }
+    }
+  }
+  // Remove the control dependencies of the subgraph to nodes that are not in the subgraph
+  DFSVisit(subgraph.outputs, [&set_subgraph, &old2new](const nnvm::NodePtr& node) {
+    std::remove_if(node->control_deps.begin(),
+                   node->control_deps.end(),
+                   [&set_subgraph](nnvm::NodePtr n_ptr) {
+                    return !set_subgraph.count(n_ptr.get());
+                   });
+    for (nnvm::NodePtr& n_ptr : node->control_deps) {
+      n_ptr = old2new[n_ptr.get()];
+    }
+  });
+
+  subgraph = UpdateSubgraphAttrs(std::move(subgraph), g, old2new, main_input_entry_to_sub);
+  auto& sub_idx = subgraph.indexed_graph();
+
+  auto trtnodeptr = ConvertNnvmGraphToOnnx(subgraph, params_map);
+  trtnodeptr->attrs.name = GetNewTrtName(g, subgraph);
+
+  // Insert new trt node and unplug replaced nodes
+  std::unordered_map<uint32_t, nnvm::NodeEntry> sub_input_entryid_to_main;
+  for (auto& p : main_input_entry_to_sub) {
+    sub_input_entryid_to_main[sub_idx.entry_id(p.second)] = p.first;
+  }
+
+  // Plug the nodes from the main graph as inputs of the trt node
+  trtnodeptr->inputs.resize(main_input_entry_to_sub.size());
+  {
+    uint32_t counter = 0;
+    for (uint32_t i : sub_idx.input_nodes()) {
+      auto it = sub_input_entryid_to_main.find(sub_idx.entry_id(i, 0));
+      if (it != sub_input_entryid_to_main.end()) {
+        trtnodeptr->inputs[counter++] = it->second;
+      }
+    }
+  }
+  nnvm::NodeEntryMap<uint32_t> sub_outputs_in_main_to_pos;
+  for (uint32_t i = 0; i < sub_outputs_in_main.size(); ++i) {
+    sub_outputs_in_main_to_pos[sub_outputs_in_main[i]] = i;
+  }
+  // Plug the trt node as inputs to the main graph nodes
+  DFSVisit(g.outputs, [&sub_outputs_in_main_to_pos, &trtnodeptr](const nnvm::NodePtr& n) {
+    for (auto& e : n->inputs) {
+      auto it = sub_outputs_in_main_to_pos.find(e);
+      if (it != sub_outputs_in_main_to_pos.end()) {
+        e.index = it->second;
+        e.node = trtnodeptr;
+      }
+    }
+  });
+
+  for (auto& output : g.outputs) {
+    auto it = sub_outputs_in_main_to_pos.find(output);
+    if (it != sub_outputs_in_main_to_pos.end()) {
+      output.index = it->second;
+      output.node = trtnodeptr;
+    }
+  }
+
+  Graph new_graph;
+  new_graph.outputs = g.outputs;
+  return new_graph;
+}
+
+std::vector<std::unordered_set<nnvm::Node*>> GetTrtCompatibleSubsets(const Graph& g,
+    std::unordered_map<std::string, NDArray>* const params_map) {
+  BidirectionalGraph biG = BidirectionalGraph(g);
+  std::vector<std::unordered_set<BidirectionalGraph::Node*>> subsets = biG.get_subsets(params_map);
+  std::vector<std::unordered_set<nnvm::Node*>> nnvm_subsets(subsets.size(),
+                                                            std::unordered_set<nnvm::Node*>());
+  for (size_t i = 0; i < subsets.size(); ++i) {
+    nnvm_subsets[i].reserve(subsets[i].size());
+    for (auto& n : subsets[i]) {
+      nnvm_subsets[i].insert(n->nnvmptr);
+    }
+  }
+  return nnvm_subsets;
+}
+
+}  // namespace exec
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/executor/trt_graph_executor.cc b/src/executor/trt_graph_executor.cc
new file mode 100644
index 0000000..65dbb29
--- /dev/null
+++ b/src/executor/trt_graph_executor.cc
@@ -0,0 +1,450 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#if MXNET_USE_TENSORRT
+
+#include "trt_graph_executor.h"
+
+#include <onnx/onnx.pb.h>
+#include <NvInfer.h>
+#include "./onnx_to_tensorrt.h"
+#include "../operator/contrib/tensorrt-inl.h"
+#include "../common/utils.h"
+#include "../common/exec_utils.h"
+
+
+namespace mxnet {
+namespace exec {
+
+using namespace mxnet::common;
+
+  /*!
+ * \brief TrtGraphExecutor initializer for simple bind flow in
+ * which only certain input shapes and dtypes are provided by users.
+ * The initializer uses these shapes and dtypes to perform
+ * shape and dtype inferences, and then create NDArrays
+ * to populate data entries of the graph. The created NDArrays
+ * for in_args, arg_grads and aux_states are passed to the
+ * front end to attach the created executor.
+ * In front end, if the simple_bind flow is trigger by
+ * _bind_ith_exec, the shared data arrays of DataParallelExecutorGroup
+ * and shared executor will be taken into account in creating
+ * NDArrays for in_args, arg_grads, and aux_states for reusing
+ * already allocated memory.
+ *
+ * This version of an executor exports the computation graph to TensorRT make use of fused
+ * kernels and other runtime enhancements.  TRT will compile the sub-graphs to executable fused
+ * operators without intervention from the user.  Operators in the original graph that are not
+ * supported by TRT will continue to be executed normally by MXNet.
+ *
+ */
+void TrtGraphExecutor::Init(nnvm::Symbol symbol,
+                            const Context& default_ctx,
+                            const std::map<std::string, Context>& ctx_map,
+                            std::vector<Context> *in_arg_ctxes,
+                            std::vector<Context> *arg_grad_ctxes,
+                            std::vector<Context> *aux_state_ctxes,
+                            std::unordered_map<std::string, TShape> *arg_shape_map,
+                            std::unordered_map<std::string, int> *arg_dtype_map,
+                            std::unordered_map<std::string, int> *arg_stype_map,
+                            std::vector<OpReqType> *grad_req_types,
+                            const std::unordered_set<std::string>& shared_arg_names,
+                            std::vector<NDArray>* in_arg_vec,
+                            std::vector<NDArray>* arg_grad_vec,
+                            std::vector<NDArray>* aux_state_vec,
+                            std::unordered_map<std::string, NDArray>* shared_buffer,
+                            Executor* shared_exec,
+                            const nnvm::NodeEntryMap<NDArray>& feed_dict) {
+  symbol = symbol.Copy();
+  nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, *in_arg_ctxes, *arg_grad_ctxes,
+                            *aux_state_ctxes, *grad_req_types);
+
+  if (need_grad_) {
+    LOG(FATAL) << "You may be attempting to use TensorRT for training.  TensorRT is an inference "
+                  "only library.  To re-enable legacy MXNet graph execution, which will support "
+                  "training, set the MXNET_USE_TENSORRT environment variable to 0, or call "
+                  "mx.contrib.tensorrt.set_use_tensorrt(False)";
+  }
+
+  if (shared_buffer == nullptr || shared_buffer->empty()) {
+    LOG(FATAL) << "MXNET_USE_TENSORRT = 1 but shared_buffer is empty. "
+               << "Please provide weights and other parameters, such as "
+               << "BatchNorm moments, via the shared_buffer, during simple bind call.";
+  }
+
+  // The following code of shape and dtype inferences and argument
+  // initialization is for simple_bind only. Regular bind operation
+  // should do this differently.
+
+  // Initialize arg_shapes and arg_dtypes for shape and type inferences.
+  // It contains all in_args and aux_states' shapes and types in a certain order.
+  const nnvm::IndexedGraph& idx = g.indexed_graph();
+  nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape());
+  nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1);
+  StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage);
+  for (size_t i = 0; i < num_forward_inputs_; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const std::string& name = idx[nid].source->attrs.name;
+    auto it1 = arg_shape_map->find(name);
+    if (arg_shape_map->end() != it1) {
+      arg_shapes[i] = it1->second;
+    }
+    auto it2 = arg_dtype_map->find(name);
+    if (arg_dtype_map->end() != it2) {
+      arg_dtypes[i] = it2->second;
+    }
+    auto it3 = arg_stype_map->find(name);
+    if (arg_stype_map->end() != it3) {
+      arg_stypes[i] = it3->second;
+    }
+  }
+  g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+  if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+    HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
+                          g.GetAttr<nnvm::ShapeVector>("shape"));
+  }
+
+  g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+  if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
+    HandleInferTypeError(num_forward_inputs_, g.indexed_graph(),
+                         g.GetAttr<nnvm::DTypeVector>("dtype"));
+  }
+
+  g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
+  if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
+    HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(),
+                                g.GetAttr<StorageTypeVector>("storage_type"));
+  }
+
+  auto trt_groups = GetTrtCompatibleSubsets(g, shared_buffer);
+  for (auto trt_group : trt_groups) {
+    if (trt_group.size() > 1) {
+      g = ReplaceSubgraph(std::move(g), trt_group, shared_buffer);
+      g = ReinitGraph(std::move(g), default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes,
+                      aux_state_ctxes, grad_req_types, arg_shape_map, arg_dtype_map,
+                      arg_stype_map, shared_buffer);
+    }
+  }
+
+
+  InitArguments(g.indexed_graph(), g.GetAttr<nnvm::ShapeVector>("shape"),
+                g.GetAttr<nnvm::DTypeVector>("dtype"),
+                g.GetAttr<StorageTypeVector>("storage_type"),
+                *in_arg_ctxes, *arg_grad_ctxes, *aux_state_ctxes,
+                *grad_req_types, shared_arg_names, shared_exec,
+                shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec);
+
+  // The above code of shape and dtype inferences and argument
+  // initialization is for simple_bind only. Regular bind operation
+  // should do this differently.
+
+  // Initialize the rest attributes of the graph.
+  // This function can be called by regular bind
+  // operation flow as well.
+  FinishInitGraph(symbol, g, shared_exec, feed_dict);
+}
+/*!
+ * \brief Initialize in_args, arg_grads, and aux_states
+ * and their data_entry_ of the executor using
+ * shared_buffer from DataParallelExecutorGroup
+ * and shared_exec if available.
+ */
+void TrtGraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
+                                  const nnvm::ShapeVector& inferred_shapes,
+                                  const nnvm::DTypeVector& inferred_dtypes,
+                                  const StorageTypeVector& inferred_stypes,
+                                  const std::vector<Context>& in_arg_ctxes,
+                                  const std::vector<Context>& arg_grad_ctxes,
+                                  const std::vector<Context>& aux_state_ctxes,
+                                  const std::vector<OpReqType>& grad_req_types,
+                                  const std::unordered_set<std::string>& shared_arg_names,
+                                  const Executor* shared_exec,
+                                  std::unordered_map<std::string, NDArray>* shared_buffer,
+                                  std::vector<NDArray>* in_arg_vec,
+                                  std::vector<NDArray>* arg_grad_vec,
+                                  std::vector<NDArray>* aux_state_vec) {
+  // initialize in_args, arg_grads, and aux_states and populate grad_store_
+  data_entry_.resize(idx.num_node_entries());
+  size_t arg_top = 0, aux_top = 0;
+  const auto& mutable_nodes = idx.mutable_input_nodes();
+  for (size_t i = 0; i < num_forward_inputs_; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const uint32_t eid = idx.entry_id(nid, 0);
+    const TShape& inferred_shape = inferred_shapes[eid];
+    const int inferred_dtype = inferred_dtypes[eid];
+    const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid];
+    const std::string& arg_name = idx[nid].source->attrs.name;
+    // aux_states
+    if (mutable_nodes.count(nid)) {
+      if (nullptr != shared_exec) {
+        const NDArray& aux_nd = shared_exec->aux_state_map().at(arg_name);
+        CHECK(inferred_stype == kDefaultStorage && aux_nd.storage_type() == kDefaultStorage)
+          << "Non-default storage type detected when creating auxilliary NDArray. The allocated "
+          << "memory of shared_exec.aux_array cannot be resued for argument: "
+          << arg_name << " for the current executor";
+        CHECK_EQ(inferred_shape, aux_nd.shape())
+          << "Inferred shape does not match shared_exec.aux_array's shape."
+             " Therefore, the allocated memory for shared_exec.aux_array cannot"
+             " be resued for creating auxilliary NDArray of the argument: "
+          << arg_name << " for the current executor";
+        CHECK_EQ(inferred_dtype, aux_nd.dtype())
+          << "Inferred dtype does not match shared_exec.aux_array's dtype."
+             " Therefore, the allocated memory for shared_exec.aux_array cannot"
+             " be resued for creating auxilliary NDArray of the argument: "
+          << arg_name << " for the current executor";
+        aux_state_vec->emplace_back(aux_nd);
+      } else {
+        auto it = shared_buffer->find(arg_name);
+        if (it != shared_buffer->end()) {
+          aux_state_vec->push_back(std::move(it->second.Copy(aux_state_ctxes[aux_top])));
+        } else {
+          aux_state_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape,
+                                                       aux_state_ctxes[aux_top], inferred_dtype)));
+        }
+      }  // if (has_shared_exec)
+      data_entry_[eid] = aux_state_vec->back();
+      aux_state_map_.emplace(arg_name, aux_state_vec->back());
+      ++aux_top;
+    } else {  // in_args and grad for in_args
+      if (shared_arg_names.count(arg_name)) {  // model parameter
+        // model parameter
+        if (nullptr != shared_exec) {
+          const NDArray& in_arg_nd = shared_exec->in_arg_map().at(arg_name);
+          auto arg_nd_stype = in_arg_nd.storage_type();
+          // for model parameter, both default storage and row_sparse storage can be shared
+          bool shareable_arg_stype = inferred_stype == kDefaultStorage ||
+                                     inferred_stype == kRowSparseStorage;
+          // try to reuse memory from shared_exec
+          CHECK(shareable_arg_stype) << "Inferred storage type "
+            << common::stype_string(inferred_stype)
+            << " does not support memory sharing with shared_exec.arg_array";
+          CHECK_EQ(inferred_stype, arg_nd_stype)
+            << "Inferred stype does not match shared_exec.arg_array's stype"
+               " Therefore, the allocated memory for shared_exec.arg_array cannot"
+               " be resued for creating NDArray of the argument "
+            << arg_name << " for the current executor";
+          CHECK_EQ(inferred_shape, in_arg_nd.shape())
+            << "Inferred shape does not match shared_exec.arg_array's shape"
+               " Therefore, the allocated memory for shared_exec.arg_array cannot"
+               " be resued for creating NDArray of the argument "
+            << arg_name << " for the current executor";
+          CHECK_EQ(inferred_dtype, in_arg_nd.dtype())
+            << "Inferred dtype does not match shared_exec.arg_array's dtype"
+               " Therefore, the allocated memory for shared_exec.arg_array cannot"
+               " be resued for creating NDArray of the argument "
+            << arg_name << " for the current executor";
+          in_arg_vec->emplace_back(in_arg_nd);
+        } else {
+          // doesn't have shared_exec, or non-default storage
+          EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top],
+                           inferred_dtype, in_arg_vec);
+        }
+        // gradient for model parameter
+        if (kNullOp == grad_req_types[arg_top]) {
+          arg_grad_vec->emplace_back();
+        } else {
+          auto grad_oid = grad_store_.size() + num_forward_outputs_;
+          auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]);
+          auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid];
+          if (nullptr != shared_exec && grad_stype == kDefaultStorage &&
+              shared_exec->arg_grad_map().at(arg_name).storage_type() == kDefaultStorage) {
+            // try to reuse memory from shared_exec
+            arg_grad_vec->emplace_back(shared_exec->arg_grad_map().at(arg_name));
+          } else {
+            // no need to reuse memory from shared_exec for gradient of non-default storage
+            EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top],
+                             inferred_dtype, arg_grad_vec);
+          }
+          grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back());
+        }
+      } else {  // !shared_arg_names.count(arg_name)
+        // model parameter, row_sparse ndarray sharing enabled
+        auto it = shared_buffer->find(arg_name);
+        if (it != shared_buffer->end()) {
+          in_arg_vec->push_back(std::move(it->second.Copy(in_arg_ctxes[arg_top])));
+        } else {
+          in_arg_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape,
+                                                    in_arg_ctxes[arg_top], inferred_dtype)));
+        }
+        // gradient for model parameter, row_sparse ndarray sharing disabled
+        if (kNullOp == grad_req_types[arg_top]) {
+          arg_grad_vec->emplace_back();
+        } else {
+          auto grad_oid = grad_store_.size() + num_forward_outputs_;
+          auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]);
+          auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid];
+          bool enable_row_sparse_sharing = false;
+          arg_grad_vec->emplace_back(ReshapeOrCreate("grad of " + arg_name, inferred_shape,
+                                                     inferred_dtype, grad_stype,
+                                                     arg_grad_ctxes[arg_top], shared_buffer,
+                                                     enable_row_sparse_sharing));
+          grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back());
+        }  // if (kNullOp == grad_req_types[arg_top])
+      }  // if (shared_arg_names.count(arg_name))
+      in_arg_map_.emplace(arg_name, in_arg_vec->back());
+      if (!arg_grad_vec->back().is_none()) {
+        arg_grad_map_.emplace(arg_name, arg_grad_vec->back());
+      }
+      data_entry_[eid] = in_arg_vec->back();
+      ++arg_top;
+    }
+  }
+}
+
+
+  /*!
+ * \brief This function is triggered after each tensorrt subgraph replacement pass.
+ * Reset arguments of GraphExecutor::Init(...) as some variables (weights and biases)
+ * are absorbed into the TRT engine it also it reruns attributes inferences accordingly
+ * to the new topology.
+ */
+Graph TrtGraphExecutor::ReinitGraph(Graph&& g, const Context &default_ctx,
+                                 const std::map<std::string, Context> &ctx_map,
+                                 std::vector<Context> *in_arg_ctxes,
+                                 std::vector<Context> *arg_grad_ctxes,
+                                 std::vector<Context> *aux_state_ctxes,
+                                 std::vector<OpReqType> *grad_req_types,
+                                 std::unordered_map<std::string, TShape> *arg_shape_map,
+                                 std::unordered_map<std::string, int> *arg_dtype_map,
+                                 std::unordered_map<std::string, int> *arg_stype_map,
+                                 std::unordered_map<std::string, NDArray> *params_map) {
+  std::unordered_set<std::string> to_remove_params;
+  for (auto& el : *params_map) {
+    to_remove_params.insert(el.first);
+  }
+
+  DFSVisit(g.outputs, [&to_remove_params](const nnvm::NodePtr n) {
+    to_remove_params.erase(n->attrs.name);
+  });
+
+  for (auto& el : to_remove_params) {
+    params_map->erase(el);
+    arg_shape_map->erase(el);
+    arg_dtype_map->erase(el);
+    arg_stype_map->erase(el);
+  }
+  const auto &idx = g.indexed_graph();
+  num_forward_inputs_ = idx.input_nodes().size();
+  in_arg_ctxes->resize(num_forward_inputs_ - idx.mutable_input_nodes().size());
+  arg_grad_ctxes->resize(num_forward_inputs_ - idx.mutable_input_nodes().size());
+  grad_req_types->resize(num_forward_inputs_ - idx.mutable_input_nodes().size());
+  aux_state_ctxes->resize(idx.mutable_input_nodes().size());
+
+  // create "device" and "context" attrs for the graph
+  g = AssignContext(g, default_ctx, ctx_map, *in_arg_ctxes, *arg_grad_ctxes,
+                    *aux_state_ctxes, *grad_req_types, num_forward_inputs_,
+                    num_forward_outputs_);
+
+  // get number of nodes used in forward pass
+  num_forward_nodes_ = 0;
+  for (size_t i = 0; i < num_forward_outputs_; ++i) {
+    num_forward_nodes_ = std::max(
+        num_forward_nodes_, static_cast<size_t>(idx.outputs()[i].node_id + 1));
+  }
+  nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape());
+  nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1);
+  StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage);
+  for (size_t i = 0; i < num_forward_inputs_; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const std::string &name = idx[nid].source->attrs.name;
+    auto it1 = arg_shape_map->find(name);
+    if (arg_shape_map->end() != it1) {
+      arg_shapes[i] = it1->second;
+    }
+    auto it2 = arg_dtype_map->find(name);
+    if (arg_dtype_map->end() != it2) {
+      arg_dtypes[i] = it2->second;
+    }
+    auto it3 = arg_stype_map->find(name);
+    if (arg_stype_map->end() != it3) {
+      arg_stypes[i] = it3->second;
+    }
+  }
+  g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+  if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+    HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
+                          g.GetAttr<nnvm::ShapeVector>("shape"));
+  }
+
+  g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+  if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
+    HandleInferTypeError(num_forward_inputs_, g.indexed_graph(),
+                         g.GetAttr<nnvm::DTypeVector>("dtype"));
+  }
+
+  g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
+
+  if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
+    HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(),
+                                g.GetAttr<StorageTypeVector>("storage_type"));
+  }
+
+  return g;
+}
+
+
+/*!
+ * \brief Return the "optimized" symbol contained in the graph.
+ * For optimization pass such as TensorRT pass
+ */
+nnvm::Symbol TrtGraphExecutor::GetOptimizedSymbol() {
+  Symbol ret;
+  ret.outputs = std::vector<nnvm::NodeEntry>(graph_.outputs.begin(),
+                                             graph_.outputs.begin() + num_forward_outputs_);
+  ret = ret.Copy();
+  static const Op* trt_op = Op::Get("_trt_op");
+  DFSVisit(ret.outputs, [](const nnvm::NodePtr n) {
+    if (n->op() == trt_op) {
+      n->attrs.dict.clear();
+    }
+  });
+  return ret;
+}
+
+Executor *TrtGraphExecutor::TensorRTBind(nnvm::Symbol symbol,
+                                         const Context &default_ctx,
+                                         const std::map<std::string, Context> &group2ctx,
+                                         std::vector<Context> *in_arg_ctxes,
+                                         std::vector<Context> *arg_grad_ctxes,
+                                         std::vector<Context> *aux_state_ctxes,
+                                         std::unordered_map<std::string, TShape> *arg_shape_map,
+                                         std::unordered_map<std::string, int> *arg_dtype_map,
+                                         std::unordered_map<std::string, int> *arg_stype_map,
+                                         std::vector<OpReqType> *grad_req_types,
+                                         const std::unordered_set<std::string> &param_names,
+                                         std::vector<NDArray> *in_args,
+                                         std::vector<NDArray> *arg_grads,
+                                         std::vector<NDArray> *aux_states,
+                                         std::unordered_map<std::string, NDArray> *shared_buffer,
+                                         Executor *shared_exec) {
+  auto exec = new exec::TrtGraphExecutor();
+  exec->Init(symbol, default_ctx, group2ctx,
+             in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
+             arg_shape_map, arg_dtype_map, arg_stype_map,
+             grad_req_types, param_names,
+             in_args, arg_grads, aux_states,
+             shared_buffer, shared_exec);
+  return exec;
+}
+
+}  // namespace exec
+
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/executor/trt_graph_executor.h b/src/executor/trt_graph_executor.h
new file mode 100644
index 0000000..96ac442
--- /dev/null
+++ b/src/executor/trt_graph_executor.h
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_
+#define MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_
+
+#if MXNET_USE_TENSORRT
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "./graph_executor.h"
+
+namespace mxnet {
+
+namespace exec {
+
+class TrtGraphExecutor : public GraphExecutor {
+ public:
+  static Executor* TensorRTBind(nnvm::Symbol symbol,
+                                const Context& default_ctx,
+                                const std::map<std::string, Context>& group2ctx,
+                                std::vector<Context> *in_arg_ctxes,
+                                std::vector<Context>* arg_grad_ctxes,
+                                std::vector<Context>* aux_state_ctxes,
+                                std::unordered_map<std::string, TShape>* arg_shape_map,
+                                std::unordered_map<std::string, int>* arg_dtype_map,
+                                std::unordered_map<std::string, int>* arg_stype_map,
+                                std::vector<OpReqType>* grad_req_types,
+                                const std::unordered_set<std::string>& param_names,
+                                std::vector<NDArray>* in_args,
+                                std::vector<NDArray>* arg_grads,
+                                std::vector<NDArray>* aux_states,
+                                std::unordered_map<std::string, NDArray>*
+                                shared_data_arrays = nullptr,
+                                Executor* shared_exec = nullptr);
+
+  virtual void Init(nnvm::Symbol symbol,
+                    const Context& default_ctx,
+                    const std::map<std::string, Context>& ctx_map,
+                    std::vector<Context> *in_arg_ctxes,
+                    std::vector<Context> *arg_grad_ctxes,
+                    std::vector<Context> *aux_state_ctxes,
+                    std::unordered_map<std::string, TShape> *arg_shape_map,
+                    std::unordered_map<std::string, int> *arg_dtype_map,
+                    std::unordered_map<std::string, int> *arg_stype_map,
+                    std::vector<OpReqType> *grad_req_types,
+                    const std::unordered_set<std::string>& shared_arg_names,
+                    std::vector<NDArray>* in_arg_vec,
+                    std::vector<NDArray>* arg_grad_vec,
+                    std::vector<NDArray>* aux_state_vec,
+                    std::unordered_map<std::string, NDArray>* shared_buffer = nullptr,
+                    Executor* shared_exec = nullptr,
+                    const nnvm::NodeEntryMap<NDArray>& feed_dict
+                      = nnvm::NodeEntryMap<NDArray>());
+
+  // Returns symbol representing the TRT optimized graph for comparison purposes.
+  nnvm::Symbol GetOptimizedSymbol();
+
+ protected:
+  Graph ReinitGraph(Graph&& g, const Context &default_ctx,
+        const std::map<std::string, Context> &ctx_map,
+        std::vector<Context> *in_arg_ctxes,
+        std::vector<Context> *arg_grad_ctxes,
+        std::vector<Context> *aux_state_ctxes,
+        std::vector<OpReqType> *grad_req_types,
+        std::unordered_map<std::string, TShape> *arg_shape_map,
+        std::unordered_map<std::string, int> *arg_dtype_map,
+        std::unordered_map<std::string, int> *arg_stype_map,
+        std::unordered_map<std::string, NDArray> *params_map);
+
+  void InitArguments(const nnvm::IndexedGraph& idx,
+                     const nnvm::ShapeVector& inferred_shapes,
+                     const nnvm::DTypeVector& inferred_dtypes,
+                     const StorageTypeVector& inferred_stypes,
+                     const std::vector<Context>& in_arg_ctxes,
+                     const std::vector<Context>& arg_grad_ctxes,
+                     const std::vector<Context>& aux_state_ctxes,
+                     const std::vector<OpReqType>& grad_req_types,
+                     const std::unordered_set<std::string>& shared_arg_names,
+                     const Executor* shared_exec,
+                     std::unordered_map<std::string, NDArray>* shared_buffer,
+                     std::vector<NDArray>* in_arg_vec,
+                     std::vector<NDArray>* arg_grad_vec,
+                     std::vector<NDArray>* aux_state_vec) override;
+};
+
+}  // namespace exec
+
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
+
+#endif  // MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_
diff --git a/src/operator/contrib/nnvm_to_onnx-inl.h b/src/operator/contrib/nnvm_to_onnx-inl.h
new file mode 100644
index 0000000..58f88b0
--- /dev/null
+++ b/src/operator/contrib/nnvm_to_onnx-inl.h
@@ -0,0 +1,156 @@
+#ifndef MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_
+#define MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file tensorrt-inl.h
+ * \brief TensorRT Operator
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include <dmlc/logging.h>
+#include <dmlc/memory_io.h>
+#include <dmlc/serializer.h>
+#include <dmlc/parameter.h>
+#include <mxnet/base.h>
+#include <mxnet/operator.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+
+#include <NvInfer.h>
+#include <onnx/onnx.pb.h>
+
+#include <algorithm>
+#include <iostream>
+#include <map>
+#include <vector>
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+#include <string>
+
+#include "./tensorrt-inl.h"
+#include "../operator_common.h"
+#include "../../common/utils.h"
+#include "../../common/serialization.h"
+
+namespace mxnet {
+namespace op {
+namespace nnvm_to_onnx {
+
+using namespace nnvm;
+using namespace ::onnx;
+using int64 = ::google::protobuf::int64;
+
+std::unordered_map<std::string, TShape> GetPlaceholderShapes(const ShapeVector& shape_inputs,
+    const nnvm::IndexedGraph& ig);
+
+std::unordered_map<std::string, uint32_t> GetOutputLookup(const nnvm::IndexedGraph& ig);
+
+void ConvertPlaceholder(
+  const std::string& node_name,
+  const std::unordered_map<std::string, TShape>& placeholder_shapes,
+  GraphProto* const graph_proto);
+
+void ConvertConstant(GraphProto* const graph_proto,
+  const std::string& node_name,
+  std::unordered_map<std::string, NDArray>* const shared_buffer);
+
+void ConvertOutput(op::tensorrt::InferenceMap_t* const trt_output_map,
+                   GraphProto* const graph_proto,
+                   const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
+                   const std::string& node_name,
+                   const nnvm::Graph& g,
+                   const StorageTypeVector& storage_types,
+                   const DTypeVector& dtypes);
+
+typedef void (*ConverterFunction)(NodeProto *node_proto,
+                                  const NodeAttrs &attrs,
+                                  const nnvm::IndexedGraph &ig,
+                                  const array_view<IndexedGraph::NodeEntry> &inputs);
+
+
+// Forward declarations
+void ConvertConvolution(
+                        NodeProto *node_proto,
+                        const NodeAttrs &attrs,
+                        const nnvm::IndexedGraph &ig,
+                        const array_view<IndexedGraph::NodeEntry> &inputs);
+
+
+void ConvertPooling(NodeProto *node_proto,
+                    const NodeAttrs &attrs,
+                    const nnvm::IndexedGraph &ig,
+                    const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertActivation(NodeProto *node_proto,
+                       const NodeAttrs &attrs,
+                       const nnvm::IndexedGraph &ig,
+                       const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertFullyConnected(NodeProto *node_proto,
+                           const NodeAttrs &attrs,
+                           const nnvm::IndexedGraph &ig,
+                           const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertSoftmaxOutput(NodeProto *node_proto,
+                          const NodeAttrs &attrs,
+                          const nnvm::IndexedGraph &ig,
+                          const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertFlatten(NodeProto *node_proto,
+                    const NodeAttrs &attrs,
+                    const nnvm::IndexedGraph &ig,
+                    const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertBatchNorm(NodeProto *node_proto,
+                    const NodeAttrs &attrs,
+                    const nnvm::IndexedGraph &ig,
+                    const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertElementwiseAdd(NodeProto *node_proto,
+                    const NodeAttrs &attrs,
+                    const nnvm::IndexedGraph &ig,
+                    const array_view<IndexedGraph::NodeEntry> &inputs);
+
+TRTParam ConvertNnvmGraphToOnnx(
+    const nnvm::Graph &g,
+    std::unordered_map<std::string, NDArray> *const shared_buffer);
+
+static const std::unordered_map<std::string, ConverterFunction> converter_map = {
+  {"Convolution", ConvertConvolution},
+  {"Pooling", ConvertPooling},
+  {"Activation", ConvertActivation},
+  {"FullyConnected", ConvertFullyConnected},
+  {"SoftmaxOutput", ConvertSoftmaxOutput},
+  {"Flatten", ConvertFlatten},
+  {"BatchNorm", ConvertBatchNorm},
+  {"elemwise_add", ConvertElementwiseAdd}};
+
+}  // namespace nnvm_to_onnx
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
+
+#endif  // MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_
diff --git a/src/operator/contrib/nnvm_to_onnx.cc b/src/operator/contrib/nnvm_to_onnx.cc
new file mode 100644
index 0000000..9024666
--- /dev/null
+++ b/src/operator/contrib/nnvm_to_onnx.cc
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file trt.cc
+ * \brief TensorRT operation registration
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include "./nnvm_to_onnx-inl.h"
+
+#include <mxnet/base.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+
+#include <algorithm>
+#include <fstream>
+#include <iostream>
+#include <unordered_map>
+#include <vector>
+
+#include "../../common/serialization.h"
+#include "../../common/utils.h"
+#include "../../ndarray/ndarray_function.h"
+#include "../../operator/nn/activation-inl.h"
+#include "../../operator/nn/batch_norm-inl.h"
+#include "../../operator/nn/convolution-inl.h"
+#include "../../operator/nn/fully_connected-inl.h"
+#include "../../operator/nn/pooling-inl.h"
+#include "../../operator/softmax_output-inl.h"
+#include "./tensorrt-inl.h"
+
+#if MXNET_USE_TENSORRT_ONNX_CHECKER
+#include <onnx/checker.h>
+#endif  // MXNET_USE_TENSORRT_ONNX_CHECKER
+
+namespace mxnet {
+namespace op {
+namespace nnvm_to_onnx {
+
+op::TRTParam ConvertNnvmGraphToOnnx(
+    const nnvm::Graph& g,
+    std::unordered_map<std::string, NDArray>* const shared_buffer) {
+    op::TRTParam trt_param;
+    op::tensorrt::NameToIdx_t trt_input_map;
+    op::tensorrt::InferenceMap_t trt_output_map;
+
+  const nnvm::IndexedGraph& ig = g.indexed_graph();
+  const auto& storage_types = g.GetAttr<StorageTypeVector>("storage_type");
+  const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
+  const auto& shape_inputs = g.GetAttr<ShapeVector>("shape_inputs");
+
+  for (auto& e : storage_types) {
+    if (e != mshadow::kFloat32) {
+      LOG(FATAL) << "ONNX converter does not support types other than float32 "
+                    "right now.";
+    }
+  }
+
+  ModelProto model_proto;
+  // Need to determine IR versions and features to support
+  model_proto.set_ir_version(static_cast<int64>(2));
+  GraphProto* graph_proto = model_proto.mutable_graph();
+
+  std::unordered_map<std::string, TShape> placeholder_shapes =
+      GetPlaceholderShapes(shape_inputs, ig);
+  std::unordered_map<std::string, uint32_t> output_lookup = GetOutputLookup(ig);
+  uint32_t current_input = 0;
+
+  // Can't do a foreach over IndexedGraph since it doesn't implement begin(), etc.
+  for (uint32_t node_idx = 0; node_idx < ig.num_nodes(); ++node_idx) {
+    const IndexedGraph::Node& node = ig[node_idx];
+    const nnvm::Node* source = node.source;
+    const NodeAttrs& attrs = source->attrs;
+    const Op* op = source->op();
+
+    std::string node_name = attrs.name;
+    // Here, "variable" actually means anything that's not an op i.e. a constant (weights) or a
+    // placeholder
+    if (source->is_variable()) {
+      // Is this a placeholder?
+      if (shared_buffer->count(node_name) == 0) {
+        // This fixes the problem with a SoftmaxOutput node during inference, but it's hacky.
+        // Need to figure out how to properly fix it.
+        if (node_name.find("label") != std::string::npos) {
+          current_input++;
+          continue;
+        }
+        trt_input_map.emplace(node_name, current_input++);
+        ConvertPlaceholder(node_name, placeholder_shapes, graph_proto);
+      } else {
+        // If it's not a placeholder, then by exclusion it's a constant.
+        ConvertConstant(graph_proto, node_name, shared_buffer);
+      }  // is_placeholder
+    } else {
+      // It's an op, rather than a "variable" (constant or placeholder)
+      NodeProto* node_proto = graph_proto->add_node();
+      node_proto->set_name(node_name);
+      if (converter_map.count(op->name) == 0) {
+        LOG(FATAL) << "Conversion for node of type " << op->name << " (node "
+                   << node_name << ") "
+                   << " is not supported yet.";
+      }
+      // Find function ptr to a converter based on the op name, and invoke the converter. This
+      // looks unsafe because find may not succeed, but it does because we're in the operator
+      // logic after testing that this node name does not represent a variable.
+      converter_map.find(op->name)->second(node_proto, attrs, ig, node.inputs);
+      // Add all inputs to the current node (i.e. add graph edges)
+      for (const nnvm::IndexedGraph::NodeEntry& entry : node.inputs) {
+        std::string in_node_name = ig[entry.node_id].source->attrs.name;
+        // As before, we're not adding labels e.g. for SoftmaxOutput, but I wish there was a less
+        // hacky way to do it than name matching.
+        if (in_node_name.find("label") != std::string::npos) {
+          continue;
+        }
+        node_proto->add_input(in_node_name);
+      }
+      // The node's output will have the same name as the node name.
+      node_proto->add_output(node_name);
+      // See if the current node is an output node
+      auto out_iter = output_lookup.find(node_name);
+      // We found an output
+      if (out_iter != output_lookup.end()) {
+        ConvertOutput(&trt_output_map, graph_proto, out_iter, node_name, g,
+                      storage_types, dtypes);
+      }  // output found
+    }    // conversion function exists
+  }      // loop over i from 0 to num_nodes
+
+  model_proto.SerializeToString(&trt_param.serialized_onnx_graph);
+  common::Serialize<op::tensorrt::NameToIdx_t>(trt_input_map,
+                                          &trt_param.serialized_input_map);
+  common::Serialize<op::tensorrt::InferenceMap_t>(trt_output_map,
+                                             &trt_param.serialized_output_map);
+
+#if MXNET_USE_TENSORRT_ONNX_CHECKER
+  onnx::checker::check_model(model_proto);
+#endif  // MXNET_USE_TENSORRT_ONNX_CHECKER
+
+  return trt_param;
+}
+
+void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
+                        const nnvm::IndexedGraph& /*ig*/,
+                        const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  const auto& conv_param = nnvm::get<op::ConvolutionParam>(attrs.parsed);
+
+  node_proto->set_op_type("Conv");
+
+  const TShape kernel = conv_param.kernel;
+  const TShape stride = conv_param.stride;
+  const TShape dilate = conv_param.dilate;
+  const TShape pad = conv_param.pad;
+  const uint32_t num_group = conv_param.num_group;
+  // const bool no_bias = conv_param.no_bias;
+  const dmlc::optional<int> layout = conv_param.layout;
+
+  // kernel shape
+  AttributeProto* const kernel_shape = node_proto->add_attribute();
+  kernel_shape->set_name("kernel_shape");
+  kernel_shape->set_type(AttributeProto::INTS);
+
+  for (const dim_t kval : kernel) {
+    kernel_shape->add_ints(static_cast<int64>(kval));
+  }
+
+  // pads
+  AttributeProto* const pads = node_proto->add_attribute();
+  pads->set_name("pads");
+  pads->set_type(AttributeProto::INTS);
+
+  for (const dim_t kval : pad) {
+    pads->add_ints(static_cast<int64>(kval));
+    pads->add_ints(static_cast<int64>(kval));
+  }
+
+  // dilations
+  AttributeProto* const dilations = node_proto->add_attribute();
+  dilations->set_name("dilations");
+  dilations->set_type(AttributeProto::INTS);
+  for (const dim_t kval : dilate) {
+    dilations->add_ints(static_cast<int64>(kval));
+  }
+
+  // strides
+  AttributeProto* const strides = node_proto->add_attribute();
+  strides->set_name("strides");
+  strides->set_type(AttributeProto::INTS);
+  for (const dim_t kval : stride) {
+    strides->add_ints(static_cast<int64>(kval));
+  }
+
+  // group
+  AttributeProto* const group = node_proto->add_attribute();
+  group->set_name("group");
+  group->set_type(AttributeProto::INT);
+  group->set_i(static_cast<int64>(num_group));
+}  // end ConvertConvolution
+
+void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
+                    const nnvm::IndexedGraph& /*ig*/,
+                    const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  const auto& pooling_param = nnvm::get<op::PoolingParam>(attrs.parsed);
+
+  const TShape kernel = pooling_param.kernel;
+  const TShape stride = pooling_param.stride;
+  const TShape pad = pooling_param.pad;
+  const int pool_type = pooling_param.pool_type;
+  const bool global_pool = pooling_param.global_pool;
+
+  if (global_pool) {
+    if (pool_type == 0) {
+      node_proto->set_op_type("GlobalMaxPool");
+    } else {
+      node_proto->set_op_type("GlobalAveragePool");
+    }
+    return;
+  }
+
+  // kernel_shape
+  AttributeProto* const kernel_shape = node_proto->add_attribute();
+  kernel_shape->set_name("kernel_shape");
+  kernel_shape->set_type(AttributeProto::INTS);
+  for (int kval : kernel) {
+    kernel_shape->add_ints(static_cast<int64>(kval));
+  }
+
+  // pads
+  AttributeProto* const pads = node_proto->add_attribute();
+  pads->set_name("pads");
+  pads->set_type(AttributeProto::INTS);
+  for (int kval : pad) {
+    pads->add_ints(static_cast<int64>(kval));
+  }
+
+  // strides
+  AttributeProto* const strides = node_proto->add_attribute();
+  strides->set_name("strides");
+  strides->set_type(AttributeProto::INTS);
+  for (int kval : stride) {
+    strides->add_ints(static_cast<int64>(kval));
+  }
+
+  if (pool_type == 0) {
+    node_proto->set_op_type("MaxPool");
+  } else {
+    node_proto->set_op_type("AveragePool");
+  }  // average pooling
+  // not global pooling
+}  // end ConvertPooling
+
+void ConvertActivation(NodeProto* node_proto, const NodeAttrs& attrs,
+                       const nnvm::IndexedGraph& /*ig*/,
+                       const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  const auto& act_param = nnvm::get<op::ActivationParam>(attrs.parsed);
+  std::string act_type;
+  switch (act_param.act_type) {
+    case op::activation::kReLU:
+      act_type = "Relu";
+      break;
+    case op::activation::kSigmoid:
+      act_type = "Sigmoid";
+      break;
+    case op::activation::kTanh:
+      act_type = "Tanh";
+      break;
+    case op::activation::kSoftReLU:
+      // act_type = "SoftReLU";
+      throw dmlc::Error("SoftReLU is not supported in ONNX");
+      break;
+    default:
+      throw dmlc::Error("Activation of such type doesn't exist");
+  }
+
+  node_proto->set_op_type(act_type);
+}
+
+void ConvertFullyConnected(NodeProto* node_proto, const NodeAttrs& attrs,
+                           const nnvm::IndexedGraph& /*ig*/,
+                           const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  const auto& act_param = nnvm::get<op::FullyConnectedParam>(attrs.parsed);
+  if (act_param.no_bias) {
+      node_proto->set_op_type("MatMul");
+  } else {
+      node_proto->set_op_type("Gemm");
+
+      AttributeProto* const alpha = node_proto->add_attribute();
+      alpha->set_name("alpha");
+      alpha->set_type(AttributeProto::FLOAT);
+      alpha->set_f(1.0f);
+
+      AttributeProto* const beta = node_proto->add_attribute();
+      beta->set_name("beta");
+      beta->set_type(AttributeProto::FLOAT);
+      beta->set_f(1.0f);
+
+      AttributeProto* const broadcast = node_proto->add_attribute();
+      broadcast->set_name("broadcast");
+      broadcast->set_type(AttributeProto::INT);
+      broadcast->set_i(1);
+
+      AttributeProto* const transA = node_proto->add_attribute();
+      transA->set_name("transA");
+      transA->set_type(AttributeProto::INT);
+      transA->set_i(0);
+
+      AttributeProto* const transB = node_proto->add_attribute();
+      transB->set_name("transB");
+      transB->set_type(AttributeProto::INT);
+      transB->set_i(1);
+  }
+}
+
+void ConvertSoftmaxOutput(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
+                          const nnvm::IndexedGraph& /*ig*/,
+                          const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  node_proto->set_op_type("Softmax");
+
+  // Setting by default to 1 since MXNet doesn't provide such an attribute for softmax in its
+  // node params. This attribute is only relevant when the input is coerced to 2D, and in that
+  // case dimension 0 is assumed to be the batch dimension.
+  AttributeProto* const axis = node_proto->add_attribute();
+  axis->set_name("axis");
+  axis->set_type(AttributeProto::INT);
+  axis->set_i(1);
+}
+
+void ConvertFlatten(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
+                    const nnvm::IndexedGraph& /*ig*/,
+                    const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  node_proto->set_op_type("Flatten");
+
+  // Setting by default to 1 since MXNet doesn't provide such an attribute for Flatten in its
+  // node params. This attribute is only relevant when the input is coerced to 2D, and in that
+  // case dimension 0 is assumed to be the batch dimension.
+  AttributeProto* const axis = node_proto->add_attribute();
+  axis->set_name("axis");
+  axis->set_type(AttributeProto::INT);
+  axis->set_i(1);
+}
+
+void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs,
+                      const nnvm::IndexedGraph& /*ig*/,
+                      const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  node_proto->set_op_type("BatchNormalization");
+  const auto& param = nnvm::get<op::BatchNormParam>(attrs.parsed);
+
+  AttributeProto* const epsilon = node_proto->add_attribute();
+  epsilon->set_name("epsilon");
+  epsilon->set_type(AttributeProto::FLOAT);
+  epsilon->set_f(static_cast<float>(param.eps));
+
+  AttributeProto* const is_test = node_proto->add_attribute();
+  is_test->set_name("is_test");
+  is_test->set_type(AttributeProto::INT);
+  is_test->set_i(1);
+
+  AttributeProto* const momentum = node_proto->add_attribute();
+  momentum->set_name("momentum");
+  momentum->set_type(AttributeProto::FLOAT);
+  momentum->set_f(param.momentum);
+
+  AttributeProto* const spatial = node_proto->add_attribute();
+  spatial->set_name("spatial");
+  spatial->set_type(AttributeProto::INT);
+  spatial->set_i(1);
+
+  AttributeProto* const consumed = node_proto->add_attribute();
+  consumed->set_name("consumed_inputs");
+  consumed->set_type(AttributeProto::INTS);
+
+  for (int i = 0; i < 5; i++) {
+    int val = (i < 3) ? 0 : 1;
+    consumed->add_ints(static_cast<int64>(val));
+  }
+}
+
+void ConvertElementwiseAdd(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
+                           const nnvm::IndexedGraph& /*ig*/,
+                           const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  node_proto->set_op_type("Add");
+  AttributeProto* const axis = node_proto->add_attribute();
+  axis->set_name("axis");
+  axis->set_type(AttributeProto::INT);
+  axis->set_i(1);
+
+  AttributeProto* const broadcast = node_proto->add_attribute();
+  broadcast->set_name("broadcast");
+  broadcast->set_type(AttributeProto::INT);
+  broadcast->set_i(0);  // 1
+}
+
+std::unordered_map<std::string, TShape> GetPlaceholderShapes(
+    const ShapeVector& shape_inputs, const nnvm::IndexedGraph& ig) {
+  std::unordered_map<std::string, TShape> placeholder_shapes;
+  for (uint32_t i = 0; i < shape_inputs.size(); ++i) {
+    std::string name = ig[ig.input_nodes()[i]].source->attrs.name;
+    TShape shp = shape_inputs[i];
+    if (shp.ndim() > 0) {
+      placeholder_shapes.emplace(name, shp);
+    }
+  }
+  return placeholder_shapes;
+}
+
+std::unordered_map<std::string, uint32_t> GetOutputLookup(
+    const nnvm::IndexedGraph& ig) {
+  std::unordered_map<std::string, uint32_t> output_lookup;
+  const std::vector<nnvm::IndexedGraph::NodeEntry>& graph_outputs =
+      ig.outputs();
+  for (uint32_t i = 0; i < graph_outputs.size(); ++i) {
+    const uint32_t id = graph_outputs[i].node_id;
+    const IndexedGraph::Node ig_node = ig[id];
+    const nnvm::Node* const source = ig_node.source;
+    const std::string name = source->attrs.name;
+    output_lookup.emplace(name, i);
+  }
+  return output_lookup;
+}
+
+void ConvertPlaceholder(
+    const std::string& node_name,
+    const std::unordered_map<std::string, TShape>& placeholder_shapes,
+    GraphProto* const graph_proto) {
+  auto val_info_proto = graph_proto->add_input();
+  auto type_proto = val_info_proto->mutable_type()->mutable_tensor_type();
+  auto shape_proto = type_proto->mutable_shape();
+
+  val_info_proto->set_name(node_name);
+  // Will support fp16, etc. in the near future
+  type_proto->set_elem_type(TensorProto_DataType_FLOAT);
+  auto entry_shape = placeholder_shapes.find(node_name)->second;
+
+  for (const auto& elem : entry_shape) {
+    TensorShapeProto_Dimension* const tsp_dim = shape_proto->add_dim();
+    tsp_dim->set_dim_value(static_cast<int64>(elem));
+  }
+}
+
+void ConvertConstant(
+    GraphProto* const graph_proto, const std::string& node_name,
+    std::unordered_map<std::string, NDArray>* const shared_buffer) {
+  NodeProto* const node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
+  node_proto->add_output(node_name);
+  node_proto->set_op_type("Constant");
+
+  const NDArray nd = shared_buffer->find(node_name)->second;
+  const TBlob& blob = nd.data();
+  const TShape shape = blob.shape_;
+  const int32_t size = shape.Size();
+
+  std::shared_ptr<float> shared_data_ptr(new float[size]);
+  float* const data_ptr = shared_data_ptr.get();
+  nd.SyncCopyToCPU(static_cast<void*>(data_ptr), size);
+
+  AttributeProto* const tensor_attr = node_proto->add_attribute();
+  tensor_attr->set_name("value");
+  tensor_attr->set_type(AttributeProto::TENSOR);
+
+  TensorProto* const tensor_proto = tensor_attr->mutable_t();
+  tensor_proto->set_data_type(TensorProto_DataType_FLOAT);
+  for (auto& dim : shape) {
+    tensor_proto->add_dims(static_cast<int64>(dim));
+  }
+
+  for (int blob_idx = 0; blob_idx < size; ++blob_idx) {
+    tensor_proto->add_float_data(data_ptr[blob_idx]);
+  }
+}
+
+void ConvertOutput(
+    op::tensorrt::InferenceMap_t* const trt_output_map,
+    GraphProto* const graph_proto,
+    const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
+    const std::string& node_name, const nnvm::Graph& g,
+    const StorageTypeVector& storage_types, const DTypeVector& dtypes) {
+  const nnvm::IndexedGraph& ig = g.indexed_graph();
+  uint32_t out_idx = ig.entry_id(ig.outputs()[out_iter->second]);
+  TShape out_shape = g.GetAttr<nnvm::ShapeVector>("shape")[out_idx];
+  int storage_type = storage_types[out_idx];
+  int dtype = dtypes[out_idx];
+
+  // This should work with fp16 as well
+  op::tensorrt::InferenceTuple_t out_tuple{out_iter->second, out_shape, storage_type,
+                                      dtype};
+
+  trt_output_map->emplace(node_name, out_tuple);
+
+  auto graph_out = graph_proto->add_output();
+  auto tensor_type = graph_out->mutable_type()->mutable_tensor_type();
+  auto tensor_shape_proto = tensor_type->mutable_shape();
+  graph_out->set_name(node_name);
+
+  // Also support fp16.
+  tensor_type->set_elem_type(TensorProto_DataType_FLOAT);
+
+  for (int64_t dim_shp : out_shape) {
+    TensorShapeProto_Dimension* const tsp_dim = tensor_shape_proto->add_dim();
+    tsp_dim->set_dim_value(static_cast<int64>(dim_shp));
+  }
+}
+
+}  // namespace nnvm_to_onnx
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/operator/contrib/tensorrt-inl.h b/src/operator/contrib/tensorrt-inl.h
new file mode 100644
index 0000000..be335ab
--- /dev/null
+++ b/src/operator/contrib/tensorrt-inl.h
@@ -0,0 +1,113 @@
+#ifndef MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_
+#define MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file tensorrt-inl.h
+ * \brief TensorRT Operator
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include <dmlc/logging.h>
+#include <dmlc/memory_io.h>
+#include <dmlc/serializer.h>
+#include <dmlc/parameter.h>
+#include <mxnet/base.h>
+#include <mxnet/operator.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+
+#include <NvInfer.h>
+#include <onnx/onnx.pb.h>
+
+#include <algorithm>
+#include <iostream>
+#include <map>
+#include <vector>
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+#include <string>
+
+#include "../operator_common.h"
+#include "../../common/utils.h"
+#include "../../common/serialization.h"
+#include "../../executor/exec_pass.h"
+#include "../../executor/graph_executor.h"
+#include "../../executor/onnx_to_tensorrt.h"
+
+namespace mxnet {
+namespace op {
+
+using namespace nnvm;
+using namespace ::onnx;
+using int64 = ::google::protobuf::int64;
+
+namespace tensorrt {
+  enum class TypeIO { Inputs = 0, Outputs = 1 };
+  using NameToIdx_t = std::map<std::string, int32_t>;
+  using InferenceTuple_t = std::tuple<uint32_t, TShape, int, int>;
+  using InferenceMap_t = std::map<std::string, InferenceTuple_t>;
+}  // namespace tensorrt
+
+using trt_name_to_idx = std::map<std::string, uint32_t>;
+
+struct TRTParam : public dmlc::Parameter<TRTParam> {
+  std::string serialized_onnx_graph;
+  std::string serialized_input_map;
+  std::string serialized_output_map;
+  tensorrt::NameToIdx_t input_map;
+  tensorrt::InferenceMap_t output_map;
+  ::onnx::ModelProto onnx_pb_graph;
+
+  TRTParam() {}
+
+  TRTParam(const ::onnx::ModelProto& onnx_graph,
+           const tensorrt::InferenceMap_t& input_map,
+           const tensorrt::NameToIdx_t& output_map) {
+    common::Serialize(input_map, &serialized_input_map);
+    common::Serialize(output_map, &serialized_output_map);
+    onnx_graph.SerializeToString(&serialized_onnx_graph);
+  }
+
+DMLC_DECLARE_PARAMETER(TRTParam) {
+    DMLC_DECLARE_FIELD(serialized_onnx_graph)
+    .describe("Serialized ONNX graph");
+    DMLC_DECLARE_FIELD(serialized_input_map)
+    .describe("Map from inputs to topological order as input.");
+    DMLC_DECLARE_FIELD(serialized_output_map)
+    .describe("Map from outputs to order in g.outputs.");
+  }
+};
+
+struct TRTEngineParam {
+  nvinfer1::IExecutionContext* trt_executor;
+  std::vector<std::pair<uint32_t, tensorrt::TypeIO> > binding_map;
+};
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
+
+#endif  // MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_
diff --git a/src/operator/contrib/tensorrt.cc b/src/operator/contrib/tensorrt.cc
new file mode 100644
index 0000000..619fe1e
--- /dev/null
+++ b/src/operator/contrib/tensorrt.cc
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file trt.cc
+ * \brief TensorRT operation registration
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include "./tensorrt-inl.h"
+
+#include <mxnet/base.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+
+#include <algorithm>
+#include <fstream>
+#include <iostream>
+#include <unordered_map>
+#include <vector>
+
+#include "../../common/serialization.h"
+#include "../../common/utils.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(TRTParam);
+
+OpStatePtr GetPtrMapping(nvinfer1::ICudaEngine* trt_engine,
+                         tensorrt::NameToIdx_t input_map,
+                         tensorrt::NameToIdx_t output_map) {
+  TRTEngineParam param;
+  for (int b = 0; b < trt_engine->getNbBindings(); ++b) {
+    const std::string& binding_name = trt_engine->getBindingName(b);
+    if (trt_engine->bindingIsInput(b)) {
+      param.binding_map.emplace_back(input_map[binding_name],
+                                     tensorrt::TypeIO::Inputs);
+    } else {
+      param.binding_map.emplace_back(output_map[binding_name],
+                                     tensorrt::TypeIO::Outputs);
+    }
+  }
+  param.trt_executor = trt_engine->createExecutionContext();
+  return OpStatePtr::Create<TRTEngineParam>(param);
+}
+
+OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/,
+                          const std::vector<TShape>& /*ishape*/,
+                          const std::vector<int>& /*itype*/) {
+  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+
+  ::onnx::ModelProto model_proto;
+  bool success = model_proto.ParseFromString(node_param.serialized_onnx_graph);
+  if (!success) {
+    LOG(FATAL) << "Problems parsing serialized ONNX model.";
+  }
+  auto graph = model_proto.graph();
+  auto first_input_type = graph.input(0).type().tensor_type();
+  auto dim_value = first_input_type.shape().dim(0).dim_value();
+  auto batch_size = static_cast<int32_t >(dim_value);
+  // Need to set up max workspace size based on device properties
+  nvinfer1::ICudaEngine* const trt_engine = ::onnx_to_tensorrt::onnxToTrtCtx(
+      node_param.serialized_onnx_graph, batch_size, 1 << 30);
+
+  tensorrt::NameToIdx_t output_map;
+  for (auto& el : node_param.output_map) {
+    output_map[el.first] = std::get<0>(el.second);
+  }
+  return GetPtrMapping(trt_engine, node_param.input_map, output_map);
+}
+
+void TRTParamParser(nnvm::NodeAttrs* attrs) {
+  TRTParam param_;
+
+  try {
+    param_.Init(attrs->dict);
+    common::Deserialize(&param_.input_map, param_.serialized_input_map);
+    common::Deserialize(&param_.output_map, param_.serialized_output_map);
+    param_.onnx_pb_graph.ParseFromString(param_.serialized_onnx_graph);
+  } catch (const dmlc::ParamError& e) {
+    std::ostringstream os;
+    os << e.what();
+    os << ", in operator " << attrs->op->name << "("
+       << "name=\"" << attrs->name << "\"";
+    for (const auto& k : attrs->dict) {
+      os << ", " << k.first << "=\"" << k.second << "\"";
+    }
+    os << ")";
+    throw dmlc::ParamError(os.str());
+  }
+
+  attrs->parsed = std::move(param_);
+}
+
+inline bool TRTInferShape(const NodeAttrs& attrs, std::vector<TShape>* /*in_shape*/,
+                          std::vector<TShape>* out_shape) {
+  const auto &node_param = nnvm::get<TRTParam>(attrs.parsed);
+  for (auto& el : node_param.output_map) {
+    (*out_shape)[std::get<0>(el.second)] = std::get<1>(el.second);
+  }
+  return true;
+}
+
+inline bool TRTInferStorageType(const NodeAttrs& /*attrs*/, const int /*dev_mask*/,
+                                DispatchMode* dispatch_mode,
+                                std::vector<int>* /*in_storage_type*/,
+                                std::vector<int>* out_storage_type) {
+  return storage_type_assign(out_storage_type, mxnet::kDefaultStorage,
+                             dispatch_mode, DispatchMode::kFCompute);
+}
+
+inline bool TRTInferType(const NodeAttrs& attrs, std::vector<int>* /*in_dtype*/,
+                         std::vector<int>* out_dtype) {
+  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+  for (auto& el : node_param.output_map) {
+    (*out_dtype)[std::get<0>(el.second)] = std::get<3>(el.second);
+  }
+  return true;
+}
+
+inline std::vector<std::string> TRTListInputNames(const NodeAttrs& attrs) {
+  std::vector<std::string> output;
+  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+  output.resize(node_param.input_map.size());
+  for (auto& el : node_param.input_map) {
+    output[el.second] = el.first;
+  }
+  return output;
+}
+
+inline std::vector<std::string> TRTListOutputNames(const NodeAttrs& attrs) {
+  std::vector<std::string> output;
+  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+  output.resize(node_param.output_map.size());
+  for (auto& el : node_param.output_map) {
+    output[std::get<0>(el.second)] = el.first;
+  }
+  return output;
+}
+
+NNVM_REGISTER_OP(_trt_op)
+    .describe(R"code(TRT operation (one engine)
+)code" ADD_FILELINE)
+    .set_num_inputs([](const NodeAttrs& attrs) {
+      const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+      return node_param.input_map.size();
+    })
+    .set_num_outputs([](const NodeAttrs& attrs) {
+      const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+      return node_param.output_map.size();
+    })
+    .set_attr_parser(TRTParamParser)
+    .set_attr<nnvm::FInferShape>("FInferShape", TRTInferShape)
+    .set_attr<nnvm::FInferType>("FInferType", TRTInferType)
+    .set_attr<nnvm::FListInputNames>("FListInputNames", TRTListInputNames)
+    .set_attr<nnvm::FListOutputNames>("FListOutputNames", TRTListOutputNames)
+    .set_attr<FCreateOpState>("FCreateOpState", TRTCreateState)
+    .set_attr<FInferStorageType>("FInferStorageType", TRTInferStorageType);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/operator/contrib/tensorrt.cu b/src/operator/contrib/tensorrt.cu
new file mode 100644
index 0000000..2fe8727
--- /dev/null
+++ b/src/operator/contrib/tensorrt.cu
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file trt.cu
+ * \brief TensorRT GPU operation
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include "./tensorrt-inl.h"
+
+namespace mxnet {
+namespace op {
+
+#define CHECK_CUDART(x) do { \
+  cudaError_t res = (x); \
+  if (res != cudaSuccess) { \
+    fprintf(stderr, "CUDART: %s = %d (%s) at (%s:%d)\n", \
+      #x, res, cudaGetErrorString(res), __FILE__, __LINE__); \
+    exit(1); \
+  } \
+} while (0)
+
+void TRTCompute(const OpStatePtr& state, const OpContext& ctx,
+                     const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
+                     const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+  cudaStream_t cuda_s = Stream<gpu>::GetStream(s);
+  const auto& param = state.get_state<TRTEngineParam>();
+  std::vector<void*> bindings;
+  bindings.reserve(param.binding_map.size());
+  for (auto& p : param.binding_map) {
+    if (p.second == tensorrt::TypeIO::Inputs) {
+      bindings.emplace_back(inputs[p.first].dptr_);
+    } else {
+      bindings.emplace_back(outputs[p.first].dptr_);
+    }
+  }
+
+  const int batch_size = static_cast<int>(inputs[0].shape_[0]);
+  param.trt_executor->enqueue(batch_size, bindings.data(), cuda_s, nullptr);
+  CHECK_CUDART(cudaStreamSynchronize(cuda_s));
+}
+
+NNVM_REGISTER_OP(_trt_op)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", TRTCompute);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/tests/.gitignore b/tests/.gitignore
index d645908..3e5eed6 100644
--- a/tests/.gitignore
+++ b/tests/.gitignore
@@ -1 +1,2 @@
 *_unittest
+*.gz
diff --git a/tests/cpp/misc/serialization.cc b/tests/cpp/misc/serialization.cc
new file mode 100644
index 0000000..96f8b6c
--- /dev/null
+++ b/tests/cpp/misc/serialization.cc
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <../../../src/common/serialization.h>
+
+using namespace mxnet;
+using namespace std;
+
+/*
+ * Test that used datastruct are properly serialized and deserialized
+ */
+
+TEST(SerializerTest, InputMapCorrect) {
+    std::map<std::string, int32_t> input_map;
+    input_map.emplace("input_0", 2);
+    input_map.emplace("another_input", 0);
+    input_map.emplace("last_input", 1);
+    std::string serialized_data;
+    common::Serialize(input_map, &serialized_data);
+    std::map<std::string, int32_t> deserialized_input_map;
+    common::Deserialize(&deserialized_input_map, serialized_data);
+    ASSERT_EQ(input_map.size(), deserialized_input_map.size());
+    for (auto& p : input_map) {
+        auto it = deserialized_input_map.find(p.first);
+        ASSERT_NE(it, deserialized_input_map.end());
+        ASSERT_EQ(it->second, p.second);
+    }
+}
+
+TEST(SerializerTest, OutputMapCorrect) {
+    std::map<std::string, std::tuple<uint32_t, TShape, int, int> > output_map;
+    output_map.emplace("output_0", std::make_tuple(1, TShape({23, 12, 63, 432}), 0, 1));
+    output_map.emplace("another_output", std::make_tuple(2, TShape({23, 123}), 14, -23));
+    output_map.emplace("last_output", std::make_tuple(0, TShape({0}), -1, 0));
+    std::string serialized_data;
+    common::Serialize(output_map, &serialized_data);
+    std::map<std::string, std::tuple<uint32_t, TShape, int, int> > deserialized_output_map;
+    common::Deserialize(&deserialized_output_map, serialized_data);
+    ASSERT_EQ(output_map.size(), deserialized_output_map.size());
+    for (auto& p : output_map) {
+        auto it = deserialized_output_map.find(p.first);
+        ASSERT_NE(it, deserialized_output_map.end());
+        auto lhs = it->second;
+        auto rhs = p.second;
+        ASSERT_EQ(std::get<0>(lhs), std::get<0>(rhs));
+        ASSERT_EQ(std::get<1>(lhs), std::get<1>(rhs));
+        ASSERT_EQ(std::get<2>(lhs), std::get<2>(rhs));
+        ASSERT_EQ(std::get<3>(lhs), std::get<3>(rhs));
+    }
+}
+
diff --git a/python/mxnet/contrib/__init__.py b/tests/python/tensorrt/common.py
similarity index 56%
copy from python/mxnet/contrib/__init__.py
copy to tests/python/tensorrt/common.py
index fbfd346..eb599f6 100644
--- a/python/mxnet/contrib/__init__.py
+++ b/tests/python/tensorrt/common.py
@@ -15,20 +15,25 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# coding: utf-8
-"""Experimental contributions"""
+import os
+from ctypes.util import find_library
 
-from . import symbol
-from . import ndarray
 
-from . import symbol as sym
-from . import ndarray as nd
+def check_tensorrt_installation():
+    assert find_library('nvinfer') is not None, "Can't find the TensorRT shared library"
 
-from . import autograd
-from . import tensorboard
 
-from . import text
-from . import onnx
-from . import io
-from . import quantization
-from . import quantization as quant
+def merge_dicts(*dict_args):
+    """Merge arg_params and aux_params to populate shared_buffer"""
+    result = {}
+    for dictionary in dict_args:
+        result.update(dictionary)
+    return result
+
+
+def get_fp16_infer_for_fp16_graph():
+    return int(os.environ.get("MXNET_TENSORRT_USE_FP16_FOR_FP32", 0))
+
+
+def set_fp16_infer_for_fp16_graph(status=False):
+    os.environ["MXNET_TENSORRT_USE_FP16_FOR_FP32"] = str(int(status))
diff --git a/python/mxnet/contrib/__init__.py b/tests/python/tensorrt/lenet5_common.py
similarity index 56%
copy from python/mxnet/contrib/__init__.py
copy to tests/python/tensorrt/lenet5_common.py
index fbfd346..347d6f3 100644
--- a/python/mxnet/contrib/__init__.py
+++ b/tests/python/tensorrt/lenet5_common.py
@@ -15,20 +15,17 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# coding: utf-8
-"""Experimental contributions"""
+import numpy as np
+import mxnet as mx
+from common import *
 
-from . import symbol
-from . import ndarray
-
-from . import symbol as sym
-from . import ndarray as nd
-
-from . import autograd
-from . import tensorboard
-
-from . import text
-from . import onnx
-from . import io
-from . import quantization
-from . import quantization as quant
+def get_iters(mnist, batch_size):
+    """Get MNIST iterators."""
+    train_iter = mx.io.NDArrayIter(mnist['train_data'],
+                                   mnist['train_label'],
+                                   batch_size,
+                                   shuffle=True)
+    val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
+    test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
+    all_test_labels = np.array(mnist['test_label'])
+    return train_iter, val_iter, test_iter, all_test_labels
diff --git a/tests/python/tensorrt/lenet5_train.py b/tests/python/tensorrt/lenet5_train.py
new file mode 100644
index 0000000..8edd9ab
--- /dev/null
+++ b/tests/python/tensorrt/lenet5_train.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import mxnet as mx
+from lenet5_common import get_iters
+
+
+def lenet5():
+    """LeNet-5 Symbol"""
+    #pylint: disable=no-member
+    data = mx.sym.Variable('data')
+    conv1 = mx.sym.Convolution(data=data, kernel=(5, 5), num_filter=20)
+    tanh1 = mx.sym.Activation(data=conv1, act_type="tanh")
+    pool1 = mx.sym.Pooling(data=tanh1, pool_type="max",
+                           kernel=(2, 2), stride=(2, 2))
+    # second conv
+    conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=50)
+    tanh2 = mx.sym.Activation(data=conv2, act_type="tanh")
+    pool2 = mx.sym.Pooling(data=tanh2, pool_type="max",
+                           kernel=(2, 2), stride=(2, 2))
+    # first fullc
+    flatten = mx.sym.Flatten(data=pool2)
+    fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=500)
+    tanh3 = mx.sym.Activation(data=fc1, act_type="tanh")
+    # second fullc
+    fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10)
+    # loss
+    lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax')
+    #pylint: enable=no-member
+    return lenet
+
+
+def train_lenet5(num_epochs, batch_size, train_iter, val_iter, test_iter):
+    """train LeNet-5 model on MNIST data"""
+    ctx = mx.gpu(0)
+    lenet_model = mx.mod.Module(lenet5(), context=ctx)
+
+    lenet_model.fit(train_iter,
+                    eval_data=val_iter,
+                    optimizer='sgd',
+                    optimizer_params={'learning_rate': 0.1, 'momentum': 0.9},
+                    eval_metric='acc',
+                    batch_end_callback=mx.callback.Speedometer(batch_size, 1),
+                    num_epoch=num_epochs)
+
+    # predict accuracy for lenet
+    acc = mx.metric.Accuracy()
+    lenet_model.score(test_iter, acc)
+    accuracy = acc.get()[1]
+    assert accuracy > 0.95, "LeNet-5 training accuracy on MNIST was too low"
+    return lenet_model
+
+
+if __name__ == '__main__':
+    num_epochs = 10
+    batch_size = 128
+    model_name = 'lenet5'
+    model_dir = os.getenv("LENET_MODEL_DIR", "/tmp")
+    model_file = '%s/%s-symbol.json' % (model_dir, model_name)
+    params_file = '%s/%s-%04d.params' % (model_dir, model_name, num_epochs)
+
+    if not (os.path.exists(model_file) and os.path.exists(params_file)):
+        mnist = mx.test_utils.get_mnist()
+
+        _, _, _, all_test_labels = get_iters(mnist, batch_size)
+
+        trained_lenet = train_lenet5(num_epochs, batch_size,
+                                    *get_iters(mnist, batch_size)[:-1])
+        trained_lenet.save_checkpoint(model_name, num_epochs)
diff --git a/tests/python/tensorrt/test_cvnets.py b/tests/python/tensorrt/test_cvnets.py
new file mode 100644
index 0000000..4fdd522
--- /dev/null
+++ b/tests/python/tensorrt/test_cvnets.py
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import gc
+import gluoncv
+import mxnet as mx
+import numpy as np
+
+from mxnet import gluon
+from time import time
+
+from mxnet.gluon.data.vision import transforms
+
+
+def get_classif_model(model_name, use_tensorrt, ctx=mx.gpu(0), batch_size=128):
+    mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt)
+    h, w = 32, 32
+    net = gluoncv.model_zoo.get_model(model_name, pretrained=True)
+    data = mx.sym.var('data')
+
+    if use_tensorrt:
+        out = net(data)
+        softmax = mx.sym.SoftmaxOutput(out, name='softmax')
+        all_params = dict([(k, v.data()) for k, v in net.collect_params().items()])
+        executor = mx.contrib.tensorrt.tensorrt_bind(softmax, ctx=ctx, all_params=all_params,
+                                                     data=(batch_size,3, h, w),
+                                                     softmax_label=(batch_size,), grad_req='null',
+                                                     force_rebind=True)
+    else:
+        # Convert gluon model to Symbolic
+        net.hybridize()
+        net.forward(mx.ndarray.zeros((batch_size, 3, h, w)))
+        net.export(model_name)
+        symbol, arg_params, aux_params = mx.model.load_checkpoint(model_name, 0)
+        executor = symbol.simple_bind(ctx=ctx, data=(batch_size, 3, h, w),
+                                      softmax_label=(batch_size,))
+        executor.copy_params_from(arg_params, aux_params)
+    return executor
+
+
+def cifar10_infer(model_name, use_tensorrt, num_workers, ctx=mx.gpu(0), batch_size=128):
+    executor = get_classif_model(model_name, use_tensorrt, ctx, batch_size)
+
+    num_ex = 10000
+    all_preds = np.zeros([num_ex, 10])
+
+    all_label_test = np.zeros(num_ex)
+
+    transform_test = transforms.Compose([
+        transforms.ToTensor(),
+        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
+    ])
+
+    data_loader = lambda: gluon.data.DataLoader(
+        gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
+        batch_size=batch_size, shuffle=False, num_workers=num_workers)
+
+    val_data = data_loader()
+
+    for idx, (data, label) in enumerate(val_data):
+        # Skip last batch if it's undersized.
+        if data.shape[0] < batch_size:
+            continue
+        offset = idx * batch_size
+        all_label_test[offset:offset + batch_size] = label.asnumpy()
+
+        # warm-up, but don't use result
+        executor.forward(is_train=False, data=data)
+        executor.outputs[0].wait_to_read()
+
+    gc.collect()
+    val_data = data_loader()
+    example_ct = 0
+    start = time()
+
+    # if use_tensorrt:
+    for idx, (data, label) in enumerate(val_data):
+        # Skip last batch if it's undersized.
+        if data.shape[0] < batch_size:
+            continue
+        executor.forward(is_train=False, data=data)
+        preds = executor.outputs[0].asnumpy()
+        offset = idx * batch_size
+        all_preds[offset:offset + batch_size, :] = preds[:batch_size]
+        example_ct += batch_size
+
+    all_preds = np.argmax(all_preds, axis=1)
+    matches = (all_preds[:example_ct] == all_label_test[:example_ct]).sum()
+    duration = time() - start
+
+    return duration, 100.0 * matches / example_ct
+
+
+def run_experiment_for(model_name, batch_size, num_workers):
+    print("\n===========================================")
+    print("Model: %s" % model_name)
+    print("===========================================")
+    print("*** Running inference using pure MXNet ***\n")
+    mx_duration, mx_pct = cifar10_infer(model_name=model_name, batch_size=batch_size,
+                                        num_workers=num_workers, use_tensorrt=False)
+    print("\nMXNet: time elapsed: %.3fs, accuracy: %.2f%%" % (mx_duration, mx_pct))
+    print("\n*** Running inference using MXNet + TensorRT ***\n")
+    trt_duration, trt_pct = cifar10_infer(model_name=model_name, batch_size=batch_size,
+                                          num_workers=num_workers, use_tensorrt=True)
+    print("TensorRT: time elapsed: %.3fs, accuracy: %.2f%%" % (trt_duration, trt_pct))
+    speedup = mx_duration / trt_duration
+    print("TensorRT speed-up (not counting compilation): %.2fx" % speedup)
+
+    acc_diff = abs(mx_pct - trt_pct)
+    print("Absolute accuracy difference: %f" % acc_diff)
+    return speedup, acc_diff
+
+
+def test_tensorrt_on_cifar_resnets(batch_size=32, tolerance=0.1, num_workers=1):
+    original_try_value = mx.contrib.tensorrt.get_use_tensorrt()
+    try:
+        models = [
+            'cifar_resnet20_v1',
+            'cifar_resnet56_v1',
+            'cifar_resnet110_v1',
+            'cifar_resnet20_v2',
+            'cifar_resnet56_v2',
+            'cifar_resnet110_v2',
+            'cifar_wideresnet16_10',
+            'cifar_wideresnet28_10',
+            'cifar_wideresnet40_8',
+            'cifar_resnext29_16x64d'
+        ]
+
+        num_models = len(models)
+
+        speedups = np.zeros(num_models, dtype=np.float32)
+        acc_diffs = np.zeros(num_models, dtype=np.float32)
+
+        test_start = time()
+
+        for idx, model in enumerate(models):
+            speedup, acc_diff = run_experiment_for(model, batch_size, num_workers)
+            speedups[idx] = speedup
+            acc_diffs[idx] = acc_diff
+            assert acc_diff < tolerance, "Accuracy difference between MXNet and TensorRT > %.2f%% for model %s" % (
+                tolerance, model)
+
+        print("Perf and correctness checks run on the following models:")
+        print(models)
+        mean_speedup = np.mean(speedups)
+        std_speedup = np.std(speedups)
+        print("\nSpeedups:")
+        print(speedups)
+        print("Speedup range: [%.2f, %.2f]" % (np.min(speedups), np.max(speedups)))
+        print("Mean speedup: %.2f" % mean_speedup)
+        print("St. dev. of speedups: %.2f" % std_speedup)
+        print("\nAcc. differences: %s" % str(acc_diffs))
+
+        test_duration = time() - test_start
+
+        print("Test duration: %.2f seconds" % test_duration)
+    finally:
+        mx.contrib.tensorrt.set_use_tensorrt(original_try_value)
+
+
+if __name__ == '__main__':
+    import nose
+
+    nose.runmodule()
diff --git a/tests/python/tensorrt/test_cycle.py b/tests/python/tensorrt/test_cycle.py
new file mode 100644
index 0000000..25f515a
--- /dev/null
+++ b/tests/python/tensorrt/test_cycle.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from common import *
+
+
+def detect_cycle_from(sym, visited, stack):
+    visited.add(sym.handle.value)
+    stack.add(sym.handle.value)
+    for s in sym.get_children():
+        if s.handle.value not in visited:
+            if detect_cycle_from(sym, visited, stack):
+                return True
+        elif s.handle.value in stack:
+            return True
+        stack.remove(sym.handle.value)
+    return False
+
+
+def has_no_cycle(sym):
+    visited = set()
+    stack = set()
+    all_nodes = sym.get_internals()
+    for s in all_nodes:
+        if s.handle.value in visited:
+            if detect_cycle_from(s, visited, stack):
+                return False
+    return True
+
+
+def test_simple_cycle():
+    inp = mx.sym.Variable('input', shape=[1,10])
+    A = mx.sym.FullyConnected(data=inp, num_hidden=10, no_bias=False, name='A')
+    B = mx.sym.FullyConnected(data=A, num_hidden=10, no_bias=False, name='B')
+    D = mx.sym.sin(data=A, name='D')
+    C = mx.sym.elemwise_add(lhs=B, rhs=D, name='C')
+    arg_params = {
+                'I_weight': mx.nd.zeros([10,10]),
+                'I_bias': mx.nd.zeros([10]),
+                'A_weight': mx.nd.zeros([10,10]),
+                'A_bias': mx.nd.zeros([10]),
+                'B_weight': mx.nd.zeros([10,10]),
+                'B_bias': mx.nd.zeros([10]),
+               }
+
+    executor = C.simple_bind(ctx=mx.gpu(0), data=(1,10), softmax_label=(1,),
+                           shared_buffer=arg_params, grad_req='null', force_rebind=True)
+    optimized_graph = mx.contrib.tensorrt.get_optimized_symbol(executor)
+    assert has_no_cycle(optimized_graph), "The graph optimized by TRT contains a cycle"
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()
diff --git a/tests/python/tensorrt/test_tensorrt_lenet5.py b/tests/python/tensorrt/test_tensorrt_lenet5.py
new file mode 100644
index 0000000..2586864
--- /dev/null
+++ b/tests/python/tensorrt/test_tensorrt_lenet5.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import numpy as np
+import mxnet as mx
+from common import *
+from lenet5_common import get_iters
+
+
+def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size, use_tensorrt):
+    """Run inference with either MXNet or TensorRT"""
+    mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt)
+
+    data_size = (batch_size,) + mnist['test_data'].shape[1:]
+    if use_tensorrt:
+        all_params = merge_dicts(arg_params, aux_params)
+        executor = mx.contrib.tensorrt.tensorrt_bind(sym, ctx=mx.gpu(0), all_params=all_params,
+                                                     data=data_size,
+                                                     softmax_label=(batch_size,),
+                                                     grad_req='null',
+                                                     force_rebind=True)
+    else:
+        executor = sym.simple_bind(ctx=mx.gpu(0),
+                                   data=data_size,
+                                   softmax_label=(batch_size,),
+                                   grad_req='null',
+                                   force_rebind=True)
+        executor.copy_params_from(arg_params, aux_params)
+
+    # Get this value from all_test_labels
+    # Also get classes from the dataset
+    num_ex = 10000
+    all_preds = np.zeros([num_ex, 10])
+    test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
+
+    example_ct = 0
+
+    for idx, dbatch in enumerate(test_iter):
+        executor.arg_dict["data"][:] = dbatch.data[0]
+        executor.forward(is_train=False)
+        offset = idx*batch_size
+        extent = batch_size if num_ex - offset > batch_size else num_ex - offset
+        all_preds[offset:offset+extent, :] = executor.outputs[0].asnumpy()[:extent]
+        example_ct += extent
+
+    all_preds = np.argmax(all_preds, axis=1)
+    matches = (all_preds[:example_ct] == all_test_labels[:example_ct]).sum()
+
+    percentage = 100.0 * matches / example_ct
+
+    return percentage
+
+
+def test_tensorrt_inference():
+    """Run LeNet-5 inference comparison between MXNet and TensorRT."""
+    original_try_value = mx.contrib.tensorrt.get_use_tensorrt()
+    try:
+        check_tensorrt_installation()
+        mnist = mx.test_utils.get_mnist()
+        num_epochs = 10
+        batch_size = 128
+        model_name = 'lenet5'
+        model_dir = os.getenv("LENET_MODEL_DIR", "/tmp")
+        model_file = '%s/%s-symbol.json' % (model_dir, model_name)
+        params_file = '%s/%s-%04d.params' % (model_dir, model_name, num_epochs)
+
+        _, _, _, all_test_labels = get_iters(mnist, batch_size)
+
+        # Load serialized MXNet model (model-symbol.json + model-epoch.params)
+        sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, num_epochs)
+
+        print("LeNet-5 test")
+        print("Running inference in MXNet")
+        mx_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels,
+                               batch_size=batch_size, use_tensorrt=False)
+
+        print("Running inference in MXNet-TensorRT")
+        trt_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels,
+                                batch_size=batch_size, use_tensorrt=True)
+
+        print("MXNet accuracy: %f" % mx_pct)
+        print("MXNet-TensorRT accuracy: %f" % trt_pct)
+
+        assert abs(mx_pct - trt_pct) < 1e-2, \
+            """Diff. between MXNet & TensorRT accuracy too high:
+               MXNet = %f, TensorRT = %f""" % (mx_pct, trt_pct)
+    finally:
+        mx.contrib.tensorrt.set_use_tensorrt(original_try_value)
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()
diff --git a/tests/python/tensorrt/test_training_warning.py b/tests/python/tensorrt/test_training_warning.py
new file mode 100644
index 0000000..fdac859
--- /dev/null
+++ b/tests/python/tensorrt/test_training_warning.py
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import gluoncv
+import mxnet as mx
+
+from tests.python.unittest.common import assertRaises
+
+
+def test_training_without_trt():
+    run_resnet(is_train=True, use_tensorrt=False)
+
+
+def test_inference_without_trt():
+    run_resnet(is_train=False, use_tensorrt=False)
+
+
+def test_training_with_trt():
+    assertRaises(RuntimeError, run_resnet, is_train=True, use_tensorrt=True)
+
+
+def test_inference_with_trt():
+    run_resnet(is_train=False, use_tensorrt=True)
+
+
+def run_resnet(is_train, use_tensorrt):
+    original_trt_value = mx.contrib.tensorrt.get_use_tensorrt()
+    try:
+        mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt)
+        ctx = mx.gpu(0)
+        batch_size = 1
+        h = 32
+        w = 32
+        model_name = 'cifar_resnet20_v1'
+        resnet = gluoncv.model_zoo.get_model(model_name, pretrained=True)
+        data = mx.sym.var('data')
+        out = resnet(data)
+        softmax = mx.sym.SoftmaxOutput(out, name='softmax')
+        if is_train:
+            grad_req = 'write'
+        else:
+            grad_req = 'null'
+        if use_tensorrt:
+            all_params = dict([(k, v.data()) for k, v in resnet.collect_params().items()])
+            mx.contrib.tensorrt.tensorrt_bind(softmax, ctx=ctx, all_params=all_params,
+                                              data=(batch_size, 3, h, w), softmax_label=(batch_size,),
+                                              force_rebind=True, grad_req=grad_req)
+        else:
+            softmax.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), softmax_label=(batch_size,),
+                                force_rebind=True, grad_req=grad_req)
+    finally:
+        mx.contrib.tensorrt.set_use_tensorrt(original_trt_value)
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()