zheng-da commented on a change in pull request #10451: [MXNET-432] Add Foreach
URL: https://github.com/apache/incubator-mxnet/pull/10451#discussion_r189734691
 
 

 ##########
 File path: python/mxnet/symbol/contrib.py
 ##########
 @@ -91,3 +98,176 @@ def rand_zipfian(true_classes, num_sampled, range_max):
     expected_prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 
1.0)).log() / log_range
     expected_count_sampled = expected_prob_sampled * num_sampled
     return sampled_classes, expected_count_true, expected_count_sampled
+
+def _get_graph_inputs(subg):
+    num_handles = ctypes.c_int(1000)
+    handles = c_array(SymbolHandle, [SymbolHandle(0) for i in range(1000)])
+    check_call(_LIB.MXSymbolGetInputSymbols(subg.handle, handles, 
ctypes.byref(num_handles)))
+
+    syms = []
+    for i in range(num_handles.value):
+        s = Symbol(handles[i])
+        syms.append(s)
+    return syms
+
+def foreach(func, data, init_states, name="foreach"):
+    """Run a for loop with user-defined computation over NDArrays on dimension 
0.
+
+    This operator simulates a for loop and func has the computation for an 
iteration
+    of the for loop. It runs the computation in func on each slice from the 
input
+    NDArrays.
+
+    func takes two arguments as input and outputs a tuple of two elements,
+    as illustrated below:
+
+    out, states = func(data1, states)
+
+    data1 can be either a symbol or a list of symbols. If data is a symbol,
+    data1 is a symbol. Otherwise, data1 is a list of symbols and has the same
+    size as data. states is a list of symbols and have the same size as 
init_states.
+    Similarly, out can be either a symbol or a list of symbols, which are 
concatenated
+    as the first output of foreach; states from the last execution of func
+    are the second output of foreach.
+
+    The computation done by this operator is equivalent to the pseudo code 
below
+    when the input data is NDArray:
+
+    states = init_states
+    outs = []
+    for i in data.shape[0]:
+        s = data[i]
+        out, states = func(s, states)
+        outs.append(out)
+    outs = stack(*outs)
+
+
+    Parameters
+    ----------
+    func : a Python function.
+        Define computation in an iteration.
+    data: a symbol or a list of symbols.
+        The input data.
+    init_states: a symbol or a list of symbols.
+        The initial values of the loop states.
+    name: string.
+        The name of the operator.
+
+    Returns
+    -------
+    outputs: a Symbol or a list of Symbols.
+        The output data concatenated from the output of all iterations.
+    states: a list of Symbols.
+        The loop states in the last iteration.
+
+    Examples
+    --------
+    >>> step = lambda data, states: (data + states[0], [states[0] * 2])
+    >>> data = mx.sym.var('data')
+    >>> states = [mx.sym.var('state')]
+    >>> outs, states = mx.sym.contrib.foreach(step, data, states)
+    """
+
+    def check_data(inputs, in_type, msg):
+        is_NDArray_or_list = True
+        if isinstance(inputs, list):
+            for i in inputs:
+                if not isinstance(i, in_type):
+                    is_NDArray_or_list = False
+                    break
+        else:
+            is_NDArray_or_list = isinstance(inputs, in_type)
+        assert is_NDArray_or_list, msg
+
+    check_data(data, symbol.Symbol, "data should be an NDArray or a list of 
NDArrays")
+    check_data(init_states, symbol.Symbol,
+            "init_states should be an NDArray or a list of NDArrays")
+    not_state_list = isinstance(init_states, symbol.Symbol)
+
+    # TODO(zhengda) If the input python function references to the symbols 
outside
+    # the python function, we need to prune the computation graph constructed 
from
+    # the function. One way of doing it is to mark the nodes in the 
computation graph
+    # with AttrScope and prune the nodes without the special attribute.
+    with AttrScope(subgraph_name=name):
 
 Review comment:
   alternative of what?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to