You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/02/26 04:35:20 UTC
[incubator-mxnet] branch master updated: [MXNET-1325] Make
InferShapeAttr a standalone pass (#14193)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 0af40f7 [MXNET-1325] Make InferShapeAttr a standalone pass (#14193)
0af40f7 is described below
commit 0af40f7afe44464a7fbd4ac092c4377b69c56918
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Mon Feb 25 20:34:52 2019 -0800
[MXNET-1325] Make InferShapeAttr a standalone pass (#14193)
* Make InferShapeAttr a standalone pass
* Fix
* Fix
* Fix
---
src/executor/infer_graph_attr_pass.cc | 268 +++++++++++++++++++++++++++++++++-
1 file changed, 267 insertions(+), 1 deletion(-)
diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc
index c14482a..e4dd3f6 100644
--- a/src/executor/infer_graph_attr_pass.cc
+++ b/src/executor/infer_graph_attr_pass.cc
@@ -322,6 +322,272 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret,
return ret;
}
+template<typename IsNone, typename FDefault>
+nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
+ const nnvm::TShape empty_val,
+ const char* infer_name,
+ const char* input_name,
+ const char* attr_key_name,
+ const char* attr_name,
+ const char* unknown_name,
+ IsNone fis_none,
+ FDefault fdefault,
+ bool bwd_identity_assign,
+ const char* dispatch_mode_name,
+ const DispatchMode default_mode_val = DispatchMode::kUndefined) {
+ using nnvm::IndexedGraph;
+ using nnvm::Op;
+ using AttrType = nnvm::TShape;
+ using FInferType = nnvm::FInferShape;
+ using AttrVector = std::vector<AttrType>;
+ using NodeAttrVector = std::vector<DispatchMode>;
+ using dmlc::any;
+ const IndexedGraph& idx = ret.indexed_graph();
+ static auto& finfer_shape =
+ Op::GetAttr<FInferType>(infer_name);
+ static auto& is_backward =
+ Op::GetAttr<nnvm::TIsBackward>("TIsBackward");
+ // gradient function, used to get node correspondence.
+ static auto& fgrad =
+ Op::GetAttr<nnvm::FGradient>("FGradient");
+ // reshape shape vector
+ AttrVector rshape;
+ // dispatch mode vector
+ DispatchModeVector dispatch_modes;
+ if (ret.attrs.count(attr_name) != 0) {
+ rshape = ret.MoveCopyAttr<AttrVector>(attr_name);
+ } else {
+ rshape.resize(idx.num_node_entries(), empty_val);
+ }
+
+ if (ret.attrs.count(input_name) != 0) {
+ const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
+ CHECK_LE(shape_args.size(), idx.input_nodes().size())
+ << "More provided " << attr_name << "s than number of arguments.";
+ for (size_t i = 0; i < shape_args.size(); ++i) {
+ rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i];
+ }
+ }
+
+ // get the shape hints
+ std::string shape_hints_key = std::string(attr_name) + "_hints";
+ if (ret.attrs.count(shape_hints_key)) {
+ nnvm::NodeEntryMap<AttrType> shape_hints =
+ ret.GetAttr<nnvm::NodeEntryMap<AttrType>>(shape_hints_key);
+ for (const auto& kv : shape_hints) {
+ nnvm::NodeEntry e = kv.first;
+ if (idx.exist(e.node.get())) {
+ rshape[idx.entry_id(kv.first)] = kv.second;
+ }
+ }
+ }
+
+ std::string shape_attr_key;
+ if (ret.attrs.count(attr_key_name) != 0) {
+ shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
+ // erase the provided arguments
+ ret.attrs.erase(attr_key_name);
+ }
+
+ // limit inference to part of the graph
+ uint32_t node_start = 0, node_end = idx.num_nodes();
+ if (ret.attrs.count("node_range")) {
+ const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("node_range");
+ node_start = range.first;
+ node_end = range.second;
+ CHECK_GE(node_start, 0);
+ CHECK_LE(node_end, idx.num_nodes());
+ ret.attrs.erase("node_range");
+ }
+ uint32_t entry_start = 0, entry_end = idx.num_node_entries();
+ if (ret.attrs.count("entry_range")) {
+ const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("entry_range");
+ entry_start = range.first;
+ entry_end = range.second;
+ CHECK_GE(entry_start, 0);
+ CHECK_LE(entry_end, idx.num_node_entries());
+ ret.attrs.erase("entry_range");
+ }
+ // populate the node attribute vector
+ if (dispatch_mode_name != nullptr) {
+ if (ret.attrs.count(dispatch_mode_name) != 0) {
+ dispatch_modes = ret.MoveCopyAttr<NodeAttrVector>(dispatch_mode_name);
+ } else {
+ LOG(FATAL) << "Node attribute " << dispatch_mode_name << " does not exist in the graph";
+ }
+ }
+
+ // Temp space for shape inference.
+ std::vector<AttrType> ishape, oshape;
+ // whether a shape is dynamic
+ std::vector<int> is_dynamic(rshape.size(), 0);
+ // inference step function for nid
+ auto infer_step = [&](uint32_t nid, bool last_iter) {
+ const auto& inode = idx[nid];
+ const std::string name = inode.source->attrs.name;
+ const uint32_t num_inputs = inode.inputs.size();
+ const uint32_t num_outputs = inode.source->num_outputs();
+ if (inode.source->is_variable()) {
+ // Variable node. No operator. Only one output entry.
+ CHECK(inode.source->op() == nullptr);
+ CHECK_EQ(num_outputs, 1U);
+ const uint32_t out_ent_id = idx.entry_id(nid, 0);
+ if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) {
+ auto it = inode.source->attrs.dict.find(shape_attr_key);
+ if (it != inode.source->attrs.dict.end()) {
+ std::istringstream is(it->second);
+ CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
+ }
+ }
+ // assign a default value to node attribute
+ if (dispatch_mode_name != nullptr) {
+ op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val);
+ }
+ } else if (is_backward.get(inode.source->op(), false) &&
+ inode.control_deps.size() && bwd_identity_assign) {
+ CHECK(dispatch_mode_name == nullptr)
+ << "Backward inference for node attributes is not available";
+ CHECK_GE(inode.control_deps.size(), 1U)
+ << "BackwardOp need to have control_deps to its forward op";
+ const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
+ nnvm::NodePtr fwd_ptr = inode.source->control_deps[0];
+ CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable";
+ // use gradient function to find out the correspondence.
+ std::vector<nnvm::NodeEntry> ograd(fwd_ptr->num_outputs());
+ for (size_t i = 0; i < ograd.size(); ++i) {
+ ograd[i].index = static_cast<uint32_t>(i);
+ }
+ // input gradient list
+ auto igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd);
+ const nnvm::Node* igrad_node = nullptr;
+ // Input gradient assignement
+ for (size_t i = 0; i < igrad.size(); ++i) {
+ if (igrad[i].node->op() == inode.source->op()) {
+ uint32_t eid = idx.entry_id(nid, igrad[i].index);
+ if (fis_none(rshape[eid])) {
+ rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])];
+ } else if (!fis_none(rshape[idx.entry_id(fnode.inputs[i])])) {
+ // Need to skip empty forward shape, because it may not be
+ // available now and it is possible to infer the forward
+ // shape in one of the next a few passes
+ CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])])
+ << "Backward shape inconsistent with the forward shape";
+ }
+ if (igrad_node == nullptr) {
+ igrad_node = igrad[i].node.get();
+ } else {
+ CHECK(igrad_node == igrad[i].node.get());
+ }
+ }
+ }
+ // out grad entries
+ CHECK(igrad_node != nullptr)
+ << "Cannot find matching backward op for " << inode.source->attrs.name;
+ for (size_t i = 0; i < igrad_node->inputs.size(); ++i) {
+ const nnvm::NodeEntry& e = igrad_node->inputs[i];
+ if (e.node == nullptr) {
+ uint32_t eid = idx.entry_id(inode.inputs[i]);
+ if (fis_none(rshape[eid])) {
+ rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)];
+ }
+ }
+ }
+ } else {
+ DispatchMode* dispatch_mode = nullptr;
+ bool forward_known = true;
+ // Forward operator inference.
+ ishape.resize(num_inputs, empty_val);
+ bool is_input_dynamic_shape = false;
+ for (uint32_t i = 0; i < ishape.size(); ++i) {
+ ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
+ if (ishape[i].ndim() == 0 && is_dynamic[idx.entry_id(inode.inputs[i])]) {
+ is_input_dynamic_shape = true;
+ }
+ if (fis_none(ishape[i])) forward_known = false;
+ }
+ oshape.resize(num_outputs, empty_val);
+ for (uint32_t i = 0; i < oshape.size(); ++i) {
+ oshape[i] = rshape[idx.entry_id(nid, i)];
+ if (fis_none(oshape[i])) forward_known = false;
+ }
+ if (dispatch_mode_name != nullptr) {
+ dispatch_mode = &dispatch_modes[nid];
+ if (dispatch_modes[nid] == DispatchMode::kUndefined) forward_known = false;
+ }
+ auto finfer = finfer_shape.get(inode.source->op(), fdefault);
+ if (finfer == nullptr || is_input_dynamic_shape) {
+ for (uint32_t i = 0; i < oshape.size(); ++i) {
+ if (oshape[i].ndim() == 0) {
+ is_dynamic[idx.entry_id(nid, i)] = 1;
+ }
+ }
+ } else if (!forward_known) {
+ if (finfer != nullptr) {
+ // Call inference function of the operator.
+ try {
+ forward_known = ApplyOpInferAttr(ret, finfer, inode.source->attrs,
+ nid, &ishape, &oshape, dispatch_mode);
+ } catch (const std::exception& e) {
+ throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what());
+ }
+ } else {
+ CHECK(!last_iter)
+ << "Attribute " << infer_name
+ << " is not registed by op " << inode.source->op()->name
+ << " we are not able to complete the inference because of this";
+ }
+ }
+ // Save to the result map.
+ for (uint32_t i = 0; i < num_inputs; ++i) {
+ rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
+ }
+ for (uint32_t i = 0; i < num_outputs; ++i) {
+ rshape[idx.entry_id(nid, i)] = oshape[i];
+ }
+ }
+ };
+
+ size_t last_num_unknown;
+ size_t num_unknown_dispatch_mode = dispatch_mode_name ? node_end - node_start : 0;
+ size_t num_unknown_entry_attr = entry_end - entry_start;
+ size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode;
+ int i = 0;
+ do {
+ if (i % 2 == 0) {
+ for (uint32_t nid = node_start; nid < node_end; ++nid) {
+ infer_step(nid, false);
+ }
+ } else {
+ // backward inference
+ for (uint32_t i = node_end; i != node_start; --i) {
+ infer_step(i - 1, false);
+ }
+ }
+ last_num_unknown = num_unknown;
+ num_unknown = 0;
+ for (size_t j = entry_start; j < entry_end; ++j) {
+ if (fis_none(rshape[j])) {
+ ++num_unknown;
+ }
+ }
+ if (dispatch_mode_name) {
+ for (size_t i = node_start; i < node_end; i++) {
+ if (dispatch_modes[i] == DispatchMode::kUndefined) ++num_unknown;
+ }
+ }
+ ++i;
+ } while (num_unknown > 0 && last_num_unknown > num_unknown);
+ // set the shapes
+ ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
+ // set the shapes
+ if (dispatch_mode_name) {
+ ret.attrs[dispatch_mode_name] = std::make_shared<any>(std::move(dispatch_modes));
+ }
+ // number of nodes who knows the shape.
+ ret.attrs[unknown_name] = std::make_shared<any>(num_unknown);
+ return ret;
+}
+
nnvm::Graph InferShape(nnvm::Graph&& graph,
nnvm::ShapeVector&& shape_inputs,
const std::string& shape_attr_key) {
@@ -332,7 +598,7 @@ nnvm::Graph InferShape(nnvm::Graph&& graph,
if (shape_attr_key.length() != 0) {
graph.attrs["shape_attr_key"] = std::make_shared<any>(shape_attr_key);
}
- return InferAttr<nnvm::TShape, nnvm::FInferShape>(
+ return InferShapeAttr(
std::move(graph), nnvm::TShape(),
"FInferShape", "shape_inputs", "shape_attr_key",
"shape", "shape_num_unknown_nodes",