This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 54632bc  [MXNET-626] Add while_loop (#11566)
54632bc is described below

commit 54632bcb38064a0ed1f23dd652897562d3a0036a
Author: Junru Shao <[email protected]>
AuthorDate: Wed Jul 18 17:09:10 2018 -0700

    [MXNET-626] Add while_loop (#11566)
    
    * Add while_loop
    
    * Avoid input/output overlap for nnvm graph cut
    
    * Add more testcases
    
    * Enhance test 4.2
    
    * Add more complicated testcases; Add testcase for nested loop
    
    * Check unused loop_vars in while_loop
    
    * Add testcases for RNN
    
    * Make lint happy
    
    * Make lint happy
    
    * Address TODOs
    
    * Fix flaky test for while_loop
    
    * Address comments
    
    * Improve docstring
    
    * Improve error message
    
    * Add benchmark code
    
    * Update benchmarks
    
    * Allow sparse types
    
    * Make max_iterations default to None
    
    * Add while_loop to docs/api/python/{symbol|ndarray}/contrib.md
    
    * Pad imperative while_loop so that it has the same shape with the symbolic 
one
    
    * Add example result into the example section
    
    * Remove unused class member
    
    * Rename unittest to test_contrib_control_flow.py
    
    * Update docstring
    
    * Update docstring
    
    * Trigger CI
    
    * Change threshold for assert_almost_equal
    
    * Trigger CI
    
    * Address comments from szha
    
    * Rewrite benchmark code
    
    * Fix sphinx warning
---
 3rdparty/tvm                                       |   2 +-
 .../python/control_flow/{rnn.py => foreach_rnn.py} |  12 +-
 benchmark/python/control_flow/rnn.py               | 273 +++---
 .../control_flow/{rnn.py => while_loop_rnn.py}     |  98 ++-
 docs/api/python/ndarray/contrib.md                 |   1 +
 docs/api/python/symbol/contrib.md                  |   1 +
 python/mxnet/ndarray/contrib.py                    | 174 +++-
 python/mxnet/symbol/contrib.py                     | 222 ++++-
 src/operator/control_flow.cc                       | 563 +++++++++++-
 src/operator/subgraph_op_common.cc                 |   9 +-
 src/operator/subgraph_op_common.h                  |  14 +-
 tests/python/unittest/test_contrib_control_flow.py | 978 +++++++++++++++++++++
 12 files changed, 2133 insertions(+), 214 deletions(-)

diff --git a/3rdparty/tvm b/3rdparty/tvm
index 6ab4da6..290226e 160000
--- a/3rdparty/tvm
+++ b/3rdparty/tvm
@@ -1 +1 @@
-Subproject commit 6ab4da6783417d8afdeb6b0426b44959b2afc709
+Subproject commit 290226e1c9adbb3e598f9ed9184018df1c12be33
diff --git a/benchmark/python/control_flow/rnn.py 
b/benchmark/python/control_flow/foreach_rnn.py
similarity index 92%
copy from benchmark/python/control_flow/rnn.py
copy to benchmark/python/control_flow/foreach_rnn.py
index 5e41b75..4ce7a42 100644
--- a/benchmark/python/control_flow/rnn.py
+++ b/benchmark/python/control_flow/foreach_rnn.py
@@ -157,7 +157,8 @@ if __name__ == '__main__':
     ndim = 512
     seq_len = 100
     batch_sizes = [1, 32]
-    cells = [gluon.rnn.GRUCell(ndim, prefix='rnn_'),
+    cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'),
+             gluon.rnn.GRUCell(ndim, prefix='rnn_'),
              gluon.rnn.LSTMCell(ndim, prefix='rnn_')]
     ctxs = [mx.cpu(0), mx.gpu(0)]
     for cell in cells:
@@ -165,8 +166,13 @@ if __name__ == '__main__':
             for batch_size in batch_sizes:
                 if len(get_gpus()) == 0 and ctx == mx.gpu(0):
                     continue
-
-                if isinstance(cell, gluon.rnn.GRUCell):
+                if isinstance(cell, gluon.rnn.RNNCell):
+                    rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, 
batch_size, ndim),
+                                            ctx=mx.cpu(0))
+                    states = []
+                    states.append(mx.nd.normal(loc=0, scale=1, 
shape=(batch_size, ndim),
+                                               ctx=mx.cpu(0)))
+                elif isinstance(cell, gluon.rnn.GRUCell):
                     rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, 
batch_size, ndim),
                                             ctx=mx.cpu(0))
                     states = []
diff --git a/benchmark/python/control_flow/rnn.py 
b/benchmark/python/control_flow/rnn.py
index 5e41b75..8a44a9c 100644
--- a/benchmark/python/control_flow/rnn.py
+++ b/benchmark/python/control_flow/rnn.py
@@ -15,175 +15,128 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from __future__ import print_function
+from six.moves import range
+
+import argparse
 import subprocess
+from itertools import product
+from time import time
+
 import mxnet as mx
+import numpy as np
 from mxnet import gluon
-import time
-import copy
 
-def get_gpus():
-    """
-    return a list of GPUs
-    """
-    try:
-        re = subprocess.check_output(["nvidia-smi", "-L"], 
universal_newlines=True)
-    except OSError:
-        return []
-    return range(len([i for i in re.split('\n') if 'GPU' in i]))
 
-class TestRNNLayer(gluon.HybridBlock):
-    def __init__(self, cell, prefix=None, params=None):
-        super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
+_parser = argparse.ArgumentParser(description='Benchmark foreach and 
while_loop on RNN tasks.')
+_parser.add_argument('--benchmark', choices=["foreach", "while_loop"], 
required=True)
+_parser.add_argument('--warmup_rounds', type=int, default=20)
+_parser.add_argument('--test_rounds', type=int, default=100)
+args = _parser.parse_args()
+
+
+class ForeachRNN(gluon.HybridBlock):
+    def __init__(self, cell, length, prefix=None, params=None):
+        super(ForeachRNN, self).__init__(prefix=prefix, params=params)
+        self.length = length
         self.cell = cell
 
     def hybrid_forward(self, F, inputs, states):
         out, states = F.contrib.foreach(self.cell, inputs, states)
         return out
 
-def benchmark_rnn(cell, rnn_data, states):
-    ctx = rnn_data.context
-    num_batches = 20
-
-    # Imperative
-    cell0 = copy.deepcopy(cell)
-    layer0 = TestRNNLayer(cell0)
-    layer0.initialize(ctx=ctx)
-
-    # Hybridize
-    cell1 = copy.deepcopy(cell)
-    cell1.hybridize()
-    layer1 = TestRNNLayer(cell1)
-    layer1.initialize(ctx=ctx)
-
-    # Hybridize
-    cell2 = copy.deepcopy(cell)
-    layer2 = TestRNNLayer(cell2)
-    layer2.initialize(ctx=ctx)
-    layer2.hybridize()
-    layer2(rnn_data, states)
-
-    # Hybridize
-    cell3 = copy.deepcopy(cell)
-    cell3.hybridize(static_alloc=True)
-    layer3 = TestRNNLayer(cell3)
-    layer3.initialize(ctx=ctx)
-
-    tic = time.time()
-    for i in range(num_batches):
-        res0 = layer0(rnn_data, states)
-        mx.nd.waitall()
-    print("Imperative inference takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        res1 = layer1(rnn_data, states)
-        mx.nd.waitall()
-    print("Hybrid-cell inference takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        res3 = layer3(rnn_data, states)
-        mx.nd.waitall()
-    print("Static-hybrid-cell inference takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        res2 = layer2(rnn_data, states)
-        mx.nd.waitall()
-    print("Hybrid inference takes " + str(time.time() - tic))
-
-    layer2.export("foreach_rnn")
-    symnet = mx.symbol.load('foreach_rnn-symbol.json')
-    args1 = {}
-    params = layer2.collect_params()
-    for key in params.keys():
-        args1[key] = params[key].data()
-    args1['data0'] = rnn_data
-    for i in range(len(states)):
-        args1['data' + str(i + 1)] = states[i]
-    exe = symnet.bind(ctx=ctx, args=args1)
-    tic = time.time()
-    for i in range(num_batches):
-        exe.forward(is_train=False)
-        mx.nd.waitall()
-    print("Symbol inference takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        with mx.autograd.record():
-            res0 = layer0(rnn_data, states)
-        res0.backward()
-        mx.nd.waitall()
-    print("Imperative training takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        with mx.autograd.record():
-            res1 = layer1(rnn_data, states)
-        res1.backward()
-        mx.nd.waitall()
-    print("Hybrid-cell training takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        with mx.autograd.record():
-            res3 = layer3(rnn_data, states)
-        res3.backward()
-        mx.nd.waitall()
-    print("Static-hybrid-cell training takes " + str(time.time() - tic))
-
-    tic = time.time()
-    for i in range(num_batches):
-        with mx.autograd.record():
-            res2 = layer2(rnn_data, states)
-        res2.backward()
-        mx.nd.waitall()
-    print("Hybrid training takes " + str(time.time() - tic))
-
-    # gradients for the backward of the foreach symbol
-    args_grad1 = {}
-    for key in args1.keys():
-        args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx)
-    exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1)
-    tic = time.time()
-    for i in range(num_batches):
-        exe.forward(is_train=True)
-        exe.backward(res2)
-        mx.nd.waitall()
-    print("Symbol training takes " + str(time.time() - tic))
-    print("")
-
-if __name__ == '__main__':
-    ndim = 512
-    seq_len = 100
+
+class WhileRNN(gluon.HybridBlock):
+    def __init__(self, cell, length, prefix=None, params=None):
+        super(WhileRNN, self).__init__(prefix=prefix, params=params)
+        self.length = length
+        self.cell = cell
+
+    def hybrid_forward(self, F, inputs, states):
+        def _func(*states):
+            i = states[0]
+            s = states[1: ]
+            data = inputs.take(i).squeeze(axis=0)
+            out, new_s = self.cell(data, s)
+            new_s = [i + 1] + new_s
+            return out, new_s
+        out, states = F.contrib.while_loop(
+            cond=lambda i, *_: i < self.length,
+            func=_func,
+            loop_vars=states,
+            max_iterations=self.length,
+        )
+        assert len(out) == 1
+        return out[0]
+
+
+def _zeros(shape, ctx):
+    return mx.nd.zeros(shape=shape, ctx=ctx)
+
+
+def _array(shape, ctx):
+    return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=ctx)
+
+
+def _get_gpus():
+    try:
+        re = subprocess.check_output(["nvidia-smi", "-L"], 
universal_newlines=True)
+    except OSError:
+        return []
+    return range(len([i for i in re.split('\n') if 'GPU' in i]))
+
+
+def run_benchmark(cell_type, ctx, seq_len, batch_size, hidden_dim):
+    obj = {"foreach": ForeachRNN, "while_loop": WhileRNN}[args.benchmark]
+    inputs = _array((seq_len, batch_size, hidden_dim), ctx)
+    states = [_array((batch_size, hidden_dim), ctx) for _ in 
cell_type(0).state_info()]
+    if args.benchmark == "while_loop":
+        states.insert(0, _zeros((1, ), ctx))
+
+    for is_train, is_hyb_cell, is_hyb_layer in product([True, False], [False, 
True], [False, True]):
+        cell = cell_type(hidden_dim)
+        if is_hyb_cell:
+            cell.hybridize(static_alloc=True)
+        layer = obj(cell, seq_len)
+        layer.initialize(ctx=ctx)
+        if is_hyb_layer:
+            layer.hybridize(static_alloc=True)
+        print("is_train = %r, hybridize_cell = %r, hybridize_layer = %r" % 
(is_train, is_hyb_cell, is_hyb_layer))
+        times = []
+        for _ in range(args.warmup_rounds + args.test_rounds):
+            tick = time()
+            if not is_train:
+                res = layer(inputs, states)
+            else:
+                with mx.autograd.record():
+                    res = layer(inputs, states)
+            if is_train:
+                res.backward()
+            mx.nd.waitall()
+            tock = time()
+            times.append((tock - tick) * 1000.0)
+        times = times[args.warmup_rounds: ]
+        print("Time used: mean = %.3f ms, std = %.3f ms" % (np.mean(times), 
np.std(times)))
+
+
+def main():
+    # testing configurations
+    cell_types = [gluon.rnn.RNNCell,
+                  gluon.rnn.GRUCell,
+                  gluon.rnn.LSTMCell]
+    ctxs = [mx.cpu(0)] + [mx.gpu(i) for i in _get_gpus()]
+    seq_lens = [100]
     batch_sizes = [1, 32]
-    cells = [gluon.rnn.GRUCell(ndim, prefix='rnn_'),
-             gluon.rnn.LSTMCell(ndim, prefix='rnn_')]
-    ctxs = [mx.cpu(0), mx.gpu(0)]
-    for cell in cells:
-        for ctx in ctxs:
-            for batch_size in batch_sizes:
-                if len(get_gpus()) == 0 and ctx == mx.gpu(0):
-                    continue
-
-                if isinstance(cell, gluon.rnn.GRUCell):
-                    rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, 
batch_size, ndim),
-                                            ctx=mx.cpu(0))
-                    states = []
-                    states.append(mx.nd.normal(loc=0, scale=1, 
shape=(batch_size, ndim),
-                                               ctx=mx.cpu(0)))
-                elif isinstance(cell, gluon.rnn.LSTMCell):
-                    rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, 
batch_size, ndim),
-                                            ctx=mx.cpu(0))
-                    states = []
-                    states.append(mx.nd.normal(loc=0, scale=1, 
shape=(batch_size, ndim),
-                                               ctx=mx.cpu(0)))
-                    states.append(mx.nd.normal(loc=0, scale=1, 
shape=(batch_size, ndim),
-                                               ctx=mx.cpu(0)))
-                if ctx == mx.gpu(0):
-                    dev = "GPU"
-                else:
-                    dev = "CPU"
-                print("Benchmark {} in {} (batch size: 
{})".format(cell._alias(), dev,
-                                                                   batch_size))
-                benchmark_rnn(cell, rnn_data, states)
+    hidden_dims = [512]
+    print("--------------------------------------")
+    print("Benchmarking", args.benchmark)
+    for cell_type, ctx, seq_len, batch_size, hidden_dim in product(  \
+        cell_types, ctxs, seq_lens, batch_sizes, hidden_dims):
+        print("--------------------------------------")
+        print("cell: %s  ctx: %s  length: %d  batch size: %d dim: %d" % \
+              (cell_type.__name__, str(ctx), seq_len, batch_size, hidden_dim))
+        run_benchmark(cell_type, ctx, seq_len, batch_size, hidden_dim)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/benchmark/python/control_flow/rnn.py 
b/benchmark/python/control_flow/while_loop_rnn.py
similarity index 67%
copy from benchmark/python/control_flow/rnn.py
copy to benchmark/python/control_flow/while_loop_rnn.py
index 5e41b75..42aaee5 100644
--- a/benchmark/python/control_flow/rnn.py
+++ b/benchmark/python/control_flow/while_loop_rnn.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
+# Code borrowed from ./benchmark/python/control_flow/foreach_rnn.py
+
 import subprocess
 import mxnet as mx
 from mxnet import gluon
