You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/07/18 00:00:39 UTC

[incubator-mxnet] 03/42: Enable np op compat check with name prefix (#14897)

This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 0eec6fc66775b9aafed1eba1b033b8049f824e15
Author: reminisce <wu...@gmail.com>
AuthorDate: Mon May 6 16:56:36 2019 -0700

    Enable np op compat check with name prefix (#14897)
---
 src/c_api/c_api_common.h                           | 17 ++++++++++++++++-
 src/operator/numpy/np_broadcast_reduce_op_value.cc |  3 +--
 2 files changed, 17 insertions(+), 3 deletions(-)

diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h
index 118341d..ab1f5f7 100644
--- a/src/c_api/c_api_common.h
+++ b/src/c_api/c_api_common.h
@@ -163,10 +163,25 @@ inline void CopyAttr(const nnvm::IndexedGraph& idx,
 extern const std::vector<std::string> kHiddenKeys;
 }  // namespace mxnet
 
+/*!
+ * An operator is considered as numpy compatible if it satisfies either one
+ * of the following conditions.
+ * 1. The op has the attribute mxnet::TIsNumpyCompatible> registered as True.
+ * 2. The op's name starts with the prefix _numpy_.
+ * The first condition is usually for the ops registered as internal ops, such
+ * as _np_add, _true_divide, etc. They are wrapped by some user-facing op
+ * APIs in the Python end.
+ * The second condition is for the ops registered in the backend while exposed
+ * directly to users as is, such as _numpy_sum etc.
+ */
 inline bool IsNumpyCompatOp(const nnvm::Op* op) {
   static const auto& is_np_compat =
       nnvm::Op::GetAttr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible");
-  return is_np_compat.get(op, false);
+  if (is_np_compat.get(op, false)) {
+    return true;
+  }
+  static const std::string prefix = "_numpy_";
+  return op->name.find(prefix.c_str(), 0, prefix.size()) != std::string::npos;
 }
 
 #endif  // MXNET_C_API_C_API_COMMON_H_
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc
index 13b575a..6c81bf6 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cc
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc
@@ -65,8 +65,7 @@ NNVM_REGISTER_OP(_numpy_sum)
   [](const NodeAttrs& attrs) {
     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
   })
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"});
 
 NNVM_REGISTER_OP(_backward_numpy_sum)
 .set_num_outputs(1)