This is an automated email from the ASF dual-hosted git repository.
zhengda pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new d577b6f [MXNET-1352] Allow dynamic shape in while_loop and if
conditionals (#14393)
d577b6f is described below
commit d577b6ff8256c9f9bb809d6144686a4e8273581a
Author: Junru Shao <[email protected]>
AuthorDate: Sun May 12 16:01:51 2019 -0700
[MXNET-1352] Allow dynamic shape in while_loop and if conditionals (#14393)
* Initial commit
* Rebase
* WIP for fixing rebase issues
* WIP for fixing rebase issues
* fix wip
* wip fix
* wip fix
* wip fix
* wip fix
* wip fix
* wip fix
* should be good to go
* wip remove debug info
* wip remove debug info
* linter
* linter
* Retrigger
* Address comments from Da
---
include/mxnet/ndarray.h | 4 +-
python/mxnet/executor.py | 2 +-
src/executor/graph_executor.cc | 215 +++++++++++++++++--
src/executor/graph_executor.h | 2 +
src/nnvm/plan_memory.cc | 2 +-
src/operator/control_flow.cc | 238 +++------------------
tests/python/unittest/test_contrib_control_flow.py | 30 ++-
7 files changed, 255 insertions(+), 238 deletions(-)
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index 05d3fa4..340c380 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -180,7 +180,9 @@ class NDArray {
* \brief set the correct shape of NDArray directly from the storage_shape
of its own chunk.
*/
void SetShapeFromChunk() {
- shape_ = ptr_->storage_shape;
+ if (!(ptr_->storage_shape.ndim() == 1 && ptr_->storage_shape[0] == 0)) {
+ shape_ = ptr_->storage_shape;
+ }
}
/*
* This indicates whether an array is a view of another array (created by
diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py
index 9dfe636..edc10df 100644
--- a/python/mxnet/executor.py
+++ b/python/mxnet/executor.py
@@ -149,7 +149,7 @@ class Executor(object):
check_call(_LIB.MXExecutorForward(
self.handle,
ctypes.c_int(int(is_train))))
-
+ self.outputs = self._get_outputs()
return self.outputs
def backward(self, out_grads=None, is_train=True):
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index e726d29..da1f13b 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -44,6 +44,7 @@ using namespace mxnet::common;
GraphExecutor::GraphExecutor() {
log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false);
need_grad_ = false;
+ is_dynamic_ = false;
subgraph_property_ = dmlc::GetEnv("MXNET_SUBGRAPH_BACKEND", std::string());
engine_ref_ = Engine::_GetSharedRef();
}
@@ -76,20 +77,77 @@ void GraphExecutor::PartialForward(bool is_train, int step,
int *step_left) {
}
void GraphExecutor::Backward(const std::vector<NDArray>& head_grads, bool
is_train) {
- const auto& idx = graph_.indexed_graph();
- if (num_forward_inputs_ != idx.input_nodes().size()) {
- for (size_t i = 0; i < head_grad_array_.size(); ++i) {
- if (!head_grad_array_[i].is_none()) {
- CHECK(i < head_grads.size() && !head_grads[i].is_none())
- << "Because the last operator is not Loss function, "
- << "head_gradient is required when calling backward. "
- << "If you are attempting to minimize the output as "
- << "an objective, please modify your network and "
- << "pass it through the make_loss symbol.";
- CopyFromTo(head_grads[i], &(head_grad_array_[i]));
+ {
+ const auto& idx = graph_.indexed_graph();
+ if (num_forward_inputs_ != idx.input_nodes().size()) {
+ for (size_t i = 0; i < head_grad_array_.size(); ++i) {
+ if (!head_grad_array_[i].is_none()) {
+ CHECK(i < head_grads.size() && !head_grads[i].is_none())
+ << "Because the last operator is not Loss function, "
+ << "head_gradient is required when calling backward. "
+ << "If you are attempting to minimize the output as "
+ << "an objective, please modify your network and "
+ << "pass it through the make_loss symbol.";
+ const NDArray &from = head_grads[i];
+ NDArray &to = head_grad_array_[i];
+ if (this->is_dynamic_) {
+ to.WaitToRead();
+ if (!shape_is_known(to.shape())) {
+ to.Init(from.shape());
+ }
+ }
+ CopyFromTo(from, &to);
+ }
}
}
}
+ if (this->is_dynamic_) {
+ graph_ = InferShape(std::move(graph_), {}, "");
+ mxnet::ShapeVector rshape =
graph_.MoveCopyAttr<mxnet::ShapeVector>("shape");
+ const auto& idx = graph_.indexed_graph();
+ for (size_t nid = 0; nid < idx.num_nodes(); ++nid) {
+ const auto& inode = idx[nid];
+ if (inode.source->is_variable()) continue;
+ OpNode& opnode = op_nodes_[nid];
+ if (opnode.skip_exec_node) continue;
+ for (NDArray &array : opnode.exec->in_array) {
+ array.WaitToRead();
+ if (!shape_is_known(array.shape())) {
+ array.SetShapeFromChunk();
+ }
+ }
+ int i = 0;
+ for (NDArray &array : opnode.exec->in_array) {
+ array.WaitToRead();
+ if (!shape_is_known(array.shape())) {
+ array.SetShapeFromChunk();
+ }
+ if (!shape_is_known(array.shape())) {
+ mxnet::TShape shape = rshape[idx.entry_id(inode.inputs[i])];
+ if (shape_is_known(shape)) {
+ array.ReshapeAndAlloc(shape);
+ }
+ }
+ ++i;
+ }
+ i = 0;
+ for (NDArray &array : opnode.exec->out_array) {
+ array.WaitToRead();
+ if (!shape_is_known(array.shape())) {
+ array.SetShapeFromChunk();
+ }
+ if (!shape_is_known(array.shape())) {
+ mxnet::TShape shape = rshape[idx.entry_id(nid, i)];
+ if (shape_is_known(shape)) {
+ array.ReshapeAndAlloc(shape);
+ }
+ }
+ ++i;
+ }
+ }
+ graph_.attrs["shape"] = std::make_shared<dmlc::any>(rshape);
+ }
+ const auto& idx = graph_.indexed_graph();
RunOps(is_train, num_forward_nodes_, idx.num_nodes());
}
@@ -119,6 +177,14 @@ void GraphExecutor::SetMonitorCallback(const
MonitorCallback& callback, bool mon
}
const std::vector<NDArray>& GraphExecutor::outputs() const {
+ if (this->is_dynamic_) {
+ for (const NDArray &array : output_arrays_) {
+ array.WaitToRead();
+ if (!shape_is_known(array.shape())) {
+ const_cast<NDArray &>(array).SetShapeFromChunk();
+ }
+ }
+ }
return output_arrays_;
}
@@ -381,8 +447,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
arg_shapes.resize(idx.input_nodes().size(), mxnet::TShape());
g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
- HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
- g.GetAttr<mxnet::ShapeVector>("shape"));
+ this->is_dynamic_ = true;
}
arg_dtypes.resize(idx.input_nodes().size(), -1);
@@ -821,8 +886,7 @@ Executor* GraphExecutor::Reshape(const bool partial_shaping,
}
g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
- HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
- g.GetAttr<mxnet::ShapeVector>("shape"));
+ this->is_dynamic_ = true;
}
const mxnet::ShapeVector& shape_vec = g.GetAttr<mxnet::ShapeVector>("shape");
std::vector<OpReqType> grad_req_types;
@@ -977,14 +1041,16 @@ void
GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
uint32_t oid = head_grad_map_.at(idx[nid].source);
uint32_t eid = idx.entry_id(idx.outputs()[oid]);
NDArrayStorageType stype = (NDArrayStorageType) vstorage_type[eid];
- CHECK(mxnet::shape_is_known(vshape[eid]));
+ bool unknown_shape = !shape_is_known(vshape[eid]);
CHECK_NE(vdtype[eid], -1);
auto data_eid = idx.entry_id(nid, 0);
// initialize based on storage_type
if (stype != kDefaultStorage) {
data_entry_[data_eid] = NDArray(stype, vshape[eid], data_context[eid],
true, vdtype[eid]);
- } else {
+ } else if (!unknown_shape) {
data_entry_[data_eid] = NDArray(vshape[eid], data_context[eid], false,
vdtype[eid]);
+ } else {
+ data_entry_[data_eid] = NDArray(data_context[eid], vdtype[eid]);
}
if (log_verbose_) {
LOG(INFO) << "\tinit head_grad entry\t" << data_eid << "\tas "
@@ -994,7 +1060,11 @@ void
GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
// get maximum bytes in each pool
for (size_t i = 0; i < vshape.size(); ++i) {
if (!data_entry_[i].is_none()) continue;
- size_t bytes = vshape[i].Size() * mshadow::mshadow_sizeof(vdtype[i]);
+ size_t shape_size = 0;
+ if (shape_is_known(vshape[i])) {
+ shape_size = vshape[i].Size();
+ }
+ size_t bytes = shape_size * mshadow::mshadow_sizeof(vdtype[i]);
int storage_id = vstorage[i];
// skip pool allocation for kBadStorageID, kExternalStorageID and
kDynamicStorageID
if (storage_id < 0) continue;
@@ -1013,7 +1083,10 @@ void
GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
std::multimap<size_t, NDArray> free_pool;
if (shared_pool != nullptr) {
for (const NDArray& nd : *shared_pool) {
- size_t bytes = nd.shape().Size() * mshadow::mshadow_sizeof(nd.dtype());
+ size_t bytes = 0;
+ if (shape_is_known(nd.shape())) {
+ bytes = nd.shape().Size() * mshadow::mshadow_sizeof(nd.dtype());
+ }
free_pool.insert(std::make_pair(bytes, nd));
}
}
@@ -1067,9 +1140,13 @@ void
GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
int storage_id = vstorage[i];
auto storage_type = (NDArrayStorageType) vstorage_type[i];
if (storage_type == kDefaultStorage) {
- CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet";
- const NDArray& src = data_pool_.at(storage_id);
- data_entry_[i] = src.AsArray(vshape[i], vdtype[i]);
+ if (!shape_is_known(vshape[i])) {
+ data_entry_[i] = NDArray(data_context[i], vdtype[i]);
+ } else {
+ CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet";
+ const NDArray& src = data_pool_.at(storage_id);
+ data_entry_[i] = src.AsArray(vshape[i], vdtype[i]);
+ }
} else {
data_entry_[i] = NDArray(storage_type, vshape[i], data_context[i],
true, vdtype[i]);
@@ -1209,7 +1286,10 @@ void GraphExecutor::InitOpSegs() {
const profiler::Profiler *prof = profiler::Profiler::Get();
bool prefer_bulk_exec_train = Imperative::PreferBulkExecTrain()
&& (!prof || !prof->AggregateEnabled());
-
+ if (this->is_dynamic_) {
+ prefer_bulk_exec_inference = false;
+ prefer_bulk_exec_train = false;
+ }
bool is_training = num_forward_nodes_ != total_num_nodes;
if (prefer_bulk_exec_train && is_training) {
@@ -1300,6 +1380,8 @@ void GraphExecutor::ExecuteMonOutputCallback(size_t nid) {
}
void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
+ static auto& finfer_shape =
nnvm::Op::GetAttr<mxnet::FInferShape>("FInferShape");
+ static auto& is_backward = Op::GetAttr<nnvm::TIsBackward>("TIsBackward");
// Update context
const auto& idx = graph_.indexed_graph();
for (size_t nid = topo_start; nid < topo_end; ++nid) {
@@ -1311,6 +1393,7 @@ void GraphExecutor::RunOps(bool is_train, size_t
topo_start, size_t topo_end) {
opnode.exec->op_ctx.need_grad = need_grad_;
}
+ mxnet::ShapeVector rshape = graph_.MoveCopyAttr<mxnet::ShapeVector>("shape");
// Push Ops
for (size_t nid = topo_start; nid < topo_end; ++nid) {
auto seg_op = cached_seg_opr_[nid];
@@ -1323,6 +1406,8 @@ void GraphExecutor::RunOps(bool is_train, size_t
topo_start, size_t topo_end) {
}
// Normal mode
const auto& inode = idx[nid];
+ const uint32_t num_inputs = inode.inputs.size();
+ const uint32_t num_outputs = inode.source->num_outputs();
if (inode.source->is_variable()) continue;
OpNode& opnode = op_nodes_[nid];
if (op_nodes_[nid].skip_exec_node) continue;
@@ -1330,6 +1415,69 @@ void GraphExecutor::RunOps(bool is_train, size_t
topo_start, size_t topo_end) {
if (monitor_callback_ && monitor_all_) {
ExecuteMonInputCallback(nid);
}
+ if (this->is_dynamic_) {
+ const auto &op = inode.source->op();
+ {
+ for (NDArray &array : opnode.exec->in_array) {
+ array.WaitToRead();
+ if (!shape_is_known(array.shape())) {
+ array.SetShapeFromChunk();
+ }
+ }
+ int i = 0;
+ for (NDArray &array : opnode.exec->out_array) {
+ array.WaitToRead();
+ if (!shape_is_known(array.shape())) {
+ array.SetShapeFromChunk();
+ }
+ if (!shape_is_known(array.shape())) {
+ mxnet::TShape shape = rshape[idx.entry_id(nid, i)];
+ if (shape_is_known(shape)) {
+ array.ReshapeAndAlloc(shape);
+ }
+ }
+ ++i;
+ }
+ }
+ if (finfer_shape.count(op)) {
+ mxnet::ShapeVector in_shapes;
+ mxnet::ShapeVector out_shapes;
+ for (NDArray &array : opnode.exec->in_array) {
+ in_shapes.push_back(array.shape());
+ }
+ for (NDArray &array : opnode.exec->out_array) {
+ out_shapes.push_back(array.shape());
+ }
+ auto finfer = finfer_shape[op];
+ try {
+ bool success = finfer(inode.source->attrs, &in_shapes, &out_shapes);
+ CHECK(success) << "InferShape failed in operator " <<
inode.source->attrs.name;
+ } catch (const std::exception& e) {
+ throw dmlc::Error("Error in operator " + inode.source->attrs.name +
": " + e.what());
+ }
+ int n_out = out_shapes.size();
+ for (int i = 0; i < n_out; ++i) {
+ NDArray &array = opnode.exec->out_array[i];
+ if (!shape_is_known(array.shape())) {
+ array.Init(out_shapes[i]);
+ }
+ }
+ } else if (is_backward.get(inode.source->op(), false) &&
inode.control_deps.size()) {
+ CHECK_GE(inode.control_deps.size(), 1U) <<
+ "BackwardOp need to have control_deps to its forward op";
+ uint32_t fid = inode.control_deps[0];
+ const OpNode& fopnode = op_nodes_[fid];
+ CHECK_EQ(fopnode.exec->in_array.size(), opnode.exec->out_array.size());
+ int nelem = fopnode.exec->in_array.size();
+ std::vector<NDArray> &from = fopnode.exec->in_array;
+ std::vector<NDArray> &to = opnode.exec->out_array;
+ for (int i = 0; i < nelem; ++i) {
+ if (!shape_is_known(to[i].shape())) {
+ to[i].Init(from[i].shape());
+ }
+ }
+ }
+ }
opnode.exec->op_ctx.is_train = is_train;
opnode.exec->op_ctx.need_grad = need_grad_;
if (opnode.exec->exec_type() == ExecType::kCrossDeviceCopy) {
@@ -1343,14 +1491,35 @@ void GraphExecutor::RunOps(bool is_train, size_t
topo_start, size_t topo_end) {
} else if (opnode.cached_opr != nullptr) {
bool profiling = profiler::Profiler::Get()->GetState() ==
profiler::Profiler::kRunning;
Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling);
+ if (this->is_dynamic_) {
+ for (NDArray &array : opnode.exec->out_array) {
+ array.WaitToRead();
+ if (!shape_is_known(array.shape())) {
+ array.SetShapeFromChunk();
+ }
+ }
+ }
} else {
LOG(FATAL) << "Not accessed";
}
+ for (uint32_t i = 0; i < num_inputs; ++i) {
+ int eid = idx.entry_id(inode.inputs[i]);
+ if (!shape_is_known(rshape[eid])) {
+ rshape[eid] = opnode.exec->in_array[i].shape();
+ }
+ }
+ for (uint32_t i = 0; i < num_outputs; ++i) {
+ int eid = idx.entry_id(nid, i);
+ if (!shape_is_known(rshape[eid])) {
+ rshape[eid] = opnode.exec->out_array[i].shape();
+ }
+ }
// Monitor callbacks
if (monitor_callback_) {
ExecuteMonOutputCallback(nid);
}
}
+ graph_.attrs["shape"] = std::make_shared<dmlc::any>(rshape);
}
GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t
topo_start, size_t topo_end) {
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index 9a86609..f150165 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -216,6 +216,8 @@ class GraphExecutor : public Executor {
void ExecuteMonOutputCallback(size_t nid);
// peform bulking and segmentation on the region [from_node, up_to_node) of
a graph
void BulkOpSegs(size_t from_node, size_t up_to_node, size_t
segment_num_nodes_max);
+ // When infer shape fails, fall back to ensure dynamic-shaped operators
executed correctly.
+ bool is_dynamic_;
// indicate whether there is a backward graph for gradients.
bool need_grad_;
// internal graph
diff --git a/src/nnvm/plan_memory.cc b/src/nnvm/plan_memory.cc
index 41b8559..ac70848 100644
--- a/src/nnvm/plan_memory.cc
+++ b/src/nnvm/plan_memory.cc
@@ -268,7 +268,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph&
idx,
// only request memory for kBadStorageID
if (storage[eid] == GraphAllocator::kBadStorageID) {
auto &eshape = shape_vec[eid];
- size_t esize = eshape.Size();
+ size_t esize = ndim_is_known(shape_vec[eid]) ? eshape.Size() : 0;
eids.insert(std::make_pair(esize, eid));
}
}
diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc
index 4c0d67b..fd087ef 100644
--- a/src/operator/control_flow.cc
+++ b/src/operator/control_flow.cc
@@ -243,7 +243,6 @@ static void ForeachGradComputeExCPU(const OpStatePtr&
state_ptr,
// the user to write state gradients to the outputs.
subg_req[loc] = iter_num != 0 ? kWriteTo : req[i +
params.in_data_locs.ndim()];
}
-
state.Backward(iter_num, subg_ograds, subg_req, subg_igrads);
size_t num_states = subg_ograds.size() - num_output_data;
@@ -579,8 +578,6 @@ static void WhileLoopComputeExCPU(const OpStatePtr&
state_ptr,
CHECK_EQ(inputs.size() + 2U, (size_t) params.num_args);
CHECK_EQ(outputs.size(), (size_t) params.num_outputs);
CHECK_EQ(outputs.size(), req.size());
- for (size_t i = 0; i < (size_t) params.num_out_data; i++)
- CHECK_EQ(params.max_iterations, outputs[i].shape()[0]);
// construct inputs and outputs for cond
std::vector<NDArray> cond_inputs, cond_outputs = {NDArray()};
extract_by_loc(inputs, params.cond_input_locs, &cond_inputs);
@@ -596,20 +593,33 @@ static void WhileLoopComputeExCPU(const OpStatePtr&
state_ptr,
break;
}
// we create func_outputs for the current step:
- // func_outputs[0: num_out_data] is a slice of outputs[][step]
- for (size_t i = 0; i < (size_t) params.num_out_data; ++i) {
- func_outputs[i] = outputs[i].At(step);
- }
- // func_outputs[num_out_data: ] are new_loop_vars, need to allocate new
memory
- for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
- func_outputs[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true,
outputs[i].dtype());
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ func_outputs[i] = NDArray(outputs[i].ctx(), outputs[i].dtype());
}
state.Forward(step, func_inputs, req, func_outputs, ctx.need_grad);
+ if (step == 0) {
+ for (int i = 0; i < params.num_out_data; ++i) {
+ func_outputs[i].WaitToRead();
+ if (!shape_is_known(func_outputs[i].shape())) {
+ func_outputs[i].SetShapeFromChunk();
+ }
+ mxnet::TShape step_shape = func_outputs[i].shape();
+ mxnet::TShape shape(step_shape.ndim() + 1, 0);
+ shape[0] = params.max_iterations;
+ for (int j = 0; j < step_shape.ndim(); ++j) {
+ shape[j + 1] = step_shape[j];
+ }
+ const_cast<NDArray &>(outputs[i]).Init(shape);
+ }
+ }
+ for (int i = 0; i < params.num_out_data; ++i) {
+ NDArray first_slot = outputs[i].At(step);
+ mxnet::CopyFromTo(func_outputs[i], &first_slot);
+ }
// func_inputs on the next step:
// the output (new_loop_vars) will become the new inputs (loop_vars)
for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
- size_t j = params.func_var_locs[i - params.num_out_data];
- CHECK_EQ(func_inputs[j].shape(), func_outputs[i].shape());
+ int j = params.func_var_locs[i - params.num_out_data];
func_inputs[j] = func_outputs[i];
int k = state.oi_map[i - params.num_out_data];
if (k != -1) {
@@ -627,8 +637,21 @@ static void WhileLoopComputeExCPU(const OpStatePtr&
state_ptr,
// therefore, we copy func_inputs[:] to outputs[num_out_data: ]
for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
size_t j = params.func_var_locs[i - params.num_out_data];
+ if (!shape_is_known(outputs[i].shape())) {
+ const_cast<NDArray &>(outputs[i]).Init(func_inputs[j].shape());
+ }
mxnet::CopyFromTo(func_inputs[j], &outputs[i]);
}
+ for (int i = 0; i < params.num_out_data; ++i) {
+ const_cast<NDArray &>(outputs[i]).SetShapeFromChunk();
+ }
+ if (state.n_iterations == 0) {
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ if (!shape_is_known(outputs[i].shape())) {
+ const_cast<NDArray &>(outputs[i]).ReshapeAndAlloc({1});
+ }
+ }
+ }
}
static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr,
@@ -726,108 +749,6 @@ static void WhileLoopGradComputeExCPU(const OpStatePtr&
state_ptr,
state.Cleanup();
}
-static bool WhileLoopShape(const nnvm::NodeAttrs& attrs,
- mxnet::ShapeVector *in_shape,
- mxnet::ShapeVector *out_shape) {
- using mxnet::ShapeVector;
- const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
- static const std::function<bool(const mxnet::TShape &)> is_udf =
is_shape_udf;
- // sanity checks
- CHECK_EQ(in_shape->size() + 2U, (size_t) params.num_args);
- CHECK_EQ(out_shape->size(), (size_t) params.num_outputs);
- CHECK_EQ(attrs.subgraphs.size(), 2U);
- CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
- // infer shape for cond and func
- auto infer_subg = [¶ms, in_shape, out_shape](std::shared_ptr<Symbol>
subg,
- ShapeVector *_subg_out,
- const mxnet::Tuple<dim_t>
&input_locs,
- int num_out_data,
- bool fill_out_shape) {
- // create subg_in
- ShapeVector subg_in;
- ShapeVector &subg_out = *_subg_out;
- extract_by_loc(*in_shape, input_locs, &subg_in);
- // create an indexed graph
- nnvm::Graph g;
- g.outputs = subg->outputs;
- const auto& idx = g.indexed_graph();
- // get input nodes
- const auto &input_nids = idx.input_nodes();
- // sanity checks
- CHECK_EQ(input_nids.size(), subg_in.size());
- CHECK_EQ(g.outputs.size(), subg_out.size());
- CHECK_EQ(idx.input_nodes().size(), subg_in.size());
- CHECK_EQ(idx.outputs().size(), subg_out.size());
- // create empty shapes for inference
- ShapeVector shapes(idx.num_node_entries());
- // copy subg_in into shapes
- for (size_t i = 0; i < subg_in.size(); ++i) {
- auto eid = idx.entry_id(input_nids[i], 0);
- shapes[eid] = subg_in[i];
- }
- // copy subg_out into shapes
- // note that ndim of out_data is not increased
- // because subg is only one step
- for (size_t i = 0; i < subg_out.size(); ++i) {
- auto eid = idx.entry_id(g.outputs[i]);
- shapes[eid] = subg_out[i];
- }
- // copy done, call InferShape
- g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
- g = exec::InferShape(std::move(g));
- // now `shapes' won't be used anymore, use new_shapes instead
- const auto& new_shapes = g.GetAttr<mxnet::ShapeVector>("shape");
- // copy subg_in back to in_shape
- for (size_t i = 0; i < subg_in.size(); ++i) {
- auto eid = idx.entry_id(input_nids[i], 0);
- auto g_out_shape = new_shapes[eid];
- if (!shape_is_known(g_out_shape)) {
- // when the shape is not fully inferred
- continue;
- }
- SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape);
- }
- if (!fill_out_shape) {
- return true;
- }
- // copy subg_out back to out_shape
- // for results in [0, num_out_data), ndim should increase by 1
- for (int i = 0; i < num_out_data; ++i) {
- auto eid = idx.entry_id(g.outputs[i]);
- auto g_out_shape = new_shapes[eid];
- if (!shape_is_known(g_out_shape)) {
- // when the shape is not fully inferred
- continue;
- }
- auto out = mxnet::TShape(g_out_shape.ndim() + 1, -1);
- out[0] = params.max_iterations;
- for (int i = 1; i < out.ndim(); i++)
- out[i] = g_out_shape[i - 1];
- SHAPE_ASSIGN_CHECK(*out_shape, i, out);
- }
- // for results in [num_out_data, ...), ndim does not change
- for (size_t i = num_out_data; i < g.outputs.size(); ++i) {
- auto eid = idx.entry_id(g.outputs[i]);
- auto g_out_shape = new_shapes[eid];
- if (!shape_is_known(g_out_shape)) {
- // when the shape is not fully inferred
- continue;
- }
- SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape);
- }
- return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
- };
- mxnet::ShapeVector cond_out_shape{mxnet::TShape(1, 1)}; // this means: [(1,
)]
- mxnet::ShapeVector func_out_shape(params.num_outputs);
- CHECK(params.sync_in_out(in_shape, out_shape, is_udf));
- bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape,
params.cond_input_locs, 0, false);
- CHECK(params.sync_in_out(in_shape, out_shape, is_udf));
- bool succ_1 = infer_subg(attrs.subgraphs[1], &func_out_shape, \
- params.func_input_locs, params.num_out_data, true);
- CHECK(params.sync_in_out(in_shape, out_shape, is_udf));
- return succ_0 && succ_1;
-}
-
static bool WhileLoopType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type, std::vector<int>
*out_type) {
const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
@@ -1033,93 +954,6 @@ static void CondGradComputeExCPU(const OpStatePtr&
state_ptr,
loop_state.Cleanup();
}
-static bool CondShape(const nnvm::NodeAttrs& attrs,
- mxnet::ShapeVector *in_shape,
- mxnet::ShapeVector *out_shape) {
- using mxnet::ShapeVector;
- const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
- static const std::function<bool(const mxnet::TShape &)> is_udf =
is_shape_udf;
- // sanity checks
- CHECK_EQ(in_shape->size() + 3U, (size_t) params.num_args);
- CHECK_EQ(out_shape->size(), (size_t) params.num_outputs);
- CHECK_EQ(attrs.subgraphs.size(), 3U);
- CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
- CHECK_EQ(attrs.subgraphs[1]->outputs.size(),
attrs.subgraphs[2]->outputs.size());
- // infer shape for cond, then and else
- auto infer_subg = [¶ms, in_shape, out_shape](std::shared_ptr<Symbol>
subg,
- ShapeVector *_subg_out,
- const mxnet::Tuple<dim_t>
&input_locs,
- bool fill_out_shape) {
- // create subg_in
- mxnet::ShapeVector subg_in;
- mxnet::ShapeVector &subg_out = *_subg_out;
- extract_by_loc(*in_shape, input_locs, &subg_in);
- // create an indexed graph
- nnvm::Graph g;
- g.outputs = subg->outputs;
- const auto& idx = g.indexed_graph();
- // get input nodes
- const auto &input_nids = idx.input_nodes();
- // sanity checks
- CHECK_EQ(input_nids.size(), subg_in.size());
- CHECK_EQ(g.outputs.size(), subg_out.size());
- CHECK_EQ(idx.input_nodes().size(), subg_in.size());
- CHECK_EQ(idx.outputs().size(), subg_out.size());
- // create empty shapes for inference
- mxnet::ShapeVector shapes(idx.num_node_entries());
- // copy subg_in into shapes
- for (size_t i = 0; i < subg_in.size(); ++i) {
- auto eid = idx.entry_id(input_nids[i], 0);
- shapes[eid] = subg_in[i];
- }
- // copy subg_out into shapes
- for (size_t i = 0; i < subg_out.size(); ++i) {
- auto eid = idx.entry_id(g.outputs[i]);
- shapes[eid] = subg_out[i];
- }
- // copy done, call InferShape
- g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
- g = exec::InferShape(std::move(g));
- // now `shapes' won't be used anymore, use new_shapes instead
- const auto& new_shapes = g.GetAttr<mxnet::ShapeVector>("shape");
- // copy subg_in back to in_shape
- for (size_t i = 0; i < subg_in.size(); ++i) {
- auto eid = idx.entry_id(input_nids[i], 0);
- auto g_out_shape = new_shapes[eid];
- if (!shape_is_known(g_out_shape)) {
- // when the shape is not fully inferred
- continue;
- }
- SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape);
- }
- if (!fill_out_shape) {
- return true;
- }
- // copy subg_out back to out_shape
- for (size_t i = 0; i < g.outputs.size(); ++i) {
- auto eid = idx.entry_id(g.outputs[i]);
- auto g_out_shape = new_shapes[eid];
- if (!shape_is_known(g_out_shape)) {
- // when the shape is not fully inferred
- continue;
- }
- SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape);
- }
- return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
- };
- ShapeVector cond_out_shape{mxnet::TShape(1, 1)}; // this means: [(1, )]
- ShapeVector then_out_shape(params.num_outputs);
- ShapeVector else_out_shape(params.num_outputs);
- bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, \
- params.cond_input_locs, false);
- bool succ_1 = infer_subg(attrs.subgraphs[1], &then_out_shape, \
- params.then_input_locs, true);
- bool succ_2 = infer_subg(attrs.subgraphs[2], &else_out_shape, \
- params.else_input_locs, true);
- sync_out_out(&then_out_shape, &else_out_shape, is_udf);
- return succ_0 && succ_1 && succ_2;
-}
-
static bool CondType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type) {
@@ -1342,7 +1176,6 @@ NNVM_REGISTER_OP(_while_loop)
})
.set_attr<nnvm::FGradient>("FGradient", WhileLoopGradient)
.set_attr<FCreateOpState>("FCreateOpState", CreateWhileLoopState)
-.set_attr<mxnet::FInferShape>("FInferShape", WhileLoopShape)
.set_attr<nnvm::FInferType>("FInferType", WhileLoopType)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", WhileLoopComputeExCPU)
.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
@@ -1405,7 +1238,6 @@ NNVM_REGISTER_OP(_cond)
})
.set_attr<nnvm::FGradient>("FGradient", CondGradient)
.set_attr<FCreateOpState>("FCreateOpState", CreateCondState)
-.set_attr<mxnet::FInferShape>("FInferShape", CondShape)
.set_attr<nnvm::FInferType>("FInferType", CondType)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CondComputeExCPU)
.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
diff --git a/tests/python/unittest/test_contrib_control_flow.py
b/tests/python/unittest/test_contrib_control_flow.py
index dd5a4d6..a93c109 100644
--- a/tests/python/unittest/test_contrib_control_flow.py
+++ b/tests/python/unittest/test_contrib_control_flow.py
@@ -935,14 +935,27 @@ def test_while_loop_rnn():
max_iterations=seq_len
)
result = mx.sym.Group(result[0] + result[1][1: ])
- arg_shapes, _, _ = result.infer_shape(
- data=(seq_len, batch_size, input_dim),
- s_0=(batch_size, hidden_dim),
- )
rnn_inputs = result.list_inputs()
- args = {name: _array(arg_shapes[i]) for i, name in
enumerate(rnn_inputs) if name != "i"}
- args["i"] = mx.nd.zeros([1])
- args_grad = {name: _array(arg_shapes[i]) for i, name in
enumerate(rnn_inputs)}
+ args = {
+ "i": mx.nd.zeros([1]),
+ "data": _array((seq_len, batch_size, input_dim)),
+ "i2h_weight": _array((input_dim * hidden_dim, input_dim)),
+ "i2h_bias": _array((input_dim * hidden_dim, )),
+ "s_0": _array((batch_size, hidden_dim)),
+ "h2h_weight": _array((input_dim * hidden_dim, seq_len)),
+ "h2h_bias": _array((input_dim * hidden_dim, )),
+ "s_1": _array((batch_size, hidden_dim)),
+ }
+ args_grad = {
+ "i": _array([1]),
+ "data": _array((seq_len, batch_size, input_dim)),
+ "i2h_weight": _array((input_dim * hidden_dim, input_dim)),
+ "i2h_bias": _array((input_dim * hidden_dim, )),
+ "s_0": _array((batch_size, hidden_dim)),
+ "h2h_weight": _array((input_dim * hidden_dim, seq_len)),
+ "h2h_bias": _array((input_dim * hidden_dim, )),
+ "s_1": _array((batch_size, hidden_dim)),
+ }
e_1 = result.bind(ctx=default_context(),
args={name: array.copy() for name, array in args.items()},
args_grad={name: array.copy() for name, array in args_grad.items()
if name != "i"},
@@ -961,9 +974,9 @@ def test_while_loop_rnn():
args_grad={name: array.copy() for name, array in args_grad.items()
if name != "i"},
)
for case_id in range(100):
- out_grads = [_array(arr.shape) for arr in e_1.outputs]
args = {name: array.copy() for name, array in args.items()}
e_1.forward(is_train=True, **args)
+ out_grads = [_array(arr.shape) for arr in e_1.outputs]
e_1.backward(out_grads)
args = {name: array.copy() for name, array in args.items() if name
!= "i"}
e_2.forward(is_train=True, **args)
@@ -1178,7 +1191,6 @@ def check_contrib_rnn(cell_type, num_states):
params2 = layer.collect_params()
for key, val in orig_params1.items():
params2[key].set_data(copy.deepcopy(val.data()))
-
trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03})
with mx.autograd.record():
res2 = layer(rnn_data, states)