This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7a83a75  [MXNET-555] Add subgraph storage type inference to CachedOp  
(#11306)
7a83a75 is described below

commit 7a83a7589e46cae1ec98d43c71d001797ad57b08
Author: Haibin Lin <[email protected]>
AuthorDate: Wed Jun 20 17:31:35 2018 -0700

    [MXNET-555] Add subgraph storage type inference to CachedOp  (#11306)
    
    * copy paste
    
    * pass unit test
    
    * remove lock
    
    * save all inputs and outputs
    
    * add one more test
    
    * update test
    
    * update backward stype inference
    
    * + fwd inference
---
 src/imperative/cached_op.cc         | 147 +++++++++++++++++++++++++++++-------
 src/imperative/cached_op.h          |  24 ++++--
 src/imperative/imperative_utils.h   |   1 -
 src/operator/operator_common.h      |   4 +-
 tests/python/unittest/test_gluon.py |  57 ++++++++++++++
 5 files changed, 197 insertions(+), 36 deletions(-)

diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index c0e5e83..5a3d44c 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -22,6 +22,7 @@
 #include "./cached_op.h"
 #include "../executor/exec_pass.h"
 #include "../profiler/profiler.h"
+#include "../operator/operator_common.h"
 
 
 namespace mxnet {
@@ -95,7 +96,6 @@ CachedOp::CachedOp(
   using namespace imperative;
   static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), 
Op::Get("_zeros")};
   static const auto _copy = Op::Get("_copy");
-
   config_.Init(flags);
 
   if (config_.static_shape) {
@@ -204,26 +204,17 @@ CachedOp::CachedOp(
     size_t num_forward_outputs = num_outputs();
     for (uint32_t i = 0; i < ograd_entries_.size(); ++i) {
       if (!idx.exist(ograd_entries_[i].node.get())) continue;
-      auto eid = idx.entry_id(ograd_entries_[i]);
-      if (ref_count[eid] > 0) {
-        bwd_ograd_dep_.push_back(i);
-      }
+      bwd_ograd_dep_.push_back(i);
     }
     save_inputs_.resize(num_forward_inputs, false);
     for (uint32_t i = 0; i < num_forward_inputs; ++i) {
-      auto eid = idx.entry_id(idx.input_nodes()[i], 0);
-      if (ref_count[eid] > 0) {
-        save_inputs_[i] = true;
-        bwd_in_dep_.push_back(i);
-      }
+      save_inputs_[i] = true;
+      bwd_in_dep_.push_back(i);
     }
     save_outputs_.resize(idx.outputs().size(), false);
     for (uint32_t i = 0; i < num_forward_outputs; ++i) {
-      auto eid = idx.entry_id(idx.outputs()[i]);
-      if (ref_count[eid] > 0) {
-        save_outputs_[i] = true;
-        bwd_out_dep_.push_back(i);
-      }
+      save_outputs_[i] = true;
+      bwd_out_dep_.push_back(i);
     }
   }
 }
@@ -233,7 +224,7 @@ CachedOp::~CachedOp() {
 
 std::vector<nnvm::NodeEntry> CachedOp::Gradient(
     const nnvm::NodePtr& node,
-    const std::vector<nnvm::NodeEntry>& ograds) {
+    const std::vector<nnvm::NodeEntry>& ograds) const {
   using namespace nnvm;
   static const auto _backward_CachedOp = Op::Get("_backward_CachedOp");
   static const auto _NoGrad = Op::Get("_NoGradient");
@@ -328,6 +319,27 @@ bool CachedOp::SetForwardGraph(
   return false;
 }
 
+// Utility function to set backward input eids
+void SetBackwardInputEid(const std::vector<uint32_t>& bwd_in_dep,
+                         const std::vector<uint32_t>& bwd_out_dep,
+                         const std::vector<uint32_t>& bwd_ograd_dep,
+                         const std::vector<nnvm::NodeEntry>& ograd_entries,
+                         const nnvm::IndexedGraph& idx,
+                         std::vector<uint32_t> *bwd_input_eid) {
+  for (const auto& i : bwd_ograd_dep) {
+    auto eid = idx.entry_id(ograd_entries[i]);
+    bwd_input_eid->push_back(eid);
+  }
+  for (const auto& i : bwd_in_dep) {
+    auto eid = idx.entry_id(idx.input_nodes()[i], 0);
+    bwd_input_eid->push_back(eid);
+  }
+  for (const auto& i : bwd_out_dep) {
+    auto eid = idx.entry_id(idx.outputs()[i]);
+    bwd_input_eid->push_back(eid);
+  }
+}
+
 bool CachedOp::SetBackwardGraph(
     GraphInfo* info,
     const std::vector<OpReqType>& reqs,
@@ -356,18 +368,8 @@ bool CachedOp::SetBackwardGraph(
 
   if (info->bwd_input_eid.size() != inputs.size()) {
     info->bwd_input_eid.clear();
-    for (const auto& i : bwd_ograd_dep_) {
-      auto eid = idx.entry_id(ograd_entries_[i]);
-      info->bwd_input_eid.push_back(eid);
-    }
-    for (const auto& i : bwd_in_dep_) {
-      auto eid = idx.entry_id(idx.input_nodes()[i], 0);
-      info->bwd_input_eid.push_back(eid);
-    }
-    for (const auto& i : bwd_out_dep_) {
-      auto eid = idx.entry_id(idx.outputs()[i]);
-      info->bwd_input_eid.push_back(eid);
-    }
+    SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_,
+                        ograd_entries_, idx, &info->bwd_input_eid);
     CHECK_EQ(inputs.size(), info->bwd_input_eid.size());
   }
 
@@ -1019,6 +1021,79 @@ void CachedOp::Backward(
   Engine::Get()->set_bulk_size(prev_bulk_size);
 }
 
+bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs,
+                                  const int dev_mask,
+                                  DispatchMode* dispatch_mode,
+                                  std::vector<int> *in_attrs,
+                                  std::vector<int> *out_attrs) {
+  using namespace imperative;
+  nnvm::Graph g(fwd_graph_);
+  const auto& idx = g.indexed_graph();
+  const auto &outputs = idx.outputs();
+
+  // Prepare stypes and contexts based on inputs
+  StorageTypeVector storage_type_inputs;
+  storage_type_inputs.reserve(in_attrs->size());
+  for (size_t i = 0; i < in_attrs->size(); ++i) {
+    storage_type_inputs.emplace_back(in_attrs->at(i));
+  }
+  exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
+
+  // Forward graph storage type inference
+  CheckAndInferStorageType(&g, std::move(dev_masks), 
std::move(storage_type_inputs), true);
+  // Retrieve result and set outputs
+  const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type");
+  for (size_t i = 0; i < out_attrs->size(); i++) {
+    const auto eid = idx.entry_id(outputs[i]);
+    STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]);
+  }
+  DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+  return true;
+}
+
+bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
+                                   const int dev_mask,
+                                   DispatchMode* dispatch_mode,
+                                   std::vector<int> *in_attrs,
+                                   std::vector<int> *out_attrs) {
+  using namespace imperative;
+  nnvm::Graph g(full_graph_);
+  const auto& idx = g.indexed_graph();
+  const auto &outputs = idx.outputs();
+  const size_t num_forward_outputs = fwd_graph_.outputs.size();
+  CHECK_EQ(outputs.size(), num_forward_outputs + out_attrs->size());
+
+  // Construct bwd_input_eid
+  std::vector<uint32_t> bwd_input_eid;
+  SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_,
+                      ograd_entries_, idx, &bwd_input_eid);
+  CHECK_EQ(in_attrs->size(), bwd_input_eid.size());
+
+  // Prepare stypes and contexts based on inputs
+  StorageTypeVector stypes(idx.num_node_entries(), -1);
+  for (size_t i = 0; i < in_attrs->size(); ++i) {
+    stypes[bwd_input_eid[i]] = in_attrs->at(i);
+  }
+  // Some out_attr is known ahead of time (e.g. the grad stype is given by 
users).
+  // Prepare these to before invoking infer storage on the subgraph
+  for (size_t i = 0; i < out_attrs->size(); i++) {
+    const auto eid = idx.entry_id(outputs[i + num_forward_outputs]);
+    stypes[eid] = out_attrs->at(i);
+  }
+  exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
+
+  // Full graph storage type inference
+  CheckAndInferStorageType(&g, std::move(dev_masks), std::move(stypes), false);
+  // Retrieve result and set outputs
+  const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type");
+  for (size_t i = 0; i < out_attrs->size(); i++) {
+    const auto eid = idx.entry_id(outputs[i + num_forward_outputs]);
+    STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]);
+  }
+  DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+  return true;
+}
+
 
 NNVM_REGISTER_OP(_CachedOp)
 .set_num_inputs([](const NodeAttrs& attrs) {
@@ -1029,6 +1104,14 @@ NNVM_REGISTER_OP(_CachedOp)
     const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
     return op->num_outputs();
   })
