You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/09/23 21:45:56 UTC
[incubator-mxnet] branch master updated: [MXNET-876] make CachedOp
a normal operator (#11641)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 3caf2ca [MXNET-876] make CachedOp a normal operator (#11641)
3caf2ca is described below
commit 3caf2cacfaae17595713d060061822655d7c5fcd
Author: Da Zheng <zh...@gmail.com>
AuthorDate: Sun Sep 23 14:45:39 2018 -0700
[MXNET-876] make CachedOp a normal operator (#11641)
* extend _CachedOp a regular operator.
* use default subgraph infer.
* fix test.
* fix compilation error.
* use default subgraph stuff.
* add comments.
* fix.
* use a more general InferStorage.
* use cachedOp as default subgraph operator.
* remove default subgraph op.
* fix.
* fix.
* rename.
* add comment.
* retrigger
* add comments.
---
src/imperative/cached_op.cc | 259 ++++++++++++++++++---
src/imperative/cached_op.h | 28 ++-
src/operator/operator_common.h | 12 +-
src/operator/subgraph/common.h | 62 +++--
src/operator/subgraph/default_subgraph_op.cc | 112 ---------
src/operator/subgraph/default_subgraph_op.cu | 44 ----
src/operator/subgraph/default_subgraph_property.cc | 11 +-
tests/python/unittest/test_subgraph.py | 149 ++++++++++++
8 files changed, 453 insertions(+), 224 deletions(-)
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 0c4c1e6..1f115cd 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -23,6 +23,7 @@
#include "../executor/exec_pass.h"
#include "../profiler/profiler.h"
#include "../operator/operator_common.h"
+#include "../operator/subgraph/common.h"
namespace mxnet {
@@ -874,7 +875,6 @@ OpStatePtr CachedOp::Forward(
return op_state;
}
-
void CachedOp::DynamicBackward(
const bool retain_graph,
const OpStatePtr& op_state,
@@ -1067,34 +1067,153 @@ void CachedOp::Backward(
Engine::Get()->set_bulk_size(prev_bulk_size);
}
-bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs,
- const int dev_mask,
- DispatchMode* dispatch_mode,
- std::vector<int> *in_attrs,
- std::vector<int> *out_attrs) {
- using namespace imperative;
- nnvm::Graph g(fwd_graph_);
- const auto& idx = g.indexed_graph();
- const auto &outputs = idx.outputs();
+/*
+ * This is the operator state of CachedOp when CachedOp is used in the symbol
+ * executor. This is different from the OpState returned by CachedOp::Forward.
+ * The main reason why we need this OpState is that CachedOp and the symbol executor
+ * maintain OpState differently. The symbol executor generates OpState in advance
+ * while CachedOp generates OpState after Forward is called. We need this data
+ * structure to keep the OpState generated by CachedOp::Forward and pass it to
+ * Backward.
+ */
+struct CachedOpActualState {
+ std::shared_ptr<CachedOp> op;
+ OpStatePtr forward_state;
- // Prepare stypes and contexts based on inputs
- StorageTypeVector storage_type_inputs;
- storage_type_inputs.reserve(in_attrs->size());
- for (size_t i = 0; i < in_attrs->size(); ++i) {
- storage_type_inputs.emplace_back(in_attrs->at(i));
+ explicit CachedOpActualState(std::shared_ptr<CachedOp> op) {
+ this->op = op;
}
- exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
+};
- // Forward graph storage type inference
- CheckAndInferStorageType(&g, std::move(dev_masks), std::move(storage_type_inputs), true);
- // Retrieve result and set outputs
- const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type");
- for (size_t i = 0; i < out_attrs->size(); i++) {
- const auto eid = idx.entry_id(outputs[i]);
- STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]);
+/*
+ * This is the forward computation when CachedOp is used as an operator in
+ * a symbol executor.
+ */
+void CachedOpForward(const OpStatePtr& state_ptr,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
+ std::vector<NDArray> in_bufs = inputs;
+ std::vector<NDArray> out_bufs = outputs;
+ std::vector<NDArray *> in_ptrs(in_bufs.size());
+ std::vector<NDArray *> out_ptrs(out_bufs.size());
+ for (size_t i = 0; i < in_ptrs.size(); i++)
+ in_ptrs[i] = &in_bufs[i];
+ for (size_t i = 0; i < out_ptrs.size(); i++)
+ out_ptrs[i] = &out_bufs[i];
+
+ // Set is_recording correct for the imperative executor.
+ bool orig_is_record;
+ if (ctx.need_grad)
+ orig_is_record = Imperative::Get()->set_is_recording(true);
+ else
+ orig_is_record = Imperative::Get()->is_recording();
+ // Set is_training correct for the imperative executor.
+ bool orig_is_train;
+ if (ctx.is_train)
+ orig_is_train = Imperative::Get()->set_is_training(true);
+ else
+ orig_is_train = Imperative::Get()->is_training();
+ s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs);
+ Imperative::Get()->set_is_training(orig_is_train);
+ Imperative::Get()->set_is_recording(orig_is_record);
+ // The arrays in out_ptrs may be changed by CachedOp.
+ // If it is, we need to copy data back.
+ for (size_t i = 0; i < out_bufs.size(); i++)
+ if (!out_bufs[i].IsSame(outputs[i]))
+ CopyFromTo(out_bufs[i], outputs[i]);
+}
+
+/*
+ * This is the backward computation when CachedOp is used as an operator in
+ * a symbol executor.
+ */
+void CachedOpBackward(const OpStatePtr& state_ptr,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ using namespace nnvm;
+ using namespace imperative;
+ CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
+ std::vector<NDArray> in_bufs = inputs;
+ std::vector<NDArray> out_bufs = outputs;
+ std::vector<NDArray *> in_ptrs;
+ std::vector<NDArray *> out_ptrs;
+ CHECK_EQ(s.op->num_backward_inputs(), inputs.size());
+ in_ptrs.reserve(s.op->num_backward_inputs());
+ out_ptrs.reserve(s.op->num_inputs());
+
+ const std::vector<bool> &save_inputs = s.op->save_inputs();
+ const std::vector<bool> &save_outputs = s.op->save_outputs();
+ size_t bwd_in_dep = s.op->num_inputs();
+ size_t bwd_out_dep = s.op->num_outputs();
+ CHECK(s.op->num_backward_inputs() > bwd_in_dep + bwd_out_dep);
+ size_t bwd_ograd_dep = s.op->num_backward_inputs() - bwd_in_dep - bwd_out_dep;
+
+ // Find inputs, outputs and ograds
+ auto ograds_begin = in_bufs.begin();
+ auto ograds_end = in_bufs.begin() + bwd_ograd_dep;
+ auto in_begin = ograds_end;
+ auto in_end = in_begin + bwd_in_dep;
+ auto out_begin = in_end;
+ auto out_end = in_bufs.end();
+
+ for (auto it = ograds_begin; it != ograds_end; it++)
+ in_ptrs.push_back(&(*it));
+
+ CHECK_EQ(save_inputs.size(), in_end - in_begin);
+ CHECK_EQ(s.op->num_outputs(), out_end - out_begin);
+ for (auto it = in_begin; it != in_end; it++) {
+ auto i = it - in_begin;
+ if (save_inputs[i])
+ in_ptrs.push_back(&(*it));
}
- DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
- return true;
+ for (auto it = out_begin; it != out_end; it++) {
+ auto i = it - out_begin;
+ if (save_outputs[i])
+ in_ptrs.push_back(&(*it));
+ }
+ CHECK_EQ(in_ptrs.size(), s.op->num_backward_inputs());
+ for (size_t i = 0; i < out_bufs.size(); i++)
+ out_ptrs.push_back(&out_bufs[i]);
+ CHECK_EQ(out_ptrs.size(), s.op->num_backward_outputs());
+ // Set is_training correct for the imperative executor.
+ bool orig_is_train;
+ if (ctx.is_train)
+ orig_is_train = Imperative::Get()->set_is_training(true);
+ else
+ orig_is_train = Imperative::Get()->is_training();
+ // TODO(zhengda) CachedOp supports recording computation when running
+ // the backward path. This is necessary if we want to support the second-order
+ // differentiation. However, MXNet operator doesn't have an interface to
+ // pass a flag to determine whether to record computation inside an operator.
+ // Let's use false here for now and design a solution when the second-order
+ // differentiation is supported.
+ s.op->Backward(false, s.forward_state, in_ptrs, req, out_ptrs);
+ Imperative::Get()->set_is_training(orig_is_train);
+
+ // Clean up what we recorded.
+ s.forward_state.reset();
+
+ // The arrays in out_ptrs may be changed by CachedOp.
+ // If it is, we need to copy data back.
+ // For example, when the inputs and outputs share the same NDArrays,
+ // the outputs will be replaced by inputs.
+ // https://github.com/apache/incubator-mxnet/blob/v1.2.0/src/imperative/cached_op.cc#L385
+ for (size_t i = 0; i < out_bufs.size(); i++)
+ if (!out_bufs[i].IsSame(outputs[i]))
+ CopyFromTo(out_bufs[i], outputs[i]);
+}
+
+OpStatePtr CreateCachedOpState(const NodeAttrs& attrs,
+ Context ctx,
+ const std::vector<TShape>& in_shapes,
+ const std::vector<int>& in_types) {
+ const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+ return OpStatePtr::Create<CachedOpActualState>(op);
}
bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
@@ -1143,6 +1262,32 @@ bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
return true;
}
+void CachedOpParamParser(nnvm::NodeAttrs* attrs) {
+ CachedOpConfig param;
+ try {
+ param.Init(attrs->dict);
+ } catch (const dmlc::ParamError& e) {
+ std::ostringstream os;
+ os << e.what();
+ os << ", in operator " << attrs->op->name << "("
+ << "name=\"" << attrs->name << "\"";
+ for (const auto& k : attrs->dict) {
+ os << ", " << k.first << "=\"" << k.second << "\"";
+ }
+ os << ")";
+ throw dmlc::ParamError(os.str());
+ }
+ if (!param.subgraph.empty()) {
+ nnvm::Graph g = nnvm::pass::LoadJSON(param.subgraph);
+ CHECK(!g.outputs.empty());
+ nnvm::Symbol sym;
+ sym.outputs = g.outputs;
+ std::vector<std::pair<std::string, std::string> > flags;
+ for (auto it = attrs->dict.begin(); it != attrs->dict.end(); it++)
+ flags.emplace_back(it->first, it->second);
+ attrs->parsed = CachedOpPtr(new CachedOp(sym, flags));
+ }
+}
NNVM_REGISTER_OP(_CachedOp)
.set_num_inputs([](const NodeAttrs& attrs) {
@@ -1153,19 +1298,62 @@ NNVM_REGISTER_OP(_CachedOp)
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->num_outputs();
})
-.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs& attrs,
- const int dev_mask,
- DispatchMode* dispatch_mode,
- std::vector<int> *in_attrs,
- std::vector<int> *out_attrs) {
- const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
- return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
- })
+.set_attr_parser(CachedOpParamParser)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(n->attrs.parsed);
return op->Gradient(n, ograds);
- });
+ })
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const nnvm::NodeAttrs& attrs) {
+ const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+ return op->ListForwardInputNames();
+ })
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
+ [](const nnvm::NodeAttrs& attrs) {
+ const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+ return op->ListForwardOutputNames();
+ })
+.set_attr<FCreateOpState>("FCreateOpState", CreateCachedOpState)
+.set_attr<nnvm::FInferShape>("FInferShape",
+ [](const nnvm::NodeAttrs& attrs,
+ std::vector<TShape> *in_shapes,
+ std::vector<TShape> *out_shapes) {
+ const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+ return op::DefaultSubgraphOpShapeHelper(op->GetForwardSym(), in_shapes, out_shapes);
+ })
+.set_attr<nnvm::FInferType>("FInferType",
+ [](const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_types,
+ std::vector<int> *out_types) {
+ const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+ return op::DefaultSubgraphOpTypeHelper(op->GetForwardSym(), in_types, out_types);
+ })
+.set_attr<FInferStorageType>("FInferStorageType",
+ [](const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_stypes,
+ std::vector<int>* out_stypes) {
+ const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+ return op::DefaultSubgraphOpStorageTypeHelper(op->GetForwardSym(),
+ dev_mask, dispatch_mode,
+ in_stypes, out_stypes);
+ })
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpForward)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpForward)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+ [](const nnvm::NodeAttrs& attrs) {
+ const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+ return op::DefaultSubgraphOpMutableInputsHelper(op->GetForwardSym());
+ })
+.set_attr<FResourceRequest>("FResourceRequest",
+ [](const nnvm::NodeAttrs& attrs) {
+ const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+ return op::DefaultSubgraphOpResourceRequestHelper(op->GetForwardSym());
+ })
+.set_attr<FExecType>("FExecType", op::DefaultSubgraphOpExecType)
+.add_argument("data", "NDArray-or-Symbol[]", "input data list");
NNVM_REGISTER_OP(_backward_CachedOp)
.set_num_inputs([](const NodeAttrs& attrs){
@@ -1184,6 +1372,9 @@ NNVM_REGISTER_OP(_backward_CachedOp)
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
})
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpBackward)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpBackward)
+.set_attr<FExecType>("FExecType", op::DefaultSubgraphOpExecType)
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<bool>("TIsBackward", true);
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
index 4f4dfdc..59a793e 100644
--- a/src/imperative/cached_op.h
+++ b/src/imperative/cached_op.h
@@ -37,6 +37,7 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
bool static_shape;
nnvm::Tuple<uint32_t> data_indices;
nnvm::Tuple<uint32_t> param_indices;
+ std::string subgraph;
DMLC_DECLARE_PARAMETER(CachedOpConfig) {
DMLC_DECLARE_FIELD(static_alloc)
.set_default(false)
@@ -62,6 +63,9 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
DMLC_DECLARE_FIELD(param_indices)
.set_default(nnvm::Tuple<uint32_t>())
.describe("Position of parameters.");
+ DMLC_DECLARE_FIELD(subgraph)
+ .set_default(std::string(""))
+ .describe("JSON string of a subgraph.");
}
};
@@ -80,6 +84,10 @@ class CachedOp {
uint32_t num_backward_inputs() const {
return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
}
+ uint32_t num_backward_outputs() const {
+ auto &idx = fwd_graph_.indexed_graph();
+ return idx.input_nodes().size() - idx.mutable_input_nodes().size();
+ }
std::vector<bool>& save_inputs() {
return save_inputs_;
}
@@ -102,13 +110,6 @@ class CachedOp {
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);
- // forward storage type inference
- bool ForwardStorageType(
- const nnvm::NodeAttrs& attrs,
- const int dev_mask,
- DispatchMode* dispatch_mode,
- std::vector<int> *in_attrs,
- std::vector<int> *out_attrs);
// backward storage type inference
bool BackwardStorageType(
const nnvm::NodeAttrs& attrs,
@@ -116,6 +117,19 @@ class CachedOp {
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);
+ std::vector<std::string> ListForwardInputNames() const {
+ nnvm::Symbol sym = GetForwardSym();
+ return sym.ListInputNames(nnvm::Symbol::kAll);
+ }
+ std::vector<std::string> ListForwardOutputNames() const {
+ nnvm::Symbol sym = GetForwardSym();
+ return sym.ListOutputNames();
+ }
+ nnvm::Symbol GetForwardSym() const {
+ nnvm::Symbol sym;
+ sym.outputs = fwd_graph_.outputs;
+ return sym;
+ }
private:
struct GraphInfo;
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index 2911293..6a4c3d0 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -221,7 +221,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
*/
#define SHAPE_ASSIGN_CHECK(shape_array, index, shape) \
{ \
- if (!shape_assign(&(shape_array)[index], TShape(shape))) { \
+ if (!::mxnet::op::shape_assign(&(shape_array)[index], TShape(shape))) { \
std::ostringstream os; \
os << "Shape inconsistent, Provided = " << (shape_array)[index] << ','\
<< " inferred shape=" << shape; \
@@ -238,11 +238,11 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
*/
#define TYPE_ASSIGN_CHECK(type_array, index, type) \
{ \
- if (!type_assign(&(type_array)[index], type)) { \
+ if (!::mxnet::op::type_assign(&(type_array)[index], type)) { \
std::ostringstream os; \
os << "Type inconsistent, Provided = " \
- << type_string((type_array)[index]) << ',' \
- << " inferred type = " << type_string(type); \
+ << ::mxnet::op::type_string((type_array)[index]) << ',' \
+ << " inferred type = " << ::mxnet::op::type_string(type); \
throw ::mxnet::op::InferTypeError(os.str(), index); \
} \
}
@@ -291,8 +291,8 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
#define UNIFORM_TYPE_CHECK(type, expected, arg) \
{ \
CHECK_EQ(type, expected) << "This layer requires uniform type. " \
- << "Expected '" << type_string(expected) \
- << "' v.s. given '" << type_string(type) \
+ << "Expected '" << ::mxnet::op::type_string(expected) \
+ << "' v.s. given '" << ::mxnet::op::type_string(type) \
<< "' at '" << arg << "'"; \
}
diff --git a/src/operator/subgraph/common.h b/src/operator/subgraph/common.h
index 22058d5..4e1cd66 100644
--- a/src/operator/subgraph/common.h
+++ b/src/operator/subgraph/common.h
@@ -49,11 +49,10 @@ inline std::vector<std::string> DefaultSubgraphOpListOutputs(const nnvm::NodeAtt
return sym.ListOutputNames();
}
-inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
- std::vector<TShape> *in_shapes,
- std::vector<TShape> *out_shapes) {
+inline bool DefaultSubgraphOpShapeHelper(const nnvm::Symbol& subgraph_sym,
+ std::vector<TShape> *in_shapes,
+ std::vector<TShape> *out_shapes) {
using namespace exec;
- const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
nnvm::Graph g;
g.outputs = subgraph_sym.outputs;
const auto& idx_g = g.indexed_graph();
@@ -94,10 +93,15 @@ inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
}
-inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
- std::vector<int> *in_types,
- std::vector<int> *out_types) {
- const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
+ std::vector<TShape> *in_shapes,
+ std::vector<TShape> *out_shapes) {
+ return DefaultSubgraphOpShapeHelper(*attrs.subgraphs[0], in_shapes, out_shapes);
+}
+
+inline bool DefaultSubgraphOpTypeHelper(const nnvm::Symbol& subgraph_sym,
+ std::vector<int> *in_types,
+ std::vector<int> *out_types) {
nnvm::Graph g;
g.outputs = subgraph_sym.outputs;
const auto& idx_g = g.indexed_graph();
@@ -137,12 +141,17 @@ inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
return g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0;
}
-inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
- const int dev_mask,
- DispatchMode* dispatch_mode,
- std::vector<int>* in_stypes,
- std::vector<int>* out_stypes) {
- const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_types,
+ std::vector<int> *out_types) {
+ return DefaultSubgraphOpTypeHelper(*attrs.subgraphs[0], in_types, out_types);
+}
+
+inline bool DefaultSubgraphOpStorageTypeHelper(const nnvm::Symbol& subgraph_sym,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_stypes,
+ std::vector<int>* out_stypes) {
nnvm::Graph g;
g.outputs = subgraph_sym.outputs;
const auto& idx_g = g.indexed_graph();
@@ -190,12 +199,21 @@ inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
return g.GetAttr<size_t>("storage_type_num_unknown_nodes") == 0;
}
+inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_stypes,
+ std::vector<int>* out_stypes) {
+ return DefaultSubgraphOpStorageTypeHelper(*attrs.subgraphs[0], dev_mask, dispatch_mode,
+ in_stypes, out_stypes);
+}
+
inline ExecType DefaultSubgraphOpExecType(const nnvm::NodeAttrs& attrs) {
return ExecType::kSubgraphExec;
}
-inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) {
- const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+inline std::vector<uint32_t> DefaultSubgraphOpMutableInputsHelper(
+ const nnvm::Symbol& subgraph_sym) {
const std::vector<std::string> input_names = subgraph_sym.ListInputNames(nnvm::Symbol::kAll);
const std::vector<std::string> immutable_input_names =
subgraph_sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
@@ -217,8 +235,12 @@ inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttr
return ret;
}
-inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) {
- const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) {
+ return DefaultSubgraphOpMutableInputsHelper(*attrs.subgraphs[0]);
+}
+
+inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequestHelper(
+ const nnvm::Symbol& subgraph_sym) {
static auto& fresource = Op::GetAttr<FResourceRequest>("FResourceRequest");
std::set<ResourceRequest::Type> resource_types;
DFSVisit(subgraph_sym.outputs, [&](const nnvm::NodePtr& node) {
@@ -231,6 +253,10 @@ inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const nnvm:
return std::vector<ResourceRequest>(resource_types.begin(), resource_types.end());
}
+inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) {
+ return DefaultSubgraphOpResourceRequestHelper(*attrs.subgraphs[0]);
+}
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/subgraph/default_subgraph_op.cc b/src/operator/subgraph/default_subgraph_op.cc
deleted file mode 100644
index d5fb7ee..0000000
--- a/src/operator/subgraph/default_subgraph_op.cc
+++ /dev/null
@@ -1,112 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements. See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership. The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied. See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*/
-
-#include <mxnet/ndarray.h>
-#include "./common.h"
-#include "../../imperative/imperative_utils.h"
-#include "../../imperative/cached_op.h"
-
-namespace mxnet {
-namespace op {
-
-#define DEBUG_SUBGRAPH 0
-
-class DefaultSubgraphOperator {
- public:
- explicit DefaultSubgraphOperator(const Symbol& sym) : subgraph_sym_(sym) {
- subgraph_exec_.reset(new CachedOp(sym, {{"static_alloc", "true"},
- {"static_shape", "true"}}));
- }
-
- void Forward(const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs);
- void Backward(const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
- LOG(FATAL) << "Not implemented";
- }
-
- private:
- nnvm::Symbol subgraph_sym_;
- CachedOpPtr subgraph_exec_;
-};
-
-void DefaultSubgraphOperator::Forward(const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
- std::vector<NDArray> tmp_inputs = inputs;
- std::vector<NDArray*> input_ptrs;
- input_ptrs.reserve(inputs.size());
- for (auto& nd : tmp_inputs) {
- input_ptrs.push_back(&nd);
- }
- std::vector<NDArray> tmp_outputs = outputs;
- std::vector<NDArray*> output_ptrs;
- for (auto& nd : tmp_outputs) {
- output_ptrs.push_back(&nd);
- }
-#if DEBUG_SUBGRAPH
- for (size_t i = 0; i < inputs.size(); ++i) {
- LOG(INFO) << "inputs[" << i << "].version = " << inputs[i].version();
- }
- for (size_t i = 0; i < outputs.size(); ++i) {
- LOG(INFO) << "outputs[" << i << "].version = " << outputs[i].version();
- }
-#endif
- subgraph_exec_->Forward(subgraph_exec_, input_ptrs, output_ptrs);
-}
-
-OpStatePtr CreateDefaultSubgraphOpState(const NodeAttrs& attrs,
- Context ctx,
- const std::vector<TShape>& in_shapes,
- const std::vector<int>& in_types) {
- return OpStatePtr::Create<DefaultSubgraphOperator>(*attrs.subgraphs[0]);
-}
-
-void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
- DefaultSubgraphOperator& op = state_ptr.get_state<DefaultSubgraphOperator>();
- op.Forward(ctx, inputs, req, outputs);
-}
-
-NNVM_REGISTER_OP(_default_subgraph_op)
-.describe(R"code(_default_subgraph_op)code" ADD_FILELINE)
-.set_num_inputs(DefaultSubgraphOpNumInputs)
-.set_num_outputs(DefaultSubgraphOpNumOutputs)
-.set_attr<nnvm::FListInputNames>("FListInputNames", DefaultSubgraphOpListInputs)
-.set_attr<nnvm::FListOutputNames>("FListOutputNames", DefaultSubgraphOpListOutputs)
-.set_attr<FCreateOpState>("FCreateOpState", CreateDefaultSubgraphOpState)
-.set_attr<nnvm::FInferShape>("FInferShape", DefaultSubgraphOpShape)
-.set_attr<nnvm::FInferType>("FInferType", DefaultSubgraphOpType)
-.set_attr<FInferStorageType>("FInferStorageType", DefaultSubgraphOpStorageType)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", DefaultSubgraphOpForward)
-.set_attr<nnvm::FMutateInputs>("FMutateInputs", DefaultSubgraphOpMutableInputs)
-.set_attr<std::string>("key_var_num_args", "num_args")
-.set_attr<FExecType>("FExecType", DefaultSubgraphOpExecType)
-.add_argument("data", "NDArray-or-Symbol[]", "input data list");
-
-} // namespace op
-} // namespace mxnet
diff --git a/src/operator/subgraph/default_subgraph_op.cu b/src/operator/subgraph/default_subgraph_op.cu
deleted file mode 100644
index 008826b..0000000
--- a/src/operator/subgraph/default_subgraph_op.cu
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2018 by Contributors
- * \file default_subgraph_op.cu
- * \brief GPU Implementation of subgraph operations
- */
-
-#include <mxnet/ndarray.h>
-#include "./common.h"
-#include "../../imperative/imperative_utils.h"
-#include "../../imperative/cached_op.h"
-
-namespace mxnet {
-namespace op {
-
-void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs);
-
-NNVM_REGISTER_OP(_default_subgraph_op)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", DefaultSubgraphOpForward);
-
-} // namespace op
-} // namespace mxnet
diff --git a/src/operator/subgraph/default_subgraph_property.cc b/src/operator/subgraph/default_subgraph_property.cc
index c8d3e9f..0152344 100644
--- a/src/operator/subgraph/default_subgraph_property.cc
+++ b/src/operator/subgraph/default_subgraph_property.cc
@@ -21,6 +21,7 @@
#include <string>
#include "./common.h"
#include "./subgraph_property.h"
+#include "../../imperative/cached_op.h"
namespace mxnet {
namespace op {
@@ -51,7 +52,7 @@ class ContainOpSelector: public SubgraphSelector {
/*
* This subgraph property finds a subgraph whose nodes have only operators
- * within a set. The operators in the subgraph will be executed by _default_subgraph_op.
+ * within a set. The operators in the subgraph will be executed by _CachedOp.
*/
class DefaultSubgraphProperty: public SubgraphProperty {
public:
@@ -59,9 +60,13 @@ class DefaultSubgraphProperty: public SubgraphProperty {
virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
const int subgraph_id = 0) const {
nnvm::NodePtr n = nnvm::Node::Create();
- n->attrs.op = Op::Get("_default_subgraph_op");
- n->attrs.name = "_default_subgraph_op" + std::to_string(subgraph_id);
+ n->attrs.op = Op::Get("_CachedOp");
+ n->attrs.name = "_CachedOp" + std::to_string(subgraph_id);
n->attrs.subgraphs.push_back(std::make_shared<nnvm::Symbol>(sym));
+
+ std::vector<std::pair<std::string, std::string> > flags{{"static_alloc", "true"}};
+ n->attrs.parsed = CachedOpPtr(new CachedOp(sym, flags));
+
return n;
}
virtual SubgraphSelectorPtr CreateSubgraphSelector() const {
diff --git a/tests/python/unittest/test_subgraph.py b/tests/python/unittest/test_subgraph.py
new file mode 100644
index 0000000..b5577d4
--- /dev/null
+++ b/tests/python/unittest/test_subgraph.py
@@ -0,0 +1,149 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+from __future__ import print_function
+import numpy as np
+import mxnet as mx
+import copy
+import math
+import ctypes
+import random
+import itertools
+from numpy.testing import assert_allclose, assert_array_equal
+from mxnet.test_utils import *
+from mxnet.base import py_str, MXNetError, _as_list, SymbolHandle, check_call, _LIB, c_handle_array, mx_uint
+from common import setup_module, with_seed, teardown
+import unittest
+from mxnet.gluon.model_zoo.vision import get_model
+
+def make_subgraph(subg, *args):
+ js = subg.tojson()
+ return mx.sym._internal._CachedOp(*args, subgraph=js)
+
+@with_seed()
+def test_make_subgraph():
+ def make_subgraph1(stype):
+ a = mx.symbol.Variable(name='a', stype=stype)
+ b = mx.symbol.Variable(name='b', stype=stype)
+ c = a * b
+ d = c * 2
+
+ a1 = mx.symbol.Variable(name='a', stype=stype)
+ b1 = mx.symbol.Variable(name='b', stype=stype)
+ y = make_subgraph(c, a1, b1)
+ y = y * 2
+
+ s = (10, 10)
+ a_arr = mx.nd.array(np.random.normal(-0.1, 0.1, size=s),
+ ctx=default_context()).tostype(stype)
+ b_arr = mx.nd.array(np.random.normal(-0.1, 0.1, size=s),
+ ctx=default_context()).tostype(stype)
+ return (d, y, {'a': a_arr, 'b': b_arr}, {})
+
+ def create_weights(shapes, names):
+ nd_dict = {}
+ sym_dict = {}
+ assert len(shapes) == len(names)
+ for i in range(len(shapes)):
+ sym_dict[names[i]] = mx.symbol.Variable(names[i])
+ nd_dict[names[i]] = mx.nd.array(np.ones(shapes[i]), ctx=default_context())
+ return (nd_dict, sym_dict)
+
+ def make_subgraph_weight(orig, shape, stype):
+ arg_shapes, out_shapes, aux_shapes = orig.infer_shape(data=shape)
+ weight_shapes = arg_shapes[1:]
+ weight_names = orig.list_arguments()[1:]
+ weight_dict, weight_sym_dict = create_weights(weight_shapes, weight_names)
+ aux_dict, aux_sym_dict = create_weights(aux_shapes, orig.list_auxiliary_states())
+
+ input_dict = copy.deepcopy(weight_sym_dict)
+ input_dict.update(aux_sym_dict)
+ input_dict['data'] = mx.symbol.Variable('data', stype=stype)
+ input_list = []
+ for name in orig.list_inputs():
+ assert name in input_dict.keys()
+ input_list.append(input_dict[name])
+ subg = make_subgraph(orig, *input_list)
+
+ arr = mx.nd.random.uniform(-1, 1, shape=shape, ctx=default_context()).tostype(stype)
+ arg_dict = weight_dict
+ arg_dict['data'] = arr
+ return (orig, subg, arg_dict, aux_dict)
+
+ def make_subgraph2(stype, out_mean_var):
+ data = mx.symbol.Variable('data', stype=stype)
+ orig = mx.symbol.BatchNorm(data, fix_gamma=False,
+ output_mean_var=out_mean_var, name="batchnorm")
+ s = (10, 10)
+ return make_subgraph_weight(orig, s, stype)
+
+ def make_subgraph3(stype):
+ data = mx.symbol.Variable('data', stype=stype)
+ conv1 = mx.symbol.Convolution(data=data, kernel=(3, 3), num_filter=16, no_bias=True)
+ bn1 = mx.symbol.BatchNorm(conv1, fix_gamma=False, output_mean_var=False)
+ conv2 = mx.symbol.Convolution(data=data, kernel=(3, 3), num_filter=16, no_bias=True)
+ bn2 = mx.symbol.BatchNorm(conv2, fix_gamma=False, output_mean_var=False)
+ orig = bn1 + bn2
+ s = (1, 3, 32, 32)
+ return make_subgraph_weight(orig, s, stype)
+
+ def make_subgraph4(stype):
+ model = get_model('resnet18_v1')
+ model.hybridize()
+ model.initialize()
+ s = (1, 3, 32, 32)
+ data = mx.nd.random.normal(shape=s)
+ out = model(data)
+ model.export('resnet18')
+ orig = mx.sym.load('resnet18-symbol.json')
+ return make_subgraph_weight(orig, s, stype)
+
+ make_subgraphs = [make_subgraph1,
+ lambda stype: make_subgraph2(stype, False),
+ lambda stype: make_subgraph2(stype, True),
+ make_subgraph3, make_subgraph4]
+ stypes = ['default', 'row_sparse']
+ for make_subg in make_subgraphs:
+ for stype in stypes:
+ orig, subg, inputs, aux_states = make_subg(stype)
+ all_inputs = copy.deepcopy(inputs)
+ all_inputs.update(aux_states)
+ args_grad = {key : mx.nd.empty(shape=all_inputs[key].shape) for key in all_inputs.keys()}
+ e1 = orig.bind(ctx=default_context(), args=all_inputs, args_grad=args_grad,
+ aux_states=all_inputs)
+ args_grad = {key : mx.nd.empty(shape=all_inputs[key].shape) for key in all_inputs.keys()}
+ e2 = subg.bind(ctx=default_context(), args=all_inputs, args_grad=args_grad,
+ aux_states=all_inputs)
+ e1.forward(is_train=True)
+ e2.forward(is_train=True)
+ for i in range(len(e1.outputs)):
+ assert_almost_equal(e1.outputs[i].asnumpy(), e2.outputs[i].asnumpy(),
+ rtol=0.001, atol=0.0001)
+
+ out_grads = [mx.nd.random.uniform(-1, 1, shape=out.shape, ctx=default_context())
+ for out in e1.outputs]
+ e1.backward(out_grads)
+ e2.backward(out_grads)
+ for i in range(len(e1.grad_arrays)):
+ assert_almost_equal(e1.grad_arrays[i].asnumpy(), e2.grad_arrays[i].asnumpy(),
+ rtol=0.001, atol=0.0001)
+
+
+if __name__ == '__main__':
+ import nose
+ nose.runmodule()