You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/10/30 15:07:19 UTC

[incubator-mxnet] branch master updated: Add async GPU dependency Engine (#20331)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 922d9f5  Add async GPU dependency Engine (#20331)
922d9f5 is described below

commit 922d9f59a013e7bb1ac17c3eed34174b26585fbe
Author: Serge Panev <sp...@nvidia.com>
AuthorDate: Sat Oct 30 08:01:45 2021 -0700

    Add async GPU dependency Engine (#20331)
    
    * Add async GPU depency Engine
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Temporarely skip byteps test
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Fix typo
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Fix lint
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Fix bad cast
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Move Async engine tag to MXNET_ENGINE_TYPE
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * clang-format
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Fix rebase errors
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
---
 ci/jenkins/Jenkinsfile_unix_gpu          |   3 +-
 include/mxnet/base.h                     |   4 +-
 include/mxnet/c_api.h                    |   2 +-
 include/mxnet/engine.h                   | 133 +++++++++++++++++-
 include/mxnet/storage.h                  |  19 ++-
 src/c_api/c_api.cc                       |  11 +-
 src/common/object_pool.h                 |   4 +-
 src/engine/engine.cc                     |  28 +++-
 src/engine/naive_engine.cc               |  15 ++-
 src/engine/stream_manager.h              |  24 +++-
 src/engine/threaded_engine.cc            | 222 ++++++++++++++++++++++++++++++-
 src/engine/threaded_engine.h             |  38 ++++--
 src/engine/threaded_engine_perdevice.cc  |  58 +++++++-
 src/engine/threaded_engine_pooled.cc     |  29 +++-
 src/imperative/imperative_utils.h        |  75 +++++------
 src/io/batchify.cc                       |   2 +-
 src/io/dataset.cc                        |   4 +-
 src/kvstore/comm.h                       |  23 +++-
 src/kvstore/gradient_compression.cc      |   4 -
 src/kvstore/kvstore_dist.h               |  21 ++-
 src/kvstore/kvstore_dist_server.h        |  10 +-
 src/kvstore/kvstore_local.h              |   7 +-
 src/kvstore/p3store_dist.h               |   7 +-
 src/ndarray/ndarray.cc                   | 137 +++++++++++++------
 src/operator/custom/ndarray_op.cc        |  10 +-
 src/operator/operator_util.cc            |  15 ---
 src/resource.cc                          |  15 ++-
 src/storage/gpu_device_storage.h         |   8 ++
 src/storage/pooled_storage_manager.h     |  46 ++++---
 src/storage/storage.cc                   |   2 +-
 tests/cpp/engine/threaded_engine_test.cc |  59 +++++---
 tests/python/gpu/test_gluon_gpu.py       |  86 ------------
 tests/python/gpu/test_operator_gpu.py    |  73 ----------
 33 files changed, 829 insertions(+), 365 deletions(-)

diff --git a/ci/jenkins/Jenkinsfile_unix_gpu b/ci/jenkins/Jenkinsfile_unix_gpu
index 2beb0f4..53224e9 100644
--- a/ci/jenkins/Jenkinsfile_unix_gpu
+++ b/ci/jenkins/Jenkinsfile_unix_gpu
@@ -49,7 +49,8 @@ core_logic: {
     custom_steps.test_unix_cpp_package_gpu('gpu'),
     // TODO(szha): fix and reenable the hanging issue. tracked in #18098
     // custom_steps.test_unix_distributed_kvstore_gpu('gpu'),
-    custom_steps.test_unix_byteps_gpu('gpu'),
+    // TODO(spanev): reenable when byteps is updated with the new dep engine API
+    // custom_steps.test_unix_byteps_gpu('gpu'),
   ]) 
 }
 ,