+.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs& 
attrs,
+                                                     const int dev_mask,
+                                                     DispatchMode* 
dispatch_mode,
+                                                     std::vector<int> 
*in_attrs,
+                                                     std::vector<int> 
*out_attrs) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, 
out_attrs);
+  })
 .set_attr<nnvm::FGradient>("FGradient",
   [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
     const CachedOpPtr& op = nnvm::get<CachedOpPtr>(n->attrs.parsed);
@@ -1044,6 +1127,14 @@ NNVM_REGISTER_OP(_backward_CachedOp)
     const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
     return op->num_inputs() - op->mutable_input_nodes().size();
   })
+.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs& 
attrs,
+                                                     const int dev_mask,
+                                                     DispatchMode* 
dispatch_mode,
+                                                     std::vector<int> 
*in_attrs,
+                                                     std::vector<int> 
*out_attrs) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, 
out_attrs);
+  })
 .set_attr<bool>("TIsLayerOpBackward", true)
 .set_attr<bool>("TIsBackward", true);
 
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
index 60a40c5..6b94c67 100644
--- a/src/imperative/cached_op.h
+++ b/src/imperative/cached_op.h
@@ -71,13 +71,13 @@ class CachedOp {
       const nnvm::Symbol& sym,
       const std::vector<std::pair<std::string, std::string> >& flags);
   ~CachedOp();
