zheng-da commented on a change in pull request #11760: [MXNET-684] Add `ifelse`
operator
URL: https://github.com/apache/incubator-mxnet/pull/11760#discussion_r203915519
##########
File path: python/mxnet/symbol/contrib.py
##########
@@ -556,3 +556,154 @@ def _union_inputs(*graphs):
outputs = [result[i] for i in range(num_out_data)]
final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)]
return outputs, final_loop_vars
+
+def ifelse(cond, then_func, else_func, inputs, name="ifelse"):
+ """Run a if-then-else using user-defined condition and computation
+
+ This operator simulates a if-like branch which chooses to do one of
+ the two customized computations according to the specified condition.
+
+ `inputs` is a list of Symbols on which the condition and computations
reply on.
+
+ `cond` is a user-defined function, used as the if condition.
+ It consumes `inputs`, and produces a scalar MXNet symbol,
+ indicating which branch of computation should be used.
+ The `cond` is variadic, and its signature should be
+ `cond(*loop_vars) => Symbol`.
+
+ `then_func` is a user-defined function, used as computation of the then
branch.
+ It consumes `inputs`, and produces `outputs`.
+ The `then_func` is variadic, and its signature should be
+ `then_func(*loop_vars) => List[Symbol]`.
+
+ `else_func` is a user-defined function, used as computation of the else
branch.
+ It also consumes `inputs`, and produces `outputs`.
+ The `else_func` is variadic, and its signature should be
+ `else_func(*loop_vars) => List[Symbol]`.
+
+ The `outputs` produces by `then_func` and `else_func` should have the same
number
+ of elements, all of which should be in the same shape, of the same dtype
and stype.
+
+ This function returns a list of symbols, representing the computation
result.
+
+ Parameters
+ ----------
+ cond: a Python function.
+ The branch condition.
+ then_func: a Python function.
+ The computation to be executed if `cond` is true.
+ else_func: a Python function.
+ The computation to be executed if `cond` is false.
+ inputs: list of Symbols.
+ The variables fed to `cond`, `then_func` and `else_func`.
+
+ Returns
+ -------
+ outputs: a list of Symbols, representing the result of computation.
+
+ Examples
+ --------
+ >>> cond = lambda a, b: a * b < 5
+ >>> then_func = lambda a, b: (a + 5) * (b + 5)
+ >>> else_func = lambda a, b: (a - 5) * (b - 5)
+ >>> inputs = (mx.sym.var('a'), mx.sym.var('b'))
+ >>> outputs = mx.sym.contrib.ifelse(cond, then_func, else_func, inputs)
+ """
+ def _to_symbol_tuple(inputs, name):
+ """Converts "inputs", possibly a single mxnet Symbol, a list of mxnet
Symbol,
+ a tuple of mxnet Symbol, into a tuple of Symbol
+ """
+ if isinstance(inputs, list):
+ inputs = tuple(inputs)
+ if isinstance(inputs, Symbol):
+ inputs = (inputs, )
+ if not isinstance(inputs, tuple):
+ raise ValueError("%s must be a Symbol, or a tuple or list of
Symbol" % (name, ))
+ for item in inputs:
+ if not isinstance(item, Symbol):
+ raise ValueError("%s must be a Symbol, or a tuple or list of
Symbol" % (name, ))
+ return inputs
+
+ def _create_subgraph(graph_vars, graph_func, subgraph_name):
Review comment:
it seems this function and the function below are the same as the one in
while_loop. Can you move them out and reuse them?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services