You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/01/31 18:45:11 UTC

[GitHub] anirudh2290 closed pull request #8738: Fix custom op - infer_storage_type_backward

anirudh2290 closed pull request #8738: Fix custom op - infer_storage_type_backward
URL: https://github.com/apache/incubator-mxnet/pull/8738
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/numpy-ops/README.md b/example/numpy-ops/README.md
index 1ec8a404c3..ffe24aa78e 100644
--- a/example/numpy-ops/README.md
+++ b/example/numpy-ops/README.md
@@ -4,4 +4,21 @@ Uses the same setup as example/mnist/mlp.py. Except the loss symbol is
 custom defined with NumpyOp. mxnet.operator.NumpyOp help move computation
 in a symbol's forward/backward operation to python frontend. This is for
 fast implementation/experimentation of non-performance-critical symbols.
-If it is becoming a bottleneck, please consider write a C++/CUDA version.
\ No newline at end of file
+If it is becoming a bottleneck, please consider writing a C++/CUDA version.
+
+# Example operator with CustomOp
+
+You can find the example of a custom operator which performs elementwise
+square for sparse ndarray: `custom_sparse_sqr.py`. The example contains
+implementations for `infer_storage_type` and `infer_storage_type_backward`
+interfaces which can be used to infer sparse storage types `csr`
+and `row_sparse` in the forward and backward pass respectively.
+
+To run the example :
+```
+python custom_sparse_sqr.py
+```
+OR
+```
+python3 custom_sparse_sqr.py
+```
diff --git a/example/numpy-ops/custom_sparse_sqr.py b/example/numpy-ops/custom_sparse_sqr.py
new file mode 100644
index 0000000000..15bbff55f1
--- /dev/null
+++ b/example/numpy-ops/custom_sparse_sqr.py
@@ -0,0 +1,92 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+import mxnet as mx
+import numpy as np
+
+class Sqr(mx.operator.CustomOp):
+    '''Example of how to use custom op with sparse ndarrays
+    '''
+    def forward(self, is_train, req, in_data, out_data, aux):
+        self.assign(out_data[0], req[0], mx.nd.sparse.square(in_data[0]))
+
+    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
+        self.assign(in_grad[0], req[0], 2 * in_data[0] * out_grad[0])
+
+@mx.operator.register("sqr")
+class SqrProp(mx.operator.CustomOpProp):
+    def __init__(self):
+        super(SqrProp, self).__init__(need_top_grad=True)
+
+    def list_arguments(self):
+        return ['data']
+
+    def list_outputs(self):
+        return ['output']
+
+    def infer_shape(self, in_shape):
+        return in_shape, [in_shape[0]], []
+
+    def infer_type(self, in_type):
+        return in_type, [in_type[0]], []
+
+    def infer_storage_type(self, in_stype):
+        '''Infer storage type logic for the forward pass
+        Takes a list of storage types for inputs
+        Returns three lists lists, one for input storage types inferred,
+        second for output storage types inferred and third for aux storage
+        types inferred
+        The in_stype is the list containing storage type for inputs
+        If the input is a dense ndarray then we infer the input
+        and output to be dense. If input is csr then input and output
+        are inferred as csr.
+        '''
+        if in_stype[0] == 'default':
+            return ['default'], ['default'], []
+        return ['csr'], ['csr'], []
+
+    def infer_storage_type_backward(self, ograd_stype, in_stype, out_stype, igrad_stype, aux_stype):
+        '''Infer storage type logic for the backward pass
+        Takes storage type of output gradients(ograd_stype), inputs(in_stype),
+        outputs(out_stype) and aux(aux_stype).
+        Returns inferred storage types in the following order:
+        ograd_stype, in_stype, out_stype, igrad_stype (Storage type for input gradients)
+        and aux_stype.
+        '''
+        if in_stype[0] == 'default':
+            return ['default'], ['default'], ['default'], ['default'], []
+        return ['csr'], ['csr'], ['csr'], ['default'], []
+
+    def create_operator(self, ctx, shapes, dtypes):
+        return Sqr()
+
+x = mx.nd.array(np.random.uniform(1, 10, size=(4,10)))
+x = x.tostype('csr')
+x.attach_grad(stype='default')
+z = mx.nd.zeros_like(x)
+with mx.contrib.autograd.train_section():
+    y = mx.nd.Custom(x, op_type='sqr')
+    y.backward(out_grad=z)
+print("Original ndarray")
+print("--------------")
+print(x.asnumpy())
+print("Squared ndarray")
+print("--------------")
+print(y.asnumpy())
+print("stype of input is {}".format(x.stype))
+print("stype of output is {}".format(y.stype))
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index d34b194554..514ca3e631 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -147,7 +147,9 @@ enum CustomOpPropCallbacks {
   kCustomOpPropInferShape,
   kCustomOpPropDeclareBackwardDependency,
   kCustomOpPropCreateOperator,
-  kCustomOpPropInferType
+  kCustomOpPropInferType,
+  kCustomOpPropInferStorageType,
+  kCustomOpPropBackwardInferStorageType
 };
 
 
