You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sa...@apache.org on 2020/11/09 17:42:50 UTC

[incubator-mxnet] branch master updated: Extension bug fixes (#19469)

This is an automated email from the ASF dual-hosted git repository.

samskalicky pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 564c6d3  Extension bug fixes (#19469)
564c6d3 is described below

commit 564c6d307e3439c1e5bb9bbd7e82d6744bea6a83
Author: Sam Skalicky <sa...@gmail.com>
AuthorDate: Mon Nov 9 09:41:45 2020 -0800

    Extension bug fixes (#19469)
    
    * initial commit
    
    * syntax fix
    
    * spacing
    
    * added test case
---
 python/mxnet/gluon/block.py                        | 44 +++++++++----
 src/c_api/c_api.cc                                 | 68 ++++++++++----------
 .../partitioner/custom_subgraph_property.h         | 74 ++++++++++++----------
 tests/python/unittest/test_extensions.py           | 17 +++--
 4 files changed, 120 insertions(+), 83 deletions(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index a13c606..7655796 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -1030,14 +1030,35 @@ class HybridBlock(Block):
 
         arg_dict, aux_dict = dict(), dict()
         if self._backend:
-            ctx = args[0].context
+            # set context for inputs
+            _, _, ctx_set, _ = _gather_type_ctx_info(list(args))
+            ctx = ctx_set.pop() if len(ctx_set) > 0 else None
             # get list of params in the order of out.list_arguments
-            arg_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data()
-                             for name in out.list_arguments()})
-            aux_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data()
-                             for name in out.list_auxiliary_states()})
-            # Partition the graph.
-            out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts)
+            input_shapes = dict()
+            for name in out.list_arguments():
+                if name in data_names.keys() and data_names[name] < len(args):
+                    if isinstance(args[data_names[name]], NDArray):
+                        arg_dict[name] = args[data_names[name]]
+                    elif (isinstance(args[data_names[name]], symbol.Symbol) and
+                          '__shape__' in args[data_names[name]].list_attr()):
+                        shape_str = args[data_names[name]].list_attr()['__shape__']
+                        input_shapes[name] = tuple(map(int, shape_str.strip('()').split(',')))
+                elif name in params:
+                    arg_dict[name] = params[name].data()
+
+            for name in out.list_auxiliary_states():
+                if name in data_names.keys() and data_names[name] < len(args):
+                    if isinstance(args[data_names[name]], NDArray):
+                        aux_dict[name] = args[data_names[name]]
+                    elif (isinstance(args[data_names[name]], symbol.Symbol) and
+                          '__shape__' in args[data_names[name]].list_attr()):
+                        shape_str = args[data_names[name]].list_attr()['__shape__']
+                        input_shapes[name] = tuple(map(int, shape_str.strip('()').split(',')))
+                elif name in params:
+                    aux_dict[name] = params[name].data()
+
+            # Partition the graph
+            out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, input_shapes, **self._backend_opts)
 
             # convert to numpy symbol if needed
             if _mx_npx.is_np_array():
@@ -1080,7 +1101,7 @@ class HybridBlock(Block):
                     param = Parameter(name, dtype=param_data.dtype)
                     param._var_name = name
                     serialization_name = name  # HybridBlock.export
-                    param._load_init(param_data, args[0].context)
+                    param._load_init(param_data, param_data.context)
                 triple = (False, serialization_name, param)
 
             self._cached_op_args.append(triple)
@@ -1182,14 +1203,11 @@ class HybridBlock(Block):
 
         # do part of forward API call
         has_symbol, has_ndarray, ctx_set, _ = _gather_type_ctx_info([x] + list(args))
-        if has_symbol:
-            raise ValueError('Inputs must be NDArrays for the optimize_for API'
-                             ' Please check the type of the args.\n')
         if not has_symbol and not has_ndarray:
