This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 4bb141d  [MXNET-684] Add `cond` operator (#11760)
4bb141d is described below

commit 4bb141d6826b60e6b3dca25c007c36f6d4585c33
Author: Junru Shao <junrushao1...@users.noreply.github.com>
AuthorDate: Mon Jul 23 21:18:17 2018 -0700

    [MXNET-684] Add `cond` operator (#11760)
    
    * Initial commit for `Ifelse`
    
    * Address comments
    
    * Rename ifelse to condition
    
    * API change
    
    * Trigger CI
    
    * Rename condition to cond
    
    * Fix lint
---
 benchmark/python/control_flow/foreach_rnn.py       | 195 --------
 benchmark/python/control_flow/while_loop_rnn.py    | 213 --------
 docs/api/python/ndarray/contrib.md                 |   1 +
 docs/api/python/symbol/contrib.md                  |   1 +
 python/mxnet/ndarray/contrib.py                    |  89 +++-
 python/mxnet/symbol/contrib.py                     | 146 +++++-
 src/operator/control_flow.cc                       | 538 +++++++++++++++++----
 src/operator/subgraph_op_common.cc                 |  28 ++
 src/operator/subgraph_op_common.h                  |  62 +++
 tests/python/unittest/test_contrib_control_flow.py | 159 +++++-
 10 files changed, 916 insertions(+), 516 deletions(-)

diff --git a/benchmark/python/control_flow/foreach_rnn.py 
b/benchmark/python/control_flow/foreach_rnn.py
deleted file mode 100644
index 4ce7a42..0000000
--- a/benchmark/python/control_flow/foreach_rnn.py
+++ /dev/null
@@ -1,195 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import subprocess
-import mxnet as mx
-from mxnet import gluon
-import time
-import copy
-
-def get_gpus():
-    """
-    return a list of GPUs
-    """
-    try:
-        re = subprocess.check_output(["nvidia-smi", "-L"], 
universal_newlines=True)
-    except OSError:
-        return []
-    return range(len([i for i in re.split('\n') if 'GPU' in i]))
-
-class TestRNNLayer(gluon.HybridBlock):
-    def __init__(self, cell, prefix=None, params=None):
-        super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
-        self.cell = cell
-
-    def hybrid_forward(self, F, inputs, states):
-        out, states = F.contrib.foreach(self.cell, inputs, states)
-        return out
-
-def benchmark_rnn(cell, rnn_data, states):
-    ctx = rnn_data.context
-    num_batches = 20
-
-    # Imperative
-    cell0 = copy.deepcopy(cell)
-    layer0 = TestRNNLayer(cell0)
-    layer0.initialize(ctx=ctx)
-
-    # Hybridize
-    cell1 = copy.deepcopy(cell)
-    cell1.hybridize()
-    layer1 = TestRNNLayer(cell1)
-    layer1.initialize(ctx=ctx)
-
-    # Hybridize
-    cell2 = copy.deepcopy(cell)
-    layer2 = TestRNNLayer(cell2)
-    layer2.initialize(ctx=ctx)
-    layer2.hybridize()
-    layer2(rnn_data, states)
-
-    # Hybridize
-    cell3 = copy.deepcopy(cell)
-    cell3.hybridize(static_alloc=True)
-    layer3 = TestRNNLayer(cell3)
-    layer3.initialize(ctx=ctx)
-
-    tic = time.time()
-    for i in range(num_batches):
-        res0 = layer0(rnn_data, states)
-        mx.nd.waitall()
-    print("Imperative inference takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        res1 = layer1(rnn_data, states)
-        mx.nd.waitall()
-    print("Hybrid-cell inference takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        res3 = layer3(rnn_data, states)
-        mx.nd.waitall()
-    print("Static-hybrid-cell inference takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        res2 = layer2(rnn_data, states)
-        mx.nd.waitall()
-    print("Hybrid inference takes " + str(time.time() - tic))
-
-    layer2.export("foreach_rnn")
-    symnet = mx.symbol.load('foreach_rnn-symbol.json')
-    args1 = {}
-    params = layer2.collect_params()
-    for key in params.keys():
-        args1[key] = params[key].data()
-    args1['data0'] = rnn_data
-    for i in range(len(states)):
-        args1['data' + str(i + 1)] = states[i]
-    exe = symnet.bind(ctx=ctx, args=args1)
-    tic = time.time()
-    for i in range(num_batches):
-        exe.forward(is_train=False)
-        mx.nd.waitall()
-    print("Symbol inference takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        with mx.autograd.record():
-            res0 = layer0(rnn_data, states)
-        res0.backward()
-        mx.nd.waitall()
-    print("Imperative training takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        with mx.autograd.record():
-            res1 = layer1(rnn_data, states)
-        res1.backward()
-        mx.nd.waitall()
-    print("Hybrid-cell training takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        with mx.autograd.record():
-            res3 = layer3(rnn_data, states)
-        res3.backward()
-        mx.nd.waitall()
-    print("Static-hybrid-cell training takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        with mx.autograd.record():
-            res2 = layer2(rnn_data, states)
-        res2.backward()
-        mx.nd.waitall()
-    print("Hybrid training takes " + str(time.time() - tic))
-
-    # gradients for the backward of the foreach symbol
-    args_grad1 = {}
-    for key in args1.keys():
-        args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx)
-    exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1)
-    tic = time.time()
-    for i in range(num_batches):
-        exe.forward(is_train=True)
-        exe.backward(res2)
-        mx.nd.waitall()
-    print("Symbol training takes " + str(time.time() - tic))
-    print("")
-
-if __name__ == '__main__':
-    ndim = 512
-    seq_len = 100
-    batch_sizes = [1, 32]
-    cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'),
-             gluon.rnn.GRUCell(ndim, prefix='rnn_'),
-             gluon.rnn.LSTMCell(ndim, prefix='rnn_')]
-    ctxs = [mx.cpu(0), mx.gpu(0)]
-    for cell in cells:
-        for ctx in ctxs:
-            for batch_size in batch_sizes:
-                if len(get_gpus()) == 0 and ctx == mx.gpu(0):
-                    continue
-                if isinstance(cell, gluon.rnn.RNNCell):
-                    rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, 
batch_size, ndim),
-                                            ctx=mx.cpu(0))
-                    states = []
-                    states.append(mx.nd.normal(loc=0, scale=1, 
shape=(batch_size, ndim),
-                                               ctx=mx.cpu(0)))
-                elif isinstance(cell, gluon.rnn.GRUCell):
-                    rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, 
batch_size, ndim),
-                                            ctx=mx.cpu(0))
-                    states = []
-                    states.append(mx.nd.normal(loc=0, scale=1, 
shape=(batch_size, ndim),
-                                               ctx=mx.cpu(0)))
-                elif isinstance(cell, gluon.rnn.LSTMCell):
-                    rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, 
batch_size, ndim),
-                                            ctx=mx.cpu(0))
-                    states = []
-                    states.append(mx.nd.normal(loc=0, scale=1, 
shape=(batch_size, ndim),
-                                               ctx=mx.cpu(0)))
-                    states.append(mx.nd.normal(loc=0, scale=1, 
shape=(batch_size, ndim),
-                                               ctx=mx.cpu(0)))
-                if ctx == mx.gpu(0):
-                    dev = "GPU"
-                else:
-                    dev = "CPU"
-                print("Benchmark {} in {} (batch size: 
{})".format(cell._alias(), dev,
-                                                                   batch_size))
-                benchmark_rnn(cell, rnn_data, states)
diff --git a/benchmark/python/control_flow/while_loop_rnn.py 
b/benchmark/python/control_flow/while_loop_rnn.py
deleted file mode 100644
index 42aaee5..0000000
--- a/benchmark/python/control_flow/while_loop_rnn.py
+++ /dev/null
@@ -1,213 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# Code borrowed from ./benchmark/python/control_flow/foreach_rnn.py
-
-import subprocess
-import mxnet as mx
-from mxnet import gluon
-import time
-import copy
-
-def get_gpus():
-    """
-    return a list of GPUs
-    """
-    try:
-        re = subprocess.check_output(["nvidia-smi", "-L"], 
universal_newlines=True)
-    except OSError:
-        return []
-    return range(len([i for i in re.split('\n') if 'GPU' in i]))
-
-class TestRNNLayer(gluon.HybridBlock):
-    def __init__(self, cell, length, prefix=None, params=None):
-        super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
-        self.length = length
-        self.cell = cell
-
-    def hybrid_forward(self, F, inputs, states):
-        def _func(*states):
-            i = states[0]
-            s = states[1: ]
-            data = inputs.take(i).squeeze(axis=0)
-            out, new_s = self.cell(data, s)
-            new_s = [i + 1] + new_s
-            return out, new_s
-        out, states = F.contrib.while_loop(
-            cond=lambda i, *_: i < self.length,
-            func=_func,
-            loop_vars=states,
-            max_iterations=self.length,
-        )
-        return out + states
-
-def benchmark_rnn(cell, rnn_data, states, length):
-    ctx = rnn_data.context
-    num_batches = 20
-
-    # Imperative
-    cell0 = copy.deepcopy(cell)
-    layer0 = TestRNNLayer(cell0, length)
-    layer0.initialize(ctx=ctx)
-
-    # Hybrid-cell
-    cell1 = copy.deepcopy(cell)
-    cell1.hybridize()
-    layer1 = TestRNNLayer(cell1, length)
-    layer1.initialize(ctx=ctx)
-
-    # Hybrid
-    cell2 = copy.deepcopy(cell)
-    layer2 = TestRNNLayer(cell2, length)
-    layer2.initialize(ctx=ctx)
-    layer2.hybridize()
-    layer2(rnn_data, states)
-
-    # Static-hybrid-cell
-    cell3 = copy.deepcopy(cell)
-    cell3.hybridize(static_alloc=True)
-    layer3 = TestRNNLayer(cell3, length)
-    layer3.initialize(ctx=ctx)
-
-    tic = time.time()
-    for i in range(num_batches):
-        res0 = layer0(rnn_data, states)
-        mx.nd.waitall()
-    print("Imperative inference takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        res1 = layer1(rnn_data, states)
-        mx.nd.waitall()
-    print("Hybrid-cell inference takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        res3 = layer3(rnn_data, states)
-        mx.nd.waitall()
-    print("Static-hybrid-cell inference takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        res2 = layer2(rnn_data, states)
-        mx.nd.waitall()
-    print("Hybrid inference takes " + str(time.time() - tic))
-
-    layer2.export("while_loop_rnn")
-    symnet = mx.symbol.load('while_loop_rnn-symbol.json')
-    args1 = {}
-    params = layer2.collect_params()
-    for key in params.keys():
-        args1[key] = params[key].data()
-    args1['data0'] = rnn_data
-    for i in range(len(states)):
-        args1['data' + str(i + 1)] = states[i]
-    exe = symnet.bind(ctx=ctx, args=args1)
-    tic = time.time()
-    for i in range(num_batches):
-        exe.forward(is_train=False)
-        mx.nd.waitall()
-    print("Symbol inference takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        with mx.autograd.record():
-            res0 = layer0(rnn_data, states)
-        res0[0].backward()
-        mx.nd.waitall()
-    print("Imperative training takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        with mx.autograd.record():
-            res1 = layer1(rnn_data, states)
-        res1[0].backward()
-        mx.nd.waitall()
-    print("Hybrid-cell training takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        with mx.autograd.record():
-            res3 = layer3(rnn_data, states)
-        res3[0].backward()
-        mx.nd.waitall()
-    print("Static-hybrid-cell training takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        with mx.autograd.record():
-            res2 = layer2(rnn_data, states)
-        res2[0].backward()
-        mx.nd.waitall()
-    print("Hybrid training takes " + str(time.time() - tic))
-
-    # gradients for the backward of the while_loop symbol
-    args_grad1 = {}
-    for key in args1.keys():
-        if key != "data1":
-            args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx)
-    exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1)
-    tic = time.time()
-    for i in range(num_batches):
-        exe.forward(is_train=True)
-        exe.backward(res2)
-        mx.nd.waitall()
-    print("Symbol training takes " + str(time.time() - tic))
-    print("")
-
-if __name__ == '__main__':
-    def _zeros(shape):
-        return mx.nd.zeros(shape=shape, ctx=mx.cpu(0))
-    def _array(shape):
-        return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=mx.cpu(0))
-    ndim = 512
-    seq_len = 100
-    batch_sizes = [1, 32]
-    cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'),
-             gluon.rnn.GRUCell(ndim, prefix='rnn_'),
-             gluon.rnn.LSTMCell(ndim, prefix='rnn_')]
-    ctxs = [mx.cpu(0), mx.gpu(0)]
-    for cell in cells:
-        for ctx in ctxs:
-            for batch_size in batch_sizes:
-                if len(get_gpus()) == 0 and ctx == mx.gpu(0):
-                    continue
-                if isinstance(cell, gluon.rnn.RNNCell):
-                    rnn_data = _array((seq_len, batch_size, ndim))
-                    states = [
-                        _zeros((1, )),
-                        _array((batch_size, ndim)),
-                    ]
-                if isinstance(cell, gluon.rnn.GRUCell):
-                    rnn_data = _array((seq_len, batch_size, ndim))
-                    states = [
-                        _zeros((1, )),
-                        _array((batch_size, ndim)),
-                    ]
-                elif isinstance(cell, gluon.rnn.LSTMCell):
-                    rnn_data = _array((seq_len, batch_size, ndim))
-                    states = [
-                        _zeros((1, )),
-                        _array((batch_size, ndim)),
-                        _array((batch_size, ndim)),
-                    ]
-                if ctx == mx.gpu(0):
-                    dev = "GPU"
-                else:
-                    dev = "CPU"
-                print("Benchmark {} in {} (batch size: 
{})".format(cell._alias(), dev, batch_size))
-                benchmark_rnn(cell, rnn_data, states, seq_len)
diff --git a/docs/api/python/ndarray/contrib.md 
b/docs/api/python/ndarray/contrib.md
index 0cf8724..97f38a5 100644
--- a/docs/api/python/ndarray/contrib.md
+++ b/docs/api/python/ndarray/contrib.md
@@ -54,6 +54,7 @@ In the rest of this document, we list routines provided by 
the `ndarray.contrib`
     quantize
     foreach
     while_loop
+    cond
 ```
 
 ## API Reference
diff --git a/docs/api/python/symbol/contrib.md 
b/docs/api/python/symbol/contrib.md
index ba43f2d..c0a4da5 100644
--- a/docs/api/python/symbol/contrib.md
+++ b/docs/api/python/symbol/contrib.md
@@ -54,6 +54,7 @@ In the rest of this document, we list routines provided by 
the `symbol.contrib`
     quantize
     foreach
     while_loop
+    cond
 ```
 
 ## API Reference
diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py
index b67cf5a..aae898a 100644
--- a/python/mxnet/ndarray/contrib.py
+++ b/python/mxnet/ndarray/contrib.py
@@ -16,7 +16,7 @@
 # under the License.
 
 # coding: utf-8
-# pylint: disable=wildcard-import, unused-wildcard-import
+# pylint: disable=wildcard-import, unused-wildcard-import,redefined-outer-name
 """Contrib NDArray API of MXNet."""
 import math
 from ..context import current_context
@@ -28,7 +28,7 @@ try:
 except ImportError:
     pass
 
-__all__ = ["rand_zipfian", "foreach", "while_loop"]
+__all__ = ["rand_zipfian", "foreach", "while_loop", "cond"]
 
 # pylint: disable=line-too-long
 def rand_zipfian(true_classes, num_sampled, range_max, ctx=None):
@@ -192,7 +192,6 @@ def foreach(body, data, init_states):
         outputs = outputs[0]
     return (outputs, states)
 
-
 def while_loop(cond, func, loop_vars, max_iterations=None):
     """Run a while loop with user-defined computation and loop condition.
 
@@ -363,3 +362,87 @@ def while_loop(cond, func, loop_vars, max_iterations=None):
                 ["  Step %d, shape is %s" % (i, str(x.shape)) for i, x in 
enumerate(items)]
             ))
     return stacked_outputs, list(loop_vars)
+
+def cond(pred, then_func, else_func):
+    """Run an if-then-else using user-defined condition and computation
+
+    This operator simulates a if-like branch which chooses to do one of
+    the two customized computations according to the specified condition.
+
+    `pred` is a scalar MXNet NDArray,
+    indicating which branch of computation should be used.
+
+    `then_func` is a user-defined function, used as computation of the then 
branch.
+    It produces `outputs`, which is a list of NDArrays.
+    The signature of `then_func` should be
+    `then_func() => List[NDArray]`.
+
+    `else_func` is a user-defined function, used as computation of the else 
branch.
+    It produces `outputs`, which is a list of NDArrays.
+    The signature of `else_func` should be
+    `else_func() => List[NDArray]`.
+
+    The `outputs` produces by `then_func` and `else_func` should have the same 
number
+    of elements, all of which should be in the same shape, of the same dtype 
and stype.
+
+    This function returns a list of symbols, representing the computation 
result.
+
+    Parameters
+    ----------
+    pred: a MXNet NDArray representing a scalar.
+        The branch condition.
+    then_func: a Python function.
+        The computation to be executed if `pred` is true.
+    else_func: a Python function.
+        The computation to be executed if `pred` is false.
+
+    Returns
+    -------
+    outputs: a list of NDArrays, representing the result of computation.
+
+    Examples
+    --------
+    >>> a, b = mx.nd.array([1]), mx.nd.array([2])
+    >>> pred = a * b < 5
+    >>> then_func = lambda a, b: (a + 5) * (b + 5)
+    >>> else_func = lambda a, b: (a - 5) * (b - 5)
+    >>> outputs = mx.nd.contrib.cond(pred, then_func, else_func)
+    >>> outputs[0]
+    [42.]
+    <NDArray 1 @cpu(0)>
+    """
+    def _to_python_scalar(inputs, type_, name):
+        """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, 
other python types,
+        to the given type
+        """
+        if hasattr(inputs, "asscalar"):
+            inputs = inputs.asscalar()
+        try:
+            inputs = type_(inputs)
+        except:
+            raise ValueError("Cannot convert %s to python %s" % (name, 
type_.__name__))
+        return inputs
+
+    def _to_ndarray_tuple(inputs, name):
+        """Converts "inputs", possibly a single mxnet NDArray, a list of mxnet 
NDArray,
+        a tuple of mxnet NDArray, into a tuple of NDArray
+        """
+        if isinstance(inputs, list):
+            inputs = tuple(inputs)
+        if isinstance(inputs, ndarray.NDArray):
+            inputs = (inputs, )
+        if not isinstance(inputs, tuple):
+            raise ValueError("%s must be an NDArray, or a tuple or list of 
NDArrays" % (name, ))
+        for item in inputs:
+            if not isinstance(item, ndarray.NDArray):
+                raise ValueError("%s must be an NDArray, or a tuple or list of 
NDArrays" % (name, ))
+        return inputs
+
+    branch = _to_python_scalar(pred, bool, "pred")
+    if branch:
+        outputs = then_func()
+        outputs = _to_ndarray_tuple(outputs, "outputs of then_func")
+    else:
+        outputs = else_func()
+        outputs = _to_ndarray_tuple(outputs, "outputs of else_func")
+    return list(outputs)
diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py
index 2c11921..8842883 100644
--- a/python/mxnet/symbol/contrib.py
+++ b/python/mxnet/symbol/contrib.py
@@ -16,7 +16,7 @@
 # under the License.
 
 # coding: utf-8
-# pylint: disable=wildcard-import, unused-wildcard-import
+# pylint: disable=wildcard-import, unused-wildcard-import,redefined-outer-name
 """Contrib Symbol API of MXNet."""
 import math
 import ctypes
@@ -34,7 +34,7 @@ from ..base import _LIB, check_call
 from ..base import SymbolHandle, _as_list
 from ..attribute import AttrScope
 
-__all__ = ["rand_zipfian", "foreach", "while_loop"]
+__all__ = ["rand_zipfian", "foreach", "while_loop", "cond"]
 
 def rand_zipfian(true_classes, num_sampled, range_max):
     """Draw random samples from an approximately log-uniform or Zipfian 
distribution.
@@ -556,3 +556,145 @@ def while_loop(cond, func, loop_vars, 
max_iterations=None, name="while_loop"):
     outputs = [result[i] for i in range(num_out_data)]
     final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)]
     return outputs, final_loop_vars
+
+def cond(pred, then_func, else_func, name="cond"):
+    """Run an if-then-else using user-defined condition and computation
+
+    This operator simulates a if-like branch which chooses to do one of
+    the two customized computations according to the specified condition.
+
+    `pred` is a scalar MXNet Symbol,
+    indicating which branch of computation should be used.
+
+    `then_func` is a user-defined function, used as computation of the then 
branch.
+    It produces `outputs`, which is a list of Symbols.
+    The signature of `then_func` should be
+    `then_func() => List[Symbol]`.
+
+    `else_func` is a user-defined function, used as computation of the else 
branch.
+    It produces `outputs`, which is a list of Symbols.
+    The signature of `else_func` should be
+    `else_func() => List[Symbol]`.
+
+    The `outputs` produces by `then_func` and `else_func` should have the same 
number
+    of elements, all of which should be in the same shape, of the same dtype 
and stype.
+
+    This function returns a list of symbols, representing the computation 
result.
+
+    Parameters
+    ----------
+    pred: a MXNet Symbol representing a scalar.
+        The branch condition.
+    then_func: a Python function.
+        The computation to be executed if `pred` is true.
+    else_func: a Python function.
+        The computation to be executed if `pred` is false.
+
+    Returns
+    -------
+    outputs: a list of Symbols, representing the result of computation.
+
+    Examples
+    --------
+    >>> a, b = mx.sym.var('a'), mx.sym.var('b')
+    >>> pred = a * b < 5
+    >>> then_func = lambda: (a + 5) * (b + 5)
+    >>> else_func = lambda: (a - 5) * (b - 5)
+    >>> outputs = mx.sym.contrib.cond(pred, then_func, else_func)
+    """
+    def _to_symbol_tuple(inputs, name):
+        """Converts "inputs", possibly a single mxnet Symbol, a list of mxnet 
Symbol,
+        a tuple of mxnet Symbol, into a tuple of Symbol
+        """
+        if isinstance(inputs, list):
+            inputs = tuple(inputs)
+        if isinstance(inputs, Symbol):
+            inputs = (inputs, )
+        if not isinstance(inputs, tuple):
+            raise ValueError("%s must be a Symbol, or a tuple or list of 
Symbol" % (name, ))
+        for item in inputs:
+            if not isinstance(item, Symbol):
+                raise ValueError("%s must be a Symbol, or a tuple or list of 
Symbol" % (name, ))
+        return inputs
+
+    def _create_subgraph(graph_vars, graph_func, subgraph_name):
+        with AttrScope(__subgraph_name__=subgraph_name):
+            # create new variables with the same name,
+            # them feed them to the given func
+            new_graph_vars = [symbol.var(sym.name) for sym in graph_vars]
+            outputs = graph_func(*new_graph_vars)
+            outputs = _to_symbol_tuple(outputs, "outputs")
+            num_outputs = len(outputs)
+            # nnvm cut-graph does not allow inputs and outputs overlap
+            # so we calculate the name of inputs, and copy outputs once it 
overlaps with inputs
+            all_input_names = symbol.Group(outputs).list_inputs()
+            make_identity = lambda x: symbol.op.identity(x) if x.name in 
all_input_names else x
+            # group all outputs of graph_func
+            graph = symbol.Group(list(map(make_identity, outputs)))
+        return graph, num_outputs
+
+    def _union_inputs(*graphs):
+        # Given a list of graphs, each whose inputs are either from input_vars 
or other variables.
+        # 1) calculate a list `inputs`, the union of their inputs.
+        # 2) for each graph, determine in which indices their inputs reside in 
`inputs`
+        # 3) for each variable in the input of `graph`, find which index it is
+        inputs = []             # List[Symbol], result of 1)
+        locs = []               # List[Tuple(List[Int], List[Int])], a list of 
tuples,
+                                # where tuples are results of 2) and 3)
+        input_id_to_loc = {}    # Dict[int, int], given id(sym), 
input_id_to_loc maps it
+                                # to a `loc`, where inputs[loc] = sym
+        for graph in graphs:
+            # input_syms: all inputs to the `graph`
+            name_to_input_syms = {sym.name: sym for sym in 
_get_graph_inputs(graph)}
+            # some input_vars are inputs to `graph`, some are not
+            name_to_input_vars = {sym.name: sym for sym in inputs}
+            # other inputs to `graph` created by cut_graph
+            name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in 
_cut_subgraph(graph)}
+            # collect arguments for each subgraph
+            input_locs = []                         # results from the second 
step
+            for name in graph.list_inputs():
+                assert name in name_to_input_syms   # it should obviously hold
+                # name -> sym
+                if name in name_to_input_vars:
+                    sym = name_to_input_vars[name]
+                elif name in name_to_cut_g_syms:
+                    sym = name_to_cut_g_syms[name]
+                else:
+                    sym = copy.deepcopy(name_to_input_syms[name])
+                # do 2), and 1) is implicitly done
+                if id(sym) in input_id_to_loc:
+                    loc = input_id_to_loc[id(sym)]
+                else:
+                    loc = len(input_id_to_loc)
+                    inputs.append(sym)
+                    input_id_to_loc[id(sym)] = loc
+                input_locs.append(loc)
+            locs.append(input_locs)
+        return inputs, locs
+    inputs = []
+    # create graph for `cond_func'
+    cond_g, cond_num_outputs = _create_subgraph(inputs, lambda: pred, name + 
"_pred")
+    if cond_num_outputs != 1:
+        raise ValueError("pred should always be a single output")
+    # create graph for `then`
+    then_g, then_num_outputs = _create_subgraph(inputs, then_func, name + 
"_then")
+    # create graph for `else`
+    else_g, else_num_outputs = _create_subgraph(inputs, else_func, name + 
"_else")
+    if then_num_outputs != else_num_outputs:
+        raise ValueError("Number of outputs differs between then-branch and 
else-branch")
+    # find symbols used in either cond_g or func_g
+    input_syms, (cond_input_locs, then_input_locs, else_input_locs) = \
+        _union_inputs(cond_g, then_g, else_g)
+    result = symbol._internal._cond(
+        # [cond, then_g, else_g, *input_syms]
+        cond_g,
+        then_g,
+        else_g,
+        *input_syms,
+        cond_input_locs=cond_input_locs,
+        then_input_locs=then_input_locs,
+        else_input_locs=else_input_locs,
+        num_outputs=then_num_outputs
+    )
+    result = _to_symbol_tuple(result, "result")
+    return list(result)
diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc
index b00ed9b..7c1becc 100644
--- a/src/operator/control_flow.cc
+++ b/src/operator/control_flow.cc
@@ -508,6 +508,18 @@ struct WhileLoopParam : public 
dmlc::Parameter<WhileLoopParam> {
     DMLC_DECLARE_FIELD(func_var_locs)
     .describe("The locations of loop_vars among func's inputs.");
   }
+  template <typename T>
+  bool sync_in_out(std::vector<T> *in,
+                   std::vector<T> *out,
+                   std::function<bool(const T &)> is_empty) const {
+    for (int i = this->num_out_data; i < this->num_outputs; ++i) {
+      // each out->at(i) is a params, loop_var
+      T &x = in->at(this->func_input_locs[this->func_var_locs[i - 
this->num_out_data]]);
+      T &y = out->at(i);
+      fill_value(&x, &y, is_empty(x), is_empty(y));
+    }
+    return true;
+  }
 };  // struct WhileLoopParam
 
 DMLC_REGISTER_PARAMETER(WhileLoopParam);
@@ -540,84 +552,8 @@ class WhileLoopState: public LoopState {
       }
     }
   }
