This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch fit-api
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/fit-api by this push:
new c2e2f80 [MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging
verbose support for Gluon fit() API (#14587)
c2e2f80 is described below
commit c2e2f80474652cae2eb52d3614ef00a05472a679
Author: Karan Jariwala <[email protected]>
AuthorDate: Fri Apr 5 11:04:04 2019 -0700
[MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose support
for Gluon fit() API (#14587)
* Retrieve Batch size and Logging verbose support for Gluon fit() API
* NIT changes
* Addressed review comments: shifted the batch size code to a separate
method, sentence correction
* Modified unittest
* removed redundant parameter
* Resolve CI test failure
* only support DataLoader for now, future PRs will include DataIter to
DataLoader converter
* Get the number of samples from shape attribute instead of length due to
low space complexity
* Simplified batch size retrieval code
* removed batch_size parameter from fit() method and fixed the tests
* Verbose exception handling
* Assigning constant to a verbose
* Modified exception message
* Resolved undefined class reference
* Addressed review comments: Modified verbose level names, docs, variable
names
* Update estimator.py
---
python/mxnet/gluon/estimator/estimator.py | 43 ++++++++--------
python/mxnet/gluon/estimator/event_handler.py | 61 +++++++++++++++--------
tests/nightly/estimator/test_estimator_cnn.py | 12 ++---
tests/nightly/estimator/test_sentiment_rnn.py | 6 +--
tests/python/unittest/test_gluon_estimator.py | 45 +++++++----------
tests/python/unittest/test_gluon_event_handler.py | 5 +-
6 files changed, 91 insertions(+), 81 deletions(-)
diff --git a/python/mxnet/gluon/estimator/estimator.py
b/python/mxnet/gluon/estimator/estimator.py
index c5da0c0..5294991 100644
--- a/python/mxnet/gluon/estimator/estimator.py
+++ b/python/mxnet/gluon/estimator/estimator.py
@@ -21,11 +21,9 @@
import copy
import warnings
-
from .event_handler import EventHandler, LoggingHandler
from ... import gluon, autograd
from ...context import Context, cpu, gpu, num_gpus
-from ...io import DataIter
from ...metric import EvalMetric, Loss, Accuracy
__all__ = ['Estimator']
@@ -168,7 +166,7 @@ class Estimator(object):
Parameters
----------
- val_data : DataLoader or DataIter
+ val_data : DataLoader
validation data with data and labels
batch_fn : function
custom batch function to extract data and label
@@ -182,13 +180,10 @@ class Estimator(object):
if not batch_fn:
if isinstance(val_data, gluon.data.DataLoader):
data, label = self._batch_fn(batch, self.context)
- elif isinstance(val_data, DataIter):
- data, label = self._batch_fn(batch, self.context,
is_iterator=True)
else:
raise ValueError("You are using a custom iteration, please
also provide "
"batch_fn to extract data and label.
Alternatively, you "
- "can provide the data as
gluon.data.DataLoader or "
- "mx.io.DataIter")
+ "can provide the data as
gluon.data.DataLoader.")
else:
data, label = batch_fn(batch, self.context)
pred = [self.net(x) for x in data]
@@ -208,16 +203,17 @@ class Estimator(object):
def fit(self, train_data,
val_data=None,
epochs=1,
- batch_size=None,
event_handlers=None,
batch_fn=None):
- """Main training loop
+ """Trains the model on a given dataset for a specified
+ number of epochs. Also, the batch size is inferred from the
+ DataLoader's batch_size.
Parameters
----------
- train_data : DataLoader or DataIter
+ train_data : DataLoader
training data with data and labels
- val_data : DataLoader or DataIter
+ val_data : DataLoader
validation data with data and labels
epochs : int, default 1
number of epochs to iterate on the training data.
@@ -232,12 +228,8 @@ class Estimator(object):
"""
self.max_epoch = epochs
- if not batch_size:
- self.batch_size = 32 * len(self.context)
- else:
- self.batch_size = batch_size
self.stop_training = False
- self.samples = None
+ self.processed_samples = None
self.batch_idx = 0
event_handlers = event_handlers or []
@@ -245,6 +237,9 @@ class Estimator(object):
if not event_handlers or \
not any(isinstance(handler, LoggingHandler) for handler in
event_handlers):
event_handlers.append(LoggingHandler())
+ warnings.warn("No Event Handler specified, default
`LoggingHandler()` "
+ "is used with
verbose=LoggingHandler.LOG_VERBOSITY_PER_EPOCH. "
+ "Please look at gluon.estimator.event_handler for
more detail.")
train_begin, epoch_begin, batch_begin, \
batch_end, epoch_end, train_end =
self._categorize_handlers(event_handlers)
@@ -261,6 +256,8 @@ class Estimator(object):
for epoch in range(self.max_epoch):
# epoch begin
self.current_epoch = epoch
+ # Number of samples trained after every batch
+ completed_samples = 0
for handler in epoch_begin:
handler.epoch_begin()
@@ -272,16 +269,15 @@ class Estimator(object):
if not batch_fn:
if isinstance(train_data, gluon.data.DataLoader):
data, label = self._batch_fn(batch, self.context)
- elif isinstance(train_data, DataIter):
- data, label = self._batch_fn(batch, self.context,
is_iterator=True)
else:
raise ValueError("You are using a custom iteration,
please also provide "
"batch_fn to extract data and label.
Alternatively, you "
- "can provide the data as
gluon.data.DataLoader or "
- "mx.io.DataIter")
+ "can provide the data as
gluon.data.DataLoader")
else:
data, label = batch_fn(batch, self.context)
+ batch_size = batch[0].shape[0]
+
# batch begin
for handler in batch_begin:
handler.batch_begin()
@@ -309,12 +305,15 @@ class Estimator(object):
name, value = loss_metric.get()
self.train_stats['train_' + name] = value
+ completed_samples += batch_size
+
self.batch_idx = i
# record trained samples v.s. total samples if using Gluon
DataLoader
if isinstance(train_data, gluon.data.DataLoader):
- self.samples = "{}/{}".format(self.batch_size * (i + 1),
len(train_data._dataset))
+ self.processed_samples = "{}/{}".format(completed_samples,
+
len(train_data._dataset))
- self.trainer.step(self.batch_size)
+ self.trainer.step(batch_size)
# batch end
for handler in batch_end:
handler.batch_end()
diff --git a/python/mxnet/gluon/estimator/event_handler.py
b/python/mxnet/gluon/estimator/event_handler.py
index 7810074..53c0bf5 100644
--- a/python/mxnet/gluon/estimator/event_handler.py
+++ b/python/mxnet/gluon/estimator/event_handler.py
@@ -85,14 +85,27 @@ class LoggingHandler(EventHandler):
file name to save the logs
file_location: str
file location to save the logs
+ verbose: int, default LOG_VERBOSITY_PER_EPOCH
+ Limit the granularity of metrics displayed during training process
+ verbose=LOG_VERBOSITY_PER_EPOCH: display metrics every epoch
+ verbose=LOG_VERBOSITY_PER_BATCH: display metrics every batch
"""
- def __init__(self, file_name=None, file_location=None):
+ LOG_VERBOSITY_PER_EPOCH = 1
+ LOG_VERBOSITY_PER_BATCH = 2
+
+ def __init__(self, file_name=None, file_location=None,
verbose=LOG_VERBOSITY_PER_EPOCH):
super(LoggingHandler, self).__init__()
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
stream_handler = logging.StreamHandler()
self.logger.addHandler(stream_handler)
+ if verbose not in [self.LOG_VERBOSITY_PER_EPOCH,
self.LOG_VERBOSITY_PER_BATCH]:
+ raise ValueError("verbose level must be either
LOG_VERBOSITY_PER_EPOCH or "
+ "LOG_VERBOSITY_PER_BATCH, received %s. "
+ "E.g:
LoggingHandler(verbose=LoggingHandler.LOG_VERBOSITY_PER_EPOCH)"
+ % verbose)
+ self.verbose = verbose
# save logger to file only if file name or location is specified
if file_name or file_location:
file_name = file_name or 'estimator_log'
@@ -118,33 +131,37 @@ class LoggingHandler(EventHandler):
self.logger.info(msg)
def batch_begin(self):
- self.batch_start = time.time()
+ if self.verbose == self.LOG_VERBOSITY_PER_BATCH:
+ self.batch_start = time.time()
def batch_end(self):
- batch_time = time.time() - self.batch_start
- epoch = self.estimator.current_epoch
- batch = self.estimator.batch_idx
- msg = '[Epoch %d] [Batch %d] ' % (epoch, batch)
- if self.estimator.samples:
- msg += '[Samples %s] ' % (self.estimator.samples)
- msg += 'time/batch: %.3fs ' % batch_time
- for key in self.estimator.train_stats:
- # only log current training loss & metric after each batch
- if key.startswith('train_'):
- msg += key + ': ' + '%.4f ' % self.estimator.train_stats[key]
- self.logger.info(msg)
+ if self.verbose == self.LOG_VERBOSITY_PER_BATCH:
+ batch_time = time.time() - self.batch_start
+ epoch = self.estimator.current_epoch
+ batch = self.estimator.batch_idx
+ msg = '[Epoch %d] [Batch %d] ' % (epoch, batch)
+ if self.estimator.processed_samples:
+ msg += '[Samples %s] ' % (self.estimator.processed_samples)
+ msg += 'time/batch: %.3fs ' % batch_time
+ for key in self.estimator.train_stats:
+ # only log current training loss & metric after each batch
+ if key.startswith('train_'):
+ msg += key + ': ' + '%.4f ' %
self.estimator.train_stats[key]
+ self.logger.info(msg)
def epoch_begin(self):
- self.epoch_start = time.time()
+ if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH:
+ self.epoch_start = time.time()
def epoch_end(self):
- epoch_time = time.time() - self.epoch_start
- epoch = self.estimator.current_epoch
- msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time)
- # log every result in train stats including train/validation loss &
metrics
- for key in self.estimator.train_stats:
- msg += '%s : %.4f ' % (key, self.estimator.train_stats[key])
- self.logger.info(msg)
+ if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH:
+ epoch_time = time.time() - self.epoch_start
+ epoch = self.estimator.current_epoch
+ msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time)
+ # log every result in train stats including train/validation loss
& metrics
+ for key in self.estimator.train_stats:
+ msg += '%s : %.4f ' % (key, self.estimator.train_stats[key])
+ self.logger.info(msg)
class CheckpointHandler(EventHandler):
diff --git a/tests/nightly/estimator/test_estimator_cnn.py
b/tests/nightly/estimator/test_estimator_cnn.py
index b99e99a..b4311b3 100644
--- a/tests/nightly/estimator/test_estimator_cnn.py
+++ b/tests/nightly/estimator/test_estimator_cnn.py
@@ -105,13 +105,12 @@ def test_estimator_cpu():
est = estimator.Estimator(net=net,
loss=loss,
metrics=mx.metric.Accuracy(),
- trainers=trainer,
+ trainer=trainer,
context=context)
# Call fit()
est.fit(train_data=train_data,
val_data=val_data,
- epochs=1,
- batch_size=1)
+ epochs=1)
def test_estimator_gpu():
'''
@@ -131,15 +130,14 @@ def test_estimator_gpu():
est = estimator.Estimator(net=net,
loss=loss,
metrics=acc,
- trainers=trainer,
+ trainer=trainer,
context=context)
# Call fit()
est.fit(train_data=train_data,
val_data=test_data,
- epochs=num_epochs,
- batch_size=batch_size)
+ epochs=num_epochs)
- assert est.train_stats['train_'+acc.name][num_epochs-1] > 0.80
+ assert est.train_stats['train_'+acc.name] > 0.80
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='test gluon estimator')
diff --git a/tests/nightly/estimator/test_sentiment_rnn.py
b/tests/nightly/estimator/test_sentiment_rnn.py
index 7e42831..c9dcbd2 100644
--- a/tests/nightly/estimator/test_sentiment_rnn.py
+++ b/tests/nightly/estimator/test_sentiment_rnn.py
@@ -179,10 +179,10 @@ def run(net, train_dataloader, test_dataloader, **kwargs):
# Define estimator
est = estimator.Estimator(net=net, loss=loss, metrics=acc,
- trainers=trainer, context=ctx)
+ trainer=trainer, context=ctx)
# Begin training
est.fit(train_data=train_dataloader, val_data=test_dataloader,
- epochs=num_epochs, batch_size=batch_size)
+ epochs=num_epochs)
return est
@@ -252,7 +252,7 @@ def test_estimator_gpu(**kwargs):
est = run(net, train_dataloader, test_dataloader, **kwargs)
- assert est.train_stats['train_accuracy'][num_epochs - 1] > 0.70
+ assert est.train_stats['train_accuracy'] > 0.70
parser = argparse.ArgumentParser(description='test gluon estimator')
diff --git a/tests/python/unittest/test_gluon_estimator.py
b/tests/python/unittest/test_gluon_estimator.py
index 25a410e..c86f4ff 100644
--- a/tests/python/unittest/test_gluon_estimator.py
+++ b/tests/python/unittest/test_gluon_estimator.py
@@ -55,20 +55,18 @@ def test_fit():
dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
train_dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size)
est.fit(train_data=train_dataloader,
- epochs=num_epochs,
- batch_size=batch_size)
+ epochs=num_epochs)
# Input dataiter
train_dataiter = mx.io.NDArrayIter(data=in_data, label=out_data,
batch_size=batch_size)
- est.fit(train_data=train_dataiter,
- epochs=num_epochs,
- batch_size=batch_size)
+ with assert_raises(ValueError):
+ est.fit(train_data=train_dataiter,
+ epochs=num_epochs)
# Input NDArray
with assert_raises(ValueError):
est.fit(train_data=[in_data, out_data],
- epochs=num_epochs,
- batch_size=batch_size)
+ epochs=num_epochs)
def test_validation():
@@ -94,22 +92,20 @@ def test_validation():
val_dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size)
est.fit(train_data=train_dataloader,
val_data=val_dataloader,
- epochs=num_epochs,
- batch_size=batch_size)
+ epochs=num_epochs)
# Input dataiter
train_dataiter = mx.io.NDArrayIter(data=in_data, label=out_data,
batch_size=batch_size)
val_dataiter = mx.io.NDArrayIter(data=in_data, label=out_data,
batch_size=batch_size)
- est.fit(train_data=train_dataiter,
- val_data=val_dataiter,
- epochs=num_epochs,
- batch_size=batch_size)
+ with assert_raises(ValueError):
+ est.fit(train_data=train_dataiter,
+ val_data=val_dataiter,
+ epochs=num_epochs)
# Input NDArray
with assert_raises(ValueError):
est.fit(train_data=[in_data, out_data],
val_data=[in_data, out_data],
- epochs=num_epochs,
- batch_size=batch_size)
+ epochs=num_epochs)
@unittest.skipIf(sys.version_info.major < 3, 'Test on python 3')
@@ -131,8 +127,7 @@ def test_initializer():
metrics=acc,
context=ctx)
est.fit(train_data=train_data,
- epochs=num_epochs,
- batch_size=batch_size)
+ epochs=num_epochs)
# different initializer for net and estimator
net = get_model()
@@ -148,8 +143,7 @@ def test_initializer():
context=ctx)
assert 'Network already initialized' in str(w[-1].message)
est.fit(train_data=train_data,
- epochs=num_epochs,
- batch_size=batch_size)
+ epochs=num_epochs)
@unittest.skipIf(sys.version_info.major < 3, 'Test on python 3')
@@ -174,8 +168,7 @@ def test_trainer():
context=ctx)
assert 'No trainer specified' in str(w[-1].message)
est.fit(train_data=train_data,
- epochs=num_epochs,
- batch_size=batch_size)
+ epochs=num_epochs)
# input invalid trainer
trainer = 'sgd'
@@ -206,8 +199,7 @@ def test_metric():
trainer=trainer,
context=ctx)
est.fit(train_data=train_data,
- epochs=num_epochs,
- batch_size=batch_size)
+ epochs=num_epochs)
# input list of metrics
metrics = [mx.metric.Accuracy(), mx.metric.Accuracy()]
est = Estimator(net=net,
@@ -216,8 +208,7 @@ def test_metric():
trainer=trainer,
context=ctx)
est.fit(train_data=train_data,
- epochs=num_epochs,
- batch_size=batch_size)
+ epochs=num_epochs)
# input invalid metric
with assert_raises(ValueError):
est = Estimator(net=net,
@@ -260,7 +251,9 @@ def test_context():
loss=loss,
metrics=metrics)
# input list of context
- ctx = [mx.gpu(0), mx.gpu(1)]
+ gpus = mx.context.num_gpus()
+ ctx = [mx.gpu(i) for i in gpus] if gpus > 0 else [mx.cpu()]
+ net = get_model()
est = Estimator(net=net,
loss=loss,
metrics=metrics,
diff --git a/tests/python/unittest/test_gluon_event_handler.py
b/tests/python/unittest/test_gluon_event_handler.py
index ccbcb54..023b046 100644
--- a/tests/python/unittest/test_gluon_event_handler.py
+++ b/tests/python/unittest/test_gluon_event_handler.py
@@ -30,7 +30,10 @@ def _get_test_network():
return net
def _get_test_data():
- return mx.io.NDArrayIter(data=nd.ones((32, 100)),
label=nd.random.randint(0, 10, (32, 1)))
+ data = nd.ones((32, 100))
+ label = nd.random.randint(0, 10, (32, 1))
+ data_arr = mx.gluon.data.dataset.ArrayDataset(data, label)
+ return mx.gluon.data.DataLoader(data_arr, batch_size=32)
def test_checkpoint_handler():