You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/06/27 16:47:00 UTC

[incubator-mxnet] branch v1.5.x updated: [backport 1.5.x]Fix Cached_op with static_shape=true (#15298) (#15380)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch v1.5.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.5.x by this push:
     new 75a9e18  [backport 1.5.x]Fix Cached_op with static_shape=true (#15298) (#15380)
75a9e18 is described below

commit 75a9e187d00a8b7ebc71412a02ed0e3ae489d91f
Author: Lai Wei <ro...@gmail.com>
AuthorDate: Thu Jun 27 09:46:32 2019 -0700

    [backport 1.5.x]Fix Cached_op with static_shape=true (#15298) (#15380)
    
    * Fix Cached_op with static_shape=true (#15298)
    
    * Fix
    
    * run ci
    
    * trigger
---
 src/imperative/cached_op.cc |  7 +++++--
 src/nnvm/legacy_op_util.cc  | 47 ++++++++++++++++++---------------------------
 2 files changed, 24 insertions(+), 30 deletions(-)

diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index d7e1543..efe3801 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -81,6 +81,7 @@ struct CachedOp::CachedOpState {
 
   std::vector<NDArray> buff;
   std::vector<NDArray*> arrays;
+  std::vector<NDArray*> arrays_with_in_out;
   std::vector<OpReqType> array_reqs;
 
   std::vector<OpStatePtr> op_states;
@@ -762,7 +763,8 @@ OpStatePtr CachedOp::StaticForward(
   // We are going to add input and output arrays to the array list.
   // The input and output arrays should only be valid for this run,
   // so we shouldn't modify the state's array list.
-  auto arrays = state.arrays;
+  state.arrays_with_in_out = state.arrays;
+  auto& arrays = state.arrays_with_in_out;
   if (config_.static_shape) {
     for (auto i : config_.param_indices) {
       auto nid = idx.input_nodes()[i];
@@ -1063,7 +1065,8 @@ void CachedOp::StaticBackward(
   // We are going to add input and output arrays to the array list.
   // The input and output arrays should only be valid for this run,
   // so we shouldn't modify the state's array list.
-  auto arrays = state.arrays;
+  state.arrays_with_in_out = state.arrays;
+  auto& arrays = state.arrays_with_in_out;
   for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
     auto eid = state.info.bwd_input_eid[i];
     if (eid == kEidNotExist) {
diff --git a/src/nnvm/legacy_op_util.cc b/src/nnvm/legacy_op_util.cc
index 698666f..3e03b6b 100644
--- a/src/nnvm/legacy_op_util.cc
+++ b/src/nnvm/legacy_op_util.cc
@@ -79,7 +79,6 @@ class OperatorState {
  public:
   OperatorState(Operator *opr, const OperatorProperty *prop) {
     opr_ = opr;
-    fwd_init_ = bwd_init_ = false;
 
     in_data_fwd_.resize(prop->ListArguments().size());
     in_data_bwd_.resize(prop->ListArguments().size());
@@ -110,19 +109,16 @@ class OperatorState {
                const std::vector<TBlob>& inputs,
                const std::vector<OpReqType>& req,
                const std::vector<TBlob>& outputs) {
-    if (!fwd_init_) {
-      CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
-      CHECK_EQ(outputs.size(), out_data_.size());
-      // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
-      // referred by arg_data_ptr_ will be overriden
-      for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i];
-      for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i];
-      for (size_t i = 0; i < aux_data_.size(); ++i) {
-        aux_data_[i] = inputs[i + in_data_fwd_.size()];
-      }
-      for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
-      fwd_init_ = true;
+    CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
+    CHECK_EQ(outputs.size(), out_data_.size());
+    // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
+    // referred by arg_data_ptr_ will be overriden
+    for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i];
+    for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i];
+    for (size_t i = 0; i < aux_data_.size(); ++i) {
+      aux_data_[i] = inputs[i + in_data_fwd_.size()];
     }
+    for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
     opr_->Forward(ctx, in_data_fwd_, req, out_data_, aux_data_);
   }
 
@@ -130,27 +126,22 @@ class OperatorState {
                 const std::vector<TBlob>& inputs,
                 const std::vector<OpReqType>& req,
                 const std::vector<TBlob>& outputs) {
-    if (!bwd_init_) {
-      CHECK(fwd_init_);
-      CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size());
-      // override tblobs pointed by arg_data_ptr_ since they might not contain
-      // initialized data during forward pass.
-      for (size_t i = 0; i < arg_data_ptr_.size(); ++i) {
-        *arg_data_ptr_[i] = inputs[i];
-      }
-      for (size_t i = 0; i < aux_data_.size(); ++i) {
-        aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i];
-      }
-      CHECK_EQ(outputs.size(), in_grad_.size());
-      for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i];
-      bwd_init_ = true;
+    CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size());
+    // override tblobs pointed by arg_data_ptr_ since they might not contain
+    // initialized data during forward pass.
+    for (size_t i = 0; i < arg_data_ptr_.size(); ++i) {
+      *arg_data_ptr_[i] = inputs[i];
+    }
+    for (size_t i = 0; i < aux_data_.size(); ++i) {
+      aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i];
     }
+    CHECK_EQ(outputs.size(), in_grad_.size());
+    for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i];
     opr_->Backward(ctx, out_grad_, in_data_bwd_, out_data_, req, in_grad_, aux_data_);
   }
 
  private:
   Operator *opr_;
-  bool fwd_init_, bwd_init_;
   // input data blobs for forward and backward
   // in_data_fwd_ and in_data_bwd_ will hold different tblobs when StorageFallbackOpExecutor
   // performs storage fallback on a non-default input NDArray. The one in in_data_fwd_ is