-  template <typename T>
-  static void extract_by_loc(const std::vector<T> &array,
-                             const nnvm::Tuple<dim_t> input_locs,
-                             std::vector<T> *out) {
-    out->clear();
-    out->reserve(input_locs.ndim());
-    for (dim_t i : input_locs) {
-      out->push_back(array[i]);
-    }
-  }
-  static bool is_shape_udf(const TShape &x) {
-    return x.ndim() == 0 || x.Size() == 0;
-  }
-  static bool is_stype_udf(const int &x) {
-    return x == exec::kBadStorageID;
-  }
-  static bool is_type_udf(const int &x) {
-    return x == -1;
-  }
-  template <typename T>
-  static bool fill_value(T *x, T *y, bool x_empty, bool y_empty) {
-    if (*x == *y || (x_empty && y_empty)) {
-      return true;
-    }
-    if (!x_empty && !y_empty) {
-      return false;
-    }
-    if (x_empty) {
-      *x = *y;
-    }
-    if (y_empty) {
-      *y = *x;
-    }
-    return true;
-  }
-  template <typename T>
-  static bool sync_in_in(const nnvm::Tuple<dim_t> &input_locs,
-                         std::vector<T> *in,
-                         std::vector<T> *subg_in,
-                         std::function<bool(const T &)> is_empty) {
-    for (size_t i = 0; i < input_locs.ndim(); ++i) {
-      T &x = in->at(input_locs[i]);
-      T &y = subg_in->at(i);
-      fill_value(&x, &y, is_empty(x), is_empty(y));
-    }
-    return true;
-  }
-  template <typename T>
-  static bool sync_in_out(const WhileLoopParam& params,
-                          std::vector<T> *in,
-                          std::vector<T> *out,
-                          std::function<bool(const T &)> is_empty) {
-    for (int i = params.num_out_data; i < params.num_outputs; ++i) {
-      // each out->at(i) is a params, loop_var
-      T &x = in->at(params.func_input_locs[params.func_var_locs[i - 
params.num_out_data]]);
-      T &y = out->at(i);
-      fill_value(&x, &y, is_empty(x), is_empty(y));
-    }
-    return true;
-  }
 };
 
