piiswrong commented on a change in pull request #10451: [MXNET-432] Add Foreach
URL: https://github.com/apache/incubator-mxnet/pull/10451#discussion_r189501415
 
 

 ##########
 File path: python/mxnet/ndarray/contrib.py
 ##########
 @@ -96,3 +98,78 @@ def rand_zipfian(true_classes, num_sampled, range_max, 
ctx=None):
     expected_count_sampled = expected_prob_sampled * num_sampled
     return sampled_classes, expected_count_true, expected_count_sampled
 # pylint: enable=line-too-long
+
+def foreach(func, data, init_states):
+    """Run a for loop with user-defined computation over NDArrays on dimension 
0.
+
+    This operator simulates a for loop and func has the computation for an 
iteration
+    of the for loop. It runs the computation in func on each slice from the 
input
+    NDArrays.
+
+    func takes two arguments as input and outputs a tuple of two elements,
+    as illustrated below:
+
+    out, states = func(data1, states)
+
+    data1 can be either a symbol or a list of symbols. If data is a symbol,
+    data1 is a symbol. Otherwise, data1 is a list of symbols and has the same
+    size as data. states is a list of symbols and have the same size as 
init_states.
+    Similarly, out can be either a symbol or a list of symbols, which are 
concatenated
+    as the first output of foreach; states from the last execution of func
+    are the second output of foreach.
+
+    The computation done by this operator is equivalent to the pseudo code 
below
+    when the input data is NDArray:
+
+    states = init_states
+    outs = []
+    for i in data.shape[0]:
+        s = data[i]
+        out, states = func(s, states)
+        outs.append(out)
+    outs = stack(*outs)
+
+
+    Parameters
+    ----------
+    func : a Python function.
+        Define computation in an iteration.
+    data: a symbol or a list of symbols.
+        The input data.
+    init_states: a list of symbols.
+        The initial values of the loop states.
+    name: string.
+        The name of the operator.
+
+    Returns
+    -------
+    outputs: a Symbol or a list of Symbols.
+        The output data concatenated from the output of all iterations.
+    states: a list of Symbols.
+        The loop states in the last iteration.
+
+    Examples
+    --------
+    >>> step = lambda data, states: (data + states[0], [states[0] * 2])
+    >>> data = mx.nd.random.uniform(shape=(2, 10))
+    >>> states = [mx.nd.random.uniform(shape=(10))]
+    >>> outs, states = mx.nd.contrib.foreach(step, data, states)
+    """
+
+    assert isinstance(init_states, list), "init_states should be a list"
+    states = init_states
+    outputs = []
+    for i in range(data.shape[0]):
+        ele = data[i]
+        outs, states = func(ele, states)
+        outs = _as_list(outs)
+        if i == 0:
+            # outputs is a list of lists
+            for out in outs:
+                outputs.append([out])
+        else:
+            for j, out in enumerate(outs):
+                outputs[j].append(out)
+    for out in outputs:
+        out = stack(*out)
+    return (outputs, states)
 
 Review comment:
   Return value is always a list?
   If func returns a single value outputs should also be an single value?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to