You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/05/03 17:40:47 UTC

[incubator-mxnet] branch master updated: Revert "Improve cached_op performance for static mode (#14785)" (#14868)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 204f3f2  Revert "Improve cached_op performance for static mode (#14785)" (#14868)
204f3f2 is described below

commit 204f3f2de812e99ca64c8464f3b9e652719ec67d
Author: Anirudh <aa...@amazon.com>
AuthorDate: Fri May 3 10:40:05 2019 -0700

    Revert "Improve cached_op performance for static mode (#14785)" (#14868)
    
    This reverts commit 369b66d0f10ba479ce96f78f7c838bd7bc41d951.
---
 src/executor/attach_op_execs_pass.cc |  8 ++------
 src/executor/exec_pass.h             |  9 +--------
 src/imperative/cached_op.cc          | 10 ++++------
 src/imperative/imperative_utils.h    | 26 ++++++++++++++------------
 4 files changed, 21 insertions(+), 32 deletions(-)

diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc
index 8f47bc2..b04d132 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -261,7 +261,7 @@ class FComputeExExecutor : public OpExecutor {
   ExecType exec_type_;
 };
 
-void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, size_t i) {
+void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
   using nnvm::DTypeVector;
   using mxnet::ShapeVector;
   using nnvm::FMutateInputs;
@@ -302,10 +302,6 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state,
 
     OpStatePtr state = fcreate_op_state[op](
         inode.source->attrs, vctx[i], ishape, itype);
-    if (p_state) {
-      CHECK_GT(p_state->size(), i);
-      p_state->at(i) = state;
-    }
     FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
         op, "FStatefulComputeEx", vctx[i]);
     // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
@@ -363,7 +359,7 @@ Graph AttachOpExecs(Graph g) {
   const auto& idx = g.indexed_graph();
   OpExecVector ret(idx.num_nodes());
   for (size_t i = 0; i < idx.num_nodes(); ++i) {
-    CreateOpExecs(g, &ret, nullptr, i);
+    CreateOpExecs(g, &ret, i);
   }
   g.attrs["op_execs"] = std::make_shared<nnvm::any>(ret);
   return g;
diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index f544d6b..acf20de 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -99,12 +99,6 @@ class OpExecutor {
 using OpExecVector = std::vector<std::shared_ptr<OpExecutor> >;
 
 /*!
- * \brief per node vector of operator states.
- * \note stored under attribute "op_states"
- */
-using OpStateVector = std::vector<OpStatePtr>;
-
-/*!
  * \brief per node context vector
  * \node stored under "context"
  */
@@ -121,10 +115,9 @@ using DevMaskVector = std::vector<int>;
  *
  * \param g input graph
  * \param p_ret OpExecVector for input and output
- * \param p_state OpStateVector if it has.
  * \param i the id of the node
  */
-void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, size_t i);
+void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i);
 /*!
  * \brief Attach OpExecutor to the graph attributes.
  *
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 7a5ed21..c9215c5 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -285,7 +285,7 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx,
   CheckAndInferShape(&g, std::move(shape_inputs), true,
                      {0, 0}, {0, 0},
                      &contain_dynamic_shape);
-  if (contain_dynamic_shape && erase_result) {
+  if (erase_result) {
     g.attrs.erase("shape");
     g.attrs.erase("shape_inputs");
   }
@@ -603,7 +603,7 @@ void CachedOp::StaticInitExec(
     }
   } else {
     for (size_t i = start_nid; i < end_nid; ++i) {
-      exec::CreateOpExecs(g, &state.execs, &state.op_states, i);
+      exec::CreateOpExecs(g, &state.execs, i);
     }
     exec::AttachOpResources(g, state.execs, start_nid, end_nid);
 
@@ -705,10 +705,8 @@ void CachedOp::StaticRunOps(
           arg_shapes.emplace_back(ndinput->shape());
           arg_dtypes.emplace_back(ndinput->dtype());
         }
-        if (!state.op_states[i]) {
-          state.op_states[i] =
-              createop[node.source->op()](node.source->attrs, default_ctx, arg_shapes, arg_dtypes);
-        }
+        state.op_states[i] = createop[node.source->op()](
+            node.source->attrs, default_ctx, arg_shapes, arg_dtypes);
         Imperative::Get()->InvokeOp(
             default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
             dispatch_mode, state.op_states[i]);
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 5c97068..9d4e4bd 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -595,21 +595,23 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g, mxnet::ShapeVector&& shapes,
     *contain_unknown = false;
   }
   nnvm::Graph& g = *p_g;
-  if (g.attrs.count("shape")) {
+  if (use_inputs) {
+    if (g.attrs.count("shape_inputs") &&
+        g.GetAttr<mxnet::ShapeVector>("shape_inputs") == shapes) return true;
+  } else if (g.attrs.count("shape")) {
     const auto& prev_shapes = g.GetAttr<mxnet::ShapeVector>("shape");
-    if (prev_shapes.size() == shapes.size()) {
-      bool match = true;
-      for (size_t i = 0; i < shapes.size(); ++i) {
-        if (i == entry_range.first) {
-          i = entry_range.second;
-          if (i >= shapes.size()) break;
-        }
-        if (shapes[i] == prev_shapes[i]) continue;
-        match = false;
-        break;
+    CHECK_EQ(prev_shapes.size(), shapes.size());
+    bool match = true;
+    for (size_t i = 0; i < shapes.size(); ++i) {
+      if (i == entry_range.first) {
+        i = entry_range.second;
+        if (i >= shapes.size()) break;
       }
-      if (match) return true;
+      if (shapes[i] == prev_shapes[i]) continue;
+      match = false;
+      break;
     }
+    if (match) return true;
   }
   g.attrs.erase("shape");
   g.attrs.erase("shape_inputs");