-template <typename T>
-T _asscalar(const NDArray &a) {
-  CHECK_EQ(a.shape().Size(), 1U);
-  T data;
-  a.SyncCopyToCPU(&data, 1U);
-  return data;
-}
-
-bool as_bool_scalar(const NDArray &a) {
-  MSHADOW_TYPE_SWITCH(a.dtype(), DType, {
-    return static_cast<bool>(_asscalar<DType>(a));
-  });
-  LOG(FATAL) << "Unknown dtype";
-  return false;
-}
-
 static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr,
                                   const OpContext& ctx,
                                   const std::vector<NDArray>& inputs,
@@ -648,13 +584,13 @@ static void WhileLoopComputeExCPU(const OpStatePtr& 
state_ptr,
     CHECK_EQ(params.max_iterations, outputs[i].shape()[0]);
   // construct inputs and outputs for cond
   std::vector<NDArray> cond_inputs, cond_outputs = {NDArray()};
-  WhileLoopState::extract_by_loc(inputs, params.cond_input_locs, &cond_inputs);
+  extract_by_loc(inputs, params.cond_input_locs, &cond_inputs);
   std::vector<NDArray*> cond_input_ptr, cond_output_ptr;
   to_ptr_vec(cond_inputs, &cond_input_ptr);
   to_ptr_vec(cond_outputs, &cond_output_ptr);
   // construct inputs and outputs for func
   std::vector<NDArray> func_inputs, func_outputs(outputs.size());
-  WhileLoopState::extract_by_loc(inputs, params.func_input_locs, &func_inputs);
+  extract_by_loc(inputs, params.func_input_locs, &func_inputs);
   for (size_t &step = state.n_iterations = 0; step < (size_t) 
params.max_iterations; ++step) {
     state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr);
     if (!as_bool_scalar(*cond_output_ptr[0])) {
@@ -716,8 +652,8 @@ static void WhileLoopGradComputeExCPU(const OpStatePtr& 
state_ptr,
   }
   std::vector<NDArray> outputs;
   std::vector<OpReqType> req;
-  WhileLoopState::extract_by_loc(_outputs, params.func_input_locs, &outputs);
-  WhileLoopState::extract_by_loc(_req, params.func_input_locs, &req);
+  extract_by_loc(_outputs, params.func_input_locs, &outputs);
+  extract_by_loc(_req, params.func_input_locs, &req);
   if (state.n_iterations == 0) {
     for (int i = params.num_out_data; i < params.num_outputs; ++i) {
       int j = params.func_var_locs[i - params.num_out_data];
@@ -796,7 +732,7 @@ static bool WhileLoopShape(const nnvm::NodeAttrs& attrs,
                            std::vector<TShape> *out_shape) {
   using nnvm::ShapeVector;
   const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
-  static const std::function<bool(const TShape &)> is_udf = 
WhileLoopState::is_shape_udf;
+  static const std::function<bool(const TShape &)> is_udf = is_shape_udf;
   // sanity checks
   CHECK_EQ(in_shape->size() + 2U, (size_t) params.num_args);
   CHECK_EQ(out_shape->size(), (size_t) params.num_outputs);
@@ -811,7 +747,7 @@ static bool WhileLoopShape(const nnvm::NodeAttrs& attrs,
     // create subg_in
     ShapeVector subg_in;
     ShapeVector &subg_out = *_subg_out;
-    WhileLoopState::extract_by_loc(*in_shape, input_locs, &subg_in);
+    extract_by_loc(*in_shape, input_locs, &subg_in);
     // create an indexed graph
     nnvm::Graph g;
     g.outputs = subg->outputs;
@@ -884,35 +820,35 @@ static bool WhileLoopShape(const nnvm::NodeAttrs& attrs,
   };
   ShapeVector cond_out_shape{TShape(1U)};  // this means: [(1, )]
   ShapeVector func_out_shape(params.num_outputs);
-  CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf));
+  CHECK(params.sync_in_out(in_shape, out_shape, is_udf));
   bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, 
params.cond_input_locs, 0, false);
-  CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf));
+  CHECK(params.sync_in_out(in_shape, out_shape, is_udf));
   bool succ_1 = infer_subg(attrs.subgraphs[1], &func_out_shape, \
                            params.func_input_locs, params.num_out_data, true);
-  CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf));
+  CHECK(params.sync_in_out(in_shape, out_shape, is_udf));
   return succ_0 && succ_1;
 }
 
 static bool WhileLoopType(const nnvm::NodeAttrs& attrs,
                           std::vector<int> *in_type, std::vector<int> 
*out_type) {
   const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
-  static const std::function<bool(const int &)> is_udf = 
WhileLoopState::is_type_udf;
+  static const std::function<bool(const int &)> is_udf = is_type_udf;
   CHECK_EQ(in_type->size() + 2U, (size_t) params.num_args);
   CHECK_EQ(out_type->size(), (size_t) params.num_outputs);
   CHECK_EQ(attrs.subgraphs.size(), 2U);
   CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
   std::vector<int> cond_in_type;
   std::vector<int> func_in_type;
-  WhileLoopState::extract_by_loc(*in_type, params.cond_input_locs, 
&cond_in_type);
-  WhileLoopState::extract_by_loc(*in_type, params.func_input_locs, 
&func_in_type);
+  extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type);
+  extract_by_loc(*in_type, params.func_input_locs, &func_in_type);
   std::vector<int> cond_out_type = {0};
