You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sa...@apache.org on 2020/08/18 22:51:34 UTC

[incubator-mxnet] branch v1.x updated: [1.x] Backporting #18779 to v1.x (#18894)

This is an automated email from the ASF dual-hosted git repository.

samskalicky pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new d1ac7c8  [1.x] Backporting #18779 to v1.x (#18894)
d1ac7c8 is described below

commit d1ac7c849dd3e0107ed0964967131f8c27ef975b
Author: Sam Skalicky <sa...@gmail.com>
AuthorDate: Tue Aug 18 15:50:30 2020 -0700

    [1.x] Backporting #18779 to v1.x (#18894)
    
    * initial commit
    
    * Support extra inputs for subgraph ops (#18779)
    
    Support additional inputs to custom subgraph ops that are not direct dependencies to ops in the subgraph. This will enable various use cases: custom control flow ops, custom ops that maintain a state that should be saved/loaded, etc.
    
    Highlights:
    
    * Added test that uses a graph pass (addInputPass) to add a new custom input to the subgraph op
    
    * Added new optional argument (clear) to hybridize & optimize_for APIs in Gluon Block to enable multiple optimizations
    
    * refactored lib_api.h JSON utilities
    
    * added new Graph data structure utilities to simplify custom graph passes
    
    * refactored custom op registration
    
    * enhanced custom subgraph op to support additional inputs to subgraph op that is not an input to ops in the subgraph
    
    * updated subgraph & graph pass READMEs
    
    * Added error messaging from external library
    
    * changed messages
    
    * changed to pointers and types
    
    * added cast
    
    * updated cast
    
    * fixed signed int
    
    * whitespace
    
    * fixd pass resource
    
    Co-authored-by: Ubuntu <ub...@ip-172-31-6-220.us-west-2.compute.internal>
---
 example/extensions/lib_api/init_lib.cc             |   4 +-
 example/extensions/lib_custom_op/gemm_lib.cc       |  16 +-
 example/extensions/lib_custom_op/relu_lib.cu       |   4 +-
 .../extensions/lib_custom_op/transposecsr_lib.cc   |  22 +-
 .../extensions/lib_custom_op/transposerowsp_lib.cc |  22 +-
 example/extensions/lib_pass/README.md              | 103 ++-
 example/extensions/lib_pass/example_connection.png | Bin 0 -> 8443 bytes
 example/extensions/lib_pass/pass_lib.cc            |  64 +-
 example/extensions/lib_pass/test_pass.py           |   2 +-
 example/extensions/lib_subgraph/README.md          | 165 ++--
 example/extensions/lib_subgraph/subgraph_lib.cc    | 212 +++--
 example/extensions/lib_subgraph/test_subgraph.py   |  11 +-
 include/mxnet/lib_api.h                            | 850 ++++++++++++++++-----
 python/mxnet/gluon/block.py                        |  11 +-
 src/c_api/c_api.cc                                 | 706 +++++++++++------
 .../partitioner/custom_subgraph_property.h         |  80 +-
 16 files changed, 1489 insertions(+), 783 deletions(-)

diff --git a/example/extensions/lib_api/init_lib.cc b/example/extensions/lib_api/init_lib.cc
index fb3a104..0ed4376 100644
--- a/example/extensions/lib_api/init_lib.cc
+++ b/example/extensions/lib_api/init_lib.cc
@@ -26,12 +26,14 @@
 #include <iostream>
 #include "lib_api.h"
 
+using namespace mxnet::ext;
+
 MXReturnValue initialize(int version) {
   if (version >= 10700) {
     std::cout << "MXNet version " << version << " supported" << std::endl;
     return MX_SUCCESS;
   } else {
-    std::cout << "MXNet version " << version << " not supported" << std::endl;
+    MX_ERROR_MSG << "MXNet version " << version << " not supported";
     return MX_FAIL;
   }
 }
diff --git a/example/extensions/lib_custom_op/gemm_lib.cc b/example/extensions/lib_custom_op/gemm_lib.cc
index 4f8daba..59905c8 100644
--- a/example/extensions/lib_custom_op/gemm_lib.cc
+++ b/example/extensions/lib_custom_op/gemm_lib.cc
@@ -26,6 +26,8 @@
 #include <iostream>
 #include "lib_api.h"
 
+using namespace mxnet::ext;
+
 // main matrix multiplication routine
 void gemm(const float* A, const float* B, float* C,
           const unsigned n, const unsigned k, const unsigned m) {
@@ -127,12 +129,12 @@ MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attr
                         std::vector<int> *outtypes) {
   // validate inputs
   if (intypes->size() != 2) {
-    std::cout << "Expected 2 inputs to inferType" << std::endl;
+    MX_ERROR_MSG << "Expected 2 inputs to inferType";
     return MX_FAIL;
   }
   for (unsigned i = 0; i < intypes->size(); i++) {
     if (intypes->at(i) != kFloat32) {
-      std::cout << "Expected input " << i << " to have float32 type" << std::endl;
+      MX_ERROR_MSG << "Expected input " << i << " to have float32 type";
       return MX_FAIL;
     }
   }
@@ -146,11 +148,11 @@ MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& att
                          std::vector<std::vector<unsigned int>>* outshapes) {
   // validate inputs
   if (inshapes->size() != 2) {
-    std::cout << "Expected 2 inputs to inferShape" << std::endl;
+    MX_ERROR_MSG << "Expected 2 inputs to inferShape";
     return MX_FAIL;
   }
   if (inshapes->at(0).size() != 2 || inshapes->at(1).size() != 2) {
-    std::cout << "Expected 2D matrices for both inputs to inferShape" << std::endl;
+    MX_ERROR_MSG << "Expected 2D matrices for both inputs to inferShape";
     return MX_FAIL;
   }
 
@@ -159,7 +161,7 @@ MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& att
   unsigned kk = inshapes->at(1)[0];
   unsigned m = inshapes->at(1)[1];
   if (k != kk) {
-    std::cout << "Exected first input axis 1 equals to second input axis 0" << std::endl;
+    MX_ERROR_MSG << "Exected first input axis 1 equals to second input axis 0";
     return MX_FAIL;
   }
 
@@ -195,8 +197,6 @@ class MyStatefulGemm : public CustomStatefulOp {
     return backward(attrs_, inputs, outputs, op_res);
   }
 
-  ~MyStatefulGemm() {}
-
  private:
   int count;
   const std::unordered_map<std::string, std::string> attrs_;
@@ -230,7 +230,7 @@ MXReturnValue initialize(int version) {
     std::cout << "MXNet version " << version << " supported" << std::endl;
     return MX_SUCCESS;
   } else {
-    std::cout << "MXNet version " << version << " not supported" << std::endl;
+    MX_ERROR_MSG << "MXNet version " << version << " not supported";
     return MX_FAIL;
   }
 }
diff --git a/example/extensions/lib_custom_op/relu_lib.cu b/example/extensions/lib_custom_op/relu_lib.cu
index a4711cb..7022c76 100644
--- a/example/extensions/lib_custom_op/relu_lib.cu
+++ b/example/extensions/lib_custom_op/relu_lib.cu
@@ -26,6 +26,8 @@
 #include <iostream>
 #include "lib_api.h"
 
+using namespace mxnet::ext;
+
 #define NumThreadPerBlock 256 // mxnet recommended cuda thread number per block
 
 __global__ void relu_gpu_forward(float *out, float *in, int64_t N) {
@@ -263,7 +265,7 @@ MXReturnValue initialize(int version) {
     std::cout << "MXNet version " << version << " supported" << std::endl;
     return MX_SUCCESS;
   } else {
-    std::cout << "MXNet version " << version << " not supported" << std::endl;
+    MX_ERROR_MSG << "MXNet version " << version << " not supported";
     return MX_FAIL;
   }
 }
diff --git a/example/extensions/lib_custom_op/transposecsr_lib.cc b/example/extensions/lib_custom_op/transposecsr_lib.cc
index 224cd6a..d3941d7 100644
--- a/example/extensions/lib_custom_op/transposecsr_lib.cc
+++ b/example/extensions/lib_custom_op/transposecsr_lib.cc
@@ -26,6 +26,8 @@
 #include <iostream>
 #include "lib_api.h"
 
+using namespace mxnet::ext;
+
 void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) {
   MXSparse* A = src.data<MXSparse>();
   MXSparse* B = dst.data<MXSparse>(); 
@@ -70,11 +72,11 @@ MXReturnValue forward(const std::unordered_map<std::string, std::string>& attrs,
   // The data types and storage types of inputs and outputs should be the same.  
   if(inputs->at(0).dtype != outputs->at(0).dtype ||
      inputs->at(0).stype != outputs->at(0).stype) {
-    std::cout << "Error! Expected all inputs and outputs to be the same type." 
-              << "Found input storage type:" << inputs->at(0).stype
-              << " Found output storage type:" << outputs->at(0).stype
-              << " Found input data type:" << inputs->at(0).dtype
-              << " Found output data type:" << outputs->at(0).dtype << std::endl;
+    MX_ERROR_MSG << "Error! Expected all inputs and outputs to be the same type." 
+                 << "Found input storage type:" << inputs->at(0).stype
+                 << " Found output storage type:" << outputs->at(0).stype
+                 << " Found input data type:" << inputs->at(0).dtype
+                 << " Found output data type:" << outputs->at(0).dtype;
     return MX_FAIL;
   }
 
@@ -101,11 +103,11 @@ MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attr
                         std::vector<int>* outtypes) {
   // validate inputs
   if (intypes->size() != 1) {
-    std::cout << "Expected 1 inputs to inferType" << std::endl;
+    MX_ERROR_MSG << "Expected 1 inputs to inferType";
     return MX_FAIL;
   }
   if (intypes->at(0) != kFloat32) {
-    std::cout << "Expected input to have float32 type" << std::endl;
+    MX_ERROR_MSG << "Expected input to have float32 type";
     return MX_FAIL;
   }
 
@@ -117,7 +119,7 @@ MXReturnValue inferSType(const std::unordered_map<std::string, std::string>& att
                          std::vector<int>* instypes,
                          std::vector<int>* outstypes) {
   if (instypes->at(0) != kCSRStorage) {
-    std::cout << "Expected storage type is kCSRStorage" << std::endl;
+    MX_ERROR_MSG << "Expected storage type is kCSRStorage";
     return MX_FAIL;
   }
   outstypes->at(0) = instypes->at(0);
@@ -129,7 +131,7 @@ MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& att
                          std::vector<std::vector<unsigned int>>* outshapes) {
   // validate inputs
   if (inshapes->size() != 1) {
-    std::cout << "Expected 1 inputs to inferShape" << std::endl;
+    MX_ERROR_MSG << "Expected 1 inputs to inferShape";
     return MX_FAIL;
   }
 
@@ -194,7 +196,7 @@ MXReturnValue initialize(int version) {
     std::cout << "MXNet version " << version << " supported" << std::endl;
     return MX_SUCCESS;
   } else {
-    std::cout << "MXNet version " << version << " not supported" << std::endl;
+    MX_ERROR_MSG << "MXNet version " << version << " not supported";
     return MX_FAIL;
   }
 }
diff --git a/example/extensions/lib_custom_op/transposerowsp_lib.cc b/example/extensions/lib_custom_op/transposerowsp_lib.cc
index 46d3c4d..90ad594 100644
--- a/example/extensions/lib_custom_op/transposerowsp_lib.cc
+++ b/example/extensions/lib_custom_op/transposerowsp_lib.cc
@@ -26,6 +26,8 @@
 #include <iostream>
 #include "lib_api.h"
 
+using namespace mxnet::ext;
+
 void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) {
   MXSparse* A = src.data<MXSparse>();
   MXSparse* B = dst.data<MXSparse>(); 
@@ -73,11 +75,11 @@ MXReturnValue forward(const std::unordered_map<std::string, std::string>& attrs,
   // The data types and storage types of inputs and outputs should be the same.
   if(inputs->at(0).dtype != outputs->at(0).dtype ||
      inputs->at(0).stype != outputs->at(0).stype) {
-    std::cout << "Error! Expected all inputs and outputs to be the same type."
-              << "Found input storage type:" << inputs->at(0).stype
-              << " Found output storage type:" << outputs->at(0).stype
-              << " Found input data type:" << inputs->at(0).dtype
-              << " Found output data type:" << outputs->at(0).dtype << std::endl;
+    MX_ERROR_MSG << "Error! Expected all inputs and outputs to be the same type."
+                 << "Found input storage type:" << inputs->at(0).stype
+                 << " Found output storage type:" << outputs->at(0).stype
+                 << " Found input data type:" << inputs->at(0).dtype
+                 << " Found output data type:" << outputs->at(0).dtype;
     return MX_FAIL;
   }
   transpose(inputs->at(0), outputs->at(0), res);
@@ -103,11 +105,11 @@ MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attr
                         std::vector<int>* outtypes) {
   // validate inputs
   if (intypes->size() != 1) {
-    std::cout << "Expected 1 inputs to inferType" << std::endl;
+    MX_ERROR_MSG << "Expected 1 inputs to inferType";
     return MX_FAIL;
   }
   if (intypes->at(0) != kFloat32) {
-    std::cout << "Expected input to have float32 type" << std::endl;
+    MX_ERROR_MSG << "Expected input to have float32 type";
     return MX_FAIL;
   }
 
@@ -119,7 +121,7 @@ MXReturnValue inferSType(const std::unordered_map<std::string, std::string>& att
                          std::vector<int>* instypes,
                          std::vector<int>* outstypes) {
   if (instypes->at(0) != kRowSparseStorage) {
-    std::cout << "Expected storage type is kRowSparseStorage" << std::endl;
+    MX_ERROR_MSG << "Expected storage type is kRowSparseStorage";
     return MX_FAIL;
   }
   outstypes->at(0) = instypes->at(0);
@@ -131,7 +133,7 @@ MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& att
                          std::vector<std::vector<unsigned int>>* outshapes) {
   // validate inputs
   if (inshapes->size() != 1) {
-    std::cout << "Expected 1 inputs to inferShape" << std::endl;
+    MX_ERROR_MSG << "Expected 1 inputs to inferShape";
     return MX_FAIL;
   }
 
@@ -196,7 +198,7 @@ MXReturnValue initialize(int version) {
     std::cout << "MXNet version " << version << " supported" << std::endl;
     return MX_SUCCESS;
   } else {
-    std::cout << "MXNet version " << version << " not supported" << std::endl;
+    MX_ERROR_MSG << "MXNet version " << version << " not supported";
     return MX_FAIL;
   }
 }
diff --git a/example/extensions/lib_pass/README.md b/example/extensions/lib_pass/README.md
index c277124..18272c0 100644
--- a/example/extensions/lib_pass/README.md
+++ b/example/extensions/lib_pass/README.md
@@ -32,22 +32,21 @@ To run the following example, the build type of MXNet doesn’t matter since the
 
 ### Run An Example
 
-You can start getting familiar with custom passes by running an example provided in the **example/extensions/lib_pass** directory. The `myPass` example just copies the input graph to the output. Go to the **lib_pass** directory and follow these steps:
+You can start getting familiar with custom passes by running an example provided in the **example/extensions/lib_pass** directory. The `myPass` example just prints out the graph. Go to the **lib_pass** directory and follow these steps:
 
 1. Run `make`. The Makefile will generate the dynamic library **libpass_lib.so** which is compiled from the `pass_lib.cc` file. This is the library you are going to load that contains everything for the custom pass.
-2. Run `python test_pass.py`. It’ll first load the above library, find the components, register them in the MXNet backend, then execute the pass on the model and execute the operators like a regular MXNet operator and output the result. Below is the output when running the `python test_pass.py` command. Notice that it loads 2 passes: myPass and jsonPass.
+2. Run `python test_pass.py`. It’ll first load the above library, find the components, register them in the MXNet backend, then execute the pass on the model and execute the operators like a regular MXNet operator and output the result. Below is the output when running the `python test_pass.py` command. Notice that it loads 1 pass: `myPass`.
 
 ```
 [10:38:03] src/c_api/c_api.cc:286: Found 0 operators in library
 [10:38:03] src/c_api/c_api.cc:785: Found 0 partitioners in library
-[07:14:00] src/c_api/c_api.cc:887: Found 2 graph passes in library
+[07:14:00] src/c_api/c_api.cc:887: Found 1 graph passes in library
 [07:14:00] src/c_api/c_api.cc:902:       Graph Pass [0] myPass
-[07:14:00] src/c_api/c_api.cc:902:       Graph Pass [1] jsonPass
 ```
 
 ### Basic Files For Custom Pass Library
 * **lib_pass/pass_lib.cc**: This file has a source code implementation of all required components to make a custom pass, it also shows registration of them so that they can be loaded by MXNet.
-* **lib_pass/Makefile**: This file compiles the source code to a dynamic shared library, with a header file `include/mxnet/lib_api.h` from MXNet source code. Currently the custom pass is compatible with C++11 onwards.
+* **lib_pass/Makefile**: This file compiles the source code to a dynamic shared library, with a header file `include/mxnet/lib_api.h` from MXNet source code. Currently the custom pass is compatible with C++11 and above.
 * **lib_pass/test_pass.py**: This file calls `mx.library.load(‘libpass_lib.so’)` to load the library containing the custom components, executes the pass on the model using the `optimize_for` API, and prints outputs of the forward passes. The outputs should be the same as the regular MXNet forward pass without running the pass.
 * **include/mxnet/lib_api.h**: This file from MXNet source code is the single header file needed to include all necessary data types and function prototypes for writing a custom library. You can either specify the include path in the `Makefile`, or copy the header file over to `example/extensions/lib_pass` folder. Note that apart from this header, the custom library is independent of MXNet source.
 ## Writing Custom Pass Library
@@ -78,38 +77,38 @@ sym_block.optimize_for(x, backend='myPass')
 
 ### Using a Custom Pass Library
 
-APIs in MXNet are available in both Symbol and Gluon APIs. For the Symbol API, the `optimize_for` API can be called on Symbol objects to return a new Symbol post graph pass.
+APIs in MXNet are available in both Symbol and Gluon APIs. For the Symbol API, `optimize_for` can be called on Symbol objects to run the graph pass and return a new Symbol.
 
-```
-optimize_for(backend, args=None, aux=None, ctx=None, **kwargs)
+```python
+sym.optimize_for(backend, args=None, aux=None, ctx=None, **kwargs)
 ```
 
 The `optimize_for` API takes at least 1 argument, `backend` which is a string that identifies which backend to use to optimize the model. The `args` and `aux` arguments are optional and take a list of NDArray or dict of str to NDArray. They are used to infer shapes and types and before executing the graph pass. The `ctx` argument is optional and takes a device context to infer storage types. It also takes any other user-specified options that will be passed to the backend APIs.
 
-For the Gluon API, the `hybridize` API can be called on HybridBlocks to execute a graph pass on the internal CachedOp Symbol.
+For the Gluon API, `hybridize` can be called on HybridBlocks to execute a graph pass on the internal CachedOp Symbol.
 
-```
-hybridize(backend=None, backend_opts=None, **kwargs)
+```python
+block.hybridize(backend=None, backend_opts=None, **kwargs)
 ```
 
 The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which pass that will be executed on the model. The `backend_opts` takes other user-specified options that will be passed to the backend APIs. The actual pass runs once just before the first the forward pass.
 
 If you just want to run a graph pass on the HybridBlock but not run a complete forward pass, you can use the `optimize_for` API that combines the work done in the `hybridize` API with part of the work done in the forward pass.
 
-```
-optimize_for(x, backend=None, backend_opts=None, **kwargs)
+```python
+block.optimize_for(x, backend=None, backend_opts=None, **kwargs)
 ```
 
 When the `optimize_for` API is called on a HybridBlock it runs the graph pass immediately. This lets users export the modified model without running a complete forward pass.
 
-```
+```python
 block.optimize_for(x, backend='myPass')
 block.export('optimized')
 ```
 
 But you can also use `optimize_for` in place of `hybridize` and run inference immediately after too.
 
-```
+```python
 block.optimize_for(x, backend='myPass')
 block(x)
 ```
@@ -120,50 +119,80 @@ There are several essential building blocks for making a custom pass:
 
 * [initialize](./pass_lib.cc#44):
     * This function is the library initialization function necessary for any dynamic libraries. It lets you check if the user is using a compatible version of MXNet. Note that this `version` parameter is passed from MXNet when library is loaded.
-
+```c++
             MXReturnValue initialize(int version)
-
+```
 * [graphPass](./pass_lib.cc#31):
-    * This function provides a copy of the model graph as a JSON string, and provides an interface for returning a modified model JSON string. Also this is where a custom pass can validate the options specified by the user.
-
+    * This function provides a copy of the model graph, and any specific options from the user.
+```c++
             MXReturnValue graphPass(
-                const std::string& in_graph,
-                const std::string** out_graph,
-                const std::unordered_map<std::string, std::string>& options,
-                const std::unordered_map<std::string, MXTensor>& args,
-                const std::unordered_map<std::string, MXTensor>& aux,
-                const PassResource& res)
-
+                mxnet::ext::Graph *g,
+                const std::unordered_map<std::string, std::string>& options)
+```
 * [REGISTER_PASS(my_pass_name)](./pass_lib.cc#L41):
     * This macro registers the custom pass and its properties to MXNet by its name. The argument to `setBody` is the `graphPass` function.
-
+```c++
             REGISTER_PASS(my_pass_name)
             .setBody(graphPass);
-
+```
 Let’s take a closer look at those registry functions:
 
-* **graphPass**: This function takes six arguments. The 1st argument is a JSON string of the model architecture graph, where nodes are inputs/params/weights and edges are data dependencies. The graph is pre-sorted in topological order. The 2nd argument is a pointer to a pointer of a JSON model string. It is expected users will dereference and assign the address of their output string allocated with `new` and `delete` will be called on it automatically. The third argument is the map of op [...]
+* **graphPass**: This function takes two arguments. The first argument is the Graph of the model architecture, where nodes are inputs/params/weights and edges are data dependencies. The second argument is the map of options specified by the user. Users can pass custom options to the pass and they are passed to this function in the `options` map.
+
+### Graph representation
 
-### Pass Resource
+The `Graph` class represents the model's architecture. Each `Node` in the graph represents an operator or weight (ie. args/aux param). Since an operator in MXNet can take multiple inputs and produce multiple outputs, each input/output is represented by a `NodeEntry`. A `Node` contains the following:
+- `op` - [string] operator name
+- `name` - [string] unique node name
+- `inputs` - [vector of NodeEntry] set of inputs to the node
+- `outputs` - [vector of NodeEntry] set of outputs from the node
+- `subgraph` - [vector of Graph] set of subgraphs in the node
+- `attrs` - [map of string to string] set of attributes for the node
 
-Some graph passes require allocating new NDArrays to add/replace model params. The `alloc_arg` and `alloc_aux` APIs enabling allocating new NDArrays and integrating them with the user-provide args and aux params. Both APIs have the following signature:
+The `inputs` are a set of `NodeEntry` where each contains a pointer to a `Node` that produces the data, and an `entry` that is the index of the output on the other `Node`. Conversely, the `outputs` are a set of `NodeEntry` where each contains a pointer to a`Node` that consumes the data, and and `entry` that is the index of the input on the other `Node`. This bidirectional dependency will enable you to easily traverse the graph. 
 
+A `Graph` contains the following:
+- `nodes` - [vector of Node] set of nodes in the graph
+- `inputs` - [vector of Node] set of inputs to the graph
+- `outputs` - [vector of NodeEntry] set of outputs from the graph
+- `attrs` - [map of string to JSON object] set of attributes for the graph
+
+The `nodes` are all the nodes in the graph (superset). The `inputs` are only those nodes that are model inputs (ie. input image) or weights (ie. arg/aux params). The `outputs` are the outputs from the operators in the model that are true outputs of the model (ie. prediction results). 
+
+Heres an example creating a new node and adding it to the graph:
+```c++
+g->addNode("myConv","Convolution");
 ```
-    MXTensor* alloc_xxx(const std::string& name,
-                        const std::vector<int64_t>& shapes,
+Heres an example creating an edge between two nodes:
+```c++
+n1->outputs.push_back({n2,1});
+n2->inputs.push_back({n1,0});
+```
+Here node `n1` produces an output at index 0 that is consumed by node `n2` on the input at index 1.
+
+![example connection](example_connection.png)
+
+Some graph passes require allocating new NDArrays to add/replace model params. The `alloc_arg` and `alloc_aux` APIs enable allocating new NDArrays and integrate them with the model args and aux params. Both APIs have the following signature:
+
+```c++
+    MXTensor* alloc_xxx(const std::vector<int64_t>& shapes,
                         const MXContext &ctx,
                         MXDType dtype)
 ```
 
-If the `name` provided matches the name of an existing param it replaces the previous one. Otherwise it adds a new param to the appropriate arg/aux set.
+This function can be called on a node in the graph to allocate a tensor for that node like:
+
+```c++
+node->alloc_arg({1},MXContext::CPU(0),kFloat32);
+```
+It adds a new param to the appropriate arg/aux set when the graph pass returns. If you wish to remove an existing param, just remove the node in the graph corresponding to that param. It will be deleted after the pass completes and removed from the dictionary of args or aux (whichever it is a member of).
 
 ### Parsing a JSON string
 
 To simplify custom libraries, basic JSON parsing utility functions have been implemented in the `lib_api.h` header file. You create a `JsonParser` object and parse the string by calling the `parse_to_json` API like:
 
 ```c++
-JsonParser parser;
-JsonVal json_val = parser.parse_to_json(json_string);
+JsonVal json_val = JsonVal::parse(json);
 ```
 
 A `JsonVal` is a class that represents the nodes in a JSON structure. You can check the type of a node (num, str, list, or map) by comparing the `JsonVal.type` to `STR`, `NUM`, `LIST`, or `MAP`. Then you can get that value from the node like:
@@ -187,4 +216,4 @@ switch(json_val.type) {
 }
 ```
 
-There are also convenience constructors for creating `JsonVal` objects for strings and numbers like `JsonVal("myKey")` or `JsonVal(42)`. This makes it easy to get specific keys from a map like `json_val.map[JsonVal("nodes")]`.
+You call the `dump` function on a `JsonVal` object like `json_val.dump()` to get a JSON-compatible string. There are also convenience constructors for creating `JsonVal` objects for strings and numbers like `JsonVal("myKey")` or `JsonVal(42)`. This makes it easy to get specific keys from a map like `json_val.map[JsonVal("nodes")]`.
diff --git a/example/extensions/lib_pass/example_connection.png b/example/extensions/lib_pass/example_connection.png
new file mode 100644
index 0000000..ef56c62
Binary files /dev/null and b/example/extensions/lib_pass/example_connection.png differ
diff --git a/example/extensions/lib_pass/pass_lib.cc b/example/extensions/lib_pass/pass_lib.cc
index bbdcd73..5f51373 100644
--- a/example/extensions/lib_pass/pass_lib.cc
+++ b/example/extensions/lib_pass/pass_lib.cc
@@ -28,77 +28,27 @@
 #include <algorithm>
 #include "lib_api.h"
 
-/* \brief a basic pass that copies the input to the output */
-MXReturnValue myPass(const std::string& in_graph, const std::string** out_graph,
-                     const std::unordered_map<std::string, std::string>& options,
-                     const std::unordered_map<std::string, MXTensor>& args,
-                     const std::unordered_map<std::string, MXTensor>& aux,
-                     const PassResource& res) {
+using namespace mxnet::ext;
+
+/* \brief a basic pass that prints out the options and the graph */
+MXReturnValue myPass(mxnet::ext::Graph *g,
+                     const std::unordered_map<std::string, std::string>& options) {
   for (auto kv : options) {
     std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
   }
-
-  *out_graph = new std::string(in_graph);
+  g->print();
   return MX_SUCCESS;
 }
 
 REGISTER_PASS(myPass)
 .setBody(myPass);
 
-/* \brief a basic pass that parses the input string to JSON and then dumps it back */
-MXReturnValue jsonPass(const std::string& in_graph, const std::string** out_graph,
-                       const std::unordered_map<std::string, std::string>& options,
-                       const std::unordered_map<std::string, MXTensor>& args,
-                       const std::unordered_map<std::string, MXTensor>& aux,
-                       const PassResource& res) {
-  for (auto kv : options)
-    std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
-
-  // add test arg/aux
-  
-  MXTensor* arg_ = res.alloc_arg("test_arg",{1},MXContext::CPU(0),kFloat32);
-  MXTensor* aux_ = res.alloc_aux("test_aux",{1},MXContext::CPU(0),kFloat32);
-  
-  // convert json string to json object
-  JsonParser parser;
-  JsonVal json_val = parser.parse_to_json(in_graph);
-
-  // get nodes list
-  JsonVal nodes = json_val.map[JsonVal("nodes")];
-
-  // loop over nodes
-  for(int i=0; i<nodes.list.size(); i++) {
-    JsonVal node = nodes.list[i];
-    // get the op name
-    std::string op = node.map[JsonVal("op")].str;
-    // get node ID inputs to op
-    JsonVal node_inputs = node.map[JsonVal("inputs")];
-
-    //get shape/type if available
-    std::string shape;
-    int dtype = -1;
-    if(node.map.find(JsonVal("attrs")) != node.map.end()) {
-      JsonVal attrs = node.map[JsonVal("attrs")];
-      if(attrs.map.find(JsonVal("shape")) != attrs.map.end()) 
-        shape = attrs.map[JsonVal("shape")].str;
-      if(attrs.map.find(JsonVal("dtype")) != attrs.map.end())
-        dtype = std::stoi(attrs.map[JsonVal("dtype")].str);
-    }
-  }
-  
-  *out_graph = new std::string(parser.dump(json_val));
-  return MX_SUCCESS;
-}
-
-REGISTER_PASS(jsonPass)
-.setBody(jsonPass);
-
 MXReturnValue initialize(int version) {
   if (version >= 10700) {
     std::cout << "MXNet version " << version << " supported" << std::endl;
     return MX_SUCCESS;
   } else {
-    std::cout << "MXNet version " << version << " not supported" << std::endl;
+    MX_ERROR_MSG << "MXNet version " << version << " not supported" << std::endl;
     return MX_FAIL;
   }
 }
diff --git a/example/extensions/lib_pass/test_pass.py b/example/extensions/lib_pass/test_pass.py
index 8930c94..01d6edd 100644
--- a/example/extensions/lib_pass/test_pass.py
+++ b/example/extensions/lib_pass/test_pass.py
@@ -51,6 +51,7 @@ def test_model(pass_name):
     # execute in MXNet
     print('-------------------------------')
     print('Testing regular MXNet execution')
+
     exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
     out = exe.forward()
     print(out)
@@ -95,4 +96,3 @@ def test_model(pass_name):
     sym_block2.export('modified')
 
 test_model('myPass')
-test_model('jsonPass')
diff --git a/example/extensions/lib_subgraph/README.md b/example/extensions/lib_subgraph/README.md
index 6644a1fd..2752d27 100644
--- a/example/extensions/lib_subgraph/README.md
+++ b/example/extensions/lib_subgraph/README.md
@@ -38,11 +38,16 @@ You can start getting familiar with custom partitioners by running an example pr
 2. Run `python test_subgraph.py`. It’ll first load the above library, find the components, register them in the MXNet backend, then partition the model and execute the operators like a regular MXNet operator and output the result. Below is the output when running the `python test_subgraph.py` command. Notice that it loads 2 operators: my_gemm and state_gemm.
 
 ```
-[10:38:03] src/c_api/c_api.cc:286: Found 1 operators in library
-[10:38:03] src/c_api/c_api.cc:350:       Op[0] _custom_subgraph_op
-[10:38:03] src/c_api/c_api.cc:785: Found 1 partitioners in library
-[10:38:03] src/c_api/c_api.cc:801:       Partitioner[0] myProp
-[10:38:03] src/c_api/c_api.cc:821:             Strategy[0] strategy1 subgraphOp: '_custom_subgraph_op'
+[02:01:18] src/c_api/c_api.cc:515: Found 1 operators in library
+[02:01:18] src/c_api/c_api.cc:580: 	Op[0] _custom_subgraph_op
+[02:01:18] src/c_api/c_api.cc:581: 		isSubgraphOp
+[02:01:18] src/c_api/c_api.cc:1121: Found 2 partitioners in library
+[02:01:18] src/c_api/c_api.cc:1137: 	Partitioner[0] myProp
+[02:01:18] src/c_api/c_api.cc:1159: 		Strategy[0] strategy1 subgraphOp: '_custom_subgraph_op'
+[02:01:18] src/c_api/c_api.cc:1137: 	Partitioner[1] mySelect
+[02:01:18] src/c_api/c_api.cc:1159: 		Strategy[0] strategy1 subgraphOp: '_custom_subgraph_op'
+[02:01:18] src/c_api/c_api.cc:1182: Found 1 graph passes in library
+[02:01:18] src/c_api/c_api.cc:1197: 	Graph Pass [0] addInputPass
 ```
 
 ### Basic Files For Custom Partitioner Library
@@ -91,38 +96,39 @@ In the Gluon hybridize flow, the model is actually hybridized during the first i
 
 ### Using a Custom Partitioner Library
 
-Partitioning APIs in MXNet are available in both Symbol and Gluon APIs. For the Symbol API, the `optimize_for` API can be called on Symbol objects to return a partitioned Symbol.
+Partitioning APIs in MXNet are available in both Symbol and Gluon APIs. For the Symbol API, `optimize_for` can be called on Symbol objects to return a partitioned Symbol.
 
-```
-optimize_for(backend, args=None, aux=None, ctx=None, **kwargs)
+```python
+sym.optimize_for(backend, args=None, aux=None, ctx=None, **kwargs)
 ```
 
 The `optimize_for` API takes at least 1 argument, `backend` which is a string that identifies which backend to partition the model for. The `args` and `aux` arguments are optional and take a list of NDArray or dict of str to NDArray. They are used to infer shapes and types and before partitioning, and passed to the backend to use during compilation. The `ctx` argument is optional and takes a device context to infer storage types. It also takes any other user-specified options that will b [...]
 
-For the Gluon API, the `hybridize` API can be called on HybridBlocks to partition the internal CachedOp Symbol.
+For the Gluon API, `hybridize` can be called on HybridBlocks to partition the internal CachedOp Symbol.
 
-```
-hybridize(backend=None, backend_opts=None, **kwargs)
+```python
+block.hybridize(backend=None, backend_opts=None, clear=True, **kwargs)
 ```
 
-The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which backend that will partition the model. The `backend_opts` takes other user-specified options that will be passed to the backend partitioning APIs. The actual partitioning takes place during the forward pass.
+The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which backend that will partition the model. The `backend_opts` are other user-specified options (as a Python dictionary of strings mapped to strings) that will be passed to the backend partitioning APIs. The `clear` argument defaults to `True` and clears any previous optimizations done on the block. If you want to chain optimizations together, set ` [...]
 
 If you just want to partition the HybridBlock but not run a complete forward pass, you can use the `optimize_for` API that combines the work done in the `hybridize` API with part of the work done in the forward pass.
 
-```
-optimize_for(x, backend=None, backend_opts=None, **kwargs)
+```python
+block.optimize_for(x, backend=None, backend_opts=None, clear=True, **kwargs)
 ```
 
-When the `optimize_for` API is called on a HybridBlock it partitions immediately. This lets users export the partitioned model without running a complete forward pass.
+When the `optimize_for` API is called on a HybridBlock it partitions immediately. This lets users export the partitioned model without running a complete forward pass. Chaining multiple optimizations is as simple as calling `optimize_for` multiple times, no need to execute a forward pass (as opposed to `hybridize`).
 
-```
+```python
 block.optimize_for(x, backend='myPart')
+block.optimize_for(x, backend='myOtherPart', clear=False)
 block.export('partitioned')
 ```
 
 But you can also use `optimize_for` in place of `hybridize` and run inference immediately after too.
 
-```
+```python
 block.optimize_for(x, backend='myPart')
 block(x)
 ```
@@ -133,44 +139,105 @@ There are several essential building blocks for making a custom partitioner:
 
 * [initialize](./subgraph_lib.cc#L261):
     * This function is the library initialization function necessary for any dynamic libraries. It lets you check if the user is using a compatible version of MXNet. Note that this `version` parameter is passed from MXNet when library is loaded.
-
+```c++
             MXReturnValue initialize(int version)
-
+```
 * [supportedOps](./subgraph_lib.cc#L179):
-    * This function provides a copy of the model graph as a JSON string, and provides an interface for identifying which operators should be partitioned into a subgraph. Also this is where a custom partitioner can validate the options specified by the user.
-
+    * This function provides a copy of the model Graph, and an interface for identifying which operators should be partitioned into a subgraph. Also this is where a custom partitioner can validate the options specified by the user.
+```c++
             MXReturnValue supportedOps(
-                std::string json,
-                std::vector<bool>& ids,
-                std::unordered_map<std::string, std::string>& options)
-
+                const mxnet::ext::Graph* graph,
+                std::vector<int>* ids,
+                const std::unordered_map<std::string, std::string>& options)
+```
 * [REGISTER_PARTITIONER(my_part_name)](./subgraph_lib.cc#L257):
-    * This macro registers the custom partitioner and its properties to MXNet by its name. Notice that a partitioner can have multiple partitioning strategies. This enables multiple *passes* to be run in a single partitioning call from the user. The first argument to `addStrategy` is a user-specified name. The second argument is the `supportedOps` function. The third argument is the name of the subgraph operator to create for each subgraph created during partitioning (see below for more  [...]
-
+    * This macro registers the custom partitioner and its properties to MXNet by its name. Notice that a partitioner can have multiple partitioning strategies. This enables multiple *passes* to be run in a single partitioning call from the user. The first argument to `addStrategy` is a user-specified name. The second argument is the name of the subgraph operator to create for each subgraph created during partitioning (see below for more info about subgraph operators). The `setSupportedOp [...]
+```c++
             REGISTER_PARTITIONER(my_part_name)
-            .addStrategy("strategy1", supportedOps, "_custom_subgraph_op")
+            .addStrategy("strategy1", "_custom_subgraph_op")
+            .setSupportedOps("strategy1", supportedOps)
             .setReviewSubgraph("strategy1", reviewSubgraph);
-
-
+```
 Also there are some optional functions you can specify:
 
 * [reviewSubgraph](./subgraph_lib.cc#L219):
     * This function provides an opportunity to accept/reject a subgraph after MXNet partitions it. It also allows specifying custom attributes on the subgraph (ie. user-generated IDs). If you do not register this function, subgraphs will be accepted by default. 
-
+```c++
             MXReturnValue reviewSubgraph(
-                std::string json,
+                const mxnet::ext::Graph* subgraph,
                 int subgraph_id,
                 bool* accept,
-                std::unordered_map<std::string, std::string>& options,
-                std::unordered_map<std::string, std::string>& attrs,
-                std::map<std::string, MXTensor>& args,
-                std::map<std::string, MXTensor>& aux)
-
+                const std::unordered_map<std::string, std::string>& options)
+```
 Let’s take a closer look at those registry functions:
 
-* **supportedOps**: This function takes four arguments. The 1st argument is a JSON string of the model architecture graph, where nodes are inputs/params/weights and edges are data dependencies. The graph is pre-sorted in topological order. The 2nd argument is an array of booleans, one for each operator in the model. When traversing the graph, operators to be partitioned into subgraphs are identified and an entry is set to `true` for the index in the `ids` array corresponding to the node  [...]
+* **supportedOps**: This function takes 3 arguments. The 1st argument is the model architecture graph, where nodes are inputs/params/weights and edges are data dependencies. The graph is pre-sorted in topological order. The 2nd argument is an array of integers, one for each operator in the model. When traversing the graph, operators to be partitioned into subgraphs are identified and an entry is set to a value for the index in the `ids` array corresponding to the node ID. Setting a non-n [...]
+
+* **reviewSubgraph**: This function takes four arguments. The 1st argument is the newly partitioned subgraph. The 2nd argument is the subgraph ID, this is just a number MXNet uses to identify this particular subgraph (it starts at zero and increments, unique for each subgraph in the model). The 3rd argument is an output to be set in this function to tell MXNet whether to accept (value: `true`) or reject (value: `false`) the subgraph. You might want to reject a subgraph if it doesnt inclu [...]
+
+### Writing a Custom Selector
+Instead of implementing the `supportedOps` API, you can choose to implement a custom selector class for more control over partitioning instead. 
 
-* **reviewSubgraph**: This function takes five arguments. The 1st argument is a JSON string of the newly partitioned subgraph. The 2nd argument is the subgraph ID, this is just a number MXNet uses to identify this particular subgraph (it starts at zero and increments, unique for each subgraph in the model). The 3rd argument is an output to be set in this function to tell MXNet whether to accept (value: `true`) or reject (value: `false`) the subgraph. You might want to reject a subgraph i [...]
+* [createSelector](./subgraph_lib.cc#L321):
+    * This function provides a copy of the model graph as the first argument. The 2nd argument is a placeholder for CustomOpSelector object. You must define a class that inherits from the `CustomOpSelector` class and override the required functions. Then you need to create an instance of your class and assign it to the placeholder. The last argument is a map of user-specified options.
+```c++
+            MXReturnValue createSelector(
+                const mxnet::ext::Graph *graph,
+                CustomOpSelector** sel_inst,
+                const std::unordered_map<std::string, std::string>& options)
+```
+Instead of registering a `supportedOps` API, register the `setCreateSelector` API. 
+```c++
+            REGISTER_PARTITIONER(my_part_name)
+            .addStrategy("strategy1", "_custom_subgraph_op")
+            .setCreateSelector("strategy1", createSelector)
+            .setReviewSubgraph("strategy1", reviewSubgraph);
+```
+When implementing your own selector class, you must inherit from the `CustomOpSelector` class and define the following APIs:
+* [Select](./subgraph_lib.cc#L301):
+    * This function selects a node to include in a subgraph by the index of the node (`nodeID`) in the graph. Return `true` to include this node or `false` to reject this node. 
+```c++
+            bool Select(
+                int nodeID)
+```
+* [SelectInput](./subgraph_lib.cc#L304):
+    * This function grows the subgraph from a node (`nodeID`) to a node that produces one of its inputs (`input_nodeID`). Return `true` to include this node (`input_nodeID`) or `false` to reject this node. 
+```c++
+            bool SelectInput(
+                int nodeID,
+                int input_nodeID)
+```
+* [SelectOutput](./subgraph_lib.cc#L304):
+    * This function grows the subgraph from a node (`nodeID`) to a node that consumes one of its outputs (`output_nodeID`). Return `true` to include this node (`output_nodeID`) or `false` to reject this node. 
+```c++
+            bool SelectOutput(
+                int nodeID,
+                int output_nodeID)
+```
+All of these APIs refer to the model's graph that is provided to the `createSelector` API. When you implement your custom `createSelector` function, you can pass the graph and options to the constructor of your class like this:
+```c++
+MXReturnValue myCreateSelector(const mxnet::ext::Graph *graph,
+                               CustomOpSelector** sel_inst,
+                               const std::unordered_map<std::string, std::string>& options) {
+  *sel_inst = new MySelector(graph, options);
+  return MX_SUCCESS;
+}
+```
+In addition to the 3 required APIs shown above, you can also implement the following optional APIs for your `CustomOpSelector` class:
+* [Filter](./subgraph_lib.cc#L310):
+    * This function enables reviewing the candidate nodes to include in subgraph. The `candidates` are the indices of nodes in the graph to be included in the subgraph. The 2nd argument `keep` is an empty vector to be filled with the indices of nodes you wish to keep in the subgraph. Any remaining candidate nodes not added to `keep` will be excluded from the subgraph. The following function body shows the default behavior when not overloaded, to keep all candidates:
+```c++
+            void Filter(
+                std::vector<int>& candidates,
+                std::vector<int>& keep) {
+              keep.insert(keep.end(), candidates.begin(), candidates.end());
+            }
+```
+* [Reset](./subgraph_lib.cc#L314):
+    * This function provides an opportunity to reset any selector state between subgraphs. It is called after growing subgraph, and before `Filter`. There is no default behavior.
+```c++
+            virtual void Reset() {}
+```
 
 ### Writing A Custom Subgraph Operator
 
@@ -178,19 +245,31 @@ A partitioning strategy specifies how to partition a model and isolate operators
 
 When registering a custom subgraph operator, all thats needed is to register a `createOpState` function and to set that the operator is a subgraph operator by calling the `setIsSubgraphOp` API like:
 
-```
+```c++
 REGISTER_OP(my_subgraph_op)
 .setIsSubgraphOp()
 .setCreateOpState(createOpState, "cpu");
 ```
 
+### Converting a JSON string encoded graph
+
+A Graph object can be created from a JSON string containing a graph/subgraph like:
+
+```c++
+mxnet::ext::Graph* g = mxnet::ext::Graph::fromString(json);
+```
+
+It can be converted back to a JSON string just as easily:
+```c++
+std::string json = g->toString();
+```
+
 ### Parsing a JSON string
 
 To simplify custom partitioner libraries, basic JSON parsing utility functions have been implemented in the `lib_api.h` header file. You create a `JsonParser` object and parse the string by calling the `parse_to_json` API like:
 
 ```c++
-JsonParser parser;
-JsonVal json_val = parser.parse_to_json(json_string);
+JsonVal json_val = JsonVal::parse(json);
 ```
 
 A `JsonVal` is a class that represents the nodes in a JSON structure. You can check the type of a node (num, str, list, or map) by comparing the `JsonVal.type` to `STR`, `NUM`, `LIST`, or `MAP`. Then you can get that value from the node like:
@@ -214,4 +293,4 @@ switch(json_val.type) {
 }
 ```
 
-There are also convenience constructors for creating `JsonVal` objects for strings and numbers like `JsonVal("myKey")` or `JsonVal(42)`. This makes it easy to get specific keys from a map like `json_val.map[JsonVal("nodes")]`.
\ No newline at end of file
+You call the `dump` function on a `JsonVal` object like `json_val.dump()` to get a JSON-compatible string. There are also convenience constructors for creating `JsonVal` objects for strings and numbers like `JsonVal("myKey")` or `JsonVal(42)`. This makes it easy to get specific keys from a map like `json_val.map[JsonVal("nodes")]`.
diff --git a/example/extensions/lib_subgraph/subgraph_lib.cc b/example/extensions/lib_subgraph/subgraph_lib.cc
index 2844207..2f954e0 100644
--- a/example/extensions/lib_subgraph/subgraph_lib.cc
+++ b/example/extensions/lib_subgraph/subgraph_lib.cc
@@ -28,19 +28,21 @@
 #include <algorithm>
 #include "lib_api.h"
 
+using namespace mxnet::ext;
+
 /* function to execute log operator on floats */
-void myLog(MXTensor &in, MXTensor &out) {
-  float* inp = in.data<float>();
-  float* outp = out.data<float>();
-  for (int64_t i = 0; i < in.size(); i++) {
+void myLog(MXTensor *in, MXTensor *out) {
+  float* inp = in->data<float>();
+  float* outp = out->data<float>();
+  for (int64_t i = 0; i < in->size(); i++) {
     outp[i] = logf(inp[i]);
   }
 }
 /* function to execute exp operator on floats */
-void myExp(MXTensor &in, MXTensor &out) {
-  float* inp = in.data<float>();
-  float* outp =out.data<float>();
-  for (int64_t i = 0; i < in.size(); i++) {
+void myExp(MXTensor *in, MXTensor *out) {
+  float* inp = in->data<float>();
+  float* outp =out->data<float>();
+  for (int64_t i = 0; i < in->size(); i++) {
     outp[i] = expf(inp[i]);
   }
 }
@@ -52,15 +54,10 @@ void myExp(MXTensor &in, MXTensor &out) {
  */
 MXReturnValue myExecutor(std::vector<MXTensor>* inputs,
                          std::vector<MXTensor>* outputs,
-                         const std::string& subgraph_sym) {
-  std::cout << "Info: subgraph symbol is: " << std::endl;
-  std::cout << subgraph_sym << std::endl;
+                         mxnet::ext::Graph *subgraph) {
+  std::cout << "Info: subgraph is: " << std::endl;
+  subgraph->print();
 
-  // convert json string to json object
-  JsonParser parser;
-  JsonVal json_val = parser.parse_to_json(subgraph_sym);
-  // get nodes list
-  JsonVal nodes = json_val.map[JsonVal("nodes")];
   //counter for inputs
   int input_cnt = 0;
   // temporary tensor storage
@@ -69,41 +66,40 @@ MXReturnValue myExecutor(std::vector<MXTensor>* inputs,
   std::vector<void*> to_free;
 
   // loop over nodes
-  for(int i=0; i<nodes.list.size(); i++) {
-    JsonVal node = nodes.list[i];
-    // get the op name
-    std::string op = node.map[JsonVal("op")].str;
-    // get node ID inputs to op
-    JsonVal node_inputs = node.map[JsonVal("inputs")];
-    
+  for(int i=0; i<subgraph->size(); i++) {
+    mxnet::ext::Node* node = subgraph->getNode(i);
     // handle each op type
-    if (op.compare("null") == 0) {
-      // null is an input data to the subgraph, add to data storage
-      data.push_back(inputs->at(input_cnt++));
-    } else if (op.compare("log") == 0) {
+    if (node->op.compare("null") == 0) {
+      // set tensor for this input to the subgraph
+      node->tensor = &inputs->at(input_cnt++);
+    } else if (node->op.compare("log") == 0) {
       // get input tensor based on node ID inputs from data storage
-      MXTensor &input = data[node_inputs.list[0].list[0].num];
+      MXTensor *input = node->inputs.at(0).node->tensor;
       // create temporary storage
-      MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, MXContext::CPU(0), kDefaultStorage);
+      MXTensor tmp(malloc(input->size()*4), input->shape, input->dtype, 0, MXContext::CPU(0), kDefaultStorage);  // NOLINT
       // save allocated ptr to free later
       to_free.push_back(tmp.data_ptr);
       // execute log operator
-      myLog(input,tmp);
+      myLog(input,&tmp);
       // add output tensor to data storage
       data.push_back(tmp);
-    } else if (op.compare("exp") == 0) {
+      // set tensor for this node so we can read it later
+      node->tensor = &data.back();
+    } else if (node->op.compare("exp") == 0) {
       // get input tensor based on node ID inputs from data storage
-      MXTensor &input = data[node_inputs.list[0].list[0].num];
+      MXTensor *input = node->inputs.at(0).node->tensor;
       // create temporary storage
-      MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, MXContext::CPU(0), kDefaultStorage);
+      MXTensor tmp(malloc(input->size()*4), input->shape, input->dtype, 0, MXContext::CPU(0), kDefaultStorage);  // NOLINT
       // save allocated ptr to free later
       to_free.push_back(tmp.data_ptr);
       // execute exp operator 
-      myExp(input,tmp);
+      myExp(input,&tmp);
       // add output tensor to data storage
       data.push_back(tmp);
+      // set tensor for this node so we can read it later
+      node->tensor = &data.back();
     } else {
-      std::cout << "Error! Unsupported op '" << op << "' found in myExecutor";
+      MX_ERROR_MSG << "Error! Unsupported op '" << node->op << "' found in myExecutor";
       // free allocated temporary storage
       for (void* ptr : to_free)
         free(ptr);
@@ -111,18 +107,16 @@ MXReturnValue myExecutor(std::vector<MXTensor>* inputs,
     }
   }
   
-  // get list of outputs from subgraph
-  JsonVal heads = json_val.map[JsonVal("heads")];
   // copy all operator results to outputs of subgraph
-  for (int j = 0; j < heads.list.size(); j++) {
+  for (int j = 0; j < subgraph->outputs.size(); j++) {
     // get computed result
-    MXTensor &result = data[heads.list[0].list[0].num];
+    MXTensor *result = subgraph->outputs[j].node->tensor;
     // get output tensor to pass to MX
     MXTensor &out = outputs->at(j);
     float *out_data = out.data<float>();
-    float *res_data = result.data<float>();
+    float *res_data = result->data<float>();
     // loop and copy data
-    for (int64_t i = 0; i < result.size(); i++) {
+    for (int64_t i = 0; i < result->size(); i++) {
       out_data[i] = res_data[i];
     }
   }
@@ -137,22 +131,26 @@ MXReturnValue myExecutor(std::vector<MXTensor>* inputs,
 
 class MyStatefulOp : public CustomStatefulOp {
  public:
-  explicit MyStatefulOp(const std::string& sym,
+  explicit MyStatefulOp(std::string json,
                         const std::unordered_map<std::string, std::string>& attrs)
-    : subgraph_sym(sym), attrs_(attrs) {
-    for (auto kv : attrs) {
+    : attrs_(attrs) {
+    for (const auto &kv : attrs) {
       std::cout << "subgraphOp attributes: " << kv.first << " ==> " << kv.second << std::endl;
     }
+    subgraph_ = mxnet::ext::Graph::fromString(json);
   }
 
   MXReturnValue Forward(std::vector<MXTensor>* inputs,
                         std::vector<MXTensor>* outputs,
-                        const OpResource& op_res) {
-    return myExecutor(inputs, outputs, subgraph_sym);
+                        const OpResource& op_res) override {
+    if(attrs_.count(MX_STR_EXTRA_INPUTS) > 0 && std::stoi(attrs_.at(MX_STR_EXTRA_INPUTS)) > 0)
+      std::cout << "forward::extra_inputs(" << attrs_.at(MX_STR_EXTRA_INPUTS) << ")::inputs ["
+		<< inputs->size() << "]" << std::endl;
+    return myExecutor(inputs, outputs, subgraph_);
   }
 
  private:
-  const std::string subgraph_sym;
+  mxnet::ext::Graph *subgraph_;
   const std::unordered_map<std::string, std::string> attrs_;
 };
 
@@ -176,39 +174,30 @@ REGISTER_OP(_custom_subgraph_op)
 
 const std::vector<std::string> op_names({"exp","log"});
 
-MXReturnValue mySupportedOps(const std::string& json,
+MXReturnValue mySupportedOps(const mxnet::ext::Graph* graph,
                              std::vector<int>* ids,
                              const std::unordered_map<std::string, std::string>& options) {
   for (auto kv : options) {
     std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
   }
-  //convert json string to json object
-  JsonParser parser;
-  JsonVal json_val = parser.parse_to_json(json);
-  //get nodes list
-  JsonVal nodes = json_val.map[JsonVal("nodes")];
 
   //loop over nodes
-  for(int i=0; i<nodes.list.size(); i++) {
-    JsonVal node = nodes.list[i];
-    JsonVal op = node.map[JsonVal("op")];
+  for(int i=0; i<graph->size(); i++) {
+    const mxnet::ext::Node *node = graph->getNode(i);
 
     //get shape/type if available
     std::string shape;
     int dtype = -1;
-    if(node.map.find(JsonVal("attrs")) != node.map.end()) {
-      JsonVal attrs = node.map[JsonVal("attrs")];
-      if(attrs.map.find(JsonVal("shape")) != attrs.map.end()) 
-        shape = attrs.map[JsonVal("shape")].str;
-      if(attrs.map.find(JsonVal("dtype")) != attrs.map.end())
-        dtype = std::stoi(attrs.map[JsonVal("dtype")].str);
-    }
+    if(node->attrs.count("shape") > 0)
+      shape = node->attrs.at("shape");
+    if(node->attrs.count("dtype") > 0)
+      dtype = std::stoi(node->attrs.at("dtype"));
 
     //check if op dtype is float, and if option was specified to require float types
     if((dtype == kFloat32 && options.count("reqFloat") > 0) || options.count("reqFloat") == 0) {
-      //check if op is in whitelist
-      if(std::find(op_names.begin(),op_names.end(),op.str.c_str()) != op_names.end()) {
-        // found op in whitelist, set value to -1 to include op in any subgraph
+      //check if op is in allowlist
+      if(std::find(op_names.begin(),op_names.end(),node->op.c_str()) != op_names.end()) {
+        // found op in allowlist, set value to -1 to include op in any subgraph
         ids->at(i) = -1;
       }
     }
@@ -216,30 +205,11 @@ MXReturnValue mySupportedOps(const std::string& json,
   return MX_SUCCESS;
 }
 
-MXReturnValue myReviewSubgraph(const std::string& json, int subgraph_id, bool* accept,
-                               const std::unordered_map<std::string, std::string>& options,
-                               std::unordered_map<std::string, std::string>* attrs,
-                               const std::unordered_map<std::string, MXTensor>& args,
-                               const std::unordered_map<std::string, MXTensor>& aux) {
+MXReturnValue myReviewSubgraph(const mxnet::ext::Graph *subgraph, int subgraph_id, bool* accept,
+                               const std::unordered_map<std::string, std::string>& options) {
   for (auto kv : options) {
     std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
   }
-  for (auto kv : args) {
-    std::cout << "arg: " << kv.first << " ==> (";
-    for (auto s : kv.second.shape)
-      std::cout << s << ",";
-    std::cout << ") [";
-    for (int i=0; i<kv.second.size(); i++)
-      std::cout << kv.second.data<float>()[i] << ", ";
-    std::cout << "]" << std::endl;
-  }
-
-  // check if option `reqArgs` was specified, and if so check if args were provided
-  if(options.count("reqArgs") > 0 && args.size() == 0) {
-    *accept = false;
-    std::cout << "rejecting subgraph since args were not provided" << std::endl;
-    return MX_SUCCESS;
-  }
 
   // check if option `reject` was specified, and if so check if value is 'True'
   if(options.count("reject") > 0 && options.at("reject").compare("True") == 0) {
@@ -249,7 +219,6 @@ MXReturnValue myReviewSubgraph(const std::string& json, int subgraph_id, bool* a
   } else {
     *accept = true;
     std::cout << "accepting subgraph" << std::endl;
-    attrs->insert(std::pair<std::string,std::string>("myKey","myVal"));
   }
   return MX_SUCCESS;
 }
@@ -261,39 +230,30 @@ REGISTER_PARTITIONER(myProp)
 
 class MySelector : public CustomOpSelector {
  public:
-  MySelector(const std::string& json,
+  MySelector(const mxnet::ext::Graph *graph,
              const std::unordered_map<std::string, std::string>& options) :
-    graph_json(json), options_(options) {
+    graph_(graph), options_(options) {
     for (auto kv : options) {
       std::cout << "selector options: " << kv.first
                 << " ==> " << kv.second << std::endl;
     }
-    //convert json string to json object
-    JsonParser parser;
-    JsonVal json_val = parser.parse_to_json(json);
-    //get nodes list
-    nodes = json_val.map[JsonVal("nodes")];
   }
   bool chooseNode(int nodeID) {
-    JsonVal node = nodes.list[nodeID];
-    JsonVal op = node.map[JsonVal("op")];
+    const mxnet::ext::Node *node = graph_->getNode(nodeID);
 
     //get shape/type if available
     std::string shape;
     int dtype = -1;
-    if(node.map.find(JsonVal("attrs")) != node.map.end()) {
-      JsonVal attrs = node.map[JsonVal("attrs")];
-      if(attrs.map.find(JsonVal("shape")) != attrs.map.end()) 
-        shape = attrs.map[JsonVal("shape")].str;
-      if(attrs.map.find(JsonVal("dtype")) != attrs.map.end())
-        dtype = std::stoi(attrs.map[JsonVal("dtype")].str);
-    }
+    if(node->attrs.count("shape") > 0)
+      shape = node->attrs.at("shape");
+    if(node->attrs.count("dtype") > 0)
+      dtype = std::stoi(node->attrs.at("dtype"));
 
     //check if op dtype is float, and if option was specified to require float types
     if((dtype == kFloat32 && options_.count("reqFloat") > 0) || options_.count("reqFloat") == 0) {
-      //check if op is in whitelist
-      if(std::find(op_names.begin(),op_names.end(),op.str.c_str()) != op_names.end()) {
-        // found op in whitelist, return true to include op subgraph
+      //check if op is in allowlist
+      if(std::find(op_names.begin(),op_names.end(),node->op.c_str()) != op_names.end()) {
+        // found op in allowlist, return true to include op subgraph
 	return true;
       }
     }
@@ -314,14 +274,13 @@ class MySelector : public CustomOpSelector {
   }
   virtual void Reset() {}
  private:
-  std::string graph_json;
-  JsonVal nodes;
+  const mxnet::ext::Graph *graph_;
   const std::unordered_map<std::string, std::string> options_;
 };
 
-MXReturnValue createSelector(const std::string& json, CustomOpSelector** sel_inst,
+MXReturnValue createSelector(const mxnet::ext::Graph *graph, CustomOpSelector** sel_inst,
                              const std::unordered_map<std::string, std::string>& options) {
-  *sel_inst = new MySelector(json, options);
+  *sel_inst = new MySelector(graph, options);
   std::cout << "Info: selector created" << std::endl;
   return MX_SUCCESS;
 }
@@ -331,12 +290,41 @@ REGISTER_PARTITIONER(mySelect)
 .setCreateSelector("strategy1", createSelector)
 .setReviewSubgraph("strategy1", myReviewSubgraph);
 
+/* \brief a basic pass that adds a new input for subgraph ops */
+MXReturnValue addInputPass(mxnet::ext::Graph *graph,
+			   const std::unordered_map<std::string, std::string>& options) {
+  //find node with '_custom_subgraph_op' op type
+  for(int i=0; i<graph->size(); i++) {
+    mxnet::ext::Node* n = graph->getNode(i);
+    if(n->op.compare("_custom_subgraph_op") == 0) {
+      //set extra input
+      n->attrs[MX_STR_EXTRA_INPUTS] = std::to_string(1);
+      
+      //create a new input Node
+      Node* input = graph->addNode(n->name + "_input", "null");
+      //set this node as an input in the graph
+      graph->inputs.push_back(input);
+      //connect new input to node
+      input->outputs.push_back({n,(int)(n->inputs.size())});
+      //connect node to new input
+      n->inputs.push_back({input,0});
+      // add a corresponding tensor for this input
+      input->alloc_arg({1},MXContext::CPU(0),kFloat32);
+    }
+  }
+
+  return MX_SUCCESS;
+}
+
+REGISTER_PASS(addInputPass)
+.setBody(addInputPass);
+
 MXReturnValue initialize(int version) {
   if (version >= 10700) {
     std::cout << "MXNet version " << version << " supported" << std::endl;
     return MX_SUCCESS;
   } else {
-    std::cout << "MXNet version " << version << " not supported" << std::endl;
+    MX_ERROR_MSG << "MXNet version " << version << " not supported by custom library" << std::endl;
     return MX_FAIL;
   }
 }
diff --git a/example/extensions/lib_subgraph/test_subgraph.py b/example/extensions/lib_subgraph/test_subgraph.py
index eb7102a..267a417 100644
--- a/example/extensions/lib_subgraph/test_subgraph.py
+++ b/example/extensions/lib_subgraph/test_subgraph.py
@@ -93,8 +93,8 @@ def test(backend):
     sym_block = nn.SymbolBlock(sym, inputs)
     sym_block.initialize()
     sym_block.hybridize(backend=backend)
-    out4 = sym_block(mx.nd.ones((3,2)),mx.nd.ones((3,2)))
-    print(out4)
+    out2 = sym_block(mx.nd.ones((3,2)),mx.nd.ones((3,2)))
+    print(out2)
 
     # Gluon Hybridize partitioning with shapes/types without inference
     print('-------------------------------')
@@ -105,6 +105,13 @@ def test(backend):
     sym_block2.optimize_for(mx.nd.ones((3,2)), mx.nd.ones((3,2)), backend=backend)
     sym_block2.export('partitioned')
 
+    # Test with additional input to subgraph op
+    print('-------------------------------')
+    print('Testing %s Gluon Hybridize partitioning with extra input' % backend)
+    sym_block2.optimize_for(mx.nd.ones((3,2)), mx.nd.ones((3,2)), backend="addInputPass", clear=False)
+    out3 = sym_block2(mx.nd.ones((3,2)),mx.nd.ones((3,2)))
+    print(out3)
+    
     ###############################################
     # Test with subgraph directly consuming params
     ###############################################
diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h
index c8ba712..edab4a4 100644
--- a/include/mxnet/lib_api.h
+++ b/include/mxnet/lib_api.h
@@ -37,19 +37,22 @@
 #include <string.h>
 #include <vector>
 #include <map>
+#include <unordered_set>
 #include <unordered_map>
 #include <string>
 #include <iostream>
 #include <utility>
 #include <stdexcept>
+#include <functional>
 #include <random>
+#include <sstream>
 
 #if defined(__NVCC__)
   #include <curand_kernel.h>
 #endif
 
 /* Make sure to update the version number everytime you make changes */
-#define MX_LIBRARY_VERSION 7
+#define MX_LIBRARY_VERSION 8
 
 /*!
  * \brief For loading multiple custom op libraries in Linux, exporting same symbol multiple
@@ -210,6 +213,9 @@ extern "C" {
 #endif
 #endif
 
+namespace mxnet {
+namespace ext {
+
 /*!
  * \brief Tensor data type, consistent with mshadow data type
  */
@@ -462,12 +468,14 @@ typedef std::mt19937 mx_cpu_rand_t;
 #define MX_NUM_CPU_RANDOM_STATES 1024
 #define MX_NUM_GPU_RANDOM_STATES 32768
 
+/* \brief Class to help allocate new args/aux params in graph passes */
 class PassResource {
  public:
   PassResource(std::unordered_map<std::string, MXTensor>* new_args,
                std::unordered_map<std::string, MXTensor>* new_aux,
                nd_malloc_t nd_malloc, const void* nd_alloc)
     : new_args_(new_args), new_aux_(new_aux), nd_malloc_(nd_malloc), nd_alloc_(nd_alloc) {}
+  // allocate new arg param, adds to args map, returns newly allocated tensor
   MXTensor* alloc_arg(const std::string& name, const std::vector<int64_t>& shapes,
                       const MXContext &ctx, MXDType dtype) const {
     void* data;
@@ -477,6 +485,7 @@ class PassResource {
     (*new_args_)[name] = tensor;
     return &(new_args_->at(name));
   }
+  // allocate new aux param, adds to aux map, returns newly allocated tensor
   MXTensor* alloc_aux(const std::string& name, const std::vector<int64_t>& shapes,
                       const MXContext &ctx, MXDType dtype) const {
     void* data;
@@ -557,10 +566,14 @@ class OpResource {
   void *rand_cpu_states, *rand_gpu_states;
 };
 
-/*! \brief Macro to help passing serialized subgraph through attribute dict */
+/*! \brief attribute key to help passing serialized subgraph through subgraph op attribute */
 #define MX_STR_SUBGRAPH_SYM_JSON "subgraph_sym_json"
+/*! \brief dtype attribute key for ops after type propagation */
 #define MX_STR_DTYPE "__ext_dtype__"
+/*! \brief shape attribute key for ops after shape propagation */
 #define MX_STR_SHAPE "__ext_shape__"
+/*! \brief extra input attribute key for ops */
+#define MX_STR_EXTRA_INPUTS "__ext_extra_inputs__"
 
 /* \brief get shape value from list of shapes string
  *
@@ -638,52 +651,50 @@ struct JsonVal {
     }
     return type < o.type;
   }
-  JsonType type;
-  int num;
-  std::string str;
-  std::vector<JsonVal> list;
-  std::map<JsonVal, JsonVal> map;
-};
 
-/*! \brief functions used for parsing JSON */
-struct JsonParser {
-  JsonVal parse_to_json(const std::string& json) {
-    unsigned int idx = 0;
-    return parse(json, &idx);
-  }
-  void print_json_val(const JsonVal& val) {
-    std::cout << json_val_string(val) << std::endl;
-  }
-  // debug function to dump data structure to string
-  std::string json_val_string(const JsonVal &val) {
+  // convert JSON object back to JSON-compatible string
+  std::string dump() const {
     std::string ret;
-    switch (val.type) {
+    switch (type) {
     case ERR:
       ret = "json(Error)";
       break;
     case STR:
-      ret = "json(STR:" + val.str + ")";
+      ret = "\"" + str + "\"";
       break;
     case NUM:
-      ret = "json(INT:" + val.str + ")";
+      ret = str;
       break;
     case LIST:
-      ret = "json(LIST:[";
-      for (auto &item : val.list)
-        ret += json_val_string(item) + ",";
-      ret += "])";
+      ret = "[";
+      for (unsigned i=0; i < list.size(); i++) {
+        auto &item = list[i];
+        ret += item.dump();
+        if (i < list.size()-1)
+          ret += ",";
+      }
+      ret += "]";
       break;
     case MAP:
-      ret = "json(MAP:{";
-      for (auto &item : val.map)
-        ret += json_val_string(item.first) + " : " + json_val_string(item.second) + ",";
-      ret += "})";
+      ret = "{";
+      unsigned cnt = 0;
+      for (auto &item : map) {
+        ret += item.first.dump() + " : " + item.second.dump();
+        if (cnt++ < map.size()-1)
+          ret += ",";
+      }
+      ret += "}";
       break;
     }
     return ret;
   }
+  // convert JSON-compatible string to JSON object
+  static JsonVal parse(const std::string& json) {
+    unsigned int idx = 0;
+    return JsonVal::parse(json, &idx);
+  }
   // parse a string JSON object
-  JsonVal parse_string(const std::string& json, unsigned int* idx) {
+  static JsonVal parse_string(const std::string& json, unsigned int* idx) {
     JsonVal ret(STR);
     while (*idx < json.size()) {
       if (json[*idx] == '"') {
@@ -698,7 +709,7 @@ struct JsonParser {
     return JsonVal();
   }
   // parse a number JSON object
-  JsonVal parse_num(const std::string& json, unsigned int* idx) {
+  static JsonVal parse_num(const std::string& json, unsigned int* idx) {
     JsonVal ret(NUM);
     while (*idx < json.size()) {
       if (json[*idx] >= '0' && json[*idx] <= '9') {
@@ -712,14 +723,14 @@ struct JsonParser {
     return ret;
   }
   // parse a list of JSON objects
-  JsonVal parse_list(const std::string& json, unsigned int* idx) {
+  static JsonVal parse_list(const std::string& json, unsigned int* idx) {
     JsonVal ret(LIST);
     while (*idx < json.size()) {
       if (json[*idx] == ']') {
         ++(*idx);
         return ret;
       } else {
-        JsonVal item = parse(json, idx);
+        JsonVal item = JsonVal::parse(json, idx);
         if (item.type != ERR)
           ret.list.push_back(item);
       }
@@ -728,14 +739,14 @@ struct JsonParser {
     return JsonVal();
   }
   // parse a map of JSON objects
-  JsonVal parse_map(const std::string& json, unsigned int* idx) {
+  static JsonVal parse_map(const std::string& json, unsigned int* idx) {
     JsonVal ret(MAP), key;
     while (*idx < json.size()) {
       if (json[*idx] == '}') {
         ++(*idx);
         return ret;
       } else {
-        JsonVal item = parse(json, idx);
+        JsonVal item = JsonVal::parse(json, idx);
         if (key.type == ERR) {
           key = item;
         } else {
@@ -748,62 +759,409 @@ struct JsonParser {
     return JsonVal();
   }
   // generic parse function
-  JsonVal parse(const std::string& json, unsigned int *idx) {
+  static JsonVal parse(const std::string& json, unsigned int *idx) {
     JsonVal ret;
     while (*idx < json.size()) {
       if (json[*idx] == '"') {
         ++(*idx);
-        ret = parse_string(json, idx);
+        ret = JsonVal::parse_string(json, idx);
       } else if (json[*idx] >= '0' && json[*idx] <= '9') {
-        ret = parse_num(json, idx);
+        ret = JsonVal::parse_num(json, idx);
       } else if (json[*idx] == '[') {
         ++(*idx);
-        ret = parse_list(json, idx);
+        ret = JsonVal::parse_list(json, idx);
       } else if (json[*idx] == '{') {
         ++(*idx);
-        ret = parse_map(json, idx);
+        ret = JsonVal::parse_map(json, idx);
       } else if (json[*idx] == ']' || json[*idx] == '}') {return ret;}
       if (ret.type != ERR) return ret;
       ++(*idx);
     }
     return ret;
   }
-  // convert JSON object back to JSON-compatible string
-  std::string dump(const JsonVal &val) {
+  // debug function to convert data structure to a debugstring
+  std::string toString() const {
     std::string ret;
-    switch (val.type) {
+    switch (type) {
     case ERR:
       ret = "json(Error)";
       break;
     case STR:
-      ret = "\"" + val.str + "\"";
+      ret = "json(STR:" + str + ")";
       break;
     case NUM:
-      ret = val.str;
+      ret = "json(INT:" + str + ")";
       break;
     case LIST:
-      ret = "[";
-      for (unsigned i=0; i < val.list.size(); i++) {
-        auto &item = val.list[i];
-        ret += dump(item);
-        if (i < val.list.size()-1)
-          ret += ",";
-      }
-      ret += "]";
+      ret = "json(LIST:[";
+      for (auto &item : list)
+        ret += item.toString() + ",";
+      ret += "])";
       break;
     case MAP:
-      ret = "{";
-      unsigned cnt = 0;
-      for (auto &item : val.map) {
-        ret += dump(item.first) + " : " + dump(item.second);
-        if (cnt++ < val.map.size()-1)
-          ret += ",";
-      }
-      ret += "}";
+      ret = "json(MAP:{";
+      for (auto &item : map)
+        ret += item.first.toString() + " : " + item.second.toString() + ",";
+      ret += "})";
       break;
     }
     return ret;
   }
+  JsonType type;
+  int num;
+  std::string str;
+  std::vector<JsonVal> list;
+  std::map<JsonVal, JsonVal> map;
+};
+
+/*!
+ * \brief Graph utility to parse serialized subgraph symbol
+ */
+class Node;
+class Graph;
+
+// Representation of an input/output to a node
+struct NodeEntry {
+  Node* node;  // other node thats producing/consuming inputs/outputs
+  int entry;  // entry index from other node (ie. output index from producing node)
+};
+
+// Representation of a node in the graph
+class Node {
+ public:
+  Node() {tensor = nullptr;}
+  // internally set passResource to enable tensor allocation for graph passes
+  void _setPassResource(PassResource* res_) {res = res_;}
+  /* \brief allocate an arg tensor for this node */
+  void alloc_arg(const std::vector<int64_t>& shapes,
+                 const MXContext &ctx, MXDType dtype) {
+    if (!res)
+      throw std::runtime_error(
+                 "Node not initialized. Cannot use alloc_arg outside of graph passes.");
+    tensor = res->alloc_arg(name, shapes, ctx, dtype);
+  }
+  /* \brief allocate an aux tensor for this node */
+  void alloc_aux(const std::vector<int64_t>& shapes,
+                 const MXContext &ctx, MXDType dtype) {
+    if (!res)
+      throw std::runtime_error(
+                 "Node not initialized. Cannot use alloc_aux outside of graph passes.");
+    tensor = res->alloc_aux(name, shapes, ctx, dtype);
+  }
+  std::string op;  // operator name (ie. Convolution)
+  std::string name;  // unique node name (ie. conv_0 or conv_1)
+  MXTensor* tensor;  // tensor data for input nodes
+  std::vector<NodeEntry> inputs;  // set of inputs to the node
+  std::vector<NodeEntry> outputs;  // set of outputs from the node
+  std::vector<Graph*> subgraphs;  // set of subgraphs within this node
+  std::unordered_map<std::string, std::string> attrs;  // node attributes
+
+ private:
+  PassResource* res;
+};
+
+// Representation of the graph
+class Graph {
+ public:
+  Graph() : res(nullptr) {}
+  /* \brief deleted nodes when deleting the graph */
+  ~Graph() {
+    for (size_t i = 0; i < nodes.size(); i++)
+      delete nodes[i];
+  }
+
+  /* \brief create a graph object from an unparsed string */
+  static Graph* fromString(const std::string& json) {
+    JsonVal val = JsonVal::parse(json);
+    return fromJson(val);
+  }
+
+  /* \brief create a graph object from a parsed JSON object */
+  static Graph* fromJson(JsonVal val) {
+    // get nodes list
+    JsonVal nodes = val.map[JsonVal("nodes")];
+    Graph *g = new Graph();
+
+    std::map<int, Node*> nodeMap;
+    // loop over nodes
+    for (size_t i = 0; i < nodes.list.size(); i++) {
+      Node* n = new Node();
+      g->nodes.push_back(n);
+      JsonVal node = nodes.list[i];
+
+      // set the op info
+      n->op = node.map[JsonVal("op")].str;
+      n->name = node.map[JsonVal("name")].str;
+
+      // if op is null it is an input to the graph
+      if (n->op.compare("null") == 0)
+        g->inputs.push_back(n);
+
+      // set attrs
+      JsonVal attributes = node.map[JsonVal("attrs")];
+      for (auto& kv : attributes.map) {
+        n->attrs[kv.first.str] = kv.second.str;
+      }
+
+      // set subgraphs, parsing each into a graph
+      if (node.map.count(JsonVal("subgraphs")) > 0) {
+        JsonVal subgraphs = node.map[JsonVal("subgraphs")];
+        for (auto &subgraph : subgraphs.list) {
+          n->subgraphs.push_back(fromJson(subgraph));
+        }
+      }
+
+      // set node inputs
+      JsonVal node_inputs = node.map[JsonVal("inputs")];
+      n->inputs.resize(node_inputs.list.size());
+      for (size_t j = 0; j < node_inputs.list.size(); j++) {
+        JsonVal input = node_inputs.list[j];
+        NodeEntry& entry = n->inputs[j];
+        // get pointer to other node
+        entry.node = nodeMap[input.list[0].num];
+        // get the other node's output index
+        entry.entry = input.list[1].num;
+        // set other nodes output as connected to this node
+        entry.node->outputs.push_back({n, static_cast<int>(j)});
+      }
+      nodeMap[i] = n;
+    }
+
+    // set graph level outputs
+    JsonVal& heads = val.map[JsonVal("heads")];
+    g->outputs.resize(heads.list.size());
+    for (size_t i = 0; i < heads.list.size(); i++) {
+      JsonVal head = heads.list[i];
+      g->outputs[i].node = nodeMap[head.list[0].num];
+      g->outputs[i].entry = head.list[1].num;
+    }
+
+    // add all attributes to the graph
+    for (auto& kv : val.map) {
+      if (kv.first.str.compare("nodes") != 0 &&
+         kv.first.str.compare("heads") != 0 &&
+         kv.first.str.compare("node_row_ptr") != 0 &&
+         kv.first.str.compare("arg_nodes") != 0) {
+        g->attrs[kv.first.str] = kv.second;
+      }
+    }
+    return g;
+  }
+
+  /* \brief convert graph object back to JSON object */
+  JsonVal toJson() {
+    // top level object is a map
+    JsonVal val(MAP);
+
+    // add attributes
+    for (auto& kv : attrs) {
+      val.map[JsonVal(kv.first)] = kv.second;
+    }
+
+    // sort graph nodes in topological order, create mapping of node to index
+    std::map<Node*, int> nodeMap;
+    std::vector<Node*> sorted = topological_sort();
+    // nodes are in reverse topological order in the vector (back is first)
+    // so loop from end to front over the vector 'sorted'
+    for (int i = sorted.size()-1; i >= 0; i--) {
+      nodeMap[sorted[i]] = sorted.size()-1-i;
+    }
+
+    // create node_row_ptr entry
+    val.map[JsonVal("node_row_ptr")] = JsonVal(LIST);
+    JsonVal& node_row_ptr = val.map[JsonVal("node_row_ptr")];
+    for (size_t i = 0; i < nodes.size(); i++)
+      node_row_ptr.list.push_back(JsonVal(i));
+
+    // add all input nodes
+    val.map[JsonVal("arg_nodes")] = JsonVal(LIST);
+    JsonVal& arg_nodes = val.map[JsonVal("arg_nodes")];
+    for (size_t i = 0; i < inputs.size(); i++)
+      arg_nodes.list.push_back(JsonVal(nodeMap[inputs[i]]));
+
+    // add all output nodes
+    val.map[JsonVal("heads")] = JsonVal(LIST);
+    JsonVal& heads = val.map[JsonVal("heads")];
+    for (size_t i = 0; i < outputs.size(); i++) {
+      heads.list.push_back(JsonVal(LIST));
+      JsonVal& out = heads.list[i];
+      out.list.push_back(JsonVal(nodeMap[outputs[i].node]));
+      out.list.push_back(JsonVal(outputs[i].entry));
+      out.list.push_back(JsonVal(0));
+    }
+
+    // add all graph nodes
+    val.map[JsonVal("nodes")] = JsonVal(LIST);
+    JsonVal& nodes_ = val.map[JsonVal("nodes")];
+    for (int i = sorted.size()-1; i >= 0; i--) {
+      // each node is a map
+      nodes_.list.push_back(JsonVal(MAP));
+      Node* n = sorted[i];
+      JsonVal& n_ = nodes_.list[nodes_.list.size()-1];
+
+      n_.map[JsonVal("op")] = JsonVal(n->op);
+      n_.map[JsonVal("name")] = JsonVal(n->name);
+      n_.map[JsonVal("inputs")] = JsonVal(LIST);
+
+      // add inputs for this node
+      JsonVal& inputs_ = n_.map[JsonVal("inputs")];
+      for (size_t j = 0; j < n->inputs.size(); j++) {
+        inputs_.list.push_back(JsonVal(LIST));
+        NodeEntry& entry = n->inputs[j];
+        JsonVal& in = inputs_.list[j];
+        in.list.push_back(JsonVal(nodeMap[entry.node]));
+        in.list.push_back(JsonVal(entry.entry));
+        in.list.push_back(JsonVal(0));
+      }
+
+      // add subgraphs for this node, convert each back to JSON
+      if (n->subgraphs.size() > 0) {
+        n_.map[JsonVal("subgraphs")] = JsonVal(LIST);
+        JsonVal &subgraphs_ = n_.map[JsonVal("subgraphs")];
+        for (Graph *subgraph : n->subgraphs) {
+          subgraphs_.list.push_back(subgraph->toJson());
+        }
+      }
+
+      // add attributes for this node
+      n_.map[JsonVal("attrs")] = JsonVal(MAP);
+      JsonVal& attrs_ = n_.map[JsonVal("attrs")];
+      for (auto& kv : n->attrs) {
+        attrs_.map[JsonVal(kv.first)] = JsonVal(kv.second);
+      }
+    }
+    return val;
+  }
+
+  /* \brief convert graph object to JSON string */
+  std::string toString() {
+    return toJson().dump();
+  }
+
+  /* \brief visits a node "n" */
+  void _dfs_util(Node* n, std::unordered_set<Node*>* to_visit,
+                 std::function<void(Node*)> handler) const {
+    to_visit->erase(n);  // remove node now that we're visiting it
+    for (NodeEntry& e : n->outputs) {
+      Node* o = e.node;
+      if (to_visit->count(o) != 0) {
+        _dfs_util(o, to_visit, handler);  // visit neighbor
+      }
+    }
+    handler(n);  // post-order visit this node
+  }
+
+  /* \brief post-order DFS graph traversal */
+  void DFS(std::function<void(Node*)> handler) const {
+    std::unordered_set<Node*> to_visit;
+    // put all nodes in set to visit
+    for (auto& n : nodes)
+      to_visit.insert(n);
+    // visit all inputs first
+    for (auto& i : inputs)
+      if (to_visit.count(i) != 0)
+        _dfs_util(i, &to_visit, handler);
+    // visit any nodes left
+    while (to_visit.size() > 0)
+      _dfs_util(*(to_visit.begin()), &to_visit, handler);
+  }
+
+  /* \brief sort graph nodes in topological order */
+  std::vector<Node*> topological_sort() const {
+    std::vector<Node*> sorted;
+    auto handler = [&](Node* n) {
+      sorted.push_back(n);  // when visiting each node, add it in order to the vector
+    };
+    DFS(handler);
+    return sorted;
+  }
+
+  /* \brief print out graph details */
+  void print(int indent = 0) const {
+    std::string space = "";
+    for (int i = 0; i < indent; i++) space+=" ";
+
+    std::cout << space << "########### Graph #############" << std::endl;
+    std::cout << space << "attributes: " << std::endl;
+    for (auto &kv : attrs)
+      std::cout << space << "\t" << kv.first << " : " << kv.second.str << std::endl;
+    std::cout << space << "inputs: " << inputs.size() << std::endl;
+    std::cout << space << "outputs: " << outputs.size() << std::endl;
+    std::cout << space << "nodes: " << nodes.size() << std::endl;
+    std::vector<Node*> sorted = topological_sort();
+    // loop over each node and print out its inputs/outputs
+    for (int i = static_cast<int>(sorted.size()-1); i >= 0; i--) {
+      std::cout << space << "Node: " << sorted[i]->name << std::endl;
+      for (size_t j = 0; j < sorted[i]->inputs.size(); j++) {
+        std::cout << space << "\tInput: " << sorted[i]->inputs[j].node->name << " "
+                  << sorted[i]->inputs[j].entry << std::endl;
+      }
+      for (size_t j = 0; j < sorted[i]->outputs.size(); j++) {
+        std::cout << space << "\tOutput: " << sorted[i]->outputs[j].node->name << " "
+                  << sorted[i]->outputs[j].entry << std::endl;
+      }
+      if (sorted[i]->subgraphs.size() > 0) {
+        for (auto &subgraph : sorted[i]->subgraphs) {
+          std::cout << space << "\tSubgraph:" << std::endl;
+          subgraph->print(indent+2);
+        }
+      }
+    }
+    std::cout << space << "###############################" << std::endl;
+  }
+
+  /* \brief add a new node to this graph */
+  Node* addNode(const std::string& name, const std::string& op) {
+    Node* n = new Node();
+    n->name = name;
+    n->op = op;
+    if (res)
+      n->_setPassResource(res);
+    return n;
+  }
+  /* \brief get node at index in graph */
+  Node* getNode(size_t idx) {
+    return nodes[idx];
+  }
+  /* \brief get const node at index in const graph */
+  const Node* getNode(size_t idx) const {
+    return nodes.at(idx);
+  }
+  /* \brief get attribute on graph */
+  const JsonVal& getAttr(const std::string& key) const {
+    return attrs.at(key);
+  }
+  /* \brief get number of nodes in the graph */
+  size_t size() const {
+    return nodes.size();
+  }
+  // internally set passResource to enable tensor allocation for graph passes
+  void _setPassResource(PassResource* res_) {
+    res = res_;
+    // set passResource for each node
+    for (Node* node : nodes) {
+      node->_setPassResource(res);
+    }
+  }
+  // internally set arg/aux params when available
+  void _setParams(std::unordered_map<std::string, mxnet::ext::MXTensor>* args,
+                  std::unordered_map<std::string, mxnet::ext::MXTensor>* aux) {
+    // set params for each input node
+    for (Node* node : inputs) {
+      if (args->count(node->name) > 0)
+        node->tensor = &args->at(node->name);
+      else if (aux->count(node->name) > 0)
+        node->tensor = &aux->at(node->name);
+    }
+  }
+
+  std::vector<Node*> inputs;
+  std::vector<NodeEntry> outputs;
+  std::map<std::string, JsonVal> attrs;
+
+ private:
+  std::vector<Node*> nodes;
+  PassResource* res;
 };
 
 /* \brief An abstract class for library authors creating custom
@@ -993,11 +1351,8 @@ class CustomOp {
 };
 
 /*! \brief Custom Pass Create function template */
-typedef MXReturnValue (*graphPass_t)(const std::string& in_graph, const std::string** out_graph,
-                                     const std::unordered_map<std::string, std::string>& options,
-                                     const std::unordered_map<std::string, MXTensor>& args,
-                                     const std::unordered_map<std::string, MXTensor>& aux,
-                                     const PassResource& res);
+typedef MXReturnValue (*graphPass_t)(mxnet::ext::Graph* graph,
+                                     const std::unordered_map<std::string, std::string>& options);
 
 /*!
  * \brief An abstract class for graph passes
@@ -1019,18 +1374,17 @@ class CustomPass {
 };
 
 /*! \brief Custom Subgraph Create function template */
-typedef MXReturnValue (*supportedOps_t)(const std::string& json, std::vector<int>* ids,
+typedef MXReturnValue (*supportedOps_t)(const mxnet::ext::Graph *graph, std::vector<int>* ids,
                                         const std::unordered_map<std::string,
                                                                  std::string>& options);
-typedef MXReturnValue (*createSelector_t)(const std::string& json, CustomOpSelector** sel_inst,
+typedef MXReturnValue (*createSelector_t)(const mxnet::ext::Graph *graph,
+                                          CustomOpSelector** sel_inst,
                                           const std::unordered_map<std::string,
                                                                    std::string>& options);
-typedef MXReturnValue (*reviewSubgraph_t)(const std::string& json, int subgraph_id, bool* accept,
+typedef MXReturnValue (*reviewSubgraph_t)(const mxnet::ext::Graph *subgraph, int subgraph_id,
+                                          bool* accept,
                                           const std::unordered_map<std::string,
-                                                                   std::string>& options,
-                                          std::unordered_map<std::string, std::string>* attrs,
-                                          const std::unordered_map<std::string, MXTensor>& args,
-                                          const std::unordered_map<std::string, MXTensor>& aux);
+                                                                   std::string>& options);
 
 /*!
  * \brief An abstract class for subgraph property
@@ -1165,6 +1519,47 @@ class Registry {
   MX_STR_CONCAT(MX_REGISTER_PASS_DEF_(Name), __COUNTER__) = \
     Registry<CustomPass>::get()->add(MX_TOSTRING(Name))
 
+/* \brief Class to store error messages from extensions to pass to MXNet */
+class MXerrorMsgs {
+ public:
+  /*!
+   * \brief get singleton pointer to class
+   * \returns pointer to class
+   */
+  static MXerrorMsgs* get() {
+    static MXerrorMsgs inst;
+    return &inst;
+  }
+  /*!
+   * \brief add a new error message
+   */
+  std::stringstream& add(const char* file, int line) {
+    messages.push_back(new std::stringstream());
+    *(messages.back()) << file << "[" << line << "]: ";
+    return *(messages.back());
+  }
+  int size() {
+    return messages.size();
+  }
+  const std::string* get(int idx) {
+    return new std::string(messages.at(idx)->str());
+  }
+
+ private:
+  /*! \brief constructor */
+  MXerrorMsgs() {}
+  /*! \brief destructor */
+  ~MXerrorMsgs() {
+    for (auto &msg : messages)
+      delete msg;
+  }
+  /*! \brief map of entries in registry */
+  std::vector<std::stringstream*> messages;
+};
+
+// Add a new error message, example: MX_ERROR_MSG << "my error msg";
+#define MX_ERROR_MSG MXerrorMsgs::get()->add(__FILE__, __LINE__)
+
 /* -------------- BELOW ARE CTYPE FUNCTIONS PROTOTYPES --------------- */
 
 /*!
@@ -1177,12 +1572,13 @@ typedef int (*opRegSize_t)(void);
 
 #define MXLIB_OPREGGET_STR "_opRegGet"
 typedef int (*opRegGet_t)(int idx, const char** name, int *isSGop,
-                          const char*** forward_ctx, fcomp_t** forward_fp, int* forward_count,
-                          const char*** backward_ctx, fcomp_t** backward_fp, int* backward_count,
-                          const char*** create_op_ctx, createOpState_t** create_op_fp,
-                          int* create_op_count,
-                          parseAttrs_t* parse, inferType_t* type, inferSType_t* stype,
-                          inferShape_t* shape, mutateInputs_t* mutate);
+                          const char*** forward_ctx, mxnet::ext::fcomp_t** forward_fp,
+                          int* forward_count, const char*** backward_ctx,
+                          mxnet::ext::fcomp_t** backward_fp, int* backward_count,
+                          const char*** create_op_ctx, mxnet::ext::createOpState_t** create_op_fp,
+                          int* create_op_count, mxnet::ext::parseAttrs_t* parse,
+                          mxnet::ext::inferType_t* type, mxnet::ext::inferSType_t* stype,
+                          mxnet::ext::inferShape_t* shape, mxnet::ext::mutateInputs_t* mutate);
 
 #define MXLIB_OPCALLFREE_STR "_opCallFree"
 typedef int (*opCallFree_t)(void* ptr);
@@ -1343,6 +1739,12 @@ typedef int (*initialize_t)(int version);
 #define MXLIB_OPVERSION_STR "_opVersion"
 typedef int (*opVersion_t)();
 
+#define MXLIB_MSGSIZE_STR "_msgSize"
+typedef int (*msgSize_t)(void);
+
+#define MXLIB_MSGGET_STR "_msgGet"
+typedef int (*msgGet_t)(int idx, const char** msg);
+
 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
 #define MX_INT_RET  __declspec(dllexport) int __cdecl
 #define MX_VOID_RET __declspec(dllexport) void __cdecl
@@ -1351,6 +1753,9 @@ typedef int (*opVersion_t)();
 #define MX_VOID_RET void
 #endif
 
+}  // namespace ext
+}  // namespace mxnet
+
 extern "C" {
   /*! \brief returns MXNet library version */
   MX_INT_RET _opVersion() {
@@ -1359,18 +1764,19 @@ extern "C" {
 
   /*! \brief returns number of ops registered in this library */
   MX_INT_RET _opRegSize() {
-    return Registry<CustomOp>::get()->size();
+    return mxnet::ext::Registry<mxnet::ext::CustomOp>::get()->size();
   }
 
   /*! \brief returns operator registration at specified index */
   MX_VOID_RET _opRegGet(int idx, const char** name, int *isSGop,
-                        const char*** forward_ctx, fcomp_t** forward_fp,
+                        const char*** forward_ctx, mxnet::ext::fcomp_t** forward_fp,
                         int* forward_count, const char*** backward_ctx,
-                        fcomp_t** backward_fp, int* backward_count,
-                        const char*** create_op_ctx, createOpState_t** create_op_fp,
-                        int* create_op_count, parseAttrs_t* parse, inferType_t* type,
-                        inferSType_t* stype, inferShape_t* shape, mutateInputs_t* mutate) {
-    CustomOp &op = Registry<CustomOp>::get()->get(idx);
+                        mxnet::ext::fcomp_t** backward_fp, int* backward_count,
+                        const char*** create_op_ctx, mxnet::ext::createOpState_t** create_op_fp,
+                        int* create_op_count, mxnet::ext::parseAttrs_t* parse,
+                        mxnet::ext::inferType_t* type, mxnet::ext::inferSType_t* stype,
+                        mxnet::ext::inferShape_t* shape, mxnet::ext::mutateInputs_t* mutate) {
+    mxnet::ext::CustomOp &op = mxnet::ext::Registry<mxnet::ext::CustomOp>::get()->get(idx);
     *name = op.name;
     *parse = op.parse_attrs;
     *type = op.infer_type;
@@ -1396,7 +1802,7 @@ extern "C" {
   }
 
   /*! \brief returns status of calling parse attributes function for operator from library */
-  MX_INT_RET _opCallParseAttrs(parseAttrs_t parseAttrs, const char* const* keys,
+  MX_INT_RET _opCallParseAttrs(mxnet::ext::parseAttrs_t parseAttrs, const char* const* keys,
                                const char* const* vals, int num,
                                int* num_in, int* num_out) {
     // create map of attributes from list
@@ -1409,7 +1815,7 @@ extern "C" {
   }
 
   /*! \brief returns status of calling inferShape function for operator from library */
-  MX_INT_RET _opCallInferShape(inferShape_t inferShape, const char* const* keys,
+  MX_INT_RET _opCallInferShape(mxnet::ext::inferShape_t inferShape, const char* const* keys,
                                const char* const* vals, int num,
                                unsigned int** inshapes, int* indims, int num_in,
                                unsigned int*** mod_inshapes, int** mod_indims,
@@ -1464,7 +1870,7 @@ extern "C" {
   }
 
   /*! \brief returns status of calling inferType function for operator from library */
-  MX_INT_RET _opCallInferType(inferType_t inferType, const char* const* keys,
+  MX_INT_RET _opCallInferType(mxnet::ext::inferType_t inferType, const char* const* keys,
                               const char* const* vals, int num,
                               int* intypes, int num_in, int* outtypes, int num_out) {
     // create map of attributes from list
@@ -1499,7 +1905,7 @@ extern "C" {
   }
 
   /*! \brief returns status of calling inferSType function for operator from library */
-  MX_INT_RET _opCallInferSType(inferSType_t inferSType, const char* const* keys,
+  MX_INT_RET _opCallInferSType(mxnet::ext::inferSType_t inferSType, const char* const* keys,
                                const char* const* vals, int num,
                                int* instypes, int num_in, int* outstypes, int num_out) {
     // create map of attributes from list
@@ -1535,14 +1941,17 @@ extern "C" {
   }
 
   /*! \brief returns status of calling Forward/Backward function for operator from library */
-  MX_INT_RET _opCallFCompute(fcomp_t fcomp, const char* const* keys, const char* const* vals,
+  MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp, const char* const* keys,
+                             const char* const* vals,
                              int num, const int64_t** inshapes, int* indims, void** indata,
                              int* intypes, size_t* inIDs, const char** indev_type, int* indev_id,
                              int num_in, const int64_t** outshapes, int* outdims, void** outdata,
                              int* outtypes, size_t* outIDs, const char** outdev_type,
-                             int* outdev_id, int num_out, xpu_malloc_t cpu_malloc, void* cpu_alloc,
-                             xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream,
-                             sparse_malloc_t sparse_malloc, void* sparse_alloc,
+                             int* outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc,
+                             void* cpu_alloc,
+                             mxnet::ext::xpu_malloc_t gpu_malloc, void* gpu_alloc,
+                             void* cuda_stream,
+                             mxnet::ext::sparse_malloc_t sparse_malloc, void* sparse_alloc,
                              int* instypes, int* outstypes, void** in_indices, void** out_indices,
                              void** in_indptr, void** out_indptr,
                              int64_t* in_indices_shapes, int64_t* out_indices_shapes,
@@ -1555,66 +1964,70 @@ extern "C" {
     }
 
     // create a vector of tensors for inputs
-    std::vector<MXTensor> inputs(num_in);
+    std::vector<mxnet::ext::MXTensor> inputs(num_in);
     // create a vector for sparse inputs
-    std::vector<MXSparse> in_sparse(num_in);
+    std::vector<mxnet::ext::MXSparse> in_sparse(num_in);
 
     for (int i = 0; i < num_in; i++) {
       // Dense representation.
       if (instypes[i] == 0) {
-        inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i],
-                            inIDs[i], MXContext(indev_type[i], indev_id[i]), kDefaultStorage);
+        inputs[i].setTensor(indata[i], (mxnet::ext::MXDType)intypes[i], inshapes[i], indims[i],
+                            inIDs[i], mxnet::ext::MXContext(indev_type[i], indev_id[i]),
+                            mxnet::ext::kDefaultStorage);
       } else {
         // Sparse representation.
-        MXStorageType type;
+        mxnet::ext::MXStorageType type;
         if (instypes[i] == 1) {
-          type = kRowSparseStorage;
+          type = mxnet::ext::kRowSparseStorage;
           in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]);
         } else {
-          type = kCSRStorage;
+          type = mxnet::ext::kCSRStorage;
           in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i],
                            in_indices_shapes[i], in_indptr[i], in_indptr_shapes[i]);
         }
-        inputs[i].setTensor(reinterpret_cast<void*>(&in_sparse[i]), (MXDType)intypes[i],
+        inputs[i].setTensor(reinterpret_cast<void*>(&in_sparse[i]), (mxnet::ext::MXDType)intypes[i],
                             inshapes[i], indims[i], inIDs[i],
-                            MXContext(indev_type[i], indev_id[i]), type);
+                            mxnet::ext::MXContext(indev_type[i], indev_id[i]), type);
       }
     }
 
     // create a vector of tensors for outputs
-    std::vector<MXTensor> outputs(num_out);
-    std::vector<MXSparse> out_sparse(num_out);
+    std::vector<mxnet::ext::MXTensor> outputs(num_out);
+    std::vector<mxnet::ext::MXSparse> out_sparse(num_out);
 
     for (int i = 0; i < num_out; i++) {
       // Dense representation.
       if (outstypes[i] == 0) {
-        outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i],
-                             outIDs[i], MXContext(outdev_type[i], outdev_id[i]), kDefaultStorage);
+        outputs[i].setTensor(outdata[i], (mxnet::ext::MXDType)outtypes[i], outshapes[i], outdims[i],
+                             outIDs[i], mxnet::ext::MXContext(outdev_type[i], outdev_id[i]),
+                             mxnet::ext::kDefaultStorage);
       } else {
         // Sparse representation.
-        MXStorageType type;
+        mxnet::ext::MXStorageType type;
         if (outstypes[i] == 1) {
-          type = kRowSparseStorage;
+          type = mxnet::ext::kRowSparseStorage;
           out_sparse[i].set(outdata[i], outshapes[i], outdims[i],
                             out_indices[i], out_indices_shapes[i]);
         } else {
-          type = kCSRStorage;
+          type = mxnet::ext::kCSRStorage;
           out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i],
                             out_indices_shapes[i], out_indptr[i], out_indptr_shapes[i]);
         }
-        outputs[i].setTensor(reinterpret_cast<void*>(&out_sparse[i]), (MXDType)outtypes[i],
+        outputs[i].setTensor(reinterpret_cast<void*>(&out_sparse[i]),
+                             (mxnet::ext::MXDType)outtypes[i],
                              outshapes[i], outdims[i], outIDs[i],
-                             MXContext(outdev_type[i], outdev_id[i]), type);
+                             mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), type);
       }
     }
 
-    OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc,
-                   cuda_stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states);
+    mxnet::ext::OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc,
+                               cuda_stream, sparse_malloc, sparse_alloc,
+                               rng_cpu_states, rng_gpu_states);
     return fcomp(attrs, &inputs, &outputs, res);
   }
 
   /*! \brief returns status of calling mutateInputs function for operator from library */
-  MX_INT_RET _opCallMutateInputs(mutateInputs_t mutate, const char* const* keys,
+  MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char* const* keys,
                                  const char* const* vals, int num,
                                  int** mutate_indices, int* indices_size) {
     // create map of attributes from list
@@ -1641,7 +2054,7 @@ extern "C" {
   }
 
   /*! \brief returns status of calling createStatefulOp function for operator from library */
-  MX_INT_RET _opCallCreateOpState(createOpState_t create_op, const char* const* keys,
+  MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char* const* keys,
                                   const char* const* vals, int num,
                                   void** state_op) {
     // create map of attributes from list
@@ -1652,7 +2065,8 @@ extern "C" {
 
     // void pointer to hold custom state op instance created in custom library
     // eventually state_op pointer is populated by instance from custom library
-    CustomStatefulOp** op_ptr = reinterpret_cast<CustomStatefulOp**>(state_op);
+    mxnet::ext::CustomStatefulOp** op_ptr =
+      reinterpret_cast<mxnet::ext::CustomStatefulOp**>(state_op);
     return create_op(attrs, op_ptr);
   }
 
@@ -1662,9 +2076,11 @@ extern "C" {
                                      const char** indev_type, int* indev_id, int num_in,
                                      const int64_t** outshapes, int* outdims, void** outdata,
                                      int* outtypes, size_t* outIDs, const char** outdev_type,
-                                     int* outdev_id, int num_out, xpu_malloc_t cpu_malloc,
-                                     void* cpu_alloc, xpu_malloc_t gpu_malloc, void* gpu_alloc,
-                                     void* stream, sparse_malloc_t sparse_malloc,
+                                     int* outdev_id, int num_out,
+                                     mxnet::ext::xpu_malloc_t cpu_malloc,
+                                     void* cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc,
+                                     void* gpu_alloc,
+                                     void* stream, mxnet::ext::sparse_malloc_t sparse_malloc,
                                      void* sparse_alloc, int* instypes, int* outstypes,
                                      void** in_indices, void** out_indices, void** in_indptr,
                                      void** out_indptr, int64_t* in_indices_shapes,
@@ -1672,64 +2088,68 @@ extern "C" {
                                      int64_t* out_indptr_shapes,
                                      void* rng_cpu_states, void* rng_gpu_states) {
     // create a vector of tensors for inputs
-    std::vector<MXTensor> inputs(num_in);
+    std::vector<mxnet::ext::MXTensor> inputs(num_in);
     // create a vector for sparse inputs
-    std::vector<MXSparse> in_sparse(num_in);
+    std::vector<mxnet::ext::MXSparse> in_sparse(num_in);
 
     for (int i = 0; i < num_in; i++) {
       if (instypes[i] == 0) {
         // Dense representation.
-        inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i],
-                            inIDs[i], MXContext(indev_type[i], indev_id[i]), kDefaultStorage);
+        inputs[i].setTensor(indata[i], (mxnet::ext::MXDType)intypes[i], inshapes[i], indims[i],
+                            inIDs[i], mxnet::ext::MXContext(indev_type[i], indev_id[i]),
+                            mxnet::ext::kDefaultStorage);
       } else {
         // Sparse representation.
-        MXStorageType type;
+        mxnet::ext::MXStorageType type;
         if (instypes[i] == 1) {
-          type = kRowSparseStorage;
+          type = mxnet::ext::kRowSparseStorage;
           in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]);
         } else {
-          type = kCSRStorage;
+          type = mxnet::ext::kCSRStorage;
           in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i],
                            in_indices_shapes[i], in_indptr[i], in_indptr_shapes[i]);
         }
-        inputs[i].setTensor(reinterpret_cast<void*>(&in_sparse[i]), (MXDType)intypes[i],
+        inputs[i].setTensor(reinterpret_cast<void*>(&in_sparse[i]), (mxnet::ext::MXDType)intypes[i],
                             inshapes[i], indims[i], inIDs[i],
-                            MXContext(indev_type[i], indev_id[i]), type);
+                            mxnet::ext::MXContext(indev_type[i], indev_id[i]), type);
       }
     }
 
     // create a vector of tensors for outputs
-    std::vector<MXTensor> outputs(num_out);
+    std::vector<mxnet::ext::MXTensor> outputs(num_out);
     // create a vector for sparse outputs
-    std::vector<MXSparse> out_sparse(num_out);
+    std::vector<mxnet::ext::MXSparse> out_sparse(num_out);
 
     for (int i = 0; i < num_out; i++) {
       if (outstypes[i] == 0) {
         // Dense representation.
-        outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i],
-                             outIDs[i], MXContext(outdev_type[i], outdev_id[i]), kDefaultStorage);
+        outputs[i].setTensor(outdata[i], (mxnet::ext::MXDType)outtypes[i], outshapes[i], outdims[i],
+                             outIDs[i], mxnet::ext::MXContext(outdev_type[i], outdev_id[i]),
+                             mxnet::ext::kDefaultStorage);
       } else {
         // Sparse representation.
-        MXStorageType type;
+        mxnet::ext::MXStorageType type;
         if (outstypes[i] == 1) {
-          type = kRowSparseStorage;
+          type = mxnet::ext::kRowSparseStorage;
           out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i],
                             out_indices_shapes[i]);
         } else {
-          type = kCSRStorage;
+          type = mxnet::ext::kCSRStorage;
           out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i],
                             out_indices_shapes[i], out_indptr[i], out_indptr_shapes[i]);
         }
-        outputs[i].setTensor(reinterpret_cast<void*>(&out_sparse[i]), (MXDType)outtypes[i],
+        outputs[i].setTensor(reinterpret_cast<void*>(&out_sparse[i]),
+                             (mxnet::ext::MXDType)outtypes[i],
                              outshapes[i], outdims[i], outIDs[i],
-                             MXContext(outdev_type[i], outdev_id[i]), type);
+                             mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), type);
       }
     }
 
-    OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc,
-                   stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states);
+    mxnet::ext::OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc,
+                               stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states);
 
-    CustomStatefulOp* op_ptr = reinterpret_cast<CustomStatefulOp*>(state_op);
+    mxnet::ext::CustomStatefulOp* op_ptr =
+      reinterpret_cast<mxnet::ext::CustomStatefulOp*>(state_op);
     if (is_forward) {
       return op_ptr->Forward(&inputs, &outputs, res);
     }
@@ -1738,22 +2158,25 @@ extern "C" {
 
   /*! \brief returns number of partitioners registered in this library */
   MX_INT_RET _partRegSize() {
-    return Registry<CustomPartitioner>::get()->size();
+    return mxnet::ext::Registry<mxnet::ext::CustomPartitioner>::get()->size();
   }
 
   /* returns number of strategies registered for partitioner
    * at specified index */
   MX_INT_RET _partRegGetCount(int idx, const char** name) {
-    CustomPartitioner part = Registry<CustomPartitioner>::get()->get(idx);
+    mxnet::ext::CustomPartitioner part =
+      mxnet::ext::Registry<mxnet::ext::CustomPartitioner>::get()->get(idx);
     *name = part.name;
     return part.strategies.size();
   }
 
   /*! \brief returns partitioner registration at specified index */
   MX_VOID_RET _partRegGet(int part_idx, int stg_idx, const char** strategy,
-                        supportedOps_t* supportedOps, createSelector_t* createSelector,
-                        reviewSubgraph_t* reviewSubgraph, const char** op_name) {
-    CustomPartitioner part = Registry<CustomPartitioner>::get()->get(part_idx);
+                          mxnet::ext::supportedOps_t* supportedOps,
+                          mxnet::ext::createSelector_t* createSelector,
+                          mxnet::ext::reviewSubgraph_t* reviewSubgraph, const char** op_name) {
+    mxnet::ext::CustomPartitioner part =
+      mxnet::ext::Registry<mxnet::ext::CustomPartitioner>::get()->get(part_idx);
     *strategy = part.strategies[stg_idx];
     *op_name = part.op_names[stg_idx];
     *supportedOps = part.getSupportedOps(stg_idx);
@@ -1762,10 +2185,10 @@ extern "C" {
   }
 
   /*! \brief returns status of calling supported ops function from library */
-  MX_INT_RET _partCallSupportedOps(supportedOps_t supportedOps, const char *json,
+  MX_INT_RET _partCallSupportedOps(mxnet::ext::supportedOps_t supportedOps, const char *json,
                                    int num_ids, int *ids, const char* const* opt_keys,
                                    const char* const* opt_vals, int num_opts) {
-    std::string subgraph_json(json);
+    mxnet::ext::Graph *graph = mxnet::ext::Graph::fromString(json);
     // create map of options from list
     std::unordered_map<std::string, std::string> opts;
     for (int i = 0; i < num_opts; i++)
@@ -1774,7 +2197,7 @@ extern "C" {
     // create array of subgraph IDs for operator support
     std::vector<int> _ids(num_ids, -2);
     // call user's supportedOps function
-    MXReturnValue retval = supportedOps(subgraph_json, &_ids, opts);
+    mxnet::ext::MXReturnValue retval = supportedOps(graph, &_ids, opts);
     if (!retval) return retval;
 
     // copy bools in ids to ints
@@ -1785,10 +2208,10 @@ extern "C" {
   }
 
   /*! \brief returns status of calling create selector function from library */
-  MX_INT_RET _partCallCreateSelector(createSelector_t createSelector, const char *json,
+  MX_INT_RET _partCallCreateSelector(mxnet::ext::createSelector_t createSelector, const char *json,
                                      void** selector, const char* const* opt_keys,
                                      const char* const* opt_vals, int num_opts) {
-    std::string symbol_json(json);
+    mxnet::ext::Graph *graph = mxnet::ext::Graph::fromString(json);
     // create map of options from list
     std::unordered_map<std::string, std::string> opts;
     for (int i = 0; i < num_opts; i++)
@@ -1796,36 +2219,41 @@ extern "C" {
 
     // void pointer to hold selector instance created in custom library
     // eventually pointer is populated by instance from custom library
-    CustomOpSelector** sel_ptr = reinterpret_cast<CustomOpSelector**>(selector);
+    mxnet::ext::CustomOpSelector** sel_ptr =
+      reinterpret_cast<mxnet::ext::CustomOpSelector**>(selector);
 
     // call user's createSelector function
-    return createSelector(symbol_json, sel_ptr, opts);
+    return createSelector(graph, sel_ptr, opts);
   }
 
   /*! \brief returns status of calling select function from library */
   MX_VOID_RET _partCallSelect(void* sel_inst, int nodeID, int* selected) {
-    CustomOpSelector* sel_ptr = reinterpret_cast<CustomOpSelector*>(sel_inst);
+    mxnet::ext::CustomOpSelector* sel_ptr =
+      reinterpret_cast<mxnet::ext::CustomOpSelector*>(sel_inst);
     *selected = sel_ptr->Select(nodeID);
   }
 
   /*! \brief returns status of calling select input function from library */
   MX_VOID_RET _partCallSelectInput(void* sel_inst, int nodeID,
                                   int input_nodeID, int* selected) {
-    CustomOpSelector* sel_ptr = reinterpret_cast<CustomOpSelector*>(sel_inst);
+    mxnet::ext::CustomOpSelector* sel_ptr =
+      reinterpret_cast<mxnet::ext::CustomOpSelector*>(sel_inst);
     *selected = sel_ptr->SelectInput(nodeID, input_nodeID);
   }
 
   /*! \brief returns status of calling select output function from library */
   MX_VOID_RET _partCallSelectOutput(void* sel_inst, int nodeID,
                                     int output_nodeID, int* selected) {
-    CustomOpSelector* sel_ptr = reinterpret_cast<CustomOpSelector*>(sel_inst);
+    mxnet::ext::CustomOpSelector* sel_ptr =
+      reinterpret_cast<mxnet::ext::CustomOpSelector*>(sel_inst);
     *selected = sel_ptr->SelectOutput(nodeID, output_nodeID);
   }
 
   /*! \brief returns status of calling filter function from library */
   MX_VOID_RET _partCallFilter(void* sel_inst, int* candidates, int num_candidates,
                               int** keep, int* num_keep) {
-    CustomOpSelector* sel_ptr = reinterpret_cast<CustomOpSelector*>(sel_inst);
+    mxnet::ext::CustomOpSelector* sel_ptr =
+      reinterpret_cast<mxnet::ext::CustomOpSelector*>(sel_inst);
     std::vector<int> candidates_(num_candidates);
     for (int i=0; i < num_candidates; i++) {
       candidates_[i] = candidates[i];
@@ -1842,12 +2270,13 @@ extern "C" {
 
   /*! \brief returns status of calling reset selector function from library */
   MX_VOID_RET _partCallReset(void* sel_inst) {
-    CustomOpSelector* sel_ptr = reinterpret_cast<CustomOpSelector*>(sel_inst);
+    mxnet::ext::CustomOpSelector* sel_ptr =
+      reinterpret_cast<mxnet::ext::CustomOpSelector*>(sel_inst);
     sel_ptr->Reset();
   }
 
   /*! \brief returns status of calling review subgraph function from library */
-  MX_INT_RET _partCallReviewSubgraph(reviewSubgraph_t reviewSubgraph, const char *json,
+  MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph, const char *json,
                                      int subgraph_id, int *accept, const char* const* opt_keys,
                                      const char* const* opt_vals, int num_opts,
                                      char*** attr_keys, char*** attr_vals, int *num_attrs,
@@ -1861,7 +2290,7 @@ extern "C" {
                                      const int* aux_dims, const int* aux_types,
                                      const size_t* aux_IDs, const char* const* aux_dev_type,
                                      const int* aux_dev_id) {
-    std::string subgraph_json(json);
+    mxnet::ext::Graph *subgraph = mxnet::ext::Graph::fromString(json);
     bool accept_bool = false;
     // create map of attributes from list
     std::unordered_map<std::string, std::string> opts;
@@ -1869,50 +2298,50 @@ extern "C" {
       opts[std::string(opt_keys[i])] = std::string(opt_vals[i]);
 
     // create a map of named tensors for args
-    std::unordered_map<std::string, MXTensor> args;
+    std::unordered_map<std::string, mxnet::ext::MXTensor> args;
     for (int i = 0; i < num_args; i++) {
       std::vector<int64_t> shapes;
       for (int j = 0; j < arg_dims[i]; j++)
         shapes.push_back(arg_shapes[i][j]);
 
-      MXTensor tensor(arg_data[i], shapes, (MXDType)arg_types[i],
-                      arg_IDs[i], MXContext(arg_dev_type[i], arg_dev_id[i]));
+      mxnet::ext::MXTensor tensor(arg_data[i], shapes, (mxnet::ext::MXDType)arg_types[i],
+                      arg_IDs[i], mxnet::ext::MXContext(arg_dev_type[i], arg_dev_id[i]));
       args[arg_names[i]] = tensor;
     }
     // create a map of named tensors for aux
-    std::unordered_map<std::string, MXTensor> aux;
+    std::unordered_map<std::string, mxnet::ext::MXTensor> aux;
     for (int i = 0; i < num_aux; i++) {
       std::vector<int64_t> shapes;
       for (int j = 0; j < aux_dims[i]; j++)
         shapes.push_back(aux_shapes[i][j]);
 
-      MXTensor tensor(aux_data[i], shapes, (MXDType)aux_types[i],
-                      aux_IDs[i], MXContext(aux_dev_type[i], aux_dev_id[i]));
+      mxnet::ext::MXTensor tensor(aux_data[i], shapes, (mxnet::ext::MXDType)aux_types[i],
+                                  aux_IDs[i], mxnet::ext::MXContext(aux_dev_type[i],
+                                                                    aux_dev_id[i]));
       aux[aux_names[i]] = tensor;
     }
 
-    // attributes to set on subgraph node
-    std::unordered_map<std::string, std::string> attrs;
-
-    MXReturnValue retval = reviewSubgraph(subgraph_json, subgraph_id, &accept_bool,
-                                          opts, &attrs, args, aux);
+    subgraph->_setParams(&args, &aux);
+    mxnet::ext::MXReturnValue retval = reviewSubgraph(subgraph, subgraph_id, &accept_bool,
+                                                      opts);
     if (!retval) return retval;
 
     *accept = accept_bool;
 
-    if (attrs.size() > 0) {
-      *num_attrs = attrs.size();
+    if (subgraph->attrs.size() > 0) {
+      *num_attrs = subgraph->attrs.size();
       // allocate space for attributes
-      *attr_keys = static_cast<char**>(malloc (attrs.size() * sizeof(char*)));
-      *attr_vals = static_cast<char**>(malloc (attrs.size() * sizeof(char*)));
+      *attr_keys = static_cast<char**>(malloc (*num_attrs * sizeof(char*)));
+      *attr_vals = static_cast<char**>(malloc (*num_attrs * sizeof(char*)));
 
       // copy attributes
       int i = 0;
-      for (auto kv : attrs) {
+      for (auto kv : subgraph->attrs) {
         (*attr_keys)[i] = static_cast<char*>(malloc ((kv.first.size()+1) * sizeof(char)));
-        (*attr_vals)[i] = static_cast<char*>(malloc ((kv.second.size()+1) * sizeof(char)));
+        std::string val = kv.second.dump();  // convert JsonVal back to string
+        (*attr_vals)[i] = static_cast<char*>(malloc ((val.size()+1) * sizeof(char)));
         snprintf((*attr_keys)[i], kv.first.size()+1, "%s", kv.first.c_str());
-        snprintf((*attr_vals)[i], kv.second.size()+1, "%s", kv.second.c_str());
+        snprintf((*attr_vals)[i], val.size()+1, "%s", val.c_str());
         i++;
       }
     }
@@ -1922,20 +2351,21 @@ extern "C" {
 
   /*! \brief returns number of graph passes registered in this library */
   MX_INT_RET _passRegSize() {
-    return Registry<CustomPass>::get()->size();
+    return mxnet::ext::Registry<mxnet::ext::CustomPass>::get()->size();
   }
 
   /*! \brief returns pass registration at specified index */
-  MX_VOID_RET _passRegGet(int pass_idx, graphPass_t* graphPass,
+  MX_VOID_RET _passRegGet(int pass_idx, mxnet::ext::graphPass_t* graphPass,
                           const char** pass_name) {
-    CustomPass pass = Registry<CustomPass>::get()->get(pass_idx);
+    mxnet::ext::CustomPass pass =
+      mxnet::ext::Registry<mxnet::ext::CustomPass>::get()->get(pass_idx);
     *graphPass = pass.pass;
     *pass_name = pass.name;
   }
 
   /*! \brief returns status of calling graph pass function from library */
-  MX_INT_RET _passCallGraphPass(graphPass_t graphPass, const char *json,
-                                char** graph, const char* const* opt_keys,
+  MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *json,
+                                char** out_graph, const char* const* opt_keys,
                                 const char* const* opt_vals, int num_opts,
                                 const char* pass_name, const char* const* arg_names, int num_args,
                                 void* const* arg_data, const int64_t* const* arg_shapes,
@@ -1945,51 +2375,48 @@ extern "C" {
                                 void* const* aux_data, const int64_t* const* aux_shapes,
                                 const int* aux_dims, const int* aux_types,
                                 const size_t* aux_IDs, const char* const* aux_dev_type,
-                                const int* aux_dev_id, nd_malloc_t nd_malloc,
+                                const int* aux_dev_id, mxnet::ext::nd_malloc_t nd_malloc,
                                 const void* nd_alloc) {
-    std::string graph_json(json);
-    const std::string* out_graph = nullptr;
+    mxnet::ext::Graph *graph = mxnet::ext::Graph::fromString(json);
     // create map of attributes from list
     std::unordered_map<std::string, std::string> opts;
     for (int i = 0; i < num_opts; i++)
       opts[std::string(opt_keys[i])] = std::string(opt_vals[i]);
 
     // create a map of named tensors for args
-    std::unordered_map<std::string, MXTensor> args;
+    std::unordered_map<std::string, mxnet::ext::MXTensor> args;
     for (int i = 0; i < num_args; i++) {
       std::vector<int64_t> shapes;
       for (int j = 0; j < arg_dims[i]; j++)
         shapes.push_back(arg_shapes[i][j]);
 
-      MXTensor tensor(arg_data[i], shapes, (MXDType)arg_types[i],
-                      arg_IDs[i], MXContext(arg_dev_type[i], arg_dev_id[i]));
+      mxnet::ext::MXTensor tensor(arg_data[i], shapes, (mxnet::ext::MXDType)arg_types[i],
+                                  arg_IDs[i], mxnet::ext::MXContext(arg_dev_type[i],
+                                                                    arg_dev_id[i]));
       args[arg_names[i]] = tensor;
     }
     // create a map of named tensors for aux
-    std::unordered_map<std::string, MXTensor> aux;
+    std::unordered_map<std::string, mxnet::ext::MXTensor> aux;
     for (int i = 0; i < num_aux; i++) {
       std::vector<int64_t> shapes;
       for (int j = 0; j < aux_dims[i]; j++)
         shapes.push_back(aux_shapes[i][j]);
 
-      MXTensor tensor(aux_data[i], shapes, (MXDType)aux_types[i],
-                      aux_IDs[i], MXContext(aux_dev_type[i], aux_dev_id[i]));
+      mxnet::ext::MXTensor tensor(aux_data[i], shapes, (mxnet::ext::MXDType)aux_types[i],
+                                  aux_IDs[i], mxnet::ext::MXContext(aux_dev_type[i],
+                                                                    aux_dev_id[i]));
       aux[aux_names[i]] = tensor;
     }
 
-    std::unordered_map<std::string, MXTensor> new_args, new_aux;
-    PassResource res(&new_args, &new_aux, nd_malloc, nd_alloc);
-    MXReturnValue retval = graphPass(graph_json, &out_graph, opts, args, aux, res);
+    std::unordered_map<std::string, mxnet::ext::MXTensor> new_args, new_aux;
+    mxnet::ext::PassResource res(&new_args, &new_aux, nd_malloc, nd_alloc);
+    graph->_setParams(&args, &aux);
+    graph->_setPassResource(&res);
+    mxnet::ext::MXReturnValue retval = graphPass(graph, opts);
     if (!retval) return retval;
 
-    if (out_graph == nullptr) {
-      std::cout << "Error calling graph pass '" << pass_name
-                << "' returned out_graph string is null" << std::endl;
-      return MX_FAIL;
-    }
-    *graph = static_cast<char*>(malloc((out_graph->length()+1) * sizeof(char)));
-    out_graph->copy(*graph, out_graph->size()+1);
-    delete out_graph;
+    std::string *tmp = new std::string(graph->toString());
+    *out_graph = const_cast<char*>(tmp->c_str());
     return retval;
   }
 
@@ -2001,10 +2428,19 @@ extern "C" {
    * \return Non-zero value on error i.e. library incompatible with passed MXNet version
    */
 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
-  __declspec(dllexport) MXReturnValue __cdecl
+  __declspec(dllexport) mxnet::ext::MXReturnValue __cdecl
 #else
-  MXReturnValue
+  mxnet::ext::MXReturnValue
 #endif
   initialize(int version);
-}
+
+  MX_INT_RET _msgSize() {
+    return mxnet::ext::MXerrorMsgs::get()->size();
+  }
+
+  /*! \brief returns operator registration at specified index */
+  MX_VOID_RET _msgGet(int idx, const char** msg) {
+    *msg = mxnet::ext::MXerrorMsgs::get()->get(idx)->c_str();
+  }
+}  // extern "C"
 #endif  // MXNET_LIB_API_H_
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 968c787..d7afd8a 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -1035,7 +1035,7 @@ class HybridBlock(Block):
             out = [out]
         return _regroup(out, self._out_format)
 
-    def optimize_for(self, x, *args, backend=None, backend_opts=None, **kwargs):
+    def optimize_for(self, x, *args, backend=None, backend_opts=None, clear=True, **kwargs):
         """Partitions the current HybridBlock and optimizes it for a given backend
         without executing a forward pass. Modifies the HybridBlock in-place.
 
@@ -1065,6 +1065,7 @@ class HybridBlock(Block):
             The name of backend, as registered in `SubgraphBackendRegistry`, default None
         backend_opts : dict of user-specified options to pass to the backend for partitioning, optional
             Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`
+        clear : clears any previous optimizations
         static_alloc : bool, default False
             Statically allocate memory to improve speed. Memory usage may increase.
         static_shape : bool, default False
@@ -1074,7 +1075,7 @@ class HybridBlock(Block):
         """
 
         # do hybrize API call
-        self.hybridize(True, backend, backend_opts, **kwargs)
+        self.hybridize(True, backend, backend_opts, clear, **kwargs)
 
         # do part of forward API call
         has_symbol, has_ndarray, ctx_set, _ = _gather_type_ctx_info([x] + list(args))
@@ -1112,7 +1113,7 @@ class HybridBlock(Block):
         super(HybridBlock, self).register_child(block, name)
         self._clear_cached_op()
 
-    def hybridize(self, active=True, backend=None, backend_opts=None, **kwargs):
+    def hybridize(self, active=True, backend=None, backend_opts=None, clear=True, **kwargs):
         """Activates or deactivates :py:class:`HybridBlock` s recursively. Has no effect on
         non-hybrid children.
 
@@ -1124,6 +1125,7 @@ class HybridBlock(Block):
             The name of backend, as registered in `SubgraphBackendRegistry`, default None
         backend_opts : dict of user-specified options to pass to the backend for partitioning, optional
             Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`
+        clear : clears any previous optimizations
         static_alloc : bool, default False
             Statically allocate memory to improve speed. Memory usage may increase.
         static_shape : bool, default False
@@ -1140,7 +1142,8 @@ class HybridBlock(Block):
 
         self._active = active
         self._flags = list(kwargs.items())
-        self._clear_cached_op()
+        if clear:
+            self._clear_cached_op()
         if active and self._forward_hooks or self._forward_pre_hooks:
             warnings.warn('"{block}" is being hybridized while still having forward hook/pre-hook. '
                           'If "{block}" is a child of HybridBlock, the hooks will not take effect.'
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index fdc79423..8384940 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -97,21 +97,41 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e,
 
 // NOTE: return value is added in API_END
 
+std::string getExtensionMsgs(mxnet::ext::msgSize_t msgSize,
+                             mxnet::ext::msgGet_t msgGet) {
+  std::string str;
+  if (msgSize() > 0) {
+    str = "\nExtension Traceback:\n";
+    for (int i = 0; i < msgSize(); i++) {
+      const char* tmp;
+      msgGet(i, &tmp);
+      // format: [i] message
+      str += std::string("\t[") + std::to_string(i) + std::string("] ")
+        + std::string(tmp) + std::string("\n");
+    }
+  }
+  return str;
+}
+
 /*!
  * \brief Common compute function dispatcher for forward/backward and stateful forward/backward
  * state_ptr will be nullptr for regular ops; fcomp_fp is nullptr for stateful ops
  */
 void CustomFComputeDispatcher(const std::string op_name,
-                              const opCallFComp_t callFComp,
-                              const fcomp_t fcomp_fp,
+                              const mxnet::ext::opCallFComp_t callFComp,
+                              const mxnet::ext::fcomp_t fcomp_fp,
                               const nnvm::NodeAttrs* attrs,
-                              const opCallFStatefulComp_t callFStatefulComp,
+                              const mxnet::ext::opCallFStatefulComp_t callFStatefulComp,
                               int stateful_forward_flag,
                               const OpStatePtr* state_ptr,
                               const OpContext& ctx,
                               const std::vector<NDArray>& inputs,
                               const std::vector<OpReqType>& req,
-                              const std::vector<NDArray>& outputs) {
+                              const std::vector<NDArray>& outputs,
+                              mxnet::ext::msgSize_t msgSize,
+                              mxnet::ext::msgGet_t msgGet) {
+  using namespace mxnet::ext;
+
   std::vector<void*> in_data, out_data;
   std::vector<const int64_t*> in_shapes, out_shapes;
   std::vector<int> in_dims, out_dims;
@@ -280,47 +300,225 @@ void CustomFComputeDispatcher(const std::string op_name,
     }
 
     // call fcompute function
-    CHECK(callFComp(fcomp_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
-                    in_shapes.data(), in_dims.data(), in_data.data(), in_types.data(),
-                    in_verIDs.data(), in_dev_type.data(), in_dev_id.data(), in_data.size(),
-                    out_shapes.data(), out_dims.data(), out_data.data(), out_types.data(),
-                    out_verIDs.data(), out_dev_type.data(), out_dev_id.data(), out_data.size(),
-                    cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream,
-                    sparse_malloc, &sparse_alloc, in_stypes.data(), out_stypes.data(),
-                    in_indices.data(), out_indices.data(), in_indptr.data(), out_indptr.data(),
-                    in_indices_shapes.data(), out_indices_shapes.data(),
-                    in_indptr_shapes.data(), out_indptr_shapes.data(),
-                    rng_cpu_states, rng_gpu_states))
-      << "Error calling FCompute for custom operator '" << op_name << "'";
+    int retval = callFComp(fcomp_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
+                           in_shapes.data(), in_dims.data(), in_data.data(), in_types.data(),
+                           in_verIDs.data(), in_dev_type.data(), in_dev_id.data(), in_data.size(),
+                           out_shapes.data(), out_dims.data(), out_data.data(), out_types.data(),
+                           out_verIDs.data(), out_dev_type.data(), out_dev_id.data(),
+                           out_data.size(),
+                           cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream,
+                           sparse_malloc, &sparse_alloc, in_stypes.data(), out_stypes.data(),
+                           in_indices.data(), out_indices.data(), in_indptr.data(),
+                           out_indptr.data(),
+                           in_indices_shapes.data(), out_indices_shapes.data(),
+                           in_indptr_shapes.data(), out_indptr_shapes.data(),
+                           rng_cpu_states, rng_gpu_states);
+    std::string msgs = getExtensionMsgs(msgSize, msgGet);
+    CHECK(retval) << "Error calling FCompute for custom operator '" << op_name << "'" << msgs;
   }
 
   if (state_ptr != nullptr) {
     // retrieve op state object created from CreateOpState
     CustomStatefulOpWrapper& op = state_ptr->get_state<CustomStatefulOpWrapper>();
     CustomStatefulOp* state_op_inst = op.get_instance();
+    std::string msgs = getExtensionMsgs(msgSize, msgGet);
     CHECK(state_op_inst != nullptr)
-      << "Error custom stateful operator is null for operator '" << op_name << "'";
+      << "Error custom stateful operator is null for operator '" << op_name << "'" << msgs;
 
     // call fcompute function
-    CHECK(callFStatefulComp(stateful_forward_flag, state_op_inst,
-                            in_shapes.data(), in_dims.data(), in_data.data(), in_types.data(),
-                            in_verIDs.data(), in_dev_type.data(), in_dev_id.data(),
-                            in_data.size(),
-                            out_shapes.data(), out_dims.data(), out_data.data(), out_types.data(),
-                            out_verIDs.data(), out_dev_type.data(), out_dev_id.data(),
-                            out_data.size(),
-                            cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream,
-                            sparse_malloc, &sparse_alloc, in_stypes.data(), out_stypes.data(),
-                            in_indices.data(), out_indices.data(),
-                            in_indptr.data(), out_indptr.data(),
-                            in_indices_shapes.data(), out_indices_shapes.data(),
-                            in_indptr_shapes.data(), out_indptr_shapes.data(),
-                            rng_cpu_states, rng_gpu_states))
-      << "Error calling FStatefulCompute for custom operator '" << op_name << "'";
-  }
-}
-
-void registerOperators(void *lib, int verbose) {
+    int retval = callFStatefulComp(stateful_forward_flag, state_op_inst,
+                                   in_shapes.data(), in_dims.data(), in_data.data(),
+                                   in_types.data(),
+                                   in_verIDs.data(), in_dev_type.data(), in_dev_id.data(),
+                                   in_data.size(),
+                                   out_shapes.data(), out_dims.data(), out_data.data(),
+                                   out_types.data(),
+                                   out_verIDs.data(), out_dev_type.data(), out_dev_id.data(),
+                                   out_data.size(),
+                                   cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream,
+                                   sparse_malloc, &sparse_alloc, in_stypes.data(),
+                                   out_stypes.data(), in_indices.data(), out_indices.data(),
+                                   in_indptr.data(), out_indptr.data(),
+                                   in_indices_shapes.data(), out_indices_shapes.data(),
+                                   in_indptr_shapes.data(), out_indptr_shapes.data(),
+                                   rng_cpu_states, rng_gpu_states);
+    msgs = getExtensionMsgs(msgSize, msgGet);
+    CHECK(retval) << "Error calling FStatefulCompute for custom operator '" << op_name << "'"
+                  << msgs;
+  }
+}
+
+template <typename RescReq, typename AttrParser, typename NumInputs, typename NumOutputs,
+          typename NumInOuts,
+          typename InferType, typename InferShape, typename InferSType, typename MutateInputs,
+          typename SubgraphNumInputs, typename SubgraphInferType, typename SubgraphInferShape,
+          typename SubgraphInferSType, typename CreateOpState, typename GradReg>
+void registerOp(const char* name, const std::string& name_str, bool isSubgraphOp,
+                RescReq resc_req, AttrParser attr_parser, NumInputs num_inputs,
+                NumOutputs num_outputs, NumInOuts num_inouts, InferType infer_type,
+                InferShape infer_shape, InferSType infer_storage_type,
+                MutateInputs mutate_inputs, SubgraphNumInputs num_subgraph_inputs,
+                SubgraphInferType infer_subgraph_type, SubgraphInferShape infer_subgraph_shape,
+                SubgraphInferSType infer_subgraph_storage_type, CreateOpState create_opstate,
+                GradReg grad_reg, mxnet::ext::mutateInputs_t mutate_fp,
+                const std::unordered_map<std::string, mxnet::ext::createOpState_t> &createop_map,
+                const std::unordered_map<std::string, mxnet::ext::fcomp_t> &forward_ctx_map,
+                const std::unordered_map<std::string, mxnet::ext::fcomp_t> &backward_ctx_map,
+                mxnet::ext::opCallFComp_t callFComp,
+                mxnet::ext::opCallFStatefulComp_t callFStatefulComp,
+                mxnet::ext::msgSize_t msgSize,
+                mxnet::ext::msgGet_t msgGet) {
+  using namespace mxnet::ext;
+
+  // check if operator is already registered
+  const nnvm::Op *regOpPtr = dmlc::Registry<nnvm::Op>::Get()->Find(name);
+  nnvm::Op &regOp = dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(name);
+  int plevel = 10;
+  if (regOpPtr != nullptr) {
+    // overwrite registration of existing op with custom op
+    regOp.arguments.clear();
+    // set attribute with higher plevel (11) to allow re-registering once
+    // TODO(samskalicky): enable constant overwriting of registertion multiple times
+    plevel++;
+  }
+  // define supported resources for both subgraph ops and regular ops
+  regOp.set_attr<FResourceRequest>("FResourceRequest", resc_req, plevel);
+  if (!isSubgraphOp) {
+    regOp.set_attr_parser(attr_parser);
+    regOp.set_num_inputs(num_inputs);
+    regOp.set_num_outputs(num_outputs);
+    regOp.set_attr<nnvm::FInferType>("FInferType", infer_type, plevel);
+    regOp.set_attr<FInferStorageType>("FInferStorageType", infer_storage_type, plevel);
+    regOp.set_attr<mxnet::FInferShape>("FInferShape", infer_shape, plevel);
+    // optionally add fmutate inputs if user specified a function
+    if (mutate_fp != nullptr)
+      regOp.set_attr<nnvm::FMutateInputs>("FMutateInputs", mutate_inputs, plevel);
+  } else {
+    using namespace mxnet::op;
+    regOp.set_num_inputs(num_subgraph_inputs);
+    regOp.set_num_outputs(DefaultSubgraphOpNumOutputs);
+    regOp.set_attr<nnvm::FInferType>("FInferType", infer_subgraph_type, plevel);
+    regOp.set_attr<mxnet::FInferShape>("FInferShape", infer_subgraph_shape, plevel);
+    regOp.set_attr<FInferStorageType>("FInferStorageType",
+                                      infer_subgraph_storage_type, plevel);
+    regOp.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+                                        DefaultSubgraphOpMutableInputs, plevel);
+  }
+  // optionally add stateful forward
+  if (createop_map.size() != 0) {
+    regOp.set_attr<FCreateOpState>("FCreateOpState", create_opstate, plevel);
+    auto fstate_forward = [=](const OpStatePtr& state_ptr,
+                              const OpContext& ctx,
+                              const std::vector<NDArray>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<NDArray>& outputs) {
+      CustomFComputeDispatcher(name_str, nullptr, nullptr, nullptr,
+                               callFStatefulComp, 1, &state_ptr, ctx, inputs, req, outputs,
+                               msgSize, msgGet);
+    };
+    if (createop_map.count("cpu") > 0)
+      regOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", fstate_forward, plevel);
+    if (createop_map.count("gpu") > 0)
+      regOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", fstate_forward, plevel);
+  } else {
+    auto forward_lambda = [=](const nnvm::NodeAttrs& attrs,
+                              const OpContext& ctx,
+                              const std::vector<NDArray>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<NDArray>& outputs) {
+      if (ctx.run_ctx.ctx.dev_mask() == Context::kCPU) {
+        CHECK_GT(forward_ctx_map.count("cpu"), 0);
+        fcomp_t fcomp = forward_ctx_map.at("cpu");
+        CustomFComputeDispatcher(name_str, callFComp, fcomp, &attrs,
+                                 nullptr, 0, nullptr, ctx, inputs, req, outputs, msgSize, msgGet);
+      } else if (ctx.run_ctx.ctx.dev_mask() == Context::kGPU) {
+        CHECK_GT(forward_ctx_map.count("gpu"), 0);
+        fcomp_t fcomp = forward_ctx_map.at("gpu");
+        CustomFComputeDispatcher(name_str, callFComp, fcomp, &attrs,
+                                 nullptr, 0, nullptr, ctx, inputs, req, outputs, msgSize, msgGet);
+      }
+    };
+    if (forward_ctx_map.count("cpu") > 0)
+      regOp.set_attr<FComputeEx>("FComputeEx<cpu>", forward_lambda, plevel);
+    if (forward_ctx_map.count("gpu") > 0)
+      regOp.set_attr<FComputeEx>("FComputeEx<gpu>", forward_lambda, plevel);
+  }
+  // optionally add fgradient if user specified a function, or for stateful ops
+  if (backward_ctx_map.size() != 0 || createop_map.size() != 0) {
+    std::string grad_name = "_backward_" + name_str;
+    nnvm::Op &gradOp = dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(grad_name);
+    regOp.set_attr<nnvm::FGradient>("FGradient", grad_reg, plevel);
+    gradOp.set_attr<nnvm::TIsBackward>("TIsBackward", true, plevel);
+    gradOp.set_attr<FInferStorageType>("FInferStorageType", infer_storage_type, plevel);
+    gradOp.set_attr<FResourceRequest>("FResourceRequest", resc_req, plevel);
+
+    if (!isSubgraphOp) {
+      // register attr parser and standard functions for non-subgraph ops
+      gradOp.set_attr_parser(attr_parser);
+      gradOp.set_num_inputs(num_inouts);
+      gradOp.set_num_outputs(num_inputs);
+    } else {
+      // for subgraph ops use special functions that do not invoke attr_parser
+      using namespace mxnet::op;
+      auto grad_inouts = [=](const nnvm::NodeAttrs& attrs) {
+        // for backward passes, inputs + outputs + input gradients (one for each output)
+        uint32_t cnt = num_subgraph_inputs(attrs);
+        cnt += 2 * DefaultSubgraphOpNumOutputs(attrs);
+        return cnt;
+      };
+      gradOp.set_num_inputs(grad_inouts);
+      gradOp.set_num_outputs(num_subgraph_inputs);
+    }
+
+    if (createop_map.size() != 0) {
+      // for stateful operators
+      gradOp.set_attr<bool>("TIsLayerOpBackward", true, plevel);
+      auto fstate_backward = [=](const OpStatePtr& state_ptr,
+                                 const OpContext& ctx,
+                                 const std::vector<NDArray>& inputs,
+                                 const std::vector<OpReqType>& req,
+                                 const std::vector<NDArray>& outputs) {
+        CustomFComputeDispatcher(name_str, nullptr, nullptr, nullptr,
+                                 callFStatefulComp, 0, &state_ptr, ctx, inputs, req, outputs,
+                                 msgSize, msgGet);
+      };
+      gradOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", fstate_backward, plevel);
+      gradOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", fstate_backward, plevel);
+    } else {
+      // for stateless operators
+      if (backward_ctx_map.count("cpu") > 0) {
+        fcomp_t fcomp_back_cpu = backward_ctx_map.at("cpu");
+        auto backward_cpu_lambda = [=](const nnvm::NodeAttrs& attrs,
+                                       const OpContext& ctx,
+                                       const std::vector<NDArray>& inputs,
+                                       const std::vector<OpReqType>& req,
+                                       const std::vector<NDArray>& outputs) {
+          CustomFComputeDispatcher(name_str, callFComp, fcomp_back_cpu, &attrs,
+                                   nullptr, 0, nullptr, ctx, inputs, req, outputs, msgSize, msgGet);
+        };
+        gradOp.set_attr<FComputeEx>("FComputeEx<cpu>", backward_cpu_lambda, plevel);
+      }
+      if (backward_ctx_map.count("gpu") > 0) {
+        fcomp_t fcomp_back_gpu = backward_ctx_map.at("gpu");
+        auto backward_gpu_lambda = [=](const nnvm::NodeAttrs& attrs,
+                                       const OpContext& ctx,
+                                       const std::vector<NDArray>& inputs,
+                                       const std::vector<OpReqType>& req,
+                                       const std::vector<NDArray>& outputs) {
+          CustomFComputeDispatcher(name_str, callFComp, fcomp_back_gpu, &attrs,
+                                   nullptr, 0, nullptr, ctx, inputs, req, outputs, msgSize, msgGet);
+        };
+        gradOp.set_attr<FComputeEx>("FComputeEx<gpu>", backward_gpu_lambda, plevel);
+      }
+    }
+    }
+  regOp.add_argument("data", "NDArray[]", "Source inputs");
+}
+
+void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
+                       mxnet::ext::msgGet_t msgGet) {
+  using namespace mxnet::ext;
+
   // get C type interface functions
   opCallFree_t callFree = get_func<opCallFree_t>(lib, const_cast<char*>(MXLIB_OPCALLFREE_STR));
 
@@ -446,9 +644,10 @@ void registerOperators(void *lib, int verbose) {
 
       int num_in = -1;
       int num_out = -1;
-      CHECK(callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
-                           &num_in, &num_out))
-      << "Error calling ParseAttrs for custom operator '" << name_str << "'";
+      int retval = callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
+                                  &num_in, &num_out);
+      std::string msgs = getExtensionMsgs(msgSize, msgGet);
+      CHECK(retval) << "Error calling ParseAttrs for custom operator '" << name_str << "'" << msgs;
 
       // return type void
     };
@@ -464,11 +663,31 @@ void registerOperators(void *lib, int verbose) {
 
       int num_in = -1;
       int num_out = -1;
-      CHECK(callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
-                           &num_in, &num_out))
-      << "Error calling ParseAttrs::num_inputs for custom operator '" << name_str << "'";
+      int retval = callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
+                                  &num_in, &num_out);
+      std::string msgs = getExtensionMsgs(msgSize, msgGet);
+      CHECK(retval) << "Error calling ParseAttrs::num_inputs for custom operator '" << name_str
+      << "'" << msgs;
+
+      // get extra inputs, if exists
+      size_t extra_inputs = 0;
+      if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
+        extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));
+
+      return num_in + extra_inputs;
+    };
+
+    // lambda function to call parse attributes and return the number of inputs for subgraph ops
+    auto num_subgraph_inputs = [=](const NodeAttrs& attrs) {
+      // get number of inputs for subgraph
+      int num_in = mxnet::op::DefaultSubgraphOpNumInputs(attrs);
 
-      return num_in;
+      // get extra inputs, if exists
+      size_t extra_inputs = 0;
+      if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
+        extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));
+
+      return num_in + extra_inputs;
     };
 
     // lambda function to call parse attributes and return the number of outputs
@@ -482,9 +701,11 @@ void registerOperators(void *lib, int verbose) {
 
       int num_in = -1;
       int num_out = -1;
-      CHECK(callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
-                           &num_in, &num_out))
-      << "Error calling ParseAttrs::num_outputs for custom operator '" << name_str << "'";
+      int retval = callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
+                                  &num_in, &num_out);
+      std::string msgs = getExtensionMsgs(msgSize, msgGet);
+      CHECK(retval) << "Error calling ParseAttrs::num_outputs for custom operator '" << name_str
+      << "'" << msgs;
 
       return num_out;
     };
@@ -501,11 +722,19 @@ void registerOperators(void *lib, int verbose) {
 
       int num_in = -1;
       int num_out = -1;
-      CHECK(callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
-                           &num_in, &num_out))
-      << "Error calling ParseAttrs::num_outputs for custom operator '" << name_str << "'";
+      int retval = callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
+                                  &num_in, &num_out);
+      std::string msgs = getExtensionMsgs(msgSize, msgGet);
+      CHECK(retval) << "Error calling ParseAttrs::num_outputs for custom operator '" << name_str
+      << "'" << msgs;
       // for backward passes, inputs + outputs + input gradients (one for each output)
-      return num_in + 2 * num_out;
+
+      // get extra inputs, if exists
+      size_t extra_inputs = 0;
+      if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
+        extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));
+
+      return num_in + extra_inputs + 2 * num_out;
     };
 
     // lambda function to call infer shape
@@ -519,17 +748,24 @@ void registerOperators(void *lib, int verbose) {
         attr_vals.push_back(kv.second.c_str());
       }
 
-      std::vector<uint32_t*> inshapes(in_shape->size());
-      std::vector<int> indims(in_shape->size());
+      // get extra inputs, if exists
+      size_t extra_inputs = 0;
+      if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
+        extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));
+      size_t num_inputs = in_shape->size() - extra_inputs;
+
+      std::vector<uint32_t*> inshapes(num_inputs);
+      std::vector<int> indims(num_inputs);
 
       // determine amount of memory needed to store all the input shapes
       size_t buff_size = 0;
-      for (const auto& i : *in_shape) buff_size += i.ndim();
+      for (size_t i = 0; i < num_inputs; ++i)
+        buff_size += (*in_shape)[i].ndim();
 
       // copy input shapes from ShapeVector to raw memory layout
       std::vector<uint32_t> inbuff(buff_size);
       uint32_t *ptr = inbuff.data();
-      for (size_t i = 0; i < in_shape->size(); ++i) {
+      for (size_t i = 0; i < num_inputs; ++i) {
         inshapes[i] = ptr;
         indims[i] = (*in_shape)[i].ndim();
         for (int j = 0; j < (*in_shape)[i].ndim(); ++j, ++ptr) {
@@ -544,23 +780,24 @@ void registerOperators(void *lib, int verbose) {
       uint32_t** outshapes = nullptr;
       int* outdims = nullptr;
 
-      CHECK(callInferShape(shape_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
-                           inshapes.data(), indims.data(), in_shape->size(),
-                           &mod_inshapes, &mod_indims,
-                           &outshapes, &outdims, out_shape->size()))
-      << "Error calling InferShape for custom operator '" << name_str << "'";
+      int retval = callInferShape(shape_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
+                                  inshapes.data(), indims.data(), num_inputs,
+                                  &mod_inshapes, &mod_indims,
+                                  &outshapes, &outdims, out_shape->size());
+      std::string msgs = getExtensionMsgs(msgSize, msgGet);
+      CHECK(retval) << "Error calling InferShape for custom operator '" << name_str << "'" << msgs;
 
-      std::vector<uint32_t*> in_shapes(in_shape->size());
+      std::vector<uint32_t*> in_shapes(num_inputs);
       // determine amount of memory needed to store all the modified input shapes
       buff_size = 0;
-      for (unsigned i = 0; i < in_shape->size(); i++) {
+      for (size_t i = 0; i < num_inputs; i++) {
         buff_size += mod_indims[i];
       }
 
       // copy modified input shapes from custom op memory to MXNet memory
       std::vector<uint32_t> mod_inbuff(buff_size);
       ptr = mod_inbuff.data();
-      for (unsigned i = 0; i < in_shape->size(); ++i) {
+      for (size_t i = 0; i < num_inputs; ++i) {
         in_shapes[i] = ptr;
         for (int j = 0; j < mod_indims[i]; ++j, ++ptr) {
           *ptr = static_cast<uint32_t>(mod_inshapes[i][j]);
@@ -568,7 +805,7 @@ void registerOperators(void *lib, int verbose) {
       }
 
       // assign modified input shapes to ShapeVector
-      for (unsigned i = 0; i < in_shape->size(); ++i) {
+      for (size_t i = 0; i < num_inputs; ++i) {
         SHAPE_ASSIGN_CHECK(*in_shape, i,
                            mxnet::TShape(in_shapes[i], in_shapes[i]+mod_indims[i]));
       }
@@ -576,14 +813,14 @@ void registerOperators(void *lib, int verbose) {
       std::vector<uint32_t*> out_shapes(out_shape->size());
       // determine amount of memory needed to store all the output shapes
       buff_size = 0;
-      for (unsigned i = 0; i < out_shape->size(); i++) {
+      for (size_t i = 0; i < out_shape->size(); i++) {
         buff_size += outdims[i];
       }
 
       // copy output shapes from custom op memory to MXNet memory
       std::vector<uint32_t> outbuff(buff_size);
       ptr = outbuff.data();
-      for (unsigned i = 0; i < out_shape->size(); ++i) {
+      for (size_t i = 0; i < out_shape->size(); ++i) {
         out_shapes[i] = ptr;
         for (int j = 0; j < outdims[i]; ++j, ++ptr) {
           *ptr = static_cast<uint32_t>(outshapes[i][j]);
@@ -591,20 +828,20 @@ void registerOperators(void *lib, int verbose) {
       }
 
       // assign output shapes to ShapeVector
-      for (unsigned i = 0; i < out_shape->size(); ++i) {
+      for (size_t i = 0; i < out_shape->size(); ++i) {
         SHAPE_ASSIGN_CHECK(*out_shape, i,
                            mxnet::TShape(out_shapes[i], out_shapes[i]+outdims[i]));
       }
 
       // free memory used by custom op to allocate shapes/dims
       callFree(mod_indims);
-      for (unsigned i = 0; i < in_shape->size(); i++) {
+      for (size_t i = 0; i < num_inputs; i++) {
         callFree(mod_inshapes[i]);
       }
       callFree(mod_inshapes);
 
       callFree(outdims);
-      for (unsigned i = 0; i < out_shape->size(); i++) {
+      for (size_t i = 0; i < out_shape->size(); i++) {
         callFree(outshapes[i]);
       }
       callFree(outshapes);
@@ -612,6 +849,28 @@ void registerOperators(void *lib, int verbose) {
       return true;
     };
 
+    // lambda function to call infer shape for subgraph ops
+    auto infer_subgraph_shape = [=] (const nnvm::NodeAttrs& attrs,
+                            mxnet::ShapeVector *in_shape,
+                            mxnet::ShapeVector *out_shape) {
+      // convert attributes to vector of char*
+      std::vector<const char*> attr_keys, attr_vals;
+      for (auto &kv : attrs.dict) {
+        attr_keys.push_back(kv.first.c_str());
+        attr_vals.push_back(kv.second.c_str());
+      }
+
+      // get extra inputs, if exists
+      size_t extra_inputs = 0;
+      if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
+        extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));
+
+      auto in_first = in_shape->begin();
+      auto in_last  = in_first + in_shape->size() - extra_inputs;
+      mxnet::ShapeVector *sg_in_shapes = new mxnet::ShapeVector(in_first, in_last);
+      return mxnet::op::DefaultSubgraphOpShape(attrs, sg_in_shapes, out_shape);
+    };
+
     // lambda function to call infer type
     auto infer_type = [=] (const nnvm::NodeAttrs& attrs,
                             std::vector<int> *in_type,
@@ -623,19 +882,26 @@ void registerOperators(void *lib, int verbose) {
         attr_vals.push_back(kv.second.c_str());
       }
 
+      // get extra inputs, if exists
+      size_t extra_inputs = 0;
+      if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
+        extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));
+      size_t num_inputs = in_type->size() - extra_inputs;
+
       // copy input types from in_type
       std::vector<int> intypes(*in_type);
 
       // output types will be populated by inferType function
       std::vector<int> outtypes(out_type->size());
 
-      CHECK(callInferType(type_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
-                           intypes.data(), in_type->size(),
-                           outtypes.data(), out_type->size()))
-      << "Error calling InferType for custom operator '" << name_str << "'";
+      int retval = callInferType(type_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
+                                 intypes.data(), num_inputs,
+                                 outtypes.data(), out_type->size());
+      std::string msgs = getExtensionMsgs(msgSize, msgGet);
+      CHECK(retval) << "Error calling InferType for custom operator '" << name_str << "'" << msgs;
 
       // copy and assign modified input types from custom op to MXNet memory
-      for (size_t i = 0; i < in_type->size(); i++) {
+      for (size_t i = 0; i < num_inputs; i++) {
         TYPE_ASSIGN_CHECK(*in_type, i, intypes[i]);
       }
       // copy and assign output types from custom op to MXNet memory
@@ -646,6 +912,29 @@ void registerOperators(void *lib, int verbose) {
       return true;
     };
 
+    // lambda function to call infer type for subgraph ops
+    auto infer_subgraph_type = [=] (const nnvm::NodeAttrs& attrs,
+                                    std::vector<int> *in_type,
+                                    std::vector<int> *out_type) {
+      // convert attributes to vector of char*
+      std::vector<const char*> attr_keys, attr_vals;
+      for (auto &kv : attrs.dict) {
+        attr_keys.push_back(kv.first.c_str());
+        attr_vals.push_back(kv.second.c_str());
+      }
+
+      // get extra inputs, if exists
+      size_t extra_inputs = 0;
+      if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
+        extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));
+
+      auto in_first = in_type->begin();
+      auto in_last  = in_first + in_type->size() - extra_inputs;
+      std::vector<int> *sg_in_types = new std::vector<int>(in_first, in_last);
+
+      return mxnet::op::DefaultSubgraphOpType(attrs, sg_in_types, out_type);
+    };
+
     // lambda function to convert from external mutate_inputs to internal MXNet types
     auto mutate_inputs = [=](const nnvm::NodeAttrs& attrs) {
       // convert attributes to vector of char*
@@ -660,9 +949,11 @@ void registerOperators(void *lib, int verbose) {
       int indices_size = 0;
 
       // call mutate inputs function
-      CHECK(callMutateInputs(mutate_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
-                      &mutate_indices, &indices_size))
-      << "Error calling MutateInputs for custom operator '" << name_str << "'";
+      int retval = callMutateInputs(mutate_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
+                                    &mutate_indices, &indices_size);
+      std::string msgs = getExtensionMsgs(msgSize, msgGet);
+      CHECK(retval) << "Error calling MutateInputs for custom operator '" << name_str << "'"
+      << msgs;
 
       std::vector<uint32_t> mutate_indices_list(indices_size);
       for (int i=0; i < indices_size; i++) {
@@ -679,7 +970,7 @@ void registerOperators(void *lib, int verbose) {
                                 std::vector<int>* in_stypes,
                                 std::vector<int>* out_stypes) {
       if (stype_fp == nullptr) {
-        // InferSType is not defineid in customized lib.
+        // InferSType is not defined in customized lib.
         CHECK(mxnet::common::ContainsOnlyStorage(*in_stypes, mxnet::kDefaultStorage))
         << "Error input tensors are not dense for custom operator '" << name_str << "'";
         // set outputs as dense
@@ -693,18 +984,27 @@ void registerOperators(void *lib, int verbose) {
           attr_keys.push_back(kv.first.c_str());
           attr_vals.push_back(kv.second.c_str());
         }
+
+        // get extra inputs, if exists
+        size_t extra_inputs = 0;
+        if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
+          extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));
+        size_t num_inputs = in_stypes->size() - extra_inputs;
+
         // copy input types from in_stype
         std::vector<int> instypes(*in_stypes);
 
         // output types will be populated by inferType function
         std::vector<int> outstypes(out_stypes->size());
-        CHECK(callInferSType(stype_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
-                             instypes.data(), in_stypes->size(),
-                             outstypes.data(), out_stypes->size()))
-        << "Error calling InferSType for custom operator '" << name_str << "'";
+        int retval = callInferSType(stype_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
+                                    instypes.data(), num_inputs,
+                                    outstypes.data(), out_stypes->size());
+        std::string msgs = getExtensionMsgs(msgSize, msgGet);
+        CHECK(retval) << "Error calling InferSType for custom operator '" << name_str << "'"
+        << msgs;
 
         // copy and assign modified input storage types from custom op to MXNet memory.
-        for (size_t i = 0; i < in_stypes->size(); i++) {
+        for (size_t i = 0; i < num_inputs; i++) {
           STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, i, instypes[i]);
         }
         // copy and assign output storage types from custom op to MXNet memory.
@@ -717,6 +1017,25 @@ void registerOperators(void *lib, int verbose) {
       }
     };
 
+    // lambda function to set storage types for subgraph ops
+    auto infer_subgraph_storage_type = [=](const nnvm::NodeAttrs& attrs,
+                                           const int dev_mask,
+                                           DispatchMode* dispatch_mode,
+                                           std::vector<int>* in_stypes,
+                                           std::vector<int>* out_stypes) {
+        // get extra inputs, if exists
+        size_t extra_inputs = 0;
+        if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0)
+          extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS));
+
+        auto in_first = in_stypes->begin();
+        auto in_last  = in_first + in_stypes->size() - extra_inputs;
+        std::vector<int> *sg_in_stypes = new std::vector<int>(in_first, in_last);
+
+        return mxnet::op::DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode,
+                                                       sg_in_stypes, out_stypes);
+    };
+
     // FGradient register lambda
     auto grad_reg = [=](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
       // create node for gradient
@@ -789,19 +1108,24 @@ void registerOperators(void *lib, int verbose) {
       if (ctx.dev_mask() == Context::kCPU) {
         CHECK(createop_map.count("cpu") > 0)
           << "CPU CreateOpState not implemented for '" << name_str << "'";
-        CHECK(callCreateOpState(createop_map.at("cpu"), attr_keys.data(), attr_vals.data(),
-                                attr_keys.size(), &state_op_inst))
-          << "Error calling CreateOpState CPU for custom operator '" << name_str << "'";
+        int retval = callCreateOpState(createop_map.at("cpu"), attr_keys.data(), attr_vals.data(),
+                                       attr_keys.size(), &state_op_inst);
+        std::string msgs = getExtensionMsgs(msgSize, msgGet);
+        CHECK(retval) << "Error calling CreateOpState CPU for custom operator '" << name_str << "'"
+                      << msgs;
       } else if (ctx.dev_mask() == Context::kGPU) {
         CHECK(createop_map.count("gpu") > 0)
           << "GPU CreateOpState not implemented for '" << name_str << "'";
-        CHECK(callCreateOpState(createop_map.at("gpu"), attr_keys.data(), attr_vals.data(),
-                                attr_keys.size(), &state_op_inst))
-          << "Error calling CreateOpState GPU for custom operator '" << name_str << "'";
+        int retval = callCreateOpState(createop_map.at("gpu"), attr_keys.data(), attr_vals.data(),
+                                       attr_keys.size(), &state_op_inst);
+        std::string msgs = getExtensionMsgs(msgSize, msgGet);
+        CHECK(retval) << "Error calling CreateOpState GPU for custom operator '" << name_str << "'"
+        << msgs;
       }
 
+      std::string msgs = getExtensionMsgs(msgSize, msgGet);
       CHECK(state_op_inst != nullptr)
-        << "Error custom library failed to create stateful operator '" << name_str << "'";
+      << "Error custom library failed to create stateful operator '" << name_str << "'" << msgs;
 
       CustomStatefulOp* state_op = reinterpret_cast<CustomStatefulOp*>(state_op_inst);
       return OpStatePtr::Create<CustomStatefulOpWrapper>(state_op);
@@ -809,151 +1133,19 @@ void registerOperators(void *lib, int verbose) {
 
     /* -------------- BELOW IS THE REGISTRATION FOR CUSTOM OPERATORS --------------- */
 
-    // check if operator is already registered
-    const nnvm::Op *regOpPtr = dmlc::Registry<nnvm::Op>::Get()->Find(name);
-    nnvm::Op &regOp = dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(name);
-    int plevel = 10;
-    if (regOpPtr != nullptr) {
-      // overwrite registration of existing op with custom op
-      regOp.arguments.clear();
-      // set attribute with higher plevel (11) to allow re-registering once
-      // TODO(samskalicky): enable constant overwriting of registertion multiple times
-      plevel++;
-    }
-    // define supported resources for both subgraph ops and regular ops
-    regOp.set_attr<FResourceRequest>("FResourceRequest", resc_req, plevel);
-    if (!isSubgraphOp) {
-      regOp.set_attr_parser(attr_parser);
-      regOp.set_num_inputs(num_inputs);
-      regOp.set_num_outputs(num_outputs);
-      regOp.set_attr<nnvm::FInferType>("FInferType", infer_type, plevel);
-      regOp.set_attr<FInferStorageType>("FInferStorageType", infer_storage_type, plevel);
-      regOp.set_attr<mxnet::FInferShape>("FInferShape", infer_shape, plevel);
-      // optionally add fmutate inputs if user specified a function
-      if (mutate_fp != nullptr)
-        regOp.set_attr<nnvm::FMutateInputs>("FMutateInputs", mutate_inputs, plevel);
-    } else {
-      using namespace mxnet::op;
-      regOp.set_num_inputs(DefaultSubgraphOpNumInputs);
-      regOp.set_num_outputs(DefaultSubgraphOpNumOutputs);
-      regOp.set_attr<nnvm::FInferType>("FInferType", DefaultSubgraphOpType, plevel);
-      regOp.set_attr<mxnet::FInferShape>("FInferShape", DefaultSubgraphOpShape, plevel);
-      regOp.set_attr<FInferStorageType>("FInferStorageType",
-                                        DefaultSubgraphOpStorageType, plevel);
-      regOp.set_attr<nnvm::FMutateInputs>("FMutateInputs",
-                                          DefaultSubgraphOpMutableInputs, plevel);
-    }
-    // optionally add stateful forward
-    if (createop_map.size() != 0) {
-      regOp.set_attr<FCreateOpState>("FCreateOpState", create_opstate, plevel);
-      auto fstate_forward = [=](const OpStatePtr& state_ptr,
-                                const OpContext& ctx,
-                                const std::vector<NDArray>& inputs,
-                                const std::vector<OpReqType>& req,
-                                const std::vector<NDArray>& outputs) {
-        CustomFComputeDispatcher(name_str, nullptr, nullptr, nullptr,
-                                 callFStatefulComp, 1, &state_ptr, ctx, inputs, req, outputs);
-      };
-      if (createop_map.count("cpu") > 0)
-        regOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", fstate_forward, plevel);
-      if (createop_map.count("gpu") > 0)
-        regOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", fstate_forward, plevel);
-    } else {
-      auto forward_lambda = [=](const nnvm::NodeAttrs& attrs,
-                                const OpContext& ctx,
-                                const std::vector<NDArray>& inputs,
-                                const std::vector<OpReqType>& req,
-                                const std::vector<NDArray>& outputs) {
-        if (ctx.run_ctx.ctx.dev_mask() == Context::kCPU) {
-          CHECK_GT(forward_ctx_map.count("cpu"), 0);
-          fcomp_t fcomp = forward_ctx_map.at("cpu");
-          CustomFComputeDispatcher(name_str, callFComp, fcomp, &attrs,
-                                   nullptr, 0, nullptr, ctx, inputs, req, outputs);
-        } else if (ctx.run_ctx.ctx.dev_mask() == Context::kGPU) {
-          CHECK_GT(forward_ctx_map.count("gpu"), 0);
-          fcomp_t fcomp = forward_ctx_map.at("gpu");
-          CustomFComputeDispatcher(name_str, callFComp, fcomp, &attrs,
-                                   nullptr, 0, nullptr, ctx, inputs, req, outputs);
-        }
-      };
-      if (forward_ctx_map.count("cpu") > 0)
-        regOp.set_attr<FComputeEx>("FComputeEx<cpu>", forward_lambda, plevel);
-      if (forward_ctx_map.count("gpu") > 0)
-        regOp.set_attr<FComputeEx>("FComputeEx<gpu>", forward_lambda, plevel);
-    }
-    // optionally add fgradient if user specified a function, or for stateful ops
-    if (backward_ctx_map.size() != 0 || createop_map.size() != 0) {
-      std::string grad_name = "_backward_" + name_str;
-      nnvm::Op &gradOp = dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(grad_name);
-      regOp.set_attr<nnvm::FGradient>("FGradient", grad_reg, plevel);
-      gradOp.set_attr<nnvm::TIsBackward>("TIsBackward", true, plevel);
-      gradOp.set_attr<FInferStorageType>("FInferStorageType", infer_storage_type, plevel);
-      gradOp.set_attr<FResourceRequest>("FResourceRequest", resc_req, plevel);
-
-      if (!isSubgraphOp) {
-        // register attr parser and standard functions for non-subgraph ops
-        gradOp.set_attr_parser(attr_parser);
-        gradOp.set_num_inputs(num_inouts);
-        gradOp.set_num_outputs(num_inputs);
-      } else {
-        // for subgraph ops use special functions that do not invoke attr_parser
-        using namespace mxnet::op;
-        auto grad_inouts = [=](const nnvm::NodeAttrs& attrs) {
-          // for backward passes, inputs + outputs + input gradients (one for each output)
-          uint32_t cnt = DefaultSubgraphOpNumInputs(attrs);
-          cnt += 2 * DefaultSubgraphOpNumOutputs(attrs);
-          return cnt;
-        };
-        gradOp.set_num_inputs(grad_inouts);
-        gradOp.set_num_outputs(DefaultSubgraphOpNumInputs);
-      }
-
-      if (createop_map.size() != 0) {
-        // for stateful operators
-        gradOp.set_attr<bool>("TIsLayerOpBackward", true, plevel);
-        auto fstate_backward = [=](const OpStatePtr& state_ptr,
-                                   const OpContext& ctx,
-                                   const std::vector<NDArray>& inputs,
-                                   const std::vector<OpReqType>& req,
-                                   const std::vector<NDArray>& outputs) {
-          CustomFComputeDispatcher(name_str, nullptr, nullptr, nullptr,
-                                   callFStatefulComp, 0, &state_ptr, ctx, inputs, req, outputs);
-        };
-        gradOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", fstate_backward, plevel);
-        gradOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", fstate_backward, plevel);
-      } else {
-        // for stateless operators
-        if (backward_ctx_map.count("cpu") > 0) {
-          fcomp_t fcomp_back_cpu = backward_ctx_map.at("cpu");
-          auto backward_cpu_lambda = [=](const nnvm::NodeAttrs& attrs,
-                                         const OpContext& ctx,
-                                         const std::vector<NDArray>& inputs,
-                                         const std::vector<OpReqType>& req,
-                                         const std::vector<NDArray>& outputs) {
-            CustomFComputeDispatcher(name_str, callFComp, fcomp_back_cpu, &attrs,
-                                     nullptr, 0, nullptr, ctx, inputs, req, outputs);
-          };
-          gradOp.set_attr<FComputeEx>("FComputeEx<cpu>", backward_cpu_lambda, plevel);
-        }
-        if (backward_ctx_map.count("gpu") > 0) {
-          fcomp_t fcomp_back_gpu = backward_ctx_map.at("gpu");
-          auto backward_gpu_lambda = [=](const nnvm::NodeAttrs& attrs,
-                                         const OpContext& ctx,
-                                         const std::vector<NDArray>& inputs,
-                                         const std::vector<OpReqType>& req,
-                                         const std::vector<NDArray>& outputs) {
-            CustomFComputeDispatcher(name_str, callFComp, fcomp_back_gpu, &attrs,
-                                     nullptr, 0, nullptr, ctx, inputs, req, outputs);
-          };
-          gradOp.set_attr<FComputeEx>("FComputeEx<gpu>", backward_gpu_lambda, plevel);
-        }
-      }
-    }
-    regOp.add_argument("data", "NDArray[]", "Source inputs");
+    registerOp(name, name_str, isSubgraphOp, resc_req, attr_parser, num_inputs, num_outputs,
+               num_inouts, infer_type, infer_shape, infer_storage_type, mutate_inputs,
+               num_subgraph_inputs, infer_subgraph_type, infer_subgraph_shape,
+               infer_subgraph_storage_type, create_opstate, grad_reg, mutate_fp,
+               createop_map, forward_ctx_map, backward_ctx_map, callFComp, callFStatefulComp,
+               msgSize, msgGet);
   }
 }
 
-void registerPartitioners(void *lib, int verbose) {
+void registerPartitioners(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
+                       mxnet::ext::msgGet_t msgGet) {
+  using namespace mxnet::ext;
+
   // get C type interface functions
   opCallFree_t callFree = get_func<opCallFree_t>(lib, const_cast<char*>(MXLIB_OPCALLFREE_STR));
 
@@ -1035,7 +1227,10 @@ void registerPartitioners(void *lib, int verbose) {
   }
 }
 
-void registerPasses(void *lib, int verbose) {
+void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
+                       mxnet::ext::msgGet_t msgGet) {
+  using namespace mxnet::ext;
+
   // get C type interface functions
   opCallFree_t callFree = get_func<opCallFree_t>(lib, const_cast<char*>(MXLIB_OPCALLFREE_STR));
 
@@ -1230,17 +1425,18 @@ void registerPasses(void *lib, int verbose) {
       };
 
       char* out_json;
-      CHECK(callGraphPass(pass_fp, in_json.c_str(), &out_json, opt_keys.data(),
-                          opt_vals.data(), opt_keys.size(), pass_name,
-                          arg_names.data(), arg_names.size(), arg_data.data(),
-                          arg_shapes.data(), arg_dims.data(), arg_types.data(),
-                          arg_verIDs.data(), arg_dev_type.data(),
-                          arg_dev_id.data(), aux_names.data(), aux_names.size(),
-                          aux_data.data(), aux_shapes.data(), aux_dims.data(),
-                          aux_types.data(), aux_verIDs.data(),
-                          aux_dev_type.data(), aux_dev_id.data(),
-                          ndarray_malloc, &ndarray_alloc))
-      << "Error calling graph pass for '" << pass_name << "'";
+      int retval = callGraphPass(pass_fp, in_json.c_str(), &out_json, opt_keys.data(),
+                                 opt_vals.data(), opt_keys.size(), pass_name,
+                                 arg_names.data(), arg_names.size(), arg_data.data(),
+                                 arg_shapes.data(), arg_dims.data(), arg_types.data(),
+                                 arg_verIDs.data(), arg_dev_type.data(),
+                                 arg_dev_id.data(), aux_names.data(), aux_names.size(),
+                                 aux_data.data(), aux_shapes.data(), aux_dims.data(),
+                                 aux_types.data(), aux_verIDs.data(),
+                                 aux_dev_type.data(), aux_dev_id.data(),
+                                 ndarray_malloc, &ndarray_alloc);
+      std::string msgs = getExtensionMsgs(msgSize, msgGet);
+      CHECK(retval) << "Error calling graph pass for '" << pass_name << "'" << msgs;
 
       std::string out_string(out_json);
       nnvm::Graph out_graph = nnvm::pass::LoadJSON(out_string);
@@ -1271,21 +1467,31 @@ int MXLoadLib(const char *path, unsigned verbose) {
     LOG(FATAL) << "Unable to load library";
 
   // check that library and MXNet use same version of library API
-  opVersion_t opVersion = get_func<opVersion_t>(lib, const_cast<char*>(MXLIB_OPVERSION_STR));
+  mxnet::ext::opVersion_t opVersion =
+    get_func<mxnet::ext::opVersion_t>(lib, const_cast<char*>(MXLIB_OPVERSION_STR));
   int libVersion =  opVersion();
   if (MX_LIBRARY_VERSION != libVersion)
     LOG(FATAL) << "Library version (" << libVersion << ") does not match MXNet version ("
                << MX_LIBRARY_VERSION << ")";
 
+  // get error messaging APIs
+  mxnet::ext::msgSize_t msgSize =
+    get_func<mxnet::ext::msgSize_t>(lib, const_cast<char*>(MXLIB_MSGSIZE_STR));
+  mxnet::ext::msgGet_t msgGet =
+    get_func<mxnet::ext::msgGet_t>(lib, const_cast<char*>(MXLIB_MSGGET_STR));
+
   // initialize library by passing MXNet version
-  initialize_t initialize = get_func<initialize_t>(lib, const_cast<char*>(MXLIB_INITIALIZE_STR));
-  if (!initialize(static_cast<int>(MXNET_VERSION)))
-    LOG(FATAL) << "Library failed to initialize";
+  mxnet::ext::initialize_t initialize =
+    get_func<mxnet::ext::initialize_t>(lib, const_cast<char*>(MXLIB_INITIALIZE_STR));
+  if (!initialize(static_cast<int>(MXNET_VERSION))) {
+    std::string msgs = getExtensionMsgs(msgSize, msgGet);
+    LOG(FATAL) << "Library failed to initialize" << msgs;
+  }
 
   // find ops, partitioners, and passes in library
-  registerOperators(lib, verbose);
-  registerPartitioners(lib, verbose);
-  registerPasses(lib, verbose);
+  registerOperators(lib, verbose, msgSize, msgGet);
+  registerPartitioners(lib, verbose, msgSize, msgGet);
+  registerPasses(lib, verbose, msgSize, msgGet);
   API_END();
 }
 
diff --git a/src/operator/subgraph/partitioner/custom_subgraph_property.h b/src/operator/subgraph/partitioner/custom_subgraph_property.h
index ea721c5..b936b05 100644
--- a/src/operator/subgraph/partitioner/custom_subgraph_property.h
+++ b/src/operator/subgraph/partitioner/custom_subgraph_property.h
@@ -49,12 +49,12 @@ namespace op {
 class CustomContainOpSelector: public SubgraphSelector {
  public:
   explicit CustomContainOpSelector(std::unordered_map<std::string, int> supported_nodes,
-                                   void* sel_inst, partCallSelect_t callSelect,
-                                   partCallSelectInput_t callSelectInput,
-                                   partCallSelectOutput_t callSelectOutput,
-                                   partCallFilter_t callFilter,
-                                   partCallReset_t callReset,
-                                   opCallFree_t callFree,
+                                   void* sel_inst, mxnet::ext::partCallSelect_t callSelect,
+                                   mxnet::ext::partCallSelectInput_t callSelectInput,
+                                   mxnet::ext::partCallSelectOutput_t callSelectOutput,
+                                   mxnet::ext::partCallFilter_t callFilter,
+                                   mxnet::ext::partCallReset_t callReset,
+                                   mxnet::ext::opCallFree_t callFree,
                                    std::unordered_map<const nnvm::Node*, unsigned> node2id) :
   supported_nodes_(supported_nodes), sel_inst_(sel_inst), callSelect_(callSelect),
     callSelectInput_(callSelectInput), callSelectOutput_(callSelectOutput),
@@ -123,12 +123,12 @@ class CustomContainOpSelector: public SubgraphSelector {
 
   std::unordered_map<std::string, int> supported_nodes_;
   void* sel_inst_;
-  partCallSelect_t callSelect_;
-  partCallSelectInput_t callSelectInput_;
-  partCallSelectOutput_t callSelectOutput_;
-  partCallFilter_t callFilter_;
-  partCallReset_t callReset_;
-  opCallFree_t callFree_;
+  mxnet::ext::partCallSelect_t callSelect_;
+  mxnet::ext::partCallSelectInput_t callSelectInput_;
+  mxnet::ext::partCallSelectOutput_t callSelectOutput_;
+  mxnet::ext::partCallFilter_t callFilter_;
+  mxnet::ext::partCallReset_t callReset_;
+  mxnet::ext::opCallFree_t callFree_;
   std::unordered_map<const nnvm::Node*, unsigned> node2id_;
 };
 
@@ -155,18 +155,18 @@ class  CustomSubgraphProperty: public SubgraphProperty {
     review_subgraph_(nullptr),
     subgraph_op_name("error") {}
   CustomSubgraphProperty(std::string subgraph_prop_name,
-                         partCallSupportedOps_t call_supported_ops,
-                         supportedOps_t supported_ops,
-                         partCallCreateSelector_t call_create_selector,
-                         createSelector_t create_selector,
-                         partCallSelect_t callSelect,
-                         partCallSelectInput_t callSelectInput,
-                         partCallSelectOutput_t callSelectOutput,
-                         partCallFilter_t callFilter,
-                         partCallReset_t callReset,
-                         partCallReviewSubgraph_t call_review_subgraph,
-                         reviewSubgraph_t review_subgraph,
-                         opCallFree_t call_free,
+                         mxnet::ext::partCallSupportedOps_t call_supported_ops,
+                         mxnet::ext::supportedOps_t supported_ops,
+                         mxnet::ext::partCallCreateSelector_t call_create_selector,
+                         mxnet::ext::createSelector_t create_selector,
+                         mxnet::ext::partCallSelect_t callSelect,
+                         mxnet::ext::partCallSelectInput_t callSelectInput,
+                         mxnet::ext::partCallSelectOutput_t callSelectOutput,
+                         mxnet::ext::partCallFilter_t callFilter,
+                         mxnet::ext::partCallReset_t callReset,
+                         mxnet::ext::partCallReviewSubgraph_t call_review_subgraph,
+                         mxnet::ext::reviewSubgraph_t review_subgraph,
+                         mxnet::ext::opCallFree_t call_free,
                          std::string op_name) :
       subgraph_prop(subgraph_prop_name),
       call_supported_ops_(call_supported_ops),
@@ -429,7 +429,7 @@ class  CustomSubgraphProperty: public SubgraphProperty {
           if (e.node->attrs.dict.count(MX_STR_SHAPE) > 0) {
             std::string& shape = e.node->attrs.dict[MX_STR_SHAPE];
             // add this shape to the list
-            ss << getShapeAt(shape, e.index);
+            ss << mxnet::ext::getShapeAt(shape, e.index);
           }
           if (i < sym.outputs.size()-1)
             ss << ",";
@@ -446,7 +446,7 @@ class  CustomSubgraphProperty: public SubgraphProperty {
           if (e.node->attrs.dict.count(MX_STR_DTYPE) > 0) {
             std::string& dtype = e.node->attrs.dict[MX_STR_DTYPE];
             // add this dtype to the list
-            ss << getDtypeAt(dtype, e.index);
+            ss << mxnet::ext::getDtypeAt(dtype, e.index);
           }
           if (i < sym.outputs.size()-1)
             ss << ",";
@@ -489,7 +489,7 @@ class  CustomSubgraphProperty: public SubgraphProperty {
         // get dtype string from other node
         std::string& dtype = orig.node->attrs.dict[MX_STR_DTYPE];
         std::stringstream ss;
-        ss << "[" << getDtypeAt(dtype, orig.index) << "]";
+        ss << "[" << mxnet::ext::getDtypeAt(dtype, orig.index) << "]";
         e->node->attrs.dict[MX_STR_DTYPE] = ss.str();
       }
 
@@ -498,7 +498,7 @@ class  CustomSubgraphProperty: public SubgraphProperty {
         std::string& shape = orig.node->attrs.dict[MX_STR_SHAPE];
         // create new shape string for this node
         std::stringstream ss;
-        ss << "[" << getShapeAt(shape, orig.index) << "]";
+        ss << "[" << mxnet::ext::getShapeAt(shape, orig.index) << "]";
         e->node->attrs.dict[MX_STR_SHAPE] = ss.str();
       }
     }
@@ -512,18 +512,18 @@ class  CustomSubgraphProperty: public SubgraphProperty {
   }
 
   std::string subgraph_prop;
-  partCallSupportedOps_t call_supported_ops_;
-  supportedOps_t supported_ops_;
-  partCallCreateSelector_t call_create_selector_;
-  createSelector_t create_selector_;
-  partCallSelect_t callSelect_;
-  partCallSelectInput_t callSelectInput_;
-  partCallSelectOutput_t callSelectOutput_;
-  partCallFilter_t callFilter_;
-  partCallReset_t callReset_;
-  partCallReviewSubgraph_t call_review_subgraph_;
-  reviewSubgraph_t review_subgraph_;
-  opCallFree_t call_free_;
+  mxnet::ext::partCallSupportedOps_t call_supported_ops_;
+  mxnet::ext::supportedOps_t supported_ops_;
+  mxnet::ext::partCallCreateSelector_t call_create_selector_;
+  mxnet::ext::createSelector_t create_selector_;
+  mxnet::ext::partCallSelect_t callSelect_;
+  mxnet::ext::partCallSelectInput_t callSelectInput_;
+  mxnet::ext::partCallSelectOutput_t callSelectOutput_;
+  mxnet::ext::partCallFilter_t callFilter_;
+  mxnet::ext::partCallReset_t callReset_;
+  mxnet::ext::partCallReviewSubgraph_t call_review_subgraph_;
+  mxnet::ext::reviewSubgraph_t review_subgraph_;
+  mxnet::ext::opCallFree_t call_free_;
   std::unordered_map<std::string, int> supported_nodes;
   std::string subgraph_op_name;
   std::vector<std::pair<std::string, std::string>> options_map_;