You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/22 15:45:02 UTC

[GitHub] KellenSunderland closed pull request #11351: WIP

KellenSunderland closed pull request #11351: WIP
URL: https://github.com/apache/incubator-mxnet/pull/11351
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/.gitmodules b/.gitmodules
index 9aeb1c75498..836d824a6f5 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -26,3 +26,6 @@
 [submodule "3rdparty/tvm"]
 	path = 3rdparty/tvm
 	url = https://github.com/dmlc/tvm
+[submodule "3rdparty/onnx-tensorrt"]
+	path = 3rdparty/onnx-tensorrt
+	url = https://github.com/onnx/onnx-tensorrt.git
diff --git a/3rdparty/onnx-tensorrt b/3rdparty/onnx-tensorrt
new file mode 160000
index 00000000000..e7be19cff37
--- /dev/null
+++ b/3rdparty/onnx-tensorrt
@@ -0,0 +1 @@
+Subproject commit e7be19cff377a95817503e8525e20de34cdc574a
diff --git a/Jenkinsfile b/Jenkinsfile
index 56fbf3d74af..9e50ed0378b 100644
--- a/Jenkinsfile
+++ b/Jenkinsfile
@@ -336,6 +336,17 @@ try {
         }
       }
     },
+    'TensorRT': {
+      node('mxnetlinux-cpu') {
+        ws('workspace/build-tensorrt') {
+          timeout(time: max_time, unit: 'MINUTES') {
+            init_git()
+            docker_run('ubuntu_gpu_tensorrt', 'build_ubuntu_gpu_tensorrt', false)
+            pack_lib('tensorrt')
+          }
+        }
+      }
+    },
     'Build CPU windows':{
       node('mxnetwindows-cpu') {
         timeout(time: max_time, unit: 'MINUTES') {
diff --git a/Makefile b/Makefile
index ff4446ab80c..2fd7bfbf4e0 100644
--- a/Makefile
+++ b/Makefile
@@ -94,6 +94,14 @@ else
 endif
 CFLAGS += -I$(TPARTYDIR)/mshadow/ -I$(TPARTYDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(TPARTYDIR)/tvm/include -Iinclude $(MSHADOW_CFLAGS)
 LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS)
+
+
+ifeq ($(USE_TENSORRT), 1)
+	CFLAGS +=  -I$(ROOTDIR) -I$(TPARTYDIR) -DONNX_NAMESPACE=$(ONNX_NAMESPACE) -DMXNET_USE_TENSORRT=1
+	LDFLAGS += -lprotobuf -pthread -lonnx -lonnx_proto -lnvonnxparser -lnvonnxparser_runtime -lnvinfer -lnvinfer_plugin
+endif
+# -L/usr/local/lib
+
 ifeq ($(DEBUG), 1)
 	NVCCFLAGS += -std=c++11 -Xcompiler -D_FORCE_INLINES -g -G -O0 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS)
 else
diff --git a/amalgamation/amalgamation.py b/amalgamation/amalgamation.py
index 52d775b7692..a3c28f7118e 100644
--- a/amalgamation/amalgamation.py
+++ b/amalgamation/amalgamation.py
@@ -23,13 +23,12 @@
 import platform
 
 blacklist = [
-    'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh',
-    'cuda_runtime.h', 'cudnn.h', 'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h',
-    'glog/logging.h', 'io/azure_filesys.h', 'io/hdfs_filesys.h', 'io/s3_filesys.h',
-    'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h',
-    'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h',
-    'nvml.h', 'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h',
-    'omp.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h',
+    'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh', 'cuda_runtime.h', 'cudnn.h',
+    'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h', 'glog/logging.h', 'io/azure_filesys.h',
+    'io/hdfs_filesys.h', 'io/s3_filesys.h', 'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h',
+    'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h', 'NvInfer.h', 'nvml.h',
+    'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h', 'omp.h',
+    'onnx/onnx.pb.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h',
     'cusolverDn.h', 'internal/concurrentqueue_internal_debug.h', 'relacy/relacy_std.hpp',
     'relacy_shims.h', 'ittnotify.h', 'shared_mutex'
     ]
@@ -150,6 +149,7 @@ def expand(x, pending, stage):
                     h not in sysheaders and
                     'mkl' not in h and
                     'nnpack' not in h and
+                    'tensorrt' not in h and
                     not h.endswith('.cuh')): sysheaders.append(h)
             else:
                 expand.treeDepth += 1
diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
new file mode 100755
index 00000000000..9f72e0a420b
--- /dev/null
+++ b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
@@ -0,0 +1,61 @@
+# -*- mode: dockerfile -*-
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Dockerfile to run MXNet on Ubuntu 16.04 for CPU
+
+FROM nvidia/cuda:9.0-cudnn7-devel
+
+WORKDIR /work/deps
+
+COPY install/ubuntu_core.sh /work/
+RUN /work/ubuntu_core.sh
+COPY install/ubuntu_ccache.sh /work/
+RUN /work/ubuntu_ccache.sh
+COPY install/ubuntu_python.sh /work/
+RUN /work/ubuntu_python.sh
+COPY install/ubuntu_scala.sh /work/
+RUN /work/ubuntu_scala.sh
+COPY install/ubuntu_r.sh /work/
+RUN /work/ubuntu_r.sh
+COPY install/ubuntu_perl.sh /work/
+RUN /work/ubuntu_perl.sh
+COPY install/ubuntu_clang.sh /work/
+RUN /work/ubuntu_clang.sh
+COPY install/ubuntu_mklml.sh /work/
+RUN /work/ubuntu_mklml.sh
+COPY install/ubuntu_tvm.sh /work/
+RUN /work/ubuntu_tvm.sh
+COPY install/ubuntu_llvm.sh /work/
+RUN /work/ubuntu_llvm.sh
+COPY install/ubuntu_caffe.sh /work/
+RUN /work/ubuntu_caffe.sh
+COPY install/ubuntu_docs.sh /work/
+RUN /work/ubuntu_docs.sh
+COPY install/ubuntu_tutorials.sh /work/
+RUN /work/ubuntu_tutorials.sh
+COPY install/tensorrt.sh /work
+RUN /work/tensorrt.sh
+
+ARG USER_ID=0
+COPY install/ubuntu_adduser.sh /work/
+RUN /work/ubuntu_adduser.sh
+
+COPY runtime_functions.sh /work/
+
+WORKDIR /work/mxnet
+ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib
diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt_old b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt_old
new file mode 100755
index 00000000000..fb8f2514ebf
--- /dev/null
+++ b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt_old
@@ -0,0 +1,179 @@
+FROM nvidia/cuda:9.0-cudnn7-devel
+
+ENV MXNET_VERSION 1.2.0+
+LABEL com.nvidia.mxnet.version="${MXNET_VERSION}"
+ENV NVIDIA_MXNET_VERSION 18.07
+
+ARG USE_TRT=1
+ARG PYVER=3.5
+ENV ONNX_NAMESPACE onnx
+
+RUN PYSFX=`[ "$PYVER" != "2.7" ] && echo "$PYVER" | cut -c1-1 || echo ""` && \
+    apt-get update && apt-get install -y --no-install-recommends \
+        build-essential \
+        ca-certificates \
+        curl \
+        wget \
+        git \
+        libatlas-base-dev \
+        pkg-config \
+        libtiff5-dev \
+        libjpeg8-dev \
+        zlib1g-dev \
+        python$PYVER-dev \
+        autoconf \
+        automake \
+        libtool \
+        nasm \
+        unzip && \
+    rm -rf /var/lib/apt/lists/*
+
+# Need a newer version of CMake for ONNX and onnx-tensorrt
+RUN cd /usr/local/src && \
+    wget https://cmake.org/files/v3.8/cmake-3.8.2.tar.gz && \
+    tar -xvf cmake-3.8.2.tar.gz && \
+    cd cmake-3.8.2 && \
+    ./bootstrap && \
+    make -j$(nproc) && \
+    make install && \
+    cd .. && \
+    rm -rf cmake*
+
+# Make sure symlinks exist for either python 2 or 3
+RUN rm -f /usr/bin/python && ln -s /usr/bin/python$PYVER /usr/bin/python
+RUN MAJ=`echo "$PYVER" | cut -c1-1` && \
+    rm -f /usr/bin/python$MAJ && ln -s /usr/bin/python$PYVER /usr/bin/python$MAJ
+
+RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
+    python get-pip.py && \
+    rm get-pip.py
+
+# We need to force NumPy 1.13.3 because default is 1.14.1 right now
+# and that issues MxNet warnings since it's not officially supported
+# Install NumPy before the pip install --upgrade
+RUN pip install numpy==1.13.3
+RUN pip install --upgrade --no-cache-dir setuptools requests
+
+# The following are needed for Sockeye on python 3+ only.
+RUN [ "$PYVER" = "2.7" ] || pip install unidecode tqdm pyyaml
+
+RUN OPENCV_VERSION=3.1.0 && \
+    wget -q -O - https://github.com/Itseez/opencv/archive/${OPENCV_VERSION}.tar.gz | tar -xzf - && \
+    cd /opencv-${OPENCV_VERSION} && \
+    cmake -DCMAKE_BUILD_TYPE=RELEASE -DCMAKE_INSTALL_PREFIX=/usr \
+          -DWITH_CUDA=OFF -DWITH_1394=OFF \
+          -DBUILD_opencv_cudalegacy=OFF -DBUILD_opencv_stitching=OFF -DWITH_IPP=OFF . && \
+    make -j"$(nproc)" install && \
+    rm -rf /opencv-${OPENCV_VERSION}
+
+# libjpeg-turbo
+RUN JPEG_TURBO_VERSION=1.5.2 && \
+    wget -q -O - https://github.com/libjpeg-turbo/libjpeg-turbo/archive/${JPEG_TURBO_VERSION}.tar.gz | tar -xzf - && \
+    cd /libjpeg-turbo-${JPEG_TURBO_VERSION} && \
+    autoreconf -fiv && \
+    ./configure --enable-shared --prefix=/usr 2>&1 >/dev/null && \
+    make -j"$(nproc)" install 2>&1 >/dev/null && \
+    rm -rf /libjpeg-turbo-${JPEG_TURBO_VERSION}
+
+WORKDIR /
+
+# Install protoc 3.5 and build protobuf here (for onnx and onnx-tensorrt)
+RUN if [ $USE_TRT = "1" ]; \
+    then \
+      echo "TensorRT build enabled. Installing Protobuf."; \
+      git clone --recursive -b 3.5.1.1 https://github.com/google/protobuf.git; \
+      cd protobuf; \
+      ./autogen.sh; \
+      ./configure; \
+      make -j$(nproc); \
+      make install; \
+      ldconfig; \
+    else \
+      echo "TensorRT build disabled. Not installing Protobuf."; \
+    fi
+
+# Install TensorRT 4.0 for CUDA 9
+RUN if [ $USE_TRT = "1" ]; \
+    then \
+      echo "TensorRT build enabled. Installing TensorRT."; \
+      wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-3.0.4-ga-cuda9.0_1.0-1_amd64.deb; \
+      dpkg -i tensorrt.deb; \
+      apt-get update; \
+      apt-get install -y --allow-downgrades libnvinfer-dev; \
+      rm tensorrt.deb; \
+    else \
+        echo "TensorRT build disabled. Not installing TensorRT."; \
+    fi
+
+WORKDIR /opt/mxnet
+COPY . .
+
+ENV MXNET_HOME "/opt/mxnet"
+ENV MXNET_CUDNN_AUTOTUNE_DEFAULT 2
+
+RUN cp make/config.mk . && \
+   echo "USE_CUDA=1" >> config.mk && \
+   echo "USE_CUDNN=1" >> config.mk && \
+   echo "CUDA_ARCH :=" \
+        "-gencode arch=compute_52,code=sm_52" \
+        "-gencode arch=compute_60,code=sm_60" \
+        "-gencode arch=compute_61,code=sm_61" \
+        "-gencode arch=compute_70,code=sm_70" \
+        "-gencode arch=compute_70,code=compute_70" >> config.mk && \
+    echo "USE_CUDA_PATH=/usr/local/cuda" >> config.mk && \
+    echo "USE_LIBJPEG_TURBO=1" >> config.mk && \
+    echo "USE_LIBJPEG_TURBO_PATH=/usr" >> config.mk
+
+RUN if [ $USE_TRT = "1" ]; \
+    then \
+      echo "TensorRT build enabled. Adding flags to config.mk."; \
+      echo "USE_TENSORRT=1" >> config.mk; \
+      echo "ONNX_NAMESPACE=$ONNX_NAMESPACE" >> config.mk; \
+    else \
+      echo "TensorRT build disabled. Not adding TensorRT flags to config.mk."; \
+    fi
+
+ENV LD_LIBRARY_PATH $LD_LIBRARY_PATH:/usr/local/lib
+
+# Building ONNX, then onnx-tensorrt
+WORKDIR /opt/mxnet/3rdparty/onnx-tensorrt/third_party/onnx
+
+RUN if [ $USE_TRT = "1" ]; \
+  then \
+    echo "TensorRT build enabled. Installing ONNX."; \
+    rm -rf build; \
+    mkdir build; \
+    cd build; \
+    cmake -DCMAKE_CXX_FLAGS=-I/usr/include/python${PYVER} -DBUILD_SHARED_LIBS=ON ..; \
+    make -j$(nproc); \
+    make install; \
+    ldconfig; \
+    cd ..; \
+    mkdir /usr/include/x86_64-linux-gnu/onnx; \
+    cp build/onnx/onnx*pb.* /usr/include/x86_64-linux-gnu/onnx; \
+    cp build/libonnx.so /usr/local/lib && ldconfig; \
+  else \
+    echo "TensorRT build disabled. Not installing ONNX."; \
+  fi
+
+WORKDIR /opt/mxnet/3rdparty/onnx-tensorrt
+
+RUN if [ $USE_TRT = "1" ]; \
+  then \
+    echo "TensorRT build enabled. Installing onnx-tensorrt."; \
+    mkdir build && cd build && cmake ..; \
+    make -j$(nproc); \
+    make install; \
+    ldconfig; \
+  else \
+    echo "TensorRT build disabled. Not installing onnx-tensorrt."; \
+  fi
+
+WORKDIR /opt/mxnet
+
+RUN make -j$(nproc) && \
+    mv lib/libmxnet.so /usr/local/lib && \
+    ldconfig && \
+    make clean && \
+    cd python && \
+    pip install -e .
diff --git a/ci/docker/install/tensorrt.sh b/ci/docker/install/tensorrt.sh
new file mode 100755
index 00000000000..77259b81ed1
--- /dev/null
+++ b/ci/docker/install/tensorrt.sh
@@ -0,0 +1,83 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Install protoc 3.5 and build protobuf here (for onnx and onnx-tensorrt)
+# echo "TensorRT build enabled. Installing Protobuf."; \
+# git clone --recursive -b 3.5.1.1 https://github.com/google/protobuf.git
+# cd protobuf
+# ./autogen.sh
+# ./configure
+# make -j$(nproc)
+# make install
+# ldconfig
+
+# Install Protobuf
+# Install protoc 3.5 and build protobuf here (for onnx and onnx-tensorrt)
+pushd .
+cd ..
+git clone --recursive -b 3.5.1.1 https://github.com/google/protobuf.git
+cd protobuf
+./autogen.sh
+./configure
+make -j$(nproc)
+make install
+ldconfig
+popd
+
+# Install TensorRT
+echo "TensorRT build enabled. Installing TensorRT."
+wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-3.0.4-ga-cuda9.0_1.0-1_amd64.deb
+dpkg -i tensorrt.deb
+apt-get update
+apt-get install -y --allow-downgrades libnvinfer-dev
+rm tensorrt.deb
+
+# Install ONNX
+#pushd .
+#cd 3rdparty/onnx-tensorrt/third_party/onnx
+#rm -rf build
+#mkdir build
+#cd build
+#cmake -DCMAKE_CXX_FLAGS=-I/usr/include/python${PYVER} -DBUILD_SHARED_LIBS=ON ..
+#make -j$(nproc)
+#make install
+#ldconfig
+#cd ..
+#mkdir /usr/include/x86_64-linux-gnu/onnx
+#cp build/onnx/onnx*pb.* /usr/include/x86_64-linux-gnu/onnx
+#cp build/libonnx.so /usr/local/lib
+#ldconfig
+#popd
+#
+## Install ONNX-TensorRT
+#echo "==============================================================="
+#pwd
+#ls -la
+#cd ..
+#ls -la
+#cd /
+#ls -R 
+#echo "==============================================================="
+#cd 3rdparty/onnx-tensorrt/
+#mkdir build
+#cd build
+#cmake ..
+#make -j$(nproc)
+#make install
+#ldconfig
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 6cefeea9fbc..1c53ea19e67 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -437,6 +437,49 @@ build_ubuntu_gpu() {
     build_ubuntu_gpu_cuda91_cudnn7
 }
 
+build_ubuntu_gpu_tensorrt() {
+
+    set -ex
+    pushd .
+    pushd .
+    # Install ONNX
+    echo "TensorRT build enabled. Installing ONNX."
+    rm -rf build
+    mkdir build
+    cd build
+    cmake -DCMAKE_CXX_FLAGS=-I/usr/include/python${PYVER} -DBUILD_SHARED_LIBS=ON ..
+    make -j$(nproc)
+    make install
+    ldconfig
+    cd ..
+    mkdir /usr/include/x86_64-linux-gnu/onnx
+    cp build/onnx/onnx*pb.* /usr/include/x86_64-linux-gnu/onnx
+    cp build/libonnx.so /usr/local/lib
+    ldconfig
+    popd
+    # Install ONNX-TensorRT
+    cd 3rdparty/onnx-tensorrt/
+    mkdir build
+    cd build
+    cmake ..
+    make -j$(nproc)
+    make install
+    ldconfig
+    popd
+
+    make \
+        DEV=1                         \
+        USE_BLAS=openblas             \
+        USE_CUDA=1                    \
+        USE_CUDA_PATH=/usr/local/cuda \
+        USE_CUDNN=1                   \
+        USE_CPP_PACKAGE=1             \
+        USE_DIST_KVSTORE=1            \
+        USE_TENSORRT=1                \
+        ONNX_NAMESPACE=onnx           \
+        -j$(nproc)
+}
+
 build_ubuntu_gpu_mkldnn() {
     set -ex
 
diff --git a/docs/api/python/contrib/tensorrt.md b/docs/api/python/contrib/tensorrt.md
new file mode 100644
index 00000000000..d2ee646ad1e
--- /dev/null
+++ b/docs/api/python/contrib/tensorrt.md
@@ -0,0 +1,117 @@
+# MxNet-TensorRT Runtime Integration
+## What is this?
+
+This document described how to use the [MxNet](http://mxnet.incubator.apache.org/)-[TensorRT](https://developer.nvidia.com/tensorrt) runtime integration to accelerate model inference.
+
+## Why is TensorRT integration useful? 
+
+TensorRT can greatly speed up inference of deep learning models. One experiment on a Titan V (V100) GPU shows that with MxNet 1.2, we can get an approximately 3x speed-up when running inference of the ResNet-50 model on the CIFAR-10 dataset in single precision (fp32). As batch sizes and image sizes go up (for CNN inference), the benefit may be less, but in general, TensorRT helps especially in cases which have:
+- many bandwidth-bound layers (e.g. pointwise operations) that benefit from GPU kernel fusion
+- inference use cases which have tight latency requirements and where the client application can't wait for large batches to be queued up
+- embedded systems, where memory constraints are tighter than on servers
+- when performing inference in reduced precision, especially for integer (e.g. int8) inference. 
+
+In the past, the main hindrance for the user wishing to benefit from TensorRT was the fact that the model needed to be exported from the framework first. Once the model got exported through some means (NNVM to TensorRT graph rewrite, via ONNX, etc.), one had to then write a TensorRT client application, which would feed the data into the TensorRT engine. Since at that point the model was independent of the original framework, and since TensorRT could only compute the neural network layers but the user had to bring their own data pipeline, this increased the burden on the user and reduced the likelihood of reproducibility (e.g. different frameworks may have slightly different data pipelines, or flexibility of data pipeline operation ordering). Moreover, since frameworks typically support more operators than TensorRT, one could have to resort to TensorRT plugins for operations that aren't already available via the TensorRT graph API.  
+
+The current experimental runtime integration of TensorRT with MxNet resolves the above concerns by ensuring that:
+- the graph is still executed by MxNet
+- the MxNet data pipeline is preserved
+- the TensorRT runtime integration logic partitions the graph into subgraphs that are either TensorRT compatible or incompatible
+- the graph partitioner collects the TensorRT-compatible subgraphs, hands them over to TensorRT, and substitutes the TensorRT compatible subgraph with a TensorRT library call, represented as a TensorRT node in NNVM.
+- if a node is not TensorRT compatible, it won't be extracted and substituted with a TensorRT call, and will still execute within MxNet
+
+The above points ensure that we find a compromise between the flexibility of MxNet, and fast inference in TensorRT, without putting a burden on the user to learn how TensorRT APIs work, without the need to write one's own client application and data pipeline, etc.
+
+## How do I build MxNet with TensorRT integration?
+
+Building MxNet together with TensorRT is somewhat complex. The recipe will hopefully be simplified in the near future, but for now, it's easiest to build a Docker container with a Ubuntu 16.04 base. This Dockerfile can be found under the ci subdirectory of the MxNet repository. You can build the container as follows:
+
+```
+docker build -t ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt mxnet_with_tensorrt
+```
+
+Next, we can run this container as follows (don't forget to install [nvidia-docker](https://github.com/NVIDIA/nvidia-docker)):
+
+```no-highlight
+nvidia-docker run -ti --rm mxnet_with_tensorrt
+```
+
+After starting the container, you will find yourself in the /opt/mxnet directory by default.
+
+## Running a "hello, world" model / unit test:
+
+You can then run the LeNet-5 unit test, which will train LeNet-5 on MNIST, and subsequently run inference in MxNet, as well as using the MxNet-TensorRT runtime integration, and compare the results. The test can be run as follows:
+
+```no-highlight
+python tests/python/tensorrt/test_tensorrt_lenet5.py
+```
+
+You should get a result similar to the following:
+
+```no-highlight
+Running inference in MxNet
+[03:31:18] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:107: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
+Running inference in MxNet-TensorRT
+[03:31:18] src/operator/contrib/nnvm_to_onnx.cc:152: ONNX graph construction complete.
+Building TensorRT engine, FP16 available:1
+    Max batch size:     1024
+    Max workspace size: 1024 MiB
+[03:31:18] src/operator/contrib/tensorrt.cc:85: TensorRT engine instantiated!!!
+MxNet accuracy: 98.680000
+MxNet-TensorRT accuracy: 98.680000
+```
+
+## Runing a more complex model
+
+To show that the runtime integration handles more complex models such as ResNet-50 (which includes batch normalization as well as skip connections), the relevant script is included in the `example/image_classification/tensorrt` directory.
+
+## Building your own models
+
+When building your own models, feel free to use the above ResNet-50 model as an example. Here, we highlight a small number of issues that need to be taken into account.
+
+1. When loading a pre-trained model, the inference will be handled using the Symbol API, rather than the Module API.
+2. In order to provide the weights to the MxNet (NNVM) to TensorRT graph converter befor the symbol is fully bound (before the memory is allocated, etc.), the `arg_params` and `aux_params` need to be provided to the symbol's `simple_bind` method. The weights and other values (e.g. moments learned from data by batch normalization, provided via `aux_params`) will be provided via the `shared_buffer` argument to `simple_bind` as follows:
+```python
+executor = sym.simple_bind(ctx=ctx, data = data_shape,
+    softmax_label=sm_shape, grad_req='null', shared_buffer=all_params, force_rebind=True)
+```
+3. To collect `arg_params` and `aux_params` from the dictionaries loaded by `model.load()`, we need to combine them into one dictionary:
+```python
+def merge_dicts(*dict_args):
+    result = {}
+    for dictionary in dict_args:
+        result.update(dictionary)
+    return result
+
+sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, epoch)
+
+all_params = merge_dicts(arg_params, aux_params)
+```
+This `all_params` dictionary cn be seem in use in the `simple_bind` call in `#2`.
+4. Once the symbol is bound, we need to feed the data and run the `forward()` method. Let's say we're using a test set data iterator called `test_iter`. We can run inference as follows:
+```python
+for idx, dbatch in enumerate(test_iter):
+    data = dbatch.data[0]
+    executor.arg_dict["data"][:] = data
+    executor.forward(is_train=False)
+    preds = executor.outputs[0].asnumpy() 
+    top1 = np.argmax(preds, axis=1)
+```
+5. **Note:** One can choose between running inference with and without TensorRT. This can be selected by changing the state of the `MXNET_USE_TENSORRT` environment variable. Let's first write a convenience function to change the state of this environment variable:
+```python
+def set_use_tensorrt(status = False):
+    os.environ["MXNET_USE_TENSORRT"] = str(int(status))
+```
+Now, assuming that the logic to bind a symbol and run inference in batches of `batch_size` on dataset `dataset` is wrapped in the `run_inference` function, we can do the following:
+```python
+print("Running inference in MxNet")
+set_use_tensorrt(False)
+mx_pct = run_inference(sym, arg_params, aux_params, mnist,
+                       all_test_labels, batch_size=batch_size)
+
+print("Running inference in MxNet-TensorRT")
+set_use_tensorrt(True)
+trt_pct = run_inference(sym, arg_params, aux_params, mnist,
+                        all_test_labels,  batch_size=batch_size)
+```
+Simply switching the flag allows us to go back and forth between MxNet and MxNet-TensorRT inference. See the details in the unit test at `tests/python/tensorrt/test_tensorrt_lenet5.py`.
diff --git a/example/image-classification/tensorrt/test_tensorrt_resnet50.py b/example/image-classification/tensorrt/test_tensorrt_resnet50.py
new file mode 100644
index 00000000000..848967d69fb
--- /dev/null
+++ b/example/image-classification/tensorrt/test_tensorrt_resnet50.py
@@ -0,0 +1,186 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import print_function
+
+import os.path
+import subprocess
+import mxnet as mx
+import numpy as np
+from time import time
+import sys
+import urllib
+
+def get_use_tensorrt():
+    return int(os.environ.get("MXNET_USE_TENSORRT", 0))
+
+def set_use_tensorrt(status = False):
+    os.environ["MXNET_USE_TENSORRT"] = str(int(status))
+
+def download_file(url, local_fname=None, force_write=False):
+    # requests is not default installed
+    import requests
+    if local_fname is None:
+        local_fname = url.split('/')[-1]
+    if not force_write and os.path.exists(local_fname):
+        return local_fname
+
+    dir_name = os.path.dirname(local_fname)
+
+    if dir_name != "":
+        if not os.path.exists(dir_name):
+            try: # try to create the directory if it doesn't exists
+                os.makedirs(dir_name)
+            except OSError as exc:
+                if exc.errno != errno.EEXIST:
+                    raise
+
+    r = requests.get(url, stream=True)
+    assert r.status_code == 200, "failed to open %s" % url
+    with open(local_fname, 'wb') as f:
+        for chunk in r.iter_content(chunk_size=1024):
+            if chunk: # filter out keep-alive new chunks
+                f.write(chunk)
+    return local_fname
+
+def download_cifar10(data_dir):
+    fnames = (os.path.join(data_dir, "cifar10_train.rec"),
+              os.path.join(data_dir, "cifar10_val.rec"))
+    download_file('http://data.mxnet.io/data/cifar10/cifar10_val.rec', fnames[1])
+    download_file('http://data.mxnet.io/data/cifar10/cifar10_train.rec', fnames[0])
+    return fnames
+
+def get_cifar10_iterator(args, kv):
+    data_shape = (3, 32, 32) #28, 28) 
+    data_dir = args['data_dir']
+    if os.name == "nt":
+        data_dir = data_dir[:-1] + "\\"
+    if '://' not in args['data_dir']:
+        print("Did not find data.")
+        download_cifar10(data_dir)
+
+    train = mx.io.ImageRecordIter(
+        path_imgrec = os.path.join(data_dir, "cifar10_train.rec"),
+        mean_img    = os.path.join(data_dir, "mean.bin"),
+        data_shape  = data_shape,
+        batch_size  = args['batch_size'],
+        rand_crop   = True,
+        rand_mirror = True,
+        num_parts   = kv['num_workers'],
+        part_index  = kv['rank'])
+
+    val = mx.io.ImageRecordIter(
+        path_imgrec = os.path.join(data_dir, "cifar10_val.rec"),
+        mean_img    = os.path.join(data_dir, "mean.bin"),
+        rand_crop   = False,
+        rand_mirror = False,
+        data_shape  = data_shape,
+        batch_size  = args['batch_size'],
+        num_parts   = kv['num_workers'],
+        part_index  = kv['rank'])
+
+    return (train, val)
+
+
+# To support Python 2 and 3.x < 3.5
+def merge_dicts(*dict_args):
+    result = {}
+    for dictionary in dict_args:
+        result.update(dictionary)
+    return result
+
+def get_exec(model_prefix='resnet50', image_size=(32, 32), batch_size = 128, ctx=mx.gpu(0), epoch=1):
+
+    sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, epoch)
+
+    h, w = image_size
+    data_shape=(batch_size, 3, h, w)
+    sm_shape=(batch_size,)
+
+    data = mx.sym.Variable("data")
+    softmax_label = mx.sym.Variable("softmax_label")
+
+    all_params = merge_dicts(arg_params, aux_params)
+
+    if not get_use_tensorrt():
+        all_params = dict([(k, v.as_in_context(mx.gpu(0))) for k, v in all_params.items()])
+
+    executor = sym.simple_bind(ctx=ctx, data = data_shape,
+        softmax_label=sm_shape, grad_req='null', shared_buffer=all_params, force_rebind=True)
+
+    return executor, h, w
+
+def compute(model_prefix, epoch, data_dir, batch_size=128):
+
+    executor, h, w = get_exec(model_prefix=model_prefix,
+                              image_size=(32, 32), 
+                              batch_size=batch_size, 
+                              ctx=mx.gpu(0),
+                              epoch=epoch)
+    num_ex = 10000
+    all_preds = np.zeros([num_ex, 10])
+
+    train_iter, test_iter = get_cifar10_iterator(args={'data_dir':data_dir, 'batch_size':batch_size}, kv={'num_workers':1, 'rank':0})
+
+    train_iter2, test_iter2 = get_cifar10_iterator(args={'data_dir':data_dir, 'batch_size':num_ex}, kv={'num_workers':1, 'rank':0})
+
+    all_label_train = train_iter2.next().label[0].asnumpy()
+    all_label_test = test_iter2.next().label[0].asnumpy().astype(np.int32)
+
+    train_iter, test_iter = get_cifar10_iterator(args={'data_dir':'./data', 'batch_size':batch_size}, kv={'num_workers':1, 'rank':0})
+
+    start = time()
+
+    example_ct = 0
+
+    for idx, dbatch in enumerate(test_iter):
+        data = dbatch.data[0]
+        executor.arg_dict["data"][:] = data
+        executor.forward(is_train=False)
+        preds = executor.outputs[0].asnumpy()
+        offset = idx*batch_size
+        extent = batch_size if num_ex - offset > batch_size else num_ex - offset
+        all_preds[offset:offset+extent, :] = preds[:extent]
+        example_ct += extent
+
+    all_preds = np.argmax(all_preds, axis=1)
+
+    matches = (all_preds[:example_ct] == all_label_test[:example_ct]).sum()
+
+    percentage = 100.0 * matches / example_ct
+
+    return percentage, time() - start
+
+if __name__ == '__main__':
+
+    model_prefix = sys.argv[1]
+    epoch = int(sys.argv[2])
+    data_dir = sys.argv[3]
+    batch_size = 1024
+
+    print("\nRunning inference in MxNet\n")
+    set_use_tensorrt(False)
+    mxnet_pct, mxnet_time = compute(model_prefix, epoch, data_dir, batch_size)
+
+    print("\nRunning inference in MxNet-TensorRT\n")
+    set_use_tensorrt(True)
+    trt_pct, trt_time = compute(model_prefix, epoch, data_dir, batch_size)
+
+    print("MxNet time: %f" % mxnet_time)
+    print("MxNet-TensorRT time: %f" % trt_time)
+    print("Speedup: %fx" % (mxnet_time / trt_time))
+
diff --git a/example/image-classification/tensorrt/test_tensorrt_resnet50.sh b/example/image-classification/tensorrt/test_tensorrt_resnet50.sh
new file mode 100755
index 00000000000..c0ec30238af
--- /dev/null
+++ b/example/image-classification/tensorrt/test_tensorrt_resnet50.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+EPOCH=20
+MODEL_PREFIX="resnet50"
+SYMBOL="${MODEL_PREFIX}-symbol.json"
+PARAMS="${MODEL_PREFIX}-$(printf "%04d" $EPOCH).params"
+DATA_DIR="./data"
+
+if [[ ! -f $SYMBOL || ! -f $PARAMS ]]; then
+  echo -e "\nTrained model does not exist. Training - please wait...\n"
+  python $MXNET_HOME/example/image-classification/train_cifar10.py \
+     --network resnet --num-layers 50 --num-epochs ${EPOCH} \
+     --model-prefix ./${MODEL_PREFIX} --gpus 0
+else
+   echo "Pre-trained model exists. Skipping training."
+fi
+
+echo "Running inference script."
+
+python test_tensorrt_resnet50.py $MODEL_PREFIX $EPOCH $DATA_DIR
+
diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h
index 842653f8653..20b5deff9ed 100644
--- a/include/mxnet/executor.h
+++ b/include/mxnet/executor.h
@@ -152,19 +152,19 @@ class Executor {
   static Executor* SimpleBind(nnvm::Symbol symbol,
                               const Context& default_ctx,
                               const std::map<std::string, Context>& group2ctx,
-                              const std::vector<Context>& in_arg_ctxes,
-                              const std::vector<Context>& arg_grad_ctxes,
-                              const std::vector<Context>& aux_state_ctxes,
-                              const std::unordered_map<std::string, TShape>& arg_shape_map,
-                              const std::unordered_map<std::string, int>& arg_dtype_map,
-                              const std::unordered_map<std::string, int>& arg_stype_map,
-                              const std::vector<OpReqType>& grad_req_types,
-                              const std::unordered_set<std::string>& param_names,
+                              std::vector<Context>* in_arg_ctxes,
+                              std::vector<Context>* arg_grad_ctxes,
+                              std::vector<Context>* aux_state_ctxes,
+                              std::unordered_map<std::string, TShape>* arg_shape_map,
+                              std::unordered_map<std::string, int>* arg_dtype_map,
+                              std::unordered_map<std::string, int>* arg_stype_map,
+                              std::vector<OpReqType>* grad_req_types,
+                              std::unordered_set<std::string>* param_names,
                               std::vector<NDArray>* in_args,
                               std::vector<NDArray>* arg_grads,
                               std::vector<NDArray>* aux_states,
                               std::unordered_map<std::string, NDArray>*
-                                shared_data_arrays = nullptr,
+                              shared_data_arrays = nullptr,
                               Executor* shared_exec = nullptr);
   /*!
    * \brief the prototype of user-defined monitor callback
diff --git a/python/mxnet/cuda_utils.py b/python/mxnet/cuda_utils.py
new file mode 100644
index 00000000000..11f8fa43995
--- /dev/null
+++ b/python/mxnet/cuda_utils.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Copyright (c) 2015 by Contributors
+# File: serialization.h
+# Purpose: Functions to query GPU count, arch, etc.
+# Author: Dick Carter
+
+"""Provides information on the visible CUDA GPUs on the system."""
+# pylint: disable=broad-except
+# As a stand-alone program, it prints a list of unique cuda SM architectures
+import ctypes as C
+from ctypes.util import find_library
+
+def cint(init_val=0):
+    """create a C int with an optional initial value"""
+    return C.c_int(init_val)
+
+def int_addr(x):
+    """given a c_int, return it's address as an int ptr"""
+    x_addr = C.addressof(x)
+    INTP = C.POINTER(C.c_int)
+    x_int_addr = C.cast(x_addr, INTP)
+    return x_int_addr
+
+def checked_call(f, *args):
+    """call a cuda function and check for success"""
+    error_t = f(*args)
+    assert error_t == 0, "Failing cuda call %s returns %s." % (f.__name__, error_t)
+
+def find_cuda_lib(candidates):
+    for candidate in candidates:
+        try:
+            lib = find_library(candidate)
+            if lib is not None:
+                return lib
+        except Exception:
+            pass
+    return None
+
+# Find cuda library in an os-independent way ('nvcuda' needed for Windows)
+try:
+    cuda = C.cdll.LoadLibrary(find_cuda_lib(['cuda', 'nvcuda']))
+    checked_call(cuda.cuInit, cint(0))
+except Exception:
+    cuda = None
+
+def get_device_count():
+    """get number of cuda devices on the system"""
+    if cuda is None:
+        return 0
+    else:
+        device_count = cint()
+        checked_call(cuda.cuDeviceGetCount, int_addr(device_count))
+        return device_count.value
+
+def get_sm_arch(device_id):
+    """get SM architecture of the device at the given index"""
+    major = cint()
+    minor = cint()
+    checked_call(cuda.cuDeviceComputeCapability, int_addr(major),
+                 int_addr(minor),
+                 cint(device_id))
+    return 10 * major.value + minor.value
+
+def unique_sm_arches():
+    """returns a list of unique cuda SM architectures on the system"""
+    archs = set()
+    device_count = get_device_count()
+    for device_id in range(device_count):
+        archs.add(get_sm_arch(device_id))
+    return sorted(archs)
+
+# print a list of unique cuda SM architectures on the system
+if __name__ == '__main__':
+    print(' '.join(str(x) for x in unique_sm_arches()))
diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py
index 5d8e95077c4..c4050699bd5 100755
--- a/python/mxnet/module/executor_group.py
+++ b/python/mxnet/module/executor_group.py
@@ -592,8 +592,8 @@ def backward(self, out_grads=None):
                     # pylint: disable=no-member
                     og_my_slice = nd.slice_axis(grad, axis=axis, begin=islice.start,
                                                 end=islice.stop)
-                    # pylint: enable=no-member
                     out_grads_slice.append(og_my_slice.as_in_context(self.contexts[i]))
+                    # pylint: enable=no-member
                 else:
                     out_grads_slice.append(grad.copyto(self.contexts[i]))
             exec_.backward(out_grads=out_grads_slice)
diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc
index 09bc23934e5..d7f21223116 100644
--- a/src/c_api/c_api_executor.cc
+++ b/src/c_api/c_api_executor.cc
@@ -26,6 +26,7 @@
 #include <mxnet/c_api.h>
 #include <mxnet/executor.h>
 #include "./c_api_common.h"
+#include "../executor/graph_executor.h"
 
 int MXExecutorPrint(ExecutorHandle handle, const char **out_str) {
   Executor *exec = static_cast<Executor*>(handle);
@@ -440,9 +441,9 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
   std::vector<NDArray> arg_grad_vec;
   std::vector<NDArray> aux_state_vec;
 
-  *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec,
-                              aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map,
-                              grad_req_type_vec, shared_arg_name_set, &in_arg_vec,
+  *out = Executor::SimpleBind(*sym, ctx, ctx_map, &in_arg_ctx_vec, &arg_grad_ctx_vec,
+                              &aux_state_ctx_vec, &arg_shape_map, &arg_dtype_map, &arg_stype_map,
+                              &grad_req_type_vec, &shared_arg_name_set, &in_arg_vec,
                               &arg_grad_vec, &aux_state_vec,
                               use_shared_buffer ? &shared_buffer_map : nullptr,
                               reinterpret_cast<Executor*>(shared_exec_handle));
diff --git a/src/common/serialization.h b/src/common/serialization.h
new file mode 100644
index 00000000000..5fb72ad3c99
--- /dev/null
+++ b/src/common/serialization.h
@@ -0,0 +1,526 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015 by Contributors
+ * \file serialization.h
+ * \brief Serialization of some STL and nnvm data-structures
+ * \author Clement Fuji Tsang
+ */
+
+#ifndef MXNET_COMMON_SERIALIZATION_H_
+#define MXNET_COMMON_SERIALIZATION_H_
+
+#include <dmlc/logging.h>
+#include <mxnet/graph_attr_types.h>
+#include <nnvm/graph_attr_types.h>
+#include <nnvm/tuple.h>
+
+#include <cstring>
+#include <map>
+#include <set>
+#include <string>
+#include <tuple>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+
+namespace mxnet {
+namespace common {
+
+template<typename T>
+inline size_t serialized_size(const T& obj);
+
+template<typename T>
+inline size_t serialized_size(const nnvm::Tuple<T>& obj);
+
+template<typename T>
+inline size_t serialized_size(const std::vector<T>& obj);
+
+template<typename K, typename V>
+inline size_t serialized_size(const std::pair<K, V>& obj);
+
+template<typename K, typename V>
+inline size_t serialized_size(const std::map<K, V>& obj);
+
+template<typename K, typename V>
+inline size_t serialized_size(const std::unordered_map<K, V>& obj);
+
+template<typename K>
+inline size_t serialized_size(const std::set<K>& obj);
+
+template<typename K>
+inline size_t serialized_size(const std::unordered_set<K>& obj);
+
+template<>
+inline size_t serialized_size(const std::string& obj);
+
+template<typename... Args>
+inline size_t serialized_size(const std::tuple<Args...>& obj);
+
+template<typename T>
+inline void serialize(const T& obj, char** buffer);
+
+template<typename T>
+inline void serialize(const nnvm::Tuple<T>& obj, char** buffer);
+
+template<typename T>
+inline void serialize(const std::vector<T>& obj, char** buffer);
+
+template<typename K, typename V>
+inline void serialize(const std::pair<K, V>& obj, char** buffer);
+
+template<typename K, typename V>
+inline void serialize(const std::map<K, V>& obj, char** buffer);
+
+template<typename K, typename V>
+inline void serialize(const std::unordered_map<K, V>& obj, char** buffer);
+
+template<typename K>
+inline void serialize(const std::set<K>& obj, char** buffer);
+
+template<typename K>
+inline void serialize(const std::unordered_set<K>& obj, char** buffer);
+
+template<>
+inline void serialize(const std::string& obj, char** buffer);
+
+template<typename... Args>
+inline void serialize(const std::tuple<Args...>& obj, char** buffer);
+
+template<typename T>
+inline void deserialize(T* obj, const std::string& buffer, size_t* curr_pos);
+
+template<typename T>
+inline void deserialize(nnvm::Tuple<T>* obj, const std::string& buffer, size_t* curr_pos);
+
+template<typename T>
+inline void deserialize(std::vector<T>* obj, const std::string& buffer, size_t* curr_pos);
+
+template<typename K, typename V>
+inline void deserialize(std::pair<K, V>* obj, const std::string& buffer, size_t* curr_pos);
+
+template<typename K, typename V>
+inline void deserialize(std::map<K, V>* obj, const std::string& buffer, size_t* curr_pos);
+
+template<typename K, typename V>
+inline void deserialize(std::unordered_map<K, V>* obj, const std::string& buffer, size_t* curr_pos);
+
+template<typename K>
+inline void deserialize(std::set<K>* obj, const std::string& buffer, size_t* curr_pos);
+
+template<typename K>
+inline void deserialize(std::unordered_set<K>* obj, const std::string& buffer, size_t* curr_pos);
+
+template<>
+inline void deserialize(std::string* obj, const std::string& buffer, size_t* curr_pos);
+
+template<typename... Args>
+inline void deserialize(std::tuple<Args...>* obj, const std::string& buffer, size_t* curr_pos);
+
+
+template<typename T>
+struct is_cont {
+  static const bool value = !std::is_pod<T>::value;
+};
+
+template<typename T>
+inline size_t serialized_size(const T& obj) {
+  return sizeof(T);
+}
+
+template<typename T>
+inline size_t serialized_size(const nnvm::Tuple<T>& obj) {
+  if (is_cont<T>::value) {
+    size_t sum_val = 4;
+    for (auto& el : obj) {
+      sum_val += serialized_size(el);
+    }
+    return sum_val;
+  } else {
+    return 4 + (obj.ndim() * sizeof(T));
+  }
+}
+
+template<typename T>
+inline size_t serialized_size(const std::vector<T>& obj) {
+  if (is_cont<T>::value) {
+    size_t sum_val = 4;
+    for (T i : obj) {
+      sum_val += serialized_size(i);
+    }
+    return sum_val;
+  } else {
+    return sizeof(T) * obj.size() + 4;
+  }
+}
+
+template<typename K, typename V>
+inline size_t serialized_size(const std::pair<K, V>& obj) {
+  return serialized_size(obj.first) + serialized_size(obj.second);
+}
+
+template<typename K, typename V>
+inline size_t serialized_size(const std::map<K, V>& obj) {
+  size_t sum_val = 4;
+  if (is_cont<K>::value && is_cont<V>::value) {
+    for (auto p : obj) {
+      sum_val += serialized_size(p.first) + serialized_size(p.second);
+    }
+  } else if (is_cont<K>::value) {
+    for (auto p : obj) {
+      sum_val += serialized_size(p.first);
+    }
+    sum_val += sizeof(V) * obj.size();
+  } else if (is_cont<K>::value) {
+    for (auto p : obj) {
+      sum_val += serialized_size(p.second);
+    }
+    sum_val += sizeof(K) * obj.size();
+  } else {
+    sum_val += (sizeof(K) + sizeof(V)) * obj.size();
+  }
+  return sum_val;
+}
+
+template<typename K, typename V>
+inline size_t serialized_size(const std::unordered_map<K, V>& obj) {
+  size_t sum_val = 4;
+  if (is_cont<K>::value && is_cont<V>::value) {
+    for (auto p : obj) {
+      sum_val += serialized_size(p.first) + serialized_size(p.second);
+    }
+  } else if (is_cont<K>::value) {
+    for (auto p : obj) {
+      sum_val += serialized_size(p.first);
+    }
+    sum_val += sizeof(V) * obj.size();
+  } else if (is_cont<K>::value) {
+    for (auto p : obj) {
+      sum_val += serialized_size(p.second);
+    }
+    sum_val += sizeof(K) * obj.size();
+  } else {
+    sum_val += (sizeof(K) + sizeof(V)) * obj.size();
+  }
+  return sum_val;
+}
+
+template<typename K>
+inline size_t serialized_size(const std::set<K>& obj) {
+  if (is_cont<K>::value) {
+    size_t sum_val = 4;
+    for (auto& el : obj) {
+      sum_val += serialized_size(el);
+    }
+    return sum_val;
+  } else {
+    return (sizeof(K) * obj.size()) + 4;
+  }
+}
+
+template<typename K>
+inline size_t serialized_size(const std::unordered_set<K>& obj) {
+  if (is_cont<K>::value) {
+    size_t sum_val = 4;
+    for (auto& el : obj) {
+      sum_val += serialized_size(el);
+    }
+    return sum_val;
+  } else {
+    return (sizeof(K) * obj.size()) + 4;
+  }
+}
+
+template<>
+inline size_t serialized_size(const std::string& obj) {
+  return obj.size() + 4;
+}
+
+template<int I>
+struct serialized_size_tuple {
+  template<typename... Args>
+  static inline size_t compute(const std::tuple<Args...>& obj) {
+    return serialized_size(std::get<I>(obj)) + serialized_size_tuple<I-1>::compute(obj);
+  }
+};
+
+template<>
+struct serialized_size_tuple<0> {
+  template<typename... Args>
+  static inline size_t compute(const std::tuple<Args...>& obj) {
+    return serialized_size(std::get<0>(obj));
+  }
+};
+
+template<typename... Args>
+inline size_t serialized_size(const std::tuple<Args...>& obj) {
+  return serialized_size_tuple<sizeof... (Args)-1>::compute(obj);
+}
+
+//  SERIALIZE
+
+template<typename T>
+inline size_t serialize_cont_size(const T& obj, char** buffer) {
+  uint32_t size = obj.size();
+  std::memcpy(*buffer, &size, 4);
+  *buffer += 4;
+  return (size_t) size;
+}
+
+template<typename T>
+inline void serialize(const T& obj, char** buffer) {
+  std::memcpy(*buffer, &obj, sizeof(T));
+  *buffer += sizeof(T);
+}
+
+template<typename T>
+inline void serialize(const nnvm::Tuple<T>& obj, char** buffer) {
+  uint32_t size = obj.ndim();
+  std::memcpy(*buffer, &size, 4);
+  *buffer += 4;
+  for (auto& el : obj) {
+    serialize(el, buffer);
+  }
+}
+
+template<typename T>
+inline void serialize(const std::vector<T>& obj, char** buffer) {
+  auto size = serialize_cont_size(obj, buffer);
+  if (is_cont<T>::value) {
+    for (auto& el : obj) {
+      serialize(el, buffer);
+    }
+  } else {
+    std::memcpy(*buffer, &obj[0], sizeof(T) * size);
+    *buffer += sizeof(T) * size;
+  }
+}
+
+template<typename K, typename V>
+inline void serialize(const std::pair<K, V>& obj, char** buffer) {
+  serialize(obj.first, buffer);
+  serialize(obj.second, buffer);
+}
+
+template<typename K, typename V>
+inline void serialize(const std::map<K, V>& obj, char** buffer) {
+  serialize_cont_size(obj, buffer);
+  for (auto& p : obj) {
+    serialize(p.first, buffer);
+    serialize(p.second, buffer);
+  }
+}
+
+template<typename K, typename V>
+inline void serialize(const std::unordered_map<K, V>& obj, char** buffer) {
+  serialize_cont_size(obj, buffer);
+  for (auto& p : obj) {
+    serialize(p.first, buffer);
+    serialize(p.second, buffer);
+  }
+}
+
+template<typename K>
+inline void serialize(const std::set<K>& obj, char** buffer) {
+  serialize_cont_size(obj, buffer);
+  for (auto& el : obj) {
+    serialize(el, buffer);
+  }
+}
+
+template<typename K>
+inline void serialize(const std::unordered_set<K>& obj, char** buffer) {
+  serialize_cont_size(obj, buffer);
+  for (auto& el : obj) {
+    serialize(el, buffer);
+  }
+}
+
+template<>
+inline void serialize(const std::string& obj, char** buffer) {
+  auto size = serialize_cont_size(obj, buffer);
+  std::memcpy(*buffer, &obj[0], size);
+  *buffer += size;
+}
+
+template<int I>
+struct serialize_tuple {
+  template<typename... Args>
+  static inline void compute(const std::tuple<Args...>& obj, char** buffer) {
+    serialize_tuple<I-1>::compute(obj, buffer);
+    serialize(std::get<I>(obj), buffer);
+  }
+};
+
+template<>
+struct serialize_tuple<0> {
+  template<typename... Args>
+  static inline void compute(const std::tuple<Args...>& obj, char** buffer) {
+    serialize(std::get<0>(obj), buffer);
+  }
+};
+
+template<typename... Args>
+inline void serialize(const std::tuple<Args...>& obj, char** buffer) {
+  serialize_tuple<sizeof... (Args)-1>::compute(obj, buffer);
+}
+
+
+// Deserializer
+template<typename T>
+inline size_t deserialize_cont_size(T* obj, const std::string& buffer, size_t* curr_pos) {
+  uint32_t size = obj->size();
+  std::memcpy(&size, &buffer[*curr_pos], 4);
+  *curr_pos += 4;
+  return (size_t) size;
+}
+
+template<typename T>
+inline void deserialize(T* obj, const std::string& buffer, size_t* curr_pos) {
+  std::memcpy(obj, &buffer[*curr_pos], sizeof(T));
+  *curr_pos += sizeof(T);
+}
+
+template<typename T>
+inline void deserialize(nnvm::Tuple<T>* obj, const std::string& buffer, size_t* curr_pos) {
+  uint32_t size = obj->ndim();
+  std::memcpy(&size, &buffer[*curr_pos], 4);
+  *curr_pos += 4;
+  obj->SetDim(size);
+  for (size_t i = 0; i < size; ++i) {
+    deserialize((*obj)[i], buffer, curr_pos);
+  }
+}
+
+template<typename T>
+inline void deserialize(std::vector<T>* obj, const std::string& buffer, size_t* curr_pos) {
+  auto size = deserialize_cont_size(obj, buffer, curr_pos);
+  obj->resize(size);
+  if (is_cont<T>::value) {
+    for (size_t i = 0; i < size; ++i) {
+      deserialize((*obj)[i], buffer, curr_pos);
+    }
+  } else {
+    std::memcpy(&(obj->front), &buffer[*curr_pos], sizeof(T) * size);
+    *curr_pos += sizeof(T) * size;
+  }
+}
+
+template<typename K, typename V>
+inline void deserialize(std::pair<K, V>* obj, const std::string& buffer, size_t* curr_pos) {
+  deserialize(obj->first, buffer, curr_pos);
+  deserialize(obj->second, buffer, curr_pos);
+}
+
+template<typename K, typename V>
+inline void deserialize(std::map<K, V>* obj, const std::string& buffer, size_t* curr_pos) {
+  auto size = deserialize_cont_size(obj, buffer, curr_pos);
+  K first;
+  for (size_t i = 0; i < size; ++i) {
+    deserialize(&first, buffer, curr_pos);
+    deserialize(&(*obj)[first], buffer, curr_pos);
+  }
+}
+
+template<typename K, typename V>
+inline void deserialize(std::unordered_map<K, V>* obj,
+                        const std::string& buffer, size_t* curr_pos) {
+  auto size = deserialize_cont_size(obj, buffer, curr_pos);
+  K first;
+  for (size_t i = 0; i < size; ++i) {
+    deserialize(first, buffer, curr_pos);
+    deserialize((*obj)[first], buffer, curr_pos);
+  }
+}
+
+template<typename K>
+inline void deserialize(std::set<K>* obj, const std::string& buffer, size_t* curr_pos) {
+  auto size = deserialize_cont_size(obj, buffer, curr_pos);
+  K first;
+  for (size_t i = 0; i < size; ++i) {
+    deserialize(first, buffer, curr_pos);
+    obj->insert(first);
+  }
+}
+
+template<typename K>
+inline void deserialize(std::unordered_set<K>* obj, const std::string& buffer, size_t* curr_pos) {
+  auto size = deserialize_cont_size(obj, buffer, curr_pos);
+  K first;
+  for (size_t i = 0; i < size; ++i) {
+    deserialize(first, buffer, curr_pos);
+    obj->insert(first);
+  }
+}
+
+template<>
+inline void deserialize(std::string* obj, const std::string& buffer, size_t* curr_pos) {
+  auto size = deserialize_cont_size(obj, buffer, curr_pos);
+  obj->resize(size);
+  std::memcpy(&(obj->front()), &buffer[*curr_pos], size);
+  *curr_pos += size;
+}
+
+template<int I>
+struct deserialize_tuple {
+  template<typename... Args>
+  static inline void compute(std::tuple<Args...>* obj,
+                             const std::string& buffer, size_t* curr_pos) {
+    deserialize_tuple<I-1>::compute(obj, buffer, curr_pos);
+    deserialize(&std::get<I>(*obj), buffer, curr_pos);
+  }
+};
+
+template<>
+struct deserialize_tuple<0> {
+  template<typename... Args>
+  static inline void compute(std::tuple<Args...>* obj,
+                             const std::string& buffer, size_t* curr_pos) {
+    deserialize(&std::get<0>(*obj), buffer, curr_pos);
+  }
+};
+
+template<typename... Args>
+inline void deserialize(std::tuple<Args...>* obj, const std::string& buffer, size_t* curr_pos) {
+  deserialize_tuple<sizeof... (Args)-1>::compute(obj, buffer, curr_pos);
+}
+
+
+template<typename T>
+inline void Serialize(const T& obj, std::string* serialized_data) {
+  serialized_data->resize(serialized_size(obj));
+  char* curr_pos = &(serialized_data->front());
+  serialize(obj, &curr_pos);
+  CHECK_EQ((int64_t)curr_pos - (int64_t)&(serialized_data->front()),
+           serialized_data->size());
+}
+
+template<typename T>
+inline void Deserialize(T* obj, const std::string& serialized_data) {
+  size_t curr_pos = 0;
+  deserialize(obj, serialized_data, &curr_pos);
+  CHECK_EQ(curr_pos, serialized_data.size());
+}
+
+}  // namespace common
+}  // namespace mxnet
+#endif  // MXNET_COMMON_SERIALIZATION_H_
diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index 99b1b162eae..9e4276f283d 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -178,6 +178,27 @@ Graph InferStorageType(Graph&& graph,
                        StorageTypeVector&& storage_type_inputs = StorageTypeVector(),
                        const std::string& storage_type_attr_key = "");
 
+/*! \brief The default storage type inference function, which assigns all undefined
+ *         storage types to kDefaultStorage. If all of input and output storage types
+ *         are kDefaultStorage, DispatchMode::kFCompute is assigned to dispatch_mode. Otherwise,
+ *         DispatchMode::kFComputeFallback is assigned to dispatch_mode.
+ */
+bool DefaultStorageType(const nnvm::NodeAttrs& attrs,
+                        const int dev_mask,
+                        DispatchMode* dispatch_mode,
+                        std::vector<int> *iattr,
+                        std::vector<int> *oattr);
+
+/*!
+ * \brief Replace subgraphs by TRT (forward only)
+ */
+Graph ReplaceSubgraph(Graph&& g,
+                      std::unordered_set<nnvm::Node*> set_subgraph,
+                      std::unordered_map<std::string, NDArray>* const params_map);
+
+std::vector<std::unordered_set<nnvm::Node*>> GetTrtCompatibleSubsets(const Graph& g,
+    std::unordered_map<std::string, NDArray>* const params_map);
+
 }  // namespace exec
 }  // namespace mxnet
 
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index e28867d5488..f2cb99c0910 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -34,6 +34,13 @@
 #include "../common/utils.h"
 #include "../common/exec_utils.h"
 