-  CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf));
+  CHECK(params.sync_in_out(in_type, out_type, is_udf));
   bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type, 
&cond_out_type);
-  CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf));
-  CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_type, 
&cond_in_type, is_udf));
+  CHECK(params.sync_in_out(in_type, out_type, is_udf));
+  CHECK(sync_in_in(params.cond_input_locs, in_type, &cond_in_type, is_udf));
   bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &func_in_type, 
out_type);
-  CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf));
-  CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_type, 
&func_in_type, is_udf));
+  CHECK(params.sync_in_out(in_type, out_type, is_udf));
+  CHECK(sync_in_in(params.func_input_locs, in_type, &func_in_type, is_udf));
   return succ_0 && succ_1;
 }
 
@@ -922,28 +858,28 @@ static bool WhileLoopStorageType(const nnvm::NodeAttrs& 
attrs,
                                  std::vector<int> *in_attrs,
                                  std::vector<int> *out_attrs) {
   const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
-  static const std::function<bool(const int &)> is_udf = 
WhileLoopState::is_stype_udf;
+  static const std::function<bool(const int &)> is_udf = is_stype_udf;
   CHECK_EQ(in_attrs->size() + 2U, (size_t) params.num_args);
   CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs);
   CHECK_EQ(attrs.subgraphs.size(), 2U);
   CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
   std::vector<int> cond_in_attrs;
   std::vector<int> func_in_attrs;
