This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new d6a8ca7  handle the case that inputs and outputs of a graph share 
NDArrays (#11436)
d6a8ca7 is described below

commit d6a8ca79b7dd8e253fe3b81a345e8f375b6a0b38
Author: Da Zheng <zhengda1...@gmail.com>
AuthorDate: Mon Jul 2 10:42:33 2018 -0700

    handle the case that inputs and outputs of a graph share NDArrays (#11436)
    
    * handle the case that inputs and outputs of a graph share NDArrays
    
    * add test.
    
    * test multiple times.
    
    * don't change the state's array list.
    
    * retrigger
---
 src/imperative/cached_op.cc         | 65 +++++++++++++++++++++++++------------
 src/imperative/cached_op.h          |  1 +
 tests/python/unittest/test_gluon.py | 50 ++++++++++++++++++++++++++++
 3 files changed, 96 insertions(+), 20 deletions(-)

diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 5a3d44c..2181c5c 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -591,6 +591,7 @@ void CachedOp::StaticRunOps(
     const Context& default_ctx,
     const nnvm::Graph& g,
     const OpStatePtr& state_ptr,
+    const std::vector<NDArray *> &state_arrays,
     size_t start_nid,
     size_t end_nid) {
   static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
@@ -624,7 +625,7 @@ void CachedOp::StaticRunOps(
       ndinputs.clear();
       ndinputs.reserve(node.inputs.size());
       for (const auto& j : node.inputs) {
-        ndinputs.emplace_back(state.arrays[idx.entry_id(j)]);
+        ndinputs.emplace_back(state_arrays[idx.entry_id(j)]);
         CHECK(!ndinputs.back()->is_none());
       }
       ndoutputs.clear();
@@ -633,7 +634,7 @@ void CachedOp::StaticRunOps(
       req.reserve(num_outputs);
       for (size_t j = 0; j < num_outputs; ++j) {
         size_t eid = idx.entry_id(i, j);
-        ndoutputs.emplace_back(state.arrays[eid]);
+        ndoutputs.emplace_back(state_arrays[eid]);
         req.push_back(state.array_reqs[eid]);
         CHECK(req.back() == kNullOp || !ndoutputs.back()->is_none());
       }
@@ -688,25 +689,29 @@ OpStatePtr CachedOp::StaticForward(
     StaticAllocMemory(state_ptr, recording, false);
   }
 
+  // We are going to add input and output arrays to the array list.
+  // The input and output arrays should only be valid for this run,
+  // so we shouldn't modify the state's array list.
+  auto arrays = state.arrays;
   if (config_.static_shape) {
     for (auto i : config_.param_indices) {
       auto nid = idx.input_nodes()[i];
-      if (!state.arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[i])) {
+      if (!arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[i])) {
         match = false;
         auto ptr = &state.buff[idx.entry_id(nid, 0)];
-        CHECK_EQ(state.arrays[idx.entry_id(nid, 0)], ptr);
-        *state.arrays[idx.entry_id(nid, 0)] = *inputs[i];
+        CHECK_EQ(arrays[idx.entry_id(nid, 0)], ptr);
+        *arrays[idx.entry_id(nid, 0)] = *inputs[i];
         state.dynamic_entries[idx.entry_id(nid, 0)] = false;
       }
     }
     for (auto i : config_.data_indices) {
       auto eid = idx.entry_id(idx.input_nodes()[i], 0);
-      state.arrays[eid] = inputs[i];
+      arrays[eid] = inputs[i];
     }
   } else {
     for (size_t i = 0; i < num_inputs(); ++i) {
       auto nid = idx.input_nodes()[i];
-      state.arrays[idx.entry_id(nid, 0)] = inputs[i];
+      arrays[idx.entry_id(nid, 0)] = inputs[i];
     }
   }
 
@@ -720,13 +725,16 @@ OpStatePtr CachedOp::StaticForward(
 
   for (size_t i = 0; i < outputs.size(); ++i) {
     auto eid = idx.entry_id(idx.outputs()[i]);
-    state.arrays[eid] = outputs[i];
+    // An input and an output may share the same array.
+    if (!arrays[eid]->is_none())
+      *outputs[i] = arrays[eid]->Detach();
+    arrays[eid] = outputs[i];
     if (!outputs[i]->is_none()) continue;
     *outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
                           shapes[eid], default_ctx, true, dtypes[eid]);
   }
 
-  StaticRunOps(default_ctx, g, state_ptr, 0, idx.num_nodes());
+  StaticRunOps(default_ctx, g, state_ptr, arrays, 0, idx.num_nodes());
 
   return recording ? state_ptr : OpStatePtr();
 }
@@ -891,7 +899,11 @@ void CachedOp::DynamicBackward(
   }
   for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) {
     if (reqs[i] == kNullOp) continue;
-    arrays[idx.entry_id(idx.outputs()[j++])] = outputs[i];
+    auto eid = idx.entry_id(idx.outputs()[j++]);
+    // An input and an output may share the same array.
+    if (!arrays[eid]->is_none())
+      *outputs[i] = arrays[eid]->Detach();
+    arrays[eid] = outputs[i];
   }
 
   // Allocate NDArrays
@@ -952,6 +964,15 @@ void CachedOp::StaticBackward(
     StaticAllocMemory(state_ptr, true, true);
   }
 
+  // We are going to add input and output arrays to the array list.
+  // The input and output arrays should only be valid for this run,
+  // so we shouldn't modify the state's array list.
+  auto arrays = state.arrays;
+  for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
+    auto eid = state.info.bwd_input_eid[i];
+    if (state.dynamic_entries[eid]) arrays[eid] = inputs[i];
+  }
+
   if (config_.static_shape) {
     for (auto i : config_.param_indices) {
       const auto iter = fwd_input_to_grad_output_.find(i);
@@ -959,11 +980,14 @@ void CachedOp::StaticBackward(
       auto entry = grad_graph_.outputs[iter->second];
       if (!idx.exist(entry.node.get())) continue;
       auto eid = idx.entry_id(entry);
-      if (!state.arrays[eid]->IsSame(*outputs[iter->second]) ||
+      if (!arrays[eid]->IsSame(*outputs[iter->second]) ||
           !(state.array_reqs[eid] == reqs[iter->second])) {
         match = false;
         state.array_reqs[eid] = reqs[iter->second];
-        *state.arrays[eid] = *outputs[iter->second];
+        // An input and an output may share the same array.
+        if (!arrays[eid]->is_none())
+          *outputs[iter->second] = arrays[eid]->Detach();
+        *arrays[eid] = *outputs[iter->second];
         state.dynamic_entries[eid] = false;
       }
     }
@@ -974,7 +998,10 @@ void CachedOp::StaticBackward(
       if (!idx.exist(entry.node.get())) continue;
       auto eid = idx.entry_id(entry);
       state.array_reqs[eid] = reqs[iter->second];
-      state.arrays[eid] = outputs[iter->second];
+      // An input and an output may share the same array.
+      if (!arrays[eid]->is_none())
+        *outputs[iter->second] = arrays[eid]->Detach();
+      arrays[eid] = outputs[iter->second];
     }
   } else {
     for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
@@ -982,7 +1009,10 @@ void CachedOp::StaticBackward(
       if (!idx.exist(entry.node.get())) continue;
       auto eid = idx.entry_id(entry);
       state.array_reqs[eid] = reqs[i];
-      state.arrays[eid] = outputs[i];
+      // An input and an output may share the same array.
+      if (!arrays[eid]->is_none())
+        *outputs[i] = arrays[eid]->Detach();
+      arrays[eid] = outputs[i];
     }
   }
 
@@ -990,12 +1020,7 @@ void CachedOp::StaticBackward(
     StaticInitExec(state_ptr, true, true);
   }
 
-  for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
-    auto eid = state.info.bwd_input_eid[i];
-    if (state.dynamic_entries[eid]) state.arrays[eid] = inputs[i];
-  }
-
-  StaticRunOps(default_ctx, g, state_ptr, num_forward_nodes, idx.num_nodes());
+  StaticRunOps(default_ctx, g, state_ptr, arrays, num_forward_nodes, 
idx.num_nodes());
 }
 
 void CachedOp::Backward(
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
index 6b94c67..370ef02 100644
--- a/src/imperative/cached_op.h
+++ b/src/imperative/cached_op.h
@@ -154,6 +154,7 @@ class CachedOp {
       const Context& default_ctx,
       const nnvm::Graph& g,
       const OpStatePtr& state_ptr,
+      const std::vector<NDArray *> &state_arrays,
       size_t start_nid,
       size_t end_nid);
   OpStatePtr StaticForward(
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index 4c552da..43777bb 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -1364,6 +1364,56 @@ def test_hybrid_static_memory_recording():
     net(x)
 
 
+def test_share_inputs_outputs():
+    class TestIOBackward(gluon.HybridBlock):
+        def __init__(self, prefix=None, params=None):
+            super(TestIOBackward, self).__init__(prefix=prefix, params=params)
+
+        def hybrid_forward(self, F, in1, in2):
+            return in1 + in2
+
+    class TestIOForward(gluon.HybridBlock):
+        def __init__(self, prefix=None, params=None):
+            super(TestIOForward, self).__init__(prefix=prefix, params=params)
+
+        def hybrid_forward(self, F, in1):
+            return in1
+
+    d1 = mx.nd.arange(10)
+    d2 = mx.nd.arange(10)
+
+    params=[{'inline_limit':0},
+            {'inline_limit':0, 'static_alloc':True},
+            {'inline_limit':0, 'static_alloc':True, 'static_shape':True}]
+    # Test the case that inputs and outputs of a forward graph share NDArrays.
+    for param in params:
+        t = TestIOForward()
+        t.hybridize(**param)
+        for i in range(5):
+            d1.attach_grad()
+            out_grad = mx.nd.random.uniform(shape=(10))
+            res = t(d1)
+            assert_almost_equal(res.asnumpy(), d1.asnumpy())
+
+    param = deepcopy(params[2])
+    param['param_indices'] = (1)
+    param['data_indices'] = (0)
+    params.append(param)
+    # Test the case that inputs and outputs of a backward graph share NDArrays.
+    for param in params:
+        t = TestIOBackward()
+        t.hybridize(**param)
+        for i in range(5):
+            d1.attach_grad()
+            d2.attach_grad()
+            out_grad = mx.nd.random.uniform(shape=(10))
+            with mx.autograd.record():
+                res = t(d1, d2)
+            res.backward(out_grad=out_grad)
+            assert_almost_equal(out_grad.asnumpy(), d1.grad.asnumpy())
+            assert_almost_equal(out_grad.asnumpy(), d2.grad.asnumpy())
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

Reply via email to