You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2018/06/15 21:52:24 UTC
[incubator-mxnet] branch master updated: Revert "[WIP] Do Not
Merge. Static memory allocation for cached_op (#10817)" (#11311)
This is an automated email from the ASF dual-hosted git repository.
marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new e48a8fd Revert "[WIP] Do Not Merge. Static memory allocation for cached_op (#10817)" (#11311)
e48a8fd is described below
commit e48a8fd45112b7542b9f3e6499cdd8dc044a0a43
Author: Marco de Abreu <ma...@users.noreply.github.com>
AuthorDate: Fri Jun 15 14:52:16 2018 -0700
Revert "[WIP] Do Not Merge. Static memory allocation for cached_op (#10817)" (#11311)
This reverts commit 2dbd143e4892bb9ad4aa1835c79f0046603e3531.
---
include/mxnet/c_api.h | 5 +
include/mxnet/imperative.h | 89 ++++
include/mxnet/ndarray.h | 8 -
include/mxnet/op_attr_types.h | 33 +-
python/mxnet/_ctypes/ndarray.py | 16 +-
python/mxnet/gluon/block.py | 74 ++--
src/c_api/c_api_ndarray.cc | 26 +-
src/engine/threaded_engine.cc | 3 +-
src/executor/attach_op_execs_pass.cc | 165 ++++---
src/executor/attach_op_resource_pass.cc | 16 +-
src/executor/exec_pass.h | 28 +-
src/executor/graph_executor.cc | 2 +-
src/imperative/cached_op.cc | 750 ++++++--------------------------
src/imperative/cached_op.h | 174 --------
src/imperative/imperative.cc | 90 +++-
src/imperative/imperative_utils.cc | 120 -----
src/imperative/imperative_utils.h | 256 ++---------
tests/python/unittest/test_gluon.py | 56 +--
18 files changed, 523 insertions(+), 1388 deletions(-)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 4dd858a..55c26bc 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -987,6 +987,11 @@ MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle,
int num_flags,
const char** keys,
const char** vals,
+ int num_inputs,
+ const char** input_names,
+ int num_params,
+ const char** param_names,
+ NDArrayHandle* params,
CachedOpHandle *out);
/*!
* \brief free cached operator
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index 7ea60df..758ce85 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -35,6 +35,23 @@
#include "./ndarray.h"
namespace mxnet {
+/*! \brief CachedOp Parameters */
+struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
+ uint32_t inline_limit;
+ uint32_t forward_bulk_size;
+ uint32_t backward_bulk_size;
+ DMLC_DECLARE_PARAMETER(CachedOpConfig) {
+ DMLC_DECLARE_FIELD(inline_limit)
+ .set_default(2)
+ .describe("Maximum number of operators that can be inlined.");
+ DMLC_DECLARE_FIELD(forward_bulk_size)
+ .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+ .describe("Segment size of bulk execution during forward pass.");
+ DMLC_DECLARE_FIELD(backward_bulk_size)
+ .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+ .describe("Segment size of bulk execution during backward pass.");
+ }
+};
/*! \brief runtime functions for NDArray */
class Imperative {
public:
@@ -77,6 +94,67 @@ class Imperative {
&& info.out_grads.size() == 1;
}
};
+ class CachedOp {
+ public:
+ CachedOp(
+ const nnvm::Symbol& sym,
+ const std::vector<std::pair<std::string, std::string> >& flags,
+ const std::vector<std::string> arg_names,
+ const std::unordered_map<std::string, std::vector<NDArray> >& params);
+ uint32_t num_inputs() {
+ return fwd_graph_.indexed_graph().input_nodes().size();
+ }
+ uint32_t num_outputs() {
+ return fwd_graph_.outputs.size();
+ }
+ uint32_t num_backward_inputs() {
+ return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
+ }
+ std::vector<bool>& save_inputs() {
+ return save_inputs_;
+ }
+ std::vector<bool>& save_outputs() {
+ return save_outputs_;
+ }
+ const std::unordered_set<uint32_t>& mutable_input_nodes() {
+ return fwd_graph_.indexed_graph().mutable_input_nodes();
+ }
+ nnvm::Graph GetForwardGraph(const bool recording,
+ const std::vector<NDArray*>& inputs);
+ nnvm::Graph GetBackwardGraph(const OpStatePtr& state,
+ const std::vector<OpReqType>& reqs,
+ const std::vector<NDArray*>& inputs);
+ std::vector<nnvm::NodeEntry> Gradient(const nnvm::NodePtr& node,
+ const std::vector<nnvm::NodeEntry>& ograds);
+ void Forward(const std::shared_ptr<CachedOp>& op_ptr,
+ const std::vector<NDArray*>& args,
+ const std::vector<NDArray*>& outputs);
+ void Backward(const bool retain_graph,
+ const OpStatePtr& state,
+ const std::vector<NDArray*>& inputs,
+ const std::vector<OpReqType>& reqs,
+ const std::vector<NDArray*>& outputs);
+
+ private:
+ struct CachedOpState {
+ std::vector<NDArray> buff;
+ std::vector<OpStatePtr> states;
+ };
+ std::mutex mutex_;
+ CachedOpConfig config_;
+ nnvm::Graph fwd_graph_;
+ nnvm::Graph grad_graph_;
+ nnvm::Graph full_graph_;
+ std::unordered_map<Context, std::vector<NDArray> > params_;
+ bool inlining_;
+ std::vector<nnvm::NodeEntry> ograd_entries_;
+ std::vector<bool> curr_grad_req_;
+ std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
+ std::vector<uint32_t> fwd_args_idx_;
+ std::vector<uint32_t> fwd_params_idx_;
+ std::vector<uint32_t> bwd_input_eid_;
+ std::vector<bool> save_inputs_, save_outputs_;
+ };
/*! \brief whether operator recording is on. */
bool is_training() const {
return is_train_;
@@ -144,6 +222,15 @@ class Imperative {
uint32_t num_inputs, uint32_t num_outputs,
std::vector<bool> *p_save_inputs,
std::vector<bool> *p_save_outputs);
+ void RunGraph(
+ const bool retain_graph,
+ const nnvm::IndexedGraph& idx,
+ const std::vector<NDArray*> arrays,
+ size_t node_start, size_t node_end,
+ std::vector<OpReqType>&& array_reqs,
+ std::vector<uint32_t>&& ref_count,
+ std::vector<OpStatePtr> *p_states,
+ const DispatchModeVector& dispatch_modes);
/*! \brief indicate whether is training. */
#if DMLC_CXX11_THREAD_LOCAL
static thread_local bool is_train_;
@@ -160,5 +247,7 @@ class Imperative {
int backward_bulk_size_{0};
};
+using CachedOpPtr = std::shared_ptr<Imperative::CachedOp>;
+
} // namespace mxnet
#endif // MXNET_IMPERATIVE_H_
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index ae96fd8..e243eb7 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -155,14 +155,6 @@ class NDArray {
return byte_offset_ > 0 || shape() != ptr_->storage_shape;
}
- /* \brief Check whether the two arrays are the same array */
- inline bool IsSame(const NDArray& other) {
- return ptr_ == other.ptr_ &&
- shape_ == other.shape_ &&
- byte_offset_ == other.byte_offset_ &&
- dtype_ == other.dtype_;
- }
-
/*!
* \return the shape of current NDArray.
*/
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index f4694ef..3969d84 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -126,36 +126,25 @@ class OpStatePtr {
template<typename T, typename... Args>
static OpStatePtr Create(Args&&... args) {
OpStatePtr ret;
- auto state = new T(std::forward<Args>(args)...);
- auto var = Engine::Get()->NewVariable();
- ret.ptr_.reset(
- new OpState(var, state),
- [](OpState* p) {
- Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), p->var);
- delete reinterpret_cast<T*>(p->state);
- delete p;
- });
+ ret.ptr_ = std::make_shared<OpState>();
+ ret.ptr_->var_ = Engine::Get()->NewVariable();
+ ret.ptr_->state_.construct<T>(std::forward<Args>(args)...);
return ret;
}
/* \brief Get engine variable associated with this state */
engine::VarHandle get_var() const {
- return ptr_->var;
+ return ptr_->var_;
}
/* \brief Get state of type T */
template<typename T>
T& get_state() const {
- return *reinterpret_cast<T*>(ptr_->state);
+ return dmlc::get<T>(ptr_->state_);
}
/* \brief clear state */
void reset() {
ptr_.reset();
}
- /* \brief checks whether the managed object is managed only by the current
- OpStatePtr instance */
- bool unique() const {
- return ptr_.unique();
- }
/* \brief Whether state is empty */
explicit operator bool() const {
return ptr_ ? true : false;
@@ -164,12 +153,16 @@ class OpStatePtr {
private:
/* \brief state structure */
struct OpState {
- engine::VarHandle var;
- void* state;
-
- OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {}
+ OpState() {}
OpState(const OpState& other) = delete;
OpState& operator=(const OpState& other) = delete;
+
+ ~OpState() {
+ Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), var_);
+ }
+
+ engine::VarHandle var_;
+ dmlc::any state_;
};
/* \brief shared pointer to state */
std::shared_ptr<OpState> ptr_;
diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py
index f324545..d2cae0c 100644
--- a/python/mxnet/_ctypes/ndarray.py
+++ b/python/mxnet/_ctypes/ndarray.py
@@ -105,14 +105,28 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
class CachedOp(object):
"""Cached operator handle."""
__slots__ = ["handle"]
- def __init__(self, sym, flags=()):
+ def __init__(self, sym, flags=(), inputs=None, params=None):
self.handle = CachedOpHandle()
+ param_names = []
+ param_arrays = []
+ if inputs is None:
+ assert params is None, "When inputs is None params must also be None."
+ inputs = sym.list_inputs()
+ elif params is not None:
+ for name, arrs in params.items():
+ param_arrays.extend(arrs)
+ param_names.extend([name] * len(arrs))
check_call(_LIB.MXCreateCachedOpEx(
sym.handle,
len(flags),
c_str_array([key for key, _ in flags]),
c_str_array([str(val) for _, val in flags]),
+ len(inputs),
+ c_str_array(inputs),
+ len(param_names),
+ c_str_array(param_names),
+ c_handle_array(param_arrays),
ctypes.byref(self.handle)))
def __del__(self):
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 293fafa..3b97c05 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -502,16 +502,8 @@ class Block(object):
----------
active : bool, default True
Whether to turn hybrid on or off.
- static_alloc : bool, default False
- Statically allocate memory to improve speed. Memory usage may increase.
- static_shape : bool, default False
- Optimize for invariant input shapes between iterations. Must also
- set static_alloc to True. Change of input shapes is still allowed
- but slower.
- forward_bulk_size : int, default 15
- Segment size of bulk execution during forward pass.
- backward_bulk_size : int, default 15
- Segment size of bulk execution during backward pass.
+ **kwargs : string
+ Additional flags for hybridized operator.
"""
for cld in self._children.values():
cld.hybridize(active, **kwargs)
@@ -704,7 +696,7 @@ class HybridBlock(Block):
self._out_format = None
self._in_format = None
self._active = False
- self._flags = []
+ self._flags = {}
def __setattr__(self, name, value):
"""Registers parameters."""
@@ -731,43 +723,39 @@ class HybridBlock(Block):
return self._cached_graph
def _build_cache(self, *args):
- data, out = self._get_graph(*args)
- data_names = {data.name : i for i, data in enumerate(data)}
- params = self.collect_params()
- input_names = out.list_inputs()
+ inputs, out = self._get_graph(*args)
+ input_names = [i.name for i in inputs]
+ params = self.collect_params()
param_names = set(params.keys())
- expected_names = set(input_names)
+ expected_names = set(out.list_inputs())
for name in expected_names:
- assert name in param_names or name in data_names, \
+ assert name in param_names or name in input_names, \
"Unknown input to HybridBlock: %s"%name
- used_data_names = [i for i in data_names if i in expected_names]
- if len(used_data_names) != len(data_names):
- unused = ', '.join(['%d-th'%i for name, i in data_names.items()
+ used_input_names = [i for i in input_names if i in expected_names]
+ if len(used_input_names) != len(input_names):
+ unused = ', '.join(['%d-th'%i for i, name in enumerate(input_names)
if name not in expected_names])
warnings.warn("The %s input to HybridBlock is not used by any "
"computation. Is this intended?"%unused, stacklevel=4)
- used_param_names = [i for i in param_names if i in expected_names]
+ used_param_names = set(i for i in param_names if i in expected_names)
if len(used_param_names) != len(param_names):
- unused = ', '.join(list(param_names - set(used_param_names)))
+ unused = ', '.join(list(param_names - used_param_names))
warnings.warn("Parameter %s is not used by any computation. "
"Is this intended?"%unused, stacklevel=4)
- data_indices = []
- param_indices = []
- self._cached_op_args = []
- for i, name in enumerate(input_names):
- if name in data_names:
- data_indices.append(i)
- self._cached_op_args.append((True, data_names[name]))
- else:
- param_indices.append(i)
- self._cached_op_args.append((False, params[name]))
- flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
- self._flags
- self._cached_op = ndarray.CachedOp(out, flags)
+ used_params = {k: params[k] for k in used_param_names}
+ try:
+ param_dict = {k: v.list_data() for k, v in used_params.items()}
+ except DeferredInitializationError:
+ self._deferred_infer_shape(*args)
+ for i in used_params.values():
+ i._finish_deferred_init()
+ param_dict = {k: v.list_data() for k, v in used_params.items()}
+
+ self._cached_op = ndarray.CachedOp(out, self._flags, input_names, param_dict)
def _deferred_infer_shape(self, *args):
try:
@@ -783,19 +771,7 @@ class HybridBlock(Block):
args, fmt = _flatten(args, "input")
assert fmt == self._in_format, "Invalid input format"
- try:
- cargs = [args[i] if is_arg else i.data()
- for is_arg, i in self._cached_op_args]
- except DeferredInitializationError:
- self._deferred_infer_shape(*args)
- cargs = []
- for is_arg, i in self._cached_op_args:
- if is_arg:
- cargs.append(args[i])
- else:
- i._finish_deferred_init()
- cargs.append(i.data())
- out = self._cached_op(*cargs)
+ out = self._cached_op(*args)
if isinstance(out, NDArray):
out = [out]
return _regroup(out, self._out_format)[0]
@@ -816,7 +792,7 @@ class HybridBlock(Block):
def hybridize(self, active=True, **kwargs):
self._active = active
- self._flags = list(kwargs.items())
+ self._flags = kwargs.items()
self._clear_cached_op()
if active and self._forward_hooks or self._forward_pre_hooks:
warnings.warn('"{}" is being hybridized while still having forward hook/pre-hook. '
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index 34bd4b2..9aabe04 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -36,7 +36,6 @@
#include "../common/utils.h"
#include "../common/exec_utils.h"
#include "../imperative/imperative_utils.h"
-#include "../imperative/cached_op.h"
using namespace mxnet;
@@ -161,8 +160,12 @@ int MXCreateCachedOp(SymbolHandle handle,
std::vector<std::string> input_names;
input_names.reserve(inputs.size());
for (const auto& i : inputs) input_names.push_back(i->attrs.name);
- *out = new CachedOpPtr(new CachedOp(
- *sym, std::vector<std::pair<std::string, std::string> >()));
+ *out = new std::shared_ptr<Imperative::CachedOp>(
+ new Imperative::CachedOp(
+ *sym,
+ std::vector<std::pair<std::string, std::string> >(),
+ input_names,
+ std::unordered_map<std::string, std::vector<NDArray> >()));
API_END();
}
@@ -170,6 +173,11 @@ int MXCreateCachedOpEx(SymbolHandle handle,
int num_flags,
const char** keys,
const char** vals,
+ int num_args,
+ const char** arg_names,
+ int num_params,
+ const char** param_names,
+ NDArrayHandle* params,
CachedOpHandle *out) {
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(handle);
@@ -178,7 +186,17 @@ int MXCreateCachedOpEx(SymbolHandle handle,
for (int i = 0; i < num_flags; ++i) {
flags.push_back({keys[i], vals[i]});
}
- *out = new CachedOpPtr(new CachedOp(*sym, flags));
+ std::vector<std::string> args;
+ for (int i = 0; i < num_args; ++i) {
+ args.push_back(arg_names[i]);
+ }
+ std::unordered_map<std::string, std::vector<NDArray> > param_dict;
+ for (int i = 0; i < num_params; ++i) {
+ param_dict[param_names[i]].emplace_back(
+ *reinterpret_cast<NDArray*>(params[i]));
+ }
+ *out = new std::shared_ptr<Imperative::CachedOp>(
+ new Imperative::CachedOp(*sym, flags, args, param_dict));
API_END();
}
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index e70cc19..dc0436e 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -278,8 +278,6 @@ void ThreadedEngine::DeleteOperator(OprHandle op) {
}
void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool profiling) {
- BulkFlush();
-
ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op);
OprBlock* opr_block = OprBlock::New();
opr_block->opr = threaded_opr;
@@ -325,6 +323,7 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
<< device_count_;
}
#endif
+ BulkFlush();
ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, wait);
opr->temporary = true;
const bool profiling = profiler_->IsProfiling(profiler::Profiler::kImperative);
diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc
index 72919d9..697e486 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -134,10 +134,6 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor {
return state_.get_var();
}
- OpStatePtr state() const override {
- return state_;
- }
-
explicit StatefulComputeExecutor(const OpStatePtr& state,
const FStatefulCompute& fcompute,
ExecType exec_type,
@@ -146,6 +142,7 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor {
state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
private:
+ friend Graph AttachOpExecs(Graph g);
OpStatePtr state_;
FStatefulCompute fcompute_;
ExecType exec_type_;
@@ -173,16 +170,13 @@ class StatefulComputeExExecutor : public OpExecutor {
return state_.get_var();
}
- OpStatePtr state() const override {
- return state_;
- }
-
explicit StatefulComputeExExecutor(const OpStatePtr& state,
const FStatefulComputeEx& fcompute,
ExecType exec_type)
: state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
private:
+ friend Graph AttachOpExecs(Graph g);
OpStatePtr state_;
FStatefulComputeEx fcompute_;
ExecType exec_type_;
@@ -247,15 +241,16 @@ class FComputeExExecutor : public OpExecutor {
ExecType exec_type_;
};
-void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
+// pass to attach operator executors
+Graph AttachOpExecs(Graph g) {
using nnvm::DTypeVector;
using nnvm::ShapeVector;
using nnvm::FMutateInputs;
- static auto& fcreate_op_state = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
- static auto& fmutate_inputs = nnvm::Op::GetAttr<FMutateInputs>("FMutateInputs");
- static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
- static auto& is_layer_backward = nnvm::Op::GetAttr<bool>("TIsLayerOpBackward");
+ auto& fcreate_op_state = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+ auto& fmutate_inputs = nnvm::Op::GetAttr<FMutateInputs>("FMutateInputs");
+ auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
+ auto& is_layer_backward = nnvm::Op::GetAttr<bool>("TIsLayerOpBackward");
const auto& vdtype = g.GetAttr<DTypeVector>("dtype");
const auto& vshape = g.GetAttr<ShapeVector>("shape");
@@ -264,88 +259,82 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
// get the graph
const auto& idx = g.indexed_graph();
- OpExecVector& ret = *p_ret;
+ std::vector<std::shared_ptr<OpExecutor> > ret(idx.num_nodes());
// initialize the nodes
- const auto& inode = idx[i];
- if (inode.source->is_variable()) return;
- const nnvm::Op *op = inode.source->op();
- ExecType exec_type = ExecType::kSync;
- std::vector<uint32_t> mutate_index;
- if (fmutate_inputs.count(op)) {
- mutate_index = fmutate_inputs[op](inode.source->attrs);
- }
- if (fexec_type.count(op)) {
- exec_type = fexec_type[op](inode.source->attrs);
- }
- CHECK(dispatch_modes[i] != DispatchMode::kUndefined);
- if (fcreate_op_state.count(op)) {
- std::vector<TShape> ishape;
- std::vector<int> itype;
- for (const auto& e : inode.inputs) {
- ishape.emplace_back(vshape[idx.entry_id(e)]);
- itype.emplace_back(vdtype[idx.entry_id(e)]);
- }
-
- OpStatePtr state = fcreate_op_state[op](
- inode.source->attrs, vctx[i], ishape, itype);
- FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
- op, "FStatefulComputeEx", vctx[i]);
- // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
- if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
- ret[i] = std::make_shared<StatefulComputeExExecutor>(state, fcompute_ex, exec_type);
- } else {
- FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
- op, "FStatefulCompute", vctx[i]);
- CHECK(fcompute != nullptr)
- << "One of FStatefulCompute and FStatefulComputeEx must be registered "
- << "for stateful operator " << op->name;
- ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute,
- exec_type, mutate_index);
+ for (size_t i = 0; i < idx.num_nodes(); ++i) {
+ const auto& inode = idx[i];
+ if (inode.source->is_variable()) continue;
+ const nnvm::Op *op = inode.source->op();
+ ExecType exec_type = ExecType::kSync;
+ std::vector<uint32_t> mutate_index;
+ if (fmutate_inputs.count(op)) {
+ mutate_index = fmutate_inputs[op](inode.source->attrs);
}
- } else if (is_layer_backward.get(op, false)) {
- CHECK_GE(inode.control_deps.size(), 1);
- uint32_t fwd_id = inode.control_deps[0];
- CHECK(vctx[fwd_id] == vctx[i]);
- CHECK(ret[fwd_id] != nullptr);
- FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
- op, "FStatefulComputeEx", vctx[i]);
- // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
- if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
- ret[i] = std::make_shared<StatefulComputeExExecutor>(
- ret[fwd_id].get()->state(), fcompute_ex, exec_type);
- } else {
- FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
- op, "FStatefulCompute", vctx[i]);
- CHECK(fcompute != nullptr)
- << "One of FStatefulCompute and FStatefulComputeEx must be registered "
- << "for stateful operator " << op->name;
- ret[i] = std::make_shared<StatefulComputeExecutor>(
- ret[fwd_id].get()->state(), fcompute, exec_type, mutate_index);
+ if (fexec_type.count(op)) {
+ exec_type = fexec_type[op](inode.source->attrs);
}
- } else {
- FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute", vctx[i]);
- FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", vctx[i]);
- if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
- ret[i] = std::make_shared<FComputeExExecutor>(
- inode.source->attrs, fcomp_ex, exec_type);
- } else if (fcompute != nullptr) {
- ret[i] = std::make_shared<FComputeExecutor>(
- inode.source->attrs, fcompute, exec_type, mutate_index);
+ CHECK(dispatch_modes[i] != DispatchMode::kUndefined);
+ if (fcreate_op_state.count(op)) {
+ std::vector<TShape> ishape;
+ std::vector<int> itype;
+ for (const auto& e : inode.inputs) {
+ ishape.emplace_back(vshape[idx.entry_id(e)]);
+ itype.emplace_back(vdtype[idx.entry_id(e)]);
+ }
+
+ OpStatePtr state = fcreate_op_state[op](
+ inode.source->attrs, vctx[i], ishape, itype);
+ FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
+ op, "FStatefulComputeEx", vctx[i]);
+ // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
+ if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
+ ret[i] = std::make_shared<StatefulComputeExExecutor>(state, fcompute_ex, exec_type);
+ } else {
+ FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
+ op, "FStatefulCompute", vctx[i]);
+ CHECK(fcompute != nullptr)
+ << "One of FStatefulCompute and FStatefulComputeEx must be registered "
+ << "for stateful operator " << op->name;
+ ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute,
+ exec_type, mutate_index);
+ }
+ } else if (is_layer_backward.get(op, false)) {
+ CHECK_GE(inode.control_deps.size(), 1);
+ uint32_t fwd_id = inode.control_deps[0];
+ CHECK(vctx[fwd_id] == vctx[i]);
+ CHECK(ret[fwd_id] != nullptr);
+ FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
+ op, "FStatefulComputeEx", vctx[i]);
+ // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
+ if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
+ ret[i] = std::make_shared<StatefulComputeExExecutor>(
+ dynamic_cast<StatefulComputeExExecutor*>(ret[fwd_id].get())->state_,
+ fcompute_ex, exec_type);
+ } else {
+ FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
+ op, "FStatefulCompute", vctx[i]);
+ CHECK(fcompute != nullptr)
+ << "One of FStatefulCompute and FStatefulComputeEx must be registered "
+ << "for stateful operator " << op->name;
+ ret[i] = std::make_shared<StatefulComputeExecutor>(
+ dynamic_cast<StatefulComputeExecutor*>(ret[fwd_id].get())->state_,
+ fcompute, exec_type, mutate_index);
+ }
} else {
- LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
+ FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute", vctx[i]);
+ FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", vctx[i]);
+ if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
+ ret[i] = std::make_shared<FComputeExExecutor>(
+ inode.source->attrs, fcomp_ex, exec_type);
+ } else if (fcompute != nullptr) {
+ ret[i] = std::make_shared<FComputeExecutor>(
+ inode.source->attrs, fcompute, exec_type, mutate_index);
+ } else {
+ LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
+ }
}
}
-}
-
-
-// pass to attach operator executors
-Graph AttachOpExecs(Graph g) {
- const auto& idx = g.indexed_graph();
- OpExecVector ret(idx.num_nodes());
- for (size_t i = 0; i < idx.num_nodes(); ++i) {
- CreateOpExecs(g, &ret, i);
- }
g.attrs["op_execs"] = std::make_shared<nnvm::any>(ret);
return g;
}
diff --git a/src/executor/attach_op_resource_pass.cc b/src/executor/attach_op_resource_pass.cc
index 56122cd..6818662 100644
--- a/src/executor/attach_op_resource_pass.cc
+++ b/src/executor/attach_op_resource_pass.cc
@@ -30,15 +30,12 @@
namespace mxnet {
namespace exec {
-void AttachOpResources(
- const Graph& g,
- const OpExecVector& op_execs,
- size_t start_nid,
- size_t end_nid) {
+Graph AttachOpResources(Graph g) {
static auto& fresource =
nnvm::Op::GetAttr<FResourceRequest>("FResourceRequest");
static auto& fresource_ex =
nnvm::Op::GetAttr<FResourceRequestEx>("FResourceRequestEx");
+ auto& op_execs = nnvm::get<OpExecVector>(*g.attrs.at("op_execs"));
const auto& vctx = g.GetAttr<ContextVector>("context");
const auto& vdispatch = g.GetAttr<DispatchModeVector>("dispatch_mode");
const auto& dev_masks = g.GetAttr<DevMaskVector>("dev_mask");
@@ -46,7 +43,7 @@ void AttachOpResources(
// Use global resource pool for each executor for now.
std::map<Context, Resource> cached_temp;
// Resource allocation
- for (uint32_t nid = start_nid; nid < end_nid; ++nid) {
+ for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
const Context &ctx = vctx[nid];
@@ -87,12 +84,7 @@ void AttachOpResources(
requested.push_back(ResourceManager::Get()->Request(ctx, ResourceRequest::kTempSpace));
}
}
+ return g;
}
-
-void AttachOpResources(const Graph& g) {
- const auto& op_execs = g.GetAttr<OpExecVector>("op_execs");
- AttachOpResources(g, op_execs, 0, g.indexed_graph().num_nodes());
-}
-
} // namespace exec
} // namespace mxnet
diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index 26a2491..99b1b16 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -82,10 +82,6 @@ class OpExecutor {
virtual engine::VarHandle var() const {
return nullptr;
}
- /*! \return return operator state */
- virtual OpStatePtr state() const {
- return OpStatePtr();
- }
};
/*!
@@ -107,14 +103,6 @@ using ContextVector = std::vector<Context>;
using DevMaskVector = std::vector<int>;
/*!
- * \brief create OpExecutor for a node in graph
- *
- * \param g input graph
- * \param p_ret OpExecVector for input and output
- * \param i the id of the node
- */
-void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i);
-/*!
* \brief Attach OpExecutor to the graph attributes.
*
* \param g input graph
@@ -127,20 +115,12 @@ Graph AttachOpExecs(Graph g);
* \brief Attach Resource to the OpExecVector of the graph.
*
* \param g input graph need to contain op_exec attribute.
- */
-void AttachOpResources(const Graph& g);
-/*!
- * \brief Attach Resource to the OpExecVector
*
- * \param g input graph
- * \param op_execs OpExecutor vector
- * \param start_nid starting node id
- * \param end_nid end node id
+ * \return graph with new attribute "op_exec" of type OpExecVector
+ * The fields on the OpExecVector are not yet been setup.
*/
-void AttachOpResources(const Graph& g,
- const OpExecVector& op_execs,
- size_t start_nid,
- size_t end_nid);
+Graph AttachOpResources(Graph g);
+
/*!
* \brief Discover chance of inplace addto operators.
* i.e. z = plus(z, source_op), and encourage it to become z += source_op.
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 831b5f9..e28867d 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -912,7 +912,7 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol,
}
g = AttachOpExecs(g);
- AttachOpResources(g);
+ g = AttachOpResources(g);
graph_ = std::move(g);
if (shared_exec != nullptr) {
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index b40605b..140b5a5 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -19,78 +19,16 @@
#include <unordered_set>
#include <iostream>
#include "./imperative_utils.h"
-#include "./cached_op.h"
-#include "../executor/exec_pass.h"
-#include "../profiler/profiler.h"
-
namespace mxnet {
DMLC_REGISTER_PARAMETER(CachedOpConfig);
-struct CachedOp::GraphInfo {
- nnvm::Graph fwd_graph;
- nnvm::Graph full_graph;
- std::vector<OpReqType> bwd_output_reqs;
- std::vector<uint32_t> bwd_input_eid;
-};
-
-struct CachedOp::DynamicRuntime {
- GraphInfo info;
- std::vector<NDArray> buff;
- std::vector<OpStatePtr> op_states;
-};
-
-struct CachedOp::CachedOpState {
- CachedOpState(const Context& context_,
- const nnvm::Graph& fwd_graph_,
- const nnvm::Graph& full_graph_) {
- context = context_;
- info.fwd_graph = fwd_graph_;
- info.full_graph = full_graph_;
-
- size_t max_nodes = info.full_graph.indexed_graph().num_nodes();
- size_t max_entries = info.full_graph.indexed_graph().num_node_entries();
- info.fwd_graph.attrs["context"] = std::make_shared<dmlc::any>(
- std::vector<Context>(info.fwd_graph.indexed_graph().num_nodes(), context));
- info.full_graph.attrs["context"] = std::make_shared<dmlc::any>(
- std::vector<Context>(max_nodes, context));
-
- buff.resize(max_entries);
- arrays.resize(max_entries);
- array_reqs.resize(max_entries);
- dynamic_entries.resize(max_entries, false);
- op_states.resize(max_nodes);
- execs.resize(max_nodes);
- opr_segs.resize(max_nodes);
- }
-
- std::mutex mutex;
- Context context;
- GraphInfo info;
-
- bool recording = false;
- bool fwd_alloc = false;
- bool bwd_alloc = false;
- bool fwd_exec_init = false;
- bool bwd_exec_init = false;
-
- std::vector<NDArray> buff;
- std::vector<NDArray*> arrays;
- std::vector<OpReqType> array_reqs;
-
- std::vector<OpStatePtr> op_states;
- std::vector<std::shared_ptr<exec::OpExecutor> > execs;
- std::vector<imperative::EngineOprSeg> opr_segs;
-
- std::vector<bool> dynamic_entries;
- std::multimap<size_t, NDArray> fwd_reuse_pool;
- std::multimap<size_t, NDArray> bwd_reuse_pool;
-};
-
-CachedOp::CachedOp(
+Imperative::CachedOp::CachedOp(
const nnvm::Symbol& sym,
- const std::vector<std::pair<std::string, std::string> >& flags) {
+ const std::vector<std::pair<std::string, std::string> >& flags,
+ const std::vector<std::string> arg_names,
+ const std::unordered_map<std::string, std::vector<NDArray> >& params) {
using namespace nnvm;
using namespace imperative;
static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")};
@@ -130,22 +68,34 @@ CachedOp::CachedOp(
fwd_graph_.attrs["forward_ref_count"] =
std::make_shared<dmlc::any>(std::move(ref_count));
- inlining_ = !config_.static_alloc &&
- (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit;
+ inlining_ = (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit;
}
// Set params
{
const auto& idx = fwd_graph_.indexed_graph();
- if (config_.data_indices.ndim() || config_.param_indices.ndim()) {
- CHECK_EQ(config_.data_indices.ndim() + config_.param_indices.ndim(),
- idx.input_nodes().size());
- } else {
- std::vector<uint32_t> tmp;
- for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
- tmp.push_back(i);
+ std::unordered_map<std::string, size_t> arg_name_to_id;
+ for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
+ const auto& name = idx[idx.input_nodes()[i]].source->attrs.name;
+ auto iter = params.find(name);
+ if (iter == params.end()) {
+ arg_name_to_id[name] = i;
+ continue;
+ }
+ fwd_params_idx_.push_back(i);
+ for (const auto& param : iter->second) {
+ params_[param.ctx()].emplace_back(param);
}
- config_.data_indices.assign(tmp.begin(), tmp.end());
+ }
+
+ CHECK_EQ(arg_name_to_id.size(), arg_names.size())
+ << "CachedOp expects " << arg_name_to_id.size()
+ << " inputs, given " << arg_names.size();
+
+ for (const auto& name : arg_names) {
+ auto iter = arg_name_to_id.find(name);
+ CHECK(iter != arg_name_to_id.end()) << "Unexpected input name " << name;
+ fwd_args_idx_.push_back(iter->second);
}
}
@@ -157,14 +107,9 @@ CachedOp::CachedOp(
}
std::vector<NodeEntry> xs;
- const auto& idx = fwd_graph_.indexed_graph();
- for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
- auto nid = idx.input_nodes()[i];
- if (idx.mutable_input_nodes().count(nid)) continue;
- fwd_input_to_grad_output_[i] = xs.size();
- xs.emplace_back(NodeEntry{idx[nid].weak_ref.lock(), 0, 0});
- }
-
+ std::vector<NodePtr> args = sym.ListInputs(Symbol::kReadOnlyArgs);
+ xs.reserve(args.size());
+ for (const auto& i : args) xs.emplace_back(NodeEntry{i, 0, 0});
CHECK_GT(xs.size(), 0)
<< "There are no inputs in computation graph that require gradients.";
@@ -180,7 +125,7 @@ CachedOp::CachedOp(
size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries();
full_graph_.outputs = fwd_graph_.outputs;
- bwd_output_reqs_ = std::vector<OpReqType>(grad_graph_.outputs.size(), kWriteTo);
+ curr_grad_req_ = std::vector<bool>(grad_graph_.outputs.size(), true);
for (const auto& i : grad_graph_.outputs) full_graph_.outputs.emplace_back(i);
const auto& idx = full_graph_.indexed_graph();
@@ -224,10 +169,7 @@ CachedOp::CachedOp(
}
}
-CachedOp::~CachedOp() {
-}
-
-std::vector<nnvm::NodeEntry> CachedOp::Gradient(
+std::vector<nnvm::NodeEntry> Imperative::CachedOp::Gradient(
const nnvm::NodePtr& node,
const std::vector<nnvm::NodeEntry>& ograds) {
using namespace nnvm;
@@ -264,15 +206,13 @@ std::vector<nnvm::NodeEntry> CachedOp::Gradient(
return ret;
}
-
-bool CachedOp::SetForwardGraph(
- GraphInfo* info,
- const bool recording,
- const std::vector<NDArray*>& inputs) {
+nnvm::Graph Imperative::CachedOp::GetForwardGraph(
+ const bool recording, const std::vector<NDArray*>& inputs) {
using namespace nnvm;
using namespace imperative;
+ std::lock_guard<std::mutex> lock(mutex_);
CHECK_EQ(inputs.size(), num_inputs());
- nnvm::Graph& g = info->fwd_graph;
+ nnvm::Graph& g = fwd_graph_;
ShapeVector shape_inputs;
DTypeVector dtype_inputs;
@@ -297,22 +237,18 @@ bool CachedOp::SetForwardGraph(
g.attrs.erase("forward_mem_plan");
g.attrs.erase("full_mem_plan");
} else if (g.attrs.count(recording ? "full_mem_plan" : "forward_mem_plan")) {
- return true;
+ return g;
}
const auto& idx = g.indexed_graph();
StorageVector storage(idx.num_node_entries(), exec::kBadStorageID);
+ for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID;
const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
CHECK_EQ(stypes.size(), storage.size());
for (size_t i = 0; i < stypes.size(); i++) {
- if (stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID;
- }
- for (const auto i : idx.input_nodes()) {
- storage[idx.entry_id(i, 0)] = exec::kExternalStorageID;
- }
- for (size_t i = 0; i < idx.outputs().size(); ++i) {
- storage[idx.entry_id(idx.outputs()[i])] = exec::kExternalStorageID;
+ if (stypes[i] != kDefaultStorage)
+ storage[i] = exec::kDynamicStorageID;
}
auto mem_plan = PlanMemory(
@@ -321,50 +257,51 @@ bool CachedOp::SetForwardGraph(
g.attrs[recording ? "full_mem_plan" : "forward_mem_plan"] =
std::make_shared<dmlc::any>(std::move(mem_plan));
- return false;
+ return g;
}
-bool CachedOp::SetBackwardGraph(
- GraphInfo* info,
+nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
+ const OpStatePtr& op_state,
const std::vector<OpReqType>& reqs,
- const std::vector<NDArray*>& inputs,
- bool detect_inplace_addto) {
+ const std::vector<NDArray*>& inputs) {
using namespace nnvm;
using namespace imperative;
std::lock_guard<std::mutex> lock(mutex_);
- Context default_ctx = inputs[0]->ctx();
- nnvm::Graph& g = info->full_graph;
-
- if (info->bwd_output_reqs != reqs) {
- info->bwd_output_reqs = reqs;
- info->bwd_input_eid.clear();
+ nnvm::Graph& g = full_graph_;
+ auto& state = op_state.get_state<CachedOpState>();
+ bool req_match = true;
+ for (size_t i = 0; i < reqs.size(); ++i) {
+ if (curr_grad_req_[i] != (reqs[i] != kNullOp)) {
+ curr_grad_req_[i] = reqs[i] != kNullOp;
+ req_match = false;
+ }
+ }
+ if (!req_match) {
g = nnvm::Graph();
g.outputs = fwd_graph_.outputs;
for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
- if (info->bwd_output_reqs[i] == kNullOp) continue;
- g.outputs.emplace_back(grad_graph_.outputs[i]);
+ if (curr_grad_req_[i]) g.outputs.emplace_back(grad_graph_.outputs[i]);
}
- g.attrs["context"] = std::make_shared<dmlc::any>(
- std::vector<Context>(g.indexed_graph().num_nodes(), default_ctx));
+ bwd_input_eid_.clear();
}
const auto& idx = g.indexed_graph();
- if (info->bwd_input_eid.size() != inputs.size()) {
- info->bwd_input_eid.clear();
+ if (bwd_input_eid_.size() != inputs.size()) {
+ bwd_input_eid_.clear();
for (const auto& i : bwd_ograd_dep_) {
auto eid = idx.entry_id(ograd_entries_[i]);
- info->bwd_input_eid.push_back(eid);
+ bwd_input_eid_.push_back(eid);
}
for (const auto& i : bwd_in_dep_) {
auto eid = idx.entry_id(idx.input_nodes()[i], 0);
- info->bwd_input_eid.push_back(eid);
+ bwd_input_eid_.push_back(eid);
}
for (const auto& i : bwd_out_dep_) {
auto eid = idx.entry_id(idx.outputs()[i]);
- info->bwd_input_eid.push_back(eid);
+ bwd_input_eid_.push_back(eid);
}
- CHECK_EQ(inputs.size(), info->bwd_input_eid.size());
+ CHECK_EQ(inputs.size(), bwd_input_eid_.size());
}
size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
@@ -375,22 +312,25 @@ bool CachedOp::SetBackwardGraph(
for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) {
for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)];
}
- for (size_t i = 0; i < inputs.size(); ++i) ++ref_count[info->bwd_input_eid[i]];
+ for (size_t i = 0; i < inputs.size(); ++i) ++ref_count[bwd_input_eid_[i]];
for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)];
g.attrs["backward_ref_count"] = std::make_shared<dmlc::any>(std::move(ref_count));
}
- auto shapes = info->fwd_graph.GetAttr<ShapeVector>("shape");
- shapes.resize(idx.num_node_entries(), TShape());
- auto dtypes = info->fwd_graph.GetAttr<DTypeVector>("dtype");
- dtypes.resize(idx.num_node_entries(), -1);
- auto stypes = info->fwd_graph.GetAttr<StorageTypeVector>("storage_type");
- stypes.resize(idx.num_node_entries(), -1);
+ ShapeVector shapes(idx.num_node_entries(), TShape());
+ DTypeVector dtypes(idx.num_node_entries(), -1);
+ StorageTypeVector stypes(idx.num_node_entries(), -1);
+
+ for (size_t i = 0; i < num_forward_entries; ++i) {
+ shapes[i] = state.buff[i].shape();
+ dtypes[i] = state.buff[i].dtype();
+ stypes[i] = state.buff[i].storage_type();
+ }
for (size_t i = 0; i < inputs.size(); ++i) {
- shapes[info->bwd_input_eid[i]] = inputs[i]->shape();
- dtypes[info->bwd_input_eid[i]] = inputs[i]->dtype();
- stypes[info->bwd_input_eid[i]] = inputs[i]->storage_type();
+ shapes[bwd_input_eid_[i]] = inputs[i]->shape();
+ dtypes[bwd_input_eid_[i]] = inputs[i]->dtype();
+ stypes[bwd_input_eid_[i]] = inputs[i]->storage_type();
}
std::pair<uint32_t, uint32_t> node_range, entry_range;
@@ -402,353 +342,79 @@ bool CachedOp::SetBackwardGraph(
node_range, entry_range);
match &= CheckAndInferType(&g, std::move(dtypes), false,
node_range, entry_range);
- exec::DevMaskVector dev_mask(idx.num_nodes(), default_ctx.dev_mask());
+ exec::DevMaskVector dev_mask(idx.num_nodes(), inputs[0]->ctx().dev_mask());
match &= CheckAndInferStorageType(&g, std::move(dev_mask), std::move(stypes),
false, node_range, entry_range);
if (!match) {
g.attrs.erase("backward_mem_plan");
} else if (g.attrs.count("backward_mem_plan")) {
- return true;
+ return g;
}
StorageVector storage(idx.num_node_entries(), exec::kBadStorageID);
- const auto& bwd_stypes = g.GetAttr<StorageTypeVector>("storage_type");
- for (size_t i = 0; i < bwd_stypes.size(); i++) {
- if (bwd_stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID;
- }
for (size_t i = 0; i < num_forward_entries; ++i) storage[i] = exec::kExternalStorageID;
for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID;
for (const auto i : idx.outputs()) storage[idx.entry_id(i)] = exec::kExternalStorageID;
+ for (size_t i = 0; i < stypes.size(); i++) {
+ if (stypes[i] != kDefaultStorage)
+ storage[i] = exec::kDynamicStorageID;
+ }
auto mem_plan = PlanMemory(
&g, std::move(storage), g.GetAttr<std::vector<uint32_t> >("backward_ref_count"),
- {num_forward_nodes, idx.num_nodes()},
- {num_forward_entries, idx.num_node_entries()},
- detect_inplace_addto);
+ {num_forward_nodes, idx.num_nodes()}, {num_forward_entries, idx.num_node_entries()});
g.attrs["backward_mem_plan"] = std::make_shared<dmlc::any>(std::move(mem_plan));
- return false;
-}
-
-OpStatePtr CachedOp::GetCachedOpState(
- const Context& ctx) {
- std::lock_guard<std::mutex> lock(mutex_);
- for (const auto& i : cached_op_states_[ctx]) {
- // only create one state per device when not using static memory
- if (!config_.static_alloc || i.unique()) {
- return i;
- }
- }
- auto state_ptr = OpStatePtr::Create<CachedOpState>(ctx, fwd_graph_, full_graph_);
-
- cached_op_states_[ctx].push_back(state_ptr);
- return state_ptr;
-}
-
-void CachedOp::StaticAllocMemory(
- const OpStatePtr& state_ptr,
- bool recording,
- bool keep_fwd) {
- using namespace nnvm;
- using namespace imperative;
-
- auto& state = state_ptr.get_state<CachedOpState>();
- const auto& default_ctx = state.context;
- nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph;
- const auto& idx = g.indexed_graph();
- const auto& vstorage_inplace = g.GetAttr<std::vector<int> >("storage_inplace_index");
- const auto& mem_plan = g.GetAttr<MemoryPlanVector>(
- keep_fwd ? "backward_mem_plan" : (recording ? "full_mem_plan" : "forward_mem_plan"));
- std::vector<int> addto_entry;
- if (g.attrs.count("addto_entry")) {
- addto_entry = g.GetAttr<std::vector<int> >("addto_entry");
- }
- size_t start_eid =
- keep_fwd ? state.info.fwd_graph.indexed_graph().num_node_entries() : 0;
- size_t end_eid = idx.num_node_entries();
-
- if (!keep_fwd) state.fwd_alloc = false;
- state.bwd_alloc = false;
- for (size_t i = start_eid; i < state.buff.size(); ++i) {
- state.buff[i] = NDArray();
- state.arrays[i] = &state.buff[i];
- state.array_reqs[i] = kNullOp;
- state.dynamic_entries[i] = false;
- }
-
- for (auto i : idx.input_nodes()) {
- auto eid = idx.entry_id(i, 0);
- if (eid >= start_eid) state.dynamic_entries[eid] = true;
- }
- for (auto i : idx.outputs()) {
- auto eid = idx.entry_id(i);
- if (eid >= start_eid) state.dynamic_entries[eid] = true;
- }
-
- for (size_t i = start_eid; i < end_eid; ++i) {
- if (addto_entry.size() && addto_entry[i]) {
- state.array_reqs[i] = kAddTo;
- } else if (vstorage_inplace[i] >= 0) {
- state.array_reqs[i] = kWriteInplace;
- } else if (vstorage_inplace[i] == -2) {
- // -2 indicate that the entry is never referenced.
- state.array_reqs[i] = kNullOp;
- } else {
- state.array_reqs[i] = kWriteTo;
- }
- }
-
- auto& reuse_pool = keep_fwd ? state.bwd_reuse_pool : state.fwd_reuse_pool;
- reuse_pool = imperative::AllocateMemory(
- g, idx, default_ctx, start_eid, end_eid, mem_plan,
- state.arrays, &state.array_reqs, std::move(reuse_pool));
-
- state.recording = recording;
- if (keep_fwd) {
- state.bwd_alloc = true;
- } else {
- state.fwd_alloc = true;
- }
+ return g;
}
-void CachedOp::StaticInitExec(
- const OpStatePtr& state_ptr,
- bool recording,
- bool keep_fwd) {
+void Imperative::CachedOp::Forward(
+ const std::shared_ptr<CachedOp>& op_ptr,
+ const std::vector<NDArray*>& args,
+ const std::vector<NDArray*>& outputs) {
using namespace nnvm;
using namespace imperative;
+ static const auto cached_op = nnvm::Op::Get("_CachedOp");
- auto& state = state_ptr.get_state<CachedOpState>();
- const auto& default_ctx = state.context;
- nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph;
- const auto& idx = g.indexed_graph();
- std::vector<int> skip_plus_node;
- if (g.attrs.count("skip_plus_node")) {
- skip_plus_node = g.GetAttr<std::vector<int> >("skip_plus_node");
- }
- size_t start_nid =
- keep_fwd ? state.info.fwd_graph.indexed_graph().num_nodes() : 0;
- size_t end_nid = idx.num_nodes();
-
- if (!keep_fwd) state.fwd_exec_init = false;
- state.bwd_exec_init = false;
-
- for (size_t i = start_nid; i < state.execs.size(); ++i) {
- state.execs[i].reset();
- state.opr_segs[i] = EngineOprSeg();
- }
-
- if (!config_.static_shape) {
- for (size_t i = start_nid; i < end_nid; ++i) {
- state.opr_segs[i].next_nid = i + 1;
- state.opr_segs[i].skip = skip_plus_node.size() && skip_plus_node[i];
- }
- } else {
- for (size_t i = start_nid; i < end_nid; ++i) {
- exec::CreateOpExecs(g, &state.execs, i);
- }
- exec::AttachOpResources(g, state.execs, start_nid, end_nid);
-
- for (size_t i = start_nid; i < end_nid; ++i) {
- bool skip = idx[i].source->is_variable();
- for (size_t j = 0; !skip && j < idx[i].inputs.size(); ++j) {
- skip = state.dynamic_entries[idx.entry_id(idx[i].inputs[j])];
- }
- for (size_t j = 0; !skip && j < idx[i].source->num_outputs(); ++j) {
- skip = state.dynamic_entries[idx.entry_id(i, j)];
- }
- if (skip) continue;
- SetupOpExec(g, i, state.execs[i], state.arrays, state.array_reqs);
- }
+ CHECK_EQ(args.size(), fwd_args_idx_.size())
+ << "CachedOp requires " << fwd_args_idx_.size()
+ << " inputs but got " << args.size();
- size_t bulk_size = idx.num_nodes();
- std::unordered_set<uint32_t> excludes;
- if (recording || keep_fwd) {
- bulk_size = keep_fwd ? config_.backward_bulk_size : config_.forward_bulk_size;
- for (const auto& i : idx.outputs()) excludes.insert(idx.entry_id(i));
- for (const auto& i : idx.input_nodes()) excludes.insert(idx.entry_id(i, 0));
- }
+ Context default_ctx = args[0]->ctx();
- CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size, excludes,
- state.execs, skip_plus_node, &state.opr_segs);
- }
- if (keep_fwd) {
- state.bwd_exec_init = true;
- } else {
- state.fwd_exec_init = true;
+ std::vector<NDArray*> inputs(num_inputs());
+ for (index_t i = 0; i < fwd_args_idx_.size(); ++i) {
+ inputs[fwd_args_idx_[i]] = args[i];
}
-}
-
-void CachedOp::StaticRunOps(
- const Context& default_ctx,
- const nnvm::Graph& g,
- const OpStatePtr& state_ptr,
- size_t start_nid,
- size_t end_nid) {
- static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
- static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
-
- bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning;
- bool is_training = Imperative::Get()->is_training();
- auto& state = state_ptr.get_state<CachedOpState>();
- const auto& idx = g.indexed_graph();
- const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
- const auto& op_execs = state.execs;
-
- std::vector<NDArray*> ndinputs, ndoutputs;
- nnvm::ShapeVector arg_shapes;
- nnvm::DTypeVector arg_dtypes;
- std::vector<OpReqType> req;
+ if (fwd_params_idx_.size()) {
+ CHECK(params_.find(default_ctx) != params_.end())
+ << "CachedOp is not initialized on context " << default_ctx;
- for (size_t i = start_nid; config_.static_shape && i < end_nid; ++i) {
- if (op_execs[i]) op_execs[i]->op_ctx.is_train = is_training;
- }
-
- for (size_t i = start_nid; i < end_nid; i = state.opr_segs[i].next_nid) {
- const auto& opr_seg = state.opr_segs[i];
- if (opr_seg.skip) continue;
- if (opr_seg.opr != nullptr) {
- Engine::Get()->Push(opr_seg.opr.get(), default_ctx, 0, profiling);
- } else {
- const nnvm::IndexedGraph::Node& node = idx[i];
- if (node.source->is_variable()) continue;
- auto num_outputs = node.source->num_outputs();
- ndinputs.clear();
- ndinputs.reserve(node.inputs.size());
- for (const auto& j : node.inputs) {
- ndinputs.emplace_back(state.arrays[idx.entry_id(j)]);
- CHECK(!ndinputs.back()->is_none());
- }
- ndoutputs.clear();
- ndoutputs.reserve(num_outputs);
- req.clear();
- req.reserve(num_outputs);
- for (size_t j = 0; j < num_outputs; ++j) {
- size_t eid = idx.entry_id(i, j);
- ndoutputs.emplace_back(state.arrays[eid]);
- req.push_back(state.array_reqs[eid]);
- CHECK(req.back() == kNullOp || !ndoutputs.back()->is_none());
- }
- const DispatchMode dispatch_mode = dispatch_modes[i];
- if (createop.count(node.source->op())) {
- arg_shapes.clear();
- arg_dtypes.clear();
- arg_shapes.reserve(ndinputs.size());
- arg_dtypes.reserve(ndinputs.size());
- for (size_t i = 0; i < ndinputs.size(); ++i) {
- arg_shapes.emplace_back(ndinputs[i]->shape());
- arg_dtypes.emplace_back(ndinputs[i]->dtype());
- }
- state.op_states[i] = createop[node.source->op()](
- node.source->attrs, default_ctx, arg_shapes, arg_dtypes);
- Imperative::Get()->InvokeOp(
- default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
- dispatch_mode, state.op_states[i]);
- } else if (is_layer_backward.get(node.source->op(), false)) {
- nnvm::Node* fwd_node = node.source->control_deps[0].get();
- auto fwd_node_id = idx.node_id(fwd_node);
- Imperative::Get()->InvokeOp(
- default_ctx, node.source->attrs, ndinputs, ndoutputs,
- req, dispatch_mode, state.op_states[fwd_node_id]);
- } else {
- Imperative::Get()->InvokeOp(
- default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
- dispatch_mode);
- }
+ for (size_t i = 0; i < fwd_params_idx_.size(); ++i) {
+ inputs[fwd_params_idx_[i]] = ¶ms_[default_ctx][i];
}
}
-}
-
-OpStatePtr CachedOp::StaticForward(
- const Context& default_ctx,
- const std::vector<NDArray*>& inputs,
- const std::vector<NDArray*>& outputs) {
- using namespace nnvm;
- using namespace imperative;
+ // Initialize
bool recording = Imperative::Get()->is_recording();
- auto state_ptr = GetCachedOpState(default_ctx);
- auto& state = state_ptr.get_state<CachedOpState>();
- std::lock_guard<std::mutex> lock(state.mutex);
-
- bool match = SetForwardGraph(&state.info, recording, inputs);
- match = match && state.recording != recording;
-
- nnvm::Graph& g = state.info.fwd_graph;
+ nnvm::Graph g = GetForwardGraph(recording, inputs);
const auto& idx = g.indexed_graph();
- if (!state.fwd_alloc || !match) {
- StaticAllocMemory(state_ptr, recording, false);
- }
-
- if (config_.static_shape) {
- for (auto i : config_.param_indices) {
- auto nid = idx.input_nodes()[i];
- if (!state.arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[i])) {
- match = false;
- auto ptr = &state.buff[idx.entry_id(nid, 0)];
- CHECK_EQ(state.arrays[idx.entry_id(nid, 0)], ptr);
- *state.arrays[idx.entry_id(nid, 0)] = *inputs[i];
- state.dynamic_entries[idx.entry_id(nid, 0)] = false;
- }
- }
- for (auto i : config_.data_indices) {
- auto eid = idx.entry_id(idx.input_nodes()[i], 0);
- state.arrays[eid] = inputs[i];
- }
- } else {
- for (size_t i = 0; i < num_inputs(); ++i) {
- auto nid = idx.input_nodes()[i];
- state.arrays[idx.entry_id(nid, 0)] = inputs[i];
- }
- }
-
- if (!state.fwd_exec_init || !match) {
- StaticInitExec(state_ptr, recording, false);
- }
-
- const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
- const auto& shapes = g.GetAttr<ShapeVector>("shape");
- const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
+ size_t num_inputs = idx.input_nodes().size();
- for (size_t i = 0; i < outputs.size(); ++i) {
- auto eid = idx.entry_id(idx.outputs()[i]);
- state.arrays[eid] = outputs[i];
- if (!outputs[i]->is_none()) continue;
- *outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
- shapes[eid], default_ctx, true, dtypes[eid]);
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ CHECK_EQ(inputs[i]->ctx(), default_ctx)
+ << "CachedOp requires all inputs to live on the same context. But "
+ << idx[idx.input_nodes()[0]].source->attrs.name << " is on " << default_ctx
+ << " while " << idx[idx.input_nodes()[i]].source->attrs.name << " is on "
+ << inputs[i]->ctx();
}
- StaticRunOps(default_ctx, g, state_ptr, 0, idx.num_nodes());
-
- return recording ? state_ptr : OpStatePtr();
-}
-
-
-OpStatePtr CachedOp::DynamicForward(
- const Context& default_ctx,
- const std::vector<NDArray*>& inputs,
- const std::vector<NDArray*>& outputs) {
- using namespace nnvm;
- using namespace imperative;
-
- // Initialize
- bool recording = Imperative::Get()->is_recording();
- auto op_state = OpStatePtr::Create<DynamicRuntime>();
- auto& runtime = op_state.get_state<DynamicRuntime>();
- {
- auto state_ptr = GetCachedOpState(default_ctx);
- auto& state = state_ptr.get_state<CachedOpState>();
- std::lock_guard<std::mutex> lock(state.mutex);
- SetForwardGraph(&state.info, recording, inputs);
- runtime.info.fwd_graph = state.info.fwd_graph;
- }
- nnvm::Graph& g = runtime.info.fwd_graph;
- const auto& idx = g.indexed_graph();
- size_t num_inputs = idx.input_nodes().size();
- auto& buff = runtime.buff;
- auto& states = runtime.op_states;
+ auto op_state_ptr = OpStatePtr::Create<CachedOpState>();
+ auto& cached_op_state = op_state_ptr.get_state<CachedOpState>();
+ auto& buff = cached_op_state.buff;
+ auto& states = cached_op_state.states;
// Allocate entries
states.resize(idx.num_nodes());
@@ -780,98 +446,57 @@ OpStatePtr CachedOp::DynamicForward(
AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(),
mem_plan, arrays, &array_reqs);
- const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
- const auto& shapes = g.GetAttr<ShapeVector>("shape");
- const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
-
- for (size_t i = 0; i < outputs.size(); ++i) {
- auto eid = idx.entry_id(idx.outputs()[i]);
- arrays[eid] = outputs[i];
- if (!outputs[i]->is_none()) continue;
- *outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
- shapes[eid], default_ctx, true, dtypes[eid]);
- }
-
const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
if (recording && !inlining_) Imperative::Get()->set_is_recording(false);
+ int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size);
- RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
- std::move(ref_count), &states, dispatch_modes);
+ Imperative::Get()->RunGraph(
+ false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
+ std::move(ref_count), &states, dispatch_modes);
+ Engine::Get()->set_bulk_size(prev_bulk_size);
Imperative::Get()->set_is_recording(recording);
- return op_state;
-}
-
-void CachedOp::Forward(
- const std::shared_ptr<CachedOp>& op_ptr,
- const std::vector<NDArray*>& inputs,
- const std::vector<NDArray*>& outputs) {
- static const auto cached_op = nnvm::Op::Get("_CachedOp");
-
- CHECK_EQ(inputs.size(), num_inputs());
-
- Context default_ctx = inputs[0]->ctx();
-
- const auto& idx = fwd_graph_.indexed_graph();
- for (size_t i = 0; i < inputs.size(); ++i) {
- CHECK_EQ(inputs[i]->ctx(), default_ctx)
- << "CachedOp requires all inputs to live on the same context. But "
- << idx[idx.input_nodes()[0]].source->attrs.name
- << " is on " << default_ctx << " while "
- << idx[idx.input_nodes()[i]].source->attrs.name
- << " is on " << inputs[i]->ctx();
- }
-
- int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size);
-
- OpStatePtr op_state;
- if (config_.static_alloc) {
- op_state = StaticForward(default_ctx, inputs, outputs);
- } else {
- op_state = DynamicForward(default_ctx, inputs, outputs);
+ for (size_t i = 0; i < idx.num_node_entries(); ++i) {
+ if (arrays[i] == &buff[i]) continue;
+ buff[i].shape_ = arrays[i]->shape_;
+ buff[i].dtype_ = arrays[i]->dtype_;
+ buff[i].storage_type_ = arrays[i]->storage_type_;
}
- Engine::Get()->set_bulk_size(prev_bulk_size);
-
- if (Imperative::Get()->is_recording() && !inlining_) {
+ if (recording && !inlining_) {
nnvm::NodeAttrs attrs;
attrs.op = cached_op;
attrs.name = "_cachedop";
attrs.parsed = op_ptr;
Imperative::Get()->RecordOp(
- std::move(attrs), inputs, outputs, op_state,
+ std::move(attrs), inputs, outputs, op_state_ptr,
&save_inputs(), &save_outputs());
}
}
-void CachedOp::DynamicBackward(
+void Imperative::CachedOp::Backward(
const bool retain_graph,
- const OpStatePtr& op_state,
+ const OpStatePtr& state,
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs) {
using namespace nnvm;
using namespace imperative;
+ CHECK(!Imperative::Get()->is_recording())
+ << "CachedOp does not support higher order gradients. "
+ << "If you want to do backward with create_graph=True please "
+ << "do not use hybridize.";
// Initialize
- Context default_ctx = outputs[0]->ctx();
- auto& runtime = op_state.get_state<DynamicRuntime>();
- {
- auto state_ptr = GetCachedOpState(default_ctx);
- auto& state = state_ptr.get_state<CachedOpState>();
- std::lock_guard<std::mutex> lock(state.mutex);
- state.info.fwd_graph = runtime.info.fwd_graph;
- SetBackwardGraph(&state.info, reqs, inputs);
- runtime.info.full_graph = state.info.full_graph;
- runtime.info.bwd_input_eid = state.info.bwd_input_eid;
- }
- nnvm::Graph& g = runtime.info.full_graph;
+ nnvm::Graph g = GetBackwardGraph(state, reqs, inputs);
const auto& idx = g.indexed_graph();
- auto& buff = runtime.buff;
- auto& states = runtime.op_states;
+
+ auto& cached_op_state = state.get_state<CachedOpState>();
+ auto& buff = cached_op_state.buff;
+ auto& states = cached_op_state.states;
size_t num_forward_outputs = fwd_graph_.outputs.size();
size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
@@ -881,7 +506,7 @@ void CachedOp::DynamicBackward(
arrays.reserve(buff.size());
for (size_t i = 0; i < buff.size(); ++i) arrays.push_back(&buff[i]);
for (size_t i = 0; i < inputs.size(); ++i) {
- arrays[runtime.info.bwd_input_eid[i]] = inputs[i];
+ arrays[bwd_input_eid_[i]] = inputs[i];
}
for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) {
if (reqs[i] == kNullOp) continue;
@@ -905,14 +530,20 @@ void CachedOp::DynamicBackward(
if (ref_count[i] == 0) array_reqs[i] = kNullOp;
}
+ Context default_ctx = outputs[0]->ctx();
const auto& mem_plan = g.GetAttr<MemoryPlanVector >("backward_mem_plan");
AllocateMemory(g, idx, default_ctx, num_forward_entries, idx.num_node_entries(),
mem_plan, arrays, &array_reqs);
const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
- RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
- std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
+ int prev_bulk_size = Engine::Get()->set_bulk_size(config_.backward_bulk_size);
+
+ Imperative::Get()->RunGraph(
+ retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
+ std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
+
+ Engine::Get()->set_bulk_size(prev_bulk_size);
if (retain_graph) {
buff.resize(num_forward_entries);
@@ -922,99 +553,6 @@ void CachedOp::DynamicBackward(
}
}
-void CachedOp::StaticBackward(
- const bool retain_graph,
- const OpStatePtr& state_ptr,
- const std::vector<NDArray*>& inputs,
- const std::vector<OpReqType>& reqs,
- const std::vector<NDArray*>& outputs) {
- using namespace nnvm;
- using namespace imperative;
-
- Context default_ctx = outputs[0]->ctx();
-
- auto& state = state_ptr.get_state<CachedOpState>();
- std::lock_guard<std::mutex> lock(state.mutex);
-
- bool match = SetBackwardGraph(&state.info, reqs, inputs, true);
-
- nnvm::Graph& g = state.info.full_graph;
- const auto& idx = g.indexed_graph();
- auto num_forward_nodes = state.info.fwd_graph.indexed_graph().num_nodes();
-
- if (!state.bwd_alloc || !match) {
- StaticAllocMemory(state_ptr, true, true);
- }
-
- if (config_.static_shape) {
- for (auto i : config_.param_indices) {
- const auto iter = fwd_input_to_grad_output_.find(i);
- if (iter == fwd_input_to_grad_output_.end()) continue;
- auto entry = grad_graph_.outputs[iter->second];
- if (!idx.exist(entry.node.get())) continue;
- auto eid = idx.entry_id(entry);
- if (!state.arrays[eid]->IsSame(*outputs[iter->second]) ||
- !(state.array_reqs[eid] == reqs[iter->second])) {
- match = false;
- state.array_reqs[eid] = reqs[iter->second];
- *state.arrays[eid] = *outputs[iter->second];
- state.dynamic_entries[eid] = false;
- }
- }
- for (auto i : config_.data_indices) {
- const auto iter = fwd_input_to_grad_output_.find(i);
- if (iter == fwd_input_to_grad_output_.end()) continue;
- auto entry = grad_graph_.outputs[iter->second];
- if (!idx.exist(entry.node.get())) continue;
- auto eid = idx.entry_id(entry);
- state.array_reqs[eid] = reqs[iter->second];
- state.arrays[eid] = outputs[iter->second];
- }
- } else {
- for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
- auto entry = grad_graph_.outputs[i];
- if (!idx.exist(entry.node.get())) continue;
- auto eid = idx.entry_id(entry);
- state.array_reqs[eid] = reqs[i];
- state.arrays[eid] = outputs[i];
- }
- }
-
- if (!state.bwd_exec_init || !match) {
- StaticInitExec(state_ptr, true, true);
- }
-
- for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
- auto eid = state.info.bwd_input_eid[i];
- if (state.dynamic_entries[eid]) state.arrays[eid] = inputs[i];
- }
-
- StaticRunOps(default_ctx, g, state_ptr, num_forward_nodes, idx.num_nodes());
-}
-
-void CachedOp::Backward(
- const bool retain_graph,
- const OpStatePtr& state,
- const std::vector<NDArray*>& inputs,
- const std::vector<OpReqType>& reqs,
- const std::vector<NDArray*>& outputs) {
- using namespace imperative;
- CHECK(!Imperative::Get()->is_recording())
- << "CachedOp does not support higher order gradients. "
- << "If you want to do backward with create_graph=True please "
- << "do not use hybridize.";
-
- int prev_bulk_size = Engine::Get()->set_bulk_size(config_.backward_bulk_size);
-
- if (config_.static_alloc) {
- StaticBackward(retain_graph, state, inputs, reqs, outputs);
- } else {
- DynamicBackward(retain_graph, state, inputs, reqs, outputs);
- }
-
- Engine::Get()->set_bulk_size(prev_bulk_size);
-}
-
NNVM_REGISTER_OP(_CachedOp)
.set_num_inputs([](const NodeAttrs& attrs) {
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
deleted file mode 100644
index 60a40c5..0000000
--- a/src/imperative/cached_op.h
+++ /dev/null
@@ -1,174 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#ifndef MXNET_IMPERATIVE_CACHED_OP_H_
-#define MXNET_IMPERATIVE_CACHED_OP_H_
-
-#include <mxnet/imperative.h>
-#include <vector>
-#include <atomic>
-#include <utility>
-#include <string>
-#include <unordered_map>
-
-namespace mxnet {
-/*! \brief CachedOp Parameters */
-struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
- uint32_t inline_limit;
- uint32_t forward_bulk_size;
- uint32_t backward_bulk_size;
- bool static_alloc;
- bool static_shape;
- nnvm::Tuple<uint32_t> data_indices;
- nnvm::Tuple<uint32_t> param_indices;
- DMLC_DECLARE_PARAMETER(CachedOpConfig) {
- DMLC_DECLARE_FIELD(static_alloc)
- .set_default(false)
- .describe("Statically allocate memory to improve speed. "
- "Memory usage may increase.");
- DMLC_DECLARE_FIELD(static_shape)
- .set_default(false)
- .describe("Optimize for invariant input shapes between iterations. "
- "Must also set static_alloc to True. "
- "Change of input shapes is still allowed but slower.");
- DMLC_DECLARE_FIELD(inline_limit)
- .set_default(2)
- .describe("Maximum number of operators that can be inlined.");
- DMLC_DECLARE_FIELD(forward_bulk_size)
- .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
- .describe("Segment size of bulk execution during forward pass.");
- DMLC_DECLARE_FIELD(backward_bulk_size)
- .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
- .describe("Segment size of bulk execution during backward pass.");
- DMLC_DECLARE_FIELD(data_indices)
- .set_default(nnvm::Tuple<uint32_t>())
- .describe("Position of argument variables.");
- DMLC_DECLARE_FIELD(param_indices)
- .set_default(nnvm::Tuple<uint32_t>())
- .describe("Position of parameters.");
- }
-};
-
-class CachedOp {
- public:
- CachedOp(
- const nnvm::Symbol& sym,
- const std::vector<std::pair<std::string, std::string> >& flags);
- ~CachedOp();
- uint32_t num_inputs() {
- return fwd_graph_.indexed_graph().input_nodes().size();
- }
- uint32_t num_outputs() {
- return fwd_graph_.outputs.size();
- }
- uint32_t num_backward_inputs() {
- return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
- }
- std::vector<bool>& save_inputs() {
- return save_inputs_;
- }
- std::vector<bool>& save_outputs() {
- return save_outputs_;
- }
- const std::unordered_set<uint32_t>& mutable_input_nodes() {
- return fwd_graph_.indexed_graph().mutable_input_nodes();
- }
- std::vector<nnvm::NodeEntry> Gradient(
- const nnvm::NodePtr& node,
- const std::vector<nnvm::NodeEntry>& ograds);
- void Forward(
- const std::shared_ptr<CachedOp>& op_ptr,
- const std::vector<NDArray*>& inputs,
- const std::vector<NDArray*>& outputs);
- void Backward(
- const bool retain_graph,
- const OpStatePtr& state,
- const std::vector<NDArray*>& inputs,
- const std::vector<OpReqType>& reqs,
- const std::vector<NDArray*>& outputs);
-
- private:
- struct GraphInfo;
- struct DynamicRuntime;
- struct CachedOpState;
-
- OpStatePtr GetCachedOpState(const Context& ctx);
- bool SetForwardGraph(
- GraphInfo* info,
- const bool recording,
- const std::vector<NDArray*>& inputs);
- bool SetBackwardGraph(
- GraphInfo* info,
- const std::vector<OpReqType>& reqs,
- const std::vector<NDArray*>& inputs,
- bool detect_inplace_addto = false);
- OpStatePtr DynamicForward(
- const Context& default_ctx,
- const std::vector<NDArray*>& inputs,
- const std::vector<NDArray*>& outputs);
- void DynamicBackward(
- const bool retain_graph,
- const OpStatePtr& op_state,
- const std::vector<NDArray*>& inputs,
- const std::vector<OpReqType>& reqs,
- const std::vector<NDArray*>& outputs);
- void StaticAllocMemory(
- const OpStatePtr& state_ptr,
- bool recording,
- bool keep_fwd);
- void StaticInitExec(
- const OpStatePtr& state_ptr,
- bool recording,
- bool keep_fwd);
- void StaticRunOps(
- const Context& default_ctx,
- const nnvm::Graph& g,
- const OpStatePtr& state_ptr,
- size_t start_nid,
- size_t end_nid);
- OpStatePtr StaticForward(
- const Context& default_ctx,
- const std::vector<NDArray*>& inputs,
- const std::vector<NDArray*>& outputs);
- void StaticBackward(
- const bool retain_graph,
- const OpStatePtr& state_ptr,
- const std::vector<NDArray*>& inputs,
- const std::vector<OpReqType>& reqs,
- const std::vector<NDArray*>& outputs);
-
- CachedOpConfig config_;
- nnvm::Graph fwd_graph_;
- nnvm::Graph grad_graph_;
- nnvm::Graph full_graph_;
- bool inlining_;
- std::vector<nnvm::NodeEntry> ograd_entries_;
- std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
- std::unordered_map<uint32_t, uint32_t> fwd_input_to_grad_output_;
- std::vector<bool> save_inputs_, save_outputs_;
- std::vector<OpReqType> bwd_output_reqs_;
-
- std::mutex mutex_;
- std::unordered_map<Context, std::vector<OpStatePtr> > cached_op_states_;
-};
-
-using CachedOpPtr = std::shared_ptr<CachedOp>;
-
-} // namespace mxnet
-#endif // MXNET_IMPERATIVE_CACHED_OP_H_
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index e165425..7caf305 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -19,7 +19,6 @@
#include <unordered_set>
#include <iostream>
#include "./imperative_utils.h"
-#include "./cached_op.h"
namespace mxnet {
#if DMLC_CXX11_THREAD_LOCAL
@@ -267,6 +266,95 @@ void Imperative::RecordOp(
}
}
+void Imperative::RunGraph(
+ const bool retain_graph,
+ const nnvm::IndexedGraph& idx,
+ const std::vector<NDArray*> arrays,
+ size_t node_start, size_t node_end,
+ std::vector<OpReqType>&& array_reqs,
+ std::vector<uint32_t>&& ref_count,
+ std::vector<OpStatePtr> *p_states,
+ const DispatchModeVector &dispatch_modes) {
+ using namespace nnvm;
+ using namespace imperative;
+ static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+ static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
+ static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
+
+ std::vector<OpStatePtr>& states = *p_states;
+ bool recording = is_recording();
+
+ std::vector<NDArray*> ndinputs, ndoutputs;
+ ShapeVector arg_shapes;
+ DTypeVector arg_dtypes;
+ std::vector<OpReqType> req;
+
+ for (size_t i = node_start; i < node_end; ++i) {
+ const nnvm::IndexedGraph::Node& node = idx[i];
+ if (node.source->op() == nullptr) continue;
+ auto num_outputs = node.source->num_outputs();
+ ndinputs.clear();
+ ndinputs.reserve(node.inputs.size());
+ for (const auto& j : node.inputs) {
+ ndinputs.emplace_back(arrays[idx.entry_id(j)]);
+ CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index;
+ }
+ ndoutputs.clear();
+ ndoutputs.reserve(num_outputs);
+ req.clear();
+ req.reserve(num_outputs);
+ for (size_t j = 0; j < num_outputs; ++j) {
+ size_t eid = idx.entry_id(i, j);
+ ndoutputs.emplace_back(arrays[eid]);
+ req.push_back(array_reqs[eid]);
+ CHECK(!ndoutputs.back()->is_none());
+ }
+ const Context& ctx = ndoutputs[0]->ctx();
+ const DispatchMode dispatch_mode = dispatch_modes[i];
+ if (node.source->op() == bwd_cached_op) {
+ const auto& cached_op = dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
+ nnvm::Node* fwd_node = node.source->control_deps[0].get();
+ auto fwd_node_id = idx.node_id(fwd_node);
+ cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs);
+ } else if (createop.count(node.source->op())) {
+ arg_shapes.clear();
+ arg_dtypes.clear();
+ arg_shapes.reserve(ndinputs.size());
+ arg_dtypes.reserve(ndinputs.size());
+ for (size_t i = 0; i < ndinputs.size(); ++i) {
+ arg_shapes.emplace_back(ndinputs[i]->shape());
+ arg_dtypes.emplace_back(ndinputs[i]->dtype());
+ }
+ states[i] = createop[node.source->op()](
+ node.source->attrs, ctx, arg_shapes, arg_dtypes);
+ InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]);
+ if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[i]);
+ } else if (is_layer_backward.get(node.source->op(), false)) {
+ nnvm::Node* fwd_node = node.source->control_deps[0].get();
+ auto fwd_node_id = idx.node_id(fwd_node);
+ InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
+ req, dispatch_mode, states[fwd_node_id]);
+ if (recording) {
+ RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[fwd_node_id]);
+ }
+ } else {
+ InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode);
+ if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs);
+ }
+
+ for (const auto& j : node.inputs) {
+ size_t eid = idx.entry_id(j);
+ --ref_count[eid];
+ if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
+ }
+ for (size_t j = 0; j < ndoutputs.size(); ++j) {
+ size_t eid = idx.entry_id(i, j);
+ if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
+ }
+ }
+}
+
+
std::vector<NDArray*> Imperative::Backward(
const std::vector<NDArray*>& outputs,
const std::vector<NDArray*>& ograds,
diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc
deleted file mode 100644
index 464aefc..0000000
--- a/src/imperative/imperative_utils.cc
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include "./imperative_utils.h"
-#include "./cached_op.h"
-
-namespace mxnet {
-namespace imperative {
-void RunGraph(
- const bool retain_graph,
- const nnvm::IndexedGraph& idx,
- const std::vector<NDArray*> arrays,
- size_t node_start, size_t node_end,
- std::vector<OpReqType>&& array_reqs,
- std::vector<uint32_t>&& ref_count,
- std::vector<OpStatePtr> *p_states,
- const DispatchModeVector &dispatch_modes) {
- using namespace nnvm;
- using namespace imperative;
- static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
- static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
- static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
-
- const auto imp = Imperative::Get();
-
- std::vector<OpStatePtr>& states = *p_states;
- bool recording = imp->is_recording();
-
- std::vector<NDArray*> ndinputs, ndoutputs;
- ShapeVector arg_shapes;
- DTypeVector arg_dtypes;
- std::vector<OpReqType> req;
-
- for (size_t i = node_start; i < node_end; ++i) {
- const nnvm::IndexedGraph::Node& node = idx[i];
- if (node.source->op() == nullptr) continue;
- auto num_outputs = node.source->num_outputs();
- ndinputs.clear();
- ndinputs.reserve(node.inputs.size());
- for (const auto& j : node.inputs) {
- ndinputs.emplace_back(arrays[idx.entry_id(j)]);
- CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index;
- }
- ndoutputs.clear();
- ndoutputs.reserve(num_outputs);
- req.clear();
- req.reserve(num_outputs);
- for (size_t j = 0; j < num_outputs; ++j) {
- size_t eid = idx.entry_id(i, j);
- ndoutputs.emplace_back(arrays[eid]);
- req.push_back(array_reqs[eid]);
- CHECK(array_reqs[eid] == kNullOp || !ndoutputs.back()->is_none());
- }
- const Context& ctx = ndoutputs[0]->ctx();
- const DispatchMode dispatch_mode = dispatch_modes[i];
- if (node.source->op() == bwd_cached_op) {
- const auto& cached_op = dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
- nnvm::Node* fwd_node = node.source->control_deps[0].get();
- auto fwd_node_id = idx.node_id(fwd_node);
- cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs);
- } else if (createop.count(node.source->op())) {
- arg_shapes.clear();
- arg_dtypes.clear();
- arg_shapes.reserve(ndinputs.size());
- arg_dtypes.reserve(ndinputs.size());
- for (size_t i = 0; i < ndinputs.size(); ++i) {
- arg_shapes.emplace_back(ndinputs[i]->shape());
- arg_dtypes.emplace_back(ndinputs[i]->dtype());
- }
- states[i] = createop[node.source->op()](
- node.source->attrs, ctx, arg_shapes, arg_dtypes);
- imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]);
- if (recording) {
- imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[i]);
- }
- } else if (is_layer_backward.get(node.source->op(), false)) {
- nnvm::Node* fwd_node = node.source->control_deps[0].get();
- auto fwd_node_id = idx.node_id(fwd_node);
- imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
- req, dispatch_mode, states[fwd_node_id]);
- if (recording) {
- imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[fwd_node_id]);
- }
- } else {
- imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode);
- if (recording) {
- imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs);
- }
- }
-
- for (const auto& j : node.inputs) {
- size_t eid = idx.entry_id(j);
- --ref_count[eid];
- if (ref_count[eid] == 0) *arrays[eid] = NDArray();
- }
- for (size_t j = 0; j < ndoutputs.size(); ++j) {
- size_t eid = idx.entry_id(i, j);
- if (ref_count[eid] == 0) *arrays[eid] = NDArray();
- }
- }
-}
-
-} // namespace imperative
-} // namespace mxnet
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 726531d..06b7e05 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -23,7 +23,6 @@
#include <utility>
#include <algorithm>
#include <vector>
-#include <map>
#include <string>
#include "../executor/graph_executor.h"
#include "../executor/exec_pass.h"
@@ -39,24 +38,11 @@ namespace mxnet {
namespace imperative {
struct MemoryPlanInfo {
- int storage_id;
- uint32_t root;
+ uint32_t sid;
size_t size;
bool inplace;
};
-struct EngineOprDeleter {
- void operator()(engine::Opr* handle) {
- Engine::Get()->DeleteOperator(handle);
- }
-};
-
-struct EngineOprSeg {
- bool skip;
- size_t next_nid;
- std::unique_ptr<engine::Opr, EngineOprDeleter> opr;
-};
-
using MemoryPlanVector = std::vector<MemoryPlanInfo>;
inline Context GetContext(const nnvm::NodeAttrs& attrs,
@@ -729,12 +715,10 @@ inline std::vector<Context> PlaceDevice(const nnvm::IndexedGraph& idx) {
inline MemoryPlanVector PlanMemory(
- nnvm::Graph* p_g,
- nnvm::StorageVector&& storage,
+ nnvm::Graph* p_g, nnvm::StorageVector&& storage,
const std::vector<uint32_t>& ref_count,
const std::pair<uint32_t, uint32_t>& node_range = {0, 0},
- const std::pair<uint32_t, uint32_t>& entry_range = {0, 0},
- bool detect_inplace_addto = false) {
+ const std::pair<uint32_t, uint32_t>& entry_range = {0, 0}) {
using namespace nnvm;
nnvm::Graph& g = *p_g;
const auto& idx = g.indexed_graph();
@@ -744,31 +728,31 @@ inline MemoryPlanVector PlanMemory(
g.attrs["ref_count"] = std::make_shared<dmlc::any>(ref_count);
g.attrs["storage"] = std::make_shared<dmlc::any>(std::move(storage));
g = nnvm::ApplyPass(g, "PlanMemory");
- if (detect_inplace_addto) g = exec::DetectInplaceAddTo(g);
const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
const auto& shapes = g.GetAttr<ShapeVector>("shape");
- const auto& storage_inplace = g.GetAttr<std::vector<int> >("storage_inplace_index");
- const auto& storage_ids = g.GetAttr<StorageVector>("storage_id");
+ const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
+ auto storage_ids = g.MoveCopyAttr<StorageVector>("storage_id");
+ auto storage_inplace = g.MoveCopyAttr<std::vector<int> >("storage_inplace_index");
uint32_t entry_start = entry_range.first;
uint32_t entry_end =
entry_range.second > entry_start ? entry_range.second : idx.num_node_entries();
MemoryPlanVector mem_plan(idx.num_node_entries());
- std::unordered_map<int, uint32_t> sid_to_root;
+ std::unordered_map<int, uint32_t> sid_to_loc;
for (uint32_t i = entry_start; i < entry_end; ++i) {
+ if (stypes[i] != kDefaultStorage) continue;
if (storage_ids[i] < 0) {
- mem_plan[i] = {storage_ids[i], i, 0, false};
- } else if (!sid_to_root.count(storage_ids[i])) {
+ mem_plan[i] = {i, mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(), false};
+ } else if (!sid_to_loc.count(storage_ids[i])) {
CHECK_LT(storage_inplace[i], 0);
- sid_to_root[storage_ids[i]] = i;
- mem_plan[i] = {storage_ids[i], i,
- mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(),
- false};
+ sid_to_loc[storage_ids[i]] = i;
+ mem_plan[i].sid = i;
+ mem_plan[i].size = mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size();
} else {
- uint32_t root = sid_to_root[storage_ids[i]];
- mem_plan[i] = {storage_ids[i], root, 0, storage_inplace[i] >= 0};
- mem_plan[root].size = std::max(mem_plan[root].size,
+ uint32_t loc = sid_to_loc[storage_ids[i]];
+ mem_plan[i] = {loc, 0, storage_inplace[i] >= 0};
+ mem_plan[loc].size = std::max(mem_plan[loc].size,
mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size());
}
}
@@ -777,213 +761,39 @@ inline MemoryPlanVector PlanMemory(
}
-inline std::multimap<size_t, NDArray> AllocateMemory(
- const nnvm::Graph& g,
- const nnvm::IndexedGraph& idx,
- const Context& default_ctx,
- const uint32_t entry_start, const uint32_t entry_end,
- const MemoryPlanVector& mem_plan,
- const std::vector<NDArray*>& arrays,
- std::vector<OpReqType> *array_reqs,
- std::multimap<size_t, NDArray>&& pool = std::multimap<size_t, NDArray>()) {
+inline void AllocateMemory(const nnvm::Graph& g,
+ const nnvm::IndexedGraph& idx,
+ const Context& default_ctx,
+ const uint32_t entry_start, const uint32_t entry_end,
+ const MemoryPlanVector& mem_plan,
+ const std::vector<NDArray*>& arrays,
+ std::vector<OpReqType> *array_reqs) {
using namespace nnvm;
const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
const auto& shapes = g.GetAttr<ShapeVector>("shape");
const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
- std::multimap<size_t, NDArray> new_pool;
-
for (uint32_t i = entry_start; i < entry_end; ++i) {
- if (mem_plan[i].storage_id == exec::kExternalStorageID) continue;
- CHECK(arrays[i]->is_none());
- if (mem_plan[i].storage_id == exec::kDynamicStorageID) {
- *arrays[i] = NDArray(static_cast<NDArrayStorageType>(stypes[i]),
- shapes[i], default_ctx, true, dtypes[i]);
- continue;
- }
- CHECK_EQ(stypes[i], kDefaultStorage);
- if (mem_plan[i].root == i) {
- CHECK_GT(mem_plan[i].size, 0);
- auto iter = pool.lower_bound(mem_plan[i].size);
- if (iter != pool.end()) {
- *arrays[i] = iter->second.AsArray(shapes[i], dtypes[i]);
- new_pool.insert(*iter);
- pool.erase(iter);
- } else {
+ if (!arrays[i]->is_none()) continue;
+ if (stypes[i] == kDefaultStorage) {
+ if (mem_plan[i].sid == i) {
+ CHECK_GT(mem_plan[i].size, 0);
NDArray buff(TShape({static_cast<nnvm::dim_t>(mem_plan[i].size)}),
default_ctx, true, mshadow::kUint8);
*arrays[i] = buff.AsArray(shapes[i], dtypes[i]);
- new_pool.insert({mem_plan[i].size, buff});
- }
- } else {
- CHECK_GE(mem_plan[mem_plan[i].root].storage_id, 0);
- *arrays[i] = arrays[mem_plan[i].root]->AsArray(shapes[i], dtypes[i]);
- if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) {
- array_reqs->at(i) = kWriteInplace;
- }
- }
- }
-
- return new_pool;
-}
-
-inline void SetupOpExec(
- const nnvm::Graph& g,
- size_t nid,
- const std::shared_ptr<exec::OpExecutor>& exec,
- const std::vector<NDArray*> arrays,
- const std::vector<OpReqType> array_reqs) {
- const auto& idx = g.indexed_graph();
- const auto& inode = idx[nid];
- CHECK_EQ(exec->in_array.size(), 0U);
- CHECK_EQ(exec->out_array.size(), 0U);
- for (const auto& e : inode.inputs) {
- CHECK(!arrays[idx.entry_id(e)]->is_none()) << inode.source->attrs.name;
- exec->in_array.push_back(*arrays[idx.entry_id(e)]);
- }
- for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
- uint32_t eid = idx.entry_id(nid, index);
- CHECK(!arrays[eid]->is_none()) << inode.source->attrs.name;
- exec->out_array.push_back(*arrays[eid]);
- exec->req.push_back(array_reqs[eid]);
- }
-
- exec->Setup();
-}
-
-inline Engine::OprHandle CreateEngineOp(
- const Context& default_ctx,
- const std::vector<std::shared_ptr<exec::OpExecutor> >& execs) {
- CHECK_GT(execs.size(), 0);
- std::vector<Engine::VarHandle> use_vars, mutate_vars;
-
- for (const auto& exec : execs) {
- CHECK_GT(exec->out_array.size(), 0);
- CHECK(execs.size() == 1 || exec->exec_type() == ExecType::kSync);
-
- // the variables
- for (const auto& nd : exec->in_array) {
- use_vars.push_back(nd.var());
- }
- for (auto& r : exec->op_ctx.requested) {
- mutate_vars.push_back(r.var);
- }
- for (auto& nd : exec->out_array) {
- mutate_vars.push_back(nd.var());
- }
- if (exec->var() != nullptr) {
- mutate_vars.push_back(exec->var());
- }
- }
-
- // dedup vars
- Engine::Get()->DeduplicateVarHandle(&use_vars, &mutate_vars);
- bool is_gpu = default_ctx.dev_mask() == gpu::kDevMask;
- bool is_async = execs.size() > 1 ? false : execs[0]->exec_type() == ExecType::kAsync;
-
- auto exec_fun = [execs, is_async, is_gpu] (
- RunContext ctx, Engine::CallbackOnComplete on_complete) {
- if (is_async) {
- execs[0]->op_ctx.async_on_complete = on_complete;
- }
- for (const auto& exec : execs) exec->Run(ctx, is_gpu);
- // call on complete only if it is async op
- if (!is_async) {
- if (is_gpu) {
- #if MXNET_USE_CUDA
- // Wait GPU kernel to finish.
- ctx.get_stream<gpu>()->Wait();
- #else
- LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
- #endif
- }
- on_complete();
- }
- };
-
- return Engine::Get()->NewOperator(
- exec_fun, use_vars, mutate_vars, FnProperty::kNormal);
-}
-
-inline void CreateEngineOpSeg(
- const nnvm::IndexedGraph& idx,
- const Context default_ctx,
- const size_t start_nid,
- const size_t end_nid,
- const size_t bulk_size,
- const std::unordered_set<uint32_t>& excludes,
- const std::vector<std::shared_ptr<exec::OpExecutor> >& execs,
- const std::vector<int> skip_plus_node,
- std::vector<EngineOprSeg> *opr_segs) {
- size_t seg_start = start_nid;
- std::vector<std::shared_ptr<exec::OpExecutor> > seg_execs;
- for (size_t nid = start_nid; nid < end_nid; ++nid) {
- const auto& node = idx[nid];
- if (node.source->is_variable()) continue;
- if (skip_plus_node.size() && skip_plus_node[nid]) continue;
- auto& exec = execs[nid];
- bool is_async = exec->exec_type() != ExecType::kSync;
- bool valid = exec->out_array.size() > 0;
-
- // Stop at async nodes and invalid node (due to input/output is not allocated)
- bool stop = is_async || !valid || seg_execs.size() >= bulk_size;
- for (size_t i = 0; i < node.inputs.size() && !stop; ++i) {
- if (excludes.count(idx.entry_id(node.inputs[i]))) stop = true;
- }
- auto num_outputs = node.source->num_outputs();
- for (size_t i = 0; i < num_outputs && !stop; ++i) {
- if (excludes.count(idx.entry_id(nid, i))) stop = true;
- }
-
- // Create opr segment for previous nodes.
- if (stop && nid > seg_start) {
- auto& seg = (*opr_segs)[seg_start];
- if (seg_execs.size()) {
- seg = EngineOprSeg{false, nid};
- seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
} else {
- seg = EngineOprSeg{true, nid, nullptr};
+ *arrays[i] = arrays[mem_plan[i].sid]->AsArray(shapes[i], dtypes[i]);
+ if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) {
+ array_reqs->at(i) = kWriteInplace;
+ }
}
- seg_start = nid;
- seg_execs.clear();
- }
-
- seg_execs.push_back(exec);
-
- auto& seg = (*opr_segs)[nid];
- if (is_async) {
- seg = EngineOprSeg{false, nid + 1};
- seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
- seg_execs.clear();
- seg_start = nid + 1;
- } else if (!valid) {
- seg = EngineOprSeg{false, nid + 1, nullptr};
- seg_execs.clear();
- seg_start = nid + 1;
- }
- }
- // The last segment
- if (end_nid > seg_start) {
- auto& seg = (*opr_segs)[seg_start];
- if (seg_execs.size()) {
- seg = EngineOprSeg{false, end_nid};
- seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
} else {
- seg = EngineOprSeg{true, end_nid, nullptr};
+ *arrays[i] = NDArray(static_cast<NDArrayStorageType>(stypes[i]),
+ shapes[i], default_ctx, true, dtypes[i]);
}
}
}
-
-void RunGraph(const bool retain_graph,
- const nnvm::IndexedGraph& idx,
- const std::vector<NDArray*> arrays,
- size_t node_start, size_t node_end,
- std::vector<OpReqType>&& array_reqs,
- std::vector<uint32_t>&& ref_count,
- std::vector<OpStatePtr> *p_states,
- const DispatchModeVector &dispatch_modes);
-
} // namespace imperative
} // namespace mxnet
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index bb61af1..451fde2 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -22,7 +22,6 @@ from mxnet.test_utils import assert_almost_equal
from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
from common import setup_module, with_seed, assertRaises, teardown
import numpy as np
-from numpy.testing import assert_array_equal
from nose.tools import raises, assert_raises
from copy import deepcopy
import warnings
@@ -1125,6 +1124,7 @@ def test_hybrid_multi_context():
net.hybridize()
net(mx.nd.zeros((1, 3, 32, 32), ctx=mx.cpu(0))).asnumpy()
+
@with_seed()
def test_zero_grad():
data = mx.nd.random.uniform(shape=(3,3))
@@ -1137,60 +1137,6 @@ def test_zero_grad():
grad = net.collect_params()['test_zero_grad_weight'].grad()
assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0)
-def check_hybrid_static_memory(**kwargs):
- x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
- x.attach_grad()
-
- net1 = gluon.model_zoo.vision.get_resnet(
- 1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context())
- net2 = gluon.model_zoo.vision.get_resnet(
- 1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context())
- net2.hybridize(**kwargs)
- net1(x)
- net2(x)
-
- def test(net, x):
- with mx.autograd.record():
- y = net(x) + net(x)
- y.backward()
-
- grads = {k: v.grad() for k, v in net.collect_params().items() if v.grad_req != 'null'}
-
- return y, grads
-
- y1, grads1 = test(net1, x)
- y2, grads2 = test(net2, x)
-
- assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5)
- for key in grads1:
- assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5)
-
-def test_hybrid_static_memory():
- check_hybrid_static_memory()
- check_hybrid_static_memory(static_alloc=True)
- check_hybrid_static_memory(static_alloc=True, static_shape=True)
-
-def check_hybrid_static_memory_switching(**kwargs):
- net = gluon.model_zoo.vision.get_resnet(
- 1, 18, pretrained=True, ctx=mx.context.current_context())
- net.hybridize(**kwargs)
-
- x = mx.nd.random.uniform(shape=(4, 3, 32, 32))
- net(x)
- with mx.autograd.record():
- y = net(x)
- y.backward()
- x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
- net(x)
- with mx.autograd.record():
- y = net(x)
- y.backward()
- mx.nd.waitall()
-
-def test_hybrid_static_memory_switching():
- check_hybrid_static_memory_switching()
- check_hybrid_static_memory_switching(static_alloc=True)
- check_hybrid_static_memory_switching(static_alloc=True, static_shape=True)
@with_seed()
def test_hook():
--
To stop receiving notification emails like this one, please contact
marcoabreu@apache.org.