-  WhileLoopState::extract_by_loc(*in_attrs, params.cond_input_locs, 
&cond_in_attrs);
-  WhileLoopState::extract_by_loc(*in_attrs, params.func_input_locs, 
&func_in_attrs);
+  extract_by_loc(*in_attrs, params.cond_input_locs, &cond_in_attrs);
+  extract_by_loc(*in_attrs, params.func_input_locs, &func_in_attrs);
   std::vector<int> cond_out_attrs = {kDefaultStorage};
   DispatchMode cond_mode = DispatchMode::kUndefined;
   DispatchMode func_mode = DispatchMode::kUndefined;
   *dispatch_mode = DispatchMode::kFComputeEx;
-  CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf));
+  CHECK(params.sync_in_out(in_attrs, out_attrs, is_udf));
   bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, \
                                      &cond_mode, &cond_in_attrs, 
&cond_out_attrs);
-  CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf));
-  CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_attrs, 
&cond_in_attrs, is_udf));
+  CHECK(params.sync_in_out(in_attrs, out_attrs, is_udf));
+  CHECK(sync_in_in(params.cond_input_locs, in_attrs, &cond_in_attrs, is_udf));
   bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, \
                                      &func_mode, &func_in_attrs, out_attrs);
-  CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf));
-  CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_attrs, 
&func_in_attrs, is_udf));
+  CHECK(params.sync_in_out(in_attrs, out_attrs, is_udf));
+  CHECK(sync_in_in(params.func_input_locs, in_attrs, &func_in_attrs, is_udf));
   return succ_0 && succ_1;
 }
 
@@ -977,6 +913,343 @@ WhileLoopGradient(const nnvm::NodePtr& n, const 
std::vector<nnvm::NodeEntry>& og
   return entries;
 }
 
