This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 5474b08  make gluon rnn layers hybrid blocks (#11482)
5474b08 is described below

commit 5474b086757a8df94984fb95622ead0047ac78b4
Author: Sheng Zha <s...@users.noreply.github.com>
AuthorDate: Sat Aug 4 12:17:50 2018 -0700

    make gluon rnn layers hybrid blocks (#11482)
    
    * make Gluon RNN layer hybrid block
    
    * separate gluon gpu tests
    
    * remove excess assert_raises_cudnn_disabled usage
    
    * add comments and refactor
    
    * add bidirectional test
    
    * temporarily remove hybridize in test_gluon_rnn.test_layer_fill_shape
---
 python/mxnet/gluon/rnn/rnn_layer.py     | 132 ++++++++++-----------
 src/operator/nn/concat.cc               | 127 ++++++++++++++++----
 src/operator/nn/concat.cu               |   4 +
 src/operator/rnn.cc                     |   6 +-
 tests/python/gpu/test_gluon_gpu.py      | 203 ++++++++++++++++++++++++++++++++
 tests/python/gpu/test_operator_gpu.py   | 124 -------------------
 tests/python/unittest/test_gluon_rnn.py |  91 +++++++++++---
 7 files changed, 449 insertions(+), 238 deletions(-)

diff --git a/python/mxnet/gluon/rnn/rnn_layer.py 
b/python/mxnet/gluon/rnn/rnn_layer.py
index 418c497..4a7a0be 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -23,12 +23,11 @@
 from __future__ import print_function
 __all__ = ['RNN', 'LSTM', 'GRU']
 
-from ... import ndarray
-from .. import Block
+from ... import ndarray, symbol
+from .. import HybridBlock, tensor_types
 from . import rnn_cell
 
-
-class _RNNLayer(Block):
+class _RNNLayer(HybridBlock):
     """Implementation of recurrent layers."""
     def __init__(self, hidden_size, num_layers, layout,
                  dropout, bidirectional, input_size,
@@ -52,33 +51,28 @@ class _RNNLayer(Block):
 
         self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]
 
-        self.i2h_weight = []
-        self.h2h_weight = []
-        self.i2h_bias = []
-        self.h2h_bias = []
-
         ng, ni, nh = self._gates, input_size, hidden_size
         for i in range(num_layers):
-            for j in (['l', 'r'] if self._dir == 2 else ['l']):
-                self.i2h_weight.append(
-                    self.params.get('%s%d_i2h_weight'%(j, i), shape=(ng*nh, 
ni),
-                                    init=i2h_weight_initializer,
-                                    allow_deferred_init=True))
-                self.h2h_weight.append(
-                    self.params.get('%s%d_h2h_weight'%(j, i), shape=(ng*nh, 
nh),
-                                    init=h2h_weight_initializer,
-                                    allow_deferred_init=True))
-                self.i2h_bias.append(
-                    self.params.get('%s%d_i2h_bias'%(j, i), shape=(ng*nh,),
-                                    init=i2h_bias_initializer,
-                                    allow_deferred_init=True))
-                self.h2h_bias.append(
-                    self.params.get('%s%d_h2h_bias'%(j, i), shape=(ng*nh,),
-                                    init=h2h_bias_initializer,
-                                    allow_deferred_init=True))
+            for j in ['l', 'r'][:self._dir]:
+                self._register_param('{}{}_i2h_weight'.format(j, i),
+                                     shape=(ng*nh, ni),
+                                     init=i2h_weight_initializer)
+                self._register_param('{}{}_h2h_weight'.format(j, i),
+                                     shape=(ng*nh, nh),
+                                     init=h2h_weight_initializer)
+                self._register_param('{}{}_i2h_bias'.format(j, i),
+                                     shape=(ng*nh,),
+                                     init=i2h_bias_initializer)
+                self._register_param('{}{}_h2h_bias'.format(j, i),
+                                     shape=(ng*nh,),
+                                     init=h2h_bias_initializer)
             ni = nh * self._dir
 
-        self._unfused = self._unfuse()
+    def _register_param(self, name, shape, init):
+        p = self.params.get(name, shape=shape, init=init,
+                            allow_deferred_init=True)
+        setattr(self, name, p)
+        return p
 
     def __repr__(self):
         s = '{name}({mapping}, {_layout}'
@@ -89,12 +83,23 @@ class _RNNLayer(Block):
         if self._dir == 2:
             s += ', bidirectional'
         s += ')'
-        shape = self.i2h_weight[0].shape
+        shape = self.l0_i2h_weight.shape
         mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0] 
// self._gates)
         return s.format(name=self.__class__.__name__,
                         mapping=mapping,
                         **self.__dict__)
 
+    def _collect_params_with_prefix(self, prefix=''):
+        if prefix:
+            prefix += '.'
+        def convert_key(key): # for compatibility with old parameter format
+            key = key.split('_')
+            return '_unfused.{}.{}_cell.{}'.format(key[0][1:], key[0][0], 
'_'.join(key[1:]))
+        ret = {prefix + convert_key(key) : val for key, val in 
self._reg_params.items()}
+        for name, child in self._children.items():
+            ret.update(child._collect_params_with_prefix(prefix + name))
+        return ret
+
     def state_info(self, batch_size=0):
         raise NotImplementedError
 
@@ -111,7 +116,7 @@ class _RNNLayer(Block):
                     'gru': lambda **kwargs: rnn_cell.GRUCell(self._hidden_size,
                                                              
**kwargs)}[self._mode]
 
