You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2021/03/05 16:53:55 UTC

[incubator-mxnet] branch v1.8.x updated: Backport TRT fixes to 1.8 (#19983)

This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch v1.8.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.8.x by this push:
     new c16aa91  Backport TRT fixes to 1.8 (#19983)
c16aa91 is described below

commit c16aa91765e56aa6f411e8c3e51b6e6e3c939fe9
Author: Serge Panev <sp...@nvidia.com>
AuthorDate: Sat Mar 6 01:52:18 2021 +0900

    Backport TRT fixes to 1.8 (#19983)
    
    * Update MXNet-TRT doc with the new optimize_for API
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * [1.x] Move block.optimize_for backend_opts to kwargs (#19386)
    
    * Move block.optimize_for backend_opts to kwargs
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Update Hybridize to use kwargs as backend opts
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Fix lint
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Change clear default to False and allow hybrize+optimize_for calls
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Fix nit
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Adress review comments
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Adress more review comments
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Adress more more review comments
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Fix nit
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Add 1:many conversions in nnvm_to_onnx and non-flatten GEMM (#19652)
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
---
 .../performance/backend/tensorrt/tensorrt.md       |  97 ++---
 example/extensions/lib_pass/README.md              |   6 +-
 example/extensions/lib_subgraph/README.md          |   6 +-
 example/extensions/lib_subgraph/test_subgraph.py   |   6 +-
 python/mxnet/gluon/block.py                        |  72 +++-
 src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h  |  69 +++-
 src/operator/subgraph/tensorrt/nnvm_to_onnx.cc     | 389 ++++++++++++++++-----
 tests/python/unittest/test_gluon.py                |   4 -
 8 files changed, 467 insertions(+), 182 deletions(-)

diff --git a/docs/python_docs/python/tutorials/performance/backend/tensorrt/tensorrt.md b/docs/python_docs/python/tutorials/performance/backend/tensorrt/tensorrt.md
index 44082f9..5b0775a 100644
--- a/docs/python_docs/python/tutorials/performance/backend/tensorrt/tensorrt.md
+++ b/docs/python_docs/python/tutorials/performance/backend/tensorrt/tensorrt.md
@@ -33,74 +33,81 @@ from mxnet.gluon.model_zoo import vision
 import time
 import os
 
+ctx=mx.gpu(0)
+
 batch_shape = (1, 3, 224, 224)
-resnet18 = vision.resnet18_v2(pretrained=True)
-resnet18.hybridize()
-resnet18.forward(mx.nd.zeros(batch_shape))
-resnet18.export('resnet18_v2')
-sym, arg_params, aux_params = mx.model.load_checkpoint('resnet18_v2', 0)
+x = mx.nd.zeros(batch_shape, ctx=ctx)
+
+model = vision.resnet18_v2(pretrained=True, ctx=ctx)
+model.hybridize(static_shape=True, static_alloc=True)
+
 ```
-In our first section of code we import the modules needed to run MXNet, and to time our benchmark runs.  We then download a pretrained version of Resnet18, hybridize it, and load it symbolically.  It's important to note that the experimental version of TensorRT integration will only work with the symbolic MXNet API. If you're using Gluon, you must [hybridize](https://gluon.mxnet.io/chapter07_distributed-learning/hybridize.html) your computation graph and export it as a symbol before runn [...]
+In our first section of code we import the modules needed to run MXNet, and to time our benchmark runs.  We then download a pretrained version of Resnet18. We hybridize (link to hybridization) it with static_alloc and static_shape to get the best performance.
 
 ## MXNet Baseline Performance
 ```python
-# Create sample input
-input = mx.nd.zeros(batch_shape)
-
-# Execute with MXNet
-executor = sym.simple_bind(ctx=mx.gpu(0), data=batch_shape, grad_req='null', force_rebind=True)
-executor.copy_params_from(arg_params, aux_params)
-
 # Warmup
-print('Warming up MXNet')
-for i in range(0, 10):
-    y_gen = executor.forward(is_train=False, data=input)
-    y_gen[0].wait_to_read()
+for i in range(0, 1000):
+	out = model(x)
+	mx.nd.waitall()
 
 # Timing
-print('Starting MXNet timed run')
-start = time.process_time()
+start = time.time()
 for i in range(0, 10000):
-    y_gen = executor.forward(is_train=False, data=input)
-    y_gen[0].wait_to_read()
-end = time.time()
-print(time.process_time() - start)
+	out = model(x)
+	mx.nd.waitall()
+print(time.time() - start)
 ```
 
-We are interested in inference performance, so to simplify the benchmark we'll pass a tensor filled with zeros as an input.  We bind a symbol as usual, returning an MXNet executor, and we run forward on this executor in a loop.  To help improve the accuracy of our benchmarks we run a small number of predictions as a warmup before running our timed loop.  On a modern PC with an RTX 2070 GPU the time taken for our MXNet baseline is **17.20s**.  Next we'll run the same model with TensorRT e [...]
+For this experiment we are strictly interested in inference performance, so to simplify the benchmark we'll pass a tensor filled with zeros as an input. 
+To help improve the accuracy of our benchmarks we run a small number of predictions as a warmup before running our timed loop. This will ensure various lazy operations, which do not represent real-world usage, have completed before we measure relative performance improvement. On a system with a V100 GPU, the time taken for our MXNet baseline is **19.5s** (512 samples/s).
 
 ## MXNet with TensorRT Integration Performance
 ```python
-# Execute with TensorRT
-print('Building TensorRT engine')
-trt_sym = sym.get_backend_symbol('TensorRT')
-arg_params, aux_params = mx.contrib.tensorrt.init_tensorrt_params(trt_sym, arg_params, aux_params)
-mx.contrib.tensorrt.set_use_fp16(True)
-executor = trt_sym.simple_bind(ctx=mx.gpu(), data=batch_shape,
-                               grad_req='null', force_rebind=True)
-executor.copy_params_from(arg_params, aux_params)
+[...]
+
+model.optimize_for(x, backend='TensorRT', static_alloc=True, static_shape=True)
+
+[...]
 ```
 
-We use a few TensorRT specific API calls from the contrib package here to setup our parameters and indicate we'd like to run inference in fp16 mode. We then call simple_bind as normal and copy our parameter dictionaries to our executor.
+Next we'll run the same model with TensorRT enabled, and see how the performance compares.
+
+To use TensorRT optimization with the Gluon, we need to call optimize_for with the TensorRT backend and provide some input data that will be used to infer shape and types (any sample representing the inference data). TensorRT backend supports only static shape, so we need to set static_alloc and static_shape to True.
+
+This will run the subgraph partitioning and replace TensorRT compatible subgraphs with TensorRT ops containing the TensorRT engines. It's ready to be used.
 
 ```python
-#Warmup
-print('Warming up TensorRT')
-for i in range(0, 10):
-    y_gen = executor.forward(is_train=False, data=input)
-    y_gen[0].wait_to_read()
+# Warmup
+for i in range(0, 1000):
+	out = model(x)
+	out[0].wait_to_read()
 
 # Timing
-print('Starting TensorRT timed run')
-start = time.process_time()
+start = time.time()
 for i in range(0, 10000):
-    y_gen = executor.forward(is_train=False, data=input)
-    y_gen[0].wait_to_read()
-end = time.time()
-print(time.process_time() - start)
+	out = model(x)
+	out[0].wait_to_read()
+print(time.time() - start)
 ```
 
-We run timing with a warmup once more, and on the same machine, run in **9.83s**. A 1.75x speed improvement!  Speed improvements when using libraries like TensorRT can come from a variety of optimizations, but in this case our speedups are coming from a technique known as [operator fusion](http://ziheng.org/2016/11/21/fusion-and-runtime-compilation-for-nnvm-and-tinyflow/).
+We run timing with a warmup once again, and on the same machine, run in **12.7s** (787 samples/s). A 1.5x speed improvement!  Speed improvements when using libraries like TensorRT can come from a variety of optimizations, but in this case our speedups are coming from a technique known as [operator fusion](http://ziheng.org/2016/11/21/fusion-and-runtime-compilation-for-nnvm-and-tinyflow/).
+
+## FP16
+
+We can give a simple speed up by turning on TensorRT FP16. This optimization comes almost as a freebie and doesn't need any other use effort than adding the optimize_for parameter precision.
+
+```python
+[...]
+
+model.optimize_for(x, backend='TensorRT', static_alloc=True, static_shape=True, backend_opts={'precision':'fp16'})
+
+[...]
+```
+
+We run timing with a warmup once more and we get **7.8s** (1282 samples/s). That's 2.5x speedup compared to the default MXNet!
+All the ops used in ResNet-18 are FP16 compatible, so the TensorRT engine was able to run FP16 kernels, hence the extra speed up.
+
 
 ## Operators and Subgraph Fusion
 
diff --git a/example/extensions/lib_pass/README.md b/example/extensions/lib_pass/README.md
index 18272c0..6f975fd 100644
--- a/example/extensions/lib_pass/README.md
+++ b/example/extensions/lib_pass/README.md
@@ -88,15 +88,15 @@ The `optimize_for` API takes at least 1 argument, `backend` which is a string th
 For the Gluon API, `hybridize` can be called on HybridBlocks to execute a graph pass on the internal CachedOp Symbol.
 
 ```python
-block.hybridize(backend=None, backend_opts=None, **kwargs)
+block.hybridize(backend=None, **kwargs)
 ```
 
-The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which pass that will be executed on the model. The `backend_opts` takes other user-specified options that will be passed to the backend APIs. The actual pass runs once just before the first the forward pass.
+The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which pass that will be executed on the model. `**kwargs` might contain other user-specified options that will be passed to the backend APIs. The actual pass runs once just before the first the forward pass.
 
 If you just want to run a graph pass on the HybridBlock but not run a complete forward pass, you can use the `optimize_for` API that combines the work done in the `hybridize` API with part of the work done in the forward pass.
 
 ```python
-block.optimize_for(x, backend=None, backend_opts=None, **kwargs)
+block.optimize_for(x, backend=None, **kwargs)
 ```
 
 When the `optimize_for` API is called on a HybridBlock it runs the graph pass immediately. This lets users export the modified model without running a complete forward pass.
diff --git a/example/extensions/lib_subgraph/README.md b/example/extensions/lib_subgraph/README.md
index 2752d27..446b659 100644
--- a/example/extensions/lib_subgraph/README.md
+++ b/example/extensions/lib_subgraph/README.md
@@ -107,15 +107,15 @@ The `optimize_for` API takes at least 1 argument, `backend` which is a string th
 For the Gluon API, `hybridize` can be called on HybridBlocks to partition the internal CachedOp Symbol.
 
 ```python
-block.hybridize(backend=None, backend_opts=None, clear=True, **kwargs)
+block.hybridize(backend=None, clear=True, **kwargs)
 ```
 
-The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which backend that will partition the model. The `backend_opts` are other user-specified options (as a Python dictionary of strings mapped to strings) that will be passed to the backend partitioning APIs. The `clear` argument defaults to `True` and clears any previous optimizations done on the block. If you want to chain optimizations together, set ` [...]
+The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which backend that will partition the model. `**kwargs` are other user-specified options (as a Python dictionary of strings mapped to strings) that will be passed to the backend partitioning APIs. The `clear` argument defaults to `False`, so it will chain optimizations together. If you want to clear any previous optimizations done on the block, set ` [...]
 
 If you just want to partition the HybridBlock but not run a complete forward pass, you can use the `optimize_for` API that combines the work done in the `hybridize` API with part of the work done in the forward pass.
 
 ```python
-block.optimize_for(x, backend=None, backend_opts=None, clear=True, **kwargs)
+block.optimize_for(x, backend=None, clear=False, **kwargs)
 ```
 
 When the `optimize_for` API is called on a HybridBlock it partitions immediately. This lets users export the partitioned model without running a complete forward pass. Chaining multiple optimizations is as simple as calling `optimize_for` multiple times, no need to execute a forward pass (as opposed to `hybridize`).
diff --git a/example/extensions/lib_subgraph/test_subgraph.py b/example/extensions/lib_subgraph/test_subgraph.py
index a8b6690..ad400dd 100644
--- a/example/extensions/lib_subgraph/test_subgraph.py
+++ b/example/extensions/lib_subgraph/test_subgraph.py
@@ -92,7 +92,7 @@ def test(backend):
     inputs = [a,b]
     sym_block = nn.SymbolBlock(sym, inputs)
     sym_block.initialize()
-    sym_block.hybridize(backend=backend, backend_opts={'dedup_subgraph':True})
+    sym_block.hybridize(backend=backend, dedup_subgraph=True)
     out2 = sym_block(mx.nd.ones((3,2)),mx.nd.ones((3,2)))
     print(out2)
 
@@ -103,14 +103,14 @@ def test(backend):
     sym_block2 = nn.SymbolBlock(sym, inputs)
     sym_block2.initialize()
     sym_block2.optimize_for(mx.nd.ones((3,2)), mx.nd.ones((3,2)), backend=backend,
-                            backend_opts={'dedup_subgraph':True})
+                            dedup_subgraph=True)
     sym_block2.export('partitioned')
 
     # Test with additional input to subgraph op
     print('-------------------------------')
     print('Testing %s Gluon Hybridize partitioning with extra input' % backend)
     sym_block2.optimize_for(mx.nd.ones((3,2)), mx.nd.ones((3,2)), backend="addInputPass",
-                            clear=False, backend_opts={'dedup_subgraph':True})
+                            dedup_subgraph=True)
     out3 = sym_block2(mx.nd.ones((3,2)),mx.nd.ones((3,2)))
     print(out3)
     
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 10df150..edd3372 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -1080,7 +1080,13 @@ class HybridBlock(Block):
             out = [out]
         return _regroup(out, self._out_format)
 
-    def optimize_for(self, x, *args, backend=None, backend_opts=None, clear=True, **kwargs):
+    def optimize_for(self, x, *args, backend=None, clear=False,
+                     static_alloc=False,
+                     static_shape=False,
+                     inline_limit=2,
+                     forward_bulk_size=None,
+                     backward_bulk_size=None,
+                     **kwargs):
         """Partitions the current HybridBlock and optimizes it for a given backend
         without executing a forward pass. Modifies the HybridBlock in-place.
 
@@ -1108,19 +1114,29 @@ class HybridBlock(Block):
             other inputs to model
         backend : str
             The name of backend, as registered in `SubgraphBackendRegistry`, default None
-        backend_opts : dict of user-specified options to pass to the backend for partitioning, optional
-            Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`
-        clear : clears any previous optimizations
+        clear : bool, default False
+            Clears any previous optimizations
         static_alloc : bool, default False
             Statically allocate memory to improve speed. Memory usage may increase.
         static_shape : bool, default False
             Optimize for invariant input shapes between iterations. Must also
             set static_alloc to True. Change of input shapes is still allowed
             but slower.
+        inline_limit : optional int, default 2
+            Maximum number of operators that can be inlined.
+        forward_bulk_size : optional int, default None
+            Segment size of bulk execution during forward pass.
+        backward_bulk_size : optional int, default None
+            Segment size of bulk execution during forward pass.
+        **kwargs: The backend options, optional
+            Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`
         """
+        if len(kwargs) > 0:
+            self._backend_opts = kwargs
 
-        # do hybrize API call
-        self.hybridize(True, backend, backend_opts, clear, **kwargs)
+        if clear or not self._active:
+            self.hybridize(True, backend, clear, static_alloc, static_shape,
+                           inline_limit, forward_bulk_size, backward_bulk_size)
 
         # do part of forward API call
         has_symbol, has_ndarray, ctx_set, _ = _gather_type_ctx_info([x] + list(args))
@@ -1155,7 +1171,12 @@ class HybridBlock(Block):
         super(HybridBlock, self).register_child(block, name)
         self._clear_cached_op()
 
-    def hybridize(self, active=True, backend=None, backend_opts=None, clear=True, **kwargs):
+    def hybridize(self, active=True, backend=None, clear=True,
+                  static_alloc=False, static_shape=False,
+                  inline_limit=2,
+                  forward_bulk_size=None,
+                  backward_bulk_size=None,
+                  **kwargs):
         """Activates or deactivates :py:class:`HybridBlock` s recursively. Has no effect on
         non-hybrid children.
 
@@ -1165,32 +1186,47 @@ class HybridBlock(Block):
             Whether to turn hybrid on or off.
         backend : str
             The name of backend, as registered in `SubgraphBackendRegistry`, default None
-        backend_opts : dict of user-specified options to pass to the backend for partitioning, optional
-            Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`
-        clear : clears any previous optimizations
-        static_alloc : bool, default False
+        clear : bool, default True
+            Clears any previous optimizations
+        static_alloc : optional bool, default False
             Statically allocate memory to improve speed. Memory usage may increase.
-        static_shape : bool, default False
+        static_shape : optional bool, default False
             Optimize for invariant input shapes between iterations. Must also
             set static_alloc to True. Change of input shapes is still allowed
             but slower.
+        inline_limit : optional int, default 2
+            Maximum number of operators that can be inlined.
+        forward_bulk_size : optional int, default None
+            Segment size of bulk execution during forward pass.
+        backward_bulk_size : optional int, default None
+            Segment size of bulk execution during forward pass.
+        **kwargs:  optional
+            Backend options.
         """
+        if len(kwargs) > 0:
+            self._backend_opts = kwargs
 
         self._backend = backend
-        if backend_opts is not None:
-            assert isinstance(backend_opts, dict), \
-            "HybridBlock hybridize requires backend_opts to be a dictionary."
-            self._backend_opts = backend_opts
 
         self._active = active
-        self._flags = list(kwargs.items())
+        self._flags = [("static_alloc", static_alloc), ("static_shape", static_shape),
+                       ("inline_limit", inline_limit)]
+        if forward_bulk_size is not None:
+            self._flags.append(("forward_bulk_size", forward_bulk_size))
+        if backward_bulk_size is not None:
+            self._flags.append(("backward_bulk_size", backward_bulk_size))
         if clear:
             self._clear_cached_op()
         if active and self._forward_hooks or self._forward_pre_hooks:
             warnings.warn('"{block}" is being hybridized while still having forward hook/pre-hook. '
                           'If "{block}" is a child of HybridBlock, the hooks will not take effect.'
                           .format(block=self))
-        super(HybridBlock, self).hybridize(active, **kwargs)
+        super(HybridBlock, self).hybridize(active,
+                                           static_alloc=static_alloc,
+                                           static_shape=static_shape,
+                                           inline_limit=inline_limit,
+                                           forward_bulk_size=forward_bulk_size,
+                                           backward_bulk_size=backward_bulk_size)
 
     def cast(self, dtype):
         self._clear_cached_op()
diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
index d444e7a..be6ebd0 100644
--- a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
+++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
@@ -47,7 +47,8 @@ using namespace nnvm;
 using namespace ::onnx;
 using int64 = ::google::protobuf::int64;
 
-std::unordered_map<std::string, mxnet::TShape> GetPlaceholderShapes(const ShapeVector& shape_inputs,
+std::unordered_map<std::string, mxnet::TShape> GetPlaceholderShapes(
+    const ShapeVector& shape_inputs,
     const nnvm::IndexedGraph& ig);
 
 std::unordered_map<std::string, int> GetPlaceholderDTypes(const DTypeVector& dtype_inputs,
@@ -70,7 +71,12 @@ void ConvertOutput(GraphProto* graph_proto,
                    const std::string& node_name, const ShapeVector& shapes,
                    const DTypeVector& dtypes, const nnvm::IndexedGraph &ig);
 
-typedef void (*ConverterFunction)(NodeProto *node_proto,
+void DefaultConnectInputsOutputs(const array_view<IndexedGraph::NodeEntry>& inputs,
+                                 const nnvm::IndexedGraph& ig,
+                                 const std::string& node_name);
+
+typedef void (*ConverterFunction)(GraphProto *graph_proto,
+                                  const std::string& node_name,
                                   const NodeAttrs &attrs,
                                   const nnvm::IndexedGraph &ig,
                                   const array_view<IndexedGraph::NodeEntry> &inputs);
@@ -84,88 +90,112 @@ void ConvDeconvConvertHelper(NodeProto *node_proto,
                              ConvDeconvType type);
 
 // Forward declarations
-void ConvertIdentity(NodeProto* node_proto,
+void ConvertIdentity(GraphProto *graph_proto,
+                     const std::string& node_name,
                      const NodeAttrs &attrs,
                      const nnvm::IndexedGraph& ig,
                      const array_view<IndexedGraph::NodeEntry> &inputs);
 
 void ConvertConvolution(
-                        NodeProto *node_proto,
+                        GraphProto *graph_proto,
+                        const std::string& node_name,
                         const NodeAttrs &attrs,
                         const nnvm::IndexedGraph &ig,
                         const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertDeconvolution(NodeProto *node_proto,
+void ConvertDeconvolution(GraphProto *graph_proto,
+                        const std::string& node_name,
                         const NodeAttrs &attrs,
                         const nnvm::IndexedGraph &ig,
                         const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertPooling(NodeProto *node_proto,
+void ConvertPooling(GraphProto *graph_proto,
+                    const std::string& node_name,
                     const NodeAttrs &attrs,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertRelu(NodeProto *node_proto,
+void ConvertRelu(GraphProto *graph_proto,
+                 const std::string& node_name,
                  const NodeAttrs &attrs,
                  const nnvm::IndexedGraph &ig,
                  const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertActivation(NodeProto *node_proto,
+void ConvertActivation(GraphProto *graph_proto,
+                       const std::string& node_name,
                        const NodeAttrs &attrs,
                        const nnvm::IndexedGraph &ig,
                        const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertFullyConnected(NodeProto *node_proto,
+void ConvertFullyConnected(GraphProto *graph_proto,
+                           const std::string& node_name,
                            const NodeAttrs &attrs,
                            const nnvm::IndexedGraph &ig,
                            const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertSoftmaxOutput(NodeProto *node_proto,
+
+void ConvertSlice(GraphProto *graph_proto,
+                  const std::string& node_name,
+                  const NodeAttrs &attrs,
+                  const nnvm::IndexedGraph &ig,
+                  const array_view<IndexedGraph::NodeEntry> &inputs);
+
+void ConvertSoftmaxOutput(GraphProto *graph_proto,
+                          const std::string& node_name,
                           const NodeAttrs &attrs,
                           const nnvm::IndexedGraph &ig,
                           const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertFlatten(NodeProto *node_proto,
+void ConvertFlatten(GraphProto *graph_proto,
+                    const std::string& node_name,
                     const NodeAttrs &attrs,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertDropout(NodeProto *node_proto,
+void ConvertDropout(GraphProto *graph_proto,
+                    const std::string& node_name,
                     const NodeAttrs &attrs,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertBatchNorm(NodeProto *node_proto,
+void ConvertBatchNorm(GraphProto *graph_proto,
+                    const std::string& node_name,
                     const NodeAttrs &attrs,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertElementwiseAdd(NodeProto *node_proto,
+void ConvertElementwiseAdd(GraphProto *graph_proto,
+                    const std::string& node_name,
                     const NodeAttrs &attrs,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertElementwiseMul(NodeProto *node_proto,
+void ConvertElementwiseMul(GraphProto *graph_proto,
+                    const std::string& node_name,
                     const NodeAttrs &attrs,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertElementwiseSub(NodeProto *node_proto,
+void ConvertElementwiseSub(GraphProto *graph_proto,
+                    const std::string& node_name,
                     const NodeAttrs &attrs,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertConcatenate(NodeProto *node_proto,
+void ConvertConcatenate(GraphProto *graph_proto,
+                    const std::string& node_name,
                     const NodeAttrs &attrs,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertClip(NodeProto *node_proto,
+void ConvertClip(GraphProto *graph_proto,
+                 const std::string& node_name,
                  const NodeAttrs &attrs,
                  const nnvm::IndexedGraph &ig,
                  const array_view<IndexedGraph::NodeEntry> &inputs);
 
-void ConvertPad(NodeProto* node_proto,
+void ConvertPad(GraphProto *graph_proto,
+                const std::string& node_name,
                 const NodeAttrs & attrs,
                 const nnvm::IndexedGraph &ig,
                 const array_view<IndexedGraph::NodeEntry> &inputs);
@@ -190,6 +220,7 @@ static const std::unordered_map<std::string, ConverterFunction> converter_map =
   {"Pad", ConvertPad},
   {"Pooling", ConvertPooling},
   {"relu", ConvertRelu},
+  {"slice", ConvertSlice},
   {"SoftmaxOutput", ConvertSoftmaxOutput}
 };
 
diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
index 4f80d27..cdc7151 100644
--- a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
+++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
@@ -130,8 +130,6 @@ std::string ConvertNnvmGraphToOnnx(
       }  // is_placeholder
     } else {
       // It's an op, rather than a "variable" (constant or placeholder)
-      NodeProto* node_proto = graph_proto->add_node();
-      node_proto->set_name(node_name);
       if (converter_map.count(op->name) == 0) {
         LOG(FATAL) << "Conversion for node of type " << op->name << " (node "
                    << node_name << ") "
@@ -140,19 +138,7 @@ std::string ConvertNnvmGraphToOnnx(
       // Find function ptr to a converter based on the op name, and invoke the converter. This
       // looks unsafe because find may not succeed, but it does because we're in the operator
       // logic after testing that this node name does not represent a variable.
-      converter_map.find(op->name)->second(node_proto, attrs, ig, node.inputs);
-      // Add all inputs to the current node (i.e. add graph edges)
-      for (const nnvm::IndexedGraph::NodeEntry& entry : node.inputs) {
-        std::string in_node_name = ig[entry.node_id].source->attrs.name;
-        // As before, we're not adding labels e.g. for SoftmaxOutput, but I wish there was a less
-        // hacky way to do it than name matching.
-        if (in_node_name.find("label") != std::string::npos) {
-          continue;
-        }
-        node_proto->add_input(in_node_name);
-      }
-      // The node's output will have the same name as the node name.
-      node_proto->add_output(node_name);
+      converter_map.find(op->name)->second(graph_proto, node_name, attrs, ig, node.inputs);
       // See if the current node is an output node
       auto out_iter = output_lookup.find(node_name);
       // We found an output
@@ -171,16 +157,113 @@ std::string ConvertNnvmGraphToOnnx(
   return serialized_onnx_graph;
 }
 
-void ConvertIdentity(NodeProto* node_proto, const NodeAttrs& attrs,
-                     const nnvm::IndexedGraph& /*ig*/,
-                     const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void DefaultConnectInputsOutputs(NodeProto *node_proto,
+                                 const array_view<IndexedGraph::NodeEntry>& inputs,
+                                 const nnvm::IndexedGraph& ig,
+                                 const std::string& node_name) {
+  for (const nnvm::IndexedGraph::NodeEntry& entry : inputs) {
+    std::string in_node_name = ig[entry.node_id].source->attrs.name;
+    // As before, we're not adding labels e.g. for SoftmaxOutput, but I wish there was a less
+    // hacky way to do it than name matching.
+    if (in_node_name.find("label") != std::string::npos) {
+      continue;
+    }
+    node_proto->add_input(in_node_name);
+  }
+  // The node's output will have the same name as the node name.
+  node_proto->add_output(node_name);
+}
+
+TensorProto* const Make1DTensor(GraphProto* const graph_proto, const int64_t& size,
+                                const std::string& name, const TensorProto_DataType& dtype) {
+  TensorProto* const initializer_proto = graph_proto->add_initializer();
+  initializer_proto->set_name(name);
+  initializer_proto->set_data_type(dtype);
+  initializer_proto->add_dims(static_cast<int64>(size));
+
+  ValueInfoProto* const input_proto = graph_proto->add_input();
+  input_proto->set_name(name);
+  auto var = input_proto->mutable_type()->mutable_tensor_type();
+  var->set_elem_type(dtype);
+  var->mutable_shape()->add_dim()->set_dim_value(static_cast<int64>(size));
+  return initializer_proto;
+}
+
+// Keep for when ONNX version will be updated
+/*
+void ConvertSlice(GraphProto* const graph_proto, const Node* node, const Graph& g) {
+  const auto& params = nnvm::get<SliceParam>(node->attrs.parsed);
+  int64 nb_slices = static_cast<int64>(params.begin.ndim());
+
+  // starts
+  auto init_starts = Make1DTensor(graph_proto, nb_slices, node->attrs.name + "_starts",
+                                  TensorProto_DataType_INT64);
+  for (auto& opt : params.begin) {
+    if (opt.has_value()) {
+      init_starts->add_int64_data(static_cast<int64>(opt.value()));
+    } else {
+      init_starts->add_int64_data(static_cast<int64>(0));
+    }
+  }
+
+  // ends
+  auto init_ends = Make1DTensor(graph_proto, nb_slices, node->attrs.name + "_ends",
+                                TensorProto_DataType_INT64);
+  for (auto& opt : params.end) {
+    if (opt.has_value()) {
+      init_ends->add_int64_data(static_cast<int64>(opt.value()));
+    } else {
+      init_ends->add_int64_data(static_cast<int64>(INT_MAX));
+    }
+  }
+
+  // axes
+  auto init_axes = Make1DTensor(graph_proto, nb_slices, node->attrs.name + "_axes",
+                                TensorProto_DataType_INT64);
+  for (int64_t i = 0; i < nb_slices; ++i) {
+    init_axes->add_int64_data(static_cast<int64>(i));
+  }
+
+  // slice node
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node->attrs.name);
+  node_proto->set_op_type("Slice");
+  node_proto->add_input(node->inputs[0].node->attrs.name);
+  node_proto->add_input(node->attrs.name + "_starts");
+  node_proto->add_input(node->attrs.name + "_ends");
+  node_proto->add_input(node->attrs.name + "_axes");
+
+  // steps
+  if (params.step.ndim() != 0) {
+    auto init_steps = Make1DTensor(graph_proto, nb_slices, node->attrs.name + "_steps",
+                                   TensorProto_DataType_INT64);
+    for (auto& opt : params.step) {
+      if (opt.has_value()) {
+        init_steps->add_int64_data(static_cast<int64>(opt.value()));
+      } else {
+        init_steps->add_int64_data(static_cast<int64>(1));
+      }
+    }
+    node_proto->add_input(node->attrs.name + "_steps");
+  }
+
+  node_proto->add_output(node->attrs.name);
+}
+*/
+
+void ConvertIdentity(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs,
+                     const nnvm::IndexedGraph& ig,
+                     const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   node_proto->set_op_type("Identity");
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
 template <class ConvDeconvParam>
-void ConvDeconvConvertHelper(NodeProto* node_proto, const NodeAttrs& attrs,
-                             const nnvm::IndexedGraph& /*ig*/,
-                             const array_view<IndexedGraph::NodeEntry>& /*input*/,
+void ConvDeconvConvertHelper(NodeProto *node_proto, const NodeAttrs& attrs,
+                             const nnvm::IndexedGraph& ig,
+                             const array_view<IndexedGraph::NodeEntry>& inputs,
                              const ConvDeconvParam& param,
                              ConvDeconvType type) {
   if (type == ConvDeconvType::Convolution) {
@@ -240,25 +323,36 @@ void ConvDeconvConvertHelper(NodeProto* node_proto, const NodeAttrs& attrs,
   }
 }
 
-void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
+void ConvertConvolution(GraphProto *graph_proto, const std::string& node_name,
+                        const NodeAttrs& attrs,
                         const nnvm::IndexedGraph& ig,
                         const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   const auto& conv_param = nnvm::get<op::ConvolutionParam>(attrs.parsed);
   ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, conv_param,
       ConvDeconvType::Convolution);
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }  // end ConvertConvolution
 
-void ConvertDeconvolution(NodeProto* node_proto, const NodeAttrs& attrs,
+void ConvertDeconvolution(GraphProto *graph_proto, const std::string& node_name,
+                          const NodeAttrs& attrs,
                           const nnvm::IndexedGraph& ig,
                           const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   const auto& deconv_param = nnvm::get<op::DeconvolutionParam>(attrs.parsed);
   ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, deconv_param,
       ConvDeconvType::Deconvolution);
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }  // end ConvertDeconvolution
 
-void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
-                    const nnvm::IndexedGraph& /*ig*/,
-                    const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertPooling(GraphProto *graph_proto, const std::string& node_name,
+                    const NodeAttrs& attrs,
+                    const nnvm::IndexedGraph& ig,
+                    const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   const auto& pooling_param = nnvm::get<op::PoolingParam>(attrs.parsed);
 
   const mxnet::TShape kernel = pooling_param.kernel;
@@ -275,6 +369,7 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
     } else {
       LOG(FATAL) << "Pool type of node '" << attrs.name << "' unsupported: " << attrs.name;
     }
+    DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
     return;
   }
 
@@ -329,17 +424,24 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
   } else {
     count_include_pad->set_i(1);
   }
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }  // end ConvertPooling
 
-void ConvertRelu(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
-                 const nnvm::IndexedGraph& /*ig*/,
-                 const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertRelu(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& /*attrs*/,
+                 const nnvm::IndexedGraph& ig,
+                 const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   node_proto->set_op_type("Relu");
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertActivation(NodeProto* node_proto, const NodeAttrs& attrs,
-                       const nnvm::IndexedGraph& /*ig*/,
-                       const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertActivation(GraphProto *graph_proto, const std::string& node_name,
+                       const NodeAttrs& attrs,
+                       const nnvm::IndexedGraph& ig,
+                       const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   const auto& act_param = nnvm::get<op::ActivationParam>(attrs.parsed);
   std::string act_type;
   switch (act_param.act_type) {
@@ -361,42 +463,120 @@ void ConvertActivation(NodeProto* node_proto, const NodeAttrs& attrs,
   }
 
   node_proto->set_op_type(act_type);
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertFullyConnected(NodeProto* node_proto, const NodeAttrs& attrs,
-                           const nnvm::IndexedGraph& /*ig*/,
-                           const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertFullyConnected(GraphProto *graph_proto, const std::string& node_name,
+                           const NodeAttrs& attrs,
+                           const nnvm::IndexedGraph& ig,
+                           const array_view<IndexedGraph::NodeEntry>& inputs) {
   const auto& act_param = nnvm::get<op::FullyConnectedParam>(attrs.parsed);
-  if (act_param.no_bias) {
-      node_proto->set_op_type("MatMul");
+  // ONNX spec doesn't support GEMMs with input of different dims, so we need to replace it
+  // by Transpose+MatMul+Add
+  if (!act_param.flatten && !act_param.no_bias) {
+    NodeProto* tranpose_node_proto = graph_proto->add_node();
+    NodeProto* matmul_node_proto = graph_proto->add_node();
+    NodeProto* add_node_proto = graph_proto->add_node();
+    tranpose_node_proto->set_name(node_name+"_Transpose");
+    matmul_node_proto->set_name(node_name+"_MatMul");
+    add_node_proto->set_name(node_name+"_Add");
+
+    tranpose_node_proto->set_op_type("Transpose");
+    matmul_node_proto->set_op_type("MatMul");
+    add_node_proto->set_op_type("Add");
+
+    std::string input_node_name = ig[inputs[op::conv::kData].node_id].source->attrs.name;
+    std::string weight_node_name = ig[inputs[op::conv::kWeight].node_id].source->attrs.name;
+    std::string bias_node_name = ig[inputs[op::conv::kBias].node_id].source->attrs.name;
+
+    tranpose_node_proto->add_input(weight_node_name);
+    tranpose_node_proto->add_output(node_name+"_Transpose");
+
+    matmul_node_proto->add_input(input_node_name);
+    matmul_node_proto->add_input(node_name+"_Transpose");
+    matmul_node_proto->add_output(node_name+"_MatMul");
+
+    add_node_proto->add_input(node_name+"_MatMul");
+    add_node_proto->add_input(bias_node_name);
+    // Add's output is the output of the Transpose+MatMul+Add subgraph
+    add_node_proto->add_output(node_name);
   } else {
-      node_proto->set_op_type("Gemm");
-
-      AttributeProto* const alpha = node_proto->add_attribute();
-      alpha->set_name("alpha");
-      alpha->set_type(AttributeProto::FLOAT);
-      alpha->set_f(1.0f);
-
-      AttributeProto* const beta = node_proto->add_attribute();
-      beta->set_name("beta");
-      beta->set_type(AttributeProto::FLOAT);
-      beta->set_f(1.0f);
-
-      AttributeProto* const transA = node_proto->add_attribute();
-      transA->set_name("transA");
-      transA->set_type(AttributeProto::INT);
-      transA->set_i(0);
+    NodeProto* node_proto = graph_proto->add_node();
+    node_proto->set_name(node_name);
+    if (act_param.no_bias) {
+        node_proto->set_op_type("MatMul");
+    } else {
+        node_proto->set_op_type("Gemm");
+
+        AttributeProto* const alpha = node_proto->add_attribute();
+        alpha->set_name("alpha");
+        alpha->set_type(AttributeProto::FLOAT);
+        alpha->set_f(1.0f);
+
+        AttributeProto* const beta = node_proto->add_attribute();
+        beta->set_name("beta");
+        beta->set_type(AttributeProto::FLOAT);
+        beta->set_f(1.0f);
+
+        AttributeProto* const transA = node_proto->add_attribute();
+        transA->set_name("transA");
+        transA->set_type(AttributeProto::INT);
+        transA->set_i(0);
+
+        AttributeProto* const transB = node_proto->add_attribute();
+        transB->set_name("transB");
+        transB->set_type(AttributeProto::INT);
+        transB->set_i(1);
+    }
+    DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
+  }
+}
 
-      AttributeProto* const transB = node_proto->add_attribute();
-      transB->set_name("transB");
-      transB->set_type(AttributeProto::INT);
-      transB->set_i(1);
+void ConvertSlice(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs,
+                  const nnvm::IndexedGraph& ig,
+                  const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
+  const auto& params = nnvm::get<SliceParam>(attrs.parsed);
+  node_proto->set_op_type("Slice");
+
+  // starts
+  AttributeProto* const starts = node_proto->add_attribute();
+  starts->set_name("starts");
+  starts->set_type(AttributeProto::INTS);
+
+  // ends
+  AttributeProto* const ends = node_proto->add_attribute();
+  ends->set_name("ends");
+  ends->set_type(AttributeProto::INTS);
+
+  // axes
+  AttributeProto* const axes = node_proto->add_attribute();
+  axes->set_name("axes");
+  axes->set_type(AttributeProto::INTS);
+
+  for (int64_t i = 1; i < params.begin.ndim(); ++i) {
+    if (params.begin[i].has_value()) {
+      starts->add_ints(static_cast<int64>(params.begin[i].value()));
+    } else {
+      starts->add_ints(static_cast<int64>(0));
+    }
+    if (params.end[i].has_value()) {
+      ends->add_ints(static_cast<int64>(params.end[i].value()));
+    } else {
+      ends->add_ints(static_cast<int64>(INT_MAX));
+    }
+    axes->add_ints(static_cast<int64>(i));
   }
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertSoftmaxOutput(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
-                          const nnvm::IndexedGraph& /*ig*/,
-                          const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertSoftmaxOutput(GraphProto *graph_proto, const std::string& node_name,
+                          const NodeAttrs& /*attrs*/,
+                          const nnvm::IndexedGraph& ig,
+                          const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   node_proto->set_op_type("Softmax");
 
   // Setting by default to 1 since MXNet doesn't provide such an attribute for softmax in its
@@ -406,11 +586,16 @@ void ConvertSoftmaxOutput(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
   axis->set_name("axis");
   axis->set_type(AttributeProto::INT);
   axis->set_i(1);
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertFlatten(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
-                    const nnvm::IndexedGraph& /*ig*/,
-                    const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+
+void ConvertFlatten(GraphProto *graph_proto, const std::string& node_name,
+                    const NodeAttrs& /*attrs*/,
+                    const nnvm::IndexedGraph& ig,
+                    const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   node_proto->set_op_type("Flatten");
 
   // Setting by default to 1 since MXNet doesn't provide such an attribute for Flatten in its
@@ -420,11 +605,15 @@ void ConvertFlatten(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
   axis->set_name("axis");
   axis->set_type(AttributeProto::INT);
   axis->set_i(1);
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs,
-                      const nnvm::IndexedGraph& /*ig*/,
-                      const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertBatchNorm(GraphProto *graph_proto, const std::string& node_name,
+                      const NodeAttrs& attrs,
+                      const nnvm::IndexedGraph& ig,
+                      const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   node_proto->set_op_type("BatchNormalization");
   const auto& param = nnvm::get<op::BatchNormParam>(attrs.parsed);
 
@@ -445,29 +634,45 @@ void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs,
   // (default in ONNX3) implies running batchnorm on all spatial features so we need to explicitly
   // disable this for MXNet's BatchNorm.
   spatial->set_i(0);
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertElementwiseAdd(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
-                           const nnvm::IndexedGraph& /*ig*/,
-                           const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertElementwiseAdd(GraphProto *graph_proto, const std::string& node_name,
+                           const NodeAttrs& /*attrs*/,
+                           const nnvm::IndexedGraph& ig,
+                           const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   node_proto->set_op_type("Add");
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertElementwiseSub(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
-                           const nnvm::IndexedGraph& /*ig*/,
-                           const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertElementwiseSub(GraphProto *graph_proto, const std::string& node_name,
+                           const NodeAttrs& /*attrs*/,
+                           const nnvm::IndexedGraph& ig,
+                           const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   node_proto->set_op_type("Sub");
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertElementwiseMul(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
-                           const nnvm::IndexedGraph& /*ig*/,
-                           const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertElementwiseMul(GraphProto *graph_proto, const std::string& node_name,
+                           const NodeAttrs& /*attrs*/,
+                           const nnvm::IndexedGraph& ig,
+                           const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   node_proto->set_op_type("Mul");
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertConcatenate(NodeProto* node_proto, const NodeAttrs& attrs,
-                        const nnvm::IndexedGraph& /*ig*/,
-                        const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertConcatenate(GraphProto *graph_proto, const std::string& node_name,
+                        const NodeAttrs& attrs,
+                        const nnvm::IndexedGraph& ig,
+                        const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   const auto& _param = nnvm::get<ConcatParam>(attrs.parsed);
   node_proto->set_op_type("Concat");
   node_proto->set_name(attrs.name);
@@ -476,6 +681,7 @@ void ConvertConcatenate(NodeProto* node_proto, const NodeAttrs& attrs,
   axis->set_name("axis");
   axis->set_type(AttributeProto::INT);
   axis->set_i(static_cast<int64_t>(_param.dim));
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
 inline TensorProto_DataType ConvertDType(int dtype) {
@@ -630,9 +836,11 @@ void ConvertOutput(
   }
 }
 
-void ConvertClip(NodeProto* node_proto, const NodeAttrs& attrs,
-                 const nnvm::IndexedGraph& /*ig*/,
-                 const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertClip(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs,
+                 const nnvm::IndexedGraph& ig,
+                 const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   const auto& param = nnvm::get<ClipParam>(attrs.parsed);
 
   node_proto->set_op_type("Clip");
@@ -648,11 +856,14 @@ void ConvertClip(NodeProto* node_proto, const NodeAttrs& attrs,
   a_min->set_name("min");
   a_min->set_type(AttributeProto::FLOAT);
   a_min->set_f(static_cast<float>(param.a_min));
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertPad(NodeProto* node_proto, const NodeAttrs& attrs,
-                const nnvm::IndexedGraph& /*ig*/,
-                const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertPad(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs,
+                const nnvm::IndexedGraph& ig,
+                const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   const auto& param = nnvm::get<PadParam>(attrs.parsed);
 
   node_proto->set_op_type("Pad");
@@ -694,12 +905,16 @@ void ConvertPad(NodeProto* node_proto, const NodeAttrs& attrs,
   value->set_name("value");
   value->set_type(AttributeProto::FLOAT);
   value->set_f(param.constant_value);
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
-void ConvertDropout(NodeProto* node_proto, const NodeAttrs& attrs,
-                    const nnvm::IndexedGraph& /*ig*/,
-                    const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
+void ConvertDropout(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs,
+                    const nnvm::IndexedGraph& ig,
+                    const array_view<IndexedGraph::NodeEntry>& inputs) {
+  NodeProto* node_proto = graph_proto->add_node();
+  node_proto->set_name(node_name);
   node_proto->set_op_type("Dropout");
+  DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
 void PreprocessBatchNorm(const NodeAttrs &attrs,
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 49b84a2..0eb8340 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -2021,10 +2021,6 @@ def test_share_inputs_outputs():
             res = t(d1)
             assert_almost_equal(res.asnumpy(), d1.asnumpy())
 
-    param = deepcopy(params[2])
-    param['param_indices'] = (1)
-    param['data_indices'] = (0)
-    params.append(param)
     # Test the case that inputs and outputs of a backward graph share NDArrays.
     for param in params:
         t = TestIOBackward()