+struct CondParam : public dmlc::Parameter<CondParam> {
+  int num_args;
+  int num_outputs;
+  nnvm::Tuple<dim_t> cond_input_locs;
+  nnvm::Tuple<dim_t> then_input_locs;
+  nnvm::Tuple<dim_t> else_input_locs;
+  DMLC_DECLARE_PARAMETER(CondParam) {
+    DMLC_DECLARE_FIELD(num_args).set_lower_bound(3)
+    .describe("Number of input arguments, including cond, then and else as 
three symbol inputs.");
+    DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1)
+    .describe("The number of outputs of the subgraph.");
+    DMLC_DECLARE_FIELD(cond_input_locs)
+    .describe("The locations of cond's inputs in the given inputs.");
+    DMLC_DECLARE_FIELD(then_input_locs)
+    .describe("The locations of then's inputs in the given inputs.");
+    DMLC_DECLARE_FIELD(else_input_locs)
+    .describe("The locations of else's inputs in the given inputs.");
+  }
+};  // struct CondParam
+
+DMLC_REGISTER_PARAMETER(CondParam);
+
+class CondState {
+ public:
+  CondParam params;
+  CachedOpPtr cond_op;
+  LoopState then_branch;
+  LoopState else_branch;
+  int branch_selection;  // 1 if then branch; 0 if else branch; -1 if undefined
+
+  CondState(const CondParam &params,
+            const Symbol &cond,
+            const Symbol &then_sym,
+            const Symbol &else_sym):
+            params(params),
+            cond_op(LoopState::MakeSharedOp(cond)),
+            then_branch(then_sym),
+            else_branch(else_sym),
+            branch_selection(-1) {
+  }
+};
+
+static void CondComputeExCPU(const OpStatePtr& state_ptr,
+                             const OpContext& ctx,
+                             const std::vector<NDArray>& inputs,
+                             const std::vector<OpReqType>& req,
+                             const std::vector<NDArray>& outputs) {
+  // The argument `inputs' are loop_vars and other inputs
+  // loop_vars are stored in stored in `loop_vars_locs'
+  // The argument `outputs' are output and new_loop_vars
+  // [0: num_out_data) are outputs at each step.
+  // [num_out_data: ) are new_loop_vars
+  CondState &state = state_ptr.get_state<CondState>();
+  const CondParam& params = state.params;
+  // a helper function, converting std::vector<NDArray> to 
std::vector<NDArray*>
+  const auto to_ptr_vec = [](std::vector<NDArray> &in, std::vector<NDArray*> 
*out) {
+    out->clear();
+    out->reserve(in.size());
+    std::transform(std::begin(in),
+                   std::end(in),
+                   std::back_inserter(*out),
+                   [](NDArray &a) {return &a;});
+  };
+  // sanity checks
+  CHECK_EQ(inputs.size() + 3U, (size_t) params.num_args);
+  CHECK_EQ(outputs.size(), (size_t) params.num_outputs);
+  CHECK_EQ(outputs.size(), req.size());
+  // construct inputs and outputs for cond
+  std::vector<NDArray> cond_inputs;
+  std::vector<NDArray> cond_outputs = {NDArray()};
+  std::vector<NDArray*> cond_input_ptr;
+  std::vector<NDArray*> cond_output_ptr;
+  extract_by_loc(inputs, params.cond_input_locs, &cond_inputs);
+  to_ptr_vec(cond_inputs, &cond_input_ptr);
+  to_ptr_vec(cond_outputs, &cond_output_ptr);
+  int &branch_selection = state.branch_selection;
+  // run cond
+  state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr);
+  branch_selection = as_bool_scalar(*cond_output_ptr[0]);
+  // select the right branch
+  const nnvm::Tuple<dim_t> &func_input_locs = branch_selection
+                                            ? params.then_input_locs
+                                            : params.else_input_locs;
+  LoopState &loop_state = branch_selection
+                        ? state.then_branch
+                        : state.else_branch;
+  // extract inputs for the branch
+  std::vector<NDArray> func_inputs;
+  extract_by_loc(inputs, func_input_locs, &func_inputs);
+  loop_state.Forward(0, func_inputs, req, outputs, ctx.need_grad);
+}
+
+static void CondGradComputeExCPU(const OpStatePtr& state_ptr,
+                                 const OpContext& ctx,
+                                 const std::vector<NDArray>& inputs,
+                                 const std::vector<OpReqType>& _req,
+                                 const std::vector<NDArray>& outputs) {
+  CondState &state = state_ptr.get_state<CondState>();
+  const CondParam& params = state.params;
+  // sanity checks
+  CHECK_EQ(outputs.size() + 3U, (size_t) params.num_args);
+  CHECK_EQ(outputs.size(), _req.size());
+  // select the right branch
+  int branch_selection = state.branch_selection;
+  CHECK_NE(branch_selection, -1);
+  const nnvm::Tuple<dim_t> &func_input_locs = branch_selection
+                                            ? params.then_input_locs
+                                            : params.else_input_locs;
+  LoopState &loop_state = branch_selection
+                        ? state.then_branch
+                        : state.else_branch;
+  // construct parameters
+  std::vector<NDArray> ograds(inputs.begin(), inputs.begin() + 
params.num_outputs);
+  std::vector<OpReqType> req;
+  extract_by_loc(_req, func_input_locs, &req);
+  std::vector<NDArray> igrads;
+  extract_by_loc(outputs, func_input_locs, &igrads);
+  loop_state.Backward(0, ograds, req, igrads);
+  loop_state.Cleanup();
+}
+
+static bool CondShape(const nnvm::NodeAttrs& attrs,
+                      std::vector<TShape> *in_shape,
+                      std::vector<TShape> *out_shape) {
+  using nnvm::ShapeVector;
+  const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
+  static const std::function<bool(const TShape &)> is_udf = is_shape_udf;
+  // sanity checks
+  CHECK_EQ(in_shape->size() + 3U, (size_t) params.num_args);
+  CHECK_EQ(out_shape->size(), (size_t) params.num_outputs);
+  CHECK_EQ(attrs.subgraphs.size(), 3U);
+  CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
+  CHECK_EQ(attrs.subgraphs[1]->outputs.size(), 
attrs.subgraphs[2]->outputs.size());
+  // infer shape for cond, then and else
+  auto infer_subg = [&params, in_shape, out_shape](std::shared_ptr<Symbol> 
subg,
+                                                   ShapeVector *_subg_out,
+                                                   const nnvm::Tuple<dim_t> 
&input_locs,
+                                                   bool fill_out_shape) {
+    // create subg_in
+    ShapeVector subg_in;
+    ShapeVector &subg_out = *_subg_out;
+    extract_by_loc(*in_shape, input_locs, &subg_in);
+    // create an indexed graph
+    nnvm::Graph g;
+    g.outputs = subg->outputs;
+    const auto& idx = g.indexed_graph();
+    // get input nodes
+    const auto &input_nids = idx.input_nodes();
+    // sanity checks
+    CHECK_EQ(input_nids.size(), subg_in.size());
+    CHECK_EQ(g.outputs.size(), subg_out.size());
+    CHECK_EQ(idx.input_nodes().size(), subg_in.size());
+    CHECK_EQ(idx.outputs().size(), subg_out.size());
+    // create empty shapes for inference
+    ShapeVector shapes(idx.num_node_entries());
+    // copy subg_in into shapes
+    for (size_t i = 0; i < subg_in.size(); ++i) {
+      auto eid = idx.entry_id(input_nids[i], 0);
+      shapes[eid] = subg_in[i];
+    }
+    // copy subg_out into shapes
+    for (size_t i = 0; i < subg_out.size(); ++i) {
+      auto eid = idx.entry_id(g.outputs[i]);
+      shapes[eid] = subg_out[i];
+    }
+    // copy done, call InferShape
+    g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
+    g = exec::InferShape(std::move(g));
+    // now `shapes' won't be used anymore, use new_shapes instead
+    const auto& new_shapes = g.GetAttr<ShapeVector>("shape");
+    // copy subg_in back to in_shape
+    for (size_t i = 0; i < subg_in.size(); ++i) {
+      auto eid = idx.entry_id(input_nids[i], 0);
+      auto g_out_shape = new_shapes[eid];
+      if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) {
+        // when the shape is not fully inferred
+        continue;
+      }
+      SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape);
+    }
+    if (!fill_out_shape) {
+      return true;
+    }
+    // copy subg_out back to out_shape
+    for (size_t i = 0; i < g.outputs.size(); ++i) {
+      auto eid = idx.entry_id(g.outputs[i]);
+      auto g_out_shape = new_shapes[eid];
+      if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) {
+        // when the shape is not fully inferred
+        continue;
+      }
+      SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape);
+    }
+    return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
+  };
+  ShapeVector cond_out_shape{TShape(1U)};  // this means: [(1, )]
+  ShapeVector then_out_shape(params.num_outputs);
+  ShapeVector else_out_shape(params.num_outputs);
+  bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, \
+                           params.cond_input_locs, false);
+  bool succ_1 = infer_subg(attrs.subgraphs[1], &then_out_shape, \
+                           params.then_input_locs, true);
+  bool succ_2 = infer_subg(attrs.subgraphs[2], &else_out_shape, \
+                           params.else_input_locs, true);
+  sync_out_out(&then_out_shape, &else_out_shape, is_udf);
+  return succ_0 && succ_1 && succ_2;
+}
+
+static bool CondType(const nnvm::NodeAttrs& attrs,
+                     std::vector<int> *in_type,
+                     std::vector<int> *out_type) {
+  const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
+  static const std::function<bool(const int &)> is_udf = is_type_udf;
+  CHECK_EQ(in_type->size() + 3U, (size_t) params.num_args);
+  CHECK_EQ(out_type->size(), (size_t) params.num_outputs);
+  CHECK_EQ(attrs.subgraphs.size(), 3U);
+  CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
+  CHECK_EQ(attrs.subgraphs[1]->outputs.size(), 
attrs.subgraphs[2]->outputs.size());
+  std::vector<int> cond_in_type;
+  std::vector<int> then_in_type;
+  std::vector<int> else_in_type;
+  extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type);
+  extract_by_loc(*in_type, params.then_input_locs, &then_in_type);
+  extract_by_loc(*in_type, params.else_input_locs, &else_in_type);
+  std::vector<int> cond_out_type = {0};
+  bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type, 
&cond_out_type);
+  CHECK(sync_in_in(params.cond_input_locs, in_type, &cond_in_type, is_udf));
+  bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &then_in_type, 
out_type);
+  CHECK(sync_in_in(params.then_input_locs, in_type, &then_in_type, is_udf));
+  bool succ_2 = InferSubgraphDataType(*attrs.subgraphs[2], &else_in_type, 
out_type);
+  CHECK(sync_in_in(params.else_input_locs, in_type, &else_in_type, is_udf));
+  return succ_0 && succ_1 && succ_2;
+}
+
+static bool CondStorageType(const nnvm::NodeAttrs& attrs,
+                            const int dev_mask,
+                            DispatchMode* dispatch_mode,
+                            std::vector<int> *in_attrs,
+                            std::vector<int> *out_attrs) {
+  const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
+  static const std::function<bool(const int &)> is_udf = is_stype_udf;
+  CHECK_EQ(in_attrs->size() + 3U, (size_t) params.num_args);
+  CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs);
+  CHECK_EQ(attrs.subgraphs.size(), 3U);
+  CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
+  CHECK_EQ(attrs.subgraphs[1]->outputs.size(), 
attrs.subgraphs[2]->outputs.size());
+  std::vector<int> cond_in_attrs;
+  std::vector<int> then_in_attrs;
+  std::vector<int> else_in_attrs;
+  extract_by_loc(*in_attrs, params.cond_input_locs, &cond_in_attrs);
+  extract_by_loc(*in_attrs, params.then_input_locs, &then_in_attrs);
+  extract_by_loc(*in_attrs, params.else_input_locs, &else_in_attrs);
+  std::vector<int> cond_out_attrs = {kDefaultStorage};
+  DispatchMode cond_mode = DispatchMode::kUndefined;
+  DispatchMode then_mode = DispatchMode::kUndefined;
+  DispatchMode else_mode = DispatchMode::kUndefined;
+  *dispatch_mode = DispatchMode::kFComputeEx;
+  bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, \
+                                     &cond_mode, &cond_in_attrs, 
&cond_out_attrs);
+  CHECK(sync_in_in(params.cond_input_locs, in_attrs, &cond_in_attrs, is_udf));
+  bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, \
+                                     &then_mode, &then_in_attrs, out_attrs);
+  CHECK(sync_in_in(params.then_input_locs, in_attrs, &then_in_attrs, is_udf));
+  bool succ_2 = InferSubgraphStorage(*attrs.subgraphs[2], dev_mask, \
+                                     &else_mode, &else_in_attrs, out_attrs);
+  CHECK(sync_in_in(params.else_input_locs, in_attrs, &else_in_attrs, is_udf));
+  return succ_0 && succ_1 && succ_2;
+}
+
+static bool BackwardCondStorageType(const nnvm::NodeAttrs& attrs,
+                                    const int dev_mask,
+                                    DispatchMode* dispatch_mode,
+                                    std::vector<int> *in_attrs,
+                                    std::vector<int> *out_attrs) {
+  const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
+  CHECK_EQ(out_attrs->size() + 3U, (size_t) params.num_args);
+  CHECK_EQ(attrs.subgraphs.size(), 3U);
+  static const std::function<bool(const int &)> is_udf = is_stype_udf;
+  auto sub_pass = [&](const std::shared_ptr<Symbol> &subg, const 
nnvm::Tuple<dim_t> &input_locs) {
+    // A. first construct subg_in_attrs
+    // need subg_in_attrs as subg_bwd_out (copy), subg_fwd_in (extract), 
subg_fwd_out (copy)
+    std::vector<int> subg_in_attrs;
+    size_t num_elts = params.num_outputs * 2 + input_locs.ndim();
+    subg_in_attrs.reserve(num_elts);
+    // part 1. subg_bwd_out (copy)
+    subg_in_attrs.insert(subg_in_attrs.end(),
+                         in_attrs->begin(),
+                         in_attrs->begin() + params.num_outputs);
+    // part 2. subg_fwd_in (extract)
+    std::vector<int> fwd_in(in_attrs->begin() + params.num_outputs,
+                            in_attrs->begin() + params.num_outputs + 
params.num_args - 3);
+    std::vector<int> subg_fwd_in;
+    extract_by_loc(fwd_in, input_locs, &subg_fwd_in);
+    subg_in_attrs.insert(subg_in_attrs.end(),
+                         subg_fwd_in.begin(),
+                         subg_fwd_in.end());
+    // part 3. subg_fwd_out (copy)
+    subg_in_attrs.insert(subg_in_attrs.end(),
+                         in_attrs->begin() + params.num_outputs + 
params.num_args - 3,
+                         in_attrs->end());
+    // check correctness of the number of elements
+    CHECK_EQ(subg_in_attrs.size(), num_elts);
+    // B. then we construct subg_out_attrs by extracting from out_attrs
+    std::vector<int> subg_out_attrs;
+    extract_by_loc(*out_attrs, input_locs, &subg_out_attrs);
+    // then we construct the subgraph and do inference
+    CachedOp op(*subg, {});
+    bool ret = op.BackwardStorageType(attrs, dev_mask, dispatch_mode, \
+                                      &subg_in_attrs, &subg_out_attrs);
+    CHECK(sync_in_in(input_locs, out_attrs, &subg_out_attrs, is_udf));
+    return ret;
+  };
+  bool succ_0 = sub_pass(attrs.subgraphs[1], params.then_input_locs);
+  bool succ_1 = sub_pass(attrs.subgraphs[2], params.else_input_locs);
+  return succ_0 && succ_1;
+}
+
+static OpStatePtr CreateCondState(const NodeAttrs& attrs,
+                                  Context ctx,
+                                  const std::vector<TShape>& ishape,
+                                  const std::vector<int>& itype) {
+  const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
+  return OpStatePtr::Create<CondState>(
+    params,
+    *attrs.subgraphs[0],
+    *attrs.subgraphs[1],
+    *attrs.subgraphs[2]);
+}
+
+static std::vector<nnvm::NodeEntry>
+CondGradient(const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& 
ograds) {
+  ElemwiseGradUseInOut fgrad{"_backward_cond"};
+  std::vector<nnvm::NodeEntry> entries = fgrad(n, ograds);
+  entries[0].node->attrs.subgraphs = n->attrs.subgraphs;
+  return entries;
+}
+
 NNVM_REGISTER_OP(_foreach)
 .MXNET_DESCRIBE("Run a for loop over an NDArray with user-defined computation")
 .set_attr_parser(ParamParser<ForeachParam>)
