You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/01/19 05:10:44 UTC

[GitHub] eric-haibin-lin commented on a change in pull request #9481: add where op with sparse condition

eric-haibin-lin commented on a change in pull request #9481: add where op with sparse condition
URL: https://github.com/apache/incubator-mxnet/pull/9481#discussion_r162538838
 
 

 ##########
 File path: src/operator/tensor/control_flow_op.h
 ##########
 @@ -185,6 +303,62 @@ void WhereOpForward(const nnvm::NodeAttrs& attrs,
   });
 }
 
+template<typename xpu>
+void WhereOpForwardCsrImpl(mshadow::Stream<xpu> *s,
+                           const NDArray& cond,
+                           const TBlob& x,
+                           const TBlob& y,
+                           const OpReqType req,
+                           const TBlob& out) {
+  using namespace mxnet_op;
+  using namespace csr;
+  if (out.Size() == 0 || req == kNullOp) return;
+  CHECK(cond.shape() == x.shape_)
+    << "WhereOpForwardCsrImpl only supports inputs of same 2-D shapes";
+  CHECK(req == kWriteInplace || req == kWriteTo)
+    << "WhereOpForwardCsrImpl doesn't support req = " << req;
+  MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
+    MSHADOW_TYPE_SWITCH(cond.dtype(), CType, {
+      MSHADOW_TYPE_SWITCH(cond.aux_type(kIdx), IType, {
+        MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+          mshadow::Copy(out.FlatTo1D<xpu, DType>(s), y.FlatTo1D<xpu, DType>(s), s);
+          // no condition is satisfied
+          if (!cond.storage_initialized()) return;
+          IType* cond_idx = cond.aux_data(kIdx).dptr<IType>();
+          IType* cond_indptr = cond.aux_data(kIndPtr).dptr<IType>();
+          CType* cond_data = cond.data().dptr<CType>();
+          Kernel<where_csr<req_type>, xpu>::Launch(s, cond.shape()[0], out.dptr<DType>(),
+                 cond_idx, cond_indptr, cond_data, cond.shape()[1], x.dptr<DType>());
+        });
+      });
+    });
+  });
+}
+
+template<typename xpu>
+void WhereOpForwardEx(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 3U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  const auto& cond_stype = inputs[0].storage_type();
 
 Review comment:
   Good catch!

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services