-        stack = rnn_cell.SequentialRNNCell(prefix=self.prefix, 
params=self.params)
+        stack = rnn_cell.HybridSequentialRNNCell(prefix=self.prefix, 
params=self.params)
         with stack.name_scope():
             ni = self._input_size
             for i in range(self._num_layers):
@@ -169,55 +174,42 @@ class _RNNLayer(Block):
             states.append(func(name='%sh0_%d'%(self.prefix, i), **info))
         return states
 
-    def forward(self, inputs, states=None):
-        batch_size = inputs.shape[self._layout.find('N')]
+    def hybrid_forward(self, F, inputs, states=None, **kwargs):
+        if F is ndarray:
+            batch_size = inputs.shape[self._layout.find('N')]
         skip_states = states is None
         if skip_states:
-            states = self.begin_state(batch_size, ctx=inputs.context)
-        if isinstance(states, ndarray.NDArray):
+            if F is ndarray:
+                states = self.begin_state(batch_size, ctx=inputs.context)
+            else:
+                states = self.begin_state(0, func=symbol.zeros)
+        if isinstance(states, tensor_types):
             states = [states]
-        for state, info in zip(states, self.state_info(batch_size)):
-            if state.shape != info['shape']:
-                raise ValueError(
-                    "Invalid recurrent state shape. Expecting %s, got %s."%(
-                        str(info['shape']), str(state.shape)))
-        if self._input_size == 0:
-            for i in range(self._dir):
-                self.i2h_weight[i].shape = (self._gates*self._hidden_size, 
inputs.shape[2])
-                self.i2h_weight[i]._finish_deferred_init()
-        out = self._forward_kernel(inputs, states)
+        if F is ndarray:
+            for state, info in zip(states, self.state_info(batch_size)):
+                if state.shape != info['shape']:
+                    raise ValueError(
+                        "Invalid recurrent state shape. Expecting %s, got 
%s."%(
+                            str(info['shape']), str(state.shape)))
+        out = self._forward_kernel(F, inputs, states, **kwargs)
 
         # out is (output, state)
         return out[0] if skip_states else out
 
-    def _forward(self, inputs, states):
-        """forward using gluon cell"""
-        ns = len(states)
-        axis = self._layout.find('T')
-        states = sum(zip(*((j for j in i) for i in states)), ())
-        outputs, states = self._unfused.unroll(
-            inputs.shape[axis], inputs, states,
-            layout=self._layout, merge_outputs=True)
-        new_states = []
-        for i in range(ns):
-            state = ndarray.concat(*(j.reshape((1,)+j.shape) for j in 
states[i::ns]), dim=0)
-            new_states.append(state)
-
-        return outputs, new_states
-
-    def _forward_kernel(self, inputs, states):
+    def _forward_kernel(self, F, inputs, states, **kwargs):
         """ forward using CUDNN or CPU kenrel"""
         if self._layout == 'NTC':
-            inputs = ndarray.swapaxes(inputs, dim1=0, dim2=1)
-        ctx = inputs.context
-        params = sum(zip(self.i2h_weight, self.h2h_weight), ())
-        params += sum(zip(self.i2h_bias, self.h2h_bias), ())
-        params = (i.data(ctx).reshape((-1,)) for i in params)
-        params = ndarray.concat(*params, dim=0)
-
-        rnn = ndarray.RNN(inputs, params, *states, 
state_size=self._hidden_size,
-                          num_layers=self._num_layers, bidirectional=self._dir 
== 2,
-                          p=self._dropout, state_outputs=True, mode=self._mode)
+            inputs = F.swapaxes(inputs, dim1=0, dim2=1)
+        params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
+                  for t in ['weight', 'bias']
+                  for l in range(self._num_layers)
+                  for d in ['l', 'r'][:self._dir]
+                  for g in ['i2h', 'h2h'])
+        params = F._internal._rnn_param_concat(*params, dim=0)
+
+        rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size,
+                    num_layers=self._num_layers, bidirectional=self._dir == 2,
+                    p=self._dropout, state_outputs=True, mode=self._mode)
 
         if self._mode == 'lstm':
             outputs, states = rnn[0], [rnn[1], rnn[2]]
@@ -225,7 +217,7 @@ class _RNNLayer(Block):
             outputs, states = rnn[0], [rnn[1]]
 
         if self._layout == 'NTC':
-            outputs = ndarray.swapaxes(outputs, dim1=0, dim2=1)
+            outputs = F.swapaxes(outputs, dim1=0, dim2=1)
 
         return outputs, states
 
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index 266ccb1..7c7f403 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -74,6 +74,65 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs,
   return dshape.Size() != 0;
 }
 
