You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/07/18 00:00:39 UTC
[incubator-mxnet] 03/42: Enable np op compat check with name prefix
(#14897)
This is an automated email from the ASF dual-hosted git repository.
haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 0eec6fc66775b9aafed1eba1b033b8049f824e15
Author: reminisce <wu...@gmail.com>
AuthorDate: Mon May 6 16:56:36 2019 -0700
Enable np op compat check with name prefix (#14897)
---
src/c_api/c_api_common.h | 17 ++++++++++++++++-
src/operator/numpy/np_broadcast_reduce_op_value.cc | 3 +--
2 files changed, 17 insertions(+), 3 deletions(-)
diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h
index 118341d..ab1f5f7 100644
--- a/src/c_api/c_api_common.h
+++ b/src/c_api/c_api_common.h
@@ -163,10 +163,25 @@ inline void CopyAttr(const nnvm::IndexedGraph& idx,
extern const std::vector<std::string> kHiddenKeys;
} // namespace mxnet
+/*!
+ * An operator is considered as numpy compatible if it satisfies either one
+ * of the following conditions.
+ * 1. The op has the attribute mxnet::TIsNumpyCompatible> registered as True.
+ * 2. The op's name starts with the prefix _numpy_.
+ * The first condition is usually for the ops registered as internal ops, such
+ * as _np_add, _true_divide, etc. They are wrapped by some user-facing op
+ * APIs in the Python end.
+ * The second condition is for the ops registered in the backend while exposed
+ * directly to users as is, such as _numpy_sum etc.
+ */
inline bool IsNumpyCompatOp(const nnvm::Op* op) {
static const auto& is_np_compat =
nnvm::Op::GetAttr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible");
- return is_np_compat.get(op, false);
+ if (is_np_compat.get(op, false)) {
+ return true;
+ }
+ static const std::string prefix = "_numpy_";
+ return op->name.find(prefix.c_str(), 0, prefix.size()) != std::string::npos;
}
#endif // MXNET_C_API_C_API_COMMON_H_
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc
index 13b575a..6c81bf6 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cc
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc
@@ -65,8 +65,7 @@ NNVM_REGISTER_OP(_numpy_sum)
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"});
NNVM_REGISTER_OP(_backward_numpy_sum)
.set_num_outputs(1)