You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/21 00:31:37 UTC

[GitHub] piiswrong closed pull request #11306: [MXNET-555] Add subgraph storage type inference to CachedOp

piiswrong closed pull request #11306: [MXNET-555] Add subgraph storage type inference to CachedOp 
URL: https://github.com/apache/incubator-mxnet/pull/11306
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index c0e5e83eb90..5a3d44c04ce 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -22,6 +22,7 @@
 #include "./cached_op.h"
 #include "../executor/exec_pass.h"
 #include "../profiler/profiler.h"
+#include "../operator/operator_common.h"
 
 
 namespace mxnet {
@@ -95,7 +96,6 @@ CachedOp::CachedOp(
   using namespace imperative;
   static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")};
   static const auto _copy = Op::Get("_copy");
-
   config_.Init(flags);
 
   if (config_.static_shape) {
@@ -204,26 +204,17 @@ CachedOp::CachedOp(
     size_t num_forward_outputs = num_outputs();
     for (uint32_t i = 0; i < ograd_entries_.size(); ++i) {
       if (!idx.exist(ograd_entries_[i].node.get())) continue;
-      auto eid = idx.entry_id(ograd_entries_[i]);
-      if (ref_count[eid] > 0) {
-        bwd_ograd_dep_.push_back(i);
-      }
+      bwd_ograd_dep_.push_back(i);
     }
     save_inputs_.resize(num_forward_inputs, false);
     for (uint32_t i = 0; i < num_forward_inputs; ++i) {
-      auto eid = idx.entry_id(idx.input_nodes()[i], 0);
-      if (ref_count[eid] > 0) {
-        save_inputs_[i] = true;
-        bwd_in_dep_.push_back(i);
-      }
+      save_inputs_[i] = true;
+      bwd_in_dep_.push_back(i);
     }
     save_outputs_.resize(idx.outputs().size(), false);
     for (uint32_t i = 0; i < num_forward_outputs; ++i) {
-      auto eid = idx.entry_id(idx.outputs()[i]);
-      if (ref_count[eid] > 0) {
-        save_outputs_[i] = true;
-        bwd_out_dep_.push_back(i);
-      }
+      save_outputs_[i] = true;
+      bwd_out_dep_.push_back(i);
     }
   }
 }
@@ -233,7 +224,7 @@ CachedOp::~CachedOp() {
 
 std::vector<nnvm::NodeEntry> CachedOp::Gradient(
     const nnvm::NodePtr& node,
-    const std::vector<nnvm::NodeEntry>& ograds) {
+    const std::vector<nnvm::NodeEntry>& ograds) const {
   using namespace nnvm;
   static const auto _backward_CachedOp = Op::Get("_backward_CachedOp");
   static const auto _NoGrad = Op::Get("_NoGradient");
@@ -328,6 +319,27 @@ bool CachedOp::SetForwardGraph(
   return false;
 }
 
+// Utility function to set backward input eids
+void SetBackwardInputEid(const std::vector<uint32_t>& bwd_in_dep,
+                         const std::vector<uint32_t>& bwd_out_dep,
+                         const std::vector<uint32_t>& bwd_ograd_dep,
+                         const std::vector<nnvm::NodeEntry>& ograd_entries,
+                         const nnvm::IndexedGraph& idx,
+                         std::vector<uint32_t> *bwd_input_eid) {
+  for (const auto& i : bwd_ograd_dep) {
+    auto eid = idx.entry_id(ograd_entries[i]);
+    bwd_input_eid->push_back(eid);
+  }
+  for (const auto& i : bwd_in_dep) {
+    auto eid = idx.entry_id(idx.input_nodes()[i], 0);
+    bwd_input_eid->push_back(eid);
+  }
+  for (const auto& i : bwd_out_dep) {
+    auto eid = idx.entry_id(idx.outputs()[i]);
+    bwd_input_eid->push_back(eid);
+  }
+}
+
 bool CachedOp::SetBackwardGraph(
     GraphInfo* info,
     const std::vector<OpReqType>& reqs,
@@ -356,18 +368,8 @@ bool CachedOp::SetBackwardGraph(
 
   if (info->bwd_input_eid.size() != inputs.size()) {
     info->bwd_input_eid.clear();
-    for (const auto& i : bwd_ograd_dep_) {
-      auto eid = idx.entry_id(ograd_entries_[i]);
-      info->bwd_input_eid.push_back(eid);
-    }
-    for (const auto& i : bwd_in_dep_) {
-      auto eid = idx.entry_id(idx.input_nodes()[i], 0);
-      info->bwd_input_eid.push_back(eid);
-    }
-    for (const auto& i : bwd_out_dep_) {
-      auto eid = idx.entry_id(idx.outputs()[i]);
-      info->bwd_input_eid.push_back(eid);
-    }
+    SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_,
+                        ograd_entries_, idx, &info->bwd_input_eid);
     CHECK_EQ(inputs.size(), info->bwd_input_eid.size());
   }
 
@@ -1019,6 +1021,79 @@ void CachedOp::Backward(
   Engine::Get()->set_bulk_size(prev_bulk_size);
 }
 
+bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs,
+                                  const int dev_mask,
+                                  DispatchMode* dispatch_mode,
+                                  std::vector<int> *in_attrs,
+                                  std::vector<int> *out_attrs) {
+  using namespace imperative;
+  nnvm::Graph g(fwd_graph_);
+  const auto& idx = g.indexed_graph();
+  const auto &outputs = idx.outputs();
+
+  // Prepare stypes and contexts based on inputs
+  StorageTypeVector storage_type_inputs;
+  storage_type_inputs.reserve(in_attrs->size());
+  for (size_t i = 0; i < in_attrs->size(); ++i) {
+    storage_type_inputs.emplace_back(in_attrs->at(i));
+  }
+  exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
+
+  // Forward graph storage type inference
+  CheckAndInferStorageType(&g, std::move(dev_masks), std::move(storage_type_inputs), true);
+  // Retrieve result and set outputs
+  const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type");
+  for (size_t i = 0; i < out_attrs->size(); i++) {
+    const auto eid = idx.entry_id(outputs[i]);
+    STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]);
+  }
+  DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+  return true;
+}
+
+bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
+                                   const int dev_mask,
+                                   DispatchMode* dispatch_mode,
+                                   std::vector<int> *in_attrs,
+                                   std::vector<int> *out_attrs) {
+  using namespace imperative;
+  nnvm::Graph g(full_graph_);
+  const auto& idx = g.indexed_graph();
+  const auto &outputs = idx.outputs();
+  const size_t num_forward_outputs = fwd_graph_.outputs.size();
+  CHECK_EQ(outputs.size(), num_forward_outputs + out_attrs->size());
+
+  // Construct bwd_input_eid
+  std::vector<uint32_t> bwd_input_eid;
+  SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_,
+                      ograd_entries_, idx, &bwd_input_eid);
+  CHECK_EQ(in_attrs->size(), bwd_input_eid.size());
+
+  // Prepare stypes and contexts based on inputs
+  StorageTypeVector stypes(idx.num_node_entries(), -1);
+  for (size_t i = 0; i < in_attrs->size(); ++i) {
+    stypes[bwd_input_eid[i]] = in_attrs->at(i);
+  }
+  // Some out_attr is known ahead of time (e.g. the grad stype is given by users).
+  // Prepare these to before invoking infer storage on the subgraph
+  for (size_t i = 0; i < out_attrs->size(); i++) {
+    const auto eid = idx.entry_id(outputs[i + num_forward_outputs]);
+    stypes[eid] = out_attrs->at(i);
+  }
+  exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
+
+  // Full graph storage type inference
+  CheckAndInferStorageType(&g, std::move(dev_masks), std::move(stypes), false);
+  // Retrieve result and set outputs
+  const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type");
+  for (size_t i = 0; i < out_attrs->size(); i++) {
+    const auto eid = idx.entry_id(outputs[i + num_forward_outputs]);
+    STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]);
+  }
+  DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+  return true;
+}
+
 
 NNVM_REGISTER_OP(_CachedOp)
 .set_num_inputs([](const NodeAttrs& attrs) {
@@ -1029,6 +1104,14 @@ NNVM_REGISTER_OP(_CachedOp)
     const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
     return op->num_outputs();
   })