@@ -159,6 +161,11 @@ typedef int (*CustomOpListFunc)(char*** /*args*/, void* /*state*/);
 typedef int (*CustomOpInferShapeFunc)(int /*num_input*/, int* /*ndims*/,
                                       unsigned** /*shapes*/, void* /*state*/);
 typedef int (*CustomOpInferTypeFunc)(int /*num_input*/, int* /*types*/, void* /*state*/);
+typedef int (*CustomOpInferStorageTypeFunc)(int /*num_input*/, int* /*stypes*/, void* /*state*/);
+typedef int (*CustomOpBackwardInferStorageTypeFunc)(int /*num_input*/,
+                                                    int * /*stypes*/,
+                                                    int * /*tags*/,
+                                                    void * /*state*/);
 typedef int (*CustomOpBwdDepFunc)(const int* /*out_grad*/, const int* /*in_data*/,
                                   const int* /*out_data*/, int* /*num_deps*/,
                                   int** /*rdeps*/, void* /*state*/);
diff --git a/python/mxnet/operator.py b/python/mxnet/operator.py
index 141a33806a..0d5e6c551c 100644
--- a/python/mxnet/operator.py
+++ b/python/mxnet/operator.py
@@ -16,7 +16,7 @@
 # under the License.
 
 # coding: utf-8
-# pylint: disable=invalid-name, protected-access, too-many-arguments, no-self-use, too-many-locals, broad-except
+# pylint: disable=invalid-name, protected-access, too-many-arguments, no-self-use, too-many-locals, broad-except, too-many-lines
 """numpy interface for operators."""
 from __future__ import absolute_import
 
@@ -31,6 +31,10 @@
 from .base import c_str, mx_uint, mx_float, ctypes2numpy_shared, NDArrayHandle, py_str
 from . import symbol, context
 from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
+from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, _STORAGE_TYPE_ID_TO_STR
+from .ndarray.ndarray import _STORAGE_TYPE_UNDEFINED, _STORAGE_TYPE_DEFAULT
+from .ndarray import _ndarray_cls
+
 
 c_int_p = POINTER(c_int)
 
@@ -518,6 +522,71 @@ def infer_type(self, in_type):
         return in_type, [in_type[0]]*len(self.list_outputs()), \
             [in_type[0]]*len(self.list_auxiliary_states())
 
