You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/06/29 15:12:57 UTC

[incubator-mxnet] branch master updated: [MXNET-537] add_n(dense, csr, dense) = dense and add_n([dense, csr, rsp]*, dense, [dense, csr, rsp]*) = dense on CPU & GPU (#11330)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new ca60b94  [MXNET-537] add_n(dense, csr, dense) = dense and add_n([dense, csr, rsp]*, dense, [dense, csr, rsp]*) = dense on CPU & GPU (#11330)
ca60b94 is described below

commit ca60b94e3ed5280c94d521ead35ca2a03306772a
Author: Hao Jin <ha...@users.noreply.github.com>
AuthorDate: Fri Jun 29 11:12:51 2018 -0400

    [MXNET-537] add_n(dense, csr, dense) = dense and add_n([dense, csr, rsp]*, dense, [dense, csr, rsp]*) = dense on CPU & GPU (#11330)
    
    * support for add_n(dense, csr, dense) = dense with tests
    
    * eliminate magic number
---
 src/common/utils.h                                 |  30 ++++++
 src/ndarray/ndarray_function.cc                    | 108 ++++++++++++++++++++-
 src/ndarray/ndarray_function.cu                    | 108 ++++++++++++++++++++-
 src/operator/elemwise_op_common.h                  |  10 ++
 src/operator/tensor/elemwise_binary_broadcast_op.h |   3 +-
 src/operator/tensor/elemwise_binary_op-inl.h       |  45 +++++++--
 src/operator/tensor/elemwise_binary_op.h           |  17 +++-
 src/operator/tensor/elemwise_binary_op_basic.cu    |  54 +++++++++++
 src/operator/tensor/elemwise_sum.cc                |  10 +-
 src/operator/tensor/elemwise_sum.cu                |   6 +-
 src/operator/tensor/indexing_op-inl.cuh            |   1 +
 tests/python/unittest/test_sparse_operator.py      |  24 +++--
 12 files changed, 393 insertions(+), 23 deletions(-)

diff --git a/src/common/utils.h b/src/common/utils.h
index be78bf4..d7ed4dd 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -319,6 +319,36 @@ inline bool ContainsOnlyStorage(const std::vector<NDArray>& ndarrays,
   return false;
 }
 
+/*! \brief returns true if storage type of any array in `ndarrays`
+ *         is the same as the target `stype`. false is returned for empty inputs.
+ */
+inline bool ContainsStorageType(const std::vector<NDArray>& ndarrays,
+                                const NDArrayStorageType stype) {
+  if (!ndarrays.empty()) {
+    for (const auto& nd : ndarrays) {
+      if (nd.storage_type() == stype) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+/*! \brief returns true if any storage type `ndstype` in `ndstypes`
+ *         is the same as the target `stype`. false is returned for empty inputs.
+ */
+inline bool ContainsStorageType(const std::vector<int>& ndstypes,
+                                const NDArrayStorageType stype) {
+  if (!ndstypes.empty()) {
+    for (const auto& ndstype : ndstypes) {
+      if (ndstype == stype) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
 /*! \brief get string representation of dispatch_mode */
 inline std::string dispatch_mode_string(const DispatchMode x) {
   switch (x) {
diff --git a/src/ndarray/ndarray_function.cc b/src/ndarray/ndarray_function.cc
index 552555a..022302a 100644
--- a/src/ndarray/ndarray_function.cc
+++ b/src/ndarray/ndarray_function.cc
@@ -26,6 +26,9 @@
 #include "./ndarray_function.h"
 #include "./ndarray_function-inl.h"
 #include "../common/utils.h"
+#include "../operator/mxnet_op.h"
+#include "../operator/tensor/elemwise_binary_op-inl.h"
+#include "../operator/tensor/elemwise_sum.h"
 
 namespace mxnet {
 namespace ndarray {
@@ -165,6 +168,102 @@ void ElementwiseSumRsp(mshadow::Stream<cpu>* s,
   });
 }
 
+void ElementwiseSumDnsCsrDnsImpl(mshadow::Stream<cpu>* s,
+                                 const Resource& rsc,
+                                 const std::vector<NDArray>& nds,
+                                 NDArray* out) {
+  using namespace mxnet::op;
+  using namespace mxnet::op::mxnet_op;
+  const TBlob& out_data = out->data();
+  MSHADOW_TYPE_SWITCH(out->dtype(), DType, {  // data type
+    Kernel<Sum, cpu>::Launch(
+      s, out_data.Size(), out_data.dptr<DType>(), kWriteTo, nds[0].data().dptr<DType>(),
+      nds[2].data().dptr<DType>());
+    const TBlob& csr_data = nds[1].data();
+    const TBlob& csr_indices = nds[1].aux_data(csr::kIdx);
+    const TBlob& csr_indptr = nds[1].aux_data(csr::kIndPtr);
+    const nnvm::dim_t num_rows = nds[1].shape()[0];
+    const nnvm::dim_t num_cols = nds[1].shape()[1];
+    MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, {  // indices type
+      MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, {  // indptr type
+        if (nds[1].storage_initialized()) {
+          Kernel<ElemwiseDnsCsrDnsKernel<kWriteTo, mshadow_op::plus>, cpu>::Launch(
+            s, num_rows, out_data.dptr<DType>(), out_data.dptr<DType>(),
+            csr_data.dptr<DType>(), csr_indices.dptr<IType>(),
+            csr_indptr.dptr<CType>(), num_rows, num_cols);
+        }
+      });
+    });
+  });
+}
+
+void ElementwiseSumContainsDnsImpl(mshadow::Stream<cpu>* s,
+                                   const Resource& rsc,
+                                   const std::vector<NDArray>& nds,
+                                   NDArray* out) {
+  using namespace mxnet::op;
+  using namespace mxnet::op::mxnet_op;
+  const TBlob& out_data = out->data();
+  MSHADOW_TYPE_SWITCH(out->dtype(), DType, {  // data type
+    Kernel<set_zero, cpu>::Launch(s, out_data.Size(), out_data.dptr<DType>());
+    for (size_t i = 0; i < nds.size(); ++i) {
+      const NDArray& nd = nds[i];
+      const nnvm::dim_t num_rows = nd.shape()[0];
+      const nnvm::dim_t num_cols = nd.shape()[1];
+      const TBlob& nd_data = nd.data();
+
+      if (i == 0) {
+        if (nd.storage_type() == kDefaultStorage) {
+          Kernel<op_with_req<mshadow_op::identity, kWriteTo>, cpu>::Launch(
+            s, out_data.Size(), out_data.dptr<DType>(), nd_data.dptr<DType>());
+          continue;
+        } else {
+          Kernel<set_zero, cpu>::Launch(s, out_data.Size(), out_data.dptr<DType>());
+        }
+      }
+
+      switch (nd.storage_type()) {
+        case kDefaultStorage: {
+          Kernel<op_with_req<mshadow_op::plus, kWriteTo>, cpu>::Launch(
+            s, out_data.Size(), out_data.dptr<DType>(), out_data.dptr<DType>(),
+            nd_data.dptr<DType>());
+          break;
+        }
+        case kCSRStorage: {
+          const TBlob& nd_indices = nd.aux_data(csr::kIdx);
+          const TBlob& nd_indptr = nd.aux_data(csr::kIndPtr);
+          MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, {  // indices type
+            MSHADOW_IDX_TYPE_SWITCH(nd_indptr.type_flag_, CType, {  // indptr type
+              if (nd.storage_initialized()) {
+                Kernel<ElemwiseDnsCsrDnsKernel<kWriteTo, mshadow_op::plus>, cpu>::Launch(
+                  s, num_rows, out_data.dptr<DType>(), out_data.dptr<DType>(),
+                  nd_data.dptr<DType>(), nd_indices.dptr<IType>(),
+                  nd_indptr.dptr<CType>(), num_rows, num_cols);
+              }
+            });
+          });
+          break;
+        }
+        case kRowSparseStorage: {
+          const TBlob& nd_indices = nd.aux_data(rowsparse::kIdx);
+          MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, {  // indices type
+            if (nd.storage_initialized()) {
+              const nnvm::dim_t nz_rows = nd_indices.Size();
+              Kernel<ElemwiseDnsRspDnsKernel<kWriteTo, mshadow_op::plus>, cpu>::Launch(
+                s, nz_rows * num_cols, out_data.dptr<DType>(),
+                out_data.dptr<DType>(), nd_data.dptr<DType>(), nd_indices.dptr<IType>(),
+                num_rows, nz_rows, num_cols);
+            }
+          });
+          break;
+        }
+        default:
+          LOG(FATAL) << "unknown storage type " << nd.storage_type() << "encountered...";
+      }
+    }
+  });
+}
+
 /*!
  * \brief Parallel cpu impl of elemwise sum for sparse tensors.
  * Currently only support row sparse sum.
@@ -175,8 +274,15 @@ void ElementwiseSum<cpu>(mshadow::Stream<cpu>* s,
                          const std::vector<NDArray>& nds,
                          NDArray* out) {
   if (nds.empty()) return;
-  if (nds[0].storage_type() == kRowSparseStorage) {
+  if (common::ContainsOnlyStorage(nds, kRowSparseStorage)) {
     ElementwiseSumRsp(s, rsc, nds, out);
+  } else if (nds.size() == 3U && nds[0].storage_type() == kDefaultStorage &&
+             nds[1].storage_type() == kCSRStorage && nds[2].storage_type() == kDefaultStorage &&
+             out->storage_type() == kDefaultStorage) {
+    ElementwiseSumDnsCsrDnsImpl(s, rsc, nds, out);
+  } else if (nds.size() > 4U && common::ContainsStorageType(nds, kDefaultStorage) &&
+             out->storage_type() == kDefaultStorage) {
+    ElementwiseSumContainsDnsImpl(s, rsc, nds, out);
   } else {
     LOG(FATAL) << "ElementwiseSum<cpu> has not been implemented for storage_type = << "
                << nds[0].storage_type();
diff --git a/src/ndarray/ndarray_function.cu b/src/ndarray/ndarray_function.cu
index 06b5ad4..0ab40a7 100644
--- a/src/ndarray/ndarray_function.cu
+++ b/src/ndarray/ndarray_function.cu
@@ -25,7 +25,9 @@
 // this will be invoked by nvcc and compile GPU version
 #include <cub/cub.cuh>
 #include <dmlc/logging.h>
-#include "../operator/mxnet_op.h"
+#include "../operator/tensor/elemwise_binary_op-inl.h"
+#include "../operator/tensor/elemwise_sum.h"
+#include "../operator/tensor/indexing_op.h"
 #include "../operator/tensor/init_op.h"
 #include "../operator/tensor/util/tensor_util-inl.h"
 #include "../operator/tensor/util/tensor_util-inl.cuh"
@@ -185,6 +187,101 @@ void ElementwiseSumRspImpl(mshadow::Stream<gpu>* s,
   });
 }
 
+void ElementwiseSumDnsCsrDnsImpl(mshadow::Stream<gpu>* s,
+                                 const Resource& rsc,
+                                 const std::vector<NDArray>& nds,
+                                 NDArray* out) {
+  using namespace mxnet::op;
+  using namespace mxnet::op::mxnet_op;
+  const TBlob& out_data = out->data();
+  MSHADOW_TYPE_SWITCH(out->dtype(), DType, {  // data type
+    Kernel<Sum, gpu>::Launch(
+      s, out_data.Size(), out_data.dptr<DType>(), kWriteTo, nds[0].data().dptr<DType>(),
+      nds[2].data().dptr<DType>());
+    const TBlob& csr_data = nds[1].data();
+    const TBlob& csr_indices = nds[1].aux_data(csr::kIdx);
+    const TBlob& csr_indptr = nds[1].aux_data(csr::kIndPtr);
+    const nnvm::dim_t num_rows = nds[1].shape()[0];
+    const nnvm::dim_t num_cols = nds[1].shape()[1];
+    MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, {  // indices type
+      MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, {  // indptr type
+        if (nds[1].storage_initialized()) {
+          Kernel<ElemwiseDnsCsrDnsWarpKernel<kWriteTo, mshadow_op::plus>, gpu>::Launch(
+            s, kWarpSize * num_rows, out_data.dptr<DType>(), out_data.dptr<DType>(),
+            csr_data.dptr<DType>(), csr_indices.dptr<IType>(),
+            csr_indptr.dptr<CType>(), num_rows, num_cols);
+        }
+      });
+    });
+  });
+}
+
+void ElementwiseSumContainsDnsImpl(mshadow::Stream<gpu>* s,
+                                 const Resource& rsc,
+                                 const std::vector<NDArray>& nds,
+                                 NDArray* out) {
+  using namespace mxnet::op;
+  using namespace mxnet::op::mxnet_op;
+  const TBlob& out_data = out->data();
+  MSHADOW_TYPE_SWITCH(out->dtype(), DType, {  // data type
+    for (size_t i = 0; i < nds.size(); ++i) {
+      const NDArray& nd = nds[i];
+      const nnvm::dim_t num_rows = nd.shape()[0];
+      const nnvm::dim_t num_cols = nd.shape()[1];
+      const TBlob& nd_data = nd.data();
+
+      if (i == 0) {
+        if (nd.storage_type() == kDefaultStorage) {
+          Kernel<op_with_req<mshadow_op::identity, kWriteTo>, gpu>::Launch(
+            s, out_data.Size(), out_data.dptr<DType>(), nd_data.dptr<DType>());
+          continue;
+        } else {
+          Kernel<set_zero, gpu>::Launch(s, out_data.Size(), out_data.dptr<DType>());
+        }
+      }
+
+      switch (nd.storage_type()) {
+        case kDefaultStorage: {
+          Kernel<op_with_req<mshadow_op::plus, kWriteTo>, gpu>::Launch(
+            s, out_data.Size(), out_data.dptr<DType>(), out_data.dptr<DType>(),
+            nd_data.dptr<DType>());
+          break;
+        }
+        case kCSRStorage: {
+          const TBlob& nd_indices = nd.aux_data(csr::kIdx);
+          const TBlob& nd_indptr = nd.aux_data(csr::kIndPtr);
+          MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, {  // indices type
+            MSHADOW_IDX_TYPE_SWITCH(nd_indptr.type_flag_, CType, {  // indptr type
+              if (nd.storage_initialized()) {
+                Kernel<ElemwiseDnsCsrDnsWarpKernel<kWriteTo, mshadow_op::plus>, gpu>::Launch(
+                  s, kWarpSize * num_rows, out_data.dptr<DType>(), out_data.dptr<DType>(),
+                  nd_data.dptr<DType>(), nd_indices.dptr<IType>(),
+                  nd_indptr.dptr<CType>(), num_rows, num_cols);
+              }
+            });
+          });
+          break;
+        }
+        case kRowSparseStorage: {
+          const TBlob& nd_indices = nd.aux_data(rowsparse::kIdx);
+          MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, {  // indices type
+            if (nd.storage_initialized()) {
+              const nnvm::dim_t nz_rows = nd_indices.Size();
+              Kernel<ElemwiseDnsRspDnsKernel<kWriteTo, mshadow_op::plus>, gpu>::Launch(
+                s, nz_rows * num_cols, out_data.dptr<DType>(),
+                out_data.dptr<DType>(), nd_data.dptr<DType>(), nd_indices.dptr<IType>(),
+                num_rows, nz_rows, num_cols);
+            }
+          });
+          break;
+        }
+        default:
+          LOG(FATAL) << "unknown storage type " << nd.storage_type() << "encountered...";
+      }
+    }
+  });
+}
+
 /*!
  * \brief Parallel gpu impl of elemwise sum for sparse tensors.
  * Currently only support row sparse sum.
@@ -195,8 +292,15 @@ void ElementwiseSum<gpu>(mshadow::Stream<gpu>* s,
                          const std::vector<NDArray>& nds,
                          NDArray* out) {
   if (nds.empty()) return;
-  if (nds[0].storage_type() == kRowSparseStorage) {
+  if (common::ContainsOnlyStorage(nds, kRowSparseStorage)) {
     ElementwiseSumRspImpl(s, rsc, nds, out);
+  } else if (nds.size() == 3U && nds[0].storage_type() == kDefaultStorage &&
+             nds[1].storage_type() == kCSRStorage && nds[2].storage_type() == kDefaultStorage &&
+             out->storage_type() == kDefaultStorage) {
+    ElementwiseSumDnsCsrDnsImpl(s, rsc, nds, out);
+  } else if (nds.size() > 4U && common::ContainsStorageType(nds, kDefaultStorage) &&
+             out->storage_type() == kDefaultStorage) {
+    ElementwiseSumContainsDnsImpl(s, rsc, nds, out);
   } else {
     LOG(FATAL) << "ElementwiseSum<gpu> has not been implemented for storage_type = << "
         << nds[0].storage_type();
diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h
index e22e23c..16aa0c3 100644
--- a/src/operator/elemwise_op_common.h
+++ b/src/operator/elemwise_op_common.h
@@ -73,6 +73,16 @@ inline bool ElemwiseStorageAttr(const nnvm::NodeAttrs& attrs,
     dispatched = storage_type_assign(out_attrs, kCSRStorage,
                                      dispatch_mode, dispatch_ex);
   }
+  if (!dispatched && in_attrs->size() == 3U && in_attrs->at(0) == kDefaultStorage &&
+      in_attrs->at(1) == kCSRStorage && in_attrs->at(2) == kDefaultStorage) {
+    dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+                                     dispatch_mode, dispatch_ex);
+  }
+  if (!dispatched && in_attrs->size() > 4U && ContainsStorageType(*in_attrs, kDefaultStorage)) {
+    // *, dense, * -> dense
+    dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+                                     dispatch_mode, dispatch_ex);
+  }
   if (!dispatched) {
     dispatch_fallback(out_attrs, dispatch_mode);
   }
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h
index e5b77e1..6779826 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -511,8 +511,7 @@ void BinaryBroadcastComputeDenseEx(const nnvm::NodeAttrs& attrs,
     // If the input is a matrix with the same shape, should be elemwise
     if (!ndim) {
       mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
-      ElemwiseBinaryOp::DnsCsrDnsOp<xpu, OP>(
-        s, attrs, ctx, dns, csr, req[0], outputs[0], !reverse);
+      ElemwiseBinaryOp::DnsCsrDnsOp<OP>(s, attrs, ctx, dns, csr, req[0], outputs[0], !reverse);
     } else {
       // broadcast(CSR, Dense(1D)) = CSR
       BinaryBroadcastCsrDnsDnsImpl<xpu, OP>(ctx, csr, dns, req[0], out,
diff --git a/src/operator/tensor/elemwise_binary_op-inl.h b/src/operator/tensor/elemwise_binary_op-inl.h
index 878dfb2..1b1a1d2 100644
--- a/src/operator/tensor/elemwise_binary_op-inl.h
+++ b/src/operator/tensor/elemwise_binary_op-inl.h
@@ -27,6 +27,9 @@
 #include <vector>
 #include <algorithm>
 #include "./elemwise_binary_op.h"
+#include "../mxnet_op.h"
+#define WARP_SIZE 32
+#define WARP_SIZE_BITS 5
 
 namespace mxnet {
 namespace op {
@@ -426,9 +429,39 @@ struct ElemwiseDnsCsrDnsKernel {
   }
 };
 
+/*!
+ * \brief Kernel for performing elemwise op between dense and csr matrix
+ * \param tid          global thread id
+ * \param req          type of request
+ * \param out          output array
+ * \param dns_data     data array of dense input
+ * \param csr_data     data array of csr input
+ * \param csr_indices  indices array of csr input
+ * \param csr_indptr   indptr array of csr input
+ * \param num_rows     number of rows of both inputs
+ * \param num_cols     number of columns of both inputs
+ */
+template<int req, typename OP>
+struct ElemwiseDnsCsrDnsWarpKernel {
+  template<typename DType, typename IType, typename CType>
+  MSHADOW_XINLINE static void Map(int tid, DType* out, DType* dns_data,
+                                  const DType* csr_data, const IType* csr_indices,
+                                  const CType* csr_indptr, const nnvm::dim_t num_rows,
+                                  const nnvm::dim_t num_cols) {
+    if (tid < WARP_SIZE * num_rows) {
+      const int row_id = tid >> WARP_SIZE_BITS;
+      const int warp_id = tid & (WARP_SIZE - 1);
+      for (int j = csr_indptr[row_id] + warp_id; j < csr_indptr[row_id+1]; j += WARP_SIZE) {
+        KERNEL_ASSIGN(out[row_id * num_cols + csr_indices[j]], req,
+                      OP::Map(dns_data[row_id * num_cols + csr_indices[j]], csr_data[j]));
+      }
+    }
+  }
+};
+
 /*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
-template<typename xpu, typename OP>
-void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<xpu> *s,
+template<typename OP>
+void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<cpu> *s,
                                    const nnvm::NodeAttrs &attrs,
                                    const OpContext &ctx,
                                    const NDArray &dns,
@@ -455,20 +488,20 @@ void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<xpu> *s,
       MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, {
         MXNET_ASSIGN_REQ_SWITCH(req, Req, {
           if (reverse && std::is_same<OP, mshadow_op::minus>::value) {
-            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::negation, Req>, xpu>::Launch(
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::negation, Req>, cpu>::Launch(
               s, output.data().Size(), output.data().dptr<DType>(), dns.data().dptr<DType>());
             if (!csr.storage_initialized()) { return; }
-            mxnet_op::Kernel<ElemwiseDnsCsrDnsKernel<Req, mshadow_op::plus>, xpu>::Launch(
+            mxnet_op::Kernel<ElemwiseDnsCsrDnsKernel<Req, mshadow_op::plus>, cpu>::Launch(
               s, num_csr_rows, output.data().dptr<DType>(),
               output.data().dptr<DType>(), csr_data.dptr<DType>(), csr_indices.dptr<IType>(),
               csr_indptr.dptr<CType>(), num_csr_rows, num_csr_cols);
           } else {
             if (req == kWriteTo) {
-              mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, Req>, xpu>::Launch(
+              mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, Req>, cpu>::Launch(
                 s, output.data().Size(), output.data().dptr<DType>(), dns.data().dptr<DType>());
             }
             if (!csr.storage_initialized()) { return; }
-            mxnet_op::Kernel<ElemwiseDnsCsrDnsKernel<Req, OP>, xpu>::Launch(
+            mxnet_op::Kernel<ElemwiseDnsCsrDnsKernel<Req, OP>, cpu>::Launch(
               s, num_csr_rows, output.data().dptr<DType>(),
               output.data().dptr<DType>(), csr_data.dptr<DType>(), csr_indices.dptr<IType>(),
               csr_indptr.dptr<CType>(), num_csr_rows, num_csr_cols);
diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h
index fbd79bb..cb1db0e 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -276,8 +276,19 @@ class ElemwiseBinaryOp : public OpBase {
                        const NDArray &output);
 
   /*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
-  template<typename xpu, typename OP>
-  static void DnsCsrDnsOp(mshadow::Stream<xpu> *s,
+  template<typename OP>
+  static void DnsCsrDnsOp(mshadow::Stream<cpu> *s,
+                          const nnvm::NodeAttrs &attrs,
+                          const OpContext &ctx,
+                          const NDArray &lhs,
+                          const NDArray &rhs,
+                          OpReqType req,
+                          const NDArray &output,
+                          const bool reverse);
+
+  /*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
+  template<typename OP>
+  static void DnsCsrDnsOp(mshadow::Stream<gpu> *s,
                           const nnvm::NodeAttrs &attrs,
                           const OpContext &ctx,
                           const NDArray &lhs,
@@ -537,7 +548,7 @@ class ElemwiseBinaryOp : public OpBase {
       const NDArray& csr = (lhs_stype == kCSRStorage)? inputs[0] : inputs[1];
       const bool reverse = (lhs_stype == kCSRStorage);
 
-      DnsCsrDnsOp<xpu, OP>(s, attrs, ctx, dns, csr, req[0], outputs[0], reverse);
+      DnsCsrDnsOp<OP>(s, attrs, ctx, dns, csr, req[0], outputs[0], reverse);
     } else if (((lhs_stype == kRowSparseStorage && rhs_stype == kDefaultStorage) ||
                 (lhs_stype == kDefaultStorage && rhs_stype == kRowSparseStorage)) &&
                 out_stype == kDefaultStorage) {
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu b/src/operator/tensor/elemwise_binary_op_basic.cu
index ea8c1fb..981e7ab 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cu
+++ b/src/operator/tensor/elemwise_binary_op_basic.cu
@@ -25,6 +25,7 @@
 #include <cub/cub.cuh>
 #include "./elemwise_binary_op.h"
 #include "./elemwise_binary_op-inl.h"
+#include "./indexing_op.h"
 
 namespace mxnet {
 namespace op {
@@ -162,6 +163,59 @@ void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<gpu> *s,
   });
 }
 
+/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
+template<typename OP>
+void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<gpu> *s,
+                                   const nnvm::NodeAttrs &attrs,
+                                   const OpContext &ctx,
+                                   const NDArray &dns,
+                                   const NDArray &csr,
+                                   const OpReqType req,
+                                   const NDArray &output,
+                                   const bool reverse) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  CHECK_EQ(dns.storage_type(), kDefaultStorage);
+  CHECK_EQ(csr.storage_type(), kCSRStorage);
+  CHECK(req != kAddTo);
+  CHECK(req != kNullOp);
+  const bool supported_op = std::is_same<OP, mshadow_op::minus>::value ||
+                            std::is_same<OP, mshadow_op::plus>::value;
+  CHECK(supported_op == true);
+  const nnvm::dim_t num_csr_rows = csr.shape()[0];
+  const nnvm::dim_t num_csr_cols = csr.shape()[1];
+  TBlob csr_data = csr.data();
+  TBlob csr_indices = csr.aux_data(csr::kIdx);
+  TBlob csr_indptr = csr.aux_data(csr::kIndPtr);
+  MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, {
+    MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, {
+      MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, {
+        MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+          if (reverse && std::is_same<OP, mshadow_op::minus>::value) {
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::negation, Req>, gpu>::Launch(
+              s, output.data().Size(), output.data().dptr<DType>(), dns.data().dptr<DType>());
+            if (!csr.storage_initialized()) { return; }
+            mxnet_op::Kernel<ElemwiseDnsCsrDnsWarpKernel<Req, mshadow_op::plus>, gpu>::Launch(
+              s, kWarpSize * num_csr_rows, output.data().dptr<DType>(),
+              output.data().dptr<DType>(), csr_data.dptr<DType>(), csr_indices.dptr<IType>(),
+              csr_indptr.dptr<CType>(), num_csr_rows, num_csr_cols);
+          } else {
+            if (req == kWriteTo) {
+              mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, Req>, gpu>::Launch(
+                s, output.data().Size(), output.data().dptr<DType>(), dns.data().dptr<DType>());
+            }
+            if (!csr.storage_initialized()) { return; }
+            mxnet_op::Kernel<ElemwiseDnsCsrDnsWarpKernel<Req, OP>, gpu>::Launch(
+              s, kWarpSize * num_csr_rows, output.data().dptr<DType>(),
+              output.data().dptr<DType>(), csr_data.dptr<DType>(), csr_indices.dptr<IType>(),
+              csr_indptr.dptr<CType>(), num_csr_rows, num_csr_cols);
+          }
+        });
+      });
+    });
+  });
+}
+
 NNVM_REGISTER_OP(elemwise_add)
 .set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::plus>)
 .set_attr<FComputeEx>("FComputeEx<gpu>", ElemwiseBinaryOp::ComputeEx<gpu, op::mshadow_op::plus>);
diff --git a/src/operator/tensor/elemwise_sum.cc b/src/operator/tensor/elemwise_sum.cc
index 8efeb85..9630988 100644
--- a/src/operator/tensor/elemwise_sum.cc
+++ b/src/operator/tensor/elemwise_sum.cc
@@ -114,7 +114,11 @@ void ElementWiseSumComputeExCPU(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(outputs.size(), 1U);
   CHECK_EQ(req.size(), 1U);
   if (req[0] == kNullOp) return;
-  if (inputs[0].storage_type() == kRowSparseStorage) {
+  if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) ||
+      (inputs.size() == 3U && inputs[0].storage_type() == kDefaultStorage &&
+       inputs[1].storage_type() == kCSRStorage && inputs[2].storage_type() == kDefaultStorage) ||
+      (inputs.size() > 4U && common::ContainsStorageType(inputs, kDefaultStorage) &&
+       outputs[0].storage_type() == kDefaultStorage)) {
     mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
     Resource rsc = ResourceManager::Get()->Request(ctx.run_ctx.get_ctx(),
         ResourceRequest(ResourceRequest::kTempSpace));
@@ -145,7 +149,9 @@ MXNET_ADD_SPARSE_OP_ALIAS(ElementWiseSum)
 The storage type of ``add_n`` output depends on storage types of inputs
 
 - add_n(row_sparse, row_sparse, ..) = row_sparse
-- otherwise, ``add_n`` generates output with default storage
+- add_n(default, csr, default) = default
+- add_n(any input combinations longer than 4 (>4) with at least one default type) = default
+- otherwise, ``add_n`` falls all inputs back to default storage and generates default storage
 
 )doc" ADD_FILELINE)
 .set_attr_parser(ParamParser<ElementWiseSumParam>)
diff --git a/src/operator/tensor/elemwise_sum.cu b/src/operator/tensor/elemwise_sum.cu
index 820c8d1..e2b41e3 100644
--- a/src/operator/tensor/elemwise_sum.cu
+++ b/src/operator/tensor/elemwise_sum.cu
@@ -38,7 +38,11 @@ void ElementWiseSumComputeExGPU(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(req.size(), 1U);
   if (req[0] == kNullOp) return;
   CHECK_EQ(req[0], kWriteTo) << "ElementWiseSumComputeExGPU only supports req = kWriteTo";
-  if (inputs[0].storage_type() == kRowSparseStorage) {
+  if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) ||
+      (inputs.size() == 3U && inputs[0].storage_type() == kDefaultStorage &&
+       inputs[1].storage_type() == kCSRStorage && inputs[2].storage_type() == kDefaultStorage) ||
+      (inputs.size() > 4U && common::ContainsStorageType(inputs, kDefaultStorage) &&
+       outputs[0].storage_type() == kDefaultStorage)) {
     mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
     NDArray out_nd = outputs[0];
     mxnet::ndarray::ElementwiseSum<gpu>(s, ctx.requested[0], inputs, &out_nd);
diff --git a/src/operator/tensor/indexing_op-inl.cuh b/src/operator/tensor/indexing_op-inl.cuh
index 34cc263..67dc2bb 100644
--- a/src/operator/tensor/indexing_op-inl.cuh
+++ b/src/operator/tensor/indexing_op-inl.cuh
@@ -27,6 +27,7 @@
 #define MXNET_OPERATOR_TENSOR_INDEXING_OP_CUH_
 #include <cub/device/device_run_length_encode.cuh>
 #include <cub/device/device_scan.cuh>
+#include "../mxnet_op.h"
 
 #if CUDA_VERSION >= 9000
 #define FULLMASK 0xFFFFFFFF
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index 09546ac..e02121a 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1731,14 +1731,14 @@ def test_sparse_storage_fallback():
 
 @with_seed()
 def test_sparse_elementwise_sum():
-    def check_sparse_elementwise_sum_with_shape(stype, shape, n):
+    def check_sparse_elementwise_sum_with_shape(stypes, shape, n):
         # forward
         inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)]
         out = mx.symbol.sparse.add_n(*inputs, name='esum')
         arr = []
-        arr_grad = [mx.nd.empty(shape, stype=stype) for _ in range(n)]
+        arr_grad = [mx.nd.empty(shape, stype=stype) for stype in stypes]
         densities = [0, 0.01, 0.5, 1.0]
-        for i in range(n):
+        for stype in stypes:
             arr.append(rand_ndarray(shape, stype, densities[np.random.randint(0, len(densities))]))
 
         exec1 = out.bind(default_context(),
@@ -1747,18 +1747,30 @@ def test_sparse_elementwise_sum():
         exec1.forward(is_train=True)
         out1 = exec1.outputs[0].asnumpy()
         out = sum(a.asnumpy() for a in arr)
-        assert_almost_equal(out, out1)
+        assert_almost_equal(out, out1, atol=1e-5)
 
         out_grad = mx.nd.empty(shape)
         out_grad[:] = np.random.uniform(-10, 10, shape)
         # backward
         exec1.backward([out_grad])
         for a in arr_grad:
-            assert_almost_equal(a.asnumpy(), out_grad.asnumpy())
+            assert_almost_equal(a.asnumpy(), out_grad.asnumpy(), atol=1e-5)
 
+    all_stypes = ['default', 'csr', 'row_sparse']
     for dim in range(2, 4):
         shape = tuple(np.random.randint(5, 10, size=dim))
-        check_sparse_elementwise_sum_with_shape('row_sparse', shape, np.random.randint(1, 9))
+        rsp_test_cnt = np.random.randint(1, 9)
+        check_sparse_elementwise_sum_with_shape(['row_sparse' for i in range(rsp_test_cnt)], shape, rsp_test_cnt)
+        if dim is 2:
+            check_sparse_elementwise_sum_with_shape(['default', 'csr', 'default'], shape, 3)
+            test_len = np.random.randint(5, 10)
+            # at least one default type
+            stypes = ['default']
+            for i in range(test_len):
+                pick_side = np.random.randint(2)
+                pick_type = np.random.randint(3)
+                stypes = ([all_stypes[pick_type]] if pick_side is 0 else []) + stypes + ([all_stypes[pick_type]] if pick_side is 1 else [])
+            check_sparse_elementwise_sum_with_shape(stypes, shape, test_len+1)
 
 
 @with_seed()