You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/10 09:52:02 UTC

[GitHub] KellenSunderland closed pull request #12122: [MXNET-703] Fix relative import in TensorRT

KellenSunderland closed pull request #12122: [MXNET-703] Fix relative import in TensorRT
URL: https://github.com/apache/incubator-mxnet/pull/12122
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/.gitmodules b/.gitmodules
index 9aeb1c75498..836d824a6f5 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -26,3 +26,6 @@
 [submodule "3rdparty/tvm"]
 	path = 3rdparty/tvm
 	url = https://github.com/dmlc/tvm
+[submodule "3rdparty/onnx-tensorrt"]
+	path = 3rdparty/onnx-tensorrt
+	url = https://github.com/onnx/onnx-tensorrt.git
diff --git a/3rdparty/onnx-tensorrt b/3rdparty/onnx-tensorrt
new file mode 160000
index 00000000000..e7be19cff37
--- /dev/null
+++ b/3rdparty/onnx-tensorrt
@@ -0,0 +1 @@
+Subproject commit e7be19cff377a95817503e8525e20de34cdc574a
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 483108a6841..8ff337ed159 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -37,6 +37,7 @@ mxnet_option(ENABLE_CUDA_RTC      "Build with CUDA runtime compilation support"
 mxnet_option(BUILD_CPP_EXAMPLES   "Build cpp examples" ON)
 mxnet_option(INSTALL_EXAMPLES     "Install the example source files." OFF)
 mxnet_option(USE_SIGNAL_HANDLER   "Print stack traces on segfaults." OFF)
+mxnet_option(USE_TENSORRT         "Enable infeference optimization with TensorRT." OFF)
 
 message(STATUS "CMAKE_SYSTEM_NAME ${CMAKE_SYSTEM_NAME}")
 if(USE_CUDA AND NOT USE_OLDCMAKECUDA)
@@ -185,6 +186,36 @@ if(USE_VTUNE)
   list(APPEND mxnet_LINKER_LIBS dl)
 endif()
 
+if(USE_TENSORRT)
+  message(STATUS "Using TensorRT")
+  set(ONNX_PATH 3rdparty/onnx-tensorrt/third_party/onnx/build/)
+  set(ONNX_TRT_PATH 3rdparty/onnx-tensorrt/build/)
+
+  include_directories(${ONNX_PATH})
+  include_directories(3rdparty/onnx-tensorrt/)
+  include_directories(3rdparty/)
+  add_definitions(-DMXNET_USE_TENSORRT=1)
+  add_definitions(-DONNX_NAMESPACE=onnx)
+
+  find_package(Protobuf REQUIRED)
+
+  find_library(ONNX_LIBRARY NAMES libonnx.so REQUIRED
+          PATHS ${ONNX_PATH}
+          DOC "Path to onnx library.")
+  find_library(ONNX_PROTO_LIBRARY NAMES libonnx_proto.so REQUIRED
+          PATHS ${ONNX_PATH}
+          DOC "Path to onnx_proto library.")
+  find_library(ONNX_TRT_RUNTIME_LIBRARY NAMES libnvonnxparser_runtime.so REQUIRED
+          PATHS ${ONNX_TRT_PATH}
+          DOC "Path to onnx_proto library.")
+  find_library(ONNX_TRT_PARSER_LIBRARY NAMES libnvonnxparser.so REQUIRED
+          PATHS ${ONNX_TRT_PATH}
+          DOC "Path to onnx_proto library.")
+
+  list(APPEND mxnet_LINKER_LIBS libnvinfer.so ${ONNX_TRT_PARSER_LIBRARY} ${ONNX_TRT_RUNTIME_LIBRARY}
+          ${ONNX_PROTO_LIBRARY} ${ONNX_LIBRARY} ${PROTOBUF_LIBRARY})
+endif()
+
 if(USE_MKLDNN)
   include(cmake/MklDnn.cmake)
   # CPU architecture (e.g., C5) can't run on another architecture (e.g., g3).
diff --git a/Jenkinsfile b/Jenkinsfile
index 6d21f496426..758e8e870ee 100644
--- a/Jenkinsfile
+++ b/Jenkinsfile
@@ -28,6 +28,7 @@ mx_dist_lib = 'lib/libmxnet.so, lib/libmxnet.a, 3rdparty/dmlc-core/libdmlc.a, 3r
 mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so'
 mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0'
 mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
+mx_tensorrt_lib = 'lib/libmxnet.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
 // timeout in minutes
 max_time = 120
 // assign any caught errors here
@@ -372,6 +373,17 @@ try {
         }
       }
     },
