zheng-da commented on a change in pull request #11760: [MXNET-684] Add `ifelse`
operator
URL: https://github.com/apache/incubator-mxnet/pull/11760#discussion_r203920239
##########
File path: python/mxnet/ndarray/contrib.py
##########
@@ -363,3 +362,97 @@ def _func_wrapper(loop_vars):
[" Step %d, shape is %s" % (i, str(x.shape)) for i, x in
enumerate(items)]
))
return stacked_outputs, list(loop_vars)
+
+def ifelse(cond, then_func, else_func, inputs):
+ """Run a if-then-else using user-defined condition and computation
+
+ This operator simulates a if-like branch which chooses to do one of
+ the two customized computations according to the specified condition.
+
+ `inputs` is a list of NDArrays on which the condition and computations
reply on.
+
+ `cond` is a user-defined function, used as the if condition.
+ It consumes `inputs`, and produces a scalar MXNet NDArray,
+ indicating which branch of computation should be used.
+ The `cond` is variadic, and its signature should be
+ `cond(*loop_vars) => NDArray`.
+
+ `then_func` is a user-defined function, used as computation of the then
branch.
+ It consumes `inputs`, and produces `outputs`.
+ The `then_func` is variadic, and its signature should be
+ `then_func(*loop_vars) => List[NDArray]`.
+
+ `else_func` is a user-defined function, used as computation of the else
branch.
+ It also consumes `inputs`, and produces `outputs`.
+ The `else_func` is variadic, and its signature should be
+ `else_func(*loop_vars) => List[NDArray]`.
+
+ The `outputs` produces by `then_func` and `else_func` should have the same
number
+ of elements, all of which should be in the same shape, of the same dtype
and stype.
+
+ This function returns a list of NDArrays, representing the computation
result.
+
+ Parameters
+ ----------
+ cond: a Python function.
+ The branch condition.
+ then_func: a Python function.
+ The computation to be executed if `cond` is true.
+ else_func: a Python function.
+ The computation to be executed if `cond` is false.
+ inputs: list of NDArrays.
+ The variables fed to `cond`, `then_func` and `else_func`.
+
+ Returns
+ -------
+ outputs: a list of NDArrays, representing the result of computation.
+
+ Examples
+ --------
+ >>> cond = lambda a, b: a * b < 5
+ >>> then_func = lambda a, b: (a + 5) * (b + 5)
+ >>> else_func = lambda a, b: (a - 5) * (b - 5)
+ >>> inputs = (mx.nd.array([1]), mx.nd.array([2]))
+ >>> outputs = mx.nd.contrib.ifelse(cond, then_func, else_func, inputs)
+ >>> outputs[0]
+ [42.]
+ <NDArray 1 @cpu(0)>
+ """
+ def _to_python_scalar(inputs, type_, name):
+ """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray,
other python types,
+ to the given type
+ """
+ if hasattr(inputs, "asscalar"):
+ inputs = inputs.asscalar()
+ try:
+ inputs = type_(inputs)
+ except:
+ raise ValueError("Cannot convert %s to python %s" % (name,
type_.__name__))
+ return inputs
+
+ def _to_ndarray_tuple(inputs, name):
+ """Converts "inputs", possibly a single mxnet NDArray, a list of mxnet
NDArray,
+ a tuple of mxnet NDArray, into a tuple of NDArray
+ """
+ if isinstance(inputs, list):
+ inputs = tuple(inputs)
+ if isinstance(inputs, ndarray.NDArray):
+ inputs = (inputs, )
+ if not isinstance(inputs, tuple):
+ raise ValueError("%s must be an NDArray, or a tuple or list of
NDArrays" % (name, ))
+ for item in inputs:
+ if not isinstance(item, ndarray.NDArray):
+ raise ValueError("%s must be an NDArray, or a tuple or list of
NDArrays" % (name, ))
+ return inputs
+
+ inputs = _to_ndarray_tuple(inputs, "inputs")
+ if len(inputs) == 0:
+ raise ValueError("inputs should contain at least one element")
+ branch = _to_python_scalar(cond(*inputs), bool, "Return value of cond")
+ if branch:
+ outputs = then_func(*inputs)
+ outputs = _to_ndarray_tuple(outputs, "outputs of then_func")
+ else:
+ outputs = else_func(*inputs)
+ outputs = _to_ndarray_tuple(outputs, "outputs of else_func")
Review comment:
is there a way of checking if the outputs from the if branch and the else
branch have the same number of outputs and the same types, etc.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services