You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/05 19:29:26 UTC
[incubator-mxnet] branch master updated: Fix custom op multi-gpu
scaling (#9283)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 004dead Fix custom op multi-gpu scaling (#9283)
004dead is described below
commit 004dead77f7c731fdd7d32a1a123ab6044e4db59
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Fri Jan 5 11:29:22 2018 -0800
Fix custom op multi-gpu scaling (#9283)
* refactor custom op
* fix
* fix
* fix
* fix
---
include/mxnet/op_attr_types.h | 2 -
src/c_api/c_api.cc | 2 +-
src/c_api/c_api_function.cc | 72 +++++++++++---------------
src/executor/graph_executor.cc | 3 --
src/imperative/imperative_utils.h | 89 ++++++++++++++++----------------
src/ndarray/ndarray.cc | 2 +-
src/operator/custom/custom-inl.h | 76 +++++++++++++++++++++++++--
src/operator/custom/custom.cc | 105 +++++++++++++++++---------------------
8 files changed, 194 insertions(+), 157 deletions(-)
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index 8cb8a99..fb41d39 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -92,8 +92,6 @@ enum class ExecType {
* will call OpContext.async_on_complete when operation finishes.
*/
kAsync,
- /*! \brief Run this operator on the scheduling thread without pushing to engine. */
- kLocal,
/*!
* \brief Cross device copy operation, this is a special operator
* That indicates copy across devices, the input and output can sit on different device.
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 027f00b..c55f6c5 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -1170,7 +1170,7 @@ int MXRtcFree(RtcHandle handle) {
int MXCustomOpRegister(const char* op_type, CustomOpPropCreator creator) {
API_BEGIN();
- mxnet::op::custom::Registry::Get()->Register(op_type, creator);
+ mxnet::op::custom::CustomOperator::Get()->Register(op_type, creator);
API_END();
}
diff --git a/src/c_api/c_api_function.cc b/src/c_api/c_api_function.cc
index 3cd4f66..e8ca189 100644
--- a/src/c_api/c_api_function.cc
+++ b/src/c_api/c_api_function.cc
@@ -29,6 +29,7 @@
#include "./c_api_common.h"
#include "../operator/operator_common.h"
+#include "../operator/custom/custom-inl.h"
namespace mxnet {
namespace custom_function {
@@ -62,68 +63,55 @@ std::vector<nnvm::NodeEntry> Gradient(
}
OpStatePtr CreateState(const nnvm::NodeAttrs& attrs,
- Context ctx,
- const std::vector<TShape>& ishape,
- const std::vector<int>& itype) {
+ Context ctx,
+ const std::vector<TShape>& ishape,
+ const std::vector<int>& itype) {
LOG(FATAL) << "Not reached";
return OpStatePtr::Create<void*>(nullptr);
}
void Forward(const OpStatePtr& state,
const OpContext& ctx,
- const std::vector<NDArray>& inputs,
+ const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
+ const std::vector<TBlob>& outputs) {
LOG(FATAL) << "Not reached";
}
void Backward(const OpStatePtr& state,
const OpContext& ctx,
- const std::vector<NDArray>& inputs,
+ const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
+ const std::vector<TBlob>& outputs) {
const CustomFunctionParam& params = state.get_state<CustomFunctionParam>();
std::vector<NDArrayHandle> ptrs;
+ std::vector<NDArray> cpys;
+
+ auto dev_id = ctx.run_ctx.ctx.dev_id;
for (const auto& i : inputs) {
- NDArray* nd = new NDArray(i.Detach());
+ NDArray* nd = new NDArray(i, dev_id);
ptrs.push_back(reinterpret_cast<NDArrayHandle>(nd));
+ cpys.push_back(*nd);
}
for (const auto& i : outputs) {
- NDArray* nd = new NDArray(i.Detach());
+ NDArray* nd = new NDArray(i, dev_id);
ptrs.push_back(reinterpret_cast<NDArrayHandle>(nd));
+ cpys.push_back(*nd);
}
- bool prev_recording = Imperative::Get()->set_is_recording(false);
- bool prev_training = Imperative::Get()->set_is_training(ctx.is_train);
-
- CHECK(reinterpret_cast<CustomFunctionBwdFunc>(
- params.info->callbacks[kCustomFunctionBackward])(
- inputs.size(), outputs.size(), ptrs.data(),
- reinterpret_cast<const int*>(req.data()), ctx.is_train,
- params.info->contexts[kCustomFunctionBackward]));
-
- Imperative::Get()->set_is_training(prev_training);
- Imperative::Get()->set_is_recording(prev_recording);
+ op::custom::CustomOperator::Get()->Push(
+ [=]() {
+ CHECK(reinterpret_cast<CustomFunctionBwdFunc>(
+ params.info->callbacks[kCustomFunctionBackward])(
+ inputs.size(), outputs.size(),
+ const_cast<NDArrayHandle*>(ptrs.data()),
+ reinterpret_cast<const int*>(req.data()), ctx.is_train,
+ params.info->contexts[kCustomFunctionBackward]));
+ }, ctx, false, ctx.is_train, cpys);
}
-// infer storage function for custom op, which assigns kDefaultStorage for
-// all undefined stypes, and dispatch on DispatchMode::kFComputeEx.
-inline bool InferStorageType(const nnvm::NodeAttrs& attrs,
- const int dev_mask,
- DispatchMode* dispatch_mode,
- std::vector<int> *iattr,
- std::vector<int> *oattr) {
- for (int& v : *oattr) {
- if (v == -1) v = kDefaultStorage;
- }
- for (int& v : *iattr) {
- if (v == -1) v = kDefaultStorage;
- }
- op::dispatch_mode_assign(dispatch_mode, DispatchMode::kFComputeEx);
- return true;
-}
NNVM_REGISTER_OP(_CustomFunction)
.set_num_inputs([](const NodeAttrs& attrs) {
@@ -150,9 +138,8 @@ NNVM_REGISTER_OP(_CustomFunction)
})
.set_attr<FCreateOpState>("FCreateOpState", CreateState)
.set_attr<nnvm::FGradient>("FGradient", Gradient)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Forward)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Forward)
-.set_attr<FInferStorageType>("FInferStorageType", InferStorageType);
+.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", Forward)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", Forward);
NNVM_REGISTER_OP(_backward_CustomFunction)
@@ -167,11 +154,10 @@ NNVM_REGISTER_OP(_backward_CustomFunction)
.set_attr<bool>("TIsBackward", true)
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
- return ExecType::kLocal;
+ return ExecType::kAsync;
})
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Backward)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Backward)
-.set_attr<FInferStorageType>("FInferStorageType", InferStorageType);
+.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", Backward)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", Backward);
} // namespace custom_function
} // namespace mxnet
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 77853a6..5f95df3 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1473,9 +1473,6 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
CHECK_EQ(opnode.exec->in_array.size(), 1U);
CHECK_EQ(opnode.exec->out_array.size(), 1U);
CopyFromTo(opnode.exec->in_array[0], &(opnode.exec->out_array[0]));
- } else if (opnode.exec->exec_type() == ExecType::kLocal) {
- bool is_gpu = opnode.ctx.dev_mask() == gpu::kDevMask;
- opnode.exec->Run(RunContext{opnode.ctx, nullptr}, is_gpu);
} else if (opnode.cached_opr != nullptr) {
#if MXNET_USE_PROFILER
bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning;
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 8be1eb4..add568d 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -341,12 +341,15 @@ inline void PushFCompute(const FCompute& fn,
const std::vector<uint32_t>& mutate_idx,
const std::vector<OpReqType>& req) {
using namespace common;
+ static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
+
bool is_train = Imperative::Get()->is_training();
+ ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync;
+ CHECK(exec_type == ExecType::kSync);
std::vector<NDArray> inputs, outputs;
DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
Engine::Get()->PushSync(
- [ctx, attrs, fn, inputs, outputs, requested, is_train, mutate_idx, req](
- RunContext rctx) {
+ [=](RunContext rctx) {
std::vector<TBlob> input_blobs, output_blobs;
// pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays
std::vector<NDArray> pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src;
@@ -354,8 +357,8 @@ inline void PushFCompute(const FCompute& fn,
std::unordered_map<uint32_t, uint32_t> in_temp_idx_map;
// setup blobs
SetupDefaultBlobsInOut(inputs, outputs, &input_blobs, &output_blobs,
- &pre_temp_src, &pre_temp_dst, &post_temp_src, &post_temp_dst,
- &in_temp_idx_map, mutate_idx);
+ &pre_temp_src, &pre_temp_dst, &post_temp_src,
+ &post_temp_dst, &in_temp_idx_map, mutate_idx);
// setup context
OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
bool is_gpu = ctx.dev_mask() == gpu::kDevMask;
@@ -384,27 +387,23 @@ inline void PushFComputeEx(const FComputeEx& fn,
static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
bool is_train = Imperative::Get()->is_training();
- ExecType exec_type = ExecType::kSync;
- if (fexec_type.count(op)) {
- exec_type = fexec_type[op](attrs);
- }
+ ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync;
std::vector<NDArray> inputs, outputs;
DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
- const auto& run = [ctx, exec_type, is_train, attrs, fn, inputs, outputs, requested, req](
- RunContext rctx) {
+ const auto& run = [=](RunContext rctx) {
OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
fn(attrs, opctx, inputs, req, outputs);
- if (exec_type == ExecType::kSync) {
- if (rctx.get_ctx().dev_mask() == gpu::kDevMask) {
- rctx.get_stream<gpu>()->Wait();
- }
+ if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync) {
+ rctx.get_stream<gpu>()->Wait();
}
};
- if (exec_type == ExecType::kLocal) {
+
+ if (exec_type == ExecType::kCrossDeviceCopy) {
run(RunContext{ctx, nullptr});
} else {
+ CHECK(exec_type == ExecType::kSync);
Engine::Get()->PushSync(run, ctx, read_vars, write_vars, FnProperty::kNormal,
- 0, PROFILER_MESSAGE(op->name.c_str()));
+ 0, PROFILER_MESSAGE(op->name.c_str()));
}
}
@@ -424,42 +423,30 @@ inline void PushOperator(const OpStatePtr& state,
static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
bool is_train = Imperative::Get()->is_training();
- ExecType exec_type = ExecType::kSync;
- if (fexec_type.count(op)) {
- exec_type = fexec_type[op](attrs);
- }
+ ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync;
std::vector<NDArray> inputs, outputs;
DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
auto fcompute = common::GetFCompute<FStatefulCompute>(op, "FStatefulCompute", ctx);
auto fcompute_ex = common::GetFCompute<FStatefulComputeEx>(op, "FStatefulComputeEx", ctx);
if (fcompute_ex != nullptr && dispatch_mode == DispatchMode::kFComputeEx) {
- const auto& run = [state, fcompute_ex, inputs, outputs, requested, is_train,
- exec_type, req](
- RunContext rctx) {
- OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
- fcompute_ex(state, opctx, inputs, req, outputs);
- if (exec_type == ExecType::kSync) {
- if (rctx.get_ctx().dev_mask() == gpu::kDevMask) {
+ CHECK(exec_type == ExecType::kSync);
+ Engine::Get()->PushSync(
+ [=](RunContext rctx) {
+ OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
+ fcompute_ex(state, opctx, inputs, req, outputs);
+ if (ctx.dev_mask() == gpu::kDevMask) {
rctx.get_stream<gpu>()->Wait();
}
- }
- };
- if (exec_type == ExecType::kLocal) {
- run(RunContext{ctx, nullptr});
- } else {
- Engine::Get()->PushSync(run, ctx, read_vars, write_vars, FnProperty::kNormal,
- 0, PROFILER_MESSAGE(op->name.c_str()));
- }
+ }, ctx, read_vars, write_vars, FnProperty::kNormal,
+ 0, PROFILER_MESSAGE(op->name.c_str()));
} else {
CHECK(fcompute != nullptr)
<< "One of FStatefulCompute and FStatefulComputeEx must be registered "
<< "for stateful operator " << op->name;
- CHECK(exec_type == ExecType::kSync || exec_type == ExecType::kAsync);
- Engine::Get()->PushSync(
- [state, fcompute, inputs, outputs, requested, is_train, exec_type, mutate_idx, req](
- RunContext rctx) {
- OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
+
+ const auto& run = [=](RunContext rctx, engine::CallbackOnComplete on_complete) {
+ OpContext opctx{is_train, rctx, on_complete, requested};
std::vector<TBlob> input_blobs, output_blobs;
// pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays
@@ -477,13 +464,23 @@ inline void PushOperator(const OpStatePtr& state,
fcompute(state, opctx, input_blobs, req, output_blobs);
// post-fcompute fallback, cast to original storage type, if necessary
CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu);
- if (exec_type == ExecType::kSync) {
- if (is_gpu) {
- rctx.get_stream<gpu>()->Wait();
- }
+ if (is_gpu && exec_type == ExecType::kSync) {
+ rctx.get_stream<gpu>()->Wait();
}
- }, ctx, read_vars, write_vars, FnProperty::kNormal,
- 0, PROFILER_MESSAGE(op->name.c_str()));
+ };
+
+ if (exec_type == ExecType::kSync) {
+ Engine::Get()->PushSync(
+ [=](RunContext rctx) {
+ run(rctx, engine::CallbackOnComplete());
+ }, ctx, read_vars, write_vars, FnProperty::kNormal,
+ 0, PROFILER_MESSAGE(op->name.c_str()));
+ } else {
+ CHECK(exec_type == ExecType::kAsync);
+ Engine::Get()->PushAsync(
+ run, ctx, read_vars, write_vars, FnProperty::kAsync,
+ 0, PROFILER_MESSAGE(op->name.c_str()));
+ }
}
}
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 212fd7c..8a3bb8d 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -1355,7 +1355,7 @@ NNVM_REGISTER_OP(_copyto)
return true;
})
.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
- return ExecType::kLocal;
+ return ExecType::kCrossDeviceCopy;
})
.set_attr<nnvm::FGradient>("FGradient", op::ElemwiseGradUseNone{"_copyto"})
.set_attr<bool>("TIsBackward", true)
diff --git a/src/operator/custom/custom-inl.h b/src/operator/custom/custom-inl.h
index 13101da..38aeefd 100644
--- a/src/operator/custom/custom-inl.h
+++ b/src/operator/custom/custom-inl.h
@@ -30,6 +30,7 @@
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <mxnet/c_api.h>
+#include <mxnet/imperative.h>
#include <map>
#include <vector>
#include <string>
@@ -46,7 +47,7 @@ namespace mxnet {
namespace op {
namespace custom {
-class Registry {
+class CustomOperator {
public:
void Register(const std::string &op_type, CustomOpPropCreator creator) {
std::lock_guard<std::mutex> lock(mutex_);
@@ -63,11 +64,80 @@ class Registry {
return nullptr;
}
- static Registry* Get();
+ template<typename Func>
+ void Push(const Func& func,
+ const OpContext& ctx,
+ bool recording,
+ bool training,
+ const std::vector<NDArray>& arrs) {
+ if (naive_engine_) {
+ func();
+ ctx.async_on_complete();
+ return;
+ }
+ std::unique_lock<std::mutex> lock(mutex_);
+ q_.push(
+ [=]() mutable {
+ bool prev_recording = Imperative::Get()->set_is_recording(recording);
+ bool prev_training = Imperative::Get()->set_is_training(training);
+
+ func();
+
+ Imperative::Get()->set_is_training(prev_training);
+ Imperative::Get()->set_is_recording(prev_recording);
+
+ std::vector<Engine::VarHandle> vars;
+ for (const auto& i : arrs) vars.push_back(i.var());
+ Engine::Get()->PushSync([=](RunContext rctx) {
+ ctx.async_on_complete();
+ }, ctx.run_ctx.ctx, vars, {},
+ FnProperty::kNormal, 0, PROFILER_MESSAGE("CustomOperator"));
+ });
+ cv_.notify_all();
+ }
+
+ ~CustomOperator() {
+ if (naive_engine_) return;
+ {
+ std::unique_lock<std::mutex> lock(mutex_);
+ destructing_ = true;
+ cv_.notify_all();
+ }
+ worker_.join();
+ }
+
+ static CustomOperator* Get();
+
private:
- Registry() {}
+ CustomOperator() {
+ destructing_ = false;
+ naive_engine_ = true;
+ if (std::string("NaiveEngine") != dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) {
+ naive_engine_ = false;
+ worker_ = std::thread(
+ [&]() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ while (!q_.empty() || !destructing_) {
+ cv_.wait(lock, [&] {return !q_.empty() || destructing_;});
+ while (!q_.empty()) {
+ auto fn = q_.front();
+ lock.unlock();
+ fn();
+ lock.lock();
+ q_.pop();
+ }
+ }
+ });
+ }
+ }
std::mutex mutex_;
std::map<std::string, CustomOpPropCreator> registry_;
+ // async worker
+ std::condition_variable cv_;
+ std::thread worker_;
+ std::queue<std::function<void(void)> > q_;
+ bool naive_engine_;
+ bool destructing_;
};
} // namespace custom
diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc
index 280b01b..beb5f3d 100644
--- a/src/operator/custom/custom.cc
+++ b/src/operator/custom/custom.cc
@@ -26,7 +26,6 @@
#include "./custom-inl.h"
#include <mxnet/base.h>
#include <mxnet/ndarray.h>
-#include <mxnet/imperative.h>
#include "../elemwise_op_common.h"
@@ -34,8 +33,8 @@ namespace mxnet {
namespace op {
namespace custom {
-Registry* Registry::Get() {
- static Registry inst;
+CustomOperator* CustomOperator::Get() {
+ static CustomOperator inst;
return &inst;
}
@@ -75,8 +74,8 @@ void AttrParser(NodeAttrs* attrs) {
}
}
CHECK(!params.op_type.empty()) << "Required argument `op_type` is missing.";
- CustomOpPropCreator creator = Registry::Get()->Find(params.op_type);
- CHECK(Registry::Get()->Find(params.op_type) != nullptr)
+ CustomOpPropCreator creator = CustomOperator::Get()->Find(params.op_type);
+ CHECK(CustomOperator::Get()->Find(params.op_type) != nullptr)
<< "Cannot find custom operator " << params.op_type;
params.info.reset(new MXCallbackList, [](MXCallbackList* ptr){
reinterpret_cast<CustomOpDelFunc>(ptr->callbacks[kCustomOpPropDelete])(
@@ -269,103 +268,95 @@ OpStatePtr CreateState(const NodeAttrs& attrs, Context ctx,
void Forward(const OpStatePtr& state,
const OpContext& ctx,
- const std::vector<NDArray>& inputs,
+ const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
+ const std::vector<TBlob>& outputs) {
const CustomParam& params = state.get_state<CustomParam>();
std::vector<void*> ptrs;
std::vector<int> tags;
+ std::vector<NDArray> cpys;
+
+ auto dev_id = ctx.run_ctx.ctx.dev_id;
for (size_t i = 0; i < params.num_args; ++i) {
- NDArray *nd = new NDArray(inputs[i].Detach());
+ NDArray *nd = new NDArray(inputs[i], dev_id);
+ cpys.push_back(*nd);
ptrs.push_back(reinterpret_cast<void*>(nd));
tags.push_back(0);
}
for (size_t i = 0; i < params.num_outs; ++i) {
- NDArray *nd = new NDArray(outputs[i].Detach());
+ NDArray *nd = new NDArray(outputs[i], dev_id);
+ cpys.push_back(*nd);
ptrs.push_back(reinterpret_cast<void*>(nd));
tags.push_back(1);
}
for (size_t i = 0; i < params.num_auxs; ++i) {
- NDArray *nd = new NDArray(inputs[i+params.num_args].Detach());
+ NDArray *nd = new NDArray(inputs[i+params.num_args], dev_id);
+ cpys.push_back(*nd);
ptrs.push_back(reinterpret_cast<void*>(nd));
tags.push_back(4);
}
- bool prev_recording = Imperative::Get()->set_is_recording(false);
- bool prev_training = Imperative::Get()->set_is_training(ctx.is_train);
-
- CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpForward])(
- ptrs.size(), ptrs.data(), tags.data(), reinterpret_cast<const int*>(req.data()),
- static_cast<int>(ctx.is_train), params.info->contexts[kCustomOpForward]));
-
- Imperative::Get()->set_is_training(prev_training);
- Imperative::Get()->set_is_recording(prev_recording);
+ CustomOperator::Get()->Push(
+ [=]() {
+ CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpForward])(
+ ptrs.size(), const_cast<void**>(ptrs.data()), const_cast<int*>(tags.data()),
+ reinterpret_cast<const int*>(req.data()), static_cast<int>(ctx.is_train),
+ params.info->contexts[kCustomOpForward]));
+ }, ctx, false, ctx.is_train, cpys);
}
void Backward(const OpStatePtr& state,
const OpContext& ctx,
- const std::vector<NDArray>& inputs,
+ const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
+ const std::vector<TBlob>& outputs) {
const CustomParam& params = state.get_state<CustomParam>();
size_t total = 2*params.num_args + 2*params.num_outs + params.num_auxs;
std::vector<void*> ptrs(params.num_args + 2*params.num_outs, nullptr);
std::vector<int> tags;
+ std::vector<NDArray> cpys;
+
ptrs.reserve(total);
tags.reserve(total);
for (size_t i = 0; i < params.num_outs; ++i) tags.push_back(3);
for (size_t i = 0; i < params.num_args; ++i) tags.push_back(0);
for (size_t i = 0; i < params.num_outs; ++i) tags.push_back(1);
+ auto dev_id = ctx.run_ctx.ctx.dev_id;
+
for (size_t i = 0; i < params.bwd_idx.size(); ++i) {
- NDArray *nd = new NDArray(inputs[i].Detach());
+ NDArray *nd = new NDArray(inputs[i], dev_id);
+ cpys.push_back(*nd);
ptrs[params.bwd_idx[i]] = reinterpret_cast<void*>(nd);
}
for (size_t i = 0; i < ptrs.size(); ++i) {
if (ptrs[i] == nullptr) ptrs[i] = reinterpret_cast<void*>(new NDArray());
}
for (const auto& i : outputs) {
- NDArray* nd = new NDArray(i.Detach());
+ NDArray* nd = new NDArray(i, dev_id);
+ cpys.push_back(*nd);
ptrs.push_back(reinterpret_cast<void*>(nd));
tags.push_back(2);
}
for (size_t i = 0; i < params.num_auxs; ++i) {
- NDArray* nd = new NDArray(inputs[inputs.size()-params.num_auxs+i].Detach());
+ NDArray* nd = new NDArray(inputs[inputs.size()-params.num_auxs+i], dev_id);
+ cpys.push_back(*nd);
ptrs.push_back(reinterpret_cast<void*>(nd));
tags.push_back(4);
}
- bool prev_recording = Imperative::Get()->set_is_recording(false);
- bool prev_training = Imperative::Get()->set_is_training(ctx.is_train);
-
- CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpBackward])(
- ptrs.size(), ptrs.data(), tags.data(), reinterpret_cast<const int*>(req.data()),
- static_cast<int>(ctx.is_train), params.info->contexts[kCustomOpBackward]));
-
- Imperative::Get()->set_is_training(prev_training);
- Imperative::Get()->set_is_recording(prev_recording);
-}
-
-// infer storage function for custom op, which assigns kDefaultStorage for
-// all undefined stypes, and dispatch on DispatchMode::kFComputeEx.
-inline bool InferStorageType(const nnvm::NodeAttrs& attrs,
- const int dev_mask,
- DispatchMode* dispatch_mode,
- std::vector<int> *iattr,
- std::vector<int> *oattr) {
- for (int& v : *oattr) {
- if (v == -1) v = kDefaultStorage;
- }
- for (int& v : *iattr) {
- if (v == -1) v = kDefaultStorage;
- }
- dispatch_mode_assign(dispatch_mode, DispatchMode::kFComputeEx);
- return true;
+ CustomOperator::Get()->Push(
+ [=]() {
+ CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpBackward])(
+ ptrs.size(), const_cast<void**>(ptrs.data()), const_cast<int*>(tags.data()),
+ reinterpret_cast<const int*>(req.data()), static_cast<int>(ctx.is_train),
+ params.info->contexts[kCustomOpBackward]));
+ }, ctx, false, ctx.is_train, cpys);
}
NNVM_REGISTER_OP(Custom)
@@ -401,13 +392,12 @@ Please check the tutorial here: http://mxnet.io/how_to/new_op.html.
return ret;
})
.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
- return ExecType::kLocal;
+ return ExecType::kAsync;
})
.set_attr<nnvm::FGradient>("FGradient", Gradient)
.set_attr<FCreateOpState>("FCreateOpState", CreateState)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Forward)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Forward)
-.set_attr<FInferStorageType>("FInferStorageType", InferStorageType)
+.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", Forward)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", Forward)
.add_argument("data", "NDArray-or-Symbol[]", "Input data for the custom operator.")
.add_argument("op_type", "string", "Name of the custom operator. "
"This is the name that is passed to `mx.operator.register` "
@@ -426,11 +416,10 @@ NNVM_REGISTER_OP(_backward_Custom)
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<bool>("TIsBackward", true)
.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
- return ExecType::kLocal;
+ return ExecType::kAsync;
})
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Backward)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Backward)
-.set_attr<FInferStorageType>("FInferStorageType", InferStorageType);
+.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", Backward)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", Backward);
} // namespace custom
} // namespace op
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].