szha commented on a change in pull request #11566: [MXNET-626] Add while_loop
URL: https://github.com/apache/incubator-mxnet/pull/11566#discussion_r203214074
##########
File path: python/mxnet/symbol/contrib.py
##########
@@ -336,3 +336,205 @@ def check_data(inputs, in_type, msg):
states = states[0]
return (outs, states)
+
+def while_loop(cond, func, loop_vars, max_iterations, name="while_loop"):
+ """Run a while loop with user-defined computation and loop condition.
+
+ This operator simulates a while loop which iterately does customized
computation
+ as long as the condition is satisfied.
+
+ `loop_vars` is a list of Symbols on which the computation uses.
+
+ `cond` is a user-defined function as the loop condition.
+ It consumes `loop_vars`, and produces a scalar MXNet symbol,
+ indicating the termination of the loop.
+ The loop ends when `cond` returns false (zero).
+ The `cond` is variadic, and its signature should be
+ `cond(*loop_vars) => Symbol`.
+
+ `func` is a user-defined function as the loop body.
+ It also consumes `loop_vars`, and produces `step_output` and
`new_loop_vars` at each step.
+ The number of elements, shape, dtype of each element in `step_output`
should be consistent.
+ The `new_loop_vars` should be consistent with `loop_vars` on each step.
+ The `func` is variadic, and its signature should be
+ `cond(*loop_vars) => (List[Symbol] step_output, List[Symbol]
new_loop_vars)`.
+
+ `max_iterations` is a scalar that defines the maximum number of iterations
allowed.
+
+ This function returns a list of Symbols of length `|step_output| +
|loop_vars|`.
+ The i-th element in the first `|step_output|` ones of the list represent
+ the i-th `step_output` at all step, stacked along axis 0.
+ The i-th element in the last `|loop_vars|` ones of the list
+ represent the final state of each loop variable.
+
+ Parameters
+ ----------
+ loop_vars: list of Symbol.
+ The initial values of the loop variables.
+ cond: a Python function.
+ The loop condition.
+ func: a Python function.
+ The loop body.
+ max_iteration: a python int.
+ Maximum number of iterations.
+
+ Returns
+ -------
+ outputs: a tuple of two lists, which both contains 0, 1 or more Symbols.
+ The first list contains the stacked output from each step,
+ The second list contains the final state.
+
+ Examples
+ --------
+ >>> cond = lambda i, s: i <= 5
+ >>> func = lambda i, s: ([i + s], [i + 1, s + i])
+ >>> loop_vars = (mx.sym.var('i'), mx.sym.var('s'))
+ >>> outputs, states = mx.sym.contrib.while_loop(cond, func, loop_vars,
max_iterations=10)
+ """
+ def _to_python_scalar(inputs, type_, name):
+ """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray,
other python types,
+ to the given type
+ """
+ if hasattr(inputs, "asscalar"):
+ inputs = inputs.asscalar()
+ try:
+ inputs = type_(inputs)
+ except:
+ raise ValueError("Cannot convert %s to python %s" % (name,
type_.__name__))
+ return inputs
+
+ def _to_symbol_tuple(inputs, name):
+ """Converts "inputs", possibly a single mxnet Symbol, a list of mxnet
Symbol,
+ a tuple of mxnet Symbol, into a tuple of Symbol
+ """
+ if isinstance(inputs, list):
+ inputs = tuple(inputs)
+ if isinstance(inputs, Symbol):
+ inputs = (inputs, )
+ if not isinstance(inputs, tuple):
+ raise ValueError("%s must be a Symbol, or a tuple or list of
Symbol" % (name, ))
+ for item in inputs:
+ if not isinstance(item, Symbol):
+ raise ValueError("%s must be a Symbol, or a tuple or list of
Symbol" % (name, ))
+ return inputs
+
+ def _cond_wrapper(loop_vars):
+ result = cond(*loop_vars)
+ if not isinstance(result, Symbol):
+ raise ValueError("Return of cond must be a Symbol")
+ return [], [result]
+
+ def _func_wrapper(loop_vars):
+ """This wrapper unifies
+ "func: loop_vars -> new_loop_vars"
+ and "func: loop_vars -> (step_output, new_loop_vars)"
+ into "func: loop_vars -> (list of step_outputs, tuple of new_loop_vars)
+ """
+ step_output, new_loop_vars = func(*loop_vars)
+ if step_output is None:
+ step_output = []
+ if new_loop_vars is None:
+ new_loop_vars = []
+ step_output = _to_symbol_tuple(step_output, "step_output")
+ new_loop_vars = _to_symbol_tuple(new_loop_vars, "new_loop_vars")
+ if len(loop_vars) != len(new_loop_vars):
+ raise ValueError("The number of loop_vars should be consistent
during the loop")
+ return list(step_output), list(new_loop_vars)
+
+ def _create_subgraph(graph_vars, graph_func, subgraph_name):
+ with AttrScope(__subgraph_name__=subgraph_name):
+ # create new variables with the same name,
+ # them feed them to the given func
+ new_graph_vars = [symbol.var(sym.name) for sym in graph_vars]
+ outputs, final_state = graph_func(new_graph_vars)
+ # first `num_out_data` elements belong to `outputs`
+ # other elements belong to `final_state`
+ num_out_data = len(outputs)
+ num_outputs = len(outputs) + len(final_state)
+ # nnvm cut-graph does not allow inputs and outputs overlap
+ # so we calculate the name of inputs, and copy outputs once it
overlaps with inputs
+ all_input_names = symbol.Group(outputs + final_state).list_inputs()
+ make_identity = lambda x: symbol.op.identity(x) if x.name in
all_input_names else x
+ # group all outputs of graph_func
+ graph = symbol.Group(list(map(make_identity, outputs +
final_state)))
+ return graph, num_out_data, num_outputs
+
+ def _union_inputs(*graphs):
+ # Given a list of graphs, each whose inputs are either from loop_vars
or other variables.
+ # 1) calculate a list `inputs`, the union of their inputs.
+ # 2) for each graph, determine in which indices their inputs reside in
`inputs`
+ # 3) for each variable in the input of `graph`, find which index it is
+ inputs = [] # List[Symbol], result of 1)
+ locs = [] # List[Tuple(List[Int], List[Int])], a list of
tuples,
+ # where tuples are results of 2) and 3)
+ input_id_to_loc = {} # Dict[int, int], given id(sym),
input_id_to_loc maps it
+ # to a `loc`, where inputs[loc] = sym
+ for graph in graphs:
+ # input_syms: all inputs to the `graph`
+ name_to_input_syms = {sym.name: sym for sym in
_get_graph_inputs(graph)}
+ # some loop_vars are inputs to `graph`, some are not
+ name_to_loop_vars = {sym.name: sym for sym in loop_vars}
+ # other inputs to `graph` created by cut_graph
+ name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in
_cut_subgraph(graph)}
+ # also we collect the mapping from var's name to var's loc in
loop_vars
+ name_to_var_locs = {sym.name: i for i, sym in enumerate(loop_vars)}
+ # collect arguments for each subgraph
+ input_locs = [] # results from the second
step
+ var_locs = [-1] * len(loop_vars) # results from the third
step
+ for name in graph.list_inputs():
+ assert name in name_to_input_syms # it should obviously hold
+ # name -> sym
+ if name in name_to_loop_vars:
+ sym = name_to_loop_vars[name]
+ elif name in name_to_cut_g_syms:
+ sym = name_to_cut_g_syms[name]
+ else:
+ sym = copy.deepcopy(name_to_input_syms[name])
+ # do 2), and 1) is implicitly done
+ if id(sym) in input_id_to_loc:
Review comment:
Why the id instead of checking for the symbol directly?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services