junrushao1994 commented on a change in pull request #11760: [MXNET-684] Add `ifelse` operator URL: https://github.com/apache/incubator-mxnet/pull/11760#discussion_r204124259
########## File path: python/mxnet/symbol/contrib.py ########## @@ -556,3 +556,154 @@ def _union_inputs(*graphs): outputs = [result[i] for i in range(num_out_data)] final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)] return outputs, final_loop_vars + +def ifelse(cond, then_func, else_func, inputs, name="ifelse"): + """Run a if-then-else using user-defined condition and computation + + This operator simulates a if-like branch which chooses to do one of + the two customized computations according to the specified condition. + + `inputs` is a list of Symbols on which the condition and computations reply on. + + `cond` is a user-defined function, used as the if condition. + It consumes `inputs`, and produces a scalar MXNet symbol, + indicating which branch of computation should be used. + The `cond` is variadic, and its signature should be + `cond(*loop_vars) => Symbol`. + + `then_func` is a user-defined function, used as computation of the then branch. + It consumes `inputs`, and produces `outputs`. + The `then_func` is variadic, and its signature should be + `then_func(*loop_vars) => List[Symbol]`. + + `else_func` is a user-defined function, used as computation of the else branch. + It also consumes `inputs`, and produces `outputs`. + The `else_func` is variadic, and its signature should be + `else_func(*loop_vars) => List[Symbol]`. + + The `outputs` produces by `then_func` and `else_func` should have the same number + of elements, all of which should be in the same shape, of the same dtype and stype. + + This function returns a list of symbols, representing the computation result. + + Parameters + ---------- + cond: a Python function. + The branch condition. + then_func: a Python function. + The computation to be executed if `cond` is true. + else_func: a Python function. + The computation to be executed if `cond` is false. + inputs: list of Symbols. + The variables fed to `cond`, `then_func` and `else_func`. + + Returns + ------- + outputs: a list of Symbols, representing the result of computation. + + Examples + -------- + >>> cond = lambda a, b: a * b < 5 + >>> then_func = lambda a, b: (a + 5) * (b + 5) + >>> else_func = lambda a, b: (a - 5) * (b - 5) + >>> inputs = (mx.sym.var('a'), mx.sym.var('b')) + >>> outputs = mx.sym.contrib.ifelse(cond, then_func, else_func, inputs) + """ + def _to_symbol_tuple(inputs, name): + """Converts "inputs", possibly a single mxnet Symbol, a list of mxnet Symbol, + a tuple of mxnet Symbol, into a tuple of Symbol + """ + if isinstance(inputs, list): + inputs = tuple(inputs) + if isinstance(inputs, Symbol): + inputs = (inputs, ) + if not isinstance(inputs, tuple): + raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, )) + for item in inputs: + if not isinstance(item, Symbol): + raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, )) + return inputs + + def _create_subgraph(graph_vars, graph_func, subgraph_name): Review comment: They are not exactly the same. One would search for var_locs, another doesn't. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services