zheng-da closed pull request #11059: Cut subgraph
URL: https://github.com/apache/incubator-mxnet/pull/11059
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 06e39bfeb38..79c92edfa0a 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1056,6 +1056,28 @@ MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint
*out_size,
*/
MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
const char **name);
+
+/*!
+ * \brief Get the input symbols of the graph.
+ * \param sym The graph.
+ * \param inputs The input symbols of the graph.
+ * \param input_size the number of input symbols returned.
+ */
+MXNET_DLL int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **inputs,
+ int *input_size);
+
+/*!
+ * \brief Cut a subgraph whose nodes are marked with a subgraph attribute.
+ * The input graph will be modified. A variable node will be created for each
+ * edge that connects to nodes outside the subgraph. The outside nodes that
+ * connect to the subgraph will be returned.
+ * \param sym The graph.
+ * \param inputs The nodes that connect to the subgraph.
+ * \param input_size The number of such nodes.
+ */
+MXNET_DLL int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **inputs,
+ int *input_size);
+
/*!
* \brief Get the detailed information about atomic symbol.
* \param creator the AtomicSymbolCreator.
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index e243eb71c47..897b5e882aa 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -693,6 +693,10 @@ class NDArray {
NDArray MKLDNNDataReshape(const TShape &shape) const;
#endif
+ const nnvm::NodeEntry &entry() const {
+ return entry_;
+ }
+
/*!
* \brief Save list of ndarray into the Stream.x
* \param fo The stream of output.
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index 3969d8445be..23a318464f1 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -64,8 +64,10 @@ enum OpReqType {
* \sa Resource
*/
struct OpContext {
+ /*! \brief whether there is a backward phase to compute gradients. */
+ bool need_grad;
/*! \brief whether it is training phase */
- int is_train;
+ bool is_train;
/*! \brief RunContext related resources */
RunContext run_ctx;
/*! \brief the callback when operation completes, used by asynchronize ops */
diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py
index ba402e6f3f8..d68698f71d6 100644
--- a/python/mxnet/ndarray/contrib.py
+++ b/python/mxnet/ndarray/contrib.py
@@ -21,6 +21,8 @@
import math
from ..context import current_context
from ..random import uniform
+from ..base import _as_list
+from . import ndarray
try:
from .gen_contrib import *
except ImportError:
@@ -96,3 +98,96 @@ def rand_zipfian(true_classes, num_sampled, range_max,
ctx=None):
expected_count_sampled = expected_prob_sampled * num_sampled
return sampled_classes, expected_count_true, expected_count_sampled
# pylint: enable=line-too-long
+
+def foreach(body, data, init_states):
+ """Run a for loop with user-defined computation over NDArrays on dimension
0.
+
+ This operator simulates a for loop and body has the computation for an
iteration
+ of the for loop. It runs the computation in body on each slice from the
input
+ NDArrays.
+
+ body takes two arguments as input and outputs a tuple of two elements,
+ as illustrated below:
+
+ out, states = body(data1, states)
+
+ data1 can be either an NDArray or a list of NDArrays. If data is an
NDArray,
+ data1 is an NDArray. Otherwise, data1 is a list of NDArrays and has the
same
+ size as data. states is a list of NDArrays and have the same size as
init_states.
+ Similarly, out can be either an NDArray or a list of NDArrays, which are
concatenated
+ as the first output of foreach; states from the last execution of body
+ are the second output of foreach.
+
+ The computation done by this operator is equivalent to the pseudo code
below
+ when the input data is NDArray:
+
+ states = init_states
+ outs = []
+ for i in data.shape[0]:
+ s = data[i]
+ out, states = body(s, states)
+ outs.append(out)
+ outs = stack(*outs)
+
+
+ Parameters
+ ----------
+ body : a Python function.
+ Define computation in an iteration.
+ data: an NDArray or a list of NDArrays.
+ The input data.
+ init_states: an NDArray or a list of NDArrays.
+ The initial values of the loop states.
+ name: string.
+ The name of the operator.
+
+ Returns
+ -------
+ outputs: an NDArray or a list of NDArrays.
+ The output data concatenated from the output of all iterations.
+ states: a list of NDArrays.
+ The loop states in the last iteration.
+
+ Examples
+ --------
+ >>> step = lambda data, states: (data + states[0], [states[0] * 2])
+ >>> data = mx.nd.random.uniform(shape=(2, 10))
+ >>> states = [mx.nd.random.uniform(shape=(10))]
+ >>> outs, states = mx.nd.contrib.foreach(step, data, states)
+ """
+
+ def check_input(inputs, in_type, msg):
+ is_NDArray_or_list = True
+ if isinstance(inputs, list):
+ for i in inputs:
+ if not isinstance(i, in_type):
+ is_NDArray_or_list = False
+ break
+ else:
+ is_NDArray_or_list = isinstance(inputs, in_type)
+ assert is_NDArray_or_list, msg
+
+ check_input(data, ndarray.NDArray, "data should be an NDArray or a list of
NDArrays")
+ check_input(init_states, ndarray.NDArray,
+ "init_states should be an NDArray or a list of NDArrays")
+
+ not_data_list = isinstance(data, ndarray.NDArray)
+ not_state_list = isinstance(init_states, ndarray.NDArray)
+ num_iters = data.shape[0] if not_data_list else data[0].shape[0]
+ states = init_states
+ outputs = []
+ for i in range(num_iters):
+ if not_data_list:
+ eles = data[i]
+ else:
+ eles = [d[i] for d in data]
+ outs, states = body(eles, states)
+ outs = _as_list(outs)
+ outputs.append(outs)
+ outputs = zip(*outputs)
+ for j, out in enumerate(outputs):
+ outputs[j] = ndarray.op.stack(*out)
+
+ if not_data_list:
+ outputs = outputs[0]
+ return (outputs, states)
diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py
index 83e90e68732..a1a1d23bbbe 100644
--- a/python/mxnet/symbol/contrib.py
+++ b/python/mxnet/symbol/contrib.py
@@ -19,6 +19,9 @@
# pylint: disable=wildcard-import, unused-wildcard-import
"""Contrib Symbol API of MXNet."""
import math
+import ctypes
+import re
+
from .random import uniform
from .symbol import Symbol
try:
@@ -26,6 +29,11 @@
except ImportError:
pass
+from . import symbol
+from ..base import _LIB, c_array, check_call
+from ..base import SymbolHandle, _as_list
+from ..attribute import AttrScope
+
__all__ = ["rand_zipfian"]
def rand_zipfian(true_classes, num_sampled, range_max):
@@ -91,3 +99,196 @@ def rand_zipfian(true_classes, num_sampled, range_max):
expected_prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 +
1.0)).log() / log_range
expected_count_sampled = expected_prob_sampled * num_sampled
return sampled_classes, expected_count_true, expected_count_sampled
+
+def _get_graph_inputs(subg):
+ num_handles = ctypes.c_int(1000)
+ handles = c_array(SymbolHandle, [SymbolHandle(0) for i in range(1000)])
+ check_call(_LIB.MXSymbolGetInputSymbols(subg.handle, handles,
ctypes.byref(num_handles)))
+
+ syms = []
+ for i in range(num_handles.value):
+ s = Symbol(handles[i])
+ syms.append(s)
+ return syms
+
+def _cut_subgraph(subg):
+ num_handles = ctypes.c_int(1000)
+ handles = c_array(SymbolHandle, [SymbolHandle(0) for i in range(1000)])
+ check_call(_LIB.MXSymbolCutSubgraph(subg.handle, handles,
ctypes.byref(num_handles)))
+
+ syms = []
+ for i in range(num_handles.value):
+ s = Symbol(handles[i])
+ syms.append(s)
+ return syms
+
+def foreach(body, data, init_states, name="foreach"):
+ """Run a for loop with user-defined computation over Symbols on dimension
0.
+
+ This operator simulates a for loop and body has the computation for an
iteration
+ of the for loop. It runs the computation in body on each slice from the
input
+ NDArrays.
+
+ body takes two arguments as input and outputs a tuple of two elements,
+ as illustrated below:
+
+ out, states = body(data1, states)
+
+ data1 can be either a symbol or a list of symbols. If data is a symbol,
+ data1 is a symbol. Otherwise, data1 is a list of symbols and has the same
+ size as data. states is a list of symbols and have the same size as
init_states.
+ Similarly, out can be either a symbol or a list of symbols, which are
concatenated
+ as the first output of foreach; states from the last execution of body
+ are the second output of foreach.
+
+ The computation done by this operator is equivalent to the pseudo code
below
+ when the input data is NDArray:
+
+ states = init_states
+ outs = []
+ for i in data.shape[0]:
+ s = data[i]
+ out, states = body(s, states)
+ outs.append(out)
+ outs = stack(*outs)
+
+
+ Parameters
+ ----------
+ body : a Python function.
+ Define computation in an iteration.
+ data: a symbol or a list of symbols.
+ The input data.
+ init_states: a symbol or a list of symbols.
+ The initial values of the loop states.
+ name: string.
+ The name of the operator.
+
+ Returns
+ -------
+ outputs: a Symbol or a list of Symbols.
+ The output data concatenated from the output of all iterations.
+ states: a list of Symbols.
+ The loop states in the last iteration.
+
+ Examples
+ --------
+ >>> step = lambda data, states: (data + states[0], [states[0] * 2])
+ >>> data = mx.sym.var('data')
+ >>> states = [mx.sym.var('state')]
+ >>> outs, states = mx.sym.contrib.foreach(step, data, states)
+ """
+
+ def check_data(inputs, in_type, msg):
+ is_NDArray_or_list = True
+ if isinstance(inputs, list):
+ for i in inputs:
+ if not isinstance(i, in_type):
+ is_NDArray_or_list = False
+ break
+ else:
+ is_NDArray_or_list = isinstance(inputs, in_type)
+ assert is_NDArray_or_list, msg
+
+ check_data(data, symbol.Symbol, "data should be an NDArray or a list of
NDArrays")
+ check_data(init_states, symbol.Symbol,
+ "init_states should be an NDArray or a list of NDArrays")
+ not_state_list = isinstance(init_states, symbol.Symbol)
+
+ # TODO(zhengda) If the input python function references to the symbols
outside
+ # the python function, we need to prune the computation graph constructed
from
+ # the function. One way of doing it is to mark the nodes in the
computation graph
+ # with AttrScope and prune the nodes without the special attribute.
+ with AttrScope(subgraph_name=name):
+ if isinstance(data, list):
+ in_eles = [symbol.var(sym.name) for sym in data]
+ else:
+ in_eles = symbol.var(data.name)
+ if isinstance(init_states, list):
+ states = [symbol.var(s.name) for s in init_states]
+ else:
+ states = symbol.var(init_states.name)
+ sym_out, sym_states = body(in_eles, states)
+
+ check_data(sym_out, symbol.Symbol,
+ "the output should be an NDArray or a list of NDArrays")
+ check_data(sym_states, symbol.Symbol,
+ "the output states should be an NDArray or a list of NDArrays")
+ if isinstance(sym_states, list):
+ assert isinstance(init_states, list) and len(sym_states) ==
len(init_states), \
+ "the number of output states (%d) should be the same as
input states (%d)" \
+ % (len(sym_states), len(init_states))
+
+ if isinstance(sym_out, list):
+ flat_out = sym_out
+ else:
+ flat_out = [sym_out]
+ num_out_data = len(flat_out)
+ if isinstance(sym_states, list):
+ for s in sym_states:
+ # There is a problem if the outputs are the same as the inputs
+ # or the first output. By calling identity, we can make sure
that
+ # all symbols will refer to different NDArrays.
+ flat_out.append(symbol.op.identity(s))
+ else:
+ flat_out.append(symbol.op.identity(sym_states))
+ g = symbol.Group(flat_out)
+
+ cut_syms = _cut_subgraph(g)
+ input_syms = _get_graph_inputs(g)
+
+ # Here we need to find out how the input symbols are ordered as well as
+ # where the loop states are located in the list of inputs.
+
+ # This dict contains the symbols of the subgraph.
+ input_syms = {sym.name:sym for sym in input_syms}
+ gin_names = input_syms.keys()
+ # This array contains the symbols for the inputs of foreach.
+ # They are ordered according to the inputs of the subgraph.
+ states_map = {sym.name:sym for sym in init_states}
+ state_names = states_map.keys()
+ data_syms = _as_list(data)
+ data_map = {sym.name:sym for sym in data_syms}
+ data_names = data_map.keys()
+
+ ordered_ins = []
+ in_state_locs = []
+ in_data_locs = []
+ for in_name in g.list_inputs():
+ assert in_name in gin_names, "The input variable %s can't be found in
graph inputs: %s" \
+ % (in_name, str(gin_names))
+ if in_name in state_names:
+ ordered_ins.append(states_map[in_name])
+ in_state_locs.append(len(ordered_ins) - 1)
+ elif in_name in data_names:
+ ordered_ins.append(data_map[in_name])
+ in_data_locs.append(len(ordered_ins) - 1)
+ else:
+ # The remaining inputs are the ones cut from the original graph.
+ # The names of these variable nodes contain the index in cut_syms.
+ m = re.search(r'\d+$', in_name)
+ idx = int(m.group()) if m else None
+ assert idx < len(cut_syms)
+ ordered_ins.append(cut_syms[idx])
+
+ num_outputs = len(flat_out)
+ num_states = len(state_names)
+ ret = symbol._internal._foreach(g, *ordered_ins, num_outputs=num_outputs,
+ num_out_data=num_out_data,
in_state_locs=in_state_locs,
+ in_data_locs=in_data_locs)
+ if num_outputs - num_states > 1:
+ outs = []
+ for i in range(num_outputs - num_states):
+ outs.append(ret[i])
+ else:
+ outs = ret[0]
+ states = []
+ for i in range(num_states):
+ states.append(ret[num_outputs - num_states + i])
+
+ if not_state_list:
+ # If there is only one input state, there should be only one output
state.
+ assert len(states) == 1
+ states = states[0]
+
+ return (outs, states)
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 4666b6adf0c..030ab432228 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -38,10 +38,11 @@ void RegisterLegacyOpProp();
void RegisterLegacyNDFunc();
}
const std::vector<std::string> kHiddenKeys = {
- "ctx_group", "lr_mult", "wd_mult", "force_mirroring", "mirror_stage"
+ "ctx_group", "lr_mult", "wd_mult", "force_mirroring", "mirror_stage",
"subgraph_name"
};
const std::vector<std::string> kReplacedHiddenKeys = {
- "__ctx_group__", "__lr_mult__", "__wd_mult__", "__force_mirroring__",
"__mirror_stage__"
+ "__ctx_group__", "__lr_mult__", "__wd_mult__", "__force_mirroring__",
"__mirror_stage__",
+ "subgraph_name"
};
const char *kNamespaceSeparator = "$";
@@ -344,6 +345,75 @@ int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator
creator,
API_END();
}
+namespace mxnet {
+
+extern std::vector<nnvm::Symbol *> GetInputSymbols(const nnvm::Symbol &sym);
+extern bool CutGraph(const std::vector<nnvm::NodeEntry *> &input_entries,
+ const std::string &in_name_prefix, bool skip_var,
+ std::vector<nnvm::NodeEntry> *orig_entries,
+ std::vector<std::string> *new_var_names);
+
+}
+
+int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **input_arr, int
*input_size) {
+ API_BEGIN();
+ nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
+ size_t max_input_size = *input_size;
+ std::vector<nnvm::Symbol *> input_syms = mxnet::GetInputSymbols(*s);
+ CHECK(input_syms.size() <= max_input_size);
+ *input_size = input_syms.size();
+ memcpy(input_arr, input_syms.data(), sizeof(*input_arr) * input_syms.size());
+ API_END_HANDLE_ERROR();
+}
+
+int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols,
+ int *input_size) {
+ // Given a graph, we want to fetch the nodes that have been marked as part of
+ // a subgraph.
+ API_BEGIN();
+ nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
+ size_t max_input_size = *input_size;
+ std::string subg_attr = "__subgraph_name__";
+ auto out_node = s->outputs[0].node;
+ auto it = out_node->attrs.dict.find(subg_attr);
+ if (it != out_node->attrs.dict.end()) {
+ std::string subg_name = it->second;
+ std::vector<nnvm::NodeEntry *> input_entries;
+ DFSVisit(s->outputs, [subg_attr, subg_name, &input_entries]
+ (nnvm::NodePtr n) {
+ // If the node itself isn't in the subgraph, we ignore it.
+ auto it = n->attrs.dict.find(subg_attr);
+ if (it == n->attrs.dict.end() || it->second != subg_name)
+ return;
+
+ // We search for nodes whose node entries aren't in the subgraph.
+ for (size_t j = 0; j < n->inputs.size(); j++) {
+ auto in_node = n->inputs[j].node;
+ auto it = in_node->attrs.dict.find(subg_attr);
+ if (it == in_node->attrs.dict.end() || it->second != subg_name)
+ input_entries.push_back(&n->inputs[j]);
+ }
+ });
+
+ std::vector<nnvm::NodeEntry> orig_entries;
+ std::vector<std::string> new_var_names;
+ CutGraph(input_entries, subg_name + "_var", false, &orig_entries,
&new_var_names);
+
+ std::vector<nnvm::Symbol *> input_syms(orig_entries.size());
+ for (size_t i = 0; i < input_syms.size(); i++) {
+ input_syms[i] = new nnvm::Symbol();
+ input_syms[i]->outputs.push_back(orig_entries[i]);
+ }
+ CHECK(input_syms.size() <= max_input_size);
+ *input_size = input_syms.size();
+ memcpy(input_symbols, input_syms.data(), sizeof(*input_symbols) *
input_syms.size());
+ } else {
+ *input_size = 0;
+ }
+
+ API_END_HANDLE_ERROR();
+}
+
int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out) {
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
diff --git a/src/executor/attach_op_execs_pass.cc
b/src/executor/attach_op_execs_pass.cc
index 697e4869a04..b90aa83099a 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -126,6 +126,10 @@ class StatefulComputeExecutor : public
StorageFallbackOpExecutor {
PostFCompute(is_gpu);
}
+ bool HasSubgraph() const override {
+ return !attrs_.subgraphs.empty();
+ }
+
ExecType exec_type() const override {
return exec_type_;
}
@@ -134,15 +138,17 @@ class StatefulComputeExecutor : public
StorageFallbackOpExecutor {
return state_.get_var();
}
- explicit StatefulComputeExecutor(const OpStatePtr& state,
+ explicit StatefulComputeExecutor(const NodeAttrs& attrs,
+ const OpStatePtr& state,
const FStatefulCompute& fcompute,
ExecType exec_type,
const std::vector<uint32_t> &mutate_idx)
- : StorageFallbackOpExecutor(mutate_idx),
+ : StorageFallbackOpExecutor(mutate_idx), attrs_(attrs),
state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
private:
friend Graph AttachOpExecs(Graph g);
+ NodeAttrs attrs_;
OpStatePtr state_;
FStatefulCompute fcompute_;
ExecType exec_type_;
@@ -160,6 +166,10 @@ class StatefulComputeExExecutor : public OpExecutor {
fcompute_(state_, op_ctx, in_array, req, out_array);
}
+ bool HasSubgraph() const override {
+ return !attrs_.subgraphs.empty();
+ }
+
void Setup() override {}
ExecType exec_type() const override {
@@ -170,13 +180,14 @@ class StatefulComputeExExecutor : public OpExecutor {
return state_.get_var();
}
- explicit StatefulComputeExExecutor(const OpStatePtr& state,
+ explicit StatefulComputeExExecutor(const NodeAttrs& attrs, const OpStatePtr&
state,
const FStatefulComputeEx& fcompute,
ExecType exec_type)
- : state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
+ : attrs_(attrs), state_(state), fcompute_(fcompute),
exec_type_(exec_type) {}
private:
friend Graph AttachOpExecs(Graph g);
+ NodeAttrs attrs_;
OpStatePtr state_;
FStatefulComputeEx fcompute_;
ExecType exec_type_;
@@ -201,6 +212,10 @@ class FComputeExecutor : public StorageFallbackOpExecutor {
return exec_type_;
}
+ bool HasSubgraph() const override {
+ return !attrs_.subgraphs.empty();
+ }
+
explicit FComputeExecutor(const NodeAttrs& attrs, FCompute fcompute,
ExecType exec_type, const std::vector<uint32_t>
&mutate_idx)
: StorageFallbackOpExecutor(mutate_idx),
@@ -226,6 +241,10 @@ class FComputeExExecutor : public OpExecutor {
void Setup() override {}
+ bool HasSubgraph() const override {
+ return !attrs_.subgraphs.empty();
+ }
+
ExecType exec_type() const override {
return exec_type_;
}
@@ -289,15 +308,17 @@ Graph AttachOpExecs(Graph g) {
op, "FStatefulComputeEx", vctx[i]);
// FStatefulComputeEx is dispatched only when dispatch_mode is
DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] ==
DispatchMode::kFComputeEx) {
- ret[i] = std::make_shared<StatefulComputeExExecutor>(state,
fcompute_ex, exec_type);
+ ret[i] =
std::make_shared<StatefulComputeExExecutor>(inode.source->attrs, state,
+ fcompute_ex,
exec_type);
} else {
FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
op, "FStatefulCompute", vctx[i]);
CHECK(fcompute != nullptr)
<< "One of FStatefulCompute and FStatefulComputeEx must be
registered "
<< "for stateful operator " << op->name;
- ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute,
- exec_type,
mutate_index);
+ ret[i] =
std::make_shared<StatefulComputeExecutor>(inode.source->attrs, state,
+ fcompute, exec_type,
+ mutate_index);
}
} else if (is_layer_backward.get(op, false)) {
CHECK_GE(inode.control_deps.size(), 1);
@@ -308,7 +329,7 @@ Graph AttachOpExecs(Graph g) {
op, "FStatefulComputeEx", vctx[i]);
// FStatefulComputeEx is dispatched only when dispatch_mode is
DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] ==
DispatchMode::kFComputeEx) {
- ret[i] = std::make_shared<StatefulComputeExExecutor>(
+ ret[i] =
std::make_shared<StatefulComputeExExecutor>(inode.source->attrs,
dynamic_cast<StatefulComputeExExecutor*>(ret[fwd_id].get())->state_,
fcompute_ex, exec_type);
} else {
@@ -317,7 +338,7 @@ Graph AttachOpExecs(Graph g) {
CHECK(fcompute != nullptr)
<< "One of FStatefulCompute and FStatefulComputeEx must be
registered "
<< "for stateful operator " << op->name;
- ret[i] = std::make_shared<StatefulComputeExecutor>(
+ ret[i] = std::make_shared<StatefulComputeExecutor>(inode.source->attrs,
dynamic_cast<StatefulComputeExecutor*>(ret[fwd_id].get())->state_,
fcompute, exec_type, mutate_index);
}
diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index 99b1b162eae..f49fcf61db2 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -64,6 +64,7 @@ class OpExecutor {
OpContext op_ctx;
/*! \brief virtual destructor */
virtual ~OpExecutor() {}
+ virtual bool HasSubgraph() const = 0;
/*!
* \brief Setup the executor for given NDArray member
* this can be called multiple times if NDArray changed during reshape.
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 7a15f6c931c..ca06a12a5a0 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -39,6 +39,7 @@ namespace exec {
GraphExecutor::GraphExecutor() {
log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false);
+ need_grad_ = false;
}
GraphExecutor::~GraphExecutor() {
@@ -257,11 +258,11 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol
symbol,
nnvm::Graph g;
g.outputs = symbol.outputs;
- bool need_grad = false;
+ need_grad_ = false;
for (OpReqType req : grad_req_types) {
- if (req != kNullOp) need_grad = true;
+ if (req != kNullOp) need_grad_ = true;
}
- if (!need_grad) return g;
+ if (!need_grad_) return g;
for (size_t i = 0; i < g.outputs.size(); ++i) {
NodeEntry ngrad{nnvm::Node::Create(), 0, 0};
head_grad_entry_.emplace_back(AttrHint(ngrad, g.outputs[i]));
@@ -1378,7 +1379,11 @@ void GraphExecutor::BulkTrainingOpSegs(size_t
total_num_nodes) {
// check if the segment relies on external input, or exceeds maxinum
number of node,
// or requires async ops
if (node->is_variable() || nid - topo_start > num_nodes_threshold ||
- op_node.exec->exec_type() != ExecType::kSync) {
+ op_node.exec->exec_type() != ExecType::kSync ||
+ // If the node has a subgraph, we shouldn't add it to the segment.
+ // We'll execute the node separately from other nodes.
+ // CreateCachedSegOpr creates a segment excluding nodes with subgraphs.
+ op_node.exec->HasSubgraph()) {
// create a new segment for the previous nodes if the current one cannot
be bulked
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
topo_start = nid + 1;
@@ -1403,7 +1408,11 @@ void GraphExecutor::BulkTrainingOpSegs(size_t
total_num_nodes) {
continue;
}
if (idx[nid].source->is_variable() || nid - topo_start >
num_nodes_threshold ||
- op_node.exec->exec_type() != ExecType::kSync) {
+ op_node.exec->exec_type() != ExecType::kSync ||
+ // If the node has a subgraph, we shouldn't add it to the segment.
+ // We'll execute the node separately from other nodes.
+ // CreateCachedSegOpr creates a segment excluding nodes with subgraphs.
+ op_node.exec->HasSubgraph()) {
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
topo_start = nid + 1;
} else {
@@ -1437,7 +1446,11 @@ void GraphExecutor::BulkInferenceOpSegs() {
// Variables do not need to be segmented at inference time.
if (node->is_variable()) continue;
- if (op_node.exec->exec_type() != ExecType::kSync) {
+ if (op_node.exec->exec_type() != ExecType::kSync ||
+ // If the node has a subgraph, we shouldn't add it to the segment.
+ // We'll execute the node separately from other nodes.
+ // CreateCachedSegOpr creates a segment excluding nodes with subgraphs.
+ op_node.exec->HasSubgraph()) {
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
topo_start = nid + 1;
}
@@ -1480,6 +1493,7 @@ void GraphExecutor::RunOps(bool is_train, size_t
topo_start, size_t topo_end) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
opnode.exec->op_ctx.is_train = is_train;
+ opnode.exec->op_ctx.need_grad = need_grad_;
}
// Push Ops
@@ -1498,11 +1512,15 @@ void GraphExecutor::RunOps(bool is_train, size_t
topo_start, size_t topo_end) {
OpNode& opnode = op_nodes_[nid];
if (op_nodes_[nid].skip_exec_node) continue;
opnode.exec->op_ctx.is_train = is_train;
+ opnode.exec->op_ctx.need_grad = need_grad_;
if (opnode.exec->exec_type() == ExecType::kCrossDeviceCopy) {
CHECK_EQ(inode.inputs.size(), 1U);
CHECK_EQ(opnode.exec->in_array.size(), 1U);
CHECK_EQ(opnode.exec->out_array.size(), 1U);
CopyFromTo(opnode.exec->in_array[0], &(opnode.exec->out_array[0]));
+ } else if (opnode.exec->HasSubgraph()) {
+ // If the node contains a subgraph, we can't execute it in the engine.
+ opnode.exec->Run(opnode.exec->op_ctx.run_ctx, false);
} else if (opnode.cached_opr != nullptr) {
bool profiling = profiler::Profiler::Get()->GetState() ==
profiler::Profiler::kRunning;
Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling);
@@ -1537,6 +1555,9 @@ GraphExecutor::CachedSegOpr
GraphExecutor::CreateCachedSegOpr(size_t topo_start,
OpNode& op_node = op_nodes_[nid];
if (op_node.skip_exec_node) continue;
if (inode.source->is_variable()) continue;
+ // We shouldn't add control flow operators to a segment.
+ // We can't execute these operators in the engine.
+ if (op_node.exec->HasSubgraph()) return ret;
if (op_node.exec->exec_type() != ExecType::kSync) {
return ret;
}
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index bcde41d508e..fa2a156d3d7 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -203,6 +203,8 @@ class GraphExecutor : public Executor {
// perform bulking and segmentation on a training graph
void BulkTrainingOpSegs(size_t total_num_nodes);
+ // indicate whether there is a backward graph for gradients.
+ bool need_grad_;
// internal graph
nnvm::Graph graph_;
// operator node
diff --git a/src/imperative/imperative_utils.h
b/src/imperative/imperative_utils.h
index d7bb37b7cfe..1135c0d2d41 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -359,6 +359,7 @@ inline void PushFCompute(const FCompute& fn,
static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
bool is_train = Imperative::Get()->is_training();
+ bool need_grad = Imperative::Get()->is_recording();
ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) :
ExecType::kSync;
CHECK(exec_type == ExecType::kSync);
std::vector<NDArray> inputs, outputs;
@@ -379,7 +380,7 @@ inline void PushFCompute(const FCompute& fn,
&input_blobs, &output_blobs, &pre_temp_src,
&pre_temp_dst,
&post_temp_src, &post_temp_dst, &in_temp_idx_map,
mutate_idx);
// setup context
- OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
+ OpContext opctx{need_grad, is_train, rctx, engine::CallbackOnComplete(),
requested};
bool is_gpu = ctx.dev_mask() == gpu::kDevMask;
// pre-fcompute fallback, cast to default storage type
CastNonDefaultStorage(pre_temp_src, pre_temp_dst, opctx, is_gpu);
@@ -406,11 +407,12 @@ inline void PushFComputeEx(const FComputeEx& fn,
static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
bool is_train = Imperative::Get()->is_training();
+ bool need_grad = Imperative::Get()->is_recording();
ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) :
ExecType::kSync;
std::vector<NDArray> inputs, outputs;
DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
const auto& run = [=](RunContext rctx) {
- OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
+ OpContext opctx{need_grad, is_train, rctx, engine::CallbackOnComplete(),
requested};
#if MXNET_USE_MKLDNN == 1
InvalidateOutputs(outputs, req);
#endif
@@ -445,6 +447,7 @@ inline void PushOperator(const OpStatePtr& state,
static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
bool is_train = Imperative::Get()->is_training();
+ bool need_grad = Imperative::Get()->is_recording();
ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) :
ExecType::kSync;
std::vector<NDArray> inputs, outputs;
DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
@@ -456,17 +459,23 @@ inline void PushOperator(const OpStatePtr& state,
if (fcompute_ex != nullptr && dispatch_mode == DispatchMode::kFComputeEx) {
const auto& run = [=](RunContext rctx,
engine::CallbackOnComplete on_complete) {
- OpContext opctx{is_train, rctx, on_complete, requested};
+ OpContext opctx{need_grad, is_train, rctx, on_complete, requested};
#if MXNET_USE_MKLDNN == 1
InvalidateOutputs(outputs, req);
#endif
fcompute_ex(state, opctx, inputs, req, outputs);
- if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync) {
+ if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync
+ && rctx.get_stream<gpu>()) {
rctx.get_stream<gpu>()->Wait();
}
};
- if (exec_type == ExecType::kSync) {
+ // For operators with subgraphs, we need to invoke them in the main thread
+ // instead of the threaded engine.
+ if (!attrs.subgraphs.empty()) {
+ RunContext rctx{ctx, nullptr};
+ run(rctx, engine::CallbackOnComplete());
+ } else if (exec_type == ExecType::kSync) {
Engine::Get()->PushSync(
[=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); },
ctx, read_vars, write_vars, FnProperty::kNormal, 0,
@@ -483,7 +492,7 @@ inline void PushOperator(const OpStatePtr& state,
<< "for stateful operator " << op->name;
const auto& run = [=](RunContext rctx, engine::CallbackOnComplete
on_complete) {
- OpContext opctx{is_train, rctx, on_complete, requested};
+ OpContext opctx{need_grad, is_train, rctx, on_complete, requested};
std::vector<TBlob> input_blobs, output_blobs;
// pre-fcompute and post-fcompute storage fallback src NDArrays and
dst NDArrays
@@ -505,12 +514,16 @@ inline void PushOperator(const OpStatePtr& state,
fcompute(state, opctx, input_blobs, tmp_req, output_blobs);
// post-fcompute fallback, cast to original storage type, if necessary
CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu);
- if (is_gpu && exec_type == ExecType::kSync) {
+ if (is_gpu && exec_type == ExecType::kSync
+ && rctx.get_stream<gpu>()) {
rctx.get_stream<gpu>()->Wait();
}
};
- if (exec_type == ExecType::kSync) {
+ if (!attrs.subgraphs.empty()) {
+ RunContext rctx{ctx, nullptr};
+ run(rctx, engine::CallbackOnComplete());
+ } else if (exec_type == ExecType::kSync) {
Engine::Get()->PushSync(
[=](RunContext rctx) {
run(rctx, engine::CallbackOnComplete());
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index d87e8bc95ea..764711f020f 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -200,6 +200,7 @@ NDArray NDArray::MKLDNNDataReshape(const TShape &shape)
const {
ret.ptr_->delay_alloc = false;
ret.ptr_->static_data = true;
ret.byte_offset_ = byte_offset_;
+ ret.reuse_ = false;
return ret;
}
}
@@ -217,6 +218,7 @@ NDArray NDArray::Reshape(const TShape &shape) const {
// Otherwise, reshape only works on the default layout.
CHECK_EQ(storage_type(), kDefaultStorage);
ret.shape_ = shape;
+ ret.reuse_ = false;
return ret;
}
@@ -249,6 +251,7 @@ NDArray NDArray::Slice(index_t begin, index_t end) const {
MSHADOW_TYPE_SWITCH(ret.dtype(), DType, {
ret.byte_offset_ += begin * length * sizeof(DType);
});
+ ret.reuse_ = false;
ret.shape_[0] = end - begin;
return ret;
}
@@ -555,6 +558,7 @@ NDArray NDArray::Reorder2Default() const {
// reshape as needed
ret.shape_ = shape_;
ret.byte_offset_ = byte_offset_;
+ ret.reuse_ = false;
return ret;
}
@@ -584,39 +588,39 @@ void NDArray::MKLDNNDataReorderAsync(const
mkldnn::memory::primitive_desc &desc)
const mkldnn::memory *NDArray::GetMKLDNNData() const {
CHECK(storage_type() == kDefaultStorage);
+ bool is_view = IsView();
if (IsMKLDNNData()) {
// If this array uses MKLDNN layout, we have to make sure it's not a view.
// Otherwise, we'll have to change the layout inside the array.
- CHECK(!IsView());
+ CHECK(!is_view);
MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
// If this array uses MKLDNN format, we should return now. Otherwise,
// SetMKLMem may mess up mkl_mem_.
return ptr_->mkl_mem_->GetRaw();
- }
- ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_, dtype_);
- MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
- if (IsView()) {
- mkldnn::memory::primitive_desc pd = ptr_->mkl_mem_->GetPrimitiveDesc();
- // Sliced array must use the default layout.
- CHECK_EQ(GetDefaultFormat(pd.desc()), pd.desc().data.format);
- void *off_addr = static_cast<char *>(ptr_->mkl_mem_->GetDataHandle())
- + byte_offset_;
-
+ } else if (is_view) {
+ // If this is a view, we can't create a MKLDNN memory for the chunk
+ // because we don't have the complete data type and shape information for
+ // the chunk.
+ void *off_addr = static_cast<char *>(ptr_->shandle.dptr) + byte_offset_;
// Create the primitive desc for the new mkldnn memory.
mkldnn::memory::dims dims(shape().ndim());
for (size_t i = 0; i < dims.size(); i++)
dims[i] = shape()[i];
mkldnn::memory::format cpp_format = static_cast<mkldnn::memory::format>(
GetDefaultFormat(shape().ndim()));
- mkldnn::memory::data_type cpp_type =
static_cast<mkldnn::memory::data_type>(
- pd.desc().data.data_type);
+ mkldnn::memory::data_type cpp_type = get_mkldnn_type(dtype_);
mkldnn::memory::desc data_md(dims, cpp_type, cpp_format);
- mkldnn::memory::primitive_desc new_pd(data_md, pd.get_engine());
+ mkldnn::memory::primitive_desc new_pd(data_md,
+ CpuEngine::Get()->get_engine());
std::shared_ptr<mkldnn::memory> ret(new mkldnn::memory(new_pd, off_addr));
MKLDNNStream::Get()->RegisterMem(ret);
return ret.get();
} else {
+ // If this isn't a view, we can create a MKLDNN memory and store it in the
+ // chunk.
+ ptr_->SetMKLMem(shape_, dtype_);
+ MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
return ptr_->mkl_mem_->GetRaw();
}
}
@@ -637,10 +641,9 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
MKLDNNStream *stream = MKLDNNStream::Get();
// If this array uses MKLDNN layout, we have to make sure it's not a view.
// Otherwise, we'll have to change the layout inside the array.
- if (IsMKLDNNData())
- CHECK(!IsView());
- ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_,
- dtype_);
+
+ CHECK(!IsView());
+ ptr_->SetMKLMem(shape_, dtype_);
stream->RegisterMem(ptr_->mkl_mem_->GetMem());
mkldnn::memory::desc from_desc = mem.get_primitive_desc().desc();
mkldnn::memory::desc this_desc = ptr_->mkl_mem_->GetPrimitiveDesc().desc();
@@ -713,9 +716,6 @@ mkldnn::memory::primitive_desc
GetPrimitiveDesc(mkldnn::memory::primitive_desc p
mkldnn_memory_format_t format);
mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc
&desc) {
- // This array shouldn't be a view.
- CHECK(!IsView());
-
if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN
memory desc";
return nullptr;
@@ -726,10 +726,26 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const
mkldnn::memory::primitive_desc &
mkldnn_memory_format_t def_format = GetDefaultFormat(_desc.desc());
// If the required format is a default format, we don't need to worry about
the shape.
// If the shape isn't the same, it actually implicitly reshapes data.
- if (required_format == def_format) {
+ if (required_format == def_format && !IsView()) {
ptr_->SetMKLMem(shape_, dtype_);
MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc);
+ } else if (required_format == def_format) {
+ ptr_->CheckAndAlloc();
+ CHECK(ptr_->shandle.dptr);
+ // When this is a view and a user wants the default layout, we can simply
+ // create a new mkldnn memory that points to the right memory.
+ std::shared_ptr<mkldnn::memory> mem(new mkldnn::memory(
+ desc, ptr_->shandle.dptr + byte_offset_));
+ MKLDNNStream::Get()->RegisterMem(mem);
+ return mem.get();
+ } else if (IsView()) {
+ // If this is a view and a user wants to write data to it with special
+ // a MKLDNN format, we should reorder the data in the array and return
NULL.
+ // In this way, the user will create a new NDArray for the special format
+ // and copy data back.
+ ptr_->Reorder2Default();
+ return nullptr;
}
if (ptr_->mkl_mem_)
@@ -1160,7 +1176,8 @@ void CopyFromToImpl(const NDArray& from, const NDArray&
to,
const Context to_ctx = to.ctx();
bool is_train = Imperative::Get()->is_training();
- OpContext opctx{is_train,
+ OpContext opctx{Imperative::Get()->is_recording(),
+ is_train,
rctx,
engine::CallbackOnComplete(),
requested};
diff --git a/src/nnvm/graph_editor.cc b/src/nnvm/graph_editor.cc
new file mode 100644
index 00000000000..98c99e2425d
--- /dev/null
+++ b/src/nnvm/graph_editor.cc
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file graph_editor.cc
+ * The functions in this file edit an NNVM graph. Potentially,
+ * these functions should be moved to NNVM in the future.
+ */
+
+#include <nnvm/symbolic.h>
+#include <nnvm/graph.h>
+#include <nnvm/node.h>
+
+namespace nnvm {
+NodePtr CreateVariableNode(const std::string& name);
+}
+
+namespace mxnet {
+
+/*
+ * Given a computation graph, this function finds the input nodes of the graph
+ * and create symbols for the input nodes. It returns the input symbols.
+ */
+std::vector<nnvm::Symbol *> GetInputSymbols(const nnvm::Symbol &sym) {
+ nnvm::Graph g;
+ std::vector<nnvm::Symbol *> input_syms;
+ g.outputs = sym.outputs;
+ const nnvm::IndexedGraph& idx = g.indexed_graph();
+ // Go through all nodes and return the ones representing variables.
+ for (size_t i = 0; i < idx.num_nodes(); i++) {
+ const nnvm::Node &n = *idx[i].source;
+ for (const nnvm::NodeEntry &e : n.inputs) {
+ auto p = e.node;
+ if (p->is_variable()) {
+ nnvm::Symbol *s = new nnvm::Symbol();
+ s->outputs.push_back(e);
+ input_syms.push_back(s);
+ }
+ }
+ }
+ return input_syms;
+}
+
+/*
+ * Given a computation graph and a set of input node entries, this function
cuts
+ * the node entries and creates new variable nodes as the input nodes of the
+ * subgraph. It returns the nodes that connect to the subgraph directly and
+ * the names of the new variable nodes.
+ */
+bool CutGraph(const std::vector<nnvm::NodeEntry *> &input_entries,
+ const std::string &in_name_prefix, bool skip_var,
+ std::vector<nnvm::NodeEntry> *orig_entries,
+ std::vector<std::string> *new_var_names) {
+ orig_entries->reserve(input_entries.size());
+ for (size_t i = 0; i < input_entries.size(); i++) {
+ nnvm::NodeEntry *e = input_entries[i];
+ // If the node is a variable itself, we may want to skip the node.
+ if (e->node->is_variable() && skip_var)
+ continue;
+
+ orig_entries->push_back(*e);
+ new_var_names->push_back(in_name_prefix + std::to_string(i));
+ nnvm::NodePtr n = nnvm::CreateVariableNode(new_var_names->back());
+ *e = nnvm::NodeEntry{n, 0, 0};
+ }
+ return true;
+}
+
+}
diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc
new file mode 100644
index 00000000000..c42aca0944d
--- /dev/null
+++ b/src/operator/control_flow.cc
@@ -0,0 +1,419 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <mxnet/io.h>
+#include <mxnet/base.h>
+#include <mxnet/ndarray.h>
+#include <mxnet/operator.h>
+#include <mxnet/operator_util.h>
+#include <dmlc/logging.h>
+#include <dmlc/optional.h>
+#include "./operator_common.h"
+#include "./elemwise_op_common.h"
+#include "../imperative/imperative_utils.h"
+#include "./subgraph_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+struct ForeachParam : public dmlc::Parameter<ForeachParam> {
+ int num_args;
+ int dim;
+ int num_outputs;
+ int num_out_data;
+ nnvm::Tuple<dim_t> in_state_locs;
+ nnvm::Tuple<dim_t> in_data_locs;
+ DMLC_DECLARE_PARAMETER(ForeachParam) {
+ DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
+ .describe("Number of inputs.");
+ DMLC_DECLARE_FIELD(dim).set_default(1)
+ .describe("the dimension of the input array to iterate.");
+ DMLC_DECLARE_FIELD(num_outputs)
+ .describe("The number of outputs of the subgraph.");
+ DMLC_DECLARE_FIELD(num_out_data)
+ .describe("The number of output data of the subgraph.");
+ DMLC_DECLARE_FIELD(in_state_locs)
+ .describe("The locations of loop states among the inputs.");
+ DMLC_DECLARE_FIELD(in_data_locs)
+ .describe("The locations of input data among the inputs.");
+ }
+}; // struct ForeachParam
+
+DMLC_REGISTER_PARAMETER(ForeachParam);
+
+class ForeachState: public LoopState {
+ public:
+ ForeachParam params;
+
+ ForeachState(const Symbol &g, const ForeachParam ¶ms) : LoopState(g) {
+ this->params = params;
+ }
+};
+
+static void ForeachComputeExCPU(const OpStatePtr& state_ptr,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ ForeachState &state = state_ptr.get_state<ForeachState>();
+ const ForeachParam& params = state.params;
+ size_t iter_dim = 0;
+ CHECK_EQ(outputs.size(), (size_t) params.num_outputs);
+ CHECK_GT(params.in_data_locs.ndim(), 0);
+ size_t loc0 = params.in_data_locs[0];
+ size_t len = inputs[loc0].shape()[iter_dim];
+ for (size_t i = 1; i < params.in_data_locs.ndim(); i++) {
+ size_t loc = params.in_data_locs[i];
+ CHECK_EQ(inputs[loc].shape()[iter_dim], len);
+ }
+ for (size_t i = 0; i < (size_t) params.num_out_data; i++)
+ CHECK_EQ(len, outputs[i].shape()[iter_dim]);
+ for (const auto &arr : outputs)
+ CHECK_EQ(arr.storage_type(), kDefaultStorage)
+ << "The for operator doesn't support the sparse format";
+
+ // Initialize the outputs of the subgraph is a little trickier.
+ // The states from the previous iteration are used as the inputs of the next
+ // iteration, so I have to maintain two arrays, so the inputs and outputs
+ // of the subgraph share the same memory.
+ std::vector<NDArray> subg_outputs1(outputs.size());
+ std::vector<NDArray> subg_outputs2(outputs.size());
+ std::vector<NDArray> *subg_outputs[2]{&subg_outputs1, &subg_outputs2};
+ // If the length is an odd number, the last iteration will use the first set
+ // of outputs. In this way, we don't need to copy the results from the
+ // subgraph to the final outputs of the loop.
+ if (len % 2 == 1) {
+ for (size_t i = 1; i < subg_outputs1.size(); i++) {
+ subg_outputs1[i] = outputs[i];
+ subg_outputs2[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true,
+ outputs[i].dtype());
+ }
+ } else {
+ // Otherwise, we'll use the second set of outputs.
+ for (size_t i = 1; i < subg_outputs1.size(); i++) {
+ subg_outputs1[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true,
+ outputs[i].dtype());
+ subg_outputs2[i] = outputs[i];
+ }
+ }
+
+ // Initialize the inputs for the subgraph.
+ // In each iteration, we need to update the subgraph inputs for input data
+ // and the loop states. This initialization helps to get the read-only
+ // arrays in the loop.
+ std::vector<NDArray> subg_inputs(inputs.size());
+ for (size_t i = 0; i < inputs.size(); i++) {
+ // These are the initial states.
+ subg_inputs[i] = inputs[i];
+ }
+
+ // Here we iterate over the first dimension of the first input array.
+ for (size_t i = 0; i < len; i++) {
+ // Initialize outputs for the subgraph.
+ std::vector<NDArray> *subg_out_curr = subg_outputs[i % 2];
+ std::vector<NDArray> *subg_out_prev = subg_outputs[(i + 1) % 2];
+ for (int j = 0; j < params.num_out_data; j++)
+ (*subg_out_curr)[j] = outputs[j].At(i);
+ // When recording for backward computation, we should make sure
+ // that output arrays are actually different in each iteration.
+ if (ctx.need_grad && i < len - 1) {
+ for (size_t j = params.num_out_data; j < subg_out_curr->size(); j++)
+ (*subg_out_curr)[j] = NDArray(outputs[j].shape(), outputs[j].ctx(),
+ true, outputs[j].dtype());
+ } else if (ctx.need_grad && i == len - 1) {
+ // For the last iteration, we need to write data to the output array
+ // directly.
+ for (size_t j = params.num_out_data; j < subg_out_curr->size(); j++)
+ (*subg_out_curr)[j] = outputs[j];
+ }
+
+ // Initialize inputs for the subgraph.
+ // Get a slice from the input data arrays.
+ for (size_t j = 0; j < params.in_data_locs.ndim(); j++) {
+ size_t loc = params.in_data_locs[j];
+ subg_inputs[loc] = inputs[loc].At(i);
+ }
+ // For the rest of the iterations, the rest of the arguments are the
outputs
+ // from the previous iteration.
+ if (i > 0) {
+ for (size_t j = params.num_out_data; j < subg_out_prev->size(); j++) {
+ size_t idx = j - params.num_out_data;
+ CHECK_LT(params.in_state_locs[idx], subg_inputs.size());
+ subg_inputs[params.in_state_locs[idx]] = (*subg_out_prev)[j];
+ }
+ }
+
+ state.Forward(subg_inputs, req, *subg_out_curr, ctx.need_grad);
+ // We need to wait for the iteration to complete before executing
+ // the next one or return from the loop. In this way, we can reuse
+ // the memory in the subgraph.
+ for (size_t j = 0; j < subg_out_curr->size(); j++) {
+ (*subg_out_curr)[j].WaitToRead();
+ }
+ }
+}
+
+static void ForeachGradComputeExCPU(const OpStatePtr& state_ptr,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ ForeachState &state = state_ptr.get_state<ForeachState>();
+ const ForeachParam& params = state.params;
+ CHECK_EQ(outputs.size(), (size_t) params.num_args - 1);
+ CHECK_GT(params.in_data_locs.ndim(), 0);
+ for (const auto &arr : outputs)
+ CHECK_EQ(arr.storage_type(), kDefaultStorage)
+ << "The for operator doesn't support the sparse format";
+ size_t iter_dim = 0;
+ std::unordered_set<size_t> in_data_locs(params.in_data_locs.begin(),
+ params.in_data_locs.end());
+ std::unordered_set<size_t> in_state_locs(params.in_state_locs.begin(),
+ params.in_state_locs.end());
+ // The inputs contain out gradients, inputs and outputs.
+ int len = inputs[0].shape()[iter_dim];
+ size_t num_output_data = params.num_out_data;
+
+ // In backward computation, we need to run iterations from backwards.
+ std::vector<NDArray> ograds(params.num_outputs);
+ std::vector<NDArray> igrads(outputs.size());
+ for (size_t i = num_output_data; i < ograds.size(); i++)
+ ograds[i] = inputs[i];
+ std::vector<OpReqType> iter_req(req.size());
+ for (auto r : req)
+ CHECK_NE(r, kWriteInplace);
+ for (int iter_num = len - 1; iter_num >= 0; iter_num--) {
+ for (int i = 0; i < params.num_out_data; i++)
+ ograds[i] = inputs[i].At(iter_num);
+
+ // There are three types of arrays in igrads.
+ // * data gradients.
+ // * loop variable gradients.
+ // * read-only variable gradients.
+ // These are the input data gradients.
+ for (size_t i = 0; i < igrads.size(); i++) {
+ // data gradients.
+ if (in_data_locs.count(i)) {
+ igrads[i] = outputs[i].At(iter_num);
+ iter_req[i] = req[i];
+ continue;
+ }
+
+ bool in_state = in_state_locs.count(i);
+ if (iter_num != 0 && in_state) {
+ // For state gradients, we need to allocate new NDArrays
+ // because intermediate state gradients won't be returned to the users.
+ igrads[i] = NDArray(outputs[i].shape(), outputs[i].ctx(),
+ true, outputs[i].dtype());
+ } else {
+ igrads[i] = outputs[i];
+ }
+ if (in_state)
+ // For the first iteration, we need to use the request provided by
+ // the user to write state gradients to the outputs.
+ iter_req[i] = iter_num != 0 ? kWriteTo : req[i];
+ else
+ // For all read-only variable gradients, we need to use the request
+ // provided by the user in the last iteration and later on add
gradients
+ // to the output arrays.
+ iter_req[i] = iter_num == len - 1 ? req[i]: kAddTo;
+ }
+
+ state.Backward(iter_num, ograds, iter_req, igrads);
+
+ // We need to wait for the iteration to complete before executing
+ // the next one or return from the loop. In this way, we can reuse
+ // the memory in the subgraph.
+ for (size_t i = 0; i < igrads.size(); i++) {
+ igrads[i].WaitToRead();
+ }
+
+ size_t num_states = ograds.size() - num_output_data;
+ for (size_t i = 0; i < num_states; i++) {
+ size_t loc = params.in_state_locs[i];
+ CHECK_LT(loc, igrads.size());
+ ograds[i + num_output_data] = igrads[loc];
+ }
+ }
+ state.Cleanup();
+}
+
+static bool ForeachShape(const nnvm::NodeAttrs& attrs,
+ std::vector<TShape> *in_shape,
+ std::vector<TShape> *out_shape) {
+ const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+ CHECK_EQ(out_shape->size(), (size_t) params.num_outputs);
+ nnvm::ShapeVector shape_inputs = *in_shape;
+ // foreach iterates over the first input NDArray over the first dimension.
+ size_t loc0 = params.in_data_locs[0];
+ size_t len = in_shape->at(loc0)[0];
+ for (size_t i = 0; i < params.in_data_locs.ndim(); i++) {
+ size_t loc = params.in_data_locs[i];
+ CHECK_EQ(len, in_shape->at(loc)[0]);
+ shape_inputs[loc] = TShape(in_shape->at(loc).begin() + 1,
in_shape->at(loc).end());
+ }
+ CHECK_EQ(attrs.subgraphs.size(), 1U);
+ nnvm::Graph g;
+ g.outputs = attrs.subgraphs[0]->outputs;
+ const auto& idx = g.indexed_graph();
+ CHECK_EQ(idx.input_nodes().size(), in_shape->size());
+ CHECK_EQ(idx.outputs().size(), out_shape->size());
+ imperative::CheckAndInferShape(&g, std::move(shape_inputs), true);
+
+ const auto& shapes = g.GetAttr<nnvm::ShapeVector>("shape");
+ // Inferring the shape in the subgraph may infer the shape of the inputs.
+ // We need to copy the inferred input shapes back.
+ const auto &input_nids = idx.input_nodes();
+ CHECK_EQ(input_nids.size(), in_shape->size());
+ for (size_t i = 0; i < in_shape->size(); i++) {
+ auto eid = idx.entry_id(input_nids[i], 0);
+ // If the input shape is none, we should update them.
+ if ((*in_shape)[i].ndim() == 0 || (*in_shape)[i].Size() == 0)
+ SHAPE_ASSIGN_CHECK(*in_shape, i, shapes[eid]);
+ }
+
+ // For the shape of output data.
+ for (int i = 0; i < params.num_out_data; i++) {
+ uint32_t eid = idx.entry_id(g.outputs[i]);
+ const auto& g_out_shape = shapes[eid];
+ auto out = TShape(g_out_shape.ndim() + 1);
+ out[0] = len;
+ for (size_t i = 1; i < out.ndim(); i++)
+ out[i] = g_out_shape[i - 1];
+ SHAPE_ASSIGN_CHECK(*out_shape, i, out);
+ }
+
+ // For the remaining shapes.
+ for (size_t i = params.num_out_data; i < g.outputs.size(); i++) {
+ uint32_t eid = idx.entry_id(g.outputs[i]);
+ SHAPE_ASSIGN_CHECK(*out_shape, i, shapes[eid]);
+ }
+ size_t num_states = g.outputs.size() - params.num_out_data;
+ for (size_t i = 0; i < num_states; i++) {
+ size_t loc = params.in_state_locs[i];
+ CHECK((*out_shape)[i + params.num_out_data] == (*in_shape)[loc]);
+ }
+ return true;
+}
+
+static bool ForeachType(const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_type, std::vector<int> *out_type)
{
+ const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+ CHECK_EQ(out_type->size(), (size_t) params.num_outputs);
+ CHECK_EQ(attrs.subgraphs.size(), 1U);
+ return InferSubgraphDataType(*attrs.subgraphs[0], in_type, out_type);
+}
+
+static bool ForeachStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+ CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs);
+ CHECK_EQ(attrs.subgraphs.size(), 1U);
+ return InferSubgraphStorage(*attrs.subgraphs[0], dev_mask,
+ dispatch_mode, in_attrs, out_attrs);
+}
+
+static bool BackwardForeachStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+ CHECK_EQ(out_attrs->size(), (size_t) params.num_args - 1);
+ CHECK_EQ(attrs.subgraphs.size(), 1U);
+ return InferSubgraphBackwardStorage(*attrs.subgraphs[0], dev_mask,
+ dispatch_mode, in_attrs, out_attrs);
+}
+
+static OpStatePtr CreateForeachState(const NodeAttrs& attrs,
+ Context ctx,
+ const std::vector<TShape>& ishape,
+ const std::vector<int>& itype) {
+ const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+ return OpStatePtr::Create<ForeachState>(*attrs.subgraphs[0], params);
+}
+
+static std::vector<nnvm::NodeEntry>
+ForeachGradient(const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>&
ograds) {
+ ElemwiseGradUseInOut fgrad{"_backward_foreach"};
+ std::vector<nnvm::NodeEntry> entries = fgrad(n, ograds);
+ entries[0].node->attrs.subgraphs = n->attrs.subgraphs;
+ return entries;
+}
+
+NNVM_REGISTER_OP(_foreach)
+.MXNET_DESCRIBE("Run a for loop over an NDArray with user-defined computation")
+.set_attr_parser(ParamParser<ForeachParam>)
+.set_attr<FInferStorageType>("FInferStorageType", ForeachStorageType)
+.set_num_inputs([](const NodeAttrs& attrs) {
+ const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+ return params.num_args;
+})
+.set_num_outputs([](const NodeAttrs& attrs) {
+ const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+ return params.num_outputs;
+})
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+ std::vector<std::string> names;
+ names.push_back("fn");
+ for (int i = 0; i < params.num_args - 1; i++)
+ names.push_back("data" + std::to_string(i));
+ return names;
+})
+.set_attr<nnvm::FInputGraph>("FInputGraph",
+ [](const NodeAttrs& attrs) {
+ return std::vector<uint32_t>{0};
+})
+.set_attr<nnvm::FGradient>("FGradient", ForeachGradient)
+.set_attr<FCreateOpState>("FCreateOpState", CreateForeachState)
+.set_attr<nnvm::FInferShape>("FInferShape", ForeachShape)
+.set_attr<nnvm::FInferType>("FInferType", ForeachType)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", ForeachComputeExCPU)
+// Foreach operator works like an executor. Its code will always run on CPU.
+// So the same code can be registered for both CPU and GPU.
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", ForeachComputeExCPU)
+.set_attr<std::string>("key_var_num_args", "num_args")
+.add_argument("fn", "Symbol", "Input graph.")
+.add_argument("data", "NDArray-or-Symbol[]",
+ "The input arrays that include data arrays and states.")
+.add_arguments(ForeachParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_foreach)
+.set_num_inputs([](const NodeAttrs& attrs){
+ const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+ return params.num_outputs * 2 + params.num_args - 1;
+ })
+.set_num_outputs([](const NodeAttrs& attrs){
+ const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+ return params.num_args - 1;
+ })
+.set_attr<FInferStorageType>("FInferStorageType", BackwardForeachStorageType)
+.set_attr_parser(ParamParser<ForeachParam>)
+.set_attr<bool>("TIsLayerOpBackward", true)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>",
ForeachGradComputeExCPU)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>",
ForeachGradComputeExCPU);
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/subgraph_op_common.cc
b/src/operator/subgraph_op_common.cc
new file mode 100644
index 00000000000..fa22898c13d
--- /dev/null
+++ b/src/operator/subgraph_op_common.cc
@@ -0,0 +1,256 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./subgraph_op_common.h"
+#include "./operator_common.h"
+#include "../imperative/imperative_utils.h"
+
+namespace mxnet {
+namespace op {
+
+bool InferSubgraphDataType(const nnvm::Symbol &subgraph,
+ std::vector<int> *in_type,
+ std::vector<int> *out_type) {
+ nnvm::DTypeVector dtype_inputs = *in_type;
+ nnvm::Graph g;
+ g.outputs = subgraph.outputs;
+ const auto& idx = g.indexed_graph();
+ CHECK_EQ(idx.input_nodes().size(), in_type->size());
+ CHECK_EQ(idx.outputs().size(), out_type->size());
+ imperative::CheckAndInferType(&g, std::move(dtype_inputs), true);
+
+ const auto &dtypes = g.GetAttr<nnvm::DTypeVector>("dtype");
+
+ // Inferring the data type in the subgraph may infer the data type of the
inputs.
+ // We need to copy the inferred input data types back.
+ const auto &input_nids = idx.input_nodes();
+ CHECK_EQ(input_nids.size(), in_type->size());
+ for (size_t i = 0; i < in_type->size(); i++) {
+ auto eid = idx.entry_id(input_nids[i], 0);
+ TYPE_ASSIGN_CHECK(*in_type, i, dtypes[eid]);
+ }
+
+ for (size_t i = 0; i < g.outputs.size(); i++)
+ TYPE_ASSIGN_CHECK(*out_type, i, dtypes[idx.entry_id(g.outputs[i])]);
+ return true;
+}
+
+bool InferSubgraphStorage(const nnvm::Symbol &subgraph,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ nnvm::Graph g;
+ g.outputs = subgraph.outputs;
+ const auto& idx = g.indexed_graph();
+ CHECK_EQ(idx.input_nodes().size(), in_attrs->size());
+ CHECK_EQ(idx.outputs().size(), out_attrs->size());
+ exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
+ StorageTypeVector storage_type_inputs = *in_attrs;
+ imperative::CheckAndInferStorageType(&g, std::move(dev_masks),
+ std::move(storage_type_inputs), true);
+
+ const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
+
+ // Inferring the storage in the subgraph may infer the storage of the inputs.
+ // We need to copy the inferred input storage back.
+ const auto &input_nids = idx.input_nodes();
+ CHECK_EQ(input_nids.size(), in_attrs->size());
+ for (size_t i = 0; i < in_attrs->size(); i++) {
+ auto eid = idx.entry_id(input_nids[i], 0);
+ STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, i, stypes[eid]);
+ }
+
+ DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+ auto &outputs = idx.outputs();
+ CHECK(outputs.size() == out_attrs->size());
+ for (size_t i = 0; i < out_attrs->size(); i++)
+ STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, stypes[idx.entry_id(outputs[i])]);
+ return true;
+}
+
+bool InferSubgraphBackwardStorage(const nnvm::Symbol &subgraph,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ using namespace nnvm;
+ // construct backward graph
+ nnvm::Graph grad_graph;
+ nnvm::Graph fwd_graph;
+ std::vector<Node *> potential_nodes;
+ {
+ fwd_graph.outputs = subgraph.outputs;
+ std::vector<nnvm::NodeEntry> ograd_entries;
+ ograd_entries.reserve(fwd_graph.outputs.size());
+ for (size_t i = 0; i < fwd_graph.outputs.size(); ++i) {
+ ograd_entries.emplace_back(NodeEntry{Node::Create(), 0, 0});
+ }
+
+ std::vector<NodeEntry> xs;
+ std::vector<NodePtr> args =
subgraph.ListInputs(nnvm::Symbol::kReadOnlyArgs);
+ xs.reserve(args.size());
+ for (const auto& i : args)
+ xs.emplace_back(NodeEntry{i, 0, 0});
+ CHECK_GT(xs.size(), 0)
+ << "There are no inputs in computation graph that require gradients.";
+
+ static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"),
Op::Get("_zeros")};
+ grad_graph = pass::Gradient(
+ fwd_graph, fwd_graph.outputs, xs, ograd_entries,
+ exec::AggregateGradient, nullptr, nullptr,
+ zero_ops, "_copy");
+ potential_nodes.reserve(fwd_graph.outputs.size() + xs.size() +
ograd_entries.size());
+ for (auto e : ograd_entries)
+ potential_nodes.push_back(e.node.get());
+ for (auto e : xs)
+ potential_nodes.push_back(e.node.get());
+ for (auto e : fwd_graph.outputs)
+ potential_nodes.push_back(e.node.get());
+ }
+
+ const auto& idx = grad_graph.indexed_graph();
+ auto input_nodes = idx.input_nodes();
+ StorageTypeVector storage_type_inputs(input_nodes.size());
+ for (size_t i = 0; i < input_nodes.size(); i++) {
+ auto node_id = input_nodes[i];
+ const nnvm::IndexedGraph::Node &n = idx[node_id];
+ auto it = std::find(potential_nodes.begin(), potential_nodes.end(),
n.source);
+ CHECK(it != potential_nodes.end());
+ size_t idx = it - potential_nodes.begin();
+ CHECK_LT(idx, in_attrs->size());
+ storage_type_inputs[i] = in_attrs->at(idx);
+ }
+ CHECK_EQ(idx.outputs().size(), out_attrs->size());
+ exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
+ imperative::CheckAndInferStorageType(&grad_graph, std::move(dev_masks),
+ std::move(storage_type_inputs), true);
+
+ const auto& stypes = grad_graph.GetAttr<StorageTypeVector>("storage_type");
+ DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+ auto &outputs = idx.outputs();
+ CHECK(outputs.size() == out_attrs->size());
+ for (size_t i = 0; i < out_attrs->size(); i++)
+ STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, stypes[idx.entry_id(outputs[i])]);
+ return true;
+}
+
+void LoopState::Forward(std::vector<NDArray> cinputs,
+ const std::vector<OpReqType>& req,
+ std::vector<NDArray> coutputs,
+ bool is_recording) {
+ using namespace nnvm;
+ using namespace imperative;
+
+ bool orig_is_record;
+ if (is_recording)
+ orig_is_record = Imperative::Get()->set_is_recording(true);
+ else
+ orig_is_record = Imperative::Get()->is_recording();
+
+ std::vector<NDArray *> inputs(cinputs.size());
+ std::vector<NDArray *> outputs(coutputs.size());
+ for (size_t i = 0; i < inputs.size(); i++)
+ inputs[i] = &cinputs[i];
+ for (size_t i = 0; i < outputs.size(); i++)
+ outputs[i] = &coutputs[i];
+
+ if (is_recording) {
+ all_inputs.push_back(cinputs);
+ std::vector<NDArray> gradients(cinputs.size());
+ std::vector<NDArray *> input_ptrs(cinputs.size());
+ std::vector<NDArray *> gradient_ptrs(cinputs.size());
+ std::vector<mx_uint> grad_reqs(cinputs.size());
+ for (size_t i = 0; i < gradients.size(); i++) {
+ gradients[i] = NDArray(cinputs[i].shape(), cinputs[i].ctx(),
+ true, cinputs[i].dtype());
+ input_ptrs[i] = &cinputs[i];
+ gradient_ptrs[i] = &gradients[i];
+ grad_reqs[i] = kWriteTo;
+ }
+ Imperative::Get()->MarkVariables(input_ptrs, grad_reqs, gradient_ptrs);;
+ }
+
+ std::vector<std::pair<std::string, std::string> > kwargs;
+ kwargs.push_back(std::pair<std::string, std::string>("inline_limit", "0"));
+ // Get input names.
+ const auto& idx = subgraph.indexed_graph();
+ std::vector<std::string> arg_names(idx.input_nodes().size());
+ for (size_t i = 0; i < idx.input_nodes().size(); ++i)
+ arg_names[i] = idx[idx.input_nodes()[i]].source->attrs.name;
+ // We don't have parameters for the cached op.
+ std::unordered_map<std::string, std::vector<NDArray> > params;
+ CachedOpPtr op = std::make_shared<Imperative::CachedOp>(subgraph_sym, kwargs,
+ arg_names, params);
+ // TODO(zhengda) we need to avoid shape inference and memory plan whenever
the op is
+ // called. Currently, CachedOp allocates memory each time Forward is called.
+ // I need to fix this once the PR for static memory allocation in CachedOp is
+ // merged. https://github.com/apache/incubator-mxnet/pull/10817
+ op->Forward(nullptr, inputs, outputs);
+
+ if (is_recording) {
+ all_outputs.push_back(coutputs);
+ iter_ops.push_back(op);
+ }
+
+ Imperative::Get()->set_is_recording(orig_is_record);
+}
+
+void LoopState::Backward(int iter_no,
+ std::vector<NDArray> ograds,
+ const std::vector<OpReqType> &req,
+ std::vector<NDArray> igrads) {
+ using namespace nnvm;
+ using namespace imperative;
+
+ CHECK_GT(iter_ops.size(), iter_no)
+ << "We didn't record the computation for iteration " << iter_no;
+ auto op = iter_ops[iter_no];
+ std::vector<NDArray *> inputs;
+ std::vector<NDArray *> outputs;
+ inputs.reserve(op->num_backward_inputs());
+ outputs.reserve(op->num_inputs());
+ for (size_t i = 0; i < ograds.size(); i++)
+ inputs.push_back(&ograds[i]);
+
+ const std::vector<bool> &save_inputs = op->save_inputs();
+ const std::vector<bool> &save_outputs = op->save_outputs();
+ CHECK_EQ(save_inputs.size(), all_inputs[iter_no].size());
+ CHECK_EQ(op->num_outputs(), all_outputs[iter_no].size());
+ for (size_t i = 0; i < all_inputs[iter_no].size(); i++) {
+ if (save_inputs[i])
+ inputs.push_back(&all_inputs[iter_no][i]);
+ }
+ for (size_t i = 0; i < all_outputs[iter_no].size(); i++) {
+ if (save_outputs[i])
+ inputs.push_back(&all_outputs[iter_no][i]);
+ }
+ CHECK_EQ(inputs.size(), op->num_backward_inputs());
+ for (size_t i = 0; i < igrads.size(); i++)
+ outputs.push_back(&igrads[i]);
+ CHECK_EQ(outputs.size(), op->num_inputs());
+
+ CHECK(!Imperative::AGInfo::IsNone(all_outputs[iter_no][0]));
+ const nnvm::NodeEntry &node_entry = all_outputs[iter_no][0].entry();
+ OpStatePtr state = Imperative::AGInfo::Get(node_entry.node).state;
+ op->Backward(false, state, inputs, req, outputs);
+}
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/subgraph_op_common.h
b/src/operator/subgraph_op_common.h
new file mode 100644
index 00000000000..74e7cb2d1cc
--- /dev/null
+++ b/src/operator/subgraph_op_common.h
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_OP_COMMON_H_
+#define MXNET_OPERATOR_SUBGRAPH_OP_COMMON_H_
+
+#include <mxnet/io.h>
+#include <mxnet/base.h>
+#include <mxnet/op_attr_types.h>
+#include <vector>
+#include "../imperative/imperative_utils.h"
+
+namespace mxnet {
+namespace op {
+
+/*
+ * Infer the data types of inputs and outputs of an operator that contains a
+ * subgraph.
+ */
+bool InferSubgraphDataType(const nnvm::Symbol &subgraph, std::vector<int>
*in_type,
+ std::vector<int> *out_type);
+
+/*
+ * Infer the storage types of inputs and outputs of an operator that contains a
+ * subgraph.
+ */
+bool InferSubgraphStorage(const nnvm::Symbol &subgraph,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs);
+
+/*
+ * Infer the storage types of inputs and outputs of the backward computation of
+ * an operator that contains a subgraph.
+ */
+bool InferSubgraphBackwardStorage(const nnvm::Symbol &subgraph,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs);
+
+/*
+ * This contains the states for running a loop and provides methods
+ * of running the subgraph computation for an iteration.
+ */
+class LoopState {
+ // These are output arrays from all iterations.
+ // They also contain the Op state for each CachedOp.
+ std::vector<std::vector<NDArray> > all_outputs;
+ std::vector<std::vector<NDArray> > all_inputs;
+ std::vector<std::vector<NDArray> > all_gradients;
+ std::vector<CachedOpPtr> iter_ops;
+ Symbol subgraph_sym;
+ nnvm::Graph subgraph;
+
+ public:
+ LoopState(const Symbol &g) {
+ this->subgraph_sym = g;
+ this->subgraph.outputs = g.outputs;
+ }
+
+ void Forward(std::vector<NDArray> cinputs,
+ const std::vector<OpReqType>& req,
+ std::vector<NDArray> coutputs,
+ bool is_recording);
+ void Backward(int iter_no,
+ std::vector<NDArray> ograds,
+ const std::vector<OpReqType> &req,
+ std::vector<NDArray> igrads);
+ void Cleanup() {
+ all_outputs.clear();
+ all_inputs.clear();
+ all_gradients.clear();
+ iter_ops.clear();
+ }
+};
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_OPERATOR_SUBGRAPH_OP_COMMON_H_
diff --git a/tests/python/unittest/test_gluon_rnn.py
b/tests/python/unittest/test_gluon_rnn.py
index 24d5a932d7b..d4ac88900c5 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -18,9 +18,10 @@
import mxnet as mx
from mxnet import gluon
import numpy as np
+import copy
from numpy.testing import assert_allclose
import unittest
-from mxnet.test_utils import almost_equal
+from mxnet.test_utils import almost_equal, assert_almost_equal
def test_rnn():
@@ -28,13 +29,62 @@ def test_rnn():
inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
outputs, _ = cell.unroll(3, inputs)
outputs = mx.sym.Group(outputs)
- assert sorted(cell.collect_params().keys()) == ['rnn_h2h_bias',
'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
+ assert sorted(cell.collect_params().keys()) == ['rnn_h2h_bias',
'rnn_h2h_weight',
+ 'rnn_i2h_bias',
'rnn_i2h_weight']
assert outputs.list_outputs() == ['rnn_t0_out_output',
'rnn_t1_out_output', 'rnn_t2_out_output']
args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50),
rnn_t1_data=(10,50), rnn_t2_data=(10,50))
assert outs == [(10, 100), (10, 100), (10, 100)]
+class TestRNNLayer(gluon.HybridBlock):
+ def __init__(self, hidden_size, prefix=None, params=None):
+ super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
+ self.cell = gluon.rnn.RNNCell(hidden_size, prefix='rnn_')
+
+ def hybrid_forward(self, F, inputs, states):
+ states = [states]
+ out, states = F.contrib.foreach(self.cell, inputs, states)
+ return out
+
+def test_contrib_rnn():
+ batch_size = 10
+ hidden_size = 100
+ rnn_data = mx.nd.normal(loc=0, scale=1, shape=(5, batch_size, 50))
+ states = mx.nd.normal(loc=0, scale=1, shape=(batch_size, hidden_size))
+ layer = TestRNNLayer(hidden_size)
+ layer.initialize(ctx=mx.cpu(0))
+ res1 = layer(rnn_data, states)
+ params1 = layer.collect_params()
+ orig_params1 = copy.deepcopy(params1)
+
+ trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03})
+ with mx.autograd.record():
+ res1 = layer(rnn_data, states)
+ res1.backward()
+ trainer.step(batch_size)
+
+ layer = TestRNNLayer(hidden_size)
+ layer.initialize(ctx=mx.cpu(0))
+ layer.hybridize()
+ res2 = layer(rnn_data, states)
+ params2 = layer.collect_params()
+ for key, val in orig_params1.items():
+ params2[key].set_data(val.data())
+
+ trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03})
+ with mx.autograd.record():
+ res2 = layer(rnn_data, states)
+ assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001,
atol=0.0001)
+ res2.backward()
+ trainer.step(batch_size)
+
+ for key, val in params1.items():
+ weight1 = val.data()
+ weight2 = params2[key].data()
+ assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(), rtol=0.001,
atol=0.0001)
+
+
def test_lstm():
cell = gluon.rnn.LSTMCell(100, prefix='rnn_')
inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index e7976e01f9d..2b2a66725bb 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -24,7 +24,7 @@
import itertools
from numpy.testing import assert_allclose, assert_array_equal
from mxnet.test_utils import *
-from mxnet.base import py_str, MXNetError
+from mxnet.base import py_str, MXNetError, _as_list
from common import setup_module, with_seed
import unittest
@@ -5663,6 +5663,313 @@ def test_float16_min_max():
assert np.finfo('float16').max == mx.nd.max(a).asscalar()
+@with_seed()
+def test_foreach():
+ v3 = mx.sym.var("v0")
+ v4 = mx.sym.var("v1")
+ v5 = mx.sym.var("v2")
+ v6 = mx.sym.var("v3")
+ v7 = mx.sym.var("v4")
+
+ # This tests foreach with accumulation sum.
+ def step1(in1, states, free):
+ out = in1 * 2 + states[0] + free[0]
+ return (out, [out])
+ def step2(in1, states, free):
+ out = states[0] + in1 * 2 + free[0]
+ return (out, [out])
+ def step3(in1, states, free):
+ out = in1[0] + in1[1] + states[0] + states[1] + free[0]
+ return ([out, out * 2], [out * 2, out * 3])
+
+ def verify_foreach(step, in_syms, state_syms, free_syms,
+ in_arrs, init_states, frees, out_grads, is_train=True,
+ free_vars_func=None):
+ step_sym = lambda in_syms, state_syms : step(in_syms, state_syms,
free_syms)
+ res, states = mx.sym.contrib.foreach(step_sym, in_syms, state_syms)
+ out = _as_list(res)
+ for i in range(len(out)):
+ out[i] = out[i] * 2
+ out.extend(states)
+ out = mx.sym.Group(out)
+ arr_grads = []
+ arg_dict = {}
+ arg_grad_dict = {}
+ i = 0
+ for arr in _as_list(in_arrs):
+ arr_grad = mx.nd.empty(arr.shape)
+ arr_grads.append(arr_grad)
+ arg_dict['v'+str(i)] = arr
+ arg_grad_dict['v'+str(i)] = arr_grad
+ i = i + 1
+ for arr in init_states:
+ arr_grad = mx.nd.empty(arr.shape)
+ arr_grads.append(arr_grad)
+ arg_dict['v'+str(i)] = arr
+ arg_grad_dict['v'+str(i)] = arr_grad
+ i = i + 1
+ for arr in frees:
+ arr_grad = mx.nd.empty(arr.shape)
+ arr_grads.append(arr_grad)
+ arg_dict['v'+str(i)] = arr
+ arg_grad_dict['v'+str(i)] = arr_grad
+ i = i + 1
+
+ gin_order = []
+ for name in out.list_inputs():
+ name = name[1:]
+ gin_order.append(int(name))
+
+ e = out.bind(ctx=default_context(), args=arg_dict,
args_grad=arg_grad_dict)
+ e.forward(is_train=is_train)
+ if (is_train):
+ # backward
+ tmp_grads = out_grads[0][:]
+ tmp_grads.extend(out_grads[1])
+ e.backward(tmp_grads)
+
+ # Below we use imperative to reimplement foreach and compute its
gradients.
+ res = []
+ for i in range(len(_as_list(out_grads[0]))):
+ res.append([])
+ for arr in _as_list(in_arrs):
+ arr.attach_grad()
+ for arr in init_states:
+ arr.attach_grad()
+ for arr in frees:
+ arr.attach_grad()
+ with mx.autograd.record():
+ frees_imp = frees if free_vars_func is None else
free_vars_func(frees)
+ step_imp = lambda in_arrs, state_arrs : step(in_arrs, state_arrs,
frees_imp)
+ states = [mx.nd.expand_dims(s, 0) for s in init_states]
+ res, states = mx.nd.contrib.foreach(step_imp, in_arrs, init_states)
+
+ res2 = _as_list(res)
+ for i in range(len(res2)):
+ res2[i] = res2[i] * 2
+ if isinstance(states, list):
+ states = [mx.nd.expand_dims(s, 0) for s in states]
+ res2.extend(states)
+ else:
+ states = mx.nd.expand_dims(states, 0)
+ res2.append(states)
+ res = mx.nd.concat(*res2, dim=0)
+
+ tmp_grads = out_grads[0][:]
+ tmp_grads1 = [mx.nd.expand_dims(grad, 0) for grad in out_grads[1]]
+ tmp_grads.extend(tmp_grads1)
+ if (is_train):
+ res.backward(mx.nd.concat(*tmp_grads, dim=0))
+ for i in range(len(res2)):
+ assert_almost_equal(e.outputs[i].asnumpy(), res2[i].asnumpy(),
+ rtol=0.001, atol=0.0001)
+ if (is_train):
+ all_ins = _as_list(in_arrs)[:]
+ all_ins.extend(init_states)
+ all_ins.extend(frees)
+ for i in range(len(all_ins)):
+ assert_almost_equal(all_ins[i].grad.asnumpy(),
+ e.grad_arrays[gin_order[i]].asnumpy())
+
+ # Test cases:
+ # * graph inputs are stored in different orders.
+ # This is to test if foreach finds the data arrays and weight arrays
+ # in the right location.
+ # * the number of iterations: odd or even.
+ # * multiple inputs and multiple outputs.
+ # * inference.
+
+ #states = [mx.nd.random.uniform(shape=(2))]
+
+ #frees1 = [mx.nd.random.uniform(shape=(2)),
mx.nd.random.uniform(shape=(2))]
+ #arrs = mx.nd.random.uniform(shape=(3, 2))
+ states = [mx.nd.arange(2)]
+
+ frees1 = [mx.nd.arange(2), mx.nd.arange(2) + 1]
+ arrs = mx.nd.arange(6).reshape(shape=(3, 2))
+ out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)],
+ [mx.nd.random.uniform(-10, 10, states[0].shape)]]
+ verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1,
out_grads, True,
+ lambda frees : [frees[0] + frees[1]])
+ verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1,
out_grads, False,
+ lambda frees : [frees[0] + frees[1]])
+
+ frees = [mx.nd.random.uniform(shape=(2))]
+ arrs = mx.nd.random.uniform(shape=(2, 2))
+ out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)],
+ [mx.nd.random.uniform(-10, 10, states[0].shape)]]
+ verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads)
+ verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads,
False)
+
+ arrs = mx.nd.random.uniform(shape=(3, 2))
+ out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)],
+ [mx.nd.random.uniform(-10, 10, states[0].shape)]]
+ verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads)
+ verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads,
False)
+
+ arrs = mx.nd.random.uniform(shape=(2, 2))
+ out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)],
+ [mx.nd.random.uniform(-10, 10, states[0].shape)]]
+ verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads)
+ verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads,
False)
+
+ arrs = mx.nd.random.uniform(shape=(3, 2))
+ out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)],
+ [mx.nd.random.uniform(-10, 10, states[0].shape)]]
+ verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads)
+ verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads,
False)
+
+ # Test multiple inputs and outputs.
+ arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3,
2))]
+ states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))]
+ out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape),
mx.nd.random.uniform(-10, 10, arrs[1].shape)],
+ [mx.nd.random.uniform(-10, 10, states[0].shape),
mx.nd.random.uniform(-10, 10, states[1].shape)]]
+ verify_foreach(step3, [v3, v4], [v5, v6], [v7], arrs, states, frees,
out_grads)
+ verify_foreach(step3, [v3, v4], [v5, v6], [v7], arrs, states, frees,
out_grads, False)
+
+
+@with_seed()
+def test_foreach_nested():
+ # Test nested foreach.
+ def step_in(in1, states):
+ out = in1 * 2 + states[0]
+ return (out, [out])
+
+ def step(in1, states):
+ out1 = mx.sym.contrib.foreach(step_in, in1, states)
+ out = mx.sym.broadcast_add(out1[0], states[0])
+ return (out, [mx.sym.squeeze(mx.sym.slice(out, begin=(0, 0), end=(1,
2)))])
+
+ data_sym = mx.sym.var("v1")
+ state_sym = mx.sym.var("v2")
+ out = mx.sym.contrib.foreach(step, data_sym, [state_sym])
+
+ out1 = _as_list(out[0])
+ for i in range(len(out1)):
+ out1[i] = out1[i]
+ out1.extend(out[1])
+ out = mx.sym.Group(out1)
+
+ data = mx.nd.arange(4).reshape((1, 2, 2))
+ state = mx.nd.arange(2)
+ data_grad = mx.nd.empty(data.shape)
+ state_grad = mx.nd.empty(state.shape)
+ e = out.bind(ctx=default_context(), args={'v1':data, 'v2':state},
+ args_grad={'v1':data_grad, 'v2':state_grad})
+ e.forward(is_train=True)
+ out = mx.nd.zeros_like(data)
+ for i in range(data.shape[0]):
+ data1 = data[i]
+ out1 = mx.nd.zeros_like(data1)
+ for j in range(data1.shape[0]):
+ if (j > 0):
+ out1[j] = out1[j-1] + data1[j] * 2
+ else:
+ out1[j] = data1[j] * 2 + state
+ if (i > 0):
+ state = mx.nd.squeeze(mx.nd.slice(out[i-1], begin=(0, 0), end=(1,
2)))
+ out[i] = mx.nd.broadcast_add(out1, state)
+ else:
+ out[i] = mx.nd.broadcast_add(out1, state)
+ out = out
+ assert_almost_equal(out.asnumpy(), e.outputs[0].asnumpy(), rtol=0.001,
atol=0.0001)
+
+
+@with_seed()
+def test_foreach_lstm():
+ data = mx.sym.var("data")
+ init_h = mx.sym.var("h")
+ init_c = mx.sym.var("c")
+ i2h_weight = mx.sym.var("i2h_weight")
+ h2h_weight = mx.sym.var("h2h_weight")
+ i2h_bias = mx.sym.var("i2h_bias")
+ h2h_bias = mx.sym.var("h2h_bias")
+
+ # This tests foreach with accumulation sum.
+ def step(in1, states):
+ params = mx.rnn.RNNParams()
+ params._params['i2h_weight'] = i2h_weight
+ params._params['h2h_weight'] = h2h_weight
+ params._params['i2h_bias'] = i2h_bias
+ params._params['h2h_bias'] = h2h_bias
+ lstm = mx.rnn.LSTMCell(4, prefix='mylstm_', params=params)
+ next_h, [next_h, next_c] = lstm(in1, states)
+ # TODO This is problematic. We can't count on the user to define two
different symbols.
+ return (next_h, [next_h, next_c])
+
+ def sym_group(out):
+ if (isinstance(out[0], mx.sym.Symbol)):
+ ret = [out[0]]
+ else:
+ ret = out[0]
+ ret.extend(out[1])
+ return mx.sym.Group(ret)
+
+ data_arr = mx.nd.random.uniform(shape=(2, 2, 4))
+ h_arr = mx.nd.random.uniform(shape=(2, 4))
+ c_arr = mx.nd.random.uniform(shape=(2, 4))
+ i2h_warr = mx.nd.random.uniform(shape=(16, 4))
+ h2h_warr = mx.nd.random.uniform(shape=(16, 4))
+ i2h_barr = mx.nd.random.uniform(shape=(16))
+ h2h_barr = mx.nd.random.uniform(shape=(16))
+
+ data_arr_grad1 = mx.nd.empty(data_arr.shape)
+ h_arr_grad1 = mx.nd.empty(h_arr.shape)
+ c_arr_grad1 = mx.nd.empty(c_arr.shape)
+ i2h_warr_grad1 = mx.nd.empty(i2h_warr.shape)
+ h2h_warr_grad1 = mx.nd.empty(h2h_warr.shape)
+ i2h_barr_grad1 = mx.nd.empty(i2h_barr.shape)
+ h2h_barr_grad1 = mx.nd.empty(h2h_barr.shape)
+ out = mx.sym.contrib.foreach(step, data, [init_h, init_c])
+ out = sym_group(out)
+ e1 = out.bind(ctx=default_context(),
+ args={'data': data_arr, 'h': h_arr, 'c': c_arr,
+ 'i2h_weight': i2h_warr, 'h2h_weight': h2h_warr,
+ 'i2h_bias': i2h_barr, 'h2h_bias': h2h_barr},
+ args_grad={'data': data_arr_grad1, 'h': h_arr_grad1, 'c':
c_arr_grad1,
+ 'i2h_weight': i2h_warr_grad1, 'h2h_weight':
h2h_warr_grad1,
+ 'i2h_bias': i2h_barr_grad1, 'h2h_bias':
h2h_barr_grad1})
+ e1.forward(is_train=True)
+ outputs1 = e1.outputs
+ # backward
+ out_grads = []
+ for arr in e1.outputs:
+ out_grads.append(mx.nd.random.uniform(-10, 10, arr.shape))
+ e1.backward(out_grads)
+
+ data_arr_grad2 = mx.nd.empty(data_arr.shape)
+ h_arr_grad2 = mx.nd.empty(h_arr.shape)
+ c_arr_grad2 = mx.nd.empty(c_arr.shape)
+ i2h_warr_grad2 = mx.nd.empty(i2h_warr.shape)
+ h2h_warr_grad2 = mx.nd.empty(h2h_warr.shape)
+ i2h_barr_grad2 = mx.nd.empty(i2h_barr.shape)
+ h2h_barr_grad2 = mx.nd.empty(h2h_barr.shape)
+ lstm = mx.rnn.LSTMCell(4, prefix='mylstm_')
+ h = init_h
+ c = init_c
+ unroll_outs = []
+ for inputs in mx.sym.split(data, num_outputs=data_arr.shape[0], axis=0,
squeeze_axis=True):
+ h, [h, c] = lstm(inputs, [h, c])
+ unroll_outs.append(mx.sym.expand_dims(h, axis=0))
+ unroll_outs = mx.sym.concat(*unroll_outs, dim=0)
+ out = mx.sym.Group([unroll_outs, h, c])
+ e2 = out.bind(ctx=default_context(),
+ args={'data': data_arr, 'h': h_arr, 'c': c_arr,
+ 'mylstm_i2h_weight': i2h_warr, 'mylstm_h2h_weight':
h2h_warr,
+ 'mylstm_i2h_bias': i2h_barr, 'mylstm_h2h_bias':
h2h_barr},
+ args_grad={'data': data_arr_grad2, 'h': h_arr_grad2, 'c':
c_arr_grad2,
+ 'mylstm_i2h_weight': i2h_warr_grad2,
'mylstm_h2h_weight': h2h_warr_grad2,
+ 'mylstm_i2h_bias': i2h_barr_grad2,
'mylstm_h2h_bias': h2h_barr_grad2})
+ e2.forward(is_train=True)
+ outputs2 = e2.outputs
+ e2.backward(out_grads)
+
+ for i in range(len(outputs2)):
+ assert_almost_equal(outputs1[i].asnumpy(), outputs2[i].asnumpy(),
rtol=0.001, atol=0.0001)
+ for i in range(len(e1.grad_arrays)):
+ assert_almost_equal(e1.grad_arrays[i].asnumpy(),
e2.grad_arrays[i].asnumpy())
+
+
@with_seed()
def test_squeeze_op():
def check_squeeze_op(shape, axis=None):
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services