You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/08 02:41:28 UTC
[incubator-mxnet] branch master updated: Update custom.cc (#7373)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new ada6d4e Update custom.cc (#7373)
ada6d4e is described below
commit ada6d4e0bbfb6a244a868c8ef6edf40529dd996d
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Mon Aug 7 19:41:25 2017 -0700
Update custom.cc (#7373)
---
src/operator/custom/custom.cc | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc
index ee42063..5a40be9 100644
--- a/src/operator/custom/custom.cc
+++ b/src/operator/custom/custom.cc
@@ -268,13 +268,13 @@ void Forward(const OpStatePtr& state,
tags.push_back(4);
}
- bool old = autograd::AutogradRuntime::Get()->SetIsTraining(false);
+ bool old = autograd::AutogradRuntime::Get()->SetIsRecording(false);
CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpForward])(
ptrs.size(), ptrs.data(), tags.data(), reinterpret_cast<const int*>(req.data()),
static_cast<int>(ctx.is_train), params.info->contexts[kCustomOpForward]));
- autograd::AutogradRuntime::Get()->SetIsTraining(old);
+ autograd::AutogradRuntime::Get()->SetIsRecording(old);
}
@@ -312,13 +312,13 @@ void Backward(const OpStatePtr& state,
tags.push_back(4);
}
- bool old = autograd::AutogradRuntime::Get()->SetIsTraining(false);
+ bool old = autograd::AutogradRuntime::Get()->SetIsRecording(false);
CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpBackward])(
- ptrs.size(), ptrs.data(), tags.data(), reinterpret_cast<const int*>(req.data()), 1,
- params.info->contexts[kCustomOpBackward]));
+ ptrs.size(), ptrs.data(), tags.data(), reinterpret_cast<const int*>(req.data()),
+ static_cast<int>(ctx.is_train), params.info->contexts[kCustomOpBackward]));
- autograd::AutogradRuntime::Get()->SetIsTraining(old);
+ autograd::AutogradRuntime::Get()->SetIsRecording(old);
}
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].