szha closed pull request #11566: [MXNET-626] Add while_loop URL: https://github.com/apache/incubator-mxnet/pull/11566
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/3rdparty/tvm b/3rdparty/tvm index 6ab4da67834..290226e1c9a 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 6ab4da6783417d8afdeb6b0426b44959b2afc709 +Subproject commit 290226e1c9adbb3e598f9ed9184018df1c12be33 diff --git a/benchmark/python/control_flow/foreach_rnn.py b/benchmark/python/control_flow/foreach_rnn.py new file mode 100644 index 00000000000..4ce7a429ee9 --- /dev/null +++ b/benchmark/python/control_flow/foreach_rnn.py @@ -0,0 +1,195 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import subprocess +import mxnet as mx +from mxnet import gluon +import time +import copy + +def get_gpus(): + """ + return a list of GPUs + """ + try: + re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True) + except OSError: + return [] + return range(len([i for i in re.split('\n') if 'GPU' in i])) + +class TestRNNLayer(gluon.HybridBlock): + def __init__(self, cell, prefix=None, params=None): + super(TestRNNLayer, self).__init__(prefix=prefix, params=params) + self.cell = cell + + def hybrid_forward(self, F, inputs, states): + out, states = F.contrib.foreach(self.cell, inputs, states) + return out + +def benchmark_rnn(cell, rnn_data, states): + ctx = rnn_data.context + num_batches = 20 + + # Imperative + cell0 = copy.deepcopy(cell) + layer0 = TestRNNLayer(cell0) + layer0.initialize(ctx=ctx) + + # Hybridize + cell1 = copy.deepcopy(cell) + cell1.hybridize() + layer1 = TestRNNLayer(cell1) + layer1.initialize(ctx=ctx) + + # Hybridize + cell2 = copy.deepcopy(cell) + layer2 = TestRNNLayer(cell2) + layer2.initialize(ctx=ctx) + layer2.hybridize() + layer2(rnn_data, states) + + # Hybridize + cell3 = copy.deepcopy(cell) + cell3.hybridize(static_alloc=True) + layer3 = TestRNNLayer(cell3) + layer3.initialize(ctx=ctx) + + tic = time.time() + for i in range(num_batches): + res0 = layer0(rnn_data, states) + mx.nd.waitall() + print("Imperative inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res1 = layer1(rnn_data, states) + mx.nd.waitall() + print("Hybrid-cell inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res3 = layer3(rnn_data, states) + mx.nd.waitall() + print("Static-hybrid-cell inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res2 = layer2(rnn_data, states) + mx.nd.waitall() + print("Hybrid inference takes " + str(time.time() - tic)) + + layer2.export("foreach_rnn") + symnet = mx.symbol.load('foreach_rnn-symbol.json') + args1 = {} + params = layer2.collect_params() + for key in params.keys(): + args1[key] = params[key].data() + args1['data0'] = rnn_data + for i in range(len(states)): + args1['data' + str(i + 1)] = states[i] + exe = symnet.bind(ctx=ctx, args=args1) + tic = time.time() + for i in range(num_batches): + exe.forward(is_train=False) + mx.nd.waitall() + print("Symbol inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res0 = layer0(rnn_data, states) + res0.backward() + mx.nd.waitall() + print("Imperative training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res1 = layer1(rnn_data, states) + res1.backward() + mx.nd.waitall() + print("Hybrid-cell training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res3 = layer3(rnn_data, states) + res3.backward() + mx.nd.waitall() + print("Static-hybrid-cell training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res2 = layer2(rnn_data, states) + res2.backward() + mx.nd.waitall() + print("Hybrid training takes " + str(time.time() - tic)) + + # gradients for the backward of the foreach symbol + args_grad1 = {} + for key in args1.keys(): + args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx) + exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1) + tic = time.time() + for i in range(num_batches): + exe.forward(is_train=True) + exe.backward(res2) + mx.nd.waitall() + print("Symbol training takes " + str(time.time() - tic)) + print("") + +if __name__ == '__main__': + ndim = 512 + seq_len = 100 + batch_sizes = [1, 32] + cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'), + gluon.rnn.GRUCell(ndim, prefix='rnn_'), + gluon.rnn.LSTMCell(ndim, prefix='rnn_')] + ctxs = [mx.cpu(0), mx.gpu(0)] + for cell in cells: + for ctx in ctxs: + for batch_size in batch_sizes: + if len(get_gpus()) == 0 and ctx == mx.gpu(0): + continue + if isinstance(cell, gluon.rnn.RNNCell): + rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), + ctx=mx.cpu(0)) + states = [] + states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), + ctx=mx.cpu(0))) + elif isinstance(cell, gluon.rnn.GRUCell): + rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), + ctx=mx.cpu(0)) + states = [] + states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), + ctx=mx.cpu(0))) + elif isinstance(cell, gluon.rnn.LSTMCell): + rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), + ctx=mx.cpu(0)) + states = [] + states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), + ctx=mx.cpu(0))) + states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), + ctx=mx.cpu(0))) + if ctx == mx.gpu(0): + dev = "GPU" + else: + dev = "CPU" + print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev, + batch_size)) + benchmark_rnn(cell, rnn_data, states) diff --git a/benchmark/python/control_flow/rnn.py b/benchmark/python/control_flow/rnn.py index 5e41b7508b6..8a44a9cab17 100644 --- a/benchmark/python/control_flow/rnn.py +++ b/benchmark/python/control_flow/rnn.py @@ -15,175 +15,128 @@ # specific language governing permissions and limitations # under the License. +from __future__ import print_function +from six.moves import range + +import argparse import subprocess +from itertools import product +from time import time + import mxnet as mx +import numpy as np from mxnet import gluon -import time -import copy -def get_gpus(): - """ - return a list of GPUs - """ - try: - re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True) - except OSError: - return [] - return range(len([i for i in re.split('\n') if 'GPU' in i])) -class TestRNNLayer(gluon.HybridBlock): - def __init__(self, cell, prefix=None, params=None): - super(TestRNNLayer, self).__init__(prefix=prefix, params=params) +_parser = argparse.ArgumentParser(description='Benchmark foreach and while_loop on RNN tasks.') +_parser.add_argument('--benchmark', choices=["foreach", "while_loop"], required=True) +_parser.add_argument('--warmup_rounds', type=int, default=20) +_parser.add_argument('--test_rounds', type=int, default=100) +args = _parser.parse_args() + + +class ForeachRNN(gluon.HybridBlock): + def __init__(self, cell, length, prefix=None, params=None): + super(ForeachRNN, self).__init__(prefix=prefix, params=params) + self.length = length self.cell = cell def hybrid_forward(self, F, inputs, states): out, states = F.contrib.foreach(self.cell, inputs, states) return out -def benchmark_rnn(cell, rnn_data, states): - ctx = rnn_data.context - num_batches = 20 - - # Imperative - cell0 = copy.deepcopy(cell) - layer0 = TestRNNLayer(cell0) - layer0.initialize(ctx=ctx) - - # Hybridize - cell1 = copy.deepcopy(cell) - cell1.hybridize() - layer1 = TestRNNLayer(cell1) - layer1.initialize(ctx=ctx) - - # Hybridize - cell2 = copy.deepcopy(cell) - layer2 = TestRNNLayer(cell2) - layer2.initialize(ctx=ctx) - layer2.hybridize() - layer2(rnn_data, states) - - # Hybridize - cell3 = copy.deepcopy(cell) - cell3.hybridize(static_alloc=True) - layer3 = TestRNNLayer(cell3) - layer3.initialize(ctx=ctx) - - tic = time.time() - for i in range(num_batches): - res0 = layer0(rnn_data, states) - mx.nd.waitall() - print("Imperative inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res1 = layer1(rnn_data, states) - mx.nd.waitall() - print("Hybrid-cell inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res3 = layer3(rnn_data, states) - mx.nd.waitall() - print("Static-hybrid-cell inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res2 = layer2(rnn_data, states) - mx.nd.waitall() - print("Hybrid inference takes " + str(time.time() - tic)) - - layer2.export("foreach_rnn") - symnet = mx.symbol.load('foreach_rnn-symbol.json') - args1 = {} - params = layer2.collect_params() - for key in params.keys(): - args1[key] = params[key].data() - args1['data0'] = rnn_data - for i in range(len(states)): - args1['data' + str(i + 1)] = states[i] - exe = symnet.bind(ctx=ctx, args=args1) - tic = time.time() - for i in range(num_batches): - exe.forward(is_train=False) - mx.nd.waitall() - print("Symbol inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res0 = layer0(rnn_data, states) - res0.backward() - mx.nd.waitall() - print("Imperative training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res1 = layer1(rnn_data, states) - res1.backward() - mx.nd.waitall() - print("Hybrid-cell training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res3 = layer3(rnn_data, states) - res3.backward() - mx.nd.waitall() - print("Static-hybrid-cell training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res2 = layer2(rnn_data, states) - res2.backward() - mx.nd.waitall() - print("Hybrid training takes " + str(time.time() - tic)) - - # gradients for the backward of the foreach symbol - args_grad1 = {} - for key in args1.keys(): - args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx) - exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1) - tic = time.time() - for i in range(num_batches): - exe.forward(is_train=True) - exe.backward(res2) - mx.nd.waitall() - print("Symbol training takes " + str(time.time() - tic)) - print("") - -if __name__ == '__main__': - ndim = 512 - seq_len = 100 + +class WhileRNN(gluon.HybridBlock): + def __init__(self, cell, length, prefix=None, params=None): + super(WhileRNN, self).__init__(prefix=prefix, params=params) + self.length = length + self.cell = cell + + def hybrid_forward(self, F, inputs, states): + def _func(*states): + i = states[0] + s = states[1: ] + data = inputs.take(i).squeeze(axis=0) + out, new_s = self.cell(data, s) + new_s = [i + 1] + new_s + return out, new_s + out, states = F.contrib.while_loop( + cond=lambda i, *_: i < self.length, + func=_func, + loop_vars=states, + max_iterations=self.length, + ) + assert len(out) == 1 + return out[0] + + +def _zeros(shape, ctx): + return mx.nd.zeros(shape=shape, ctx=ctx) + + +def _array(shape, ctx): + return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=ctx) + + +def _get_gpus(): + try: + re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True) + except OSError: + return [] + return range(len([i for i in re.split('\n') if 'GPU' in i])) + + +def run_benchmark(cell_type, ctx, seq_len, batch_size, hidden_dim): + obj = {"foreach": ForeachRNN, "while_loop": WhileRNN}[args.benchmark] + inputs = _array((seq_len, batch_size, hidden_dim), ctx) + states = [_array((batch_size, hidden_dim), ctx) for _ in cell_type(0).state_info()] + if args.benchmark == "while_loop": + states.insert(0, _zeros((1, ), ctx)) + + for is_train, is_hyb_cell, is_hyb_layer in product([True, False], [False, True], [False, True]): + cell = cell_type(hidden_dim) + if is_hyb_cell: + cell.hybridize(static_alloc=True) + layer = obj(cell, seq_len) + layer.initialize(ctx=ctx) + if is_hyb_layer: + layer.hybridize(static_alloc=True) + print("is_train = %r, hybridize_cell = %r, hybridize_layer = %r" % (is_train, is_hyb_cell, is_hyb_layer)) + times = [] + for _ in range(args.warmup_rounds + args.test_rounds): + tick = time() + if not is_train: + res = layer(inputs, states) + else: + with mx.autograd.record(): + res = layer(inputs, states) + if is_train: + res.backward() + mx.nd.waitall() + tock = time() + times.append((tock - tick) * 1000.0) + times = times[args.warmup_rounds: ] + print("Time used: mean = %.3f ms, std = %.3f ms" % (np.mean(times), np.std(times))) + + +def main(): + # testing configurations + cell_types = [gluon.rnn.RNNCell, + gluon.rnn.GRUCell, + gluon.rnn.LSTMCell] + ctxs = [mx.cpu(0)] + [mx.gpu(i) for i in _get_gpus()] + seq_lens = [100] batch_sizes = [1, 32] - cells = [gluon.rnn.GRUCell(ndim, prefix='rnn_'), - gluon.rnn.LSTMCell(ndim, prefix='rnn_')] - ctxs = [mx.cpu(0), mx.gpu(0)] - for cell in cells: - for ctx in ctxs: - for batch_size in batch_sizes: - if len(get_gpus()) == 0 and ctx == mx.gpu(0): - continue - - if isinstance(cell, gluon.rnn.GRUCell): - rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), - ctx=mx.cpu(0)) - states = [] - states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), - ctx=mx.cpu(0))) - elif isinstance(cell, gluon.rnn.LSTMCell): - rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), - ctx=mx.cpu(0)) - states = [] - states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), - ctx=mx.cpu(0))) - states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), - ctx=mx.cpu(0))) - if ctx == mx.gpu(0): - dev = "GPU" - else: - dev = "CPU" - print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev, - batch_size)) - benchmark_rnn(cell, rnn_data, states) + hidden_dims = [512] + print("--------------------------------------") + print("Benchmarking", args.benchmark) + for cell_type, ctx, seq_len, batch_size, hidden_dim in product( \ + cell_types, ctxs, seq_lens, batch_sizes, hidden_dims): + print("--------------------------------------") + print("cell: %s ctx: %s length: %d batch size: %d dim: %d" % \ + (cell_type.__name__, str(ctx), seq_len, batch_size, hidden_dim)) + run_benchmark(cell_type, ctx, seq_len, batch_size, hidden_dim) + + +if __name__ == "__main__": + main() diff --git a/benchmark/python/control_flow/while_loop_rnn.py b/benchmark/python/control_flow/while_loop_rnn.py new file mode 100644 index 00000000000..42aaee5840d --- /dev/null +++ b/benchmark/python/control_flow/while_loop_rnn.py @@ -0,0 +1,213 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Code borrowed from ./benchmark/python/control_flow/foreach_rnn.py + +import subprocess +import mxnet as mx +from mxnet import gluon +import time +import copy + +def get_gpus(): + """ + return a list of GPUs + """ + try: + re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True) + except OSError: + return [] + return range(len([i for i in re.split('\n') if 'GPU' in i])) + +class TestRNNLayer(gluon.HybridBlock): + def __init__(self, cell, length, prefix=None, params=None): + super(TestRNNLayer, self).__init__(prefix=prefix, params=params) + self.length = length + self.cell = cell + + def hybrid_forward(self, F, inputs, states): + def _func(*states): + i = states[0] + s = states[1: ] + data = inputs.take(i).squeeze(axis=0) + out, new_s = self.cell(data, s) + new_s = [i + 1] + new_s + return out, new_s + out, states = F.contrib.while_loop( + cond=lambda i, *_: i < self.length, + func=_func, + loop_vars=states, + max_iterations=self.length, + ) + return out + states + +def benchmark_rnn(cell, rnn_data, states, length): + ctx = rnn_data.context + num_batches = 20 + + # Imperative + cell0 = copy.deepcopy(cell) + layer0 = TestRNNLayer(cell0, length) + layer0.initialize(ctx=ctx) + + # Hybrid-cell + cell1 = copy.deepcopy(cell) + cell1.hybridize() + layer1 = TestRNNLayer(cell1, length) + layer1.initialize(ctx=ctx) + + # Hybrid + cell2 = copy.deepcopy(cell) + layer2 = TestRNNLayer(cell2, length) + layer2.initialize(ctx=ctx) + layer2.hybridize() + layer2(rnn_data, states) + + # Static-hybrid-cell + cell3 = copy.deepcopy(cell) + cell3.hybridize(static_alloc=True) + layer3 = TestRNNLayer(cell3, length) + layer3.initialize(ctx=ctx) + + tic = time.time() + for i in range(num_batches): + res0 = layer0(rnn_data, states) + mx.nd.waitall() + print("Imperative inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res1 = layer1(rnn_data, states) + mx.nd.waitall() + print("Hybrid-cell inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res3 = layer3(rnn_data, states) + mx.nd.waitall() + print("Static-hybrid-cell inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res2 = layer2(rnn_data, states) + mx.nd.waitall() + print("Hybrid inference takes " + str(time.time() - tic)) + + layer2.export("while_loop_rnn") + symnet = mx.symbol.load('while_loop_rnn-symbol.json') + args1 = {} + params = layer2.collect_params() + for key in params.keys(): + args1[key] = params[key].data() + args1['data0'] = rnn_data + for i in range(len(states)): + args1['data' + str(i + 1)] = states[i] + exe = symnet.bind(ctx=ctx, args=args1) + tic = time.time() + for i in range(num_batches): + exe.forward(is_train=False) + mx.nd.waitall() + print("Symbol inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res0 = layer0(rnn_data, states) + res0[0].backward() + mx.nd.waitall() + print("Imperative training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res1 = layer1(rnn_data, states) + res1[0].backward() + mx.nd.waitall() + print("Hybrid-cell training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res3 = layer3(rnn_data, states) + res3[0].backward() + mx.nd.waitall() + print("Static-hybrid-cell training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res2 = layer2(rnn_data, states) + res2[0].backward() + mx.nd.waitall() + print("Hybrid training takes " + str(time.time() - tic)) + + # gradients for the backward of the while_loop symbol + args_grad1 = {} + for key in args1.keys(): + if key != "data1": + args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx) + exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1) + tic = time.time() + for i in range(num_batches): + exe.forward(is_train=True) + exe.backward(res2) + mx.nd.waitall() + print("Symbol training takes " + str(time.time() - tic)) + print("") + +if __name__ == '__main__': + def _zeros(shape): + return mx.nd.zeros(shape=shape, ctx=mx.cpu(0)) + def _array(shape): + return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=mx.cpu(0)) + ndim = 512 + seq_len = 100 + batch_sizes = [1, 32] + cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'), + gluon.rnn.GRUCell(ndim, prefix='rnn_'), + gluon.rnn.LSTMCell(ndim, prefix='rnn_')] + ctxs = [mx.cpu(0), mx.gpu(0)] + for cell in cells: + for ctx in ctxs: + for batch_size in batch_sizes: + if len(get_gpus()) == 0 and ctx == mx.gpu(0): + continue + if isinstance(cell, gluon.rnn.RNNCell): + rnn_data = _array((seq_len, batch_size, ndim)) + states = [ + _zeros((1, )), + _array((batch_size, ndim)), + ] + if isinstance(cell, gluon.rnn.GRUCell): + rnn_data = _array((seq_len, batch_size, ndim)) + states = [ + _zeros((1, )), + _array((batch_size, ndim)), + ] + elif isinstance(cell, gluon.rnn.LSTMCell): + rnn_data = _array((seq_len, batch_size, ndim)) + states = [ + _zeros((1, )), + _array((batch_size, ndim)), + _array((batch_size, ndim)), + ] + if ctx == mx.gpu(0): + dev = "GPU" + else: + dev = "CPU" + print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev, batch_size)) + benchmark_rnn(cell, rnn_data, states, seq_len) diff --git a/docs/api/python/ndarray/contrib.md b/docs/api/python/ndarray/contrib.md index 36a2c151e85..0cf8724de30 100644 --- a/docs/api/python/ndarray/contrib.md +++ b/docs/api/python/ndarray/contrib.md @@ -53,6 +53,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib` ifft quantize foreach + while_loop ``` ## API Reference diff --git a/docs/api/python/symbol/contrib.md b/docs/api/python/symbol/contrib.md index 66471656050..ba43f2d6633 100644 --- a/docs/api/python/symbol/contrib.md +++ b/docs/api/python/symbol/contrib.md @@ -53,6 +53,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib` ifft quantize foreach + while_loop ``` ## API Reference diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index b1f065e9f82..b67cf5a55da 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -28,7 +28,7 @@ except ImportError: pass -__all__ = ["rand_zipfian"] +__all__ = ["rand_zipfian", "foreach", "while_loop"] # pylint: disable=line-too-long def rand_zipfian(true_classes, num_sampled, range_max, ctx=None): @@ -191,3 +191,175 @@ def check_input(inputs, in_type, msg): if not_data_list and len(outputs) == 1: outputs = outputs[0] return (outputs, states) + + +def while_loop(cond, func, loop_vars, max_iterations=None): + """Run a while loop with user-defined computation and loop condition. + + This operator simulates a while loop which iterately does customized computation + as long as the condition is satisfied. + + `loop_vars` is a list of NDArrays on which the computation uses. + + `cond` is a user-defined function, used as the loop condition. + It consumes `loop_vars`, and produces a scalar MXNet NDArray, + indicating the termination of the loop. + The loop ends when `cond` returns false (zero). + The `cond` is variadic, and its signature should be + `cond(*loop_vars) => NDArray`. + + `func` is a user-defined function, used as the loop body. + It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step. + In each step, `step_output` should contain the same number elements. + Through all steps, the i-th element of `step_output` should have the same shape and dtype. + Also, `new_loop_vars` should contain the same number of elements as `loop_vars`, + and the corresponding element should have the same shape and dtype. + The `func` is variadic, and its signature should be + `func(*loop_vars) => (List[NDArray] step_output, List[NDArray] new_loop_vars)`. + + `max_iterations` is a scalar that defines the maximum number of iterations allowed. + + This function returns two lists. + The first list has the length of `|step_output|`, + in which the i-th element are all i-th elements of + `step_output` from all steps, stacked along axis 0. + The second list has the length of `|loop_vars|`, + which represents final states of loop variables. + + .. warning:: + + For now, the axis 0 of all NDArrays in the first list are `max_iterations`, + due to lack of dynamic shape inference. + + .. warning:: + + When `cond` is never satisfied, we assume `step_output` is empty, + because it cannot be inferred. This is different from the symbolic version. + + Parameters + ---------- + cond: a Python function. + The loop condition. + func: a Python function. + The loop body. + loop_vars: list of NDArrays. + The initial values of the loop variables. + max_iterations: a python int. + Maximum number of iterations. + + Returns + ------ + outputs: list of NDArrays + stacked output from each step + states: list of NDArrays + final state + + Examples + -------- + >>> cond = lambda i, s: i <= 5 + >>> func = lambda i, s: ([i + s], [i + 1, s + i]) + >>> loop_vars = (mx.nd.array([0], dtype="int64"), mx.nd.array([1], dtype="int64")) + >>> outputs, states = mx.nd.contrib.while_loop(cond, func, loop_vars, max_iterations=10) + >>> outputs + [ + [[ 1] + [ 2] + [ 4] + [ 7] + [11] + [16] + [...] # undefined value + [...] + [...] + [...]] + <NDArray 6x1 @cpu(0)>] + >>> states + [ + [6] + <NDArray 1 @cpu(0)>, + [16] + <NDArray 1 @cpu(0)>] + """ + def _to_python_scalar(inputs, type_, name): + """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, + to the given type + """ + if isinstance(inputs, ndarray.NDArray): + inputs = inputs.asscalar() + try: + inputs = type_(inputs) + except: + raise ValueError("Cannot convert %s to python %s" % (name, type_.__name__)) + return inputs + + def _to_ndarray_tuple(inputs, name): + """Converts "inputs", possibly a single mxnet NDArray, a list of mxnet NDArray, + a tuple of mxnet NDArray, into a tuple of NDArray + """ + if isinstance(inputs, list): + inputs = tuple(inputs) + if isinstance(inputs, ndarray.NDArray): + inputs = (inputs, ) + if not isinstance(inputs, tuple): + raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) + for item in inputs: + if not isinstance(item, ndarray.NDArray): + raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) + return inputs + + def _func_wrapper(loop_vars): + """This wrapper unifies + "func: loop_vars -> new_loop_vars" + and "func: loop_vars -> (step_output, new_loop_vars)" + into "func: loop_vars -> (None or tuple of step_outputs, tuple of new_loop_vars) + """ + step_output, new_loop_vars = func(*loop_vars) + if step_output is None: + step_output = [] + if new_loop_vars is None: + new_loop_vars = [] + step_output = _to_ndarray_tuple(step_output, "step_output") + new_loop_vars = _to_ndarray_tuple(new_loop_vars, "new_loop_vars") + if len(loop_vars) != len(new_loop_vars): + raise ValueError("The length of loop_vars should be consistent during the loop") + return step_output, new_loop_vars + + if max_iterations is None: + raise ValueError("max_iterations should be specified") + max_iterations = _to_python_scalar(max_iterations, int, "max_iteration") + loop_vars = _to_ndarray_tuple(loop_vars, "loop_vars") + # It should be work as fine if loop_vars are empty I guess, + # but it is semantically unnecessary to include this case. + if len(loop_vars) == 0: + raise ValueError("loop_vars should contain at least one element") + + steps = 0 + outputs = [] + while steps < max_iterations and \ + _to_python_scalar(cond(*loop_vars), bool, "Return value of cond"): # loop condition + step_output, loop_vars = _func_wrapper(loop_vars) + outputs.append(step_output) + steps += 1 + if len(outputs) != steps or len(step_output) != len(outputs[0]): + raise ValueError("Number of elements in step_output should be the same in each step") + stacked_outputs = [] + for i_th, items in enumerate(zip(*outputs), 1): + # `mx.ndarray.pad` only support 4-D or 5-D inputs for now + # so we could not use it. + items = [x.expand_dims(0) for x in items] + if steps != max_iterations and items: + pad_shape = [max_iterations - steps] + list(items[0].shape[1: ]) + pad = ndarray.empty( + shape=pad_shape, + ctx=items[0].context, + dtype=items[0].dtype, + ) + items = list(items) + [pad] + try: + stacked_outputs.append(ndarray.op.concat(*items, dim=0)) + except ValueError: + raise ValueError("\n".join( + ["Shapes of %d-th elements in step_outputs are inconsistent, which are:" % i_th] + + [" Step %d, shape is %s" % (i, str(x.shape)) for i, x in enumerate(items)] + )) + return stacked_outputs, list(loop_vars) diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 28bb507dd13..2c11921383c 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -34,7 +34,7 @@ from ..base import SymbolHandle, _as_list from ..attribute import AttrScope -__all__ = ["rand_zipfian", "foreach"] +__all__ = ["rand_zipfian", "foreach", "while_loop"] def rand_zipfian(true_classes, num_sampled, range_max): """Draw random samples from an approximately log-uniform or Zipfian distribution. @@ -336,3 +336,223 @@ def check_data(inputs, in_type, msg): states = states[0] return (outs, states) + +def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"): + """Run a while loop with user-defined computation and loop condition. + + This operator simulates a while loop which iterately does customized computation + as long as the condition is satisfied. + + `loop_vars` is a list of Symbols on which the computation uses. + + `cond` is a user-defined function, used as the loop condition. + It consumes `loop_vars`, and produces a scalar MXNet symbol, + indicating the termination of the loop. + The loop ends when `cond` returns false (zero). + The `cond` is variadic, and its signature should be + `cond(*loop_vars) => Symbol`. + + `func` is a user-defined function, used as the loop body. + It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step. + In each step, `step_output` should contain the same number elements. + Through all steps, the i-th element of `step_output` should have the same shape and dtype. + Also, `new_loop_vars` should contain the same number of elements as `loop_vars`, + and the corresponding element should have the same shape and dtype. + The `func` is variadic, and its signature should be + `func(*loop_vars) => (List[Symbol] step_output, List[Symbol] new_loop_vars)`. + + `max_iterations` is a scalar that defines the maximum number of iterations allowed. + + This function returns two lists. + The first list has the length of `|step_output|`, + in which the i-th element are all i-th elements of + `step_output` from all steps, stacked along axis 0. + The second list has the length of `|loop_vars|`, + which represents final states of loop variables. + + .. warning:: + + For now, the axis 0 of all Symbols in the first list are `max_iterations`, + due to lack of dynamic shape inference. + + .. warning:: + + Even if `cond` is never satisfied, + while_loop returns a list of outputs with inferred dtype and shape. + This is different from the Symbol version, + where in this case `step_outputs` are assumed as an empty list. + + Parameters + ---------- + cond: a Python function. + The loop condition. + func: a Python function. + The loop body. + loop_vars: list of Symbol. + The initial values of the loop variables. + max_iterations: a python int. + Maximum number of iterations. + + Returns + ------ + outputs: list of Symbols + stacked output from each step + states: list of Symbols + final state + + Examples + -------- + >>> cond = lambda i, s: i <= 5 + >>> func = lambda i, s: ([i + s], [i + 1, s + i]) + >>> loop_vars = (mx.sym.var('i'), mx.sym.var('s')) + >>> outputs, states = mx.sym.contrib.while_loop(cond, func, loop_vars, max_iterations=10) + """ + def _to_python_scalar(inputs, type_, name): + """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, + to the given type + """ + if hasattr(inputs, "asscalar"): + inputs = inputs.asscalar() + try: + inputs = type_(inputs) + except: + raise ValueError("Cannot convert %s to python %s" % (name, type_.__name__)) + return inputs + + def _to_symbol_tuple(inputs, name): + """Converts "inputs", possibly a single mxnet Symbol, a list of mxnet Symbol, + a tuple of mxnet Symbol, into a tuple of Symbol + """ + if isinstance(inputs, list): + inputs = tuple(inputs) + if isinstance(inputs, Symbol): + inputs = (inputs, ) + if not isinstance(inputs, tuple): + raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, )) + for item in inputs: + if not isinstance(item, Symbol): + raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, )) + return inputs + + def _cond_wrapper(loop_vars): + result = cond(*loop_vars) + if not isinstance(result, Symbol): + raise ValueError("Return of cond must be a Symbol") + return [], [result] + + def _func_wrapper(loop_vars): + """This wrapper unifies + "func: loop_vars -> new_loop_vars" + and "func: loop_vars -> (step_output, new_loop_vars)" + into "func: loop_vars -> (list of step_outputs, tuple of new_loop_vars) + """ + step_output, new_loop_vars = func(*loop_vars) + if step_output is None: + step_output = [] + if new_loop_vars is None: + new_loop_vars = [] + step_output = _to_symbol_tuple(step_output, "step_output") + new_loop_vars = _to_symbol_tuple(new_loop_vars, "new_loop_vars") + if len(loop_vars) != len(new_loop_vars): + raise ValueError("The number of loop_vars should be consistent during the loop") + return list(step_output), list(new_loop_vars) + + def _create_subgraph(graph_vars, graph_func, subgraph_name): + with AttrScope(__subgraph_name__=subgraph_name): + # create new variables with the same name, + # them feed them to the given func + new_graph_vars = [symbol.var(sym.name) for sym in graph_vars] + outputs, final_state = graph_func(new_graph_vars) + # first `num_out_data` elements belong to `outputs` + # other elements belong to `final_state` + num_out_data = len(outputs) + num_outputs = len(outputs) + len(final_state) + # nnvm cut-graph does not allow inputs and outputs overlap + # so we calculate the name of inputs, and copy outputs once it overlaps with inputs + all_input_names = symbol.Group(outputs + final_state).list_inputs() + make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x + # group all outputs of graph_func + graph = symbol.Group(list(map(make_identity, outputs + final_state))) + return graph, num_out_data, num_outputs + + def _union_inputs(*graphs): + # Given a list of graphs, each whose inputs are either from loop_vars or other variables. + # 1) calculate a list `inputs`, the union of their inputs. + # 2) for each graph, determine in which indices their inputs reside in `inputs` + # 3) for each variable in the input of `graph`, find which index it is + inputs = [] # List[Symbol], result of 1) + locs = [] # List[Tuple(List[Int], List[Int])], a list of tuples, + # where tuples are results of 2) and 3) + input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it + # to a `loc`, where inputs[loc] = sym + for graph in graphs: + # input_syms: all inputs to the `graph` + name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)} + # some loop_vars are inputs to `graph`, some are not + name_to_loop_vars = {sym.name: sym for sym in loop_vars} + # other inputs to `graph` created by cut_graph + name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)} + # also we collect the mapping from var's name to var's loc in loop_vars + name_to_var_locs = {sym.name: i for i, sym in enumerate(loop_vars)} + # collect arguments for each subgraph + input_locs = [] # results from the second step + var_locs = [-1] * len(loop_vars) # results from the third step + for name in graph.list_inputs(): + assert name in name_to_input_syms # it should obviously hold + # name -> sym + if name in name_to_loop_vars: + sym = name_to_loop_vars[name] + elif name in name_to_cut_g_syms: + sym = name_to_cut_g_syms[name] + else: + sym = copy.deepcopy(name_to_input_syms[name]) + # do 2), and 1) is implicitly done + if id(sym) in input_id_to_loc: + loc = input_id_to_loc[id(sym)] + else: + loc = len(input_id_to_loc) + inputs.append(sym) + input_id_to_loc[id(sym)] = loc + input_locs.append(loc) + # do 3) + if name in name_to_var_locs: + var_locs[name_to_var_locs[name]] = len(input_locs) - 1 + locs.append((input_locs, var_locs)) + return inputs, locs + if max_iterations is None: + raise ValueError("max_iterations should be specified") + max_iterations = _to_python_scalar(max_iterations, int, "max_iteration") + loop_vars = _to_symbol_tuple(loop_vars, "loop_vars") + # It should be work as fine if loop_vars are empty I guess, + # but it is semantically unnecessary to include this case. + if len(loop_vars) == 0: + raise ValueError("loop_vars should contain at least one element") + # create graph for `cond' + cond_g, num_out_data, num_outputs = \ + _create_subgraph(loop_vars, _cond_wrapper, name + "_cond") + assert num_out_data == 0 + assert num_outputs == 1 + # create graph for `func` + func_g, num_out_data, num_outputs = \ + _create_subgraph(loop_vars, _func_wrapper, name + "_func") + # find symbols used in either cond_g or func_g + input_syms, ((cond_input_locs, _), (func_input_locs, func_var_locs)) = \ + _union_inputs(cond_g, func_g) + for i_th, loc in enumerate(func_var_locs, 1): + if loc == -1: + raise ValueError("The %d-th loop_var doesn't involve into the computation" % i_th) + result = symbol._internal._while_loop( + # [cond, func_g, *input_syms] + cond_g, + func_g, + *input_syms, + max_iterations=max_iterations, + cond_input_locs=cond_input_locs, + func_input_locs=func_input_locs, + func_var_locs=func_var_locs, + num_out_data=num_out_data, + num_outputs=num_outputs + ) + outputs = [result[i] for i in range(num_out_data)] + final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)] + return outputs, final_loop_vars diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index c091fdb67e0..b00ed9b19d8 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -480,6 +480,503 @@ ForeachGradient(const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ogra return entries; } +struct WhileLoopParam : public dmlc::Parameter<WhileLoopParam> { + int num_args; + int num_outputs; + int num_out_data; + int max_iterations; + // `cond' and `func' each takes a subset of while_loop's inputs as that to their subgraphs + // `cond_input_locs' contains indices of inputs fed to `cond', and + // `func_input_locs' contains indices of inputs fed to `func'. + // `func_var_locs' are indices in which input "variables" are stored in func's inputs. + nnvm::Tuple<dim_t> cond_input_locs; + nnvm::Tuple<dim_t> func_input_locs; + nnvm::Tuple<dim_t> func_var_locs; + DMLC_DECLARE_PARAMETER(WhileLoopParam) { + DMLC_DECLARE_FIELD(num_args).set_lower_bound(2) + .describe("Number of input arguments, including cond and func as two symbol inputs."); + DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1) + .describe("The number of outputs of the subgraph."); + DMLC_DECLARE_FIELD(num_out_data).set_lower_bound(0) + .describe("The number of outputs from the function body."); + DMLC_DECLARE_FIELD(max_iterations).set_lower_bound(1) + .describe("Maximum number of iterations."); + DMLC_DECLARE_FIELD(cond_input_locs) + .describe("The locations of cond's inputs in the given inputs."); + DMLC_DECLARE_FIELD(func_input_locs) + .describe("The locations of func's inputs in the given inputs."); + DMLC_DECLARE_FIELD(func_var_locs) + .describe("The locations of loop_vars among func's inputs."); + } +}; // struct WhileLoopParam + +DMLC_REGISTER_PARAMETER(WhileLoopParam); + +class WhileLoopState: public LoopState { + public: + WhileLoopParam params; + size_t n_iterations; // the actual number of steps taken in this while loop, <= max_iterations + CachedOpPtr cond_op; + // abbrev for output_input_mapping + // indicates to which index the output of `func' will be copied to the input of `cond' + std::vector<int> oi_map; + + WhileLoopState(const WhileLoopParam ¶ms, const Symbol &cond, const Symbol &func) : + LoopState(func), + params(params), + n_iterations(0U), + cond_op(LoopState::MakeSharedOp(cond)), + oi_map(params.func_var_locs.ndim(), -1) { + const nnvm::Tuple<dim_t> &func_input_locs = params.func_input_locs; + const nnvm::Tuple<dim_t> &func_var_locs = params.func_var_locs; + const nnvm::Tuple<dim_t> &cond_input_locs = params.cond_input_locs; + for (size_t i = 0; i < func_var_locs.ndim(); ++i) { + dim_t pos_i = func_input_locs[func_var_locs[i]]; + for (size_t j = 0; j < cond_input_locs.ndim(); ++j) { + dim_t pos_j = cond_input_locs[j]; + if (pos_i == pos_j) { + this->oi_map[i] = j; + } + } + } + } + template <typename T> + static void extract_by_loc(const std::vector<T> &array, + const nnvm::Tuple<dim_t> input_locs, + std::vector<T> *out) { + out->clear(); + out->reserve(input_locs.ndim()); + for (dim_t i : input_locs) { + out->push_back(array[i]); + } + } + static bool is_shape_udf(const TShape &x) { + return x.ndim() == 0 || x.Size() == 0; + } + static bool is_stype_udf(const int &x) { + return x == exec::kBadStorageID; + } + static bool is_type_udf(const int &x) { + return x == -1; + } + template <typename T> + static bool fill_value(T *x, T *y, bool x_empty, bool y_empty) { + if (*x == *y || (x_empty && y_empty)) { + return true; + } + if (!x_empty && !y_empty) { + return false; + } + if (x_empty) { + *x = *y; + } + if (y_empty) { + *y = *x; + } + return true; + } + template <typename T> + static bool sync_in_in(const nnvm::Tuple<dim_t> &input_locs, + std::vector<T> *in, + std::vector<T> *subg_in, + std::function<bool(const T &)> is_empty) { + for (size_t i = 0; i < input_locs.ndim(); ++i) { + T &x = in->at(input_locs[i]); + T &y = subg_in->at(i); + fill_value(&x, &y, is_empty(x), is_empty(y)); + } + return true; + } + template <typename T> + static bool sync_in_out(const WhileLoopParam& params, + std::vector<T> *in, + std::vector<T> *out, + std::function<bool(const T &)> is_empty) { + for (int i = params.num_out_data; i < params.num_outputs; ++i) { + // each out->at(i) is a params, loop_var + T &x = in->at(params.func_input_locs[params.func_var_locs[i - params.num_out_data]]); + T &y = out->at(i); + fill_value(&x, &y, is_empty(x), is_empty(y)); + } + return true; + } +}; + +template <typename T> +T _asscalar(const NDArray &a) { + CHECK_EQ(a.shape().Size(), 1U); + T data; + a.SyncCopyToCPU(&data, 1U); + return data; +} + +bool as_bool_scalar(const NDArray &a) { + MSHADOW_TYPE_SWITCH(a.dtype(), DType, { + return static_cast<bool>(_asscalar<DType>(a)); + }); + LOG(FATAL) << "Unknown dtype"; + return false; +} + +static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector<NDArray>& inputs, + const std::vector<OpReqType>& req, + const std::vector<NDArray>& outputs) { + // The argument `inputs' are loop_vars and other inputs + // loop_vars are stored in stored in `loop_vars_locs' + // The argument `outputs' are output and new_loop_vars + // [0: num_out_data) are outputs at each step. + // [num_out_data: ) are new_loop_vars + // TODO(Junru): avoid dynamic NDArray allocation + WhileLoopState &state = state_ptr.get_state<WhileLoopState>(); + const WhileLoopParam& params = state.params; + // a helper function, converting std::vector<NDArray> to std::vector<NDArray*> + const auto to_ptr_vec = [](std::vector<NDArray> &in, std::vector<NDArray*> *out) { + out->clear(); + out->reserve(in.size()); + std::transform(std::begin(in), + std::end(in), + std::back_inserter(*out), + [](NDArray &a) {return &a;}); + }; + // sanity checks + CHECK_EQ(inputs.size() + 2U, (size_t) params.num_args); + CHECK_EQ(outputs.size(), (size_t) params.num_outputs); + CHECK_EQ(outputs.size(), req.size()); + for (size_t i = 0; i < (size_t) params.num_out_data; i++) + CHECK_EQ(params.max_iterations, outputs[i].shape()[0]); + // construct inputs and outputs for cond + std::vector<NDArray> cond_inputs, cond_outputs = {NDArray()}; + WhileLoopState::extract_by_loc(inputs, params.cond_input_locs, &cond_inputs); + std::vector<NDArray*> cond_input_ptr, cond_output_ptr; + to_ptr_vec(cond_inputs, &cond_input_ptr); + to_ptr_vec(cond_outputs, &cond_output_ptr); + // construct inputs and outputs for func + std::vector<NDArray> func_inputs, func_outputs(outputs.size()); + WhileLoopState::extract_by_loc(inputs, params.func_input_locs, &func_inputs); + for (size_t &step = state.n_iterations = 0; step < (size_t) params.max_iterations; ++step) { + state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr); + if (!as_bool_scalar(*cond_output_ptr[0])) { + break; + } + // we create func_outputs for the current step: + // func_outputs[0: num_out_data] is a slice of outputs[][step] + for (size_t i = 0; i < (size_t) params.num_out_data; ++i) { + func_outputs[i] = outputs[i].At(step); + } + // func_outputs[num_out_data: ] are new_loop_vars, need to allocate new memory + for (size_t i = params.num_out_data; i < outputs.size(); ++i) { + func_outputs[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true, outputs[i].dtype()); + } + state.Forward(step, func_inputs, req, func_outputs, ctx.need_grad); + // func_inputs on the next step: + // the output (new_loop_vars) will become the new inputs (loop_vars) + for (size_t i = params.num_out_data; i < outputs.size(); ++i) { + size_t j = params.func_var_locs[i - params.num_out_data]; + CHECK_EQ(func_inputs[j].shape(), func_outputs[i].shape()); + func_inputs[j] = func_outputs[i]; + int k = state.oi_map[i - params.num_out_data]; + if (k != -1) { + // I actually don't need to update cond_inputs + cond_inputs[k] = func_outputs[i]; + cond_input_ptr[k] = &func_outputs[i]; + } + } + } + // copy output data to `outputs' + // case 1: at least one step is executed, + // the final_loop_vars must be stored in func_inputs + // case 2: no step is executed + // the final_loop_vars is the same as loop_vars, which are also stored in func_inputs + // therefore, we copy func_inputs[:] to outputs[num_out_data: ] + for (size_t i = params.num_out_data; i < outputs.size(); ++i) { + size_t j = params.func_var_locs[i - params.num_out_data]; + mxnet::CopyFromTo(func_inputs[j], &outputs[i]); + } +} + +static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector<NDArray>& inputs, + const std::vector<OpReqType>& _req, + const std::vector<NDArray>& _outputs) { + // inputs are dl / df(x) + // outputs are dl / dx + // where f is the current function, + // x is the input to the current function, + // TODO(Junru): avoid dynamic NDArray allocation + WhileLoopState &state = state_ptr.get_state<WhileLoopState>(); + const WhileLoopParam& params = state.params; + // sanity checks + CHECK_EQ(_outputs.size() + 2U, (size_t) params.num_args); + CHECK_EQ(_outputs.size(), _req.size()); + for (auto x : _req) { + CHECK_NE(x, kWriteInplace); + } + std::vector<NDArray> outputs; + std::vector<OpReqType> req; + WhileLoopState::extract_by_loc(_outputs, params.func_input_locs, &outputs); + WhileLoopState::extract_by_loc(_req, params.func_input_locs, &req); + if (state.n_iterations == 0) { + for (int i = params.num_out_data; i < params.num_outputs; ++i) { + int j = params.func_var_locs[i - params.num_out_data]; + mxnet::CopyFromTo(inputs[i], &outputs[j]); + } + state.Cleanup(); + return; + } + // collect var_locs and out_locs, positions other than var_locs are out_locs, i.e. + // [0, var_locs[0]) + // (var_locs[1], var_locs[2]), + // (var_locs[2], var_locs[3]), + // ... + // (var_locs[-2], var_locs[-1] = params.num_args - 2) + std::vector<dim_t> var_locs(params.func_var_locs.begin(), params.func_var_locs.end()); + var_locs.push_back((dim_t) params.num_args - 2U); + sort(var_locs.begin(), var_locs.end()); + // vectors for the backward loop + std::vector<NDArray> ograds(params.num_outputs); + std::vector<NDArray> igrads(outputs.size()); + std::vector<OpReqType> iter_req(req.size()); + for (int i = params.num_out_data; i < params.num_outputs; ++i) + ograds[i] = inputs[i]; + const int n_iter = state.n_iterations; + for (int step = n_iter - 1; step >= 0; --step) { + // ograds[ : num_out_data] = inputs[ : num_out_data][step] + // ograds[num_out_data: ] is maintained in the end of each loop + std::transform(std::begin(inputs), + std::begin(inputs) + params.num_out_data, + std::begin(ograds), + [step] (const NDArray &a) { return a.At(step); } ); + // igrads[i] = + // outputs[i] (step == 0) + // outputs[i] (step != 0 && i not in loop_var_locs) + // ArrayLike(outputs[i]) (step != 0 && i in loop_var_locs) + // iter_req = + // kWriteTo (step != 0 && i in loop_var_locs) + // req[i] (step == 0 && i in loop_var_locs) + // kAddTo (step != n_iters - 1 && i not in loop_var_locs) + // req[i] (step == n_iters - 1 && i not in loop_var_locs) + { + size_t i = 0; + for (size_t loc : var_locs) { + for ( ; i < loc; ++i) { + // locs other that var_locs + igrads[i] = outputs[i]; + iter_req[i] = (step + 1 == n_iter || req[i] == kNullOp) + ? req[i] + : kAddTo; + } + if (i < (size_t) params.num_args - 2U) { + // a var + igrads[i] = (step == 0) + ? outputs[i] + : NDArray(outputs[i].shape(), outputs[i].ctx(), true, outputs[i].dtype()); + iter_req[i] = (step == 0 || req[i] == kNullOp) + ? req[i] + : kWriteTo; + ++i; + } else { + break; + } + } + } + state.Backward(step, ograds, iter_req, igrads); + for (int i = params.num_out_data; i < params.num_outputs; ++i) { + size_t j = params.func_var_locs[i - params.num_out_data]; + ograds[i] = igrads[j]; + } + } + state.Cleanup(); +} + +static bool WhileLoopShape(const nnvm::NodeAttrs& attrs, + std::vector<TShape> *in_shape, + std::vector<TShape> *out_shape) { + using nnvm::ShapeVector; + const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed); + static const std::function<bool(const TShape &)> is_udf = WhileLoopState::is_shape_udf; + // sanity checks + CHECK_EQ(in_shape->size() + 2U, (size_t) params.num_args); + CHECK_EQ(out_shape->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 2U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + // infer shape for cond and func + auto infer_subg = [¶ms, in_shape, out_shape](std::shared_ptr<Symbol> subg, + ShapeVector *_subg_out, + const nnvm::Tuple<dim_t> &input_locs, + int num_out_data, + bool fill_out_shape) { + // create subg_in + ShapeVector subg_in; + ShapeVector &subg_out = *_subg_out; + WhileLoopState::extract_by_loc(*in_shape, input_locs, &subg_in); + // create an indexed graph + nnvm::Graph g; + g.outputs = subg->outputs; + const auto& idx = g.indexed_graph(); + // get input nodes + const auto &input_nids = idx.input_nodes(); + // sanity checks + CHECK_EQ(input_nids.size(), subg_in.size()); + CHECK_EQ(g.outputs.size(), subg_out.size()); + CHECK_EQ(idx.input_nodes().size(), subg_in.size()); + CHECK_EQ(idx.outputs().size(), subg_out.size()); + // create empty shapes for inference + ShapeVector shapes(idx.num_node_entries()); + // copy subg_in into shapes + for (size_t i = 0; i < subg_in.size(); ++i) { + auto eid = idx.entry_id(input_nids[i], 0); + shapes[eid] = subg_in[i]; + } + // copy subg_out into shapes + // note that ndim of out_data is not increased + // because subg is only one step + for (size_t i = 0; i < subg_out.size(); ++i) { + auto eid = idx.entry_id(g.outputs[i]); + shapes[eid] = subg_out[i]; + } + // copy done, call InferShape + g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes)); + g = exec::InferShape(std::move(g)); + // now `shapes' won't be used anymore, use new_shapes instead + const auto& new_shapes = g.GetAttr<ShapeVector>("shape"); + // copy subg_in back to in_shape + for (size_t i = 0; i < subg_in.size(); ++i) { + auto eid = idx.entry_id(input_nids[i], 0); + auto g_out_shape = new_shapes[eid]; + if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + // when the shape is not fully inferred + continue; + } + SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape); + } + if (!fill_out_shape) { + return true; + } + // copy subg_out back to out_shape + // for results in [0, num_out_data), ndim should increase by 1 + for (int i = 0; i < num_out_data; ++i) { + auto eid = idx.entry_id(g.outputs[i]); + auto g_out_shape = new_shapes[eid]; + if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + // when the shape is not fully inferred + continue; + } + auto out = TShape(g_out_shape.ndim() + 1); + out[0] = params.max_iterations; + for (size_t i = 1; i < out.ndim(); i++) + out[i] = g_out_shape[i - 1]; + SHAPE_ASSIGN_CHECK(*out_shape, i, out); + } + // for results in [num_out_data, ...), ndim does not change + for (size_t i = num_out_data; i < g.outputs.size(); ++i) { + auto eid = idx.entry_id(g.outputs[i]); + auto g_out_shape = new_shapes[eid]; + if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + // when the shape is not fully inferred + continue; + } + SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape); + } + return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0; + }; + ShapeVector cond_out_shape{TShape(1U)}; // this means: [(1, )] + ShapeVector func_out_shape(params.num_outputs); + CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, params.cond_input_locs, 0, false); + CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + bool succ_1 = infer_subg(attrs.subgraphs[1], &func_out_shape, \ + params.func_input_locs, params.num_out_data, true); + CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + return succ_0 && succ_1; +} + +static bool WhileLoopType(const nnvm::NodeAttrs& attrs, + std::vector<int> *in_type, std::vector<int> *out_type) { + const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed); + static const std::function<bool(const int &)> is_udf = WhileLoopState::is_type_udf; + CHECK_EQ(in_type->size() + 2U, (size_t) params.num_args); + CHECK_EQ(out_type->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 2U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + std::vector<int> cond_in_type; + std::vector<int> func_in_type; + WhileLoopState::extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type); + WhileLoopState::extract_by_loc(*in_type, params.func_input_locs, &func_in_type); + std::vector<int> cond_out_type = {0}; + CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); + bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type, &cond_out_type); + CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); + CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_type, &cond_in_type, is_udf)); + bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &func_in_type, out_type); + CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); + CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_type, &func_in_type, is_udf)); + return succ_0 && succ_1; +} + +static bool WhileLoopStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector<int> *in_attrs, + std::vector<int> *out_attrs) { + const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed); + static const std::function<bool(const int &)> is_udf = WhileLoopState::is_stype_udf; + CHECK_EQ(in_attrs->size() + 2U, (size_t) params.num_args); + CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 2U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + std::vector<int> cond_in_attrs; + std::vector<int> func_in_attrs; + WhileLoopState::extract_by_loc(*in_attrs, params.cond_input_locs, &cond_in_attrs); + WhileLoopState::extract_by_loc(*in_attrs, params.func_input_locs, &func_in_attrs); + std::vector<int> cond_out_attrs = {kDefaultStorage}; + DispatchMode cond_mode = DispatchMode::kUndefined; + DispatchMode func_mode = DispatchMode::kUndefined; + *dispatch_mode = DispatchMode::kFComputeEx; + CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); + bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, \ + &cond_mode, &cond_in_attrs, &cond_out_attrs); + CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); + CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_attrs, &cond_in_attrs, is_udf)); + bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, \ + &func_mode, &func_in_attrs, out_attrs); + CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); + CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_attrs, &func_in_attrs, is_udf)); + return succ_0 && succ_1; +} + +static bool BackwardWhileLoopStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector<int> *in_attrs, + std::vector<int> *out_attrs) { + // `cond' is not backwarded, don't check + const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed); + CHECK_EQ(out_attrs->size() + 2U, (size_t) params.num_args); + CHECK_EQ(attrs.subgraphs.size(), 2U); + CachedOp op(*attrs.subgraphs[1], {}); + return op.BackwardStorageType(attrs, dev_mask, dispatch_mode, + in_attrs, out_attrs); +} + +static OpStatePtr CreateWhileLoopState(const NodeAttrs& attrs, + Context ctx, + const std::vector<TShape>& ishape, + const std::vector<int>& itype) { + const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed); + return OpStatePtr::Create<WhileLoopState>(params, *attrs.subgraphs[0], *attrs.subgraphs[1]); +} + +static std::vector<nnvm::NodeEntry> +WhileLoopGradient(const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) { + ElemwiseGradUseInOut fgrad{"_backward_while_loop"}; + std::vector<nnvm::NodeEntry> entries = fgrad(n, ograds); + entries[0].node->attrs.subgraphs = n->attrs.subgraphs; + return entries; +} + NNVM_REGISTER_OP(_foreach) .MXNET_DESCRIBE("Run a for loop over an NDArray with user-defined computation") .set_attr_parser(ParamParser<ForeachParam>) @@ -526,11 +1023,11 @@ NNVM_REGISTER_OP(_backward_foreach) .set_num_inputs([](const NodeAttrs& attrs){ const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed); return params.num_outputs * 2 + params.num_args - 1; - }) +}) .set_num_outputs([](const NodeAttrs& attrs){ const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed); return params.num_args - 1; - }) +}) .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) { return ExecType::kSubgraphExec; }) @@ -541,5 +1038,67 @@ NNVM_REGISTER_OP(_backward_foreach) .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", ForeachGradComputeExCPU) .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", ForeachGradComputeExCPU); +NNVM_REGISTER_OP(_while_loop) +.MXNET_DESCRIBE("Run a while loop over with user-defined condition and computation") +.set_attr_parser(ParamParser<WhileLoopParam>) +.set_attr<FInferStorageType>("FInferStorageType", WhileLoopStorageType) +.set_num_inputs([](const NodeAttrs& attrs) { + const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed); + return params.num_args; +}) +.set_num_outputs([](const NodeAttrs& attrs) { + const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed); + return params.num_outputs; +}) +.set_attr<nnvm::FListInputNames>("FListInputNames", + [](const NodeAttrs& attrs) { + const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed); + std::vector<std::string> names; + names.reserve(params.num_args); + names.push_back("cond"); + names.push_back("func"); + for (int i = 2; i < params.num_args; i++) + names.push_back("data" + std::to_string(i - 2)); + return names; +}) +.set_attr<nnvm::FInputGraph>("FInputGraph", + [](const NodeAttrs& attrs) { + return std::vector<uint32_t>{0, 1}; +}) +.set_attr<nnvm::FGradient>("FGradient", WhileLoopGradient) +.set_attr<FCreateOpState>("FCreateOpState", CreateWhileLoopState) +.set_attr<nnvm::FInferShape>("FInferShape", WhileLoopShape) +.set_attr<nnvm::FInferType>("FInferType", WhileLoopType) +.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", WhileLoopComputeExCPU) +.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kSubgraphExec; +}) +.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", WhileLoopComputeExCPU) +.set_attr<std::string>("key_var_num_args", "num_args") +.add_argument("cond", "Symbol", "Input graph for the loop condition.") +.add_argument("func", "Symbol", "Input graph for the loop body.") +.add_argument("data", "NDArray-or-Symbol[]", + "The input arrays that include data arrays and states.") +.add_arguments(WhileLoopParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_while_loop) +.set_num_inputs([](const NodeAttrs& attrs){ + const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed); + return params.num_outputs * 2 + params.num_args - 2; +}) +.set_num_outputs([](const NodeAttrs& attrs){ + const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed); + return params.num_args - 2; +}) +.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kSubgraphExec; +}) +.set_attr<FInferStorageType>("FInferStorageType", BackwardWhileLoopStorageType) +.set_attr_parser(ParamParser<WhileLoopParam>) +.set_attr<bool>("TIsLayerOpBackward", true) +.set_attr<nnvm::TIsBackward>("TIsBackward", true) +.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", WhileLoopGradComputeExCPU) +.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", WhileLoopGradComputeExCPU); + } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph_op_common.cc b/src/operator/subgraph_op_common.cc index 71a9a21c28c..d845aa907d3 100644 --- a/src/operator/subgraph_op_common.cc +++ b/src/operator/subgraph_op_common.cc @@ -164,14 +164,7 @@ bool InferSubgraphShape(const nnvm::Symbol &subgraph, LoopState::LoopState(const Symbol &g) { this->subgraph_sym = g; this->subgraph.outputs = g.outputs; - - std::vector<std::pair<std::string, std::string> > kwargs; - kwargs.push_back(std::pair<std::string, std::string>("inline_limit", "0")); - // We turn on static_alloc for two reasons. - // It avoids the overhead of unnecessary memory allocation. - // only static_alloc supports nested call of CachedOp. - kwargs.push_back(std::pair<std::string, std::string>("static_alloc", "1")); - iter_op = std::make_shared<CachedOp>(subgraph_sym, kwargs); + this->iter_op = LoopState::MakeSharedOp(g); } void LoopState::Forward(int iter_no, diff --git a/src/operator/subgraph_op_common.h b/src/operator/subgraph_op_common.h index 79078409e21..f73f09cd5c8 100644 --- a/src/operator/subgraph_op_common.h +++ b/src/operator/subgraph_op_common.h @@ -24,6 +24,8 @@ #include <mxnet/base.h> #include <mxnet/op_attr_types.h> #include <vector> +#include <utility> +#include <string> #include "../imperative/cached_op.h" #include "../imperative/imperative_utils.h" @@ -69,8 +71,8 @@ class LoopState { // For training, each iteration has a cached op because each iteration // needs to maintain a set of memory buffers for all computation states, // which will be used in the backward. - CachedOpPtr iter_op; std::vector<OpStatePtr> all_states; + CachedOpPtr iter_op; Symbol subgraph_sym; nnvm::Graph subgraph; @@ -91,6 +93,16 @@ class LoopState { all_inputs.clear(); all_states.clear(); } + static CachedOpPtr MakeSharedOp(const Symbol &sym) { + // We turn on static_alloc for two reasons. + // It avoids the overhead of unnecessary memory allocation. + // only static_alloc supports nested call of CachedOp. + std::vector<std::pair<std::string, std::string> > kwargs = { + {"inline_limit", "0"}, + {"static_alloc", "1"} + }; + return std::make_shared<CachedOp>(sym, kwargs); + } }; } // namespace op diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py new file mode 100644 index 00000000000..9dd5c4397be --- /dev/null +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -0,0 +1,978 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +from mxnet import gluon +import numpy as np +import copy +from numpy.testing import assert_allclose +import unittest +from mxnet.test_utils import almost_equal, default_context +from numpy.testing import assert_allclose as assert_almost_equal # This is more restrictive +from mxnet.base import _as_list + + +def test_while_loop_simple_forward(): + + class _TestBlock(gluon.HybridBlock): + + def __init__(self, cond, func, max_iterations): + super(_TestBlock, self).__init__() + self.cond = cond + self.func = func + self.max_iterations = max_iterations + + def hybrid_forward(self, F, *loop_vars): + return F.contrib.while_loop( + cond=self.cond, + func=self.func, + loop_vars=loop_vars, + max_iterations=self.max_iterations + ) + + for hybridize in [False, True]: + # Case 1.1: result should be sum([1, 2, 3 ... 100]) + model = _TestBlock( + cond=lambda i, s: i <= 5, + func=lambda i, s: (None, (i + 1, s + i)), + max_iterations=10, + ) + if hybridize: + model.hybridize() + _, result = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + ) + assert result[0].asscalar() == 6 + assert result[1].asscalar() == 15 + # Case 1.2: result should be sum([1, 2, 3 ... 1000]) + model = _TestBlock( + cond=lambda i, s, true: true, + func=lambda i, s, true: (None, (i + 1, s + i, true)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + _, result = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + mx.nd.array([1], dtype="int64"), # true + ) + assert result[0].asscalar() == 1001 + assert result[1].asscalar() == 500500 + assert result[2].asscalar() == 1 + # Case 1.3: result should be sum([]) + model = _TestBlock( + cond=lambda i, s, false: false, + func=lambda i, s, false: (None, (i + 1, s + i, false)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + _, result = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + mx.nd.array([0], dtype="int64"), # false + ) + assert result[0].asscalar() == 1 + assert result[1].asscalar() == 0 + assert result[2].asscalar() == 0 + # Case 2.1: result should be sum([1, 2, 3 ... 100]) + model = _TestBlock( + cond=lambda i, s: i <= 100, + func=lambda i, s: (i, (i + 1, s + i)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + (outputs, ), (result_i, result_s) = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + ) + assert all(outputs.asnumpy()[ : 100] == np.arange(1, 101).reshape(100, 1)) + assert result_i.asscalar() == 101 + assert result_s.asscalar() == 5050 + # Case 2.2: result should be sum([1, 2, 3 ... 1000]) + model = _TestBlock( + cond=lambda i, s, true: true, + func=lambda i, s, true: (i, (i + 1, s + i, true)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + (outputs, ), (result_i, result_s, _) = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + mx.nd.array([1], dtype="int64"), # true + ) + assert all(outputs.asnumpy() == np.arange(1, 1001).reshape(1000, 1)) + assert result_i.asscalar() == 1001 + assert result_s.asscalar() == 500500 + # Case 2.3: a corner case, in which loop body is never executed + model = _TestBlock( + cond=lambda i, s, false: false, + func=lambda i, s, false: (i, (i + 1, s + i, false)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + _, (result_i, result_s, _) = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + mx.nd.array([0], dtype="int64"), # false + ) + assert result_i.asscalar() == 1 + assert result_s.asscalar() == 0 + + +def _verify_while_loop(cond, func, loop_var_shapes, free_var_shapes, is_train, max_iterations, is_for, n_steps): + + def _create_vars(num, prefix): + return [mx.sym.var(prefix + str(i)) for i in range(num)] + + def _create_arrays(shapes): + return [mx.nd.random.uniform(-1.0, 1.0, shape=x) for x in shapes] + + def _create_dict(prefix, shapes): + return {prefix + str(i): mx.nd.random.uniform(-1.0, 1.0, shape=x) for i, x in enumerate(shapes)} + + def _merge_dict(*dicts): + result = {} + for item in dicts: + result.update(item) + return result + + def _to_numpy_list(arrays): + return [x.asnumpy() if x is not None else x for x in arrays] + + def _get_imperative_result(n_steps): + free_vars = [args["FreeVar" + str(i)].copy() for i, _ in enumerate(free_var_shapes)] + loop_vars = [args["LoopVar" + str(i)].copy() for i, _ in enumerate(loop_var_shapes)] + loop_var_start = int(is_for) + if is_train: + for var in free_vars + loop_vars[loop_var_start: ]: + var.attach_grad() + with mx.autograd.record(train_mode=is_train): + outputs, final_loop_vars = mx.nd.contrib.while_loop( + cond=lambda *_loop_vars: cond(_loop_vars, free_vars), + func=lambda *_loop_vars: func(_loop_vars, free_vars), + loop_vars=loop_vars, + max_iterations=max_iterations, + ) + outputs = [x[: n_steps] for x in outputs] + out_grads = _create_arrays(x.shape for x in outputs) \ + + _create_arrays(x.shape for x in final_loop_vars) + loop_result_nd = [x * 2 for x in outputs] + [x * 3 for x in final_loop_vars] + grads = [] + if is_train: + cat_out = mx.nd.concat(*[x.reshape(-1) for x in loop_result_nd], dim=0) + cat_out.backward(out_grad=mx.nd.concat(*[x.reshape(-1) for x in out_grads], dim=0)) + grads = [free_vars[i].grad for i, _ in enumerate(free_var_shapes)] \ + + [loop_vars[i].grad for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start] + return _to_numpy_list(loop_result_nd), _to_numpy_list(grads), out_grads + + def _get_symbolic_result(out_grads, n_steps): + + def _copy_args_dict(name_list): + return {name: args[name].copy() for name in name_list} + + def _zeros_like_dict(name_list): + return {name: mx.nd.zeros_like(args[name]) for name in name_list} + + free_syms = _create_vars(len(free_var_shapes), "FreeVar") + loop_syms = _create_vars(len(loop_var_shapes), "LoopVar") + outputs, final_loop_syms = mx.sym.contrib.while_loop( + cond=lambda *_loop_vars: cond(_loop_vars, free_syms), + func=lambda *_loop_vars: func(_loop_vars, free_syms), + loop_vars=loop_syms, + max_iterations=max_iterations, + ) + if n_steps == 0: + outputs = [] + else: + outputs = [x.slice_axis(axis=0, begin=0, end=n_steps) for x in outputs] + loop_result_sym = [x * 2 for x in outputs] + [x * 3 for x in final_loop_syms] + loop_result_sym = mx.sym.Group(loop_result_sym) + + loop_var_start = int(is_for) + args_names = ["FreeVar" + str(i) for i, _ in enumerate(free_var_shapes)] \ + + ["LoopVar" + str(i) for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start] + args_grad = None if not is_train else _zeros_like_dict(x for x in args_names) + executor = loop_result_sym.bind( + ctx=default_context(), + args=_copy_args_dict(loop_result_sym.list_inputs()), + args_grad=args_grad, + ) + loop_result_nd = executor.forward(is_train=is_train) + grads = [] + if is_train: + executor.backward(out_grads=out_grads) + grads = [executor.grad_dict.get("FreeVar" + str(i), None) for i, _ in enumerate(free_var_shapes)] \ + + [executor.grad_dict.get("LoopVar" + str(i), None) for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start] + return _to_numpy_list(loop_result_nd), _to_numpy_list(grads) + + args = _merge_dict( + _create_dict("FreeVar", free_var_shapes), + _create_dict("LoopVar", loop_var_shapes), + ) + if is_for: + assert loop_var_shapes[0] == (1, ) + args["LoopVar0"] = mx.nd.array([0]) + imp_outs, imp_grads, out_grads = _get_imperative_result(n_steps) + sym_outs, sym_grads = _get_symbolic_result(out_grads, n_steps) + for imp_out, sym_out in zip(imp_outs, sym_outs): + if imp_out is None or sym_out is None: + continue + assert_almost_equal(imp_out, sym_out, rtol=1e-4, atol=1e-4) + for imp_grad, sym_grad in zip(imp_grads, sym_grads): + if imp_grad is None or sym_grad is None: + continue + assert_almost_equal(imp_grad, sym_grad, rtol=1e-4, atol=1e-4) + + +def test_while_loop_for_foreach(): + + def make_true_cond(): + return lambda loop_vars, _: (loop_vars[0] < 1e200).prod() + + def make_false_cond(): + return lambda loop_vars, _: (loop_vars[0] > 1e200).prod() + + def make_for_cond(length): + return lambda loop_vars, _: loop_vars[0] < length + + def case_0(): + # This is a simple testcase that all loop steps are independent' + # It basically scans the array and outputs itself + # There is 1 output + # There is 1 state: i + def _simple_func(loop, free): + (i, ), (scanned, ) = loop, free + in_ = scanned.take(i).squeeze(axis=0) + return (in_, i + 1) + _verify_while_loop( + cond=make_true_cond(), + func=_simple_func, + max_iterations=1, + is_train=True, + is_for=True, + loop_var_shapes=[ + (1, ), # i + ], + free_var_shapes=[ + (1, 3), # scanned + ], + n_steps=1, + ) + + def case_1(**params): + # This is a simple testcase that simulates a cumulative sum + # There is 1 output + # There is 1 state: s + step_funcs = [ + lambda a, b, s: s, + lambda a, b, s: a * 1.5 + b * 2.5 - s * 3.5, + lambda a, b, s: a * 1.5 - s * 3.5 + b * 2.5, + lambda a, b, s: b * 2.5 + a * 1.5 - s * 3.5, + lambda a, b, s: b * 2.5 - s * 3.5 + a * 1.5, + lambda a, b, s: s * -3.5 + a * 1.5 + b * 2.5, + lambda a, b, s: s * -3.5 + b * 2.5 + a * 1.5, + lambda a, b, s: a * 2.5 * b + s * 0.3, + lambda a, b, s: b * 2.5 * a + s * 0.3, + lambda a, b, s: 2.5 * a * b + s * 0.3, + lambda a, b, s: b * a * 2.5 + s * 0.3, + lambda a, b, s: 2.5 * b * a + s * 0.3, + lambda a, b, s: b * a * 2.5 + s * 0.3, + lambda a, b, s: s * 0.3 + a * 2.5 * b, + lambda a, b, s: s * 0.3 + b * 2.5 * a, + lambda a, b, s: s * 0.3 + 2.5 * a * b, + lambda a, b, s: s * 0.3 + b * a * 2.5, + lambda a, b, s: s * 0.3 + 2.5 * b * a, + lambda a, b, s: s * 0.3 + b * a * 2.5, + ] + def make_func(step_func): + def step(loop, free): + (s, ), (a, b) = loop, free + out = step_func(a, b, s) + return (out, out) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + _verify_while_loop( + func=make_func(step_func), + is_train=is_train, + is_for=False, + **params + ) + + def case_2(**params): + # This is a testcase that involves non-differentiable operators + # There is 1 output + # There is 2 states: i, s + step_funcs = [ + lambda in_, s, f_1: (in_ * 2) * s * f_1, + lambda in_, s, f_1: (in_ * 2) * f_1 * s, + lambda in_, s, f_1: s * (in_ * 2) * f_1, + lambda in_, s, f_1: s * f_1 * (in_ * 2), + lambda in_, s, f_1: f_1 * (in_ * 2) * s, + lambda in_, s, f_1: f_1 * s * (in_ * 2), + lambda in_, s, f_1: (2 * in_) * s * f_1, + lambda in_, s, f_1: (2 * in_) * f_1 * s, + lambda in_, s, f_1: s * (2 * in_) * f_1, + lambda in_, s, f_1: s * f_1 * (2 * in_), + lambda in_, s, f_1: f_1 * (2 * in_) * s, + lambda in_, s, f_1: f_1 * s * (2 * in_), + ] + def make_func(step_func): + """This simulates: + def compute(s, inputs, f_1, length): + outputs = [] + for i in range(length): + s += inputs[i] * 2 + f_1 + outputs.append(s) + return outputs, s + """ + def step(loop, free): + (i, s), (scanned, f_1, _) = loop, free + in_ = scanned.take(i).squeeze(axis=0) + out = step_func(in_, s, f_1) + return (out, (i + 1, out)) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + _verify_while_loop( + func=make_func(step_func), + max_iterations=1000, + is_train=is_train, + is_for=True, + **params + ) + + def case_3(length, **params): + # This is a testcase for multiple non-differentiable operators and different ways of slicing + # There are 2 outputs + # There are 3 states: i, s_0, s_1 + step_funcs = [ + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0, + lambda i_0, i_1, s_0, s_1, f_0: i_1, + lambda i_0, i_1, s_0, s_1, f_0: s_0, + lambda i_0, i_1, s_0, s_1, f_0: s_1, + lambda i_0, i_1, s_0, s_1, f_0: f_0, + ] + def make_func(step_func): + """This simulates: + def compute(input_0, input_1, s_0, s_1, f_0, length): + output_0 = [] + output_1 = [] + for i in range(length): + i_0 = input_0[i] + i_1 = input_1[length - 1 - i] + out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0 + s_0 = (s_0 + out) * 1.05 + s_1 = (s_1 - out * 0.5) * 0.95 + output_0.append(out) + output_1.append(out * 1.5) + return outputs, s_0, s_1 + """ + def step(loop, free): + (i, s_0, s_1), (sc_0, sc_1, f_0, _) = loop, free + i_0 = sc_0.take(i).squeeze(axis=0) + i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) + out = step_func(i_0, i_1, s_0, s_1, f_0) + return ([out, out * 1.5], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95]) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + _verify_while_loop( + func=make_func(step_func), + max_iterations=1000, + is_train=is_train, + is_for=True, + **params + ) + + def case_4(length, single_shape, **params): + # It is for the case that inputs & outputs are the same + # There are 3 outputs + # There are 4 states: i, s_0, s_1, s_2 + # i is used in both non-differentiable (take) and differentiable (+) occasions + step_funcs = [ + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0, + lambda i_0, i_1, s_0, s_1, f_0: i_1, + lambda i_0, i_1, s_0, s_1, f_0: s_0, + lambda i_0, i_1, s_0, s_1, f_0: s_1, + lambda i_0, i_1, s_0, s_1, f_0: f_0, + ] + def make_func(step_func): + """This simulates: + def compute(input_0, input_1, s_0, s_1, s_2, f_0, length): + # here s_2 remains untouched + output_0 = [] + output_1 = [] + output_2 = [] + for i in range(length): + i_0 = input_0[i] + i_1 = input_1[length - 1 - i] + out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0 + out = out * i * i_0 * i_1 + s_0 = (s_0 + out) * 1.05 + s_1 = (s_1 - out * 0.5) * 0.95 + output_0.append(out) + output_1.append(f_0) + output_2.append(out * 1.5) + return output_0, output_1, output_2, s_0, s_1, s_2 + """ + def step(loop, free): + (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free + i_0 = sc_0.take(i).squeeze(axis=0) + i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) + out = step_func(i_0, i_1, s_0, s_1, f_0) + out = out * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) + out = out * i_0 * i_1 + return ([out, f_0, out * 1.5], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95, s_2]) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + _verify_while_loop( + func=make_func(step_func), + max_iterations=1000, + is_train=is_train, + is_for=True, + **params + ) + + def case_5(length, single_shape, **params): + # It is for the case that inputs & outputs are the same + # There are 0 outputs + # There are 4 states: i, s_0, s_1, s_2 + # i is used in both differentiable (take) and non-differentiable (+) occasions + step_funcs = [ + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0, + lambda i_0, i_1, s_0, s_1, f_0: i_1, + lambda i_0, i_1, s_0, s_1, f_0: s_0, + lambda i_0, i_1, s_0, s_1, f_0: s_1, + lambda i_0, i_1, s_0, s_1, f_0: f_0, + ] + def make_func(step_func): + """This simulates: + def compute(input_0, input_1, s_0, s_1, s_2, f_0, length): + # here s_2 remains untouched + output_0 = [] + output_1 = [] + output_2 = [] + for i in range(length): + i_0 = input_0[i] + i_1 = input_1[length - 1 - i] + out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0 + out = out * i * i_0 * i_1 + s_0 = (s_0 + out) * 1.05 + s_1 = (s_1 - out * 0.5) * 0.95 + output_0.append(out) + output_1.append(f_0) + output_2.append(out * 1.5) + return output_0, output_1, output_2, s_0, s_1, s_2 + """ + def step(loop, free): + (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free + i_0 = sc_0.take(i).squeeze(axis=0) + i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) + out = step_func(i_0, i_1, s_0, s_1, f_0) + out = out * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) + out = out * i_0 * i_1 + return ([], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95, s_2]) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + _verify_while_loop( + func=make_func(step_func), + max_iterations=1000, + is_train=is_train, + is_for=True, + **params + ) + + def case_6(length, single_shape, **params): + # It is for the case that inputs & outputs are the same + # There are 3 outputs + # There are 4 states: i, s_0, s_1, s_2 + # i is used in both differentiable (take) and non-differentiable (+) occasions + step_funcs = [ + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0, + lambda i_0, i_1, s_0, s_1, f_0: i_1, + lambda i_0, i_1, s_0, s_1, f_0: s_0, + lambda i_0, i_1, s_0, s_1, f_0: s_1, + lambda i_0, i_1, s_0, s_1, f_0: f_0, + ] + def make_func(step_func): + """This simulates: + def compute(input_0, input_1, s_0, s_1, s_2, f_0, length): + # here s_2 remains untouched + output_0 = [] + output_1 = [] + output_2 = [] + for i in range(length): + i_0 = input_0[i] + i_1 = input_1[length - 1 - i] + out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0 + out = out * i * i_0 * i_1 + s_0 = (s_0 + out) * 1.05 + s_1 = (s_1 - out * 0.5) * 0.95 + output_0.append(out) + output_1.append(f_0) + output_2.append(out * 1.5) + return output_0, output_1, output_2, s_0, s_1, s_2 + """ + def step(loop, free): + (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free + F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd + i_0 = sc_0.take(i).squeeze(axis=0) + i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) + out_0 = step_func(i_0, i_1, s_0, s_1, f_0) + out_0 = out_0 * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) + out_1 = step_func(i_1, s_0, f_0, s_1, i_0) + out_1 = out_1 * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) + return ([F.dot(out_0, s_2), f_0, F.dot(s_2, out_1) * 1.5], [i + 1, (s_0 + out_1) * 1.05, (s_1 - out_0 * 0.5) * 0.95, s_2]) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + _verify_while_loop( + func=make_func(step_func), + max_iterations=1000, + is_train=is_train, + is_for=True, + **params + ) + + # Case 0: the simpest case + case_0() + # Case 1.1.* + case_1( + cond=make_true_cond(), + loop_var_shapes=[ + (1, ), # s + ], + free_var_shapes=[ + (1, ), # a + (1, ), # b + ], + max_iterations=23, + n_steps=23, + ) + # Case 1.2.* + case_1( + cond=make_true_cond(), + loop_var_shapes=[ + (2, 3, 4), # s + ], + free_var_shapes=[ + (2, 3, 4), # a + (2, 3, 4), # b + ], + max_iterations=31, + n_steps=31, + ) + # Case 1.3.* + case_1( + cond=make_false_cond(), + loop_var_shapes=[ + (2, 3, 4), # s + ], + free_var_shapes=[ + (2, 3, 4), # a + (2, 3, 4), # b + ], + max_iterations=20, + n_steps=0, + ) + # Case 2.1.* + case_2( + cond=make_for_cond(length=31), + loop_var_shapes=[ + (1, ), # i + (2, ), # s + ], + free_var_shapes=[ + (100, 2), # scanned + (2, ), # f_1 + (3, 4, 5, 6), # f_2, unused + ], + n_steps=31, + ) + # Case 2.2.* + case_2( + cond=make_for_cond(length=25), + loop_var_shapes=[ + (1, ), # i + (2, ), # s + ], + free_var_shapes=[ + (30, 2), # scanned + (2, ), # f_1 + (3, 4, 5, 6), # f_2, unused + ], + n_steps=25, + ) + # Case 3.* + case_3( + length=11, + cond=make_for_cond(length=11), + loop_var_shapes=[ + (1, ), # i + (2, ), # s_0 + (2, ), # s_1 + ], + free_var_shapes=[ + (30, 2), # sc_0 + (30, 2), # sc_1 + (2, ), # f_0 + (3, 4, 5, 6), # f_1, unused + ], + n_steps=11, + ) + # Case 4.1.* + case_4( + length=4, + cond=make_for_cond(length=4), + single_shape=[5], + loop_var_shapes=[ + (1, ), # i + (5, ), # s_0 + (5, ), # s_1 + (23, 6, 8), # s_2 + ], + free_var_shapes=[ + (30, 5), # sc_0 + (30, 5), # sc_1 + (5, ), # f_0 + (3, 4, 5, 6), # f_1, unused + ], + n_steps=4, + ) + # Case 4.2.* + case_4( + length=5, + cond=make_for_cond(length=5), + single_shape=[5, 12], + loop_var_shapes=[ + (1, ), # i + (5, 12), # s_0 + (5, 12), # s_1 + (23, 6, 8), # s_2 + ], + free_var_shapes=[ + (30, 5, 12), # sc_0 + (30, 5, 12), # sc_1 + (5, 12), # f_0 + (3, 4, 5, 6), # f_1, unused + ], + n_steps=5, + ) + # Case 5.1.* + case_5( + length=4, + cond=make_for_cond(length=4), + single_shape=[5], + loop_var_shapes=[ + (1, ), # i + (5, ), # s_0 + (5, ), # s_1 + (23, 6, 8), # s_2 + ], + free_var_shapes=[ + (30, 5), # sc_0 + (30, 5), # sc_1 + (5, ), # f_0 + (3, 4, 5, 6), # f_1, unused + ], + n_steps=4, + ) + # Case 5.2.* + case_5( + length=5, + cond=make_for_cond(length=5), + single_shape=[3, 4, 2], + loop_var_shapes=[ + (1, ), # i + (3, 4, 2), # s_0 + (3, 4, 2), # s_1 + (23, 6, 8), # s_2 + ], + free_var_shapes=[ + (30, 3, 4, 2), # sc_0 + (30, 3, 4, 2), # sc_1 + (3, 4, 2), # f_0 + (3, 4, 5, 6), # f_1, unused + ], + n_steps=5, + ) + # Case 6.* + case_6( + length=5, + cond=make_for_cond(length=5), + single_shape=[5, 3], + loop_var_shapes=[ + (1, ), # i + (5, 3), # s_0 + (5, 3), # s_1 + (3, 5), # s_2 + ], + free_var_shapes=[ + (30, 5, 3), # sc_0 + (30, 5, 3), # sc_1 + (5, 3), # f_0 + (3, 4, 5, 6), # f_1, unused + ], + n_steps=5, + ) + + +def test_while_loop_nested(): + + def _to_np_list(arrays): + return [x.asnumpy() if x is not None else x for x in arrays] + + def _array(shape): + return mx.nd.random.uniform(-1.0, 1.0, shape=shape) + + def inner_cond(i, j, x_sum, sc): + return j < 2 + + def inner_body(i, j, x_sum, sc): + x_ij = sc.take(j).squeeze(axis=0) + return (x_ij, x_ij), (i, j + 1, x_sum, sc) + + def outer_cond(i, j, x_sum, sc): + return i < 2 + + def outer_body(i, j, x_sum, sc): + F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd + (x_ij, x_ji), (i_p, j_p, x_sum_p, sc_p) = F.contrib.while_loop( + cond=inner_cond, + func=inner_body, + loop_vars=(i, j, x_sum, sc), + max_iterations=2, + ) + return (x_ij, x_ji), (i_p + 1, j_p - 2, x_sum_p, sc_p) + + def make_loop(i, j, x_sum, sc): + F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd + (x_ij, x_ji), (new_i, new_j, x_sum_p, sc_p) = F.contrib.while_loop( + cond=outer_cond, + func=outer_body, + loop_vars=(i, j, x_sum, sc), + max_iterations=2, + ) + return new_i, new_j, x_sum_p, sc_p, x_ij, x_ji + + args = { + "i": mx.nd.array([0]), + "j": mx.nd.array([0]), + "x_sum": _array([5, 3]), + "sc": _array([10, 10, 5, 3]), + } + args_grad = { + "x_sum": _array([5, 3]), + "sc": _array([10, 10, 5, 3]), + } + out_grad = [ + _array([1]), + _array([1]), + _array([5, 3]), + _array([10, 10, 5, 3]), + _array([2, 2, 10, 5, 3]), + _array([2, 2, 10, 5, 3]), + ] + def _get_imp_result(is_train, args, args_grad, out_grad): + args = {k: v.copy() for k, v in args.items()} + args_grad = {k: v.copy() for k, v in args_grad.items()} + i, j, x_sum, sc = [args[x].copy() for x in ["i", "j", "x_sum", "sc"]] + if is_train: + x_sum.attach_grad() + sc.attach_grad() + with mx.autograd.record(train_mode=is_train): + results = make_loop(i, j, x_sum, sc) + cat_res = mx.nd.concat(*[x.reshape(-1) for x in results], dim=0) + if not is_train: + return _to_np_list(results), [] + cat_grad = mx.nd.concat(*[x.reshape(-1) for x in out_grad], dim=0) + assert cat_grad.shape == cat_res.shape + cat_res.backward(out_grad=cat_grad) + grads = [x_sum.grad, sc.grad] + return _to_np_list(results), _to_np_list(grads) + + def _get_sym_result(is_train, args, args_grad, out_grad): + args = {k: v.copy() for k, v in args.items()} + args_grad = {k: v.copy() for k, v in args_grad.items()} + i, j, x_sum, sc = [ + mx.sym.var("i"), + mx.sym.var("j"), + mx.sym.var("x_sum"), + mx.sym.var("sc"), + ] + result_sym = mx.sym.Group(make_loop(i, j, x_sum, sc)) + executor = result_sym.bind( + ctx=default_context(), + args=args, + args_grad=args_grad, + ) + results = executor.forward(is_train=is_train) + if not is_train: + return _to_np_list(results), [] + executor.backward(out_grads=out_grad) + grads = [executor.grad_dict["x_sum"], executor.grad_dict["sc"]] + return _to_np_list(results), _to_np_list(grads) + + for is_train in [True, False]: + imp_out, imp_grad = _get_imp_result(is_train=is_train, args=args, args_grad=args_grad, out_grad=out_grad) + sym_out, sym_grad = _get_sym_result(is_train=is_train, args=args, args_grad=args_grad, out_grad=out_grad) + assert len(imp_out) == len(sym_out) + assert len(imp_grad) == len(sym_grad) + for x, y in zip(imp_out, sym_out): + assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) + for x, y in zip(imp_grad, sym_grad): + assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) + + +def test_while_loop_rnn(): + def _array(shape): + return mx.nd.random.uniform(-1.0, 1.0, shape=shape) + + cell_types = [mx.rnn.LSTMCell] + num_params = [2] + + batch_size = 2 + hidden_dim = 3 + input_dim = 4 + seq_len = 3 + + for cell, n_param in zip(cell_types, num_params): + # using while_loop + params = mx.rnn.RNNParams() + data = mx.sym.var("data") + iter_i = mx.sym.var("i") + def _cond(*states): + i = states[0] + return i < seq_len + def _func(*states): + i = states[0] + states = states[1:] + in_ = data.take(i).squeeze(axis=0) + rnn = cell(hidden_dim, prefix='', params=params) + next_hidden, next_states = rnn(in_, states) + return [next_hidden], [i + 1] + list(next_states) + states = [mx.sym.var("s_" + str(i)) for i in range(n_param)] + result = mx.sym.contrib.while_loop( + cond=_cond, + func=_func, + loop_vars=[iter_i] + states, + max_iterations=seq_len + ) + result = mx.sym.Group(result[0] + result[1][1: ]) + arg_shapes, _, _ = result.infer_shape( + data=(seq_len, batch_size, input_dim), + s_0=(batch_size, hidden_dim), + ) + rnn_inputs = result.list_inputs() + args = {name: _array(arg_shapes[i]) for i, name in enumerate(rnn_inputs) if name != "i"} + args["i"] = mx.nd.zeros([1]) + args_grad = {name: _array(arg_shapes[i]) for i, name in enumerate(rnn_inputs)} + e_1 = result.bind(ctx=default_context(), + args={name: array.copy() for name, array in args.items()}, + args_grad={name: array.copy() for name, array in args_grad.items() if name != "i"}, + ) + # using unrolled rnn + rnn = cell(hidden_dim, prefix='') + unroll_outs = [] + for inputs in mx.sym.split(data, num_outputs=seq_len, axis=0, squeeze_axis=True): + h, states = rnn(inputs, states) + unroll_outs.append(mx.sym.expand_dims(h, axis=0)) + unroll_outs = _as_list(mx.sym.concat(*unroll_outs, dim=0)) + unroll_outs.extend(states) + result = mx.sym.Group(unroll_outs) + e_2 = result.bind(ctx=default_context(), + args={name: array.copy() for name, array in args.items() if name != "i"}, + args_grad={name: array.copy() for name, array in args_grad.items() if name != "i"}, + ) + for case_id in range(100): + out_grads = [_array(arr.shape) for arr in e_1.outputs] + args = {name: array.copy() for name, array in args.items()} + e_1.forward(is_train=True, **args) + e_1.backward(out_grads) + args = {name: array.copy() for name, array in args.items() if name != "i"} + e_2.forward(is_train=True, **args) + e_2.backward(out_grads) + assert len(e_1.outputs) == len(e_2.outputs) + for x, y in zip(e_1.outputs, e_2.outputs): + x = x.asnumpy() + y = y.asnumpy() + assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) + grad_keys = list(e_2.grad_dict.keys()) + e_1_grad = [e_1.grad_dict[x] for x in grad_keys] + e_2_grad = [e_2.grad_dict[x] for x in grad_keys] + for x, y in zip(e_1_grad, e_2_grad): + x = x.asnumpy() + y = y.asnumpy() + assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) + + +if __name__ == '__main__': + import nose + nose.runmodule() ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
