You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/19 05:07:57 UTC

[GitHub] cjolivier01 commented on a change in pull request #10150: [WIP] [DO NOT MERGE] Sparse operator broadcast_mul/div(csr, dense) = csr

cjolivier01 commented on a change in pull request #10150: [WIP] [DO NOT MERGE] Sparse operator broadcast_mul/div(csr, dense) = csr
URL: https://github.com/apache/incubator-mxnet/pull/10150#discussion_r175332618
 
 

 ##########
 File path: src/operator/tensor/elemwise_binary_broadcast_op.h
 ##########
 @@ -185,6 +227,75 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
   }
 }
 
+template<typename xpu, typename OP>
+void BinaryBroadCastCsrDnsCsrImpl(const OpContext& ctx,
+                                  const NDArray& csr,
+                                  const NDArray& dns,
+                                  const OpReqType req,
+                                  const NDArray& output) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace csr;
+  CHECK_EQ(dns.shape().ndim(), 1) << "input dense should be a vector";
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  bool col_vec = (dns.shape()[0] == csr.shape()[0])? true : false;
+  if (!csr.storage_initialized()) {
+    FillZerosCsrImpl(s, output);
+    return;
+  }
+  const nnvm::dim_t nnz = csr.storage_shape()[0];
+  const nnvm::dim_t num_rows = output.shape()[0];
+  output.CheckAndAlloc({Shape1(num_rows + 1), Shape1(nnz)});
+
+  MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
+    MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), CType, {
+      MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIndPtr), RType, {
+        MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+          Kernel<csr_dns_csr_broadcast_kernel<DType, CType, RType, req_type, OP>, xpu>::Launch(
+            s, num_rows, csr.data().dptr<DType>(), csr.aux_data(kIdx).dptr<CType>(),
+            csr.aux_data(kIndPtr).dptr<RType>(), dns.data().dptr<DType>(),
+            output.data().dptr<DType>(), csr.shape()[1], col_vec);
+          Copy(output.aux_data(kIdx).FlatTo1D<xpu, CType>(),
+               csr.aux_data(kIdx).FlatTo1D<xpu, CType>());
+          Copy(output.aux_data(kIndPtr).FlatTo1D<xpu, RType>(),
+               csr.aux_data(kIndPtr).FlatTo1D<xpu, RType>());
+        });
+      });
+    });
+  });
+}
+
+template<typename xpu, typename OP>
+void BinaryBroadcastComputeCsrEx(const nnvm::NodeAttrs& attrs,
+                                 const OpContext& ctx,
+                                 const std::vector<NDArray>& inputs,
+                                 const std::vector<OpReqType>& req,
+                                 const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  CHECK_LE(inputs[1].shape().ndim(), 2U) << "input dense matrix should have less than 2 dimensions";
+  const auto in1_stype = inputs[0].storage_type();
+  const auto in2_stype = inputs[1].storage_type();
+  const auto out_stype = outputs[0].storage_type();
+  if (!(inputs[1].shape().ndim() == 1U)) {
+    ElemwiseBinaryOp::ComputeEx<xpu, OP>(attrs, ctx, inputs, req, outputs);
+  } else {
+    if (req[0] != kNullOp) {
+      // broadcast(CSR, Dense(1D)) = CSR
+      if (in1_stype == kCSRStorage && in2_stype == kDefaultStorage && out_stype == kCSRStorage) {
+        BinaryBroadCastCsrDnsCsrImpl<xpu, OP>(ctx, inputs[0], inputs[1], req[0], outputs[0]);
+      // broadcast(CSR, Dense(1D)) = Dense
+      //} else if (in1_stype == kCSRStorage && in2_stype == kDefaultStorage &&
+      //           out_stype == kDefaultStorage) {
+      //  BinaryBroadCastCsrDnsDnsImpl(ctx, inputs[0], input[1], req[0], outputs[0]);
+      } else {
+        LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
 
 Review comment:
   would catching this in the storage type inference and then doing a fallback not work for this case?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services