This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 19ac41d  Manually check node existence in CachedOp (#11545)
19ac41d is described below

commit 19ac41d4d7fc0666699c59831c688bc38f726eeb
Author: Junru Shao <[email protected]>
AuthorDate: Thu Jul 5 12:31:31 2018 -0700

    Manually check node existence in CachedOp (#11545)
    
    * Manually check node existence in CachedOp
    
    * Fix lint
    
    * Trigger CI
    
    * Improve readability, replace `numeric_limits::max` with `kEidNotExist`
    
    * Add testcase
    
    * Trigger CI
    
    * Remove commented lines in unittests
    
    * Trigger CI
---
 src/imperative/cached_op.cc         | 28 +++++++++++++++++++++++++---
 tests/python/unittest/test_gluon.py | 15 +++++++++++++++
 2 files changed, 40 insertions(+), 3 deletions(-)

diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 5e48c5a..d4da99e 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -29,6 +29,8 @@ namespace mxnet {
 
 DMLC_REGISTER_PARAMETER(CachedOpConfig);
 
+constexpr uint32_t kEidNotExist = std::numeric_limits<uint32_t>::max();
+
 struct CachedOp::GraphInfo {
   nnvm::Graph fwd_graph;
   nnvm::Graph full_graph;
@@ -327,8 +329,12 @@ void SetBackwardInputEid(const std::vector<uint32_t>& 
bwd_in_dep,
                          const nnvm::IndexedGraph& idx,
                          std::vector<uint32_t> *bwd_input_eid) {
   for (const auto& i : bwd_ograd_dep) {
-    auto eid = idx.entry_id(ograd_entries[i]);
-    bwd_input_eid->push_back(eid);
+    auto ograd = ograd_entries[i];
+    if (idx.exist(ograd.node.get())) {
+      bwd_input_eid->push_back(idx.entry_id(ograd));
+    } else {
+      bwd_input_eid->push_back(kEidNotExist);
+    }
   }
   for (const auto& i : bwd_in_dep) {
     auto eid = idx.entry_id(idx.input_nodes()[i], 0);
@@ -381,7 +387,11 @@ bool CachedOp::SetBackwardGraph(
     for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) {
       for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)];
     }
-    for (size_t i = 0; i < inputs.size(); ++i) 
++ref_count[info->bwd_input_eid[i]];
+    for (size_t i = 0; i < inputs.size(); ++i) {
+      if (info->bwd_input_eid[i] != kEidNotExist) {
+        ++ref_count[info->bwd_input_eid[i]];
+      }
+    }
     for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)];
     g.attrs["backward_ref_count"] = 
std::make_shared<dmlc::any>(std::move(ref_count));
   }
@@ -394,6 +404,9 @@ bool CachedOp::SetBackwardGraph(
   stypes.resize(idx.num_node_entries(), -1);
 
   for (size_t i = 0; i < inputs.size(); ++i) {
+    if (info->bwd_input_eid[i] == kEidNotExist) {
+      continue;
+    }
     shapes[info->bwd_input_eid[i]] = inputs[i]->shape();
     dtypes[info->bwd_input_eid[i]] = inputs[i]->dtype();
     stypes[info->bwd_input_eid[i]] = inputs[i]->storage_type();
@@ -896,6 +909,9 @@ void CachedOp::DynamicBackward(
   arrays.reserve(buff.size());
   for (size_t i = 0; i < buff.size(); ++i) arrays.push_back(&buff[i]);
   for (size_t i = 0; i < inputs.size(); ++i) {
+    if (runtime.info.bwd_input_eid[i] == kEidNotExist) {
+      continue;
+    }
     arrays[runtime.info.bwd_input_eid[i]] = inputs[i];
   }
   for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) {
@@ -971,6 +987,9 @@ void CachedOp::StaticBackward(
   auto arrays = state.arrays;
   for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
     auto eid = state.info.bwd_input_eid[i];
+    if (eid == kEidNotExist) {
+      continue;
+    }
     if (state.dynamic_entries[eid]) arrays[eid] = inputs[i];
   }
 
@@ -1104,6 +1123,9 @@ bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& 
attrs,
   // Prepare these to before invoking infer storage on the subgraph
   for (size_t i = 0; i < out_attrs->size(); i++) {
     const auto eid = idx.entry_id(outputs[i + num_forward_outputs]);
+    if (bwd_input_eid[i] == kEidNotExist) {
+      continue;
+    }
     stypes[eid] = out_attrs->at(i);
   }
   exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index 43777bb..67e8b9e 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -1414,6 +1414,21 @@ def test_share_inputs_outputs():
             assert_almost_equal(out_grad.asnumpy(), d2.grad.asnumpy())
 
 
+def test_grad_graph_change():
+    class Model(mx.gluon.HybridBlock):
+        def hybrid_forward(self, F, array, index):
+            row = array.take(index)
+            return row, index
+    array = mx.nd.arange(3)
+    index = mx.nd.array([2])
+    array.attach_grad()
+    model = Model()
+    model.hybridize(inline_limit=0)
+    with mx.autograd.record(train_mode=True):
+        row, _ = model(array, index)
+    row.backward()
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

Reply via email to