You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2020/04/16 02:40:41 UTC

[incubator-mxnet] branch v1.7.x updated: [1.7] MXNet Extension PRs (#17623, #17569, #17762) (#18063)

This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch v1.7.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.7.x by this push:
     new bf99f27  [1.7] MXNet Extension PRs (#17623, #17569, #17762) (#18063)
bf99f27 is described below

commit bf99f275b53549b94dff16b8ad74291448d1d47e
Author: Sam Skalicky <sa...@gmail.com>
AuthorDate: Wed Apr 15 19:37:53 2020 -0700

    [1.7] MXNet Extension PRs (#17623, #17569, #17762) (#18063)
    
    * Dynamic subgraph compile support (#17623)
    
    This PR adds support for passing the NDArrays from the existing optimize_for API down to the reviewSubgraph function in an external library. It also adds a new API for HybridBlock called optimize_for that can partition the model without running a forward pass.
    
    Feature changes
    
        Adds new API to HybridBlock optimize_for that partitions the model but does not call the cachedOp
        Modifies the subgraph library example to optionally require args to be provided
        Adds annotation on subgraph inputs for the name of the original param so that inputs can be mapped and passes annotations to input nodes of subgraphs
        Adds support for tensors in MKLDNN format, calls Reorder2Default
    
    New tests
    
        Adds a new test to partition operators that directly consume params
        add a new model to test where ops to be partitioned have args/params
    
    Bug Fixes
    
        fixes bug in passing ids vector by value instead of by reference
        fixes bug in passing copies of attributes instead of by reference
        fixes bug where _cached_graph was not updated after partitioning
        fixes memory leak where user-specified attributes on subgraph ops were not freed if subgraph was rejected
        fixes problem incorrectly indexing into shape/dtype maps when annotating the graph
    
    Docs
    
        Updates the README doc with the latest changes described above
    
    * Adding sparse support to MXTensor for custom operators (#17569)
    
    * Added enum for sparse storage
    
    * Add structure for Dense and Sparse
    
    * redesign the data structure for MXSparse
    
    * pull out aux data from sparse NDArray
    
    * Added more sparse arguments to API interface
    
    * Passed sparse from c_api to lib_api.h and set in MXTensor
    
    * Fix indent
    
    * fix segfault
    
    * Fix NDArray to MXTensor errors
    
    * Add a sample of sparse(CSR) transpose
    
    * Make CSR transpose temporarily work by hardcoding
    
    * Fixed sparse output size(Refined)
    
    * Add tests for symbolic and stateful ops
    
    * Added a sample for row sparse transpose
    
    * Added real row sparse transpose
    
    * Fix output size issue by adding lambda for CheckAndAlloc()
    
    * Fix mixed storage formats error
    
    * Added infer storage type function
    
    * resolve comments
    
    * Set inferSType as optional function
    
    * Resolve comments
    
    * Add error messages
    
    * Resolve comments
    
    * verify transpose ops results
    
    * fix sanity check
    
    * update MX_LIBRARY_VERSION to 5
    
    * Custom Operator Random Number Generator Support (#17762)
    
    Add random number generator support for custom operator libraries.
    
    Design: We pass from MXNet the initialized and seeded states, located on CPU and GPU, to custom library. So user could use those seeds to generate deterministic values from a given seed passed to MXNet. Basically this workflow:
    
    mx.random.seed(128)
    r1 = mx.nd.some_custom_random_op(data)
    mx.random.seed(128)
    r2 = mx.nd.some_custom_random_op(data)
    assert (r1 == r2)
    
    This PR does not let custom library generate exactly the same sequence of random numbers comparing to MXNet
    
    This is a continuation of the custom operator project #15921 and #17270
    
    Co-authored-by: guanxinq <58...@users.noreply.github.com>
    Co-authored-by: Ziyi Mu <zi...@columbia.edu>
---
 CMakeLists.txt                                     |  15 +-
 example/extensions/lib_custom_op/Makefile          |  10 +-
 example/extensions/lib_custom_op/relu_lib.cu       |  90 ++++-
 example/extensions/lib_custom_op/test_relu.py      |  43 ++-
 .../extensions/lib_custom_op/test_transposecsr.py  |  78 ++++
 .../lib_custom_op/test_transposerowsp.py           |  73 ++++
 .../extensions/lib_custom_op/transposecsr_lib.cc   | 197 ++++++++++
 .../extensions/lib_custom_op/transposerowsp_lib.cc | 199 ++++++++++
 example/extensions/lib_subgraph/README.md          |  69 ++--
 example/extensions/lib_subgraph/subgraph_lib.cc    |  45 ++-
 example/extensions/lib_subgraph/test_subgraph.py   |  62 ++-
 include/mxnet/c_api.h                              |   4 +-
 include/mxnet/lib_api.h                            | 422 +++++++++++++++++++--
 include/mxnet/random_generator.h                   |   8 +
 perl-package/AI-MXNetCAPI/mxnet.i                  |   2 +
 python/mxnet/gluon/block.py                        |  71 +++-
 python/mxnet/symbol/symbol.py                      |  23 +-
 src/c_api/c_api.cc                                 | 213 ++++++++---
 src/c_api/c_api_symbolic.cc                        |  70 +++-
 src/common/random_generator.cu                     |   5 +
 src/operator/subgraph/build_subgraph.cc            |   9 +-
 .../partitioner/custom_subgraph_property.h         | 247 ++++++++++--
 src/operator/subgraph/subgraph_property.h          |   8 +
 tests/python/gpu/test_extensions_gpu.py            |  18 +-
 tests/python/unittest/test_extensions.py           |  14 +
 tests/python/unittest/test_subgraph_op.py          |   4 +-
 26 files changed, 1779 insertions(+), 220 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4731663..fac6e64 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -590,7 +590,7 @@ if(USE_CUDA)
   message("-- CUDA: Using the following NVCC architecture flags ${CUDA_ARCH_FLAGS}")
   set(arch_code_list)
   foreach(arch_str ${CUDA_ARCH_FLAGS})
-    if((arch_str MATCHES ".*sm_[0-9]+")) 
+    if((arch_str MATCHES ".*sm_[0-9]+"))
       string( REGEX REPLACE  ".*sm_([0-9]+)" "\\1" arch_code ${arch_str} )
       list(APPEND arch_code_list ${arch_code})
     endif()
@@ -730,26 +730,21 @@ elseif(MSVC)
 
 endif()
 
+# extension libraries (custom operators, custom subgraphs) are built by default
 add_library(customop_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/gemm_lib.cc)
 add_library(subgraph_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_subgraph/subgraph_lib.cc)
 target_include_directories(customop_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
 target_include_directories(subgraph_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
-if (USE_CUDA)
+if(USE_CUDA)
   add_library(customop_gpu_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/relu_lib.cu)
   target_include_directories(customop_gpu_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
 endif()
-if(UNIX)
-  target_compile_options(customop_lib PUBLIC -shared)
-  target_compile_options(subgraph_lib PUBLIC -shared)
-  if (USE_CUDA)
-    target_compile_options(customop_gpu_lib PUBLIC -shared)
-  endif()
-elseif(MSVC)
+if(MSVC)
   target_compile_options(customop_lib PUBLIC /LD)
   target_compile_options(subgraph_lib PUBLIC /LD)
   set_target_properties(customop_lib PROPERTIES PREFIX "lib")
   set_target_properties(subgraph_lib PROPERTIES PREFIX "lib")
-  if (USE_CUDA)
+  if(USE_CUDA)
     target_compile_options(customop_gpu_lib PUBLIC "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-fPIC>")
     set_target_properties(customop_gpu_lib PROPERTIES PREFIX "lib")
   endif()
diff --git a/example/extensions/lib_custom_op/Makefile b/example/extensions/lib_custom_op/Makefile
index edd753b..feded29 100644
--- a/example/extensions/lib_custom_op/Makefile
+++ b/example/extensions/lib_custom_op/Makefile
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-all: gemm_lib relu_lib
+all: gemm_lib relu_lib transposecsr_lib transposerowsp_lib
 
 gemm_lib:
 	g++ -shared -fPIC -std=c++11 gemm_lib.cc -o libgemm_lib.so -I ../../../include/mxnet
@@ -23,5 +23,11 @@ gemm_lib:
 relu_lib:
 	nvcc -shared -std=c++11 -Xcompiler -fPIC relu_lib.cu -o librelu_lib.so -I ../../../include/mxnet
 
+transposecsr_lib:
+	g++ -shared -fPIC -std=c++11 transposecsr_lib.cc -o libtransposecsr_lib.so -I ../../../include/mxnet
+
+transposerowsp_lib:
+	g++ -shared -fPIC -std=c++11 transposerowsp_lib.cc -o libtransposerowsp_lib.so -I ../../../include/mxnet
+
 clean:
-	rm -rf libgemm_lib.so librelu_lib.so
+	rm -rf libgemm_lib.so librelu_lib.so libtransposecsr_lib.so libtransposerowsp_lib.so
diff --git a/example/extensions/lib_custom_op/relu_lib.cu b/example/extensions/lib_custom_op/relu_lib.cu
index 3beb68c..60112ee 100644
--- a/example/extensions/lib_custom_op/relu_lib.cu
+++ b/example/extensions/lib_custom_op/relu_lib.cu
@@ -20,12 +20,14 @@
 /*!
  * Copyright (c) 2020 by Contributors
  * \file relu_lib.cu
- * \brief simple custom relu operator implemented using CUDA function
+ * \brief simple custom relu and noisy relu operator implemented using CUDA function
  */
 
 #include <iostream>
 #include "lib_api.h"
 
+#define NumThreadPerBlock 256 // mxnet recommended cuda thread number per block
+
 __global__ void relu_gpu_forward(float *out, float *in, int64_t N) {
     int tid = blockIdx.x * blockDim.x + threadIdx.x;
     if (tid < N)
@@ -72,9 +74,9 @@ MXReturnValue forwardGPU(std::map<std::string, std::string> attrs,
 
     mx_stream_t cuda_stream = res.get_cuda_stream();
     int64_t N = inputs[0].size();
-    int block = 256;
-    int grid = (N + (block - 1)) / block;
-    relu_gpu_forward<<<grid,block,0,cuda_stream>>>(out_data, in_data, N);
+    int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock;
+
+    relu_gpu_forward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(out_data, in_data, N);
 
     return MX_SUCCESS;
 }
@@ -89,9 +91,9 @@ MXReturnValue backwardGPU(std::map<std::string, std::string> attrs,
 
     mx_stream_t cuda_stream = res.get_cuda_stream();
     int64_t N = inputs[0].size();
-    int block = 256;
-    int grid = (N + (block - 1)) / block;
-    relu_gpu_backward<<<grid,block,0,cuda_stream>>>(in_grad, out_grad, in_data, N);
+    int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock;
+
+    relu_gpu_backward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(in_grad, out_grad, in_data, N);
 
     return MX_SUCCESS;
 }
@@ -180,6 +182,80 @@ REGISTER_OP(my_state_relu)
 .setCreateOpState(createOpStateCPU, "cpu")
 .setCreateOpState(createOpStateGPU, "gpu");
 
+/*
+ * Below is noisy ReLU operator example
+ * noisy ReLU is made from ReLU extended to include Gaussian noise
+ * forward - add Gaussian noise generated from normal distribution to each unit
+ * backward - gradient doesn't need to change since noise is constant
+ */
+
+#define NumRandomPerThread 64 // mxnet recommended random numbers generated per thread
+
+__global__ void noisy_relu_gpu_forward(float *out, float *in, int64_t N, mx_gpu_rand_t* states, int step) {
+    // the launcher logic ensures tid less than NumGPURandomStates
+    int tid = blockIdx.x * blockDim.x + threadIdx.x;
+    // each thread generates unique sequence of random numbers
+    mx_gpu_rand_t thread_state = states[tid];
+    // each thread works on <step> number of calculation
+    int start = tid * step;
+    int end = start + step;
+    for (int i=start; i<end && i<N; ++i) {
+        float noise = curand_normal(&thread_state);
+        out[i] = in[i] + noise > 0 ? in[i] + noise : 0;
+    }
+}
+
+MXReturnValue noisyForwardCPU(std::map<std::string, std::string> attrs,
+                              std::vector<MXTensor> inputs,
+                              std::vector<MXTensor> outputs,
+                              OpResource res) {
+    float* in_data = inputs[0].data<float>();
+    float* out_data = outputs[0].data<float>();
+
+    mx_cpu_rand_t* states = res.get_cpu_rand_states();
+    std::normal_distribution<float> dist_normal;
+
+    for (int i=0; i<inputs[0].size(); ++i) {
+        float noise = dist_normal(*states);
+        out_data[i] = in_data[i] + noise > 0 ? in_data[i] + noise : 0;
+    }
+    return MX_SUCCESS;
+}
+
+MXReturnValue noisyForwardGPU(std::map<std::string, std::string> attrs,
+                              std::vector<MXTensor> inputs,
+                              std::vector<MXTensor> outputs,
+                              OpResource res) {
+    float* in_data = inputs[0].data<float>();
+    float* out_data = outputs[0].data<float>();
+
+    mx_stream_t cuda_stream = res.get_cuda_stream();
+    int64_t N = inputs[0].size();
+
+    // below is mxnet recommended workflow to parallel random number generating
+    int nthread = (N + NumRandomPerThread - 1) / NumRandomPerThread;
+    // we should not launch more threads than mxnet supported random number GPU states
+    int num_thread_need = nthread < MX_NUM_GPU_RANDOM_STATES ? nthread : MX_NUM_GPU_RANDOM_STATES;
+    // each cuda thread processes [step * tid, step * id + step) snippet of input tensor
+    int step = (N + num_thread_need - 1) / num_thread_need;
+    // this can ensure number of parallel threads less than mxnet supported random number states
+    int num_block = (num_thread_need + NumThreadPerBlock - 1) / NumThreadPerBlock;
+
+    noisy_relu_gpu_forward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(
+                                out_data, in_data, N, res.get_gpu_rand_states(), step);
+
+    return MX_SUCCESS;
+}
+
+REGISTER_OP(my_noisy_relu)
+.setParseAttrs(parseAttrs)
+.setInferType(inferType)
+.setInferShape(inferShape)
+.setForward(noisyForwardCPU, "cpu")
+.setForward(noisyForwardGPU, "gpu")
+.setBackward(backwardCPU, "cpu")
+.setBackward(backwardGPU, "gpu");
+
 MXReturnValue initialize(int version) {
     if (version >= 10400) {
         std::cout << "MXNet version " << version << " supported" << std::endl;
diff --git a/example/extensions/lib_custom_op/test_relu.py b/example/extensions/lib_custom_op/test_relu.py
index 03d02f3..a37ea25 100644
--- a/example/extensions/lib_custom_op/test_relu.py
+++ b/example/extensions/lib_custom_op/test_relu.py
@@ -35,13 +35,13 @@ if (os.name=='posix'):
 a = mx.nd.array([[-2,-1],[1,2]], ctx=mx.cpu())
 b = mx.nd.array([[-2,-1],[1,2]], ctx=mx.gpu())
 
-print("--------start ndarray compute---------")
+print("--------ndarray compute---------")
 print(mx.nd.my_relu(a))
 print(mx.nd.my_relu(b))
 print(mx.nd.my_state_relu(a))
 print(mx.nd.my_state_relu(b))
 
-print("--------start symbolic compute--------")
+print("--------symbolic compute--------")
 c = mx.sym.Variable('c')
 d = mx.sym.Variable('d')
 e = mx.sym.my_relu(c)
@@ -55,30 +55,41 @@ out_base = exe_base.forward()
 print(out)
 print(out_base)
 
-print("--------start backward compute--------")
+print("--------backward compute--------")
 out_grad = mx.nd.ones((2,2), ctx=mx.gpu())
 exe.backward([out_grad])
 exe_base.backward([out_grad])
 print(in_grad)
 print(in_grad_base)
 
-print("--------start testing larger ndarray---------")
-a = mx.nd.uniform(shape=(100,100,100), ctx=mx.cpu())
+print("--------test ndarray with size of 1 million---------")
 b = mx.nd.uniform(shape=(100,100,100), ctx=mx.gpu())
 mx.nd.waitall()
 t1 = time.time()
-r1 = mx.nd.my_relu(a)
+r1 = mx.nd.my_relu(b)
 mx.nd.waitall()
 t2 = time.time()
-r2 = mx.nd.my_relu(b)
+r2 = mx.nd.relu(b)
 mx.nd.waitall()
 t3 = time.time()
-r3 = mx.nd.relu(b)
-mx.nd.waitall()
-t4 = time.time()
-print("CPU running time:")
-print(t2 - t1)
-print("GPU running time:")
-print(t3 - t2)
-print("Baseline GPU running time:")
-print(t4 - t3)
+print("Custom ReLU running time in ms:")
+print((t2 - t1) * 1000)
+print("Native ReLU running time in ms:")
+print((t3 - t2) * 1000)
+
+print("--------test noisy relu identical sequence---------")
+
+a = mx.nd.ones(shape=(13,5), ctx=mx.cpu())
+b = mx.nd.ones(shape=(13,5), ctx=mx.gpu())
+
+mx.random.seed(128, ctx=mx.cpu())
+print(mx.nd.my_noisy_relu(a))
+
+mx.random.seed(128, ctx=mx.cpu())
+print(mx.nd.my_noisy_relu(a))
+
+mx.random.seed(128, ctx=mx.gpu())
+print(mx.nd.my_noisy_relu(b))
+
+mx.random.seed(128, ctx=mx.gpu())
+print(mx.nd.my_noisy_relu(b))
diff --git a/example/extensions/lib_custom_op/test_transposecsr.py b/example/extensions/lib_custom_op/test_transposecsr.py
new file mode 100644
index 0000000..37d066a
--- /dev/null
+++ b/example/extensions/lib_custom_op/test_transposecsr.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python3
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=arguments-differ
+
+# This test checks dynamic loading of custom library into MXNet
+# and checks end to end compute of a simple 2D gemm custom op
+
+import mxnet as mx
+import os
+
+#load library
+if (os.name=='posix'):
+    path = os.path.abspath('libtransposecsr_lib.so')
+    mx.library.load(path)
+elif (os.name=='nt'):
+    path = os.path.abspath('libtransposecsr_lib.dll')
+    mx.library.load(path)
+
+a = mx.nd.array([[1,3,0,2,1],[0,1,0,0,0],[0,2,4,5,3]])
+a = a.tostype('csr')
+print("--------Input CSR Array---------")
+print("data:", a.data.asnumpy())
+print("indices:", a.indices.asnumpy())
+print("indptr:", a.indptr.asnumpy())
+
+print("--------Start NDArray Compute---------")
+b = mx.nd.my_transposecsr(a)
+print("Compute Results:")
+print("data:", b.data.asnumpy())
+print("indices:", b.indices.asnumpy())
+print("indptr:", b.indptr.asnumpy())
+
+print("Stateful Compute Result:")
+c = mx.nd.my_state_transposecsr(a, test_kw=100)
+print("data:", c.data.asnumpy())
+print("indices:", c.indices.asnumpy())
+print("indptr:", c.indptr.asnumpy())
+
+print("--------start symbolic compute--------")
+d = mx.sym.Variable('d')
+e = mx.sym.my_transposecsr(d)
+f = mx.sym.my_state_transposecsr(d, test_kw=200)
+
+exe = e.bind(ctx=mx.cpu(),args={'d':a})
+exe2 = f.bind(ctx=mx.cpu(),args={'d':a})
+out = exe.forward()
+print("Compute Results:")
+print("data:", out[0].data.asnumpy())
+print("indices:", out[0].indices.asnumpy())
+print("indptr:", out[0].indptr.asnumpy())
+
+out2 = exe2.forward()
+out2 = exe2.forward()
+print("Stateful Compute Result:")
+print("data:", out2[0].data.asnumpy())
+print("indices:", out2[0].indices.asnumpy())
+print("indptr:", out2[0].indptr.asnumpy())
+
+print("--------Baseline(dense)--------")
+print(mx.nd.transpose(a.tostype('default')))
diff --git a/example/extensions/lib_custom_op/test_transposerowsp.py b/example/extensions/lib_custom_op/test_transposerowsp.py
new file mode 100644
index 0000000..cea62ec
--- /dev/null
+++ b/example/extensions/lib_custom_op/test_transposerowsp.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python3
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=arguments-differ
+
+# This test checks dynamic loading of custom library into MXNet
+# and checks end to end compute of a simple 2D gemm custom op
+
+import mxnet as mx
+import os
+
+#load library
+if (os.name=='posix'):
+    path = os.path.abspath('libtransposerowsp_lib.so')
+    mx.library.load(path)
+elif (os.name=='nt'):
+    path = os.path.abspath('libtransposerowsp_lib.dll')
+    mx.library.load(path)
+
+a = mx.nd.array([[1,2,3],[0,0,0],[4,0,5],[0,0,0],[0,0,0]])
+a = a.tostype('row_sparse')
+print("--------Input CSR Array---------")
+print("data:", a.data.asnumpy())
+print("indices:", a.indices.asnumpy())
+
+print("--------Start NDArray Compute---------")
+b = mx.nd.my_transposerowsp(a)
+print("Compute Results:")
+print("data:", b.data.asnumpy())
+print("indices:", b.indices.asnumpy())
+
+print("Stateful Compute Result:")
+c = mx.nd.my_state_transposerowsp(a, test_kw=100)
+print("data:", c.data.asnumpy())
+print("indices:", c.indices.asnumpy())
+
+print("--------start symbolic compute--------")
+d = mx.sym.Variable('d')
+e = mx.sym.my_transposerowsp(d)
+f = mx.sym.my_state_transposerowsp(d, test_kw=200)
+
+exe = e.bind(ctx=mx.cpu(),args={'d':a})
+exe2 = f.bind(ctx=mx.cpu(),args={'d':a})
+out = exe.forward()
+print("Compute Results:")
+print("data:", out[0].data.asnumpy())
+print("indices:", out[0].indices.asnumpy())
+
+out2 = exe2.forward()
+out2 = exe2.forward()
+print("Stateful Compute Result:")
+print("data:", out2[0].data.asnumpy())
+print("indices:", out2[0].indices.asnumpy())
+
+print("--------Baseline(dense)--------")
+print(mx.nd.transpose(a.tostype('default')))
diff --git a/example/extensions/lib_custom_op/transposecsr_lib.cc b/example/extensions/lib_custom_op/transposecsr_lib.cc
new file mode 100644
index 0000000..0daeb3e
--- /dev/null
+++ b/example/extensions/lib_custom_op/transposecsr_lib.cc
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file transsparse_lib.cc
+ * \brief Sample 2D transpose custom operator.
+ */
+
+#include <iostream>
+#include "lib_api.h"
+
+void transpose(MXTensor src, MXTensor dst, OpResource res) {
+  MXSparse* A = src.data<MXSparse>();
+  MXSparse* B = dst.data<MXSparse>(); 
+  std::vector<int64_t> shape = src.shape;
+  int64_t h = shape[0];
+  int64_t w = shape[1];
+  if(src.stype == kCSRStorage) {
+    float *Aval = (float*) (A->data);
+    // Here we need one more element to help calculate index(line 57).
+    std::vector<int64_t> rowPtr(w + 2, 0);
+    // count column
+    for(int i = 0; i < A->data_len; i++) {
+      rowPtr[A->indices[i] + 2]++;
+    }
+    // Accumulated sum. After this for loop, rowPtr[1:w+2) stores the correct 
+    // result of transposed rowPtr.
+    for(int i = 2; i < rowPtr.size(); i++) {
+      rowPtr[i] += rowPtr[i - 1];
+    }
+    
+    // Alloc memory for sparse data, where 0 is the index
+    // of B in output vector.
+    res.alloc_sparse(B, 0, A->data_len, w + 1);
+    float *Bval = (float*) (B->data);
+    for(int i = 0; i < h; i++) {
+      for(int j = A->indptr[i]; j < A->indptr[i + 1]; j++) {
+        // Helps calculate index and after that rowPtr[0:w+1) stores the 
+        // correct result of transposed rowPtr.
+        int index = rowPtr[A->indices[j] + 1]++;
+        Bval[index] = Aval[j];
+        B->indices[index] = i;
+      }
+    }
+    memcpy(B->indptr, rowPtr.data(), sizeof(int64_t) * (w + 1));
+  }
+}
+
+MXReturnValue forward(std::map<std::string, std::string> attrs,
+                      std::vector<MXTensor> inputs,
+                      std::vector<MXTensor> outputs,
+                      OpResource res) {
+  // The data types and storage types of inputs and outputs should be the same.  
+  if(inputs[0].dtype != outputs[0].dtype || inputs[0].stype != outputs[0].stype) {
+    std::cout << "Error! Expected all inputs and outputs to be the same type." 
+              << "Found input storage type:" << inputs[0].stype
+              << " Found output storage type:" << outputs[0].stype
+              << " Found input data type:" << inputs[0].dtype
+              << " Found output data type:" << outputs[0].dtype << std::endl;
+    return MX_FAIL;
+  }
+
+  transpose(inputs[0], outputs[0], res);
+  return MX_SUCCESS;
+}
+
+MXReturnValue backward(std::map<std::string, std::string> attrs,
+                       std::vector<MXTensor> inputs,
+                       std::vector<MXTensor> outputs,
+                       OpResource res) {
+  return MX_SUCCESS;
+}
+
+MXReturnValue parseAttrs(std::map<std::string, std::string> attrs, int* num_in, int* num_out) {
+  *num_in = 1;
+  *num_out = 1;
+  return MX_SUCCESS;
+}
+
+MXReturnValue inferType(std::map<std::string, std::string> attrs,
+                        std::vector<int> &intypes,
+                        std::vector<int> &outtypes) {
+  // validate inputs
+  if (intypes.size() != 1) {
+    std::cout << "Expected 1 inputs to inferType" << std::endl;
+    return MX_FAIL;
+  }
+  if (intypes[0] != kFloat32) {
+    std::cout << "Expected input to have float32 type" << std::endl;
+    return MX_FAIL;
+  }
+
+  outtypes[0] = intypes[0];
+  return MX_SUCCESS;
+}
+
+MXReturnValue inferSType(std::map<std::string, std::string> attrs,
+                        std::vector<int> &instypes,
+                        std::vector<int> &outstypes) {
+  if (instypes[0] != kCSRStorage) {
+    std::cout << "Expected storage type is kCSRStorage" << std::endl;
+    return MX_FAIL;
+  }
+  outstypes[0] = instypes[0];
+  return MX_SUCCESS;
+}
+
+MXReturnValue inferShape(std::map<std::string, std::string> attrs,
+                         std::vector<std::vector<unsigned int>> &inshapes,
+                         std::vector<std::vector<unsigned int>> &outshapes) {
+  // validate inputs
+  if (inshapes.size() != 1) {
+    std::cout << "Expected 1 inputs to inferShape" << std::endl;
+    return MX_FAIL;
+  }
+
+  outshapes[0].push_back(inshapes[0][1]);
+  outshapes[0].push_back(inshapes[0][0]);
+  return MX_SUCCESS;
+}
+
+REGISTER_OP(my_transposecsr)
+.setForward(forward, "cpu")
+.setBackward(backward, "cpu")
+.setParseAttrs(parseAttrs)
+.setInferType(inferType)
+.setInferSType(inferSType)
+.setInferShape(inferShape);
+
+/* ------------------------------------------------------------------------- */
+
+class MyStatefulTransposeCSR : public CustomStatefulOp {
+ public:
+  explicit MyStatefulTransposeCSR(int count) : count(count) {}
+
+  MXReturnValue Forward(std::vector<MXTensor> inputs,
+                        std::vector<MXTensor> outputs,
+                        OpResource op_res) {
+    std::cout << "Info: keyword + number of forward: " << ++count << std::endl;
+    std::map<std::string, std::string> attrs;
+    return forward(attrs, inputs, outputs, op_res);
+  }
+
+  MXReturnValue Backward(std::vector<MXTensor> inputs,
+                         std::vector<MXTensor> outputs,
+                         OpResource op_res) {
+    std::map<std::string, std::string> attrs;
+    return backward(attrs, inputs, outputs, op_res);
+  }
+
+ private:
+  int count;
+};
+
+MXReturnValue createOpState(std::map<std::string, std::string> attrs,
+                            CustomStatefulOp** op_inst) {
+  // testing passing of keyword arguments
+  int count = attrs.count("test_kw") > 0 ? std::stoi(attrs["test_kw"]) : 0;
+  // creating stateful operator instance
+  *op_inst = new MyStatefulTransposeCSR(count);
+  std::cout << "Info: stateful operator created" << std::endl;
+  return MX_SUCCESS;
+}
+
+REGISTER_OP(my_state_transposecsr)
+.setParseAttrs(parseAttrs)
+.setInferType(inferType)
+.setInferSType(inferSType)
+.setInferShape(inferShape)
+.setCreateOpState(createOpState, "cpu");
+
+MXReturnValue initialize(int version) {
+  if (version >= 10400) {
+    std::cout << "MXNet version " << version << " supported" << std::endl;
+    return MX_SUCCESS;
+  } else {
+    std::cout << "MXNet version " << version << " not supported" << std::endl;
+    return MX_FAIL;
+  }
+}
diff --git a/example/extensions/lib_custom_op/transposerowsp_lib.cc b/example/extensions/lib_custom_op/transposerowsp_lib.cc
new file mode 100644
index 0000000..883d816
--- /dev/null
+++ b/example/extensions/lib_custom_op/transposerowsp_lib.cc
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file transsparse_lib.cc
+ * \brief Sample 2D transpose custom operator.
+ */
+
+#include <iostream>
+#include "lib_api.h"
+
+void transpose(MXTensor src, MXTensor dst, OpResource res) {
+  MXSparse* A = src.data<MXSparse>();
+  MXSparse* B = dst.data<MXSparse>(); 
+
+  std::vector<int64_t> shape = src.shape;
+  int64_t h = shape[0];
+  int64_t w = shape[1];
+  if(src.stype == kRowSparseStorage) {
+    // Keys of the map is the row index of transposed tensors.
+    // Values of the map is the rows which have non-zero elements.    
+    std::map<int, std::vector<float>> mp;
+    float *Aval = (float*) (A->data);
+    for(int i = 0; i < A->data_len; i++) {
+      int row = i / w;
+      int col = i % w;
+      row = A->indices[row];
+      if(Aval[i] != 0) {
+        if(mp.find(col) == mp.end()) {
+          mp[col] = std::vector<float>(h, 0);
+          mp[col][row] = Aval[i];
+        }
+        else {
+          mp[col][row] = Aval[i];
+        }
+      }
+    }
+
+    // Alloc memory for output tensors.
+    res.alloc_sparse(B, 0, mp.size());
+    float *Bval = (float*) (B->data);
+    int didx = 0, iidx = 0;
+    for(auto i : mp) {
+      B->indices[iidx++] = i.first;
+      for(auto j : i.second) {
+        Bval[didx++] = j;
+      }
+    }
+  }
+}
+
+MXReturnValue forward(std::map<std::string, std::string> attrs,
+                      std::vector<MXTensor> inputs,
+                      std::vector<MXTensor> outputs,
+                      OpResource res) {
+  // The data types and storage types of inputs and outputs should be the same.
+  if(inputs[0].dtype != outputs[0].dtype || inputs[0].stype != outputs[0].stype) {
+    std::cout << "Error! Expected all inputs and outputs to be the same type."
+              << "Found input storage type:" << inputs[0].stype
+              << " Found output storage type:" << outputs[0].stype
+              << " Found input data type:" << inputs[0].dtype
+              << " Found output data type:" << outputs[0].dtype << std::endl;
+    return MX_FAIL;
+  }
+  transpose(inputs[0], outputs[0], res);
+  return MX_SUCCESS;
+}
+
+MXReturnValue backward(std::map<std::string, std::string> attrs,
+                       std::vector<MXTensor> inputs,
+                       std::vector<MXTensor> outputs,
+                       OpResource res) {
+  return MX_SUCCESS;
+}
+
+MXReturnValue parseAttrs(std::map<std::string, std::string> attrs, int* num_in, int* num_out) {
+  *num_in = 1;
+  *num_out = 1;
+  return MX_SUCCESS;
+}
+
+MXReturnValue inferType(std::map<std::string, std::string> attrs,
+                        std::vector<int> &intypes,
+                        std::vector<int> &outtypes) {
+  // validate inputs
+  if (intypes.size() != 1) {
+    std::cout << "Expected 1 inputs to inferType" << std::endl;
+    return MX_FAIL;
+  }
+  if (intypes[0] != kFloat32) {
+    std::cout << "Expected input to have float32 type" << std::endl;
+    return MX_FAIL;
+  }
+
+  outtypes[0] = intypes[0];
+  return MX_SUCCESS;
+}
+
+MXReturnValue inferSType(std::map<std::string, std::string> attrs,
+                        std::vector<int> &instypes,
+                        std::vector<int> &outstypes) {
+  if (instypes[0] != kRowSparseStorage) {
+    std::cout << "Expected storage type is kRowSparseStorage" << std::endl;
+    return MX_FAIL;
+  }
+  outstypes[0] = instypes[0];
+  return MX_SUCCESS;
+}
+
+MXReturnValue inferShape(std::map<std::string, std::string> attrs,
+                         std::vector<std::vector<unsigned int>> &inshapes,
+                         std::vector<std::vector<unsigned int>> &outshapes) {
+  // validate inputs
+  if (inshapes.size() != 1) {
+    std::cout << "Expected 1 inputs to inferShape" << std::endl;
+    return MX_FAIL;
+  }
+
+  outshapes[0].push_back(inshapes[0][1]);
+  outshapes[0].push_back(inshapes[0][0]);
+  return MX_SUCCESS;
+}
+
+REGISTER_OP(my_transposerowsp)
+.setForward(forward, "cpu")
+.setBackward(backward, "cpu")
+.setParseAttrs(parseAttrs)
+.setInferType(inferType)
+.setInferSType(inferSType)
+.setInferShape(inferShape);
+
+/* ------------------------------------------------------------------------- */
+
+class MyStatefulTransposeRowSP : public CustomStatefulOp {
+ public:
+  explicit MyStatefulTransposeRowSP(int count) : count(count) {}
+
+  MXReturnValue Forward(std::vector<MXTensor> inputs,
+                        std::vector<MXTensor> outputs,
+                        OpResource op_res) {
+    std::cout << "Info: keyword + number of forward: " << ++count << std::endl;
+    std::map<std::string, std::string> attrs;
+    return forward(attrs, inputs, outputs, op_res);
+  }
+
+  MXReturnValue Backward(std::vector<MXTensor> inputs,
+                         std::vector<MXTensor> outputs,
+                         OpResource op_res) {
+    std::map<std::string, std::string> attrs;
+    return backward(attrs, inputs, outputs, op_res);
+  }
+
+ private:
+  int count;
+};
+
+MXReturnValue createOpState(std::map<std::string, std::string> attrs,
+                            CustomStatefulOp** op_inst) {
+  // testing passing of keyword arguments
+  int count = attrs.count("test_kw") > 0 ? std::stoi(attrs["test_kw"]) : 0;
+  // creating stateful operator instance
+  *op_inst = new MyStatefulTransposeRowSP(count);
+  std::cout << "Info: stateful operator created" << std::endl;
+  return MX_SUCCESS;
+}
+
+REGISTER_OP(my_state_transposerowsp)
+.setParseAttrs(parseAttrs)
+.setInferType(inferType)
+.setInferSType(inferSType)
+.setInferShape(inferShape)
+.setCreateOpState(createOpState, "cpu");
+
+MXReturnValue initialize(int version) {
+  if (version >= 10400) {
+    std::cout << "MXNet version " << version << " supported" << std::endl;
+    return MX_SUCCESS;
+  } else {
+    std::cout << "MXNet version " << version << " not supported" << std::endl;
+    return MX_FAIL;
+  }
+}
diff --git a/example/extensions/lib_subgraph/README.md b/example/extensions/lib_subgraph/README.md
index b113be2..83c8236 100644
--- a/example/extensions/lib_subgraph/README.md
+++ b/example/extensions/lib_subgraph/README.md
@@ -53,9 +53,11 @@ You can start getting familiar with custom partitioners by running an example pr
 
 * **lib_subgraph/test_subgraph.py**: This file calls `mx.library.load(‘libsubgraph_lib.so’)` to load the library containing the custom components, partitions the model using the `optimize_for` API, and prints outputs of the forward passes. The outputs should be the same as the regular MXNet forward pass without partitioning.
 
+* **include/mxnet/lib_api.h**: This file from MXNet source code is the single header file needed to include all necessary data types and function prototypes for writing a custom operator library. You can either specify the include path in the `Makefile`, or copy the header file over to `example/extensions/lib_subgraph` folder. Note that apart from this header, the custom operator library is independent of MXNet source.
+
 ## Writing Custom Partitioner Library
 
-For building a library containing your own custom partitioner, compose a C++ source file like `mypart_lib.cc`, include `lib_api.h` header file, and write your custom partitioner with these essential functions:
+To build your own library containing a custom partitioner, compose a C++ source file like `mypart_lib.cc`, include `lib_api.h` header file, and write your custom partitioner with these essential functions:
 - `initialize` - Library Initialization Function
 - `REGISTER_PARTITIONER ` - Partitioner Registration Macro
 - `mySupportedOps ` - Operator Support
@@ -76,34 +78,60 @@ sym, _, _ = mx.model.load_checkpoint('mymodel', 0)
 # Symbol/Module flow
 sym2 = sym.optimize_for("myPart")
 
-# Gluon flow
+# Gluon flow 1
 sym_block = nn.SymbolBlock(sym, inputs)
 sym_block.hybridize(backend='myPart')
+
+# Gluon flow 2
+sym_block = nn.SymbolBlock(sym, inputs)
+sym_block.optimize_for(x, backend='myPart')
 ```
 
+In the Gluon hybridize flow, the model is actually hybridized during the first inference, rather than immediately when calling `hybridize`. This hybridize-based flow is useful if a user expects to run inference immediately after hybridizing. But for users than just want to partition but not run a whole forward pass, the `optimize_for` API combines the hybrdize/forward APIs but does not run a forward pass. After calling `optimize_for` users can `export` their model immediately without run [...]
+
 ### Using a Custom Partitioner Library
 
 Partitioning APIs in MXNet are available in both Symbol and Gluon APIs. For the Symbol API, the `optimize_for` API can be called on Symbol objects to return a partitioned Symbol.
 
 ```
-optimize_for(backend, args=None, ctx=None, **kwargs)
+optimize_for(backend, args=None, aux=None, ctx=None, **kwargs)
 ```
 
-The `optimize_for` API takes at least 1 argument, `backend` which is a string that identifies which backend to partition the model for. The `args` argument is optional and takes a list of NDArray or dict of str to NDArray. It is used to infer shapes and types and before partitioning. The `ctx` argument is optional and takes a device context to infer storage types. It also take any other user-specified options that will be passed to the backend partitioning APIs.
+The `optimize_for` API takes at least 1 argument, `backend` which is a string that identifies which backend to partition the model for. The `args` and `aux` arguments are optional and take a list of NDArray or dict of str to NDArray. They are used to infer shapes and types and before partitioning, and passed to the backend to use during compilation. The `ctx` argument is optional and takes a device context to infer storage types. It also takes any other user-specified options that will b [...]
 
 For the Gluon API, the `hybridize` API can be called on HybridBlocks to partition the internal CachedOp Symbol.
 
 ```
-hybridize(backend=None, backend_opts=None)
+hybridize(backend=None, backend_opts=None, **kwargs)
+```
+
+The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which backend that will partition the model. The `backend_opts` takes other user-specified options that will be passed to the backend partitioning APIs. The actual partitioning takes place during the forward pass.
+
+If you just want to partition the HybridBlock but not run a complete forward pass, you can use the `optimize_for` API that combines the work done in the `hybridize` API with part of the work done in the forward pass.
+
+```
+optimize_for(x, backend=None, backend_opts=None, **kwargs)
+```
+
+When the `optimize_for` API is called on a HybridBlock it partitions immediately. This lets users export the partitioned model without running a complete forward pass.
+
+```
+block.optimize_for(x, backend='myPart')
+block.export('partitioned')
 ```
 
-When the `hybridize` function is called, Gluon will convert the program’s execution into the style used in symbolic programming. The `backend` argument is a string that identifies which backend to partition the model for. The `backend_opts` takes other user-specified options that will be passed to the backend partitioning APIs.
+But you can also use `optimize_for` in place of `hybridize` and run inference immediately after too.
+
+```
+block.optimize_for(x, backend='myPart')
+block(x)
+```
 
 ### Writing A Custom Partitioner
 
 There are several essential building blocks for making a custom partitioner:
 
-* [initialize](./subgraph_lib.cc#L242):
+* [initialize](./subgraph_lib.cc#L261):
     * This function is the library initialization function necessary for any dynamic libraries. It lets you check if the user is using a compatible version of MXNet. Note that this `version` parameter is passed from MXNet when library is loaded.
 
             MXReturnValue initialize(int version)
@@ -116,40 +144,37 @@ There are several essential building blocks for making a custom partitioner:
                 std::vector<bool>& ids,
                 std::unordered_map<std::string, std::string>& options)
 
-* [REGISTER_PARTITIONER(my_part_name)](./subgraph_lib.cc#L238):
+* [REGISTER_PARTITIONER(my_part_name)](./subgraph_lib.cc#L257):
     * This macro registers the custom partitioner and its properties to MXNet by its name. Notice that a partitioner can have multiple partitioning strategies. This enables multiple *passes* to be run in a single partitioning call from the user. The first argument to `addStrategy` is a user-specified name. The second argument is the `supportedOps` function. The third argument is the name of the subgraph operator to create for each subgraph created during partitioning (see below for more  [...]
 
             REGISTER_PARTITIONER(my_part_name)
-            .addStrategy("strategy1", 
-                          supportedOps, 
-                          "_custom_subgraph_op")
-            .setReviewSubgraph("strategy1", 
-                                reviewSubgraph);
+            .addStrategy("strategy1", supportedOps, "_custom_subgraph_op")
+            .setReviewSubgraph("strategy1", reviewSubgraph);
 
 
 Also there are some optional functions you can specify:
 
-* [reviewSubgraph](./subgraph_lib.cc#L220):
+* [reviewSubgraph](./subgraph_lib.cc#L219):
     * This function provides an opportunity to accept/reject a subgraph after MXNet partitions it. It also allows specifying custom attributes on the subgraph (ie. user-generated IDs). If you do not register this function, subgraphs will be accepted by default. 
 
             MXReturnValue reviewSubgraph(
                 std::string json,
-                int subraph_id,
+                int subgraph_id,
                 bool* accept,
-                std::unordered_map<std::string, 
-                                   std::string>& options,
-                std::unordered_map<std::string, 
-                                   std::string>& attrs)
+                std::unordered_map<std::string, std::string>& options,
+                std::unordered_map<std::string, std::string>& attrs,
+                std::map<std::string, MXTensor>& args,
+                std::map<std::string, MXTensor>& aux)
 
 Let’s take a closer look at those registry functions:
 
-* **supportedOps**: This function takes four arguments. The 1st argument is a JSON string of the model architecture graph, where nodes are inputs/params/weights and edges are data dependencies. The graph is pre-sorted in topological order. The 2nd argument is an array of booleans, one for each operator in the model. When traversing the graph, operators to be partitioned into subgraphs are identified and an entry is set to `true` for the node ID in the `ids` array. The last argument is th [...]
+* **supportedOps**: This function takes four arguments. The 1st argument is a JSON string of the model architecture graph, where nodes are inputs/params/weights and edges are data dependencies. The graph is pre-sorted in topological order. The 2nd argument is an array of booleans, one for each operator in the model. When traversing the graph, operators to be partitioned into subgraphs are identified and an entry is set to `true` for the index in the `ids` array corresponding to the node  [...]
 
-* **reviewSubgraph**: This function takes five arguments. The 1st argument is a JSON string of the newly partitioned subgraph. The 2nd argument is the subgraph ID, this is just a number MXNet uses to identify this particular subgraph (it starts at zero and increments). The 3rd argument is an output to be set in this function to tell MXNet whether to accept (value: `true`) or reject (value: `false`) the subgraph. The 4th argument is the map of options specified by the user. The last argum [...]
+* **reviewSubgraph**: This function takes five arguments. The 1st argument is a JSON string of the newly partitioned subgraph. The 2nd argument is the subgraph ID, this is just a number MXNet uses to identify this particular subgraph (it starts at zero and increments, unique for each subgraph in the model). The 3rd argument is an output to be set in this function to tell MXNet whether to accept (value: `true`) or reject (value: `false`) the subgraph. You might want to reject a subgraph i [...]
 
 ### Writing A Custom Subgraph Operator
 
-A partitioning strategy specifies how to partition a model and isolate operators into subgraphs. In MXNet, subgraphs are just a [stateful operator](../lib_custom_op#writing-stateful-custom-operator). Subgraph operators have an extra attribute called `SUBGRAPH_SYM_JSON` that maps to a JSON string of the subgraph. The expectation is that when a subgraph operator executes a forward/backward call, it executes all of the operators in the subgraph. 
+A partitioning strategy specifies how to partition a model and isolate operators into subgraphs. In MXNet, subgraphs are just a [stateful operator](../lib_custom_op#writing-stateful-custom-operator). Subgraph operators have an extra attribute called `MX_STR_SUBGRAPH_SYM_JSON` that maps to a JSON string of the subgraph. The expectation is that when a subgraph operator executes a forward/backward call, it executes all of the operators in the subgraph. 
 
 When registering a custom subgraph operator, all thats needed is to register a `createOpState` function and to set that the operator is a subgraph operator by calling the `setIsSubgraphOp` API like:
 
diff --git a/example/extensions/lib_subgraph/subgraph_lib.cc b/example/extensions/lib_subgraph/subgraph_lib.cc
index da888fd..d821bdb 100644
--- a/example/extensions/lib_subgraph/subgraph_lib.cc
+++ b/example/extensions/lib_subgraph/subgraph_lib.cc
@@ -84,7 +84,7 @@ MXReturnValue myExecutor(std::vector<MXTensor> inputs,
       // get input tensor based on node ID inputs from data storage
       MXTensor &input = data[node_inputs.list[0].list[0].num];
       // create temporary storage
-      MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, {"cpu", 0});
+      MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, {"cpu", 0}, kDefaultStorage);
       // save allocated ptr to free later
       to_free.push_back(tmp.data_ptr);
       // execute log operator
@@ -95,7 +95,7 @@ MXReturnValue myExecutor(std::vector<MXTensor> inputs,
       // get input tensor based on node ID inputs from data storage
       MXTensor &input = data[node_inputs.list[0].list[0].num];
       // create temporary storage
-      MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, {"cpu", 0});
+      MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, {"cpu", 0}, kDefaultStorage);
       // save allocated ptr to free later
       to_free.push_back(tmp.data_ptr);
       // execute exp operator 
@@ -160,11 +160,11 @@ MXReturnValue createOpState(std::map<std::string, std::string> attrs,
   std::string serialized_subgraph = "[empty]";
   // MXNet subgraph is stored as Symbol in operator node attrs subgraphs field
   // custom subgraph is stored as json string in custom operator attrs map entry
-  if (attrs.count(SUBGRAPH_SYM_JSON)) {
+  if (attrs.count(MX_STR_SUBGRAPH_SYM_JSON)) {
     // user can now parse json and run other custom ops inside subgraph
-    serialized_subgraph = attrs[SUBGRAPH_SYM_JSON];
+    serialized_subgraph = attrs[MX_STR_SUBGRAPH_SYM_JSON];
   }
-  attrs.erase(SUBGRAPH_SYM_JSON);
+  attrs.erase(MX_STR_SUBGRAPH_SYM_JSON);
   *op_inst = new MyStatefulOp(serialized_subgraph, attrs);
   std::cout << "Info: stateful operator created" << std::endl;
   return MX_SUCCESS;
@@ -177,7 +177,7 @@ REGISTER_OP(_custom_subgraph_op)
 const std::vector<std::string> op_names({"exp","log"});
 
 MXReturnValue mySupportedOps(std::string json,
-                             std::vector<bool> ids,
+                             std::vector<bool>& ids,
                              std::unordered_map<std::string, std::string>& options) {
   for (auto kv : options) {
     std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
@@ -204,8 +204,8 @@ MXReturnValue mySupportedOps(std::string json,
         dtype = std::stoi(attrs.map[JsonVal("dtype")].str);
     }
 
-    //check if op dtype is float
-    if(dtype == kFloat32) {
+    //check if op dtype is float, and if option was specified to require float types
+    if((dtype == kFloat32 && options.count("reqFloat") > 0) || options.count("reqFloat") == 0) {
       //check if op is in whitelist
       if(std::find(op_names.begin(),op_names.end(),op.str.c_str()) != op_names.end()) {
         // found op in whitelist, set value to 1 to include op in subgraph
@@ -216,14 +216,34 @@ MXReturnValue mySupportedOps(std::string json,
   return MX_SUCCESS;
 }
 
-MXReturnValue myReviewSubgraph(std::string json, int subraph_id, bool* accept,
+MXReturnValue myReviewSubgraph(std::string json, int subgraph_id, bool* accept,
                                std::unordered_map<std::string, std::string>& options,
-                               std::unordered_map<std::string, std::string>& attrs) {
+                               std::unordered_map<std::string, std::string>& attrs,
+                               std::map<std::string, MXTensor>& args,
+                               std::map<std::string, MXTensor>& aux) {
   for (auto kv : options) {
     std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
   }
-  if(options.find("reject") != options.end() &&
-     options["reject"].compare("True") == 0) {
+  for (auto kv : args) {
+    std::cout << "arg: " << kv.first << " ==> (";
+    for (auto s : kv.second.shape)
+      std::cout << s << ",";
+    std::cout << ") [";
+    for (int i=0; i<kv.second.size(); i++)
+      std::cout << kv.second.data<float>()[i] << ", ";
+    std::cout << "]" << std::endl;
+  }
+
+  // check if option `reqArgs` was specified, and if so check if args were provided
+  if(options.count("reqArgs") > 0 && args.size() == 0) {
+    *accept = false;
+    std::cout << "rejecting subgraph since args were not provided" << std::endl;
+    return MX_SUCCESS;
+  }
+
+  // check if option `reject` was specified, and if so check if value is 'True'
+  if(options.count("reject") > 0 && options["reject"].compare("True") == 0) {
+    // if specified, reject the subgraph. this is only used for testing
     *accept = false;
     std::cout << "rejecting subgraph" << std::endl;
   } else {
@@ -231,7 +251,6 @@ MXReturnValue myReviewSubgraph(std::string json, int subraph_id, bool* accept,
     std::cout << "accepting subgraph" << std::endl;
     attrs["myKey"] = "myVal";
   }
-  std::cout << json << std::endl;
   return MX_SUCCESS;
 }
 
diff --git a/example/extensions/lib_subgraph/test_subgraph.py b/example/extensions/lib_subgraph/test_subgraph.py
index 1bcecae..55a4051 100644
--- a/example/extensions/lib_subgraph/test_subgraph.py
+++ b/example/extensions/lib_subgraph/test_subgraph.py
@@ -23,8 +23,10 @@
 # This test checks if dynamic loading of library into MXNet is successful
 # and checks the end of end computation of custom operator
 
-import mxnet as mx
 import os, ctypes
+import mxnet as mx
+from mxnet.gluon import nn
+from mxnet import nd
 from mxnet.base import _LIB, check_call, mx_uint, c_str, c_str_array, SymbolHandle
 
 # load library
@@ -35,6 +37,10 @@ elif (os.name=='nt'):
     path = os.path.abspath('libsubgraph_lib.dll')
     mx.library.load(path)
 
+###############################################
+# Test with subgraph not consuming params
+###############################################
+# example model, ops to be partitioned do not have args (use outputs from other ops as inputs)
 a = mx.sym.var('a')
 b = mx.sym.var('b')
 c = a + b
@@ -75,9 +81,6 @@ exe3 = mysym3.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,
 out3 = exe3.forward()
 print(out3)
 
-from mxnet.gluon import nn
-from mxnet import nd
-
 # Gluon Hybridize partitioning with shapes/types
 print('-------------------------------')
 print('Testing Gluon Hybridize partitioning with shapes/types')
@@ -88,3 +91,54 @@ sym_block.hybridize(backend='myProp')
 out4 = sym_block(mx.nd.ones((3,2)),mx.nd.ones((3,2)))
 print(out4)
 
+# Gluon Hybridize partitioning with shapes/types without inference
+print('-------------------------------')
+print('Testing Gluon Hybridize partitioning with shapes/types without inference')
+inputs = [a,b]
+sym_block2 = nn.SymbolBlock(sym, inputs)
+sym_block2.initialize()
+sym_block2.optimize_for(mx.nd.ones((3,2)), mx.nd.ones((3,2)), backend='myProp')
+sym_block2.export('partitioned')
+
+
+###############################################
+# Test with subgraph directly consuming params
+###############################################
+# example model, ops to be partitioned have args
+d2 = mx.sym.exp(a)
+sym2 = mx.sym.log(d2)
+
+#execute in MXNet
+print('-------------------------------')
+print('Testing regular MXNet execution')
+exe5 = sym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))})
+out5 = exe5.forward()
+print(out5)
+
+# with propogating shapes/types
+print('-------------------------------')
+print('Testing partitioning with shapes/types')
+arg_array = [mx.nd.ones((3,2),dtype='float32')]
+mysym6 = sym2.optimize_for("myProp", arg_array, reqArgs=True)
+print(mysym6.tojson())
+exe6 = mysym6.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))})
+out6 = exe6.forward()
+print(out6)
+
+# without propogating shapes/types
+print('-------------------------------')
+print('Testing partitioning without shapes/types')
+mysym7 = sym2.optimize_for("myProp", reqArgs=True)
+exe7 = mysym7.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))})
+out7 = exe7.forward()
+print(out7)
+
+# Gluon Hybridize partitioning with shapes/types
+print('-------------------------------')
+print('Testing Gluon Hybridize partitioning with shapes/types')
+inputs = [a]
+sym2_block = nn.SymbolBlock(sym2, inputs)
+sym2_block.initialize()
+sym2_block.hybridize(backend='myProp')
+out8 = sym2_block(mx.nd.ones((3,2)))
+print(out8)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index bb2a568..e3d9062 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -2170,8 +2170,10 @@ MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle,
                                    const char* backend_name,
                                    const int dev_type,
                                    SymbolHandle* ret_sym_handle,
-                                   const mx_uint len,
+                                   const mx_uint args_len,
                                    NDArrayHandle* in_args_handle,
+                                   const mx_uint aux_len,
+                                   NDArrayHandle* in_aux_handle,
                                    const mx_uint num_options,
                                    const char** keys,
                                    const char** vals);
diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h
index d59f2b1..c793a30 100644
--- a/include/mxnet/lib_api.h
+++ b/include/mxnet/lib_api.h
@@ -38,8 +38,14 @@
 #include <iostream>
 #include <utility>
 #include <stdexcept>
+#include <random>
 
-#define MX_LIBRARY_VERSION 3
+#if defined(__NVCC__)
+  #include <curand_kernel.h>
+#endif
+
+/* Make sure to update the version number everytime you make changes */
+#define MX_LIBRARY_VERSION 6
 
 /*!
  * \brief For loading multiple custom op libraries in Linux, exporting same symbol multiple
@@ -214,6 +220,18 @@ enum MXDType {
   kUNSET = 100,
 };
 
+/*
+ * MXTensor storage type.
+ */
+enum MXStorageType {
+  // dense
+  kDefaultStorage = 0,
+  // row sparse
+  kRowSparseStorage = 1,
+  // csr
+  kCSRStorage = 2,
+};
+
 /*!
  * \brief Context info passing from MXNet OpContext
  * dev_type is string repr of supported context, currently only "cpu" and "gpu"
@@ -229,20 +247,64 @@ enum MXReturnValue {
   MX_SUCCESS = 1,
 };
 
+// For sparse tensors, read/write the data from NDarray via pointers.
+struct MXSparse {
+  // Pointer to data.
+  void *data{nullptr};
+  // length of (non-zero) data.
+  int64_t data_len;
+
+  // To store aux data for sparse.
+  // For CSR, indices stores the col index of non-zero elements.
+  // For row sparse, indices store row index of rows which have non-zero elements.
+  int64_t* indices;
+  int64_t indices_len;
+
+  // For CSR, indptr gives the start and end index of data for each row.
+  // For row sparse, indptr is not used.
+  int64_t* indptr = nullptr;
+  int64_t indptr_len;
+
+  void set(void *data_ptr, const int64_t* dims, int ndims, void *idx,
+          int64_t num_idx, void *idx_ptr = nullptr, int64_t num_idx_ptr = 0) {
+    data = data_ptr;
+    // If CSR, num of non-zero elemets is num_idx,
+    // If row sparse, num of elements is num_idx * width.
+    data_len = num_idx;
+    if (!idx_ptr) {
+      for (int i = 1; i < ndims; ++i)
+         data_len *= dims[i];
+    }
+
+    indices = reinterpret_cast<int64_t*>(idx);
+    indices_len = num_idx;
+
+    if (idx_ptr) {
+      indptr = reinterpret_cast<int64_t*>(idx_ptr);
+      indptr_len = num_idx_ptr;
+    }
+  }
+};
+
 /*!
  * \brief Tensor data structure used by custom operator
  */
 struct MXTensor {
-  MXTensor() : data_ptr(nullptr), dtype(kUNSET), verID(0) {}
-
+  MXTensor() : data_ptr(nullptr), dtype(kUNSET), verID(0), stype(kDefaultStorage) {}
+  MXTensor(const MXTensor& oth) : data_ptr(oth.data_ptr), shape(oth.shape),
+    dtype(oth.dtype), verID(oth.verID), ctx(oth.ctx), stype(oth.stype) {
+    setDLTensor();
+  }
   MXTensor(void *data_ptr, const std::vector<int64_t> &shape, MXDType dtype,
-           size_t vID, MXContext mx_ctx)
-  : data_ptr(data_ptr), shape(shape), dtype(dtype), verID(vID), ctx(mx_ctx) {}
+           size_t vID, MXContext mx_ctx, MXStorageType stype = kDefaultStorage)
+  : data_ptr(data_ptr), shape(shape), dtype(dtype), verID(vID), ctx(mx_ctx), stype(stype) {
+    setDLTensor();
+  }
 
   /*! \brief populate internal tensor fields */
   void setTensor(void *dptr, MXDType type, const int64_t* dims, int ndims,
-                 size_t vID, MXContext mx_ctx) {
-    data_ptr = dptr; dtype = type; verID = vID; ctx = mx_ctx;
+                 size_t vID, MXContext mx_ctx, MXStorageType storage_type) {
+    data_ptr = dptr; dtype = type; verID = vID; ctx = mx_ctx; stype = storage_type;
     shape.clear();
     for (int j = 0; j < ndims; j++) {
       shape.push_back(dims[j]);
@@ -335,11 +397,12 @@ struct MXTensor {
            verID == oth.verID &&
            ctx.dev_type == oth.ctx.dev_type &&
            ctx.dev_id == oth.ctx.dev_id &&
-           shape == oth.shape;
+           shape == oth.shape &&
+           stype == oth.stype;
   }
 
-  // data is flatten 1D repr of tensor, elements are in continuous memory
-  // user can access each element using the shape of tensor
+  // For dense, data_ptr points to 1D flattened tensor data
+  // For sparse, data_ptr points to MXSparse
   void *data_ptr;
 
   // shape is in [2,3,4] format to represent high-dim tensor
@@ -357,16 +420,29 @@ struct MXTensor {
   // corresponding DLTensor repr of MXTensor
   // easy way to reuse functions taking DLTensor
   DLTensor dltensor;
+
+  // storage type
+  MXStorageType stype;
 };
 
 /*! \brief resource malloc function to allocate memory inside Forward/Backward functions */
 typedef void* (*xpu_malloc_t)(void*, int);
 
+typedef void (*sparse_malloc_t)(void*, int, int, int, void**, int64_t**, int64_t**);
+
 #if defined(__NVCC__)
   typedef cudaStream_t mx_stream_t;
+  typedef curandStatePhilox4_32_10_t mx_gpu_rand_t;
 #else
   typedef void* mx_stream_t;
+  typedef void* mx_gpu_rand_t;
 #endif
+typedef std::mt19937 mx_cpu_rand_t;
+
+/*! \brief MXNet initialized random states for each device, used for parallelism */
+/* Each thread should generate random number unique sequence out of different states */
+#define MX_NUM_CPU_RANDOM_STATES 1024
+#define MX_NUM_GPU_RANDOM_STATES 32768
 
 /*!
  * \brief provide resource APIs memory allocation mechanism to Forward/Backward functions
@@ -374,9 +450,13 @@ typedef void* (*xpu_malloc_t)(void*, int);
 class OpResource {
  public:
   OpResource(xpu_malloc_t cpu_malloc_fp, void* cpu_alloc_fp,
-             xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream)
+             xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream,
+             sparse_malloc_t sparse_malloc_fp, void* sparse_alloc_fp,
+             void* rng_cpu_states, void* rng_gpu_states)
     : cpu_malloc(cpu_malloc_fp), gpu_malloc(gpu_malloc_fp),
-      cpu_alloc(cpu_alloc_fp), gpu_alloc(gpu_alloc_fp), cuda_stream(stream) {}
+      cpu_alloc(cpu_alloc_fp), gpu_alloc(gpu_alloc_fp), cuda_stream(stream),
+      sparse_malloc(sparse_malloc_fp), sparse_alloc(sparse_alloc_fp),
+      rand_cpu_states(rng_cpu_states), rand_gpu_states(rng_gpu_states) {}
 
   /*! \brief allocate cpu memory controlled by MXNet */
   void* alloc_cpu(int size) {
@@ -393,6 +473,25 @@ class OpResource {
     return static_cast<mx_stream_t>(cuda_stream);
   }
 
+  /*! \brief allocate sparse memory controlled by MXNet */
+  void alloc_sparse(MXSparse* sparse, int index, int indices_len, int indptr_len = 0) {
+    sparse_malloc(sparse_alloc, index, indices_len, indptr_len,
+                   &(sparse->data), &(sparse->indices), &(sparse->indptr));
+  }
+
+  /*! \brief get pointer to initialized and seeded random number states located on CPU */
+  /* Access each state by states[id], but this id should be <= MX_NUM_CPU_RANDOM_STATES */
+  mx_cpu_rand_t* get_cpu_rand_states() {
+    return static_cast<mx_cpu_rand_t*>(rand_cpu_states);
+  }
+
+  /*! \brief get pointer to initialized and seeded random number states located on GPU */
+  /* Access each state by states[id], but this id should be <= MX_NUM_GPU_RANDOM_STATES */
+  /* Note that if you are using cpu build, it will return a nullptr */
+  mx_gpu_rand_t* get_gpu_rand_states() {
+    return static_cast<mx_gpu_rand_t*>(rand_gpu_states);
+  }
+
  private:
   /*! \brief allocation lambda function */
   xpu_malloc_t cpu_malloc, gpu_malloc;
@@ -400,13 +499,47 @@ class OpResource {
   void *cpu_alloc, *gpu_alloc;
   /*! \brief cuda stream passed from MXNet */
   void *cuda_stream;
+  /*! \brief sparse allocation lambda function */
+  sparse_malloc_t sparse_malloc;
+  /*! \brief lambda function to return allocated sparse memory handle */
+  void *sparse_alloc;
+  /*! \brief cpu and gpu rng fully inited and seeded states */
+  void *rand_cpu_states, *rand_gpu_states;
 };
 
 /*!
  * \brief Json utility to parse serialized subgraph symbol
  */
 /*! \brief Macro to help passing serialized subgraph through attribute dict */
-#define SUBGRAPH_SYM_JSON "subgraph_sym_json"
+#define MX_STR_SUBGRAPH_SYM_JSON "subgraph_sym_json"
+#define MX_STR_DTYPE "__dtype__"
+#define MX_STR_SHAPE "__shape__"
+
+/* \brief get shape value from list of shapes string
+ * format: [[1]] or [[1],[2]]
+ */
+std::string getShapeAt(const std::string& shape, unsigned index) {
+  int idx = 1;  // start at 1 to skip the first square bracket [
+  // find the beginning of the output shape for the particular output index
+  for (unsigned x=0; x < index; x++)
+    idx = shape.find("[", idx+1);
+  int stop = shape.find("]", idx);  // find stop index for this output shape
+  // add this shape to the list
+  return shape.substr(idx, stop-idx+1);
+}
+
+/* \brief get dtype value from list of dtypes string
+ * format: [1] or [1,2]
+ */
+std::string getDtypeAt(const std::string& dtype, unsigned index) {
+  // find the beginning of the output dtype for the particular output index
+  int idx = 0;
+  for (unsigned x=0; x < index; x++)
+    idx = dtype.find(",", idx+1);
+  int stop = dtype.find(",", idx+1);  // find stop index for this output dtype
+  if (stop == -1) stop = dtype.find("]", idx+1);
+  return dtype.substr(idx+1, stop-idx-1);
+}
 
 /*! \brief Types of JSON objects */
 enum JsonType {ERR, STR, NUM, LIST, MAP};
@@ -614,6 +747,8 @@ typedef MXReturnValue (*parseAttrs_t)(std::map<std::string, std::string>,
                                       int*, int*);
 typedef MXReturnValue (*inferType_t)(std::map<std::string, std::string>,
                                      std::vector<int>&, std::vector<int>&);
+typedef MXReturnValue (*inferSType_t)(std::map<std::string, std::string>,
+                                     std::vector<int>&, std::vector<int>&);
 typedef MXReturnValue (*inferShape_t)(std::map<std::string, std::string>,
                                       std::vector<std::vector<unsigned int> >&,
                                       std::vector<std::vector<unsigned int> >&);
@@ -627,9 +762,9 @@ typedef MXReturnValue (*createOpState_t)(std::map<std::string, std::string>,
  */
 class CustomOp {
  public:
-  explicit CustomOp(const char* op_name) :
-      name(op_name), parse_attrs(nullptr), infer_type(nullptr),
-      infer_shape(nullptr), mutate_inputs(nullptr), isSGop(false) {}
+  explicit CustomOp(const char* op_name) : name(op_name),
+    parse_attrs(NULL), infer_type(NULL), infer_storage_type(NULL), infer_shape(NULL),
+    mutate_inputs(NULL), isSGop(false) {}
   CustomOp& setForward(fcomp_t fcomp, const char* ctx) {
     if (forward_ctx_map.count(ctx) > 0)
       raiseDuplicateContextError();
@@ -650,6 +785,10 @@ class CustomOp {
     infer_type = func;
     return *this;
   }
+  CustomOp& setInferSType(inferSType_t func) {
+    infer_storage_type = func;
+    return *this;
+  }
   CustomOp& setInferShape(inferShape_t func) {
     infer_shape = func;
     return *this;
@@ -690,6 +829,7 @@ class CustomOp {
   /*! \brief operator functions */
   parseAttrs_t parse_attrs;
   inferType_t infer_type;
+  inferSType_t infer_storage_type;
   inferShape_t infer_shape;
   mutateInputs_t mutate_inputs;
   bool isSGop;
@@ -713,11 +853,13 @@ class CustomOp {
 };
 
 /*! \brief Custom Subgraph Create function template */
-typedef MXReturnValue (*supportedOps_t)(std::string, std::vector<bool>,
+typedef MXReturnValue (*supportedOps_t)(std::string, std::vector<bool>&,
                                         std::unordered_map<std::string, std::string>&);
 typedef MXReturnValue (*reviewSubgraph_t)(std::string, int, bool*,
                                           std::unordered_map<std::string, std::string>&,
-                                          std::unordered_map<std::string, std::string>&);
+                                          std::unordered_map<std::string, std::string>&,
+                                          std::map<std::string, MXTensor>&,
+                                          std::map<std::string, MXTensor>&);
 
 /*!
  * \brief An abstract class for subgraph property
@@ -841,7 +983,7 @@ typedef int (*opRegGet_t)(int idx, const char** name, int *isSGop,
                           const char*** backward_ctx, fcomp_t** backward_fp, int* backward_count,
                           const char*** create_op_ctx, createOpState_t** create_op_fp,
                           int* create_op_count,
-                          parseAttrs_t* parse, inferType_t* type,
+                          parseAttrs_t* parse, inferType_t* type, inferSType_t* stype,
                           inferShape_t* shape, mutateInputs_t* mutate);
 
 #define MXLIB_OPCALLFREE_STR "_opCallFree"
@@ -863,6 +1005,11 @@ typedef int (*opCallInferType_t)(inferType_t inferType, const char* const* keys,
                                  const char* const* vals, int num,
                                  int* intypes, int num_in, int* outtypes, int num_out);
 
+#define MXLIB_OPCALLINFERSTYPE_STR "_opCallInferSType"
+typedef int (*opCallInferSType_t)(inferSType_t inferSType, const char* const* keys,
+                                 const char* const* vals, int num,
+                                 int* intypes, int num_in, int* outtypes, int num_out);
+
 #define MXLIB_OPCALLFCOMP_STR "_opCallFCompute"
 typedef int (*opCallFComp_t)(fcomp_t fcomp, const char* const* keys,
                              const char* const* vals, int num,
@@ -875,7 +1022,14 @@ typedef int (*opCallFComp_t)(fcomp_t fcomp, const char* const* keys,
                              size_t* outIDs, const char** outdev_type,
                              int* outdev_id, int num_out,
                              xpu_malloc_t cpu_malloc, void* cpu_alloc,
-                             xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream);
+                             xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream,
+                             sparse_malloc_t sparse_malloc, void* sparse_alloc,
+                             int* instypes, int* outstypes,
+                             void** in_indices, void** out_indices,
+                             void** in_indptr, void** out_indptr,
+                             int64_t* in_indices_shapes, int64_t* out_indices_shapes,
+                             int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
+                             void* rng_cpu_states, void* rng_gpu_states);
 
 #define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs"
 typedef int (*opCallMutateInputs_t)(mutateInputs_t mutate, const char* const* keys,
@@ -898,7 +1052,14 @@ typedef int (*opCallFStatefulComp_t)(int is_forward, void* state_op,
                                      size_t* outIDs, const char** outdev_type,
                                      int* outdev_id, int num_out,
                                      xpu_malloc_t cpu_malloc, void* cpu_alloc,
-                                     xpu_malloc_t gpu_malloc, void* gpu_alloc, void* stream);
+                                     xpu_malloc_t gpu_malloc, void* gpu_alloc, void* stream,
+                                     sparse_malloc_t sparse_malloc, void* sparse_alloc,
+                                     int* instypes, int* outstypes,
+                                     void** in_indices, void** out_indices,
+                                     void** in_indptr, void** out_indptr,
+                                     int64_t* in_indices_shapes, int64_t* out_indices_shapes,
+                                     int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
+                                     void* rng_cpu_states, void* rng_gpu_states);
 
 #define MXLIB_PARTREGSIZE_STR "_partRegSize"
 typedef int (*partRegSize_t)(void);
@@ -920,7 +1081,17 @@ typedef int (*partCallSupportedOps_t)(supportedOps_t supportedOps, const char *j
 typedef int (*partCallReviewSubgraph_t)(reviewSubgraph_t reviewSubgraph, const char *json,
                                         int subgraph_id, int *accept, const char* const* opt_keys,
                                         const char* const* opt_vals, int num_opts,
-                                        char*** attr_keys, char*** attr_vals, int *num_attrs);
+                                        char*** attr_keys, char*** attr_vals, int *num_attrs,
+                                        const char* const* arg_names, int num_args,
+                                        void* const* arg_data, const int64_t* const* arg_shapes,
+                                        const int* arg_dims, const int* arg_types,
+                                        const size_t* arg_IDs, const char* const* arg_dev_type,
+                                        const int* arg_dev_id,
+                                        const char* const* aux_names, int num_aux,
+                                        void* const* aux_data, const int64_t* const* aux_shapes,
+                                        const int* aux_dims, const int* aux_types,
+                                        const size_t* aux_IDs, const char* const* aux_dev_type,
+                                        const int* aux_dev_id);
 
 #define MXLIB_INITIALIZE_STR "initialize"
 typedef int (*initialize_t)(int version);
@@ -959,12 +1130,13 @@ extern "C" {
             const char*** forward_ctx, fcomp_t** forward_fp, int* forward_count,
             const char*** backward_ctx, fcomp_t** backward_fp, int* backward_count,
             const char*** create_op_ctx, createOpState_t** create_op_fp, int* create_op_count,
-            parseAttrs_t* parse, inferType_t* type,
+            parseAttrs_t* parse, inferType_t* type, inferSType_t* stype,
             inferShape_t* shape, mutateInputs_t* mutate) {
     CustomOp &op = Registry<CustomOp>::get()->get(idx);
     *name = op.name;
     *parse = op.parse_attrs;
     *type = op.infer_type;
+    *stype = op.infer_storage_type;
     *shape = op.infer_shape;
     *mutate = op.mutate_inputs;
     *isSGop = op.isSGop;
@@ -1091,6 +1263,43 @@ extern "C" {
     return retval;
   }
 
+  /*! \brief returns status of calling inferSType function for operator from library */
+#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
+  __declspec(dllexport) int __cdecl
+#else
+  int
+#endif
+  _opCallInferSType(inferSType_t inferSType, const char* const* keys,
+                   const char* const* vals, int num,
+                   int* instypes, int num_in, int* outstypes, int num_out) {
+    // create map of attributes from list
+    std::map<std::string, std::string> attrs;
+    for (int i = 0; i < num; i++) {
+      attrs[std::string(keys[i])] = std::string(vals[i]);
+    }
+
+    // create a vector of types for inputs
+    std::vector<int> in_stypes(num_in);
+    for (int i = 0; i < num_in; i++) {
+      in_stypes[i] = instypes[i];
+    }
+
+    // create a vector of types for outputs
+    std::vector<int> out_stypes(num_out, -1);
+
+    int retval = inferSType(attrs, in_stypes, out_stypes);
+
+    if (!retval)
+      return retval;
+
+    // copy output storage types
+    for (int i = 0; i < num_out; i++) {
+      outstypes[i] = out_stypes[i];
+    }
+
+    return retval;
+  }
+
   /*! \brief returns status of calling Forward/Backward function for operator from library */
 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
   __declspec(dllexport) int __cdecl
@@ -1103,7 +1312,13 @@ extern "C" {
                   const int64_t** outshapes, int* outdims, void** outdata, int* outtypes,
                   size_t* outIDs, const char** outdev_type, int* outdev_id, int num_out,
                   xpu_malloc_t cpu_malloc, void* cpu_alloc,
-                  xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream) {
+                  xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream,
+                  sparse_malloc_t sparse_malloc, void* sparse_alloc,
+                  int* instypes, int* outstypes, void** in_indices, void** out_indices,
+                  void** in_indptr, void** out_indptr,
+                  int64_t* in_indices_shapes, int64_t* out_indices_shapes,
+                  int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
+                  void* rng_cpu_states, void* rng_gpu_states) {
     // create map of attributes from list
     std::map<std::string, std::string> attrs;
     for (int i = 0; i < num; i++) {
@@ -1112,20 +1327,59 @@ extern "C" {
 
     // create a vector of tensors for inputs
     std::vector<MXTensor> inputs(num_in);
+    // create a vector for sparse inputs
+    std::vector<MXSparse> in_sparse(num_in);
+
     for (int i = 0; i < num_in; i++) {
-      inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i],
-                          inIDs[i], {indev_type[i], indev_id[i]});
+      // Dense representation.
+      if (instypes[i] == 0) {
+        inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i],
+                            inIDs[i], {indev_type[i], indev_id[i]}, kDefaultStorage);
+      } else {
+        // Sparse representation.
+        MXStorageType type;
+        if (instypes[i] == 1) {
+          type = kRowSparseStorage;
+          in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]);
+        } else {
+          type = kCSRStorage;
+          in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i],
+                           in_indices_shapes[i], in_indptr[i], in_indptr_shapes[i]);
+        }
+        inputs[i].setTensor(reinterpret_cast<void*>(&in_sparse[i]), (MXDType)intypes[i],
+                            inshapes[i], indims[i], inIDs[i], {indev_type[i], indev_id[i]}, type);
+      }
     }
 
     // create a vector of tensors for outputs
     std::vector<MXTensor> outputs(num_out);
+    std::vector<MXSparse> out_sparse(num_out);
+
     for (int i = 0; i < num_out; i++) {
-      outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i],
-                           outIDs[i], {outdev_type[i], outdev_id[i]});
+      // Dense representation.
+      if (outstypes[i] == 0) {
+        outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i],
+                            outIDs[i], {outdev_type[i], outdev_id[i]}, kDefaultStorage);
+      } else {
+        // Sparse representation.
+        MXStorageType type;
+        if (outstypes[i] == 1) {
+          type = kRowSparseStorage;
+          out_sparse[i].set(outdata[i], outshapes[i], outdims[i],
+                            out_indices[i], out_indices_shapes[i]);
+        } else {
+          type = kCSRStorage;
+          out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i],
+                            out_indices_shapes[i], out_indptr[i], out_indptr_shapes[i]);
+        }
+        outputs[i].setTensor(reinterpret_cast<void*>(&out_sparse[i]), (MXDType)outtypes[i],
+                            outshapes[i], outdims[i], outIDs[i], {outdev_type[i],
+                            outdev_id[i]}, type);
+      }
     }
 
-    OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, cuda_stream);
-
+    OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc,
+                   cuda_stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states);
     return fcomp(attrs, inputs, outputs, res);
   }
 
@@ -1194,22 +1448,70 @@ extern "C" {
                           const int64_t** outshapes, int* outdims, void** outdata, int* outtypes,
                           size_t* outIDs, const char** outdev_type, int* outdev_id, int num_out,
                           xpu_malloc_t cpu_malloc, void* cpu_alloc,
-                          xpu_malloc_t gpu_malloc, void* gpu_alloc, void* stream) {
+                          xpu_malloc_t gpu_malloc, void* gpu_alloc, void* stream,
+                          sparse_malloc_t sparse_malloc, void* sparse_alloc,
+                          int* instypes, int* outstypes, void** in_indices, void** out_indices,
+                          void** in_indptr, void** out_indptr,
+                          int64_t* in_indices_shapes, int64_t* out_indices_shapes,
+                          int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
+                          void* rng_cpu_states, void* rng_gpu_states) {
     // create a vector of tensors for inputs
     std::vector<MXTensor> inputs(num_in);
+    // create a vector for sparse inputs
+    std::vector<MXSparse> in_sparse(num_in);
+
     for (int i = 0; i < num_in; i++) {
-      inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i],
-                          inIDs[i], {indev_type[i], indev_id[i]});
+      if (instypes[i] == 0) {
+        // Dense representation.
+        inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i],
+                            inIDs[i], {indev_type[i], indev_id[i]}, kDefaultStorage);
+      } else {
+        // Sparse representation.
+        MXStorageType type;
+        if (instypes[i] == 1) {
+          type = kRowSparseStorage;
+          in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]);
+        } else {
+          type = kCSRStorage;
+          in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i],
+                           in_indices_shapes[i], in_indptr[i], in_indptr_shapes[i]);
+        }
+        inputs[i].setTensor(reinterpret_cast<void*>(&in_sparse[i]), (MXDType)intypes[i],
+                            inshapes[i], indims[i], inIDs[i], {indev_type[i],
+                            indev_id[i]}, type);
+      }
     }
 
     // create a vector of tensors for outputs
     std::vector<MXTensor> outputs(num_out);
+    // create a vector for sparse outputs
+    std::vector<MXSparse> out_sparse(num_out);
+
     for (int i = 0; i < num_out; i++) {
-      outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i],
-                           outIDs[i], {outdev_type[i], outdev_id[i]});
+      if (outstypes[i] == 0) {
+        // Dense representation.
+        outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i],
+                             outIDs[i], {outdev_type[i], outdev_id[i]}, kDefaultStorage);
+      } else {
+        // Sparse representation.
+        MXStorageType type;
+        if (outstypes[i] == 1) {
+          type = kRowSparseStorage;
+          out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i],
+                            out_indices_shapes[i]);
+        } else {
+          type = kCSRStorage;
+          out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i],
+                            out_indices_shapes[i], out_indptr[i], out_indptr_shapes[i]);
+        }
+        outputs[i].setTensor(reinterpret_cast<void*>(&out_sparse[i]), (MXDType)outtypes[i],
+                             outshapes[i], outdims[i], outIDs[i], {outdev_type[i],
+                             outdev_id[i]}, type);
+      }
     }
 
-    OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, stream);
+    OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc,
+                   stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states);
 
     CustomStatefulOp* op_ptr = reinterpret_cast<CustomStatefulOp*>(state_op);
     if (is_forward) {
@@ -1266,11 +1568,11 @@ extern "C" {
                         int num_ids, int *ids, const char* const* opt_keys,
                         const char* const* opt_vals, int num_opts) {
     std::string subgraph_json(json);
-    // create map of attributes from list
+    // create map of options from list
     std::unordered_map<std::string, std::string> opts;
-    for (int i = 0; i < num_opts; i++) {
+    for (int i = 0; i < num_opts; i++)
       opts[std::string(opt_keys[i])] = std::string(opt_vals[i]);
-    }
+
     // create array of bools for operator support
     std::vector<bool> _ids(num_ids, false);
     // call user's supportedOps function
@@ -1293,19 +1595,55 @@ extern "C" {
   _partCallReviewSubgraph(reviewSubgraph_t reviewSubgraph, const char *json,
                           int subgraph_id, int *accept, const char* const* opt_keys,
                           const char* const* opt_vals, int num_opts,
-                          char*** attr_keys, char*** attr_vals, int *num_attrs) {
+                          char*** attr_keys, char*** attr_vals, int *num_attrs,
+                          const char* const* arg_names, int num_args,
+                          void* const* arg_data, const int64_t* const* arg_shapes,
+                          const int* arg_dims, const int* arg_types,
+                          const size_t* arg_IDs, const char* const* arg_dev_type,
+                          const int* arg_dev_id,
+                          const char* const* aux_names, int num_aux,
+                          void* const* aux_data, const int64_t* const* aux_shapes,
+                          const int* aux_dims, const int* aux_types,
+                          const size_t* aux_IDs, const char* const* aux_dev_type,
+                          const int* aux_dev_id) {
     std::string subgraph_json(json);
     bool accept_bool = false;
     // create map of attributes from list
     std::unordered_map<std::string, std::string> opts;
-    for (int i = 0; i < num_opts; i++) {
+    for (int i = 0; i < num_opts; i++)
       opts[std::string(opt_keys[i])] = std::string(opt_vals[i]);
+
+    // create a map of named tensors for args
+    std::map<std::string, MXTensor> args;
+    for (int i = 0; i < num_args; i++) {
+      std::vector<int64_t> shapes;
+      for (int j = 0; j < arg_dims[i]; j++)
+        shapes.push_back(arg_shapes[i][j]);
+
+      MXTensor tensor(arg_data[i], shapes, (MXDType)arg_types[i],
+            arg_IDs[i], {arg_dev_type[i], arg_dev_id[i]});
+      args[arg_names[i]] = tensor;
+    }
+    // create a map of named tensors for aux
+    std::map<std::string, MXTensor> aux;
+    for (int i = 0; i < num_aux; i++) {
+      std::vector<int64_t> shapes;
+      for (int j = 0; j < aux_dims[i]; j++)
+        shapes.push_back(aux_shapes[i][j]);
+
+      MXTensor tensor(aux_data[i], shapes, (MXDType)aux_types[i],
+            aux_IDs[i], {aux_dev_type[i], aux_dev_id[i]});
+      aux[aux_names[i]] = tensor;
     }
 
+
     // attributes to set on subgraph node
     std::unordered_map<std::string, std::string> attrs;
 
-    MXReturnValue retval = reviewSubgraph(subgraph_json, subgraph_id, &accept_bool, opts, attrs);
+    MXReturnValue retval = reviewSubgraph(subgraph_json, subgraph_id, &accept_bool,
+                                          opts, attrs, args, aux);
+    if (!retval) return retval;
+
     *accept = accept_bool;
 
     if (attrs.size() > 0) {
diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h
index e7b4193..a5a9b8e 100644
--- a/include/mxnet/random_generator.h
+++ b/include/mxnet/random_generator.h
@@ -96,6 +96,11 @@ class RandGenerator<cpu, DType> {
     for (int i = 0; i < kNumRandomStates; ++i) (states_ + i)->seed(seed + i);
   }
 
+  // export global random states, used by c++ custom operator
+  MSHADOW_XINLINE void* GetStates() {
+    return static_cast<void*>(states_);
+  }
+
  private:
   std::mt19937 *states_;
 };  // class RandGenerator<cpu, DType>
@@ -165,6 +170,9 @@ class RandGenerator<gpu, DType> {
 
   void Seed(mshadow::Stream<gpu> *s, uint32_t seed);
 
+  // export global random states, used by c++ custom operator
+  void* GetStates();
+
  private:
   curandStatePhilox4_32_10_t *states_;
 };  // class RandGenerator<gpu, DType>
diff --git a/perl-package/AI-MXNetCAPI/mxnet.i b/perl-package/AI-MXNetCAPI/mxnet.i
index 3bc53d6..846b28f 100644
--- a/perl-package/AI-MXNetCAPI/mxnet.i
+++ b/perl-package/AI-MXNetCAPI/mxnet.i
@@ -1633,6 +1633,8 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
                                    const mx_uint in,
                                    NDArrayHandle* in,
                                    const mx_uint in,
+                                   NDArrayHandle* in,
+                                   const mx_uint in,
                                    const char** keys,
                                    const char** vals);
 
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index e925b31..bed6679 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -973,11 +973,15 @@ class HybridBlock(Block):
             # get list of params in the order of out.list_arguments
             arg_array = [args[data_names[name]] if name in data_names.keys() else params[name].data()
                          for name in out.list_arguments()]
+            aux_array = [args[data_names[name]] if name in data_names.keys() else params[name].data()
+                         for name in out.list_auxiliary_states()]
             # Partition the graph.
-            out = out.optimize_for(self._backend, arg_array, ctx, **self._backend_opts)
-
+            out = out.optimize_for(self._backend, arg_array, aux_array, ctx, **self._backend_opts)
+            #update cached graph with partitioned graph
+            self._cached_graph = data, out
         self._cached_op = ndarray.CachedOp(out, flags)
 
+
     def _deferred_infer_shape(self, *args):
         try:
             self.infer_shape(*args)
@@ -1026,6 +1030,69 @@ class HybridBlock(Block):
             out = [out]
         return _regroup(out, self._out_format)
 
+    def optimize_for(self, x, *args, backend=None, backend_opts=None, **kwargs):
+        """Partitions the current HybridBlock and optimizes it for a given backend
+        without executing a forward pass. Modifies the HybridBlock in-place.
+
+        Immediately partitions a HybridBlock using the specified backend. Combines
+        the work done in the hybridize API with part of the work done in the forward
+        pass without calling the CachedOp. Can be used in place of hybridize,
+        afterwards `export` can be called or inference can be run. See README.md in
+        example/extensions/lib_subgraph/README.md for more details.
+
+        Examples
+        --------
+        # partition and then export to file
+        block.optimize_for(x, backend='myPart')
+        block.export('partitioned')
+
+        # partition and then run inference
+        block.optimize_for(x, backend='myPart')
+        block(x)
+
+        Parameters
+        ----------
+        x : NDArray
+            first input to model
+        *args : NDArray
+            other inputs to model
+        backend : str
+            The name of backend, as registered in `SubgraphBackendRegistry`, default None
+        backend_opts : dict of user-specified options to pass to the backend for partitioning, optional
+            Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`
+        static_alloc : bool, default False
+            Statically allocate memory to improve speed. Memory usage may increase.
+        static_shape : bool, default False
+            Optimize for invariant input shapes between iterations. Must also
+            set static_alloc to True. Change of input shapes is still allowed
+            but slower.
+        """
+
+        # do hybrize API call
+        self.hybridize(True, backend, backend_opts, **kwargs)
+
+        # do part of forward API call
+        has_symbol, has_ndarray, ctx_set, _ = _gather_type_ctx_info([x] + list(args))
+        if has_symbol:
+            raise ValueError('Inputs must be NDArrays for the optimize_for API'
+                             ' Please check the type of the args.\n')
+        if not has_symbol and not has_ndarray:
+            raise ValueError('In HybridBlock, there must be one NDArray as input.'
+                             ' Please check the type of the args.\n')
+        if len(ctx_set) > 1:
+            raise ValueError('Find multiple contexts in the input, '
+                             'After hybridized, the HybridBlock only supports one input '
+                             'context. You can print the ele.ctx in the '
+                             'input arguments to inspect their contexts. '
+                             'Find all contexts = {}'.format(ctx_set))
+
+        self._build_cache(x, *args)
+        assert self._cached_op, "Gluon failed to build the cache. " \
+                                "This should never happen. " \
+                                "Please submit an issue on Github" \
+                                " https://github.com/apache/incubator-mxnet."
+        # do not actually call the cached_op
+
     def _clear_cached_op(self):
         self._cached_graph = ()
         self._cached_op = None
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index a4599c8..0a19018 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -1445,7 +1445,7 @@ class Symbol(SymbolBase):
         return Symbol(handle)
 
 
-    def optimize_for(self, backend, args=None, ctx=None, **kwargs):
+    def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
         """Partitions current symbol and optimizes it for a given backend,
         returns new partitioned symbol.
 
@@ -1461,6 +1461,13 @@ class Symbol(SymbolBase):
             - If type is a dict of str to `NDArray`, then it maps the name of arguments
               to the corresponding `NDArray`.
 
+        aux : list of NDArray or dict of str to NDArray, optional
+            Input auxiliary arguments to the symbol
+
+            - If type is a list of `NDArray`, the order is the same as that of `list_arguments()`.
+            - If type is a dict of str to `NDArray`, then it maps the name of arguments
+              to the corresponding `NDArray`.
+
         ctx : Context, optional
             Device context, used to infer stypes
 
@@ -1475,13 +1482,19 @@ class Symbol(SymbolBase):
         out = SymbolHandle()
         assert isinstance(backend, str)
 
-        if args is None:
+        if args is None or len(args) == 0:
             args = []
             args_handle = c_array(NDArrayHandle, [])
         else:
-            listed_arguments = self.list_arguments()
-            args_handle, args = self._get_ndarray_inputs('args', args, listed_arguments, False)
+            args_handle, args = self._get_ndarray_inputs('args', args,
+                                                         self.list_arguments(), False)
 
+        if aux is None or len(aux) == 0:
+            aux = []
+            aux_handle = c_array(NDArrayHandle, [])
+        else:
+            aux_handle, aux = self._get_ndarray_inputs('aux_states', aux,
+                                                       self.list_auxiliary_states(), False)
         if ctx is None:
             ctx = current_context()
         assert isinstance(ctx, Context)
@@ -1497,6 +1510,8 @@ class Symbol(SymbolBase):
                                              ctypes.byref(out),
                                              mx_uint(len(args)),
                                              args_handle,
+                                             mx_uint(len(aux)),
+                                             aux_handle,
                                              mx_uint(len(key_list)),
                                              c_str_array(key_list),
                                              c_str_array(val_list)))
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 31b9d84..09fede6 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -113,23 +113,54 @@ void CustomFComputeDispatcher(const std::string op_name,
                               const std::vector<OpReqType>& req,
                               const std::vector<NDArray>& outputs) {
   std::vector<void*> in_data, out_data;
-  std::vector<const int64_t *> in_shapes, out_shapes;
+  std::vector<const int64_t*> in_shapes, out_shapes;
   std::vector<int> in_dims, out_dims;
   std::vector<int> in_types, out_types;
   std::vector<size_t> in_verIDs, out_verIDs;
   std::vector<const char*> in_dev_type, out_dev_type;
   std::vector<int> in_dev_id, out_dev_id;
+  std::vector<NDArray> conv_mkl;  // converted NDArrays from MKLDNN format
+
+  // Extra data for sparse inputs and outputs.
+  std::vector<int> in_stypes(inputs.size(), 0), out_stypes(outputs.size(), 0);
+  std::vector<void*> in_indices(inputs.size(), nullptr), out_indices(outputs.size(), nullptr);
+  std::vector<void*> in_indptr(inputs.size(), nullptr), out_indptr(outputs.size(), nullptr);
+  std::vector<int64_t> in_indices_shapes(inputs.size(), 0), out_indices_shapes(outputs.size(), 0);
+  std::vector<int64_t> in_indptr_shapes(inputs.size(), 0), out_indptr_shapes(outputs.size(), 0);
 
   // convert inputs/outpus NDArray to C types to be passed to lib_api.h
   for (size_t i = 0; i < inputs.size(); i++) {
-    in_data.push_back(inputs[i].data().dptr_);
-    in_shapes.push_back(inputs[i].shape().data());
-    in_dims.push_back(inputs[i].shape().ndim());
-    in_types.push_back(inputs[i].dtype());
-    in_verIDs.push_back(inputs[i].version());
-    const char* ctx_str = inputs[i].ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
+    NDArray const* in_nd = &(inputs[i]);
+#if MXNET_USE_MKLDNN == 1
+    // reorder data if in MKLDNN format
+    if (in_nd->IsMKLDNNData()) {
+      // convert from MKLDNN
+      conv_mkl.push_back(in_nd->Reorder2Default());
+      in_nd = &(conv_mkl.back());
+    }
+#endif
+    // pull out parts to pass over to library
+    in_data.push_back(in_nd->data().dptr_);
+    in_shapes.push_back(in_nd->shape().data());
+    in_dims.push_back(in_nd->shape().ndim());
+    in_types.push_back(in_nd->dtype());
+    in_verIDs.push_back(in_nd->version());
+    // string repr of supported context for custom library, currently only "cpu" and "gpu"
+    const char* ctx_str = in_nd->ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
     in_dev_type.push_back(ctx_str);
-    in_dev_id.push_back(inputs[i].ctx().real_dev_id());
+
+    in_dev_id.push_back(in_nd->ctx().real_dev_id());
+    if (inputs[i].storage_type() == mxnet::kRowSparseStorage) {
+      in_stypes[i] = 1;
+      in_indices[i] = inputs[i].aux_data(rowsparse::kIdx).dptr_;
+      in_indices_shapes[i] = inputs[i].aux_shape(rowsparse::kIdx).Size();
+    } else if (inputs[i].storage_type() == mxnet::kCSRStorage) {
+      in_stypes[i] = 2;
+      in_indices[i] = inputs[i].aux_data(csr::kIdx).dptr_;
+      in_indptr[i] = inputs[i].aux_data(csr::kIndPtr).dptr_;
+      in_indices_shapes[i] = inputs[i].aux_shape(csr::kIdx).Size();
+      in_indptr_shapes[i] = inputs[i].aux_shape(csr::kIndPtr).Size();
+    }
   }
 
   for (size_t i = 0; i < outputs.size(); i++) {
@@ -141,10 +172,24 @@ void CustomFComputeDispatcher(const std::string op_name,
     const char* ctx_str = outputs[i].ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
     out_dev_type.push_back(ctx_str);
     out_dev_id.push_back(outputs[i].ctx().real_dev_id());
+
+    if (outputs[i].storage_type() == mxnet::kRowSparseStorage) {
+      out_stypes[i] = 1;
+      out_indices[i] = outputs[i].aux_data(rowsparse::kIdx).dptr_;
+      out_indices_shapes[i] = outputs[i].aux_shape(rowsparse::kIdx).Size();
+    } else if (outputs[i].storage_type() == mxnet::kCSRStorage) {
+      out_stypes[i] = 2;
+      out_indices[i] = outputs[i].aux_data(csr::kIdx).dptr_;
+      out_indptr[i] = outputs[i].aux_data(csr::kIndPtr).dptr_;
+      out_indices_shapes[i] = outputs[i].aux_shape(csr::kIdx).Size();
+      out_indptr_shapes[i] = outputs[i].aux_shape(csr::kIndPtr).Size();
+    }
   }
 
   // get memory resource and mxnet backend streams
-  const Resource &resource = ctx.requested[0];
+  CHECK(ctx.requested.size() >= 2)
+    << "Custom operator should register at least memory resource and parallel random resource";
+  const Resource &resource = ctx.requested.at(0);
   mshadow::Stream<mxnet::cpu> *cpu_stream = ctx.get_stream<mxnet::cpu>();
   mshadow::Stream<mxnet::gpu> *gpu_stream = ctx.get_stream<mxnet::gpu>();
 
@@ -161,7 +206,25 @@ void CustomFComputeDispatcher(const std::string op_name,
     return workspace.dptr_;
   };
 
-  // create lambda without captures so that we can cast it to function pointer
+  // create lambda that allocates memory for sparse and
+  // returns allocated arrays for data, indices and indptr.
+  auto sparse_alloc = [&](int index, int indices_len, int idxptr_len,
+                           void** data, int64_t** indices, int64_t** indptr) {
+    if (idxptr_len == 0) {
+      // Row Sparse
+      outputs[index].CheckAndAlloc({mshadow::Shape1(indices_len)});
+      *data = outputs[index].data().dptr_;
+      *indices = reinterpret_cast<int64_t*>(outputs[index].aux_data(rowsparse::kIdx).dptr_);
+    } else {
+      // CSR
+      outputs[index].CheckAndAlloc({mshadow::Shape1(idxptr_len), mshadow::Shape1(indices_len)});
+      *data = outputs[index].data().dptr_;
+      *indices = reinterpret_cast<int64_t*>(outputs[index].aux_data(csr::kIdx).dptr_);
+      *indptr = reinterpret_cast<int64_t*>(outputs[index].aux_data(csr::kIndPtr).dptr_);
+    }
+  };
+
+  // create no-capture lambda so that we can cast it to function pointer
   // lambda with captures cannot be cast to function pointer and pass to lib_api.h
   // this needs to be a lambda function so that we can do the decltype cast
   typedef decltype(cpu_alloc) alloc_type_cpu;
@@ -171,20 +234,39 @@ void CustomFComputeDispatcher(const std::string op_name,
     // call cpu_alloc to actually allocate memory and return the pointer
     return static_cast<void*>((*cpualloc)(size));
   };
+
   typedef decltype(gpu_alloc) alloc_type_gpu;
   auto gpu_malloc = [](void* _gpu_alloc, int size) {
     alloc_type_gpu* gpualloc = static_cast<alloc_type_gpu*>(_gpu_alloc);
     return static_cast<void*>((*gpualloc)(size));
   };
 
+  typedef decltype(sparse_alloc) alloc_type_sparse;
+  auto sparse_malloc = [](void* _sparse_alloc, int index, int indices_len, int idxptr_len,
+                           void** data, int64_t** indices, int64_t** indptr) {
+    alloc_type_sparse* sparsealloc = static_cast<alloc_type_sparse*>(_sparse_alloc);
+    (*sparsealloc)(index, indices_len, idxptr_len, data, indices, indptr);
+  };
+
   // get actual cudaStream_t out of mxnet gpu stream and pass to lib_api.h
   void *cuda_stream = nullptr;
 #if MXNET_USE_CUDA
-  if (inputs[0].ctx().dev_mask() == Context::kGPU) {
+  if ((inputs.size() > 0 && inputs[0].ctx().dev_mask() == Context::kGPU) ||
+      (outputs.size() > 0 && outputs[0].ctx().dev_mask() == Context::kGPU)) {
     cuda_stream = static_cast<void*>(gpu_stream->stream_);
   }
 #endif
 
+  // get mxnet initialized and seeded RNG states and pass to lib_api.h
+  void *rng_cpu_states = nullptr, *rng_gpu_states = nullptr;
+  using mxnet::common::random::RandGenerator;
+  RandGenerator<cpu, float> *pgen_cpu = ctx.requested.at(1).get_parallel_random<cpu, float>();
+  rng_cpu_states = pgen_cpu->GetStates();
+#if MXNET_USE_CUDA
+  RandGenerator<gpu, float> *pgen_gpu = ctx.requested.at(1).get_parallel_random<gpu, float>();
+  rng_gpu_states = pgen_gpu->GetStates();
+#endif
+
   CHECK((fcomp_fp != nullptr && state_ptr == nullptr)
         || (fcomp_fp == nullptr && state_ptr != nullptr))
     << "Can only register either regular op or stateful op for '" << op_name << "'";
@@ -192,17 +274,23 @@ void CustomFComputeDispatcher(const std::string op_name,
   if (fcomp_fp != nullptr) {
     // convert attributes to vector of char*
     std::vector<const char*> attr_keys, attr_vals;
-    for (auto kv : attrs->dict) {
+    for (auto &kv : attrs->dict) {
       attr_keys.push_back(kv.first.c_str());
       attr_vals.push_back(kv.second.c_str());
     }
+
     // call fcompute function
     CHECK(callFComp(fcomp_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
                     in_shapes.data(), in_dims.data(), in_data.data(), in_types.data(),
                     in_verIDs.data(), in_dev_type.data(), in_dev_id.data(), in_data.size(),
                     out_shapes.data(), out_dims.data(), out_data.data(), out_types.data(),
                     out_verIDs.data(), out_dev_type.data(), out_dev_id.data(), out_data.size(),
-                    cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream))
+                    cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream,
+                    sparse_malloc, &sparse_alloc, in_stypes.data(), out_stypes.data(),
+                    in_indices.data(), out_indices.data(), in_indptr.data(), out_indptr.data(),
+                    in_indices_shapes.data(), out_indices_shapes.data(),
+                    in_indptr_shapes.data(), out_indptr_shapes.data(),
+                    rng_cpu_states, rng_gpu_states))
       << "Error calling FCompute for custom operator '" << op_name << "'";
   }
 
@@ -221,7 +309,13 @@ void CustomFComputeDispatcher(const std::string op_name,
                             out_shapes.data(), out_dims.data(), out_data.data(), out_types.data(),
                             out_verIDs.data(), out_dev_type.data(), out_dev_id.data(),
                             out_data.size(),
-                            cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream))
+                            cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream,
+                            sparse_malloc, &sparse_alloc, in_stypes.data(), out_stypes.data(),
+                            in_indices.data(), out_indices.data(),
+                            in_indptr.data(), out_indptr.data(),
+                            in_indices_shapes.data(), out_indices_shapes.data(),
+                            in_indptr_shapes.data(), out_indptr_shapes.data(),
+                            rng_cpu_states, rng_gpu_states))
       << "Error calling FStatefulCompute for custom operator '" << op_name << "'";
   }
 }
@@ -260,6 +354,9 @@ int MXLoadLib(const char *path) {
   opCallInferType_t callInferType =
     get_func<opCallInferType_t>(lib, const_cast<char*>(MXLIB_OPCALLINFERTYPE_STR));
 
+  opCallInferSType_t callInferSType =
+    get_func<opCallInferSType_t>(lib, const_cast<char*>(MXLIB_OPCALLINFERSTYPE_STR));
+
   opCallFComp_t callFComp =
     get_func<opCallFComp_t>(lib, const_cast<char*>(MXLIB_OPCALLFCOMP_STR));
 
@@ -275,7 +372,6 @@ int MXLoadLib(const char *path) {
   partCallSupportedOps_t callSupportedOps =
     get_func<partCallSupportedOps_t>(lib, const_cast<char*>(MXLIB_PARTCALLSUPPORTEDOPS_STR));
 
-
   partCallReviewSubgraph_t callReviewSubgraph =
     get_func<partCallReviewSubgraph_t>(lib, const_cast<char*>(MXLIB_PARTCALLREVIEWSUBGRAPH_STR));
 
@@ -294,6 +390,7 @@ int MXLoadLib(const char *path) {
     // function pointers holding implementation from custom library
     parseAttrs_t parse_fp = nullptr;
     inferType_t type_fp = nullptr;
+    inferSType_t stype_fp = nullptr;
     inferShape_t shape_fp = nullptr;
     // optional attributes
     mutateInputs_t mutate_fp = nullptr;
@@ -310,7 +407,7 @@ int MXLoadLib(const char *path) {
              &forward_ctx, &forward_fcomp, &forward_count,
              &backward_ctx, &backward_fcomp, &backward_count,
              &createop_ctx, &createop_fp, &createop_count,
-             &parse_fp, &type_fp, &shape_fp, &mutate_fp);
+             &parse_fp, &type_fp, &stype_fp, &shape_fp, &mutate_fp);
 
     // construct maps of context to forward/backward custom library function
     std::unordered_map<std::string, fcomp_t> forward_ctx_map;
@@ -353,14 +450,14 @@ int MXLoadLib(const char *path) {
     /*
      * Below are a series of lambda functions that will be registered in the NNVM op registration
      * Each one has the standard MXNet signature and converts to types supported by externally
-     * registered operators. 
+     * registered operators.
      */
 
     // lambda function to call parse attributes
     auto attr_parser = [=](const NodeAttrs* attrs) {
       // convert attributes to vector of char
       std::vector<const char*> attr_keys, attr_vals;
-      for (auto kv : attrs->dict) {
+      for (auto &kv : attrs->dict) {
         attr_keys.push_back(kv.first.c_str());
         attr_vals.push_back(kv.second.c_str());
       }
@@ -370,7 +467,7 @@ int MXLoadLib(const char *path) {
         nnvm::Graph g;
         g.outputs = attrs->subgraphs[0].get()->outputs;
         subgraph_json = nnvm::pass::SaveJSON(g);
-        attr_keys.push_back(SUBGRAPH_SYM_JSON);
+        attr_keys.push_back(MX_STR_SUBGRAPH_SYM_JSON);
         attr_vals.push_back(subgraph_json.c_str());
       }
 
@@ -387,7 +484,7 @@ int MXLoadLib(const char *path) {
     auto num_inputs = [=](const NodeAttrs& attrs) {
       // convert attributes to vector of char
       std::vector<const char*> attr_keys, attr_vals;
-      for (auto kv : attrs.dict) {
+      for (auto &kv : attrs.dict) {
         attr_keys.push_back(kv.first.c_str());
         attr_vals.push_back(kv.second.c_str());
       }
@@ -405,7 +502,7 @@ int MXLoadLib(const char *path) {
     auto num_outputs = [=](const NodeAttrs& attrs) {
       // convert attributes to vector of char*
       std::vector<const char*> attr_keys, attr_vals;
-      for (auto kv : attrs.dict) {
+      for (auto &kv : attrs.dict) {
         attr_keys.push_back(kv.first.c_str());
         attr_vals.push_back(kv.second.c_str());
       }
@@ -424,7 +521,7 @@ int MXLoadLib(const char *path) {
     auto num_inouts = [=](const NodeAttrs& attrs) {
       // convert attributes to vector of char*
       std::vector<const char*> attr_keys, attr_vals;
-      for (auto kv : attrs.dict) {
+      for (auto &kv : attrs.dict) {
         attr_keys.push_back(kv.first.c_str());
         attr_vals.push_back(kv.second.c_str());
       }
@@ -444,7 +541,7 @@ int MXLoadLib(const char *path) {
                             mxnet::ShapeVector *out_shape) {
       // convert attributes to vector of char*
       std::vector<const char*> attr_keys, attr_vals;
-      for (auto kv : attrs.dict) {
+      for (auto &kv : attrs.dict) {
         attr_keys.push_back(kv.first.c_str());
         attr_vals.push_back(kv.second.c_str());
       }
@@ -515,7 +612,7 @@ int MXLoadLib(const char *path) {
                             std::vector<int> *out_type) {
       // convert attributes to vector of char*
       std::vector<const char*> attr_keys, attr_vals;
-      for (auto kv : attrs.dict) {
+      for (auto &kv : attrs.dict) {
         attr_keys.push_back(kv.first.c_str());
         attr_vals.push_back(kv.second.c_str());
       }
@@ -543,7 +640,7 @@ int MXLoadLib(const char *path) {
     auto mutate_inputs = [=](const nnvm::NodeAttrs& attrs) {
       // convert attributes to vector of char*
       std::vector<const char*> attr_keys, attr_vals;
-      for (auto kv : attrs.dict) {
+      for (auto &kv : attrs.dict) {
         attr_keys.push_back(kv.first.c_str());
         attr_vals.push_back(kv.second.c_str());
       }
@@ -571,12 +668,39 @@ int MXLoadLib(const char *path) {
                                 DispatchMode* dispatch_mode,
                                 std::vector<int>* in_stypes,
                                 std::vector<int>* out_stypes) {
-      // TODO(ziyimu): remove this dense enforce check after supporting sparse tensor
-      CHECK(mxnet::common::ContainsOnlyStorage(*in_stypes, mxnet::kDefaultStorage))
-      << "Error input tensors are not dense for custom operator '" << name_str << "'";
-      // set outputs as dense
-      return op::storage_type_assign(out_stypes, mxnet::kDefaultStorage,
-                                     dispatch_mode, DispatchMode::kFComputeEx);
+      if (stype_fp == nullptr) {
+        // InferSType is not defineid in customized lib.
+        CHECK(mxnet::common::ContainsOnlyStorage(*in_stypes, mxnet::kDefaultStorage))
+        << "Error input tensors are not dense for custom operator '" << name_str << "'";
+        // set outputs as dense
+        return op::storage_type_assign(out_stypes, mxnet::kDefaultStorage,
+                                       dispatch_mode, DispatchMode::kFComputeEx);
+      } else {
+        // InferSType is defined in customized lib.
+        // convert attributes to vector of char*
+        std::vector<const char*> attr_keys, attr_vals;
+        for (auto kv : attrs.dict) {
+          attr_keys.push_back(kv.first.c_str());
+          attr_vals.push_back(kv.second.c_str());
+        }
+        // copy input types from in_stype
+        std::vector<int> instypes(*in_stypes);
+
+        // output types will be populated by inferType function
+        std::vector<int> outstypes(out_stypes->size());
+        CHECK(callInferSType(stype_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
+                             instypes.data(), in_stypes->size(),
+                             outstypes.data(), out_stypes->size()))
+        << "Error calling InferSType for custom operator '" << name_str << "'";
+
+        // copy and assign output storage types from custom op to MXNet memory.
+        for (size_t i = 0; i < out_stypes->size(); i++) {
+          STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, outstypes[i]);
+        }
+        // assign dispatch mode
+        DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+        return true;
+      }
     };
 
     // FGradient register lambda
@@ -617,7 +741,8 @@ int MXLoadLib(const char *path) {
     };
 
     auto resc_req = [=](const NodeAttrs& attrs) {
-      return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+      return std::vector<ResourceRequest>{ResourceRequest::kTempSpace,
+                                          ResourceRequest::kParallelRandom};
     };
 
     // library author should implement and return a 'state' which points to an instance
@@ -628,7 +753,7 @@ int MXLoadLib(const char *path) {
                                const std::vector<int>& in_types) {
       // convert attributes to vector of char*
       std::vector<const char*> attr_keys, attr_vals;
-      for (auto kv : attrs.dict) {
+      for (auto &kv : attrs.dict) {
         attr_keys.push_back(kv.first.c_str());
         attr_vals.push_back(kv.second.c_str());
       }
@@ -639,7 +764,7 @@ int MXLoadLib(const char *path) {
         nnvm::Graph g;
         g.outputs = attrs.subgraphs[0].get()->outputs;
         subgraph_json = nnvm::pass::SaveJSON(g);
-        attr_keys.push_back(SUBGRAPH_SYM_JSON);
+        attr_keys.push_back(MX_STR_SUBGRAPH_SYM_JSON);
         attr_vals.push_back(subgraph_json.c_str());
       }
 
@@ -681,14 +806,15 @@ int MXLoadLib(const char *path) {
       // TODO(samskalicky): enable constant overwriting of registertion multiple times
       plevel++;
     }
+    // define supported resources for both subgraph ops and regular ops
+    regOp.set_attr<FResourceRequest>("FResourceRequest", resc_req, plevel);
     if (!isSubgraphOp) {
       regOp.set_attr_parser(attr_parser);
       regOp.set_num_inputs(num_inputs);
       regOp.set_num_outputs(num_outputs);
       regOp.set_attr<nnvm::FInferType>("FInferType", infer_type, plevel);
-      regOp.set_attr<mxnet::FInferShape>("FInferShape", infer_shape, plevel);
       regOp.set_attr<FInferStorageType>("FInferStorageType", infer_storage_type, plevel);
-      regOp.set_attr<FResourceRequest>("FResourceRequest", resc_req, plevel);
+      regOp.set_attr<mxnet::FInferShape>("FInferShape", infer_shape, plevel);
       // optionally add fmutate inputs if user specified a function
       if (mutate_fp != nullptr)
         regOp.set_attr<nnvm::FMutateInputs>("FMutateInputs", mutate_inputs, plevel);
@@ -700,8 +826,6 @@ int MXLoadLib(const char *path) {
       regOp.set_attr<mxnet::FInferShape>("FInferShape", DefaultSubgraphOpShape, plevel);
       regOp.set_attr<FInferStorageType>("FInferStorageType",
                                         DefaultSubgraphOpStorageType, plevel);
-      regOp.set_attr<FResourceRequest>("FResourceRequest",
-                                       DefaultSubgraphOpResourceRequest, plevel);
       regOp.set_attr<nnvm::FMutateInputs>("FMutateInputs",
                                           DefaultSubgraphOpMutableInputs, plevel);
     }
@@ -779,8 +903,7 @@ int MXLoadLib(const char *path) {
                                    const std::vector<OpReqType>& req,
                                    const std::vector<NDArray>& outputs) {
           CustomFComputeDispatcher(name_str, nullptr, nullptr, nullptr,
-                                   callFStatefulComp, 0, &state_ptr,
-                                   ctx, inputs, req, outputs);
+                                   callFStatefulComp, 0, &state_ptr, ctx, inputs, req, outputs);
         };
         gradOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", fstate_backward, plevel);
         gradOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", fstate_backward, plevel);
@@ -857,12 +980,10 @@ int MXLoadLib(const char *path) {
       std::string op_name_str(op_name);
       LOG(INFO) << "\t\tStrategy[" << j << "] " << strategy_str
                 << " subgraphOp: '" << op_name_str << "'";
-
-      // MXNET_REGISTER_SUBGRAPH_PROPERTY(customBackend, CustomSubgraphProperty);
-      mxnet::op::SubgraphBackendRegistry::Get()->__REGISTER_CUSTOM_PROPERTY__(name_str,
-                            std::make_shared<mxnet::op::CustomSubgraphProperty>(
-                           strategy_str, callSupportedOps, supportedOps_fp,
-                           callReviewSubgraph, reviewSubgraph_fp, callFree, op_name_str));
+      mxnet::op::SubgraphBackendRegistry::Get()->__REGISTER_CUSTOM_PROPERTY__
+        (name_str, std::make_shared<mxnet::op::CustomSubgraphProperty>
+          (strategy_str, callSupportedOps, supportedOps_fp,
+           callReviewSubgraph, reviewSubgraph_fp, callFree, op_name_str));
     }
   }
   API_END();
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 8f78fc1..d2b17a9 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -1343,32 +1343,54 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
                          const char* backend_name,
                          const int dev_type,
                          SymbolHandle* ret_sym_handle,
-                         const mx_uint len,
+                         const mx_uint args_len,
                          NDArrayHandle* in_args_handle,
+                         const mx_uint aux_len,
+                         NDArrayHandle* in_aux_handle,
                          const mx_uint num_options,
                          const char** keys,
                          const char** vals) {
+  // create copy of input symbol
   nnvm::Symbol *s = new nnvm::Symbol();
   API_BEGIN();
   nnvm::Symbol *sym = static_cast<nnvm::Symbol *>(sym_handle);
   *s = sym->Copy();
   nnvm::Graph g = Symbol2Graph(*s);
-  if (len) {
+  const auto& indexed_graph = g.indexed_graph();
+  const auto& mutable_nodes = indexed_graph.mutable_input_nodes();
+  std::vector<std::string> input_names = sym->ListInputNames(nnvm::Symbol::kAll);
+  size_t num_forward_inputs = input_names.size();
+  if (args_len || aux_len) {
     NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
+    NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
     Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
-    mxnet::ShapeVector arg_shapes(len);
-    nnvm::DTypeVector arg_dtypes(len);
-    StorageTypeVector arg_stypes(len);
-    for (mx_uint i = 0; i < len; i++) {
-      const auto &in_arg = *(in_args_ptr[i]);
-      arg_shapes[i] = in_arg.shape();
-      arg_dtypes[i] = in_arg.dtype();
-      arg_stypes[i] = in_arg.storage_type();
+    mxnet::ShapeVector arg_shapes(args_len + aux_len);
+    nnvm::DTypeVector arg_dtypes(args_len + aux_len);
+    StorageTypeVector arg_stypes(args_len + aux_len);
+    size_t args_top = 0, aux_top = 0;
+    // loop over inputs to symbol in order and add to args/aux if mutable
+    for (size_t i = 0; i < num_forward_inputs; ++i) {
+      const uint32_t nid = indexed_graph.input_nodes().at(i);
+      if (mutable_nodes.count(nid)) {
+        CHECK_LT(aux_top, aux_len)
+          << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
+        const auto &in_arg = *(in_aux_ptr[aux_top++]);
+        arg_shapes[i] = in_arg.shape();
+        arg_dtypes[i] = in_arg.dtype();
+        arg_stypes[i] = in_arg.storage_type();
+      } else {
+        CHECK_LT(args_top, args_len)
+          << "Cannot find arg '" << input_names[i] << "' in provided args to optimize_for";
+        const auto &in_arg = *(in_args_ptr[args_top++]);
+        arg_shapes[i] = in_arg.shape();
+        arg_dtypes[i] = in_arg.dtype();
+        arg_stypes[i] = in_arg.storage_type();
+      }
     }
-    const auto& indexed_graph = g.indexed_graph();
-    const auto num_forward_inputs = indexed_graph.input_nodes().size();
+
     g.attrs["context"] = std::make_shared<nnvm::any>(
         exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
+
     // infer shapes
     g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
     // infer dtypes
@@ -1383,11 +1405,31 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
       common::HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
                                           g.GetAttr<StorageTypeVector>("storage_type"));
     }
+    // set args/aux as attributes on graph so that subgraph property can use them
+    std::vector<std::string> arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
+    g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
+    g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);
+
+    std::vector<std::string> aux_names = sym->ListInputNames(nnvm::Symbol::kAuxiliaryStates);
+    g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
+    g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
+  } else {
+    // args/aux were not specified, so set nullptr/empty-lists
+    NDArray **in_args_ptr = static_cast<NDArray**>(nullptr);
+    std::vector<std::string> arg_names;
+    g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
+    g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);
+
+    NDArray **in_aux_ptr = static_cast<NDArray**>(nullptr);
+    std::vector<std::string> aux_names;
+    g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
+    g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
   }
+  // create a data structure from pointer array
   std::vector<std::pair<std::string, std::string>> options_map;
-  for (mx_uint i = 0; i < num_options; ++i) {
+  for (mx_uint i = 0; i < num_options; ++i)
     options_map.emplace_back(keys[i], vals[i]);
-  }
+
   const auto backend = mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(backend_name);
   const auto& subgraph_prop_list = backend->GetSubgraphProperties();
   for (auto property : subgraph_prop_list) {
diff --git a/src/common/random_generator.cu b/src/common/random_generator.cu
index a2d3e0d..8f7b959 100644
--- a/src/common/random_generator.cu
+++ b/src/common/random_generator.cu
@@ -70,6 +70,11 @@ void RandGenerator<gpu, float>::FreeState(RandGenerator<gpu> *inst) {
   CUDA_CALL(cudaFree(inst->states_));
 }
 
+template<>
+void* RandGenerator<gpu, float>::GetStates() {
+  return static_cast<void*>(states_);
+}
+
 }  // namespace random
 }  // namespace common
 }  // namespace mxnet
diff --git a/src/operator/subgraph/build_subgraph.cc b/src/operator/subgraph/build_subgraph.cc
index a66e8a1..2d5501d 100644
--- a/src/operator/subgraph/build_subgraph.cc
+++ b/src/operator/subgraph/build_subgraph.cc
@@ -560,11 +560,7 @@ void CutGraphInputs(const std::vector<nnvm::NodeEntry*> &input_entries,
     }
     nnvm::ObjectPtr n = nnvm::CreateVariableNode(
         var_name + std::to_string(name_count_map[var_name]));
-    // set attribute for subgraph input to indicate if it is from an arg/param to model
-    if (e->node->is_variable())
-      n->attrs.dict["isArg"] = "True";
-    else
-      n->attrs.dict["isArg"] = "False";
+
     *e = nnvm::NodeEntry{n, 0, 0};
   }
 }
@@ -583,7 +579,7 @@ void ReattachGraphInputs(const std::vector<nnvm::NodeEntry*> &input_entries,
 }
 
 /*!
- * \brief Replace a set of nodes belonging to the same subgraph with a subgrpah node
+ * \brief Replace a set of nodes belonging to the same subgraph with a subgraph node
  * and keep the subgraph in the subgraph node.
  */
 void CreateSubgraphNode(nnvm::Graph* g,
@@ -613,6 +609,7 @@ void CreateSubgraphNode(nnvm::Graph* g,
     sym.outputs[i] = *output_entries[i];
   }
   const SubgraphPropertyPtr& subg_prop = g->GetAttr<SubgraphPropertyPtr>("subgraph_property");
+  subg_prop->InitSubgraphInputs(&input_entries, &orig_input_entries);
   nnvm::ObjectPtr n = subg_prop->CreateSubgraphNode(sym, subgraph_selector, subgraph_id);
   // CreateSubgraphNode returns NULL if subgraph property determines that subgraph is sub-optimal
   // In that case, subgraph node is not created and graph is not modified
diff --git a/src/operator/subgraph/partitioner/custom_subgraph_property.h b/src/operator/subgraph/partitioner/custom_subgraph_property.h
index 410d983..b7f2cc2 100644
--- a/src/operator/subgraph/partitioner/custom_subgraph_property.h
+++ b/src/operator/subgraph/partitioner/custom_subgraph_property.h
@@ -33,6 +33,7 @@
 #include <string>
 #include <utility>
 #include <vector>
+#include <map>
 #include "../common.h"
 #include "../subgraph_property.h"
 #include "../../include/mxnet/lib_api.h"
@@ -99,6 +100,75 @@ class  CustomSubgraphProperty: public SubgraphProperty {
     const std::vector<std::pair<std::string, std::string>>& options_map) {
     // clear supported_nodes to remove state from previous calls
     supported_nodes.clear();
+    // get input args and arg names
+    in_arg_names = g.GetAttr<std::vector<std::string>>("in_arg_names");
+    in_args_ptr = g.GetAttr<NDArray**>("in_args");
+    in_aux_names = g.GetAttr<std::vector<std::string>>("in_aux_names");
+    in_aux_ptr = g.GetAttr<NDArray**>("in_aux");
+
+    // convert input args
+    arg_names.clear();
+    arg_data.clear();
+    arg_shapes.clear();
+    arg_dims.clear();
+    arg_types.clear();
+    arg_verIDs.clear();
+    arg_dev_type.clear();
+    arg_dev_id.clear();
+    for (size_t i=0; i < in_arg_names.size(); i++) {
+      arg_names.push_back(in_arg_names[i].c_str());
+      const NDArray &in_arg = *(in_args_ptr[i]);
+
+#if MXNET_USE_MKLDNN == 1
+      // reorder data if in MKLDNN format
+      if (in_arg.IsMKLDNNData()) {
+        in_arg.Reorder2DefaultAsync();
+        in_arg.WaitToRead();
+      }
+#endif
+
+      // pull out parts of NDArray to send to backend
+      arg_data.push_back(in_arg.data().dptr_);
+      arg_shapes.push_back(in_arg.shape().data());
+      arg_dims.push_back(in_arg.shape().ndim());
+      arg_types.push_back(in_arg.dtype());
+      arg_verIDs.push_back(in_arg.version());
+      const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
+      arg_dev_type.push_back(arg_ctx_str);
+      arg_dev_id.push_back(in_arg.ctx().real_dev_id());
+    }
+
+    // convert input aux
+    aux_names.clear();
+    aux_data.clear();
+    aux_shapes.clear();
+    aux_dims.clear();
+    aux_types.clear();
+    aux_verIDs.clear();
+    aux_dev_type.clear();
+    aux_dev_id.clear();
+    for (size_t i=0; i < in_aux_names.size(); i++) {
+      aux_names.push_back(in_aux_names[i].c_str());
+      const auto &in_aux = *(in_aux_ptr[i]);
+
+#if MXNET_USE_MKLDNN == 1
+      // reorder data if in MKLDNN format
+      if (in_aux.IsMKLDNNData()) {
+        in_aux.Reorder2DefaultAsync();
+        in_aux.WaitToRead();
+      }
+#endif
+
+      // pull out parts of NDArray to send to backend
+      aux_data.push_back(in_aux.data().dptr_);
+      aux_shapes.push_back(in_aux.shape().data());
+      aux_dims.push_back(in_aux.shape().ndim());
+      aux_types.push_back(in_aux.dtype());
+      aux_verIDs.push_back(in_aux.version());
+      const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
+      aux_dev_type.push_back(aux_ctx_str);
+      aux_dev_id.push_back(in_aux.ctx().real_dev_id());
+    }
 
     // remove all graph attrs, some cannot be saved to json
     nnvm::Graph graph = std::move(g);
@@ -108,23 +178,37 @@ class  CustomSubgraphProperty: public SubgraphProperty {
     // set shape attrs for each node in the graph
     if (g.HasAttr("shape")) {
       mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
-      for (unsigned i = 0; i < indexed_graph.num_nodes(); i++) {
-        nnvm::Node* node = const_cast<nnvm::Node*>(indexed_graph[i].source);
-        mxnet::TShape shape = shapes[i];
+      for (unsigned nid = 0; nid < indexed_graph.num_nodes(); nid++) {
+        nnvm::Node* node = const_cast<nnvm::Node*>(indexed_graph[nid].source);
         std::stringstream ss;
-        ss << shape;
-        node->attrs.dict["shape"] = ss.str();
+        ss << "[";
+        // set the output shapes for this node
+        for (unsigned oid = 0; oid < node->num_outputs(); oid++) {
+          const uint32_t out_entry_id = indexed_graph.entry_id(nid, oid);
+          mxnet::TShape& shape = shapes[out_entry_id];
+          ss << shape;
+          if (oid < node->num_outputs()-1) ss << ",";
+        }
+        ss << "]";
+        node->attrs.dict[MX_STR_SHAPE] = ss.str();
       }
     }
     // set dtype attrs for each node in the graph
     if (g.HasAttr("dtype")) {
       std::vector<int> dtypes = g.GetAttr<std::vector<int> >("dtype");
-      for (unsigned i = 0; i < indexed_graph.num_nodes(); i++) {
-        nnvm::Node* node = const_cast<nnvm::Node*>(indexed_graph[i].source);
-        int dtype = dtypes[i];
+      for (unsigned nid = 0; nid < indexed_graph.num_nodes(); nid++) {
+        nnvm::Node* node = const_cast<nnvm::Node*>(indexed_graph[nid].source);
         std::stringstream ss;
-        ss << dtype;
-        node->attrs.dict["dtype"] = ss.str();
+        ss << "[";
+        // set the output dtypes for this node
+        for (unsigned oid = 0; oid < node->num_outputs(); oid++) {
+          const uint32_t out_entry_id = indexed_graph.entry_id(nid, oid);
+          int dtype = dtypes[out_entry_id];
+          ss << dtype;
+          if (oid < node->num_outputs()-1) ss << ",";
+        }
+        ss << "]";
+        node->attrs.dict[MX_STR_DTYPE] = ss.str();
       }
     }
 
@@ -142,10 +226,14 @@ class  CustomSubgraphProperty: public SubgraphProperty {
     opt_keys_.clear();
     opt_vals_.clear();
     options_map_.clear();
-    for (auto kv : options_map) {
+    // store options in map in subgraph property to re-use later for reviewSubgraph
+    for (auto& kv : options_map) {
       options_map_.push_back(kv);
-      opt_keys_.push_back(options_map_.back().first.c_str());
-      opt_vals_.push_back(options_map_.back().second.c_str());
+    }
+    // convert options_map_ to char* to pass to backend library
+    for (auto& kv : options_map_) {
+      opt_keys_.push_back(kv.first.c_str());
+      opt_vals_.push_back(kv.second.c_str());
     }
 
     CHECK(call_supported_ops_(supported_ops_, json, supported_node_IDs.size(), ids,
@@ -162,9 +250,10 @@ class  CustomSubgraphProperty: public SubgraphProperty {
   }
   // override CreateSubgraphNode
   virtual nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym,
-                                           const int subgraph_id = 0) const {
+                                             const int subgraph_id = 0) const {
     int accept = 1;
     int num_attr = 0;
+    std::map<std::string, std::string> user_attrs;
     char** attr_keys = nullptr;
     char** attr_vals = nullptr;
     if (review_subgraph_) {
@@ -173,8 +262,9 @@ class  CustomSubgraphProperty: public SubgraphProperty {
       const auto& idx = g.indexed_graph();
 
       // set isArg/isAux for each null op/param in the graph
-      const std::vector<std::string> aux_names = sym.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
-      std::unordered_set<std::string> aux_set(aux_names.begin(), aux_names.end());
+      const std::vector<std::string> aux_state_names =
+        sym.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
+      std::unordered_set<std::string> aux_set(aux_state_names.begin(), aux_state_names.end());
       for (unsigned i = 0; i < idx.num_nodes(); i++) {
         nnvm::Node* node = const_cast<nnvm::Node*>(idx[i].source);
         // check if this node is input to subgraph
@@ -188,31 +278,121 @@ class  CustomSubgraphProperty: public SubgraphProperty {
       }
 
       std::string subgraph_json = nnvm::pass::SaveJSON(g);
-      CHECK(call_review_subgraph_(review_subgraph_, subgraph_json.c_str(),
-                                subgraph_id, &accept, opt_keys_.data(),
-                                opt_vals_.data(), opt_keys_.size(),
-                                &attr_keys, &attr_vals, &num_attr))
+      CHECK(call_review_subgraph_(review_subgraph_, subgraph_json.c_str(),  subgraph_id,
+                                  &accept, opt_keys_.data(), opt_vals_.data(),
+                                  opt_keys_.size(),  &attr_keys, &attr_vals, &num_attr,
+                                  arg_names.data(), arg_names.size(), arg_data.data(),
+                                  arg_shapes.data(), arg_dims.data(), arg_types.data(),
+                                  arg_verIDs.data(), arg_dev_type.data(),
+                                  arg_dev_id.data(), aux_names.data(), aux_names.size(),
+                                  aux_data.data(), aux_shapes.data(), aux_dims.data(),
+                                  aux_types.data(), aux_verIDs.data(),
+                                  aux_dev_type.data(), aux_dev_id.data()))
         << "Error calling review_subgraph for '" << subgraph_prop << "'";
+
+      if (num_attr > 0) {
+        // set user specified attributes
+        for (int i=0; i < num_attr; i++) {
+          user_attrs[attr_keys[i]] = attr_vals[i];
+          call_free_(attr_vals[i]);
+          call_free_(attr_keys[i]);
+        }
+        // free memory used by custom op to allocate attributes
+        call_free_(attr_vals);
+        call_free_(attr_keys);
+      }
     }
+
     if (accept) {
       nnvm::ObjectPtr n = nnvm::Node::Create();
       n->attrs.op = Op::Get(subgraph_op_name);
       n->attrs.name = "_op" + std::to_string(subgraph_id);
       n->attrs.subgraphs.push_back(std::make_shared<nnvm::Symbol>(sym));
-      // set user specified attributes
-      for (int i=0; i < num_attr; i++) {
-        n->attrs.dict[attr_keys[i]] = attr_vals[i];
-        call_free_(attr_vals[i]);
-        call_free_(attr_keys[i]);
+
+      // set shapes
+      {
+        std::stringstream ss;
+        ss << "[";
+        for (unsigned i=0; i < sym.outputs.size(); i++) {
+          const nnvm::NodeEntry& e = sym.outputs[i];
+          if (e.node->attrs.dict.count("__shape__") > 0) {
+            std::string& shape = e.node->attrs.dict["__shape__"];
+            // add this shape to the list
+            ss << getShapeAt(shape, e.index);
+          }
+          if (i < sym.outputs.size()-1)
+            ss << ",";
+        }
+        ss << "]";
+        n->attrs.dict["__shape__"] = ss.str();
+      }
+      // set dtypes
+      {
+        std::stringstream ss;
+        ss << "[";
+        for (unsigned i=0; i < sym.outputs.size(); i++) {
+          const nnvm::NodeEntry& e = sym.outputs[i];
+          if (e.node->attrs.dict.count("__dtype__") > 0) {
+            std::string& dtype = e.node->attrs.dict["__dtype__"];
+            // add this dtype to the list
+            ss << getDtypeAt(dtype, e.index);
+          }
+          if (i < sym.outputs.size()-1)
+            ss << ",";
+        }
+        ss << "]";
+        n->attrs.dict["__dtype__"] = ss.str();
       }
-      // free memory used by custom op to allocate attributes
-      call_free_(attr_vals);
-      call_free_(attr_keys);
+      // set user specified attributes
+      for (auto attr : user_attrs)
+        n->attrs.dict[attr.first] = attr.second;
       return n;
     } else {
       return nullptr;
     }
   }
+
+  virtual void InitSubgraphInputs(std::vector<nnvm::NodeEntry*>* input_entries,
+                                  std::vector<nnvm::NodeEntry>* orig_input_entries) const {
+    for (size_t i = 0; i < input_entries->size(); ++i) {
+      nnvm::NodeEntry *e = input_entries->at(i);
+      nnvm::NodeEntry& orig = orig_input_entries->at(i);
+
+      // set attribute for subgraph input to indicate if it is from an arg/param to model
+      if (orig.node->is_variable()) {
+        // get name of original output entry
+        nnvm::Symbol sym;
+        sym.outputs.push_back(orig);
+        const auto output_names = sym.ListOutputNames();
+        CHECK_EQ(output_names.size(), 1U);
+        const std::string& var_name = output_names[0];
+
+        e->node->attrs.dict["isArg"] = "True";
+        e->node->attrs.dict["argName"] = var_name;
+      } else {
+        e->node->attrs.dict["isArg"] = "False";
+      }
+
+      // pass down other attributes if available
+      if (orig.node->attrs.dict.count("__dtype__") > 0) {
+        // get dtype string from other node
+        std::string& dtype = orig.node->attrs.dict["__dtype__"];
+        std::stringstream ss;
+        ss << "[" << getDtypeAt(dtype, orig.index) << "]";
+        e->node->attrs.dict["__dtype__"] = ss.str();
+      }
+
+      if (orig.node->attrs.dict.count("__shape__") > 0) {
+        // get shape string from other node
+        std::string& shape = orig.node->attrs.dict["__shape__"];
+        // create new shape string for this node
+        std::stringstream ss;
+        ss << "[" << getShapeAt(shape, orig.index) << "]";
+        e->node->attrs.dict["__shape__"] = ss.str();
+      }
+    }
+  }
+
   // override CreateSubgraphSelector
   virtual SubgraphSelectorPtr CreateSubgraphSelector() const {
     return std::make_shared<CustomContainOpSelector>(supported_nodes);
@@ -228,6 +408,17 @@ class  CustomSubgraphProperty: public SubgraphProperty {
   std::string subgraph_op_name;
   std::vector<std::pair<std::string, std::string>> options_map_;
   std::vector<const char*> opt_keys_, opt_vals_;
+  std::vector<std::string> in_arg_names, in_aux_names;
+  NDArray **in_args_ptr;
+  NDArray **in_aux_ptr;
+  std::vector<const char*> arg_names, aux_names;
+  std::vector<void*> arg_data, aux_data;
+  std::vector<const int64_t*> arg_shapes, aux_shapes;
+  std::vector<int> arg_dims, aux_dims;
+  std::vector<int> arg_types, aux_types;
+  std::vector<size_t> arg_verIDs, aux_verIDs;
+  std::vector<const char*> arg_dev_type, aux_dev_type;
+  std::vector<int> arg_dev_id, aux_dev_id;
 };
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h
index f765aba..a710c5e 100644
--- a/src/operator/subgraph/subgraph_property.h
+++ b/src/operator/subgraph/subgraph_property.h
@@ -358,6 +358,14 @@ class SubgraphProperty {
     subgraph_node->inputs = *orig_input_entries;
   }
   /*!
+   * \brief Initialize subgraph internal inputs with external input entries.
+   * Called before CreateSubgraphNode, optional
+   * \param input_entries input entries inside subgraph
+   * \param orig_input_entries input entries outside subgraph
+   */
+  virtual void InitSubgraphInputs(std::vector<nnvm::NodeEntry*>* input_entries,
+                                  std::vector<nnvm::NodeEntry>* orig_input_entries) const {}
+  /*!
    * \brief Set an attr with name in the attr map.
    */
   template <typename T>
diff --git a/tests/python/gpu/test_extensions_gpu.py b/tests/python/gpu/test_extensions_gpu.py
index 08930a3..8315b49 100644
--- a/tests/python/gpu/test_extensions_gpu.py
+++ b/tests/python/gpu/test_extensions_gpu.py
@@ -68,8 +68,24 @@ def test_custom_op_gpu():
     out_base = exe_base.forward()
     assert_almost_equal(out_base[0].asnumpy(), out[0].asnumpy(), rtol=1e-3, atol=1e-3)
 
-    # test backward
+    # test custom relu backward
     out_grad = mx.nd.ones((2,2), ctx=mx.gpu())
     exe.backward([out_grad])
     exe_base.backward([out_grad])
     assert_almost_equal(in_grad_base[0].asnumpy(), in_grad[0].asnumpy(), rtol=1e-3, atol=1e-3)
+
+    # test custom noisy relu producing deterministic result given same seed managed by mxnet
+    d1 = mx.nd.ones(shape=(10,10,10), ctx=mx.cpu())
+    d2 = mx.nd.ones(shape=(10,10,10), ctx=mx.gpu())
+
+    mx.random.seed(128, ctx=mx.cpu())
+    r1 = mx.nd.my_noisy_relu(d1)
+    mx.random.seed(128, ctx=mx.cpu())
+    r2 = mx.nd.my_noisy_relu(d1)
+    assert_almost_equal(r1.asnumpy(), r2.asnumpy(), rtol=1e-3, atol=1e-3)
+
+    mx.random.seed(128, ctx=mx.gpu())
+    r3 = mx.nd.my_noisy_relu(d2)
+    mx.random.seed(128, ctx=mx.gpu())
+    r4 = mx.nd.my_noisy_relu(d2)
+    assert_almost_equal(r3.asnumpy(), r4.asnumpy(), rtol=1e-3, atol=1e-3)
diff --git a/tests/python/unittest/test_extensions.py b/tests/python/unittest/test_extensions.py
index 799615b..d00f149 100644
--- a/tests/python/unittest/test_extensions.py
+++ b/tests/python/unittest/test_extensions.py
@@ -167,3 +167,17 @@ def test_subgraph():
     out4 = sym_block(mx.nd.ones((3,2)),mx.nd.ones((3,2)))
     # check that result matches one executed by MXNet
     assert_almost_equal(out[0].asnumpy(), out4[0].asnumpy(), rtol=1e-3, atol=1e-3)
+
+    # Gluon Hybridize partitioning with shapes/types
+    sym_block2 = nn.SymbolBlock(sym, [a,b])
+    sym_block2.initialize()
+    a_data = mx.nd.ones((3,2))
+    b_data = mx.nd.ones((3,2))
+    sym_block2.optimize_for(a_data, b_data, backend='myProp')
+    sym_block2.export('optimized')
+    sym_block3 = nn.SymbolBlock.imports('optimized-symbol.json',['a','b'],
+                                        'optimized-0000.params')
+
+    out5 = sym_block3(a_data, b_data)
+    # check that result matches one executed by MXNet
+    assert_almost_equal(out[0].asnumpy(), out5[0].asnumpy(), rtol=1e-3, atol=1e-3)
diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py
index f1572e7..e414a98 100644
--- a/tests/python/unittest/test_subgraph_op.py
+++ b/tests/python/unittest/test_subgraph_op.py
@@ -282,7 +282,7 @@ def check_subgraph_exe6(sym, subgraph_backend, op_names):
     # infer shape/type before partition before simple_bind
     check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)),
                                                  c_str_array(op_names)))
-    part_sym = sym.optimize_for(subgraph_backend, exe1.arg_dict)
+    part_sym = sym.optimize_for(subgraph_backend, exe1.arg_dict, exe1.aux_dict)
     check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend)))
 
     exe2 = part_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
@@ -335,7 +335,7 @@ def check_subgraph_exe8(sym, subgraph_backend, op_names):
     # infer shape/type before partition before bind
     check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)),
                                                  c_str_array(op_names)))
-    part_sym = sym.optimize_for(subgraph_backend, arg_array)
+    part_sym = sym.optimize_for(subgraph_backend, arg_array, aux_array)
     check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend)))
 
     exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null')