@@ -32,40 +34,53 @@ def get_gpus():
     return range(len([i for i in re.split('\n') if 'GPU' in i]))
 
 class TestRNNLayer(gluon.HybridBlock):
-    def __init__(self, cell, prefix=None, params=None):
+    def __init__(self, cell, length, prefix=None, params=None):
         super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
+        self.length = length
         self.cell = cell
 
     def hybrid_forward(self, F, inputs, states):
-        out, states = F.contrib.foreach(self.cell, inputs, states)
-        return out
-
-def benchmark_rnn(cell, rnn_data, states):
+        def _func(*states):
+            i = states[0]
+            s = states[1: ]
+            data = inputs.take(i).squeeze(axis=0)
+            out, new_s = self.cell(data, s)
+            new_s = [i + 1] + new_s
+            return out, new_s
+        out, states = F.contrib.while_loop(
+            cond=lambda i, *_: i < self.length,
+            func=_func,
+            loop_vars=states,
+            max_iterations=self.length,
+        )
+        return out + states
+
+def benchmark_rnn(cell, rnn_data, states, length):
     ctx = rnn_data.context
     num_batches = 20
 
     # Imperative
     cell0 = copy.deepcopy(cell)
-    layer0 = TestRNNLayer(cell0)
+    layer0 = TestRNNLayer(cell0, length)
     layer0.initialize(ctx=ctx)
 
-    # Hybridize
+    # Hybrid-cell
     cell1 = copy.deepcopy(cell)
     cell1.hybridize()
-    layer1 = TestRNNLayer(cell1)
+    layer1 = TestRNNLayer(cell1, length)
     layer1.initialize(ctx=ctx)
 
-    # Hybridize
+    # Hybrid
     cell2 = copy.deepcopy(cell)
-    layer2 = TestRNNLayer(cell2)
+    layer2 = TestRNNLayer(cell2, length)
     layer2.initialize(ctx=ctx)
     layer2.hybridize()
     layer2(rnn_data, states)
 
-    # Hybridize
+    # Static-hybrid-cell
     cell3 = copy.deepcopy(cell)
     cell3.hybridize(static_alloc=True)
-    layer3 = TestRNNLayer(cell3)
+    layer3 = TestRNNLayer(cell3, length)
     layer3.initialize(ctx=ctx)
 
     tic = time.time()
@@ -92,8 +107,8 @@ def benchmark_rnn(cell, rnn_data, states):
         mx.nd.waitall()
     print("Hybrid inference takes " + str(time.time() - tic))
 
-    layer2.export("foreach_rnn")
-    symnet = mx.symbol.load('foreach_rnn-symbol.json')
+    layer2.export("while_loop_rnn")
+    symnet = mx.symbol.load('while_loop_rnn-symbol.json')
     args1 = {}
     params = layer2.collect_params()
     for key in params.keys():
@@ -112,7 +127,7 @@ def benchmark_rnn(cell, rnn_data, states):
     for i in range(num_batches):
         with mx.autograd.record():
             res0 = layer0(rnn_data, states)
-        res0.backward()
+        res0[0].backward()
         mx.nd.waitall()
     print("Imperative training takes " + str(time.time() - tic))
 
@@ -120,7 +135,7 @@ def benchmark_rnn(cell, rnn_data, states):
     for i in range(num_batches):
         with mx.autograd.record():
             res1 = layer1(rnn_data, states)
-        res1.backward()
+        res1[0].backward()
         mx.nd.waitall()
     print("Hybrid-cell training takes " + str(time.time() - tic))
 
@@ -128,7 +143,7 @@ def benchmark_rnn(cell, rnn_data, states):
     for i in range(num_batches):
         with mx.autograd.record():
             res3 = layer3(rnn_data, states)
-        res3.backward()
+        res3[0].backward()
         mx.nd.waitall()
     print("Static-hybrid-cell training takes " + str(time.time() - tic))
 
@@ -136,14 +151,15 @@ def benchmark_rnn(cell, rnn_data, states):
     for i in range(num_batches):
         with mx.autograd.record():
             res2 = layer2(rnn_data, states)
-        res2.backward()
+        res2[0].backward()
         mx.nd.waitall()
     print("Hybrid training takes " + str(time.time() - tic))
 
-    # gradients for the backward of the foreach symbol
+    # gradients for the backward of the while_loop symbol
     args_grad1 = {}
     for key in args1.keys():
-        args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx)
+        if key != "data1":
+            args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx)
     exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1)
     tic = time.time()
     for i in range(num_batches):
@@ -154,10 +170,15 @@ def benchmark_rnn(cell, rnn_data, states):
     print("")
 
 if __name__ == '__main__':
+    def _zeros(shape):
+        return mx.nd.zeros(shape=shape, ctx=mx.cpu(0))
+    def _array(shape):
+        return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=mx.cpu(0))
     ndim = 512
     seq_len = 100
     batch_sizes = [1, 32]
-    cells = [gluon.rnn.GRUCell(ndim, prefix='rnn_'),
+    cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'),
+             gluon.rnn.GRUCell(ndim, prefix='rnn_'),
              gluon.rnn.LSTMCell(ndim, prefix='rnn_')]
     ctxs = [mx.cpu(0), mx.gpu(0)]
     for cell in cells:
@@ -165,25 +186,28 @@ if __name__ == '__main__':
             for batch_size in batch_sizes:
                 if len(get_gpus()) == 0 and ctx == mx.gpu(0):
                     continue
-
+                if isinstance(cell, gluon.rnn.RNNCell):
+                    rnn_data = _array((seq_len, batch_size, ndim))
+                    states = [
+                        _zeros((1, )),
+                        _array((batch_size, ndim)),
+                    ]
                 if isinstance(cell, gluon.rnn.GRUCell):
-                    rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, 
batch_size, ndim),
-                                            ctx=mx.cpu(0))
-                    states = []
-                    states.append(mx.nd.normal(loc=0, scale=1, 
shape=(batch_size, ndim),
-                                               ctx=mx.cpu(0)))
+                    rnn_data = _array((seq_len, batch_size, ndim))
+                    states = [
+                        _zeros((1, )),
+                        _array((batch_size, ndim)),
+                    ]
                 elif isinstance(cell, gluon.rnn.LSTMCell):
-                    rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, 
batch_size, ndim),
-                                            ctx=mx.cpu(0))
-                    states = []
-                    states.append(mx.nd.normal(loc=0, scale=1, 
shape=(batch_size, ndim),
-                                               ctx=mx.cpu(0)))
-                    states.append(mx.nd.normal(loc=0, scale=1, 
shape=(batch_size, ndim),
-                                               ctx=mx.cpu(0)))
+                    rnn_data = _array((seq_len, batch_size, ndim))
+                    states = [
+                        _zeros((1, )),
+                        _array((batch_size, ndim)),
+                        _array((batch_size, ndim)),
+                    ]
                 if ctx == mx.gpu(0):
                     dev = "GPU"
                 else:
                     dev = "CPU"
-                print("Benchmark {} in {} (batch size: 
{})".format(cell._alias(), dev,
-                                                                   batch_size))
-                benchmark_rnn(cell, rnn_data, states)
+                print("Benchmark {} in {} (batch size: 
{})".format(cell._alias(), dev, batch_size))
+                benchmark_rnn(cell, rnn_data, states, seq_len)
diff --git a/docs/api/python/ndarray/contrib.md 
b/docs/api/python/ndarray/contrib.md
index 36a2c15..0cf8724 100644
--- a/docs/api/python/ndarray/contrib.md
+++ b/docs/api/python/ndarray/contrib.md
@@ -53,6 +53,7 @@ In the rest of this document, we list routines provided by 
the `ndarray.contrib`
     ifft
     quantize
     foreach
+    while_loop
 ```
 
 ## API Reference
diff --git a/docs/api/python/symbol/contrib.md 
b/docs/api/python/symbol/contrib.md
index 6647165..ba43f2d 100644
--- a/docs/api/python/symbol/contrib.md
+++ b/docs/api/python/symbol/contrib.md
@@ -53,6 +53,7 @@ In the rest of this document, we list routines provided by 
the `symbol.contrib`
     ifft
     quantize
     foreach
+    while_loop
 ```
 
 ## API Reference
diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py
index b1f065e..b67cf5a 100644
--- a/python/mxnet/ndarray/contrib.py
+++ b/python/mxnet/ndarray/contrib.py
@@ -28,7 +28,7 @@ try:
 except ImportError:
     pass
 
-__all__ = ["rand_zipfian"]
+__all__ = ["rand_zipfian", "foreach", "while_loop"]
 
 # pylint: disable=line-too-long
 def rand_zipfian(true_classes, num_sampled, range_max, ctx=None):
@@ -191,3 +191,175 @@ def foreach(body, data, init_states):
     if not_data_list and len(outputs) == 1:
         outputs = outputs[0]
     return (outputs, states)
