szha commented on a change in pull request #11566: [MXNET-626] Add while_loop
URL: https://github.com/apache/incubator-mxnet/pull/11566#discussion_r203513937
 
 

 ##########
 File path: python/mxnet/symbol/contrib.py
 ##########
 @@ -336,3 +336,221 @@ def check_data(inputs, in_type, msg):
         states = states[0]
 
     return (outs, states)
+
+def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"):
+    """Run a while loop with user-defined computation and loop condition.
+
+    This operator simulates a while loop which iterately does customized 
computation
+    as long as the condition is satisfied.
+
+    `loop_vars` is a list of Symbols on which the computation uses.
+
+    `cond` is a user-defined function, used as the loop condition.
+    It consumes `loop_vars`, and produces a scalar MXNet symbol,
+    indicating the termination of the loop.
+    The loop ends when `cond` returns false (zero).
+    The `cond` is variadic, and its signature should be
+    `cond(*loop_vars) => Symbol`.
+
+    `func` is a user-defined function, used as the loop body.
+    It also consumes `loop_vars`, and produces `step_output` and 
`new_loop_vars` at each step.
+    In each step, `step_output` should contain the same number elements.
+    Through all steps, the i-th element of `step_output` should have the same 
shape and dtype.
+    Also, `new_loop_vars` should contain the same number of elements as 
`loop_vars`,
+    and the corresponding element should have the same shape and dtype.
+    The `func` is variadic, and its signature should be
+    `func(*loop_vars) => (List[Symbol] step_output, List[Symbol] 
new_loop_vars)`.
+
+    `max_iterations` is a scalar that defines the maximum number of iterations 
allowed.
+
+    This function returns two lists.
+    The first list has the length of `|step_output|`,
+    in which the i-th element are all i-th elements of
+    `step_output` from all steps, stacked along axis 0.
+    The second list has the length of `|loop_vars|`,
+    which represents final states of loop variables.
+
+    .. warning::
+    For now, the axis 0 of all Symbols in the first list are `max_iterations`,
 
 Review comment:
   there should be an empty line before and after the warning block. Also, 
indentation is needed for it to work. refer to sphinx doc

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to