+#if MXNET_USE_TENSORRT
+#include <onnx/onnx.pb.h>
+#include <NvInfer.h>
+#include "./onnx_to_tensorrt.h"
+#include "../operator/contrib/tensorrt-inl.h"
+#endif  // MNET_USE_TENSORRT
+
 namespace mxnet {
 namespace exec {
 
@@ -781,8 +788,13 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
           << arg_name << " for the current executor";
         aux_state_vec->emplace_back(aux_nd);
       } else {
-        EmplaceBackZeros(inferred_stype, inferred_shape, aux_state_ctxes[aux_top],
-                         inferred_dtype, aux_state_vec);
+        auto it = shared_buffer->find(arg_name);
+        if ( it != shared_buffer->end() ) {
+          aux_state_vec->push_back(std::move(it->second.Copy(aux_state_ctxes[aux_top])));
+        } else {
+          aux_state_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape,
+                                                    aux_state_ctxes[aux_top], inferred_dtype)));
+        }
       }  // if (has_shared_exec)
       data_entry_[eid] = aux_state_vec->back();
       aux_state_map_.emplace(arg_name, aux_state_vec->back());
@@ -842,9 +854,25 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
       } else {  // !shared_arg_names.count(arg_name)
         // model parameter, row_sparse ndarray sharing enabled
         bool enable_row_sparse_sharing = true;
-        in_arg_vec->emplace_back(ReshapeOrCreate(arg_name, inferred_shape, inferred_dtype,
-                                                 inferred_stype, in_arg_ctxes[arg_top],
-                                                 shared_buffer, enable_row_sparse_sharing));
+        if (dmlc::GetEnv("MXNET_USE_TENSORRT", false)) {
+            #if MXNET_USE_TENSORRT
+                auto it = shared_buffer->find(arg_name);
+                if (it != shared_buffer->end()) {
+                  in_arg_vec->push_back(std::move(it->second.Copy(in_arg_ctxes[arg_top])));
+                } else {
+                  in_arg_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape,
+                                                  in_arg_ctxes[arg_top], inferred_dtype)));
+                }
+            #else
+                LOG(FATAL) << "Env. var. MXNET_USE_TENSORRT = 1 set, but MxNet wasn't "
+                  << "built with TensorRT. Add USE_TENSORRT = 1 in config.mk";
+            #endif
+        } else {
+            in_arg_vec->emplace_back(ReshapeOrCreate(
+                arg_name, inferred_shape, inferred_dtype, inferred_stype,
+                in_arg_ctxes[arg_top], shared_buffer, enable_row_sparse_sharing));
+        }
+
         // gradient for model parameter, row_sparse ndarray sharing disabled
         if (kNullOp == grad_req_types[arg_top]) {
           arg_grad_vec->emplace_back();
@@ -940,6 +968,91 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol,
   this->InitOpSegs();
 }
 