+
+
+def while_loop(cond, func, loop_vars, max_iterations=None):
+    """Run a while loop with user-defined computation and loop condition.
+
+    This operator simulates a while loop which iterately does customized 
computation
+    as long as the condition is satisfied.
+
+    `loop_vars` is a list of NDArrays on which the computation uses.
+
+    `cond` is a user-defined function, used as the loop condition.
+    It consumes `loop_vars`, and produces a scalar MXNet NDArray,
+    indicating the termination of the loop.
+    The loop ends when `cond` returns false (zero).
+    The `cond` is variadic, and its signature should be
+    `cond(*loop_vars) => NDArray`.
+
+    `func` is a user-defined function, used as the loop body.
+    It also consumes `loop_vars`, and produces `step_output` and 
`new_loop_vars` at each step.
+    In each step, `step_output` should contain the same number elements.
+    Through all steps, the i-th element of `step_output` should have the same 
shape and dtype.
+    Also, `new_loop_vars` should contain the same number of elements as 
`loop_vars`,
+    and the corresponding element should have the same shape and dtype.
+    The `func` is variadic, and its signature should be
+    `func(*loop_vars) => (List[NDArray] step_output, List[NDArray] 
new_loop_vars)`.
+
+    `max_iterations` is a scalar that defines the maximum number of iterations 
allowed.
+
+    This function returns two lists.
+    The first list has the length of `|step_output|`,
+    in which the i-th element are all i-th elements of
+    `step_output` from all steps, stacked along axis 0.
+    The second list has the length of `|loop_vars|`,
+    which represents final states of loop variables.
+
+    .. warning::
+
+       For now, the axis 0 of all NDArrays in the first list are 
`max_iterations`,
+       due to lack of dynamic shape inference.
+
+    .. warning::
+
+       When `cond` is never satisfied, we assume `step_output` is empty,
+       because it cannot be inferred. This is different from the symbolic 
version.
+
+    Parameters
+    ----------
+    cond: a Python function.
+        The loop condition.
+    func: a Python function.
+        The loop body.
+    loop_vars: list of NDArrays.
+        The initial values of the loop variables.
+    max_iterations: a python int.
+        Maximum number of iterations.
+
+    Returns
+    ------
+    outputs: list of NDArrays
+        stacked output from each step
+    states: list of NDArrays
+        final state
+
+    Examples
+    --------
+    >>> cond = lambda i, s: i <= 5
+    >>> func = lambda i, s: ([i + s], [i + 1, s + i])
+    >>> loop_vars = (mx.nd.array([0], dtype="int64"), mx.nd.array([1], 
dtype="int64"))
+    >>> outputs, states = mx.nd.contrib.while_loop(cond, func, loop_vars, 
max_iterations=10)
+    >>> outputs
+    [
+    [[ 1]
+    [ 2]
+    [ 4]
+    [ 7]
+    [11]
+    [16]
+    [...]  # undefined value
+    [...]
+    [...]
+    [...]]
+    <NDArray 6x1 @cpu(0)>]
+    >>> states
+    [
+    [6]
+    <NDArray 1 @cpu(0)>,
+    [16]
+    <NDArray 1 @cpu(0)>]
+    """
+    def _to_python_scalar(inputs, type_, name):
+        """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, 
other python types,
+        to the given type
+        """
+        if isinstance(inputs, ndarray.NDArray):
+            inputs = inputs.asscalar()
+        try:
+            inputs = type_(inputs)
+        except:
+            raise ValueError("Cannot convert %s to python %s" % (name, 
type_.__name__))
+        return inputs
+
+    def _to_ndarray_tuple(inputs, name):
+        """Converts "inputs", possibly a single mxnet NDArray, a list of mxnet 
NDArray,
+        a tuple of mxnet NDArray, into a tuple of NDArray
+        """
+        if isinstance(inputs, list):
+            inputs = tuple(inputs)
+        if isinstance(inputs, ndarray.NDArray):
+            inputs = (inputs, )
+        if not isinstance(inputs, tuple):
+            raise ValueError("%s must be an NDArray, or a tuple or list of 
NDArrays" % (name, ))
+        for item in inputs:
+            if not isinstance(item, ndarray.NDArray):
+                raise ValueError("%s must be an NDArray, or a tuple or list of 
NDArrays" % (name, ))
+        return inputs
+
+    def _func_wrapper(loop_vars):
+        """This wrapper unifies
+             "func: loop_vars -> new_loop_vars"
+         and "func: loop_vars -> (step_output, new_loop_vars)"
+        into "func: loop_vars -> (None or tuple of step_outputs, tuple of 
new_loop_vars)
+        """
+        step_output, new_loop_vars = func(*loop_vars)
+        if step_output is None:
+            step_output = []
+        if new_loop_vars is None:
+            new_loop_vars = []
+        step_output = _to_ndarray_tuple(step_output, "step_output")
+        new_loop_vars = _to_ndarray_tuple(new_loop_vars, "new_loop_vars")
+        if len(loop_vars) != len(new_loop_vars):
+            raise ValueError("The length of loop_vars should be consistent 
during the loop")
+        return step_output, new_loop_vars
+
+    if max_iterations is None:
+        raise ValueError("max_iterations should be specified")
+    max_iterations = _to_python_scalar(max_iterations, int, "max_iteration")
+    loop_vars = _to_ndarray_tuple(loop_vars, "loop_vars")
+    # It should be work as fine if loop_vars are empty I guess,
+    # but it is semantically unnecessary to include this case.
+    if len(loop_vars) == 0:
+        raise ValueError("loop_vars should contain at least one element")
+
+    steps = 0
+    outputs = []
+    while steps < max_iterations and \
+            _to_python_scalar(cond(*loop_vars), bool, "Return value of cond"): 
# loop condition
+        step_output, loop_vars = _func_wrapper(loop_vars)
+        outputs.append(step_output)
+        steps += 1
+        if len(outputs) != steps or len(step_output) != len(outputs[0]):
+            raise ValueError("Number of elements in step_output should be the 
same in each step")
+    stacked_outputs = []
+    for i_th, items in enumerate(zip(*outputs), 1):
+        # `mx.ndarray.pad` only support 4-D or 5-D inputs for now
+        # so we could not use it.
+        items = [x.expand_dims(0) for x in items]
+        if steps != max_iterations and items:
+            pad_shape = [max_iterations - steps] + list(items[0].shape[1: ])
+            pad = ndarray.empty(
+                shape=pad_shape,
+                ctx=items[0].context,
+                dtype=items[0].dtype,
+            )
+            items = list(items) + [pad]
+        try:
+            stacked_outputs.append(ndarray.op.concat(*items, dim=0))
+        except ValueError:
+            raise ValueError("\n".join(
+                ["Shapes of %d-th elements in step_outputs are inconsistent, 
which are:" % i_th] +
+                ["  Step %d, shape is %s" % (i, str(x.shape)) for i, x in 
enumerate(items)]
+            ))
+    return stacked_outputs, list(loop_vars)
diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py
index 28bb507..2c11921 100644
--- a/python/mxnet/symbol/contrib.py
+++ b/python/mxnet/symbol/contrib.py
@@ -34,7 +34,7 @@ from ..base import _LIB, check_call
 from ..base import SymbolHandle, _as_list
 from ..attribute import AttrScope
 
-__all__ = ["rand_zipfian", "foreach"]
+__all__ = ["rand_zipfian", "foreach", "while_loop"]
 
 def rand_zipfian(true_classes, num_sampled, range_max):
     """Draw random samples from an approximately log-uniform or Zipfian 
distribution.
@@ -336,3 +336,223 @@ def foreach(body, data, init_states, name="foreach"):
         states = states[0]
 
     return (outs, states)
+
+def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"):
+    """Run a while loop with user-defined computation and loop condition.
+
+    This operator simulates a while loop which iterately does customized 
computation
+    as long as the condition is satisfied.
+
+    `loop_vars` is a list of Symbols on which the computation uses.
+
+    `cond` is a user-defined function, used as the loop condition.
+    It consumes `loop_vars`, and produces a scalar MXNet symbol,
+    indicating the termination of the loop.
+    The loop ends when `cond` returns false (zero).
+    The `cond` is variadic, and its signature should be
+    `cond(*loop_vars) => Symbol`.
+
+    `func` is a user-defined function, used as the loop body.
+    It also consumes `loop_vars`, and produces `step_output` and 
`new_loop_vars` at each step.
+    In each step, `step_output` should contain the same number elements.
+    Through all steps, the i-th element of `step_output` should have the same 
shape and dtype.
+    Also, `new_loop_vars` should contain the same number of elements as 
`loop_vars`,
+    and the corresponding element should have the same shape and dtype.
+    The `func` is variadic, and its signature should be
+    `func(*loop_vars) => (List[Symbol] step_output, List[Symbol] 
new_loop_vars)`.
+
+    `max_iterations` is a scalar that defines the maximum number of iterations 
allowed.
+
+    This function returns two lists.
+    The first list has the length of `|step_output|`,
+    in which the i-th element are all i-th elements of
+    `step_output` from all steps, stacked along axis 0.
+    The second list has the length of `|loop_vars|`,
+    which represents final states of loop variables.
+
+    .. warning::
+
+       For now, the axis 0 of all Symbols in the first list are 
`max_iterations`,
+       due to lack of dynamic shape inference.
+
+    .. warning::
+
+       Even if `cond` is never satisfied,
+       while_loop returns a list of outputs with inferred dtype and shape.
+       This is different from the Symbol version,
+       where in this case `step_outputs` are assumed as an empty list.
+
+    Parameters
+    ----------
+    cond: a Python function.
+        The loop condition.
+    func: a Python function.
+        The loop body.
+    loop_vars: list of Symbol.
+        The initial values of the loop variables.
+    max_iterations: a python int.
+        Maximum number of iterations.
+
+    Returns
+    ------
+    outputs: list of Symbols
+        stacked output from each step
+    states: list of Symbols
+        final state
+
+    Examples
+    --------
+    >>> cond = lambda i, s: i <= 5
+    >>> func = lambda i, s: ([i + s], [i + 1, s + i])
+    >>> loop_vars = (mx.sym.var('i'), mx.sym.var('s'))
+    >>> outputs, states = mx.sym.contrib.while_loop(cond, func, loop_vars, 
max_iterations=10)
+    """
+    def _to_python_scalar(inputs, type_, name):
+        """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, 
other python types,
+        to the given type
+        """
+        if hasattr(inputs, "asscalar"):
+            inputs = inputs.asscalar()
+        try:
+            inputs = type_(inputs)
+        except:
+            raise ValueError("Cannot convert %s to python %s" % (name, 
type_.__name__))
+        return inputs
+
+    def _to_symbol_tuple(inputs, name):
+        """Converts "inputs", possibly a single mxnet Symbol, a list of mxnet 
Symbol,
+        a tuple of mxnet Symbol, into a tuple of Symbol
+        """
+        if isinstance(inputs, list):
+            inputs = tuple(inputs)
+        if isinstance(inputs, Symbol):
+            inputs = (inputs, )
+        if not isinstance(inputs, tuple):
+            raise ValueError("%s must be a Symbol, or a tuple or list of 
Symbol" % (name, ))
+        for item in inputs:
+            if not isinstance(item, Symbol):
+                raise ValueError("%s must be a Symbol, or a tuple or list of 
Symbol" % (name, ))
+        return inputs
+
+    def _cond_wrapper(loop_vars):
+        result = cond(*loop_vars)
+        if not isinstance(result, Symbol):
+            raise ValueError("Return of cond must be a Symbol")
+        return [], [result]
+
+    def _func_wrapper(loop_vars):
+        """This wrapper unifies
+             "func: loop_vars -> new_loop_vars"
+         and "func: loop_vars -> (step_output, new_loop_vars)"
+        into "func: loop_vars -> (list of step_outputs, tuple of new_loop_vars)
+        """
+        step_output, new_loop_vars = func(*loop_vars)
+        if step_output is None:
+            step_output = []
+        if new_loop_vars is None:
+            new_loop_vars = []
+        step_output = _to_symbol_tuple(step_output, "step_output")
+        new_loop_vars = _to_symbol_tuple(new_loop_vars, "new_loop_vars")
+        if len(loop_vars) != len(new_loop_vars):
+            raise ValueError("The number of loop_vars should be consistent 
during the loop")
+        return list(step_output), list(new_loop_vars)
+
+    def _create_subgraph(graph_vars, graph_func, subgraph_name):
+        with AttrScope(__subgraph_name__=subgraph_name):
+            # create new variables with the same name,
+            # them feed them to the given func
+            new_graph_vars = [symbol.var(sym.name) for sym in graph_vars]
+            outputs, final_state = graph_func(new_graph_vars)
+            # first `num_out_data` elements belong to `outputs`
+            # other elements belong to `final_state`
+            num_out_data = len(outputs)
+            num_outputs = len(outputs) + len(final_state)
+            # nnvm cut-graph does not allow inputs and outputs overlap
+            # so we calculate the name of inputs, and copy outputs once it 
overlaps with inputs
+            all_input_names = symbol.Group(outputs + final_state).list_inputs()
+            make_identity = lambda x: symbol.op.identity(x) if x.name in 
all_input_names else x
+            # group all outputs of graph_func
+            graph = symbol.Group(list(map(make_identity, outputs + 
final_state)))
+        return graph, num_out_data, num_outputs
+
+    def _union_inputs(*graphs):
+        # Given a list of graphs, each whose inputs are either from loop_vars 
or other variables.
+        # 1) calculate a list `inputs`, the union of their inputs.
+        # 2) for each graph, determine in which indices their inputs reside in 
`inputs`
+        # 3) for each variable in the input of `graph`, find which index it is
+        inputs = []             # List[Symbol], result of 1)
+        locs = []               # List[Tuple(List[Int], List[Int])], a list of 
tuples,
+                                # where tuples are results of 2) and 3)
+        input_id_to_loc = {}    # Dict[int, int], given id(sym), 
input_id_to_loc maps it
+                                # to a `loc`, where inputs[loc] = sym
+        for graph in graphs:
+            # input_syms: all inputs to the `graph`
+            name_to_input_syms = {sym.name: sym for sym in 
_get_graph_inputs(graph)}
+            # some loop_vars are inputs to `graph`, some are not
+            name_to_loop_vars = {sym.name: sym for sym in loop_vars}
+            # other inputs to `graph` created by cut_graph
+            name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in 
_cut_subgraph(graph)}
+            # also we collect the mapping from var's name to var's loc in 
loop_vars
+            name_to_var_locs = {sym.name: i for i, sym in enumerate(loop_vars)}
+            # collect arguments for each subgraph
+            input_locs = []                         # results from the second 
step
+            var_locs = [-1] * len(loop_vars)        # results from the third 
step
+            for name in graph.list_inputs():
+                assert name in name_to_input_syms   # it should obviously hold
+                # name -> sym
+                if name in name_to_loop_vars:
+                    sym = name_to_loop_vars[name]
+                elif name in name_to_cut_g_syms:
+                    sym = name_to_cut_g_syms[name]
+                else:
+                    sym = copy.deepcopy(name_to_input_syms[name])
+                # do 2), and 1) is implicitly done
+                if id(sym) in input_id_to_loc:
+                    loc = input_id_to_loc[id(sym)]
+                else:
+                    loc = len(input_id_to_loc)
+                    inputs.append(sym)
+                    input_id_to_loc[id(sym)] = loc
+                input_locs.append(loc)
+                # do 3)
+                if name in name_to_var_locs:
+                    var_locs[name_to_var_locs[name]] = len(input_locs) - 1
+            locs.append((input_locs, var_locs))
+        return inputs, locs
+    if max_iterations is None:
+        raise ValueError("max_iterations should be specified")
+    max_iterations = _to_python_scalar(max_iterations, int, "max_iteration")
+    loop_vars = _to_symbol_tuple(loop_vars, "loop_vars")
+    # It should be work as fine if loop_vars are empty I guess,
+    # but it is semantically unnecessary to include this case.
+    if len(loop_vars) == 0:
+        raise ValueError("loop_vars should contain at least one element")
+    # create graph for `cond'
+    cond_g, num_out_data, num_outputs = \
+        _create_subgraph(loop_vars, _cond_wrapper, name + "_cond")
+    assert num_out_data == 0
+    assert num_outputs == 1
+    # create graph for `func`
+    func_g, num_out_data, num_outputs = \
+        _create_subgraph(loop_vars, _func_wrapper, name + "_func")
+    # find symbols used in either cond_g or func_g
+    input_syms, ((cond_input_locs, _), (func_input_locs, func_var_locs)) = \
+        _union_inputs(cond_g, func_g)
+    for i_th, loc in enumerate(func_var_locs, 1):
+        if loc == -1:
+            raise ValueError("The %d-th loop_var doesn't involve into the 
computation" % i_th)
+    result = symbol._internal._while_loop(
+        # [cond, func_g, *input_syms]
+        cond_g,
+        func_g,
+        *input_syms,
+        max_iterations=max_iterations,
+        cond_input_locs=cond_input_locs,
+        func_input_locs=func_input_locs,
+        func_var_locs=func_var_locs,
+        num_out_data=num_out_data,
+        num_outputs=num_outputs
+    )
+    outputs = [result[i] for i in range(num_out_data)]
+    final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)]
+    return outputs, final_loop_vars
diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc
index c091fdb..b00ed9b 100644
--- a/src/operator/control_flow.cc
+++ b/src/operator/control_flow.cc
@@ -480,6 +480,503 @@ ForeachGradient(const nnvm::NodePtr& n, const 
std::vector<nnvm::NodeEntry>& ogra
   return entries;
 }
 
+struct WhileLoopParam : public dmlc::Parameter<WhileLoopParam> {
+  int num_args;
+  int num_outputs;
+  int num_out_data;
+  int max_iterations;
+  // `cond' and `func' each takes a subset of while_loop's inputs as that to 
their subgraphs
+  // `cond_input_locs' contains indices of inputs fed to `cond', and
+  // `func_input_locs' contains indices of inputs fed to `func'.
+  // `func_var_locs' are indices in which input "variables" are stored in 
func's inputs.
+  nnvm::Tuple<dim_t> cond_input_locs;
+  nnvm::Tuple<dim_t> func_input_locs;
+  nnvm::Tuple<dim_t> func_var_locs;
+  DMLC_DECLARE_PARAMETER(WhileLoopParam) {
+    DMLC_DECLARE_FIELD(num_args).set_lower_bound(2)
+    .describe("Number of input arguments, including cond and func as two 
symbol inputs.");
+    DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1)
+    .describe("The number of outputs of the subgraph.");
+    DMLC_DECLARE_FIELD(num_out_data).set_lower_bound(0)
+    .describe("The number of outputs from the function body.");
+    DMLC_DECLARE_FIELD(max_iterations).set_lower_bound(1)
+    .describe("Maximum number of iterations.");
+    DMLC_DECLARE_FIELD(cond_input_locs)
+    .describe("The locations of cond's inputs in the given inputs.");
+    DMLC_DECLARE_FIELD(func_input_locs)
+    .describe("The locations of func's inputs in the given inputs.");
+    DMLC_DECLARE_FIELD(func_var_locs)
+    .describe("The locations of loop_vars among func's inputs.");
+  }
+};  // struct WhileLoopParam
+
+DMLC_REGISTER_PARAMETER(WhileLoopParam);
+
+class WhileLoopState: public LoopState {
+ public:
+  WhileLoopParam params;
+  size_t n_iterations;  // the actual number of steps taken in this while 
loop, <= max_iterations
+  CachedOpPtr cond_op;
+  // abbrev for output_input_mapping
+  // indicates to which index the output of `func' will be copied to the input 
of `cond'
+  std::vector<int> oi_map;
+
+  WhileLoopState(const WhileLoopParam &params, const Symbol &cond, const 
Symbol &func) :
+                 LoopState(func),
+                 params(params),
+                 n_iterations(0U),
+                 cond_op(LoopState::MakeSharedOp(cond)),
+                 oi_map(params.func_var_locs.ndim(), -1) {
+    const nnvm::Tuple<dim_t> &func_input_locs = params.func_input_locs;
+    const nnvm::Tuple<dim_t> &func_var_locs = params.func_var_locs;
+    const nnvm::Tuple<dim_t> &cond_input_locs = params.cond_input_locs;
+    for (size_t i = 0; i < func_var_locs.ndim(); ++i) {
+      dim_t pos_i = func_input_locs[func_var_locs[i]];
+      for (size_t j = 0; j < cond_input_locs.ndim(); ++j) {
+        dim_t pos_j = cond_input_locs[j];
+        if (pos_i == pos_j) {
+          this->oi_map[i] = j;
+        }
+      }
+    }
+  }
+  template <typename T>
+  static void extract_by_loc(const std::vector<T> &array,
+                             const nnvm::Tuple<dim_t> input_locs,
+                             std::vector<T> *out) {
+    out->clear();
+    out->reserve(input_locs.ndim());
+    for (dim_t i : input_locs) {
+      out->push_back(array[i]);
+    }
+  }
+  static bool is_shape_udf(const TShape &x) {
+    return x.ndim() == 0 || x.Size() == 0;
+  }
+  static bool is_stype_udf(const int &x) {
+    return x == exec::kBadStorageID;
+  }
+  static bool is_type_udf(const int &x) {
+    return x == -1;
+  }
+  template <typename T>
+  static bool fill_value(T *x, T *y, bool x_empty, bool y_empty) {
+    if (*x == *y || (x_empty && y_empty)) {
+      return true;
+    }
+    if (!x_empty && !y_empty) {
+      return false;
+    }
+    if (x_empty) {
+      *x = *y;
+    }
+    if (y_empty) {
+      *y = *x;
+    }
+    return true;
+  }
+  template <typename T>
+  static bool sync_in_in(const nnvm::Tuple<dim_t> &input_locs,
+                         std::vector<T> *in,
+                         std::vector<T> *subg_in,
+                         std::function<bool(const T &)> is_empty) {
+    for (size_t i = 0; i < input_locs.ndim(); ++i) {
+      T &x = in->at(input_locs[i]);
+      T &y = subg_in->at(i);
+      fill_value(&x, &y, is_empty(x), is_empty(y));
+    }
+    return true;
+  }
+  template <typename T>
+  static bool sync_in_out(const WhileLoopParam& params,
+                          std::vector<T> *in,
+                          std::vector<T> *out,
+                          std::function<bool(const T &)> is_empty) {
+    for (int i = params.num_out_data; i < params.num_outputs; ++i) {
+      // each out->at(i) is a params, loop_var
+      T &x = in->at(params.func_input_locs[params.func_var_locs[i - 
params.num_out_data]]);
+      T &y = out->at(i);
+      fill_value(&x, &y, is_empty(x), is_empty(y));
+    }
+    return true;
+  }
+};
+
+template <typename T>
+T _asscalar(const NDArray &a) {
+  CHECK_EQ(a.shape().Size(), 1U);
+  T data;
+  a.SyncCopyToCPU(&data, 1U);
+  return data;
+}
+
+bool as_bool_scalar(const NDArray &a) {
+  MSHADOW_TYPE_SWITCH(a.dtype(), DType, {
+    return static_cast<bool>(_asscalar<DType>(a));
+  });
+  LOG(FATAL) << "Unknown dtype";
+  return false;
+}
+
+static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr,
+                                  const OpContext& ctx,
+                                  const std::vector<NDArray>& inputs,
+                                  const std::vector<OpReqType>& req,
+                                  const std::vector<NDArray>& outputs) {
+  // The argument `inputs' are loop_vars and other inputs
+  // loop_vars are stored in stored in `loop_vars_locs'
+  // The argument `outputs' are output and new_loop_vars
+  // [0: num_out_data) are outputs at each step.
+  // [num_out_data: ) are new_loop_vars
+  // TODO(Junru): avoid dynamic NDArray allocation
+  WhileLoopState &state = state_ptr.get_state<WhileLoopState>();
+  const WhileLoopParam& params = state.params;
+  // a helper function, converting std::vector<NDArray> to 
std::vector<NDArray*>
+  const auto to_ptr_vec = [](std::vector<NDArray> &in, std::vector<NDArray*> 
*out) {
+    out->clear();
+    out->reserve(in.size());
+    std::transform(std::begin(in),
+                   std::end(in),
+                   std::back_inserter(*out),
+                   [](NDArray &a) {return &a;});
+  };
+  // sanity checks
+  CHECK_EQ(inputs.size() + 2U, (size_t) params.num_args);
+  CHECK_EQ(outputs.size(), (size_t) params.num_outputs);
+  CHECK_EQ(outputs.size(), req.size());
+  for (size_t i = 0; i < (size_t) params.num_out_data; i++)
+    CHECK_EQ(params.max_iterations, outputs[i].shape()[0]);
+  // construct inputs and outputs for cond
+  std::vector<NDArray> cond_inputs, cond_outputs = {NDArray()};
+  WhileLoopState::extract_by_loc(inputs, params.cond_input_locs, &cond_inputs);
+  std::vector<NDArray*> cond_input_ptr, cond_output_ptr;
+  to_ptr_vec(cond_inputs, &cond_input_ptr);
+  to_ptr_vec(cond_outputs, &cond_output_ptr);
+  // construct inputs and outputs for func
+  std::vector<NDArray> func_inputs, func_outputs(outputs.size());
+  WhileLoopState::extract_by_loc(inputs, params.func_input_locs, &func_inputs);
+  for (size_t &step = state.n_iterations = 0; step < (size_t) 
params.max_iterations; ++step) {
+    state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr);
+    if (!as_bool_scalar(*cond_output_ptr[0])) {
+      break;
+    }
+    // we create func_outputs for the current step:
+    // func_outputs[0: num_out_data] is a slice of outputs[][step]
+    for (size_t i = 0; i < (size_t) params.num_out_data; ++i) {
+      func_outputs[i] = outputs[i].At(step);
+    }
+    // func_outputs[num_out_data: ] are new_loop_vars, need to allocate new 
memory
+    for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
+      func_outputs[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true, 
outputs[i].dtype());
+    }
+    state.Forward(step, func_inputs, req, func_outputs, ctx.need_grad);
+    // func_inputs on the next step:
+    // the output (new_loop_vars) will become the new inputs (loop_vars)
+    for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
+      size_t j = params.func_var_locs[i - params.num_out_data];
+      CHECK_EQ(func_inputs[j].shape(), func_outputs[i].shape());
+      func_inputs[j] = func_outputs[i];
+      int k = state.oi_map[i - params.num_out_data];
+      if (k != -1) {
+        // I actually don't need to update cond_inputs
+        cond_inputs[k] = func_outputs[i];
+        cond_input_ptr[k] = &func_outputs[i];
+      }
+    }
+  }
+  // copy output data to `outputs'
+  // case 1: at least one step is executed,
+  // the final_loop_vars must be stored in func_inputs
+  // case 2: no step is executed
+  // the final_loop_vars is the same as loop_vars, which are also stored in 
func_inputs
+  // therefore, we copy func_inputs[:] to outputs[num_out_data: ]
+  for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
+    size_t j = params.func_var_locs[i - params.num_out_data];
+    mxnet::CopyFromTo(func_inputs[j], &outputs[i]);
+  }
+}
+
+static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr,
+                                      const OpContext& ctx,
+                                      const std::vector<NDArray>& inputs,
+                                      const std::vector<OpReqType>& _req,
+                                      const std::vector<NDArray>& _outputs) {
+  // inputs are dl / df(x)
+  // outputs are dl / dx
+  // where f is the current function,
+  // x is the input to the current function,
+  // TODO(Junru): avoid dynamic NDArray allocation
+  WhileLoopState &state = state_ptr.get_state<WhileLoopState>();
+  const WhileLoopParam& params = state.params;
+  // sanity checks
+  CHECK_EQ(_outputs.size() + 2U, (size_t) params.num_args);
+  CHECK_EQ(_outputs.size(), _req.size());
+  for (auto x : _req) {
+    CHECK_NE(x, kWriteInplace);
+  }
+  std::vector<NDArray> outputs;
+  std::vector<OpReqType> req;
+  WhileLoopState::extract_by_loc(_outputs, params.func_input_locs, &outputs);
+  WhileLoopState::extract_by_loc(_req, params.func_input_locs, &req);
+  if (state.n_iterations == 0) {
+    for (int i = params.num_out_data; i < params.num_outputs; ++i) {
+      int j = params.func_var_locs[i - params.num_out_data];
+      mxnet::CopyFromTo(inputs[i], &outputs[j]);
+    }
+    state.Cleanup();
+    return;
+  }
+  // collect var_locs and out_locs, positions other than var_locs are 
out_locs, i.e.
+  // [0, var_locs[0])
+  // (var_locs[1], var_locs[2]),
+  // (var_locs[2], var_locs[3]),
+  // ...
+  // (var_locs[-2], var_locs[-1] = params.num_args - 2)
+  std::vector<dim_t> var_locs(params.func_var_locs.begin(), 
params.func_var_locs.end());
+  var_locs.push_back((dim_t) params.num_args - 2U);
+  sort(var_locs.begin(), var_locs.end());
+  // vectors for the backward loop
+  std::vector<NDArray> ograds(params.num_outputs);
+  std::vector<NDArray> igrads(outputs.size());
+  std::vector<OpReqType> iter_req(req.size());
+  for (int i = params.num_out_data; i < params.num_outputs; ++i)
+    ograds[i] = inputs[i];
+  const int n_iter = state.n_iterations;
+  for (int step = n_iter - 1; step >= 0; --step) {
+    // ograds[ : num_out_data] = inputs[ : num_out_data][step]
+    // ograds[num_out_data: ] is maintained in the end of each loop
+    std::transform(std::begin(inputs),
+                   std::begin(inputs) + params.num_out_data,
+                   std::begin(ograds),
+                   [step] (const NDArray &a) { return a.At(step); } );
+    // igrads[i] =
+    //    outputs[i]            (step == 0)
+    //    outputs[i]            (step != 0 && i not in loop_var_locs)
+    //    ArrayLike(outputs[i]) (step != 0 && i in loop_var_locs)
+    // iter_req =
+    //    kWriteTo              (step != 0           && i in loop_var_locs)
+    //    req[i]                (step == 0           && i in loop_var_locs)
+    //    kAddTo                (step != n_iters - 1 && i not in loop_var_locs)
+    //    req[i]                (step == n_iters - 1 && i not in loop_var_locs)
+    {
+      size_t i = 0;
+      for (size_t loc : var_locs) {
+        for ( ; i < loc; ++i) {
+          // locs other that var_locs
+          igrads[i] = outputs[i];
+          iter_req[i] = (step + 1 == n_iter || req[i] == kNullOp)
+                      ? req[i]
+                      : kAddTo;
+        }
+        if (i < (size_t) params.num_args - 2U) {
+          // a var
+          igrads[i] = (step == 0)
+                    ? outputs[i]
+                    : NDArray(outputs[i].shape(), outputs[i].ctx(), true, 
outputs[i].dtype());
+          iter_req[i] = (step == 0 || req[i] == kNullOp)
+                      ? req[i]
+                      : kWriteTo;
+          ++i;
+        } else {
+          break;
+        }
+      }
+    }
+    state.Backward(step, ograds, iter_req, igrads);
+    for (int i = params.num_out_data; i < params.num_outputs; ++i) {
+      size_t j = params.func_var_locs[i - params.num_out_data];
+      ograds[i] = igrads[j];
+    }
+  }
+  state.Cleanup();
+}
+
+static bool WhileLoopShape(const nnvm::NodeAttrs& attrs,
+                           std::vector<TShape> *in_shape,
+                           std::vector<TShape> *out_shape) {
+  using nnvm::ShapeVector;
+  const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+  static const std::function<bool(const TShape &)> is_udf = 
WhileLoopState::is_shape_udf;
+  // sanity checks
+  CHECK_EQ(in_shape->size() + 2U, (size_t) params.num_args);
+  CHECK_EQ(out_shape->size(), (size_t) params.num_outputs);
+  CHECK_EQ(attrs.subgraphs.size(), 2U);
+  CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
+  // infer shape for cond and func
+  auto infer_subg = [&params, in_shape, out_shape](std::shared_ptr<Symbol> 
subg,
+                                                   ShapeVector *_subg_out,
+                                                   const nnvm::Tuple<dim_t> 
&input_locs,
+                                                   int num_out_data,
+                                                   bool fill_out_shape) {
+    // create subg_in
+    ShapeVector subg_in;
+    ShapeVector &subg_out = *_subg_out;
+    WhileLoopState::extract_by_loc(*in_shape, input_locs, &subg_in);
+    // create an indexed graph
+    nnvm::Graph g;
+    g.outputs = subg->outputs;
+    const auto& idx = g.indexed_graph();
+    // get input nodes
+    const auto &input_nids = idx.input_nodes();
+    // sanity checks
+    CHECK_EQ(input_nids.size(), subg_in.size());
+    CHECK_EQ(g.outputs.size(), subg_out.size());
+    CHECK_EQ(idx.input_nodes().size(), subg_in.size());
+    CHECK_EQ(idx.outputs().size(), subg_out.size());
+    // create empty shapes for inference
+    ShapeVector shapes(idx.num_node_entries());
+    // copy subg_in into shapes
+    for (size_t i = 0; i < subg_in.size(); ++i) {
+      auto eid = idx.entry_id(input_nids[i], 0);
+      shapes[eid] = subg_in[i];
+    }
+    // copy subg_out into shapes
+    // note that ndim of out_data is not increased
+    // because subg is only one step
+    for (size_t i = 0; i < subg_out.size(); ++i) {
+      auto eid = idx.entry_id(g.outputs[i]);
+      shapes[eid] = subg_out[i];
+    }
+    // copy done, call InferShape
+    g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
+    g = exec::InferShape(std::move(g));
+    // now `shapes' won't be used anymore, use new_shapes instead
+    const auto& new_shapes = g.GetAttr<ShapeVector>("shape");
+    // copy subg_in back to in_shape
+    for (size_t i = 0; i < subg_in.size(); ++i) {
+      auto eid = idx.entry_id(input_nids[i], 0);
+      auto g_out_shape = new_shapes[eid];
+      if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) {
+        // when the shape is not fully inferred
+        continue;
+      }
+      SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape);
+    }
+    if (!fill_out_shape) {
+      return true;
+    }
+    // copy subg_out back to out_shape
+    // for results in [0, num_out_data), ndim should increase by 1
+    for (int i = 0; i < num_out_data; ++i) {
+      auto eid = idx.entry_id(g.outputs[i]);
+      auto g_out_shape = new_shapes[eid];
+      if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) {
+        // when the shape is not fully inferred
+        continue;
+      }
+      auto out = TShape(g_out_shape.ndim() + 1);
+      out[0] = params.max_iterations;
+      for (size_t i = 1; i < out.ndim(); i++)
+        out[i] = g_out_shape[i - 1];
+      SHAPE_ASSIGN_CHECK(*out_shape, i, out);
+    }
+    // for results in [num_out_data, ...), ndim does not change
+    for (size_t i = num_out_data; i < g.outputs.size(); ++i) {
+      auto eid = idx.entry_id(g.outputs[i]);
+      auto g_out_shape = new_shapes[eid];
+      if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) {
+        // when the shape is not fully inferred
+        continue;
+      }
+      SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape);
+    }
+    return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
+  };
+  ShapeVector cond_out_shape{TShape(1U)};  // this means: [(1, )]
+  ShapeVector func_out_shape(params.num_outputs);
+  CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf));
+  bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, 
params.cond_input_locs, 0, false);
+  CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf));
+  bool succ_1 = infer_subg(attrs.subgraphs[1], &func_out_shape, \
+                           params.func_input_locs, params.num_out_data, true);
+  CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf));
+  return succ_0 && succ_1;
+}
+
+static bool WhileLoopType(const nnvm::NodeAttrs& attrs,
+                          std::vector<int> *in_type, std::vector<int> 
*out_type) {
+  const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+  static const std::function<bool(const int &)> is_udf = 
WhileLoopState::is_type_udf;
+  CHECK_EQ(in_type->size() + 2U, (size_t) params.num_args);
+  CHECK_EQ(out_type->size(), (size_t) params.num_outputs);
+  CHECK_EQ(attrs.subgraphs.size(), 2U);
+  CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
+  std::vector<int> cond_in_type;
+  std::vector<int> func_in_type;
+  WhileLoopState::extract_by_loc(*in_type, params.cond_input_locs, 
&cond_in_type);
+  WhileLoopState::extract_by_loc(*in_type, params.func_input_locs, 
&func_in_type);
+  std::vector<int> cond_out_type = {0};
+  CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf));
+  bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type, 
&cond_out_type);
+  CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf));
+  CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_type, 
&cond_in_type, is_udf));
+  bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &func_in_type, 
out_type);
+  CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf));
+  CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_type, 
&func_in_type, is_udf));
+  return succ_0 && succ_1;
+}
+
+static bool WhileLoopStorageType(const nnvm::NodeAttrs& attrs,
+                                 const int dev_mask,
+                                 DispatchMode* dispatch_mode,
+                                 std::vector<int> *in_attrs,
+                                 std::vector<int> *out_attrs) {
+  const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+  static const std::function<bool(const int &)> is_udf = 
WhileLoopState::is_stype_udf;
+  CHECK_EQ(in_attrs->size() + 2U, (size_t) params.num_args);
+  CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs);
+  CHECK_EQ(attrs.subgraphs.size(), 2U);
+  CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
+  std::vector<int> cond_in_attrs;
+  std::vector<int> func_in_attrs;
+  WhileLoopState::extract_by_loc(*in_attrs, params.cond_input_locs, 
&cond_in_attrs);
+  WhileLoopState::extract_by_loc(*in_attrs, params.func_input_locs, 
&func_in_attrs);
+  std::vector<int> cond_out_attrs = {kDefaultStorage};
+  DispatchMode cond_mode = DispatchMode::kUndefined;
+  DispatchMode func_mode = DispatchMode::kUndefined;
+  *dispatch_mode = DispatchMode::kFComputeEx;
+  CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf));
+  bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, \
+                                     &cond_mode, &cond_in_attrs, 
&cond_out_attrs);
+  CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf));
+  CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_attrs, 
&cond_in_attrs, is_udf));
+  bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, \
+                                     &func_mode, &func_in_attrs, out_attrs);
+  CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf));
+  CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_attrs, 
&func_in_attrs, is_udf));
+  return succ_0 && succ_1;
+}
+
+static bool BackwardWhileLoopStorageType(const nnvm::NodeAttrs& attrs,
+                                         const int dev_mask,
+                                         DispatchMode* dispatch_mode,
+                                         std::vector<int> *in_attrs,
+                                         std::vector<int> *out_attrs) {
+  // `cond' is not backwarded, don't check
+  const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+  CHECK_EQ(out_attrs->size() + 2U, (size_t) params.num_args);
+  CHECK_EQ(attrs.subgraphs.size(), 2U);
+  CachedOp op(*attrs.subgraphs[1], {});
+  return op.BackwardStorageType(attrs, dev_mask, dispatch_mode,
+                                in_attrs, out_attrs);
+}
+
+static OpStatePtr CreateWhileLoopState(const NodeAttrs& attrs,
+                                       Context ctx,
+                                       const std::vector<TShape>& ishape,
+                                       const std::vector<int>& itype) {
+  const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+  return OpStatePtr::Create<WhileLoopState>(params, *attrs.subgraphs[0], 
*attrs.subgraphs[1]);
+}
+
+static std::vector<nnvm::NodeEntry>
+WhileLoopGradient(const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& 
ograds) {
+  ElemwiseGradUseInOut fgrad{"_backward_while_loop"};
+  std::vector<nnvm::NodeEntry> entries = fgrad(n, ograds);
+  entries[0].node->attrs.subgraphs = n->attrs.subgraphs;
+  return entries;
+}
+
 NNVM_REGISTER_OP(_foreach)
 .MXNET_DESCRIBE("Run a for loop over an NDArray with user-defined computation")
 .set_attr_parser(ParamParser<ForeachParam>)
@@ -526,11 +1023,11 @@ NNVM_REGISTER_OP(_backward_foreach)
 .set_num_inputs([](const NodeAttrs& attrs){
   const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
   return params.num_outputs * 2 + params.num_args - 1;
-  })
+})
 .set_num_outputs([](const NodeAttrs& attrs){
   const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
   return params.num_args - 1;
-  })
+})
 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
   return ExecType::kSubgraphExec;
 })
@@ -541,5 +1038,67 @@ NNVM_REGISTER_OP(_backward_foreach)
 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", 
ForeachGradComputeExCPU)
 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", 
ForeachGradComputeExCPU);
 
+NNVM_REGISTER_OP(_while_loop)
+.MXNET_DESCRIBE("Run a while loop over with user-defined condition and 
computation")
+.set_attr_parser(ParamParser<WhileLoopParam>)
+.set_attr<FInferStorageType>("FInferStorageType", WhileLoopStorageType)
+.set_num_inputs([](const NodeAttrs& attrs) {
+  const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+  return params.num_args;
+})
+.set_num_outputs([](const NodeAttrs& attrs) {
+  const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+  return params.num_outputs;
+})
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+    [](const NodeAttrs& attrs) {
+  const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+  std::vector<std::string> names;
+  names.reserve(params.num_args);
+  names.push_back("cond");
+  names.push_back("func");
+  for (int i = 2; i < params.num_args; i++)
+    names.push_back("data" + std::to_string(i - 2));
+  return names;
+})
+.set_attr<nnvm::FInputGraph>("FInputGraph",
+    [](const NodeAttrs& attrs) {
+  return std::vector<uint32_t>{0, 1};
+})
+.set_attr<nnvm::FGradient>("FGradient", WhileLoopGradient)
+.set_attr<FCreateOpState>("FCreateOpState", CreateWhileLoopState)
+.set_attr<nnvm::FInferShape>("FInferShape", WhileLoopShape)
+.set_attr<nnvm::FInferType>("FInferType", WhileLoopType)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", WhileLoopComputeExCPU)
+.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
+  return ExecType::kSubgraphExec;
+})
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", WhileLoopComputeExCPU)
+.set_attr<std::string>("key_var_num_args", "num_args")
+.add_argument("cond", "Symbol", "Input graph for the loop condition.")
+.add_argument("func", "Symbol", "Input graph for the loop body.")
+.add_argument("data", "NDArray-or-Symbol[]",
+              "The input arrays that include data arrays and states.")
+.add_arguments(WhileLoopParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_while_loop)
+.set_num_inputs([](const NodeAttrs& attrs){
+  const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+  return params.num_outputs * 2 + params.num_args - 2;
+})
+.set_num_outputs([](const NodeAttrs& attrs){
+  const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+  return params.num_args - 2;
+})
+.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
+  return ExecType::kSubgraphExec;
+})
+.set_attr<FInferStorageType>("FInferStorageType", BackwardWhileLoopStorageType)
+.set_attr_parser(ParamParser<WhileLoopParam>)
+.set_attr<bool>("TIsLayerOpBackward", true)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", 
WhileLoopGradComputeExCPU)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", 
WhileLoopGradComputeExCPU);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/subgraph_op_common.cc 
b/src/operator/subgraph_op_common.cc
index 71a9a21..d845aa9 100644
--- a/src/operator/subgraph_op_common.cc
+++ b/src/operator/subgraph_op_common.cc
@@ -164,14 +164,7 @@ bool InferSubgraphShape(const nnvm::Symbol &subgraph,
 LoopState::LoopState(const Symbol &g) {
   this->subgraph_sym = g;
   this->subgraph.outputs = g.outputs;
-
-  std::vector<std::pair<std::string, std::string> > kwargs;
-  kwargs.push_back(std::pair<std::string, std::string>("inline_limit", "0"));
-  // We turn on static_alloc for two reasons.
-  // It avoids the overhead of unnecessary memory allocation.
-  // only static_alloc supports nested call of CachedOp.
-  kwargs.push_back(std::pair<std::string, std::string>("static_alloc", "1"));
-  iter_op = std::make_shared<CachedOp>(subgraph_sym, kwargs);
+  this->iter_op = LoopState::MakeSharedOp(g);
 }
 
 void LoopState::Forward(int iter_no,
diff --git a/src/operator/subgraph_op_common.h 
b/src/operator/subgraph_op_common.h
index 7907840..f73f09c 100644
--- a/src/operator/subgraph_op_common.h
+++ b/src/operator/subgraph_op_common.h
@@ -24,6 +24,8 @@
 #include <mxnet/base.h>
 #include <mxnet/op_attr_types.h>
 #include <vector>
+#include <utility>
+#include <string>
 #include "../imperative/cached_op.h"
 #include "../imperative/imperative_utils.h"
 
@@ -69,8 +71,8 @@ class LoopState {
   // For training, each iteration has a cached op because each iteration
   // needs to maintain a set of memory buffers for all computation states,
   // which will be used in the backward.
-  CachedOpPtr iter_op;
   std::vector<OpStatePtr> all_states;
+  CachedOpPtr iter_op;
   Symbol subgraph_sym;
   nnvm::Graph subgraph;
 
@@ -91,6 +93,16 @@ class LoopState {
     all_inputs.clear();
     all_states.clear();
   }
+  static CachedOpPtr MakeSharedOp(const Symbol &sym) {
+    // We turn on static_alloc for two reasons.
+    // It avoids the overhead of unnecessary memory allocation.
+    // only static_alloc supports nested call of CachedOp.
+    std::vector<std::pair<std::string, std::string> > kwargs = {
+      {"inline_limit", "0"},
+      {"static_alloc", "1"}
+    };
+    return std::make_shared<CachedOp>(sym, kwargs);
+  }
 };
 
 }  // namespace op
diff --git a/tests/python/unittest/test_contrib_control_flow.py 
b/tests/python/unittest/test_contrib_control_flow.py
new file mode 100644
index 0000000..9dd5c43
--- /dev/null
+++ b/tests/python/unittest/test_contrib_control_flow.py
@@ -0,0 +1,978 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from mxnet import gluon
+import numpy as np
+import copy
+from numpy.testing import assert_allclose
+import unittest
+from mxnet.test_utils import almost_equal, default_context
+from numpy.testing import assert_allclose as assert_almost_equal  # This is 
more restrictive
+from mxnet.base import _as_list
+
+
+def test_while_loop_simple_forward():
+
+    class _TestBlock(gluon.HybridBlock):
+
+        def __init__(self, cond, func, max_iterations):
+            super(_TestBlock, self).__init__()
+            self.cond = cond
+            self.func = func
+            self.max_iterations = max_iterations
+
+        def hybrid_forward(self, F, *loop_vars):
+            return F.contrib.while_loop(
+                cond=self.cond,
+                func=self.func,
+                loop_vars=loop_vars,
+                max_iterations=self.max_iterations
+            )
+
+    for hybridize in [False, True]:
+        # Case 1.1: result should be sum([1, 2, 3 ... 100])
+        model = _TestBlock(
+            cond=lambda i, s: i <= 5,
+            func=lambda i, s: (None, (i + 1, s + i)),
+            max_iterations=10,
+        )
+        if hybridize:
+            model.hybridize()
+        _, result = model(
+            mx.nd.array([1], dtype="int64"), # i
+            mx.nd.array([0], dtype="int64"), # s
+        )
+        assert result[0].asscalar() == 6
+        assert result[1].asscalar() == 15
+        # Case 1.2: result should be sum([1, 2, 3 ... 1000])
+        model = _TestBlock(
+            cond=lambda i, s, true: true,
+            func=lambda i, s, true: (None, (i + 1, s + i, true)),
+            max_iterations=1000,
+        )
+        if hybridize:
+            model.hybridize()
+        _, result = model(
+            mx.nd.array([1], dtype="int64"), # i
+            mx.nd.array([0], dtype="int64"), # s
+            mx.nd.array([1], dtype="int64"), # true
+        )
+        assert result[0].asscalar() == 1001
+        assert result[1].asscalar() == 500500
+        assert result[2].asscalar() == 1
+        # Case 1.3: result should be sum([])
+        model = _TestBlock(
+            cond=lambda i, s, false: false,
+            func=lambda i, s, false: (None, (i + 1, s + i, false)),
+            max_iterations=1000,
+        )
+        if hybridize:
+            model.hybridize()
+        _, result = model(
+            mx.nd.array([1], dtype="int64"), # i
+            mx.nd.array([0], dtype="int64"), # s
+            mx.nd.array([0], dtype="int64"), # false
+        )
+        assert result[0].asscalar() == 1
+        assert result[1].asscalar() == 0
+        assert result[2].asscalar() == 0
+        # Case 2.1: result should be sum([1, 2, 3 ... 100])
+        model = _TestBlock(
+            cond=lambda i, s: i <= 100,
+            func=lambda i, s: (i, (i + 1, s + i)),
+            max_iterations=1000,
+        )
+        if hybridize:
+            model.hybridize()
+        (outputs, ), (result_i, result_s) = model(
+            mx.nd.array([1], dtype="int64"), # i
+            mx.nd.array([0], dtype="int64"), # s
+        )
+        assert all(outputs.asnumpy()[ : 100] == np.arange(1, 101).reshape(100, 
1))
+        assert result_i.asscalar() == 101
+        assert result_s.asscalar() == 5050
+        # Case 2.2: result should be sum([1, 2, 3 ... 1000])
+        model = _TestBlock(
+            cond=lambda i, s, true: true,
+            func=lambda i, s, true: (i, (i + 1, s + i, true)),
+            max_iterations=1000,
+        )
+        if hybridize:
+            model.hybridize()
+        (outputs, ), (result_i, result_s, _) = model(
+            mx.nd.array([1], dtype="int64"), # i
+            mx.nd.array([0], dtype="int64"), # s
+            mx.nd.array([1], dtype="int64"), # true
+        )
+        assert all(outputs.asnumpy() == np.arange(1, 1001).reshape(1000, 1))
+        assert result_i.asscalar() == 1001
+        assert result_s.asscalar() == 500500
+        # Case 2.3: a corner case, in which loop body is never executed
+        model = _TestBlock(
+            cond=lambda i, s, false: false,
+            func=lambda i, s, false: (i, (i + 1, s + i, false)),
+            max_iterations=1000,
+        )
+        if hybridize:
+            model.hybridize()
+        _, (result_i, result_s, _) = model(
+            mx.nd.array([1], dtype="int64"), # i
+            mx.nd.array([0], dtype="int64"), # s
+            mx.nd.array([0], dtype="int64"), # false
+        )
+        assert result_i.asscalar() == 1
+        assert result_s.asscalar() == 0
+
+
+def _verify_while_loop(cond, func, loop_var_shapes, free_var_shapes, is_train, 
max_iterations, is_for, n_steps):
+
+    def _create_vars(num, prefix):
+        return [mx.sym.var(prefix + str(i)) for i in range(num)]
+
+    def _create_arrays(shapes):
+        return [mx.nd.random.uniform(-1.0, 1.0, shape=x) for x in shapes]
+
+    def _create_dict(prefix, shapes):
+        return {prefix + str(i): mx.nd.random.uniform(-1.0, 1.0, shape=x) for 
i, x in enumerate(shapes)}
+
+    def _merge_dict(*dicts):
+        result = {}
+        for item in dicts:
+            result.update(item)
+        return result
+
+    def _to_numpy_list(arrays):
+        return [x.asnumpy() if x is not None else x for x in arrays]
+
+    def _get_imperative_result(n_steps):
+        free_vars = [args["FreeVar" + str(i)].copy() for i, _ in 
enumerate(free_var_shapes)]
+        loop_vars = [args["LoopVar" + str(i)].copy() for i, _ in 
enumerate(loop_var_shapes)]
+        loop_var_start = int(is_for)
+        if is_train:
+            for var in free_vars + loop_vars[loop_var_start: ]:
+                var.attach_grad()
+        with mx.autograd.record(train_mode=is_train):
+            outputs, final_loop_vars = mx.nd.contrib.while_loop(
+                cond=lambda *_loop_vars: cond(_loop_vars, free_vars),
+                func=lambda *_loop_vars: func(_loop_vars, free_vars),
+                loop_vars=loop_vars,
+                max_iterations=max_iterations,
+            )
+            outputs = [x[: n_steps] for x in outputs]
+            out_grads = _create_arrays(x.shape for x in outputs)  \
+                      + _create_arrays(x.shape for x in final_loop_vars)
+            loop_result_nd = [x * 2 for x in outputs] + [x * 3 for x in 
final_loop_vars]
+            grads = []
+            if is_train:
+                cat_out = mx.nd.concat(*[x.reshape(-1) for x in 
loop_result_nd], dim=0)
+                cat_out.backward(out_grad=mx.nd.concat(*[x.reshape(-1) for x 
in out_grads], dim=0))
+                grads = [free_vars[i].grad for i, _ in 
enumerate(free_var_shapes)] \
+                      + [loop_vars[i].grad for i, _ in 
enumerate(loop_var_shapes) if i >= loop_var_start]
+            return _to_numpy_list(loop_result_nd), _to_numpy_list(grads), 
out_grads
+
+    def _get_symbolic_result(out_grads, n_steps):
+
+        def _copy_args_dict(name_list):
+            return {name: args[name].copy() for name in name_list}
+
+        def _zeros_like_dict(name_list):
+            return {name: mx.nd.zeros_like(args[name]) for name in name_list}
+
+        free_syms = _create_vars(len(free_var_shapes), "FreeVar")
+        loop_syms = _create_vars(len(loop_var_shapes), "LoopVar")
+        outputs, final_loop_syms = mx.sym.contrib.while_loop(
+            cond=lambda *_loop_vars: cond(_loop_vars, free_syms),
+            func=lambda *_loop_vars: func(_loop_vars, free_syms),
+            loop_vars=loop_syms,
+            max_iterations=max_iterations,
+        )
+        if n_steps == 0:
+            outputs = []
+        else:
+            outputs = [x.slice_axis(axis=0, begin=0, end=n_steps) for x in 
outputs]
+        loop_result_sym = [x * 2 for x in outputs] + [x * 3 for x in 
final_loop_syms]
+        loop_result_sym = mx.sym.Group(loop_result_sym)
+
+        loop_var_start = int(is_for)
+        args_names = ["FreeVar" + str(i) for i, _ in 
enumerate(free_var_shapes)] \
+                   + ["LoopVar" + str(i) for i, _ in 
enumerate(loop_var_shapes) if i >= loop_var_start]
+        args_grad = None if not is_train else _zeros_like_dict(x for x in 
args_names)
+        executor = loop_result_sym.bind(
+            ctx=default_context(),
+            args=_copy_args_dict(loop_result_sym.list_inputs()),
+            args_grad=args_grad,
+        )
+        loop_result_nd = executor.forward(is_train=is_train)
+        grads = []
+        if is_train:
+            executor.backward(out_grads=out_grads)
+            grads = [executor.grad_dict.get("FreeVar" + str(i), None) for i, _ 
in enumerate(free_var_shapes)] \
+                  + [executor.grad_dict.get("LoopVar" + str(i), None) for i, _ 
in enumerate(loop_var_shapes) if i >= loop_var_start]
+        return _to_numpy_list(loop_result_nd), _to_numpy_list(grads)
+
+    args = _merge_dict(
+        _create_dict("FreeVar", free_var_shapes),
+        _create_dict("LoopVar", loop_var_shapes),
+    )
+    if is_for:
+        assert loop_var_shapes[0] == (1, )
+        args["LoopVar0"] = mx.nd.array([0])
+    imp_outs, imp_grads, out_grads = _get_imperative_result(n_steps)
+    sym_outs, sym_grads = _get_symbolic_result(out_grads, n_steps)
+    for imp_out, sym_out in zip(imp_outs, sym_outs):
+        if imp_out is None or sym_out is None:
+            continue
+        assert_almost_equal(imp_out, sym_out, rtol=1e-4, atol=1e-4)
+    for imp_grad, sym_grad in zip(imp_grads, sym_grads):
+        if imp_grad is None or sym_grad is None:
+            continue
+        assert_almost_equal(imp_grad, sym_grad, rtol=1e-4, atol=1e-4)
+
+
+def test_while_loop_for_foreach():
+
+    def make_true_cond():
+        return lambda loop_vars, _: (loop_vars[0] < 1e200).prod()
+
+    def make_false_cond():
+        return lambda loop_vars, _: (loop_vars[0] > 1e200).prod()
+
+    def make_for_cond(length):
+        return lambda loop_vars, _: loop_vars[0] < length
+
+    def case_0():
+        # This is a simple testcase that all loop steps are independent'
+        # It basically scans the array and outputs itself
+        # There is 1 output
+        # There is 1 state: i
+        def _simple_func(loop, free):
+            (i, ), (scanned, ) = loop, free
+            in_ = scanned.take(i).squeeze(axis=0)
+            return (in_, i + 1)
+        _verify_while_loop(
+            cond=make_true_cond(),
+            func=_simple_func,
+            max_iterations=1,
+            is_train=True,
+            is_for=True,
+            loop_var_shapes=[
+                (1, ),          # i
+            ],
+            free_var_shapes=[
+                (1, 3),         # scanned
+            ],
+            n_steps=1,
+        )
+
+    def case_1(**params):
+        # This is a simple testcase that simulates a cumulative sum
+        # There is 1 output
+        # There is 1 state: s
+        step_funcs = [
+            lambda a, b, s: s,
+            lambda a, b, s: a * 1.5 + b * 2.5 - s * 3.5,
+            lambda a, b, s: a * 1.5 - s * 3.5 + b * 2.5,
+            lambda a, b, s: b * 2.5 + a * 1.5 - s * 3.5,
+            lambda a, b, s: b * 2.5 - s * 3.5 + a * 1.5,
+            lambda a, b, s: s * -3.5 + a * 1.5 + b * 2.5,
+            lambda a, b, s: s * -3.5 + b * 2.5 + a * 1.5,
+            lambda a, b, s: a * 2.5 * b + s * 0.3,
+            lambda a, b, s: b * 2.5 * a + s * 0.3,
+            lambda a, b, s: 2.5 * a * b + s * 0.3,
+            lambda a, b, s: b * a * 2.5 + s * 0.3,
+            lambda a, b, s: 2.5 * b * a + s * 0.3,
+            lambda a, b, s: b * a * 2.5 + s * 0.3,
+            lambda a, b, s: s * 0.3 + a * 2.5 * b,
+            lambda a, b, s: s * 0.3 + b * 2.5 * a,
+            lambda a, b, s: s * 0.3 + 2.5 * a * b,
+            lambda a, b, s: s * 0.3 + b * a * 2.5,
+            lambda a, b, s: s * 0.3 + 2.5 * b * a,
+            lambda a, b, s: s * 0.3 + b * a * 2.5,
+        ]
+        def make_func(step_func):
+            def step(loop, free):
+                (s, ), (a, b) = loop, free
+                out = step_func(a, b, s)
+                return (out, out)
+            return step
+        case_id = 0
+        for is_train in [True, False]:
+            for step_func in step_funcs:
+                case_id += 1
+                _verify_while_loop(
+                    func=make_func(step_func),
+                    is_train=is_train,
+                    is_for=False,
+                    **params
+                )
+
+    def case_2(**params):
+        # This is a testcase that involves non-differentiable operators
+        # There is 1 output
+        # There is 2 states: i, s
+        step_funcs = [
+            lambda in_, s, f_1: (in_ * 2) * s * f_1,
+            lambda in_, s, f_1: (in_ * 2) * f_1 * s,
+            lambda in_, s, f_1: s * (in_ * 2) * f_1,
+            lambda in_, s, f_1: s * f_1 * (in_ * 2),
+            lambda in_, s, f_1: f_1 * (in_ * 2) * s,
+            lambda in_, s, f_1: f_1 * s * (in_ * 2),
+            lambda in_, s, f_1: (2 * in_) * s * f_1,
+            lambda in_, s, f_1: (2 * in_) * f_1 * s,
+            lambda in_, s, f_1: s * (2 * in_) * f_1,
+            lambda in_, s, f_1: s * f_1 * (2 * in_),
+            lambda in_, s, f_1: f_1 * (2 * in_) * s,
+            lambda in_, s, f_1: f_1 * s * (2 * in_),
+        ]
+        def make_func(step_func):
+            """This simulates:
+            def compute(s, inputs, f_1, length):
+                outputs = []
+                for i in range(length):
+                    s += inputs[i] * 2 + f_1
+                    outputs.append(s)
+                return outputs, s
+            """
+            def step(loop, free):
+                (i, s), (scanned, f_1, _) = loop, free
+                in_ = scanned.take(i).squeeze(axis=0)
+                out = step_func(in_, s, f_1)
+                return (out, (i + 1, out))
+            return step
+        case_id = 0
+        for is_train in [True, False]:
+            for step_func in step_funcs:
+                case_id += 1
+                _verify_while_loop(
+                    func=make_func(step_func),
+                    max_iterations=1000,
+                    is_train=is_train,
+                    is_for=True,
+                    **params
+                )
+
+    def case_3(length, **params):
+        # This is a testcase for multiple non-differentiable operators and 
different ways of slicing
+        # There are 2 outputs
+        # There are 3 states: i, s_0, s_1
+        step_funcs = [
+            lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) 
* f_0,
+            lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 
* 2),
+            lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 
* f_0,
+            lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 
* s_0,
+            lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) 
* f_0,
+            lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 
* 2),
+            lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 
* f_0,
+            lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 
* s_0,
+            lambda i_0, i_1, s_0, s_1, f_0: i_0,
+            lambda i_0, i_1, s_0, s_1, f_0: i_1,
+            lambda i_0, i_1, s_0, s_1, f_0: s_0,
+            lambda i_0, i_1, s_0, s_1, f_0: s_1,
+            lambda i_0, i_1, s_0, s_1, f_0: f_0,
+        ]
+        def make_func(step_func):
+            """This simulates:
+            def compute(input_0, input_1, s_0, s_1, f_0, length):
+                output_0 = []
+                output_1 = []
+                for i in range(length):
+                    i_0 = input_0[i]
+                    i_1 = input_1[length - 1 - i]
+                    out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0
+                    s_0 = (s_0 + out) * 1.05
+                    s_1 = (s_1 - out * 0.5) * 0.95
+                    output_0.append(out)
+                    output_1.append(out * 1.5)
+                return outputs, s_0, s_1
+            """
+            def step(loop, free):
+                (i, s_0, s_1), (sc_0, sc_1, f_0, _) = loop, free
+                i_0 = sc_0.take(i).squeeze(axis=0)
+                i_1 = sc_1.take(length - 1 - i).squeeze(axis=0)
+                out = step_func(i_0, i_1, s_0, s_1, f_0)
+                return ([out, out * 1.5], [i + 1, (s_0 + out) * 1.05, (s_1 - 
out * 0.5) * 0.95])
+            return step
+        case_id = 0
+        for is_train in [True, False]:
+            for step_func in step_funcs:
+                case_id += 1
+                _verify_while_loop(
+                    func=make_func(step_func),
+                    max_iterations=1000,
+                    is_train=is_train,
+                    is_for=True,
+                    **params
+                )
+
+    def case_4(length, single_shape, **params):
+        # It is for the case that inputs & outputs are the same
+        # There are 3 outputs
+        # There are 4 states: i, s_0, s_1, s_2
+        # i is used in both non-differentiable (take) and differentiable (+) 
occasions
+        step_funcs = [
+            lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) 
* f_0,
+            lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 
* 2),
+            lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 
* f_0,
+            lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 
* s_0,
+            lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) 
* f_0,
+            lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 
* 2),
+            lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 
* f_0,
+            lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 
* s_0,
+            lambda i_0, i_1, s_0, s_1, f_0: i_0,
+            lambda i_0, i_1, s_0, s_1, f_0: i_1,
+            lambda i_0, i_1, s_0, s_1, f_0: s_0,
+            lambda i_0, i_1, s_0, s_1, f_0: s_1,
+            lambda i_0, i_1, s_0, s_1, f_0: f_0,
+        ]
+        def make_func(step_func):
+            """This simulates:
+            def compute(input_0, input_1, s_0, s_1, s_2, f_0, length):
+                # here s_2 remains untouched
+                output_0 = []
+                output_1 = []
+                output_2 = []
+                for i in range(length):
+                    i_0 = input_0[i]
+                    i_1 = input_1[length - 1 - i]
+                    out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0
+                    out = out * i * i_0 * i_1
+                    s_0 = (s_0 + out) * 1.05
+                    s_1 = (s_1 - out * 0.5) * 0.95
+                    output_0.append(out)
+                    output_1.append(f_0)
+                    output_2.append(out * 1.5)
+                return output_0, output_1, output_2, s_0, s_1, s_2
+            """
+            def step(loop, free):
+                (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free
+                i_0 = sc_0.take(i).squeeze(axis=0)
+                i_1 = sc_1.take(length - 1 - i).squeeze(axis=0)
+                out = step_func(i_0, i_1, s_0, s_1, f_0)
+                out = out * i.reshape([1] * 
len(single_shape)).broadcast_to(single_shape)
+                out = out * i_0 * i_1
+                return ([out, f_0, out * 1.5], [i + 1, (s_0 + out) * 1.05, 
(s_1 - out * 0.5) * 0.95, s_2])
+            return step
+        case_id = 0
+        for is_train in [True, False]:
+            for step_func in step_funcs:
+                case_id += 1
+                _verify_while_loop(
+                    func=make_func(step_func),
+                    max_iterations=1000,
+                    is_train=is_train,
+                    is_for=True,
+                    **params
+                )
+
+    def case_5(length, single_shape, **params):
+        # It is for the case that inputs & outputs are the same
+        # There are 0 outputs
+        # There are 4 states: i, s_0, s_1, s_2
+        # i is used in both differentiable (take) and non-differentiable (+) 
occasions
+        step_funcs = [
+            lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) 
* f_0,
+            lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 
* 2),
+            lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 
* f_0,
+            lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 
* s_0,
+            lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) 
* f_0,
+            lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 
* 2),
+            lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 
* f_0,
+            lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 
* s_0,
+            lambda i_0, i_1, s_0, s_1, f_0: i_0,
+            lambda i_0, i_1, s_0, s_1, f_0: i_1,
+            lambda i_0, i_1, s_0, s_1, f_0: s_0,
+            lambda i_0, i_1, s_0, s_1, f_0: s_1,
+            lambda i_0, i_1, s_0, s_1, f_0: f_0,
+        ]
+        def make_func(step_func):
+            """This simulates:
+            def compute(input_0, input_1, s_0, s_1, s_2, f_0, length):
+                # here s_2 remains untouched
+                output_0 = []
+                output_1 = []
+                output_2 = []
+                for i in range(length):
+                    i_0 = input_0[i]
+                    i_1 = input_1[length - 1 - i]
+                    out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0
+                    out = out * i * i_0 * i_1
+                    s_0 = (s_0 + out) * 1.05
+                    s_1 = (s_1 - out * 0.5) * 0.95
+                    output_0.append(out)
+                    output_1.append(f_0)
+                    output_2.append(out * 1.5)
+                return output_0, output_1, output_2, s_0, s_1, s_2
+            """
+            def step(loop, free):
+                (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free
+                i_0 = sc_0.take(i).squeeze(axis=0)
+                i_1 = sc_1.take(length - 1 - i).squeeze(axis=0)
+                out = step_func(i_0, i_1, s_0, s_1, f_0)
+                out = out * i.reshape([1] * 
len(single_shape)).broadcast_to(single_shape)
+                out = out * i_0 * i_1
+                return ([], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 
0.95, s_2])
+            return step
+        case_id = 0
+        for is_train in [True, False]:
+            for step_func in step_funcs:
+                case_id += 1
+                _verify_while_loop(
+                    func=make_func(step_func),
+                    max_iterations=1000,
+                    is_train=is_train,
+                    is_for=True,
+                    **params
+                )
+
+    def case_6(length, single_shape, **params):
+        # It is for the case that inputs & outputs are the same
+        # There are 3 outputs
+        # There are 4 states: i, s_0, s_1, s_2
+        # i is used in both differentiable (take) and non-differentiable (+) 
occasions
+        step_funcs = [
+            lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) 
* f_0,
+            lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 
* 2),
+            lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 
* f_0,
+            lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 
* s_0,
+            lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) 
* f_0,
+            lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 
* 2),
+            lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 
* f_0,
+            lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 
* s_0,
+            lambda i_0, i_1, s_0, s_1, f_0: i_0,
+            lambda i_0, i_1, s_0, s_1, f_0: i_1,
+            lambda i_0, i_1, s_0, s_1, f_0: s_0,
+            lambda i_0, i_1, s_0, s_1, f_0: s_1,
+            lambda i_0, i_1, s_0, s_1, f_0: f_0,
+        ]
+        def make_func(step_func):
+            """This simulates:
+            def compute(input_0, input_1, s_0, s_1, s_2, f_0, length):
+                # here s_2 remains untouched
+                output_0 = []
+                output_1 = []
+                output_2 = []
+                for i in range(length):
+                    i_0 = input_0[i]
+                    i_1 = input_1[length - 1 - i]
+                    out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0
+                    out = out * i * i_0 * i_1
+                    s_0 = (s_0 + out) * 1.05
+                    s_1 = (s_1 - out * 0.5) * 0.95
+                    output_0.append(out)
+                    output_1.append(f_0)
+                    output_2.append(out * 1.5)
+                return output_0, output_1, output_2, s_0, s_1, s_2
+            """
+            def step(loop, free):
+                (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free
+                F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd
+                i_0 = sc_0.take(i).squeeze(axis=0)
+                i_1 = sc_1.take(length - 1 - i).squeeze(axis=0)
+                out_0 = step_func(i_0, i_1, s_0, s_1, f_0)
+                out_0 = out_0 * i.reshape([1] * 
len(single_shape)).broadcast_to(single_shape)
+                out_1 = step_func(i_1, s_0, f_0, s_1, i_0)
+                out_1 = out_1 * i.reshape([1] * 
len(single_shape)).broadcast_to(single_shape)
+                return ([F.dot(out_0, s_2), f_0, F.dot(s_2, out_1) * 1.5], [i 
+ 1, (s_0 + out_1) * 1.05, (s_1 - out_0 * 0.5) * 0.95, s_2])
+            return step
+        case_id = 0
+        for is_train in [True, False]:
+            for step_func in step_funcs:
+                case_id += 1
+                _verify_while_loop(
+                    func=make_func(step_func),
+                    max_iterations=1000,
+                    is_train=is_train,
+                    is_for=True,
+                    **params
+                )
+
+    # Case 0: the simpest case
+    case_0()
+    # Case 1.1.*
+    case_1(
+        cond=make_true_cond(),
+        loop_var_shapes=[
+            (1, ),          # s
+        ],
+        free_var_shapes=[
+            (1, ),          # a
+            (1, ),          # b
+        ],
+        max_iterations=23,
+        n_steps=23,
+    )
+    # Case 1.2.*
+    case_1(
+        cond=make_true_cond(),
+        loop_var_shapes=[
+            (2, 3, 4),      # s
+        ],
+        free_var_shapes=[
+            (2, 3, 4),      # a
+            (2, 3, 4),      # b
+        ],
+        max_iterations=31,
+        n_steps=31,
+    )
+    # Case 1.3.*
+    case_1(
+        cond=make_false_cond(),
+        loop_var_shapes=[
+            (2, 3, 4),      # s
+        ],
+        free_var_shapes=[
+            (2, 3, 4),      # a
+            (2, 3, 4),      # b
+        ],
+        max_iterations=20,
+        n_steps=0,
+    )
+    # Case 2.1.*
+    case_2(
+        cond=make_for_cond(length=31),
+        loop_var_shapes=[
+            (1, ),          # i
+            (2, ),          # s
+        ],
+        free_var_shapes=[
+            (100, 2),       # scanned
+            (2, ),          # f_1
+            (3, 4, 5, 6),   # f_2, unused
+        ],
+        n_steps=31,
+    )
+    # Case 2.2.*
+    case_2(
+        cond=make_for_cond(length=25),
+        loop_var_shapes=[
+            (1, ),          # i
+            (2, ),          # s
+        ],
+        free_var_shapes=[
+            (30, 2),        # scanned
+            (2, ),          # f_1
+            (3, 4, 5, 6),   # f_2, unused
+        ],
+        n_steps=25,
+    )
+    # Case 3.*
+    case_3(
+        length=11,
+        cond=make_for_cond(length=11),
+        loop_var_shapes=[
+            (1, ),          # i
+            (2, ),          # s_0
+            (2, ),          # s_1
+        ],
+        free_var_shapes=[
+            (30, 2),        # sc_0
+            (30, 2),        # sc_1
+            (2, ),          # f_0
+            (3, 4, 5, 6),   # f_1, unused
+        ],
+        n_steps=11,
+    )
+    # Case 4.1.*
+    case_4(
+        length=4,
+        cond=make_for_cond(length=4),
+        single_shape=[5],
+        loop_var_shapes=[
+            (1, ),          # i
+            (5, ),          # s_0
+            (5, ),          # s_1
+            (23, 6, 8),     # s_2
+        ],
+        free_var_shapes=[
+            (30, 5),        # sc_0
+            (30, 5),        # sc_1
+            (5, ),          # f_0
+            (3, 4, 5, 6),   # f_1, unused
+        ],
+        n_steps=4,
+    )
+    # Case 4.2.*
+    case_4(
+        length=5,
+        cond=make_for_cond(length=5),
+        single_shape=[5, 12],
+        loop_var_shapes=[
+            (1, ),          # i
+            (5, 12),        # s_0
+            (5, 12),        # s_1
+            (23, 6, 8),     # s_2
+        ],
+        free_var_shapes=[
+            (30, 5, 12),    # sc_0
+            (30, 5, 12),    # sc_1
+            (5, 12),        # f_0
+            (3, 4, 5, 6),   # f_1, unused
+        ],
+        n_steps=5,
+    )
+    # Case 5.1.*
+    case_5(
+        length=4,
+        cond=make_for_cond(length=4),
+        single_shape=[5],
+        loop_var_shapes=[
+            (1, ),          # i
+            (5, ),          # s_0
+            (5, ),          # s_1
+            (23, 6, 8),     # s_2
+        ],
+        free_var_shapes=[
+            (30, 5),        # sc_0
+            (30, 5),        # sc_1
+            (5, ),          # f_0
+            (3, 4, 5, 6),   # f_1, unused
+        ],
+        n_steps=4,
+    )
+    # Case 5.2.*
+    case_5(
+        length=5,
+        cond=make_for_cond(length=5),
+        single_shape=[3, 4, 2],
+        loop_var_shapes=[
+            (1, ),          # i
+            (3, 4, 2),      # s_0
+            (3, 4, 2),      # s_1
+            (23, 6, 8),     # s_2
+        ],
+        free_var_shapes=[
+            (30, 3, 4, 2),  # sc_0
+            (30, 3, 4, 2),  # sc_1
+            (3, 4, 2),      # f_0
+            (3, 4, 5, 6),   # f_1, unused
+        ],
+        n_steps=5,
+    )
+    # Case 6.*
+    case_6(
+        length=5,
+        cond=make_for_cond(length=5),
+        single_shape=[5, 3],
+        loop_var_shapes=[
+            (1, ),          # i
+            (5, 3),         # s_0
+            (5, 3),         # s_1
+            (3, 5),         # s_2
+        ],
+        free_var_shapes=[
+            (30, 5, 3),     # sc_0
+            (30, 5, 3),     # sc_1
+            (5, 3),         # f_0
+            (3, 4, 5, 6),   # f_1, unused
+        ],
+        n_steps=5,
+    )
+
+
+def test_while_loop_nested():
+
+    def _to_np_list(arrays):
+        return [x.asnumpy() if x is not None else x for x in arrays]
+
+    def _array(shape):
+        return mx.nd.random.uniform(-1.0, 1.0, shape=shape)
+
+    def inner_cond(i, j, x_sum, sc):
+        return j < 2
+
+    def inner_body(i, j, x_sum, sc):
+        x_ij = sc.take(j).squeeze(axis=0)
+        return (x_ij, x_ij), (i, j + 1, x_sum, sc)
+
+    def outer_cond(i, j, x_sum, sc):
+        return i < 2
+
+    def outer_body(i, j, x_sum, sc):
+        F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd
+        (x_ij, x_ji), (i_p, j_p, x_sum_p, sc_p) = F.contrib.while_loop(
+            cond=inner_cond,
+            func=inner_body,
+            loop_vars=(i, j, x_sum, sc),
+            max_iterations=2,
+        )
+        return (x_ij, x_ji), (i_p + 1, j_p - 2, x_sum_p, sc_p)
+
+    def make_loop(i, j, x_sum, sc):
+        F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd
+        (x_ij, x_ji), (new_i, new_j, x_sum_p, sc_p) = F.contrib.while_loop(
+            cond=outer_cond,
+            func=outer_body,
+            loop_vars=(i, j, x_sum, sc),
+            max_iterations=2,
+        )
+        return new_i, new_j, x_sum_p, sc_p, x_ij, x_ji
+
+    args = {
+        "i": mx.nd.array([0]),
+        "j": mx.nd.array([0]),
+        "x_sum": _array([5, 3]),
+        "sc": _array([10, 10, 5, 3]),
+    }
+    args_grad = {
+        "x_sum": _array([5, 3]),
+        "sc": _array([10, 10, 5, 3]),
+    }
+    out_grad = [
+        _array([1]),
+        _array([1]),
+        _array([5, 3]),
+        _array([10, 10, 5, 3]),
+        _array([2, 2, 10, 5, 3]),
+        _array([2, 2, 10, 5, 3]),
+    ]
+    def _get_imp_result(is_train, args, args_grad, out_grad):
+        args = {k: v.copy() for k, v in args.items()}
+        args_grad = {k: v.copy() for k, v in args_grad.items()}
+        i, j, x_sum, sc = [args[x].copy() for x in ["i", "j", "x_sum", "sc"]]
+        if is_train:
+            x_sum.attach_grad()
+            sc.attach_grad()
+        with mx.autograd.record(train_mode=is_train):
+            results = make_loop(i, j, x_sum, sc)
+            cat_res = mx.nd.concat(*[x.reshape(-1) for x in results], dim=0)
+        if not is_train:
+            return _to_np_list(results), []
+        cat_grad = mx.nd.concat(*[x.reshape(-1) for x in out_grad], dim=0)
+        assert cat_grad.shape == cat_res.shape
+        cat_res.backward(out_grad=cat_grad)
+        grads = [x_sum.grad, sc.grad]
+        return _to_np_list(results), _to_np_list(grads)
+
+    def _get_sym_result(is_train, args, args_grad, out_grad):
+        args = {k: v.copy() for k, v in args.items()}
+        args_grad = {k: v.copy() for k, v in args_grad.items()}
+        i, j, x_sum, sc = [
+            mx.sym.var("i"),
+            mx.sym.var("j"),
+            mx.sym.var("x_sum"),
+            mx.sym.var("sc"),
+        ]
+        result_sym = mx.sym.Group(make_loop(i, j, x_sum, sc))
+        executor = result_sym.bind(
+            ctx=default_context(),
+            args=args,
+            args_grad=args_grad,
+        )
+        results = executor.forward(is_train=is_train)
+        if not is_train:
+            return _to_np_list(results), []
+        executor.backward(out_grads=out_grad)
+        grads = [executor.grad_dict["x_sum"], executor.grad_dict["sc"]]
+        return _to_np_list(results), _to_np_list(grads)
+
+    for is_train in [True, False]:
+        imp_out, imp_grad = _get_imp_result(is_train=is_train, args=args, 
args_grad=args_grad, out_grad=out_grad)
+        sym_out, sym_grad = _get_sym_result(is_train=is_train, args=args, 
args_grad=args_grad, out_grad=out_grad)
+        assert len(imp_out) == len(sym_out)
+        assert len(imp_grad) == len(sym_grad)
+        for x, y in zip(imp_out, sym_out):
+            assert_almost_equal(x, y, rtol=1e-4, atol=1e-4)
+        for x, y in zip(imp_grad, sym_grad):
+            assert_almost_equal(x, y, rtol=1e-4, atol=1e-4)
+
+
+def test_while_loop_rnn():
+    def _array(shape):
+        return mx.nd.random.uniform(-1.0, 1.0, shape=shape)
+
+    cell_types = [mx.rnn.LSTMCell]
+    num_params = [2]
+
+    batch_size = 2
+    hidden_dim = 3
+    input_dim = 4
+    seq_len = 3
+
+    for cell, n_param in zip(cell_types, num_params):
+        # using while_loop
+        params = mx.rnn.RNNParams()
+        data = mx.sym.var("data")
+        iter_i = mx.sym.var("i")
+        def _cond(*states):
+            i = states[0]
+            return i < seq_len
+        def _func(*states):
+            i = states[0]
+            states = states[1:]
+            in_ = data.take(i).squeeze(axis=0)
+            rnn = cell(hidden_dim, prefix='', params=params)
+            next_hidden, next_states = rnn(in_, states)
+            return [next_hidden], [i + 1] + list(next_states)
+        states = [mx.sym.var("s_" + str(i)) for i in range(n_param)]
+        result = mx.sym.contrib.while_loop(
+                    cond=_cond,
+                    func=_func,
+                    loop_vars=[iter_i] + states,
+                    max_iterations=seq_len
+                )
+        result = mx.sym.Group(result[0] + result[1][1: ])
+        arg_shapes, _, _ = result.infer_shape(
+            data=(seq_len, batch_size, input_dim),
+            s_0=(batch_size, hidden_dim),
+        )
+        rnn_inputs = result.list_inputs()
+        args = {name: _array(arg_shapes[i]) for i, name in 
enumerate(rnn_inputs) if name != "i"}
+        args["i"] = mx.nd.zeros([1])
+        args_grad = {name: _array(arg_shapes[i]) for i, name in 
enumerate(rnn_inputs)}
+        e_1 = result.bind(ctx=default_context(),
+            args={name: array.copy() for name, array in args.items()},
+            args_grad={name: array.copy() for name, array in args_grad.items() 
if name != "i"},
+        )
+        # using unrolled rnn
+        rnn = cell(hidden_dim, prefix='')
+        unroll_outs = []
+        for inputs in mx.sym.split(data, num_outputs=seq_len, axis=0, 
squeeze_axis=True):
+            h, states = rnn(inputs, states)
+            unroll_outs.append(mx.sym.expand_dims(h, axis=0))
+        unroll_outs = _as_list(mx.sym.concat(*unroll_outs, dim=0))
+        unroll_outs.extend(states)
+        result = mx.sym.Group(unroll_outs)
+        e_2 = result.bind(ctx=default_context(),
+            args={name: array.copy() for name, array in args.items() if name 
!= "i"},
+            args_grad={name: array.copy() for name, array in args_grad.items() 
if name != "i"},
+        )
+        for case_id in range(100):
+            out_grads = [_array(arr.shape) for arr in e_1.outputs]
+            args = {name: array.copy() for name, array in args.items()}
+            e_1.forward(is_train=True, **args)
+            e_1.backward(out_grads)
+            args = {name: array.copy() for name, array in args.items() if name 
!= "i"}
+            e_2.forward(is_train=True, **args)
+            e_2.backward(out_grads)
+            assert len(e_1.outputs) == len(e_2.outputs)
+            for x, y in zip(e_1.outputs, e_2.outputs):
+                x = x.asnumpy()
+                y = y.asnumpy()
+                assert_almost_equal(x, y, rtol=1e-4, atol=1e-4)
+            grad_keys = list(e_2.grad_dict.keys())
+            e_1_grad = [e_1.grad_dict[x] for x in grad_keys]
+            e_2_grad = [e_2.grad_dict[x] for x in grad_keys]
+            for x, y in zip(e_1_grad, e_2_grad):
+                x = x.asnumpy()
+                y = y.asnumpy()
+                assert_almost_equal(x, y, rtol=1e-4, atol=1e-4)
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()

Reply via email to