This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6aa8c27  [MXNET-1327] Allow RNN Layers to be initialized to fp16 
(#14219)
6aa8c27 is described below

commit 6aa8c27a69cafc65eb97c4e077cb399462b617e2
Author: Thomas Delteil <[email protected]>
AuthorDate: Tue Mar 12 14:30:19 2019 -0700

    [MXNET-1327] Allow RNN Layers to be initialized to fp16 (#14219)
    
    * update rnn for fp16
    
    * fix typo in test
    
    * fix tests
    
    * fix tests
    
    * fix gpu tests
    
    * Update test_gluon_rnn.py
    
    * Update test_gluon_rnn.py
    
    * trigger
    
    * try removing checks for unix
---
 python/mxnet/gluon/rnn/rnn_layer.py     | 59 ++++++++++++--------
 tests/python/unittest/test_gluon_rnn.py | 98 ++++++++++++++++++++++-----------
 2 files changed, 101 insertions(+), 56 deletions(-)

diff --git a/python/mxnet/gluon/rnn/rnn_layer.py 
b/python/mxnet/gluon/rnn/rnn_layer.py
index c43dc85..6dfec43 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -37,7 +37,7 @@ class _RNNLayer(HybridBlock):
                  i2h_bias_initializer, h2h_bias_initializer,
                  mode, projection_size, h2r_weight_initializer,
                  lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan,
-                 **kwargs):
+                 dtype, **kwargs):
         super(_RNNLayer, self).__init__(**kwargs)
         assert layout in ('TNC', 'NTC'), \
             "Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout
@@ -57,6 +57,7 @@ class _RNNLayer(HybridBlock):
         self._lstm_state_clip_min = lstm_state_clip_min
         self._lstm_state_clip_max = lstm_state_clip_max
         self._lstm_state_clip_nan = lstm_state_clip_nan
+        self._dtype = dtype
 
         self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]
 
@@ -66,16 +67,16 @@ class _RNNLayer(HybridBlock):
                 for j in ['l', 'r'][:self._dir]:
                     self._register_param('{}{}_i2h_weight'.format(j, i),
                                          shape=(ng*nh, ni),
-                                         init=i2h_weight_initializer)
+                                         init=i2h_weight_initializer, 
dtype=dtype)
                     self._register_param('{}{}_h2h_weight'.format(j, i),
                                          shape=(ng*nh, nh),
-                                         init=h2h_weight_initializer)
+                                         init=h2h_weight_initializer, 
dtype=dtype)
                     self._register_param('{}{}_i2h_bias'.format(j, i),
                                          shape=(ng*nh,),
-                                         init=i2h_bias_initializer)
+                                         init=i2h_bias_initializer, 
dtype=dtype)
                     self._register_param('{}{}_h2h_bias'.format(j, i),
                                          shape=(ng*nh,),
-                                         init=h2h_bias_initializer)
+                                         init=h2h_bias_initializer, 
dtype=dtype)
                 ni = nh * self._dir
         else:
             np = self._projection_size
@@ -83,24 +84,24 @@ class _RNNLayer(HybridBlock):
                 for j in ['l', 'r'][:self._dir]:
                     self._register_param('{}{}_i2h_weight'.format(j, i),
                                          shape=(ng*nh, ni),
-                                         init=i2h_weight_initializer)
+                                         init=i2h_weight_initializer, 
dtype=dtype)
                     self._register_param('{}{}_h2h_weight'.format(j, i),
                                          shape=(ng*nh, np),
-                                         init=h2h_weight_initializer)
+                                         init=h2h_weight_initializer, 
dtype=dtype)
                     self._register_param('{}{}_i2h_bias'.format(j, i),
                                          shape=(ng*nh,),
-                                         init=i2h_bias_initializer)
+                                         init=i2h_bias_initializer, 
dtype=dtype)
                     self._register_param('{}{}_h2h_bias'.format(j, i),
                                          shape=(ng*nh,),
-                                         init=h2h_bias_initializer)
+                                         init=h2h_bias_initializer, 
dtype=dtype)
                     self._register_param('{}{}_h2r_weight'.format(j, i),
                                          shape=(np, nh),
-                                         init=h2r_weight_initializer)
+                                         init=h2r_weight_initializer, 
dtype=dtype)
                 ni = np * self._dir
 
