You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sa...@apache.org on 2020/12/04 21:37:48 UTC
[incubator-mxnet] branch v1.x updated: Support destructors for
custom stateful ops (#19607)
This is an automated email from the ASF dual-hosted git repository.
samskalicky pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 35cab31 Support destructors for custom stateful ops (#19607)
35cab31 is described below
commit 35cab31c89170023ae5d0b305af962bb84caefe0
Author: Sam Skalicky <sa...@gmail.com>
AuthorDate: Fri Dec 4 13:35:46 2020 -0800
Support destructors for custom stateful ops (#19607)
* initial commit
* added warning message
* added destroy function to call delete in custom library
* try removing six per #19604
* fixed lint
* fixed readability
---
example/extensions/lib_custom_op/gemm_lib.cc | 6 +++-
.../extensions/lib_custom_op/transposerowsp_lib.cc | 1 +
include/mxnet/lib_api.h | 42 +++++++++++++++++-----
src/c_api/c_api.cc | 11 +++++-
src/lib_api.cc | 14 ++++++++
5 files changed, 64 insertions(+), 10 deletions(-)
diff --git a/example/extensions/lib_custom_op/gemm_lib.cc b/example/extensions/lib_custom_op/gemm_lib.cc
index f7c2e1a..47efd7b 100644
--- a/example/extensions/lib_custom_op/gemm_lib.cc
+++ b/example/extensions/lib_custom_op/gemm_lib.cc
@@ -184,6 +184,10 @@ class MyStatefulGemm : public CustomStatefulOp {
const std::unordered_map<std::string, std::string>& attrs)
: count(count), attrs_(attrs) {}
+ ~MyStatefulGemm() {
+ std::cout << "Info: destructing MyStatefulGemm" << std::endl;
+ }
+
MXReturnValue Forward(std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
const OpResource& op_res) {
@@ -210,7 +214,7 @@ MXReturnValue createOpState(const std::unordered_map<std::string, std::string>&
// testing passing of keyword arguments
int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
// creating stateful operator instance
- *op_inst = new MyStatefulGemm(count, attrs);
+ *op_inst = CustomStatefulOp::create<MyStatefulGemm>(count, attrs);
std::cout << "Info: stateful operator created" << std::endl;
return MX_SUCCESS;
}
diff --git a/example/extensions/lib_custom_op/transposerowsp_lib.cc b/example/extensions/lib_custom_op/transposerowsp_lib.cc
index 4689a0e..9a4c29a 100644
--- a/example/extensions/lib_custom_op/transposerowsp_lib.cc
+++ b/example/extensions/lib_custom_op/transposerowsp_lib.cc
@@ -185,6 +185,7 @@ MXReturnValue createOpState(const std::unordered_map<std::string, std::string>&
int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
// creating stateful operator instance
*op_inst = new MyStatefulTransposeRowSP(count, attrs);
+ (*op_inst)->ignore_warn = true;
std::cout << "Info: stateful operator created" << std::endl;
return MX_SUCCESS;
}
diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h
index 2f7864f..57d75bf 100644
--- a/include/mxnet/lib_api.h
+++ b/include/mxnet/lib_api.h
@@ -53,7 +53,7 @@
#endif
/* Make sure to update the version number everytime you make changes */
-#define MX_LIBRARY_VERSION 10
+#define MX_LIBRARY_VERSION 11
/*!
* \brief For loading multiple custom op libraries in Linux, exporting same symbol multiple
@@ -686,6 +686,18 @@ class CustomOpSelector {
*/
class CustomStatefulOp {
public:
+ CustomStatefulOp();
+ virtual ~CustomStatefulOp();
+
+ template<class A, typename ...Ts>
+ static CustomStatefulOp* create(Ts...args) {
+ CustomStatefulOp* op = new A(args...);
+ op->created = true;
+ return op;
+ }
+
+ bool wasCreated() { return created; }
+
virtual MXReturnValue Forward(std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
const OpResource& op_res) = 0;
@@ -695,15 +707,11 @@ class CustomStatefulOp {
MX_ERROR_MSG << "Error! Operator does not support backward" << std::endl;
return MX_FAIL;
}
-};
-/*! \brief StatefulOp wrapper class to pass to backend OpState */
-class CustomStatefulOpWrapper {
- public:
- explicit CustomStatefulOpWrapper(CustomStatefulOp* inst) : instance(inst) {}
- CustomStatefulOp* get_instance() { return instance; }
+ bool ignore_warn;
+
private:
- CustomStatefulOp* instance;
+ bool created;
};
/*! \brief Custom Operator function templates */
@@ -1009,6 +1017,9 @@ typedef int (*opCallCreateOpState_t)(createOpState_t create_op, const char* cons
int dev_id, unsigned int** inshapes, int* indims,
int num_in, const int* intypes, void** state_op);
+#define MXLIB_OPCALLDESTROYOPSTATE_STR "_opCallDestroyOpState"
+typedef int (*opCallDestroyOpState_t)(void* state_op);
+
#define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute"
typedef int (*opCallFStatefulComp_t)(int is_forward, void* state_op,
const int64_t** inshapes, int* indims,
@@ -1118,6 +1129,18 @@ typedef int (*msgSize_t)(void);
#define MXLIB_MSGGET_STR "_msgGet"
typedef int (*msgGet_t)(int idx, const char** msg);
+/*! \brief StatefulOp wrapper class to pass to backend OpState */
+class CustomStatefulOpWrapper {
+ public:
+ ~CustomStatefulOpWrapper();
+ explicit CustomStatefulOpWrapper(CustomStatefulOp* inst, opCallDestroyOpState_t destroy)
+ : instance(inst), destroy_(destroy) {}
+ CustomStatefulOp* get_instance() { return instance; }
+ private:
+ CustomStatefulOp* instance;
+ opCallDestroyOpState_t destroy_;
+};
+
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
#define MX_INT_RET __declspec(dllexport) int __cdecl
#define MX_VOID_RET __declspec(dllexport) void __cdecl
@@ -1200,6 +1223,9 @@ extern "C" {
int dev_id, unsigned int** inshapes, int* indims,
int num_in, const int* intypes, void** state_op);
+ /*! \brief returns status of deleting StatefulOp instance for operator from library */
+ MX_VOID_RET _opCallDestroyOpState(void* state_op);
+
/*! \brief returns status of calling Stateful Forward/Backward for operator from library */
MX_INT_RET _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t** inshapes,
int* indims, void** indata, int* intypes, size_t* inIDs,
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 6c57606..1cb5583 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -544,6 +544,9 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
opCallCreateOpState_t callCreateOpState =
get_func<opCallCreateOpState_t>(lib, const_cast<char*>(MXLIB_OPCALLCREATEOPSTATE_STR));
+ opCallDestroyOpState_t callDestroyOpState =
+ get_func<opCallDestroyOpState_t>(lib, const_cast<char*>(MXLIB_OPCALLDESTROYOPSTATE_STR));
+
opCallFStatefulComp_t callFStatefulComp =
get_func<opCallFStatefulComp_t>(lib, const_cast<char*>(MXLIB_OPCALLFSTATEFULCOMP_STR));
@@ -1171,7 +1174,13 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
<< "Error custom library failed to create stateful operator '" << name_str << "'" << msgs;
CustomStatefulOp* state_op = reinterpret_cast<CustomStatefulOp*>(state_op_inst);
- return OpStatePtr::Create<CustomStatefulOpWrapper>(state_op);
+ if (!state_op->wasCreated() && !state_op->ignore_warn)
+ LOG(INFO) << "WARNING! Custom stateful op " << state_op_inst << " was created without "
+ << "calling CustomStatefulOp::create(). Please ensure this object was "
+ << "allocated with 'new' since it will be destructed with 'delete'. "
+ << "To suppress this message without calling CustomStatefulOp::create() "
+ << "set ignore_warn to 'true' on custom stateful op instance.";
+ return OpStatePtr::Create<CustomStatefulOpWrapper>(state_op, callDestroyOpState);
};
/* -------------- BELOW IS THE REGISTRATION FOR CUSTOM OPERATORS --------------- */
diff --git a/src/lib_api.cc b/src/lib_api.cc
index c273678..49c739b 100644
--- a/src/lib_api.cc
+++ b/src/lib_api.cc
@@ -858,6 +858,13 @@ void mxnet::ext::CustomOp::raiseDuplicateContextError() {
+ op_name_str + "'");
}
+mxnet::ext::CustomStatefulOp::CustomStatefulOp() : ignore_warn(false), created(false) {}
+mxnet::ext::CustomStatefulOp::~CustomStatefulOp() {}
+
+mxnet::ext::CustomStatefulOpWrapper::~CustomStatefulOpWrapper() {
+ destroy_(instance);
+}
+
mxnet::ext::CustomPass::CustomPass() : name("ERROR") {}
mxnet::ext::CustomPass::CustomPass(const char* pass_name)
: name(pass_name) {}
@@ -1248,6 +1255,13 @@ MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const cha
return create_op(attrs, ctx, in_shapes, in_types, op_ptr);
}
+/*! \brief calls StatefulOp destructor for operator from library */
+MX_VOID_RET _opCallDestroyOpState(void* state_op) {
+ mxnet::ext::CustomStatefulOp* op_ptr =
+ reinterpret_cast<mxnet::ext::CustomStatefulOp*>(state_op);
+ delete op_ptr;
+}
+
/*! \brief returns status of calling Stateful Forward/Backward for operator from library */
MX_INT_RET _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t** inshapes,
int* indims, void** indata, int* intypes, size_t* inIDs,