This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 7fc0396 fix group2ctx with null reqs (#8717) 7fc0396 is described below commit 7fc039639b288f80fa7fe6482de1a25e04261e5e Author: Haibin Lin <linhaibin.e...@gmail.com> AuthorDate: Sun Nov 19 21:55:31 2017 -0800 fix group2ctx with null reqs (#8717) --- src/executor/graph_executor.cc | 18 +++++++++++++---- tests/python/unittest/test_multi_device_exec.py | 26 +++++++++++++++++-------- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index ade8e83..01484da 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -321,6 +321,7 @@ Graph AssignContext(Graph g, const std::vector<Context>& in_arg_ctxes, const std::vector<Context>& arg_grad_ctxes, const std::vector<Context>& aux_state_ctxes, + const std::vector<OpReqType>& grad_req_types, size_t num_forward_inputs, size_t num_forward_outputs) { const auto& idx = g.indexed_graph(); @@ -385,9 +386,15 @@ Graph AssignContext(Graph g, // loop through backward input nodes and populate maps and lists // the backward input nodes is the gradient of the loss wrt the output - for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i) { + size_t arg_grad_offset = 0; + // keep an offset into the arg_grad_ctxes vector, + // since g.outputs exclude arg_grad whose req == null + CHECK_GE(grad_req_types.size(), g.outputs.size() - num_forward_outputs) + << "insufficient number of grad_reqs"; + for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) { + while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset; const uint32_t nid = idx.outputs()[i].node_id; - Context ctx = arg_grad_ctxes[i - num_forward_outputs]; + Context ctx = arg_grad_ctxes[arg_grad_offset]; if (ctx2id.count(ctx) == 0) { ctx2id[ctx] = static_cast<int>(ctx_list.size()); ctx_list.push_back(ctx); @@ -417,9 +424,11 @@ Graph AssignContext(Graph g, // if the assigned device of gradient node // corresponds to storage of grads auto &new_idx = g.indexed_graph(); - for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i) { + arg_grad_offset = 0; + for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) { + while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset; const uint32_t nid = new_idx.outputs()[i].node_id; - Context ctx = arg_grad_ctxes[i - num_forward_outputs]; + Context ctx = arg_grad_ctxes[arg_grad_offset]; CHECK(ctx == vcontext[nid]) << "Trying to save gradient to " << ctx << " while its source node \"" << new_idx[nid].source->attrs.name @@ -1055,6 +1064,7 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, + grad_req_types, num_forward_inputs_, num_forward_outputs_); diff --git a/tests/python/unittest/test_multi_device_exec.py b/tests/python/unittest/test_multi_device_exec.py index 0a2739d..aa279b1 100644 --- a/tests/python/unittest/test_multi_device_exec.py +++ b/tests/python/unittest/test_multi_device_exec.py @@ -20,6 +20,17 @@ import numpy as np import mxnet as mx def test_ctx_group(): + def check_ctx_group(group2ctx, grad_req, mlp, set_stage1): + texec = mlp.simple_bind(mx.cpu(0), + group2ctx=group2ctx, + data=(1,200), grad_req=grad_req) + + for arr, name in zip(texec.arg_arrays, mlp.list_arguments()): + if name in set_stage1: + assert arr.context == group2ctx['stage1'] + else: + assert arr.context == group2ctx['stage2'] + with mx.AttrScope(ctx_group='stage1'): data = mx.symbol.Variable('data') fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) @@ -40,15 +51,14 @@ def test_ctx_group(): 'stage2' : mx.cpu(2) } - texec = mlp.simple_bind(mx.cpu(0), - group2ctx=group2ctx, - data=(1,200)) + # generate reqs with null + grad_req_with_null = {} + for arg in mlp.list_arguments(): + grad_req_with_null[arg] = 'null' if arg == 'data' else 'write' - for arr, name in zip(texec.arg_arrays, mlp.list_arguments()): - if name in set_stage1: - assert arr.context == group2ctx['stage1'] - else: - assert arr.context == group2ctx['stage2'] + grad_reqs = ['write', grad_req_with_null] + for grad_req in grad_reqs: + check_ctx_group(group2ctx, grad_req, mlp, set_stage1) def test_ctx_group_sparse(): with mx.AttrScope(ctx_group='stage1'): -- To stop receiving notification emails like this one, please contact ['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].