junrushao1994 commented on a change in pull request #11760: [MXNET-684] Add 
`ifelse` operator
URL: https://github.com/apache/incubator-mxnet/pull/11760#discussion_r204124259
 
 

 ##########
 File path: python/mxnet/symbol/contrib.py
 ##########
 @@ -556,3 +556,154 @@ def _union_inputs(*graphs):
     outputs = [result[i] for i in range(num_out_data)]
     final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)]
     return outputs, final_loop_vars
+
+def ifelse(cond, then_func, else_func, inputs, name="ifelse"):
+    """Run a if-then-else using user-defined condition and computation
+
+    This operator simulates a if-like branch which chooses to do one of
+    the two customized computations according to the specified condition.
+
+    `inputs` is a list of Symbols on which the condition and computations 
reply on.
+
+    `cond` is a user-defined function, used as the if condition.
+    It consumes `inputs`, and produces a scalar MXNet symbol,
+    indicating which branch of computation should be used.
+    The `cond` is variadic, and its signature should be
+    `cond(*loop_vars) => Symbol`.
+
+    `then_func` is a user-defined function, used as computation of the then 
branch.
+    It consumes `inputs`, and produces `outputs`.
+    The `then_func` is variadic, and its signature should be
+    `then_func(*loop_vars) => List[Symbol]`.
+
+    `else_func` is a user-defined function, used as computation of the else 
branch.
+    It also consumes `inputs`, and produces `outputs`.
+    The `else_func` is variadic, and its signature should be
+    `else_func(*loop_vars) => List[Symbol]`.
+
+    The `outputs` produces by `then_func` and `else_func` should have the same 
number
+    of elements, all of which should be in the same shape, of the same dtype 
and stype.
+
+    This function returns a list of symbols, representing the computation 
result.
+
+    Parameters
+    ----------
+    cond: a Python function.
+        The branch condition.
+    then_func: a Python function.
+        The computation to be executed if `cond` is true.
+    else_func: a Python function.
+        The computation to be executed if `cond` is false.
+    inputs: list of Symbols.
+        The variables fed to `cond`, `then_func` and `else_func`.
+
+    Returns
+    -------
+    outputs: a list of Symbols, representing the result of computation.
+
+    Examples
+    --------
+    >>> cond = lambda a, b: a * b < 5
+    >>> then_func = lambda a, b: (a + 5) * (b + 5)
+    >>> else_func = lambda a, b: (a - 5) * (b - 5)
+    >>> inputs = (mx.sym.var('a'), mx.sym.var('b'))
+    >>> outputs = mx.sym.contrib.ifelse(cond, then_func, else_func, inputs)
+    """
+    def _to_symbol_tuple(inputs, name):
+        """Converts "inputs", possibly a single mxnet Symbol, a list of mxnet 
Symbol,
+        a tuple of mxnet Symbol, into a tuple of Symbol
+        """
+        if isinstance(inputs, list):
+            inputs = tuple(inputs)
+        if isinstance(inputs, Symbol):
+            inputs = (inputs, )
+        if not isinstance(inputs, tuple):
+            raise ValueError("%s must be a Symbol, or a tuple or list of 
Symbol" % (name, ))
+        for item in inputs:
+            if not isinstance(item, Symbol):
+                raise ValueError("%s must be a Symbol, or a tuple or list of 
Symbol" % (name, ))
+        return inputs
+
+    def _create_subgraph(graph_vars, graph_func, subgraph_name):
 
 Review comment:
   They are not exactly the same. One would search for var_locs, another 
doesn't.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to