+    'TensorRT': {
+      node(NODE_LINUX_CPU) {
+        ws('workspace/build-tensorrt') {
+          timeout(time: max_time, unit: 'MINUTES') {
+            utils.init_git()
+            utils.docker_run('ubuntu_gpu_tensorrt', 'build_ubuntu_gpu_tensorrt', false)
+            utils.pack_lib('tensorrt', mx_tensorrt_lib)
+          }
+        }
+      }
+    },
     'Build CPU windows':{
       node('mxnetwindows-cpu') {
         timeout(time: max_time, unit: 'MINUTES') {
@@ -740,6 +752,22 @@ try {
         }
       }
     },
+    'Python3: TensorRT GPU': {
+      node(NODE_LINUX_GPU_P3) {
+        ws('workspace/build-tensorrt') {
+          timeout(time: max_time, unit: 'MINUTES') {
+            try {
+              utils.init_git()
+              utils.unpack_lib('tensorrt', mx_tensorrt_lib)
+              utils.docker_run('ubuntu_gpu_tensorrt', 'unittest_ubuntu_tensorrt_gpu', true)
+              utils.publish_test_coverage()
+            } finally {
+              utils.collect_test_results_unix('nosetests_tensorrt.xml', 'nosetests_python3_tensorrt_gpu.xml')
+            }
+          }
+        }
+      }
+    },
     'Scala: CPU': {
       node('mxnetlinux-cpu') {
         ws('workspace/ut-scala-cpu') {
diff --git a/Makefile b/Makefile
index 88f7dd9278c..b794e00f00a 100644
--- a/Makefile
+++ b/Makefile
@@ -91,6 +91,14 @@ else
 endif
 CFLAGS += -I$(TPARTYDIR)/mshadow/ -I$(TPARTYDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(TPARTYDIR)/tvm/include -Iinclude $(MSHADOW_CFLAGS)
 LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS)
+
+
+ifeq ($(USE_TENSORRT), 1)
+	CFLAGS +=  -I$(ROOTDIR) -I$(TPARTYDIR) -DONNX_NAMESPACE=$(ONNX_NAMESPACE) -DMXNET_USE_TENSORRT=1
+	LDFLAGS += -lprotobuf -pthread -lonnx -lonnx_proto -lnvonnxparser -lnvonnxparser_runtime -lnvinfer -lnvinfer_plugin
+endif
+# -L/usr/local/lib
+
 ifeq ($(DEBUG), 1)
 	NVCCFLAGS += -std=c++11 -Xcompiler -D_FORCE_INLINES -g -G -O0 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS)
 else
diff --git a/amalgamation/amalgamation.py b/amalgamation/amalgamation.py
index 52d775b7692..a3c28f7118e 100644
--- a/amalgamation/amalgamation.py
+++ b/amalgamation/amalgamation.py
@@ -23,13 +23,12 @@
 import platform
 
 blacklist = [
-    'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh',
-    'cuda_runtime.h', 'cudnn.h', 'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h',
-    'glog/logging.h', 'io/azure_filesys.h', 'io/hdfs_filesys.h', 'io/s3_filesys.h',
-    'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h',
-    'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h',
-    'nvml.h', 'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h',
-    'omp.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h',
+    'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh', 'cuda_runtime.h', 'cudnn.h',
+    'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h', 'glog/logging.h', 'io/azure_filesys.h',
+    'io/hdfs_filesys.h', 'io/s3_filesys.h', 'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h',
+    'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h', 'NvInfer.h', 'nvml.h',
+    'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h', 'omp.h',
+    'onnx/onnx.pb.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h',
     'cusolverDn.h', 'internal/concurrentqueue_internal_debug.h', 'relacy/relacy_std.hpp',
     'relacy_shims.h', 'ittnotify.h', 'shared_mutex'
     ]
@@ -150,6 +149,7 @@ def expand(x, pending, stage):
                     h not in sysheaders and
                     'mkl' not in h and
                     'nnpack' not in h and
+                    'tensorrt' not in h and
                     not h.endswith('.cuh')): sysheaders.append(h)
             else:
                 expand.treeDepth += 1
diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
new file mode 100755
index 00000000000..255da316041
--- /dev/null
+++ b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
@@ -0,0 +1,41 @@
+# -*- mode: dockerfile -*-
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Dockerfile to run MXNet on Ubuntu 16.04 for CPU
+
+FROM nvidia/cuda:9.0-cudnn7-devel
+
+WORKDIR /work/deps
+
+COPY install/ubuntu_core.sh /work/
+RUN /work/ubuntu_core.sh
+COPY install/deb_ubuntu_ccache.sh /work/
+RUN /work/deb_ubuntu_ccache.sh
+COPY install/ubuntu_python.sh /work/
+RUN /work/ubuntu_python.sh
+COPY install/tensorrt.sh /work
+RUN /work/tensorrt.sh
+
+ARG USER_ID=0
+COPY install/ubuntu_adduser.sh /work/
+RUN /work/ubuntu_adduser.sh
+
+COPY runtime_functions.sh /work/
+
+WORKDIR /work/mxnet
+ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib
diff --git a/ci/docker/install/tensorrt.sh b/ci/docker/install/tensorrt.sh
new file mode 100755
index 00000000000..a6258d94f62
--- /dev/null
+++ b/ci/docker/install/tensorrt.sh
@@ -0,0 +1,45 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Install gluoncv since we're testing Gluon models as well
+pip2 install gluoncv==0.2.0
+pip3 install gluoncv==0.2.0
+
+# Install Protobuf
+# Install protoc 3.5 and build protobuf here (for onnx and onnx-tensorrt)
+pushd .
+cd ..
+apt-get update
+apt-get install -y automake libtool
+git clone --recursive -b 3.5.1.1 https://github.com/google/protobuf.git
+cd protobuf
+./autogen.sh
+./configure
+make -j$(nproc)
+make install
+ldconfig
+popd
+
+# Install TensorRT
+echo "TensorRT build enabled. Installing TensorRT."
+wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0_1-1_amd64.deb
+dpkg -i tensorrt.deb
+apt-get update
+apt-get install -y --allow-downgrades libnvinfer-dev
+rm tensorrt.deb
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index a0795eb58a5..3e19eaf7004 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -436,6 +436,60 @@ build_ubuntu_gpu() {
     build_ubuntu_gpu_cuda91_cudnn7
 }
 
+build_ubuntu_gpu_tensorrt() {
+
+    set -ex
+
+    build_ccache_wrappers
+
+    # Build ONNX
+    pushd .
+    echo "Installing ONNX."
+    cd 3rdparty/onnx-tensorrt/third_party/onnx
+    rm -rf build
+    mkdir -p build
+    cd build
+    cmake \
+        -DCMAKE_CXX_FLAGS=-I/usr/include/python${PYVER}\
+        -DBUILD_SHARED_LIBS=ON ..\
+        -G Ninja
+    ninja -v
+    export LIBRARY_PATH=`pwd`:`pwd`/onnx/:$LIBRARY_PATH
+    export CPLUS_INCLUDE_PATH=`pwd`:$CPLUS_INCLUDE_PATH
+    popd
+
+    # Build ONNX-TensorRT
+    pushd .
+    cd 3rdparty/onnx-tensorrt/
+    mkdir -p build
+    cd build
+    cmake ..
+    make -j$(nproc)
+    export LIBRARY_PATH=`pwd`:$LIBRARY_PATH
+    popd
+
+    mkdir -p /work/mxnet/lib/
+    cp 3rdparty/onnx-tensorrt/third_party/onnx/build/*.so /work/mxnet/lib/
+    cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser_runtime.so.0 /work/mxnet/lib/
+    cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser.so.0 /work/mxnet/lib/
+
+    rm -rf build
+    make \
+        DEV=1                                               \
+        USE_BLAS=openblas                                   \
+        USE_CUDA=1                                          \
+        USE_CUDA_PATH=/usr/local/cuda                       \
+        USE_CUDNN=1                                         \
+        USE_OPENCV=0                                        \
+        USE_DIST_KVSTORE=0                                  \
+        USE_TENSORRT=1                                      \
+        USE_JEMALLOC=0                                      \
+        USE_GPERFTOOLS=0                                    \
+        ONNX_NAMESPACE=onnx                                 \
+        CUDA_ARCH="-gencode arch=compute_70,code=compute_70"\
+        -j$(nproc)
+}
+
 build_ubuntu_gpu_mkldnn() {
     set -ex
 
@@ -638,6 +692,15 @@ unittest_ubuntu_python3_gpu_nocudnn() {
     nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_gpu.xml --verbose tests/python/gpu
 }
 
+unittest_ubuntu_tensorrt_gpu() {
+    set -ex
+    export PYTHONPATH=./python/
+    export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
+    export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH
+    python tests/python/tensorrt/lenet5_train.py
+    nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_trt_gpu.xml --verbose tests/python/tensorrt/
+}
+
 # quantization gpu currently only runs on P3 instances
 # need to separte it from unittest_ubuntu_python2_gpu()
 unittest_ubuntu_python2_quantization_gpu() {
@@ -970,3 +1033,5 @@ EOF
     declare -F | cut -d' ' -f3
     echo
 fi
+
+
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 75147cfd706..58b1b1b4daf 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1714,6 +1714,13 @@ MXNET_DLL int MXExecutorReshape(int partial_shaping,
                                 NDArrayHandle** aux_states,
                                 ExecutorHandle shared_exec,
                                 ExecutorHandle *out);
+
+/*!
+ * \brief get optimized graph from graph executor
+ */
+MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle,
+                                           SymbolHandle *out);
+
 /*!
  * \brief set a call back to notify the completion of operation
  */
diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h
index 842653f8653..0ab04b86a0a 100644
--- a/include/mxnet/executor.h
+++ b/include/mxnet/executor.h
@@ -166,6 +166,7 @@ class Executor {
                               std::unordered_map<std::string, NDArray>*
                                 shared_data_arrays = nullptr,
                               Executor* shared_exec = nullptr);
+
   /*!
    * \brief the prototype of user-defined monitor callback
    */
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 4df794bdfe3..5f4ae8bd0ac 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -709,3 +709,19 @@ def write_all_str(module_file, module_all_list):
     module_op_file.close()
     write_all_str(module_internal_file, module_internal_all)
     module_internal_file.close()
+
+def cint(init_val=0):
+    """create a C int with an optional initial value"""
+    return C.c_int(init_val)
+
+def int_addr(x):
+    """given a c_int, return it's address as an int ptr"""
+    x_addr = C.addressof(x)
+    int_p = C.POINTER(C.c_int)
+    x_int_addr = C.cast(x_addr, int_p)
+    return x_int_addr
+
+def checked_call(f, *args):
+    """call a cuda function and check for success"""
+    error_t = f(*args)
+    assert error_t == 0, "Failing cuda call %s returns %s." % (f.__name__, error_t)
diff --git a/python/mxnet/contrib/__init__.py b/python/mxnet/contrib/__init__.py
index fbfd3469678..606bb0ada54 100644
--- a/python/mxnet/contrib/__init__.py
+++ b/python/mxnet/contrib/__init__.py
@@ -32,3 +32,4 @@
 from . import io
 from . import quantization
 from . import quantization as quant
+from . import tensorrt
diff --git a/python/mxnet/contrib/tensorrt.py b/python/mxnet/contrib/tensorrt.py
new file mode 100644
index 00000000000..4ff39c4b482
--- /dev/null
+++ b/python/mxnet/contrib/tensorrt.py
@@ -0,0 +1,110 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Module to enable the use of TensorRT optimized graphs."""
+
+import ctypes
+import logging
+import os
+
+from .. import symbol as sym
+
+from ..base import _LIB, SymbolHandle, MXNetError
+from ..base import check_call
+
+
+def set_use_tensorrt(status):
+    """
+    Set an environment variable which will enable or disable the use of TensorRT in the backend.
+    Note: this is useful for A/B testing purposes.
+    :param status: Boolean, true if TensorRT optimization should be applied, False for legacy
+    behaviour.
+    """
+    os.environ["MXNET_USE_TENSORRT"] = str(int(status))
+
+
+def get_use_tensorrt():
+    """
+    Get an environment variable which describes if TensorRT is currently enabled in the backend.
+    Note: this is useful for A/B testing purposes.
+    :return: Boolean, true if TensorRT optimization should be applied, False for legacy
+    behaviour.
+    """
+    return bool(int(os.environ.get("MXNET_USE_TENSORRT", 0)) == 1)
+
+
+def get_optimized_symbol(executor):
+    """
+    Take an executor's underlying symbol graph and return its generated optimized version.
+
+    Parameters
+    ----------
+    executor :
+        An executor for which you want to see an optimized symbol. Getting an optimized symbol
+        is useful to compare and verify the work TensorRT has done against a legacy behaviour.
+
+    Returns
+    -------
+    symbol : nnvm::Symbol
+        The nnvm symbol optimized.
+    """
+    handle = SymbolHandle()
+    try:
+        check_call(_LIB.MXExecutorGetOptimizedSymbol(executor.handle, ctypes.byref(handle)))
+        result = sym.Symbol(handle=handle)
+        return result
+    except MXNetError:
+        logging.error('Error while trying to fetch TRT optimized symbol for graph. Please ensure '
+                      'build was compiled with MXNET_USE_TENSORRT enabled.')
+        raise
+
+
+def tensorrt_bind(symbol, ctx, all_params, type_dict=None, stype_dict=None, group2ctx=None,
+                  **kwargs):
+    """Bind current symbol to get an optimized trt executor.
+
+    Parameters
+    ----------
+    symbol : Symbol
+        The symbol you wish to bind, and optimize with TensorRT.
+
+    ctx : Context
+        The device context the generated executor to run on.
+
+    all_params : Dict of str->ndarray
+        A dictionary of mappings from parameter names to parameter NDArrays.
+
+    type_dict  : Dict of str->numpy.dtype
+        Input type dictionary, name->dtype
+
+    stype_dict  : Dict of str->str
+        Input storage type dictionary, name->storage_type
+
+    group2ctx : Dict of string to mx.Context
+        The dict mapping the `ctx_group` attribute to the context assignment.
+
+    kwargs : Dict of str->shape
+        Input shape dictionary, name->shape
+
+    Returns
+    -------
+    executor : mxnet.Executor
+        An optimized TensorRT executor.
+    """
+    kwargs['shared_buffer'] = all_params
+    return symbol.simple_bind(ctx, type_dict=type_dict, stype_dict=stype_dict,
+                              group2ctx=group2ctx, **kwargs)
diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py
index c0272c5bb43..fcd5406236e 100644
--- a/python/mxnet/executor.py
+++ b/python/mxnet/executor.py
@@ -73,6 +73,7 @@ def __init__(self, handle, symbol, ctx, grad_req, group2ctx):
         self.aux_arrays = []
         self.outputs = self._get_outputs()
         self._symbol = copy.deepcopy(symbol)
+        self._optimized_symbol = None
         self._arg_dict = None
         self._grad_dict = None
         self._aux_dict = None
diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py
index 5d8e95077c4..c4050699bd5 100755
--- a/python/mxnet/module/executor_group.py
+++ b/python/mxnet/module/executor_group.py
@@ -592,8 +592,8 @@ def backward(self, out_grads=None):
                     # pylint: disable=no-member
                     og_my_slice = nd.slice_axis(grad, axis=axis, begin=islice.start,
                                                 end=islice.stop)
-                    # pylint: enable=no-member
                     out_grads_slice.append(og_my_slice.as_in_context(self.contexts[i]))
+                    # pylint: enable=no-member
                 else:
                     out_grads_slice.append(grad.copyto(self.contexts[i]))
             exec_.backward(out_grads=out_grads_slice)
diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc
index 09bc23934e5..b99350525bf 100644
--- a/src/c_api/c_api_executor.cc
+++ b/src/c_api/c_api_executor.cc
@@ -26,6 +26,10 @@
 #include <mxnet/c_api.h>
 #include <mxnet/executor.h>
 #include "./c_api_common.h"
+#include "../executor/graph_executor.h"
+#if MXNET_USE_TENSORRT
+#include "../executor/trt_graph_executor.h"
+#endif  // MXNET_USE_TENSORRT
 
 int MXExecutorPrint(ExecutorHandle handle, const char **out_str) {
   Executor *exec = static_cast<Executor*>(handle);
@@ -439,13 +443,38 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
   std::vector<NDArray> in_arg_vec;
   std::vector<NDArray> arg_grad_vec;
   std::vector<NDArray> aux_state_vec;
-
-  *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec,
-                              aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map,
-                              grad_req_type_vec, shared_arg_name_set, &in_arg_vec,
-                              &arg_grad_vec, &aux_state_vec,
-                              use_shared_buffer ? &shared_buffer_map : nullptr,
-                              reinterpret_cast<Executor*>(shared_exec_handle));
+#if MXNET_USE_TENSORRT
+  // If we've built with TensorRT support we by default return an TRTExecutor.
+  // Users can override this behaviour via env var, which is useful for example for A/B
+  // performance testing.
+  if (dmlc::GetEnv("MXNET_USE_TENSORRT", false)) {
+    *out = exec::TrtGraphExecutor::TensorRTBind(*sym, ctx, ctx_map, &in_arg_ctx_vec,
+                                                &arg_grad_ctx_vec, &aux_state_ctx_vec,
+                                                &arg_shape_map, &arg_dtype_map, &arg_stype_map,
+                                                &grad_req_type_vec, shared_arg_name_set,
+                                                &in_arg_vec, &arg_grad_vec, &aux_state_vec,
+                                                use_shared_buffer ? &shared_buffer_map : nullptr,
+                                                reinterpret_cast<Executor*>(shared_exec_handle));
+  } else {
+    // Checks to see if this env var has been set to true or false by the user.
+    // If the user is using a TensorRT build, but has not enabled TRT at inference time, warn
+    // them and describe further steps.
+    const int unset_indicator =  std::numeric_limits<int>::quiet_NaN();
+    if (dmlc::GetEnv("MXNET_USE_TENSORRT", unset_indicator) == unset_indicator) {
+      LOG(INFO) << "TensorRT not enabled by default.  Please set the MXNET_USE_TENSORRT "
+                   "environment variable to 1 or call mx.contrib.tensorrt.set_use_tensorrt(True) "
+                   "to enable.";
+    }
+#endif  // MXNET_USE_TENSORRT
+    *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec,
+                                aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map,
+                                grad_req_type_vec, shared_arg_name_set, &in_arg_vec,
+                                &arg_grad_vec, &aux_state_vec,
+                                use_shared_buffer ? &shared_buffer_map : nullptr,
+                                reinterpret_cast<Executor*>(shared_exec_handle));
+#if MXNET_USE_TENSORRT
+  }
+#endif  // MXNET_USE_TENSORRT
 
   // copy ndarray ptrs to ret->handles so that front end
   // can access them
@@ -597,6 +626,25 @@ int MXExecutorReshape(int partial_shaping,
   API_END_HANDLE_ERROR(delete out);
 }
 
+int MXExecutorGetOptimizedSymbol(ExecutorHandle handle,
+                                 SymbolHandle *out) {
+  auto s = new nnvm::Symbol();
+  API_BEGIN();
+
+#if MXNET_USE_TENSORRT
+  auto exec = static_cast<exec::TrtGraphExecutor*>(handle);
+  *s = exec->GetOptimizedSymbol();
+  *out = s;
+#else
+  LOG(FATAL) << "GetOptimizedSymbol may only be used when MXNet is compiled with "
+                "MXNET_USE_TENSORRT enabled.  Please re-compile MXNet with TensorRT support.";
+#endif  // MXNET_USE_TENSORRT
+
+  API_END_HANDLE_ERROR(delete s);
+}
+
+
+
 int MXExecutorSetMonitorCallback(ExecutorHandle handle,
                                  ExecutorMonitorCallback callback,
                                  void* callback_handle) {
diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h
index 816599b955c..fbe544221a3 100644
--- a/src/common/exec_utils.h
+++ b/src/common/exec_utils.h
@@ -24,10 +24,14 @@
 #ifndef MXNET_COMMON_EXEC_UTILS_H_
 #define MXNET_COMMON_EXEC_UTILS_H_
 
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+#include <map>
 #include <vector>
 #include <string>
 #include <utility>
 #include "../common/utils.h"
+#include "../executor/exec_pass.h"
 
 namespace mxnet {
 namespace common {
@@ -366,6 +370,257 @@ inline void LogInferStorage(const nnvm::Graph& g) {
   }
 }
 
+// prints a helpful message after shape inference errors in executor.
+inline void HandleInferShapeError(const size_t num_forward_inputs,
+                                  const nnvm::IndexedGraph& idx,
+                                  const nnvm::ShapeVector& inferred_shapes) {
+  int cnt = 10;
+  std::ostringstream oss;
+  for (size_t i = 0; i < num_forward_inputs; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const uint32_t eid = idx.entry_id(nid, 0);
+    const TShape& inferred_shape = inferred_shapes[eid];
+    if (inferred_shape.ndim() == 0 || inferred_shape.Size() == 0U) {
+      const std::string& arg_name = idx[nid].source->attrs.name;
+      oss << arg_name << ": " << inferred_shape << ", ";
+      if (--cnt == 0) {
+        oss << "...";
+        break;
+      }
+    }
+  }
+  LOG(FATAL) << "InferShape pass cannot decide shapes for the following arguments "
+                "(0s means unknown dimensions). Please consider providing them as inputs:\n"
+             << oss.str();
+}
+
+// prints a helpful message after type inference errors in executor.
+inline void HandleInferTypeError(const size_t num_forward_inputs,
+                                 const nnvm::IndexedGraph& idx,
+                                 const nnvm::DTypeVector& inferred_dtypes) {
+  int cnt = 10;
+  std::ostringstream oss;
+  for (size_t i = 0; i < num_forward_inputs; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const uint32_t eid = idx.entry_id(nid, 0);
+    const int inferred_dtype = inferred_dtypes[eid];
+    if (inferred_dtype == -1) {
+      const std::string& arg_name = idx[nid].source->attrs.name;
+      oss << arg_name << ": " << inferred_dtype << ", ";
+      if (--cnt == 0) {
+        oss << "...";
+        break;
+      }
+    }
+  }
+  LOG(FATAL) << "InferType pass cannot decide dtypes for the following arguments "
+                "(-1 means unknown dtype). Please consider providing them as inputs:\n"
+             << oss.str();
+}
+
+// prints a helpful message after storage type checking errors in executor.
+inline void HandleInferStorageTypeError(const size_t num_forward_inputs,
+                                        const nnvm::IndexedGraph& idx,
+                                        const StorageTypeVector& inferred_stypes) {
+  int cnt = 10;
+  std::ostringstream oss;
+  for (size_t i = 0; i < num_forward_inputs; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const uint32_t eid = idx.entry_id(nid, 0);
+    const int inferred_stype = inferred_stypes[eid];
+    if (inferred_stype == -1) {
+      const std::string& arg_name = idx[nid].source->attrs.name;
+      oss << arg_name << ": " << common::stype_string(inferred_stype) << ", ";
+      if (--cnt == 0) {
+        oss << "...";
+        break;
+      }
+    }
+  }
+  LOG(FATAL) << "InferStorageType pass cannot decide storage type for the following arguments "
+                "(-1 means unknown stype). Please consider providing them as inputs:\n"
+             << oss.str();
+}
+
+/*!
+ * \brief If the requested ndarray's shape size is less than
+ * the corresponding shared_data_array's shape size and the
+ * storage type is shareable, reuse the memory allocation
+ * in shared_buffer; otherwise, create a zero ndarray.
+ * Shareable storages include both default storage and row_sparse storage
+ * if enable_row_sparse_sharing is `True`, otherwise default storage only.
+ */
+inline NDArray ReshapeOrCreate(const std::string& name,
+                               const TShape& dest_arg_shape,
+                               const int dest_arg_dtype,
+                               const NDArrayStorageType dest_arg_stype,
+                               const Context& ctx,
+                               std::unordered_map<std::string, NDArray>* shared_buffer,
+                               bool enable_row_sparse_sharing) {
+  bool stype_shareable = dest_arg_stype == kDefaultStorage;
+  if (enable_row_sparse_sharing) {
+    stype_shareable = stype_shareable || dest_arg_stype == kRowSparseStorage;
+  }
+  auto it = shared_buffer->find(name);
+  if (it != shared_buffer->end()) {
+    // check if size is large enough for sharing
+    bool size_shareable = it->second.shape().Size() >= dest_arg_shape.Size();
+    if (size_shareable && stype_shareable) {  // memory can be reused
+      CHECK_EQ(it->second.dtype(), dest_arg_dtype)
+          << "Requested arg array's dtype does not match that of the reusable ndarray";
+      CHECK_EQ(it->second.storage_type(), dest_arg_stype)
+          << "Requested arg array's stype does not match that of the reusable ndarray";
+      return it->second.Reshape(dest_arg_shape);
+    } else if (stype_shareable) {
+      LOG(WARNING) << "Bucketing: data " << name << " has a shape " << dest_arg_shape
+                   << ", which is larger than already allocated shape " << it->second.shape()
+                   << ". Need to re-allocate. Consider putting default bucket key to be "
+                   << "the bucket taking the largest input for better memory sharing.";
+      // size is not large enough, creating a larger one for sharing
+      // the NDArrays in shared_buffer are guaranteed to be of shareable storages
+      it->second = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
+      return it->second;
+    } else {
+      // not shareable storage
+      return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
+    }
+  } else {
+    auto ret = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
+    if (stype_shareable) {
+      shared_buffer->emplace(name, ret);
+    }
+    return ret;
+  }  // if (it != shared_buffer->end())
+}
+
+/*!
+ * \brief Assign context to the graph.
+ * This is triggered by both simple_bind and bind flows.
+ */
+inline nnvm::Graph AssignContext(nnvm::Graph g,
+                                 const Context& default_ctx,
+                                 const std::map<std::string, Context>& ctx_map,
+                                 const std::vector<Context>& in_arg_ctxes,
+                                 const std::vector<Context>& arg_grad_ctxes,
+                                 const std::vector<Context>& aux_state_ctxes,
+                                 const std::vector<OpReqType>& grad_req_types,
+                                 size_t num_forward_inputs,
+                                 size_t num_forward_outputs) {
+  const auto& idx = g.indexed_graph();
+  const auto& mutable_nodes = idx.mutable_input_nodes();
+  // default use default context.
+  if (ctx_map.size() == 0) {
+    g.attrs["context"] = std::make_shared<nnvm::any>(
+        exec::ContextVector(idx.num_nodes(), default_ctx));
+    for (const auto& x : in_arg_ctxes) {
+      CHECK(x == default_ctx)
+          << "Input array is in " << x << " while binding with ctx=" << default_ctx
+          << ". All arguments must be in global context (" << default_ctx
+          << ") unless group2ctx is specified for cross-device graph.";
+    }
+    for (const auto& x : arg_grad_ctxes) {
+      CHECK(x == default_ctx)
+          << "Gradient array is in " << x << " while binding with ctx="
+          << default_ctx << ". All gradients must be in global context (" << default_ctx
+          << ") unless group2ctx is specified for cross-device graph.";
+    }
+    return g;
+  }
+
+  // otherwise, use context assignment.
+  std::map<Context, int> ctx2id;  // map ctx to device id
+  std::vector<Context> ctx_list;  // index is device id
+  nnvm::DeviceVector device(idx.num_nodes(), -1);  // index is node id
+  nnvm::DeviceAssignMap device_map;  // map arg name to device id
+
+  // loop through the user input ctx_map and
+  // populate maps and lists
+  for (auto &kv : ctx_map) {
+    if (ctx2id.count(kv.second) == 0) {  // if context has no device id, create one
+      ctx2id[kv.second] = static_cast<int>(ctx_list.size());  // assign device id to ctx
+      ctx_list.push_back(kv.second);  // save ctx to the list
+    }
+    // assign device id to to the arg name with the corresponding ctx
+    device_map[kv.first] = ctx2id.at(kv.second);
+  }
+
+  // loop through all the rest of input nodes not specified
+  // in the ctx_map and populate maps and lists
+  size_t arg_top = 0, aux_top = 0;
+  for (size_t i = 0; i < num_forward_inputs; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    Context ctx;
+    if (mutable_nodes.count(nid)) {  // aux node is mutable
+      CHECK_LT(aux_top, aux_state_ctxes.size());
+      ctx = aux_state_ctxes[aux_top];
+      ++aux_top;
+    } else {  // regular input node is immutable
+      CHECK_LT(arg_top, in_arg_ctxes.size());
+      ctx = in_arg_ctxes[arg_top];
+      ++arg_top;
+    }
+    if (ctx2id.count(ctx) == 0) {  // if the current ctx is not in the map of ctx and device id
+      ctx2id[ctx] = static_cast<int>(ctx_list.size());  // assign the current ctx with device id
+      ctx_list.push_back(ctx);  // save the current ctx in the list
+    }
+    device[nid] = ctx2id.at(ctx);  // assign device id to the current node
+  }
+
+  // loop through backward input nodes and populate maps and lists
+  // the backward input nodes is the gradient of the loss wrt the output
+  size_t arg_grad_offset = 0;
+  // keep an offset into the arg_grad_ctxes vector,
+  // since g.outputs exclude arg_grad whose req == null
+  CHECK_GE(grad_req_types.size(), g.outputs.size() - num_forward_outputs)
+      << "insufficient number of grad_reqs";
+  for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) {
+    while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset;
+    const uint32_t nid = idx.outputs()[i].node_id;
+    Context ctx = arg_grad_ctxes[arg_grad_offset];
+    if (ctx2id.count(ctx) == 0) {
+      ctx2id[ctx] = static_cast<int>(ctx_list.size());
+      ctx_list.push_back(ctx);
+    }
+    int devid = ctx2id.at(ctx);
+    if (device[nid] != -1) {
+      CHECK_EQ(device[nid], devid) << "device of same output not equal to each other";
+    } else {
+      device[nid] = devid;
+    }
+  }
+
+  g.attrs["device"] = std::make_shared<dmlc::any>(std::move(device));
+  g = nnvm::pass::PlaceDevice(g, "__ctx_group__", device_map, "_CrossDeviceCopy");
+  const auto& assigned_device = g.GetAttr<nnvm::DeviceVector>("device");
+
+  exec::ContextVector vcontext;
+  for (size_t i = 0; i < assigned_device.size(); ++i) {
+    if (assigned_device[i] == -1) {
+      vcontext.push_back(default_ctx);
+    } else {
+      vcontext.push_back(ctx_list[assigned_device[i]]);
+    }
+  }
+
+  // after device planning, we should check again
+  // if the assigned device of gradient node
+  // corresponds to storage of grads
+  auto &new_idx = g.indexed_graph();
+  arg_grad_offset = 0;
+  for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) {
+    while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset;
+    const uint32_t nid = new_idx.outputs()[i].node_id;
+    Context ctx = arg_grad_ctxes[arg_grad_offset];
+    CHECK(ctx == vcontext[nid])
+        << "Trying to save gradient to " << ctx
+        << " while its source node \"" << new_idx[nid].source->attrs.name
+        << "\" computes it on " << vcontext[nid]
+        << ". Check your ctx in NDArray allocation.";
+  }
+
+  g.attrs["context"] = std::make_shared<nnvm::any>(std::move(vcontext));
+  return g;
+}
 
 }  // namespace common
 }  // namespace mxnet
diff --git a/src/common/serialization.h b/src/common/serialization.h
new file mode 100644
index 00000000000..8a1bcc6e6ed
--- /dev/null
+++ b/src/common/serialization.h
@@ -0,0 +1,319 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015 by Contributors
+ * \file serialization.h
+ * \brief Serialization of some STL and nnvm data-structures
+ * \author Clement Fuji Tsang
+ */
+
+#ifndef MXNET_COMMON_SERIALIZATION_H_
+#define MXNET_COMMON_SERIALIZATION_H_
+
+#include <dmlc/logging.h>
+#include <mxnet/graph_attr_types.h>
+#include <nnvm/graph_attr_types.h>
+#include <nnvm/tuple.h>
+
+#include <cstring>
+#include <map>
+#include <set>
+#include <string>
+#include <tuple>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+
+namespace mxnet {
+namespace common {
+
+template<typename T>
+inline size_t SerializedSize(const T &obj);
+
+template<typename T>
+inline size_t SerializedSize(const nnvm::Tuple <T> &obj);
+
+template<typename K, typename V>
+inline size_t SerializedSize(const std::map <K, V> &obj);
+
+template<>
+inline size_t SerializedSize(const std::string &obj);
+
+template<typename... Args>
+inline size_t SerializedSize(const std::tuple<Args...> &obj);
+
+template<typename T>
+inline void Serialize(const T &obj, char **buffer);
+
+template<typename T>
+inline void Serialize(const nnvm::Tuple <T> &obj, char **buffer);
+
+template<typename K, typename V>
+inline void Serialize(const std::map <K, V> &obj, char **buffer);
+
+template<>
+inline void Serialize(const std::string &obj, char **buffer);
+
+template<typename... Args>
+inline void Serialize(const std::tuple<Args...> &obj, char **buffer);
+
+template<typename T>
+inline void Deserialize(T *obj, const std::string &buffer, size_t *curr_pos);
+
+template<typename T>
+inline void Deserialize(nnvm::Tuple <T> *obj, const std::string &buffer, size_t *curr_pos);
+
+template<typename K, typename V>
+inline void Deserialize(std::map <K, V> *obj, const std::string &buffer, size_t *curr_pos);
+
+template<>
+inline void Deserialize(std::string *obj, const std::string &buffer, size_t *curr_pos);
+
+template<typename... Args>
+inline void Deserialize(std::tuple<Args...> *obj, const std::string &buffer, size_t *curr_pos);
+
+
+template<typename T>
+struct is_container {
+  static const bool value = !std::is_pod<T>::value;
+};
+
+template<typename T>
+inline size_t SerializedSize(const T &obj) {
+  return sizeof(T);
+}
+
+template<typename T>
+inline size_t SerializedSize(const nnvm::Tuple <T> &obj) {
+  if (is_container<T>::value) {
+    size_t sum_val = 4;
+    for (const auto& el : obj) {
+      sum_val += SerializedSize(el);
+    }
+    return sum_val;
+  } else {
+    return 4 + (obj.ndim() * sizeof(T));
+  }
+}
+
+template<typename K, typename V>
+inline size_t SerializedSize(const std::map <K, V> &obj) {
+  size_t sum_val = 4;
+  if (is_container<K>::value && is_container<V>::value) {
+    for (const auto& p : obj) {
+      sum_val += SerializedSize(p.first) + SerializedSize(p.second);
+    }
+  } else if (is_container<K>::value) {
+    for (const auto& p : obj) {
+      sum_val += SerializedSize(p.first);
+    }
+    sum_val += sizeof(V) * obj.size();
+  } else if (is_container<V>::value) {
+    for (const auto& p : obj) {
+      sum_val += SerializedSize(p.second);
+    }
+    sum_val += sizeof(K) * obj.size();
+  } else {
+    sum_val += (sizeof(K) + sizeof(V)) * obj.size();
+  }
+  return sum_val;
+}
+
+template<>
+inline size_t SerializedSize(const std::string &obj) {
+  return obj.size() + 4;
+}
+
+template<int I>
+struct serialized_size_tuple {
+  template<typename... Args>
+  static inline size_t Compute(const std::tuple<Args...> &obj) {
+    return SerializedSize(std::get<I>(obj)) + serialized_size_tuple<I-1>::Compute(obj);
+  }
+};
+
+template<>
+struct serialized_size_tuple<0> {
+  template<typename... Args>
+  static inline size_t Compute(const std::tuple<Args...> &obj) {
+    return SerializedSize(std::get<0>(obj));
+  }
+};
+
+template<typename... Args>
+inline size_t SerializedSize(const std::tuple<Args...> &obj) {
+  return serialized_size_tuple<sizeof... (Args)-1>::Compute(obj);
+}
+
+//  Serializer
+
+template<typename T>
+inline size_t SerializedContainerSize(const T &obj, char **buffer) {
+  uint32_t size = obj.size();
+  std::memcpy(*buffer, &size, 4);
+  *buffer += 4;
+  return (size_t) size;
+}
+
+template<typename T>
+inline void Serialize(const T &obj, char **buffer) {
+  std::memcpy(*buffer, &obj, sizeof(T));
+  *buffer += sizeof(T);
+}
+
+template<typename T>
+inline void Serialize(const nnvm::Tuple <T> &obj, char **buffer) {
+  uint32_t size = obj.ndim();
+  std::memcpy(*buffer, &size, 4);
+  *buffer += 4;
+  for (auto& el : obj) {
+    Serialize(el, buffer);
+  }
+}
+
+template<typename K, typename V>
+inline void Serialize(const std::map <K, V> &obj, char **buffer) {
+  SerializedContainerSize(obj, buffer);
+  for (auto& p : obj) {
+    Serialize(p.first, buffer);
+    Serialize(p.second, buffer);
+  }
+}
+
+template<>
+inline void Serialize(const std::string &obj, char **buffer) {
+  auto size = SerializedContainerSize(obj, buffer);
+  std::memcpy(*buffer, &obj[0], size);
+  *buffer += size;
+}
+
+template<int I>
+struct serialize_tuple {
+  template<typename... Args>
+  static inline void Compute(const std::tuple<Args...> &obj, char **buffer) {
+    serialize_tuple<I-1>::Compute(obj, buffer);
+    Serialize(std::get<I>(obj), buffer);
+  }
+};
+
+template<>
+struct serialize_tuple<0> {
+  template<typename... Args>
+  static inline void Compute(const std::tuple<Args...> &obj, char **buffer) {
+    Serialize(std::get<0>(obj), buffer);
+  }
+};
+
+template<typename... Args>
+inline void Serialize(const std::tuple<Args...> &obj, char **buffer) {
+  serialize_tuple<sizeof... (Args)-1>::Compute(obj, buffer);
+}
+
+// Deserializer
+
+template<typename T>
+inline size_t DeserializedContainerSize(T *obj, const std::string &buffer, size_t *curr_pos) {
+  uint32_t size = obj->size();
+  std::memcpy(&size, &buffer[*curr_pos], 4);
+  *curr_pos += 4;
+  return (size_t) size;
+}
+
+template<typename T>
+inline void Deserialize(T *obj, const std::string &buffer, size_t *curr_pos) {
+  std::memcpy(obj, &buffer[*curr_pos], sizeof(T));
+  *curr_pos += sizeof(T);
+}
+
+template<typename T>
+inline void Deserialize(nnvm::Tuple <T> *obj, const std::string &buffer, size_t *curr_pos) {
+  uint32_t size = obj->ndim();
+  std::memcpy(&size, &buffer[*curr_pos], 4);
+  *curr_pos += 4;
+  obj->SetDim(size);
+  for (size_t i = 0; i < size; ++i) {
+    Deserialize((*obj)[i], buffer, curr_pos);
+  }
+}
+
+template<typename K, typename V>
+inline void Deserialize(std::map <K, V> *obj, const std::string &buffer, size_t *curr_pos) {
+  auto size = DeserializedContainerSize(obj, buffer, curr_pos);
+  K first;
+  for (size_t i = 0; i < size; ++i) {
+    Deserialize(&first, buffer, curr_pos);
+    Deserialize(&(*obj)[first], buffer, curr_pos);
+  }
+}
+
+template<>
+inline void Deserialize(std::string *obj, const std::string &buffer, size_t *curr_pos) {
+  auto size = DeserializedContainerSize(obj, buffer, curr_pos);
+  obj->resize(size);
+  std::memcpy(&(obj->front()), &buffer[*curr_pos], size);
+  *curr_pos += size;
+}
+
+template<int I>
+struct deserialize_tuple {
+  template<typename... Args>
+  static inline void Compute(std::tuple<Args...> *obj,
+                             const std::string &buffer, size_t *curr_pos) {
+    deserialize_tuple<I-1>::Compute(obj, buffer, curr_pos);
+    Deserialize(&std::get<I>(*obj), buffer, curr_pos);
+  }
+};
+
+template<>
+struct deserialize_tuple<0> {
+  template<typename... Args>
+  static inline void Compute(std::tuple<Args...> *obj,
+                             const std::string &buffer, size_t *curr_pos) {
+    Deserialize(&std::get<0>(*obj), buffer, curr_pos);
+  }
+};
+
+template<typename... Args>
+inline void Deserialize(std::tuple<Args...> *obj, const std::string &buffer, size_t *curr_pos) {
+  deserialize_tuple<sizeof... (Args)-1>::Compute(obj, buffer, curr_pos);
+}
+
+
+template<typename T>
+inline void Serialize(const T& obj, std::string* serialized_data) {
+  serialized_data->resize(SerializedSize(obj));
+  char* curr_pos = &(serialized_data->front());
+  Serialize(obj, &curr_pos);
+  CHECK_EQ((int64_t)curr_pos - (int64_t)&(serialized_data->front()),
+           serialized_data->size());
+}
+
+template<typename T>
+inline void Deserialize(T* obj, const std::string& serialized_data) {
+  size_t curr_pos = 0;
+  Deserialize(obj, serialized_data, &curr_pos);
+  CHECK_EQ(curr_pos, serialized_data.size());
+}
+
+}  // namespace common
+}  // namespace mxnet
+#endif  // MXNET_COMMON_SERIALIZATION_H_
diff --git a/src/common/utils.h b/src/common/utils.h
index 96949a047fb..fcc3da82b05 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -675,6 +675,37 @@ MSHADOW_XINLINE int ilog2ui(unsigned int a) {
   return k;
 }
 
+/*!
+ * \brief Return an NDArray of all zeros.
+ */
+inline NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape,
+                         const Context &ctx, const int dtype) {
+  // NDArray with default storage
+  if (stype == kDefaultStorage) {
+    NDArray ret(shape, ctx, false, dtype);
+    ret = 0;
+    return ret;
+  }
+  // NDArray with non-default storage. Storage allocation is always delayed.
+  return NDArray(stype, shape, ctx, true, dtype);
+}
+
+/*!
+ * \brief Helper to add a NDArray of zeros to a std::vector.
+ */
+inline void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape,
+                             const Context &ctx, const int dtype,
+                             std::vector<NDArray> *vec) {
+  // NDArray with default storage
+  if (stype == kDefaultStorage) {
+    vec->emplace_back(shape, ctx, false, dtype);
+    vec->back() = 0;
+  } else {
+    // NDArray with non-default storage. Storage allocation is always delayed.
+    vec->emplace_back(stype, shape, ctx, true, dtype);
+  }
+}
+
 }  // namespace common
 }  // namespace mxnet
 #endif  // MXNET_COMMON_UTILS_H_
diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index 26a24911894..8c483e9b2b8 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -198,6 +198,18 @@ Graph InferStorageType(Graph&& graph,
                        StorageTypeVector&& storage_type_inputs = StorageTypeVector(),
                        const std::string& storage_type_attr_key = "");
 
+#if MXNET_USE_TENSORRT
+/*!
+ * \brief Replace subgraphs by TRT (forward only)
+ */
+Graph ReplaceSubgraph(Graph&& g,
+                      const std::unordered_set<nnvm::Node*>& set_subgraph,
+                      std::unordered_map<std::string, NDArray>* const params_map);
+
+std::vector<std::unordered_set<nnvm::Node*>> GetTrtCompatibleSubsets(const Graph& g,
+    std::unordered_map<std::string, NDArray>* const params_map);
+#endif
+
 }  // namespace exec
 }  // namespace mxnet
 
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 7386de4d12e..6810800c8b7 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -37,6 +37,8 @@
 namespace mxnet {
 namespace exec {
 
+using namespace mxnet::common;
+
 GraphExecutor::GraphExecutor() {
   log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false);
   need_grad_ = false;
@@ -56,30 +58,6 @@ GraphExecutor::~GraphExecutor() {
   }
 }
 
-inline NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape,
-                                const Context &ctx, const int dtype) {
-  // NDArray with default storage
-  if (stype == kDefaultStorage) {
-    NDArray ret(shape, ctx, false, dtype);
-    ret = 0;
-    return ret;
-  }
-  // NDArray with non-default storage. Storage allocation is always delayed.
-  return NDArray(stype, shape, ctx, true, dtype);
-}
-
-inline void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape,
-                             const Context &ctx, const int dtype,
-                             std::vector<NDArray> *vec) {
-  // NDArray with default storage
-  if (stype == kDefaultStorage) {
-    vec->emplace_back(shape, ctx, false, dtype);
-    vec->back() = 0;
-  } else {
-    // NDArray with non-default storage. Storage allocation is always delayed.
-    vec->emplace_back(stype, shape, ctx, true, dtype);
-  }
-}
 void GraphExecutor::Forward(bool is_train) {
   RunOps(is_train, 0, num_forward_nodes_);
 }
@@ -308,204 +286,6 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
   return g;
 }
 
-/*!
- * \brief Assign context to the graph.
- * This is triggered by both simple_bind and bind flows.
- */
-static Graph AssignContext(Graph g,
-                    const Context& default_ctx,
-                    const std::map<std::string, Context>& ctx_map,
-                    const std::vector<Context>& in_arg_ctxes,
-                    const std::vector<Context>& arg_grad_ctxes,
-                    const std::vector<Context>& aux_state_ctxes,
-                    const std::vector<OpReqType>& grad_req_types,
-                    size_t num_forward_inputs,
-                    size_t num_forward_outputs) {
-  const auto& idx = g.indexed_graph();
-  const auto& mutable_nodes = idx.mutable_input_nodes();
-  // default use default context.
-  if (ctx_map.size() == 0) {
-    g.attrs["context"] = std::make_shared<nnvm::any>(
-        ContextVector(idx.num_nodes(), default_ctx));
-    for (const auto& x : in_arg_ctxes) {
-      CHECK(x == default_ctx)
-        << "Input array is in " << x << " while binding with ctx=" << default_ctx
-        << ". All arguments must be in global context (" << default_ctx
-        << ") unless group2ctx is specified for cross-device graph.";
-    }
-    for (const auto& x : arg_grad_ctxes) {
-      CHECK(x == default_ctx)
-        << "Gradient array is in " << x << " while binding with ctx="
-        << default_ctx << ". All gradients must be in global context (" << default_ctx
-        << ") unless group2ctx is specified for cross-device graph.";
-    }
-    return g;
-  }
-
-  // otherwise, use context assignment.
-  std::map<Context, int> ctx2id;  // map ctx to device id
-  std::vector<Context> ctx_list;  // index is device id
-  nnvm::DeviceVector device(idx.num_nodes(), -1);  // index is node id
-  nnvm::DeviceAssignMap device_map;  // map arg name to device id
-
-  // loop through the user input ctx_map and
-  // populate maps and lists
-  for (auto &kv : ctx_map) {
-    if (ctx2id.count(kv.second) == 0) {  // if context has no device id, create one
-      ctx2id[kv.second] = static_cast<int>(ctx_list.size());  // assign device id to ctx
-      ctx_list.push_back(kv.second);  // save ctx to the list
-    }
-    // assign device id to to the arg name with the corresponding ctx
-    device_map[kv.first] = ctx2id.at(kv.second);
-  }
-
-  // loop through all the rest of input nodes not specified
-  // in the ctx_map and populate maps and lists
-  size_t arg_top = 0, aux_top = 0;
-  for (size_t i = 0; i < num_forward_inputs; ++i) {
-    const uint32_t nid = idx.input_nodes().at(i);
-    Context ctx;
-    if (mutable_nodes.count(nid)) {  // aux node is mutable
-      CHECK_LT(aux_top, aux_state_ctxes.size());
-      ctx = aux_state_ctxes[aux_top];
-      ++aux_top;
-    } else {  // regular input node is immutable
-      CHECK_LT(arg_top, in_arg_ctxes.size());
-      ctx = in_arg_ctxes[arg_top];
-      ++arg_top;
-    }
-    if (ctx2id.count(ctx) == 0) {  // if the current ctx is not in the map of ctx and device id
-      ctx2id[ctx] = static_cast<int>(ctx_list.size());  // assign the current ctx with device id
-      ctx_list.push_back(ctx);  // save the current ctx in the list
-    }
-    device[nid] = ctx2id.at(ctx);  // assign device id to the current node
-  }
-
-  // loop through backward input nodes and populate maps and lists
-  // the backward input nodes is the gradient of the loss wrt the output
-  size_t arg_grad_offset = 0;
-  // keep an offset into the arg_grad_ctxes vector,
-  // since g.outputs exclude arg_grad whose req == null
-  CHECK_GE(grad_req_types.size(), g.outputs.size() - num_forward_outputs)
-           << "insufficient number of grad_reqs";
-  for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) {
-    while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset;
-    const uint32_t nid = idx.outputs()[i].node_id;
-    Context ctx = arg_grad_ctxes[arg_grad_offset];
-    if (ctx2id.count(ctx) == 0) {
-      ctx2id[ctx] = static_cast<int>(ctx_list.size());
-      ctx_list.push_back(ctx);
-    }
-    int devid = ctx2id.at(ctx);
-    if (device[nid] != -1) {
-      CHECK_EQ(device[nid], devid) << "device of same output not equal to each other";
-    } else {
-      device[nid] = devid;
-    }
-  }
-
-  g.attrs["device"] = std::make_shared<dmlc::any>(std::move(device));
-  g = nnvm::pass::PlaceDevice(g, "__ctx_group__", device_map, "_CrossDeviceCopy");
-  const auto& assigned_device = g.GetAttr<nnvm::DeviceVector>("device");
-
-  ContextVector vcontext;
-  for (size_t i = 0; i < assigned_device.size(); ++i) {
-    if (assigned_device[i] == -1) {
-      vcontext.push_back(default_ctx);
-    } else {
-      vcontext.push_back(ctx_list[assigned_device[i]]);
-    }
-  }
-
-  // after device planning, we should check again
-  // if the assigned device of gradient node
-  // corresponds to storage of grads
-  auto &new_idx = g.indexed_graph();
-  arg_grad_offset = 0;
-  for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) {
-    while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset;
-    const uint32_t nid = new_idx.outputs()[i].node_id;
-    Context ctx = arg_grad_ctxes[arg_grad_offset];
-    CHECK(ctx == vcontext[nid])
-      << "Trying to save gradient to " << ctx
-      << " while its source node \"" << new_idx[nid].source->attrs.name
-      << "\" computes it on " << vcontext[nid]
-      << ". Check your ctx in NDArray allocation.";
-  }
-
-  g.attrs["context"] = std::make_shared<nnvm::any>(std::move(vcontext));
-  return g;
-}
-
-static void HandleInferShapeError(const size_t num_forward_inputs,
-                           const nnvm::IndexedGraph& idx,
-                           const nnvm::ShapeVector& inferred_shapes) {
-  int cnt = 10;
-  std::ostringstream oss;
-  for (size_t i = 0; i < num_forward_inputs; ++i) {
-    const uint32_t nid = idx.input_nodes().at(i);
-    const uint32_t eid = idx.entry_id(nid, 0);
-    const TShape& inferred_shape = inferred_shapes[eid];
-    if (inferred_shape.ndim() == 0 || inferred_shape.Size() == 0U) {
-      const std::string& arg_name = idx[nid].source->attrs.name;
-      oss << arg_name << ": " << inferred_shape << ", ";
-      if (--cnt == 0) {
-        oss << "...";
-        break;
-      }
-    }
-  }
-  LOG(FATAL) << "InferShape pass cannot decide shapes for the following arguments "
-                "(0s means unknown dimensions). Please consider providing them as inputs:\n"
-             << oss.str();
-}
-
-static void HandleInferTypeError(const size_t num_forward_inputs,
-                          const nnvm::IndexedGraph& idx,
-                          const nnvm::DTypeVector& inferred_dtypes) {
-  int cnt = 10;
-  std::ostringstream oss;
-  for (size_t i = 0; i < num_forward_inputs; ++i) {
-    const uint32_t nid = idx.input_nodes().at(i);
-    const uint32_t eid = idx.entry_id(nid, 0);
-    const int inferred_dtype = inferred_dtypes[eid];
-    if (inferred_dtype == -1) {
-      const std::string& arg_name = idx[nid].source->attrs.name;
-      oss << arg_name << ": " << inferred_dtype << ", ";
-      if (--cnt == 0) {
-        oss << "...";
-        break;
-      }
-    }
-  }
-  LOG(FATAL) << "InferType pass cannot decide dtypes for the following arguments "
-                "(-1 means unknown dtype). Please consider providing them as inputs:\n"
-             << oss.str();
-}
-
-static void HandleInferStorageTypeError(const size_t num_forward_inputs,
-                                 const nnvm::IndexedGraph& idx,
-                                 const StorageTypeVector& inferred_stypes) {
-  int cnt = 10;
-  std::ostringstream oss;
-  for (size_t i = 0; i < num_forward_inputs; ++i) {
-    const uint32_t nid = idx.input_nodes().at(i);
-    const uint32_t eid = idx.entry_id(nid, 0);
-    const int inferred_stype = inferred_stypes[eid];
-    if (inferred_stype == -1) {
-      const std::string& arg_name = idx[nid].source->attrs.name;
-      oss << arg_name << ": " << common::stype_string(inferred_stype) << ", ";
-      if (--cnt == 0) {
-        oss << "...";
-        break;
-      }
-    }
-  }
-  LOG(FATAL) << "InferStorageType pass cannot decide storage type for the following arguments "
-                "(-1 means unknown stype). Please consider providing them as inputs:\n"
-             << oss.str();
-}
-
 /*!
  * \brief GraphExecutor initializer for regular bind flow in which
  * input arguments and gradients are provided by users. This initializer
@@ -680,57 +460,6 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
   }
 }
 
-/*!
- * \brief If the requested ndarray's shape size is less than
- * the corresponding shared_data_array's shape size and the
- * storage type is shareable, reuse the memory allocation
- * in shared_buffer; otherwise, create a zero ndarray.
- * Shareable storages include both default storage and row_sparse storage
- * if enable_row_sparse_sharing is `True`, otherwise default storage only.
- */
-static NDArray ReshapeOrCreate(const std::string& name,
-                        const TShape& dest_arg_shape,
-                        const int dest_arg_dtype,
-                        const NDArrayStorageType dest_arg_stype,
-                        const Context& ctx,
-                        std::unordered_map<std::string, NDArray>* shared_buffer,
-                        bool enable_row_sparse_sharing) {
-  bool stype_shareable = dest_arg_stype == kDefaultStorage;
-  if (enable_row_sparse_sharing) {
-    stype_shareable = stype_shareable || dest_arg_stype == kRowSparseStorage;
-  }
-  auto it = shared_buffer->find(name);
-  if (it != shared_buffer->end()) {
-    // check if size is large enough for sharing
-    bool size_shareable = it->second.shape().Size() >= dest_arg_shape.Size();
-    if (size_shareable && stype_shareable) {  // memory can be reused
-      CHECK_EQ(it->second.dtype(), dest_arg_dtype)
-        << "Requested arg array's dtype does not match that of the reusable ndarray";
-      CHECK_EQ(it->second.storage_type(), dest_arg_stype)
-        << "Requested arg array's stype does not match that of the reusable ndarray";
-      return it->second.Reshape(dest_arg_shape);
-    } else if (stype_shareable) {
-      LOG(WARNING) << "Bucketing: data " << name << " has a shape " << dest_arg_shape
-                   << ", which is larger than already allocated shape " << it->second.shape()
-                   << ". Need to re-allocate. Consider putting default bucket key to be "
-                   << "the bucket taking the largest input for better memory sharing.";
-      // size is not large enough, creating a larger one for sharing
-      // the NDArrays in shared_buffer are guaranteed to be of shareable storages
-      it->second = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
-      return it->second;
-    } else {
-      // not shareable storage
-      return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
-    }
-  } else {
-    auto ret = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
-    if (stype_shareable) {
-      shared_buffer->emplace(name, ret);
-    }
-    return ret;
-  }  // if (it != shared_buffer->end())
-}
-
 /*!
  * \brief Initialize in_args, arg_grads, and aux_states
  * and their data_entry_ of the executor using
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index bfc415b4526..7b936c30025 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -163,20 +163,21 @@ class GraphExecutor : public Executor {
                      std::vector<NDArray>* aux_state_vec);
   // Initialize in_args, arg_grads and aux_states with
   // shared_buffer and shared_exec
-  void InitArguments(const nnvm::IndexedGraph& idx,
-                     const nnvm::ShapeVector& inferred_shapes,
-                     const nnvm::DTypeVector& inferred_dtypes,
-                     const StorageTypeVector& inferred_stypes,
-                     const std::vector<Context>& in_arg_ctxes,
-                     const std::vector<Context>& arg_grad_ctxes,
-                     const std::vector<Context>& aux_state_ctxes,
-                     const std::vector<OpReqType>& grad_req_types,
-                     const std::unordered_set<std::string>& shared_arg_names,
-                     const Executor* shared_exec,
-                     std::unordered_map<std::string, NDArray>* shared_buffer,
-                     std::vector<NDArray>* in_arg_vec,
-                     std::vector<NDArray>* arg_grad_vec,
-                     std::vector<NDArray>* aux_state_vec);
+  virtual void InitArguments(const nnvm::IndexedGraph& idx,
+                             const nnvm::ShapeVector& inferred_shapes,
+                             const nnvm::DTypeVector& inferred_dtypes,
+                             const StorageTypeVector& inferred_stypes,
+                             const std::vector<Context>& in_arg_ctxes,
+                             const std::vector<Context>& arg_grad_ctxes,
+                             const std::vector<Context>& aux_state_ctxes,
+                             const std::vector<OpReqType>& grad_req_types,
+                             const std::unordered_set<std::string>& shared_arg_names,
+                             const Executor* shared_exec,
+                             std::unordered_map<std::string, NDArray>* shared_buffer,
+                             std::vector<NDArray>* in_arg_vec,
+                             std::vector<NDArray>* arg_grad_vec,
+                             std::vector<NDArray>* aux_state_vec);
+
   // internal initialization of the graph for simple bind
   Graph InitGraph(nnvm::Symbol symbol,
                   const Context& default_ctx,
@@ -212,7 +213,6 @@ class GraphExecutor : public Executor {
   void BulkInferenceOpSegs();
   // perform bulking and segmentation on a training graph
   void BulkTrainingOpSegs(size_t total_num_nodes);
-
   // indicate whether there is a backward graph for gradients.
   bool need_grad_;
   // internal graph
diff --git a/src/executor/onnx_to_tensorrt.cc b/src/executor/onnx_to_tensorrt.cc
new file mode 100644
index 00000000000..0b4d91be700
--- /dev/null
+++ b/src/executor/onnx_to_tensorrt.cc
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file onnx_to_tensorrt.cc
+ * \brief TensorRT integration with the MXNet executor
+ * \author Marek Kolodziej, Clement Fuji Tsang
+ */
+
+#if MXNET_USE_TENSORRT
+
+#include "./onnx_to_tensorrt.h"
+
+#include <onnx/onnx.pb.h>
+
+#include <NvInfer.h>
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/text_format.h>
+#include <onnx-tensorrt/NvOnnxParser.h>
+#include <onnx-tensorrt/NvOnnxParserRuntime.h>
+#include <onnx-tensorrt/PluginFactory.hpp>
+#include <onnx-tensorrt/plugin_common.hpp>
+
+using std::cout;
+using std::cerr;
+using std::endl;
+
+namespace onnx_to_tensorrt {
+
+struct InferDeleter {
+  template<typename T>
+    void operator()(T* obj) const {
+      if ( obj ) {
+        obj->destroy();
+      }
+    }
+};
+
+template<typename T>
+inline std::shared_ptr<T> InferObject(T* obj) {
+  if ( !obj ) {
+    throw std::runtime_error("Failed to create object");
+  }
+  return std::shared_ptr<T>(obj, InferDeleter());
+}
+
+std::string onnx_ir_version_string(int64_t ir_version = onnx::IR_VERSION) {
+  int onnx_ir_major = ir_version / 1000000;
+  int onnx_ir_minor = ir_version % 1000000 / 10000;
+  int onnx_ir_patch = ir_version % 10000;
+  return (std::to_string(onnx_ir_major) + "." +
+    std::to_string(onnx_ir_minor) + "." +
+    std::to_string(onnx_ir_patch));
+}
+
+void PrintVersion() {
+  cout << "Parser built against:" << endl;
+  cout << "  ONNX IR version:  " << onnx_ir_version_string(onnx::IR_VERSION) << endl;
+  cout << "  TensorRT version: "
+    << NV_TENSORRT_MAJOR << "."
+    << NV_TENSORRT_MINOR << "."
+    << NV_TENSORRT_PATCH << endl;
+}
+
+nvinfer1::ICudaEngine* onnxToTrtCtx(
+        const std::string& onnx_model,
+        int32_t max_batch_size,
+        size_t max_workspace_size,
+        nvinfer1::ILogger::Severity verbosity,
+        bool debug_builder) {
+  GOOGLE_PROTOBUF_VERIFY_VERSION;
+
+  TRT_Logger trt_logger(verbosity);
+  auto trt_builder = InferObject(nvinfer1::createInferBuilder(trt_logger));
+  auto trt_network = InferObject(trt_builder->createNetwork());
+  auto trt_parser  = InferObject(nvonnxparser::createParser(
+      *trt_network, trt_logger));
+  ::ONNX_NAMESPACE::ModelProto parsed_model;
+  // We check for a valid parse, but the main effect is the side effect
+  // of populating parsed_model
+  if (!parsed_model.ParseFromString(onnx_model)) {
+    throw dmlc::Error("Could not parse ONNX from string");
+  }
+
+  if ( !trt_parser->parse(onnx_model.c_str(), onnx_model.size()) ) {
+      int nerror = trt_parser->getNbErrors();
+      for ( int i=0; i < nerror; ++i ) {
+        nvonnxparser::IParserError const* error = trt_parser->getError(i);
+        if ( error->node() != -1 ) {
+          ::ONNX_NAMESPACE::NodeProto const& node =
+            parsed_model.graph().node(error->node());
+          cerr << "While parsing node number " << error->node()
+               << " [" << node.op_type();
+          if ( !node.output().empty() ) {
+            cerr << " -> \"" << node.output(0) << "\"";
+          }
+          cerr << "]:" << endl;
+          if ( static_cast<int>(verbosity) >= \
+            static_cast<int>(nvinfer1::ILogger::Severity::kINFO) ) {
+            cerr << "--- Begin node ---" << endl;
+            cerr << node.DebugString() << endl;
+            cerr << "--- End node ---" << endl;
+          }
+        }
+        cerr << "ERROR: "
+             << error->file() << ":" << error->line()
+             << " In function " << error->func() << ":\n"
+             << "[" << static_cast<int>(error->code()) << "] " << error->desc()
+             << endl;
+      }
+      throw dmlc::Error("Cannot parse ONNX into TensorRT Engine");
+  }
+
+  bool fp16 = trt_builder->platformHasFastFp16();
+
+  trt_builder->setMaxBatchSize(max_batch_size);
+  trt_builder->setMaxWorkspaceSize(max_workspace_size);
+  if ( fp16 && dmlc::GetEnv("MXNET_TENSORRT_USE_FP16_FOR_FP32", false) ) {
+    LOG(INFO) << "WARNING: TensorRT using fp16 given original MXNet graph in fp32 !!!";
+    trt_builder->setHalf2Mode(true);
+  }
+
+  trt_builder->setDebugSync(debug_builder);
+  nvinfer1::ICudaEngine* trt_engine = trt_builder->buildCudaEngine(*trt_network.get());
+  return trt_engine;
+}
+
+}  // namespace onnx_to_tensorrt
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/executor/onnx_to_tensorrt.h b/src/executor/onnx_to_tensorrt.h
new file mode 100644
index 00000000000..259cfce7c33
--- /dev/null
+++ b/src/executor/onnx_to_tensorrt.h
@@ -0,0 +1,77 @@
+#ifndef MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_
+#define MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file onnx_to_tensorrt.h
+ * \brief TensorRT integration with the MXNet executor
+ * \author Marek Kolodziej, Clement Fuji Tsang
+ */
+
+#if MXNET_USE_TENSORRT
+
+#include <fstream>
+#include <iostream>
+#include <NvInfer.h>
+#include <sstream>
+#include <string>
+
+#include "../operator/contrib/tensorrt-inl.h"
+
+namespace onnx_to_tensorrt {
+
+class TRT_Logger : public nvinfer1::ILogger {
+        nvinfer1::ILogger::Severity _verbosity;
+        std::ostream* _ostream;
+ public:
+        TRT_Logger(Severity verbosity = Severity::kWARNING,
+                   std::ostream& ostream = std::cout)
+                : _verbosity(verbosity), _ostream(&ostream) {}
+        void log(Severity severity, const char* msg) override {
+                if ( severity <= _verbosity ) {
+                        time_t rawtime = std::time(0);
+                        char buf[256];
+                        strftime(&buf[0], 256,
+                                 "%Y-%m-%d %H:%M:%S",
+                                 std::gmtime(&rawtime));
+                        const char* sevstr = (severity == Severity::kINTERNAL_ERROR ? "    BUG" :
+                                              severity == Severity::kERROR          ? "  ERROR" :
+                                              severity == Severity::kWARNING        ? "WARNING" :
+                                              severity == Severity::kINFO           ? "   INFO" :
+                                              "UNKNOWN");
+                        (*_ostream) << "[" << buf << " " << sevstr << "] "
+                                    << msg
+                                    << std::endl;
+                }
+        }
+};
+
+nvinfer1::ICudaEngine* onnxToTrtCtx(
+        const std::string& onnx_model,
+        int32_t max_batch_size = 32,
+        size_t max_workspace_size = 1L << 30,
+        nvinfer1::ILogger::Severity verbosity = nvinfer1::ILogger::Severity::kWARNING,
+        bool debug_builder = false);
+}  // namespace onnx_to_tensorrt
+
+#endif  // MXNET_USE_TENSORRT
+
+#endif  // MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_
diff --git a/src/executor/tensorrt_pass.cc b/src/executor/tensorrt_pass.cc
new file mode 100644
index 00000000000..b5fc8d15f7a
--- /dev/null
+++ b/src/executor/tensorrt_pass.cc
@@ -0,0 +1,596 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file tensorrt_pass.cc
+ * \brief Replace TRT compatible subgraphs by TRT engines
+ * \author Clement Fuji Tsang
+ */
+
+#if MXNET_USE_TENSORRT
+
+#include <NvInfer.h>
+#include <mxnet/base.h>
+#include <mxnet/op_attr_types.h>
+#include <mxnet/operator.h>
+#include <nnvm/graph_attr_types.h>
+#include <onnx/onnx.pb.h>
+
+#include "../operator/contrib/nnvm_to_onnx-inl.h"
+#include "./exec_pass.h"
+#include "./onnx_to_tensorrt.h"
+
+namespace mxnet {
+namespace exec {
+
+using NodePtr = nnvm::NodePtr;
+
+/*!
+ * \brief Custom graph class, which will contain bi-directional nodes
+ * we need to compute DFS and reverse DFS for graph partitioning
+ */
+class BidirectionalGraph {
+ public:
+  struct Node {
+    nnvm::Node* nnvmptr;
+    std::vector<Node*> inputs;
+    std::vector<Node*> outputs;
+  };
+  std::vector<Node> nodes;
+  std::unordered_map<nnvm::Node*, uint32_t> nnvm2nid;
+  std::vector<Node*> outputs;
+  static const std::unordered_set<std::string> unconditionalTRTop;
+
+  explicit BidirectionalGraph(const Graph &g) {
+    auto& idx = g.indexed_graph();
+    auto num_nodes = idx.num_nodes();
+    nodes.reserve(num_nodes);
+    nnvm2nid.reserve(num_nodes);
+    outputs.reserve(idx.outputs().size());
+    DFSVisit(g.outputs, [this](const nnvm::NodePtr& n) {
+      BidirectionalGraph::Node new_node;
+      new_node.nnvmptr = n.get();
+      nnvm2nid[n.get()] = static_cast<uint32_t>(nodes.size());
+      nodes.emplace_back(std::move(new_node));
+    });
+    for (const auto& it : nnvm2nid) {
+      nnvm::Node* nnvmnode = it.first;
+      uint32_t nid = it.second;
+      for (auto& n : nnvmnode->inputs) {
+        uint32_t input_nid = nnvm2nid[n.node.get()];
+        nodes[input_nid].outputs.emplace_back(&nodes[nid]);
+        nodes[nid].inputs.emplace_back(&nodes[input_nid]);
+      }
+    }
+    for (auto& e : g.outputs) {
+      uint32_t nid = nnvm2nid[e.node.get()];
+      outputs.emplace_back(&nodes[nid]);
+    }
+  }
+
+  template <typename FVisit>
+  void DFS(const std::vector<Node*>& heads, bool reverse, FVisit fvisit) {
+    std::unordered_set<Node*> visited;
+    std::vector<Node*> vec(heads.begin(), heads.end());
+    visited.reserve(heads.size());
+    while (!vec.empty()) {
+      Node* vertex = vec.back();
+      vec.pop_back();
+      if (visited.count(vertex) == 0) {
+        visited.insert(vertex);
+        fvisit(vertex);
+        std::vector<Node*> nexts = reverse ? vertex->inputs : vertex->outputs;
+        for (Node* node : nexts) {
+          if (visited.count(node) == 0) {
+            vec.emplace_back(node);
+          }
+        }
+      }
+    }
+  }
+
+  using t_pairset = std::pair<std::unordered_set<Node*>, std::unordered_set<Node*>>;
+  using t_pairvec = std::pair<std::vector<Node*>, std::vector<Node*>>;
+  using t_uncomp_map = std::unordered_map<Node*, std::unordered_set<Node*>>;
+
+  std::unordered_set<Node*> naive_grow_subgraph(Node* head,
+                                                std::unordered_set<Node*>* set_unused,
+                                                t_uncomp_map* uncomp_map) {
+    std::unordered_set<Node*> subgraph;
+    std::unordered_set<Node*> uncomp_set;
+    std::deque<Node*> stack;
+    stack.emplace_back(head);
+    while (!stack.empty()) {
+      Node* vertex = stack.back();
+      stack.pop_back();
+      if (set_unused->count(vertex) && !uncomp_set.count(vertex)) {
+        set_unused->erase(vertex);
+        subgraph.insert(vertex);
+        uncomp_set.insert((*uncomp_map)[vertex].begin(), (*uncomp_map)[vertex].end());
+        for (Node* input : vertex->inputs) {
+          if (set_unused->count(input) && !uncomp_set.count(input)) {
+            stack.emplace_back(input);
+          }
+        }
+        for (Node* output : vertex->outputs) {
+          if (set_unused->count(output) && !uncomp_set.count(output)) {
+            stack.emplace_back(output);
+          }
+        }
+      }
+    }
+    return subgraph;
+  }
+
+  std::vector<std::unordered_set<Node*>> get_subsets(
+    std::unordered_map<std::string, NDArray>* const params_map) {
+    std::vector<std::unordered_set<Node*>> subgraphs;
+    std::unordered_set<Node*> set_nonTRTnodes;
+    std::unordered_set<Node*> set_allnodes(nodes.size());
+    std::vector<t_pairset> separation_sets;
+    for (Node& node : nodes) {
+      if (!IsTRTCompatible(node.nnvmptr)) {
+        set_nonTRTnodes.insert(&node);
+        std::unordered_set<Node*> in_graph;
+        std::unordered_set<Node*> out_graph;
+        std::vector<Node*> dummy_head;
+        dummy_head.emplace_back(&node);
+        DFS(dummy_head, false, [&out_graph](Node* node) {
+          out_graph.insert(node);
+        });
+        DFS(dummy_head, true, [&in_graph](Node* node) {
+          in_graph.insert(node);
+        });
+        separation_sets.emplace_back(std::make_pair(in_graph, out_graph));
+      }
+      set_allnodes.emplace(&node);
+    }
+    t_uncomp_map uncomp_map;
+    std::unordered_set<Node*> set_TRTnodes;
+    set_TRTnodes.insert(set_allnodes.begin(), set_allnodes.end());
+    for (Node* n : set_nonTRTnodes) {
+      set_TRTnodes.erase(n);
+    }
+    for (Node* n : set_TRTnodes) {
+      for (t_pairset p : separation_sets) {
+        if (p.first.count(n)) {
+          uncomp_map[n].insert(p.second.begin(), p.second.end());
+        } else if (p.second.count(n)) {
+          uncomp_map[n].insert(p.first.begin(), p.first.end());
+        }
+      }
+      for (Node* nonTRTn : set_nonTRTnodes) {
+        uncomp_map[n].erase(nonTRTn);
+      }
+    }
+    std::unordered_set<Node*> set_unused;
+    set_unused.reserve(set_TRTnodes.size());
+
+    for (auto& n : set_TRTnodes) {
+      if (n->nnvmptr->attrs.op != nullptr || params_map->count(n->nnvmptr->attrs.name)) {
+        set_unused.insert(n);
+      }
+    }
+    std::unordered_set<Node*> visited;
+    std::deque<Node*> stack(outputs.begin(), outputs.end());
+    while (!stack.empty()) {
+      Node* vertex = stack.front();
+      stack.pop_front();
+      if (!visited.count(vertex)) {
+        visited.insert(vertex);
+        if (set_unused.count(vertex)) {
+          subgraphs.emplace_back(naive_grow_subgraph(vertex, &set_unused, &uncomp_map));
+        }
+        for (Node* input : vertex->inputs) {
+          stack.emplace_back(input);
+        }
+      }
+    }
+
+    return subgraphs;
+  }
+
+
+ private:
+  friend class Graph;
+
+  bool IsTRTCompatible(nnvm::Node* nodeptr) {
+    if (nodeptr->op() == nullptr) {
+      return true;
+    }
+
+    const std::string op_name = nodeptr->op()->name;
+    if (op_name == "Pooling") {
+      return (nodeptr->attrs.dict.at("pool_type") == "avg" ||
+          nodeptr->attrs.dict.at("pool_type") == "max");
+    }
+
+    if (unconditionalTRTop.count(op_name)) {
+      return true;
+    }
+
+    if (op_name == "Activation") {
+      return nodeptr->attrs.dict.at("act_type") == "relu" ||
+        nodeptr->attrs.dict.at("act_type") == "tanh" ||
+        nodeptr->attrs.dict.at("act_type") == "sigmoid";
+    }
+
+    return false;
+  }
+};  // class BidirectionalGraph
+
+/*!
+ * \brief function which transform std::vector<dmlc::any> back to Attrs (dmlc::any)
+ */
+const std::unordered_set<std::string> BidirectionalGraph::unconditionalTRTop = {
+  "Convolution",
+  "BatchNorm",
+  "elemwise_add",
+  "elemwise_sub",
+  "elemwise_mul",
+  "rsqrt",
+  "pad",
+  "Pad",
+  "mean",
+  "FullyConnected",
+  "Flatten",
+  "SoftmaxOutput",
+};
+
+
+using NodeEntrySet = std::unordered_set<nnvm::NodeEntry, nnvm::NodeEntryHash,
+                                        nnvm::NodeEntryEqual>;
+
+/*!
+ * \brief get the output nodes of the subgraph in the main graph
+ * \return a vector of the output nodes
+*/
+std::vector<nnvm::NodeEntry> GetSubgraphNodeEntries(Graph g,
+    std::unordered_set<nnvm::Node*> set_subgraph) {
+  std::vector<nnvm::NodeEntry> outputs;
+  NodeEntrySet _outputs;
+  for (auto& e : g.outputs) {
+    if (set_subgraph.count(e.node.get())) {
+      _outputs.insert(e);
+    }
+  }
+  DFSVisit(g.outputs, [&set_subgraph, &_outputs](const nnvm::NodePtr &node){
+    if (!set_subgraph.count(node.get())) {
+      for (auto& e : node->inputs) {
+        if (set_subgraph.count(e.node.get())) {
+          _outputs.insert(e);
+        }
+      }
+    }
+  });
+  outputs.insert(outputs.begin(), _outputs.begin(), _outputs.end());
+  return outputs;
+}
+
+
+/*!
+ * \brief get the nodes outside of the subgraph for which outputs are used in the subgraph
+ * \return a vector the nodes
+*/
+std::vector<nnvm::NodeEntry> GetSubgraphInterfaceNodes(Graph g,
+    std::unordered_set<nnvm::Node*> set_subgraph) {
+  std::vector<nnvm::NodeEntry> inputs;
+  NodeEntrySet _inputs;
+  DFSVisit(g.outputs, [&set_subgraph, &_inputs](const nnvm::NodePtr &node){
+    if (set_subgraph.count(node.get())) {
+      for (auto& e : node->inputs) {
+        if (!set_subgraph.count(e.node.get())) {
+          _inputs.insert(e);
+        }
+      }
+    }
+  });
+  inputs.insert(inputs.begin(), _inputs.begin(), _inputs.end());
+  return inputs;
+}
+
+std::unordered_map<uint32_t, uint32_t> GetGraphInputsMap(const Graph& g) {
+  std::unordered_map<uint32_t, uint32_t> outputs;
+  auto& idx = g.indexed_graph();
+  outputs.reserve(idx.num_nodes());
+  std::vector<uint32_t> input_nodes = idx.input_nodes();
+  for (size_t i = 0; i < input_nodes.size(); ++i) {
+    outputs[input_nodes[i]] = static_cast<uint32_t>(i);
+  }
+  return outputs;
+}
+
+/*!
+ * \brief Dummy function which creates a fake TensorRT Node
+ */
+nnvm::NodePtr ConvertNnvmGraphToOnnx(const nnvm::Graph &g,
+                                     std::unordered_map<std::string, NDArray>* const params_map) {
+  auto p = nnvm::Node::Create();
+  p->attrs.op = nnvm::Op::Get("_trt_op");
+  op::TRTParam trt_param = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(g, params_map);
+  p->attrs.dict["serialized_output_map"] = trt_param.serialized_output_map;
+  p->attrs.dict["serialized_input_map"]  = trt_param.serialized_input_map;
+  p->attrs.dict["serialized_onnx_graph"] = trt_param.serialized_onnx_graph;
+  if (p->op()->attr_parser != nullptr) {
+    p->op()->attr_parser(&(p->attrs));
+  }
+  return p;
+}
+
+/*!
+ * \brief Update attributes of the graph (such as some inputs properties)
+ */
+Graph UpdateSubgraphAttrs(Graph&& subgraph, const Graph& g,
+                          const std::unordered_map<nnvm::Node*, nnvm::NodePtr>& old2new,
+                          const nnvm::NodeEntryMap<nnvm::NodeEntry>& main_input_entry_to_sub) {
+  const auto& idx     = g.indexed_graph();
+  const auto& sub_idx = subgraph.indexed_graph();
+
+  const auto& shape               = g.GetAttr<nnvm::ShapeVector>("shape");
+  const auto& dtype               = g.GetAttr<nnvm::DTypeVector>("dtype");
+  const auto& storage_type        = g.GetAttr<StorageTypeVector>("storage_type");
+  const auto& shape_inputs        = g.GetAttr<nnvm::ShapeVector>("shape_inputs");
+  const auto& dtype_inputs        = g.GetAttr<nnvm::DTypeVector>("dtype_inputs");
+  const auto& storage_type_inputs = g.GetAttr<StorageTypeVector>("storage_type_inputs");
+
+  nnvm::ShapeVector sub_shape(sub_idx.num_node_entries());
+  nnvm::DTypeVector sub_dtype(sub_idx.num_node_entries());
+  StorageTypeVector sub_storage_type(sub_idx.num_node_entries());
+  nnvm::ShapeVector sub_shape_inputs(sub_idx.input_nodes().size());
+  nnvm::DTypeVector sub_dtype_inputs(sub_idx.input_nodes().size());
+  StorageTypeVector sub_storage_type_inputs(sub_idx.input_nodes().size());
+
+  const std::unordered_map<uint32_t, uint32_t> inputsindex2pos     = GetGraphInputsMap(g);
+  const std::unordered_map<uint32_t, uint32_t> sub_inputsindex2pos = GetGraphInputsMap(subgraph);
+  // map attributes from graph to subgraph
+  for (auto& p : old2new) {
+    const uint32_t nid     = idx.node_id(p.first);
+    const uint32_t sub_nid = sub_idx.node_id(p.second.get());
+    const nnvm::Op* op = sub_idx[sub_nid].source->op();
+    if (op == nullptr) {  // if it's an input node, there is only one output node entry
+      const uint32_t sub_i       = sub_idx.entry_id(sub_nid, 0);
+      const uint32_t sub_input_i = sub_inputsindex2pos.at(sub_nid);
+      const uint32_t i           = idx.entry_id(nid, 0);
+
+      sub_shape[sub_i] = shape[i];
+      sub_dtype[sub_i] = dtype[i];
+      sub_storage_type[sub_i]       = storage_type[i];
+      sub_shape_inputs[sub_input_i] = shape_inputs[inputsindex2pos.at(nid)];
+      sub_dtype_inputs[sub_input_i] = dtype_inputs[inputsindex2pos.at(nid)];
+      sub_storage_type_inputs[sub_input_i] = storage_type_inputs[inputsindex2pos.at(nid)];
+
+    } else {
+      for (size_t oi = 0; oi < op->num_outputs; ++oi) {
+        const uint32_t sub_i = sub_idx.entry_id(sub_nid, oi);
+        const uint32_t i = idx.entry_id(nid, oi);
+          sub_shape[sub_i] = shape[i];
+          sub_dtype[sub_i] = dtype[i];
+          sub_storage_type[sub_i] = storage_type[i];
+      }
+    }
+  }
+  // old2new doesn't contain placeholder / interfaces
+  for (auto& p : main_input_entry_to_sub) {
+    nnvm::NodeEntry main_entry = p.first;
+    nnvm::NodeEntry sub_entry = p.second;
+    const uint32_t sub_nid = sub_idx.node_id(sub_entry.node.get());
+    const uint32_t sub_i = sub_idx.entry_id(sub_entry);
+    const uint32_t i = idx.entry_id(main_entry);
+    const uint32_t sub_input_i = sub_inputsindex2pos.at(sub_nid);
+    sub_shape[sub_i] = shape[i];
+    sub_dtype[sub_i] = dtype[i];
+    sub_storage_type[sub_i] = storage_type[i];
+    sub_shape_inputs[sub_input_i] = sub_shape[sub_i];
+    sub_dtype_inputs[sub_input_i] = sub_dtype[sub_i];
+    sub_storage_type_inputs[sub_input_i] = sub_storage_type[sub_i];
+  }
+  subgraph.attrs["shape"] =
+      std::make_shared<dmlc::any>(std::move(sub_shape));
+  subgraph.attrs["dtype"] =
+      std::make_shared<dmlc::any>(std::move(sub_dtype));
+  subgraph.attrs["storage_type"] =
+      std::make_shared<dmlc::any>(std::move(sub_storage_type));
+  subgraph.attrs["shape_inputs"] =
+      std::make_shared<dmlc::any>(std::move(sub_shape_inputs));
+  subgraph.attrs["dtype_inputs"] =
+      std::make_shared<dmlc::any>(std::move(sub_dtype_inputs));
+  subgraph.attrs["storage_type_inputs"] =
+      std::make_shared<dmlc::any>(std::move(sub_storage_type_inputs));
+
+  return subgraph;
+}
+
+/*!
+ * \brief Generate a name for a new TRT node, avoid collision if some TRT_nodes are already defined
+ */
+const std::string GetNewTrtName(const Graph& g, const Graph& subgraph) {
+  const std::string name_prefix("TRT_node");
+  std::unordered_set<std::string> name_set;
+  DFSVisit(g.outputs, [&name_set, &name_prefix](const nnvm::NodePtr& node) {
+    if (node->attrs.name.compare(0, name_prefix.size(), name_prefix) == 0) {
+      name_set.insert(node->attrs.name);
+    }
+  });
+  // name inside the subgraph will be avaible as they will be removed
+  DFSVisit(subgraph.outputs, [&name_set, &name_prefix](const nnvm::NodePtr& node) {
+    if (node->attrs.name.compare(0, name_prefix.size(), name_prefix) == 0) {
+      name_set.erase(node->attrs.name);
+    }
+  });
+  uint32_t name_suffix = 0;
+  std::string full_name = name_prefix + std::to_string(name_suffix);
+  while (name_set.count(full_name)) {
+    full_name = name_prefix + std::to_string(++name_suffix);
+  }
+  return full_name;
+}
+
+/*!
+ * \brief helper function to display what nodes are in a specific subset
+ */
+void dispNodesSet(Graph g, std::unordered_set<nnvm::Node*> s) {
+  DFSVisit(g.outputs, [&s](const nnvm::NodePtr n){
+    if (s.count(n.get())) {
+      std::cout << "  Y " << n->attrs.name << std::endl;
+    } else {
+      std::cout << "  N " << n->attrs.name << std::endl;
+    }
+  });
+}
+
+/*!
+ * \brief Replace a set of nodes by a TensorRT node
+ */
+Graph ReplaceSubgraph(Graph&& g,
+                      const std::unordered_set<nnvm::Node*>& set_subgraph,
+                      std::unordered_map<std::string, NDArray>* const params_map) {
+  // Create MXNet subgraph
+  Graph subgraph;
+
+  const auto sub_outputs_in_main = GetSubgraphNodeEntries(g, set_subgraph);
+  subgraph.outputs = sub_outputs_in_main;
+  // old2new will link raw pointer of the nodes in the graph to
+  // the corresponding shared_ptr of the nodes in the generated subgraph
+  std::unordered_map<nnvm::Node*, nnvm::NodePtr> old2new;
+  std::deque<nnvm::Node*> stack;
+  std::unordered_set<nnvm::Node*> visited;
+  int32_t reservation = set_subgraph.size();
+  old2new.reserve(reservation);
+  visited.reserve(reservation);
+
+  // Create the shared_ptr using the same raw pointer don't really matter
+  for (auto& n : set_subgraph) {
+    old2new[n] = std::make_shared<nnvm::Node>(*n);
+  }
+
+  // To generate a subgraph an input have to be replace by data node (no op)
+  // and it have to be agnostic to the node from which it's an output
+  // (For exemple even if two inputs are two different outputs from the same node)
+  nnvm::NodeEntryMap<nnvm::NodeEntry> main_input_entry_to_sub;
+  for (auto& e : GetSubgraphInterfaceNodes(g, set_subgraph)) {
+    auto node = nnvm::Node::Create();
+    node->attrs.name = e.node->attrs.name + "_" + std::to_string(e.index);
+    auto new_e = nnvm::NodeEntry{node, 0, 0};
+    main_input_entry_to_sub[e] = new_e;
+  }
+
+  for (nnvm::NodeEntry& e : subgraph.outputs) {
+    e.node = old2new[e.node.get()];
+    stack.emplace_back(e.node.get());
+  }
+  // link all nodes in the subgraph to nodes in the subgraph instead of main graph
+  while (!stack.empty()) {
+    auto vertex = stack.front();
+    stack.pop_front();
+    if (!visited.count(vertex)) {
+      visited.insert(vertex);
+      for (auto& e : vertex->inputs) {
+        auto it = main_input_entry_to_sub.find(e);
+        if (it != main_input_entry_to_sub.end()) {
+          e = it->second;
+        } else {
+          e.node = old2new[e.node.get()];
+        }
+      stack.emplace_back(e.node.get());
+      }
+    }
+  }
+  // Remove the control dependencies of the subgraph to nodes that are not in the subgraph
+  DFSVisit(subgraph.outputs, [&set_subgraph, &old2new](const nnvm::NodePtr& node) {
+    std::remove_if(node->control_deps.begin(),
+                   node->control_deps.end(),
+                   [&set_subgraph](nnvm::NodePtr n_ptr) {
+                    return !set_subgraph.count(n_ptr.get());
+                   });
+    for (nnvm::NodePtr& n_ptr : node->control_deps) {
+      n_ptr = old2new[n_ptr.get()];
+    }
+  });
+
+  subgraph = UpdateSubgraphAttrs(std::move(subgraph), g, old2new, main_input_entry_to_sub);
+  auto& sub_idx = subgraph.indexed_graph();
+
+  auto trtnodeptr = ConvertNnvmGraphToOnnx(subgraph, params_map);
+  trtnodeptr->attrs.name = GetNewTrtName(g, subgraph);
+
+  // Insert new trt node and unplug replaced nodes
+  std::unordered_map<uint32_t, nnvm::NodeEntry> sub_input_entryid_to_main;
+  for (auto& p : main_input_entry_to_sub) {
+    sub_input_entryid_to_main[sub_idx.entry_id(p.second)] = p.first;
+  }
+
+  // Plug the nodes from the main graph as inputs of the trt node
+  trtnodeptr->inputs.resize(main_input_entry_to_sub.size());
+  {
+    uint32_t counter = 0;
+    for (uint32_t i : sub_idx.input_nodes()) {
+      auto it = sub_input_entryid_to_main.find(sub_idx.entry_id(i, 0));
+      if (it != sub_input_entryid_to_main.end()) {
+        trtnodeptr->inputs[counter++] = it->second;
+      }
+    }
+  }
+  nnvm::NodeEntryMap<uint32_t> sub_outputs_in_main_to_pos;
+  for (uint32_t i = 0; i < sub_outputs_in_main.size(); ++i) {
+    sub_outputs_in_main_to_pos[sub_outputs_in_main[i]] = i;
+  }
+  // Plug the trt node as inputs to the main graph nodes
+  DFSVisit(g.outputs, [&sub_outputs_in_main_to_pos, &trtnodeptr](const nnvm::NodePtr& n) {
+    for (auto& e : n->inputs) {
+      auto it = sub_outputs_in_main_to_pos.find(e);
+      if (it != sub_outputs_in_main_to_pos.end()) {
+        e.index = it->second;
+        e.node = trtnodeptr;
+      }
+    }
+  });
+
+  for (auto& output : g.outputs) {
+    auto it = sub_outputs_in_main_to_pos.find(output);
+    if (it != sub_outputs_in_main_to_pos.end()) {
+      output.index = it->second;
+      output.node = trtnodeptr;
+    }
+  }
+
+  Graph new_graph;
+  new_graph.outputs = g.outputs;
+  return new_graph;
+}
+
+std::vector<std::unordered_set<nnvm::Node*>> GetTrtCompatibleSubsets(const Graph& g,
+    std::unordered_map<std::string, NDArray>* const params_map) {
+  BidirectionalGraph biG = BidirectionalGraph(g);
+  std::vector<std::unordered_set<BidirectionalGraph::Node*>> subsets = biG.get_subsets(params_map);
+  std::vector<std::unordered_set<nnvm::Node*>> nnvm_subsets(subsets.size(),
+                                                            std::unordered_set<nnvm::Node*>());
+  for (size_t i = 0; i < subsets.size(); ++i) {
+    nnvm_subsets[i].reserve(subsets[i].size());
+    for (auto& n : subsets[i]) {
+      nnvm_subsets[i].insert(n->nnvmptr);
+    }
+  }
+  return nnvm_subsets;
+}
+
+}  // namespace exec
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/executor/trt_graph_executor.cc b/src/executor/trt_graph_executor.cc
new file mode 100644
index 00000000000..65dbb29792e
--- /dev/null
+++ b/src/executor/trt_graph_executor.cc
@@ -0,0 +1,450 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#if MXNET_USE_TENSORRT
+
+#include "trt_graph_executor.h"
+
+#include <onnx/onnx.pb.h>
+#include <NvInfer.h>
+#include "./onnx_to_tensorrt.h"
+#include "../operator/contrib/tensorrt-inl.h"
+#include "../common/utils.h"
+#include "../common/exec_utils.h"
+
+
+namespace mxnet {
+namespace exec {
+
+using namespace mxnet::common;
+
+  /*!
+ * \brief TrtGraphExecutor initializer for simple bind flow in
+ * which only certain input shapes and dtypes are provided by users.
+ * The initializer uses these shapes and dtypes to perform
+ * shape and dtype inferences, and then create NDArrays
+ * to populate data entries of the graph. The created NDArrays
+ * for in_args, arg_grads and aux_states are passed to the
+ * front end to attach the created executor.
+ * In front end, if the simple_bind flow is trigger by
+ * _bind_ith_exec, the shared data arrays of DataParallelExecutorGroup
+ * and shared executor will be taken into account in creating
+ * NDArrays for in_args, arg_grads, and aux_states for reusing
+ * already allocated memory.
+ *
+ * This version of an executor exports the computation graph to TensorRT make use of fused
+ * kernels and other runtime enhancements.  TRT will compile the sub-graphs to executable fused
+ * operators without intervention from the user.  Operators in the original graph that are not
+ * supported by TRT will continue to be executed normally by MXNet.
+ *
+ */
+void TrtGraphExecutor::Init(nnvm::Symbol symbol,
+                            const Context& default_ctx,
+                            const std::map<std::string, Context>& ctx_map,
+                            std::vector<Context> *in_arg_ctxes,
+                            std::vector<Context> *arg_grad_ctxes,
+                            std::vector<Context> *aux_state_ctxes,
+                            std::unordered_map<std::string, TShape> *arg_shape_map,
+                            std::unordered_map<std::string, int> *arg_dtype_map,
+                            std::unordered_map<std::string, int> *arg_stype_map,
+                            std::vector<OpReqType> *grad_req_types,
+                            const std::unordered_set<std::string>& shared_arg_names,
+                            std::vector<NDArray>* in_arg_vec,
+                            std::vector<NDArray>* arg_grad_vec,
+                            std::vector<NDArray>* aux_state_vec,
+                            std::unordered_map<std::string, NDArray>* shared_buffer,
+                            Executor* shared_exec,
+                            const nnvm::NodeEntryMap<NDArray>& feed_dict) {
+  symbol = symbol.Copy();
+  nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, *in_arg_ctxes, *arg_grad_ctxes,
+                            *aux_state_ctxes, *grad_req_types);
+
+  if (need_grad_) {
+    LOG(FATAL) << "You may be attempting to use TensorRT for training.  TensorRT is an inference "
+                  "only library.  To re-enable legacy MXNet graph execution, which will support "
+                  "training, set the MXNET_USE_TENSORRT environment variable to 0, or call "
+                  "mx.contrib.tensorrt.set_use_tensorrt(False)";
+  }
+
+  if (shared_buffer == nullptr || shared_buffer->empty()) {
+    LOG(FATAL) << "MXNET_USE_TENSORRT = 1 but shared_buffer is empty. "
+               << "Please provide weights and other parameters, such as "
+               << "BatchNorm moments, via the shared_buffer, during simple bind call.";
+  }
+
+  // The following code of shape and dtype inferences and argument
+  // initialization is for simple_bind only. Regular bind operation
+  // should do this differently.
+
+  // Initialize arg_shapes and arg_dtypes for shape and type inferences.
+  // It contains all in_args and aux_states' shapes and types in a certain order.
+  const nnvm::IndexedGraph& idx = g.indexed_graph();
+  nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape());
+  nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1);
+  StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage);
+  for (size_t i = 0; i < num_forward_inputs_; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const std::string& name = idx[nid].source->attrs.name;
+    auto it1 = arg_shape_map->find(name);
+    if (arg_shape_map->end() != it1) {
+      arg_shapes[i] = it1->second;
+    }
+    auto it2 = arg_dtype_map->find(name);
+    if (arg_dtype_map->end() != it2) {
+      arg_dtypes[i] = it2->second;
+    }
+    auto it3 = arg_stype_map->find(name);
+    if (arg_stype_map->end() != it3) {
+      arg_stypes[i] = it3->second;
+    }
+  }
+  g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+  if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+    HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
+                          g.GetAttr<nnvm::ShapeVector>("shape"));
+  }
+
+  g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+  if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
+    HandleInferTypeError(num_forward_inputs_, g.indexed_graph(),
+                         g.GetAttr<nnvm::DTypeVector>("dtype"));
+  }
+
+  g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
+  if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
+    HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(),
+                                g.GetAttr<StorageTypeVector>("storage_type"));
+  }
+
+  auto trt_groups = GetTrtCompatibleSubsets(g, shared_buffer);
+  for (auto trt_group : trt_groups) {
+    if (trt_group.size() > 1) {
+      g = ReplaceSubgraph(std::move(g), trt_group, shared_buffer);
+      g = ReinitGraph(std::move(g), default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes,
+                      aux_state_ctxes, grad_req_types, arg_shape_map, arg_dtype_map,
+                      arg_stype_map, shared_buffer);
+    }
+  }
+
+
+  InitArguments(g.indexed_graph(), g.GetAttr<nnvm::ShapeVector>("shape"),
+                g.GetAttr<nnvm::DTypeVector>("dtype"),
+                g.GetAttr<StorageTypeVector>("storage_type"),
+                *in_arg_ctxes, *arg_grad_ctxes, *aux_state_ctxes,
+                *grad_req_types, shared_arg_names, shared_exec,
+                shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec);
+
+  // The above code of shape and dtype inferences and argument
+  // initialization is for simple_bind only. Regular bind operation
+  // should do this differently.
+
+  // Initialize the rest attributes of the graph.
+  // This function can be called by regular bind
+  // operation flow as well.
+  FinishInitGraph(symbol, g, shared_exec, feed_dict);
+}
+/*!
+ * \brief Initialize in_args, arg_grads, and aux_states
+ * and their data_entry_ of the executor using
+ * shared_buffer from DataParallelExecutorGroup
+ * and shared_exec if available.
+ */
+void TrtGraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
+                                  const nnvm::ShapeVector& inferred_shapes,
+                                  const nnvm::DTypeVector& inferred_dtypes,
+                                  const StorageTypeVector& inferred_stypes,
+                                  const std::vector<Context>& in_arg_ctxes,
+                                  const std::vector<Context>& arg_grad_ctxes,
+                                  const std::vector<Context>& aux_state_ctxes,
+                                  const std::vector<OpReqType>& grad_req_types,
+                                  const std::unordered_set<std::string>& shared_arg_names,
+                                  const Executor* shared_exec,
+                                  std::unordered_map<std::string, NDArray>* shared_buffer,
+                                  std::vector<NDArray>* in_arg_vec,
+                                  std::vector<NDArray>* arg_grad_vec,
+                                  std::vector<NDArray>* aux_state_vec) {
+  // initialize in_args, arg_grads, and aux_states and populate grad_store_
+  data_entry_.resize(idx.num_node_entries());
+  size_t arg_top = 0, aux_top = 0;
+  const auto& mutable_nodes = idx.mutable_input_nodes();
+  for (size_t i = 0; i < num_forward_inputs_; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const uint32_t eid = idx.entry_id(nid, 0);
+    const TShape& inferred_shape = inferred_shapes[eid];
+    const int inferred_dtype = inferred_dtypes[eid];
+    const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid];
+    const std::string& arg_name = idx[nid].source->attrs.name;
+    // aux_states
+    if (mutable_nodes.count(nid)) {
+      if (nullptr != shared_exec) {
+        const NDArray& aux_nd = shared_exec->aux_state_map().at(arg_name);
+        CHECK(inferred_stype == kDefaultStorage && aux_nd.storage_type() == kDefaultStorage)
+          << "Non-default storage type detected when creating auxilliary NDArray. The allocated "
+          << "memory of shared_exec.aux_array cannot be resued for argument: "
+          << arg_name << " for the current executor";
+        CHECK_EQ(inferred_shape, aux_nd.shape())
+          << "Inferred shape does not match shared_exec.aux_array's shape."
+             " Therefore, the allocated memory for shared_exec.aux_array cannot"
+             " be resued for creating auxilliary NDArray of the argument: "
+          << arg_name << " for the current executor";
+        CHECK_EQ(inferred_dtype, aux_nd.dtype())
+          << "Inferred dtype does not match shared_exec.aux_array's dtype."
+             " Therefore, the allocated memory for shared_exec.aux_array cannot"
+             " be resued for creating auxilliary NDArray of the argument: "
+          << arg_name << " for the current executor";
+        aux_state_vec->emplace_back(aux_nd);
+      } else {
+        auto it = shared_buffer->find(arg_name);
+        if (it != shared_buffer->end()) {
+          aux_state_vec->push_back(std::move(it->second.Copy(aux_state_ctxes[aux_top])));
+        } else {
+          aux_state_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape,
+                                                       aux_state_ctxes[aux_top], inferred_dtype)));
+        }
+      }  // if (has_shared_exec)
+      data_entry_[eid] = aux_state_vec->back();
+      aux_state_map_.emplace(arg_name, aux_state_vec->back());
+      ++aux_top;
+    } else {  // in_args and grad for in_args
+      if (shared_arg_names.count(arg_name)) {  // model parameter
+        // model parameter
+        if (nullptr != shared_exec) {
+          const NDArray& in_arg_nd = shared_exec->in_arg_map().at(arg_name);
+          auto arg_nd_stype = in_arg_nd.storage_type();
+          // for model parameter, both default storage and row_sparse storage can be shared
+          bool shareable_arg_stype = inferred_stype == kDefaultStorage ||
+                                     inferred_stype == kRowSparseStorage;
+          // try to reuse memory from shared_exec
+          CHECK(shareable_arg_stype) << "Inferred storage type "
+            << common::stype_string(inferred_stype)
+            << " does not support memory sharing with shared_exec.arg_array";
+          CHECK_EQ(inferred_stype, arg_nd_stype)
+            << "Inferred stype does not match shared_exec.arg_array's stype"
+               " Therefore, the allocated memory for shared_exec.arg_array cannot"
+               " be resued for creating NDArray of the argument "
+            << arg_name << " for the current executor";
+          CHECK_EQ(inferred_shape, in_arg_nd.shape())
+            << "Inferred shape does not match shared_exec.arg_array's shape"
+               " Therefore, the allocated memory for shared_exec.arg_array cannot"
+               " be resued for creating NDArray of the argument "
+            << arg_name << " for the current executor";
+          CHECK_EQ(inferred_dtype, in_arg_nd.dtype())
+            << "Inferred dtype does not match shared_exec.arg_array's dtype"
+               " Therefore, the allocated memory for shared_exec.arg_array cannot"
+               " be resued for creating NDArray of the argument "
+            << arg_name << " for the current executor";
+          in_arg_vec->emplace_back(in_arg_nd);
+        } else {
+          // doesn't have shared_exec, or non-default storage
+          EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top],
+                           inferred_dtype, in_arg_vec);
+        }
+        // gradient for model parameter
+        if (kNullOp == grad_req_types[arg_top]) {
+          arg_grad_vec->emplace_back();
+        } else {
+          auto grad_oid = grad_store_.size() + num_forward_outputs_;
+          auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]);
+          auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid];
+          if (nullptr != shared_exec && grad_stype == kDefaultStorage &&
+              shared_exec->arg_grad_map().at(arg_name).storage_type() == kDefaultStorage) {
+            // try to reuse memory from shared_exec
+            arg_grad_vec->emplace_back(shared_exec->arg_grad_map().at(arg_name));
+          } else {
+            // no need to reuse memory from shared_exec for gradient of non-default storage
+            EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top],
+                             inferred_dtype, arg_grad_vec);
+          }
+          grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back());
+        }
+      } else {  // !shared_arg_names.count(arg_name)
+        // model parameter, row_sparse ndarray sharing enabled
+        auto it = shared_buffer->find(arg_name);
+        if (it != shared_buffer->end()) {
+          in_arg_vec->push_back(std::move(it->second.Copy(in_arg_ctxes[arg_top])));
+        } else {
+          in_arg_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape,
+                                                    in_arg_ctxes[arg_top], inferred_dtype)));
+        }
+        // gradient for model parameter, row_sparse ndarray sharing disabled
+        if (kNullOp == grad_req_types[arg_top]) {
+          arg_grad_vec->emplace_back();
+        } else {
+          auto grad_oid = grad_store_.size() + num_forward_outputs_;
+          auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]);
+          auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid];
+          bool enable_row_sparse_sharing = false;
+          arg_grad_vec->emplace_back(ReshapeOrCreate("grad of " + arg_name, inferred_shape,
+                                                     inferred_dtype, grad_stype,
+                                                     arg_grad_ctxes[arg_top], shared_buffer,
+                                                     enable_row_sparse_sharing));
+          grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back());
+        }  // if (kNullOp == grad_req_types[arg_top])
+      }  // if (shared_arg_names.count(arg_name))
+      in_arg_map_.emplace(arg_name, in_arg_vec->back());
+      if (!arg_grad_vec->back().is_none()) {
+        arg_grad_map_.emplace(arg_name, arg_grad_vec->back());
+      }
+      data_entry_[eid] = in_arg_vec->back();
+      ++arg_top;
+    }
+  }
+}
+
+
+  /*!
+ * \brief This function is triggered after each tensorrt subgraph replacement pass.
+ * Reset arguments of GraphExecutor::Init(...) as some variables (weights and biases)
+ * are absorbed into the TRT engine it also it reruns attributes inferences accordingly
+ * to the new topology.
+ */
+Graph TrtGraphExecutor::ReinitGraph(Graph&& g, const Context &default_ctx,
+                                 const std::map<std::string, Context> &ctx_map,
+                                 std::vector<Context> *in_arg_ctxes,
+                                 std::vector<Context> *arg_grad_ctxes,
+                                 std::vector<Context> *aux_state_ctxes,
+                                 std::vector<OpReqType> *grad_req_types,
+                                 std::unordered_map<std::string, TShape> *arg_shape_map,
+                                 std::unordered_map<std::string, int> *arg_dtype_map,
+                                 std::unordered_map<std::string, int> *arg_stype_map,
+                                 std::unordered_map<std::string, NDArray> *params_map) {
+  std::unordered_set<std::string> to_remove_params;
+  for (auto& el : *params_map) {
+    to_remove_params.insert(el.first);
+  }
+
+  DFSVisit(g.outputs, [&to_remove_params](const nnvm::NodePtr n) {
+    to_remove_params.erase(n->attrs.name);
+  });
+
+  for (auto& el : to_remove_params) {
+    params_map->erase(el);
+    arg_shape_map->erase(el);
+    arg_dtype_map->erase(el);
+    arg_stype_map->erase(el);
+  }
+  const auto &idx = g.indexed_graph();
+  num_forward_inputs_ = idx.input_nodes().size();
+  in_arg_ctxes->resize(num_forward_inputs_ - idx.mutable_input_nodes().size());
+  arg_grad_ctxes->resize(num_forward_inputs_ - idx.mutable_input_nodes().size());
+  grad_req_types->resize(num_forward_inputs_ - idx.mutable_input_nodes().size());
+  aux_state_ctxes->resize(idx.mutable_input_nodes().size());
+
+  // create "device" and "context" attrs for the graph
+  g = AssignContext(g, default_ctx, ctx_map, *in_arg_ctxes, *arg_grad_ctxes,
+                    *aux_state_ctxes, *grad_req_types, num_forward_inputs_,
+                    num_forward_outputs_);
+
+  // get number of nodes used in forward pass
+  num_forward_nodes_ = 0;
+  for (size_t i = 0; i < num_forward_outputs_; ++i) {
+    num_forward_nodes_ = std::max(
+        num_forward_nodes_, static_cast<size_t>(idx.outputs()[i].node_id + 1));
+  }
+  nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape());
+  nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1);
+  StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage);
+  for (size_t i = 0; i < num_forward_inputs_; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const std::string &name = idx[nid].source->attrs.name;
+    auto it1 = arg_shape_map->find(name);
+    if (arg_shape_map->end() != it1) {
+      arg_shapes[i] = it1->second;
+    }
+    auto it2 = arg_dtype_map->find(name);
+    if (arg_dtype_map->end() != it2) {
+      arg_dtypes[i] = it2->second;
+    }
+    auto it3 = arg_stype_map->find(name);
+    if (arg_stype_map->end() != it3) {
+      arg_stypes[i] = it3->second;
+    }
+  }
+  g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+  if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+    HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
+                          g.GetAttr<nnvm::ShapeVector>("shape"));
+  }
+
+  g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+  if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
+    HandleInferTypeError(num_forward_inputs_, g.indexed_graph(),
+                         g.GetAttr<nnvm::DTypeVector>("dtype"));
+  }
+
+  g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
+
+  if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
+    HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(),
+                                g.GetAttr<StorageTypeVector>("storage_type"));
+  }
+
+  return g;
+}
+
+
+/*!
+ * \brief Return the "optimized" symbol contained in the graph.
+ * For optimization pass such as TensorRT pass
+ */
+nnvm::Symbol TrtGraphExecutor::GetOptimizedSymbol() {
+  Symbol ret;
+  ret.outputs = std::vector<nnvm::NodeEntry>(graph_.outputs.begin(),
+                                             graph_.outputs.begin() + num_forward_outputs_);
+  ret = ret.Copy();
+  static const Op* trt_op = Op::Get("_trt_op");
+  DFSVisit(ret.outputs, [](const nnvm::NodePtr n) {
+    if (n->op() == trt_op) {
+      n->attrs.dict.clear();
+    }
+  });
+  return ret;
+}
+
+Executor *TrtGraphExecutor::TensorRTBind(nnvm::Symbol symbol,
+                                         const Context &default_ctx,
+                                         const std::map<std::string, Context> &group2ctx,
+                                         std::vector<Context> *in_arg_ctxes,
+                                         std::vector<Context> *arg_grad_ctxes,
+                                         std::vector<Context> *aux_state_ctxes,
+                                         std::unordered_map<std::string, TShape> *arg_shape_map,
+                                         std::unordered_map<std::string, int> *arg_dtype_map,
+                                         std::unordered_map<std::string, int> *arg_stype_map,
+                                         std::vector<OpReqType> *grad_req_types,
+                                         const std::unordered_set<std::string> &param_names,
+                                         std::vector<NDArray> *in_args,
+                                         std::vector<NDArray> *arg_grads,
+                                         std::vector<NDArray> *aux_states,
+                                         std::unordered_map<std::string, NDArray> *shared_buffer,
+                                         Executor *shared_exec) {
+  auto exec = new exec::TrtGraphExecutor();
+  exec->Init(symbol, default_ctx, group2ctx,
+             in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
+             arg_shape_map, arg_dtype_map, arg_stype_map,
+             grad_req_types, param_names,
+             in_args, arg_grads, aux_states,
+             shared_buffer, shared_exec);
+  return exec;
+}
+
+}  // namespace exec
+
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/executor/trt_graph_executor.h b/src/executor/trt_graph_executor.h
new file mode 100644
index 00000000000..96ac4426270
--- /dev/null
+++ b/src/executor/trt_graph_executor.h
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_
+#define MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_
+
+#if MXNET_USE_TENSORRT
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "./graph_executor.h"
+
+namespace mxnet {
+
+namespace exec {
+
+class TrtGraphExecutor : public GraphExecutor {
+ public:
+  static Executor* TensorRTBind(nnvm::Symbol symbol,
+                                const Context& default_ctx,
+                                const std::map<std::string, Context>& group2ctx,
+                                std::vector<Context> *in_arg_ctxes,
+                                std::vector<Context>* arg_grad_ctxes,
+                                std::vector<Context>* aux_state_ctxes,
+                                std::unordered_map<std::string, TShape>* arg_shape_map,
+                                std::unordered_map<std::string, int>* arg_dtype_map,
+                                std::unordered_map<std::string, int>* arg_stype_map,
+                                std::vector<OpReqType>* grad_req_types,
+                                const std::unordered_set<std::string>& param_names,
+                                std::vector<NDArray>* in_args,
+                                std::vector<NDArray>* arg_grads,
+                                std::vector<NDArray>* aux_states,
+                                std::unordered_map<std::string, NDArray>*
+                                shared_data_arrays = nullptr,
+                                Executor* shared_exec = nullptr);
+
+  virtual void Init(nnvm::Symbol symbol,
+                    const Context& default_ctx,
+                    const std::map<std::string, Context>& ctx_map,
+                    std::vector<Context> *in_arg_ctxes,
+                    std::vector<Context> *arg_grad_ctxes,
+                    std::vector<Context> *aux_state_ctxes,
+                    std::unordered_map<std::string, TShape> *arg_shape_map,
+                    std::unordered_map<std::string, int> *arg_dtype_map,
+                    std::unordered_map<std::string, int> *arg_stype_map,
+                    std::vector<OpReqType> *grad_req_types,
+                    const std::unordered_set<std::string>& shared_arg_names,
+                    std::vector<NDArray>* in_arg_vec,
+                    std::vector<NDArray>* arg_grad_vec,
+                    std::vector<NDArray>* aux_state_vec,
+                    std::unordered_map<std::string, NDArray>* shared_buffer = nullptr,
+                    Executor* shared_exec = nullptr,
+                    const nnvm::NodeEntryMap<NDArray>& feed_dict
+                      = nnvm::NodeEntryMap<NDArray>());
+
+  // Returns symbol representing the TRT optimized graph for comparison purposes.
+  nnvm::Symbol GetOptimizedSymbol();
+
+ protected:
+  Graph ReinitGraph(Graph&& g, const Context &default_ctx,
+        const std::map<std::string, Context> &ctx_map,
+        std::vector<Context> *in_arg_ctxes,
+        std::vector<Context> *arg_grad_ctxes,
+        std::vector<Context> *aux_state_ctxes,
+        std::vector<OpReqType> *grad_req_types,
+        std::unordered_map<std::string, TShape> *arg_shape_map,
+        std::unordered_map<std::string, int> *arg_dtype_map,
+        std::unordered_map<std::string, int> *arg_stype_map,
+        std::unordered_map<std::string, NDArray> *params_map);
+
+  void InitArguments(const nnvm::IndexedGraph& idx,
+                     const nnvm::ShapeVector& inferred_shapes,
+                     const nnvm::DTypeVector& inferred_dtypes,
+                     const StorageTypeVector& inferred_stypes,
+                     const std::vector<Context>& in_arg_ctxes,
+                     const std::vector<Context>& arg_grad_ctxes,
+                     const std::vector<Context>& aux_state_ctxes,
+                     const std::vector<OpReqType>& grad_req_types,
+                     const std::unordered_set<std::string>& shared_arg_names,
+                     const Executor* shared_exec,
+                     std::unordered_map<std::string, NDArray>* shared_buffer,
+                     std::vector<NDArray>* in_arg_vec,
+                     std::vector<NDArray>* arg_grad_vec,
+                     std::vector<NDArray>* aux_state_vec) override;
+};
+
+}  // namespace exec
+
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
+
+#endif  // MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_
diff --git a/src/operator/contrib/nnvm_to_onnx-inl.h b/src/operator/contrib/nnvm_to_onnx-inl.h
new file mode 100644
index 00000000000..58f88b05143
--- /dev/null
+++ b/src/operator/contrib/nnvm_to_onnx-inl.h
@@ -0,0 +1,156 @@
+#ifndef MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_
+#define MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file tensorrt-inl.h
+ * \brief TensorRT Operator
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include <dmlc/logging.h>
+#include <dmlc/memory_io.h>
+#include <dmlc/serializer.h>
+#include <dmlc/parameter.h>
+#include <mxnet/base.h>
+#include <mxnet/operator.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+
+#include <NvInfer.h>
+#include <onnx/onnx.pb.h>
+
+#include <algorithm>
+#include <iostream>
+#include <map>
+#include <vector>
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+#include <string>
+
+#include "./tensorrt-inl.h"
+#include "../operator_common.h"
+#include "../../common/utils.h"
+#include "../../common/serialization.h"
+
+namespace mxnet {
+namespace op {
+namespace nnvm_to_onnx {
+
+using namespace nnvm;
+using namespace ::onnx;
+using int64 = ::google::protobuf::int64;
+
+std::unordered_map<std::string, TShape> GetPlaceholderShapes(const ShapeVector& shape_inputs,
+    const nnvm::IndexedGraph& ig);
+
+std::unordered_map<std::string, uint32_t> GetOutputLookup(const nnvm::IndexedGraph& ig);
+
+void ConvertPlaceholder(
+  const std::string& node_name,
+  const std::unordered_map<std::string, TShape>& placeholder_shapes,
+  GraphProto* const graph_proto);
+
+void ConvertConstant(GraphProto* const graph_proto,
+  const std::string& node_name,
+  std::unordered_map<std::string, NDArray>* const shared_buffer);
+
+void ConvertOutput(op::tensorrt::InferenceMap_t* const trt_output_map,
+                   GraphProto* const graph_proto,
+                   const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
+                   const std::string& node_name,
+                   const nnvm::Graph& g,
+                   const StorageTypeVector& storage_types,
+                   const DTypeVector& dtypes);
+
+typedef void (*ConverterFunction)(NodeProto *node_proto,
+                                  const NodeAttrs &attrs,
+                                  const nnvm::IndexedGraph &ig,
+                                  const array_view<IndexedGraph::NodeEntry> &inputs);
+
+
+// Forward declarations
+void ConvertConvolution(
+                        NodeProto *node_proto,
+                        const NodeAttrs &attrs,
+                        const nnvm::IndexedGraph &ig,
+                        const array_view<IndexedGraph::NodeEntry> &inputs);
+
+
+void ConvertPooling(NodeProto *node_proto,
+                    const NodeAttrs &attrs,
+                    const nnvm::IndexedGraph &ig,
+                    const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertActivation(NodeProto *node_proto,
+                       const NodeAttrs &attrs,
+                       const nnvm::IndexedGraph &ig,
+                       const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertFullyConnected(NodeProto *node_proto,
+                           const NodeAttrs &attrs,
+                           const nnvm::IndexedGraph &ig,
+                           const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertSoftmaxOutput(NodeProto *node_proto,
+                          const NodeAttrs &attrs,
+                          const nnvm::IndexedGraph &ig,
+                          const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertFlatten(NodeProto *node_proto,
+                    const NodeAttrs &attrs,
+                    const nnvm::IndexedGraph &ig,
+                    const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertBatchNorm(NodeProto *node_proto,
+                    const NodeAttrs &attrs,
+                    const nnvm::IndexedGraph &ig,
+                    const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertElementwiseAdd(NodeProto *node_proto,
+                    const NodeAttrs &attrs,
+                    const nnvm::IndexedGraph &ig,
+                    const array_view<IndexedGraph::NodeEntry> &inputs);
+
+TRTParam ConvertNnvmGraphToOnnx(
+    const nnvm::Graph &g,
+    std::unordered_map<std::string, NDArray> *const shared_buffer);
+
+static const std::unordered_map<std::string, ConverterFunction> converter_map = {
+  {"Convolution", ConvertConvolution},
+  {"Pooling", ConvertPooling},
+  {"Activation", ConvertActivation},
+  {"FullyConnected", ConvertFullyConnected},
+  {"SoftmaxOutput", ConvertSoftmaxOutput},
+  {"Flatten", ConvertFlatten},
+  {"BatchNorm", ConvertBatchNorm},
+  {"elemwise_add", ConvertElementwiseAdd}};
+
+}  // namespace nnvm_to_onnx
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
+
+#endif  // MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_
diff --git a/src/operator/contrib/nnvm_to_onnx.cc b/src/operator/contrib/nnvm_to_onnx.cc
new file mode 100644
index 00000000000..902466614c7
--- /dev/null
+++ b/src/operator/contrib/nnvm_to_onnx.cc
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file trt.cc
+ * \brief TensorRT operation registration
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include "./nnvm_to_onnx-inl.h"
+
+#include <mxnet/base.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+
+#include <algorithm>
+#include <fstream>
+#include <iostream>
+#include <unordered_map>
+#include <vector>
+
+#include "../../common/serialization.h"
+#include "../../common/utils.h"
+#include "../../ndarray/ndarray_function.h"
+#include "../../operator/nn/activation-inl.h"
+#include "../../operator/nn/batch_norm-inl.h"
+#include "../../operator/nn/convolution-inl.h"
+#include "../../operator/nn/fully_connected-inl.h"
+#include "../../operator/nn/pooling-inl.h"
+#include "../../operator/softmax_output-inl.h"
+#include "./tensorrt-inl.h"
+
+#if MXNET_USE_TENSORRT_ONNX_CHECKER
+#include <onnx/checker.h>
+#endif  // MXNET_USE_TENSORRT_ONNX_CHECKER
+
+namespace mxnet {
+namespace op {
+namespace nnvm_to_onnx {
+
+op::TRTParam ConvertNnvmGraphToOnnx(
+    const nnvm::Graph& g,
+    std::unordered_map<std::string, NDArray>* const shared_buffer) {
+    op::TRTParam trt_param;
+    op::tensorrt::NameToIdx_t trt_input_map;
+    op::tensorrt::InferenceMap_t trt_output_map;
+
+  const nnvm::IndexedGraph& ig = g.indexed_graph();
+  const auto& storage_types = g.GetAttr<StorageTypeVector>("storage_type");
+  const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
+  const auto& shape_inputs = g.GetAttr<ShapeVector>("shape_inputs");
+
+  for (auto& e : storage_types) {
+    if (e != mshadow::kFloat32) {
+      LOG(FATAL) << "ONNX converter does not support types other than float32 "
+                    "right now.";
+    }
+  }
+
+  ModelProto model_proto;
+  // Need to determine IR versions and features to support
+  model_proto.set_ir_version(static_cast<int64>(2));
+  GraphProto* graph_proto = model_proto.mutable_graph();
+
+  std::unordered_map<std::string, TShape> placeholder_shapes =
+      GetPlaceholderShapes(shape_inputs, ig);
+  std::unordered_map<std::string, uint32_t> output_lookup = GetOutputLookup(ig);
+  uint32_t current_input = 0;
+
+  // Can't do a foreach over IndexedGraph since it doesn't implement begin(), etc.
+  for (uint32_t node_idx = 0; node_idx < ig.num_nodes(); ++node_idx) {
+    const IndexedGraph::Node& node = ig[node_idx];
+    const nnvm::Node* source = node.source;
+    const NodeAttrs& attrs = source->attrs;
+    const Op* op = source->op();
+
+    std::string node_name = attrs.name;
+    // Here, "variable" actually means anything that's not an op i.e. a constant (weights) or a
+    // placeholder
+    if (source->is_variable()) {
+      // Is this a placeholder?
+      if (shared_buffer->count(node_name) == 0) {
+        // This fixes the problem with a SoftmaxOutput node during inference, but it's hacky.
+        // Need to figure out how to properly fix it.
+        if (node_name.find("label") != std::string::npos) {
+          current_input++;
+          continue;
+        }
+        trt_input_map.emplace(node_name, current_input++);
+        ConvertPlaceholder(node_name, placeholder_shapes, graph_proto);
+      } else {
+        // If it's not a placeholder, then by exclusion it's a constant.
+        ConvertConstant(graph_proto, node_name, shared_buffer);
+      }  // is_placeholder
+    } else {
+      // It's an op, rather than a "variable" (constant or placeholder)
+      NodeProto* node_proto = graph_proto->add_node();
+      node_proto->set_name(node_name);
+      if (converter_map.count(op->name) == 0) {
+        LOG(FATAL) << "Conversion for node of type " << op->name << " (node "
+                   << node_name << ") "
+                   << " is not supported yet.";
+      }
+      // Find function ptr to a converter based on the op name, and invoke the converter. This
+      // looks unsafe because find may not succeed, but it does because we're in the operator
+      // logic after testing that this node name does not represent a variable.
+      converter_map.find(op->name)->second(node_proto, attrs, ig, node.inputs);
+      // Add all inputs to the current node (i.e. add graph edges)
+      for (const nnvm::IndexedGraph::NodeEntry& entry : node.inputs) {
+        std::string in_node_name = ig[entry.node_id].source->attrs.name;
+        // As before, we're not adding labels e.g. for SoftmaxOutput, but I wish there was a less
+        // hacky way to do it than name matching.
+        if (in_node_name.find("label") != std::string::npos) {
+          continue;
+        }
+        node_proto->add_input(in_node_name);
+      }
+      // The node's output will have the same name as the node name.
+      node_proto->add_output(node_name);
+      // See if the current node is an output node
+      auto out_iter = output_lookup.find(node_name);
+      // We found an output
+      if (out_iter != output_lookup.end()) {
+        ConvertOutput(&trt_output_map, graph_proto, out_iter, node_name, g,
+                      storage_types, dtypes);
+      }  // output found
+    }    // conversion function exists
+  }      // loop over i from 0 to num_nodes
+
+  model_proto.SerializeToString(&trt_param.serialized_onnx_graph);
+  common::Serialize<op::tensorrt::NameToIdx_t>(trt_input_map,
+                                          &trt_param.serialized_input_map);
+  common::Serialize<op::tensorrt::InferenceMap_t>(trt_output_map,
+                                             &trt_param.serialized_output_map);
+
+#if MXNET_USE_TENSORRT_ONNX_CHECKER
+  onnx::checker::check_model(model_proto);
+#endif  // MXNET_USE_TENSORRT_ONNX_CHECKER
+
+  return trt_param;
+}
+
+void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
+                        const nnvm::IndexedGraph& /*ig*/,
+                        const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  const auto& conv_param = nnvm::get<op::ConvolutionParam>(attrs.parsed);
+
+  node_proto->set_op_type("Conv");
+
+  const TShape kernel = conv_param.kernel;
+  const TShape stride = conv_param.stride;
+  const TShape dilate = conv_param.dilate;
+  const TShape pad = conv_param.pad;
+  const uint32_t num_group = conv_param.num_group;
+  // const bool no_bias = conv_param.no_bias;
+  const dmlc::optional<int> layout = conv_param.layout;
+
+  // kernel shape
+  AttributeProto* const kernel_shape = node_proto->add_attribute();
+  kernel_shape->set_name("kernel_shape");
+  kernel_shape->set_type(AttributeProto::INTS);
+
+  for (const dim_t kval : kernel) {
+    kernel_shape->add_ints(static_cast<int64>(kval));
+  }
+
+  // pads
+  AttributeProto* const pads = node_proto->add_attribute();
+  pads->set_name("pads");
+  pads->set_type(AttributeProto::INTS);
+
+  for (const dim_t kval : pad) {
+    pads->add_ints(static_cast<int64>(kval));
+    pads->add_ints(static_cast<int64>(kval));
+  }
+
+  // dilations
+  AttributeProto* const dilations = node_proto->add_attribute();
+  dilations->set_name("dilations");
+  dilations->set_type(AttributeProto::INTS);
+  for (const dim_t kval : dilate) {
+    dilations->add_ints(static_cast<int64>(kval));
+  }
+
+  // strides
+  AttributeProto* const strides = node_proto->add_attribute();
+  strides->set_name("strides");
+  strides->set_type(AttributeProto::INTS);
+  for (const dim_t kval : stride) {
+    strides->add_ints(static_cast<int64>(kval));
+  }
+
+  // group
+  AttributeProto* const group = node_proto->add_attribute();
+  group->set_name("group");
+  group->set_type(AttributeProto::INT);
+  group->set_i(static_cast<int64>(num_group));
+}  // end ConvertConvolution
+
+void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
+                    const nnvm::IndexedGraph& /*ig*/,
+                    const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  const auto& pooling_param = nnvm::get<op::PoolingParam>(attrs.parsed);
+
+  const TShape kernel = pooling_param.kernel;
+  const TShape stride = pooling_param.stride;
+  const TShape pad = pooling_param.pad;
+  const int pool_type = pooling_param.pool_type;
+  const bool global_pool = pooling_param.global_pool;
+
+  if (global_pool) {
+    if (pool_type == 0) {
+      node_proto->set_op_type("GlobalMaxPool");
+    } else {
+      node_proto->set_op_type("GlobalAveragePool");
+    }
+    return;
+  }
+
+  // kernel_shape
+  AttributeProto* const kernel_shape = node_proto->add_attribute();
+  kernel_shape->set_name("kernel_shape");
+  kernel_shape->set_type(AttributeProto::INTS);
+  for (int kval : kernel) {
+    kernel_shape->add_ints(static_cast<int64>(kval));
+  }
+
+  // pads
+  AttributeProto* const pads = node_proto->add_attribute();
+  pads->set_name("pads");
+  pads->set_type(AttributeProto::INTS);
+  for (int kval : pad) {
+    pads->add_ints(static_cast<int64>(kval));
+  }
+
+  // strides
+  AttributeProto* const strides = node_proto->add_attribute();
+  strides->set_name("strides");
+  strides->set_type(AttributeProto::INTS);
+  for (int kval : stride) {
+    strides->add_ints(static_cast<int64>(kval));
+  }
+
+  if (pool_type == 0) {
+    node_proto->set_op_type("MaxPool");
+  } else {
+    node_proto->set_op_type("AveragePool");
+  }  // average pooling
+  // not global pooling
+}  // end ConvertPooling
+
+void ConvertActivation(NodeProto* node_proto, const NodeAttrs& attrs,
+                       const nnvm::IndexedGraph& /*ig*/,
+                       const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  const auto& act_param = nnvm::get<op::ActivationParam>(attrs.parsed);
+  std::string act_type;
+  switch (act_param.act_type) {
+    case op::activation::kReLU:
+      act_type = "Relu";
+      break;
+    case op::activation::kSigmoid:
+      act_type = "Sigmoid";
+      break;
+    case op::activation::kTanh:
+      act_type = "Tanh";
+      break;
+    case op::activation::kSoftReLU:
+      // act_type = "SoftReLU";
+      throw dmlc::Error("SoftReLU is not supported in ONNX");
+      break;
+    default:
+      throw dmlc::Error("Activation of such type doesn't exist");
+  }
+
+  node_proto->set_op_type(act_type);
+}
+
+void ConvertFullyConnected(NodeProto* node_proto, const NodeAttrs& attrs,
+                           const nnvm::IndexedGraph& /*ig*/,
+                           const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  const auto& act_param = nnvm::get<op::FullyConnectedParam>(attrs.parsed);
+  if (act_param.no_bias) {
+      node_proto->set_op_type("MatMul");
+  } else {
+      node_proto->set_op_type("Gemm");
+
+      AttributeProto* const alpha = node_proto->add_attribute();
+      alpha->set_name("alpha");
+      alpha->set_type(AttributeProto::FLOAT);
+      alpha->set_f(1.0f);
+
+      AttributeProto* const beta = node_proto->add_attribute();
+      beta->set_name("beta");
+      beta->set_type(AttributeProto::FLOAT);
+      beta->set_f(1.0f);
+
+      AttributeProto* const broadcast = node_proto->add_attribute();
+      broadcast->set_name("broadcast");
+      broadcast->set_type(AttributeProto::INT);
+      broadcast->set_i(1);
+
+      AttributeProto* const transA = node_proto->add_attribute();
+      transA->set_name("transA");
+      transA->set_type(AttributeProto::INT);
+      transA->set_i(0);
+
+      AttributeProto* const transB = node_proto->add_attribute();
+      transB->set_name("transB");
+      transB->set_type(AttributeProto::INT);
+      transB->set_i(1);
+  }
+}
+
+void ConvertSoftmaxOutput(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
+                          const nnvm::IndexedGraph& /*ig*/,
+                          const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  node_proto->set_op_type("Softmax");
+
+  // Setting by default to 1 since MXNet doesn't provide such an attribute for softmax in its
+  // node params. This attribute is only relevant when the input is coerced to 2D, and in that
+  // case dimension 0 is assumed to be the batch dimension.
+  AttributeProto* const axis = node_proto->add_attribute();
+  axis->set_name("axis");
+  axis->set_type(AttributeProto::INT);
+  axis->set_i(1);
+}
+
+void ConvertFlatten(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
+                    const nnvm::IndexedGraph& /*ig*/,
+                    const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  node_proto->set_op_type("Flatten");
+
+  // Setting by default to 1 since MXNet doesn't provide such an attribute for Flatten in its
+  // node params. This attribute is only relevant when the input is coerced to 2D, and in that
+  // case dimension 0 is assumed to be the batch dimension.
+  AttributeProto* const axis = node_proto->add_attribute();
+  axis->set_name("axis");
+  axis->set_type(AttributeProto::INT);
+  axis->set_i(1);
+}
+
+void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs,
+                      const nnvm::IndexedGraph& /*ig*/,
+                      const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  node_proto->set_op_type("BatchNormalization");
+  const auto& param = nnvm::get<op::BatchNormParam>(attrs.parsed);
+
+  AttributeProto* const epsilon = node_proto->add_attribute();
+  epsilon->set_name("epsilon");
+  epsilon->set_type(AttributeProto::FLOAT);
+  epsilon->set_f(static_cast<float>(param.eps));
+
+  AttributeProto* const is_test = node_proto->add_attribute();
+  is_test->set_name("is_test");
+  is_test->set_type(AttributeProto::INT);
+  is_test->set_i(1);
+
+  AttributeProto* const momentum = node_proto->add_attribute();
+  momentum->set_name("momentum");
+  momentum->set_type(AttributeProto::FLOAT);
+  momentum->set_f(param.momentum);
+
+  AttributeProto* const spatial = node_proto->add_attribute();
+  spatial->set_name("spatial");
+  spatial->set_type(AttributeProto::INT);
+  spatial->set_i(1);
+
+  AttributeProto* const consumed = node_proto->add_attribute();
+  consumed->set_name("consumed_inputs");
+  consumed->set_type(AttributeProto::INTS);
+
+  for (int i = 0; i < 5; i++) {
+    int val = (i < 3) ? 0 : 1;
+    consumed->add_ints(static_cast<int64>(val));
+  }
+}
+
+void ConvertElementwiseAdd(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
+                           const nnvm::IndexedGraph& /*ig*/,
+                           const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+  node_proto->set_op_type("Add");
+  AttributeProto* const axis = node_proto->add_attribute();
+  axis->set_name("axis");
+  axis->set_type(AttributeProto::INT);
+  axis->set_i(1);
+
+  AttributeProto* const broadcast = node_proto->add_attribute();
+  broadcast->set_name("broadcast");
+  broadcast->set_type(AttributeProto::INT);
+  broadcast->set_i(0);  // 1
+}
+
+std::unordered_map<std::string, TShape> GetPlaceholderShapes(
+    const ShapeVector& shape_inputs, const nnvm::IndexedGraph& ig) {
+  std::unordered_map<std::string, TShape> placeholder_shapes;
+  for (uint32_t i = 0; i < shape_inputs.size(); ++i) {
+    std::string name = ig[ig.input_nodes()[i]].source->attrs.name;
+    TShape shp = shape_inputs[i];
+    if (shp.ndim() > 0) {
+      placeholder_shapes.emplace(name, shp);
+    }
+  }
+  return placeholder_shapes;
+}
+
+std::unordered_map<std::string, uint32_t> GetOutputLookup(
+    const nnvm::IndexedGraph& ig) {
+  std::unordered_map<std::string, uint32_t> output_lookup;
+  const std::vector<nnvm::IndexedGraph::NodeEntry>& graph_outputs =
+      ig.outputs();
+  for (uint32_t i = 0; i < graph_outputs.size(); ++i) {
+    const uint32_t id = graph_outputs[i].node_id;
+    const IndexedGraph::Node ig_node = ig[id];
+    const nnvm::Node* const source = ig_node.source;
+    const std::string name = source->attrs.name;
+    output_lookup.emplace(name, i);
+  }
+  return output_lookup;
+}
+
+void ConvertPlaceholder(
+    const std::string& node_name,
+    const std::unordered_map<std::string, TShape>& placeholder_shapes,
+    GraphProto* const graph_proto) {
+  auto val_info_proto = graph_proto->add_input();
+  auto type_proto = val_info_proto->mutable_type()->mutable_tensor_type();
+  auto shape_proto = type_proto->mutable_shape();
+
+  val_info_proto->set_name(node_name);
+  // Will support fp16, etc. in the near future
+  type_proto->set_elem_type(TensorProto_DataType_FLOAT);
+  auto entry_shape = placeholder_shapes.find(node_name)->second;
+
+  for (const auto& elem : entry_shape) {
+    TensorShapeProto_Dimension* const tsp_dim = shape_proto->add_dim();
+    tsp_dim->set_dim_value(static_cast<int64>(elem));
+  }
+}
+
+void ConvertConstant(
+    GraphProto* const graph_proto, const std::string& node_name,
+    std::unordered_map<std::string, NDArray>* const shared_buffer) {
+  NodeProto* const node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
+  node_proto->add_output(node_name);
+  node_proto->set_op_type("Constant");
+
+  const NDArray nd = shared_buffer->find(node_name)->second;
+  const TBlob& blob = nd.data();
+  const TShape shape = blob.shape_;
+  const int32_t size = shape.Size();
+
+  std::shared_ptr<float> shared_data_ptr(new float[size]);
+  float* const data_ptr = shared_data_ptr.get();
+  nd.SyncCopyToCPU(static_cast<void*>(data_ptr), size);
+
+  AttributeProto* const tensor_attr = node_proto->add_attribute();
+  tensor_attr->set_name("value");
+  tensor_attr->set_type(AttributeProto::TENSOR);
+
+  TensorProto* const tensor_proto = tensor_attr->mutable_t();
+  tensor_proto->set_data_type(TensorProto_DataType_FLOAT);
+  for (auto& dim : shape) {
+    tensor_proto->add_dims(static_cast<int64>(dim));
+  }
+
+  for (int blob_idx = 0; blob_idx < size; ++blob_idx) {
+    tensor_proto->add_float_data(data_ptr[blob_idx]);
+  }
+}
+
+void ConvertOutput(
+    op::tensorrt::InferenceMap_t* const trt_output_map,
+    GraphProto* const graph_proto,
+    const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
+    const std::string& node_name, const nnvm::Graph& g,
+    const StorageTypeVector& storage_types, const DTypeVector& dtypes) {
+  const nnvm::IndexedGraph& ig = g.indexed_graph();
+  uint32_t out_idx = ig.entry_id(ig.outputs()[out_iter->second]);
+  TShape out_shape = g.GetAttr<nnvm::ShapeVector>("shape")[out_idx];
+  int storage_type = storage_types[out_idx];
+  int dtype = dtypes[out_idx];
+
+  // This should work with fp16 as well
+  op::tensorrt::InferenceTuple_t out_tuple{out_iter->second, out_shape, storage_type,
+                                      dtype};
+
+  trt_output_map->emplace(node_name, out_tuple);
+
+  auto graph_out = graph_proto->add_output();
+  auto tensor_type = graph_out->mutable_type()->mutable_tensor_type();
+  auto tensor_shape_proto = tensor_type->mutable_shape();
+  graph_out->set_name(node_name);
+
+  // Also support fp16.
+  tensor_type->set_elem_type(TensorProto_DataType_FLOAT);
+
+  for (int64_t dim_shp : out_shape) {
+    TensorShapeProto_Dimension* const tsp_dim = tensor_shape_proto->add_dim();
+    tsp_dim->set_dim_value(static_cast<int64>(dim_shp));
+  }
+}
+
+}  // namespace nnvm_to_onnx
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/operator/contrib/tensorrt-inl.h b/src/operator/contrib/tensorrt-inl.h
new file mode 100644
index 00000000000..be335ab1208
--- /dev/null
+++ b/src/operator/contrib/tensorrt-inl.h
@@ -0,0 +1,113 @@
+#ifndef MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_
+#define MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file tensorrt-inl.h
+ * \brief TensorRT Operator
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include <dmlc/logging.h>
+#include <dmlc/memory_io.h>
+#include <dmlc/serializer.h>
+#include <dmlc/parameter.h>
+#include <mxnet/base.h>
+#include <mxnet/operator.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+
+#include <NvInfer.h>
+#include <onnx/onnx.pb.h>
+
+#include <algorithm>
+#include <iostream>
+#include <map>
+#include <vector>
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+#include <string>
+
+#include "../operator_common.h"
+#include "../../common/utils.h"
+#include "../../common/serialization.h"
+#include "../../executor/exec_pass.h"
+#include "../../executor/graph_executor.h"
+#include "../../executor/onnx_to_tensorrt.h"
+
+namespace mxnet {
+namespace op {
+
+using namespace nnvm;
+using namespace ::onnx;
+using int64 = ::google::protobuf::int64;
+
+namespace tensorrt {
+  enum class TypeIO { Inputs = 0, Outputs = 1 };
+  using NameToIdx_t = std::map<std::string, int32_t>;
+  using InferenceTuple_t = std::tuple<uint32_t, TShape, int, int>;
+  using InferenceMap_t = std::map<std::string, InferenceTuple_t>;
+}  // namespace tensorrt
+
+using trt_name_to_idx = std::map<std::string, uint32_t>;
+
+struct TRTParam : public dmlc::Parameter<TRTParam> {
+  std::string serialized_onnx_graph;
+  std::string serialized_input_map;
+  std::string serialized_output_map;
+  tensorrt::NameToIdx_t input_map;
+  tensorrt::InferenceMap_t output_map;
+  ::onnx::ModelProto onnx_pb_graph;
+
+  TRTParam() {}
+
+  TRTParam(const ::onnx::ModelProto& onnx_graph,
+           const tensorrt::InferenceMap_t& input_map,
+           const tensorrt::NameToIdx_t& output_map) {
+    common::Serialize(input_map, &serialized_input_map);
+    common::Serialize(output_map, &serialized_output_map);
+    onnx_graph.SerializeToString(&serialized_onnx_graph);
+  }
+
+DMLC_DECLARE_PARAMETER(TRTParam) {
+    DMLC_DECLARE_FIELD(serialized_onnx_graph)
+    .describe("Serialized ONNX graph");
+    DMLC_DECLARE_FIELD(serialized_input_map)
+    .describe("Map from inputs to topological order as input.");
+    DMLC_DECLARE_FIELD(serialized_output_map)
+    .describe("Map from outputs to order in g.outputs.");
+  }
+};
+
+struct TRTEngineParam {
+  nvinfer1::IExecutionContext* trt_executor;
+  std::vector<std::pair<uint32_t, tensorrt::TypeIO> > binding_map;
+};
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
+
+#endif  // MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_
diff --git a/src/operator/contrib/tensorrt.cc b/src/operator/contrib/tensorrt.cc
new file mode 100644
index 00000000000..619fe1e2b8f
--- /dev/null
+++ b/src/operator/contrib/tensorrt.cc
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file trt.cc
+ * \brief TensorRT operation registration
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include "./tensorrt-inl.h"
+
+#include <mxnet/base.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+
+#include <algorithm>
+#include <fstream>
+#include <iostream>
+#include <unordered_map>
+#include <vector>
+
+#include "../../common/serialization.h"
+#include "../../common/utils.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(TRTParam);
+
+OpStatePtr GetPtrMapping(nvinfer1::ICudaEngine* trt_engine,
+                         tensorrt::NameToIdx_t input_map,
+                         tensorrt::NameToIdx_t output_map) {
+  TRTEngineParam param;
+  for (int b = 0; b < trt_engine->getNbBindings(); ++b) {
+    const std::string& binding_name = trt_engine->getBindingName(b);
+    if (trt_engine->bindingIsInput(b)) {
+      param.binding_map.emplace_back(input_map[binding_name],
+                                     tensorrt::TypeIO::Inputs);
+    } else {
+      param.binding_map.emplace_back(output_map[binding_name],
+                                     tensorrt::TypeIO::Outputs);
+    }
+  }
+  param.trt_executor = trt_engine->createExecutionContext();
+  return OpStatePtr::Create<TRTEngineParam>(param);
+}
+
+OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/,
+                          const std::vector<TShape>& /*ishape*/,
+                          const std::vector<int>& /*itype*/) {
+  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+
+  ::onnx::ModelProto model_proto;
+  bool success = model_proto.ParseFromString(node_param.serialized_onnx_graph);
+  if (!success) {
+    LOG(FATAL) << "Problems parsing serialized ONNX model.";
+  }
+  auto graph = model_proto.graph();
+  auto first_input_type = graph.input(0).type().tensor_type();
+  auto dim_value = first_input_type.shape().dim(0).dim_value();
+  auto batch_size = static_cast<int32_t >(dim_value);
+  // Need to set up max workspace size based on device properties
+  nvinfer1::ICudaEngine* const trt_engine = ::onnx_to_tensorrt::onnxToTrtCtx(
+      node_param.serialized_onnx_graph, batch_size, 1 << 30);
+
+  tensorrt::NameToIdx_t output_map;
+  for (auto& el : node_param.output_map) {
+    output_map[el.first] = std::get<0>(el.second);
+  }
+  return GetPtrMapping(trt_engine, node_param.input_map, output_map);
+}
+
+void TRTParamParser(nnvm::NodeAttrs* attrs) {
+  TRTParam param_;
+
+  try {
+    param_.Init(attrs->dict);
+    common::Deserialize(&param_.input_map, param_.serialized_input_map);
+    common::Deserialize(&param_.output_map, param_.serialized_output_map);
+    param_.onnx_pb_graph.ParseFromString(param_.serialized_onnx_graph);
+  } catch (const dmlc::ParamError& e) {
+    std::ostringstream os;
+    os << e.what();
+    os << ", in operator " << attrs->op->name << "("
+       << "name=\"" << attrs->name << "\"";
+    for (const auto& k : attrs->dict) {
+      os << ", " << k.first << "=\"" << k.second << "\"";
+    }
+    os << ")";
+    throw dmlc::ParamError(os.str());
+  }
+
+  attrs->parsed = std::move(param_);
+}
+
+inline bool TRTInferShape(const NodeAttrs& attrs, std::vector<TShape>* /*in_shape*/,
+                          std::vector<TShape>* out_shape) {
+  const auto &node_param = nnvm::get<TRTParam>(attrs.parsed);
+  for (auto& el : node_param.output_map) {
+    (*out_shape)[std::get<0>(el.second)] = std::get<1>(el.second);
+  }
+  return true;
+}
+
+inline bool TRTInferStorageType(const NodeAttrs& /*attrs*/, const int /*dev_mask*/,
+                                DispatchMode* dispatch_mode,
+                                std::vector<int>* /*in_storage_type*/,
+                                std::vector<int>* out_storage_type) {
+  return storage_type_assign(out_storage_type, mxnet::kDefaultStorage,
+                             dispatch_mode, DispatchMode::kFCompute);
+}
+
+inline bool TRTInferType(const NodeAttrs& attrs, std::vector<int>* /*in_dtype*/,
+                         std::vector<int>* out_dtype) {
+  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+  for (auto& el : node_param.output_map) {
+    (*out_dtype)[std::get<0>(el.second)] = std::get<3>(el.second);
+  }
+  return true;
+}
+
+inline std::vector<std::string> TRTListInputNames(const NodeAttrs& attrs) {
+  std::vector<std::string> output;
+  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+  output.resize(node_param.input_map.size());
+  for (auto& el : node_param.input_map) {
+    output[el.second] = el.first;
+  }
+  return output;
+}
+
+inline std::vector<std::string> TRTListOutputNames(const NodeAttrs& attrs) {
+  std::vector<std::string> output;
+  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+  output.resize(node_param.output_map.size());
+  for (auto& el : node_param.output_map) {
+    output[std::get<0>(el.second)] = el.first;
+  }
+  return output;
+}
+
+NNVM_REGISTER_OP(_trt_op)
+    .describe(R"code(TRT operation (one engine)
+)code" ADD_FILELINE)
+    .set_num_inputs([](const NodeAttrs& attrs) {
+      const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+      return node_param.input_map.size();
+    })
+    .set_num_outputs([](const NodeAttrs& attrs) {
+      const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+      return node_param.output_map.size();
+    })
+    .set_attr_parser(TRTParamParser)
+    .set_attr<nnvm::FInferShape>("FInferShape", TRTInferShape)
+    .set_attr<nnvm::FInferType>("FInferType", TRTInferType)
+    .set_attr<nnvm::FListInputNames>("FListInputNames", TRTListInputNames)
+    .set_attr<nnvm::FListOutputNames>("FListOutputNames", TRTListOutputNames)
+    .set_attr<FCreateOpState>("FCreateOpState", TRTCreateState)
+    .set_attr<FInferStorageType>("FInferStorageType", TRTInferStorageType);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/operator/contrib/tensorrt.cu b/src/operator/contrib/tensorrt.cu
new file mode 100644
index 00000000000..2fe8727b73e
--- /dev/null
+++ b/src/operator/contrib/tensorrt.cu
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file trt.cu
+ * \brief TensorRT GPU operation
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include "./tensorrt-inl.h"
+
+namespace mxnet {
+namespace op {
+
+#define CHECK_CUDART(x) do { \
+  cudaError_t res = (x); \
+  if (res != cudaSuccess) { \
+    fprintf(stderr, "CUDART: %s = %d (%s) at (%s:%d)\n", \
+      #x, res, cudaGetErrorString(res), __FILE__, __LINE__); \
+    exit(1); \
+  } \
+} while (0)
+
+void TRTCompute(const OpStatePtr& state, const OpContext& ctx,
+                     const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
+                     const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+  cudaStream_t cuda_s = Stream<gpu>::GetStream(s);
+  const auto& param = state.get_state<TRTEngineParam>();
+  std::vector<void*> bindings;
+  bindings.reserve(param.binding_map.size());
+  for (auto& p : param.binding_map) {
+    if (p.second == tensorrt::TypeIO::Inputs) {
+      bindings.emplace_back(inputs[p.first].dptr_);
+    } else {
+      bindings.emplace_back(outputs[p.first].dptr_);
+    }
+  }
+
+  const int batch_size = static_cast<int>(inputs[0].shape_[0]);
+  param.trt_executor->enqueue(batch_size, bindings.data(), cuda_s, nullptr);
+  CHECK_CUDART(cudaStreamSynchronize(cuda_s));
+}
+
+NNVM_REGISTER_OP(_trt_op)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", TRTCompute);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/tests/.gitignore b/tests/.gitignore
index d6459089c24..3e5eed695f0 100644
--- a/tests/.gitignore
+++ b/tests/.gitignore
@@ -1 +1,2 @@
 *_unittest
+*.gz
diff --git a/tests/cpp/misc/serialization.cc b/tests/cpp/misc/serialization.cc
new file mode 100644
index 00000000000..96f8b6c3a3a
--- /dev/null
+++ b/tests/cpp/misc/serialization.cc
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <../../../src/common/serialization.h>
+
+using namespace mxnet;
+using namespace std;
+
+/*
+ * Test that used datastruct are properly serialized and deserialized
+ */
+
+TEST(SerializerTest, InputMapCorrect) {
+    std::map<std::string, int32_t> input_map;
+    input_map.emplace("input_0", 2);
+    input_map.emplace("another_input", 0);
+    input_map.emplace("last_input", 1);
+    std::string serialized_data;
+    common::Serialize(input_map, &serialized_data);
+    std::map<std::string, int32_t> deserialized_input_map;
+    common::Deserialize(&deserialized_input_map, serialized_data);
+    ASSERT_EQ(input_map.size(), deserialized_input_map.size());
+    for (auto& p : input_map) {
+        auto it = deserialized_input_map.find(p.first);
+        ASSERT_NE(it, deserialized_input_map.end());
+        ASSERT_EQ(it->second, p.second);
+    }
+}
+
+TEST(SerializerTest, OutputMapCorrect) {
+    std::map<std::string, std::tuple<uint32_t, TShape, int, int> > output_map;
+    output_map.emplace("output_0", std::make_tuple(1, TShape({23, 12, 63, 432}), 0, 1));
+    output_map.emplace("another_output", std::make_tuple(2, TShape({23, 123}), 14, -23));
+    output_map.emplace("last_output", std::make_tuple(0, TShape({0}), -1, 0));
+    std::string serialized_data;
+    common::Serialize(output_map, &serialized_data);
+    std::map<std::string, std::tuple<uint32_t, TShape, int, int> > deserialized_output_map;
+    common::Deserialize(&deserialized_output_map, serialized_data);
+    ASSERT_EQ(output_map.size(), deserialized_output_map.size());
+    for (auto& p : output_map) {
+        auto it = deserialized_output_map.find(p.first);
+        ASSERT_NE(it, deserialized_output_map.end());
+        auto lhs = it->second;
+        auto rhs = p.second;
+        ASSERT_EQ(std::get<0>(lhs), std::get<0>(rhs));
+        ASSERT_EQ(std::get<1>(lhs), std::get<1>(rhs));
+        ASSERT_EQ(std::get<2>(lhs), std::get<2>(rhs));
+        ASSERT_EQ(std::get<3>(lhs), std::get<3>(rhs));
+    }
+}
+
diff --git a/tests/python/tensorrt/common.py b/tests/python/tensorrt/common.py
new file mode 100644
index 00000000000..eb599f69973
--- /dev/null
+++ b/tests/python/tensorrt/common.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+from ctypes.util import find_library
+
+
+def check_tensorrt_installation():
+    assert find_library('nvinfer') is not None, "Can't find the TensorRT shared library"
+
+
+def merge_dicts(*dict_args):
+    """Merge arg_params and aux_params to populate shared_buffer"""
+    result = {}
+    for dictionary in dict_args:
+        result.update(dictionary)
+    return result
+
+
+def get_fp16_infer_for_fp16_graph():
+    return int(os.environ.get("MXNET_TENSORRT_USE_FP16_FOR_FP32", 0))
+
+
+def set_fp16_infer_for_fp16_graph(status=False):
+    os.environ["MXNET_TENSORRT_USE_FP16_FOR_FP32"] = str(int(status))
diff --git a/tests/python/tensorrt/lenet5_common.py b/tests/python/tensorrt/lenet5_common.py
new file mode 100644
index 00000000000..347d6f3c11b
--- /dev/null
+++ b/tests/python/tensorrt/lenet5_common.py
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import mxnet as mx
+from common import *
+
+def get_iters(mnist, batch_size):
+    """Get MNIST iterators."""
+    train_iter = mx.io.NDArrayIter(mnist['train_data'],
+                                   mnist['train_label'],
+                                   batch_size,
+                                   shuffle=True)
+    val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
+    test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
+    all_test_labels = np.array(mnist['test_label'])
+    return train_iter, val_iter, test_iter, all_test_labels
diff --git a/tests/python/tensorrt/lenet5_train.py b/tests/python/tensorrt/lenet5_train.py
new file mode 100644
index 00000000000..8edd9abf70e
--- /dev/null
+++ b/tests/python/tensorrt/lenet5_train.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import mxnet as mx
+from lenet5_common import get_iters
+
+
+def lenet5():
+    """LeNet-5 Symbol"""
+    #pylint: disable=no-member
+    data = mx.sym.Variable('data')
+    conv1 = mx.sym.Convolution(data=data, kernel=(5, 5), num_filter=20)
+    tanh1 = mx.sym.Activation(data=conv1, act_type="tanh")
+    pool1 = mx.sym.Pooling(data=tanh1, pool_type="max",
+                           kernel=(2, 2), stride=(2, 2))
+    # second conv
+    conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=50)
+    tanh2 = mx.sym.Activation(data=conv2, act_type="tanh")
+    pool2 = mx.sym.Pooling(data=tanh2, pool_type="max",
+                           kernel=(2, 2), stride=(2, 2))
+    # first fullc
+    flatten = mx.sym.Flatten(data=pool2)
+    fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=500)
+    tanh3 = mx.sym.Activation(data=fc1, act_type="tanh")
+    # second fullc
+    fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10)
+    # loss
+    lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax')
+    #pylint: enable=no-member
+    return lenet
+
+
+def train_lenet5(num_epochs, batch_size, train_iter, val_iter, test_iter):
+    """train LeNet-5 model on MNIST data"""
+    ctx = mx.gpu(0)
+    lenet_model = mx.mod.Module(lenet5(), context=ctx)
+
+    lenet_model.fit(train_iter,
+                    eval_data=val_iter,
+                    optimizer='sgd',
+                    optimizer_params={'learning_rate': 0.1, 'momentum': 0.9},
+                    eval_metric='acc',
+                    batch_end_callback=mx.callback.Speedometer(batch_size, 1),
+                    num_epoch=num_epochs)
+
+    # predict accuracy for lenet
+    acc = mx.metric.Accuracy()
+    lenet_model.score(test_iter, acc)
+    accuracy = acc.get()[1]
+    assert accuracy > 0.95, "LeNet-5 training accuracy on MNIST was too low"
+    return lenet_model
+
+
+if __name__ == '__main__':
+    num_epochs = 10
+    batch_size = 128
+    model_name = 'lenet5'
+    model_dir = os.getenv("LENET_MODEL_DIR", "/tmp")
+    model_file = '%s/%s-symbol.json' % (model_dir, model_name)
+    params_file = '%s/%s-%04d.params' % (model_dir, model_name, num_epochs)
+
+    if not (os.path.exists(model_file) and os.path.exists(params_file)):
+        mnist = mx.test_utils.get_mnist()
+
+        _, _, _, all_test_labels = get_iters(mnist, batch_size)
+
+        trained_lenet = train_lenet5(num_epochs, batch_size,
+                                    *get_iters(mnist, batch_size)[:-1])
+        trained_lenet.save_checkpoint(model_name, num_epochs)
diff --git a/tests/python/tensorrt/test_cvnets.py b/tests/python/tensorrt/test_cvnets.py
new file mode 100644
index 00000000000..4fdd522341b
--- /dev/null
+++ b/tests/python/tensorrt/test_cvnets.py
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import gc
+import gluoncv
+import mxnet as mx
+import numpy as np
+
+from mxnet import gluon
+from time import time
+
+from mxnet.gluon.data.vision import transforms
+
+
+def get_classif_model(model_name, use_tensorrt, ctx=mx.gpu(0), batch_size=128):
+    mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt)
+    h, w = 32, 32
+    net = gluoncv.model_zoo.get_model(model_name, pretrained=True)
+    data = mx.sym.var('data')
+
+    if use_tensorrt:
+        out = net(data)
+        softmax = mx.sym.SoftmaxOutput(out, name='softmax')
+        all_params = dict([(k, v.data()) for k, v in net.collect_params().items()])
+        executor = mx.contrib.tensorrt.tensorrt_bind(softmax, ctx=ctx, all_params=all_params,
+                                                     data=(batch_size,3, h, w),
+                                                     softmax_label=(batch_size,), grad_req='null',
+                                                     force_rebind=True)
+    else:
+        # Convert gluon model to Symbolic
+        net.hybridize()
+        net.forward(mx.ndarray.zeros((batch_size, 3, h, w)))
+        net.export(model_name)
+        symbol, arg_params, aux_params = mx.model.load_checkpoint(model_name, 0)
+        executor = symbol.simple_bind(ctx=ctx, data=(batch_size, 3, h, w),
+                                      softmax_label=(batch_size,))
+        executor.copy_params_from(arg_params, aux_params)
+    return executor
+
+
+def cifar10_infer(model_name, use_tensorrt, num_workers, ctx=mx.gpu(0), batch_size=128):
+    executor = get_classif_model(model_name, use_tensorrt, ctx, batch_size)
+
+    num_ex = 10000
+    all_preds = np.zeros([num_ex, 10])
+
+    all_label_test = np.zeros(num_ex)
+
+    transform_test = transforms.Compose([
+        transforms.ToTensor(),
+        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
+    ])
+
+    data_loader = lambda: gluon.data.DataLoader(
+        gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
+        batch_size=batch_size, shuffle=False, num_workers=num_workers)
+
+    val_data = data_loader()
+
+    for idx, (data, label) in enumerate(val_data):
+        # Skip last batch if it's undersized.
+        if data.shape[0] < batch_size:
+            continue
+        offset = idx * batch_size
+        all_label_test[offset:offset + batch_size] = label.asnumpy()
+
+        # warm-up, but don't use result
+        executor.forward(is_train=False, data=data)
+        executor.outputs[0].wait_to_read()
+
+    gc.collect()
+    val_data = data_loader()
+    example_ct = 0
+    start = time()
+
+    # if use_tensorrt:
+    for idx, (data, label) in enumerate(val_data):
+        # Skip last batch if it's undersized.
+        if data.shape[0] < batch_size:
+            continue
+        executor.forward(is_train=False, data=data)
+        preds = executor.outputs[0].asnumpy()
+        offset = idx * batch_size
+        all_preds[offset:offset + batch_size, :] = preds[:batch_size]
+        example_ct += batch_size
+
+    all_preds = np.argmax(all_preds, axis=1)
+    matches = (all_preds[:example_ct] == all_label_test[:example_ct]).sum()
+    duration = time() - start
+
+    return duration, 100.0 * matches / example_ct
+
+
+def run_experiment_for(model_name, batch_size, num_workers):
+    print("\n===========================================")
+    print("Model: %s" % model_name)
+    print("===========================================")
+    print("*** Running inference using pure MXNet ***\n")
+    mx_duration, mx_pct = cifar10_infer(model_name=model_name, batch_size=batch_size,
+                                        num_workers=num_workers, use_tensorrt=False)
+    print("\nMXNet: time elapsed: %.3fs, accuracy: %.2f%%" % (mx_duration, mx_pct))
+    print("\n*** Running inference using MXNet + TensorRT ***\n")
+    trt_duration, trt_pct = cifar10_infer(model_name=model_name, batch_size=batch_size,
+                                          num_workers=num_workers, use_tensorrt=True)
+    print("TensorRT: time elapsed: %.3fs, accuracy: %.2f%%" % (trt_duration, trt_pct))
+    speedup = mx_duration / trt_duration
+    print("TensorRT speed-up (not counting compilation): %.2fx" % speedup)
+
+    acc_diff = abs(mx_pct - trt_pct)
+    print("Absolute accuracy difference: %f" % acc_diff)
+    return speedup, acc_diff
+
+
+def test_tensorrt_on_cifar_resnets(batch_size=32, tolerance=0.1, num_workers=1):
+    original_try_value = mx.contrib.tensorrt.get_use_tensorrt()
+    try:
+        models = [
+            'cifar_resnet20_v1',
+            'cifar_resnet56_v1',
+            'cifar_resnet110_v1',
+            'cifar_resnet20_v2',
+            'cifar_resnet56_v2',
+            'cifar_resnet110_v2',
+            'cifar_wideresnet16_10',
+            'cifar_wideresnet28_10',
+            'cifar_wideresnet40_8',
+            'cifar_resnext29_16x64d'
+        ]
+
+        num_models = len(models)
+
+        speedups = np.zeros(num_models, dtype=np.float32)
+        acc_diffs = np.zeros(num_models, dtype=np.float32)
+
+        test_start = time()
+
+        for idx, model in enumerate(models):
+            speedup, acc_diff = run_experiment_for(model, batch_size, num_workers)
+            speedups[idx] = speedup
+            acc_diffs[idx] = acc_diff
+            assert acc_diff < tolerance, "Accuracy difference between MXNet and TensorRT > %.2f%% for model %s" % (
+                tolerance, model)
+
+        print("Perf and correctness checks run on the following models:")
+        print(models)
+        mean_speedup = np.mean(speedups)
+        std_speedup = np.std(speedups)
+        print("\nSpeedups:")
+        print(speedups)
+        print("Speedup range: [%.2f, %.2f]" % (np.min(speedups), np.max(speedups)))
+        print("Mean speedup: %.2f" % mean_speedup)
+        print("St. dev. of speedups: %.2f" % std_speedup)
+        print("\nAcc. differences: %s" % str(acc_diffs))
+
+        test_duration = time() - test_start
+
+        print("Test duration: %.2f seconds" % test_duration)
+    finally:
+        mx.contrib.tensorrt.set_use_tensorrt(original_try_value)
+
+
+if __name__ == '__main__':
+    import nose
+
+    nose.runmodule()
diff --git a/tests/python/tensorrt/test_cycle.py b/tests/python/tensorrt/test_cycle.py
new file mode 100644
index 00000000000..25f515a106a
--- /dev/null
+++ b/tests/python/tensorrt/test_cycle.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from common import *
+
+
+def detect_cycle_from(sym, visited, stack):
+    visited.add(sym.handle.value)
+    stack.add(sym.handle.value)
+    for s in sym.get_children():
+        if s.handle.value not in visited:
+            if detect_cycle_from(sym, visited, stack):
+                return True
+        elif s.handle.value in stack:
+            return True
+        stack.remove(sym.handle.value)
+    return False
+
+
+def has_no_cycle(sym):
+    visited = set()
+    stack = set()
+    all_nodes = sym.get_internals()
+    for s in all_nodes:
+        if s.handle.value in visited:
+            if detect_cycle_from(s, visited, stack):
+                return False
+    return True
+
+
+def test_simple_cycle():
+    inp = mx.sym.Variable('input', shape=[1,10])
+    A = mx.sym.FullyConnected(data=inp, num_hidden=10, no_bias=False, name='A')
+    B = mx.sym.FullyConnected(data=A, num_hidden=10, no_bias=False, name='B')
+    D = mx.sym.sin(data=A, name='D')
+    C = mx.sym.elemwise_add(lhs=B, rhs=D, name='C')
+    arg_params = {
+                'I_weight': mx.nd.zeros([10,10]),
+                'I_bias': mx.nd.zeros([10]),
+                'A_weight': mx.nd.zeros([10,10]),
+                'A_bias': mx.nd.zeros([10]),
+                'B_weight': mx.nd.zeros([10,10]),
+                'B_bias': mx.nd.zeros([10]),
+               }
+
+    executor = C.simple_bind(ctx=mx.gpu(0), data=(1,10), softmax_label=(1,),
+                           shared_buffer=arg_params, grad_req='null', force_rebind=True)
+    optimized_graph = mx.contrib.tensorrt.get_optimized_symbol(executor)
+    assert has_no_cycle(optimized_graph), "The graph optimized by TRT contains a cycle"
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()
diff --git a/tests/python/tensorrt/test_tensorrt_lenet5.py b/tests/python/tensorrt/test_tensorrt_lenet5.py
new file mode 100644
index 00000000000..258686428a4
--- /dev/null
+++ b/tests/python/tensorrt/test_tensorrt_lenet5.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import numpy as np
+import mxnet as mx
+from common import *
+from lenet5_common import get_iters
+
+
+def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size, use_tensorrt):
+    """Run inference with either MXNet or TensorRT"""
+    mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt)
+
+    data_size = (batch_size,) + mnist['test_data'].shape[1:]
+    if use_tensorrt:
+        all_params = merge_dicts(arg_params, aux_params)
+        executor = mx.contrib.tensorrt.tensorrt_bind(sym, ctx=mx.gpu(0), all_params=all_params,
+                                                     data=data_size,
+                                                     softmax_label=(batch_size,),
+                                                     grad_req='null',
+                                                     force_rebind=True)
+    else:
+        executor = sym.simple_bind(ctx=mx.gpu(0),
+                                   data=data_size,
+                                   softmax_label=(batch_size,),
+                                   grad_req='null',
+                                   force_rebind=True)
+        executor.copy_params_from(arg_params, aux_params)
+
+    # Get this value from all_test_labels
+    # Also get classes from the dataset
+    num_ex = 10000
+    all_preds = np.zeros([num_ex, 10])
+    test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
+
+    example_ct = 0
+
+    for idx, dbatch in enumerate(test_iter):
+        executor.arg_dict["data"][:] = dbatch.data[0]
+        executor.forward(is_train=False)
+        offset = idx*batch_size
+        extent = batch_size if num_ex - offset > batch_size else num_ex - offset
+        all_preds[offset:offset+extent, :] = executor.outputs[0].asnumpy()[:extent]
+        example_ct += extent
+
+    all_preds = np.argmax(all_preds, axis=1)
+    matches = (all_preds[:example_ct] == all_test_labels[:example_ct]).sum()
+
+    percentage = 100.0 * matches / example_ct
+
+    return percentage
+
+
+def test_tensorrt_inference():
+    """Run LeNet-5 inference comparison between MXNet and TensorRT."""
+    original_try_value = mx.contrib.tensorrt.get_use_tensorrt()
+    try:
+        check_tensorrt_installation()
+        mnist = mx.test_utils.get_mnist()
+        num_epochs = 10
+        batch_size = 128
+        model_name = 'lenet5'
+        model_dir = os.getenv("LENET_MODEL_DIR", "/tmp")
+        model_file = '%s/%s-symbol.json' % (model_dir, model_name)
+        params_file = '%s/%s-%04d.params' % (model_dir, model_name, num_epochs)
+
+        _, _, _, all_test_labels = get_iters(mnist, batch_size)
+
+        # Load serialized MXNet model (model-symbol.json + model-epoch.params)
+        sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, num_epochs)
+
+        print("LeNet-5 test")
+        print("Running inference in MXNet")
+        mx_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels,
+                               batch_size=batch_size, use_tensorrt=False)
+
+        print("Running inference in MXNet-TensorRT")
+        trt_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels,
+                                batch_size=batch_size, use_tensorrt=True)
+
+        print("MXNet accuracy: %f" % mx_pct)
+        print("MXNet-TensorRT accuracy: %f" % trt_pct)
+
+        assert abs(mx_pct - trt_pct) < 1e-2, \
+            """Diff. between MXNet & TensorRT accuracy too high:
+               MXNet = %f, TensorRT = %f""" % (mx_pct, trt_pct)
+    finally:
+        mx.contrib.tensorrt.set_use_tensorrt(original_try_value)
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()
diff --git a/tests/python/tensorrt/test_training_warning.py b/tests/python/tensorrt/test_training_warning.py
new file mode 100644
index 00000000000..fdac859aef6
--- /dev/null
+++ b/tests/python/tensorrt/test_training_warning.py
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import gluoncv
+import mxnet as mx
+
+from tests.python.unittest.common import assertRaises
+
+
+def test_training_without_trt():
+    run_resnet(is_train=True, use_tensorrt=False)
+
+
+def test_inference_without_trt():
+    run_resnet(is_train=False, use_tensorrt=False)
+
+
+def test_training_with_trt():
+    assertRaises(RuntimeError, run_resnet, is_train=True, use_tensorrt=True)
+
+
+def test_inference_with_trt():
+    run_resnet(is_train=False, use_tensorrt=True)
+
+
+def run_resnet(is_train, use_tensorrt):
+    original_trt_value = mx.contrib.tensorrt.get_use_tensorrt()
+    try:
+        mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt)
+        ctx = mx.gpu(0)
+        batch_size = 1
+        h = 32
+        w = 32
+        model_name = 'cifar_resnet20_v1'
+        resnet = gluoncv.model_zoo.get_model(model_name, pretrained=True)
+        data = mx.sym.var('data')
+        out = resnet(data)
+        softmax = mx.sym.SoftmaxOutput(out, name='softmax')
+        if is_train:
+            grad_req = 'write'
+        else:
+            grad_req = 'null'
+        if use_tensorrt:
+            all_params = dict([(k, v.data()) for k, v in resnet.collect_params().items()])
+            mx.contrib.tensorrt.tensorrt_bind(softmax, ctx=ctx, all_params=all_params,
+                                              data=(batch_size, 3, h, w), softmax_label=(batch_size,),
+                                              force_rebind=True, grad_req=grad_req)
+        else:
+            softmax.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), softmax_label=(batch_size,),
+                                force_rebind=True, grad_req=grad_req)
+    finally:
+        mx.contrib.tensorrt.set_use_tensorrt(original_trt_value)
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services