-    def _register_param(self, name, shape, init):
+    def _register_param(self, name, shape, init, dtype):
         p = self.params.get(name, shape=shape, init=init,
-                            allow_deferred_init=True)
+                            allow_deferred_init=True, dtype=dtype)
         setattr(self, name, p)
         return p
 
@@ -179,6 +180,10 @@ class _RNNLayer(HybridBlock):
 
         return stack
 
+    def cast(self, dtype):
+        super(_RNNLayer, self).cast(dtype)
+        self._dtype = dtype
+
     def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
         """Initial state for this cell.
 
@@ -317,6 +322,8 @@ class RNN(_RNNLayer):
     input_size: int, default 0
         The number of expected features in the input x.
         If not specified, it will be inferred from input.
+    dtype : str, default 'float32'
+        Type to initialize the parameters and default states to
     prefix : str or None
         Prefix of this `Block`.
     params : ParameterDict or None
@@ -357,17 +364,17 @@ class RNN(_RNNLayer):
                  layout='TNC', dropout=0, bidirectional=False,
                  i2h_weight_initializer=None, h2h_weight_initializer=None,
                  i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
-                 input_size=0, **kwargs):
+                 input_size=0, dtype='float32', **kwargs):
         super(RNN, self).__init__(hidden_size, num_layers, layout,
                                   dropout, bidirectional, input_size,
                                   i2h_weight_initializer, 
h2h_weight_initializer,
                                   i2h_bias_initializer, h2h_bias_initializer,
                                   'rnn_'+activation, None, None, None, None, 
False,
-                                  **kwargs)
+                                  dtype, **kwargs)
 
     def state_info(self, batch_size=0):
         return [{'shape': (self._num_layers * self._dir, batch_size, 
self._hidden_size),
-                 '__layout__': 'LNC'}]
+                 '__layout__': 'LNC', 'dtype': self._dtype}]
 
 
 class LSTM(_RNNLayer):
@@ -432,6 +439,8 @@ class LSTM(_RNNLayer):
     state_clip_nan : boolean, default False
         Whether to stop NaN from propagating in state by clipping it to 
min/max.
         If the clipping range is not specified, this option is ignored.
+    dtype : str, default 'float32'
+        Type to initialize the parameters and default states to
     input_size: int, default 0
         The number of expected features in the input x.
         If not specified, it will be inferred from input.
@@ -477,26 +486,26 @@ class LSTM(_RNNLayer):
                  i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                  projection_size=None, h2r_weight_initializer=None,
                  state_clip_min=None, state_clip_max=None, 
state_clip_nan=False,
-                 **kwargs):
+                 dtype='float32', **kwargs):
         super(LSTM, self).__init__(hidden_size, num_layers, layout,
                                    dropout, bidirectional, input_size,
                                    i2h_weight_initializer, 
h2h_weight_initializer,
                                    i2h_bias_initializer, h2h_bias_initializer,
                                    'lstm', projection_size, 
h2r_weight_initializer,
                                    state_clip_min, state_clip_max, 
state_clip_nan,
-                                   **kwargs)
+                                   dtype, **kwargs)
 
     def state_info(self, batch_size=0):
         if self._projection_size is None:
             return [{'shape': (self._num_layers * self._dir, batch_size, 
self._hidden_size),
-                     '__layout__': 'LNC'},
+                     '__layout__': 'LNC', 'dtype': self._dtype},
                     {'shape': (self._num_layers * self._dir, batch_size, 
self._hidden_size),
-                     '__layout__': 'LNC'}]
+                     '__layout__': 'LNC', 'dtype': self._dtype}]
         else:
             return [{'shape': (self._num_layers * self._dir, batch_size, 
self._projection_size),
-                     '__layout__': 'LNC'},
+                     '__layout__': 'LNC', 'dtype': self._dtype},
                     {'shape': (self._num_layers * self._dir, batch_size, 
self._hidden_size),
-                     '__layout__': 'LNC'}]
+                     '__layout__': 'LNC', 'dtype': self._dtype}]
 
 
 class GRU(_RNNLayer):
@@ -544,6 +553,8 @@ class GRU(_RNNLayer):
         Initializer for the bias vector.
     h2h_bias_initializer : str or Initializer
         Initializer for the bias vector.
+    dtype : str, default 'float32'
+        Type to initialize the parameters and default states to
     input_size: int, default 0
         The number of expected features in the input x.
         If not specified, it will be inferred from input.
@@ -586,14 +597,14 @@ class GRU(_RNNLayer):
                  dropout=0, bidirectional=False, input_size=0,
                  i2h_weight_initializer=None, h2h_weight_initializer=None,
                  i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
-                 **kwargs):
+                 dtype='float32', **kwargs):
         super(GRU, self).__init__(hidden_size, num_layers, layout,
                                   dropout, bidirectional, input_size,
                                   i2h_weight_initializer, 
h2h_weight_initializer,
                                   i2h_bias_initializer, h2h_bias_initializer,
                                   'gru', None, None, None, None, False,
-                                  **kwargs)
+                                  dtype, **kwargs)
 
     def state_info(self, batch_size=0):
         return [{'shape': (self._num_layers * self._dir, batch_size, 
self._hidden_size),
-                 '__layout__': 'LNC'}]
+                 '__layout__': 'LNC', 'dtype': self._dtype}]
diff --git a/tests/python/unittest/test_gluon_rnn.py 
b/tests/python/unittest/test_gluon_rnn.py
index edc43d2..b410362 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -427,9 +427,15 @@ def test_rnn_cells_export_import():
         assert_almost_equal(output1.asnumpy(), output2.asnumpy())
 
 
-def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
-    layer.collect_params().initialize()
+def check_rnn_layer_forward(layer, inputs, states=None, run_only=False, 
ctx=mx.cpu()):
+    layer.collect_params().initialize(ctx=ctx)
+    inputs = inputs.as_in_context(ctx)
     inputs.attach_grad()
+    if states is not None:
+        if isinstance(states, (list, tuple)):
+            states = [s.as_in_context(ctx) for s in states]
+        else:
+            states = states.as_in_context(ctx)
     with mx.autograd.record():
         if states is None:
             out = layer(inputs)
@@ -467,47 +473,76 @@ def check_rnn_layer_forward(layer, inputs, states=None, 
run_only=False):
         mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), 
rtol=1e-3, atol=1e-5)
 
 
-@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
-def test_rnn_layers():
-    check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)))
-    check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True), 
mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)))
-    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2), mx.nd.ones((8, 3, 20)))
-    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True), 
mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 10)), mx.nd.ones((4, 3, 10))])
-    check_rnn_layer_forward(gluon.rnn.GRU(10, 2), mx.nd.ones((8, 3, 20)))
-    check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True), 
mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)))
-
-    check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dropout=0.5), mx.nd.ones((8, 
3, 20)),
-                            run_only=True)
-    check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True, 
dropout=0.5),
-                            mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), 
run_only=True)
-    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5), mx.nd.ones((8, 
3, 20)),
-                            run_only=True)
-    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, 
dropout=0.5),
-                            mx.nd.ones((8, 3, 20)),
-                            [mx.nd.ones((4, 3, 10)), mx.nd.ones((4, 3, 10))], 
run_only=True)
-    check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5), mx.nd.ones((8, 
3, 20)),
-                            run_only=True)
-    check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, 
dropout=0.5),
-                            mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), 
run_only=True)
+
+def run_rnn_layers(dtype, dtype2, ctx=mx.cpu()):
+
+    check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype), mx.nd.ones((8, 
3, 20), dtype=dtype), ctx=ctx)
+    check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, 
bidirectional=True), mx.nd.ones((8, 3, 20),  dtype=dtype), mx.nd.ones((4, 3, 
10),  dtype=dtype), ctx=ctx)
+    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype), mx.nd.ones((8, 
3, 20),  dtype=dtype), ctx=ctx)
+    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype,  
bidirectional=True), mx.nd.ones((8, 3, 20),  dtype=dtype), [mx.nd.ones((4, 3, 
10),  dtype=dtype), mx.nd.ones((4, 3, 10),  dtype=dtype)],ctx=ctx)
+    check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, ), 
mx.nd.ones((8, 3, 20), dtype=dtype),ctx=ctx)
+    check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, 
bidirectional=True), mx.nd.ones((8, 3, 20),  dtype=dtype), mx.nd.ones((4, 3, 
10),  dtype=dtype),ctx=ctx)
+
+
+    check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, dropout=0.5), 
mx.nd.ones((8, 3, 20), dtype=dtype),
+                            run_only=True, ctx=ctx)
+    check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True, 
dropout=0.5, dtype=dtype),
+                            mx.nd.ones((8, 3, 20), dtype=dtype), 
mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True, ctx=ctx)
+    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, dtype=dtype), 
mx.nd.ones((8, 3, 20), dtype=dtype),
+                            run_only=True, ctx=ctx)
+    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, 
dropout=0.5, dtype=dtype),
+                            mx.nd.ones((8, 3, 20), dtype=dtype),
+                            [mx.nd.ones((4, 3, 10), dtype=dtype), 
mx.nd.ones((4, 3, 10), dtype=dtype)], run_only=True, ctx=ctx)
+    check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5, dtype=dtype), 
mx.nd.ones((8, 3, 20), dtype=dtype),
+                            run_only=True, ctx=ctx)
+    check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, 
dropout=0.5, dtype=dtype),
+                            mx.nd.ones((8, 3, 20), dtype=dtype), 
mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True, ctx=ctx)
 
     net = gluon.nn.Sequential()
-    net.add(gluon.rnn.LSTM(10, bidirectional=True))
+    net.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype2))
     net.add(gluon.nn.BatchNorm(axis=2))
     net.add(gluon.nn.Flatten())
     net.add(gluon.nn.Dense(3, activation='relu'))
-    net.collect_params().initialize()
+    net.collect_params().initialize(ctx=ctx)
+    net.cast(dtype)
     with mx.autograd.record():
-        net(mx.nd.ones((2, 3, 10))).backward()
+        out = net(mx.nd.ones((2, 3, 10), dtype=dtype, ctx=ctx))
+        out.backward()
+        out = out.asnumpy()
 
     net2 = gluon.nn.HybridSequential()
-    net2.add(gluon.rnn.LSTM(10, bidirectional=True))
+    net2.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype2))
     net2.add(gluon.nn.BatchNorm(axis=2))
     net2.add(gluon.nn.Flatten())
     net2.add(gluon.nn.Dense(3, activation='relu'))
     net2.hybridize()
-    net2.collect_params().initialize()
+    net2.collect_params().initialize(ctx=ctx)
+    net2.cast(dtype)
+    with mx.autograd.record():
+        out = net2(mx.nd.ones((2, 3, 10), dtype=dtype, ctx=ctx))
+        out.backward()
+        out = out.asnumpy()
+
+    net3 = gluon.nn.HybridSequential()
+    net3.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype))
+    net3.add(gluon.nn.BatchNorm(axis=2))
+    net3.add(gluon.nn.Flatten())
+    net3.add(gluon.nn.Dense(3, activation='relu'))
+    net3.hybridize()
+    net3.collect_params().initialize(ctx=ctx)
+    net3.cast(dtype2)
     with mx.autograd.record():
-        net2(mx.nd.ones((2, 3, 10))).backward()
+        out = net3(mx.nd.ones((2, 3, 10), dtype=dtype2, ctx=ctx))
+        out.backward()
+        out = out.asnumpy()
+
+def test_rnn_layers_fp32():
+    run_rnn_layers('float32', 'float32')
+
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
[email protected](mx.context.num_gpus() == 0, "RNN FP16 only implemented for 
GPU for now")
+def test_rnn_layers_fp16():
+    run_rnn_layers('float16', 'float32', mx.gpu())
 
 
 def test_rnn_unroll_variant_length():
@@ -590,8 +625,6 @@ def test_cell_fill_shape():
     check_rnn_forward(cell, mx.nd.ones((2, 3, 7)))
     assert cell.i2h_weight.shape[1] == 7, cell.i2h_weight.shape[1]
 
-
-@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_layer_fill_shape():
     layer = gluon.rnn.LSTM(10)
     layer.hybridize()
@@ -603,6 +636,7 @@ def test_layer_fill_shape():
 def test_bidirectional_unroll_valid_length():
     # Test BidirectionalCell.
     # In 1.3.1 version, after hybridize( ), BidirectionalCell would failed 
when pass valid_length to unroll( ).
+    
     class BiLSTM(gluon.nn.HybridBlock):
         def __init__(self, rnn_size, time_step, **kwargs):
             super(BiLSTM, self).__init__(**kwargs)

Reply via email to