This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 54632bc [MXNET-626] Add while_loop (#11566)
54632bc is described below
commit 54632bcb38064a0ed1f23dd652897562d3a0036a
Author: Junru Shao <[email protected]>
AuthorDate: Wed Jul 18 17:09:10 2018 -0700
[MXNET-626] Add while_loop (#11566)
* Add while_loop
* Avoid input/output overlap for nnvm graph cut
* Add more testcases
* Enhance test 4.2
* Add more complicated testcases; Add testcase for nested loop
* Check unused loop_vars in while_loop
* Add testcases for RNN
* Make lint happy
* Make lint happy
* Address TODOs
* Fix flaky test for while_loop
* Address comments
* Improve docstring
* Improve error message
* Add benchmark code
* Update benchmarks
* Allow sparse types
* Make max_iterations default to None
* Add while_loop to docs/api/python/{symbol|ndarray}/contrib.md
* Pad imperative while_loop so that it has the same shape with the symbolic
one
* Add example result into the example section
* Remove unused class member
* Rename unittest to test_contrib_control_flow.py
* Update docstring
* Update docstring
* Trigger CI
* Change threshold for assert_almost_equal
* Trigger CI
* Address comments from szha
* Rewrite benchmark code
* Fix sphinx warning
---
3rdparty/tvm | 2 +-
.../python/control_flow/{rnn.py => foreach_rnn.py} | 12 +-
benchmark/python/control_flow/rnn.py | 273 +++---
.../control_flow/{rnn.py => while_loop_rnn.py} | 98 ++-
docs/api/python/ndarray/contrib.md | 1 +
docs/api/python/symbol/contrib.md | 1 +
python/mxnet/ndarray/contrib.py | 174 +++-
python/mxnet/symbol/contrib.py | 222 ++++-
src/operator/control_flow.cc | 563 +++++++++++-
src/operator/subgraph_op_common.cc | 9 +-
src/operator/subgraph_op_common.h | 14 +-
tests/python/unittest/test_contrib_control_flow.py | 978 +++++++++++++++++++++
12 files changed, 2133 insertions(+), 214 deletions(-)
diff --git a/3rdparty/tvm b/3rdparty/tvm
index 6ab4da6..290226e 160000
--- a/3rdparty/tvm
+++ b/3rdparty/tvm
@@ -1 +1 @@
-Subproject commit 6ab4da6783417d8afdeb6b0426b44959b2afc709
+Subproject commit 290226e1c9adbb3e598f9ed9184018df1c12be33
diff --git a/benchmark/python/control_flow/rnn.py
b/benchmark/python/control_flow/foreach_rnn.py
similarity index 92%
copy from benchmark/python/control_flow/rnn.py
copy to benchmark/python/control_flow/foreach_rnn.py
index 5e41b75..4ce7a42 100644
--- a/benchmark/python/control_flow/rnn.py
+++ b/benchmark/python/control_flow/foreach_rnn.py
@@ -157,7 +157,8 @@ if __name__ == '__main__':
ndim = 512
seq_len = 100
batch_sizes = [1, 32]
- cells = [gluon.rnn.GRUCell(ndim, prefix='rnn_'),
+ cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'),
+ gluon.rnn.GRUCell(ndim, prefix='rnn_'),
gluon.rnn.LSTMCell(ndim, prefix='rnn_')]
ctxs = [mx.cpu(0), mx.gpu(0)]
for cell in cells:
@@ -165,8 +166,13 @@ if __name__ == '__main__':
for batch_size in batch_sizes:
if len(get_gpus()) == 0 and ctx == mx.gpu(0):
continue
-
- if isinstance(cell, gluon.rnn.GRUCell):
+ if isinstance(cell, gluon.rnn.RNNCell):
+ rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len,
batch_size, ndim),
+ ctx=mx.cpu(0))
+ states = []
+ states.append(mx.nd.normal(loc=0, scale=1,
shape=(batch_size, ndim),
+ ctx=mx.cpu(0)))
+ elif isinstance(cell, gluon.rnn.GRUCell):
rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len,
batch_size, ndim),
ctx=mx.cpu(0))
states = []
diff --git a/benchmark/python/control_flow/rnn.py
b/benchmark/python/control_flow/rnn.py
index 5e41b75..8a44a9c 100644
--- a/benchmark/python/control_flow/rnn.py
+++ b/benchmark/python/control_flow/rnn.py
@@ -15,175 +15,128 @@
# specific language governing permissions and limitations
# under the License.
+from __future__ import print_function
+from six.moves import range
+
+import argparse
import subprocess
+from itertools import product
+from time import time
+
import mxnet as mx
+import numpy as np
from mxnet import gluon
-import time
-import copy
-def get_gpus():
- """
- return a list of GPUs
- """
- try:
- re = subprocess.check_output(["nvidia-smi", "-L"],
universal_newlines=True)
- except OSError:
- return []
- return range(len([i for i in re.split('\n') if 'GPU' in i]))
-class TestRNNLayer(gluon.HybridBlock):
- def __init__(self, cell, prefix=None, params=None):
- super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
+_parser = argparse.ArgumentParser(description='Benchmark foreach and
while_loop on RNN tasks.')
+_parser.add_argument('--benchmark', choices=["foreach", "while_loop"],
required=True)
+_parser.add_argument('--warmup_rounds', type=int, default=20)
+_parser.add_argument('--test_rounds', type=int, default=100)
+args = _parser.parse_args()
+
+
+class ForeachRNN(gluon.HybridBlock):
+ def __init__(self, cell, length, prefix=None, params=None):
+ super(ForeachRNN, self).__init__(prefix=prefix, params=params)
+ self.length = length
self.cell = cell
def hybrid_forward(self, F, inputs, states):
out, states = F.contrib.foreach(self.cell, inputs, states)
return out
-def benchmark_rnn(cell, rnn_data, states):
- ctx = rnn_data.context
- num_batches = 20
-
- # Imperative
- cell0 = copy.deepcopy(cell)
- layer0 = TestRNNLayer(cell0)
- layer0.initialize(ctx=ctx)
-
- # Hybridize
- cell1 = copy.deepcopy(cell)
- cell1.hybridize()
- layer1 = TestRNNLayer(cell1)
- layer1.initialize(ctx=ctx)
-
- # Hybridize
- cell2 = copy.deepcopy(cell)
- layer2 = TestRNNLayer(cell2)
- layer2.initialize(ctx=ctx)
- layer2.hybridize()
- layer2(rnn_data, states)
-
- # Hybridize
- cell3 = copy.deepcopy(cell)
- cell3.hybridize(static_alloc=True)
- layer3 = TestRNNLayer(cell3)
- layer3.initialize(ctx=ctx)
-
- tic = time.time()
- for i in range(num_batches):
- res0 = layer0(rnn_data, states)
- mx.nd.waitall()
- print("Imperative inference takes " + str(time.time() - tic))
-
- tic = time.time()
- for i in range(num_batches):
- res1 = layer1(rnn_data, states)
- mx.nd.waitall()
- print("Hybrid-cell inference takes " + str(time.time() - tic))
-
- tic = time.time()
- for i in range(num_batches):
- res3 = layer3(rnn_data, states)
- mx.nd.waitall()
- print("Static-hybrid-cell inference takes " + str(time.time() - tic))
-
- tic = time.time()
- for i in range(num_batches):
- res2 = layer2(rnn_data, states)
- mx.nd.waitall()
- print("Hybrid inference takes " + str(time.time() - tic))
-
- layer2.export("foreach_rnn")
- symnet = mx.symbol.load('foreach_rnn-symbol.json')
- args1 = {}
- params = layer2.collect_params()
- for key in params.keys():
- args1[key] = params[key].data()
- args1['data0'] = rnn_data
- for i in range(len(states)):
- args1['data' + str(i + 1)] = states[i]
- exe = symnet.bind(ctx=ctx, args=args1)
- tic = time.time()
- for i in range(num_batches):
- exe.forward(is_train=False)
- mx.nd.waitall()
- print("Symbol inference takes " + str(time.time() - tic))
-
- tic = time.time()
- for i in range(num_batches):
- with mx.autograd.record():
- res0 = layer0(rnn_data, states)
- res0.backward()
- mx.nd.waitall()
- print("Imperative training takes " + str(time.time() - tic))
-
- tic = time.time()
- for i in range(num_batches):
- with mx.autograd.record():
- res1 = layer1(rnn_data, states)
- res1.backward()
- mx.nd.waitall()
- print("Hybrid-cell training takes " + str(time.time() - tic))
-
- tic = time.time()
- for i in range(num_batches):
- with mx.autograd.record():
- res3 = layer3(rnn_data, states)
- res3.backward()
- mx.nd.waitall()
- print("Static-hybrid-cell training takes " + str(time.time() - tic))
-
- tic = time.time()
- for i in range(num_batches):
- with mx.autograd.record():
- res2 = layer2(rnn_data, states)
- res2.backward()
- mx.nd.waitall()
- print("Hybrid training takes " + str(time.time() - tic))
-
- # gradients for the backward of the foreach symbol
- args_grad1 = {}
- for key in args1.keys():
- args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx)
- exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1)
- tic = time.time()
- for i in range(num_batches):
- exe.forward(is_train=True)
- exe.backward(res2)
- mx.nd.waitall()
- print("Symbol training takes " + str(time.time() - tic))
- print("")
-
-if __name__ == '__main__':
- ndim = 512
- seq_len = 100
+
+class WhileRNN(gluon.HybridBlock):
+ def __init__(self, cell, length, prefix=None, params=None):
+ super(WhileRNN, self).__init__(prefix=prefix, params=params)
+ self.length = length
+ self.cell = cell
+
+ def hybrid_forward(self, F, inputs, states):
+ def _func(*states):
+ i = states[0]
+ s = states[1: ]
+ data = inputs.take(i).squeeze(axis=0)
+ out, new_s = self.cell(data, s)
+ new_s = [i + 1] + new_s
+ return out, new_s
+ out, states = F.contrib.while_loop(
+ cond=lambda i, *_: i < self.length,
+ func=_func,
+ loop_vars=states,
+ max_iterations=self.length,
+ )
+ assert len(out) == 1
+ return out[0]
+
+
+def _zeros(shape, ctx):
+ return mx.nd.zeros(shape=shape, ctx=ctx)
+
+
+def _array(shape, ctx):
+ return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=ctx)
+
+
+def _get_gpus():
+ try:
+ re = subprocess.check_output(["nvidia-smi", "-L"],
universal_newlines=True)
+ except OSError:
+ return []
+ return range(len([i for i in re.split('\n') if 'GPU' in i]))
+
+
+def run_benchmark(cell_type, ctx, seq_len, batch_size, hidden_dim):
+ obj = {"foreach": ForeachRNN, "while_loop": WhileRNN}[args.benchmark]
+ inputs = _array((seq_len, batch_size, hidden_dim), ctx)
+ states = [_array((batch_size, hidden_dim), ctx) for _ in
cell_type(0).state_info()]
+ if args.benchmark == "while_loop":
+ states.insert(0, _zeros((1, ), ctx))
+
+ for is_train, is_hyb_cell, is_hyb_layer in product([True, False], [False,
True], [False, True]):
+ cell = cell_type(hidden_dim)
+ if is_hyb_cell:
+ cell.hybridize(static_alloc=True)
+ layer = obj(cell, seq_len)
+ layer.initialize(ctx=ctx)
+ if is_hyb_layer:
+ layer.hybridize(static_alloc=True)
+ print("is_train = %r, hybridize_cell = %r, hybridize_layer = %r" %
(is_train, is_hyb_cell, is_hyb_layer))
+ times = []
+ for _ in range(args.warmup_rounds + args.test_rounds):
+ tick = time()
+ if not is_train:
+ res = layer(inputs, states)
+ else:
+ with mx.autograd.record():
+ res = layer(inputs, states)
+ if is_train:
+ res.backward()
+ mx.nd.waitall()
+ tock = time()
+ times.append((tock - tick) * 1000.0)
+ times = times[args.warmup_rounds: ]
+ print("Time used: mean = %.3f ms, std = %.3f ms" % (np.mean(times),
np.std(times)))
+
+
+def main():
+ # testing configurations
+ cell_types = [gluon.rnn.RNNCell,
+ gluon.rnn.GRUCell,
+ gluon.rnn.LSTMCell]
+ ctxs = [mx.cpu(0)] + [mx.gpu(i) for i in _get_gpus()]
+ seq_lens = [100]
batch_sizes = [1, 32]
- cells = [gluon.rnn.GRUCell(ndim, prefix='rnn_'),
- gluon.rnn.LSTMCell(ndim, prefix='rnn_')]
- ctxs = [mx.cpu(0), mx.gpu(0)]
- for cell in cells:
- for ctx in ctxs:
- for batch_size in batch_sizes:
- if len(get_gpus()) == 0 and ctx == mx.gpu(0):
- continue
-
- if isinstance(cell, gluon.rnn.GRUCell):
- rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len,
batch_size, ndim),
- ctx=mx.cpu(0))
- states = []
- states.append(mx.nd.normal(loc=0, scale=1,
shape=(batch_size, ndim),
- ctx=mx.cpu(0)))
- elif isinstance(cell, gluon.rnn.LSTMCell):
- rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len,
batch_size, ndim),
- ctx=mx.cpu(0))
- states = []
- states.append(mx.nd.normal(loc=0, scale=1,
shape=(batch_size, ndim),
- ctx=mx.cpu(0)))
- states.append(mx.nd.normal(loc=0, scale=1,
shape=(batch_size, ndim),
- ctx=mx.cpu(0)))
- if ctx == mx.gpu(0):
- dev = "GPU"
- else:
- dev = "CPU"
- print("Benchmark {} in {} (batch size:
{})".format(cell._alias(), dev,
- batch_size))
- benchmark_rnn(cell, rnn_data, states)
+ hidden_dims = [512]
+ print("--------------------------------------")
+ print("Benchmarking", args.benchmark)
+ for cell_type, ctx, seq_len, batch_size, hidden_dim in product( \
+ cell_types, ctxs, seq_lens, batch_sizes, hidden_dims):
+ print("--------------------------------------")
+ print("cell: %s ctx: %s length: %d batch size: %d dim: %d" % \
+ (cell_type.__name__, str(ctx), seq_len, batch_size, hidden_dim))
+ run_benchmark(cell_type, ctx, seq_len, batch_size, hidden_dim)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmark/python/control_flow/rnn.py
b/benchmark/python/control_flow/while_loop_rnn.py
similarity index 67%
copy from benchmark/python/control_flow/rnn.py
copy to benchmark/python/control_flow/while_loop_rnn.py
index 5e41b75..42aaee5 100644
--- a/benchmark/python/control_flow/rnn.py
+++ b/benchmark/python/control_flow/while_loop_rnn.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
+# Code borrowed from ./benchmark/python/control_flow/foreach_rnn.py
+
import subprocess
import mxnet as mx
from mxnet import gluon
@@ -32,40 +34,53 @@ def get_gpus():
return range(len([i for i in re.split('\n') if 'GPU' in i]))
class TestRNNLayer(gluon.HybridBlock):
- def __init__(self, cell, prefix=None, params=None):
+ def __init__(self, cell, length, prefix=None, params=None):
super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
+ self.length = length
self.cell = cell
def hybrid_forward(self, F, inputs, states):
- out, states = F.contrib.foreach(self.cell, inputs, states)
- return out
-
-def benchmark_rnn(cell, rnn_data, states):
+ def _func(*states):
+ i = states[0]
+ s = states[1: ]
+ data = inputs.take(i).squeeze(axis=0)
+ out, new_s = self.cell(data, s)
+ new_s = [i + 1] + new_s
+ return out, new_s
+ out, states = F.contrib.while_loop(
+ cond=lambda i, *_: i < self.length,
+ func=_func,
+ loop_vars=states,
+ max_iterations=self.length,
+ )
+ return out + states
+
+def benchmark_rnn(cell, rnn_data, states, length):
ctx = rnn_data.context
num_batches = 20
# Imperative
cell0 = copy.deepcopy(cell)
- layer0 = TestRNNLayer(cell0)
+ layer0 = TestRNNLayer(cell0, length)
layer0.initialize(ctx=ctx)
- # Hybridize
+ # Hybrid-cell
cell1 = copy.deepcopy(cell)
cell1.hybridize()
- layer1 = TestRNNLayer(cell1)
+ layer1 = TestRNNLayer(cell1, length)
layer1.initialize(ctx=ctx)
- # Hybridize
+ # Hybrid
cell2 = copy.deepcopy(cell)
- layer2 = TestRNNLayer(cell2)
+ layer2 = TestRNNLayer(cell2, length)
layer2.initialize(ctx=ctx)
layer2.hybridize()
layer2(rnn_data, states)
- # Hybridize
+ # Static-hybrid-cell
cell3 = copy.deepcopy(cell)
cell3.hybridize(static_alloc=True)
- layer3 = TestRNNLayer(cell3)
+ layer3 = TestRNNLayer(cell3, length)
layer3.initialize(ctx=ctx)
tic = time.time()
@@ -92,8 +107,8 @@ def benchmark_rnn(cell, rnn_data, states):
mx.nd.waitall()
print("Hybrid inference takes " + str(time.time() - tic))
- layer2.export("foreach_rnn")
- symnet = mx.symbol.load('foreach_rnn-symbol.json')
+ layer2.export("while_loop_rnn")
+ symnet = mx.symbol.load('while_loop_rnn-symbol.json')
args1 = {}
params = layer2.collect_params()
for key in params.keys():
@@ -112,7 +127,7 @@ def benchmark_rnn(cell, rnn_data, states):
for i in range(num_batches):
with mx.autograd.record():
res0 = layer0(rnn_data, states)
- res0.backward()
+ res0[0].backward()
mx.nd.waitall()
print("Imperative training takes " + str(time.time() - tic))
@@ -120,7 +135,7 @@ def benchmark_rnn(cell, rnn_data, states):
for i in range(num_batches):
with mx.autograd.record():
res1 = layer1(rnn_data, states)
- res1.backward()
+ res1[0].backward()
mx.nd.waitall()
print("Hybrid-cell training takes " + str(time.time() - tic))
@@ -128,7 +143,7 @@ def benchmark_rnn(cell, rnn_data, states):
for i in range(num_batches):
with mx.autograd.record():
res3 = layer3(rnn_data, states)
- res3.backward()
+ res3[0].backward()
mx.nd.waitall()
print("Static-hybrid-cell training takes " + str(time.time() - tic))
@@ -136,14 +151,15 @@ def benchmark_rnn(cell, rnn_data, states):
for i in range(num_batches):
with mx.autograd.record():
res2 = layer2(rnn_data, states)
- res2.backward()
+ res2[0].backward()
mx.nd.waitall()
print("Hybrid training takes " + str(time.time() - tic))
- # gradients for the backward of the foreach symbol
+ # gradients for the backward of the while_loop symbol
args_grad1 = {}
for key in args1.keys():
- args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx)
+ if key != "data1":
+ args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx)
exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1)
tic = time.time()
for i in range(num_batches):
@@ -154,10 +170,15 @@ def benchmark_rnn(cell, rnn_data, states):
print("")
if __name__ == '__main__':
+ def _zeros(shape):
+ return mx.nd.zeros(shape=shape, ctx=mx.cpu(0))
+ def _array(shape):
+ return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=mx.cpu(0))
ndim = 512
seq_len = 100
batch_sizes = [1, 32]
- cells = [gluon.rnn.GRUCell(ndim, prefix='rnn_'),
+ cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'),
+ gluon.rnn.GRUCell(ndim, prefix='rnn_'),
gluon.rnn.LSTMCell(ndim, prefix='rnn_')]
ctxs = [mx.cpu(0), mx.gpu(0)]
for cell in cells:
@@ -165,25 +186,28 @@ if __name__ == '__main__':
for batch_size in batch_sizes:
if len(get_gpus()) == 0 and ctx == mx.gpu(0):
continue
-
+ if isinstance(cell, gluon.rnn.RNNCell):
+ rnn_data = _array((seq_len, batch_size, ndim))
+ states = [
+ _zeros((1, )),
+ _array((batch_size, ndim)),
+ ]
if isinstance(cell, gluon.rnn.GRUCell):
- rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len,
batch_size, ndim),
- ctx=mx.cpu(0))
- states = []
- states.append(mx.nd.normal(loc=0, scale=1,
shape=(batch_size, ndim),
- ctx=mx.cpu(0)))
+ rnn_data = _array((seq_len, batch_size, ndim))
+ states = [
+ _zeros((1, )),
+ _array((batch_size, ndim)),
+ ]
elif isinstance(cell, gluon.rnn.LSTMCell):
- rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len,
batch_size, ndim),
- ctx=mx.cpu(0))
- states = []
- states.append(mx.nd.normal(loc=0, scale=1,
shape=(batch_size, ndim),
- ctx=mx.cpu(0)))
- states.append(mx.nd.normal(loc=0, scale=1,
shape=(batch_size, ndim),
- ctx=mx.cpu(0)))
+ rnn_data = _array((seq_len, batch_size, ndim))
+ states = [
+ _zeros((1, )),
+ _array((batch_size, ndim)),
+ _array((batch_size, ndim)),
+ ]
if ctx == mx.gpu(0):
dev = "GPU"
else:
dev = "CPU"
- print("Benchmark {} in {} (batch size:
{})".format(cell._alias(), dev,
- batch_size))
- benchmark_rnn(cell, rnn_data, states)
+ print("Benchmark {} in {} (batch size:
{})".format(cell._alias(), dev, batch_size))
+ benchmark_rnn(cell, rnn_data, states, seq_len)
diff --git a/docs/api/python/ndarray/contrib.md
b/docs/api/python/ndarray/contrib.md
index 36a2c15..0cf8724 100644
--- a/docs/api/python/ndarray/contrib.md
+++ b/docs/api/python/ndarray/contrib.md
@@ -53,6 +53,7 @@ In the rest of this document, we list routines provided by
the `ndarray.contrib`
ifft
quantize
foreach
+ while_loop
```
## API Reference
diff --git a/docs/api/python/symbol/contrib.md
b/docs/api/python/symbol/contrib.md
index 6647165..ba43f2d 100644
--- a/docs/api/python/symbol/contrib.md
+++ b/docs/api/python/symbol/contrib.md
@@ -53,6 +53,7 @@ In the rest of this document, we list routines provided by
the `symbol.contrib`
ifft
quantize
foreach
+ while_loop
```
## API Reference
diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py
index b1f065e..b67cf5a 100644
--- a/python/mxnet/ndarray/contrib.py
+++ b/python/mxnet/ndarray/contrib.py
@@ -28,7 +28,7 @@ try:
except ImportError:
pass
-__all__ = ["rand_zipfian"]
+__all__ = ["rand_zipfian", "foreach", "while_loop"]
# pylint: disable=line-too-long
def rand_zipfian(true_classes, num_sampled, range_max, ctx=None):
@@ -191,3 +191,175 @@ def foreach(body, data, init_states):
if not_data_list and len(outputs) == 1:
outputs = outputs[0]
return (outputs, states)
+
+
+def while_loop(cond, func, loop_vars, max_iterations=None):
+ """Run a while loop with user-defined computation and loop condition.
+
+ This operator simulates a while loop which iterately does customized
computation
+ as long as the condition is satisfied.
+
+ `loop_vars` is a list of NDArrays on which the computation uses.
+
+ `cond` is a user-defined function, used as the loop condition.
+ It consumes `loop_vars`, and produces a scalar MXNet NDArray,
+ indicating the termination of the loop.
+ The loop ends when `cond` returns false (zero).
+ The `cond` is variadic, and its signature should be
+ `cond(*loop_vars) => NDArray`.
+
+ `func` is a user-defined function, used as the loop body.
+ It also consumes `loop_vars`, and produces `step_output` and
`new_loop_vars` at each step.
+ In each step, `step_output` should contain the same number elements.
+ Through all steps, the i-th element of `step_output` should have the same
shape and dtype.
+ Also, `new_loop_vars` should contain the same number of elements as
`loop_vars`,
+ and the corresponding element should have the same shape and dtype.
+ The `func` is variadic, and its signature should be
+ `func(*loop_vars) => (List[NDArray] step_output, List[NDArray]
new_loop_vars)`.
+
+ `max_iterations` is a scalar that defines the maximum number of iterations
allowed.
+
+ This function returns two lists.
+ The first list has the length of `|step_output|`,
+ in which the i-th element are all i-th elements of
+ `step_output` from all steps, stacked along axis 0.
+ The second list has the length of `|loop_vars|`,
+ which represents final states of loop variables.
+
+ .. warning::
+
+ For now, the axis 0 of all NDArrays in the first list are
`max_iterations`,
+ due to lack of dynamic shape inference.
+
+ .. warning::
+
+ When `cond` is never satisfied, we assume `step_output` is empty,
+ because it cannot be inferred. This is different from the symbolic
version.
+
+ Parameters
+ ----------
+ cond: a Python function.
+ The loop condition.
+ func: a Python function.
+ The loop body.
+ loop_vars: list of NDArrays.
+ The initial values of the loop variables.
+ max_iterations: a python int.
+ Maximum number of iterations.
+
+ Returns
+ ------
+ outputs: list of NDArrays
+ stacked output from each step
+ states: list of NDArrays
+ final state
+
+ Examples
+ --------
+ >>> cond = lambda i, s: i <= 5
+ >>> func = lambda i, s: ([i + s], [i + 1, s + i])
+ >>> loop_vars = (mx.nd.array([0], dtype="int64"), mx.nd.array([1],
dtype="int64"))
+ >>> outputs, states = mx.nd.contrib.while_loop(cond, func, loop_vars,
max_iterations=10)
+ >>> outputs
+ [
+ [[ 1]
+ [ 2]
+ [ 4]
+ [ 7]
+ [11]
+ [16]
+ [...] # undefined value
+ [...]
+ [...]
+ [...]]
+ <NDArray 6x1 @cpu(0)>]
+ >>> states
+ [
+ [6]
+ <NDArray 1 @cpu(0)>,
+ [16]
+ <NDArray 1 @cpu(0)>]
+ """
+ def _to_python_scalar(inputs, type_, name):
+ """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray,
other python types,
+ to the given type
+ """
+ if isinstance(inputs, ndarray.NDArray):
+ inputs = inputs.asscalar()
+ try:
+ inputs = type_(inputs)
+ except:
+ raise ValueError("Cannot convert %s to python %s" % (name,
type_.__name__))
+ return inputs
+
+ def _to_ndarray_tuple(inputs, name):
+ """Converts "inputs", possibly a single mxnet NDArray, a list of mxnet
NDArray,
+ a tuple of mxnet NDArray, into a tuple of NDArray
+ """
+ if isinstance(inputs, list):
+ inputs = tuple(inputs)
+ if isinstance(inputs, ndarray.NDArray):
+ inputs = (inputs, )
+ if not isinstance(inputs, tuple):
+ raise ValueError("%s must be an NDArray, or a tuple or list of
NDArrays" % (name, ))
+ for item in inputs:
+ if not isinstance(item, ndarray.NDArray):
+ raise ValueError("%s must be an NDArray, or a tuple or list of
NDArrays" % (name, ))
+ return inputs
+
+ def _func_wrapper(loop_vars):
+ """This wrapper unifies
+ "func: loop_vars -> new_loop_vars"
+ and "func: loop_vars -> (step_output, new_loop_vars)"
+ into "func: loop_vars -> (None or tuple of step_outputs, tuple of
new_loop_vars)
+ """
+ step_output, new_loop_vars = func(*loop_vars)
+ if step_output is None:
+ step_output = []
+ if new_loop_vars is None:
+ new_loop_vars = []
+ step_output = _to_ndarray_tuple(step_output, "step_output")
+ new_loop_vars = _to_ndarray_tuple(new_loop_vars, "new_loop_vars")
+ if len(loop_vars) != len(new_loop_vars):
+ raise ValueError("The length of loop_vars should be consistent
during the loop")
+ return step_output, new_loop_vars
+
+ if max_iterations is None:
+ raise ValueError("max_iterations should be specified")
+ max_iterations = _to_python_scalar(max_iterations, int, "max_iteration")
+ loop_vars = _to_ndarray_tuple(loop_vars, "loop_vars")
+ # It should be work as fine if loop_vars are empty I guess,
+ # but it is semantically unnecessary to include this case.
+ if len(loop_vars) == 0:
+ raise ValueError("loop_vars should contain at least one element")
+
+ steps = 0
+ outputs = []
+ while steps < max_iterations and \
+ _to_python_scalar(cond(*loop_vars), bool, "Return value of cond"):
# loop condition
+ step_output, loop_vars = _func_wrapper(loop_vars)
+ outputs.append(step_output)
+ steps += 1
+ if len(outputs) != steps or len(step_output) != len(outputs[0]):
+ raise ValueError("Number of elements in step_output should be the
same in each step")
+ stacked_outputs = []
+ for i_th, items in enumerate(zip(*outputs), 1):
+ # `mx.ndarray.pad` only support 4-D or 5-D inputs for now
+ # so we could not use it.
+ items = [x.expand_dims(0) for x in items]
+ if steps != max_iterations and items:
+ pad_shape = [max_iterations - steps] + list(items[0].shape[1: ])
+ pad = ndarray.empty(
+ shape=pad_shape,
+ ctx=items[0].context,
+ dtype=items[0].dtype,
+ )
+ items = list(items) + [pad]
+ try:
+ stacked_outputs.append(ndarray.op.concat(*items, dim=0))
+ except ValueError:
+ raise ValueError("\n".join(
+ ["Shapes of %d-th elements in step_outputs are inconsistent,
which are:" % i_th] +
+ [" Step %d, shape is %s" % (i, str(x.shape)) for i, x in
enumerate(items)]
+ ))
+ return stacked_outputs, list(loop_vars)
diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py
index 28bb507..2c11921 100644
--- a/python/mxnet/symbol/contrib.py
+++ b/python/mxnet/symbol/contrib.py
@@ -34,7 +34,7 @@ from ..base import _LIB, check_call
from ..base import SymbolHandle, _as_list
from ..attribute import AttrScope
-__all__ = ["rand_zipfian", "foreach"]
+__all__ = ["rand_zipfian", "foreach", "while_loop"]
def rand_zipfian(true_classes, num_sampled, range_max):
"""Draw random samples from an approximately log-uniform or Zipfian
distribution.
@@ -336,3 +336,223 @@ def foreach(body, data, init_states, name="foreach"):
states = states[0]
return (outs, states)
+
+def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"):
+ """Run a while loop with user-defined computation and loop condition.
+
+ This operator simulates a while loop which iterately does customized
computation
+ as long as the condition is satisfied.
+
+ `loop_vars` is a list of Symbols on which the computation uses.
+
+ `cond` is a user-defined function, used as the loop condition.
+ It consumes `loop_vars`, and produces a scalar MXNet symbol,
+ indicating the termination of the loop.
+ The loop ends when `cond` returns false (zero).
+ The `cond` is variadic, and its signature should be
+ `cond(*loop_vars) => Symbol`.
+
+ `func` is a user-defined function, used as the loop body.
+ It also consumes `loop_vars`, and produces `step_output` and
`new_loop_vars` at each step.
+ In each step, `step_output` should contain the same number elements.
+ Through all steps, the i-th element of `step_output` should have the same
shape and dtype.
+ Also, `new_loop_vars` should contain the same number of elements as
`loop_vars`,
+ and the corresponding element should have the same shape and dtype.
+ The `func` is variadic, and its signature should be
+ `func(*loop_vars) => (List[Symbol] step_output, List[Symbol]
new_loop_vars)`.
+
+ `max_iterations` is a scalar that defines the maximum number of iterations
allowed.
+
+ This function returns two lists.
+ The first list has the length of `|step_output|`,
+ in which the i-th element are all i-th elements of
+ `step_output` from all steps, stacked along axis 0.
+ The second list has the length of `|loop_vars|`,
+ which represents final states of loop variables.
+
+ .. warning::
+
+ For now, the axis 0 of all Symbols in the first list are
`max_iterations`,
+ due to lack of dynamic shape inference.
+
+ .. warning::
+
+ Even if `cond` is never satisfied,
+ while_loop returns a list of outputs with inferred dtype and shape.
+ This is different from the Symbol version,
+ where in this case `step_outputs` are assumed as an empty list.
+
+ Parameters
+ ----------
+ cond: a Python function.
+ The loop condition.
+ func: a Python function.
+ The loop body.
+ loop_vars: list of Symbol.
+ The initial values of the loop variables.
+ max_iterations: a python int.
+ Maximum number of iterations.
+
+ Returns
+ ------
+ outputs: list of Symbols
+ stacked output from each step
+ states: list of Symbols
+ final state
+
+ Examples
+ --------
+ >>> cond = lambda i, s: i <= 5
+ >>> func = lambda i, s: ([i + s], [i + 1, s + i])
+ >>> loop_vars = (mx.sym.var('i'), mx.sym.var('s'))
+ >>> outputs, states = mx.sym.contrib.while_loop(cond, func, loop_vars,
max_iterations=10)
+ """
+ def _to_python_scalar(inputs, type_, name):
+ """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray,
other python types,
+ to the given type
+ """
+ if hasattr(inputs, "asscalar"):
+ inputs = inputs.asscalar()
+ try:
+ inputs = type_(inputs)
+ except:
+ raise ValueError("Cannot convert %s to python %s" % (name,
type_.__name__))
+ return inputs
+
+ def _to_symbol_tuple(inputs, name):
+ """Converts "inputs", possibly a single mxnet Symbol, a list of mxnet
Symbol,
+ a tuple of mxnet Symbol, into a tuple of Symbol
+ """
+ if isinstance(inputs, list):
+ inputs = tuple(inputs)
+ if isinstance(inputs, Symbol):
+ inputs = (inputs, )
+ if not isinstance(inputs, tuple):
+ raise ValueError("%s must be a Symbol, or a tuple or list of
Symbol" % (name, ))
+ for item in inputs:
+ if not isinstance(item, Symbol):
+ raise ValueError("%s must be a Symbol, or a tuple or list of
Symbol" % (name, ))
+ return inputs
+
+ def _cond_wrapper(loop_vars):
+ result = cond(*loop_vars)
+ if not isinstance(result, Symbol):
+ raise ValueError("Return of cond must be a Symbol")
+ return [], [result]
+
+ def _func_wrapper(loop_vars):
+ """This wrapper unifies
+ "func: loop_vars -> new_loop_vars"
+ and "func: loop_vars -> (step_output, new_loop_vars)"
+ into "func: loop_vars -> (list of step_outputs, tuple of new_loop_vars)
+ """
+ step_output, new_loop_vars = func(*loop_vars)
+ if step_output is None:
+ step_output = []
+ if new_loop_vars is None:
+ new_loop_vars = []
+ step_output = _to_symbol_tuple(step_output, "step_output")
+ new_loop_vars = _to_symbol_tuple(new_loop_vars, "new_loop_vars")
+ if len(loop_vars) != len(new_loop_vars):
+ raise ValueError("The number of loop_vars should be consistent
during the loop")
+ return list(step_output), list(new_loop_vars)
+
+ def _create_subgraph(graph_vars, graph_func, subgraph_name):
+ with AttrScope(__subgraph_name__=subgraph_name):
+ # create new variables with the same name,
+ # them feed them to the given func
+ new_graph_vars = [symbol.var(sym.name) for sym in graph_vars]
+ outputs, final_state = graph_func(new_graph_vars)
+ # first `num_out_data` elements belong to `outputs`
+ # other elements belong to `final_state`
+ num_out_data = len(outputs)
+ num_outputs = len(outputs) + len(final_state)
+ # nnvm cut-graph does not allow inputs and outputs overlap
+ # so we calculate the name of inputs, and copy outputs once it
overlaps with inputs
+ all_input_names = symbol.Group(outputs + final_state).list_inputs()
+ make_identity = lambda x: symbol.op.identity(x) if x.name in
all_input_names else x
+ # group all outputs of graph_func
+ graph = symbol.Group(list(map(make_identity, outputs +
final_state)))
+ return graph, num_out_data, num_outputs
+
+ def _union_inputs(*graphs):
+ # Given a list of graphs, each whose inputs are either from loop_vars
or other variables.
+ # 1) calculate a list `inputs`, the union of their inputs.
+ # 2) for each graph, determine in which indices their inputs reside in
`inputs`
+ # 3) for each variable in the input of `graph`, find which index it is
+ inputs = [] # List[Symbol], result of 1)
+ locs = [] # List[Tuple(List[Int], List[Int])], a list of
tuples,
+ # where tuples are results of 2) and 3)
+ input_id_to_loc = {} # Dict[int, int], given id(sym),
input_id_to_loc maps it
+ # to a `loc`, where inputs[loc] = sym
+ for graph in graphs:
+ # input_syms: all inputs to the `graph`
+ name_to_input_syms = {sym.name: sym for sym in
_get_graph_inputs(graph)}
+ # some loop_vars are inputs to `graph`, some are not
+ name_to_loop_vars = {sym.name: sym for sym in loop_vars}
+ # other inputs to `graph` created by cut_graph
+ name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in
_cut_subgraph(graph)}
+ # also we collect the mapping from var's name to var's loc in
loop_vars
+ name_to_var_locs = {sym.name: i for i, sym in enumerate(loop_vars)}
+ # collect arguments for each subgraph
+ input_locs = [] # results from the second
step
+ var_locs = [-1] * len(loop_vars) # results from the third
step
+ for name in graph.list_inputs():
+ assert name in name_to_input_syms # it should obviously hold
+ # name -> sym
+ if name in name_to_loop_vars:
+ sym = name_to_loop_vars[name]
+ elif name in name_to_cut_g_syms:
+ sym = name_to_cut_g_syms[name]
+ else:
+ sym = copy.deepcopy(name_to_input_syms[name])
+ # do 2), and 1) is implicitly done
+ if id(sym) in input_id_to_loc:
+ loc = input_id_to_loc[id(sym)]
+ else:
+ loc = len(input_id_to_loc)
+ inputs.append(sym)
+ input_id_to_loc[id(sym)] = loc
+ input_locs.append(loc)
+ # do 3)
+ if name in name_to_var_locs:
+ var_locs[name_to_var_locs[name]] = len(input_locs) - 1
+ locs.append((input_locs, var_locs))
+ return inputs, locs
+ if max_iterations is None:
+ raise ValueError("max_iterations should be specified")
+ max_iterations = _to_python_scalar(max_iterations, int, "max_iteration")
+ loop_vars = _to_symbol_tuple(loop_vars, "loop_vars")
+ # It should be work as fine if loop_vars are empty I guess,
+ # but it is semantically unnecessary to include this case.
+ if len(loop_vars) == 0:
+ raise ValueError("loop_vars should contain at least one element")
+ # create graph for `cond'
+ cond_g, num_out_data, num_outputs = \
+ _create_subgraph(loop_vars, _cond_wrapper, name + "_cond")
+ assert num_out_data == 0
+ assert num_outputs == 1
+ # create graph for `func`
+ func_g, num_out_data, num_outputs = \
+ _create_subgraph(loop_vars, _func_wrapper, name + "_func")
+ # find symbols used in either cond_g or func_g
+ input_syms, ((cond_input_locs, _), (func_input_locs, func_var_locs)) = \
+ _union_inputs(cond_g, func_g)
+ for i_th, loc in enumerate(func_var_locs, 1):
+ if loc == -1:
+ raise ValueError("The %d-th loop_var doesn't involve into the
computation" % i_th)
+ result = symbol._internal._while_loop(
+ # [cond, func_g, *input_syms]
+ cond_g,
+ func_g,
+ *input_syms,
+ max_iterations=max_iterations,
+ cond_input_locs=cond_input_locs,
+ func_input_locs=func_input_locs,
+ func_var_locs=func_var_locs,
+ num_out_data=num_out_data,
+ num_outputs=num_outputs
+ )
+ outputs = [result[i] for i in range(num_out_data)]
+ final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)]
+ return outputs, final_loop_vars
diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc
index c091fdb..b00ed9b 100644
--- a/src/operator/control_flow.cc
+++ b/src/operator/control_flow.cc
@@ -480,6 +480,503 @@ ForeachGradient(const nnvm::NodePtr& n, const
std::vector<nnvm::NodeEntry>& ogra
return entries;
}
+struct WhileLoopParam : public dmlc::Parameter<WhileLoopParam> {
+ int num_args;
+ int num_outputs;
+ int num_out_data;
+ int max_iterations;
+ // `cond' and `func' each takes a subset of while_loop's inputs as that to
their subgraphs
+ // `cond_input_locs' contains indices of inputs fed to `cond', and
+ // `func_input_locs' contains indices of inputs fed to `func'.
+ // `func_var_locs' are indices in which input "variables" are stored in
func's inputs.
+ nnvm::Tuple<dim_t> cond_input_locs;
+ nnvm::Tuple<dim_t> func_input_locs;
+ nnvm::Tuple<dim_t> func_var_locs;
+ DMLC_DECLARE_PARAMETER(WhileLoopParam) {
+ DMLC_DECLARE_FIELD(num_args).set_lower_bound(2)
+ .describe("Number of input arguments, including cond and func as two
symbol inputs.");
+ DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1)
+ .describe("The number of outputs of the subgraph.");
+ DMLC_DECLARE_FIELD(num_out_data).set_lower_bound(0)
+ .describe("The number of outputs from the function body.");
+ DMLC_DECLARE_FIELD(max_iterations).set_lower_bound(1)
+ .describe("Maximum number of iterations.");
+ DMLC_DECLARE_FIELD(cond_input_locs)
+ .describe("The locations of cond's inputs in the given inputs.");
+ DMLC_DECLARE_FIELD(func_input_locs)
+ .describe("The locations of func's inputs in the given inputs.");
+ DMLC_DECLARE_FIELD(func_var_locs)
+ .describe("The locations of loop_vars among func's inputs.");
+ }
+}; // struct WhileLoopParam
+
+DMLC_REGISTER_PARAMETER(WhileLoopParam);
+
+class WhileLoopState: public LoopState {
+ public:
+ WhileLoopParam params;
+ size_t n_iterations; // the actual number of steps taken in this while
loop, <= max_iterations
+ CachedOpPtr cond_op;
+ // abbrev for output_input_mapping
+ // indicates to which index the output of `func' will be copied to the input
of `cond'
+ std::vector<int> oi_map;
+
+ WhileLoopState(const WhileLoopParam ¶ms, const Symbol &cond, const
Symbol &func) :
+ LoopState(func),
+ params(params),
+ n_iterations(0U),
+ cond_op(LoopState::MakeSharedOp(cond)),
+ oi_map(params.func_var_locs.ndim(), -1) {
+ const nnvm::Tuple<dim_t> &func_input_locs = params.func_input_locs;
+ const nnvm::Tuple<dim_t> &func_var_locs = params.func_var_locs;
+ const nnvm::Tuple<dim_t> &cond_input_locs = params.cond_input_locs;
+ for (size_t i = 0; i < func_var_locs.ndim(); ++i) {
+ dim_t pos_i = func_input_locs[func_var_locs[i]];
+ for (size_t j = 0; j < cond_input_locs.ndim(); ++j) {
+ dim_t pos_j = cond_input_locs[j];
+ if (pos_i == pos_j) {
+ this->oi_map[i] = j;
+ }
+ }
+ }
+ }
+ template <typename T>
+ static void extract_by_loc(const std::vector<T> &array,
+ const nnvm::Tuple<dim_t> input_locs,
+ std::vector<T> *out) {
+ out->clear();
+ out->reserve(input_locs.ndim());
+ for (dim_t i : input_locs) {
+ out->push_back(array[i]);
+ }
+ }
+ static bool is_shape_udf(const TShape &x) {
+ return x.ndim() == 0 || x.Size() == 0;
+ }
+ static bool is_stype_udf(const int &x) {
+ return x == exec::kBadStorageID;
+ }
+ static bool is_type_udf(const int &x) {
+ return x == -1;
+ }
+ template <typename T>
+ static bool fill_value(T *x, T *y, bool x_empty, bool y_empty) {
+ if (*x == *y || (x_empty && y_empty)) {
+ return true;
+ }
+ if (!x_empty && !y_empty) {
+ return false;
+ }
+ if (x_empty) {
+ *x = *y;
+ }
+ if (y_empty) {
+ *y = *x;
+ }
+ return true;
+ }
+ template <typename T>
+ static bool sync_in_in(const nnvm::Tuple<dim_t> &input_locs,
+ std::vector<T> *in,
+ std::vector<T> *subg_in,
+ std::function<bool(const T &)> is_empty) {
+ for (size_t i = 0; i < input_locs.ndim(); ++i) {
+ T &x = in->at(input_locs[i]);
+ T &y = subg_in->at(i);
+ fill_value(&x, &y, is_empty(x), is_empty(y));
+ }
+ return true;
+ }
+ template <typename T>
+ static bool sync_in_out(const WhileLoopParam& params,
+ std::vector<T> *in,
+ std::vector<T> *out,
+ std::function<bool(const T &)> is_empty) {
+ for (int i = params.num_out_data; i < params.num_outputs; ++i) {
+ // each out->at(i) is a params, loop_var
+ T &x = in->at(params.func_input_locs[params.func_var_locs[i -
params.num_out_data]]);
+ T &y = out->at(i);
+ fill_value(&x, &y, is_empty(x), is_empty(y));
+ }
+ return true;
+ }
+};
+
+template <typename T>
+T _asscalar(const NDArray &a) {
+ CHECK_EQ(a.shape().Size(), 1U);
+ T data;
+ a.SyncCopyToCPU(&data, 1U);
+ return data;
+}
+
+bool as_bool_scalar(const NDArray &a) {
+ MSHADOW_TYPE_SWITCH(a.dtype(), DType, {
+ return static_cast<bool>(_asscalar<DType>(a));
+ });
+ LOG(FATAL) << "Unknown dtype";
+ return false;
+}
+
+static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ // The argument `inputs' are loop_vars and other inputs
+ // loop_vars are stored in stored in `loop_vars_locs'
+ // The argument `outputs' are output and new_loop_vars
+ // [0: num_out_data) are outputs at each step.
+ // [num_out_data: ) are new_loop_vars
+ // TODO(Junru): avoid dynamic NDArray allocation
+ WhileLoopState &state = state_ptr.get_state<WhileLoopState>();
+ const WhileLoopParam& params = state.params;
+ // a helper function, converting std::vector<NDArray> to
std::vector<NDArray*>
+ const auto to_ptr_vec = [](std::vector<NDArray> &in, std::vector<NDArray*>
*out) {
+ out->clear();
+ out->reserve(in.size());
+ std::transform(std::begin(in),
+ std::end(in),
+ std::back_inserter(*out),
+ [](NDArray &a) {return &a;});
+ };
+ // sanity checks
+ CHECK_EQ(inputs.size() + 2U, (size_t) params.num_args);
+ CHECK_EQ(outputs.size(), (size_t) params.num_outputs);
+ CHECK_EQ(outputs.size(), req.size());
+ for (size_t i = 0; i < (size_t) params.num_out_data; i++)
+ CHECK_EQ(params.max_iterations, outputs[i].shape()[0]);
+ // construct inputs and outputs for cond
+ std::vector<NDArray> cond_inputs, cond_outputs = {NDArray()};
+ WhileLoopState::extract_by_loc(inputs, params.cond_input_locs, &cond_inputs);
+ std::vector<NDArray*> cond_input_ptr, cond_output_ptr;
+ to_ptr_vec(cond_inputs, &cond_input_ptr);
+ to_ptr_vec(cond_outputs, &cond_output_ptr);
+ // construct inputs and outputs for func
+ std::vector<NDArray> func_inputs, func_outputs(outputs.size());
+ WhileLoopState::extract_by_loc(inputs, params.func_input_locs, &func_inputs);
+ for (size_t &step = state.n_iterations = 0; step < (size_t)
params.max_iterations; ++step) {
+ state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr);
+ if (!as_bool_scalar(*cond_output_ptr[0])) {
+ break;
+ }
+ // we create func_outputs for the current step:
+ // func_outputs[0: num_out_data] is a slice of outputs[][step]
+ for (size_t i = 0; i < (size_t) params.num_out_data; ++i) {
+ func_outputs[i] = outputs[i].At(step);
+ }
+ // func_outputs[num_out_data: ] are new_loop_vars, need to allocate new
memory
+ for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
+ func_outputs[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true,
outputs[i].dtype());
+ }
+ state.Forward(step, func_inputs, req, func_outputs, ctx.need_grad);
+ // func_inputs on the next step:
+ // the output (new_loop_vars) will become the new inputs (loop_vars)
+ for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
+ size_t j = params.func_var_locs[i - params.num_out_data];
+ CHECK_EQ(func_inputs[j].shape(), func_outputs[i].shape());
+ func_inputs[j] = func_outputs[i];
+ int k = state.oi_map[i - params.num_out_data];
+ if (k != -1) {
+ // I actually don't need to update cond_inputs
+ cond_inputs[k] = func_outputs[i];
+ cond_input_ptr[k] = &func_outputs[i];
+ }
+ }
+ }
+ // copy output data to `outputs'
+ // case 1: at least one step is executed,
+ // the final_loop_vars must be stored in func_inputs
+ // case 2: no step is executed
+ // the final_loop_vars is the same as loop_vars, which are also stored in
func_inputs
+ // therefore, we copy func_inputs[:] to outputs[num_out_data: ]
+ for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
+ size_t j = params.func_var_locs[i - params.num_out_data];
+ mxnet::CopyFromTo(func_inputs[j], &outputs[i]);
+ }
+}
+
+static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& _req,
+ const std::vector<NDArray>& _outputs) {
+ // inputs are dl / df(x)
+ // outputs are dl / dx
+ // where f is the current function,
+ // x is the input to the current function,
+ // TODO(Junru): avoid dynamic NDArray allocation
+ WhileLoopState &state = state_ptr.get_state<WhileLoopState>();
+ const WhileLoopParam& params = state.params;
+ // sanity checks
+ CHECK_EQ(_outputs.size() + 2U, (size_t) params.num_args);
+ CHECK_EQ(_outputs.size(), _req.size());
+ for (auto x : _req) {
+ CHECK_NE(x, kWriteInplace);
+ }
+ std::vector<NDArray> outputs;
+ std::vector<OpReqType> req;
+ WhileLoopState::extract_by_loc(_outputs, params.func_input_locs, &outputs);
+ WhileLoopState::extract_by_loc(_req, params.func_input_locs, &req);
+ if (state.n_iterations == 0) {
+ for (int i = params.num_out_data; i < params.num_outputs; ++i) {
+ int j = params.func_var_locs[i - params.num_out_data];
+ mxnet::CopyFromTo(inputs[i], &outputs[j]);
+ }
+ state.Cleanup();
+ return;
+ }
+ // collect var_locs and out_locs, positions other than var_locs are
out_locs, i.e.
+ // [0, var_locs[0])
+ // (var_locs[1], var_locs[2]),
+ // (var_locs[2], var_locs[3]),
+ // ...
+ // (var_locs[-2], var_locs[-1] = params.num_args - 2)
+ std::vector<dim_t> var_locs(params.func_var_locs.begin(),
params.func_var_locs.end());
+ var_locs.push_back((dim_t) params.num_args - 2U);
+ sort(var_locs.begin(), var_locs.end());
+ // vectors for the backward loop
+ std::vector<NDArray> ograds(params.num_outputs);
+ std::vector<NDArray> igrads(outputs.size());
+ std::vector<OpReqType> iter_req(req.size());
+ for (int i = params.num_out_data; i < params.num_outputs; ++i)
+ ograds[i] = inputs[i];
+ const int n_iter = state.n_iterations;
+ for (int step = n_iter - 1; step >= 0; --step) {
+ // ograds[ : num_out_data] = inputs[ : num_out_data][step]
+ // ograds[num_out_data: ] is maintained in the end of each loop
+ std::transform(std::begin(inputs),
+ std::begin(inputs) + params.num_out_data,
+ std::begin(ograds),
+ [step] (const NDArray &a) { return a.At(step); } );
+ // igrads[i] =
+ // outputs[i] (step == 0)
+ // outputs[i] (step != 0 && i not in loop_var_locs)
+ // ArrayLike(outputs[i]) (step != 0 && i in loop_var_locs)
+ // iter_req =
+ // kWriteTo (step != 0 && i in loop_var_locs)
+ // req[i] (step == 0 && i in loop_var_locs)
+ // kAddTo (step != n_iters - 1 && i not in loop_var_locs)
+ // req[i] (step == n_iters - 1 && i not in loop_var_locs)
+ {
+ size_t i = 0;
+ for (size_t loc : var_locs) {
+ for ( ; i < loc; ++i) {
+ // locs other that var_locs
+ igrads[i] = outputs[i];
+ iter_req[i] = (step + 1 == n_iter || req[i] == kNullOp)
+ ? req[i]
+ : kAddTo;
+ }
+ if (i < (size_t) params.num_args - 2U) {
+ // a var
+ igrads[i] = (step == 0)
+ ? outputs[i]
+ : NDArray(outputs[i].shape(), outputs[i].ctx(), true,
outputs[i].dtype());
+ iter_req[i] = (step == 0 || req[i] == kNullOp)
+ ? req[i]
+ : kWriteTo;
+ ++i;
+ } else {
+ break;
+ }
+ }
+ }
+ state.Backward(step, ograds, iter_req, igrads);
+ for (int i = params.num_out_data; i < params.num_outputs; ++i) {
+ size_t j = params.func_var_locs[i - params.num_out_data];
+ ograds[i] = igrads[j];
+ }
+ }
+ state.Cleanup();
+}
+
+static bool WhileLoopShape(const nnvm::NodeAttrs& attrs,
+ std::vector<TShape> *in_shape,
+ std::vector<TShape> *out_shape) {
+ using nnvm::ShapeVector;
+ const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+ static const std::function<bool(const TShape &)> is_udf =
WhileLoopState::is_shape_udf;
+ // sanity checks
+ CHECK_EQ(in_shape->size() + 2U, (size_t) params.num_args);
+ CHECK_EQ(out_shape->size(), (size_t) params.num_outputs);
+ CHECK_EQ(attrs.subgraphs.size(), 2U);
+ CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
+ // infer shape for cond and func
+ auto infer_subg = [¶ms, in_shape, out_shape](std::shared_ptr<Symbol>
subg,
+ ShapeVector *_subg_out,
+ const nnvm::Tuple<dim_t>
&input_locs,
+ int num_out_data,
+ bool fill_out_shape) {
+ // create subg_in
+ ShapeVector subg_in;
+ ShapeVector &subg_out = *_subg_out;
+ WhileLoopState::extract_by_loc(*in_shape, input_locs, &subg_in);
+ // create an indexed graph
+ nnvm::Graph g;
+ g.outputs = subg->outputs;
+ const auto& idx = g.indexed_graph();
+ // get input nodes
+ const auto &input_nids = idx.input_nodes();
+ // sanity checks
+ CHECK_EQ(input_nids.size(), subg_in.size());
+ CHECK_EQ(g.outputs.size(), subg_out.size());
+ CHECK_EQ(idx.input_nodes().size(), subg_in.size());
+ CHECK_EQ(idx.outputs().size(), subg_out.size());
+ // create empty shapes for inference
+ ShapeVector shapes(idx.num_node_entries());
+ // copy subg_in into shapes
+ for (size_t i = 0; i < subg_in.size(); ++i) {
+ auto eid = idx.entry_id(input_nids[i], 0);
+ shapes[eid] = subg_in[i];
+ }
+ // copy subg_out into shapes
+ // note that ndim of out_data is not increased
+ // because subg is only one step
+ for (size_t i = 0; i < subg_out.size(); ++i) {
+ auto eid = idx.entry_id(g.outputs[i]);
+ shapes[eid] = subg_out[i];
+ }
+ // copy done, call InferShape
+ g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
+ g = exec::InferShape(std::move(g));
+ // now `shapes' won't be used anymore, use new_shapes instead
+ const auto& new_shapes = g.GetAttr<ShapeVector>("shape");
+ // copy subg_in back to in_shape
+ for (size_t i = 0; i < subg_in.size(); ++i) {
+ auto eid = idx.entry_id(input_nids[i], 0);
+ auto g_out_shape = new_shapes[eid];
+ if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) {
+ // when the shape is not fully inferred
+ continue;
+ }
+ SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape);
+ }
+ if (!fill_out_shape) {
+ return true;
+ }
+ // copy subg_out back to out_shape
+ // for results in [0, num_out_data), ndim should increase by 1
+ for (int i = 0; i < num_out_data; ++i) {
+ auto eid = idx.entry_id(g.outputs[i]);
+ auto g_out_shape = new_shapes[eid];
+ if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) {
+ // when the shape is not fully inferred
+ continue;
+ }
+ auto out = TShape(g_out_shape.ndim() + 1);
+ out[0] = params.max_iterations;
+ for (size_t i = 1; i < out.ndim(); i++)
+ out[i] = g_out_shape[i - 1];
+ SHAPE_ASSIGN_CHECK(*out_shape, i, out);
+ }
+ // for results in [num_out_data, ...), ndim does not change
+ for (size_t i = num_out_data; i < g.outputs.size(); ++i) {
+ auto eid = idx.entry_id(g.outputs[i]);
+ auto g_out_shape = new_shapes[eid];
+ if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) {
+ // when the shape is not fully inferred
+ continue;
+ }
+ SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape);
+ }
+ return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
+ };
+ ShapeVector cond_out_shape{TShape(1U)}; // this means: [(1, )]
+ ShapeVector func_out_shape(params.num_outputs);
+ CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf));
+ bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape,
params.cond_input_locs, 0, false);
+ CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf));
+ bool succ_1 = infer_subg(attrs.subgraphs[1], &func_out_shape, \
+ params.func_input_locs, params.num_out_data, true);
+ CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf));
+ return succ_0 && succ_1;
+}
+
+static bool WhileLoopType(const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_type, std::vector<int>
*out_type) {
+ const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+ static const std::function<bool(const int &)> is_udf =
WhileLoopState::is_type_udf;
+ CHECK_EQ(in_type->size() + 2U, (size_t) params.num_args);
+ CHECK_EQ(out_type->size(), (size_t) params.num_outputs);
+ CHECK_EQ(attrs.subgraphs.size(), 2U);
+ CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
+ std::vector<int> cond_in_type;
+ std::vector<int> func_in_type;
+ WhileLoopState::extract_by_loc(*in_type, params.cond_input_locs,
&cond_in_type);
+ WhileLoopState::extract_by_loc(*in_type, params.func_input_locs,
&func_in_type);
+ std::vector<int> cond_out_type = {0};
+ CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf));
+ bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type,
&cond_out_type);
+ CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf));
+ CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_type,
&cond_in_type, is_udf));
+ bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &func_in_type,
out_type);
+ CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf));
+ CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_type,
&func_in_type, is_udf));
+ return succ_0 && succ_1;
+}
+
+static bool WhileLoopStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+ static const std::function<bool(const int &)> is_udf =
WhileLoopState::is_stype_udf;
+ CHECK_EQ(in_attrs->size() + 2U, (size_t) params.num_args);
+ CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs);
+ CHECK_EQ(attrs.subgraphs.size(), 2U);
+ CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
+ std::vector<int> cond_in_attrs;
+ std::vector<int> func_in_attrs;
+ WhileLoopState::extract_by_loc(*in_attrs, params.cond_input_locs,
&cond_in_attrs);
+ WhileLoopState::extract_by_loc(*in_attrs, params.func_input_locs,
&func_in_attrs);
+ std::vector<int> cond_out_attrs = {kDefaultStorage};
+ DispatchMode cond_mode = DispatchMode::kUndefined;
+ DispatchMode func_mode = DispatchMode::kUndefined;
+ *dispatch_mode = DispatchMode::kFComputeEx;
+ CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf));
+ bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, \
+ &cond_mode, &cond_in_attrs,
&cond_out_attrs);
+ CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf));
+ CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_attrs,
&cond_in_attrs, is_udf));
+ bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, \
+ &func_mode, &func_in_attrs, out_attrs);
+ CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf));
+ CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_attrs,
&func_in_attrs, is_udf));
+ return succ_0 && succ_1;
+}
+
+static bool BackwardWhileLoopStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ // `cond' is not backwarded, don't check
+ const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+ CHECK_EQ(out_attrs->size() + 2U, (size_t) params.num_args);
+ CHECK_EQ(attrs.subgraphs.size(), 2U);
+ CachedOp op(*attrs.subgraphs[1], {});
+ return op.BackwardStorageType(attrs, dev_mask, dispatch_mode,
+ in_attrs, out_attrs);
+}
+
+static OpStatePtr CreateWhileLoopState(const NodeAttrs& attrs,
+ Context ctx,
+ const std::vector<TShape>& ishape,
+ const std::vector<int>& itype) {
+ const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+ return OpStatePtr::Create<WhileLoopState>(params, *attrs.subgraphs[0],
*attrs.subgraphs[1]);
+}
+
+static std::vector<nnvm::NodeEntry>
+WhileLoopGradient(const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>&
ograds) {
+ ElemwiseGradUseInOut fgrad{"_backward_while_loop"};
+ std::vector<nnvm::NodeEntry> entries = fgrad(n, ograds);
+ entries[0].node->attrs.subgraphs = n->attrs.subgraphs;
+ return entries;
+}
+
NNVM_REGISTER_OP(_foreach)
.MXNET_DESCRIBE("Run a for loop over an NDArray with user-defined computation")
.set_attr_parser(ParamParser<ForeachParam>)
@@ -526,11 +1023,11 @@ NNVM_REGISTER_OP(_backward_foreach)
.set_num_inputs([](const NodeAttrs& attrs){
const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
return params.num_outputs * 2 + params.num_args - 1;
- })
+})
.set_num_outputs([](const NodeAttrs& attrs){
const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
return params.num_args - 1;
- })
+})
.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
return ExecType::kSubgraphExec;
})
@@ -541,5 +1038,67 @@ NNVM_REGISTER_OP(_backward_foreach)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>",
ForeachGradComputeExCPU)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>",
ForeachGradComputeExCPU);
+NNVM_REGISTER_OP(_while_loop)
+.MXNET_DESCRIBE("Run a while loop over with user-defined condition and
computation")
+.set_attr_parser(ParamParser<WhileLoopParam>)
+.set_attr<FInferStorageType>("FInferStorageType", WhileLoopStorageType)
+.set_num_inputs([](const NodeAttrs& attrs) {
+ const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+ return params.num_args;
+})
+.set_num_outputs([](const NodeAttrs& attrs) {
+ const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+ return params.num_outputs;
+})
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+ std::vector<std::string> names;
+ names.reserve(params.num_args);
+ names.push_back("cond");
+ names.push_back("func");
+ for (int i = 2; i < params.num_args; i++)
+ names.push_back("data" + std::to_string(i - 2));
+ return names;
+})
+.set_attr<nnvm::FInputGraph>("FInputGraph",
+ [](const NodeAttrs& attrs) {
+ return std::vector<uint32_t>{0, 1};
+})
+.set_attr<nnvm::FGradient>("FGradient", WhileLoopGradient)
+.set_attr<FCreateOpState>("FCreateOpState", CreateWhileLoopState)
+.set_attr<nnvm::FInferShape>("FInferShape", WhileLoopShape)
+.set_attr<nnvm::FInferType>("FInferType", WhileLoopType)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", WhileLoopComputeExCPU)
+.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
+ return ExecType::kSubgraphExec;
+})
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", WhileLoopComputeExCPU)
+.set_attr<std::string>("key_var_num_args", "num_args")
+.add_argument("cond", "Symbol", "Input graph for the loop condition.")
+.add_argument("func", "Symbol", "Input graph for the loop body.")
+.add_argument("data", "NDArray-or-Symbol[]",
+ "The input arrays that include data arrays and states.")
+.add_arguments(WhileLoopParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_while_loop)
+.set_num_inputs([](const NodeAttrs& attrs){
+ const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+ return params.num_outputs * 2 + params.num_args - 2;
+})
+.set_num_outputs([](const NodeAttrs& attrs){
+ const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
+ return params.num_args - 2;
+})
+.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
+ return ExecType::kSubgraphExec;
+})
+.set_attr<FInferStorageType>("FInferStorageType", BackwardWhileLoopStorageType)
+.set_attr_parser(ParamParser<WhileLoopParam>)
+.set_attr<bool>("TIsLayerOpBackward", true)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>",
WhileLoopGradComputeExCPU)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>",
WhileLoopGradComputeExCPU);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/subgraph_op_common.cc
b/src/operator/subgraph_op_common.cc
index 71a9a21..d845aa9 100644
--- a/src/operator/subgraph_op_common.cc
+++ b/src/operator/subgraph_op_common.cc
@@ -164,14 +164,7 @@ bool InferSubgraphShape(const nnvm::Symbol &subgraph,
LoopState::LoopState(const Symbol &g) {
this->subgraph_sym = g;
this->subgraph.outputs = g.outputs;
-
- std::vector<std::pair<std::string, std::string> > kwargs;
- kwargs.push_back(std::pair<std::string, std::string>("inline_limit", "0"));
- // We turn on static_alloc for two reasons.
- // It avoids the overhead of unnecessary memory allocation.
- // only static_alloc supports nested call of CachedOp.
- kwargs.push_back(std::pair<std::string, std::string>("static_alloc", "1"));
- iter_op = std::make_shared<CachedOp>(subgraph_sym, kwargs);
+ this->iter_op = LoopState::MakeSharedOp(g);
}
void LoopState::Forward(int iter_no,
diff --git a/src/operator/subgraph_op_common.h
b/src/operator/subgraph_op_common.h
index 7907840..f73f09c 100644
--- a/src/operator/subgraph_op_common.h
+++ b/src/operator/subgraph_op_common.h
@@ -24,6 +24,8 @@
#include <mxnet/base.h>
#include <mxnet/op_attr_types.h>
#include <vector>
+#include <utility>
+#include <string>
#include "../imperative/cached_op.h"
#include "../imperative/imperative_utils.h"
@@ -69,8 +71,8 @@ class LoopState {
// For training, each iteration has a cached op because each iteration
// needs to maintain a set of memory buffers for all computation states,
// which will be used in the backward.
- CachedOpPtr iter_op;
std::vector<OpStatePtr> all_states;
+ CachedOpPtr iter_op;
Symbol subgraph_sym;
nnvm::Graph subgraph;
@@ -91,6 +93,16 @@ class LoopState {
all_inputs.clear();
all_states.clear();
}
+ static CachedOpPtr MakeSharedOp(const Symbol &sym) {
+ // We turn on static_alloc for two reasons.
+ // It avoids the overhead of unnecessary memory allocation.
+ // only static_alloc supports nested call of CachedOp.
+ std::vector<std::pair<std::string, std::string> > kwargs = {
+ {"inline_limit", "0"},
+ {"static_alloc", "1"}
+ };
+ return std::make_shared<CachedOp>(sym, kwargs);
+ }
};
} // namespace op
diff --git a/tests/python/unittest/test_contrib_control_flow.py
b/tests/python/unittest/test_contrib_control_flow.py
new file mode 100644
index 0000000..9dd5c43
--- /dev/null
+++ b/tests/python/unittest/test_contrib_control_flow.py
@@ -0,0 +1,978 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from mxnet import gluon
+import numpy as np
+import copy
+from numpy.testing import assert_allclose
+import unittest
+from mxnet.test_utils import almost_equal, default_context
+from numpy.testing import assert_allclose as assert_almost_equal # This is
more restrictive
+from mxnet.base import _as_list
+
+
+def test_while_loop_simple_forward():
+
+ class _TestBlock(gluon.HybridBlock):
+
+ def __init__(self, cond, func, max_iterations):
+ super(_TestBlock, self).__init__()
+ self.cond = cond
+ self.func = func
+ self.max_iterations = max_iterations
+
+ def hybrid_forward(self, F, *loop_vars):
+ return F.contrib.while_loop(
+ cond=self.cond,
+ func=self.func,
+ loop_vars=loop_vars,
+ max_iterations=self.max_iterations
+ )
+
+ for hybridize in [False, True]:
+ # Case 1.1: result should be sum([1, 2, 3 ... 100])
+ model = _TestBlock(
+ cond=lambda i, s: i <= 5,
+ func=lambda i, s: (None, (i + 1, s + i)),
+ max_iterations=10,
+ )
+ if hybridize:
+ model.hybridize()
+ _, result = model(
+ mx.nd.array([1], dtype="int64"), # i
+ mx.nd.array([0], dtype="int64"), # s
+ )
+ assert result[0].asscalar() == 6
+ assert result[1].asscalar() == 15
+ # Case 1.2: result should be sum([1, 2, 3 ... 1000])
+ model = _TestBlock(
+ cond=lambda i, s, true: true,
+ func=lambda i, s, true: (None, (i + 1, s + i, true)),
+ max_iterations=1000,
+ )
+ if hybridize:
+ model.hybridize()
+ _, result = model(
+ mx.nd.array([1], dtype="int64"), # i
+ mx.nd.array([0], dtype="int64"), # s
+ mx.nd.array([1], dtype="int64"), # true
+ )
+ assert result[0].asscalar() == 1001
+ assert result[1].asscalar() == 500500
+ assert result[2].asscalar() == 1
+ # Case 1.3: result should be sum([])
+ model = _TestBlock(
+ cond=lambda i, s, false: false,
+ func=lambda i, s, false: (None, (i + 1, s + i, false)),
+ max_iterations=1000,
+ )
+ if hybridize:
+ model.hybridize()
+ _, result = model(
+ mx.nd.array([1], dtype="int64"), # i
+ mx.nd.array([0], dtype="int64"), # s
+ mx.nd.array([0], dtype="int64"), # false
+ )
+ assert result[0].asscalar() == 1
+ assert result[1].asscalar() == 0
+ assert result[2].asscalar() == 0
+ # Case 2.1: result should be sum([1, 2, 3 ... 100])
+ model = _TestBlock(
+ cond=lambda i, s: i <= 100,
+ func=lambda i, s: (i, (i + 1, s + i)),
+ max_iterations=1000,
+ )
+ if hybridize:
+ model.hybridize()
+ (outputs, ), (result_i, result_s) = model(
+ mx.nd.array([1], dtype="int64"), # i
+ mx.nd.array([0], dtype="int64"), # s
+ )
+ assert all(outputs.asnumpy()[ : 100] == np.arange(1, 101).reshape(100,
1))
+ assert result_i.asscalar() == 101
+ assert result_s.asscalar() == 5050
+ # Case 2.2: result should be sum([1, 2, 3 ... 1000])
+ model = _TestBlock(
+ cond=lambda i, s, true: true,
+ func=lambda i, s, true: (i, (i + 1, s + i, true)),
+ max_iterations=1000,
+ )
+ if hybridize:
+ model.hybridize()
+ (outputs, ), (result_i, result_s, _) = model(
+ mx.nd.array([1], dtype="int64"), # i
+ mx.nd.array([0], dtype="int64"), # s
+ mx.nd.array([1], dtype="int64"), # true
+ )
+ assert all(outputs.asnumpy() == np.arange(1, 1001).reshape(1000, 1))
+ assert result_i.asscalar() == 1001
+ assert result_s.asscalar() == 500500
+ # Case 2.3: a corner case, in which loop body is never executed
+ model = _TestBlock(
+ cond=lambda i, s, false: false,
+ func=lambda i, s, false: (i, (i + 1, s + i, false)),
+ max_iterations=1000,
+ )
+ if hybridize:
+ model.hybridize()
+ _, (result_i, result_s, _) = model(
+ mx.nd.array([1], dtype="int64"), # i
+ mx.nd.array([0], dtype="int64"), # s
+ mx.nd.array([0], dtype="int64"), # false
+ )
+ assert result_i.asscalar() == 1
+ assert result_s.asscalar() == 0
+
+
+def _verify_while_loop(cond, func, loop_var_shapes, free_var_shapes, is_train,
max_iterations, is_for, n_steps):
+
+ def _create_vars(num, prefix):
+ return [mx.sym.var(prefix + str(i)) for i in range(num)]
+
+ def _create_arrays(shapes):
+ return [mx.nd.random.uniform(-1.0, 1.0, shape=x) for x in shapes]
+
+ def _create_dict(prefix, shapes):
+ return {prefix + str(i): mx.nd.random.uniform(-1.0, 1.0, shape=x) for
i, x in enumerate(shapes)}
+
+ def _merge_dict(*dicts):
+ result = {}
+ for item in dicts:
+ result.update(item)
+ return result
+
+ def _to_numpy_list(arrays):
+ return [x.asnumpy() if x is not None else x for x in arrays]
+
+ def _get_imperative_result(n_steps):
+ free_vars = [args["FreeVar" + str(i)].copy() for i, _ in
enumerate(free_var_shapes)]
+ loop_vars = [args["LoopVar" + str(i)].copy() for i, _ in
enumerate(loop_var_shapes)]
+ loop_var_start = int(is_for)
+ if is_train:
+ for var in free_vars + loop_vars[loop_var_start: ]:
+ var.attach_grad()
+ with mx.autograd.record(train_mode=is_train):
+ outputs, final_loop_vars = mx.nd.contrib.while_loop(
+ cond=lambda *_loop_vars: cond(_loop_vars, free_vars),
+ func=lambda *_loop_vars: func(_loop_vars, free_vars),
+ loop_vars=loop_vars,
+ max_iterations=max_iterations,
+ )
+ outputs = [x[: n_steps] for x in outputs]
+ out_grads = _create_arrays(x.shape for x in outputs) \
+ + _create_arrays(x.shape for x in final_loop_vars)
+ loop_result_nd = [x * 2 for x in outputs] + [x * 3 for x in
final_loop_vars]
+ grads = []
+ if is_train:
+ cat_out = mx.nd.concat(*[x.reshape(-1) for x in
loop_result_nd], dim=0)
+ cat_out.backward(out_grad=mx.nd.concat(*[x.reshape(-1) for x
in out_grads], dim=0))
+ grads = [free_vars[i].grad for i, _ in
enumerate(free_var_shapes)] \
+ + [loop_vars[i].grad for i, _ in
enumerate(loop_var_shapes) if i >= loop_var_start]
+ return _to_numpy_list(loop_result_nd), _to_numpy_list(grads),
out_grads
+
+ def _get_symbolic_result(out_grads, n_steps):
+
+ def _copy_args_dict(name_list):
+ return {name: args[name].copy() for name in name_list}
+
+ def _zeros_like_dict(name_list):
+ return {name: mx.nd.zeros_like(args[name]) for name in name_list}
+
+ free_syms = _create_vars(len(free_var_shapes), "FreeVar")
+ loop_syms = _create_vars(len(loop_var_shapes), "LoopVar")
+ outputs, final_loop_syms = mx.sym.contrib.while_loop(
+ cond=lambda *_loop_vars: cond(_loop_vars, free_syms),
+ func=lambda *_loop_vars: func(_loop_vars, free_syms),
+ loop_vars=loop_syms,
+ max_iterations=max_iterations,
+ )
+ if n_steps == 0:
+ outputs = []
+ else:
+ outputs = [x.slice_axis(axis=0, begin=0, end=n_steps) for x in
outputs]
+ loop_result_sym = [x * 2 for x in outputs] + [x * 3 for x in
final_loop_syms]
+ loop_result_sym = mx.sym.Group(loop_result_sym)
+
+ loop_var_start = int(is_for)
+ args_names = ["FreeVar" + str(i) for i, _ in
enumerate(free_var_shapes)] \
+ + ["LoopVar" + str(i) for i, _ in
enumerate(loop_var_shapes) if i >= loop_var_start]
+ args_grad = None if not is_train else _zeros_like_dict(x for x in
args_names)
+ executor = loop_result_sym.bind(
+ ctx=default_context(),
+ args=_copy_args_dict(loop_result_sym.list_inputs()),
+ args_grad=args_grad,
+ )
+ loop_result_nd = executor.forward(is_train=is_train)
+ grads = []
+ if is_train:
+ executor.backward(out_grads=out_grads)
+ grads = [executor.grad_dict.get("FreeVar" + str(i), None) for i, _
in enumerate(free_var_shapes)] \
+ + [executor.grad_dict.get("LoopVar" + str(i), None) for i, _
in enumerate(loop_var_shapes) if i >= loop_var_start]
+ return _to_numpy_list(loop_result_nd), _to_numpy_list(grads)
+
+ args = _merge_dict(
+ _create_dict("FreeVar", free_var_shapes),
+ _create_dict("LoopVar", loop_var_shapes),
+ )
+ if is_for:
+ assert loop_var_shapes[0] == (1, )
+ args["LoopVar0"] = mx.nd.array([0])
+ imp_outs, imp_grads, out_grads = _get_imperative_result(n_steps)
+ sym_outs, sym_grads = _get_symbolic_result(out_grads, n_steps)
+ for imp_out, sym_out in zip(imp_outs, sym_outs):
+ if imp_out is None or sym_out is None:
+ continue
+ assert_almost_equal(imp_out, sym_out, rtol=1e-4, atol=1e-4)
+ for imp_grad, sym_grad in zip(imp_grads, sym_grads):
+ if imp_grad is None or sym_grad is None:
+ continue
+ assert_almost_equal(imp_grad, sym_grad, rtol=1e-4, atol=1e-4)
+
+
+def test_while_loop_for_foreach():
+
+ def make_true_cond():
+ return lambda loop_vars, _: (loop_vars[0] < 1e200).prod()
+
+ def make_false_cond():
+ return lambda loop_vars, _: (loop_vars[0] > 1e200).prod()
+
+ def make_for_cond(length):
+ return lambda loop_vars, _: loop_vars[0] < length
+
+ def case_0():
+ # This is a simple testcase that all loop steps are independent'
+ # It basically scans the array and outputs itself
+ # There is 1 output
+ # There is 1 state: i
+ def _simple_func(loop, free):
+ (i, ), (scanned, ) = loop, free
+ in_ = scanned.take(i).squeeze(axis=0)
+ return (in_, i + 1)
+ _verify_while_loop(
+ cond=make_true_cond(),
+ func=_simple_func,
+ max_iterations=1,
+ is_train=True,
+ is_for=True,
+ loop_var_shapes=[
+ (1, ), # i
+ ],
+ free_var_shapes=[
+ (1, 3), # scanned
+ ],
+ n_steps=1,
+ )
+
+ def case_1(**params):
+ # This is a simple testcase that simulates a cumulative sum
+ # There is 1 output
+ # There is 1 state: s
+ step_funcs = [
+ lambda a, b, s: s,
+ lambda a, b, s: a * 1.5 + b * 2.5 - s * 3.5,
+ lambda a, b, s: a * 1.5 - s * 3.5 + b * 2.5,
+ lambda a, b, s: b * 2.5 + a * 1.5 - s * 3.5,
+ lambda a, b, s: b * 2.5 - s * 3.5 + a * 1.5,
+ lambda a, b, s: s * -3.5 + a * 1.5 + b * 2.5,
+ lambda a, b, s: s * -3.5 + b * 2.5 + a * 1.5,
+ lambda a, b, s: a * 2.5 * b + s * 0.3,
+ lambda a, b, s: b * 2.5 * a + s * 0.3,
+ lambda a, b, s: 2.5 * a * b + s * 0.3,
+ lambda a, b, s: b * a * 2.5 + s * 0.3,
+ lambda a, b, s: 2.5 * b * a + s * 0.3,
+ lambda a, b, s: b * a * 2.5 + s * 0.3,
+ lambda a, b, s: s * 0.3 + a * 2.5 * b,
+ lambda a, b, s: s * 0.3 + b * 2.5 * a,
+ lambda a, b, s: s * 0.3 + 2.5 * a * b,
+ lambda a, b, s: s * 0.3 + b * a * 2.5,
+ lambda a, b, s: s * 0.3 + 2.5 * b * a,
+ lambda a, b, s: s * 0.3 + b * a * 2.5,
+ ]
+ def make_func(step_func):
+ def step(loop, free):
+ (s, ), (a, b) = loop, free
+ out = step_func(a, b, s)
+ return (out, out)
+ return step
+ case_id = 0
+ for is_train in [True, False]:
+ for step_func in step_funcs:
+ case_id += 1
+ _verify_while_loop(
+ func=make_func(step_func),
+ is_train=is_train,
+ is_for=False,
+ **params
+ )
+
+ def case_2(**params):
+ # This is a testcase that involves non-differentiable operators
+ # There is 1 output
+ # There is 2 states: i, s
+ step_funcs = [
+ lambda in_, s, f_1: (in_ * 2) * s * f_1,
+ lambda in_, s, f_1: (in_ * 2) * f_1 * s,
+ lambda in_, s, f_1: s * (in_ * 2) * f_1,
+ lambda in_, s, f_1: s * f_1 * (in_ * 2),
+ lambda in_, s, f_1: f_1 * (in_ * 2) * s,
+ lambda in_, s, f_1: f_1 * s * (in_ * 2),
+ lambda in_, s, f_1: (2 * in_) * s * f_1,
+ lambda in_, s, f_1: (2 * in_) * f_1 * s,
+ lambda in_, s, f_1: s * (2 * in_) * f_1,
+ lambda in_, s, f_1: s * f_1 * (2 * in_),
+ lambda in_, s, f_1: f_1 * (2 * in_) * s,
+ lambda in_, s, f_1: f_1 * s * (2 * in_),
+ ]
+ def make_func(step_func):
+ """This simulates:
+ def compute(s, inputs, f_1, length):
+ outputs = []
+ for i in range(length):
+ s += inputs[i] * 2 + f_1
+ outputs.append(s)
+ return outputs, s
+ """
+ def step(loop, free):
+ (i, s), (scanned, f_1, _) = loop, free
+ in_ = scanned.take(i).squeeze(axis=0)
+ out = step_func(in_, s, f_1)
+ return (out, (i + 1, out))
+ return step
+ case_id = 0
+ for is_train in [True, False]:
+ for step_func in step_funcs:
+ case_id += 1
+ _verify_while_loop(
+ func=make_func(step_func),
+ max_iterations=1000,
+ is_train=is_train,
+ is_for=True,
+ **params
+ )
+
+ def case_3(length, **params):
+ # This is a testcase for multiple non-differentiable operators and
different ways of slicing
+ # There are 2 outputs
+ # There are 3 states: i, s_0, s_1
+ step_funcs = [
+ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2)
* f_0,
+ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1
* 2),
+ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0
* f_0,
+ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0
* s_0,
+ lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2)
* f_0,
+ lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1
* 2),
+ lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0
* f_0,
+ lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0
* s_0,
+ lambda i_0, i_1, s_0, s_1, f_0: i_0,
+ lambda i_0, i_1, s_0, s_1, f_0: i_1,
+ lambda i_0, i_1, s_0, s_1, f_0: s_0,
+ lambda i_0, i_1, s_0, s_1, f_0: s_1,
+ lambda i_0, i_1, s_0, s_1, f_0: f_0,
+ ]
+ def make_func(step_func):
+ """This simulates:
+ def compute(input_0, input_1, s_0, s_1, f_0, length):
+ output_0 = []
+ output_1 = []
+ for i in range(length):
+ i_0 = input_0[i]
+ i_1 = input_1[length - 1 - i]
+ out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0
+ s_0 = (s_0 + out) * 1.05
+ s_1 = (s_1 - out * 0.5) * 0.95
+ output_0.append(out)
+ output_1.append(out * 1.5)
+ return outputs, s_0, s_1
+ """
+ def step(loop, free):
+ (i, s_0, s_1), (sc_0, sc_1, f_0, _) = loop, free
+ i_0 = sc_0.take(i).squeeze(axis=0)
+ i_1 = sc_1.take(length - 1 - i).squeeze(axis=0)
+ out = step_func(i_0, i_1, s_0, s_1, f_0)
+ return ([out, out * 1.5], [i + 1, (s_0 + out) * 1.05, (s_1 -
out * 0.5) * 0.95])
+ return step
+ case_id = 0
+ for is_train in [True, False]:
+ for step_func in step_funcs:
+ case_id += 1
+ _verify_while_loop(
+ func=make_func(step_func),
+ max_iterations=1000,
+ is_train=is_train,
+ is_for=True,
+ **params
+ )
+
+ def case_4(length, single_shape, **params):
+ # It is for the case that inputs & outputs are the same
+ # There are 3 outputs
+ # There are 4 states: i, s_0, s_1, s_2
+ # i is used in both non-differentiable (take) and differentiable (+)
occasions
+ step_funcs = [
+ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2)
* f_0,
+ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1
* 2),
+ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0
* f_0,
+ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0
* s_0,
+ lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2)
* f_0,
+ lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1
* 2),
+ lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0
* f_0,
+ lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0
* s_0,
+ lambda i_0, i_1, s_0, s_1, f_0: i_0,
+ lambda i_0, i_1, s_0, s_1, f_0: i_1,
+ lambda i_0, i_1, s_0, s_1, f_0: s_0,
+ lambda i_0, i_1, s_0, s_1, f_0: s_1,
+ lambda i_0, i_1, s_0, s_1, f_0: f_0,
+ ]
+ def make_func(step_func):
+ """This simulates:
+ def compute(input_0, input_1, s_0, s_1, s_2, f_0, length):
+ # here s_2 remains untouched
+ output_0 = []
+ output_1 = []
+ output_2 = []
+ for i in range(length):
+ i_0 = input_0[i]
+ i_1 = input_1[length - 1 - i]
+ out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0
+ out = out * i * i_0 * i_1
+ s_0 = (s_0 + out) * 1.05
+ s_1 = (s_1 - out * 0.5) * 0.95
+ output_0.append(out)
+ output_1.append(f_0)
+ output_2.append(out * 1.5)
+ return output_0, output_1, output_2, s_0, s_1, s_2
+ """
+ def step(loop, free):
+ (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free
+ i_0 = sc_0.take(i).squeeze(axis=0)
+ i_1 = sc_1.take(length - 1 - i).squeeze(axis=0)
+ out = step_func(i_0, i_1, s_0, s_1, f_0)
+ out = out * i.reshape([1] *
len(single_shape)).broadcast_to(single_shape)
+ out = out * i_0 * i_1
+ return ([out, f_0, out * 1.5], [i + 1, (s_0 + out) * 1.05,
(s_1 - out * 0.5) * 0.95, s_2])
+ return step
+ case_id = 0
+ for is_train in [True, False]:
+ for step_func in step_funcs:
+ case_id += 1
+ _verify_while_loop(
+ func=make_func(step_func),
+ max_iterations=1000,
+ is_train=is_train,
+ is_for=True,
+ **params
+ )
+
+ def case_5(length, single_shape, **params):
+ # It is for the case that inputs & outputs are the same
+ # There are 0 outputs
+ # There are 4 states: i, s_0, s_1, s_2
+ # i is used in both differentiable (take) and non-differentiable (+)
occasions
+ step_funcs = [
+ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2)
* f_0,
+ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1
* 2),
+ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0
* f_0,
+ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0
* s_0,
+ lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2)
* f_0,
+ lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1
* 2),
+ lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0
* f_0,
+ lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0
* s_0,
+ lambda i_0, i_1, s_0, s_1, f_0: i_0,
+ lambda i_0, i_1, s_0, s_1, f_0: i_1,
+ lambda i_0, i_1, s_0, s_1, f_0: s_0,
+ lambda i_0, i_1, s_0, s_1, f_0: s_1,
+ lambda i_0, i_1, s_0, s_1, f_0: f_0,
+ ]
+ def make_func(step_func):
+ """This simulates:
+ def compute(input_0, input_1, s_0, s_1, s_2, f_0, length):
+ # here s_2 remains untouched
+ output_0 = []
+ output_1 = []
+ output_2 = []
+ for i in range(length):
+ i_0 = input_0[i]
+ i_1 = input_1[length - 1 - i]
+ out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0
+ out = out * i * i_0 * i_1
+ s_0 = (s_0 + out) * 1.05
+ s_1 = (s_1 - out * 0.5) * 0.95
+ output_0.append(out)
+ output_1.append(f_0)
+ output_2.append(out * 1.5)
+ return output_0, output_1, output_2, s_0, s_1, s_2
+ """
+ def step(loop, free):
+ (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free
+ i_0 = sc_0.take(i).squeeze(axis=0)
+ i_1 = sc_1.take(length - 1 - i).squeeze(axis=0)
+ out = step_func(i_0, i_1, s_0, s_1, f_0)
+ out = out * i.reshape([1] *
len(single_shape)).broadcast_to(single_shape)
+ out = out * i_0 * i_1
+ return ([], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) *
0.95, s_2])
+ return step
+ case_id = 0
+ for is_train in [True, False]:
+ for step_func in step_funcs:
+ case_id += 1
+ _verify_while_loop(
+ func=make_func(step_func),
+ max_iterations=1000,
+ is_train=is_train,
+ is_for=True,
+ **params
+ )
+
+ def case_6(length, single_shape, **params):
+ # It is for the case that inputs & outputs are the same
+ # There are 3 outputs
+ # There are 4 states: i, s_0, s_1, s_2
+ # i is used in both differentiable (take) and non-differentiable (+)
occasions
+ step_funcs = [
+ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2)
* f_0,
+ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1
* 2),
+ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0
* f_0,
+ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0
* s_0,
+ lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2)
* f_0,
+ lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1
* 2),
+ lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0
* f_0,
+ lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0
* s_0,
+ lambda i_0, i_1, s_0, s_1, f_0: i_0,
+ lambda i_0, i_1, s_0, s_1, f_0: i_1,
+ lambda i_0, i_1, s_0, s_1, f_0: s_0,
+ lambda i_0, i_1, s_0, s_1, f_0: s_1,
+ lambda i_0, i_1, s_0, s_1, f_0: f_0,
+ ]
+ def make_func(step_func):
+ """This simulates:
+ def compute(input_0, input_1, s_0, s_1, s_2, f_0, length):
+ # here s_2 remains untouched
+ output_0 = []
+ output_1 = []
+ output_2 = []
+ for i in range(length):
+ i_0 = input_0[i]
+ i_1 = input_1[length - 1 - i]
+ out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0
+ out = out * i * i_0 * i_1
+ s_0 = (s_0 + out) * 1.05
+ s_1 = (s_1 - out * 0.5) * 0.95
+ output_0.append(out)
+ output_1.append(f_0)
+ output_2.append(out * 1.5)
+ return output_0, output_1, output_2, s_0, s_1, s_2
+ """
+ def step(loop, free):
+ (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free
+ F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd
+ i_0 = sc_0.take(i).squeeze(axis=0)
+ i_1 = sc_1.take(length - 1 - i).squeeze(axis=0)
+ out_0 = step_func(i_0, i_1, s_0, s_1, f_0)
+ out_0 = out_0 * i.reshape([1] *
len(single_shape)).broadcast_to(single_shape)
+ out_1 = step_func(i_1, s_0, f_0, s_1, i_0)
+ out_1 = out_1 * i.reshape([1] *
len(single_shape)).broadcast_to(single_shape)
+ return ([F.dot(out_0, s_2), f_0, F.dot(s_2, out_1) * 1.5], [i
+ 1, (s_0 + out_1) * 1.05, (s_1 - out_0 * 0.5) * 0.95, s_2])
+ return step
+ case_id = 0
+ for is_train in [True, False]:
+ for step_func in step_funcs:
+ case_id += 1
+ _verify_while_loop(
+ func=make_func(step_func),
+ max_iterations=1000,
+ is_train=is_train,
+ is_for=True,
+ **params
+ )
+
+ # Case 0: the simpest case
+ case_0()
+ # Case 1.1.*
+ case_1(
+ cond=make_true_cond(),
+ loop_var_shapes=[
+ (1, ), # s
+ ],
+ free_var_shapes=[
+ (1, ), # a
+ (1, ), # b
+ ],
+ max_iterations=23,
+ n_steps=23,
+ )
+ # Case 1.2.*
+ case_1(
+ cond=make_true_cond(),
+ loop_var_shapes=[
+ (2, 3, 4), # s
+ ],
+ free_var_shapes=[
+ (2, 3, 4), # a
+ (2, 3, 4), # b
+ ],
+ max_iterations=31,
+ n_steps=31,
+ )
+ # Case 1.3.*
+ case_1(
+ cond=make_false_cond(),
+ loop_var_shapes=[
+ (2, 3, 4), # s
+ ],
+ free_var_shapes=[
+ (2, 3, 4), # a
+ (2, 3, 4), # b
+ ],
+ max_iterations=20,
+ n_steps=0,
+ )
+ # Case 2.1.*
+ case_2(
+ cond=make_for_cond(length=31),
+ loop_var_shapes=[
+ (1, ), # i
+ (2, ), # s
+ ],
+ free_var_shapes=[
+ (100, 2), # scanned
+ (2, ), # f_1
+ (3, 4, 5, 6), # f_2, unused
+ ],
+ n_steps=31,
+ )
+ # Case 2.2.*
+ case_2(
+ cond=make_for_cond(length=25),
+ loop_var_shapes=[
+ (1, ), # i
+ (2, ), # s
+ ],
+ free_var_shapes=[
+ (30, 2), # scanned
+ (2, ), # f_1
+ (3, 4, 5, 6), # f_2, unused
+ ],
+ n_steps=25,
+ )
+ # Case 3.*
+ case_3(
+ length=11,
+ cond=make_for_cond(length=11),
+ loop_var_shapes=[
+ (1, ), # i
+ (2, ), # s_0
+ (2, ), # s_1
+ ],
+ free_var_shapes=[
+ (30, 2), # sc_0
+ (30, 2), # sc_1
+ (2, ), # f_0
+ (3, 4, 5, 6), # f_1, unused
+ ],
+ n_steps=11,
+ )
+ # Case 4.1.*
+ case_4(
+ length=4,
+ cond=make_for_cond(length=4),
+ single_shape=[5],
+ loop_var_shapes=[
+ (1, ), # i
+ (5, ), # s_0
+ (5, ), # s_1
+ (23, 6, 8), # s_2
+ ],
+ free_var_shapes=[
+ (30, 5), # sc_0
+ (30, 5), # sc_1
+ (5, ), # f_0
+ (3, 4, 5, 6), # f_1, unused
+ ],
+ n_steps=4,
+ )
+ # Case 4.2.*
+ case_4(
+ length=5,
+ cond=make_for_cond(length=5),
+ single_shape=[5, 12],
+ loop_var_shapes=[
+ (1, ), # i
+ (5, 12), # s_0
+ (5, 12), # s_1
+ (23, 6, 8), # s_2
+ ],
+ free_var_shapes=[
+ (30, 5, 12), # sc_0
+ (30, 5, 12), # sc_1
+ (5, 12), # f_0
+ (3, 4, 5, 6), # f_1, unused
+ ],
+ n_steps=5,
+ )
+ # Case 5.1.*
+ case_5(
+ length=4,
+ cond=make_for_cond(length=4),
+ single_shape=[5],
+ loop_var_shapes=[
+ (1, ), # i
+ (5, ), # s_0
+ (5, ), # s_1
+ (23, 6, 8), # s_2
+ ],
+ free_var_shapes=[
+ (30, 5), # sc_0
+ (30, 5), # sc_1
+ (5, ), # f_0
+ (3, 4, 5, 6), # f_1, unused
+ ],
+ n_steps=4,
+ )
+ # Case 5.2.*
+ case_5(
+ length=5,
+ cond=make_for_cond(length=5),
+ single_shape=[3, 4, 2],
+ loop_var_shapes=[
+ (1, ), # i
+ (3, 4, 2), # s_0
+ (3, 4, 2), # s_1
+ (23, 6, 8), # s_2
+ ],
+ free_var_shapes=[
+ (30, 3, 4, 2), # sc_0
+ (30, 3, 4, 2), # sc_1
+ (3, 4, 2), # f_0
+ (3, 4, 5, 6), # f_1, unused
+ ],
+ n_steps=5,
+ )
+ # Case 6.*
+ case_6(
+ length=5,
+ cond=make_for_cond(length=5),
+ single_shape=[5, 3],
+ loop_var_shapes=[
+ (1, ), # i
+ (5, 3), # s_0
+ (5, 3), # s_1
+ (3, 5), # s_2
+ ],
+ free_var_shapes=[
+ (30, 5, 3), # sc_0
+ (30, 5, 3), # sc_1
+ (5, 3), # f_0
+ (3, 4, 5, 6), # f_1, unused
+ ],
+ n_steps=5,
+ )
+
+
+def test_while_loop_nested():
+
+ def _to_np_list(arrays):
+ return [x.asnumpy() if x is not None else x for x in arrays]
+
+ def _array(shape):
+ return mx.nd.random.uniform(-1.0, 1.0, shape=shape)
+
+ def inner_cond(i, j, x_sum, sc):
+ return j < 2
+
+ def inner_body(i, j, x_sum, sc):
+ x_ij = sc.take(j).squeeze(axis=0)
+ return (x_ij, x_ij), (i, j + 1, x_sum, sc)
+
+ def outer_cond(i, j, x_sum, sc):
+ return i < 2
+
+ def outer_body(i, j, x_sum, sc):
+ F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd
+ (x_ij, x_ji), (i_p, j_p, x_sum_p, sc_p) = F.contrib.while_loop(
+ cond=inner_cond,
+ func=inner_body,
+ loop_vars=(i, j, x_sum, sc),
+ max_iterations=2,
+ )
+ return (x_ij, x_ji), (i_p + 1, j_p - 2, x_sum_p, sc_p)
+
+ def make_loop(i, j, x_sum, sc):
+ F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd
+ (x_ij, x_ji), (new_i, new_j, x_sum_p, sc_p) = F.contrib.while_loop(
+ cond=outer_cond,
+ func=outer_body,
+ loop_vars=(i, j, x_sum, sc),
+ max_iterations=2,
+ )
+ return new_i, new_j, x_sum_p, sc_p, x_ij, x_ji
+
+ args = {
+ "i": mx.nd.array([0]),
+ "j": mx.nd.array([0]),
+ "x_sum": _array([5, 3]),
+ "sc": _array([10, 10, 5, 3]),
+ }
+ args_grad = {
+ "x_sum": _array([5, 3]),
+ "sc": _array([10, 10, 5, 3]),
+ }
+ out_grad = [
+ _array([1]),
+ _array([1]),
+ _array([5, 3]),
+ _array([10, 10, 5, 3]),
+ _array([2, 2, 10, 5, 3]),
+ _array([2, 2, 10, 5, 3]),
+ ]
+ def _get_imp_result(is_train, args, args_grad, out_grad):
+ args = {k: v.copy() for k, v in args.items()}
+ args_grad = {k: v.copy() for k, v in args_grad.items()}
+ i, j, x_sum, sc = [args[x].copy() for x in ["i", "j", "x_sum", "sc"]]
+ if is_train:
+ x_sum.attach_grad()
+ sc.attach_grad()
+ with mx.autograd.record(train_mode=is_train):
+ results = make_loop(i, j, x_sum, sc)
+ cat_res = mx.nd.concat(*[x.reshape(-1) for x in results], dim=0)
+ if not is_train:
+ return _to_np_list(results), []
+ cat_grad = mx.nd.concat(*[x.reshape(-1) for x in out_grad], dim=0)
+ assert cat_grad.shape == cat_res.shape
+ cat_res.backward(out_grad=cat_grad)
+ grads = [x_sum.grad, sc.grad]
+ return _to_np_list(results), _to_np_list(grads)
+
+ def _get_sym_result(is_train, args, args_grad, out_grad):
+ args = {k: v.copy() for k, v in args.items()}
+ args_grad = {k: v.copy() for k, v in args_grad.items()}
+ i, j, x_sum, sc = [
+ mx.sym.var("i"),
+ mx.sym.var("j"),
+ mx.sym.var("x_sum"),
+ mx.sym.var("sc"),
+ ]
+ result_sym = mx.sym.Group(make_loop(i, j, x_sum, sc))
+ executor = result_sym.bind(
+ ctx=default_context(),
+ args=args,
+ args_grad=args_grad,
+ )
+ results = executor.forward(is_train=is_train)
+ if not is_train:
+ return _to_np_list(results), []
+ executor.backward(out_grads=out_grad)
+ grads = [executor.grad_dict["x_sum"], executor.grad_dict["sc"]]
+ return _to_np_list(results), _to_np_list(grads)
+
+ for is_train in [True, False]:
+ imp_out, imp_grad = _get_imp_result(is_train=is_train, args=args,
args_grad=args_grad, out_grad=out_grad)
+ sym_out, sym_grad = _get_sym_result(is_train=is_train, args=args,
args_grad=args_grad, out_grad=out_grad)
+ assert len(imp_out) == len(sym_out)
+ assert len(imp_grad) == len(sym_grad)
+ for x, y in zip(imp_out, sym_out):
+ assert_almost_equal(x, y, rtol=1e-4, atol=1e-4)
+ for x, y in zip(imp_grad, sym_grad):
+ assert_almost_equal(x, y, rtol=1e-4, atol=1e-4)
+
+
+def test_while_loop_rnn():
+ def _array(shape):
+ return mx.nd.random.uniform(-1.0, 1.0, shape=shape)
+
+ cell_types = [mx.rnn.LSTMCell]
+ num_params = [2]
+
+ batch_size = 2
+ hidden_dim = 3
+ input_dim = 4
+ seq_len = 3
+
+ for cell, n_param in zip(cell_types, num_params):
+ # using while_loop
+ params = mx.rnn.RNNParams()
+ data = mx.sym.var("data")
+ iter_i = mx.sym.var("i")
+ def _cond(*states):
+ i = states[0]
+ return i < seq_len
+ def _func(*states):
+ i = states[0]
+ states = states[1:]
+ in_ = data.take(i).squeeze(axis=0)
+ rnn = cell(hidden_dim, prefix='', params=params)
+ next_hidden, next_states = rnn(in_, states)
+ return [next_hidden], [i + 1] + list(next_states)
+ states = [mx.sym.var("s_" + str(i)) for i in range(n_param)]
+ result = mx.sym.contrib.while_loop(
+ cond=_cond,
+ func=_func,
+ loop_vars=[iter_i] + states,
+ max_iterations=seq_len
+ )
+ result = mx.sym.Group(result[0] + result[1][1: ])
+ arg_shapes, _, _ = result.infer_shape(
+ data=(seq_len, batch_size, input_dim),
+ s_0=(batch_size, hidden_dim),
+ )
+ rnn_inputs = result.list_inputs()
+ args = {name: _array(arg_shapes[i]) for i, name in
enumerate(rnn_inputs) if name != "i"}
+ args["i"] = mx.nd.zeros([1])
+ args_grad = {name: _array(arg_shapes[i]) for i, name in
enumerate(rnn_inputs)}
+ e_1 = result.bind(ctx=default_context(),
+ args={name: array.copy() for name, array in args.items()},
+ args_grad={name: array.copy() for name, array in args_grad.items()
if name != "i"},
+ )
+ # using unrolled rnn
+ rnn = cell(hidden_dim, prefix='')
+ unroll_outs = []
+ for inputs in mx.sym.split(data, num_outputs=seq_len, axis=0,
squeeze_axis=True):
+ h, states = rnn(inputs, states)
+ unroll_outs.append(mx.sym.expand_dims(h, axis=0))
+ unroll_outs = _as_list(mx.sym.concat(*unroll_outs, dim=0))
+ unroll_outs.extend(states)
+ result = mx.sym.Group(unroll_outs)
+ e_2 = result.bind(ctx=default_context(),
+ args={name: array.copy() for name, array in args.items() if name
!= "i"},
+ args_grad={name: array.copy() for name, array in args_grad.items()
if name != "i"},
+ )
+ for case_id in range(100):
+ out_grads = [_array(arr.shape) for arr in e_1.outputs]
+ args = {name: array.copy() for name, array in args.items()}
+ e_1.forward(is_train=True, **args)
+ e_1.backward(out_grads)
+ args = {name: array.copy() for name, array in args.items() if name
!= "i"}
+ e_2.forward(is_train=True, **args)
+ e_2.backward(out_grads)
+ assert len(e_1.outputs) == len(e_2.outputs)
+ for x, y in zip(e_1.outputs, e_2.outputs):
+ x = x.asnumpy()
+ y = y.asnumpy()
+ assert_almost_equal(x, y, rtol=1e-4, atol=1e-4)
+ grad_keys = list(e_2.grad_dict.keys())
+ e_1_grad = [e_1.grad_dict[x] for x in grad_keys]
+ e_2_grad = [e_2.grad_dict[x] for x in grad_keys]
+ for x, y in zip(e_1_grad, e_2_grad):
+ x = x.asnumpy()
+ y = y.asnumpy()
+ assert_almost_equal(x, y, rtol=1e-4, atol=1e-4)
+
+
+if __name__ == '__main__':
+ import nose
+ nose.runmodule()