You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/31 03:07:24 UTC
[incubator-mxnet] branch master updated: add where op with sparse
condition (#9481)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 8e9798c add where op with sparse condition (#9481)
8e9798c is described below
commit 8e9798cb9739cbb0eb3d05894a113b2f3921f069
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Tue Jan 30 19:07:14 2018 -0800
add where op with sparse condition (#9481)
* add where op with csr condition
* remove typo
* fix lint
* fix doc
* rebase
* fix doc
* CR comments
* fix
---
docs/api/python/ndarray/sparse.md | 1 +
src/operator/tensor/control_flow_op.cc | 44 +++--
src/operator/tensor/control_flow_op.cu | 6 +-
src/operator/tensor/control_flow_op.h | 250 ++++++++++++++++++++++++++
tests/python/unittest/test_sparse_operator.py | 96 ++++++++++
5 files changed, 384 insertions(+), 13 deletions(-)
diff --git a/docs/api/python/ndarray/sparse.md b/docs/api/python/ndarray/sparse.md
index 9fca3c1..a7aaa1f 100644
--- a/docs/api/python/ndarray/sparse.md
+++ b/docs/api/python/ndarray/sparse.md
@@ -372,6 +372,7 @@ We summarize the interface for each class in the following sections.
slice
retain
+ where
```
## Mathematical functions
diff --git a/src/operator/tensor/control_flow_op.cc b/src/operator/tensor/control_flow_op.cc
index 9e1091e..164fd6a 100644
--- a/src/operator/tensor/control_flow_op.cc
+++ b/src/operator/tensor/control_flow_op.cc
@@ -28,16 +28,33 @@ namespace mxnet {
namespace op {
NNVM_REGISTER_OP(where)
-.MXNET_DESCRIBE("Given three ndarrays, condition, x, and y, return an ndarray"
- " with the elements from x or y, depending on the elements"
- " from condition are true or false. x and y must have the same"
- " shape. If condition has the same shape as x, each element"
- " in the output array is from x if the corresponding element"
- " in the condition is true, and from y if false. If condition"
- " does not have the same shape as x, it must be a 1D array"
- " whose size is the same as x's first dimension size. Each"
- " row of the output array is from x's row if the corresponding"
- " element from condition is true, and from y's row if false.")
+MXNET_ADD_SPARSE_OP_ALIAS(where)
+.describe(R"code(Return the elements, either from x or y, depending on the condition.
+
+Given three ndarrays, condition, x, and y, return an ndarray with the elements from x or y,
+depending on the elements from condition are true or false. x and y must have the same shape.
+If condition has the same shape as x, each element in the output array is from x if the
+corresponding element in the condition is true, and from y if false.
+
+If condition does not have the same shape as x, it must be a 1D array whose size is
+the same as x's first dimension size. Each row of the output array is from x's row
+if the corresponding element from condition is true, and from y's row if false.
+
+Note that all non-zero values are interpreted as ``True`` in condition.
+
+Examples::
+
+ x = [[1, 2], [3, 4]]
+ y = [[5, 6], [7, 8]]
+ cond = [[0, 1], [-1, 0]]
+
+ where(cond, x, y) = [[5, 2], [3, 8]]
+
+ csr_cond = cast_storage(cond, 'csr')
+
+ where(csr_cond, x, y) = [[5, 2], [3, 8]]
+
+)code" ADD_FILELINE)
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
@@ -46,7 +63,9 @@ NNVM_REGISTER_OP(where)
})
.set_attr<nnvm::FInferShape>("FInferShape", WhereOpShape)
.set_attr<nnvm::FInferType>("FInferType", WhereOpType)
+.set_attr<FInferStorageType>("FInferStorageType", WhereOpForwardStorageType)
.set_attr<FCompute>("FCompute<cpu>", WhereOpForward<cpu>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", WhereOpForwardEx<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
// Use the following lambda function instead of ElemwiseGradUseIn
// for best efficiency. grad[condition] = 0; to calculate grad[x] and grad[y]
@@ -83,7 +102,10 @@ NNVM_REGISTER_OP(_backward_where)
.set_num_inputs(2)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<FCompute>("FCompute<cpu>", WhereOpBackward<cpu>);
+.set_attr<FInferStorageType>("FInferStorageType", WhereOpBackwardStorageType)
+.set_attr<FCompute>("FCompute<cpu>", WhereOpBackward<cpu>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", WhereOpBackwardEx<cpu>);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/tensor/control_flow_op.cu b/src/operator/tensor/control_flow_op.cu
index cc5198d..7c34a65 100644
--- a/src/operator/tensor/control_flow_op.cu
+++ b/src/operator/tensor/control_flow_op.cu
@@ -28,10 +28,12 @@ namespace mxnet {
namespace op {
NNVM_REGISTER_OP(where)
-.set_attr<FCompute>("FCompute<gpu>", WhereOpForward<gpu>);
+.set_attr<FCompute>("FCompute<gpu>", WhereOpForward<gpu>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", WhereOpForwardEx<gpu>);
NNVM_REGISTER_OP(_backward_where)
-.set_attr<FCompute>("FCompute<gpu>", WhereOpBackward<gpu>);
+.set_attr<FCompute>("FCompute<gpu>", WhereOpBackward<gpu>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", WhereOpBackwardEx<gpu>);
} // namespace op
} // namespace mxnet
diff --git a/src/operator/tensor/control_flow_op.h b/src/operator/tensor/control_flow_op.h
index f1136c8..503bc7c 100644
--- a/src/operator/tensor/control_flow_op.h
+++ b/src/operator/tensor/control_flow_op.h
@@ -31,6 +31,7 @@
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"
+#include "../tensor/init_op.h"
namespace mxnet {
namespace op {
@@ -51,6 +52,35 @@ struct where {
}
};
+/*! \brief Choose elements from x or y depending on condition.
+ * The condition, x, and y have the same shape.
+ * The returned array is formed by elements from x or y
+ * depending on the elements of condition.
+ * The condition is a csr matrix, while x and y are both dense.
+ */
+template<int req>
+struct where_csr {
+ // DType is the output data type
+ // CType is condition data type
+ // i is for i-th row in the output
+ template<typename DType, typename CType, typename IType>
+ MSHADOW_XINLINE static void Map(int i, DType* out, const IType* cond_idx,
+ const IType* cond_indptr, const CType* cond_data,
+ const nnvm::dim_t num_cols, const DType* x) {
+ using nnvm::dim_t;
+ const dim_t offset = i * num_cols;
+ for (dim_t j = cond_indptr[i]; j < cond_indptr[i + 1]; j++) {
+ const CType data = cond_data[j];
+ if (data != 0) {
+ const IType col_idx = cond_idx[j];
+ const dim_t out_idx = offset + col_idx;
+ KERNEL_ASSIGN(out[out_idx], req, x[out_idx]);
+ }
+ }
+ }
+};
+
+
/*! \brief Choose elements from x or y depending on condition
* The condition is a vector whose size is the same as the
* x's first dim size.
@@ -92,6 +122,37 @@ struct where_backward {
* template argument req is OpReqType; negate indicates
* whether the output is grad_x (negate=true)
* or grad_y (negate=false).
+ * cond is a csr matrix, while others are dense ones.
+ */
+template<int req, bool negate>
+struct where_backward_csr {
+ // DType is the output data type
+ // CType is condition data type
+ // IType is condition aux data type
+ template<typename DType, typename CType, typename IType>
+ MSHADOW_XINLINE static void Map(int i, DType* grad_out,
+ const DType* grad_in,
+ const CType* cond_data,
+ const IType* cond_idx,
+ const IType* cond_indptr,
+ const nnvm::dim_t num_cols) {
+ const IType offset = i * num_cols;
+ const DType zero = static_cast<DType>(0);
+ for (IType j = cond_indptr[i]; j < cond_indptr[i + 1]; j++) {
+ const IType col = cond_idx[j];
+ const IType grad_offset = offset + col;
+ KERNEL_ASSIGN(grad_out[grad_offset], req,
+ ((0 == cond_data[j])^negate)? grad_in[grad_offset] : zero);
+ }
+ }
+};
+
+
+/*!
+ * \brief Template for calculating grad[x] and grad[y].
+ * template argument req is OpReqType; negate indicates
+ * whether the output is grad_x (negate=true)
+ * or grad_y (negate=false).
* The condition is a vector whose size is the same as the
* x's first dim size.
*/
@@ -152,6 +213,63 @@ inline bool WhereOpType(const nnvm::NodeAttrs& attrs,
return true;
}
+inline bool WhereOpForwardStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 3U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ const int cond_stype = in_attrs->at(0);
+ const int x_stype = in_attrs->at(1);
+ const int y_stype = in_attrs->at(2);
+ auto& out_stype = out_attrs->at(0);
+ bool dispatched = false;
+ if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+ // dns, dns -> dns
+ dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
+ DispatchMode::kFCompute);
+ }
+ if (!dispatched && cond_stype == kCSRStorage && x_stype == kDefaultStorage &&
+ y_stype == kDefaultStorage) {
+ // csr, dns, dns -> dns
+ dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+ dispatch_mode, DispatchMode::kFComputeEx);
+ }
+ if (!dispatched) {
+ dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+ }
+ return dispatched;
+}
+
+inline bool WhereOpBackwardStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 2U);
+ CHECK_EQ(out_attrs->size(), 2U);
+ const auto in_grad_stype = in_attrs->at(0);
+ const auto cond_stype = in_attrs->at(1);
+ bool dispatched = false;
+ if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+ // dns, dns -> dns, dns
+ dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+ dispatch_mode, DispatchMode::kFCompute);
+ }
+ if (!dispatched && cond_stype == kCSRStorage && in_grad_stype == kDefaultStorage) {
+ // dns, csr -> dns, dns
+ dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+ dispatch_mode, DispatchMode::kFComputeEx);
+ }
+ if (!dispatched) {
+ dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+ }
+ return dispatched;
+}
+
+
+
template<typename xpu>
void WhereOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
@@ -185,6 +303,62 @@ void WhereOpForward(const nnvm::NodeAttrs& attrs,
});
}
+template<typename xpu>
+void WhereOpForwardCsrImpl(mshadow::Stream<xpu> *s,
+ const NDArray& cond,
+ const TBlob& x,
+ const TBlob& y,
+ const OpReqType req,
+ const TBlob& out) {
+ using namespace mxnet_op;
+ using namespace csr;
+ if (out.Size() == 0 || req == kNullOp) return;
+ CHECK(cond.shape() == x.shape_)
+ << "WhereOpForwardCsrImpl only supports inputs of same 2-D shapes";
+ CHECK(req == kWriteInplace || req == kWriteTo)
+ << "WhereOpForwardCsrImpl doesn't support req = " << req;
+ MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
+ MSHADOW_TYPE_SWITCH(cond.dtype(), CType, {
+ MSHADOW_TYPE_SWITCH(cond.aux_type(kIdx), IType, {
+ MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+ mshadow::Copy(out.FlatTo1D<xpu, DType>(s), y.FlatTo1D<xpu, DType>(s), s);
+ // no condition is satisfied
+ if (!cond.storage_initialized()) return;
+ IType* cond_idx = cond.aux_data(kIdx).dptr<IType>();
+ IType* cond_indptr = cond.aux_data(kIndPtr).dptr<IType>();
+ CType* cond_data = cond.data().dptr<CType>();
+ Kernel<where_csr<req_type>, xpu>::Launch(s, cond.shape()[0], out.dptr<DType>(),
+ cond_idx, cond_indptr, cond_data, cond.shape()[1], x.dptr<DType>());
+ });
+ });
+ });
+ });
+}
+
+template<typename xpu>
+void WhereOpForwardEx(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ CHECK_EQ(inputs.size(), 3U);
+ CHECK_EQ(outputs.size(), 1U);
+ CHECK_EQ(req.size(), 1U);
+ const int cond_stype = inputs[0].storage_type();
+ const int x_stype = inputs[1].storage_type();
+ const int y_stype = inputs[2].storage_type();
+ const auto& out_stype = outputs[0].storage_type();
+ mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+ CHECK_NE(inputs[0].shape().ndim(), 1) << "WhereOpForwardEx with 1-D cond is not implemented";
+ if (cond_stype == kCSRStorage && x_stype == kDefaultStorage &&
+ y_stype == kDefaultStorage && out_stype == kDefaultStorage) {
+ WhereOpForwardCsrImpl(s, inputs[0], inputs[1].data(), inputs[2].data(), req[0],
+ outputs[0].data());
+ } else {
+ LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+ }
+}
+
/*!
* \brief Compute the gradient of the loss function
* with respect to condition, x, and y. The gradient
@@ -239,6 +413,82 @@ void WhereOpBackward(const nnvm::NodeAttrs& attrs,
});
}
+template<typename xpu>
+void WhereOpBackwardCsrImpl(mshadow::Stream<xpu> *s,
+ const TBlob& grad_in,
+ const NDArray& cond,
+ const std::vector<OpReqType>& req,
+ const TBlob& grad_x,
+ const TBlob& grad_y) {
+ using namespace mxnet_op;
+ using namespace csr;
+ if (grad_in.Size() == 0) return;
+ CHECK(cond.shape() == grad_x.shape_)
+ << "WhereOpForwardCsrImpl only supports inputs of same 2-D shapes";
+ CHECK_NE(req[0], kAddTo) << "WhereOpForwardCsrImpl doesn't support kAddTo";
+ CHECK_NE(req[1], kAddTo) << "WhereOpForwardCsrImpl doesn't support kAddTo";
+ MSHADOW_TYPE_SWITCH(grad_in.type_flag_, DType, {
+ MSHADOW_TYPE_SWITCH(cond.dtype(), CType, {
+ MSHADOW_IDX_TYPE_SWITCH(cond.aux_type(kIdx), IType, {
+ if (req[0] != kNullOp) {
+ Fill<false>(s, grad_x, req[0], 0);
+ // some conditions are satisfied
+ if (cond.storage_initialized()) {
+ const IType* cond_indptr = cond.aux_data(kIndPtr).dptr<IType>();
+ const IType* cond_idx = cond.aux_data(kIdx).dptr<IType>();
+ const CType* cond_data = cond.data().dptr<CType>();
+ MXNET_ASSIGN_REQ_SWITCH(req[0], req_type_x, {
+ Kernel<where_backward_csr<req_type_x, true>, xpu>::Launch(s, cond.shape()[0],
+ grad_x.dptr<DType>(), grad_in.dptr<DType>(), cond_data, cond_idx,
+ cond_indptr, cond.shape()[1]);
+ });
+ }
+ }
+ if (req[1] != kNullOp) {
+ mshadow::Copy(grad_y.FlatTo1D<xpu, DType>(s), grad_in.FlatTo1D<xpu, DType>(s), s);
+ CHECK_EQ(req[1], kWriteTo);
+ if (cond.storage_initialized()) {
+ const IType* cond_indptr = cond.aux_data(kIndPtr).dptr<IType>();
+ const IType* cond_idx = cond.aux_data(kIdx).dptr<IType>();
+ const CType* cond_data = cond.data().dptr<CType>();
+ MXNET_ASSIGN_REQ_SWITCH(req[1], req_type_y, {
+ Kernel<where_backward_csr<req_type_y, false>, xpu>::Launch(s, cond.shape()[0],
+ grad_y.dptr<DType>(), grad_in.dptr<DType>(), cond_data, cond_idx,
+ cond_indptr, cond.shape()[1]);
+ });
+ }
+ }
+ });
+ });
+ });
+}
+
+template<typename xpu>
+void WhereOpBackwardEx(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ CHECK_EQ(inputs.size(), 2U);
+ CHECK_EQ(req.size(), 2U);
+ CHECK_EQ(outputs.size(), 2U);
+ mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+ if (inputs[1].shape().ndim() == 1) {
+ LOG(FATAL) << "WhereOpBackwardEx with 1-D cond is not implemented";
+ }
+ const auto grad_in_stype = inputs[0].storage_type();
+ const auto cond_stype = inputs[1].storage_type();
+ const auto grad_x_stype = outputs[0].storage_type();
+ const auto grad_y_stype = outputs[1].storage_type();
+ if (grad_in_stype == kDefaultStorage && cond_stype == kCSRStorage &&
+ grad_x_stype == kDefaultStorage && grad_y_stype == kDefaultStorage) {
+ WhereOpBackwardCsrImpl(s, inputs[0].data(), inputs[1], req, outputs[0].data(),
+ outputs[1].data());
+ } else {
+ LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+ }
+}
+
} // namespace op
} // namespace mxnet
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index 134cb26..84dfc58 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1798,6 +1798,102 @@ def test_scatter_ops():
lambda l, r: l + r,
rhs_is_scalar=True, verbose=False, density=0.5)
+def test_sparse_nd_where():
+ def get_forward_expected_output(condition, x, y):
+ original_shape = x.shape
+ out = np.zeros(original_shape)
+ if condition.shape == x.shape:
+ for index, c in np.ndenumerate(condition):
+ if c != 0:
+ out[index] = x[index]
+ else:
+ out[index] = y[index]
+ else:
+ raise RuntimeError("Invalid condition shape for where op")
+
+ out = out.reshape(original_shape)
+ return out
+
+ def get_forward_inputs_same_shape(shape):
+ condition_np = np.random.randint(0, 2, np.prod(shape)).reshape(shape)
+ x_np = np.random.randint(1, 6, np.prod(shape)).reshape(shape)
+ y_np = np.random.randint(7, 11, np.prod(shape)).reshape(shape)
+ return condition_np, x_np, y_np
+
+ def get_backward_input(shape):
+ return np.random.randint(20, 30, np.prod(shape)).reshape(shape)
+
+ def get_backward_expected_outputs(grad_in, condition):
+ shape = grad_in.shape
+ grad_cond = np.zeros(condition.shape)
+ grad_x = np.empty(shape)
+ grad_y = np.empty(shape)
+
+ for index, c in np.ndenumerate(condition):
+ if 0 != c:
+ grad_x[index] = grad_in[index]
+ grad_y[index] = 0
+ else:
+ grad_x[index] = 0
+ grad_y[index] = grad_in[index]
+
+ return grad_cond, grad_x, grad_y
+
+ def test_where_helper(shape):
+ condition_np, x_np, y_np = get_forward_inputs_same_shape(shape)
+
+ out_expected = get_forward_expected_output(condition_np, x_np, y_np)
+
+ grad_in_np = get_backward_input(shape)
+ grad_expected_cond, grad_expected_x, grad_expected_y \
+ = get_backward_expected_outputs(grad_in_np, condition_np)
+
+ condition = mx.sym.Variable('condition', stype='csr')
+ x = mx.sym.Variable('x')
+ y = mx.sym.Variable('y')
+ grad_in_mx = mx.nd.array(grad_in_np, dtype=np.int32)
+ where_sym = mx.sym.where(condition, x, y)
+
+ # test req='write'
+ where_exe_write = where_sym.simple_bind(ctx=default_context(),
+ condition=condition_np.shape,
+ x=x_np.shape, y=y_np.shape,
+ grad_req='write')
+ # test forward req='write'
+ cond_nd = mx.nd.array(condition_np).tostype('csr')
+ outputs = where_exe_write.forward(is_train=True, \
+ condition=cond_nd, x=x_np, y=y_np)
+ assert same(outputs[0].asnumpy(), out_expected)
+ # test backward req='write'
+ where_exe_write.backward(grad_in_mx)
+ assert same(where_exe_write.grad_dict['x'].asnumpy(), grad_expected_x)
+ assert same(where_exe_write.grad_dict['y'].asnumpy(), grad_expected_y)
+ assert same(where_exe_write.grad_dict['condition'].asnumpy(), grad_expected_cond)
+
+ # test req='add'
+ x_grad_init = np.random.randint(30, 40, np.prod(shape)).reshape(shape)
+ y_grad_init = np.random.randint(40, 50, np.prod(shape)).reshape(shape)
+ where_exe_add = where_sym.simple_bind(ctx=default_context(),
+ condition=cond_nd.shape,
+ x=x_np.shape, y=y_np.shape,
+ grad_req='add')
+ where_exe_add.grad_dict['x'][:] = x_grad_init
+ where_exe_add.grad_dict['y'][:] = y_grad_init
+ # test forward req='add'
+ outputs = where_exe_add.forward(is_train=True, condition=cond_nd, x=x_np, y=y_np)
+ assert same(outputs[0].asnumpy(), out_expected)
+
+ def test_where_numeric_gradient(shape):
+ condition = mx.sym.Variable('condition', stype='csr')
+ x = mx.sym.Variable('x')
+ y = mx.sym.Variable('y')
+ where_sym = mx.sym.where(condition, x, y)
+ condition_np, x_np, y_np = get_forward_inputs_same_shape(shape)
+ check_numeric_gradient(where_sym, [condition_np, x_np, y_np], grad_nodes=['x', 'y'])
+
+ test_where_helper((5, 9))
+ test_where_numeric_gradient((5, 9))
+
if __name__ == '__main__':
import nose
--
To stop receiving notification emails like this one, please contact
jxie@apache.org.