+.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs& attrs,
+                                                     const int dev_mask,
+                                                     DispatchMode* dispatch_mode,
+                                                     std::vector<int> *in_attrs,
+                                                     std::vector<int> *out_attrs) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
+  })
 .set_attr<nnvm::FGradient>("FGradient",
   [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
     const CachedOpPtr& op = nnvm::get<CachedOpPtr>(n->attrs.parsed);
@@ -1044,6 +1127,14 @@ NNVM_REGISTER_OP(_backward_CachedOp)
     const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
     return op->num_inputs() - op->mutable_input_nodes().size();
   })
+.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs& attrs,
+                                                     const int dev_mask,
+                                                     DispatchMode* dispatch_mode,
+                                                     std::vector<int> *in_attrs,
+                                                     std::vector<int> *out_attrs) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
+  })
 .set_attr<bool>("TIsLayerOpBackward", true)
 .set_attr<bool>("TIsBackward", true);
 
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
index 60a40c5e4a5..6b94c67a94e 100644
--- a/src/imperative/cached_op.h
+++ b/src/imperative/cached_op.h
@@ -71,13 +71,13 @@ class CachedOp {
       const nnvm::Symbol& sym,
       const std::vector<std::pair<std::string, std::string> >& flags);
   ~CachedOp();
-  uint32_t num_inputs() {
+  uint32_t num_inputs() const {
     return fwd_graph_.indexed_graph().input_nodes().size();
   }
-  uint32_t num_outputs() {
+  uint32_t num_outputs() const {
     return fwd_graph_.outputs.size();
   }
-  uint32_t num_backward_inputs() {
+  uint32_t num_backward_inputs() const {
     return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
   }
   std::vector<bool>& save_inputs() {
@@ -86,12 +86,12 @@ class CachedOp {
   std::vector<bool>& save_outputs() {
     return save_outputs_;
   }
-  const std::unordered_set<uint32_t>& mutable_input_nodes() {
+  const std::unordered_set<uint32_t>& mutable_input_nodes() const {
     return fwd_graph_.indexed_graph().mutable_input_nodes();
   }
   std::vector<nnvm::NodeEntry> Gradient(
       const nnvm::NodePtr& node,
-      const std::vector<nnvm::NodeEntry>& ograds);
+      const std::vector<nnvm::NodeEntry>& ograds) const;
   void Forward(
       const std::shared_ptr<CachedOp>& op_ptr,
       const std::vector<NDArray*>& inputs,
@@ -102,6 +102,20 @@ class CachedOp {
       const std::vector<NDArray*>& inputs,
       const std::vector<OpReqType>& reqs,
       const std::vector<NDArray*>& outputs);
+  // forward storage type inference
+  bool ForwardStorageType(
+      const nnvm::NodeAttrs& attrs,
+      const int dev_mask,
+      DispatchMode* dispatch_mode,
+      std::vector<int> *in_attrs,
+      std::vector<int> *out_attrs);
+  // backward storage type inference
+  bool BackwardStorageType(
+      const nnvm::NodeAttrs& attrs,
+      const int dev_mask,
+      DispatchMode* dispatch_mode,
+      std::vector<int> *in_attrs,
+      std::vector<int> *out_attrs);
 
  private:
   struct GraphInfo;
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 726531d0299..faff5f173fe 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -668,7 +668,6 @@ inline bool CheckAndInferStorageType(nnvm::Graph* p_g, exec::DevMaskVector&& dev
     g.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(storage_types));
     g = exec::InferStorageType(std::move(g));
   }
-
   CHECK_EQ(g.GetAttr<size_t>("storage_type_num_unknown_nodes"), 0U);
   return false;
 }
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index 0a9cd08db81..02130eb32e5 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -256,7 +256,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
  */
 #define STORAGE_TYPE_ASSIGN_CHECK(type_array, index, type)                  \
   {                                                                         \
-    if (!type_assign(&(type_array)[index], type)) {                         \
+    if (!::mxnet::op::type_assign(&(type_array)[index], type)) {            \
       std::ostringstream os;                                                \
       os << "Storage type inconsistent, Provided = "                        \
          << common::stype_string((type_array)[index]) << ','                \
@@ -274,7 +274,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
  */
 #define DISPATCH_MODE_ASSIGN_CHECK(type_array, index, type)                 \
   {                                                                         \
-    if (!dispatch_mode_assign(&(type_array)[index], type)) {                \
+    if (!::mxnet::op::dispatch_mode_assign(&(type_array)[index], type)) {   \
       std::ostringstream os;                                                \
       os << "Dispatch mode inconsistent, Provided = "                       \
          << common::dispatch_mode_string((type_array)[index]) << ','        \
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 6fafb3671ff..cd3cc685bdd 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -1293,6 +1293,63 @@ def test_legacy_save_params():
     model.load_params('test.params', ctx=mx.cpu())
 
 
+@with_seed()
+def test_sparse_hybrid_block_grad():
+    class Embedding(mx.gluon.HybridBlock):
+        def __init__(self, num_tokens, embedding_size):
+            super(Embedding, self).__init__()
+            self.num_tokens = num_tokens
+
+            with self.name_scope():
+                self.embedding = mx.gluon.nn.Embedding(
+                    num_tokens, embedding_size, sparse_grad=True)
+
+        def hybrid_forward(self, F, words):
+            emb = self.embedding(words)
+            return emb + F.ones_like(emb)
+
+    embedding = Embedding(20, 3)
+    embedding.initialize()
+    embedding.hybridize()
+
+    with mx.autograd.record():
+        emb0 = embedding(mx.nd.arange(10)).sum()
+        emb1 = embedding(mx.nd.arange(10)).sum()
+        loss = emb0 + emb1
+    loss.backward()
+    grad = embedding.embedding.weight.grad().asnumpy()
+    assert (grad[:10] == 2).all()
+    assert (grad[10:] == 0).all()
+
+@with_seed()
+def test_sparse_hybrid_block():
+    class Linear(mx.gluon.HybridBlock):
+        def __init__(self, units):
+            super(Linear, self).__init__()
+            with self.name_scope():
+                self.w = self.params.get('w', shape=(units, units))
+
+        def hybrid_forward(self, F, x, w):
+            return F.dot(x, w)
+
+    class SparseBlock(mx.gluon.HybridBlock):
+        def __init__(self, units):
+            super(SparseBlock, self).__init__()
+            with self.name_scope():
+                self.net = Linear(units)
+
+        def hybrid_forward(self, F, x):
+            return self.net(x) * x
+
+    block = SparseBlock(2)
+    block.initialize()
+    block.hybridize()
+    x = mx.nd.ones((2,2)).tostype('csr')
+    with mx.autograd.record():
+        z = block(x) + block(x)
+    z.backward()
+    assert (block.net.w.grad().asnumpy() == 4).all()
+
 def test_hybrid_static_memory_recording():
     net = gluon.model_zoo.vision.get_resnet(
         1, 18, pretrained=True, ctx=mx.context.current_context())


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services