@@ -1100,5 +1373,68 @@ NNVM_REGISTER_OP(_backward_while_loop)
 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", 
WhileLoopGradComputeExCPU)
 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", 
WhileLoopGradComputeExCPU);
 
+NNVM_REGISTER_OP(_cond)
+.MXNET_DESCRIBE("Run a if-then-else using user-defined condition and 
computation")
+.set_attr_parser(ParamParser<CondParam>)
+.set_attr<FInferStorageType>("FInferStorageType", CondStorageType)
+.set_num_inputs([](const NodeAttrs& attrs) {
+  const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
+  return params.num_args;
+})
+.set_num_outputs([](const NodeAttrs& attrs) {
+  const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
+  return params.num_outputs;
+})
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+    [](const NodeAttrs& attrs) {
+  const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
+  std::vector<std::string> names;
+  names.reserve(params.num_args);
+  names.push_back("cond");
+  names.push_back("then_branch");
+  names.push_back("else_branch");
+  for (int i = 3; i < params.num_args; ++i)
+    names.push_back("data" + std::to_string(i - 3));
+  return names;
+})
+.set_attr<nnvm::FInputGraph>("FInputGraph",
+    [](const NodeAttrs& attrs) {
+  return std::vector<uint32_t>{0, 1, 2};
+})
+.set_attr<nnvm::FGradient>("FGradient", CondGradient)
+.set_attr<FCreateOpState>("FCreateOpState", CreateCondState)
+.set_attr<nnvm::FInferShape>("FInferShape", CondShape)
+.set_attr<nnvm::FInferType>("FInferType", CondType)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CondComputeExCPU)
+.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
+  return ExecType::kSubgraphExec;
+})
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CondComputeExCPU)
+.set_attr<std::string>("key_var_num_args", "num_args")
+.add_argument("cond", "Symbol", "Input graph for the condition.")
+.add_argument("then_branch", "Symbol", "Input graph for the then branch.")
+.add_argument("else_branch", "Symbol", "Input graph for the else branch.")
+.add_argument("data", "NDArray-or-Symbol[]",
+              "The input arrays that include data arrays and states.")
+.add_arguments(CondParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_cond)
+.set_num_inputs([](const NodeAttrs& attrs){
+  const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
+  return params.num_outputs * 2 + params.num_args - 3;
+})
+.set_num_outputs([](const NodeAttrs& attrs){
+  const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
+  return params.num_args - 3;
+})
+.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
+  return ExecType::kSubgraphExec;
+})
+.set_attr<FInferStorageType>("FInferStorageType", BackwardCondStorageType)
+.set_attr_parser(ParamParser<CondParam>)
+.set_attr<bool>("TIsLayerOpBackward", true)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CondGradComputeExCPU)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CondGradComputeExCPU);
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/subgraph_op_common.cc 
b/src/operator/subgraph_op_common.cc
index d845aa9..7a99aed 100644
--- a/src/operator/subgraph_op_common.cc
+++ b/src/operator/subgraph_op_common.cc
@@ -161,6 +161,34 @@ bool InferSubgraphShape(const nnvm::Symbol &subgraph,
   return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
 }
 