+// Concat for RNN param deals with the reverse shape inference from output
+// for the special case of concatenating RNN parameters.
+// The first (and sometimes the second) input may be unknown on the target 
axis.
+// If the two inputs are unknown, they always have the same shape.
+static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
+                                std::vector<TShape> *in_shape,
+                                std::vector<TShape> *out_shape) {
+  using namespace mshadow;
+  const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
+  CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
+  TShape dshape;
+  index_t size = 0;
+  int num_zero = 0;
+  int axis = -1;
+  for (int i = 0; i < param_.num_args; ++i) {
+    TShape tmp = (*in_shape)[i];
+    if (tmp.ndim()) {
+      axis = CheckAxis(param_.dim, tmp.ndim());
+      num_zero += tmp[axis] == 0;
+      size += tmp[axis];
+      tmp[axis] = 0;
+      shape_assign(&dshape, tmp);
+    }
+  }
+
+  TShape tmp = (*out_shape)[0];
+  if (tmp.ndim()) {
+    axis = CheckAxis(param_.dim, tmp.ndim());
+    tmp[axis] = 0;
+    shape_assign(&dshape, tmp);
+  }
+
+  if (dshape.ndim() == 0) return false;
+
+  for (int i = 0; i < param_.num_args; ++i) {
+    CHECK(shape_assign(&(*in_shape)[i], dshape))
+        << "Incompatible input shape: expected " << dshape << ", got " << 
(*in_shape)[i];
+  }
+
+  if (!num_zero) dshape[axis] = size;
+  CHECK(shape_assign(&(*out_shape)[0], dshape))
+      << "Incompatible output shape: expected " << dshape << ", got " << 
(*out_shape)[0];
+  if ((*out_shape)[0][axis] != 0 && num_zero) {
+    int residual = (*out_shape)[0][axis] - size;
+    CHECK_GE(residual, 0)
+        << "Input size already exceeds output size. Residual: " << residual;
+    CHECK(num_zero <= 2 && num_zero >= 0)
+        << "Expecting 1 or 2 inputs that need shape inference. Got: " << 
num_zero;
+    bool need_infer = !(*out_shape)[0].Size();
+    for (int i = 0; i < num_zero; i++) {
+      (*in_shape)[i*2][axis] = residual / num_zero;
+      need_infer = need_infer || !(*in_shape)[i].Size();
+    }
+    return !need_infer;
+  }
+
+  return dshape.Size() != 0;
+}
+
 static bool ConcatType(const nnvm::NodeAttrs& attrs,
                        std::vector<int> *in_type,
                        std::vector<int> *out_type) {
@@ -228,6 +287,34 @@ struct ConcatGrad {
 
 DMLC_REGISTER_PARAMETER(ConcatParam);
 
+#define CONCAT_FORWARD_ATTRS \
+.set_num_inputs([](const NodeAttrs& attrs) { \
+  const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed); \
+  return params.num_args; \
+}) \
+.set_num_outputs(1) \
+.set_attr_parser(ParamParser<ConcatParam>) \
+.set_attr<nnvm::FListInputNames>("FListInputNames", \
+    [](const NodeAttrs& attrs) { \
+  const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed); \
+  std::vector<std::string> ret; \
+  for (int i = 0; i < params.num_args; ++i) { \
+    ret.push_back(std::string("arg") + std::to_string(i)); \
+  } \
+  return ret; \
+}) \
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", \
+    [](const NodeAttrs& attrs) { \
+    return std::vector<std::string>{"output"}; \
+}) \
+.set_attr<nnvm::FInferType>("FInferType", ConcatType) \
+.set_attr<FInferStorageType>("FInferStorageType", 
ConcatForwardInferStorageType) \
+.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>) \
+.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatComputeExCPU) \
+.set_attr<nnvm::FGradient>("FGradient", ConcatGrad{"_backward_Concat"}) \
+.set_attr<std::string>("key_var_num_args", "num_args")
+
+
 NNVM_REGISTER_OP(Concat)
 MXNET_ADD_SPARSE_OP_ALIAS(concat)
 .add_alias("concat")
@@ -268,37 +355,13 @@ Example::
                          [ 5.,  5.,  8.,  8.]]
 
 )code" ADD_FILELINE)
-.set_num_inputs([](const NodeAttrs& attrs) {
-  const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
-  return params.num_args;
-})
-.set_num_outputs(1)
-.set_attr_parser(ParamParser<ConcatParam>)
-.set_attr<nnvm::FListInputNames>("FListInputNames",
-    [](const NodeAttrs& attrs) {
-  const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
-  std::vector<std::string> ret;
-  for (int i = 0; i < params.num_args; ++i) {
-    ret.push_back(std::string("arg") + std::to_string(i));
-  }
-  return ret;
-})
-.set_attr<nnvm::FListOutputNames>("FListOutputNames",
-    [](const NodeAttrs& attrs) {
-    return std::vector<std::string>{"output"};
-})
 #if MXNET_USE_MKLDNN == 1
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
 })
 #endif
+CONCAT_FORWARD_ATTRS
 .set_attr<nnvm::FInferShape>("FInferShape", ConcatShape)