+
+Graph GraphExecutor::ReinitGraph(Graph&& g, const Context &default_ctx,
+                                 const std::map<std::string, Context> &ctx_map,
+                                 std::vector<Context> *in_arg_ctxes,
+                                 std::vector<Context> *arg_grad_ctxes,
+                                 std::vector<Context> *aux_state_ctxes,
+                                 std::vector<OpReqType> *grad_req_types,
+                                 std::unordered_map<std::string, TShape> *arg_shape_map,
+                                 std::unordered_map<std::string, int> *arg_dtype_map,
+                                 std::unordered_map<std::string, int> *arg_stype_map,
+                                 std::unordered_map<std::string, NDArray> *params_map) {
+  std::unordered_set<std::string> to_remove_params;
+  for (auto& el : *params_map) {
+    to_remove_params.insert(el.first);
+  }
+
+  DFSVisit(g.outputs, [&to_remove_params](const nnvm::NodePtr n) {
+    to_remove_params.erase(n->attrs.name);
+  });
+
+  for (auto& el : to_remove_params) {
+    params_map->erase(el);
+    arg_shape_map->erase(el);
+    arg_dtype_map->erase(el);
+    arg_stype_map->erase(el);
+  }
+  const auto &idx = g.indexed_graph();
+  num_forward_inputs_ = idx.input_nodes().size();
+  in_arg_ctxes->resize(num_forward_inputs_ - idx.mutable_input_nodes().size());
+  arg_grad_ctxes->resize(num_forward_inputs_ - idx.mutable_input_nodes().size());
+  grad_req_types->resize(num_forward_inputs_ - idx.mutable_input_nodes().size());
+  aux_state_ctxes->resize(idx.mutable_input_nodes().size());
+
+  // create "device" and "context" attrs for the graph
+  g = AssignContext(g, default_ctx, ctx_map, *in_arg_ctxes, *arg_grad_ctxes,
+                    *aux_state_ctxes, *grad_req_types, num_forward_inputs_,
+                    num_forward_outputs_);
+
+  // get number of nodes used in forward pass
+  num_forward_nodes_ = 0;
+  for (size_t i = 0; i < num_forward_outputs_; ++i) {
+    num_forward_nodes_ = std::max(
+        num_forward_nodes_, static_cast<size_t>(idx.outputs()[i].node_id + 1));
+  }
+  nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape());
+  nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1);
+  StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage);
+  for (size_t i = 0; i < num_forward_inputs_; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const std::string &name = idx[nid].source->attrs.name;
+    auto it1 = arg_shape_map->find(name);
+    if (arg_shape_map->end() != it1) {
+      arg_shapes[i] = it1->second;
+    }
+    auto it2 = arg_dtype_map->find(name);
+    if (arg_dtype_map->end() != it2) {
+      arg_dtypes[i] = it2->second;
+    }
+    auto it3 = arg_stype_map->find(name);
+    if (arg_stype_map->end() != it3) {
+      arg_stypes[i] = it3->second;
+    }
+  }
+  g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+  if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+    HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
+                          g.GetAttr<nnvm::ShapeVector>("shape"));
+  }
+
+  g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+  if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
+    HandleInferTypeError(num_forward_inputs_, g.indexed_graph(),
+                         g.GetAttr<nnvm::DTypeVector>("dtype"));
+  }
+
+  g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
+
+  if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
+    HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(),
+                                g.GetAttr<StorageTypeVector>("storage_type"));
+  }
+
+  return g;
+}
+
 /*!
  * \brief GraphExecutor initializer for simple bind flow in
  * which only certain input shapes and dtypes are provided by users.
@@ -957,22 +1070,22 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol,
 void GraphExecutor::Init(nnvm::Symbol symbol,
                          const Context& default_ctx,
                          const std::map<std::string, Context>& ctx_map,
-                         const std::vector<Context>& in_arg_ctxes,
-                         const std::vector<Context>& arg_grad_ctxes,
-                         const std::vector<Context>& aux_state_ctxes,
-                         const std::unordered_map<std::string, TShape>& arg_shape_map,
-                         const std::unordered_map<std::string, int>& arg_dtype_map,
-                         const std::unordered_map<std::string, int>& arg_stype_map,
-                         const std::vector<OpReqType>& grad_req_types,
-                         const std::unordered_set<std::string>& shared_arg_names,
+                         std::vector<Context>* in_arg_ctxes,
+                         std::vector<Context>* arg_grad_ctxes,
+                         std::vector<Context>* aux_state_ctxes,
+                         std::unordered_map<std::string, TShape>* arg_shape_map,
+                         std::unordered_map<std::string, int>* arg_dtype_map,
+                         std::unordered_map<std::string, int>* arg_stype_map,
+                         std::vector<OpReqType>* grad_req_types,
+                         std::unordered_set<std::string>* shared_arg_names,
                          std::vector<NDArray>* in_arg_vec,
                          std::vector<NDArray>* arg_grad_vec,
                          std::vector<NDArray>* aux_state_vec,
                          std::unordered_map<std::string, NDArray>* shared_buffer,
                          Executor* shared_exec,
                          const nnvm::NodeEntryMap<NDArray>& feed_dict) {
-  nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes,
-                            aux_state_ctxes, grad_req_types);
+  nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, *in_arg_ctxes, *arg_grad_ctxes,
+                            *aux_state_ctxes, *grad_req_types);
   // The following code of shape and dtype inferences and argument
   // initialization is for simple_bind only. Regular bind operation
   // should do this differently.
@@ -986,16 +1099,16 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
   for (size_t i = 0; i < num_forward_inputs_; ++i) {
     const uint32_t nid = idx.input_nodes().at(i);
     const std::string& name = idx[nid].source->attrs.name;
-    auto it1 = arg_shape_map.find(name);
-    if (arg_shape_map.end() != it1) {
+    auto it1 = arg_shape_map->find(name);
+    if (arg_shape_map->end() != it1) {
       arg_shapes[i] = it1->second;
     }
-    auto it2 = arg_dtype_map.find(name);
-    if (arg_dtype_map.end() != it2) {
+    auto it2 = arg_dtype_map->find(name);
+    if (arg_dtype_map->end() != it2) {
       arg_dtypes[i] = it2->second;
     }
-    auto it3 = arg_stype_map.find(name);
-    if (arg_stype_map.end() != it3) {
+    auto it3 = arg_stype_map->find(name);
+    if (arg_stype_map->end() != it3) {
       arg_stypes[i] = it3->second;
     }
   }
@@ -1017,20 +1130,37 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
                                 g.GetAttr<StorageTypeVector>("storage_type"));
   }
 
+  if (dmlc::GetEnv("MXNET_USE_TENSORRT", false)) {
+      #if MXNET_USE_TENSORRT
+          auto trt_groups = GetTrtCompatibleSubsets(g, shared_buffer);
+          for (auto trt_group : trt_groups) {
+            if (trt_group.size() > 1) {
+              g = ReplaceSubgraph(std::move(g), trt_group, shared_buffer);
+              g = ReinitGraph(std::move(g), default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes,
+                              aux_state_ctxes, grad_req_types, arg_shape_map, arg_dtype_map,
+                              arg_stype_map, shared_buffer);
+            }
+          }
+      #else
+          LOG(FATAL) << "Env. var. MXNET_USE_TENSORRT = 1 set but MxNet wasn't "
+            << "built with TensorRT. Add USE_TENSORRT = 1 to config.mk";
+      #endif
+  }
+
   // Create in_args, arg_grads, and aux_states using
   // the inferred shapes and dtypes.
   if (nullptr == shared_buffer) {  // regular simple bind
-    InitArguments(idx, g.GetAttr<nnvm::ShapeVector>("shape"),
+    InitArguments(g.indexed_graph(), g.GetAttr<nnvm::ShapeVector>("shape"),
                   g.GetAttr<nnvm::DTypeVector>("dtype"),
                   g.GetAttr<StorageTypeVector>("storage_type"),
-                  in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
-                  grad_req_types, in_arg_vec, arg_grad_vec, aux_state_vec);
+                  *in_arg_ctxes, *arg_grad_ctxes, *aux_state_ctxes,
+                  *grad_req_types, in_arg_vec, arg_grad_vec, aux_state_vec);
   } else {  // simple bind using shared data arrays and shared_exec
-    InitArguments(idx, g.GetAttr<nnvm::ShapeVector>("shape"),
+    InitArguments(g.indexed_graph(), g.GetAttr<nnvm::ShapeVector>("shape"),
                   g.GetAttr<nnvm::DTypeVector>("dtype"),
                   g.GetAttr<StorageTypeVector>("storage_type"),
-                  in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
-                  grad_req_types, shared_arg_names, shared_exec,
+                  *in_arg_ctxes, *arg_grad_ctxes, *aux_state_ctxes,
+                  *grad_req_types, *shared_arg_names, shared_exec,
                   shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec);
   }
   // The above code of shape and dtype inferences and argument
@@ -1698,14 +1828,14 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start,
 Executor *Executor::SimpleBind(nnvm::Symbol symbol,
                                const Context& default_ctx,
                                const std::map<std::string, Context>& group2ctx,
-                               const std::vector<Context>& in_arg_ctxes,
-                               const std::vector<Context>& arg_grad_ctxes,
-                               const std::vector<Context>& aux_state_ctxes,
-                               const std::unordered_map<std::string, TShape>& arg_shape_map,
-                               const std::unordered_map<std::string, int>& arg_dtype_map,
-                               const std::unordered_map<std::string, int>& arg_stype_map,
-                               const std::vector<OpReqType>& grad_req_types,
-                               const std::unordered_set<std::string>& shared_arg_names,
+                               std::vector<Context>* in_arg_ctxes,
+                               std::vector<Context>* arg_grad_ctxes,
+                               std::vector<Context>* aux_state_ctxes,
+                               std::unordered_map<std::string, TShape>* arg_shape_map,
+                               std::unordered_map<std::string, int>* arg_dtype_map,
+                               std::unordered_map<std::string, int>* arg_stype_map,
+                               std::vector<OpReqType>* grad_req_types,
+                               std::unordered_set<std::string>* shared_arg_names,
                                std::vector<NDArray>* in_args,
                                std::vector<NDArray>* arg_grads,
                                std::vector<NDArray>* aux_states,
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index 24f98894912..052df707498 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -91,14 +91,14 @@ class GraphExecutor : public Executor {
   void Init(nnvm::Symbol symbol,
             const Context& default_ctx,
             const std::map<std::string, Context>& ctx_map,
-            const std::vector<Context>& in_arg_ctxes,
-            const std::vector<Context>& arg_grad_ctxes,
-            const std::vector<Context>& aux_state_ctxes,
-            const std::unordered_map<std::string, TShape>& arg_shape_map,
-            const std::unordered_map<std::string, int>& arg_dtype_map,
-            const std::unordered_map<std::string, int>& arg_stype_map,
-            const std::vector<OpReqType>& grad_req_types,
-            const std::unordered_set<std::string>& shared_arg_names,
+            std::vector<Context>* in_arg_ctxes,
+            std::vector<Context>* arg_grad_ctxes,
+            std::vector<Context>* aux_state_ctxes,
+            std::unordered_map<std::string, TShape>* arg_shape_map,
+            std::unordered_map<std::string, int>* arg_dtype_map,
+            std::unordered_map<std::string, int>* arg_stype_map,
+            std::vector<OpReqType>* grad_req_types,
+            std::unordered_set<std::string>* shared_arg_names,
             std::vector<NDArray>* in_arg_vec,
             std::vector<NDArray>* arg_grad_vec,
             std::vector<NDArray>* aux_state_vec,
@@ -213,6 +213,17 @@ class GraphExecutor : public Executor {
   // perform bulking and segmentation on a training graph
   void BulkTrainingOpSegs(size_t total_num_nodes);
 
+  Graph ReinitGraph(Graph&& g, const Context &default_ctx,
+                    const std::map<std::string, Context> &ctx_map,
+                    std::vector<Context>* in_arg_ctxes,
+                    std::vector<Context>* arg_grad_ctxes,
+                    std::vector<Context>* aux_state_ctxes,
+                    std::vector<OpReqType>* grad_req_types,
+                    std::unordered_map<std::string, TShape>* arg_shape_map,
+                    std::unordered_map<std::string, int>* arg_dtype_map,
+                    std::unordered_map<std::string, int>* arg_stype_map,
+                    std::unordered_map<std::string, NDArray> *params_map);
+
   // internal graph
   nnvm::Graph graph_;
   // operator node
diff --git a/src/executor/onnx_to_tensorrt.cc b/src/executor/onnx_to_tensorrt.cc
new file mode 100644
index 00000000000..0705575e926
--- /dev/null
+++ b/src/executor/onnx_to_tensorrt.cc
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file onnx_to_tensorrt.cc
+ * \brief TensorRT integration with the MxNet executor
+ * \author Marek Kolodziej, Clement Fuji Tsang
+ */
+
+#if MXNET_USE_TENSORRT
+
+#include "./onnx_to_tensorrt.h"
+
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/text_format.h>
+#include <NvInfer.h>
+#include <onnx/onnx.pb.h>
+#include <onnx-tensorrt/NvOnnxParser.h>
+#include <onnx-tensorrt/NvOnnxParserRuntime.h>
+#include <onnx-tensorrt/PluginFactory.hpp>
+#include <onnx-tensorrt/plugin_common.hpp>
+
+using std::cout;
+using std::cerr;
+using std::endl;
+
+namespace onnx_to_tensorrt {
+
+struct InferDeleter {
+  template<typename T>
+    void operator()(T* obj) const {
+      if ( obj ) {
+        obj->destroy();
+      }
+    }
+};
+
+template<typename T>
+inline std::shared_ptr<T> infer_object(T* obj) {
+  if ( !obj ) {
+    throw std::runtime_error("Failed to create object");
+  }
+  return std::shared_ptr<T>(obj, InferDeleter());
+}
+
+std::string onnx_ir_version_string(int64_t ir_version = onnx::IR_VERSION) {
+  int onnx_ir_major = ir_version / 1000000;
+  int onnx_ir_minor = ir_version % 1000000 / 10000;
+  int onnx_ir_patch = ir_version % 10000;
+  return (std::to_string(onnx_ir_major) + "." +
+    std::to_string(onnx_ir_minor) + "." +
+    std::to_string(onnx_ir_patch));
+}
+
+void print_version() {
+  cout << "Parser built against:" << endl;
+  cout << "  ONNX IR version:  " << onnx_ir_version_string(onnx::IR_VERSION) << endl;
+  cout << "  TensorRT version: "
+    << NV_TENSORRT_MAJOR << "."
+    << NV_TENSORRT_MINOR << "."
+    << NV_TENSORRT_PATCH << endl;
+}
+
+nvinfer1::ICudaEngine* onnxToTrtCtx(
+        std::string onnx_model,
+        int32_t max_batch_size,
+        size_t max_workspace_size,
+        int model_dtype_nbits,
+        nvinfer1::ILogger::Severity verbosity,
+        bool print_layer_info,
+        bool debug_builder) {
+  GOOGLE_PROTOBUF_VERIFY_VERSION;
+
+  nvinfer1::DataType model_dtype;
+  switch ( model_dtype_nbits ) {
+    case 32:
+      model_dtype = nvinfer1::DataType::kFLOAT;
+      break;
+    case 16:
+      model_dtype = nvinfer1::DataType::kHALF;
+      break;
+    default:
+      cerr << "ERROR: Invalid model data type bit depth: " << model_dtype_nbits << endl;
+      break;
+  }
+
+  TRT_Logger trt_logger((nvinfer1::ILogger::Severity)verbosity);
+  auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger));
+  auto trt_network = infer_object(trt_builder->createNetwork());
+  auto trt_parser  = infer_object(nvonnxparser::createParser(
+                                      *trt_network, trt_logger));
+
+  ::ONNX_NAMESPACE::ModelProto parsed_model;
+  bool can_parse_onnx = parsed_model.ParseFromString(onnx_model);
+
+  if ( !trt_parser->parse(onnx_model.c_str(), onnx_model.size()) ) {
+      int nerror = trt_parser->getNbErrors();
+      for ( int i=0; i < nerror; ++i ) {
+        nvonnxparser::IParserError const* error = trt_parser->getError(i);
+        if ( error->node() != -1 ) {
+          ::ONNX_NAMESPACE::NodeProto const& node =
+            parsed_model.graph().node(error->node());
+          cerr << "While parsing node number " << error->node()
+               << " [" << node.op_type();
+          if ( node.output().size() ) {
+            cerr << " -> \"" << node.output(0) << "\"";
+          }
+          cerr << "]:" << endl;
+          if ( static_cast<int>(verbosity) >= \
+            static_cast<int>(nvinfer1::ILogger::Severity::kINFO) ) {
+            cerr << "--- Begin node ---" << endl;
+            cerr << node.DebugString() << endl;
+            cerr << "--- End node ---" << endl;
+          }
+        }
+        cerr << "ERROR: "
+             << error->file() << ":" << error->line()
+             << " In function " << error->func() << ":\n"
+             << "[" << static_cast<int>(error->code()) << "] " << error->desc()
+             << endl;
+      }
+      throw dmlc::Error("Cannot parse ONNX into TensorRT Engine");
+  }
+
+  bool fp16 = trt_builder->platformHasFastFp16();
+
+  if ( static_cast<int>(verbosity) >= static_cast<int>(nvinfer1::ILogger::Severity::kWARNING) ) {
+    cout << "Building TensorRT engine, FP16 available:"<< fp16 << endl;
+    cout << "    Max batch size:     " << max_batch_size << endl;
+    cout << "    Max workspace size: " << max_workspace_size / (1024. * 1024) << " MiB" << endl;
+  }
+  trt_builder->setMaxBatchSize(max_batch_size);
+  trt_builder->setMaxWorkspaceSize(max_workspace_size);
+  if ( fp16 && model_dtype == nvinfer1::DataType::kHALF ) {
+    trt_builder->setHalf2Mode(true);
+  } else if ( model_dtype == nvinfer1::DataType::kINT8 ) {
+    // Add Int8 support
+    // trt_builder->setInt8Mode(true);
+    LOG(FATAL) << "ERROR: Int8 mode not yet supported";
+  }
+  trt_builder->setDebugSync(debug_builder);
+  nvinfer1::ICudaEngine* trt_engine = trt_builder->buildCudaEngine(*trt_network.get());
+  return trt_engine;
+}
+
+}  // namespace onnx_to_tensorrt
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/executor/onnx_to_tensorrt.h b/src/executor/onnx_to_tensorrt.h
new file mode 100644
index 00000000000..6f83f476bd5
--- /dev/null
+++ b/src/executor/onnx_to_tensorrt.h
@@ -0,0 +1,80 @@
+#ifndef MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_
+#define MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file onnx_to_tensorrt.h
+ * \brief TensorRT integration with the MxNet executor
+ * \author Marek Kolodziej, Clement Fuji Tsang
+ */
+
+#if MXNET_USE_TENSORRT
+
+#include <fstream>
+#include <iostream>
+#include <NvInfer.h>
+#include <sstream>
+#include <string>
+
+#include "../operator/contrib/tensorrt-inl.h"
+
+namespace onnx_to_tensorrt {
+
+class TRT_Logger : public nvinfer1::ILogger {
+        nvinfer1::ILogger::Severity _verbosity;
+        std::ostream* _ostream;
+ public:
+        TRT_Logger(Severity verbosity = Severity::kWARNING,
+                   std::ostream& ostream = std::cout)
+                : _verbosity(verbosity), _ostream(&ostream) {}
+        void log(Severity severity, const char* msg) override {
+                if ( severity <= _verbosity ) {
+                        time_t rawtime = std::time(0);
+                        char buf[256];
+                        strftime(&buf[0], 256,
+                                 "%Y-%m-%d %H:%M:%S",
+                                 std::gmtime(&rawtime));
+                        const char* sevstr = (severity == Severity::kINTERNAL_ERROR ? "    BUG" :
+                                              severity == Severity::kERROR          ? "  ERROR" :
+                                              severity == Severity::kWARNING        ? "WARNING" :
+                                              severity == Severity::kINFO           ? "   INFO" :
+                                              "UNKNOWN");
+                        (*_ostream) << "[" << buf << " " << sevstr << "] "
+                                    << msg
+                                    << std::endl;
+                }
+        }
+};
+
+nvinfer1::ICudaEngine* onnxToTrtCtx(
+        std::string onnx_model,
+        int32_t max_batch_size = 32,
+        size_t max_workspace_size = 1L << 30,
+        int model_dtype_nbits = 32,
+        nvinfer1::ILogger::Severity verbosity = nvinfer1::ILogger::Severity::kWARNING,
+        bool print_layer_info = false,
+        bool debug_builder = false);
+
+}  // namespace onnx_to_tensorrt
+
+#endif  // MXNET_USE_TENSORRT
+
+#endif  // MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_
diff --git a/src/executor/tensorrt_pass.cc b/src/executor/tensorrt_pass.cc
new file mode 100644
index 00000000000..5902be725d1
--- /dev/null
+++ b/src/executor/tensorrt_pass.cc
@@ -0,0 +1,583 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file tensorrt_pass.cc
+ * \brief Replace TRT compatible subgraphs by TRT engines
+ * \author Clement Fuji Tsang
+ */
+
+#if MXNET_USE_TENSORRT
+
+#include <mxnet/base.h>
+#include <mxnet/operator.h>
+#include <mxnet/op_attr_types.h>
+#include <nnvm/graph_attr_types.h>
+#include <onnx/onnx.pb.h>
+#include <NvInfer.h>
+
+#include "./onnx_to_tensorrt.h"
+#include "./exec_pass.h"
+#include "../operator/contrib/nnvm_to_onnx-inl.h"
+
+namespace mxnet {
+namespace exec {
+
+using NodePtr = nnvm::NodePtr;
+
+/*!
+ * \brief Custom graph class, which will contain bi-directional nodes
+ * we need to compute DFS and reverse DFS for graph partitioning
+ */
+class BidirectionalGraph {
+ public:
+  struct Node {
+    nnvm::Node* nnvmptr;
+    std::vector<Node*> inputs;
+    std::vector<Node*> outputs;
+  };
+  std::vector<Node> nodes;
+  std::unordered_map<nnvm::Node*, uint32_t> nnvm2nid;
+  std::vector<Node*> outputs;
+  static const std::unordered_set<std::string> unconditionalTRTop;
+
+  explicit BidirectionalGraph(const Graph &g) {
+    auto& idx = g.indexed_graph();
+    auto num_nodes = idx.num_nodes();
+    nodes.reserve(num_nodes);
+    nnvm2nid.reserve(num_nodes);
+    outputs.reserve(idx.outputs().size());
+    DFSVisit(g.outputs, [this](const nnvm::NodePtr& n) {
+      BidirectionalGraph::Node new_node;
+      new_node.nnvmptr = n.get();
+      nnvm2nid[n.get()] = static_cast<uint32_t>(nodes.size());
+      nodes.emplace_back(std::move(new_node));
+    });
+    for (const auto& it : nnvm2nid) {
+      nnvm::Node* nnvmnode = it.first;
+      uint32_t nid = it.second;
+      for (auto& n : nnvmnode->inputs) {
+        uint32_t input_nid = nnvm2nid[n.node.get()];
+        nodes[input_nid].outputs.emplace_back(&nodes[nid]);
+        nodes[nid].inputs.emplace_back(&nodes[input_nid]);
+      }
+    }
+    for (auto& e : g.outputs) {
+      uint32_t nid = nnvm2nid[e.node.get()];
+      outputs.emplace_back(&nodes[nid]);
+    }
+  }
+
+  template <typename FVisit>
+  void DFS(const std::vector<Node*>& heads, bool reverse, FVisit fvisit) {
+    std::unordered_set<Node*> visited;
+    std::deque<Node*> stack(heads.begin(), heads.end());
+    visited.reserve(heads.size());
+    while (!stack.empty()) {
+      Node* vertex = stack.back();
+      stack.pop_back();
+      if (visited.count(vertex) == 0) {
+        visited.insert(vertex);
+        fvisit(vertex);
+        std::vector<Node*> nexts = reverse ? vertex->inputs : vertex->outputs;
+        for (Node* node : nexts) {
+          if (visited.count(node) == 0) {
+            stack.emplace_back(node);
+          }
+        }
+      }
+    }
+  }
+
+  using t_pairset = std::pair<std::unordered_set<Node*>, std::unordered_set<Node*>>;
+  using t_pairvec = std::pair<std::vector<Node*>, std::vector<Node*>>;
+  using t_uncomp_map = std::unordered_map<Node*, std::unordered_set<Node*>>;
+
+  std::unordered_set<Node*> naive_grow_subgraph(Node* head,
+                                                std::unordered_set<Node*>* set_unused,
+                                                t_uncomp_map* uncomp_map) {
+    std::unordered_set<Node*> subgraph;
+    std::unordered_set<Node*> uncomp_set;
+    std::deque<Node*> stack;
+    stack.emplace_back(head);
+    while (!stack.empty()) {
+      Node* vertex = stack.back();
+      stack.pop_back();
+      if (set_unused->count(vertex) && !uncomp_set.count(vertex)) {
+        set_unused->erase(vertex);
+        subgraph.insert(vertex);
+        uncomp_set.insert((*uncomp_map)[vertex].begin(), (*uncomp_map)[vertex].end());
+        for (Node* input : vertex->inputs) {
+          if (set_unused->count(input) && !uncomp_set.count(input)) {
+            stack.emplace_back(input);
+          }
+        }
+        for (Node* output : vertex->outputs) {
+          if (set_unused->count(output) && !uncomp_set.count(output)) {
+            stack.emplace_back(output);
+          }
+        }
+      }
+    }
+    return subgraph;
+  }
+
+  std::vector<std::unordered_set<Node*>> get_subsets(
+    std::unordered_map<std::string, NDArray>* const params_map) {
+    std::vector<std::unordered_set<Node*>> subgraphs;
+    std::unordered_set<Node*> set_nonTRTnodes;
+    std::unordered_set<Node*> set_allnodes(nodes.size());
+    std::vector<t_pairset> separation_sets;
+    for (Node& node : nodes) {
+      if (!IsTRTCompatible(node.nnvmptr)) {
+        set_nonTRTnodes.insert(&node);
+        std::unordered_set<Node*> in_graph;
+        std::unordered_set<Node*> out_graph;
+        std::vector<Node*> dummy_head;
+        dummy_head.emplace_back(&node);
+        DFS(dummy_head, false, [&out_graph](Node* node) {
+          out_graph.insert(node);
+        });
+        DFS(dummy_head, true, [&in_graph](Node* node) {
+          in_graph.insert(node);
+        });
+        separation_sets.emplace_back(std::make_pair(in_graph, out_graph));
+      }
+      set_allnodes.emplace(&node);
+    }
+    t_uncomp_map uncomp_map;
+    std::unordered_set<Node*> set_TRTnodes;
+    set_TRTnodes.insert(set_allnodes.begin(), set_allnodes.end());
+    for (Node* n : set_nonTRTnodes) {
+      set_TRTnodes.erase(n);
+    }
+    for (Node* n : set_TRTnodes) {
+      for (t_pairset p : separation_sets) {
+        if (p.first.count(n)) {
+          uncomp_map[n].insert(p.second.begin(), p.second.end());
+        } else if (p.second.count(n)) {
+          uncomp_map[n].insert(p.first.begin(), p.first.end());
+        }
+      }
+      for (Node* nonTRTn : set_nonTRTnodes) {
+        uncomp_map[n].erase(nonTRTn);
+      }
+    }
+    std::unordered_set<Node*> set_unused;
+    set_unused.reserve(set_TRTnodes.size());
+
+    for (auto& n : set_TRTnodes) {
+      if (n->nnvmptr->attrs.op != nullptr || params_map->count(n->nnvmptr->attrs.name)) {
+        set_unused.insert(n);
+      }
+    }
+    std::unordered_set<Node*> visited;
+    std::deque<Node*> stack(outputs.begin(), outputs.end());
+    while (!stack.empty()) {
+      Node* vertex = stack.front();
+      stack.pop_front();
+      if (!visited.count(vertex)) {
+        visited.insert(vertex);
+        if (set_unused.count(vertex)) {
+          subgraphs.emplace_back(naive_grow_subgraph(vertex, &set_unused, &uncomp_map));
+        }
+        for (Node* input : vertex->inputs) {
+          stack.emplace_back(input);
+        }
+      }
+    }
+
+    return subgraphs;
+  }
+
+
+ private:
+  friend class Graph;
+
+  bool IsTRTCompatible(nnvm::Node* nodeptr) {
+    if (nodeptr->op() != nullptr) {
+      const std::string op_name = nodeptr->op()->name;
+      if (op_name == "Pooling") {
+        return (nodeptr->attrs.dict.at("pool_type") == "avg" ||
+          nodeptr->attrs.dict.at("pool_type") == "max");
+      } else if (unconditionalTRTop.count(op_name)) {
+        return true;
+      } else if (op_name == "Activation") {
+        return nodeptr->attrs.dict.at("act_type") == "relu" ||
+          nodeptr->attrs.dict.at("act_type") == "tanh" ||
+          nodeptr->attrs.dict.at("act_type") == "sigmoid";
+      }
+      return false;
+    }
+    return true;
+  }
+};  // class BidirectionalGraph
+
+/*!
+ * \brief function which transform std::vector<dmlc::any> back to Attrs (dmlc::any)
+ */
+const std::unordered_set<std::string> BidirectionalGraph::unconditionalTRTop = {
+  "Convolution",
+  "BatchNorm",
+  "elemwise_add",
+  "elemwise_sub",
+  "elemwise_mul",
+  "rsqrt",
+  "pad",
+  "Pad",
+  "mean",
+  "FullyConnected",
+  "Flatten",
+  "SoftmaxOutput",
+};
+
+
+using NodeEntrySet = std::unordered_set<nnvm::NodeEntry, nnvm::NodeEntryHash,
+                                        nnvm::NodeEntryEqual>;
+
+/*!
+ * \brief get the output nodes of the subgraph in the main graph
+ * \return a vector of the output nodes
+*/
+std::vector<nnvm::NodeEntry> GetSubgraphOutputs(Graph g,
+    std::unordered_set<nnvm::Node*> set_subgraph) {
+  std::vector<nnvm::NodeEntry> outputs;
+  NodeEntrySet _outputs;
+  for (auto& e : g.outputs) {
+    if (set_subgraph.count(e.node.get())) {
+      _outputs.insert(e);
+    }
+  }
+  DFSVisit(g.outputs, [&set_subgraph, &_outputs](const nnvm::NodePtr &node){
+    if (!set_subgraph.count(node.get())) {
+      for (auto& e : node->inputs) {
+        if (set_subgraph.count(e.node.get())) {
+          _outputs.insert(e);
+        }
+      }
+    }
+  });
+  outputs.insert(outputs.begin(), _outputs.begin(), _outputs.end());
+  return outputs;
+}
+
+
+/*!
+ * \brief get the nodes outside of the subgraph for which outputs are used in the subgraph
+ * \return a vector the nodes
+*/
+std::vector<nnvm::NodeEntry> GetSubgraphInterfaceNodes(Graph g,
+    std::unordered_set<nnvm::Node*> set_subgraph) {
+  std::vector<nnvm::NodeEntry> inputs;
+  NodeEntrySet _inputs;
+  DFSVisit(g.outputs, [&set_subgraph, &_inputs](const nnvm::NodePtr &node){
+    if (set_subgraph.count(node.get())) {
+      for (auto& e : node->inputs) {
+        if (!set_subgraph.count(e.node.get())) {
+          _inputs.insert(e);
+        }
+      }
+    }
+  });
+  inputs.insert(inputs.begin(), _inputs.begin(), _inputs.end());
+  return inputs;
+}
+
+std::unordered_map<uint32_t, uint32_t> GetGraphInputsMap(Graph g) {
+  std::unordered_map<uint32_t, uint32_t> outputs;
+  auto& idx = g.indexed_graph();
+  outputs.reserve(idx.num_nodes());
+  std::vector<uint32_t> input_nodes = idx.input_nodes();
+  for (size_t i = 0; i < input_nodes.size(); ++i) {
+    outputs[input_nodes[i]] = static_cast<uint32_t>(i);
+  }
+  return outputs;
+}
+
+/*!
+ * \brief Dummy function which creates a fake TensorRT Node
+ */
+nnvm::NodePtr ConvertNnvmGraphToOnnx(const nnvm::Graph &g,
+                                     std::unordered_map<std::string, NDArray>* const params_map) {
+  auto p = nnvm::Node::Create();
+  p->attrs.op = nnvm::Op::Get("_trt_op");
+  op::TRTParam trt_param = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(g, params_map);
+  p->attrs.dict["serialized_output_map"] = trt_param.serialized_output_map;
+  p->attrs.dict["serialized_input_map"]  = trt_param.serialized_input_map;
+  p->attrs.dict["serialized_onnx_graph"] = trt_param.serialized_onnx_graph;
+  if (p->op()->attr_parser != nullptr) {
+    p->op()->attr_parser(&(p->attrs));
+  }
+  return p;
+}
+
+/*!
+ * \brief Update attributes of the graph (such as some inputs properties)
+ */
+Graph UpdateSubgraphAttrs(Graph&& subgraph, const Graph& g,
+                          const std::unordered_map<nnvm::Node*, nnvm::NodePtr>& old2new,
+                          const nnvm::NodeEntryMap<nnvm::NodeEntry>& main_input_entry_to_sub) {
+  const auto& idx     = g.indexed_graph();
+  const auto& sub_idx = subgraph.indexed_graph();
+
+  const auto shape               = g.GetAttr<nnvm::ShapeVector>("shape");
+  const auto dtype               = g.GetAttr<nnvm::DTypeVector>("dtype");
+  const auto storage_type        = g.GetAttr<StorageTypeVector>("storage_type");
+  const auto shape_inputs        = g.GetAttr<nnvm::ShapeVector>("shape_inputs");
+  const auto dtype_inputs        = g.GetAttr<nnvm::DTypeVector>("dtype_inputs");
+  const auto storage_type_inputs = g.GetAttr<StorageTypeVector>("storage_type_inputs");
+
+  nnvm::ShapeVector sub_shape(sub_idx.num_node_entries());
+  nnvm::DTypeVector sub_dtype(sub_idx.num_node_entries());
+  StorageTypeVector sub_storage_type(sub_idx.num_node_entries());
+  nnvm::ShapeVector sub_shape_inputs(sub_idx.input_nodes().size());
+  nnvm::DTypeVector sub_dtype_inputs(sub_idx.input_nodes().size());
+  StorageTypeVector sub_storage_type_inputs(sub_idx.input_nodes().size());
+
+  const std::unordered_map<uint32_t, uint32_t> inputsindex2pos     = GetGraphInputsMap(g);
+  const std::unordered_map<uint32_t, uint32_t> sub_inputsindex2pos = GetGraphInputsMap(subgraph);
+  // map attributes from graph to subgraph
+  for (auto& p : old2new) {
+    const uint32_t nid     = idx.node_id(p.first);
+    const uint32_t sub_nid = sub_idx.node_id(p.second.get());
+    const nnvm::Op* op = sub_idx[sub_nid].source->op();
+    if (op == nullptr) {  // if it's an input node, there is only one output node entry
+      const uint32_t sub_i       = sub_idx.entry_id(sub_nid, 0);
+      const uint32_t sub_input_i = sub_inputsindex2pos.at(sub_nid);
+      const uint32_t i           = idx.entry_id(nid, 0);
+
+      sub_shape[sub_i] = shape[i];
+      sub_dtype[sub_i] = dtype[i];
+      sub_storage_type[sub_i]       = storage_type[i];
+      sub_shape_inputs[sub_input_i] = shape_inputs[inputsindex2pos.at(nid)];
+      sub_dtype_inputs[sub_input_i] = dtype_inputs[inputsindex2pos.at(nid)];
+      sub_storage_type_inputs[sub_input_i] = storage_type_inputs[inputsindex2pos.at(nid)];
+
+    } else {
+      for (size_t oi = 0; oi < op->num_outputs; ++oi) {
+        const uint32_t sub_i = sub_idx.entry_id(sub_nid, oi);
+        const uint32_t i = idx.entry_id(nid, oi);
+          sub_shape[sub_i] = shape[i];
+          sub_dtype[sub_i] = dtype[i];
+          sub_storage_type[sub_i] = storage_type[i];
+      }
+    }
+  }
+  // old2new doesn't contain placeholder / interfaces
+  for (auto& p : main_input_entry_to_sub) {
+    nnvm::NodeEntry main_entry = p.first;
+    nnvm::NodeEntry sub_entry = p.second;
+    const uint32_t sub_nid = sub_idx.node_id(sub_entry.node.get());
+    const uint32_t sub_i = sub_idx.entry_id(sub_entry);
+    const uint32_t i = idx.entry_id(main_entry);
+    const uint32_t sub_input_i = sub_inputsindex2pos.at(sub_nid);
+    sub_shape[sub_i] = shape[i];
+    sub_dtype[sub_i] = dtype[i];
+    sub_storage_type[sub_i] = storage_type[i];
+    sub_shape_inputs[sub_input_i] = sub_shape[sub_i];
+    sub_dtype_inputs[sub_input_i] = sub_dtype[sub_i];
+    sub_storage_type_inputs[sub_input_i] = sub_storage_type[sub_i];
+  }
+  subgraph.attrs["shape"] =
+      std::make_shared<dmlc::any>(std::move(sub_shape));
+  subgraph.attrs["dtype"] =
+      std::make_shared<dmlc::any>(std::move(sub_dtype));
+  subgraph.attrs["storage_type"] =
+      std::make_shared<dmlc::any>(std::move(sub_storage_type));
+  subgraph.attrs["shape_inputs"] =
+      std::make_shared<dmlc::any>(std::move(sub_shape_inputs));
+  subgraph.attrs["dtype_inputs"] =
+      std::make_shared<dmlc::any>(std::move(sub_dtype_inputs));
+  subgraph.attrs["storage_type_inputs"] =
+      std::make_shared<dmlc::any>(std::move(sub_storage_type_inputs));
+
+  return subgraph;
+}
+
+/*!
+ * \brief Generate a name for a new TRT node, avoid collision if some TRT_nodes are already defined
+ */
+const std::string GetNewTrtName(const Graph& g, const Graph& subgraph) {
+  const std::string name_prefix("TRT_node");
+  std::unordered_set<std::string> name_set;
+  DFSVisit(g.outputs, [&name_set, &name_prefix](const nnvm::NodePtr& node) {
+    if (node->attrs.name.compare(0, name_prefix.size(), name_prefix) == 0) {
+      name_set.insert(node->attrs.name);
+    }
+  });
+  // name inside the subgraph will be avaible as they will be removed
+  DFSVisit(subgraph.outputs, [&name_set, &name_prefix](const nnvm::NodePtr& node) {
+    if (node->attrs.name.compare(0, name_prefix.size(), name_prefix) == 0) {
+      name_set.erase(node->attrs.name);
+    }
+  });
+  uint32_t name_suffix = 0;
+  std::string full_name = name_prefix + std::to_string(name_suffix);
+  while (name_set.count(full_name)) {
+    full_name = name_prefix + std::to_string(++name_suffix);
+  }
+  return full_name;
+}
+
+/*!
+ * \brief helper function to display what nodes are in a specific subset
+ */
+void dispNodesSet(Graph g, std::unordered_set<nnvm::Node*> s) {
+  DFSVisit(g.outputs, [&s](const nnvm::NodePtr n){
+    if (s.count(n.get())) {
+      std::cout << "  Y " << n->attrs.name << std::endl;
+    } else {
+      std::cout << "  N " << n->attrs.name << std::endl;
+    }
+  });
+}
+
+/*!
+ * \brief Replace a set of nodes by a TensorRT node
+ */
+Graph ReplaceSubgraph(Graph&& g,
+                      std::unordered_set<nnvm::Node*> set_subgraph,
+                      std::unordered_map<std::string, NDArray>* const params_map) {
+  // Create MxNet subgraph
+  Graph subgraph;
+
+  const auto sub_outputs_in_main = GetSubgraphOutputs(g, set_subgraph);
+  subgraph.outputs = sub_outputs_in_main;
+  std::unordered_map<nnvm::Node*, nnvm::NodePtr> old2new;
+  std::deque<nnvm::Node*> stack;
+  std::unordered_set<nnvm::Node*> visited;
+  int32_t reservation = set_subgraph.size();
+  old2new.reserve(reservation);
+  visited.reserve(reservation);
+
+  for (auto& n : set_subgraph) {
+    old2new[n] = std::make_shared<nnvm::Node>(*n);
+  }
+
+  nnvm::NodeEntryMap<nnvm::NodeEntry> main_input_entry_to_sub;
+  for (auto& e : GetSubgraphInterfaceNodes(g, set_subgraph)) {
+    auto node = nnvm::Node::Create();
+    node->attrs.name = e.node->attrs.name + "_" + std::to_string(e.index);
+    auto new_e = nnvm::NodeEntry{node, 0, 0};
+    main_input_entry_to_sub[e] = new_e;
+  }
+
+  for (nnvm::NodeEntry& e : subgraph.outputs) {
+    e.node = old2new[e.node.get()];
+    stack.emplace_back(e.node.get());
+  }
+
+  while (!stack.empty()) {
+    auto vertex = stack.front();
+    stack.pop_front();
+    if (!visited.count(vertex)) {
+      visited.insert(vertex);
+      for (auto& e : vertex->inputs) {
+        auto it = main_input_entry_to_sub.find(e);
+        if (it != main_input_entry_to_sub.end()) {
+          e = it->second;
+        } else {
+          e.node = old2new[e.node.get()];
+        }
+      stack.emplace_back(e.node.get());
+      }
+    }
+  }
+
+  DFSVisit(subgraph.outputs, [&set_subgraph, &old2new](const nnvm::NodePtr& node) {
+    std::remove_if(node->control_deps.begin(),
+                   node->control_deps.end(),
+                   [&set_subgraph](nnvm::NodePtr n_ptr) {
+                    return !set_subgraph.count(n_ptr.get());
+                   });
+    for (nnvm::NodePtr& n_ptr : node->control_deps) {
+      n_ptr = old2new[n_ptr.get()];
+    }
+  });
+
+  subgraph = UpdateSubgraphAttrs(std::move(subgraph), g, old2new, main_input_entry_to_sub);
+  auto& sub_idx = subgraph.indexed_graph();
+
+  auto trtnodeptr = ConvertNnvmGraphToOnnx(subgraph, params_map);
+  trtnodeptr->attrs.name = GetNewTrtName(g, subgraph);
+
+  // Insert new trt node and unplug replaced nodes
+  std::unordered_map<uint32_t, nnvm::NodeEntry> sub_input_entryid_to_main;
+  for (auto& p : main_input_entry_to_sub) {
+    sub_input_entryid_to_main[sub_idx.entry_id(p.second)] = p.first;
+  }
+
+  trtnodeptr->inputs.resize(main_input_entry_to_sub.size());
+  {
+    uint32_t counter = 0;
+    for (uint32_t i : sub_idx.input_nodes()) {
+      auto it = sub_input_entryid_to_main.find(sub_idx.entry_id(i, 0));
+      if (it != sub_input_entryid_to_main.end()) {
+        trtnodeptr->inputs[counter++] = it->second;
+      }
+    }
+  }
+
+  nnvm::NodeEntryMap<uint32_t> sub_outputs_in_main_to_pos;
+  for (uint32_t i = 0; i < sub_outputs_in_main.size(); ++i) {
+    sub_outputs_in_main_to_pos[sub_outputs_in_main[i]] = i;
+  }
+  DFSVisit(g.outputs, [&sub_outputs_in_main_to_pos, &trtnodeptr](const nnvm::NodePtr& n) {
+    for (auto& e : n->inputs) {
+      auto it = sub_outputs_in_main_to_pos.find(e);
+      if (it != sub_outputs_in_main_to_pos.end()) {
+        e.index = it->second;
+        e.node = trtnodeptr;
+      }
+    }
+  });
+
+  for (size_t i = 0; i < g.outputs.size(); ++i) {
+    auto it = sub_outputs_in_main_to_pos.find(g.outputs[i]);
+    if (it != sub_outputs_in_main_to_pos.end()) {
+      g.outputs[i].index = it->second;
+      g.outputs[i].node = trtnodeptr;
+    }
+  }
+
+  Graph new_graph;
+  new_graph.outputs = g.outputs;
+  return new_graph;
+}
+
+std::vector<std::unordered_set<nnvm::Node*>> GetTrtCompatibleSubsets(const Graph& g,
+    std::unordered_map<std::string, NDArray>* const params_map) {
+  BidirectionalGraph biG = BidirectionalGraph(g);
+  std::vector<std::unordered_set<BidirectionalGraph::Node*>> subsets = biG.get_subsets(params_map);
+  std::vector<std::unordered_set<nnvm::Node*>> nnvm_subsets(subsets.size(),
+                                                            std::unordered_set<nnvm::Node*>());
+  for (size_t i = 0; i < subsets.size(); ++i) {
+    nnvm_subsets[i].reserve(subsets[i].size());
+    for (auto& n : subsets[i]) {
+      nnvm_subsets[i].insert(n->nnvmptr);
+    }
+  }
+  return nnvm_subsets;
+}
+
+}  // namespace exec
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/operator/contrib/nnvm_to_onnx-inl.h b/src/operator/contrib/nnvm_to_onnx-inl.h
new file mode 100644
index 00000000000..58f88b05143
--- /dev/null
+++ b/src/operator/contrib/nnvm_to_onnx-inl.h
@@ -0,0 +1,156 @@
+#ifndef MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_
+#define MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file tensorrt-inl.h
+ * \brief TensorRT Operator
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include <dmlc/logging.h>
+#include <dmlc/memory_io.h>
+#include <dmlc/serializer.h>
+#include <dmlc/parameter.h>
+#include <mxnet/base.h>
+#include <mxnet/operator.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+
+#include <NvInfer.h>
+#include <onnx/onnx.pb.h>
+
+#include <algorithm>
+#include <iostream>
+#include <map>
+#include <vector>
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+#include <string>
+
+#include "./tensorrt-inl.h"
+#include "../operator_common.h"
+#include "../../common/utils.h"
+#include "../../common/serialization.h"
+
+namespace mxnet {
+namespace op {
+namespace nnvm_to_onnx {
+
+using namespace nnvm;
+using namespace ::onnx;
+using int64 = ::google::protobuf::int64;
+
+std::unordered_map<std::string, TShape> GetPlaceholderShapes(const ShapeVector& shape_inputs,
+    const nnvm::IndexedGraph& ig);
+
+std::unordered_map<std::string, uint32_t> GetOutputLookup(const nnvm::IndexedGraph& ig);
+
+void ConvertPlaceholder(
+  const std::string& node_name,
+  const std::unordered_map<std::string, TShape>& placeholder_shapes,
+  GraphProto* const graph_proto);
+
+void ConvertConstant(GraphProto* const graph_proto,
+  const std::string& node_name,
+  std::unordered_map<std::string, NDArray>* const shared_buffer);
+
+void ConvertOutput(op::tensorrt::InferenceMap_t* const trt_output_map,
+                   GraphProto* const graph_proto,
+                   const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
+                   const std::string& node_name,
+                   const nnvm::Graph& g,
+                   const StorageTypeVector& storage_types,
+                   const DTypeVector& dtypes);
+
+typedef void (*ConverterFunction)(NodeProto *node_proto,
+                                  const NodeAttrs &attrs,
+                                  const nnvm::IndexedGraph &ig,
+                                  const array_view<IndexedGraph::NodeEntry> &inputs);
+
+
+// Forward declarations
+void ConvertConvolution(
+                        NodeProto *node_proto,
+                        const NodeAttrs &attrs,
+                        const nnvm::IndexedGraph &ig,
+                        const array_view<IndexedGraph::NodeEntry> &inputs);
+
+
+void ConvertPooling(NodeProto *node_proto,
+                    const NodeAttrs &attrs,
+                    const nnvm::IndexedGraph &ig,
+                    const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertActivation(NodeProto *node_proto,
+                       const NodeAttrs &attrs,
+                       const nnvm::IndexedGraph &ig,
+                       const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertFullyConnected(NodeProto *node_proto,
+                           const NodeAttrs &attrs,
+                           const nnvm::IndexedGraph &ig,
+                           const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertSoftmaxOutput(NodeProto *node_proto,
+                          const NodeAttrs &attrs,
+                          const nnvm::IndexedGraph &ig,
+                          const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertFlatten(NodeProto *node_proto,
+                    const NodeAttrs &attrs,
+                    const nnvm::IndexedGraph &ig,
+                    const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertBatchNorm(NodeProto *node_proto,
+                    const NodeAttrs &attrs,
+                    const nnvm::IndexedGraph &ig,
+                    const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertElementwiseAdd(NodeProto *node_proto,
+                    const NodeAttrs &attrs,
+                    const nnvm::IndexedGraph &ig,
+                    const array_view<IndexedGraph::NodeEntry> &inputs);
+
+TRTParam ConvertNnvmGraphToOnnx(
+    const nnvm::Graph &g,
+    std::unordered_map<std::string, NDArray> *const shared_buffer);
+
+static const std::unordered_map<std::string, ConverterFunction> converter_map = {
+  {"Convolution", ConvertConvolution},
+  {"Pooling", ConvertPooling},
+  {"Activation", ConvertActivation},
+  {"FullyConnected", ConvertFullyConnected},
+  {"SoftmaxOutput", ConvertSoftmaxOutput},
+  {"Flatten", ConvertFlatten},
+  {"BatchNorm", ConvertBatchNorm},
+  {"elemwise_add", ConvertElementwiseAdd}};
+
+}  // namespace nnvm_to_onnx
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
+
+#endif  // MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_
diff --git a/src/operator/contrib/nnvm_to_onnx.cc b/src/operator/contrib/nnvm_to_onnx.cc
new file mode 100644
index 00000000000..1a68d7aae53
--- /dev/null
+++ b/src/operator/contrib/nnvm_to_onnx.cc
@@ -0,0 +1,548 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file trt.cc
+ * \brief TensorRT operation registration
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include "./nnvm_to_onnx-inl.h"
+
+#include <mxnet/base.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+
+#include <algorithm>
+#include <fstream>
+#include <iostream>
+#include <unordered_map>
+#include <vector>
+
+#include "./tensorrt-inl.h"
+#include "../../common/serialization.h"
+#include "../../common/utils.h"
+#include "../../ndarray/ndarray_function.h"
+#include "../../operator/nn/activation-inl.h"
+#include "../../operator/nn/batch_norm-inl.h"
+#include "../../operator/nn/convolution-inl.h"
+#include "../../operator/nn/fully_connected-inl.h"
+#include "../../operator/nn/pooling-inl.h"
+#include "../../operator/softmax_output-inl.h"
+
+// #include <onnx/checker.h>
+
+namespace mxnet {
+namespace op {
+namespace nnvm_to_onnx {
+
+op::TRTParam ConvertNnvmGraphToOnnx(
+    const nnvm::Graph& g,
+    std::unordered_map<std::string, NDArray>* const shared_buffer) {
+    op::TRTParam trt_param;
+    op::tensorrt::NameToIdx_t trt_input_map;
+    op::tensorrt::InferenceMap_t trt_output_map;
+
+  const nnvm::IndexedGraph& ig = g.indexed_graph();
+  const StorageTypeVector& storage_types =
+      g.GetAttr<StorageTypeVector>("storage_type");
+  const DTypeVector& dtypes = g.GetAttr<DTypeVector>("dtype");
+  const ShapeVector& shape_inputs = g.GetAttr<ShapeVector>("shape_inputs");
+
+  for (auto& e : storage_types) {
+    if (e != mshadow::kFloat32) {
+      LOG(FATAL) << "ONNX converter does not support types other than float32 "
+                    "right now.";
+    }
+  }
+
+  ModelProto model_proto;
+  // Need to determine IR versions and features to support
+  model_proto.set_ir_version(static_cast<int64>(2));
+  GraphProto* graph_proto = model_proto.mutable_graph();
+  //  graph_proto->set_name(graph_name);
+
+  std::unordered_map<std::string, TShape> placeholder_shapes =
+      GetPlaceholderShapes(shape_inputs, ig);
+  std::unordered_map<std::string, uint32_t> output_lookup = GetOutputLookup(ig);
+  uint32_t current_input = 0;
+
+  // can't do a foreach over IndexedGraph since it doesn't
+  // implement begin(), etc.
+  for (uint32_t node_idx = 0; node_idx < ig.num_nodes(); ++node_idx) {
+    const IndexedGraph::Node& node = ig[node_idx];
+    const nnvm::Node* source = node.source;
+    const NodeAttrs& attrs = source->attrs;
+    const Op* op = source->op();
+
+    std::string node_name = attrs.name;
+    // Here, "variable" actually means anything that's not an op,
+    // i.e. a constant (weights) or a placeholder
+    if (source->is_variable()) {
+      // Is this a placeholder?
+      if (shared_buffer->count(node_name) == 0) {
+        // This fixes the problem with a SoftmaxOutput node during inference,
+        // but it's hacky.
+        // Need to figure out how to properly fix it.
+        if (node_name.find("label") != std::string::npos) {
+          current_input++;
+          continue;
+        }
+        trt_input_map.emplace(node_name, current_input++);
+        ConvertPlaceholder(node_name, placeholder_shapes, graph_proto);
+      } else {
+        // If it's not a placeholder, then by exclusion it's a constant.
+        ConvertConstant(graph_proto, node_name, shared_buffer);
+      }  // is_placeholder
+    } else {
+      // It's an op, rather than a "variable" (constant or placeholder)
+      NodeProto* node_proto = graph_proto->add_node();
+      node_proto->set_name(node_name);
+      if (converter_map.count(op->name) == 0) {
+        LOG(FATAL) << "Conversion for node of type " << op->name << " (node "
+                   << node_name << ") "
+                   << " is not supported yet.";
+      }
+      // Find function ptr to a converter based on the op name, and
+      // invoke the converter. This looks unsafe because find may not
+      // succeed, but it does because we're in the operator logic after
+      // testing that this node name does not represent a variable.
+      converter_map.find(op->name)->second(node_proto, attrs, ig, node.inputs);
+      // Add all inputs to the current node (i.e. add graph edges)
+      for (const nnvm::IndexedGraph::NodeEntry& entry : node.inputs) {
+        std::string in_node_name = ig[entry.node_id].source->attrs.name;
+        // As before, we're not adding labels e.g. for SoftmaxOutput,
+        // but I wish there was a less hacky way to do it than name matching.
+        if (in_node_name.find("label") != std::string::npos) {
+          continue;
+        }
+        node_proto->add_input(in_node_name);
+      }
+      // The node's output will have the same name as the node name.
+      node_proto->add_output(node_name);
+      // See if the current node is an output node
+      auto out_iter = output_lookup.find(node_name);
+      // We found an output
+      if (out_iter != output_lookup.end()) {
+        ConvertOutput(&trt_output_map, graph_proto, out_iter, node_name, g,
+                      storage_types, dtypes);
+      }  // output found
+    }    // conversion function exists
+  }      // loop over i from 0 to num_nodes
+
+  LOG(INFO) << "ONNX graph construction complete.";
+  model_proto.SerializeToString(&trt_param.serialized_onnx_graph);
+  common::Serialize<op::tensorrt::NameToIdx_t>(trt_input_map,
+                                          &trt_param.serialized_input_map);
+  common::Serialize<op::tensorrt::InferenceMap_t>(trt_output_map,
+                                             &trt_param.serialized_output_map);
+  //  onnx::checker::check_model(model_proto);
+  return trt_param;
+}
+
+void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
+                        const nnvm::IndexedGraph& ig,
+                        const array_view<IndexedGraph::NodeEntry>& inputs) {
+  const op::ConvolutionParam& conv_param =
+      nnvm::get<op::ConvolutionParam>(attrs.parsed);
+
+  node_proto->set_op_type("Conv");
+
+  const TShape kernel = conv_param.kernel;
+  const TShape stride = conv_param.stride;
+  const TShape dilate = conv_param.dilate;
+  const TShape pad = conv_param.pad;
+  // const uint32_t num_filter = conv_param.num_filter;
+  const uint32_t num_group = conv_param.num_group;
+  // const bool no_bias = conv_param.no_bias;
+  const dmlc::optional<int> layout = conv_param.layout;
+
+  // kernel shape
+  AttributeProto* const kernel_shape = node_proto->add_attribute();
+  kernel_shape->set_name("kernel_shape");
+  kernel_shape->set_type(AttributeProto::INTS);
+
+  for (int kval : kernel) {
+    kernel_shape->add_ints(static_cast<int64>(kval));
+  }
+
+  // pads
+  AttributeProto* const pads = node_proto->add_attribute();
+  pads->set_name("pads");
+  pads->set_type(AttributeProto::INTS);
+
+  for (int kval : pad) {
+    pads->add_ints(static_cast<int64>(kval));
+    pads->add_ints(static_cast<int64>(kval));
+  }
+
+  // dilations
+  AttributeProto* const dilations = node_proto->add_attribute();
+  dilations->set_name("dilations");
+  dilations->set_type(AttributeProto::INTS);
+  for (int kval : dilate) {
+    dilations->add_ints(static_cast<int64>(kval));
+  }
+
+  // strides
+  AttributeProto* const strides = node_proto->add_attribute();
+  strides->set_name("strides");
+  strides->set_type(AttributeProto::INTS);
+  for (int kval : stride) {
+    strides->add_ints(static_cast<int64>(kval));
+  }
+
+  // group
+  AttributeProto* const group = node_proto->add_attribute();
+  group->set_name("group");
+  group->set_type(AttributeProto::INT);
+  group->set_i(static_cast<int64>(num_group));
+}  // end ConvertConvolution
+
+void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
+                    const nnvm::IndexedGraph& ig,
+                    const array_view<IndexedGraph::NodeEntry>& inputs) {
+  const op::PoolingParam& pooling_param =
+      nnvm::get<op::PoolingParam>(attrs.parsed);
+
+  const TShape kernel = pooling_param.kernel;
+  const TShape stride = pooling_param.stride;
+  const TShape pad = pooling_param.pad;
+  const int pool_type = pooling_param.pool_type;
+  // const int pooling_convention = pooling_param.pooling_convention;
+  const bool global_pool = pooling_param.global_pool;
+
+  if (global_pool) {
+    if (pool_type == 0) {
+      node_proto->set_op_type("GlobalMaxPool");
+    } else {
+      node_proto->set_op_type("GlobalAveragePool");
+    }
+    return;
+
+  } else {
+    // kernel_shape
+    AttributeProto* const kernel_shape = node_proto->add_attribute();
+    kernel_shape->set_name("kernel_shape");
+    kernel_shape->set_type(AttributeProto::INTS);
+    for (int kval : kernel) {
+      kernel_shape->add_ints(static_cast<int64>(kval));
+    }
+
+    // pads
+    AttributeProto* const pads = node_proto->add_attribute();
+    pads->set_name("pads");
+    pads->set_type(AttributeProto::INTS);
+    for (int kval : pad) {
+      pads->add_ints(static_cast<int64>(kval));
+    }
+
+    // strides
+    AttributeProto* const strides = node_proto->add_attribute();
+    strides->set_name("strides");
+    strides->set_type(AttributeProto::INTS);
+    for (int kval : stride) {
+      strides->add_ints(static_cast<int64>(kval));
+    }
+
+    if (pool_type == 0) {
+      node_proto->set_op_type("MaxPool");
+
+    } else {
+      node_proto->set_op_type("AveragePool");
+    }  // average pooling
+    // not global pooling
+  }
+}  // end ConvertPooling
+
+void ConvertActivation(NodeProto* node_proto, const NodeAttrs& attrs,
+                       const nnvm::IndexedGraph& ig,
+                       const array_view<IndexedGraph::NodeEntry>& inputs) {
+  const op::ActivationParam& act_param =
+      nnvm::get<op::ActivationParam>(attrs.parsed);
+  std::string act_type;
+  switch (act_param.act_type) {
+    case op::activation::kReLU:
+      act_type = "Relu";
+      break;
+    case op::activation::kSigmoid:
+      act_type = "Sigmoid";
+      break;
+    case op::activation::kTanh:
+      act_type = "Tanh";
+      break;
+    case op::activation::kSoftReLU:
+      // act_type = "SoftReLU";
+      throw dmlc::Error("SoftReLU is not supported in ONNX");
+      break;
+    default:
+      throw dmlc::Error("Activation of such type doesn't exist");
+  }
+
+  node_proto->set_op_type(act_type);
+}
+
+void ConvertFullyConnected(NodeProto* node_proto, const NodeAttrs& attrs,
+                           const nnvm::IndexedGraph& ig,
+                           const array_view<IndexedGraph::NodeEntry>& inputs) {
+  // const op::FullyConnectedParam &act_param =
+  //    nnvm::get<op::FullyConnectedParam>(attrs.parsed);
+
+  node_proto->set_op_type("Gemm");
+
+  // const int num_hidden = act_param.num_hidden;
+  // const bool no_bias = act_param.no_bias;
+  // Whether to collapse all but the first axis of the input data tensor.
+  // const bool flatten = act_param.flatten;
+
+  AttributeProto* const alpha = node_proto->add_attribute();
+  alpha->set_name("alpha");
+  alpha->set_type(AttributeProto::FLOAT);
+  alpha->set_f(1.0f);
+
+  AttributeProto* const beta = node_proto->add_attribute();
+  beta->set_name("beta");
+  beta->set_type(AttributeProto::FLOAT);
+  beta->set_f(1.0f);
+
+  AttributeProto* const broadcast = node_proto->add_attribute();
+  broadcast->set_name("broadcast");
+  broadcast->set_type(AttributeProto::INT);
+  broadcast->set_i(1);
+
+  AttributeProto* const transA = node_proto->add_attribute();
+  transA->set_name("transA");
+  transA->set_type(AttributeProto::INT);
+  transA->set_i(0);
+
+  AttributeProto* const transB = node_proto->add_attribute();
+  transB->set_name("transB");
+  transB->set_type(AttributeProto::INT);
+  transB->set_i(1);
+}
+
+void ConvertSoftmaxOutput(NodeProto* node_proto, const NodeAttrs& attrs,
+                          const nnvm::IndexedGraph& ig,
+                          const array_view<IndexedGraph::NodeEntry>& inputs) {
+  node_proto->set_op_type("Softmax");
+
+  // Setting by default to 1 since MxNet doesn't provide such an attribute
+  // for softmax
+  // in its node params. This attribute is only relevant when the input is
+  // coerced
+  // to 2D, and in that case dimension 0 is assumed to be the batch
+  // dimension.
+  AttributeProto* const axis = node_proto->add_attribute();
+  axis->set_name("axis");
+  axis->set_type(AttributeProto::INT);
+  axis->set_i(1);
+}
+
+void ConvertFlatten(NodeProto* node_proto, const NodeAttrs& attrs,
+                    const nnvm::IndexedGraph& ig,
+                    const array_view<IndexedGraph::NodeEntry>& inputs) {
+  node_proto->set_op_type("Flatten");
+
+  // Setting by default to 1 since MxNet doesn't provide such an attribute
+  // for Flatten
+  // in its node params. This attribute is only relevant when the input is
+  // coerced
+  // to 2D, and in that case dimension 0 is assumed to be the batch
+  // dimension.
+  AttributeProto* const axis = node_proto->add_attribute();
+  axis->set_name("axis");
+  axis->set_type(AttributeProto::INT);
+  axis->set_i(1);
+}
+
+void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs,
+                      const nnvm::IndexedGraph& ig,
+                      const array_view<IndexedGraph::NodeEntry>& inputs) {
+  node_proto->set_op_type("BatchNormalization");
+  const op::BatchNormParam& param = nnvm::get<op::BatchNormParam>(attrs.parsed);
+
+  AttributeProto* const epsilon = node_proto->add_attribute();
+  epsilon->set_name("epsilon");
+  epsilon->set_type(AttributeProto::FLOAT);
+  epsilon->set_f(static_cast<float>(param.eps));
+
+  AttributeProto* const is_test = node_proto->add_attribute();
+  is_test->set_name("is_test");
+  is_test->set_type(AttributeProto::INT);
+  is_test->set_i(1);
+
+  AttributeProto* const momentum = node_proto->add_attribute();
+  momentum->set_name("momentum");
+  momentum->set_type(AttributeProto::FLOAT);
+  momentum->set_f(param.momentum);
+
+  AttributeProto* const spatial = node_proto->add_attribute();
+  spatial->set_name("spatial");
+  spatial->set_type(AttributeProto::INT);
+  spatial->set_i(1);
+
+  AttributeProto* const consumed = node_proto->add_attribute();
+  consumed->set_name("consumed_inputs");
+  consumed->set_type(AttributeProto::INTS);
+
+  for (int i = 0; i < 5; i++) {
+    int val = (i < 3) ? 0 : 1;
+    consumed->add_ints(static_cast<int64>(val));
+  }
+}
+
+void ConvertElementwiseAdd(NodeProto* node_proto, const NodeAttrs& attrs,
+                           const nnvm::IndexedGraph& ig,
+                           const array_view<IndexedGraph::NodeEntry>& inputs) {
+  node_proto->set_op_type("Add");
+  AttributeProto* const axis = node_proto->add_attribute();
+  axis->set_name("axis");
+  axis->set_type(AttributeProto::INT);
+  axis->set_i(1);
+
+  AttributeProto* const broadcast = node_proto->add_attribute();
+  broadcast->set_name("broadcast");
+  broadcast->set_type(AttributeProto::INT);
+  broadcast->set_i(0);  // 1
+}
+
+std::unordered_map<std::string, TShape> GetPlaceholderShapes(
+    const ShapeVector& shape_inputs, const nnvm::IndexedGraph& ig) {
+  std::unordered_map<std::string, TShape> placeholder_shapes;
+  for (uint32_t i = 0; i < shape_inputs.size(); ++i) {
+    std::string name = ig[ig.input_nodes()[i]].source->attrs.name;
+    TShape shp = shape_inputs[i];
+    if (shp.ndim() > 0) {
+      placeholder_shapes.emplace(name, shp);
+    }
+  }
+  return placeholder_shapes;
+}
+
+std::unordered_map<std::string, uint32_t> GetOutputLookup(
+    const nnvm::IndexedGraph& ig) {
+  std::unordered_map<std::string, uint32_t> output_lookup;
+  const std::vector<nnvm::IndexedGraph::NodeEntry>& graph_outputs =
+      ig.outputs();
+  for (uint32_t i = 0; i < graph_outputs.size(); ++i) {
+    const uint32_t id = graph_outputs[i].node_id;
+    const IndexedGraph::Node ig_node = ig[id];
+    const nnvm::Node* const source = ig_node.source;
+    const std::string name = source->attrs.name;
+    output_lookup.emplace(name, i);
+  }
+  return output_lookup;
+}
+
+void ConvertPlaceholder(
+    const std::string& node_name,
+    const std::unordered_map<std::string, TShape>& placeholder_shapes,
+    GraphProto* const graph_proto) {
+  auto val_info_proto = graph_proto->add_input();
+  auto type_proto = val_info_proto->mutable_type()->mutable_tensor_type();
+  auto shape_proto = type_proto->mutable_shape();
+
+  val_info_proto->set_name(node_name);
+  // Will support fp16, etc. in the near future
+  type_proto->set_elem_type(TensorProto_DataType_FLOAT);
+  auto entry_shape = placeholder_shapes.find(node_name)->second;
+
+  for (const auto& elem : entry_shape) {
+    TensorShapeProto_Dimension* const tsp_dim = shape_proto->add_dim();
+    tsp_dim->set_dim_value(static_cast<int64>(elem));
+  }
+}
+
+void ConvertConstant(
+    GraphProto* const graph_proto, const std::string& node_name,
+    std::unordered_map<std::string, NDArray>* const shared_buffer) {
+  NodeProto* const node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
+  node_proto->add_output(node_name);
+  node_proto->set_op_type("Constant");
+
+  const NDArray nd = shared_buffer->find(node_name)->second;
+  const TBlob blob = nd.data();
+  const TShape shape = blob.shape_;
+  // const int type_flag = blob.type_flag_;
+  const int32_t size = shape.Size();
+
+  // Use MSHADOW_TYPE_SWITCH for e.g. fp16 support !!!
+  // const int dev_mask = blob.dev_mask();
+  // const int dev_id = blob.dev_id();
+
+  std::shared_ptr<float> shared_data_ptr(new float[size]);
+  float* const data_ptr = shared_data_ptr.get();
+  nd.SyncCopyToCPU(static_cast<void*>(data_ptr), size);
+
+  AttributeProto* const tensor_attr = node_proto->add_attribute();
+  tensor_attr->set_name("value");
+  tensor_attr->set_type(AttributeProto::TENSOR);
+
+  TensorProto* const tensor_proto = tensor_attr->mutable_t();
+  tensor_proto->set_data_type(TensorProto_DataType_FLOAT);
+  for (auto& dim : shape) {
+    tensor_proto->add_dims(static_cast<int64>(dim));
+  }
+
+  for (int blob_idx = 0; blob_idx < size; ++blob_idx) {
+    tensor_proto->add_float_data(data_ptr[blob_idx]);
+  }
+}
+
+void ConvertOutput(
+    op::tensorrt::InferenceMap_t* const trt_output_map,
+    GraphProto* const graph_proto,
+    const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
+    const std::string& node_name, const nnvm::Graph& g,
+    const StorageTypeVector& storage_types, const DTypeVector& dtypes) {
+  const nnvm::IndexedGraph& ig = g.indexed_graph();
+  uint32_t out_idx = ig.entry_id(ig.outputs()[out_iter->second]);
+  TShape out_shape = g.GetAttr<nnvm::ShapeVector>("shape")[out_idx];
+  int storage_type = storage_types[out_idx];
+  int dtype = dtypes[out_idx];
+
+  // This should work with fp16 as well
+  op::tensorrt::InferenceTuple_t out_tuple{out_iter->second, out_shape, storage_type,
+                                      dtype};
+
+  trt_output_map->emplace(node_name, out_tuple);
+
+  auto graph_out = graph_proto->add_output();
+  auto tensor_type = graph_out->mutable_type()->mutable_tensor_type();
+  auto tensor_shape_proto = tensor_type->mutable_shape();
+  graph_out->set_name(node_name);
+
+  // Also support fp16.
+  tensor_type->set_elem_type(TensorProto_DataType_FLOAT);
+
+  for (int64_t dim_shp : out_shape) {
+    TensorShapeProto_Dimension* const tsp_dim = tensor_shape_proto->add_dim();
+    tsp_dim->set_dim_value(static_cast<int64>(dim_shp));
+  }
+}
+
+}  // namespace nnvm_to_onnx
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/operator/contrib/tensorrt-inl.h b/src/operator/contrib/tensorrt-inl.h
new file mode 100644
index 00000000000..be4248fc762
--- /dev/null
+++ b/src/operator/contrib/tensorrt-inl.h
@@ -0,0 +1,140 @@
+#ifndef MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_
+#define MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file tensorrt-inl.h
+ * \brief TensorRT Operator
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include <dmlc/logging.h>
+#include <dmlc/memory_io.h>
+#include <dmlc/serializer.h>
+#include <dmlc/parameter.h>
+#include <mxnet/base.h>
+#include <mxnet/operator.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+
+#include <NvInfer.h>
+#include <onnx/onnx.pb.h>
+
+#include <algorithm>
+#include <iostream>
+#include <map>
+#include <vector>
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+#include <string>
+
+#include "../operator_common.h"
+#include "../../common/utils.h"
+#include "../../common/serialization.h"
+#include "../../executor/exec_pass.h"
+#include "../../executor/graph_executor.h"
+#include "../../executor/onnx_to_tensorrt.h"
+
+namespace mxnet {
+namespace op {
+
+using namespace nnvm;
+using namespace ::onnx;
+using int64 = ::google::protobuf::int64;
+
+namespace tensorrt {
+  enum class TypeIO { Inputs = 0, Outputs = 1 };
+  using NameToIdx_t = std::map<std::string, int32_t>;
+  using InferenceTuple_t = std::tuple<uint32_t, TShape, int, int>;
+  using InferenceMap_t = std::map<std::string, InferenceTuple_t>;
+}  // namespace tensorrt
+
+using trt_name_to_idx = std::map<std::string, uint32_t>;
+
+struct TRTParam : public dmlc::Parameter<TRTParam> {
+  std::string serialized_onnx_graph;
+  std::string serialized_input_map;
+  std::string serialized_output_map;
+  tensorrt::NameToIdx_t input_map;
+  tensorrt::InferenceMap_t output_map;
+  ::onnx::ModelProto onnx_pb_graph;
+
+  TRTParam() {}
+
+  TRTParam(const ::onnx::ModelProto& onnx_graph,
+           const tensorrt::InferenceMap_t& input_map,
+           const tensorrt::NameToIdx_t& output_map) {
+    common::Serialize(input_map, &serialized_input_map);
+    common::Serialize(output_map, &serialized_output_map);
+    onnx_graph.SerializeToString(&serialized_onnx_graph);
+  }
+
+DMLC_DECLARE_PARAMETER(TRTParam) {
+    DMLC_DECLARE_FIELD(serialized_onnx_graph)
+    .describe("Serialized ONNX graph");
+    DMLC_DECLARE_FIELD(serialized_input_map)
+    .describe("Map from inputs to topological order as input.");
+    DMLC_DECLARE_FIELD(serialized_output_map)
+    .describe("Map from outputs to order in g.outputs.");
+  }
+};
+
+struct TRTEngineParam {
+  nvinfer1::IExecutionContext* trt_executor;
+  std::vector<std::pair<uint32_t, tensorrt::TypeIO> > binding_map;
+};
+
+OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context ctx,
+                          const std::vector<TShape>& ishape,
+                          const std::vector<int>& itype);
+
+template<typename xpu>
+void TRTCompute(const OpStatePtr& state, const OpContext& ctx,
+                const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
+                const std::vector<TBlob>& outputs);
+
+inline bool TRTInferShape(const NodeAttrs& attrs,
+                          std::vector<TShape> *in_shape,
+                          std::vector<TShape> *out_shape);
+
+inline bool TRTInferStorageType(const NodeAttrs& attrs,
+                                const int dev_mask,
+                                DispatchMode* dispatch_mode,
+                                std::vector<int> *in_storage_type,
+                                std::vector<int> *out_storage_type);
+
+inline bool TRTInferType(const NodeAttrs& attrs,
+                         std::vector<int> *in_dtype,
+                         std::vector<int> *out_dtype);
+
+inline std::vector<std::string> TRTListInputNames(const NodeAttrs& attrs);
+
+inline std::vector<std::string> TRTListOutputNames(const NodeAttrs& attrs);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
+
+#endif  // MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_
diff --git a/src/operator/contrib/tensorrt.cc b/src/operator/contrib/tensorrt.cc
new file mode 100644
index 00000000000..535c7875e33
--- /dev/null
+++ b/src/operator/contrib/tensorrt.cc
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file trt.cc
+ * \brief TensorRT operation registration
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include "./tensorrt-inl.h"
+
+#include <mxnet/base.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass_functions.h>
+
+#include <algorithm>
+#include <fstream>
+#include <iostream>
+#include <unordered_map>
+#include <vector>
+
+#include "../../common/serialization.h"
+#include "../../common/utils.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(TRTParam);
+
+OpStatePtr GetPtrMapping(nvinfer1::ICudaEngine* trt_engine,
+                         tensorrt::NameToIdx_t input_map,
+                         tensorrt::NameToIdx_t output_map) {
+  TRTEngineParam param;
+  for (int b = 0; b < trt_engine->getNbBindings(); ++b) {
+    const std::string& binding_name = trt_engine->getBindingName(b);
+    if (trt_engine->bindingIsInput(b)) {
+      param.binding_map.emplace_back(input_map[binding_name],
+                                     tensorrt::TypeIO::Inputs);
+    } else {
+      param.binding_map.emplace_back(output_map[binding_name],
+                                     tensorrt::TypeIO::Outputs);
+    }
+  }
+  param.trt_executor = trt_engine->createExecutionContext();
+  return OpStatePtr::Create<TRTEngineParam>(param);
+}
+
+OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context ctx,
+                          const std::vector<TShape>& ishape,
+                          const std::vector<int>& itype) {
+  const TRTParam& node_param = nnvm::get<TRTParam>(attrs.parsed);
+
+  ::onnx::ModelProto model_proto;
+  bool success = model_proto.ParseFromString(node_param.serialized_onnx_graph);
+  if (!success) {
+    LOG(FATAL) << "Problems parsing serialized ONNX model.";
+  }
+  auto graph = model_proto.graph();
+  auto first_input_type = graph.input(0).type().tensor_type();
+  auto dim_value = first_input_type.shape().dim(0).dim_value();
+  uint64_t batch_size = static_cast<uint64_t>(dim_value);
+  // Need to set up max workspace size based on device properties
+  nvinfer1::ICudaEngine* const trt_engine = ::onnx_to_tensorrt::onnxToTrtCtx(
+      node_param.serialized_onnx_graph, batch_size, 1 << 30);
+
+  LOG(INFO) << "TensorRT engine instantiated!!!";
+
+  tensorrt::NameToIdx_t output_map;
+  for (auto& el : node_param.output_map) {
+    output_map[el.first] = std::get<0>(el.second);
+  }
+  return GetPtrMapping(trt_engine, node_param.input_map, output_map);
+}
+
+void TRTParamParser(nnvm::NodeAttrs* attrs) {
+  using namespace mshadow;
+
+  TRTParam param_;
+
+  try {
+    param_.Init(attrs->dict);
+    common::Deserialize(&param_.input_map, param_.serialized_input_map);
+    common::Deserialize(&param_.output_map, param_.serialized_output_map);
+    param_.onnx_pb_graph.ParseFromString(param_.serialized_onnx_graph);
+  } catch (const dmlc::ParamError& e) {
+    std::ostringstream os;
+    os << e.what();
+    os << ", in operator " << attrs->op->name << "("
+       << "name=\"" << attrs->name << "\"";
+    for (const auto& k : attrs->dict) {
+      os << ", " << k.first << "=\"" << k.second << "\"";
+    }
+    os << ")";
+    throw dmlc::ParamError(os.str());
+  }
+
+  attrs->parsed = std::move(param_);
+}
+
+template <>
+void TRTCompute<cpu>(const OpStatePtr& state, const OpContext& ctx,
+                     const std::vector<TBlob>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<TBlob>& outputs) {
+  LOG(FATAL) << "TRTCompute not implemented on the CPU";
+}
+
+inline bool TRTInferShape(const NodeAttrs& attrs, std::vector<TShape>* in_shape,
+                          std::vector<TShape>* out_shape) {
+  const auto node_param = nnvm::get<TRTParam>(attrs.parsed);
+  for (auto& el : node_param.output_map) {
+    (*out_shape)[std::get<0>(el.second)] = std::get<1>(el.second);
+  }
+  return true;
+}
+
+inline bool TRTInferStorageType(const NodeAttrs& attrs, const int dev_mask,
+                                DispatchMode* dispatch_mode,
+                                std::vector<int>* in_storage_type,
+                                std::vector<int>* out_storage_type) {
+  return storage_type_assign(out_storage_type, mxnet::kDefaultStorage,
+                             dispatch_mode, DispatchMode::kFCompute);
+}
+
+inline bool TRTInferType(const NodeAttrs& attrs, std::vector<int>* in_dtype,
+                         std::vector<int>* out_dtype) {
+  const auto node_param = nnvm::get<TRTParam>(attrs.parsed);
+  for (auto& el : node_param.output_map) {
+    (*out_dtype)[std::get<0>(el.second)] = std::get<3>(el.second);
+  }
+  return true;
+}
+
+inline std::vector<std::string> TRTListInputNames(const NodeAttrs& attrs) {
+  std::vector<std::string> output;
+  const auto node_param = nnvm::get<TRTParam>(attrs.parsed);
+  output.reserve(node_param.input_map.size());
+  for (auto& el : node_param.input_map) {
+    output[el.second] = el.first;
+  }
+  return output;
+}
+
+inline std::vector<std::string> TRTListOutputNames(const NodeAttrs& attrs) {
+  std::vector<std::string> output;
+  const auto node_param = nnvm::get<TRTParam>(attrs.parsed);
+  output.reserve(node_param.output_map.size());
+  for (auto& el : node_param.output_map) {
+    output[std::get<0>(el.second)] = el.first;
+  }
+  return output;
+}
+
+NNVM_REGISTER_OP(_trt_op)
+    .describe(R"code(TRT operation (one engine)
+)code" ADD_FILELINE)
+    .set_num_inputs([](const NodeAttrs& attrs) {
+      const auto node_param = nnvm::get<TRTParam>(attrs.parsed);
+      return node_param.input_map.size();
+    })
+    .set_num_outputs([](const NodeAttrs& attrs) {
+      const auto node_param = nnvm::get<TRTParam>(attrs.parsed);
+      return node_param.output_map.size();
+    })
+    .set_attr_parser(TRTParamParser)
+    .set_attr<nnvm::FInferShape>("FInferShape", TRTInferShape)
+    .set_attr<nnvm::FInferType>("FInferType", TRTInferType)
+    .set_attr<nnvm::FListInputNames>("FListInputNames", TRTListInputNames)
+    .set_attr<nnvm::FListOutputNames>("FListOutputNames", TRTListOutputNames)
+    .set_attr<FCreateOpState>("FCreateOpState", TRTCreateState)
+    .set_attr<FStatefulCompute>("FStatefulCompute<cpu>", TRTCompute<cpu>)
+    .set_attr<FInferStorageType>("FInferStorageType", TRTInferStorageType);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/src/operator/contrib/tensorrt.cu b/src/operator/contrib/tensorrt.cu
new file mode 100644
index 00000000000..5211b0b9b03
--- /dev/null
+++ b/src/operator/contrib/tensorrt.cu
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file trt.cu
+ * \brief TensorRT GPU operation
+ * \author Marek Kolodziej, Clement Fuji Tsang
+*/
+
+#if MXNET_USE_TENSORRT
+
+#include "./tensorrt-inl.h"
+
+namespace mxnet {
+namespace op {
+
+#define CHECK_CUDART(x) do { \
+  cudaError_t res = (x); \
+  if (res != cudaSuccess) { \
+    fprintf(stderr, "CUDART: %s = %d (%s) at (%s:%d)\n", \
+      #x, res, cudaGetErrorString(res), __FILE__, __LINE__); \
+    exit(1); \
+  } \
+} while (0)
+
+template<>
+void TRTCompute<gpu>(const OpStatePtr& state, const OpContext& ctx,
+                     const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
+                     const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+  cudaStream_t cuda_s = Stream<gpu>::GetStream(s);
+  const auto& param = state.get_state<TRTEngineParam>();
+  std::vector<void*> bindings;
+  bindings.reserve(param.binding_map.size());
+  for (auto& p : param.binding_map) {
+    if (p.second == tensorrt::TypeIO::Inputs) {
+      bindings.emplace_back(inputs[p.first].dptr_);
+    } else {
+      bindings.emplace_back(outputs[p.first].dptr_);
+    }
+  }
+
+  const int batch_size = static_cast<int>(inputs[0].shape_[0]);
+  param.trt_executor->enqueue(batch_size, bindings.data(), cuda_s, nullptr);
+  CHECK_CUDART(cudaStreamSynchronize(cuda_s));
+}
+
+NNVM_REGISTER_OP(_trt_op)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", TRTCompute<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_TENSORRT
diff --git a/tests/python/tensorrt/test_tensorrt_lenet5.py b/tests/python/tensorrt/test_tensorrt_lenet5.py
new file mode 100644
index 00000000000..d51814fff89
--- /dev/null
+++ b/tests/python/tensorrt/test_tensorrt_lenet5.py
@@ -0,0 +1,175 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+#pylint: disable=unused-import
+import unittest
+#pylint: enable=unused-import
+import numpy as np
+import mxnet as mx
+from ctypes.util import find_library
+from mxnet.cuda_utils import get_device_count
+
+assert get_device_count() > 0, "No GPUs available to test TensorRT"
+
+assert find_library('nvinfer') is not None, "Can't find the TensorRT shared library"
+
+def get_use_tensorrt():
+    return int(os.environ.get("MXNET_USE_TENSORRT", 0))
+
+def set_use_tensorrt(status = False):
+    os.environ["MXNET_USE_TENSORRT"] = str(int(status))
+
+def merge_dicts(*dict_args):
+    """Merge arg_params and aux_params to populate shared_buffer"""
+    result = {}
+    for dictionary in dict_args:
+        result.update(dictionary)
+    return result
+
+def get_iters(mnist, batch_size):
+    """Get MNIST iterators."""
+    train_iter = mx.io.NDArrayIter(mnist['train_data'],
+                                   mnist['train_label'],
+                                   batch_size,
+                                   shuffle=True)
+    val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
+    test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
+    all_test_labels = np.array(mnist['test_label'])
+    return train_iter, val_iter, test_iter, all_test_labels
+
+def lenet5():
+    """LeNet-5 Symbol"""
+    #pylint: disable=no-member
+    data = mx.sym.Variable('data')
+    conv1 = mx.sym.Convolution(data=data, kernel=(5, 5), num_filter=20)
+    tanh1 = mx.sym.Activation(data=conv1, act_type="tanh")
+    pool1 = mx.sym.Pooling(data=tanh1, pool_type="max",
+                           kernel=(2, 2), stride=(2, 2))
+    # second conv
+    conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=50)
+    tanh2 = mx.sym.Activation(data=conv2, act_type="tanh")
+    pool2 = mx.sym.Pooling(data=tanh2, pool_type="max",
+                           kernel=(2, 2), stride=(2, 2))
+    # first fullc
+    flatten = mx.sym.Flatten(data=pool2)
+    fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=500)
+    tanh3 = mx.sym.Activation(data=fc1, act_type="tanh")
+    # second fullc
+    fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10)
+    # loss
+    lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax')
+    #pylint: enable=no-member
+    return lenet
+
+def train_lenet5(num_epochs, batch_size, train_iter, val_iter, test_iter):
+    """train LeNet-5 model on MNIST data"""
+    ctx = mx.gpu(0)
+    lenet_model = mx.mod.Module(lenet5(), context=ctx)
+
+    lenet_model.fit(train_iter,
+                    eval_data=val_iter,
+                    optimizer='sgd',
+                    optimizer_params={'learning_rate': 0.1, 'momentum': 0.9},
+                    eval_metric='acc',
+                    batch_end_callback=mx.callback.Speedometer(batch_size, 1),
+                    num_epoch=num_epochs)
+
+    # predict accuracy for lenet
+    acc = mx.metric.Accuracy()
+    lenet_model.score(test_iter, acc)
+    accuracy = acc.get()[1]
+    assert accuracy > 0.95, "LeNet-5 training accuracy on MNIST was too low"
+    return lenet_model
+
+def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size):
+    """Run inference with either MxNet or TensorRT"""
+
+    shared_buffer = merge_dicts(arg_params, aux_params)
+    if not get_use_tensorrt():
+        shared_buffer = dict([(k, v.as_in_context(mx.gpu(0))) for k, v in shared_buffer.items()])
+    executor = sym.simple_bind(ctx=mx.gpu(0),
+                               data=(batch_size,) +  mnist['test_data'].shape[1:],
+                               softmax_label=(batch_size,),
+                               shared_buffer=shared_buffer,
+                               grad_req='null',
+                               force_rebind=True)
+
+    # Get this value from all_test_labels
+    # Also get classes from the dataset
+    num_ex = 10000
+    all_preds = np.zeros([num_ex, 10])
+    test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
+
+    example_ct = 0
+
+    for idx, dbatch in enumerate(test_iter):
+        executor.arg_dict["data"][:] = dbatch.data[0]
+        executor.forward(is_train=False)
+        offset = idx*batch_size
+        extent = batch_size if num_ex - offset > batch_size else num_ex - offset
+        all_preds[offset:offset+extent, :] = executor.outputs[0].asnumpy()[:extent]
+        example_ct += extent
+
+    all_preds = np.argmax(all_preds, axis=1)
+    matches = (all_preds[:example_ct] == all_test_labels[:example_ct]).sum()
+
+    percentage = 100.0 * matches / example_ct
+
+    return percentage
+
+
+def test_tensorrt_inference():
+    """Run inference comparison between MxNet and TensorRT.
+       This could be used stand-alone or with nosetests."""
+    mnist = mx.test_utils.get_mnist()
+    num_epochs = 10
+    batch_size = 1024
+    model_name = 'lenet5'
+    model_file = '%s-symbol.json' % model_name
+    params_file = '%s-%04d.params' % (model_name, num_epochs)
+
+    _, _, _, all_test_labels = get_iters(mnist, batch_size)
+
+    if not (os.path.exists(model_file) and os.path.exists(params_file)):
+        trained_lenet = train_lenet5(num_epochs, batch_size,
+                                     *get_iters(mnist, batch_size)[:-1])
+        trained_lenet.save_checkpoint(model_name, num_epochs)
+
+    # Load serialized MxNet model (model-symbol.json + model-epoch.params)
+    sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, num_epochs)
+
+    print("Running inference in MxNet")
+    set_use_tensorrt(False)
+    mx_pct = run_inference(sym, arg_params, aux_params, mnist,
+                           all_test_labels, batch_size=batch_size)
+
+    print("Running inference in MxNet-TensorRT")
+    set_use_tensorrt(True)
+    trt_pct = run_inference(sym, arg_params, aux_params, mnist,
+                            all_test_labels,  batch_size=batch_size)
+
+    print("MxNet accuracy: %f" % mx_pct)
+    print("MxNet-TensorRT accuracy: %f" % trt_pct)
+
+    assert abs(mx_pct - trt_pct) < 1e-2, \
+        """Diff. between MxNet & TensorRT accuracy too high:
+           MxNet = %f, TensorRT = %f""" % (mx_pct, trt_pct)
+
+if __name__ == '__main__':
+    test_tensorrt_inference()
+


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services