You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2019/01/18 18:00:07 UTC
[incubator-mxnet] branch master updated: Support populating errors
back to MXNet engine in callback (#13922)
This is an automated email from the ASF dual-hosted git repository.
marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 0c85665 Support populating errors back to MXNet engine in callback (#13922)
0c85665 is described below
commit 0c8566525fcff31abbb2fdf36343aced58c0a1de
Author: Yuxi Hu <da...@gmail.com>
AuthorDate: Fri Jan 18 09:59:45 2019 -0800
Support populating errors back to MXNet engine in callback (#13922)
* add an optional error_msg in engine on_complete callbcak
* use dmlc::Error struct to make error population extendable
---
include/mxnet/engine.h | 8 ++++----
src/engine/naive_engine.cc | 3 ++-
src/engine/threaded_engine.cc | 8 ++++++--
src/engine/threaded_engine.h | 3 ++-
4 files changed, 14 insertions(+), 8 deletions(-)
diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index e02b995..408a70a 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -74,15 +74,15 @@ class CallbackOnComplete {
public:
// use implicit copy and assign
/*! \brief involve the callback */
- inline void operator()() const {
- (*callback_)(engine_, param_);
+ inline void operator()(const dmlc::Error* error = nullptr) const {
+ (*callback_)(engine_, param_, error);
}
private:
/*! \brief engine can see content of callback */
friend class ::mxnet::Engine;
/*! \brief the real callback */
- void (*callback_)(Engine *, void *);
+ void (*callback_)(Engine *, void *, const dmlc::Error *);
/*! \brief the engine class passed to callback */
Engine* engine_;
/*! \brief the parameter set on callback */
@@ -275,7 +275,7 @@ class MXNET_API Engine {
* \param param the paramter passed to callback.
*/
inline CallbackOnComplete CreateCallback(
- void (*callback)(Engine *, void *), void *param) {
+ void (*callback)(Engine *, void *, const dmlc::Error *), void *param) {
CallbackOnComplete ret;
ret.callback_ = callback;
ret.engine_ = this;
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index daff530..05b72d2 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -208,7 +208,8 @@ class NaiveEngine final : public Engine {
private:
// callback to oncomplete
- static void OnComplete(Engine *engine, void *param) {
+ static void OnComplete(Engine *engine, void *param,
+ const dmlc::Error* error) {
static_cast<NaiveEngine*>(engine)->req_completed_ = true;
}
// whether action is completed
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index 3a7587f..6a60040 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -478,10 +478,14 @@ inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) {
return;
}
-void ThreadedEngine::OnCompleteStatic(
- Engine *engine, void *opr_block_) {
+void ThreadedEngine::OnCompleteStatic(Engine *engine, void *opr_block_,
+ const dmlc::Error* error) {
OprBlock *opr_block = static_cast<OprBlock*>(opr_block_);
ThreadedOpr *threaded_opr = opr_block->opr;
+ if (error != nullptr) {
+ auto ex_p = std::make_exception_ptr(*error);
+ threaded_opr->opr_exception = std::make_shared<std::exception_ptr>(ex_p);
+ }
if (opr_block->profiling && threaded_opr->opr_name) {
// record operator end timestamp
opr_block->opr_profile->stop();
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index ccfd09d..fae120d 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -465,7 +465,8 @@ class ThreadedEngine : public Engine {
}
}
- static void OnCompleteStatic(Engine *engine, void *threaded_opr);
+ static void OnCompleteStatic(Engine *engine, void *threaded_opr,
+ const dmlc::Error* error);
/*! \brief append an operator to bulk */
inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,