-.set_attr<nnvm::FInferType>("FInferType", ConcatType)
-.set_attr<FInferStorageType>("FInferStorageType", 
ConcatForwardInferStorageType)
-.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatComputeExCPU)
-.set_attr<nnvm::FGradient>("FGradient", ConcatGrad{"_backward_Concat"})
-.set_attr<std::string>("key_var_num_args", "num_args")
 .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
 .add_arguments(ConcatParam::__FIELDS__());
 
@@ -320,5 +383,19 @@ NNVM_REGISTER_OP(_backward_Concat)
 #endif
 .set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);
 
+// _rnn_param_concat is a custom concat op with specialized infer_shape,
+// which handles the case where the first one or two inputs may have
+// unknown shape that can be inferred from output shape.
+NNVM_REGISTER_OP(_rnn_param_concat)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
+#endif
+CONCAT_FORWARD_ATTRS
+.set_attr<nnvm::FInferShape>("FInferShape", RNNParamConcatShape)
+.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
+.add_arguments(ConcatParam::__FIELDS__());
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/nn/concat.cu b/src/operator/nn/concat.cu
index 4f6b8fc..2872d52 100644
--- a/src/operator/nn/concat.cu
+++ b/src/operator/nn/concat.cu
@@ -50,6 +50,10 @@ NNVM_REGISTER_OP(Concat)
 .set_attr<FCompute>("FCompute<gpu>", ConcatCompute<gpu>)
 .set_attr<FComputeEx>("FComputeEx<gpu>", ConcatComputeExGPU);
 
