You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/06/25 22:20:49 UTC

[GitHub] [incubator-mxnet] ZhennanQin commented on a change in pull request #15298: Fix Cached_op with static_shape=true

ZhennanQin commented on a change in pull request #15298: Fix Cached_op with static_shape=true
URL: https://github.com/apache/incubator-mxnet/pull/15298#discussion_r297417332
 
 

 ##########
 File path: src/nnvm/legacy_op_util.cc
 ##########
 @@ -110,47 +109,39 @@ class OperatorState {
                const std::vector<TBlob>& inputs,
                const std::vector<OpReqType>& req,
                const std::vector<TBlob>& outputs) {
-    if (!fwd_init_) {
-      CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
-      CHECK_EQ(outputs.size(), out_data_.size());
-      // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
-      // referred by arg_data_ptr_ will be overriden
-      for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i];
-      for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i];
-      for (size_t i = 0; i < aux_data_.size(); ++i) {
-        aux_data_[i] = inputs[i + in_data_fwd_.size()];
-      }
-      for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
-      fwd_init_ = true;
+    CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
+    CHECK_EQ(outputs.size(), out_data_.size());
+    // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
+    // referred by arg_data_ptr_ will be overriden
+    for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i];
+    for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i];
+    for (size_t i = 0; i < aux_data_.size(); ++i) {
+      aux_data_[i] = inputs[i + in_data_fwd_.size()];
     }
+    for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
     opr_->Forward(ctx, in_data_fwd_, req, out_data_, aux_data_);
   }
 
   void Backward(const OpContext &ctx,
                 const std::vector<TBlob>& inputs,
                 const std::vector<OpReqType>& req,
                 const std::vector<TBlob>& outputs) {
-    if (!bwd_init_) {
-      CHECK(fwd_init_);
-      CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size());
-      // override tblobs pointed by arg_data_ptr_ since they might not contain
-      // initialized data during forward pass.
-      for (size_t i = 0; i < arg_data_ptr_.size(); ++i) {
-        *arg_data_ptr_[i] = inputs[i];
-      }
-      for (size_t i = 0; i < aux_data_.size(); ++i) {
-        aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i];
-      }
-      CHECK_EQ(outputs.size(), in_grad_.size());
-      for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i];
-      bwd_init_ = true;
 
 Review comment:
   When using legacy ops in Cached_op, this caching is not correct, because even static_alloc=true and static_shape=true, the input or output TBlobs may changed if they are the input or output of Cached_op. 
   
   Thinking a small case that end-user only hybridize one legacy op, then its input is the Cached_op's input, and also for output. Then end-user may pass different NDArrays to this Cached_op, and this TBlobs cache isn't correct.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services