This is an automated email from the ASF dual-hosted git repository.

zhengda pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new d577b6f  [MXNET-1352] Allow dynamic shape in while_loop and if 
conditionals (#14393)
d577b6f is described below

commit d577b6ff8256c9f9bb809d6144686a4e8273581a
Author: Junru Shao <[email protected]>
AuthorDate: Sun May 12 16:01:51 2019 -0700

    [MXNET-1352] Allow dynamic shape in while_loop and if conditionals (#14393)
    
    * Initial commit
    
    * Rebase
    
    * WIP for fixing rebase issues
    
    * WIP for fixing rebase issues
    
    * fix wip
    
    * wip fix
    
    * wip fix
    
    * wip fix
    
    * wip fix
    
    * wip fix
    
    * wip fix
    
    * should be good to go
    
    * wip remove debug info
    
    * wip remove debug info
    
    * linter
    
    * linter
    
    * Retrigger
    
    * Address comments from Da
---
 include/mxnet/ndarray.h                            |   4 +-
 python/mxnet/executor.py                           |   2 +-
 src/executor/graph_executor.cc                     | 215 +++++++++++++++++--
 src/executor/graph_executor.h                      |   2 +
 src/nnvm/plan_memory.cc                            |   2 +-
 src/operator/control_flow.cc                       | 238 +++------------------
 tests/python/unittest/test_contrib_control_flow.py |  30 ++-
 7 files changed, 255 insertions(+), 238 deletions(-)

diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index 05d3fa4..340c380 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -180,7 +180,9 @@ class NDArray {
    * \brief set the correct shape of NDArray directly from the storage_shape 
of its own chunk.
    */
   void SetShapeFromChunk() {
-    shape_ = ptr_->storage_shape;
+    if (!(ptr_->storage_shape.ndim() == 1 && ptr_->storage_shape[0] == 0)) {
+      shape_ = ptr_->storage_shape;
+    }
   }
   /*
    * This indicates whether an array is a view of another array (created by
diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py
index 9dfe636..edc10df 100644
--- a/python/mxnet/executor.py
+++ b/python/mxnet/executor.py
@@ -149,7 +149,7 @@ class Executor(object):
         check_call(_LIB.MXExecutorForward(
             self.handle,
             ctypes.c_int(int(is_train))))
-
+        self.outputs = self._get_outputs()
         return self.outputs
 
     def backward(self, out_grads=None, is_train=True):
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index e726d29..da1f13b 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -44,6 +44,7 @@ using namespace mxnet::common;
 GraphExecutor::GraphExecutor() {
   log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false);
   need_grad_ = false;
+  is_dynamic_ = false;
   subgraph_property_ = dmlc::GetEnv("MXNET_SUBGRAPH_BACKEND", std::string());
   engine_ref_ = Engine::_GetSharedRef();
 }
@@ -76,20 +77,77 @@ void GraphExecutor::PartialForward(bool is_train, int step, 
int *step_left) {
 }
 
 void GraphExecutor::Backward(const std::vector<NDArray>& head_grads, bool 
is_train) {
-  const auto& idx = graph_.indexed_graph();
-  if (num_forward_inputs_ != idx.input_nodes().size()) {
-    for (size_t i = 0; i < head_grad_array_.size(); ++i) {
-      if (!head_grad_array_[i].is_none()) {
-        CHECK(i < head_grads.size() && !head_grads[i].is_none())
-            << "Because the last operator is not Loss function, "
-            << "head_gradient is required when calling backward. "
-            << "If you are attempting to minimize the output as "
-            << "an objective, please modify your network and "
-            << "pass it through the make_loss symbol.";
-        CopyFromTo(head_grads[i], &(head_grad_array_[i]));
+  {
+    const auto& idx = graph_.indexed_graph();
+    if (num_forward_inputs_ != idx.input_nodes().size()) {
+      for (size_t i = 0; i < head_grad_array_.size(); ++i) {
+        if (!head_grad_array_[i].is_none()) {
+          CHECK(i < head_grads.size() && !head_grads[i].is_none())
+              << "Because the last operator is not Loss function, "
+              << "head_gradient is required when calling backward. "
+              << "If you are attempting to minimize the output as "
+              << "an objective, please modify your network and "
+              << "pass it through the make_loss symbol.";
+          const NDArray &from = head_grads[i];
+          NDArray &to = head_grad_array_[i];
+          if (this->is_dynamic_) {
+            to.WaitToRead();
+            if (!shape_is_known(to.shape())) {
+              to.Init(from.shape());
+            }
+          }
+          CopyFromTo(from, &to);
+        }
       }
     }
   }
+  if (this->is_dynamic_) {
+    graph_ = InferShape(std::move(graph_), {}, "");
+    mxnet::ShapeVector rshape = 
graph_.MoveCopyAttr<mxnet::ShapeVector>("shape");
+    const auto& idx = graph_.indexed_graph();
+    for (size_t nid = 0; nid < idx.num_nodes(); ++nid) {
+      const auto& inode = idx[nid];
+      if (inode.source->is_variable()) continue;
+      OpNode& opnode = op_nodes_[nid];
+      if (opnode.skip_exec_node) continue;
+      for (NDArray &array : opnode.exec->in_array) {
+        array.WaitToRead();
+        if (!shape_is_known(array.shape())) {
+          array.SetShapeFromChunk();
+        }
+      }
+      int i = 0;
+      for (NDArray &array : opnode.exec->in_array) {
+        array.WaitToRead();
+        if (!shape_is_known(array.shape())) {
+          array.SetShapeFromChunk();
+        }
+        if (!shape_is_known(array.shape())) {
+          mxnet::TShape shape = rshape[idx.entry_id(inode.inputs[i])];
+          if (shape_is_known(shape)) {
+            array.ReshapeAndAlloc(shape);
+          }
+        }
+        ++i;
+      }
+      i = 0;
+      for (NDArray &array : opnode.exec->out_array) {
+        array.WaitToRead();
+        if (!shape_is_known(array.shape())) {
+          array.SetShapeFromChunk();
+        }
+        if (!shape_is_known(array.shape())) {
+          mxnet::TShape shape = rshape[idx.entry_id(nid, i)];
+          if (shape_is_known(shape)) {
+            array.ReshapeAndAlloc(shape);
+          }
+        }
+        ++i;
+      }
+    }
+    graph_.attrs["shape"] = std::make_shared<dmlc::any>(rshape);
+  }
+  const auto& idx = graph_.indexed_graph();
   RunOps(is_train, num_forward_nodes_, idx.num_nodes());
 }
 
@@ -119,6 +177,14 @@ void GraphExecutor::SetMonitorCallback(const 
MonitorCallback& callback, bool mon
 }
 
 const std::vector<NDArray>& GraphExecutor::outputs() const {
+  if (this->is_dynamic_) {
+    for (const NDArray &array : output_arrays_) {
+      array.WaitToRead();
+      if (!shape_is_known(array.shape())) {
+        const_cast<NDArray &>(array).SetShapeFromChunk();
+      }
+    }
+  }
   return output_arrays_;
 }
 
@@ -381,8 +447,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
   arg_shapes.resize(idx.input_nodes().size(), mxnet::TShape());
   g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
   if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
-    HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
-                          g.GetAttr<mxnet::ShapeVector>("shape"));
+    this->is_dynamic_ = true;
   }
 
   arg_dtypes.resize(idx.input_nodes().size(), -1);
@@ -821,8 +886,7 @@ Executor* GraphExecutor::Reshape(const bool partial_shaping,
   }
   g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
   if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
-    HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
-                          g.GetAttr<mxnet::ShapeVector>("shape"));
+    this->is_dynamic_ = true;
   }
   const mxnet::ShapeVector& shape_vec = g.GetAttr<mxnet::ShapeVector>("shape");
   std::vector<OpReqType> grad_req_types;
@@ -977,14 +1041,16 @@ void 
GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
     uint32_t oid = head_grad_map_.at(idx[nid].source);
     uint32_t eid = idx.entry_id(idx.outputs()[oid]);
     NDArrayStorageType stype = (NDArrayStorageType) vstorage_type[eid];
-    CHECK(mxnet::shape_is_known(vshape[eid]));
+    bool unknown_shape = !shape_is_known(vshape[eid]);
     CHECK_NE(vdtype[eid], -1);
     auto data_eid = idx.entry_id(nid, 0);
     // initialize based on storage_type
     if (stype != kDefaultStorage) {
       data_entry_[data_eid] = NDArray(stype, vshape[eid], data_context[eid], 
true, vdtype[eid]);
-    } else {
+    } else if (!unknown_shape) {
       data_entry_[data_eid] = NDArray(vshape[eid], data_context[eid], false, 
vdtype[eid]);
+    } else {
+      data_entry_[data_eid] = NDArray(data_context[eid], vdtype[eid]);
     }
     if (log_verbose_) {
       LOG(INFO) << "\tinit head_grad entry\t" << data_eid << "\tas "
@@ -994,7 +1060,11 @@ void 
GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
   // get maximum bytes in each pool
   for (size_t i = 0; i < vshape.size(); ++i) {
     if (!data_entry_[i].is_none()) continue;
-    size_t bytes = vshape[i].Size() * mshadow::mshadow_sizeof(vdtype[i]);
+    size_t shape_size = 0;
+    if (shape_is_known(vshape[i])) {
+      shape_size = vshape[i].Size();
+    }
+    size_t bytes = shape_size * mshadow::mshadow_sizeof(vdtype[i]);
     int storage_id = vstorage[i];
     // skip pool allocation for kBadStorageID, kExternalStorageID and 
kDynamicStorageID
     if (storage_id < 0) continue;
@@ -1013,7 +1083,10 @@ void 
GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
   std::multimap<size_t, NDArray> free_pool;
   if (shared_pool != nullptr) {
     for (const NDArray& nd : *shared_pool) {
-      size_t bytes = nd.shape().Size() * mshadow::mshadow_sizeof(nd.dtype());
+      size_t bytes = 0;
+      if (shape_is_known(nd.shape())) {
+        bytes = nd.shape().Size() * mshadow::mshadow_sizeof(nd.dtype());
+      }
       free_pool.insert(std::make_pair(bytes, nd));
     }
   }
@@ -1067,9 +1140,13 @@ void 
GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
     int storage_id = vstorage[i];
     auto storage_type = (NDArrayStorageType) vstorage_type[i];
     if (storage_type == kDefaultStorage) {
-      CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet";
-      const NDArray& src = data_pool_.at(storage_id);
-      data_entry_[i] = src.AsArray(vshape[i], vdtype[i]);
+      if (!shape_is_known(vshape[i])) {
+        data_entry_[i] = NDArray(data_context[i], vdtype[i]);
+      } else {
+        CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet";
+        const NDArray& src = data_pool_.at(storage_id);
+        data_entry_[i] = src.AsArray(vshape[i], vdtype[i]);
+      }
     } else {
       data_entry_[i] = NDArray(storage_type, vshape[i], data_context[i],
                                true, vdtype[i]);
@@ -1209,7 +1286,10 @@ void GraphExecutor::InitOpSegs() {
   const profiler::Profiler *prof = profiler::Profiler::Get();
   bool prefer_bulk_exec_train = Imperative::PreferBulkExecTrain()
                                 && (!prof || !prof->AggregateEnabled());
-
+  if (this->is_dynamic_) {
+    prefer_bulk_exec_inference = false;
+    prefer_bulk_exec_train = false;
+  }
   bool is_training = num_forward_nodes_ != total_num_nodes;
 
   if (prefer_bulk_exec_train && is_training) {
@@ -1300,6 +1380,8 @@ void GraphExecutor::ExecuteMonOutputCallback(size_t nid) {
 }
 
 void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
+  static auto& finfer_shape = 
nnvm::Op::GetAttr<mxnet::FInferShape>("FInferShape");
+  static auto& is_backward = Op::GetAttr<nnvm::TIsBackward>("TIsBackward");
   // Update context
   const auto& idx = graph_.indexed_graph();
   for (size_t nid = topo_start; nid < topo_end; ++nid) {
@@ -1311,6 +1393,7 @@ void GraphExecutor::RunOps(bool is_train, size_t 
topo_start, size_t topo_end) {
     opnode.exec->op_ctx.need_grad = need_grad_;
   }
 
+  mxnet::ShapeVector rshape = graph_.MoveCopyAttr<mxnet::ShapeVector>("shape");
   // Push Ops
   for (size_t nid = topo_start; nid < topo_end; ++nid) {
     auto seg_op = cached_seg_opr_[nid];
@@ -1323,6 +1406,8 @@ void GraphExecutor::RunOps(bool is_train, size_t 
topo_start, size_t topo_end) {
     }
     // Normal mode
     const auto& inode = idx[nid];
+    const uint32_t num_inputs = inode.inputs.size();
+    const uint32_t num_outputs = inode.source->num_outputs();
     if (inode.source->is_variable()) continue;
     OpNode& opnode = op_nodes_[nid];
     if (op_nodes_[nid].skip_exec_node) continue;
@@ -1330,6 +1415,69 @@ void GraphExecutor::RunOps(bool is_train, size_t 
topo_start, size_t topo_end) {
     if (monitor_callback_ && monitor_all_) {
       ExecuteMonInputCallback(nid);
     }
+    if (this->is_dynamic_) {
+      const auto &op = inode.source->op();
+      {
+        for (NDArray &array : opnode.exec->in_array) {
+          array.WaitToRead();
+          if (!shape_is_known(array.shape())) {
+            array.SetShapeFromChunk();
+          }
+        }
+        int i = 0;
+        for (NDArray &array : opnode.exec->out_array) {
+          array.WaitToRead();
+          if (!shape_is_known(array.shape())) {
+            array.SetShapeFromChunk();
+          }
+          if (!shape_is_known(array.shape())) {
+            mxnet::TShape shape = rshape[idx.entry_id(nid, i)];
+            if (shape_is_known(shape)) {
+              array.ReshapeAndAlloc(shape);
+            }
+          }
+          ++i;
+        }
+      }
+      if (finfer_shape.count(op)) {
+        mxnet::ShapeVector in_shapes;
+        mxnet::ShapeVector out_shapes;
+        for (NDArray &array : opnode.exec->in_array) {
+          in_shapes.push_back(array.shape());
+        }
+        for (NDArray &array : opnode.exec->out_array) {
+          out_shapes.push_back(array.shape());
+        }
+        auto finfer = finfer_shape[op];
+        try {
+          bool success = finfer(inode.source->attrs, &in_shapes, &out_shapes);
+          CHECK(success) << "InferShape failed in operator " << 
inode.source->attrs.name;
+        } catch (const std::exception& e) {
+          throw dmlc::Error("Error in operator " + inode.source->attrs.name + 
": " + e.what());
+        }
+        int n_out = out_shapes.size();
+        for (int i = 0; i < n_out; ++i) {
+          NDArray &array = opnode.exec->out_array[i];
+          if (!shape_is_known(array.shape())) {
+            array.Init(out_shapes[i]);
+          }
+        }
+      } else if (is_backward.get(inode.source->op(), false) && 
inode.control_deps.size()) {
+        CHECK_GE(inode.control_deps.size(), 1U) <<
+          "BackwardOp need to have control_deps to its forward op";
+        uint32_t fid = inode.control_deps[0];
+        const OpNode& fopnode = op_nodes_[fid];
+        CHECK_EQ(fopnode.exec->in_array.size(), opnode.exec->out_array.size());
+        int nelem = fopnode.exec->in_array.size();
+        std::vector<NDArray> &from = fopnode.exec->in_array;
+        std::vector<NDArray> &to = opnode.exec->out_array;
+        for (int i = 0; i < nelem; ++i) {
+          if (!shape_is_known(to[i].shape())) {
+            to[i].Init(from[i].shape());
+          }
+        }
+      }
+    }
     opnode.exec->op_ctx.is_train = is_train;
     opnode.exec->op_ctx.need_grad = need_grad_;
     if (opnode.exec->exec_type() == ExecType::kCrossDeviceCopy) {
@@ -1343,14 +1491,35 @@ void GraphExecutor::RunOps(bool is_train, size_t 
topo_start, size_t topo_end) {
     } else if (opnode.cached_opr != nullptr) {
       bool profiling = profiler::Profiler::Get()->GetState() == 
profiler::Profiler::kRunning;
       Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling);
+      if (this->is_dynamic_) {
+        for (NDArray &array : opnode.exec->out_array) {
+          array.WaitToRead();
+          if (!shape_is_known(array.shape())) {
+            array.SetShapeFromChunk();
+          }
+        }
+      }
     } else {
       LOG(FATAL) << "Not accessed";
     }
+    for (uint32_t i = 0; i < num_inputs; ++i) {
+      int eid = idx.entry_id(inode.inputs[i]);
+      if (!shape_is_known(rshape[eid])) {
+        rshape[eid] = opnode.exec->in_array[i].shape();
+      }
+    }
+    for (uint32_t i = 0; i < num_outputs; ++i) {
+      int eid = idx.entry_id(nid, i);
+      if (!shape_is_known(rshape[eid])) {
+        rshape[eid] = opnode.exec->out_array[i].shape();
+      }
+    }
     // Monitor callbacks
     if (monitor_callback_) {
       ExecuteMonOutputCallback(nid);
     }
   }
+  graph_.attrs["shape"] = std::make_shared<dmlc::any>(rshape);
 }
 
 GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t 
topo_start, size_t topo_end) {
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index 9a86609..f150165 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -216,6 +216,8 @@ class GraphExecutor : public Executor {
   void ExecuteMonOutputCallback(size_t nid);
   // peform bulking and segmentation on the region [from_node, up_to_node) of 
a graph
   void BulkOpSegs(size_t from_node, size_t up_to_node, size_t 
segment_num_nodes_max);
+  // When infer shape fails, fall back to ensure dynamic-shaped operators 
executed correctly.
+  bool is_dynamic_;
   // indicate whether there is a backward graph for gradients.
   bool need_grad_;
   // internal graph
diff --git a/src/nnvm/plan_memory.cc b/src/nnvm/plan_memory.cc
index 41b8559..ac70848 100644
--- a/src/nnvm/plan_memory.cc
+++ b/src/nnvm/plan_memory.cc
@@ -268,7 +268,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& 
idx,
       // only request memory for kBadStorageID
       if (storage[eid] == GraphAllocator::kBadStorageID) {
         auto &eshape = shape_vec[eid];
-        size_t esize = eshape.Size();
+        size_t esize = ndim_is_known(shape_vec[eid]) ? eshape.Size() : 0;
         eids.insert(std::make_pair(esize, eid));
       }
     }
diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc
index 4c0d67b..fd087ef 100644
--- a/src/operator/control_flow.cc
+++ b/src/operator/control_flow.cc
@@ -243,7 +243,6 @@ static void ForeachGradComputeExCPU(const OpStatePtr& 
state_ptr,
       // the user to write state gradients to the outputs.
       subg_req[loc] = iter_num != 0 ? kWriteTo : req[i + 
params.in_data_locs.ndim()];
     }
-
     state.Backward(iter_num, subg_ograds, subg_req, subg_igrads);
 
     size_t num_states = subg_ograds.size() - num_output_data;
@@ -579,8 +578,6 @@ static void WhileLoopComputeExCPU(const OpStatePtr& 
state_ptr,
   CHECK_EQ(inputs.size() + 2U, (size_t) params.num_args);
   CHECK_EQ(outputs.size(), (size_t) params.num_outputs);
   CHECK_EQ(outputs.size(), req.size());
-  for (size_t i = 0; i < (size_t) params.num_out_data; i++)
-    CHECK_EQ(params.max_iterations, outputs[i].shape()[0]);
   // construct inputs and outputs for cond
   std::vector<NDArray> cond_inputs, cond_outputs = {NDArray()};
   extract_by_loc(inputs, params.cond_input_locs, &cond_inputs);
@@ -596,20 +593,33 @@ static void WhileLoopComputeExCPU(const OpStatePtr& 
state_ptr,
       break;
     }
     // we create func_outputs for the current step:
-    // func_outputs[0: num_out_data] is a slice of outputs[][step]
-    for (size_t i = 0; i < (size_t) params.num_out_data; ++i) {
-      func_outputs[i] = outputs[i].At(step);
-    }
-    // func_outputs[num_out_data: ] are new_loop_vars, need to allocate new 
memory
-    for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
-      func_outputs[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true, 
outputs[i].dtype());
+    for (size_t i = 0; i < outputs.size(); ++i) {
+      func_outputs[i] = NDArray(outputs[i].ctx(), outputs[i].dtype());
     }
     state.Forward(step, func_inputs, req, func_outputs, ctx.need_grad);
+    if (step == 0) {
+      for (int i = 0; i < params.num_out_data; ++i) {
+        func_outputs[i].WaitToRead();
+        if (!shape_is_known(func_outputs[i].shape())) {
+          func_outputs[i].SetShapeFromChunk();
+        }
+        mxnet::TShape step_shape = func_outputs[i].shape();
+        mxnet::TShape shape(step_shape.ndim() + 1, 0);
+        shape[0] = params.max_iterations;
+        for (int j = 0; j < step_shape.ndim(); ++j) {
+          shape[j + 1] = step_shape[j];
+        }
+        const_cast<NDArray &>(outputs[i]).Init(shape);
+      }
+    }
+    for (int i = 0; i < params.num_out_data; ++i) {
+      NDArray first_slot = outputs[i].At(step);
+      mxnet::CopyFromTo(func_outputs[i], &first_slot);
+    }
     // func_inputs on the next step:
     // the output (new_loop_vars) will become the new inputs (loop_vars)
     for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
-      size_t j = params.func_var_locs[i - params.num_out_data];
-      CHECK_EQ(func_inputs[j].shape(), func_outputs[i].shape());
+      int j = params.func_var_locs[i - params.num_out_data];
       func_inputs[j] = func_outputs[i];
       int k = state.oi_map[i - params.num_out_data];
       if (k != -1) {
@@ -627,8 +637,21 @@ static void WhileLoopComputeExCPU(const OpStatePtr& 
state_ptr,
   // therefore, we copy func_inputs[:] to outputs[num_out_data: ]
   for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
     size_t j = params.func_var_locs[i - params.num_out_data];
+    if (!shape_is_known(outputs[i].shape())) {
+      const_cast<NDArray &>(outputs[i]).Init(func_inputs[j].shape());
+    }
     mxnet::CopyFromTo(func_inputs[j], &outputs[i]);
   }
+  for (int i = 0; i < params.num_out_data; ++i) {
+    const_cast<NDArray &>(outputs[i]).SetShapeFromChunk();
+  }
+  if (state.n_iterations == 0) {
+    for (size_t i = 0; i < outputs.size(); ++i) {
+      if (!shape_is_known(outputs[i].shape())) {
+        const_cast<NDArray &>(outputs[i]).ReshapeAndAlloc({1});
+      }
+    }
+  }
 }
 
 static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr,
@@ -726,108 +749,6 @@ static void WhileLoopGradComputeExCPU(const OpStatePtr& 
state_ptr,
   state.Cleanup();
 }
 
-static bool WhileLoopShape(const nnvm::NodeAttrs& attrs,
-                           mxnet::ShapeVector *in_shape,
-                           mxnet::ShapeVector *out_shape) {
-  using mxnet::ShapeVector;
-  const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
-  static const std::function<bool(const mxnet::TShape &)> is_udf = 
is_shape_udf;
-  // sanity checks
-  CHECK_EQ(in_shape->size() + 2U, (size_t) params.num_args);
-  CHECK_EQ(out_shape->size(), (size_t) params.num_outputs);
-  CHECK_EQ(attrs.subgraphs.size(), 2U);
-  CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
-  // infer shape for cond and func
-  auto infer_subg = [&params, in_shape, out_shape](std::shared_ptr<Symbol> 
subg,
-                                                   ShapeVector *_subg_out,
-                                                   const mxnet::Tuple<dim_t> 
&input_locs,
-                                                   int num_out_data,
-                                                   bool fill_out_shape) {
-    // create subg_in
-    ShapeVector subg_in;
-    ShapeVector &subg_out = *_subg_out;
-    extract_by_loc(*in_shape, input_locs, &subg_in);
-    // create an indexed graph
-    nnvm::Graph g;
-    g.outputs = subg->outputs;
-    const auto& idx = g.indexed_graph();
-    // get input nodes
-    const auto &input_nids = idx.input_nodes();
-    // sanity checks
-    CHECK_EQ(input_nids.size(), subg_in.size());
-    CHECK_EQ(g.outputs.size(), subg_out.size());
-    CHECK_EQ(idx.input_nodes().size(), subg_in.size());
-    CHECK_EQ(idx.outputs().size(), subg_out.size());
-    // create empty shapes for inference
-    ShapeVector shapes(idx.num_node_entries());
-    // copy subg_in into shapes
-    for (size_t i = 0; i < subg_in.size(); ++i) {
-      auto eid = idx.entry_id(input_nids[i], 0);
-      shapes[eid] = subg_in[i];
-    }
-    // copy subg_out into shapes
-    // note that ndim of out_data is not increased
-    // because subg is only one step
-    for (size_t i = 0; i < subg_out.size(); ++i) {
-      auto eid = idx.entry_id(g.outputs[i]);
-      shapes[eid] = subg_out[i];
-    }
-    // copy done, call InferShape
-    g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
-    g = exec::InferShape(std::move(g));
-    // now `shapes' won't be used anymore, use new_shapes instead
-    const auto& new_shapes = g.GetAttr<mxnet::ShapeVector>("shape");
-    // copy subg_in back to in_shape
-    for (size_t i = 0; i < subg_in.size(); ++i) {
-      auto eid = idx.entry_id(input_nids[i], 0);
-      auto g_out_shape = new_shapes[eid];
-      if (!shape_is_known(g_out_shape)) {
-        // when the shape is not fully inferred
-        continue;
-      }
-      SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape);
-    }
-    if (!fill_out_shape) {
-      return true;
-    }
-    // copy subg_out back to out_shape
-    // for results in [0, num_out_data), ndim should increase by 1
-    for (int i = 0; i < num_out_data; ++i) {
-      auto eid = idx.entry_id(g.outputs[i]);
-      auto g_out_shape = new_shapes[eid];
-      if (!shape_is_known(g_out_shape)) {
-        // when the shape is not fully inferred
-        continue;
-      }
-      auto out = mxnet::TShape(g_out_shape.ndim() + 1, -1);
-      out[0] = params.max_iterations;
-      for (int i = 1; i < out.ndim(); i++)
-        out[i] = g_out_shape[i - 1];
-      SHAPE_ASSIGN_CHECK(*out_shape, i, out);
-    }
-    // for results in [num_out_data, ...), ndim does not change
-    for (size_t i = num_out_data; i < g.outputs.size(); ++i) {
-      auto eid = idx.entry_id(g.outputs[i]);
-      auto g_out_shape = new_shapes[eid];
-      if (!shape_is_known(g_out_shape)) {
-        // when the shape is not fully inferred
-        continue;
-      }
-      SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape);
-    }
-    return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
-  };
-  mxnet::ShapeVector cond_out_shape{mxnet::TShape(1, 1)};  // this means: [(1, 
)]
-  mxnet::ShapeVector func_out_shape(params.num_outputs);
-  CHECK(params.sync_in_out(in_shape, out_shape, is_udf));
-  bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, 
params.cond_input_locs, 0, false);
-  CHECK(params.sync_in_out(in_shape, out_shape, is_udf));
-  bool succ_1 = infer_subg(attrs.subgraphs[1], &func_out_shape, \
-                           params.func_input_locs, params.num_out_data, true);
-  CHECK(params.sync_in_out(in_shape, out_shape, is_udf));
-  return succ_0 && succ_1;
-}
-
 static bool WhileLoopType(const nnvm::NodeAttrs& attrs,
                           std::vector<int> *in_type, std::vector<int> 
*out_type) {
   const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
@@ -1033,93 +954,6 @@ static void CondGradComputeExCPU(const OpStatePtr& 
state_ptr,
   loop_state.Cleanup();
 }
 
-static bool CondShape(const nnvm::NodeAttrs& attrs,
-                      mxnet::ShapeVector *in_shape,
-                      mxnet::ShapeVector *out_shape) {
-  using mxnet::ShapeVector;
-  const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
-  static const std::function<bool(const mxnet::TShape &)> is_udf = 
is_shape_udf;
-  // sanity checks
-  CHECK_EQ(in_shape->size() + 3U, (size_t) params.num_args);
-  CHECK_EQ(out_shape->size(), (size_t) params.num_outputs);
-  CHECK_EQ(attrs.subgraphs.size(), 3U);
-  CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
-  CHECK_EQ(attrs.subgraphs[1]->outputs.size(), 
attrs.subgraphs[2]->outputs.size());
-  // infer shape for cond, then and else
-  auto infer_subg = [&params, in_shape, out_shape](std::shared_ptr<Symbol> 
subg,
-                                                   ShapeVector *_subg_out,
-                                                   const mxnet::Tuple<dim_t> 
&input_locs,
-                                                   bool fill_out_shape) {
-    // create subg_in
-    mxnet::ShapeVector subg_in;
-    mxnet::ShapeVector &subg_out = *_subg_out;
-    extract_by_loc(*in_shape, input_locs, &subg_in);
-    // create an indexed graph
-    nnvm::Graph g;
-    g.outputs = subg->outputs;
-    const auto& idx = g.indexed_graph();
-    // get input nodes
-    const auto &input_nids = idx.input_nodes();
-    // sanity checks
-    CHECK_EQ(input_nids.size(), subg_in.size());
-    CHECK_EQ(g.outputs.size(), subg_out.size());
-    CHECK_EQ(idx.input_nodes().size(), subg_in.size());
-    CHECK_EQ(idx.outputs().size(), subg_out.size());
-    // create empty shapes for inference
-    mxnet::ShapeVector shapes(idx.num_node_entries());
-    // copy subg_in into shapes
-    for (size_t i = 0; i < subg_in.size(); ++i) {
-      auto eid = idx.entry_id(input_nids[i], 0);
-      shapes[eid] = subg_in[i];
-    }
-    // copy subg_out into shapes
-    for (size_t i = 0; i < subg_out.size(); ++i) {
-      auto eid = idx.entry_id(g.outputs[i]);
-      shapes[eid] = subg_out[i];
-    }
-    // copy done, call InferShape
-    g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
-    g = exec::InferShape(std::move(g));
-    // now `shapes' won't be used anymore, use new_shapes instead
-    const auto& new_shapes = g.GetAttr<mxnet::ShapeVector>("shape");
-    // copy subg_in back to in_shape
-    for (size_t i = 0; i < subg_in.size(); ++i) {
-      auto eid = idx.entry_id(input_nids[i], 0);
-      auto g_out_shape = new_shapes[eid];
-      if (!shape_is_known(g_out_shape)) {
-        // when the shape is not fully inferred
-        continue;
-      }
-      SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape);
-    }
-    if (!fill_out_shape) {
-      return true;
-    }
-    // copy subg_out back to out_shape
-    for (size_t i = 0; i < g.outputs.size(); ++i) {
-      auto eid = idx.entry_id(g.outputs[i]);
-      auto g_out_shape = new_shapes[eid];
-      if (!shape_is_known(g_out_shape)) {
-        // when the shape is not fully inferred
-        continue;
-      }
-      SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape);
-    }
-    return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
-  };
-  ShapeVector cond_out_shape{mxnet::TShape(1, 1)};  // this means: [(1, )]
-  ShapeVector then_out_shape(params.num_outputs);
-  ShapeVector else_out_shape(params.num_outputs);
-  bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, \
-                           params.cond_input_locs, false);
-  bool succ_1 = infer_subg(attrs.subgraphs[1], &then_out_shape, \
-                           params.then_input_locs, true);
-  bool succ_2 = infer_subg(attrs.subgraphs[2], &else_out_shape, \
-                           params.else_input_locs, true);
-  sync_out_out(&then_out_shape, &else_out_shape, is_udf);
-  return succ_0 && succ_1 && succ_2;
-}
-
 static bool CondType(const nnvm::NodeAttrs& attrs,
                      std::vector<int> *in_type,
                      std::vector<int> *out_type) {
@@ -1342,7 +1176,6 @@ NNVM_REGISTER_OP(_while_loop)
 })
 .set_attr<nnvm::FGradient>("FGradient", WhileLoopGradient)
 .set_attr<FCreateOpState>("FCreateOpState", CreateWhileLoopState)
-.set_attr<mxnet::FInferShape>("FInferShape", WhileLoopShape)
 .set_attr<nnvm::FInferType>("FInferType", WhileLoopType)
 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", WhileLoopComputeExCPU)
 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
@@ -1405,7 +1238,6 @@ NNVM_REGISTER_OP(_cond)
 })
 .set_attr<nnvm::FGradient>("FGradient", CondGradient)
 .set_attr<FCreateOpState>("FCreateOpState", CreateCondState)
-.set_attr<mxnet::FInferShape>("FInferShape", CondShape)
 .set_attr<nnvm::FInferType>("FInferType", CondType)
 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CondComputeExCPU)
 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
diff --git a/tests/python/unittest/test_contrib_control_flow.py 
b/tests/python/unittest/test_contrib_control_flow.py
index dd5a4d6..a93c109 100644
--- a/tests/python/unittest/test_contrib_control_flow.py
+++ b/tests/python/unittest/test_contrib_control_flow.py
@@ -935,14 +935,27 @@ def test_while_loop_rnn():
                     max_iterations=seq_len
                 )
         result = mx.sym.Group(result[0] + result[1][1: ])
-        arg_shapes, _, _ = result.infer_shape(
-            data=(seq_len, batch_size, input_dim),
-            s_0=(batch_size, hidden_dim),
-        )
         rnn_inputs = result.list_inputs()
-        args = {name: _array(arg_shapes[i]) for i, name in 
enumerate(rnn_inputs) if name != "i"}
-        args["i"] = mx.nd.zeros([1])
-        args_grad = {name: _array(arg_shapes[i]) for i, name in 
enumerate(rnn_inputs)}
+        args = {
+            "i": mx.nd.zeros([1]),
+            "data": _array((seq_len, batch_size, input_dim)),
+            "i2h_weight": _array((input_dim * hidden_dim, input_dim)),
+            "i2h_bias": _array((input_dim * hidden_dim, )),
+            "s_0": _array((batch_size, hidden_dim)),
+            "h2h_weight": _array((input_dim * hidden_dim, seq_len)),
+            "h2h_bias": _array((input_dim * hidden_dim, )),
+            "s_1": _array((batch_size, hidden_dim)),
+        }
+        args_grad = {
+            "i": _array([1]),
+            "data": _array((seq_len, batch_size, input_dim)),
+            "i2h_weight": _array((input_dim * hidden_dim, input_dim)),
+            "i2h_bias": _array((input_dim * hidden_dim, )),
+            "s_0": _array((batch_size, hidden_dim)),
+            "h2h_weight": _array((input_dim * hidden_dim, seq_len)),
+            "h2h_bias": _array((input_dim * hidden_dim, )),
+            "s_1": _array((batch_size, hidden_dim)),
+        }
         e_1 = result.bind(ctx=default_context(),
             args={name: array.copy() for name, array in args.items()},
             args_grad={name: array.copy() for name, array in args_grad.items() 
if name != "i"},
@@ -961,9 +974,9 @@ def test_while_loop_rnn():
             args_grad={name: array.copy() for name, array in args_grad.items() 
if name != "i"},
         )
         for case_id in range(100):
-            out_grads = [_array(arr.shape) for arr in e_1.outputs]
             args = {name: array.copy() for name, array in args.items()}
             e_1.forward(is_train=True, **args)
+            out_grads = [_array(arr.shape) for arr in e_1.outputs]
             e_1.backward(out_grads)
             args = {name: array.copy() for name, array in args.items() if name 
!= "i"}
             e_2.forward(is_train=True, **args)
@@ -1178,7 +1191,6 @@ def check_contrib_rnn(cell_type, num_states):
         params2 = layer.collect_params()
         for key, val in orig_params1.items():
             params2[key].set_data(copy.deepcopy(val.data()))
-
         trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03})
         with mx.autograd.record():
             res2 = layer(rnn_data, states)

Reply via email to