+    def infer_storage_type(self, in_stype):
+        """infer_storage_type interface. Used to infer storage type of
+        inputs and outputs in the forward pass.
+
+        Parameters
+        ----------
+        in_stype : list of stypes, Valid stypes are default, row_sparse and
+            csr
+
+        Returns
+        -------
+        in_stype : list
+            list of argument stypes.
+        out_stype : list
+            list of output types calculated from in_stype,
+            in the same order as declared in list_outputs.
+        aux_type : Optional, list
+            list of aux types calculated from in_stype,
+            in the same order as declared in list_auxiliary_states.
+        """
+        return in_stype, [in_stype[0]]*len(self.list_outputs()), \
+            [in_stype[0]]*len(self.list_auxiliary_states())
+
+    def infer_storage_type_backward(self, ograd_stype, in_stype, out_stype, igrad_stype, aux_stype):
+        """infer_storage_type_backward interface. Used to infer storage
+        type of inputs and outputs in the backward pass.
+
+        If "undefined" is provided for any storage types it will fallback
+        to the existing storage type. If the existing storage type is
+        also "undefined" it will fallback to "default".
+        If returned lists are empty or less than the length of corresponding
+        input lists the missing values are populated from the input or fallback
+        to "default" if input storage types are "undefined".
+        If returned lists are greater than the length of corresponding input
+        input lists, it will throw an exception.
+
+        Parameters
+        ----------
+        ograd_stype : list
+            list of output gradient storage types.
+        in_stype : list
+            list of input storage types
+        out_stype : list
+            list of output storage types
+        igrad_stype : list
+            list of input gradient storage types.
+        aux_stype : list
+            list of auxiliary storage types
+
+        Returns
+        -------
+        ograd_stype : list
+            list of inferred output gradient storage types
+        in_stype : list
+            list of inferred input storage types
+        out_stype : list
+            list of inferred output storage types
+        igrad_stype : list
+            list of inferred input gradient storage types
+        aux_stype : list
+            list of inferred storage types for auxiliary states.
+        """
+        return list(ograd_stype), list(in_stype), list(out_stype), \
+               list(igrad_stype), list(aux_stype)
+
     def list_outputs(self):
         """list_outputs interface. Can override when creating new operators.
 
@@ -583,6 +652,7 @@ class _Registry(object):
     def __init__(self):
         self.ref_holder = {}
         self.counter = 0
+        self.result_deps = set()
         self.lock = Lock()
 
     def inc(self):
@@ -606,6 +676,9 @@ def do_register(prop_cls):
         infershape_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int),
                                         POINTER(POINTER(mx_uint)), c_void_p)
         infertype_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p)
+        inferstorage_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p)
+        inferstorage_backward_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), \
+                                                   POINTER(c_int), c_void_p)
         list_functype = CFUNCTYPE(c_int, POINTER(POINTER(POINTER(c_char))), c_void_p)
         deps_functype = CFUNCTYPE(c_int, c_int_p, c_int_p, c_int_p,
                                   c_int_p, POINTER(c_int_p), c_void_p)
@@ -661,6 +734,109 @@ def infer_shape_entry(num_tensor, tensor_dims,
                     return False
                 return True
 
+            def infer_storage_type_backward_entry(num_tensor, tensor_stypes, tags, _):
+                """C Callback for CustomOpProp::InferStorageTypeBackward"""
+                try:
+                    tensors = [[] for i in range(5)]
+                    for i in range(num_tensor):
+                        tensors[tags[i]].append(_STORAGE_TYPE_ID_TO_STR[tensor_stypes[i]])
+                    # Ordering of stypes: ograd, input, output, igrad, aux
+                    tensors = [tensors[3], tensors[0], tensors[1], tensors[2], tensors[4]]
+                    ret = op_prop.infer_storage_type_backward(tensors[0],
+                                                              tensors[1],
+                                                              tensors[2],
+                                                              tensors[3],
+                                                              tensors[4])
+                    if len(ret) == 4:
+                        ret += []
+                    elif len(ret) == 5:
+                        pass
+                    else:
+                        raise AssertionError("infer_storage_type_backward must return 4 or 5 lists")
+                    assert len(ret[0]) <= len(tensors[0]), \
+                        "InferStorageTypeBackward Error: expecting <= %d " \
+                        "entries in returned output gradient " \
+                        "stypes, got %d."%(len(tensors[0]), len(ret[0]))
+                    assert len(ret[1]) <= len(tensors[1]), \
+                        "InferStorageTypeBackward Error: expecting <= %d " \
+                        "entries in returned input stypes, " \
+                        "got %d."%(len(tensors[1]), len(ret[1]))
+                    assert len(ret[2]) <= len(tensors[2]), \
+                        "InferStorageTypeBackward Error: expecting <= %d " \
+                        "entries in returned output stypes, " \
+                        "got %d."%(len(tensors[2]), len(ret[2]))
+                    assert len(ret[3]) <= len(tensors[3]), \
+                        "InferStorageTypeBackward Error: expecting <= %d " \
+                        "entries in returned input gradient stypes, " \
+                        "got %d."%(len(tensors[3]), len(ret[3]))
+                    assert len(ret[4]) <= len(tensors[4]), \
+                        "InferStorageTypeBackward Error: expecting <= %d " \
+                        "entries in returned aux stypes, " \
+                        "got %d."%(len(tensors[4]), len(ret[4]))
+                    rstype = []
+                    for i, ret_list in enumerate(ret):
+                        if len(ret_list) < len(tensors[i]):
+                            ret_list.extend(tensors[i][len(ret_list):])
+                        for j, stype in enumerate(ret_list):
+                            if stype == _STORAGE_TYPE_ID_TO_STR[_STORAGE_TYPE_UNDEFINED] \
+                            and tensors[i][j] == _STORAGE_TYPE_ID_TO_STR[_STORAGE_TYPE_UNDEFINED]:
+                                # Fallback to default if user provided undefined
+                                # and existing stype is undefined
+                                ret_list[j] = _STORAGE_TYPE_ID_TO_STR[_STORAGE_TYPE_DEFAULT]
+                            elif stype == _STORAGE_TYPE_ID_TO_STR[_STORAGE_TYPE_UNDEFINED]:
+                                # If user provided undefined for a stype which was defined,
+                                # use the defined stype
+                                ret_list[j] = tensors[i][j]
+                        rstype.extend(ret_list)
+
+                    for i, stype in enumerate(rstype):
+                        tensor_stypes[i] = _STORAGE_TYPE_STR_TO_ID[stype]
+
+                    infer_storage_type_backward_entry._ref_holder = [tensor_stypes]
+                except Exception:
+                    print('Error in %s.infer_type: %s' % (reg_name, traceback.format_exc()))
+                    return False
+                return True
+
+
+            def infer_storage_type_entry(num_tensor, tensor_stypes, _):
+                """C Callback for CustomOpProp::InferStorageType"""
+                try:
+                    n_in = len(op_prop.list_arguments())
+                    n_out = len(op_prop.list_outputs())
+                    n_aux = len(op_prop.list_auxiliary_states())
+                    assert num_tensor == n_in + n_out + n_aux
+
+                    stypes = [_STORAGE_TYPE_ID_TO_STR[tensor_stypes[i]] for i in range(n_in)]
+                    ret = op_prop.infer_storage_type(stypes)
+                    if len(ret) == 2:
+                        istype, ostype = ret
+                        astype = []
+                    elif len(ret) == 3:
+                        istype, ostype, astype = ret
+                    else:
+                        raise AssertionError("infer_storage_type must return 2 or 3 lists")
+
+                    assert len(ostype) == n_out, \
+                        "InferStorageType Error: expecting %d entries in returned output " \
+                        "stypes, got %d."%(n_out, len(ostype))
+                    assert len(istype) == n_in, \
+                        "InferStorageType Error: expecting %d entries in returned input " \
+                        "stypes, got %d."%(n_in, len(istype))
+                    assert len(astype) == n_aux, \
+                        "InferStorageType Error: expecting %d entries in returned aux state " \
+                        "stypes, got %d."%(n_aux, len(astype))
+                    rtype = list(istype) + list(ostype) + list(astype)
+                    for i, dtype in enumerate(rtype):
+                        tensor_stypes[i] = _STORAGE_TYPE_STR_TO_ID[dtype]
+
+                    infer_storage_type_entry._ref_holder = [tensor_stypes]
+                except Exception:
+                    print('Error in %s.infer_type: %s' % (reg_name, traceback.format_exc()))
+                    return False
+                return True
+
+
             def infer_type_entry(num_tensor, tensor_types, _):
                 """C Callback for CustomOpProp::InferType"""
                 try:
@@ -680,13 +856,13 @@ def infer_type_entry(num_tensor, tensor_types, _):
                         raise AssertionError("infer_type must return 2 or 3 lists")
                     assert len(otype) == n_out, \
                         "InferType Error: expecting %d entries in returned output " \
-                        "shapes, got %d."%(n_out, len(otype))
+                        "types, got %d."%(n_out, len(otype))
                     assert len(itype) == n_in, \
                         "InferType Error: expecting %d entries in returned input " \
-                        "shapes, got %d."%(n_in, len(itype))
+                        "types, got %d."%(n_in, len(itype))
                     assert len(atype) == n_aux, \
                         "InferType Error: expecting %d entries in returned aux state " \
-                        "shapes, got %d."%(n_aux, len(atype))
+                        "types, got %d."%(n_aux, len(atype))
                     rtype = list(itype) + list(otype) + list(atype)
                     for i, dtype in enumerate(rtype):
                         tensor_types[i] = _DTYPE_NP_TO_MX[dtype]
@@ -748,6 +924,9 @@ def declare_backward_dependency_entry(out_grad, in_data, out_data, num_dep, deps
                     out_data = [out_data[i] for i in range(len(op_prop.list_outputs()))]
                     rdeps = op_prop.declare_backward_dependency(out_grad, in_data, out_data)
                     num_dep[0] = len(rdeps)
+                    _registry.result_deps = set()
+                    for dep in rdeps:
+                        _registry.result_deps.add(dep)
                     rdeps = cast(c_array_buf(c_int, array('i', rdeps)), c_int_p)
                     deps[0] = rdeps
 
@@ -775,13 +954,13 @@ def forward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _):
                             tensors = [[] for i in range(5)]
                             for i in range(num_ndarray):
                                 if tags[i] == 1 or tags[i] == 4:
-                                    tensors[tags[i]].append(NDArray(cast(ndarraies[i],
-                                                                         NDArrayHandle),
-                                                                    writable=True))
+                                    tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i],
+                                                                              NDArrayHandle),
+                                                                         writable=True))
                                 else:
-                                    tensors[tags[i]].append(NDArray(cast(ndarraies[i],
-                                                                         NDArrayHandle),
-                                                                    writable=False))
+                                    tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i],
+                                                                              NDArrayHandle),
+                                                                         writable=False))
                             reqs = [req_enum[reqs[i]] for i in range(len(tensors[1]))]
                             with ctx:
                                 op.forward(is_train=is_train, req=reqs,
@@ -797,15 +976,29 @@ def backward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _):
                         # pylint: disable=W0613
                         try:
                             tensors = [[] for i in range(5)]
+                            num_outputs = len(op_prop.list_outputs())
+                            num_args = len(op_prop.list_arguments())
                             for i in range(num_ndarray):
+                                if i in _registry.result_deps or i >= (num_outputs * 2 + num_args):
+                                    # If it is a backward dependency or output or aux:
+                                    # Set stype as undefined so that it returns
+                                    # ndarray based on existing stype
+                                    stype = _STORAGE_TYPE_UNDEFINED
+                                else:
+                                    # If it is some input, output or out grad ndarray not part of
+                                    # backward dependency it is empty and thus the ndarray should
+                                    # be set to default
+                                    stype = _STORAGE_TYPE_DEFAULT
                                 if tags[i] == 2 or tags[i] == 4:
-                                    tensors[tags[i]].append(NDArray(cast(ndarraies[i],
-                                                                         NDArrayHandle),
-                                                                    writable=True))
+                                    tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i],
+                                                                              NDArrayHandle),
+                                                                         writable=True,
+                                                                         stype=stype))
                                 else:
-                                    tensors[tags[i]].append(NDArray(cast(ndarraies[i],
-                                                                         NDArrayHandle),
-                                                                    writable=False))
+                                    tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i],
+                                                                              NDArrayHandle),
+                                                                         writable=False,
+                                                                         stype=stype))
                             reqs = [req_enum[reqs[i]] for i in range(len(tensors[2]))]
                             with ctx:
                                 op.backward(req=reqs,
@@ -863,7 +1056,9 @@ def delete_entry(_):
                          infershape_functype(infer_shape_entry),
                          deps_functype(declare_backward_dependency_entry),
                          createop_functype(create_operator_entry),
-                         infertype_functype(infer_type_entry)]
+                         infertype_functype(infer_type_entry),
+                         inferstorage_functype(infer_storage_type_entry),
+                         inferstorage_backward_functype(infer_storage_type_backward_entry)]
             callbacks = [cast(i, CFUNCTYPE(c_int)) for i in callbacks]
             contexts = [None]*len(callbacks)
             ret[0] = MXCallbackList(c_int(len(callbacks)),
diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc
index 280b01b22e..b6703787f1 100644
--- a/src/operator/custom/custom.cc
+++ b/src/operator/custom/custom.cc
@@ -351,20 +351,118 @@ void Backward(const OpStatePtr& state,
   Imperative::Get()->set_is_recording(prev_recording);
 }
 
+inline bool BackwardInferStorageType(const nnvm::NodeAttrs& attrs,
+                                     const int dev_mask,
+                                     DispatchMode* dispatch_mode,
+                                     std::vector<int>* iattr,
+                                     std::vector<int>* oattr) {
+  const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
+
+  if (params.info->num_callbacks <= kCustomOpPropBackwardInferStorageType) {
+    for (size_t i = 0; i < iattr->size(); i++) {
+      STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, kDefaultStorage);
+    }
+    for (size_t i = 0; i < oattr->size(); i++) {
+      STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, kDefaultStorage);
+    }
+    DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+    return true;
+  }
+
+  size_t total = 2 * params.num_args + 2 * params.num_outs + params.num_auxs;
+  size_t bwd_deps_size = params.bwd_idx.size();
+  std::vector<int> stypes(bwd_deps_size, -1);
+  std::vector<int> tags;
+  stypes.reserve(total);
+  tags.reserve(total);
+
+  for (size_t i = 0; i < bwd_deps_size; i++) {
+    if (params.bwd_idx[i] < static_cast<int>(params.num_outs))
+      tags.push_back(3);
+    else if (params.bwd_idx[i] <
+             static_cast<int>(params.num_outs + params.num_args))
+      tags.push_back(0);
+    else
+      tags.push_back(1);
+    stypes[i] = (*iattr)[i];
+  }
+
+  for (size_t i = 0; i < oattr->size(); i++) {
+    stypes.push_back((*oattr)[i]);
+    tags.push_back(2);
+  }
+
+  for (size_t i = (iattr->size() - params.num_auxs); i < iattr->size(); i++) {
+    stypes.push_back((*iattr)[i]);
+    tags.push_back(4);
+  }
+
+  CHECK(reinterpret_cast<CustomOpBackwardInferStorageTypeFunc>(
+      params.info->callbacks[kCustomOpPropBackwardInferStorageType])(
+      stypes.size(), stypes.data(), tags.data(),
+      params.info->contexts[kCustomOpPropBackwardInferStorageType]));
+
+  for (size_t i = 0; i < bwd_deps_size; ++i) {
+    STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, stypes[i]);
+  }
+  for (size_t i = 0; i < oattr->size(); ++i) {
+    STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, stypes[i + bwd_deps_size]);
+  }
+  for (size_t i = 0; i < params.num_auxs; ++i) {
+    STORAGE_TYPE_ASSIGN_CHECK(
+        *iattr, (i + iattr->size() - params.num_auxs), stypes[i + params.num_outs + bwd_deps_size]);
+  }
+
+  DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+  return true;
+}
+
 // infer storage function for custom op, which assigns kDefaultStorage for
 // all undefined stypes, and dispatch on DispatchMode::kFComputeEx.
-inline bool InferStorageType(const nnvm::NodeAttrs& attrs,
-                             const int dev_mask,
+inline bool InferStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask,
                              DispatchMode* dispatch_mode,
-                             std::vector<int> *iattr,
-                             std::vector<int> *oattr) {
-  for (int& v : *oattr) {
-    if (v == -1) v = kDefaultStorage;
+                             std::vector<int>* iattr, std::vector<int>* oattr) {
+  const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
+
+  if (params.info->num_callbacks <= kCustomOpPropInferStorageType) {
+    for (size_t i = 0; i < iattr->size(); i++) {
+      STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, kDefaultStorage);
+    }
+    for (size_t i = 0; i < oattr->size(); i++) {
+      STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, kDefaultStorage);
+    }
+    DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+    return true;
+  }
+
+  std::vector<int> stypes;
+  stypes.reserve(params.num_args + params.num_outs + params.num_auxs);
+  for (size_t i = 0; i < params.num_args; ++i) {
+    stypes.push_back((*iattr)[i]);
   }
-  for (int& v : *iattr) {
-    if (v == -1) v = kDefaultStorage;
+  for (const auto& i : *oattr) {
+    stypes.push_back(i);
   }
-  dispatch_mode_assign(dispatch_mode, DispatchMode::kFComputeEx);
+  for (size_t i = 0; i < params.num_auxs; ++i) {
+    stypes.push_back((*iattr)[params.num_args + i]);
+  }
+
+  CHECK(reinterpret_cast<CustomOpInferStorageTypeFunc>(
+      params.info->callbacks[kCustomOpPropInferStorageType])(
+      stypes.size(), stypes.data(),
+      params.info->contexts[kCustomOpPropInferStorageType]));
+  for (size_t i = 0; i < params.num_args; ++i) {
+    STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, stypes[i]);
+  }
+  for (size_t i = 0; i < params.num_outs; ++i) {
+    STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, stypes[params.num_args + i]);
+  }
+  for (size_t i = 0; i < params.num_auxs; ++i) {
+    STORAGE_TYPE_ASSIGN_CHECK(*iattr, params.num_args + i,
+                              stypes[params.num_args + params.num_outs + i]);
+  }
+
+  DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
   return true;
 }
 
@@ -430,7 +528,7 @@ NNVM_REGISTER_OP(_backward_Custom)
   })
 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Backward)
 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Backward)
-.set_attr<FInferStorageType>("FInferStorageType", InferStorageType);
+.set_attr<FInferStorageType>("FInferStorageType", BackwardInferStorageType);
 
 }  // namespace custom
 }  // namespace op
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 19c4e65d3d..49dbc1f9c1 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3585,12 +3585,18 @@ def test_rcbrt_op():
 def test_custom_op():
     class Sqr(mx.operator.CustomOp):
         def forward(self, is_train, req, in_data, out_data, aux):
-            self.assign(out_data[0], req[0], in_data[0]*in_data[0])
-            aux[0][:] = 1
+            if in_data[0].stype == 'default':
+                aux[0][:] = 1
+                self.assign(out_data[0], req[0], in_data[0]*in_data[0])
+            else:
+                self.assign(out_data[0], req[0], mx.nd.sparse.square(in_data[0]))
+                if in_data[0].stype == 'csr':
+                    assert(isinstance(in_data[0], mx.nd.sparse.CSRNDArray))
 
         def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
             self.assign(in_grad[0], req[0], 2*in_data[0]*out_grad[0])
-            assert (aux[0].asnumpy() == 1).all()
+            if in_data[0].stype == 'default':
+                assert (aux[0].asnumpy() == 1).all()
 
     @mx.operator.register("sqr")
     class SqrProp(mx.operator.CustomOpProp):
@@ -3612,6 +3618,16 @@ def infer_shape(self, in_shape):
         def infer_type(self, in_type):
             return in_type, [in_type[0]], [in_type[0]]
 
+        def infer_storage_type(self, in_stype):
+            if in_stype[0] == 'default':
+                return ['default'], ['default'], ['default']
+            return ['csr'], ['csr'], ['csr']
+
+        def infer_storage_type_backward(self, ograd_stype, in_stype, out_stype, igrad_stype, aux_stype):
+            if in_stype[0] == 'default':
+                return ['default'], ['default'], ['default'], ['default'], ['default']
+            return ['default'], ['csr'], ['csr'], ['csr'], ['csr']
+
         def create_operator(self, ctx, shapes, dtypes):
             return Sqr()
 
@@ -3624,15 +3640,85 @@ def create_operator(self, ctx, shapes, dtypes):
 
     data = mx.symbol.cast(data, dtype='float64')
     op = mx.symbol.cast(op, dtype='float32')
-    x = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10)))
-    aux = mx.nd.zeros_like(x)
     check_numeric_gradient(op, [x], [aux])
 
+    x = x.tostype('csr')
+    aux = mx.nd.zeros_like(x)
     x.attach_grad()
     with mx.contrib.autograd.train_section():
-        y = mx.nd.Custom(x, aux, op_type='sqr')
+        y = mx.nd.Custom(x, aux, name='sqr', op_type='sqr')
         y.backward()
 
+    # test for backward compatibility, i.e. the correctness of default implementation of
+    # infer storage in custom operator
+    class Mult(mx.operator.CustomOp):
+        def forward(self, is_train, req, in_data, out_data, aux):
+            self.assign(out_data[0], req[0], in_data[0]*in_data[1])
+
+        def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
+            self.assign(in_grad[0], req[0], in_data[1])
+            self.assign(in_grad[1], req[1], in_data[0])
+
+    @mx.operator.register("mult")
+    class MultProp(mx.operator.CustomOpProp):
+        def __init__(self):
+            super(MultProp, self).__init__(need_top_grad=True)
+
+        def list_arguments(self):
+            return ['lhs', 'rhs']
+
+        def list_outputs(self):
+            return ['output']
+
+        def infer_shape(self, in_shape):
+            return in_shape, [in_shape[0]], []
+
+        def create_operator(self, ctx, shapes, dtypes):
+            return Mult()
+
+    lhs = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10)))
+    rhs = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10)))
+    lhs.attach_grad()
+    rhs.attach_grad()
+    with mx.contrib.autograd.train_section():
+        y = mx.nd.Custom(lhs, rhs, name='mult', op_type='mult')
+        y.backward()
+    assert_almost_equal(rhs.asnumpy(), lhs.grad.asnumpy())
+    assert_almost_equal(lhs.asnumpy(), rhs.grad.asnumpy())
+
+    class MultNoGrad(mx.operator.CustomOp):
+        def forward(self, is_train, req, in_data, out_data, aux):
+            self.assign(out_data[0], req[0], in_data[0]*in_data[1])
+
+        def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
+            self.assign(in_grad[0], req[0], in_data[1])
+            self.assign(in_grad[1], req[1], in_data[0])
+
+    @mx.operator.register("mult_no_grad")
+    class MultNoGradProp(mx.operator.CustomOpProp):
+        def __init__(self):
+            super(MultNoGradProp, self).__init__(need_top_grad=False)
+
+        def list_arguments(self):
+            return ['lhs', 'rhs']
+
+        def list_outputs(self):
+            return ['output']
+
+        def infer_shape(self, in_shape):
+            return in_shape, [in_shape[0]], []
+
+        def create_operator(self, ctx, shapes, dtypes):
+            return MultNoGrad()
+
+        def infer_storage_type_backward(self, ograd_stype, in_stype, out_stype, igrad_stype, aux_stype):
+            return [], [], [], ['default'], []
+
+    with mx.contrib.autograd.train_section():
+        y2 = mx.nd.Custom(lhs, rhs, name="mult_no_grad", op_type="mult_no_grad")
+        y2.backward()
+    assert_almost_equal(rhs.asnumpy(), lhs.grad.asnumpy())
+    assert_almost_equal(lhs.asnumpy(), rhs.grad.asnumpy())
 
 def test_psroipooling():
     for num_rois in [1, 2]:


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services