You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/02/26 04:35:20 UTC

[incubator-mxnet] branch master updated: [MXNET-1325] Make InferShapeAttr a standalone pass (#14193)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0af40f7  [MXNET-1325] Make InferShapeAttr a standalone pass (#14193)
0af40f7 is described below

commit 0af40f7afe44464a7fbd4ac092c4377b69c56918
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Mon Feb 25 20:34:52 2019 -0800

    [MXNET-1325] Make InferShapeAttr a standalone pass (#14193)
    
    * Make InferShapeAttr a standalone pass
    
    * Fix
    
    * Fix
    
    * Fix
---
 src/executor/infer_graph_attr_pass.cc | 268 +++++++++++++++++++++++++++++++++-
 1 file changed, 267 insertions(+), 1 deletion(-)

diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc
index c14482a..e4dd3f6 100644
--- a/src/executor/infer_graph_attr_pass.cc
+++ b/src/executor/infer_graph_attr_pass.cc
@@ -322,6 +322,272 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret,
   return ret;
 }
 
+template<typename IsNone, typename FDefault>
+nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
+                           const nnvm::TShape empty_val,
+                           const char* infer_name,
+                           const char* input_name,
+                           const char* attr_key_name,
+                           const char* attr_name,
+                           const char* unknown_name,
+                           IsNone fis_none,
+                           FDefault fdefault,
+                           bool bwd_identity_assign,
+                           const char* dispatch_mode_name,
+                           const DispatchMode default_mode_val = DispatchMode::kUndefined) {
+  using nnvm::IndexedGraph;
+  using nnvm::Op;
+  using AttrType = nnvm::TShape;
+  using FInferType = nnvm::FInferShape;
+  using AttrVector = std::vector<AttrType>;
+  using NodeAttrVector = std::vector<DispatchMode>;
+  using dmlc::any;
+  const IndexedGraph& idx = ret.indexed_graph();
+  static auto& finfer_shape =
+      Op::GetAttr<FInferType>(infer_name);
+  static auto& is_backward =
+      Op::GetAttr<nnvm::TIsBackward>("TIsBackward");
+  // gradient function, used to get node correspondence.
+  static auto& fgrad =
+      Op::GetAttr<nnvm::FGradient>("FGradient");
+  // reshape shape vector
+  AttrVector rshape;
+  // dispatch mode vector
+  DispatchModeVector dispatch_modes;
+  if (ret.attrs.count(attr_name) != 0) {
+    rshape = ret.MoveCopyAttr<AttrVector>(attr_name);
+  } else {
+    rshape.resize(idx.num_node_entries(), empty_val);
+  }
+
+  if (ret.attrs.count(input_name) != 0) {
+    const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
+    CHECK_LE(shape_args.size(), idx.input_nodes().size())
+        << "More provided " << attr_name << "s than number of arguments.";
+    for (size_t i = 0; i < shape_args.size(); ++i) {
+      rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i];
+    }
+  }
+
+  // get the shape hints
+  std::string shape_hints_key = std::string(attr_name) + "_hints";
+  if (ret.attrs.count(shape_hints_key)) {
+    nnvm::NodeEntryMap<AttrType> shape_hints =
+      ret.GetAttr<nnvm::NodeEntryMap<AttrType>>(shape_hints_key);
+    for (const auto& kv : shape_hints) {
+      nnvm::NodeEntry e = kv.first;
+      if (idx.exist(e.node.get())) {
+        rshape[idx.entry_id(kv.first)] = kv.second;
+      }
+    }
+  }
+
+  std::string shape_attr_key;
+  if (ret.attrs.count(attr_key_name) != 0) {
+    shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
+    // erase the provided arguments
+    ret.attrs.erase(attr_key_name);
+  }
+
+  // limit inference to part of the graph
+  uint32_t node_start = 0, node_end = idx.num_nodes();
+  if (ret.attrs.count("node_range")) {
+    const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("node_range");
+    node_start = range.first;
+    node_end = range.second;
+    CHECK_GE(node_start, 0);
+    CHECK_LE(node_end, idx.num_nodes());
+    ret.attrs.erase("node_range");
+  }
+  uint32_t entry_start = 0, entry_end = idx.num_node_entries();
+  if (ret.attrs.count("entry_range")) {
+    const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("entry_range");
+    entry_start = range.first;
+    entry_end = range.second;
+    CHECK_GE(entry_start, 0);
+    CHECK_LE(entry_end, idx.num_node_entries());
+    ret.attrs.erase("entry_range");
+  }
+  // populate the node attribute vector
+  if (dispatch_mode_name != nullptr) {
+    if (ret.attrs.count(dispatch_mode_name) != 0) {
+      dispatch_modes = ret.MoveCopyAttr<NodeAttrVector>(dispatch_mode_name);
+    } else {
+      LOG(FATAL) << "Node attribute " << dispatch_mode_name << " does not exist in the graph";
+    }
+  }
+
+  // Temp space for shape inference.
+  std::vector<AttrType> ishape, oshape;
+  // whether a shape is dynamic
+  std::vector<int> is_dynamic(rshape.size(), 0);
+  // inference step function for nid
+  auto infer_step = [&](uint32_t nid, bool last_iter) {
+    const auto& inode = idx[nid];
+    const std::string name = inode.source->attrs.name;
+    const uint32_t num_inputs = inode.inputs.size();
+    const uint32_t num_outputs = inode.source->num_outputs();
+    if (inode.source->is_variable()) {
+      // Variable node. No operator. Only one output entry.
+      CHECK(inode.source->op() == nullptr);
+      CHECK_EQ(num_outputs, 1U);
+      const uint32_t out_ent_id = idx.entry_id(nid, 0);
+      if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) {
+        auto it = inode.source->attrs.dict.find(shape_attr_key);
+        if (it != inode.source->attrs.dict.end()) {
+          std::istringstream is(it->second);
+          CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
+        }
+      }
+      // assign a default value to node attribute
+      if (dispatch_mode_name != nullptr) {
+        op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val);
+      }
+    } else if (is_backward.get(inode.source->op(), false) &&
+               inode.control_deps.size() && bwd_identity_assign) {
+      CHECK(dispatch_mode_name == nullptr)
+        << "Backward inference for node attributes is not available";
+      CHECK_GE(inode.control_deps.size(), 1U)
+        << "BackwardOp need to have control_deps to its forward op";
+      const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
+      nnvm::NodePtr fwd_ptr = inode.source->control_deps[0];
+      CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable";
+      // use gradient function to find out the correspondence.
+      std::vector<nnvm::NodeEntry> ograd(fwd_ptr->num_outputs());
+      for (size_t i = 0; i < ograd.size(); ++i) {
+        ograd[i].index = static_cast<uint32_t>(i);
+      }
+      // input gradient list
+      auto igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd);
+      const nnvm::Node* igrad_node = nullptr;
+      // Input gradient assignement
+      for (size_t i = 0; i < igrad.size(); ++i) {
+        if (igrad[i].node->op() == inode.source->op()) {
+          uint32_t eid = idx.entry_id(nid, igrad[i].index);
+          if (fis_none(rshape[eid])) {
+            rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])];
+          } else if (!fis_none(rshape[idx.entry_id(fnode.inputs[i])])) {
+            // Need to skip empty forward shape, because it may not be
+            // available now and it is possible to infer the forward
+            // shape in one of the next a few passes
+            CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])])
+                << "Backward shape inconsistent with the forward shape";
+          }
+          if (igrad_node == nullptr) {
+            igrad_node = igrad[i].node.get();
+          } else {
+            CHECK(igrad_node == igrad[i].node.get());
+          }
+        }
+      }
+      // out grad entries
+      CHECK(igrad_node != nullptr)
+        << "Cannot find matching backward op for " << inode.source->attrs.name;
+      for (size_t i = 0; i < igrad_node->inputs.size(); ++i) {
+        const nnvm::NodeEntry& e = igrad_node->inputs[i];
+        if (e.node == nullptr) {
+          uint32_t eid = idx.entry_id(inode.inputs[i]);
+          if (fis_none(rshape[eid])) {
+            rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)];
+          }
+        }
+      }
+    } else {
+      DispatchMode* dispatch_mode = nullptr;
+      bool forward_known = true;
+      // Forward operator inference.
+      ishape.resize(num_inputs, empty_val);
+      bool is_input_dynamic_shape = false;
+      for (uint32_t i = 0; i < ishape.size(); ++i) {
+        ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
+        if (ishape[i].ndim() == 0 && is_dynamic[idx.entry_id(inode.inputs[i])]) {
+          is_input_dynamic_shape = true;
+        }
+        if (fis_none(ishape[i])) forward_known = false;
+      }
+      oshape.resize(num_outputs, empty_val);
+      for (uint32_t i = 0; i < oshape.size(); ++i) {
+        oshape[i] = rshape[idx.entry_id(nid, i)];
+        if (fis_none(oshape[i])) forward_known = false;
+      }
+      if (dispatch_mode_name != nullptr) {
+        dispatch_mode = &dispatch_modes[nid];
+        if (dispatch_modes[nid] == DispatchMode::kUndefined) forward_known = false;
+      }
+      auto finfer = finfer_shape.get(inode.source->op(), fdefault);
+      if (finfer == nullptr || is_input_dynamic_shape) {
+        for (uint32_t i = 0; i < oshape.size(); ++i) {
+          if (oshape[i].ndim() == 0) {
+            is_dynamic[idx.entry_id(nid, i)] = 1;
+          }
+        }
+      } else if (!forward_known) {
+        if (finfer != nullptr) {
+          // Call inference function of the operator.
+          try {
+            forward_known = ApplyOpInferAttr(ret, finfer, inode.source->attrs,
+                                             nid, &ishape, &oshape, dispatch_mode);
+          } catch (const std::exception& e) {
+            throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what());
+          }
+        } else {
+          CHECK(!last_iter)
+              << "Attribute " << infer_name
+              << " is not registed by op " << inode.source->op()->name
+              << " we are not able to complete the inference because of this";
+        }
+      }
+      // Save to the result map.
+      for (uint32_t i = 0; i < num_inputs; ++i) {
+        rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
+      }
+      for (uint32_t i = 0; i < num_outputs; ++i) {
+        rshape[idx.entry_id(nid, i)] = oshape[i];
+      }
+    }
+  };
+
+  size_t last_num_unknown;
+  size_t num_unknown_dispatch_mode = dispatch_mode_name ? node_end - node_start : 0;
+  size_t num_unknown_entry_attr = entry_end - entry_start;
+  size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode;
+  int i = 0;
+  do {
+    if (i % 2 == 0) {
+      for (uint32_t nid = node_start; nid < node_end; ++nid) {
+        infer_step(nid, false);
+      }
+    } else {
+      // backward inference
+      for (uint32_t i = node_end; i != node_start; --i) {
+        infer_step(i - 1, false);
+      }
+    }
+    last_num_unknown = num_unknown;
+    num_unknown = 0;
+    for (size_t j = entry_start; j < entry_end; ++j) {
+      if (fis_none(rshape[j])) {
+        ++num_unknown;
+      }
+    }
+    if (dispatch_mode_name) {
+      for (size_t i = node_start; i < node_end; i++) {
+        if (dispatch_modes[i] == DispatchMode::kUndefined) ++num_unknown;
+      }
+    }
+    ++i;
+  } while (num_unknown > 0 && last_num_unknown > num_unknown);
+  // set the shapes
+  ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
+  // set the shapes
+  if (dispatch_mode_name) {
+    ret.attrs[dispatch_mode_name] = std::make_shared<any>(std::move(dispatch_modes));
+  }
+  // number of nodes who knows the shape.
+  ret.attrs[unknown_name] = std::make_shared<any>(num_unknown);
+  return ret;
+}
+
 nnvm::Graph InferShape(nnvm::Graph&& graph,
                        nnvm::ShapeVector&& shape_inputs,
                        const std::string& shape_attr_key) {
@@ -332,7 +598,7 @@ nnvm::Graph InferShape(nnvm::Graph&& graph,
   if (shape_attr_key.length() != 0) {
     graph.attrs["shape_attr_key"] = std::make_shared<any>(shape_attr_key);
   }
-  return InferAttr<nnvm::TShape, nnvm::FInferShape>(
+  return InferShapeAttr(
       std::move(graph), nnvm::TShape(),
       "FInferShape", "shape_inputs", "shape_attr_key",
       "shape", "shape_num_unknown_nodes",