-            raise ValueError('In HybridBlock, there must be one NDArray as input.'
+            raise ValueError('In HybridBlock, there must be one NDArray or one Symbol in the input.'
                              ' Please check the type of the args.\n')
         if len(ctx_set) > 1:
-            raise ValueError('Find multiple contexts in the input, '
+            raise ValueError('Found multiple contexts in the input, '
                              'After hybridized, the HybridBlock only supports one input '
                              'context. You can print the ele.ctx in the '
                              'input arguments to inspect their contexts. '
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index a02c0a6..959f2e0 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -1377,50 +1377,54 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
 
       // convert input args
       for (size_t i=0; i < in_arg_names.size(); i++) {
-        arg_names.push_back(in_arg_names[i].c_str());
-        const NDArray &in_arg = *(in_args_ptr[i]);
+        if (in_args_ptr[i] != nullptr) {
+          arg_names.push_back(in_arg_names[i].c_str());
+          const NDArray &in_arg = *(in_args_ptr[i]);
 
 #if MXNET_USE_MKLDNN == 1
-        // reorder data if in MKLDNN format
-        if (in_arg.IsMKLDNNData()) {
-          in_arg.Reorder2DefaultAsync();
-          in_arg.WaitToRead();
-        }
+          // reorder data if in MKLDNN format
+          if (in_arg.IsMKLDNNData()) {
+            in_arg.Reorder2DefaultAsync();
+            in_arg.WaitToRead();
+          }
 #endif
 
-        // pull out parts of NDArray to send to backend
-        arg_data.push_back(in_arg.data().dptr_);
-        arg_shapes.push_back(in_arg.shape().data());
-        arg_dims.push_back(in_arg.shape().ndim());
-        arg_types.push_back(in_arg.dtype());
-        arg_verIDs.push_back(in_arg.version());
-        const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
-        arg_dev_type.push_back(arg_ctx_str);
-        arg_dev_id.push_back(in_arg.ctx().real_dev_id());
+          // pull out parts of NDArray to send to backend
+          arg_data.push_back(in_arg.data().dptr_);
+          arg_shapes.push_back(in_arg.shape().data());
+          arg_dims.push_back(in_arg.shape().ndim());
+          arg_types.push_back(in_arg.dtype());
+          arg_verIDs.push_back(in_arg.version());
+          const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
+          arg_dev_type.push_back(arg_ctx_str);
+          arg_dev_id.push_back(in_arg.ctx().real_dev_id());
+        }
       }
 
       // convert input aux
       for (size_t i=0; i < in_aux_names.size(); i++) {
-        aux_names.push_back(in_aux_names[i].c_str());
-        const auto &in_aux = *(in_aux_ptr[i]);
+        if (in_aux_ptr[i] != nullptr) {
+          aux_names.push_back(in_aux_names[i].c_str());
+          const auto &in_aux = *(in_aux_ptr[i]);
 
 #if MXNET_USE_MKLDNN == 1
-        // reorder data if in MKLDNN format
-        if (in_aux.IsMKLDNNData()) {
-          in_aux.Reorder2DefaultAsync();
-          in_aux.WaitToRead();
-        }
+          // reorder data if in MKLDNN format
+          if (in_aux.IsMKLDNNData()) {
+            in_aux.Reorder2DefaultAsync();
+            in_aux.WaitToRead();
+          }
 #endif
 
-        // pull out parts of NDArray to send to backend
-        aux_data.push_back(in_aux.data().dptr_);
-        aux_shapes.push_back(in_aux.shape().data());
-        aux_dims.push_back(in_aux.shape().ndim());
-        aux_types.push_back(in_aux.dtype());
-        aux_verIDs.push_back(in_aux.version());
-        const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
-        aux_dev_type.push_back(aux_ctx_str);
-        aux_dev_id.push_back(in_aux.ctx().real_dev_id());
+          // pull out parts of NDArray to send to backend
+          aux_data.push_back(in_aux.data().dptr_);
+          aux_shapes.push_back(in_aux.shape().data());
+          aux_dims.push_back(in_aux.shape().ndim());
+          aux_types.push_back(in_aux.dtype());
+          aux_verIDs.push_back(in_aux.version());
+          const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
+          aux_dev_type.push_back(aux_ctx_str);
+          aux_dev_id.push_back(in_aux.ctx().real_dev_id());
+        }
       }
 
       // convert graph to string
diff --git a/src/operator/subgraph/partitioner/custom_subgraph_property.h b/src/operator/subgraph/partitioner/custom_subgraph_property.h
index 49d5a8f..cd41035 100644
--- a/src/operator/subgraph/partitioner/custom_subgraph_property.h
+++ b/src/operator/subgraph/partitioner/custom_subgraph_property.h
@@ -208,26 +208,28 @@ class  CustomSubgraphProperty: public SubgraphProperty {
     arg_dev_type.clear();
     arg_dev_id.clear();
     for (size_t i=0; i < in_arg_names.size(); i++) {
-      arg_names.push_back(in_arg_names[i].c_str());
-      const NDArray &in_arg = *(in_args_ptr[i]);
+      if (in_args_ptr[i] != nullptr) {
+        arg_names.push_back(in_arg_names[i].c_str());
+        const NDArray &in_arg = *(in_args_ptr[i]);
 
 #if MXNET_USE_MKLDNN == 1
-      // reorder data if in MKLDNN format
-      if (in_arg.IsMKLDNNData()) {
-        in_arg.Reorder2DefaultAsync();
-        in_arg.WaitToRead();
-      }
+        // reorder data if in MKLDNN format
+        if (in_arg.IsMKLDNNData()) {
+          in_arg.Reorder2DefaultAsync();
+          in_arg.WaitToRead();
+        }
 #endif
 
-      // pull out parts of NDArray to send to backend
-      arg_data.push_back(in_arg.data().dptr_);
-      arg_shapes.push_back(in_arg.shape().data());
-      arg_dims.push_back(in_arg.shape().ndim());
-      arg_types.push_back(in_arg.dtype());
-      arg_verIDs.push_back(in_arg.version());
-      const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
-      arg_dev_type.push_back(arg_ctx_str);
-      arg_dev_id.push_back(in_arg.ctx().real_dev_id());
+        // pull out parts of NDArray to send to backend
+        arg_data.push_back(in_arg.data().dptr_);
+        arg_shapes.push_back(in_arg.shape().data());
+        arg_dims.push_back(in_arg.shape().ndim());
+        arg_types.push_back(in_arg.dtype());
+        arg_verIDs.push_back(in_arg.version());
+        const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
+        arg_dev_type.push_back(arg_ctx_str);
+        arg_dev_id.push_back(in_arg.ctx().real_dev_id());
+      }
     }
 
     // convert input aux
@@ -240,26 +242,28 @@ class  CustomSubgraphProperty: public SubgraphProperty {
     aux_dev_type.clear();
     aux_dev_id.clear();
     for (size_t i=0; i < in_aux_names.size(); i++) {
-      aux_names.push_back(in_aux_names[i].c_str());
-      const auto &in_aux = *(in_aux_ptr[i]);
+      if (in_aux_ptr[i] != nullptr) {
+        aux_names.push_back(in_aux_names[i].c_str());
+        const auto &in_aux = *(in_aux_ptr[i]);
 
 #if MXNET_USE_MKLDNN == 1
-      // reorder data if in MKLDNN format
-      if (in_aux.IsMKLDNNData()) {
-        in_aux.Reorder2DefaultAsync();
-        in_aux.WaitToRead();
-      }
+        // reorder data if in MKLDNN format
+        if (in_aux.IsMKLDNNData()) {
+          in_aux.Reorder2DefaultAsync();
+          in_aux.WaitToRead();
+        }
 #endif
 
-      // pull out parts of NDArray to send to backend
-      aux_data.push_back(in_aux.data().dptr_);
-      aux_shapes.push_back(in_aux.shape().data());
-      aux_dims.push_back(in_aux.shape().ndim());
-      aux_types.push_back(in_aux.dtype());
-      aux_verIDs.push_back(in_aux.version());
-      const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
-      aux_dev_type.push_back(aux_ctx_str);
-      aux_dev_id.push_back(in_aux.ctx().real_dev_id());
+        // pull out parts of NDArray to send to backend
+        aux_data.push_back(in_aux.data().dptr_);
+        aux_shapes.push_back(in_aux.shape().data());
+        aux_dims.push_back(in_aux.shape().ndim());
+        aux_types.push_back(in_aux.dtype());
+        aux_verIDs.push_back(in_aux.version());
+        const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
+        aux_dev_type.push_back(aux_ctx_str);
+        aux_dev_id.push_back(in_aux.ctx().real_dev_id());
+      }
     }
 
     // remove all graph attrs, some cannot be saved to json
@@ -285,13 +289,17 @@ class  CustomSubgraphProperty: public SubgraphProperty {
         for (unsigned oid = 0; oid < node->num_outputs(); oid++) {
           const uint32_t out_entry_id = indexed_graph.entry_id(nid, oid);
           mxnet::TShape& shape = shapes[out_entry_id];
-          ss << shape;
+          if (shape.ndim() == -1)
+            ss << "[None]";
+          else
+            ss << shape;
           if (oid < node->num_outputs()-1) ss << ",";
         }
         ss << "]";
         node->attrs.dict[MX_STR_SHAPE] = ss.str();
       }
     }
+
     // set dtype attrs for each node in the graph
     if (g.HasAttr("dtype")) {
       std::vector<int> dtypes = g.GetAttr<std::vector<int> >("dtype");
diff --git a/tests/python/unittest/test_extensions.py b/tests/python/unittest/test_extensions.py
index 52f9995..8d94680 100644
--- a/tests/python/unittest/test_extensions.py
+++ b/tests/python/unittest/test_extensions.py
@@ -166,16 +166,23 @@ def test_subgraph():
     # check that result matches one executed by MXNet
     assert_almost_equal(out[0].asnumpy(), out4[0].asnumpy(), rtol=1e-3, atol=1e-3)
 
-    # Gluon Hybridize partitioning with shapes/types
+    # Gluon Hybridize partitioning with sym.var
     sym_block2 = nn.SymbolBlock(sym, [a,b])
     sym_block2.initialize()
+    a_var = mx.sym.var('a',shape=(3,2))
+    b_var = mx.sym.var('b',shape=(3,2))
+    sym_block2.optimize_for(a_var, b_var, backend='myProp')
+
+    # Gluon Hybridize partitioning with shapes/types
+    sym_block3 = nn.SymbolBlock(sym, [a,b])
+    sym_block3.initialize()
     a_data = mx.nd.ones((3,2))
     b_data = mx.nd.ones((3,2))
-    sym_block2.optimize_for(a_data, b_data, backend='myProp')
-    sym_block2.export('optimized')
-    sym_block3 = nn.SymbolBlock.imports('optimized-symbol.json',['a','b'],
+    sym_block3.optimize_for(a_data, b_data, backend='myProp')
+    sym_block3.export('optimized')
+    sym_block4 = nn.SymbolBlock.imports('optimized-symbol.json',['a','b'],
                                         'optimized-0000.params')
 
-    out5 = sym_block3(a_data, b_data)
+    out5 = sym_block4(a_data, b_data)
     # check that result matches one executed by MXNet
     assert_almost_equal(out[0].asnumpy(), out5[0].asnumpy(), rtol=1e-3, atol=1e-3)