+NNVM_REGISTER_OP(_rnn_param_concat)
+.set_attr<FCompute>("FCompute<gpu>", ConcatCompute<gpu>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", ConcatComputeExGPU);
+
 NNVM_REGISTER_OP(_backward_Concat)
 .set_attr<FCompute>("FCompute<gpu>", ConcatGradCompute<gpu>);
 
diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc
index 1e670a9..73ef4f0 100644
--- a/src/operator/rnn.cc
+++ b/src/operator/rnn.cc
@@ -45,12 +45,12 @@ Operator *RNNProp::CreateOperatorEx(Context ctx,
 DMLC_REGISTER_PARAMETER(RNNParam);
 
 MXNET_REGISTER_OP_PROPERTY(RNN, RNNProp)
-.describe(R"code(Applies recurrent layers to input data. Currently, vanilla 
RNN, LSTM and GRU are 
+.describe(R"code(Applies recurrent layers to input data. Currently, vanilla 
RNN, LSTM and GRU are
 implemented, with both multi-layer and bidirectional support.
 
 **Vanilla RNN**
 
-Applies a single-gate recurrent layer to input X. Two kinds of activation 
function are supported: 
+Applies a single-gate recurrent layer to input X. Two kinds of activation 
function are supported:
 ReLU and Tanh.
 
 With ReLU activation function:
@@ -63,7 +63,7 @@ With Tanh activtion function:
 .. math::
     h_t = \tanh(W_{ih} * x_t + b_{ih}  +  W_{hh} * h_{(t-1)} + b_{hh})
 
-Reference paper: Finding structure in time - Elman, 1988. 
+Reference paper: Finding structure in time - Elman, 1988.
 https://crl.ucsd.edu/~elman/Papers/fsit.pdf
 
 **LSTM**
diff --git a/tests/python/gpu/test_gluon_gpu.py 
b/tests/python/gpu/test_gluon_gpu.py
new file mode 100644
index 0000000..42d65da
--- /dev/null
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -0,0 +1,203 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import print_function
+import sys
+import os
+import time
+import multiprocessing as mp
+import unittest
+import mxnet as mx
+import numpy as np
+import unittest
+from nose.tools import assert_raises
+from mxnet.test_utils import check_consistency, set_default_context, 
assert_almost_equal
+from mxnet.base import MXNetError
+from mxnet import autograd
+from numpy.testing import assert_allclose
+
+curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
+sys.path.insert(0, os.path.join(curr_path, '../unittest'))
+from common import setup_module, with_seed, teardown, 
assert_raises_cudnn_disabled
+from test_gluon import *
+from test_loss import *
+from test_gluon_rnn import *
+
+set_default_context(mx.gpu(0))
+
+def check_rnn_layer(layer):
+    layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)])
+    with mx.gpu(0):
+        x = mx.nd.ones((10, 16, 30))
+        states = layer.begin_state(16)
+        go, gs = layer(x, states)
+
+    with mx.cpu(0):
+        x = mx.nd.ones((10, 16, 30))
+        states = layer.begin_state(16)
+        co, cs = layer(x, states)
+
+    # atol of 1e-6 required, as exposed by seed 2124685726
+    assert_almost_equal(go.asnumpy(), co.asnumpy(), rtol=1e-2, atol=1e-6)
+    for g, c in zip(gs, cs):
+        assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6)
+
+
+def check_rnn_layer_w_rand_inputs(layer):
+    layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)])
+    x = mx.nd.uniform(shape=(10, 16, 30))
+    with mx.gpu(0):
+        x = x.copyto(mx.gpu(0))
+        states = layer.begin_state(16)
+        go, gs = layer(x, states)
+
+    with mx.cpu(0):
+        x = x.copyto(mx.cpu(0))
+        states = layer.begin_state(16)
+        co, cs = layer(x, states)
+
+    assert_almost_equal(go.asnumpy(), co.asnumpy(), rtol=1e-2, atol=1e-6)
+    for g, c in zip(gs, cs):
+        assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6)
+
+
+@with_seed()
+@assert_raises_cudnn_disabled()
+def test_rnn_layer():
+    check_rnn_layer(gluon.rnn.RNN(100, num_layers=3))
+    check_rnn_layer(gluon.rnn.RNN(100, activation='tanh', num_layers=3))
+    check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3))
+    check_rnn_layer(gluon.rnn.GRU(100, num_layers=3))
+
+    check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True))
+    check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM(100, num_layers=3, 
bidirectional=True))
+
+
+@with_seed()
+def test_gluon_ctc_consistency():
+    loss = mx.gluon.loss.CTCLoss()
+    data = mx.nd.arange(0, 4, repeat=40, 
ctx=mx.gpu(0)).reshape((2,20,4)).flip(axis=0)
+    cpu_label = mx.nd.array([[2,1,-1,-1],[3,2,2,-1]], ctx=mx.cpu(0))
+    gpu_label = mx.nd.array([[2,1,-1,-1],[3,2,2,-1]], ctx=mx.gpu(0))
+
+    cpu_data = data.copy().as_in_context(mx.cpu(0))
+    cpu_data.attach_grad()
+    with mx.autograd.record():
+        l_cpu = loss(cpu_data, cpu_label)
+        l_cpu.backward()
+
+    gpu_data = data.copyto(mx.gpu(0))
+    gpu_data.attach_grad()
+    with mx.autograd.record():
+        l_gpu = loss(gpu_data, gpu_label)
+        l_gpu.backward()
+
+    assert_almost_equal(cpu_data.grad.asnumpy(), gpu_data.grad.asnumpy(), 
atol=1e-3, rtol=1e-3)
+
+
+@with_seed()
+def test_global_norm_clip_multi_device():
+    x1 = mx.nd.ones((3,3), ctx=mx.gpu(0))
+    x2 = mx.nd.ones((4,4), ctx=mx.cpu(0))
+    norm = gluon.utils.clip_global_norm([x1, x2], 1.0)
+    assert norm == 5.0
+    assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5)
+    assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5)
+
+
+def _check_batchnorm_result(input, num_devices=1, cuda=False):
+    from mxnet.gluon.utils import split_and_load
+    def _find_bn(module):
+        if isinstance(module, (mx.gluon.nn.BatchNorm, 
mx.gluon.contrib.nn.SyncBatchNorm)):
+            return module
+        elif isinstance(module.module, (mx.gluon.nn.BatchNorm, 
mx.gluon.contrib.nn.SyncBatchNorm)):
+            return module.module
+
+        raise RuntimeError('BN not found')
+
+    def _syncParameters(bn1, bn2, ctx):
+        ctx = input.context
+        bn2.gamma.set_data(bn1.gamma.data(ctx))
+        bn2.beta.set_data(bn1.beta.data(ctx))
+        bn2.running_mean.set_data(bn1.running_mean.data(ctx))
+        bn2.running_var.set_data(bn1.running_var.data(ctx))
+
+    input1 = input.copy()
+    input2 = input.copy()
+
+    if cuda:
+        input1 = input.as_in_context(mx.gpu(0))
+        ctx_list = [mx.gpu(i) for i in range(num_devices)]
+    else:
+        ctx_list = [mx.cpu(0) for _ in range(num_devices)]
+
+    nch = input.shape[1]
+    bn1 = mx.gluon.nn.BatchNorm(in_channels=nch)
+    bn2 = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=nch, 
num_devices=num_devices)
+
+    bn1.initialize(ctx=ctx_list[0])
+    bn2.initialize(ctx=ctx_list)
+
+    # using the same values for gamma and beta
+    #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0])
+
+    input1.attach_grad()
+    inputs2 = split_and_load(input2, ctx_list, batch_axis=0)
+    for xi in inputs2:
+        xi.attach_grad()
+
+    with mx.autograd.record():
+        output1 = bn1(input1)
+        output2  = [bn2(xi) for xi in inputs2]
+        loss1 = (output1 ** 2).sum()
+        loss2 = [(output ** 2).sum() for output in output2]
+        mx.autograd.backward(loss1)
+        mx.autograd.backward(loss2)
+
+    output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in 
output2], dim=0)
+    # assert forwarding
+    assert_almost_equal(input1.asnumpy(), input2.asnumpy(), atol=1e-3, 
rtol=1e-3)
+    assert_almost_equal(output1.asnumpy(), output2.asnumpy(), atol=1e-3, 
rtol=1e-3)
+    assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(),
+                        _find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(),
+                        atol=1e-3, rtol=1e-3)
+    assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(),
+                        _find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(),
+                        atol=1e-3, rtol=1e-3)
+    input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for 
output in inputs2], dim=0)
+    assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(), 
atol=1e-3, rtol=1e-3)
+
+
+def test_sync_batchnorm():
+    def get_num_devices():
+        for i in range(100):
+            try:
+                mx.nd.zeros((1,), ctx=mx.gpu(i))
+            except:
+                return i
+    # no need to use SyncBN with 1 gpu
+    if get_num_devices() < 2:
+        return
+    ndev = 2
+    # check with unsync version
+    for i in range(10):
+        _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)),
+                                num_devices=ndev, cuda=True)
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()
diff --git a/tests/python/gpu/test_operator_gpu.py 
b/tests/python/gpu/test_operator_gpu.py
index a3e663a..3d799aa 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -36,11 +36,8 @@ from common import setup_module, with_seed, teardown, 
assert_raises_cudnn_disabl
 from test_operator import *
 from test_optimizer import *
 from test_random import *
