You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2019/01/18 18:00:07 UTC

[incubator-mxnet] branch master updated: Support populating errors back to MXNet engine in callback (#13922)

This is an automated email from the ASF dual-hosted git repository.

marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0c85665  Support populating errors back to MXNet engine in callback (#13922)
0c85665 is described below

commit 0c8566525fcff31abbb2fdf36343aced58c0a1de
Author: Yuxi Hu <da...@gmail.com>
AuthorDate: Fri Jan 18 09:59:45 2019 -0800

    Support populating errors back to MXNet engine in callback (#13922)
    
    * add an optional error_msg in engine on_complete callbcak
    
    * use dmlc::Error struct to make error population extendable
---
 include/mxnet/engine.h        | 8 ++++----
 src/engine/naive_engine.cc    | 3 ++-
 src/engine/threaded_engine.cc | 8 ++++++--
 src/engine/threaded_engine.h  | 3 ++-
 4 files changed, 14 insertions(+), 8 deletions(-)

diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index e02b995..408a70a 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -74,15 +74,15 @@ class CallbackOnComplete {
  public:
   // use implicit copy and assign
   /*! \brief involve the callback */
-  inline void operator()() const {
-    (*callback_)(engine_, param_);
+  inline void operator()(const dmlc::Error* error = nullptr) const {
+    (*callback_)(engine_, param_, error);
   }
 
  private:
   /*! \brief engine can see content of callback */
   friend class ::mxnet::Engine;
   /*! \brief the real callback */
-  void (*callback_)(Engine *, void *);
+  void (*callback_)(Engine *, void *, const dmlc::Error *);
   /*! \brief the engine class passed to callback */
   Engine* engine_;
   /*! \brief the parameter set on callback */
@@ -275,7 +275,7 @@ class MXNET_API Engine {
    * \param param the paramter passed to callback.
    */
   inline CallbackOnComplete CreateCallback(
-      void (*callback)(Engine *, void *), void *param) {
+      void (*callback)(Engine *, void *, const dmlc::Error *), void *param) {
     CallbackOnComplete ret;
     ret.callback_ = callback;
     ret.engine_ = this;
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index daff530..05b72d2 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -208,7 +208,8 @@ class NaiveEngine final : public Engine {
 
  private:
   // callback to oncomplete
-  static void OnComplete(Engine *engine, void *param) {
+  static void OnComplete(Engine *engine, void *param,
+                         const dmlc::Error* error) {
     static_cast<NaiveEngine*>(engine)->req_completed_ = true;
   }
   // whether action is completed
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index 3a7587f..6a60040 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -478,10 +478,14 @@ inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) {
   return;
 }
 
-void ThreadedEngine::OnCompleteStatic(
-    Engine *engine, void *opr_block_) {
+void ThreadedEngine::OnCompleteStatic(Engine *engine, void *opr_block_,
+                                      const dmlc::Error* error) {
   OprBlock *opr_block = static_cast<OprBlock*>(opr_block_);
   ThreadedOpr *threaded_opr = opr_block->opr;
+  if (error != nullptr) {
+    auto ex_p = std::make_exception_ptr(*error);
+    threaded_opr->opr_exception = std::make_shared<std::exception_ptr>(ex_p);
+  }
   if (opr_block->profiling && threaded_opr->opr_name) {
     // record operator end timestamp
     opr_block->opr_profile->stop();
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index ccfd09d..fae120d 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -465,7 +465,8 @@ class ThreadedEngine : public Engine {
     }
   }
 
-  static void OnCompleteStatic(Engine *engine, void *threaded_opr);
+  static void OnCompleteStatic(Engine *engine, void *threaded_opr,
+                               const dmlc::Error* error);
   /*! \brief append an operator to bulk */
   inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
                          std::vector<VarHandle> const& const_vars,