This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b848c24  fix example/rnn: Speedometer(..., auto_reset=False) (#6679)
b848c24 is described below

commit b848c241be41b6933923e9acc7caa40d6c2f76b4
Author: Leonard Lausen <[email protected]>
AuthorDate: Fri Aug 4 05:35:18 2017 +0900

    fix example/rnn: Speedometer(..., auto_reset=False) (#6679)
    
    If the Speedometer resets the eval_metric and due to an unlucky number of
    batches the end_of_batch is reached immediately after, the Perplexity will 
throw
    an ZeroDivisionError as eval_metric.num_inst == 0.
---
 example/rnn/cudnn_lstm_bucketing.py | 4 ++--
 example/rnn/lstm_bucketing.py       | 2 +-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/example/rnn/cudnn_lstm_bucketing.py 
b/example/rnn/cudnn_lstm_bucketing.py
index 140f2e6..fbf32bb 100644
--- a/example/rnn/cudnn_lstm_bucketing.py
+++ b/example/rnn/cudnn_lstm_bucketing.py
@@ -135,13 +135,13 @@ def train(args):
         eval_metric         = mx.metric.Perplexity(invalid_label),
         kvstore             = args.kv_store,
         optimizer           = args.optimizer,
-        optimizer_params    = opt_params, 
+        optimizer_params    = opt_params,
         initializer         = mx.init.Xavier(factor_type="in", magnitude=2.34),
         arg_params          = arg_params,
         aux_params          = aux_params,
         begin_epoch         = args.load_epoch,
         num_epoch           = args.num_epochs,
-        batch_end_callback  = mx.callback.Speedometer(args.batch_size, 
args.disp_batches),
+        batch_end_callback  = mx.callback.Speedometer(args.batch_size, 
args.disp_batches, auto_reset=False),
         epoch_end_callback  = mx.rnn.do_rnn_checkpoint(cell, 
args.model_prefix, 1)
                               if args.model_prefix else None)
 
diff --git a/example/rnn/lstm_bucketing.py b/example/rnn/lstm_bucketing.py
index 6c4371b..609276a 100644
--- a/example/rnn/lstm_bucketing.py
+++ b/example/rnn/lstm_bucketing.py
@@ -107,4 +107,4 @@ if __name__ == '__main__':
                                 'wd': args.wd },
         initializer         = mx.init.Xavier(factor_type="in", magnitude=2.34),
         num_epoch           = args.num_epochs,
-        batch_end_callback  = mx.callback.Speedometer(args.batch_size, 
args.disp_batches))
+        batch_end_callback  = mx.callback.Speedometer(args.batch_size, 
args.disp_batches, auto_reset=False))

-- 
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].

Reply via email to