diff --git a/include/mxnet/base.h b/include/mxnet/base.h
index e374523..dc428da 100644
--- a/include/mxnet/base.h
+++ b/include/mxnet/base.h
@@ -348,9 +348,9 @@ struct RunContext {
    */
   void *aux_stream;
   /*!
-   * \brief indicator of whether this execution is run in bulk mode
+   * \brief pointer to the cuda event pool used by the dependency engine
    */
-  bool is_bulk;
+  void* event_pool = nullptr;
   /*!
    * \brief get mshadow stream from Context
    * \return the mshadow stream
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 6e668a4..0aff747 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -110,7 +110,7 @@ typedef const void *EngineFnPropertyHandle;
 typedef void *EngineVarHandle;
 
 /*! \brief Engine asynchronous operation */
-typedef void (*EngineAsyncFunc)(void*, void*, void*);
+typedef void (*EngineAsyncFunc)(void*, void*, void*, void*);
 /*! \brief Engine synchronous operation */
 typedef void (*EngineSyncFunc)(void*, void*);
 /*! \brief Callback to free the param for EngineAsyncFunc/EngineSyncFunc */
diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index 1a9582e..cdb8998 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -29,6 +29,7 @@
 #include <memory>
 #include <functional>
 #endif
+#include <utility>
 #include <vector>
 #include "./base.h"
 
@@ -39,6 +40,73 @@ class Engine;
 
 /*! \brief namespace of engine internal types. */
 namespace engine {
+#if MXNET_USE_CUDA
+/* \brief The class wrapping CUDA event with timing disabled. */
+class CUDAEvent final {
+ public:
+  explicit CUDAEvent(Context const& ctx);
+
+  CUDAEvent(CUDAEvent&& other) : event_(other.event_), dev_id_(other.dev_id_) {
+    other.event_ = nullptr;
+  }
+
+  CUDAEvent(const CUDAEvent& other) = delete;
+  void operator=(const CUDAEvent& other) = delete;
+
+  ~CUDAEvent();
+
+  inline std::weak_ptr<cudaEvent_t> GetEvent() noexcept {
+    return event_;
+  }
+
+ private:
+  std::shared_ptr<cudaEvent_t> event_;
+  int dev_id_;
+};
+
+class CUDAEventPool final {
+ public:
+  explicit CUDAEventPool(Context const& ctx) : counter_(0) {
+    for (size_t i = 0; i < kPoolSize; ++i) {
+      events_.emplace_back(ctx);
+    }
+  }
+
+  inline std::weak_ptr<cudaEvent_t> GetEvent(size_t i) noexcept {
+    return events_.at(i).GetEvent();
+  }
+
+  inline std::pair<std::weak_ptr<cudaEvent_t>, uint64_t> GetNextEvent() noexcept {
+    uint64_t c = counter_++;
+    return {events_.at((c) % kPoolSize).GetEvent(), c};
+  }
+
+  inline uint64_t GetCounterValue() noexcept {
+    return counter_.load();
+  }
+
+ private:
+  static constexpr size_t kPoolSize = 64;
+  std::vector<CUDAEvent> events_;
+  std::atomic<uint64_t> counter_;
+};
+
+/*! \brief full event info for the sync object.*/
+struct EventInfo {
+  std::weak_ptr<cudaEvent_t> event;
+  cudaStream_t stream;
+  uint64_t pool_index;
+};
+/*! \brief struct containing cuda events and variables needed for the dependencies.*/
+struct SyncObject {
+  // vector can carry multiple reader events
+  std::vector<EventInfo> reader_events;
+  // vector should carry only 1 writer event
+  std::vector<EventInfo> writer_event;
+  std::mutex mutex;
+};
+#endif
+
 /*! \brief base class of engine variables.*/
 struct Var {
   virtual size_t version() {
@@ -57,6 +125,12 @@ struct Var {
    * is modified, the version number is incremented by 1.
    */
   size_t version_{0};
+#if MXNET_USE_CUDA
+  /*!
+   * \brief struct containing cuda events and variables needed for the dependencies.
+   */
+  SyncObject sync_object;
+#endif
 };  // struct Var
 
 /*! \brief Internal representation of operator.  */
@@ -66,6 +140,29 @@ typedef Var* VarHandle;
 /*! \brief Operator pointer type, usually hold by user.*/
 typedef Opr* OprHandle;
 /*!
+ * \brief OnStart callback to the engine,
+ *  called by AsyncFn before the action
+ */
+class CallbackOnStart {
+ public:
+  // use implicit copy and assign
+  /*! \brief involve the callback */
+  inline void operator()(const dmlc::Error* error = nullptr) const {
+    if (callback_ != nullptr)
+      (*callback_)(engine_, param_, error);
+  }
+
+ private:
+  /*! \brief engine can see content of callback */
+  friend class ::mxnet::Engine;
+  /*! \brief the real callback */
+  void (*callback_)(Engine*, void*, const dmlc::Error*);
+  /*! \brief the engine class passed to callback */
+  Engine* engine_;
+  /*! \brief the parameter set on callback */
+  void* param_;
+};
+/*!
  * \brief OnComplete Callback to the engine,
  *  called by AsyncFn when action completes
  */
@@ -115,12 +212,14 @@ enum class FnProperty {
 */
 class MXNET_API Engine {
  public:
+  /*! \brief on start*/
+  typedef engine::CallbackOnStart CallbackOnStart;
   /*! \brief callback on complete*/
   typedef engine::CallbackOnComplete CallbackOnComplete;
   /*! \brief Synchronous operation to pass to engine. */
   typedef std::function<void(RunContext)> SyncFn;
   /*! \brief Asynchronous operation to pass to engine. */
-  typedef std::function<void(RunContext, CallbackOnComplete)> AsyncFn;
+  typedef std::function<void(RunContext, CallbackOnStart, CallbackOnComplete)> AsyncFn;
   /*! \brief Variable pointer */
   typedef engine::VarHandle VarHandle;
   /*! \brief Operator pointer */
@@ -247,7 +346,7 @@ class MXNET_API Engine {
    *
    * \return A shared pointer to Engine singleton.
    */
-  static std::shared_ptr<Engine> _GetSharedRef();
+  static const std::shared_ptr<Engine>& _GetSharedRef();
   /*!
    * \brief Push an synchronous operation to the engine.
    * \param exec_fn Execution function that executes the operation.
@@ -266,10 +365,32 @@ class MXNET_API Engine {
                         FnProperty prop = FnProperty::kNormal,
                         int priority = 0,
                         const char* opr_name = nullptr) {
-    this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) {
-        exec_fn(ctx);
-        on_complete();
-      }, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name);
+    this->PushAsync(
+        [exec_fn](RunContext ctx, CallbackOnStart on_start, CallbackOnComplete on_complete) {
+          on_start();
+          exec_fn(ctx);
+          on_complete();
+        },
+        exec_ctx,
+        const_vars,
+        mutable_vars,
+        prop,
+        priority,
+        opr_name);
+  }
+
+  /*!
+   * \brief factory function to create OnStart callback.
+   * \param callback th static callback function.
+   * \param param the paramter passed to callback.
+   */
+  inline CallbackOnStart CreateOnStart(void (*callback)(Engine*, void*, const dmlc::Error*),
+                                       void* param) {
+    CallbackOnStart ret;
+    ret.callback_ = callback;
+    ret.engine_   = this;
+    ret.param_    = param;
+    return ret;
   }
 
   /*!
diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h
index a72da9e..06db6ce 100644
--- a/include/mxnet/storage.h
+++ b/include/mxnet/storage.h
@@ -26,6 +26,7 @@
 
 #include <memory>
 #include <string>
+#include <vector>
 #include "./base.h"
 
 namespace mxnet {
@@ -39,6 +40,17 @@ namespace mxnet {
 class Storage {
  public:
   /*!
+   * \brief Storage sync object.
+   */
+  struct SyncObj {
+#if MXNET_USE_CUDA
+    /*!
+     * \brief All the events from the engine variable.
+     */
+    std::vector<std::weak_ptr<cudaEvent_t>> events;
+#endif
+  };
+  /*!
    * \brief Storage handle.
    */
   struct Handle {
@@ -64,6 +76,11 @@ class Storage {
      */
     std::string profiler_scope{MXNET_STORAGE_DEFAULT_PROFILER_SCOPE_CSTR};
     std::string name{MXNET_STORAGE_DEFAULT_NAME_CSTR};
+    /*!
+     * \brief Used to pass events back and forth between the engine Var
+     * and the storage manager.
+     */
+    SyncObj sync_obj;
   };
   /*!
    * \brief Allocate a new contiguous memory for a given size.
@@ -137,7 +154,7 @@ class Storage {
    *
    * \return A shared pointer to Storage singleton.
    */
-  static std::shared_ptr<Storage> _GetSharedRef();
+  static const std::shared_ptr<Storage>& _GetSharedRef();
 
  private:
   std::mutex cpu_mutex_;
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index edd2e55..8bb2b54 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -3764,6 +3764,7 @@ int MXNDArrayCreateFromSharedMem(int shared_pid,
 }
 
 using VarHandle          = Engine::VarHandle;
+using CallbackOnStart    = Engine::CallbackOnStart;
 using CallbackOnComplete = Engine::CallbackOnComplete;
 
 void AssertValidNumberVars(int num_const_vars, int num_mutable_vars) {
@@ -3795,15 +3796,17 @@ int MXEnginePushAsync(EngineAsyncFunc async_func,
 
   Engine::AsyncFn exec_fn;
   if (deleter == nullptr) {
-    exec_fn = [async_func, func_param](RunContext rctx, CallbackOnComplete on_complete) {
-      async_func(&rctx, &on_complete, func_param);
+    exec_fn = [async_func, func_param](
+                  RunContext rctx, CallbackOnStart on_start, CallbackOnComplete on_complete) {
+      async_func(&rctx, &on_start, &on_complete, func_param);
     };
   } else {
     // Wrap func_param in a shared_ptr with deleter such that deleter
     // will be called when the lambda goes out of scope.
     std::shared_ptr<void> shared_func_param(func_param, deleter);
-    exec_fn = [async_func, shared_func_param](RunContext rctx, CallbackOnComplete on_complete) {
-      async_func(&rctx, &on_complete, shared_func_param.get());
+    exec_fn = [async_func, shared_func_param](
+                  RunContext rctx, CallbackOnStart on_start, CallbackOnComplete on_complete) {
+      async_func(&rctx, &on_start, &on_complete, shared_func_param.get());
     };
   }
 
diff --git a/src/common/object_pool.h b/src/common/object_pool.h
index 72ba387..66385b9 100644
--- a/src/common/object_pool.h
+++ b/src/common/object_pool.h
@@ -61,7 +61,7 @@ class ObjectPool {
    * \brief Get a shared ptr of the singleton instance of pool.
    * \return Shared pointer to the Object Pool.
    */
-  static std::shared_ptr<ObjectPool> _GetSharedRef();
+  static const std::shared_ptr<ObjectPool>& _GetSharedRef();
 
  private:
   /*!
@@ -170,7 +170,7 @@ ObjectPool<T>* ObjectPool<T>::Get() {
 }
 
 template <typename T>
-std::shared_ptr<ObjectPool<T> > ObjectPool<T>::_GetSharedRef() {
+const std::shared_ptr<ObjectPool<T> >& ObjectPool<T>::_GetSharedRef() {
   static std::shared_ptr<ObjectPool<T> > inst_ptr(new ObjectPool<T>());
   return inst_ptr;
 }
diff --git a/src/engine/engine.cc b/src/engine/engine.cc
index 1d236e1..2e1e050 100644
--- a/src/engine/engine.cc
+++ b/src/engine/engine.cc
@@ -25,6 +25,7 @@
 #include <memory>
 #include <cstdlib>
 #include "./engine_impl.h"
+#include "../common/cuda/utils.h"
 
 namespace mxnet {
 namespace engine {
@@ -35,6 +36,13 @@ inline Engine* CreateEngine() {
     type = "ThreadedEnginePerDevice";
   std::string stype = type;
 
+  // The async tag is used later to determine if we use the GPU dependecy engine
+  std::string async_engine_tag = "Async";
+  auto tag_pos                 = stype.find(async_engine_tag);
+  if (tag_pos != std::string::npos && tag_pos + async_engine_tag.length() == stype.length()) {
+    stype = stype.substr(0, tag_pos);
+  }
+
   Engine* ret = nullptr;
 #if MXNET_PREDICT_ONLY == 0
   if (stype == "NaiveEngine") {
@@ -56,9 +64,27 @@ inline Engine* CreateEngine() {
   }
   return ret;
 }
+
+#if MXNET_USE_CUDA
+CUDAEvent::CUDAEvent(Context const& ctx)
+    : event_(std::make_shared<cudaEvent_t>()), dev_id_(ctx.dev_id) {
+  cudaEvent_t ev;
+  common::cuda::DeviceStore device_store(dev_id_);
+  CUDA_CALL(cudaEventCreateWithFlags(&ev, cudaEventDisableTiming));
+  *event_ = ev;
+}
+
+CUDAEvent::~CUDAEvent() {
+  if (event_ && *event_ != nullptr) {
+    common::cuda::DeviceStore device_store(dev_id_);
+    CUDA_CALL(cudaEventSynchronize(*event_));
+    CUDA_CALL(cudaEventDestroy(*event_));
+  }
+}
+#endif
 }  // namespace engine
 
-std::shared_ptr<Engine> Engine::_GetSharedRef() {
+const std::shared_ptr<Engine>& Engine::_GetSharedRef() {
   static std::shared_ptr<Engine> sptr(engine::CreateEngine());
   return sptr;
 }
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index 0693574..ad24af1 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -118,7 +118,7 @@ class NaiveEngine final : public Engine {
     NaiveOpr* opr                = op->Cast<NaiveOpr>();
     opr->profiling = profiling && profiler->IsProfiling(profiler::Profiler::kSymbolic);
     this->PushAsync(
-        [&](RunContext ctx, CallbackOnComplete on_complete) {
+        [&](RunContext ctx, CallbackOnStart on_start, CallbackOnComplete on_complete) {
           if (opr->profiling) {
             std::unique_ptr<profiler::ProfileOperator::Attributes> attrs;
             if (profiler->AggregateEnabled()) {
@@ -128,7 +128,7 @@ class NaiveEngine final : public Engine {
                 std::make_unique<profiler::ProfileOperator>(opr->opr_name.c_str(), attrs.release());
             opr->opr_profile->startForDevice(exec_ctx.dev_type, exec_ctx.dev_id);
           }
-          opr->fn(ctx, on_complete);
+          opr->fn(ctx, on_start, on_complete);
           if (opr->profiling) {
             opr->opr_profile->stop();
           }
@@ -156,6 +156,7 @@ class NaiveEngine final : public Engine {
                  bool wait            = false) override {
     std::promise<void> promise;
     std::future<void> future     = promise.get_future();
+    CallbackOnStart on_start     = CreateOnStart(NaiveEngine::OnStart, &promise);
     CallbackOnComplete callback  = CreateCallback(NaiveEngine::OnComplete, &promise);
     profiler::Profiler* profiler = profiler::Profiler::Get();
     auto opr_deleter             = [this](NaiveOpr* p) { this->DeleteOperator(p); };
@@ -189,12 +190,12 @@ class NaiveEngine final : public Engine {
         streams_[dev_id]     = mshadow::NewStream<gpu>(true, MXNET_USE_CUDNN != 0, dev_id);
         aux_streams_[dev_id] = new GPUAuxStream(streams_[dev_id]);
       }
-      exec_fun(RunContext{exec_ctx, streams_[dev_id], aux_streams_[dev_id], false}, callback);
+      exec_fun(RunContext{exec_ctx, streams_[dev_id], aux_streams_[dev_id]}, on_start, callback);
 #else
       LOG(FATAL) << "GPU is not enabled";
 #endif
     } else {
-      exec_fun(RunContext{exec_ctx, &cpu_stream_, nullptr, false}, callback);
+      exec_fun(RunContext{exec_ctx, &cpu_stream_, nullptr}, on_start, callback);
     }
     future.wait();
     // increment mutable var version
@@ -209,7 +210,9 @@ class NaiveEngine final : public Engine {
   void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override {
     NaiveVar* naive_var = NaiveVar::CastFromBase(var);
     this->PushAsync(
-        [delete_fn, naive_var](RunContext ctx, CallbackOnComplete on_complete) mutable {
+        [delete_fn, naive_var](
+            RunContext ctx, CallbackOnStart on_start, CallbackOnComplete on_complete) mutable {
+          on_start();
           delete_fn(ctx);
           NaiveVar::Delete(naive_var);
           on_complete();
@@ -233,6 +236,8 @@ class NaiveEngine final : public Engine {
   }
 
  private:
+  // onstart
+  static void OnStart(Engine* engine, void* param, const dmlc::Error* error) {}
   // callback to oncomplete
   static void OnComplete(Engine* engine, void* param, const dmlc::Error* error) {
     static_cast<std::promise<void>*>(param)->set_value();
diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h
index 87df70e..2384e1f 100644
--- a/src/engine/stream_manager.h
+++ b/src/engine/stream_manager.h
@@ -25,7 +25,9 @@
 #include <cstddef>
 #include <array>
 #include <string>
+#include <memory>
 #include <mutex>
+#include "./engine_impl.h"
 #include "../common/cuda/utils.h"
 
 namespace mxnet {
@@ -55,6 +57,7 @@ class StreamManager {
   std::array<std::array<GPUAuxStream*, kStreams>, kNumGpus> gpu_aux_streams_;
   std::array<mshadow::Stream<gpu>*, kNumGpus> gpu_io_streams_;
   std::array<int, kNumGpus> gpu_cnt_;
+  std::array<std::unique_ptr<CUDAEventPool>, kNumGpus> event_pools_;
 #endif  // MXNET_USE_CUDA
   DISALLOW_COPY_AND_ASSIGN(StreamManager);
 };  // class StreamManager
@@ -64,11 +67,12 @@ RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(Context const& ctx)
   RunContext ret;
   switch (ctx.dev_mask()) {
     case cpu::kDevMask:
-      ret = RunContext{ctx, nullptr, nullptr, false};
+      ret = RunContext{ctx, nullptr, nullptr};
       break;
     case gpu::kDevMask: {
 #if MXNET_USE_CUDA
       std::size_t use_counter;
+      CUDAEventPool* event_pool;
       {
         std::lock_guard<std::mutex> lock{mutex_};
         auto&& counter = gpu_cnt_.at(ctx.dev_id);
@@ -84,13 +88,17 @@ RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(Context const& ctx)
           }
           counter = 0;
         }
+        if (event_pools_.at(ctx.dev_id) == nullptr) {
+          event_pools_[ctx.dev_id] = std::make_unique<CUDAEventPool>(ctx);
+        }
+        event_pool  = event_pools_.at(ctx.dev_id).get();
         use_counter = counter;
         counter     = (counter + 1) % kStreams;
       }
       ret = RunContext{ctx,
                        gpu_streams_.at(ctx.dev_id).at(use_counter),
                        gpu_aux_streams_.at(ctx.dev_id).at(use_counter),
-                       false};
+                       event_pool};
       break;
 #else
       LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
@@ -107,18 +115,23 @@ RunContext StreamManager<kNumGpus, kStreams>::GetIORunContext(Context const& ctx
   RunContext ret;
   switch (ctx.dev_mask()) {
     case cpu::kDevMask:
-      ret = RunContext{ctx, nullptr, nullptr, false};
+      ret = RunContext{ctx, nullptr, nullptr};
       break;
     case gpu::kDevMask: {
 #if MXNET_USE_CUDA
+      CUDAEventPool* event_pool;
       {
         std::lock_guard<std::mutex> lock{mutex_};
         if (gpu_io_streams_.at(ctx.dev_id) == nullptr) {
           mxnet::common::cuda::DeviceStore device_store(ctx.dev_id);
           gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream<gpu>(false, false, ctx.dev_id);
         }
+        if (event_pools_.at(ctx.dev_id) == nullptr) {
+          event_pools_[ctx.dev_id] = std::make_unique<CUDAEventPool>(ctx);
+        }
+        event_pool = event_pools_.at(ctx.dev_id).get();
       }
-      ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id), nullptr, false};
+      ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id), nullptr, event_pool};
       break;
 #else
       LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
@@ -147,6 +160,9 @@ void StreamManager<kNumGpus, kStreams>::Finalize() {
 #if MXNET_USE_CUDA
   for (std::size_t i = 0; i < kNumGpus; ++i) {
     if (gpu_cnt_.at(i) != -1) {
+      if (event_pools_.at(i) != nullptr) {
+        event_pools_[i].reset();
+      }
       for (auto&& primary_stream : gpu_streams_.at(i)) {
         // Catch exception for CUDA driver shutdown
         MSHADOW_CATCH_ERROR(mshadow::DeleteStream<gpu>(primary_stream));
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index 58af6df..40d852b 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -272,7 +272,8 @@ void ThreadedEngine::DeleteOperator(OprHandle op) {
   deps.insert(deps.end(), threaded_opr->const_vars.begin(), threaded_opr->const_vars.end());
   deps.insert(deps.end(), threaded_opr->mutable_vars.begin(), threaded_opr->mutable_vars.end());
   this->PushAsync(
-      [threaded_opr](RunContext, CallbackOnComplete on_complete) {
+      [threaded_opr](RunContext, CallbackOnStart on_start, CallbackOnComplete on_complete) {
+        on_start();
         ThreadedOpr::Delete(threaded_opr);
         on_complete();
       },
@@ -349,7 +350,8 @@ void ThreadedEngine::PushSync(SyncFn exec_fn,
                               const char* opr_name) {
   if (!bulk_size() || prop != FnProperty::kNormal || priority) {
     this->PushAsync(
-        [exec_fn](RunContext ctx, CallbackOnComplete on_complete) {
+        [exec_fn](RunContext ctx, CallbackOnStart on_start, CallbackOnComplete on_complete) {
+          on_start();
           exec_fn(ctx);
           on_complete();
         },
@@ -371,9 +373,11 @@ void ThreadedEngine::PushSync(SyncFn exec_fn,
 void ThreadedEngine::DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) {
   ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var);
   this->PushAsync(
-      [delete_fn, threaded_var](RunContext ctx, CallbackOnComplete on_complete) {
+      [delete_fn, threaded_var](
+          RunContext ctx, CallbackOnStart on_start, CallbackOnComplete on_complete) {
         // Mark variable as orphan,
         // so during `ThreadedEngine::OnComplete` it could be recycled.
+        on_start();
         threaded_var->SetToDelete();
         delete_fn(ctx);
         on_complete();
@@ -399,7 +403,8 @@ void ThreadedEngine::WaitForVar(VarHandle var) {
   }
   std::atomic<bool> done{false};
   this->PushAsync(
-      [this, &done](RunContext, CallbackOnComplete on_complete) {
+      [this, &done](RunContext, CallbackOnStart on_start, CallbackOnComplete on_complete) {
+        on_start();
         if (engine_info_) {
           LOG(INFO) << "Sync is executed";
         }
@@ -480,6 +485,14 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
       }
     });
     if (to_delete) {
+#if MXNET_USE_CUDA
+      auto& sync_obj = i->sync_object;
+      {
+        std::lock_guard<std::mutex> l(sync_obj.mutex);
+        sync_obj.reader_events.clear();
+        sync_obj.writer_event.clear();
+      }
+#endif
       ThreadedVar::Delete(i);
     }
   }
@@ -533,5 +546,206 @@ void ThreadedEngine::OnCompleteStatic(Engine* engine, void* opr_block_, const dm
   OprBlock::Delete(opr_block);
 }
 
+void ThreadedEngine::OnStartStatic(Engine* engine, void* opr_block, const dmlc::Error* error) {
+  // no-op
+}
+
+#if MXNET_USE_CUDA
+static inline void AddEventHelper(std::unordered_map<cudaStream_t, EventInfo>* events_per_stream,
+                                  const EventInfo& cuda_event) {
+  auto event_stream = cuda_event.stream;
+  if (events_per_stream->count(event_stream) > 0) {
+    if ((*events_per_stream)[event_stream].pool_index < cuda_event.pool_index) {
+      (*events_per_stream)[event_stream] = cuda_event;
+    }
+  } else {
+    (*events_per_stream).emplace(event_stream, cuda_event);
+  }
+}
+
+static inline bool IsEngineAsync() {
+  std::string type = dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string(""));
+  std::string async_engine_tag("Async");
+  auto tag_pos = type.find(async_engine_tag);
+  return tag_pos != std::string::npos;
+}
+
+void ThreadedEngine::OnStartCPU(Engine* engine, void* opr_block, const dmlc::Error* error) {
+  static bool use_new_dep_engine = IsEngineAsync();
+  if (!use_new_dep_engine) {
+    return;
+  }
+  ThreadedOpr* threaded_opr = static_cast<OprBlock*>(opr_block)->opr;
+  std::unordered_map<cudaStream_t, EventInfo> event_per_stream;
+  for (auto* read_var : threaded_opr->const_vars) {
+    auto& sync_obj = read_var->sync_object;
+    std::lock_guard<std::mutex> l(sync_obj.mutex);
+    auto& reader_events = sync_obj.reader_events;
+    // check for expired events and delete them
+    reader_events.erase(std::remove_if(reader_events.begin(),
+                                       reader_events.end(),
+                                       [&](const EventInfo e_i) { return e_i.event.expired(); }),
+                        reader_events.end());
+    for (auto& cuda_event : reader_events) {
+      AddEventHelper(&event_per_stream, cuda_event);
+    }
+    if (!sync_obj.writer_event.empty()) {
+      if (sync_obj.writer_event[0].event.expired()) {
+        sync_obj.writer_event.clear();
+      } else {
+        AddEventHelper(&event_per_stream, sync_obj.writer_event[0]);
+      }
+    }
+  }
+
+  for (auto* write_var : threaded_opr->mutable_vars) {
+    auto& sync_obj = write_var->sync_object;
+    std::lock_guard<std::mutex> l(sync_obj.mutex);
+    auto& reader_events = sync_obj.reader_events;
+    // check for expired events and delete them
+    reader_events.erase(std::remove_if(reader_events.begin(),
+                                       reader_events.end(),
+                                       [&](const EventInfo e_i) { return e_i.event.expired(); }),
+                        reader_events.end());
+    for (auto& cuda_event : reader_events) {
+      AddEventHelper(&event_per_stream, cuda_event);
+    }
+    if (!sync_obj.writer_event.empty()) {
+      if (sync_obj.writer_event[0].event.expired()) {
+        sync_obj.writer_event.clear();
+      } else {
+        AddEventHelper(&event_per_stream, sync_obj.writer_event[0]);
+      }
+    }
+  }
+  for (auto event : event_per_stream) {
+    auto ev = event.second.event.lock();
+    MSHADOW_CUDA_CALL(cudaEventSynchronize(*ev));
+  }
+}
+
+void ThreadedEngine::OnStartGPU(Engine* engine, void* sync_info, const dmlc::Error* error) {
+  static bool use_new_dep_engine = IsEngineAsync();
+  if (!use_new_dep_engine) {
+    return;
+  }
+  auto* info = reinterpret_cast<GPUWorkerSyncInfo*>(sync_info);
+  CHECK(info->stream != nullptr);
+  auto* worker_stream       = reinterpret_cast<mshadow::Stream<gpu>*>(info->stream);
+  ThreadedOpr* threaded_opr = static_cast<OprBlock*>(info->opr_block)->opr;
+  std::unordered_map<cudaStream_t, EventInfo> event_per_stream;
+  for (auto* read_var : threaded_opr->const_vars) {
+    auto& sync_obj = read_var->sync_object;
+    std::lock_guard<std::mutex> l(sync_obj.mutex);
+    auto& reader_events = sync_obj.reader_events;
+    // check for expired events and delete them
+    reader_events.erase(std::remove_if(reader_events.begin(),
+                                       reader_events.end(),
+                                       [&](const EventInfo e_i) { return e_i.event.expired(); }),
+                        reader_events.end());
+    for (auto& writer : sync_obj.writer_event) {
+      if (writer.event.expired()) {
+        sync_obj.writer_event.clear();
+        break;
+      }
+      if (writer.stream != worker_stream->stream_) {
+        // if there is already a reader on the same stream as us,
+        // it already synced with that writer and we can rely on
+        // the ongoing sync
+        bool found = false;
+        for (const auto& reader : reader_events) {
+          if (reader.stream == worker_stream->stream_) {
+            found = true;
+            break;
+          }
+        }
+        if (!found) {
+          AddEventHelper(&event_per_stream, writer);
+        }
+      }
+    }
+  }
+  for (auto* write_var : threaded_opr->mutable_vars) {
+    auto& sync_obj = write_var->sync_object;
+    std::lock_guard<std::mutex> l(sync_obj.mutex);
+    // check for expired events and delete them
+    auto& reader_events = sync_obj.reader_events;
+    reader_events.erase(std::remove_if(reader_events.begin(),
+                                       reader_events.end(),
+                                       [&](const EventInfo e_i) { return e_i.event.expired(); }),
+                        reader_events.end());
+    // if there are some readers, we wait for them
+    for (auto& cuda_event : reader_events) {
+      if (worker_stream->stream_ != cuda_event.stream) {
+        AddEventHelper(&event_per_stream, cuda_event);
+      }
+    }
+    if (!sync_obj.writer_event.empty()) {
+      if (sync_obj.writer_event[0].event.expired()) {
+        sync_obj.writer_event.clear();
+      } else {
+        if (worker_stream->stream_ != sync_obj.writer_event[0].stream) {
+          AddEventHelper(&event_per_stream, sync_obj.writer_event[0]);
+        }
+      }
+    }
+  }
+  for (auto event : event_per_stream) {
+    auto ev = event.second.event.lock();
+    MSHADOW_CUDA_CALL(cudaStreamWaitEvent(worker_stream->stream_, *ev, 0));
+  }
+}
+
+void ThreadedEngine::OnCompleteGPU(Engine* engine, void* sync_info, const dmlc::Error* error) {
+  auto* info = reinterpret_cast<GPUWorkerSyncInfo*>(sync_info);
+  CHECK(info->stream != nullptr);
+
+  auto* worker_stream            = reinterpret_cast<mshadow::Stream<gpu>*>(info->stream);
+  static bool use_new_dep_engine = IsEngineAsync();
+
+  if (!use_new_dep_engine) {
+    worker_stream->Wait();
+    ThreadedEngine::OnCompleteStatic(engine, info->opr_block, error);
+    GPUWorkerSyncInfo::Delete(info);
+    return;
+  }
+
+  ThreadedOpr* threaded_opr    = static_cast<OprBlock*>(info->opr_block)->opr;
+  auto* event_pool             = static_cast<CUDAEventPool*>(info->event_pool);
+  auto [event, event_pool_idx] = event_pool->GetNextEvent();
+  auto ev                      = event.lock();
+  MSHADOW_CUDA_CALL(cudaEventRecord(*ev, worker_stream->stream_));
+  for (auto* read_var : threaded_opr->const_vars) {
+    auto& sync_obj = read_var->sync_object;
+    std::lock_guard<std::mutex> l(sync_obj.mutex);
+    // If some reader event is already recorded on the same stream,
+    // we want to replace ourselves by it
+    int i;
+    for (i = 0; i < sync_obj.reader_events.size(); ++i) {
+      auto stream = sync_obj.reader_events[i].stream;
+      if (stream == worker_stream->stream_) {
+        sync_obj.reader_events[i].event      = event;
+        sync_obj.reader_events[i].pool_index = event_pool_idx;
+        break;
+      }
+    }
+    if (i == sync_obj.reader_events.size()) {
+      sync_obj.reader_events.push_back({event, worker_stream->stream_, event_pool_idx});
+    }
+  }
+
+  for (auto* write_var : threaded_opr->mutable_vars) {
+    auto& sync_obj = write_var->sync_object;
+    std::lock_guard<std::mutex> l(sync_obj.mutex);
+    sync_obj.reader_events.clear();
+    sync_obj.writer_event.clear();
+    sync_obj.writer_event.push_back({event, worker_stream->stream_, event_pool_idx});
+  }
+
+  ThreadedEngine::OnCompleteStatic(engine, info->opr_block, error);
+  GPUWorkerSyncInfo::Delete(info);
+}
+#endif
+
 }  // namespace engine
 }  // namespace mxnet
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index bd3f34c..a9e08a8 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -353,7 +353,10 @@ class ThreadedEngine : public Engine {
    * \param run_ctx runtime context used to execute the function.
    * \param opr_block the opr_block to be executed and deleted.
    */
-  void ExecuteOprBlock(RunContext run_ctx, OprBlock* opr_block) {
+  void ExecuteOprBlock(RunContext run_ctx,
+                       OprBlock* opr_block,
+                       CallbackOnStart on_start,
+                       CallbackOnComplete callback) {
     ThreadedOpr* threaded_opr = opr_block->opr;
     if (opr_block->profiling && threaded_opr->opr_name.size()) {
       std::unique_ptr<profiler::ProfileOperator::Attributes> attrs;
@@ -365,7 +368,6 @@ class ThreadedEngine : public Engine {
           new profiler::ProfileOperator(threaded_opr->opr_name.c_str(), attrs.release()));
       opr_block->opr_profile->startForDevice(ctx.dev_type, ctx.dev_id);
     }
-    CallbackOnComplete callback = this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block);
     const bool debug_info       = (engine_info_ && debug_push_opr_ == opr_block);
     if (debug_info) {
       LOG(INFO) << "ExecuteOprBlock " << opr_block << "shutdown_phase=" << shutdown_phase_;
@@ -381,11 +383,13 @@ class ThreadedEngine : public Engine {
           if ((!(threaded_opr->opr_exception && *threaded_opr->opr_exception) ||
                threaded_opr->prop == FnProperty::kNoSkip) ||
               threaded_opr->wait) {
-            threaded_opr->fn(run_ctx, callback);
+            threaded_opr->fn(run_ctx, on_start, callback);
           } else {
+            on_start();
             callback();
           }
         } catch (const std::exception& e) {
+          on_start();
           threaded_opr->opr_exception =
               std::make_shared<std::exception_ptr>(std::current_exception());
           callback();
@@ -408,6 +412,7 @@ class ThreadedEngine : public Engine {
         }
       }
     } else {
+      on_start();
       callback();
     }
   }
@@ -429,6 +434,22 @@ class ThreadedEngine : public Engine {
     return bulk_size;
   }
 
+ protected:
+  static void OnStartStatic(Engine* engine, void* opr_block, const dmlc::Error* error);
+  static void OnCompleteStatic(Engine* engine, void* threaded_opr, const dmlc::Error* error);
+#if MXNET_USE_CUDA
+  static void OnStartCPU(Engine* engine, void* opr_block, const dmlc::Error* error);
+  static void OnStartGPU(Engine* engine, void* sync_info, const dmlc::Error* error);
+  static void OnCompleteGPU(Engine* engine, void* sync_info, const dmlc::Error* error);
+  struct GPUWorkerSyncInfo : public common::ObjectPoolAllocatable<GPUWorkerSyncInfo> {
+    void* opr_block{nullptr};
+    void* stream{nullptr};
+    void* event_pool{nullptr};
+  };
+
+  std::shared_ptr<common::ObjectPool<GPUWorkerSyncInfo>> objpool_gpu_sync_ref_;
+#endif
+
  private:
   /*! \brief structure for holding bulk execution status */
   struct BulkStatus {
@@ -491,7 +512,6 @@ class ThreadedEngine : public Engine {
     }
   }
 
-  static void OnCompleteStatic(Engine* engine, void* threaded_opr, const dmlc::Error* error);
   /*!
    * \brief find exception in global_exception_refs and add it if missing
    * \param opr_exception the exception to be added to global_exception_refs
@@ -536,16 +556,11 @@ class ThreadedEngine : public Engine {
     DeduplicateVarHandle(&bulk_status.const_vars, &bulk_status.mutable_vars);
     auto functions = bulk_status.functions;
     this->PushAsync(
-        [functions](RunContext ctx, CallbackOnComplete on_complete) {
-          ctx.is_bulk = true;
+        [functions](RunContext ctx, CallbackOnStart on_start, CallbackOnComplete on_complete) {
+          on_start();
           for (auto& fn : *functions) {
             fn(ctx);
           }
-          ctx.is_bulk = false;
-          bool is_gpu = ctx.ctx.dev_mask() == gpu::kDevMask;
-          if (is_gpu) {
-            ctx.get_stream<gpu>()->Wait();
-          }
           on_complete();
         },
         bulk_status.ctx,
@@ -554,7 +569,6 @@ class ThreadedEngine : public Engine {
         FnProperty::kNormal,
         0,
         "ImperativeBulk");
-
     bulk_status.functions.reset(new std::vector<SyncFn>());
     bulk_status.functions->reserve(bulk_status.bulk_size);
     bulk_status.const_vars.clear();
diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc
index b70823b..b566e44 100644
--- a/src/engine/threaded_engine_perdevice.cc
+++ b/src/engine/threaded_engine_perdevice.cc
@@ -53,8 +53,14 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
   static auto constexpr kCopyQueue     = kPriority;
   static auto constexpr kPriorityQueue = kPriority;
   static auto constexpr kWorkerQueue   = kFIFO;
+  static int constexpr kMaxStreams     = 256;
 
   ThreadedEnginePerDevice() noexcept(false) {
+#if MXNET_USE_CUDA
+    // Make sure that the pool is not destroyed before the engine
+    objpool_gpu_sync_ref_ = common::ObjectPool<GPUWorkerSyncInfo>::_GetSharedRef();
+    streams_.reserve(kMaxStreams);
+#endif
     this->Start();
   }
   ~ThreadedEnginePerDevice() noexcept(false) override {
@@ -77,6 +83,15 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
     StopNoWait();
   }
 
+#if MXNET_USE_CUDA
+  void WaitForAll() override {
+    ThreadedEngine::WaitForAll();
+    for (auto s : streams_) {
+      s->Wait();
+    }
+  }
+#endif
+
   void Start() override {
     if (is_worker_)
       return;
@@ -107,7 +122,10 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
         MSHADOW_CATCH_ERROR(mshadow::SetDevice<gpu>(ctx.dev_id));
 #endif
       }
-      this->ExecuteOprBlock(RunContext{ctx, nullptr, nullptr, false}, opr_block);
+      CallbackOnStart on_start = this->CreateOnStart(ThreadedEngine::OnStartStatic, opr_block);
+      CallbackOnComplete callback =
+          this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block);
+      this->ExecuteOprBlock(RunContext{ctx, nullptr, nullptr}, opr_block, on_start, callback);
     } else {
       if (ctx.dev_mask() == Context::kCPU) {
         // CPU execution.
@@ -238,6 +256,12 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
   common::LazyAllocArray<ThreadWorkerBlock<kCopyQueue>> gpu_copy_workers_;
   // gpu priority workers
   common::LazyAllocArray<ThreadWorkerBlock<kPriorityQueue>> gpu_priority_workers_;
+#if MXNET_USE_CUDA
+  std::vector<mshadow::Stream<gpu>*> streams_;
+
+  std::unordered_map<int, std::unique_ptr<CUDAEventPool>> cuda_event_pool_per_worker_;
+#endif
+
   /*!
    * \brief GPU worker that performs operations on a certain device.
    * \param dev_id The device id of the worker.
@@ -265,9 +289,20 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
         aux_stream = new GPUAuxStream(stream);
       }
     } while (false);
+    // register stream
+    streams_.push_back(stream);
+    CUDAEventPool* event_pool;
+    auto event_pool_it = cuda_event_pool_per_worker_.find(ctx.dev_id);
+    if (event_pool_it != cuda_event_pool_per_worker_.end()) {
+      event_pool = event_pool_it->second.get();
+    } else {
+      auto res =
+          cuda_event_pool_per_worker_.emplace(ctx.dev_id, std::make_unique<CUDAEventPool>(ctx));
+      event_pool = res.first->second.get();
+    }
     // execute task
     OprBlock* opr_block;
-    RunContext run_ctx{ctx, stream, aux_stream, false};
+    RunContext run_ctx{ctx, stream, aux_stream};
     auto* task_queue = &(block->task_queue);
 
     // Don't eat up omp threads for GPU jobs.  They're probably best used elsewhere,
@@ -284,7 +319,13 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
       auto color = common::cuda::nvtx::nameToColor(nvtx_name, name_prefix_len);
       common::cuda::nvtx::gpuRangeStart(color, nvtx_name);
 #endif
-      this->ExecuteOprBlock(run_ctx, opr_block);
+      auto* info                  = ThreadedEngine::GPUWorkerSyncInfo::New();
+      info->opr_block             = opr_block;
+      info->stream                = stream;
+      info->event_pool            = event_pool;
+      CallbackOnStart on_start    = this->CreateOnStart(ThreadedEngine::OnStartGPU, info);
+      CallbackOnComplete callback = this->CreateCallback(ThreadedEngine::OnCompleteGPU, info);
+      this->ExecuteOprBlock(run_ctx, opr_block, on_start, callback);
 #if MXNET_USE_NVTX
       common::cuda::nvtx::gpuRangeStop();
 #endif
@@ -303,7 +344,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
                         const std::shared_ptr<dmlc::ManualEvent>& ready_event) {
     this->is_worker_ = true;
     auto* task_queue = &(block->task_queue);
-    RunContext run_ctx{ctx, nullptr, nullptr, false};
+    RunContext run_ctx{ctx, nullptr, nullptr};
 
     // execute task
     OprBlock* opr_block;
@@ -313,7 +354,14 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
     OpenMP::Get()->on_start_worker_thread(true);
 
     while (task_queue->Pop(&opr_block)) {
-      this->ExecuteOprBlock(run_ctx, opr_block);
+#if MXNET_USE_CUDA
+      CallbackOnStart on_start = this->CreateOnStart(ThreadedEngine::OnStartCPU, opr_block);
+#else
+      CallbackOnStart on_start = this->CreateOnStart(ThreadedEngine::OnStartStatic, opr_block);
+#endif
+      CallbackOnComplete callback =
+          this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block);
+      this->ExecuteOprBlock(run_ctx, opr_block, on_start, callback);
     }
   }
 
diff --git a/src/engine/threaded_engine_pooled.cc b/src/engine/threaded_engine_pooled.cc
index c0ca039..0ec91b2 100644
--- a/src/engine/threaded_engine_pooled.cc
+++ b/src/engine/threaded_engine_pooled.cc
@@ -47,6 +47,10 @@ namespace engine {
 class ThreadedEnginePooled : public ThreadedEngine {
  public:
   ThreadedEnginePooled() {
+#if MXNET_USE_CUDA
+    // Make sure that the pool is not destroyed before the engine
+    objpool_gpu_sync_ref_ = common::ObjectPool<ThreadedEngine::GPUWorkerSyncInfo>::_GetSharedRef();
+#endif
     this->Start();
   }
 
@@ -55,13 +59,13 @@ class ThreadedEnginePooled : public ThreadedEngine {
   }
 
   void StopNoWait() {
-    streams_->Finalize();
     task_queue_->SignalForKill();
     io_task_queue_->SignalForKill();
     task_queue_     = nullptr;
     io_task_queue_  = nullptr;
     thread_pool_    = nullptr;
     io_thread_pool_ = nullptr;
+    streams_->Finalize();
     streams_        = nullptr;
   }
 
@@ -152,7 +156,28 @@ class ThreadedEnginePooled : public ThreadedEngine {
                     opr_block->opr->prop == FnProperty::kCopyToGPU);
     auto&& rctx  = is_copy ? streams_->GetIORunContext(opr_block->ctx)
                            : streams_->GetRunContext(opr_block->ctx);
-    this->ExecuteOprBlock(rctx, opr_block);
+#if MXNET_USE_CUDA
+    CallbackOnStart on_start;
+    CallbackOnComplete callback;
+    if (opr_block->ctx.dev_mask() == Context::kCPU) {
+      on_start = this->CreateOnStart(ThreadedEngine::OnStartCPU, opr_block);
+      callback = this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block);
+    } else {
+      CHECK_EQ(opr_block->ctx.dev_mask(), Context::kGPU);
+      auto stream      = rctx.get_stream<gpu>();
+      auto event_pool  = static_cast<CUDAEventPool*>(rctx.event_pool);
+      auto* info       = ThreadedEngine::GPUWorkerSyncInfo::New();
+      info->opr_block  = opr_block;
+      info->stream     = stream;
+      info->event_pool = event_pool;
+      on_start         = this->CreateOnStart(ThreadedEngine::OnStartGPU, info);
+      callback         = this->CreateCallback(ThreadedEngine::OnCompleteGPU, info);
+    }
+#else   // MXNET_USE_CUDA
+    CallbackOnStart on_start = this->CreateOnStart(ThreadedEngine::OnStartStatic, opr_block);
+    CallbackOnComplete callback = this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block);
+#endif  // MXNET_USE_CUDA
+    this->ExecuteOprBlock(rctx, opr_block, on_start, callback);
   }
   /*!
    * \brief Push the operation to the queue.
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 7d506fa..b649958 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -698,14 +698,11 @@ inline void PushFCompute(const FCompute& fn,
     fn(attrs, opctx, input_blobs, tmp_req, output_blobs);
     // post-fcompute fallback, cast to original storage type
     CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu);
-    if (is_gpu && !rctx.is_bulk) {
-      rctx.get_stream<gpu>()->Wait();
-    }
     DerefInputOutputRelease(inputs, outputs);
   };
   if (CheckIfSkipEngine(attrs)) {
     // execute without engine
-    run(RunContext{ctx, nullptr, nullptr, false});
+    run(RunContext{ctx, nullptr, nullptr});
   } else {
     Engine::Get()->PushSync(
         run, ctx, read_vars, write_vars, FnProperty::kNormal, 0, op->name.c_str());
@@ -736,12 +733,9 @@ inline void PushFComputeEx(const FComputeEx& fn,
     INVALIDATE_OUTPUTS_COND(!cross_device_copy, outputsA, req);
     CREATE_DEFAULT_INPUTS(!cross_device_copy, attrs, CreateDefaultInputs(&inputsA));
     fn(attrs, opctx, inputsA, req, outputsA);
-    if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync && !rctx.is_bulk) {
-      rctx.get_stream<gpu>()->Wait();
-    }
   };
   if (cross_device_copy || CheckIfSkipEngine(attrs)) {
-    run(RunContext{ctx, nullptr, nullptr, false});
+    run(RunContext{ctx, nullptr, nullptr});
   } else {
     CHECK(exec_type == ExecType::kSync);
     Engine::Get()->PushSync(
@@ -772,7 +766,9 @@ inline void PushOperator(const OpStatePtr& state,
 
   auto fcompute_ex = common::GetFCompute<FStatefulComputeEx>(op, "FStatefulComputeEx", ctx);
   if (fcompute_ex != nullptr && dispatch_mode == DispatchMode::kFComputeEx) {
-    const auto& run = [=](RunContext rctx, engine::CallbackOnComplete on_complete) {
+    const auto& run = [=](RunContext rctx,
+                          engine::CallbackOnStart on_start,
+                          engine::CallbackOnComplete on_complete) {
       OpContext opctx{need_grad, is_train, rctx, on_complete, requested};
       REDEFINE_INPUTS_OUTPUTS(inputs, outputs, inputsA, outputsA);
       INVALIDATE_OUTPUTS_COND(
@@ -780,26 +776,26 @@ inline void PushOperator(const OpStatePtr& state,
       CREATE_DEFAULT_INPUTS(exec_type != ExecType::kCrossDeviceCopy && op->name != "_CachedOp",
                             attrs,
                             CreateDefaultInputs(&inputsA));
+      on_start();
       fcompute_ex(state, opctx, inputsA, req, outputsA);
-      if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync &&
-          rctx.get_stream<gpu>() && !rctx.is_bulk) {
-        rctx.get_stream<gpu>()->Wait();
-      }
     };
 
     // For operators with subgraphs, we need to invoke them in the main thread
     // instead of the threaded engine.
     if (exec_type == ExecType::kSubgraphExec || CheckIfSkipEngine(attrs)) {
-      RunContext rctx{ctx, nullptr, nullptr, false};
-      run(rctx, engine::CallbackOnComplete());
+      RunContext rctx{ctx, nullptr, nullptr};
+      run(rctx, engine::CallbackOnStart(), engine::CallbackOnComplete());
     } else if (exec_type == ExecType::kSync) {
-      Engine::Get()->PushSync([=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); },
-                              ctx,
-                              read_vars,
-                              write_vars,
-                              FnProperty::kNormal,
-                              0,
-                              op->name.c_str());
+      Engine::Get()->PushSync(
+          [=](RunContext rctx) {
+            run(rctx, engine::CallbackOnStart(), engine::CallbackOnComplete());
+          },
+          ctx,
+          read_vars,
+          write_vars,
+          FnProperty::kNormal,
+          0,
+          op->name.c_str());
     } else {
       CHECK(exec_type == ExecType::kAsync);
       Engine::Get()->PushAsync(
@@ -811,7 +807,9 @@ inline void PushOperator(const OpStatePtr& state,
         << "One of FStatefulCompute and FStatefulComputeEx must be registered "
         << "for stateful operator " << op->name;
 
-    const auto& run = [=](RunContext rctx, engine::CallbackOnComplete on_complete) {
+    const auto& run = [=](RunContext rctx,
+                          engine::CallbackOnStart on_start,
+                          engine::CallbackOnComplete on_complete) {
       OpContext opctx{need_grad, is_train, rctx, on_complete, requested};
 
       std::vector<TBlob> input_blobs, output_blobs;
@@ -843,23 +841,23 @@ inline void PushOperator(const OpStatePtr& state,
       fcompute(state, opctx, input_blobs, tmp_req, output_blobs);
       // post-fcompute fallback, cast to original storage type, if necessary
       CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu);
-      if (is_gpu && exec_type == ExecType::kSync && rctx.get_stream<gpu>() && !rctx.is_bulk) {
-        rctx.get_stream<gpu>()->Wait();
-      }
       DerefInputOutputRelease(inputs, outputs);
     };
 
     if (exec_type == ExecType::kSubgraphExec || CheckIfSkipEngine(attrs)) {
-      RunContext rctx{ctx, nullptr, nullptr, false};
-      run(rctx, engine::CallbackOnComplete());
+      RunContext rctx{ctx, nullptr};
+      run(rctx, engine::CallbackOnStart(), engine::CallbackOnComplete());
     } else if (exec_type == ExecType::kSync) {
-      Engine::Get()->PushSync([=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); },
-                              ctx,
-                              read_vars,
-                              write_vars,
-                              FnProperty::kNormal,
-                              0,
-                              op->name.c_str());
+      Engine::Get()->PushSync(
+          [=](RunContext rctx) {
+            run(rctx, engine::CallbackOnStart(), engine::CallbackOnComplete());
+          },
+          ctx,
+          read_vars,
+          write_vars,
+          FnProperty::kNormal,
+          0,
+          op->name.c_str());
     } else {
       CHECK(exec_type == ExecType::kAsync);
       Engine::Get()->PushAsync(
@@ -1251,7 +1249,9 @@ inline Engine::OprHandle CreateEngineOp(
   bool is_async = execs.size() > 1 ? false : execs[0]->exec_type() == ExecType::kAsync;
 
   auto exec_fun = [execs, is_async, is_gpu](RunContext ctx,
+                                            Engine::CallbackOnStart on_start,
                                             Engine::CallbackOnComplete on_complete) {
+    on_start();
     if (is_async) {
       execs[0]->op_ctx.async_on_complete = on_complete;
     }
@@ -1260,10 +1260,7 @@ inline Engine::OprHandle CreateEngineOp(
     // call on complete only if it is async op
     if (!is_async) {
       if (is_gpu) {
-#if MXNET_USE_CUDA
-        // Wait GPU kernel to finish.
-        ctx.get_stream<gpu>()->Wait();
-#else
+#if !MXNET_USE_CUDA
         LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
 #endif
       }
diff --git a/src/io/batchify.cc b/src/io/batchify.cc
index 7d602fb..944bbd8 100644
--- a/src/io/batchify.cc
+++ b/src/io/batchify.cc
@@ -166,7 +166,7 @@ class StackBatchify : public BatchifyFunction {
             // inputs[j][i].WaitToRead();
             DType* ptr = (*outputs)[i].data().dptr<DType>();
             auto asize = ashape.Size();
-            RunContext rctx{(*outputs)[i].ctx(), nullptr, nullptr, false};
+            RunContext rctx{(*outputs)[i].ctx(), nullptr, nullptr};
             auto dst = TBlob(ptr + asize * j, inputs[j][i].data().shape_, cpu::kDevMask, dtype, 0);
             mxnet::ndarray::Copy<cpu, cpu>(
                 inputs[j][i].data(), &dst, Context::CPU(), Context::CPU(), rctx);
diff --git a/src/io/dataset.cc b/src/io/dataset.cc
index a461187..153e3c4 100644
--- a/src/io/dataset.cc
+++ b/src/io/dataset.cc
@@ -95,7 +95,7 @@ class RecordFileDataset final : public Dataset {
       const size_t size = read_buff.size();
       out = NDArray(TShape({static_cast<dim_t>(size)}), Context::CPU(), false, mshadow::kInt8);
       TBlob dst = out.data();
-      RunContext rctx{Context::CPU(), nullptr, nullptr, false};
+      RunContext rctx{Context::CPU(), nullptr, nullptr};
       mxnet::ndarray::Copy<cpu, cpu>(TBlob(const_cast<void*>(reinterpret_cast<const void*>(buf)),
                                            out.shape(),
                                            cpu::kDevMask,
@@ -212,7 +212,7 @@ class ImageRecordFileDataset : public Dataset {
     size -= sizeof(header);
     s += sizeof(header);
     NDArray label = NDArray(Context::CPU(), mshadow::default_type_flag);
-    RunContext rctx{Context::CPU(), nullptr, nullptr, false};
+    RunContext rctx{Context::CPU(), nullptr, nullptr};
     if (header.flag > 0) {
       auto label_shape = header.flag <= 1 ? TShape(0, 1) : TShape({header.flag});
       label.ReshapeAndAlloc(label_shape);
diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index c2dcd20..5a1df93 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -165,7 +165,10 @@ class CommCPU : public Comm {
       }
 
       Engine::Get()->PushAsync(
-          [reduce, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+          [reduce, this](RunContext rctx,
+                         Engine::CallbackOnStart on_start,
+                         Engine::CallbackOnComplete on_complete) {
+            on_start();
             ReduceSumCPU(reduce);
             on_complete();
           },
@@ -175,7 +178,6 @@ class CommCPU : public Comm {
           FnProperty::kCPUPrioritized,
           priority,
           "KVStoreReduce");
-
     } else {
       // sparse reduce
       std::vector<Engine::VarHandle> const_vars(src.size());
@@ -199,7 +201,10 @@ class CommCPU : public Comm {
       Resource rsc = ResourceManager::Get()->Request(buf_merged.ctx(),
                                                      ResourceRequest(ResourceRequest::kTempSpace));
       Engine::Get()->PushAsync(
-          [reduce, buf_merged, rsc, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+          [reduce, buf_merged, rsc, this](RunContext rctx,
+                                          Engine::CallbackOnStart on_start,
+                                          Engine::CallbackOnComplete on_complete) {
+            on_start();
             NDArray out = buf_merged;
             is_serial_push_
                 ? ReduceSumCPUExSerial(reduce, &out)
@@ -271,7 +276,10 @@ class CommCPU : public Comm {
                         "consider create a new NDArray buffer to store the output.");
       }
       Engine::Get()->PushAsync(
-          [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+          [=](RunContext rctx,
+              Engine::CallbackOnStart on_start,
+              Engine::CallbackOnComplete on_complete) {
+            on_start();
             const TBlob& indices = row_id.data();
             NDArray temp         = retained_cpu;  // get rid the of const qualifier
             op::SparseRetainOpForwardRspImpl<cpu>(
@@ -679,7 +687,10 @@ class CommDevice : public Comm {
       }
       bool is_gpu = retained_gpu.ctx().dev_mask() == gpu::kDevMask;
       Engine::Get()->PushAsync(
-          [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+          [=](RunContext rctx,
+              Engine::CallbackOnStart on_start,
+              Engine::CallbackOnComplete on_complete) {
+            on_start();
             const TBlob& indices = row_id.data();
             using namespace mxnet::common;
             NDArray temp = retained_gpu;
@@ -693,8 +704,6 @@ class CommDevice : public Comm {
               case gpu::kDevMask: {
                 SparseRetainOpForwardRspWrapper<gpu>(
                     rctx.get_stream<gpu>(), src, indices, kWriteTo, &temp);
-                // wait for GPU operations to complete
-                rctx.get_stream<gpu>()->Wait();
                 break;
               }
 #endif
diff --git a/src/kvstore/gradient_compression.cc b/src/kvstore/gradient_compression.cc
index bfec4fe..0205409 100644
--- a/src/kvstore/gradient_compression.cc
+++ b/src/kvstore/gradient_compression.cc
@@ -151,8 +151,6 @@ void GradientCompression::Quantize(const mxnet::NDArray& from,
             [from, to, residual, threshold](mxnet::RunContext ctx) {
               std::vector<mxnet::TBlob> inputs = {from.data(), residual->data(), to->data()};
               Quantize1BitImpl(ctx.get_stream<mshadow::gpu>(), inputs, threshold);
-              // Wait GPU kernel to complete
-              ctx.get_stream<mshadow::gpu>()->Wait();
             },
             from.ctx(),
             {from.var()},
@@ -165,8 +163,6 @@ void GradientCompression::Quantize(const mxnet::NDArray& from,
             [from, to, residual, threshold](mxnet::RunContext ctx) {
               std::vector<mxnet::TBlob> inputs = {from.data(), residual->data(), to->data()};
               Quantize2BitImpl(ctx.get_stream<mshadow::gpu>(), inputs, threshold);
-              // Wait GPU kernel to complete
-              ctx.get_stream<mshadow::gpu>()->Wait();
             },
             from.ctx(),
             {from.var()},
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index bc4ce42..09612a5 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -421,7 +421,9 @@ class KVStoreDist : public KVStoreLocal {
     }
     gradient_compression_->Quantize(comm_buf, &small_buf, &res_buf, priority);
     auto push_to_servers = [this, key, dtype, pskv, small_buf](RunContext rctx,
+                                                               Engine::CallbackOnStart on_start,
                                                                Engine::CallbackOnComplete cb) {
+      on_start();
       size_t size = small_buf.shape().Size() * mshadow::mshadow_sizeof(dtype);
       char* data  = static_cast<char*>(small_buf.data().dptr_);
       // do push. false means no delete
@@ -442,7 +444,9 @@ class KVStoreDist : public KVStoreLocal {
 
   virtual void PushDefault(int key, const NDArray& send_buf, const PSKV& pskv, int priority) {
     auto push_to_servers = [this, key, pskv, send_buf](RunContext rctx,
+                                                       Engine::CallbackOnStart on_start,
                                                        Engine::CallbackOnComplete cb) {
+      on_start();
       const int dtype = send_buf.dtype();
       // convert to ps keys
       const size_t size = send_buf.shape().Size() * mshadow::mshadow_sizeof(dtype);
@@ -464,7 +468,10 @@ class KVStoreDist : public KVStoreLocal {
   // push row sparse gradient
   virtual void PushRowSparse(int key, const NDArray& send_buf, int priority) {
     using namespace rowsparse;
-    auto push_to_servers = [this, key, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
+    auto push_to_servers = [this, key, send_buf](RunContext rctx,
+                                                 Engine::CallbackOnStart on_start,
+                                                 Engine::CallbackOnComplete cb) {
+      on_start();
       char* data             = static_cast<char*>(send_buf.data().dptr_);
       const int64_t num_rows = send_buf.aux_shape(kIdx)[0];
       const auto offsets     = send_buf.aux_data(kIdx).dptr<int64_t>();
@@ -492,7 +499,10 @@ class KVStoreDist : public KVStoreLocal {
   }
 
   virtual void PullDefault(int key, const NDArray& recv_buf, int priority) {
-    auto pull_from_servers = [this, key, recv_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
+    auto pull_from_servers = [this, key, recv_buf](RunContext rctx,
+                                                   Engine::CallbackOnStart on_start,
+                                                   Engine::CallbackOnComplete cb) {
+      on_start();
       // convert to ps keys
       size_t size         = recv_buf.shape().Size();
       const int dtype     = recv_buf.dtype();
@@ -531,7 +541,9 @@ class KVStoreDist : public KVStoreLocal {
                               int priority) {
     using namespace rowsparse;
     auto pull_from_servers = [this, key, recv_buf, indices](RunContext rctx,
+                                                            Engine::CallbackOnStart on_start,
                                                             Engine::CallbackOnComplete cb) {
+      on_start();
       // allocate memory for the buffer
       CHECK_EQ(indices.dtype(), mshadow::kInt64);
       const TBlob idx_data  = indices.data();
@@ -573,7 +585,10 @@ class KVStoreDist : public KVStoreLocal {
   }
 
   virtual void PushPullDefault(int key, const NDArray& comm_buf, int priority) {
-    auto pushpull = [this, key, comm_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
+    auto pushpull = [this, key, comm_buf](RunContext rctx,
+                                          Engine::CallbackOnStart on_start,
+                                          Engine::CallbackOnComplete cb) {
+      on_start();
       size_t size         = comm_buf.shape().Size();
       const int dtype     = comm_buf.dtype();
       const int num_bytes = mshadow::mshadow_sizeof(dtype);
diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h
index 29bc455..14276a9 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -430,7 +430,10 @@ class KVStoreDistServer {
     // accumulate row_sparse gradients
     using namespace mshadow;
     Engine::Get()->PushAsync(
-        [to_merge, updateBuf, out](RunContext ctx, Engine::CallbackOnComplete on_complete) {
+        [to_merge, updateBuf, out](RunContext ctx,
+                                   Engine::CallbackOnStart on_start,
+                                   Engine::CallbackOnComplete on_complete) {
+          on_start();
           op::ElemwiseBinaryOp::ComputeEx<cpu, op::mshadow_op::plus>(
               {}, {}, {to_merge, updateBuf->merged}, {kWriteTo}, {out});
           on_complete();
@@ -518,7 +521,10 @@ class KVStoreDistServer {
       store_[master_key] = NDArray(kRowSparseStorage, dshape, Context(), true, type.dtype);
     }
     Engine::Get()->PushAsync(
-        [this, recved, stored, type](RunContext ctx, Engine::CallbackOnComplete on_complete) {
+        [this, recved, stored, type](RunContext ctx,
+                                     Engine::CallbackOnStart on_start,
+                                     Engine::CallbackOnComplete on_complete) {
+          on_start();
           NDArray rsp = stored;
           stored.CheckAndAlloc({mshadow::Shape1(recved.shape()[0])});
           mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 2a0ac3a..8f9dc9b 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -493,7 +493,10 @@ class KVStoreLocal : public KVStore {
     // GPU requires temp resources
     bool is_gpu = out.ctx().dev_mask() == gpu::kDevMask;
     Engine::Get()->PushAsync(
-        [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+        [=](RunContext rctx,
+            Engine::CallbackOnStart on_start,
+            Engine::CallbackOnComplete on_complete) {
+          on_start();
           // copy data.data() to out.data()
           out.CheckAndAlloc({mshadow::Shape1(num_elements)});
           TBlob out_data = out.data();
@@ -510,8 +513,6 @@ class KVStoreLocal : public KVStore {
               mshadow::Stream<gpu>* s = rctx.get_stream<gpu>();
               ndarray::Copy<gpu, gpu>(data_in_ctx.data(), &out_data, ctx, ctx, rctx);
               UniqueImpl(&workspace, s, out);
-              // wait for GPU operations to complete
-              s->Wait();
               break;
             }
 #endif
diff --git a/src/kvstore/p3store_dist.h b/src/kvstore/p3store_dist.h
index 0e0aff0..56912cd 100644
--- a/src/kvstore/p3store_dist.h
+++ b/src/kvstore/p3store_dist.h
@@ -79,7 +79,9 @@ class P3StoreDist : public KVStoreDist {
 
   void PushDefault(int key, const NDArray& send_buf, const PSKV& pskv, int priority) override {
     auto push_to_servers = [this, key, pskv, send_buf, priority](RunContext rctx,
+                                                                 Engine::CallbackOnStart on_start,
                                                                  Engine::CallbackOnComplete cb) {
+      on_start();
       const int dtype = send_buf.dtype();
       // convert to ps keys
       const size_t size = send_buf.shape().Size() * mshadow::mshadow_sizeof(dtype);
@@ -87,7 +89,6 @@ class P3StoreDist : public KVStoreDist {
       // do push. false means no delete
       ps::SArray<char> vals(data, size, false);
       int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype);
-
       size_t off   = 0;
       auto counter = new std::atomic<int>(pskv.keys.size());
       for (size_t idx = 0; idx < pskv.keys.size(); idx++) {
@@ -127,7 +128,9 @@ class P3StoreDist : public KVStoreDist {
     CHECK(gradient_compression_->get_type() == CompressionType::kNone)
         << "Gradient compression not supported in P3StoreDist.";
     auto pull_from_servers = [this, key, recv_buf, priority](RunContext rctx,
+                                                             Engine::CallbackOnStart on_start,
                                                              Engine::CallbackOnComplete cb) {
+      on_start();
       // convert to ps keys
       size_t size         = recv_buf.shape().Size();
       const int dtype     = recv_buf.dtype();
@@ -181,7 +184,9 @@ class P3StoreDist : public KVStoreDist {
     CHECK(gradient_compression_->get_type() == CompressionType::kNone)
         << "Compression not supported in P3StoreDist";
     auto pushpull = [this, key, comm_buf, priority](RunContext rctx,
+                                                    Engine::CallbackOnStart on_start,
                                                     Engine::CallbackOnComplete cb) {
+      on_start();
       size_t size         = comm_buf.shape().Size();
       const int dtype     = comm_buf.dtype();
       const int num_bytes = mshadow::mshadow_sizeof(dtype);
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index d927ff8..cfcdab2 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -132,7 +132,22 @@ NDArray::Chunk::~Chunk() {
 #endif
   if (auto engine = engine_ref_.lock()) {
     engine->DeleteVariable(
-        [mem, skip_free](RunContext s) {
+        [mem, skip_free, var = this->var](RunContext s) mutable {
+#if MXNET_USE_CUDA
+          auto& sync_obj = var->sync_object;
+          Storage::SyncObj storage_sync_obj;
+          {
+            std::lock_guard<std::mutex> l(sync_obj.mutex);
+            for (auto& ev : sync_obj.reader_events) {
+              storage_sync_obj.events.push_back(ev.event);
+            }
+            if (!sync_obj.writer_event.empty()) {
+              auto ev = sync_obj.writer_event[0];
+              storage_sync_obj.events.push_back(ev.event);
+            }
+          }
+          mem.h.sync_obj = storage_sync_obj;
+#endif
           if (skip_free == false) {
 #if MXNET_USE_ONEDNN == 1
             if (mem.mem) {
@@ -744,7 +759,10 @@ void NDArray::Reorder2DefaultAsync() const {
   std::vector<Engine::VarHandle> mutable_vars(1, this->var());
   NDArray tmp = *this;
   Engine::Get()->PushAsync(
-      [tmp](RunContext ctx, Engine::CallbackOnComplete on_complete) {
+      [tmp](RunContext ctx,
+            Engine::CallbackOnStart on_start,
+            Engine::CallbackOnComplete on_complete) {
+        on_start();
         tmp.ptr_->Reorder2Default();
         on_complete();
       },
@@ -776,7 +794,10 @@ void NDArray::DNNLDataReorderAsync(const dnnl::memory::desc& desc) const {
   NDArray tmp        = *this;
   const auto version = this->version();
   Engine::Get()->PushAsync(
-      [tmp, version, desc](RunContext ctx, Engine::CallbackOnComplete on_complete) {
+      [tmp, version, desc](RunContext ctx,
+                           Engine::CallbackOnStart on_start,
+                           Engine::CallbackOnComplete on_complete) {
+        on_start();
         // MXNet will try to reuse NDArray from memory planning, so we need to ensure
         // the NDArray is still holding the original trunk data.
         if (tmp.version() == version) {
@@ -990,8 +1011,6 @@ void TernaryOp(const NDArray& lhs, const NDArray& mhs, const NDArray& rhs, NDArr
           [lhs, mhs, rhs, ret](RunContext ctx) {
             TBlob tmp = ret.data();
             ndarray::Eval<gpu, OP>(lhs.data(), mhs.data(), rhs.data(), &tmp, ctx);
-            // Wait GPU kernel to complete
-            ctx.get_stream<gpu>()->Wait();
           },
           lhs.ctx(),
           const_vars,
@@ -1078,8 +1097,6 @@ void BinaryOpKernel(const NDArray& lhs, const NDArray& rhs, NDArray* out) {
             TBlob tmp               = ret.data();
             mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
             ndarray::BinaryOpKernelImpl<OP>(s, lhs.data(), rhs.data(), &tmp);
-            // Wait GPU kernel to complete
-            ctx.get_stream<gpu>()->Wait();
           },
           lhs.ctx(),
           const_vars,
@@ -1129,8 +1146,6 @@ void BinaryOp(const NDArray& lhs, const NDArray& rhs, NDArray* out) {
           [lhs, rhs, ret](RunContext ctx) {
             TBlob tmp = ret.data();
             ndarray::Eval<gpu, OP>(lhs.data(), rhs.data(), &tmp, ctx);
-            // Wait GPU kernel to complete
-            ctx.get_stream<gpu>()->Wait();
           },
           lhs.ctx(),
           const_vars,
@@ -1170,8 +1185,6 @@ void SetValueOp(const real_t& rhs, NDArray* out) {
             } else {
               ndarray::Eval(ctx.get_stream<gpu>(), rhs, ret);
             }
-            // Wait GPU kernel to complete
-            ctx.get_stream<gpu>()->Wait();
             break;
           }
 #endif
@@ -1231,8 +1244,6 @@ void ScalarOp(const NDArray& lhs, const real_t& rhs, NDArray* out) {
           [lhs, rhs, ret](RunContext ctx) {
             TBlob tmp = ret.data();
             ndarray::Eval<gpu, OP, reverse>(lhs.data(), rhs, &tmp, ctx);
-            // Wait GPU kernel to complete
-            ctx.get_stream<gpu>()->Wait();
           },
           lhs.ctx(),
           const_vars,
@@ -1458,7 +1469,10 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op
 
   if (a == cpu::kDevMask && b == cpu::kDevMask) {
     Engine::Get()->PushAsync(
-        [from, to, requested](RunContext ctx, Engine::CallbackOnComplete on_complete) {
+        [from, to, requested](RunContext ctx,
+                              Engine::CallbackOnStart on_start,
+                              Engine::CallbackOnComplete on_complete) {
+          on_start();
           CopyFromToImpl<cpu, cpu>(from, to, ctx, requested);
           on_complete();
         },
@@ -1472,9 +1486,11 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op
 #if MXNET_USE_CUDA
     if (a == cpu::kDevMask && b == gpu::kDevMask) {
       Engine::Get()->PushAsync(
-          [from, to, requested](RunContext ctx, Engine::CallbackOnComplete on_complete) {
+          [from, to, requested](RunContext ctx,
+                                Engine::CallbackOnStart on_start,
+                                Engine::CallbackOnComplete on_complete) {
+            on_start();
             CopyFromToImpl<cpu, gpu>(from, to, ctx, requested);
-            ctx.get_stream<gpu>()->Wait();
             on_complete();
           },
           to.ctx(),
@@ -1485,9 +1501,11 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op
           "CopyCPU2GPU");
     } else if (a == gpu::kDevMask && b == cpu::kDevMask) {
       Engine::Get()->PushAsync(
-          [from, to, requested](RunContext ctx, Engine::CallbackOnComplete on_complete) {
+          [from, to, requested](RunContext ctx,
+                                Engine::CallbackOnStart on_start,
+                                Engine::CallbackOnComplete on_complete) {
+            on_start();
             CopyFromToImpl<gpu, cpu>(from, to, ctx, requested);
-            ctx.get_stream<gpu>()->Wait();
             on_complete();
           },
           from.ctx(),
@@ -1498,9 +1516,11 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op
           "CopyGPU2CPU");
     } else if (a == gpu::kDevMask && b == gpu::kDevMask) {
       Engine::Get()->PushAsync(
-          [from, to, requested](RunContext ctx, Engine::CallbackOnComplete on_complete) {
+          [from, to, requested](RunContext ctx,
+                                Engine::CallbackOnStart on_start,
+                                Engine::CallbackOnComplete on_complete) {
+            on_start();
             CopyFromToImpl<gpu, gpu>(from, to, ctx, requested);
-            ctx.get_stream<gpu>()->Wait();
             on_complete();
           },
           from.ctx(),
@@ -1571,8 +1591,6 @@ void ElementwiseSum(const std::vector<NDArray>& source, NDArray* out, int priori
               }
               TBlob tmp = ret.data();
               ndarray::ElementwiseSum<gpu>(source_tblob, &tmp, ctx);
-              // Wait GPU kernel to complete
-              ctx.get_stream<gpu>()->Wait();
             },
             out->ctx(),
             const_vars,
@@ -1601,8 +1619,6 @@ void ElementwiseSum(const std::vector<NDArray>& source, NDArray* out, int priori
 #if MXNET_USE_CUDA
             case gpu::kDevMask: {
               mxnet::ndarray::ElementwiseSum(rctx.get_stream<gpu>(), rsc, source, &result);
-              // wait for GPU operations to complete
-              rctx.get_stream<gpu>()->Wait();
               break;
             }
 #endif
@@ -1696,8 +1712,6 @@ void SampleOP(const real_t& a, const real_t& b, NDArray* out) {
           [a, b, resource, ret](RunContext ctx) {
             TBlob tmp = ret.data();
             ndarray::EvalRandom<gpu, Distribution>(a, b, resource, &tmp, ctx);
-            // Wait GPU kernel to complete
-            ctx.get_stream<gpu>()->Wait();
           },
           out->ctx(),
           {},
@@ -2176,17 +2190,18 @@ void NDArray::SyncCopyFromCPU(const void* data, size_t size) const {
 
   if (this->ctx().dev_mask() == cpu::kDevMask) {
     this->WaitToWrite();
-    RunContext rctx{this->ctx(), nullptr, nullptr, false};
+    RunContext rctx{this->ctx(), nullptr, nullptr};
     TBlob dst = this->data();
     ndarray::Copy<cpu, cpu>(src, &dst, Context::CPU(), Context::CPU(), rctx);
   } else {
 #if MXNET_USE_CUDA
     Engine::Get()->PushAsync(
-        [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+        [&](RunContext rctx,
+            Engine::CallbackOnStart on_start,
+            Engine::CallbackOnComplete on_complete) {
+          on_start();
           TBlob dst = this->data();
           ndarray::Copy<cpu, gpu>(src, &dst, Context::CPU(), this->ctx(), rctx);
-          // Wait GPU kernel to complete
-          rctx.get_stream<gpu>()->Wait();
           on_complete();
         },
         this->ctx(),
@@ -2262,11 +2277,13 @@ void NDArray::SyncCopyFromNDArray(const NDArray& src, int i, int j) {
 #if MXNET_USE_CUDA
     if (src_dev_mask == cpu::kDevMask && dst_dev_mask == gpu::kDevMask) {
       Engine::Get()->PushAsync(
-          [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+          [&](RunContext rctx,
+              Engine::CallbackOnStart on_start,
+              Engine::CallbackOnComplete on_complete) {
+            on_start();
             const TBlob src_data = (i >= 0 ? src.aux_data(i) : src.data());
             TBlob dst_data       = get_dst_data(src_data.shape_);
             ndarray::Copy<cpu, gpu>(src_data, &dst_data, src.ctx(), this->ctx(), rctx);
-            rctx.get_stream<gpu>()->Wait();
             on_complete();
           },
           this->ctx(),
@@ -2277,11 +2294,13 @@ void NDArray::SyncCopyFromNDArray(const NDArray& src, int i, int j) {
           "SyncCopyFromNDArrayCPU2GPU");
     } else if (src_dev_mask == gpu::kDevMask && dst_dev_mask == cpu::kDevMask) {
       Engine::Get()->PushAsync(
-          [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+          [&](RunContext rctx,
+              Engine::CallbackOnStart on_start,
+              Engine::CallbackOnComplete on_complete) {
+            on_start();
             const TBlob src_data = (i >= 0 ? src.aux_data(i) : src.data());
             TBlob dst_data       = get_dst_data(src_data.shape_);
             ndarray::Copy<gpu, cpu>(src_data, &dst_data, src.ctx(), this->ctx(), rctx);
-            rctx.get_stream<gpu>()->Wait();
             on_complete();
           },
           src.ctx(),
@@ -2292,11 +2311,13 @@ void NDArray::SyncCopyFromNDArray(const NDArray& src, int i, int j) {
           "SyncCopyFromNDArrayGPU2CPU");
     } else if (src_dev_mask == gpu::kDevMask && dst_dev_mask == gpu::kDevMask) {
       Engine::Get()->PushAsync(
-          [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+          [&](RunContext rctx,
+              Engine::CallbackOnStart on_start,
+              Engine::CallbackOnComplete on_complete) {
+            on_start();
             const TBlob src_data = (i >= 0 ? src.aux_data(i) : src.data());
             TBlob dst_data       = get_dst_data(src_data.shape_);
             ndarray::Copy<gpu, gpu>(src_data, &dst_data, src.ctx(), this->ctx(), rctx);
-            rctx.get_stream<gpu>()->Wait();
             on_complete();
           },
           this->ctx(),
@@ -2340,7 +2361,7 @@ void NDArray::SyncCopyToCPU(void* data, size_t size) const {
   this->WaitToRead();
 
   if (this->ctx().dev_mask() == cpu::kDevMask) {
-    RunContext rctx{this->ctx(), nullptr, nullptr, false};
+    RunContext rctx{this->ctx(), nullptr, nullptr};
     NDArray src = *this;
 #if MXNET_USE_ONEDNN == 1
     if (src.IsDNNLData())
@@ -2350,10 +2371,40 @@ void NDArray::SyncCopyToCPU(void* data, size_t size) const {
   } else {
 #if MXNET_USE_CUDA
     Engine::Get()->PushAsync(
-        [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+        [&](RunContext rctx,
+            Engine::CallbackOnStart on_start,
+            Engine::CallbackOnComplete on_complete) {
+          on_start();
+          {
+            auto var       = this->var();
+            auto& sync_obj = var->sync_object;
+            std::lock_guard<std::mutex> lock{sync_obj.mutex};
+            bool has_writer = false;
+            std::shared_ptr<cudaEvent_t> w_ev_ptr;
+            if (!sync_obj.writer_event.empty()) {
+              w_ev_ptr   = sync_obj.writer_event[0].event.lock();
+              has_writer = w_ev_ptr ? true : false;
+            }
+            for (auto ev : sync_obj.reader_events) {
+              auto event_ptr = ev.event.lock();
+              if (!event_ptr) {
+                continue;
+              }
+              cudaEvent_t event = *event_ptr;
+              if (has_writer) {
+                auto w_ev = sync_obj.writer_event[0];
+                if (w_ev.stream == ev.stream) {
+                  event      = w_ev.pool_index > ev.pool_index ? *w_ev_ptr : *event_ptr;
+                  has_writer = false;
+                }
+              }
+              CUDA_CALL(cudaEventSynchronize(event));
+            }
+            if (has_writer) {
+              CUDA_CALL(cudaEventSynchronize(*w_ev_ptr));
+            }
+          }
           ndarray::Copy<gpu, cpu>(this->data(), &dst, this->ctx(), Context::CPU(), rctx);
-          // Wait GPU kernel to complete
-          rctx.get_stream<gpu>()->Wait();
           on_complete();
         },
         this->ctx(),
@@ -2386,7 +2437,6 @@ void NDArray::SyncCheckFormat(const bool full_check) const {
     Engine::Get()->PushSync(
         [&](RunContext rctx) {
           common::CheckFormatWrapper<gpu>(rctx, *this, err_cpu, full_check);
-          rctx.get_stream<gpu>()->Wait();
         },
         this->ctx(),
         {this->var()},
@@ -2425,7 +2475,10 @@ void NDArray::WaitToWrite() const {
   Imperative::DCInfo::Compute(*this);
   // Push an empty mutable function to flush all preceding reads to the variable.
   Engine::Get()->PushAsync(
-      [](RunContext, Engine::CallbackOnComplete on_complete) { on_complete(); },
+      [](RunContext, Engine::CallbackOnStart on_start, Engine::CallbackOnComplete on_complete) {
+        on_start();
+        on_complete();
+      },
       Context{},
       {},
       {ptr_->var});
diff --git a/src/operator/custom/ndarray_op.cc b/src/operator/custom/ndarray_op.cc
index ac59d5f..fe07a3e 100644
--- a/src/operator/custom/ndarray_op.cc
+++ b/src/operator/custom/ndarray_op.cc
@@ -87,7 +87,10 @@ void NDArrayOp<xpu>::Forward(const OpContext& ctx,
 
   CHECK(param_.pinfo->forward(ptrs.size(), ptrs.data(), tags.data(), param_.pinfo->p_forward));
   Engine::Get()->PushAsync(
-      [ndcpy, ctx](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+      [ndcpy, ctx](RunContext rctx,
+                   Engine::CallbackOnStart on_start,
+                   Engine::CallbackOnComplete on_complete) {
+        on_start();
         ctx.async_on_complete();
         on_complete();
       },
@@ -144,7 +147,10 @@ void NDArrayOp<xpu>::Backward(const OpContext& ctx,
 
   CHECK(param_.pinfo->backward(ptrs.size(), ptrs.data(), tags.data(), param_.pinfo->p_backward));
   Engine::Get()->PushAsync(
-      [ndcpy, ctx](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+      [ndcpy, ctx](RunContext rctx,
+                   Engine::CallbackOnStart on_start,
+                   Engine::CallbackOnComplete on_complete) {
+        on_start();
         ctx.async_on_complete();
         on_complete();
       },
diff --git a/src/operator/operator_util.cc b/src/operator/operator_util.cc
index 07eba96..b2277b3 100644
--- a/src/operator/operator_util.cc
+++ b/src/operator/operator_util.cc
@@ -490,11 +490,6 @@ void SimpleOpRegEntryImpl::RegisterSourceImperative() {
         [ret, fun, dev_mask, req, env](RunContext ctx) {
           TBlob tmp = ret.data();
           (*fun)(env, &tmp, req, ctx);
-#if MXNET_USE_CUDA
-          if (dev_mask == gpu::kDevMask) {
-            ctx.get_stream<gpu>()->Wait();
-          }
-#endif
         },
         ret.ctx(),
         {},
@@ -672,11 +667,6 @@ void SimpleOpRegEntryImpl::RegisterUnaryImperative() {
         [src, ret, fun, dev_mask, req, env](RunContext ctx) {
           TBlob tmp = ret.data();
           (*fun)(src.data(), env, &tmp, req, ctx);
-#if MXNET_USE_CUDA
-          if (dev_mask == gpu::kDevMask) {
-            ctx.get_stream<gpu>()->Wait();
-          }
-#endif
         },
         src.ctx(),
         const_vars,
@@ -954,11 +944,6 @@ void SimpleOpRegEntryImpl::RegisterBinaryImperative() {
         [lhs, rhs, ret, fun, dev_mask, req, env](RunContext ctx) {
           TBlob tmp = ret.data();
           (*fun)(lhs.data(), rhs.data(), env, &tmp, req, ctx);
-#if MXNET_USE_CUDA
-          if (dev_mask == gpu::kDevMask) {
-            ctx.get_stream<gpu>()->Wait();
-          }
-#endif
         },
         lhs.ctx(),
         const_vars,
diff --git a/src/resource.cc b/src/resource.cc
index 899f58d..010481f 100644
--- a/src/resource.cc
+++ b/src/resource.cc
@@ -266,7 +266,10 @@ class ResourceManagerImpl : public ResourceManager {
     inline void Seed(uint32_t seed) {
       mshadow::Random<xpu>* r = prnd;
       Engine::Get()->PushAsync(
-          [r, seed](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+          [r, seed](RunContext rctx,
+                    Engine::CallbackOnStart on_start,
+                    Engine::CallbackOnComplete on_complete) {
+            on_start();
             r->set_stream(rctx.get_stream<xpu>());
             r->Seed(seed);
             on_complete();
@@ -341,7 +344,10 @@ class ResourceManagerImpl : public ResourceManager {
       uint32_t current_seed = p->ctx.dev_id + i * kMaxNumGPUs + seed * kRandMagic;
       Resource* r           = &(p->resource[i]);
       Engine::Get()->PushAsync(
-          [r, current_seed](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+          [r, current_seed](RunContext rctx,
+                            Engine::CallbackOnStart on_start,
+                            Engine::CallbackOnComplete on_complete) {
+            on_start();
             auto state_space             = static_cast<resource::SpaceAllocator*>(r->ptr_);
             mshadow::Stream<gpu>* stream = rctx.get_stream<gpu>();
             CHECK_EQ(state_space->ctx.dev_id, stream->dev_id)
@@ -448,7 +454,10 @@ class ResourceManagerImpl : public ResourceManager {
     inline void SeedOne(size_t i, uint32_t seed) {
       common::random::RandGenerator<xpu>* r = sampler[i];
       Engine::Get()->PushAsync(
-          [r, seed](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+          [r, seed](RunContext rctx,
+                    Engine::CallbackOnStart on_start,
+                    Engine::CallbackOnComplete on_complete) {
+            on_start();
             r->Seed(rctx.get_stream<xpu>(), seed);
             on_complete();
           },
diff --git a/src/storage/gpu_device_storage.h b/src/storage/gpu_device_storage.h
index ee8be75..a7d7af4 100644
--- a/src/storage/gpu_device_storage.h
+++ b/src/storage/gpu_device_storage.h
@@ -61,6 +61,14 @@ inline void GPUDeviceStorage::Free(Storage::Handle handle) {
 #if MXNET_USE_NCCL
   std::lock_guard<std::mutex> l(Storage::Get()->GetMutex(Context::kGPU));
 #endif  // MXNET_USE_NCCL
+#if MXNET_USE_CUDA
+  for (auto ev : handle.sync_obj.events) {
+    auto valid_ev = ev.lock();
+    if (valid_ev) {
+      MSHADOW_CUDA_CALL(cudaEventSynchronize(*valid_ev));
+    }
+  }
+#endif
   CUDA_CALL(cudaFree(handle.dptr))
   profiler::GpuDeviceStorageProfiler::Get()->OnFree(handle);
 }
diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h
index a58fc6e..0afff32 100644
--- a/src/storage/pooled_storage_manager.h
+++ b/src/storage/pooled_storage_manager.h
@@ -29,6 +29,7 @@
 #include <algorithm>
 #include <mutex>
 #include <tuple>
+#include <utility>
 #include "./storage_manager.h"
 #include "../profiler/storage_profiler.h"
 
@@ -129,7 +130,8 @@ class PooledStorageManager : public StorageManager, public BucketingStrategy, pu
   void Free(Storage::Handle handle) override {
     // Insert returned memory in cache
     std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(dev_type_));
-    StoringMethod::InsertInCache(BucketingStrategy::get_bucket(handle.size), handle.dptr);
+    StoringMethod::InsertInCache(
+        BucketingStrategy::get_bucket(handle.size), handle.dptr, handle.sync_obj);
   }
 
   void DirectFree(Storage::Handle handle) override {
@@ -154,7 +156,7 @@ class PooledStorageManager : public StorageManager, public BucketingStrategy, pu
     UNSET_DEVICE(device_store);
   }
 
-  bool MemoryIsAvalable(size_t roundSize) const {
+  bool MemoryIsAvailable(size_t roundSize) const {
     const auto free = contextHelper_->freeMemorySize();
     return free > roundSize && memory_allocation_limit_ <= free - roundSize;
   }
@@ -178,7 +180,7 @@ void PooledStorageManager<BucketingStrategy, StoringMethod>::Alloc(Storage::Hand
   if (!reuse_pool) {
     SET_DEVICE(device_store, contextHelper_, handle->ctx, true);
     roundSize = BucketingStrategy::RoundAllocSizeForBucket(bucket_id);
-    if (!MemoryIsAvalable(roundSize))
+    if (!MemoryIsAvailable(roundSize))
       ReleaseAllNoLock(false);
 
     void* ret = nullptr;
@@ -204,7 +206,19 @@ void PooledStorageManager<BucketingStrategy, StoringMethod>::Alloc(Storage::Hand
     handle->dptr = ret;
   } else {
     // Reusing memory
-    handle->dptr = reuse_pool->back();
+    auto ptr_syncobj = reuse_pool->back();
+    handle->dptr     = ptr_syncobj.first;
+    if (dev_type_ == Context::kGPU) {
+      handle->sync_obj = ptr_syncobj.second;
+#if MXNET_USE_CUDA
+      for (auto ev : handle->sync_obj.events) {
+        auto valid_ev = ev.lock();
+        if (valid_ev) {
+          MSHADOW_CUDA_CALL(cudaEventSynchronize(*valid_ev));
+        }
+      }
+#endif
+    }
     reuse_pool->pop_back();
   }
 #if MXNET_USE_CUDA
@@ -378,11 +392,11 @@ class RoundPower2 : public RoundHelper {
 class UnorderedMapContainer {
  protected:
   inline void InitContainer(const RoundHelper* p) {}
-  inline void InsertInCache(size_t key, void* dptr) {
-    memory_pool_[key].push_back(dptr);
+  inline void InsertInCache(size_t key, void* dptr, Storage::SyncObj sync_obj) {
+    memory_pool_[key].emplace_back(dptr, sync_obj);
   }
 
-  inline std::vector<void*>* GetMemStorage(size_t key) {
+  inline std::vector<std::pair<void*, Storage::SyncObj>>* GetMemStorage(size_t key) {
     auto&& reuse_it = memory_pool_.find(key);
     return reuse_it != memory_pool_.end() && reuse_it->second.size() ? &reuse_it->second : nullptr;
   }
@@ -392,8 +406,8 @@ class UnorderedMapContainer {
     size_t released_memory = 0;
     for (auto&& i : memory_pool_) {
       for (auto&& j : i.second) {
-        contextHelper->Free(j);
-        GPU_PROFILER_ON_FREE(profilerGPU, j);
+        contextHelper->Free(j.first);
+        GPU_PROFILER_ON_FREE(profilerGPU, j.first);
       }
       released_memory += i.first * i.second.size();
       i.second.clear();
@@ -403,7 +417,7 @@ class UnorderedMapContainer {
   }
 
  private:
-  std::unordered_map<size_t, std::vector<void*>> memory_pool_;
+  std::unordered_map<size_t, std::vector<std::pair<void*, Storage::SyncObj>>> memory_pool_;
 };  // class UnorderedMapContainer
 
 /*!
@@ -422,11 +436,11 @@ class VectorContainer {
     memory_pool_.resize(vector_size);
   }
 
-  inline void InsertInCache(size_t idx, void* dptr) {
-    memory_pool_[idx].push_back(dptr);
+  inline void InsertInCache(size_t idx, void* dptr, Storage::SyncObj sync_obj) {
+    memory_pool_[idx].emplace_back(dptr, sync_obj);
   }
 
-  std::vector<void*>* GetMemStorage(size_t idx) {
+  std::vector<std::pair<void*, Storage::SyncObj>>* GetMemStorage(size_t idx) {
     auto&& reuse_pool = memory_pool_[idx];
     return reuse_pool.size() ? &reuse_pool : nullptr;
   }
@@ -439,8 +453,8 @@ class VectorContainer {
         continue;
 
       for (auto& j : memory_pool_[i]) {
-        contextHelper->Free(j);
-        GPU_PROFILER_ON_FREE(profilerGPU, j);
+        contextHelper->Free(j.first);
+        GPU_PROFILER_ON_FREE(profilerGPU, j.first);
       }
       released_memory += rndHelper->get_size(i) * memory_pool_[i].size();
       memory_pool_[i].clear();
@@ -449,7 +463,7 @@ class VectorContainer {
   }
 
  private:
-  std::vector<std::vector<void*>> memory_pool_;
+  std::vector<std::vector<std::pair<void*, Storage::SyncObj>>> memory_pool_;
   size_t first_bucket_;
 };  // class VectorContainer
 
diff --git a/src/storage/storage.cc b/src/storage/storage.cc
index 04760b3..d11fde2 100644
--- a/src/storage/storage.cc
+++ b/src/storage/storage.cc
@@ -256,7 +256,7 @@ const std::string env_var_name(const char* dev_type, env_var_type type) {
 
 }  // namespace storage
 
-std::shared_ptr<Storage> Storage::_GetSharedRef() {
+const std::shared_ptr<Storage>& Storage::_GetSharedRef() {
 #ifdef __MXNET_JS__
   // dummy code needed for emscripten code to pass
   // do not know why, the new will be NULLPTR
diff --git a/tests/cpp/engine/threaded_engine_test.cc b/tests/cpp/engine/threaded_engine_test.cc
index 11ca2c9..465e387 100644
--- a/tests/cpp/engine/threaded_engine_test.cc
+++ b/tests/cpp/engine/threaded_engine_test.cc
@@ -110,8 +110,12 @@ double EvaluateWorkloads(const std::vector<Workload>& workloads,
     if (engine == nullptr) {
       EvaluateWorkload(wl, data);
     } else {
-      auto func = [wl, data](RunContext ctx, Engine::CallbackOnComplete cb) {
-        EvaluateWorkload(wl, data); cb();
+      auto func = [wl, data](RunContext ctx,
+                             Engine::CallbackOnStart on_start,
+                             Engine::CallbackOnComplete cb) {
+        on_start();
+        EvaluateWorkload(wl, data);
+        cb();
       };
       std::vector<Engine::VarHandle> reads;
       for (auto i : wl.reads) {
@@ -182,7 +186,7 @@ TEST(Engine, RandSumExpr) {
 
 void Foo(mxnet::RunContext, int i) { printf("The fox says %d\n", i); }
 
-void FooAsyncFunc(void*, void* cb_ptr, void* param) {
+void FooAsyncFunc(void*, void*, void* cb_ptr, void* param) {
   if (param == nullptr) {
     LOG(INFO) << "The fox asynchronously says receiving nothing.";
   } else {
@@ -346,12 +350,16 @@ TEST(Engine, basics) {
   printf("============= Test #1 ==============\n");
   for (int i = 0; i < 10; ++i) {
     oprs.push_back(engine->NewOperator(
-        [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
+        [i](mxnet::RunContext ctx,
+            mxnet::Engine::CallbackOnStart on_start,
+            mxnet::Engine::CallbackOnComplete cb) {
+          on_start();
           Foo(ctx, i);
           std::this_thread::sleep_for(std::chrono::seconds{1});
           cb();
         },
-        {var}, {}));
+        {var},
+        {}));
     engine->Push(oprs.at(i), mxnet::Context{});
   }
   engine->WaitForAll();
@@ -368,12 +376,16 @@ TEST(Engine, basics) {
   oprs.clear();
   for (int i = 0; i < 10; ++i) {
     oprs.push_back(engine->NewOperator(
-        [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
+        [i](mxnet::RunContext ctx,
+            mxnet::Engine::CallbackOnStart on_start,
+            mxnet::Engine::CallbackOnComplete cb) {
+          on_start();
           Foo(ctx, i);
           std::this_thread::sleep_for(std::chrono::milliseconds{500});
           cb();
         },
-        {}, {var}));
+        {},
+        {var}));
     engine->Push(oprs.at(i), mxnet::Context{});
   }
   // std::this_thread::sleep_for(std::chrono::seconds{1});
@@ -394,12 +406,17 @@ TEST(Engine, basics) {
   var = engine->NewVariable();
   oprs.clear();
   oprs.push_back(engine->NewOperator(
-      [](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
+      [](mxnet::RunContext ctx,
+         mxnet::Engine::CallbackOnStart on_start,
+         mxnet::Engine::CallbackOnComplete cb) {
         std::this_thread::sleep_for(std::chrono::seconds{2});
+        on_start();
         Foo(ctx, 42);
         cb();
       },
-      {}, {var}, mxnet::FnProperty::kCopyFromGPU));
+      {},
+      {var},
+      mxnet::FnProperty::kCopyFromGPU));
   engine->Push(oprs.at(0), mxnet::Context{});
   LOG(INFO) << "IO operator pushed, should wait for 2 seconds.";
   engine->WaitForVar(var);
@@ -414,12 +431,16 @@ TEST(Engine, basics) {
   var = engine->NewVariable();
   oprs.clear();
   oprs.push_back(engine->NewOperator(
-      [](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
+      [](mxnet::RunContext ctx,
+         mxnet::Engine::CallbackOnStart on_start,
+         mxnet::Engine::CallbackOnComplete cb) {
+        on_start();
         Foo(ctx, 42);
         std::this_thread::sleep_for(std::chrono::seconds{2});
         cb();
       },
-      {var}, {}));
+      {var},
+      {}));
   engine->Push(oprs.at(0), mxnet::Context{});
   LOG(INFO) << "Operator pushed, should not wait.";
   engine->WaitForVar(var);
@@ -452,11 +473,15 @@ TEST(Engine, VarVersion) {
     EXPECT_EQ(var->version(), 0U);
     for (int i = 0; i < 10; ++i) {
       oprs.push_back(engine->NewOperator(
-          [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
+          [i](mxnet::RunContext ctx,
+              mxnet::Engine::CallbackOnStart on_start,
+              mxnet::Engine::CallbackOnComplete cb) {
+            on_start();
             Foo(ctx, i);
             cb();
           },
-          {var}, {}));
+          {var},
+          {}));
       engine->Push(oprs.at(i), mxnet::Context{});
     }
     engine->WaitForAll();
@@ -473,11 +498,15 @@ TEST(Engine, VarVersion) {
     oprs.clear();
     for (int i = 0; i < 10; ++i) {
       oprs.push_back(engine->NewOperator(
-          [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
+          [i](mxnet::RunContext ctx,
+              mxnet::Engine::CallbackOnStart on_start,
+              mxnet::Engine::CallbackOnComplete cb) {
+            on_start();
             Foo(ctx, i);
             cb();
           },
-          {}, {var}));
+          {},
+          {var}));
       engine->Push(oprs.at(i), mxnet::Context{});
     }
     engine->WaitForAll();
diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py
index 94190ab..6a3f3b9 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -529,92 +529,6 @@ def test_large_models():
         # Evaluate model
         net(data_in).asnumpy()
 
-# isolated execution bulking test function to be invoked with different env var settings
-
-
-@mx.util.use_np
-def _test_bulking_in_process(seed, time_per_iteration):
-    # Use flip since it's a simple function with same-sized I/O unlikely to ever be fused.
-    class Flip(gluon.HybridBlock):
-        def __init__(self, **kwargs):
-            super(Flip, self).__init__(**kwargs)
-
-        def forward(self, x):
-            return mx.np.flip(x, axis=0)
-
-    def get_net(num_ops):
-        net = nn.HybridSequential()
-        for _ in range(num_ops):
-            net.add(Flip())
-        return net
-
-    data_shape = (10,)
-    num_ops = 1000
-    num_iterations = 20
-
-    # build model
-    x = mx.np.zeros(data_shape)
-    x.attach_grad()
-    dy = mx.np.ones(data_shape)
-    net = get_net(num_ops)
-    net.hybridize(static_alloc=True, static_shape=True)
-
-    # time a number of forward() and backward() executions after some warm-up iterations
-    warmups = 1
-    for i in range(num_iterations + warmups):
-        with autograd.record():
-            if i == warmups:
-                start = time.time()
-            y = net(x)
-            y.backward(dy)
-            x.grad.wait_to_read()
-
-    time_per_iteration.value = (time.time() - start) / num_iterations
-
-def _test_bulking(test_bulking_func):
-    # test case format: (max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training)
-    test_cases = [(0, 0, True), (1, 1, True), (15, 15, False),
-                  (15, 0, True), (0, 15, True), (15, 15, True)]
-    times = {}
-    times_str = ''
-    for seg_sizes in test_cases:
-        # Create shared variable to return measured time from test process
-        time_per_iteration = mp.Manager().Value('d', 0.0)
-
-        if not run_in_spawned_process(test_bulking_func,
-                                      {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD': str(seg_sizes[0]),
-                                       'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD': str(seg_sizes[1]),
-                                       'MXNET_EXEC_BULK_EXEC_TRAIN': str(seg_sizes[2])},
-                                      time_per_iteration):
-            # skip test since the python version can't run it properly.  Warning msg was logged.
-            return
-        times[seg_sizes] = time_per_iteration.value
-        times_str += \
-            '\n    runtime of (fwd,bwd,enable) op seg setting ({},{},{}) =\t{:.1f} msec'.format(
-                seg_sizes[0], seg_sizes[1], seg_sizes[2], 1000.0 * times[seg_sizes])
-
-    fastest_non_bulked_time = min(times[(0, 0, True)], times[(1, 1, True)], times[(15, 15, False)])
-    slowest_half_bulked_time = max(times[(0, 15, True)], times[(15, 0, True)])
-    fastest_half_bulked_time = min(times[(0, 15, True)], times[(15, 0, True)])
-    fully_bulked_time = times[(15, 15, True)]
-
-    print(times_str)
-    # Non-bulked times[0,0,True], times[1,1,True] and times[15,15,False] should be about the same,
-    # slower than both half-bulked times[0,15,True] and times[15,0,True]
-    assert slowest_half_bulked_time < fastest_non_bulked_time, \
-        'A half-bulked exec time is slower than the non-bulked time by {} secs! {}' \
-        .format(slowest_half_bulked_time - fastest_non_bulked_time, times_str)
-    # The fully bulked times[15,15,True] should be faster than both half-bulked runs
-    assert fully_bulked_time < fastest_half_bulked_time, \
-        'The fully-bulked exec time is slower than a half-bulked time by {} secs! {}' \
-        .format(fully_bulked_time - fastest_half_bulked_time, times_str)
-
-@pytest.mark.skip(reason='skippping temporarily, tracked by https://github.com/apache/incubator-mxnet/issues/14970')
-def test_bulking_gluon_gpu():
-    _test_bulking(_test_bulking_in_process)
-
-
-@mx.util.use_np
 def test_hybridblock_mix_ctx_raise():
     class FooHybrid(gluon.HybridBlock):
         def forward(self, a, b):
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 9ce005b..195e409 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -48,7 +48,6 @@ from test_sparse_ndarray import *
 from test_sparse_operator import *
 from test_ndarray import *
 from test_subgraph_op import *
-from test_gluon_gpu import _test_bulking
 from test_contrib_operator import test_multibox_target_op
 from test_optimizer import test_adamW
 del test_custom_op_fork  #noqa
@@ -2115,78 +2114,6 @@ def test_bilinear_sampler_versions():
                 if req_dict['grid'] is 'write':
                     assert_almost_equal(exe.grad_dict['grid'], exe_list[ref_idx].grad_dict['grid'], rtol=1e-3, atol=1e-5)
 
-
-# isolated execution bulking test function to be invoked with different env var settings
-def _test_bulking_in_process(seed, time_per_iteration):
-    data_shape = (10,)
-    num_ops = 1000
-    num_iterations = 20
-
-    ctx = default_context()
-    # build symbol
-    X = mx.sym.Variable('X')
-    sym = mx.sym.flip(X, axis=0)
-    for _ in range(num_ops-1):
-        sym = mx.sym.flip(sym, axis=0)
-    x = mx.ndarray.zeros(data_shape)
-    dx = mx.ndarray.zeros(data_shape)
-    dy = mx.ndarray.ones(data_shape)
-    exe = sym._bind(ctx=ctx, args=[x], args_grad = {'X':dx})
-
-    # time a number of forward() and backward() executions after some warm-up iterations
-    warmups = 1
-    for i in range(num_iterations+warmups):
-        if i == warmups:
-            start = time.time()
-        exe.forward(is_train=True)
-        exe.backward(dy)
-        dx.wait_to_read()
-    time_per_iteration.value = (time.time() - start) / num_iterations
-
-
-@pytest.mark.skip(reason='skippping temporarily, tracked by https://github.com/apache/incubator-mxnet/issues/16517')
-def test_bulking_operator_gpu():
-    _test_bulking(_test_bulking_in_process)
-
-
-@pytest.mark.skip(reason='skippping temporarily, tracked by https://github.com/apache/incubator-mxnet/issues/14970')
-def test_bulking():
-    # test case format: (max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training)
-    test_cases = [(0,0,True), (1,1,True), (15,15,False), (15,0,True), (0,15,True), (15,15,True)]
-    times = {}
-    times_str = ''
-    for seg_sizes in test_cases:
-        # Create shared variable to return measured time from test process
-        time_per_iteration = mp.Manager().Value('d', 0.0)
-        if not run_in_spawned_process(_test_bulking_in_process,
-                                      {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD' : str(seg_sizes[0]),
-                                       'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD' : str(seg_sizes[1]),
-                                       'MXNET_EXEC_BULK_EXEC_TRAIN' : str(seg_sizes[2])},
-                                      time_per_iteration):
-            # skip test since the python version can't run it properly.  Warning msg was logged.
-            return
-        times[seg_sizes] = time_per_iteration.value
-        times_str += \
-            '\n    runtime of (fwd,bwd,enable) op seg setting ({},{},{}) =\t{:.1f} msec'.format(
-            seg_sizes[0], seg_sizes[1], seg_sizes[2], 1000.0 * times[seg_sizes])
-
-    fastest_non_bulked_time = min(times[(0,0,True)], times[(1,1,True)], times[(15,15,False)])
-    slowest_half_bulked_time = max(times[(0,15,True)], times[(15,0,True)])
-    fastest_half_bulked_time = min(times[(0,15,True)], times[(15,0,True)])
-    fully_bulked_time = times[(15,15,True)]
-
-    print(times_str)
-    # Non-bulked times[0,0,True], times[1,1,True] and times[15,15,False] should be about the same,
-    # slower than both half-bulked times[0,15,True] and times[15,0,True]
-    assert slowest_half_bulked_time < fastest_non_bulked_time, \
-        'A half-bulked exec time is slower than the non-bulked time by {} secs! {}' \
-            .format(slowest_half_bulked_time - fastest_non_bulked_time, times_str)
-    # The fully bulked times[15,15,True] should be faster than both half-bulked runs
-    assert fully_bulked_time < fastest_half_bulked_time, \
-        'The fully-bulked exec time is slower than a half-bulked time by {} secs! {}' \
-            .format(fully_bulked_time - fastest_half_bulked_time, times_str)
-
-
 @pytest.mark.serial
 def test_allclose_function_gpu():
     allclose_function([mx.cpu(), mx.gpu(0)])