You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/01/03 23:04:41 UTC

[GitHub] piiswrong commented on a change in pull request #9283: Fix custom op multi-gpu scaling

piiswrong commented on a change in pull request #9283: Fix custom op multi-gpu scaling
URL: https://github.com/apache/incubator-mxnet/pull/9283#discussion_r159549694
 
 

 ##########
 File path: src/operator/custom/custom-inl.h
 ##########
 @@ -63,11 +64,80 @@ class Registry {
     return nullptr;
   }
 
-  static Registry* Get();
+  template<typename Func>
+  void Push(const Func& func,
+            const OpContext& ctx,
+            bool recording,
+            bool training,
+            const std::vector<NDArray>& arrs) {
+    if (naive_engine_) {
+      func();
+      ctx.async_on_complete();
+      return;
+    }
+    std::unique_lock<std::mutex> lock(mutex_);
+    q_.push(
+      [=]() mutable {
+        bool prev_recording = Imperative::Get()->set_is_recording(recording);
+        bool prev_training = Imperative::Get()->set_is_training(training);
+
+        func();
+
+        Imperative::Get()->set_is_training(prev_training);
+        Imperative::Get()->set_is_recording(prev_recording);
+
+        std::vector<Engine::VarHandle> vars;
+        for (const auto& i : arrs) vars.push_back(i.var());
+        Engine::Get()->PushSync([=](RunContext rctx) {
 
 Review comment:
   That's fine. Engine is thread safe

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services