You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sa...@apache.org on 2020/12/04 21:37:48 UTC

[incubator-mxnet] branch v1.x updated: Support destructors for custom stateful ops (#19607)

This is an automated email from the ASF dual-hosted git repository.

samskalicky pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 35cab31  Support destructors for custom stateful ops (#19607)
35cab31 is described below

commit 35cab31c89170023ae5d0b305af962bb84caefe0
Author: Sam Skalicky <sa...@gmail.com>
AuthorDate: Fri Dec 4 13:35:46 2020 -0800

    Support destructors for custom stateful ops (#19607)
    
    * initial commit
    
    * added warning message
    
    * added destroy function to call delete in custom library
    
    * try removing six per #19604
    
    * fixed lint
    
    * fixed readability
---
 example/extensions/lib_custom_op/gemm_lib.cc       |  6 +++-
 .../extensions/lib_custom_op/transposerowsp_lib.cc |  1 +
 include/mxnet/lib_api.h                            | 42 +++++++++++++++++-----
 src/c_api/c_api.cc                                 | 11 +++++-
 src/lib_api.cc                                     | 14 ++++++++
 5 files changed, 64 insertions(+), 10 deletions(-)

diff --git a/example/extensions/lib_custom_op/gemm_lib.cc b/example/extensions/lib_custom_op/gemm_lib.cc
index f7c2e1a..47efd7b 100644
--- a/example/extensions/lib_custom_op/gemm_lib.cc
+++ b/example/extensions/lib_custom_op/gemm_lib.cc
@@ -184,6 +184,10 @@ class MyStatefulGemm : public CustomStatefulOp {
                           const std::unordered_map<std::string, std::string>& attrs)
     : count(count), attrs_(attrs) {}
 
+  ~MyStatefulGemm() {
+    std::cout << "Info: destructing MyStatefulGemm" << std::endl;
+  }
+  
   MXReturnValue Forward(std::vector<MXTensor>* inputs,
                         std::vector<MXTensor>* outputs,
                         const OpResource& op_res) {
@@ -210,7 +214,7 @@ MXReturnValue createOpState(const std::unordered_map<std::string, std::string>&
   // testing passing of keyword arguments
   int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
   // creating stateful operator instance
-  *op_inst = new MyStatefulGemm(count, attrs);
+  *op_inst = CustomStatefulOp::create<MyStatefulGemm>(count, attrs);
   std::cout << "Info: stateful operator created" << std::endl;
   return MX_SUCCESS;
 }
diff --git a/example/extensions/lib_custom_op/transposerowsp_lib.cc b/example/extensions/lib_custom_op/transposerowsp_lib.cc
index 4689a0e..9a4c29a 100644
--- a/example/extensions/lib_custom_op/transposerowsp_lib.cc
+++ b/example/extensions/lib_custom_op/transposerowsp_lib.cc
@@ -185,6 +185,7 @@ MXReturnValue createOpState(const std::unordered_map<std::string, std::string>&
   int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
   // creating stateful operator instance
   *op_inst = new MyStatefulTransposeRowSP(count, attrs);
+  (*op_inst)->ignore_warn = true;
   std::cout << "Info: stateful operator created" << std::endl;
   return MX_SUCCESS;
 }
diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h
index 2f7864f..57d75bf 100644
--- a/include/mxnet/lib_api.h
+++ b/include/mxnet/lib_api.h
@@ -53,7 +53,7 @@
 #endif
 
 /* Make sure to update the version number everytime you make changes */
-#define MX_LIBRARY_VERSION 10
+#define MX_LIBRARY_VERSION 11
 
 /*!
  * \brief For loading multiple custom op libraries in Linux, exporting same symbol multiple
@@ -686,6 +686,18 @@ class CustomOpSelector {
  */
 class CustomStatefulOp {
  public:
+  CustomStatefulOp();
+  virtual ~CustomStatefulOp();
+
+  template<class A, typename ...Ts>
+  static CustomStatefulOp* create(Ts...args) {
+    CustomStatefulOp* op = new A(args...);
+    op->created = true;
+    return op;
+  }
+
+  bool wasCreated() { return created; }
+
   virtual MXReturnValue Forward(std::vector<MXTensor>* inputs,
                                 std::vector<MXTensor>* outputs,
                                 const OpResource& op_res) = 0;
@@ -695,15 +707,11 @@ class CustomStatefulOp {
     MX_ERROR_MSG << "Error! Operator does not support backward" << std::endl;
     return MX_FAIL;
   }
-};
 
-/*! \brief StatefulOp wrapper class to pass to backend OpState */
-class CustomStatefulOpWrapper {
- public:
-  explicit CustomStatefulOpWrapper(CustomStatefulOp* inst) : instance(inst) {}
-  CustomStatefulOp* get_instance() { return instance; }
+  bool ignore_warn;
+
  private:
-  CustomStatefulOp* instance;
+  bool created;
 };
 
 /*! \brief Custom Operator function templates */
@@ -1009,6 +1017,9 @@ typedef int (*opCallCreateOpState_t)(createOpState_t create_op, const char* cons
                                      int dev_id, unsigned int** inshapes, int* indims,
                                      int num_in, const int* intypes, void** state_op);
 
+#define MXLIB_OPCALLDESTROYOPSTATE_STR "_opCallDestroyOpState"
+typedef int (*opCallDestroyOpState_t)(void* state_op);
+
 #define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute"
 typedef int (*opCallFStatefulComp_t)(int is_forward, void* state_op,
                                      const int64_t** inshapes, int* indims,
@@ -1118,6 +1129,18 @@ typedef int (*msgSize_t)(void);
 #define MXLIB_MSGGET_STR "_msgGet"
 typedef int (*msgGet_t)(int idx, const char** msg);
 
+/*! \brief StatefulOp wrapper class to pass to backend OpState */
+class CustomStatefulOpWrapper {
+ public:
+  ~CustomStatefulOpWrapper();
+  explicit CustomStatefulOpWrapper(CustomStatefulOp* inst, opCallDestroyOpState_t destroy)
+    : instance(inst), destroy_(destroy) {}
+  CustomStatefulOp* get_instance() { return instance; }
+ private:
+  CustomStatefulOp* instance;
+  opCallDestroyOpState_t destroy_;
+};
+
 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
 #define MX_INT_RET  __declspec(dllexport) int __cdecl
 #define MX_VOID_RET __declspec(dllexport) void __cdecl
@@ -1200,6 +1223,9 @@ extern "C" {
                                   int dev_id, unsigned int** inshapes, int* indims,
                                   int num_in, const int* intypes, void** state_op);
 
+  /*! \brief returns status of deleting StatefulOp instance for operator from library */
+  MX_VOID_RET _opCallDestroyOpState(void* state_op);
+
   /*! \brief returns status of calling Stateful Forward/Backward for operator from library */
   MX_INT_RET _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t** inshapes,
                                      int* indims, void** indata, int* intypes, size_t* inIDs,
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 6c57606..1cb5583 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -544,6 +544,9 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
   opCallCreateOpState_t callCreateOpState =
     get_func<opCallCreateOpState_t>(lib, const_cast<char*>(MXLIB_OPCALLCREATEOPSTATE_STR));
 
+  opCallDestroyOpState_t callDestroyOpState =
+    get_func<opCallDestroyOpState_t>(lib, const_cast<char*>(MXLIB_OPCALLDESTROYOPSTATE_STR));
+
   opCallFStatefulComp_t callFStatefulComp =
     get_func<opCallFStatefulComp_t>(lib, const_cast<char*>(MXLIB_OPCALLFSTATEFULCOMP_STR));
 
@@ -1171,7 +1174,13 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
       << "Error custom library failed to create stateful operator '" << name_str << "'" << msgs;
 
       CustomStatefulOp* state_op = reinterpret_cast<CustomStatefulOp*>(state_op_inst);
-      return OpStatePtr::Create<CustomStatefulOpWrapper>(state_op);
+      if (!state_op->wasCreated() && !state_op->ignore_warn)
+        LOG(INFO) << "WARNING! Custom stateful op " << state_op_inst << " was created without "
+                  << "calling CustomStatefulOp::create(). Please ensure this object was "
+                  << "allocated with 'new' since it will be destructed with 'delete'. "
+                  << "To suppress this message without calling CustomStatefulOp::create() "
+                  << "set ignore_warn to 'true' on custom stateful op instance.";
+      return OpStatePtr::Create<CustomStatefulOpWrapper>(state_op, callDestroyOpState);
     };
 
     /* -------------- BELOW IS THE REGISTRATION FOR CUSTOM OPERATORS --------------- */
diff --git a/src/lib_api.cc b/src/lib_api.cc
index c273678..49c739b 100644
--- a/src/lib_api.cc
+++ b/src/lib_api.cc
@@ -858,6 +858,13 @@ void mxnet::ext::CustomOp::raiseDuplicateContextError() {
     + op_name_str + "'");
 }
 
+mxnet::ext::CustomStatefulOp::CustomStatefulOp() : ignore_warn(false), created(false) {}
+mxnet::ext::CustomStatefulOp::~CustomStatefulOp() {}
+
+mxnet::ext::CustomStatefulOpWrapper::~CustomStatefulOpWrapper() {
+  destroy_(instance);
+}
+
 mxnet::ext::CustomPass::CustomPass() : name("ERROR") {}
 mxnet::ext::CustomPass::CustomPass(const char* pass_name)
   : name(pass_name) {}
@@ -1248,6 +1255,13 @@ MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const cha
   return create_op(attrs, ctx, in_shapes, in_types, op_ptr);
 }
 
+/*! \brief calls StatefulOp destructor for operator from library */
+MX_VOID_RET _opCallDestroyOpState(void* state_op) {
+  mxnet::ext::CustomStatefulOp* op_ptr =
+    reinterpret_cast<mxnet::ext::CustomStatefulOp*>(state_op);
+  delete op_ptr;
+}
+
 /*! \brief returns status of calling Stateful Forward/Backward for operator from library */
 MX_INT_RET _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t** inshapes,
                                    int* indims, void** indata, int* intypes, size_t* inIDs,