This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new b848c24 fix example/rnn: Speedometer(..., auto_reset=False) (#6679)
b848c24 is described below
commit b848c241be41b6933923e9acc7caa40d6c2f76b4
Author: Leonard Lausen <[email protected]>
AuthorDate: Fri Aug 4 05:35:18 2017 +0900
fix example/rnn: Speedometer(..., auto_reset=False) (#6679)
If the Speedometer resets the eval_metric and due to an unlucky number of
batches the end_of_batch is reached immediately after, the Perplexity will
throw
an ZeroDivisionError as eval_metric.num_inst == 0.
---
example/rnn/cudnn_lstm_bucketing.py | 4 ++--
example/rnn/lstm_bucketing.py | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/example/rnn/cudnn_lstm_bucketing.py
b/example/rnn/cudnn_lstm_bucketing.py
index 140f2e6..fbf32bb 100644
--- a/example/rnn/cudnn_lstm_bucketing.py
+++ b/example/rnn/cudnn_lstm_bucketing.py
@@ -135,13 +135,13 @@ def train(args):
eval_metric = mx.metric.Perplexity(invalid_label),
kvstore = args.kv_store,
optimizer = args.optimizer,
- optimizer_params = opt_params,
+ optimizer_params = opt_params,
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),
arg_params = arg_params,
aux_params = aux_params,
begin_epoch = args.load_epoch,
num_epoch = args.num_epochs,
- batch_end_callback = mx.callback.Speedometer(args.batch_size,
args.disp_batches),
+ batch_end_callback = mx.callback.Speedometer(args.batch_size,
args.disp_batches, auto_reset=False),
epoch_end_callback = mx.rnn.do_rnn_checkpoint(cell,
args.model_prefix, 1)
if args.model_prefix else None)
diff --git a/example/rnn/lstm_bucketing.py b/example/rnn/lstm_bucketing.py
index 6c4371b..609276a 100644
--- a/example/rnn/lstm_bucketing.py
+++ b/example/rnn/lstm_bucketing.py
@@ -107,4 +107,4 @@ if __name__ == '__main__':
'wd': args.wd },
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),
num_epoch = args.num_epochs,
- batch_end_callback = mx.callback.Speedometer(args.batch_size,
args.disp_batches))
+ batch_end_callback = mx.callback.Speedometer(args.batch_size,
args.disp_batches, auto_reset=False))
--
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].