-  uint32_t num_inputs() {
+  uint32_t num_inputs() const {
     return fwd_graph_.indexed_graph().input_nodes().size();
   }
-  uint32_t num_outputs() {
+  uint32_t num_outputs() const {
     return fwd_graph_.outputs.size();
   }
-  uint32_t num_backward_inputs() {
+  uint32_t num_backward_inputs() const {
     return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
   }
   std::vector<bool>& save_inputs() {
@@ -86,12 +86,12 @@ class CachedOp {
   std::vector<bool>& save_outputs() {
     return save_outputs_;
   }
-  const std::unordered_set<uint32_t>& mutable_input_nodes() {
+  const std::unordered_set<uint32_t>& mutable_input_nodes() const {
     return fwd_graph_.indexed_graph().mutable_input_nodes();
   }
   std::vector<nnvm::NodeEntry> Gradient(
       const nnvm::NodePtr& node,
-      const std::vector<nnvm::NodeEntry>& ograds);
+      const std::vector<nnvm::NodeEntry>& ograds) const;
   void Forward(
       const std::shared_ptr<CachedOp>& op_ptr,
       const std::vector<NDArray*>& inputs,
@@ -102,6 +102,20 @@ class CachedOp {
       const std::vector<NDArray*>& inputs,
       const std::vector<OpReqType>& reqs,
       const std::vector<NDArray*>& outputs);
+  // forward storage type inference
+  bool ForwardStorageType(
+      const nnvm::NodeAttrs& attrs,
+      const int dev_mask,
+      DispatchMode* dispatch_mode,
+      std::vector<int> *in_attrs,
+      std::vector<int> *out_attrs);
+  // backward storage type inference
+  bool BackwardStorageType(
+      const nnvm::NodeAttrs& attrs,
+      const int dev_mask,
+      DispatchMode* dispatch_mode,
+      std::vector<int> *in_attrs,
+      std::vector<int> *out_attrs);
 
  private:
   struct GraphInfo;
diff --git a/src/imperative/imperative_utils.h 
b/src/imperative/imperative_utils.h
index 726531d..faff5f1 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -668,7 +668,6 @@ inline bool CheckAndInferStorageType(nnvm::Graph* p_g, 
exec::DevMaskVector&& dev
     g.attrs["storage_type"] = 
std::make_shared<dmlc::any>(std::move(storage_types));
     g = exec::InferStorageType(std::move(g));
   }
-
   CHECK_EQ(g.GetAttr<size_t>("storage_type_num_unknown_nodes"), 0U);
   return false;
 }
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index 0a9cd08..02130eb 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -256,7 +256,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const 
DispatchMode& x) {
  */
 #define STORAGE_TYPE_ASSIGN_CHECK(type_array, index, type)                  \
   {                                                                         \
-    if (!type_assign(&(type_array)[index], type)) {                         \
+    if (!::mxnet::op::type_assign(&(type_array)[index], type)) {            \
       std::ostringstream os;                                                \
       os << "Storage type inconsistent, Provided = "                        \
          << common::stype_string((type_array)[index]) << ','                \
@@ -274,7 +274,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const 
DispatchMode& x) {
  */
 #define DISPATCH_MODE_ASSIGN_CHECK(type_array, index, type)                 \
   {                                                                         \
-    if (!dispatch_mode_assign(&(type_array)[index], type)) {                \
+    if (!::mxnet::op::dispatch_mode_assign(&(type_array)[index], type)) {   \
       std::ostringstream os;                                                \
       os << "Dispatch mode inconsistent, Provided = "                       \
          << common::dispatch_mode_string((type_array)[index]) << ','        \
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index 6fafb36..cd3cc68 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -1293,6 +1293,63 @@ def test_legacy_save_params():
     model.load_params('test.params', ctx=mx.cpu())
 
 
+@with_seed()
+def test_sparse_hybrid_block_grad():
+    class Embedding(mx.gluon.HybridBlock):
+        def __init__(self, num_tokens, embedding_size):
+            super(Embedding, self).__init__()
+            self.num_tokens = num_tokens
+
+            with self.name_scope():
+                self.embedding = mx.gluon.nn.Embedding(
+                    num_tokens, embedding_size, sparse_grad=True)
+
+        def hybrid_forward(self, F, words):
+            emb = self.embedding(words)
+            return emb + F.ones_like(emb)
+
+    embedding = Embedding(20, 3)
+    embedding.initialize()
+    embedding.hybridize()
+
+    with mx.autograd.record():
+        emb0 = embedding(mx.nd.arange(10)).sum()
+        emb1 = embedding(mx.nd.arange(10)).sum()
+        loss = emb0 + emb1
+    loss.backward()
+    grad = embedding.embedding.weight.grad().asnumpy()
+    assert (grad[:10] == 2).all()
+    assert (grad[10:] == 0).all()
+
+@with_seed()
+def test_sparse_hybrid_block():
+    class Linear(mx.gluon.HybridBlock):
+        def __init__(self, units):
+            super(Linear, self).__init__()
+            with self.name_scope():
+                self.w = self.params.get('w', shape=(units, units))
+
+        def hybrid_forward(self, F, x, w):
+            return F.dot(x, w)
+
+    class SparseBlock(mx.gluon.HybridBlock):
+        def __init__(self, units):
+            super(SparseBlock, self).__init__()
+            with self.name_scope():
+                self.net = Linear(units)
+
+        def hybrid_forward(self, F, x):
+            return self.net(x) * x
+
+    block = SparseBlock(2)
+    block.initialize()
+    block.hybridize()
+    x = mx.nd.ones((2,2)).tostype('csr')
+    with mx.autograd.record():
+        z = block(x) + block(x)
+    z.backward()
+    assert (block.net.w.grad().asnumpy() == 4).all()
+
 def test_hybrid_static_memory_recording():
     net = gluon.model_zoo.vision.get_resnet(
         1, 18, pretrained=True, ctx=mx.context.current_context())

Reply via email to