You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/08 02:41:28 UTC

[incubator-mxnet] branch master updated: Update custom.cc (#7373)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new ada6d4e  Update custom.cc (#7373)
ada6d4e is described below

commit ada6d4e0bbfb6a244a868c8ef6edf40529dd996d
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Mon Aug 7 19:41:25 2017 -0700

    Update custom.cc (#7373)
---
 src/operator/custom/custom.cc | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc
index ee42063..5a40be9 100644
--- a/src/operator/custom/custom.cc
+++ b/src/operator/custom/custom.cc
@@ -268,13 +268,13 @@ void Forward(const OpStatePtr& state,
     tags.push_back(4);
   }
 
-  bool old = autograd::AutogradRuntime::Get()->SetIsTraining(false);
+  bool old = autograd::AutogradRuntime::Get()->SetIsRecording(false);
 
   CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpForward])(
     ptrs.size(), ptrs.data(), tags.data(), reinterpret_cast<const int*>(req.data()),
     static_cast<int>(ctx.is_train), params.info->contexts[kCustomOpForward]));
 
-  autograd::AutogradRuntime::Get()->SetIsTraining(old);
+  autograd::AutogradRuntime::Get()->SetIsRecording(old);
 }
 
 
@@ -312,13 +312,13 @@ void Backward(const OpStatePtr& state,
     tags.push_back(4);
   }
 
-  bool old = autograd::AutogradRuntime::Get()->SetIsTraining(false);
+  bool old = autograd::AutogradRuntime::Get()->SetIsRecording(false);
 
   CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpBackward])(
-    ptrs.size(), ptrs.data(), tags.data(), reinterpret_cast<const int*>(req.data()), 1,
-    params.info->contexts[kCustomOpBackward]));
+    ptrs.size(), ptrs.data(), tags.data(), reinterpret_cast<const int*>(req.data()),
+    static_cast<int>(ctx.is_train), params.info->contexts[kCustomOpBackward]));
 
-  autograd::AutogradRuntime::Get()->SetIsTraining(old);
+  autograd::AutogradRuntime::Get()->SetIsRecording(old);
 }
 
 

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].