-from test_gluon import *
-from test_loss import *
 from test_exc_handling import *
 #from test_rnn import *
-from test_gluon_rnn import *
 from test_sparse_ndarray import *
 from test_sparse_operator import *
 from test_ndarray import *
@@ -1661,17 +1658,6 @@ def check_rnn_layer_w_rand_inputs(layer):
         assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6)
 
 @with_seed()
-@assert_raises_cudnn_disabled()
-def test_rnn_layer():
-    check_rnn_layer(gluon.rnn.RNN(100, num_layers=3))
-    check_rnn_layer(gluon.rnn.RNN(100, activation='tanh', num_layers=3))
-    check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3))
-    check_rnn_layer(gluon.rnn.GRU(100, num_layers=3))
-
-    check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True))
-    check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM(100, num_layers=3, 
bidirectional=True))
-
-@with_seed()
 def test_sequence_reverse():
     check_sequence_reverse(mx.gpu(0))
 
@@ -1689,28 +1675,6 @@ def test_autograd_save_memory():
 
 
 @with_seed()
-def test_gluon_ctc_consistency():
-    loss = mx.gluon.loss.CTCLoss()
-    data = mx.nd.arange(0, 4, repeat=40, 
ctx=mx.gpu(0)).reshape((2,20,4)).flip(axis=0)
-    cpu_label = mx.nd.array([[2,1,-1,-1],[3,2,2,-1]], ctx=mx.cpu(0))
-    gpu_label = mx.nd.array([[2,1,-1,-1],[3,2,2,-1]], ctx=mx.gpu(0))
-
-    cpu_data = data.copy().as_in_context(mx.cpu(0))
-    cpu_data.attach_grad()
-    with mx.autograd.record():
-        l_cpu = loss(cpu_data, cpu_label)
-        l_cpu.backward()
-
-    gpu_data = data.copyto(mx.gpu(0))
-    gpu_data.attach_grad()
-    with mx.autograd.record():
-        l_gpu = loss(gpu_data, gpu_label)
-        l_gpu.backward()
-
-    assert_almost_equal(cpu_data.grad.asnumpy(), gpu_data.grad.asnumpy(), 
atol=1e-3, rtol=1e-3)
-
-
-@with_seed()
 def test_cuda_rtc():
     source = r'''
     extern "C" __global__ void axpy(const float *x, float *y, float alpha) {
@@ -1741,16 +1705,6 @@ def test_cuda_rtc():
 
 
 @with_seed()
-def test_global_norm_clip_multi_device():
-    x1 = mx.nd.ones((3,3), ctx=mx.gpu(0))
-    x2 = mx.nd.ones((4,4), ctx=mx.cpu(0))
-    norm = gluon.utils.clip_global_norm([x1, x2], 1.0)
-    assert norm == 5.0
-    assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5)
-    assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5)
-
-
-@with_seed()
 def test_cross_device_autograd():
     x = mx.nd.random.uniform(shape=(10,))
     x.attach_grad()
@@ -1968,84 +1922,6 @@ def test_context_num_gpus():
     # Test that num_gpus reports at least one GPU, as the test is run on a GPU 
host.
     assert mx.context.num_gpus() > 0
 
-def _check_batchnorm_result(input, num_devices=1, cuda=False):
-    from mxnet.gluon.utils import split_and_load
-    def _find_bn(module):
-        if isinstance(module, (mx.gluon.nn.BatchNorm, 
mx.gluon.contrib.nn.SyncBatchNorm)):
-            return module
-        elif isinstance(module.module, (mx.gluon.nn.BatchNorm, 
mx.gluon.contrib.nn.SyncBatchNorm)):
-            return module.module
-
-        raise RuntimeError('BN not found')
-
-    def _syncParameters(bn1, bn2, ctx):
-        ctx = input.context
-        bn2.gamma.set_data(bn1.gamma.data(ctx))
-        bn2.beta.set_data(bn1.beta.data(ctx))
-        bn2.running_mean.set_data(bn1.running_mean.data(ctx))
-        bn2.running_var.set_data(bn1.running_var.data(ctx))
-
-    input1 = input.copy()
-    input2 = input.copy()
-
-    if cuda:
-        input1 = input.as_in_context(mx.gpu(0))
-        ctx_list = [mx.gpu(i) for i in range(num_devices)]
-    else:
-        ctx_list = [mx.cpu(0) for _ in range(num_devices)]
-
-    nch = input.shape[1]
-    bn1 = mx.gluon.nn.BatchNorm(in_channels=nch)
-    bn2 = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=nch, 
num_devices=num_devices)
-
-    bn1.initialize(ctx=ctx_list[0])
-    bn2.initialize(ctx=ctx_list)
-
-    # using the same values for gamma and beta
-    #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0])
-
-    input1.attach_grad()
-    inputs2 = split_and_load(input2, ctx_list, batch_axis=0)
-    for xi in inputs2:
-        xi.attach_grad()
-
-    with mx.autograd.record():
-        output1 = bn1(input1)
-        output2  = [bn2(xi) for xi in inputs2]
-        loss1 = (output1 ** 2).sum()
-        loss2 = [(output ** 2).sum() for output in output2]
-        mx.autograd.backward(loss1)
-        mx.autograd.backward(loss2)
-
-    output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in 
output2], dim=0)
-    # assert forwarding
-    assert_almost_equal(input1.asnumpy(), input2.asnumpy(), atol=1e-3, 
rtol=1e-3)
-    assert_almost_equal(output1.asnumpy(), output2.asnumpy(), atol=1e-3, 
rtol=1e-3)
-    assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(),
-                        _find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(),
-                        atol=1e-3, rtol=1e-3)
-    assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(),
-                        _find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(),
-                        atol=1e-3, rtol=1e-3)
-    input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for 
output in inputs2], dim=0)
-    assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(), 
atol=1e-3, rtol=1e-3)
-
-def test_sync_batchnorm():
-    def get_num_devices():
-        for i in range(100):
-            try:
-                mx.nd.zeros((1,), ctx=mx.gpu(i))
-            except:
-                return i
-    # no need to use SyncBN with 1 gpu
-    if get_num_devices() < 2:
-        return
-    ndev = 2
-    # check with unsync version
-    for i in range(10):
-        _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)),
-                                num_devices=ndev, cuda=True)
-
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_gluon_rnn.py 
b/tests/python/unittest/test_gluon_rnn.py
index a9a2904..4e8241f 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -16,7 +16,7 @@
 # under the License.
 
 import mxnet as mx
-from mxnet import gluon
+from mxnet import gluon, nd
 import numpy as np
 import copy
 from numpy.testing import assert_allclose
@@ -25,7 +25,6 @@ from mxnet.test_utils import almost_equal, assert_almost_equal
 from common import assert_raises_cudnn_disabled
 
 
-@assert_raises_cudnn_disabled()
 def test_rnn():
     cell = gluon.rnn.RNNCell(100, prefix='rnn_')
     inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
@@ -51,7 +50,6 @@ def test_lstm():
     assert outs == [(10, 100), (10, 100), (10, 100)]
 
 
-@assert_raises_cudnn_disabled()
 def test_lstm_forget_bias():
     forget_bias = 2.0
     stack = gluon.rnn.SequentialRNNCell()
@@ -77,19 +75,23 @@ def test_lstm_forget_bias():
 def test_lstm_cpu_inference():
     # should behave the same as lstm cell
     EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 
0.95215213],
-                                  [0.72045636, 0.72045636, 0.95215213, 
0.95215213]],
-                                 [[0.95215213, 0.95215213, 0.72045636, 
0.72045636],
-                                  [0.95215213, 0.95215213, 0.72045636, 
0.72045636]]])
+                                      [0.72045636, 0.72045636, 0.95215213, 
0.95215213]],
+                                     [[0.95215213, 0.95215213, 0.72045636, 
0.72045636],
+                                      [0.95215213, 0.95215213, 0.72045636, 
0.72045636]]])
     x = mx.nd.ones(shape=(2, 2, 2))
     model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True)
+    model_cell = model._unfuse()
     model.initialize(mx.init.One())
+
     y = model(x).asnumpy()
+    y_cell = model_cell.unroll(2, x, layout='TNC', 
merge_outputs=True)[0].asnumpy()
 
+    mx.test_utils.assert_almost_equal(y_cell, EXPECTED_LSTM_OUTPUT,
+                                      rtol=1e-3, atol=1e-5)
     mx.test_utils.assert_almost_equal(y, EXPECTED_LSTM_OUTPUT,
                                       rtol=1e-3, atol=1e-5)
 
 
-@assert_raises_cudnn_disabled()
 def test_gru():
     cell = gluon.rnn.GRUCell(100, prefix='rnn_')
     inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
@@ -241,6 +243,46 @@ def test_bidirectional():
     assert outs == [(10, 200), (10, 200), (10, 200)]
 
 
+@assert_raises_cudnn_disabled()
+def test_layer_bidirectional():
+    class RefBiLSTM(gluon.Block):
+        def __init__(self, size, **kwargs):
+            super(RefBiLSTM, self).__init__(**kwargs)
+            with self.name_scope():
+                self._lstm_fwd = gluon.rnn.LSTM(size, bidirectional=False, 
prefix='l0')
+                self._lstm_bwd = gluon.rnn.LSTM(size, bidirectional=False, 
prefix='r0')
+
+        def forward(self, inpt):
+            fwd = self._lstm_fwd(inpt)
+            bwd_inpt = nd.flip(inpt, 0)
+            bwd = self._lstm_bwd(bwd_inpt)
+            bwd = nd.flip(bwd, 0)
+            return nd.concat(fwd, bwd, dim=2)
+
+    size = 7
+    in_size = 5
+    weights = {}
+    for d in ['l', 'r']:
+        weights['lstm_{}0_i2h_weight'.format(d)] = 
mx.random.uniform(shape=(size*4, in_size))
+        weights['lstm_{}0_h2h_weight'.format(d)] = 
mx.random.uniform(shape=(size*4, size))
+        weights['lstm_{}0_i2h_bias'.format(d)] = 
mx.random.uniform(shape=(size*4,))
+        weights['lstm_{}0_h2h_bias'.format(d)] = 
mx.random.uniform(shape=(size*4,))
+
+    net = gluon.rnn.LSTM(size, bidirectional=True, prefix='lstm_')
+    ref_net = RefBiLSTM(size, prefix='lstm_')
+    net.initialize()
+    ref_net.initialize()
+    net_params = net.collect_params()
+    ref_net_params = ref_net.collect_params()
+    for k in weights:
+        net_params[k].set_data(weights[k])
+        ref_net_params[k.replace('l0', 'l0l0').replace('r0', 
'r0l0')].set_data(weights[k])
+
+    data = mx.random.uniform(shape=(3, 10, in_size))
+    assert_allclose(net(data).asnumpy(), ref_net(data).asnumpy())
+
+
+
 def test_zoneout():
     cell = gluon.rnn.ZoneoutCell(gluon.rnn.RNNCell(100, prefix='rnn_'), 
zoneout_outputs=0.5,
                               zoneout_states=0.5)
@@ -341,9 +383,12 @@ def check_rnn_layer_forward(layer, inputs, states=None, 
run_only=False):
     layer.collect_params().initialize()
     inputs.attach_grad()
     with mx.autograd.record():
-        out = layer(inputs, states)
+        if states is None:
+            out = layer(inputs)
+        else:
+            out = layer(inputs, states)
         if states is not None:
-            assert isinstance(out, tuple) and len(out) == 2
+            assert isinstance(out, (list, tuple)) and len(out) == 2
             out = out[0]
         else:
             assert isinstance(out, mx.nd.NDArray)
@@ -355,15 +400,19 @@ def check_rnn_layer_forward(layer, inputs, states=None, 
run_only=False):
     layer.hybridize()
 
     with mx.autograd.record():
-        out = layer(inputs, states)
         if states is not None:
-            assert isinstance(out, tuple) and len(out) == 2
+            out = layer(inputs, states)
+            assert isinstance(out, (list, tuple)) and len(out) == 2
             out = out[0]
         else:
+            out = layer(inputs)
             assert isinstance(out, mx.nd.NDArray)
         out.backward()
 
-    layer(inputs, states) # test is_training = false
+    if states is not None:
+        layer(inputs, states) # test is_training = false
+    else:
+        layer(inputs)
 
     if not run_only:
         mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, 
atol=1e-5)
@@ -393,15 +442,26 @@ def test_rnn_layers():
     check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, 
dropout=0.5),
                             mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), 
run_only=True)
 
-    net = gluon.nn.Sequential()
-    net.add(gluon.rnn.LSTM(10, 2, bidirectional=True))
+    net = gluon.nn.HybridSequential()
+    net.add(gluon.rnn.LSTM(10, bidirectional=True))
     net.add(gluon.nn.BatchNorm(axis=2))
     net.add(gluon.nn.Flatten())
     net.add(gluon.nn.Dense(3, activation='relu'))
+    net.hybridize()
     net.collect_params().initialize()
     with mx.autograd.record():
         net(mx.nd.ones((2, 3, 10))).backward()
 
+    net2 = gluon.nn.HybridSequential()
+    net2.add(gluon.rnn.LSTM(10, bidirectional=True))
+    net2.add(gluon.nn.BatchNorm(axis=2))
+    net2.add(gluon.nn.Flatten())
+    net2.add(gluon.nn.Dense(3, activation='relu'))
+    net2.hybridize()
+    net2.collect_params().initialize()
+    with mx.autograd.record():
+        net2(mx.nd.ones((2, 3, 10))).backward()
+
 
 def test_rnn_unroll_variant_length():
     # Test for imperative usage
@@ -487,10 +547,9 @@ def test_cell_fill_shape():
 @assert_raises_cudnn_disabled()
 def test_layer_fill_shape():
     layer = gluon.rnn.LSTM(10)
-    layer.hybridize()
     check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7)))
     print(layer)
-    assert layer.i2h_weight[0].shape[1] == 7, layer.i2h_weight[0].shape[1]
+    assert layer.l0_i2h_weight.shape[1] == 7, layer.l0_i2h_weight.shape[1]
 
 
 if __name__ == '__main__':

Reply via email to