You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/06/25 17:12:47 UTC

[GitHub] [incubator-mxnet] anirudh2290 commented on a change in pull request #15298: Fix Cached_op with static_shape=true

anirudh2290 commented on a change in pull request #15298: Fix Cached_op with static_shape=true
URL: https://github.com/apache/incubator-mxnet/pull/15298#discussion_r297298941
 
 

 ##########
 File path: src/nnvm/legacy_op_util.cc
 ##########
 @@ -110,47 +109,39 @@ class OperatorState {
                const std::vector<TBlob>& inputs,
                const std::vector<OpReqType>& req,
                const std::vector<TBlob>& outputs) {
-    if (!fwd_init_) {
-      CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
-      CHECK_EQ(outputs.size(), out_data_.size());
-      // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
-      // referred by arg_data_ptr_ will be overriden
-      for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i];
-      for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i];
-      for (size_t i = 0; i < aux_data_.size(); ++i) {
-        aux_data_[i] = inputs[i + in_data_fwd_.size()];
-      }
-      for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
-      fwd_init_ = true;
+    CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
+    CHECK_EQ(outputs.size(), out_data_.size());
+    // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
+    // referred by arg_data_ptr_ will be overriden
+    for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i];
+    for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i];
+    for (size_t i = 0; i < aux_data_.size(); ++i) {
+      aux_data_[i] = inputs[i + in_data_fwd_.size()];
     }
+    for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
     opr_->Forward(ctx, in_data_fwd_, req, out_data_, aux_data_);
   }
 
   void Backward(const OpContext &ctx,
                 const std::vector<TBlob>& inputs,
                 const std::vector<OpReqType>& req,
                 const std::vector<TBlob>& outputs) {
-    if (!bwd_init_) {
-      CHECK(fwd_init_);
-      CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size());
-      // override tblobs pointed by arg_data_ptr_ since they might not contain
-      // initialized data during forward pass.
-      for (size_t i = 0; i < arg_data_ptr_.size(); ++i) {
-        *arg_data_ptr_[i] = inputs[i];
-      }
-      for (size_t i = 0; i < aux_data_.size(); ++i) {
-        aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i];
-      }
-      CHECK_EQ(outputs.size(), in_grad_.size());
-      for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i];
-      bwd_init_ = true;
 
 Review comment:
   this caching was first removed in #14738 . I think this has certain performance implications since we are not caching the TBlobs anymore. Is the use case also similar, is this caused by split operator ?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services