You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/16 23:12:01 UTC

[incubator-mxnet] branch master updated: Support sparse for custom python operators (#8620)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 938eda9  Support sparse for custom python operators (#8620)
938eda9 is described below

commit 938eda901e10e87d469fd3cff43091060a6e1033
Author: Anirudh Subramanian <an...@gmail.com>
AuthorDate: Thu Nov 16 15:11:55 2017 -0800

    Support sparse for custom python operators (#8620)
    
    * Add asscipy support and coo format support
    
    * Comment misalignment change
    
    * Add documentation for Sparse NDarray
    
    * Change comment
    
    * Adding comments and support for dtype
    
    * Modifying tests
    
    * Add spsp None check
    
    * Fix lint
    
    * Custom operators for sparse
    
    * Use DISPATCH_MODE_ASSIGN_CHECK
    
    * Change NDArray to _ndarray_cls
    
    * Remove redundant code
    
    * Add a test to make sure the NDArray is an instance of CSRNDArray
    
    * Fix lint
    
    * Fix test
    
    * Trigger CI
---
 include/mxnet/c_api.h                  |   8 +-
 python/mxnet/operator.py               | 161 +++++++++++++++++++++++++++++----
 src/operator/custom/custom.cc          | 100 ++++++++++++++++++--
 tests/python/unittest/test_operator.py |  31 +++++--
 4 files changed, 266 insertions(+), 34 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 8ea2b0e..0726566 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -146,7 +146,9 @@ enum CustomOpPropCallbacks {
   kCustomOpPropInferShape,
   kCustomOpPropDeclareBackwardDependency,
   kCustomOpPropCreateOperator,
-  kCustomOpPropInferType
+  kCustomOpPropInferType,
+  kCustomOpPropInferStorageType,
+  kCustomOpPropBackwardInferStorageType
 };
 
 
@@ -158,6 +160,10 @@ typedef int (*CustomOpListFunc)(char*** /*args*/, void* /*state*/);
 typedef int (*CustomOpInferShapeFunc)(int /*num_input*/, int* /*ndims*/,
                                       unsigned** /*shapes*/, void* /*state*/);
 typedef int (*CustomOpInferTypeFunc)(int /*num_input*/, int* /*types*/, void* /*state*/);
+typedef int (*CustomOpInferStorageTypeFunc)(int /*num_input*/, int* /*stypes*/, void* /*state*/);
+typedef int (*CustomOpBackwardInferStorageTypeFunc)(int /*num_input*/,
+                                                    int * /*stypes*/,
+                                                    void * /*state*/);
 typedef int (*CustomOpBwdDepFunc)(const int* /*out_grad*/, const int* /*in_data*/,
                                   const int* /*out_data*/, int* /*num_deps*/,
                                   int** /*rdeps*/, void* /*state*/);
diff --git a/python/mxnet/operator.py b/python/mxnet/operator.py
index 1337bbc..8fcf127 100644
--- a/python/mxnet/operator.py
+++ b/python/mxnet/operator.py
@@ -16,7 +16,7 @@
 # under the License.
 
 # coding: utf-8
-# pylint: disable=invalid-name, protected-access, too-many-arguments, no-self-use, too-many-locals, broad-except
+# pylint: disable=invalid-name, protected-access, too-many-arguments, no-self-use, too-many-locals, broad-except, too-many-lines
 """numpy interface for operators."""
 from __future__ import absolute_import
 
@@ -30,6 +30,9 @@ from .base import _LIB, check_call, MXCallbackList
 from .base import c_array, c_str, mx_uint, mx_float, ctypes2numpy_shared, NDArrayHandle, py_str
 from . import symbol, context
 from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
+from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, _STORAGE_TYPE_ID_TO_STR
+from .ndarray import _ndarray_cls
+
 
 c_int_p = POINTER(c_int)
 
@@ -513,6 +516,51 @@ class CustomOpProp(object):
         return in_type, [in_type[0]]*len(self.list_outputs()), \
             [in_type[0]]*len(self.list_auxiliary_states())
 
+    def infer_storage_type(self, in_stype):
+        """infer_storage_type interface. Used to infer storage type of
+        inputs and outputs in the forward pass.
+
+        Parameters
+        ----------
+        in_stype : list of stypes, Valid stypes are default, row_sparse and
+            csr
+
+        Returns
+        -------
+        in_stype : list
+            list of argument stypes.
+        out_stype : list
+            list of output types calculated from in_stype,
+            in the same order as declared in list_outputs.
+        aux_type : Optional, list
+            list of aux types calculated from in_stype,
+            in the same order as declared in list_auxiliary_states.
+        """
+        return in_stype, [in_stype[0]]*len(self.list_outputs()), \
+            [in_stype[0]]*len(self.list_auxiliary_states())
+
+    def infer_storage_type_backward(self, in_stype):
+        """infer_storage_type_backward interface. Used to infer storage
+        type of inputs and outputs in the backward pass.
+
+        Parameters
+        ----------
+        in_stype : list of stypes. Provide the in_stypes in the
+        following order: output_grads, in_data, out_data, aux_data(optional)
+
+        Returns
+        -------
+        in_stype : list
+            list of input stypes.
+        out_stype : list
+            list of output stypes calculated from in_stype.
+        aux_stype : list
+            list of aux stypes calculated from in_stype,
+            in the same order as declared in list_auxiliary_states.
+        """
+        return in_stype, [in_stype[0]]*len(self.list_outputs()), \
+            [in_stype[0]]*len(self.list_auxiliary_states())
+
     def list_outputs(self):
         """list_outputs interface. Can override when creating new operators.
 
@@ -601,6 +649,8 @@ def register(reg_name):
         infershape_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int),
                                         POINTER(POINTER(mx_uint)), c_void_p)
         infertype_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p)
+        inferstorage_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p)
+        inferstorage_backward_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p)
         list_functype = CFUNCTYPE(c_int, POINTER(POINTER(POINTER(c_char))), c_void_p)
         deps_functype = CFUNCTYPE(c_int, c_int_p, c_int_p, c_int_p,
                                   c_int_p, POINTER(c_int_p), c_void_p)
@@ -654,6 +704,81 @@ def register(reg_name):
                     return False
                 return True
 
+            def infer_storage_type_backward_entry(num_tensor, tensor_stypes, _):
+                """C Callback for CustomOpProp::InferStorageTypeBackward"""
+                try:
+                    n_in = len(op_prop.list_arguments())
+                    n_out = len(op_prop.list_outputs())
+                    n_aux = len(op_prop.list_auxiliary_states())
+                    total_inputs = n_in + 2 * n_out
+                    total_aux = n_aux
+                    total_outputs = n_in
+                    assert num_tensor == (2 * n_in + 2 * n_out + n_aux)
+
+                    stypes = [_STORAGE_TYPE_ID_TO_STR[tensor_stypes[i]] \
+                             for i in range(total_inputs + total_aux)]
+                    ret = op_prop.infer_storage_type_backward(stypes)
+                    if len(ret) == 2:
+                        istype, ostype = ret
+                        astype = []
+                    elif len(ret) == 3:
+                        istype, ostype, astype = ret
+                    else:
+                        raise AssertionError("infer_storage_type backward must return 2 or 3 lists")
+                    assert len(ostype) == total_outputs, \
+                        "InferStorageTypeBackward Error: expecting %d entries in returned output " \
+                        "stypes, got %d."%(total_outputs, len(ostype))
+                    assert len(istype) == (total_inputs), \
+                        "InferStorageTypeBackward Error: expecting %d entries in returned output " \
+                        "stypes, got %d."%(total_inputs, len(istype))
+                    rtype = list(istype) + list(ostype) + list(astype)
+                    for i, dtype in enumerate(rtype):
+                        tensor_stypes[i] = _STORAGE_TYPE_STR_TO_ID[dtype]
+                    infer_storage_type_backward_entry._ref_holder = [tensor_stypes]
+                except Exception:
+                    print('Error in %s.infer_type: %s' % (reg_name, traceback.format_exc()))
+                    return False
+                return True
+
+
+            def infer_storage_type_entry(num_tensor, tensor_stypes, _):
+                """C Callback for CustomOpProp::InferStorageType"""
+                try:
+                    n_in = len(op_prop.list_arguments())
+                    n_out = len(op_prop.list_outputs())
+                    n_aux = len(op_prop.list_auxiliary_states())
+                    assert num_tensor == n_in + n_out + n_aux
+
+                    stypes = [_STORAGE_TYPE_ID_TO_STR[tensor_stypes[i]] for i in range(n_in)]
+                    ret = op_prop.infer_storage_type(stypes)
+                    if len(ret) == 2:
+                        istype, ostype = ret
+                        astype = []
+                    elif len(ret) == 3:
+                        istype, ostype, astype = ret
+                    else:
+                        raise AssertionError("infer_storage_type must return 2 or 3 lists")
+
+                    assert len(ostype) == n_out, \
+                        "InferStorageType Error: expecting %d entries in returned output " \
+                        "stypes, got %d."%(n_out, len(ostype))
+                    assert len(istype) == n_in, \
+                        "InferStorageType Error: expecting %d entries in returned input " \
+                        "stypes, got %d."%(n_in, len(istype))
+                    assert len(astype) == n_aux, \
+                        "InferStorageType Error: expecting %d entries in returned aux state " \
+                        "stypes, got %d."%(n_aux, len(astype))
+                    rtype = list(istype) + list(ostype) + list(astype)
+                    for i, dtype in enumerate(rtype):
+                        tensor_stypes[i] = _STORAGE_TYPE_STR_TO_ID[dtype]
+
+                    infer_storage_type_entry._ref_holder = [tensor_stypes]
+                except Exception:
+                    print('Error in %s.infer_type: %s' % (reg_name, traceback.format_exc()))
+                    return False
+                return True
+
+
             def infer_type_entry(num_tensor, tensor_types, _):
                 """C Callback for CustomOpProp::InferType"""
                 try:
@@ -673,13 +798,13 @@ def register(reg_name):
                         raise AssertionError("infer_type must return 2 or 3 lists")
                     assert len(otype) == n_out, \
                         "InferType Error: expecting %d entries in returned output " \
-                        "shapes, got %d."%(n_out, len(otype))
+                        "types, got %d."%(n_out, len(otype))
                     assert len(itype) == n_in, \
                         "InferType Error: expecting %d entries in returned input " \
-                        "shapes, got %d."%(n_in, len(itype))
+                        "types, got %d."%(n_in, len(itype))
                     assert len(atype) == n_aux, \
                         "InferType Error: expecting %d entries in returned aux state " \
-                        "shapes, got %d."%(n_aux, len(atype))
+                        "types, got %d."%(n_aux, len(atype))
                     rtype = list(itype) + list(otype) + list(atype)
                     for i, dtype in enumerate(rtype):
                         tensor_types[i] = _DTYPE_NP_TO_MX[dtype]
@@ -768,13 +893,13 @@ def register(reg_name):
                             tensors = [[] for i in range(5)]
                             for i in range(num_ndarray):
                                 if tags[i] == 1 or tags[i] == 4:
-                                    tensors[tags[i]].append(NDArray(cast(ndarraies[i],
-                                                                         NDArrayHandle),
-                                                                    writable=True))
+                                    tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i],
+                                                                              NDArrayHandle),
+                                                                         writable=True))
                                 else:
-                                    tensors[tags[i]].append(NDArray(cast(ndarraies[i],
-                                                                         NDArrayHandle),
-                                                                    writable=False))
+                                    tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i],
+                                                                              NDArrayHandle),
+                                                                         writable=False))
                             reqs = [req_enum[reqs[i]] for i in range(len(tensors[1]))]
                             with ctx:
                                 op.forward(is_train=is_train, req=reqs,
@@ -792,13 +917,13 @@ def register(reg_name):
                             tensors = [[] for i in range(5)]
                             for i in range(num_ndarray):
                                 if tags[i] == 2 or tags[i] == 4:
-                                    tensors[tags[i]].append(NDArray(cast(ndarraies[i],
-                                                                         NDArrayHandle),
-                                                                    writable=True))
+                                    tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i],
+                                                                              NDArrayHandle),
+                                                                         writable=True))
                                 else:
-                                    tensors[tags[i]].append(NDArray(cast(ndarraies[i],
-                                                                         NDArrayHandle),
-                                                                    writable=False))
+                                    tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i],
+                                                                              NDArrayHandle),
+                                                                         writable=False))
                             reqs = [req_enum[reqs[i]] for i in range(len(tensors[2]))]
                             with ctx:
                                 op.backward(req=reqs,
@@ -856,7 +981,9 @@ def register(reg_name):
                          infershape_functype(infer_shape_entry),
                          deps_functype(declare_backward_dependency_entry),
                          createop_functype(create_operator_entry),
-                         infertype_functype(infer_type_entry)]
+                         infertype_functype(infer_type_entry),
+                         inferstorage_functype(infer_storage_type_entry),
+                         inferstorage_backward_functype(infer_storage_type_backward_entry)]
             callbacks = [cast(i, CFUNCTYPE(c_int)) for i in callbacks]
             contexts = [None]*len(callbacks)
             ret[0] = MXCallbackList(c_int(len(callbacks)),
diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc
index 683423f..5e35e90 100644
--- a/src/operator/custom/custom.cc
+++ b/src/operator/custom/custom.cc
@@ -350,20 +350,100 @@ void Backward(const OpStatePtr& state,
   Imperative::Get()->set_is_recording(prev_recording);
 }
 
+inline bool BackwardInferStorageType(const nnvm::NodeAttrs& attrs,
+                                     const int dev_mask,
+                                     DispatchMode* dispatch_mode,
+                                     std::vector<int>* iattr,
+                                     std::vector<int>* oattr) {
+  const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
+
+  if (params.info->num_callbacks <= kCustomOpPropBackwardInferStorageType) {
+    for (size_t i = 0; i < iattr->size(); i++) {
+      STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, kDefaultStorage);
+    }
+    for (size_t i = 0; i < oattr->size(); i++) {
+      STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, kDefaultStorage);
+    }
+    DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+    return true;
+  }
+
+  std::vector<int> stypes;
+  stypes.reserve(params.num_outs * 2 + params.num_args * 2 + params.num_auxs);
+  for (size_t i = 0; i < iattr->size(); ++i) {
+    stypes.push_back((*iattr)[i]);
+  }
+  for (size_t i = 0; i < oattr->size(); ++i) {
+    stypes.push_back((*oattr)[i]);
+  }
+
+  CHECK(reinterpret_cast<CustomOpBackwardInferStorageTypeFunc>(
+      params.info->callbacks[kCustomOpPropBackwardInferStorageType])(
+      stypes.size(), stypes.data(),
+      params.info->contexts[kCustomOpPropBackwardInferStorageType]));
+  for (size_t i = 0; i < 2 * params.num_outs + params.num_args; ++i) {
+    STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, stypes[i]);
+  }
+  for (size_t i = 0; i < params.num_args; ++i) {
+    STORAGE_TYPE_ASSIGN_CHECK(
+        *oattr, i, stypes[i + 2 * params.num_outs + params.num_args]);
+  }
+  for (size_t i = 0; i < params.num_auxs; ++i) {
+    STORAGE_TYPE_ASSIGN_CHECK(
+        *iattr, i + 2 * params.num_outs + params.num_args,
+        stypes[i + 2 * params.num_outs + 2 * params.num_args]);
+  }
+
+  DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+  return true;
+}
+
 // infer storage function for custom op, which assigns kDefaultStorage for
 // all undefined stypes, and dispatch on DispatchMode::kFComputeEx.
-inline bool InferStorageType(const nnvm::NodeAttrs& attrs,
-                             const int dev_mask,
+inline bool InferStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask,
                              DispatchMode* dispatch_mode,
-                             std::vector<int> *iattr,
-                             std::vector<int> *oattr) {
-  for (int& v : *oattr) {
-    if (v == -1) v = kDefaultStorage;
+                             std::vector<int>* iattr, std::vector<int>* oattr) {
+  const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
+
+  if (params.info->num_callbacks <= kCustomOpPropInferStorageType) {
+    for (size_t i = 0; i < iattr->size(); i++) {
+      STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, kDefaultStorage);
+    }
+    for (size_t i = 0; i < oattr->size(); i++) {
+      STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, kDefaultStorage);
+    }
+    DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+    return true;
   }
-  for (int& v : *iattr) {
-    if (v == -1) v = kDefaultStorage;
+
+  std::vector<int> stypes;
+  stypes.reserve(params.num_args + params.num_outs + params.num_auxs);
+  for (size_t i = 0; i < params.num_args; ++i) {
+    stypes.push_back((*iattr)[i]);
+  }
+  for (const auto& i : *oattr) {
+    stypes.push_back(i);
+  }
+  for (size_t i = 0; i < params.num_auxs; ++i) {
+    stypes.push_back((*iattr)[params.num_args + i]);
   }
-  dispatch_mode_assign(dispatch_mode, DispatchMode::kFComputeEx);
+
+  CHECK(reinterpret_cast<CustomOpInferStorageTypeFunc>(
+      params.info->callbacks[kCustomOpPropInferStorageType])(
+      stypes.size(), stypes.data(),
+      params.info->contexts[kCustomOpPropInferStorageType]));
+  for (size_t i = 0; i < params.num_args; ++i) {
+    STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, stypes[i]);
+  }
+  for (size_t i = 0; i < params.num_outs; ++i) {
+    STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, stypes[params.num_args + i]);
+  }
+  for (size_t i = 0; i < params.num_auxs; ++i) {
+    STORAGE_TYPE_ASSIGN_CHECK(*iattr, params.num_args + i,
+                              stypes[params.num_args + params.num_outs + i]);
+  }
+
+  DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
   return true;
 }
 
@@ -429,7 +509,7 @@ NNVM_REGISTER_OP(_backward_Custom)
   })
 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Backward)
 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Backward)
-.set_attr<FInferStorageType>("FInferStorageType", InferStorageType);
+.set_attr<FInferStorageType>("FInferStorageType", BackwardInferStorageType);
 
 }  // namespace custom
 }  // namespace op
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 3484b18..d322fa4 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3570,12 +3570,18 @@ def test_rcbrt_op():
 def test_custom_op():
     class Sqr(mx.operator.CustomOp):
         def forward(self, is_train, req, in_data, out_data, aux):
-            self.assign(out_data[0], req[0], in_data[0]*in_data[0])
-            aux[0][:] = 1
+            if in_data[0].stype == 'default':
+                aux[0][:] = 1
+                self.assign(out_data[0], req[0], in_data[0]*in_data[0])
+            else:
+                self.assign(out_data[0], req[0], mx.nd.sparse.square(in_data[0]))
+                if in_data[0].stype == 'csr':
+                    assert(isinstance(in_data[0], mx.nd.sparse.CSRNDArray))
 
         def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
             self.assign(in_grad[0], req[0], 2*in_data[0]*out_grad[0])
-            assert (aux[0].asnumpy() == 1).all()
+            if in_data[0].stype == 'default':
+                assert (aux[0].asnumpy() == 1).all()
 
     @mx.operator.register("sqr")
     class SqrProp(mx.operator.CustomOpProp):
@@ -3597,6 +3603,16 @@ def test_custom_op():
         def infer_type(self, in_type):
             return in_type, [in_type[0]], [in_type[0]]
 
+        def infer_storage_type(self, in_stype):
+            if in_stype[0] == 'default':
+                return ['default'], ['default'], ['default']
+            return ['csr'], ['csr'], ['csr']
+
+        def infer_storage_type_backward(self, in_stype):
+            if in_stype[1] == 'default':
+                return ['default', 'default', 'default'], ['default'], ['default']
+            return ['default', 'csr', 'csr'], ['csr'], ['csr']
+
         def create_operator(self, ctx, shapes, dtypes):
             return Sqr()
 
@@ -3609,15 +3625,18 @@ def test_custom_op():
 
     data = mx.symbol.cast(data, dtype='float64')
     op = mx.symbol.cast(op, dtype='float32')
-    x = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10)))
-    aux = mx.nd.zeros_like(x)
     check_numeric_gradient(op, [x], [aux])
 
+    x = x.tostype('csr')
+    aux = mx.nd.zeros_like(x)
     x.attach_grad()
     with mx.contrib.autograd.train_section():
         y = mx.nd.Custom(x, aux, op_type='sqr')
         y.backward()
-
+    mx.nd.waitall()
+    assert (x.grad.stype == 'csr')
+    assert (y.stype == 'csr')
+    assert (aux.stype == 'csr')
 
 def test_psroipooling():
     for num_rois in [1, 2]:

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].