+template <typename T>
+T _asscalar(const NDArray &a) {
+  CHECK_EQ(a.shape().Size(), 1U);
+  T data;
+  a.SyncCopyToCPU(&data, 1U);
+  return data;
+}
+
+bool as_bool_scalar(const NDArray &a) {
+  MSHADOW_TYPE_SWITCH(a.dtype(), DType, {
+    return static_cast<bool>(_asscalar<DType>(a));
+  });
+  LOG(FATAL) << "Unknown dtype";
+  return false;
+}
+
+bool is_shape_udf(const TShape &x) {
+  return x.ndim() == 0 || x.Size() == 0;
+}
+
+bool is_stype_udf(const int &x) {
+  return x == exec::kBadStorageID;
+}
+
+bool is_type_udf(const int &x) {
+  return x == -1;
+}
+
 LoopState::LoopState(const Symbol &g) {
   this->subgraph_sym = g;
   this->subgraph.outputs = g.outputs;
diff --git a/src/operator/subgraph_op_common.h 
b/src/operator/subgraph_op_common.h
index f73f09c..ebf727f 100644
--- a/src/operator/subgraph_op_common.h
+++ b/src/operator/subgraph_op_common.h
@@ -57,6 +57,68 @@ bool InferSubgraphStorage(const nnvm::Symbol &subgraph,
                           std::vector<int> *in_attrs,
                           std::vector<int> *out_attrs);
 
+bool as_bool_scalar(const NDArray &a);
+
+bool is_shape_udf(const TShape &x);
+
+bool is_stype_udf(const int &x);
+
+bool is_type_udf(const int &x);
+
+template <typename T>
+void extract_by_loc(const std::vector<T> &array,
+                    const nnvm::Tuple<dim_t> input_locs,
+                    std::vector<T> *out) {
+  out->clear();
+  out->reserve(input_locs.ndim());
+  for (dim_t i : input_locs) {
+    out->push_back(array[i]);
+  }
+}
+
+template <typename T>
+bool fill_value(T *x, T *y, bool x_empty, bool y_empty) {
+  if (*x == *y || (x_empty && y_empty)) {
+    return true;
+  }
+  if (!x_empty && !y_empty) {
+    return false;
+  }
+  if (x_empty) {
+    *x = *y;
+  }
+  if (y_empty) {
+    *y = *x;
+  }
+  return true;
+}
+
+template <typename T>
+bool sync_in_in(const nnvm::Tuple<dim_t> &input_locs,
+                         std::vector<T> *in,
+                         std::vector<T> *subg_in,
+                         std::function<bool(const T &)> is_empty) {
+  for (size_t i = 0; i < input_locs.ndim(); ++i) {
+    T &x = in->at(input_locs[i]);
+    T &y = subg_in->at(i);
+    fill_value(&x, &y, is_empty(x), is_empty(y));
+  }
+  return true;
+}
+
+template <typename T>
+bool sync_out_out(std::vector<T> *out_1,
+                  std::vector<T> *out_2,
+                  std::function<bool(const T &)> is_empty) {
+  CHECK_EQ(out_1->size(), out_2->size());
+  for (size_t i = 0; i < out_1->size(); ++i) {
+    T &x = out_1->at(i);
+    T &y = out_2->at(i);
+    fill_value(&x, &y, is_empty(x), is_empty(y));
+  }
+  return true;
+}
+
 /*
  * This contains the states for running a loop and provides methods
  * of running the subgraph computation for an iteration.
diff --git a/tests/python/unittest/test_contrib_control_flow.py 
b/tests/python/unittest/test_contrib_control_flow.py
index 83eebec..1c4e491 100644
--- a/tests/python/unittest/test_contrib_control_flow.py
+++ b/tests/python/unittest/test_contrib_control_flow.py
@@ -975,6 +975,160 @@ def test_while_loop_rnn():
                 y = y.asnumpy()
                 assert_almost_equal(x, y, rtol=1e-4, atol=1e-4)
 
+def _verify_cond(cond_func, then_func, else_func, input_var_shapes, 
free_var_shapes, is_train):
+
+    def _create_symbol(prefix, i):
+        return mx.sym.var(prefix + str(i))
+
+    def _create_array(shape):
+        return mx.nd.random.uniform(-1.0, 1.0, shape=shape)
+
+    def _to_numpy_list(arrays):
+        return [x.asnumpy() if x is not None else x for x in arrays]
+
+    def _merge_dict(*dicts):
+        result = {}
+        for item in dicts:
+            result.update(item)
+        return result
+
+    _input_syms = [_create_symbol("InputVar", i) for i, _ in 
enumerate(input_var_shapes)]
+    _free_syms = [_create_symbol("FreeVar", i) for i, _ in 
enumerate(free_var_shapes)]
+    _input_vars = [_create_array(x) for x in input_var_shapes]
+    _free_vars = [_create_array(x) for x in free_var_shapes]
+    _args_dict = _merge_dict(
+        {"InputVar" + str(i): x for i, x in enumerate(_input_vars)},
+        {"FreeVar" + str(i): x for i, x in enumerate(_free_vars)},
+    )
+
+    def _get_imperative_result():
+        free_vars = [x.copy() for x in _free_vars]
+        input_vars = [x.copy() for x in _input_vars]
+        out_grads = []
+        if is_train:
+            for var in free_vars + input_vars:
+                var.attach_grad()
+        with mx.autograd.record(train_mode=is_train):
+            outputs = mx.nd.contrib.cond(
+                pred=cond_func(input_vars, free_vars),
+                then_func=lambda: then_func(input_vars, free_vars),
+                else_func=lambda: else_func(input_vars, free_vars),
+            )
+            outputs = [x * 2 for x in outputs]
+            grads = []
+            if is_train:
+                out_grads = [_create_array(x.shape) for x in outputs]
+                cat_out = mx.nd.concat(*[x.reshape(-1) for x in outputs], 
dim=0)
+                cat_out.backward(out_grad=mx.nd.concat(*[x.reshape(-1) for x 
in out_grads], dim=0))
+                grads = [free_vars[i].grad for i, _ in 
enumerate(free_var_shapes)] \
+                      + [input_vars[i].grad for i, _ in 
enumerate(input_var_shapes)]
+            return _to_numpy_list(outputs), _to_numpy_list(grads), out_grads
+
+    def _get_symbolic_result(out_grads):
+        outputs_sym = mx.sym.contrib.cond(
+            pred=cond_func(_input_syms, _free_syms),
+            then_func=lambda: then_func(_input_syms, _free_syms),
+            else_func=lambda: else_func(_input_syms, _free_syms),
+        )
+        outputs_sym = [x * 2 for x in outputs_sym]
+        outputs_sym = mx.sym.Group(outputs_sym)
+        executor = outputs_sym.bind(
+            ctx=default_context(),
+            args={name: _args_dict[name].copy() for name in 
outputs_sym.list_inputs()},
+            args_grad=None if not is_train else _merge_dict(
+                {"InputVar" + str(i): mx.nd.zeros(s) for i, s in 
enumerate(input_var_shapes)},
+                {"FreeVar" + str(i): mx.nd.zeros(s) for i, s in 
enumerate(free_var_shapes)},
+            ),
+        )
+        outputs = executor.forward(is_train=is_train)
+        grads = []
+        if is_train:
+            executor.backward(out_grads=out_grads)
+            grads = [executor.grad_dict.get("FreeVar" + str(i), None) for i, _ 
in enumerate(free_var_shapes)] \
+                  + [executor.grad_dict.get("InputVar" + str(i), None) for i, 
_ in enumerate(input_var_shapes)]
+        return _to_numpy_list(outputs), _to_numpy_list(grads)
+
+    imp_outs, imp_grads, out_grads = _get_imperative_result()
+    sym_outs, sym_grads = _get_symbolic_result(out_grads)
+    for imp_out, sym_out in zip(imp_outs, sym_outs):
+        if imp_out is None or sym_out is None:
+            continue
+        assert_almost_equal(imp_out, sym_out, rtol=1e-5, atol=1e-5)
+    for imp_grad, sym_grad in zip(imp_grads, sym_grads):
+        if imp_grad is None or sym_grad is None:
+            continue
+        assert_almost_equal(imp_grad, sym_grad, rtol=1e-5, atol=1e-5)
+
+
+@with_seed()
+def test_cond():
+    # whether there are free variables in three graphs
+    # whether these three graphs contain input_vars
+    # whether to use all input_vars
+    # which branch to choose
+    def run_case(cond_func, then_func, else_func, **params):
+        def make_cond(is_inverse):
+            def cond(inputs, free):
+                x = cond_func(inputs, free)
+                if is_inverse:
+                    if isinstance(x, mx.sym.Symbol):
+                        return mx.sym.logical_not(x)
+                    else:
+                        return mx.nd.logical_not(x)
+                return x
+            return cond
+        for is_train in [True, False]:
+            for is_inverse in [False, True]:
+                _verify_cond(
+                    cond_func=make_cond(is_inverse),
+                    then_func=then_func,
+                    else_func=else_func,
+                    is_train=is_train,
+                    **params
+                )
+    # Each function can
+    # 1. use_free_vars or not: T/F
+    # 2. use_input_vars or not: T/F
+    # 3. use_all_input_vars or not: T/F
+    # (a, b, c) are inputs, (d, e, f) are free_vars
+    cond_funcs = [
+        lambda a, b, c, d, e, f: (a * b).sum() < 0.5,               # F, T, F
+        lambda a, b, c, d, e, f: (a + b + c).sum() < 0.5,           # F, T, T
+        lambda a, b, c, d, e, f: (d + e).sum() < 0.5,               # T, F, F
+        lambda a, b, c, d, e, f: (d + e * a).sum() < 0.5,           # T, T, F
+        lambda a, b, c, d, e, f: (d + e * a + b * c).sum() < 0.5,   # T, T, T
+    ]
+    body_funcs = [
+        lambda a, b, c, d, e, f: a * b,                             # F, T, F
+        lambda a, b, c, d, e, f: a * b * c,                         # F, T, T
+        lambda a, b, c, d, e, f: d * e,                             # T, F, F
+        lambda a, b, c, d, e, f: d * e * a,                         # T, T, F
+        lambda a, b, c, d, e, f: d * e * a * b * c,                 # T, T, T
+        # some extra tests
+        lambda a, b, c, d, e, f: b * c,
+        lambda a, b, c, d, e, f: a * c,
+        lambda a, b, c, d, e, f: (a + b) * c,
+        lambda a, b, c, d, e, f: c * (b - a),
+    ]
+    # enumerate all kinds of possible combinations
+    for cond_func in cond_funcs:
+        for then_func in body_funcs:
+            for else_func in body_funcs:
+                run_case(
+                    cond_func=lambda x, y: cond_func(x[0], x[1], x[2], y[0], 
y[1], y[2]),
+                    then_func=lambda x, y: then_func(x[0], x[1], x[2], y[0], 
y[1], y[2]),
+                    else_func=lambda x, y: else_func(x[0], x[1], x[2], y[0], 
y[1], y[2]),
+                    input_var_shapes=[
+                        (2, 3),
+                        (2, 3),
+                        (2, 3),
+                    ],
+                    free_var_shapes=[
+                        (2, 3),
+                        (2, 3),
+                        (2, 3),
+                    ]
+                )
 
 class TestRNNLayer(gluon.HybridBlock):
     def __init__(self, cell_type, hidden_size, prefix=None, params=None):
@@ -1510,5 +1664,6 @@ def test_foreach_rnn():
 
 
 if __name__ == '__main__':
-    import nose
-    nose.runmodule()
+    # import nose
+    # nose.runmodule()
+    test_cond()

Reply via email to