This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch fit-api
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/fit-api by this push:
new 7e10355 [Fit API] improve event handlers (#14685)
7e10355 is described below
commit 7e10355a2adb349d1b5ff1f3c46f367a43ee1706
Author: Lai Wei <[email protected]>
AuthorDate: Thu Apr 18 22:27:52 2019 -0700
[Fit API] improve event handlers (#14685)
* improve event handlers
* update tests
* passing weakref of estimator
* fix unit test
* fix test
* fix pylint
* fix test
* fix pylint
* move default metric logic
* combine nightly tests
---
ci/docker/runtime_functions.sh | 18 +-
python/mxnet/gluon/contrib/estimator/estimator.py | 271 ++++++++---------
.../mxnet/gluon/contrib/estimator/event_handler.py | 337 ++++++++++++++-------
tests/nightly/Jenkinsfile | 12 +-
tests/nightly/JenkinsfileForBinaries | 16 -
tests/nightly/estimator/test_estimator_cnn.py | 2 +-
tests/nightly/estimator/test_sentiment_rnn.py | 6 +-
tests/python/unittest/test_gluon_estimator.py | 18 +-
tests/python/unittest/test_gluon_event_handler.py | 46 +--
9 files changed, 388 insertions(+), 338 deletions(-)
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 59ff221..b194ebb 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -1296,31 +1296,19 @@ nightly_scala_demo_test_cpu() {
bash bin/run_im.sh
}
-nightly_estimator_cnn_gpu() {
+nightly_estimator_gpu() {
set -ex
cd /work/mxnet/tests/nightly/estimator
export PYTHONPATH=/work/mxnet/python/
python test_estimator_cnn.py --type gpu
-}
-
-nightly_estimator_cnn_cpu() {
- set -ex
- cd /work/mxnet/tests/nightly/estimator
- export PYTHONPATH=/work/mxnet/python/
- python test_estimator_cnn.py --type cpu
-}
-
-nightly_estimator_rnn_gpu() {
- set -ex
- cd /work/mxnet/tests/nightly/estimator
- export PYTHONPATH=/work/mxnet/python/
python test_sentiment_rnn.py --type gpu
}
-nightly_estimator_rnn_cpu() {
+nightly_estimator_cpu() {
set -ex
cd /work/mxnet/tests/nightly/estimator
export PYTHONPATH=/work/mxnet/python/
+ python test_estimator_cnn.py --type cpu
python test_sentiment_rnn.py --type cpu
}
diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py
b/python/mxnet/gluon/contrib/estimator/estimator.py
index f7c97c4..78672d2 100644
--- a/python/mxnet/gluon/contrib/estimator/estimator.py
+++ b/python/mxnet/gluon/contrib/estimator/estimator.py
@@ -16,12 +16,15 @@
# under the License.
# coding: utf-8
-# pylint: disable=wildcard-import
+# pylint: disable=wildcard-import, unused-variable
"""Gluon Estimator"""
import copy
import warnings
-from .event_handler import EventHandler, LoggingHandler
+import weakref
+
+from .event_handler import MetricHandler, ValidationHandler, LoggingHandler
+from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd,
EpochEnd, TrainEnd
from .... import gluon, autograd
from ....context import Context, cpu, gpu, num_gpus
from ....metric import EvalMetric, Loss, Accuracy
@@ -46,7 +49,7 @@ class Estimator(object):
trainer : Trainer
Trainer to apply optimizer on network parameters
context : Context or list of Context
- devices to run the training on
+ device(s) to run the training on
"""
def __init__(self, net,
@@ -57,46 +60,39 @@ class Estimator(object):
context=None):
self.net = net
+ self.loss = self._check_loss(loss)
+ self.train_metrics = self._check_metrics(metrics)
+
+ self.context = self._check_context(context)
+ self._initialize(initializer)
+ self.trainer = self._check_trainer(trainer)
+ def _check_loss(self, loss):
if isinstance(loss, gluon.loss.Loss):
- self.loss = [loss]
- elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss)
for l in loss]):
- self.loss = loss
+ loss = [loss]
+ elif isinstance(loss, list) or all([isinstance(l, gluon.loss.Loss) for
l in loss]):
+ loss = loss
else:
raise ValueError("loss must be a Loss or a list of Loss, "
"refer to gluon.loss.Loss:{}".format(loss))
+ return loss
+ def _check_metrics(self, metrics):
if isinstance(metrics, EvalMetric):
- self.train_metrics = [metrics]
+ metrics = [metrics]
else:
- self.train_metrics = metrics or []
- if not all([isinstance(metric, EvalMetric) for metric in
self.train_metrics]):
+ metrics = metrics or []
+ if not all([isinstance(metric, EvalMetric) for metric in metrics]):
raise ValueError("metrics must be a Metric or a list of
Metric, "
"refer to
mxnet.metric.EvalMetric:{}".format(metrics))
+ return metrics
- # Use default mx.metric.Accuracy() for
gluon.loss.SoftmaxCrossEntropyLoss()
- if not self.train_metrics and any([isinstance(l,
gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]):
- self.train_metrics = [Accuracy()]
-
- # Use same metrics for validation
- self.val_metrics = copy.deepcopy(self.train_metrics)
-
- # store training statistics
- self.train_stats = {}
-
- # separate train and validation
- self.train_loss_metrics = []
- self.val_loss_metrics = []
- # using the metric wrapper for loss to record loss value
- for l in self.loss:
- self.train_loss_metrics.append(Loss(l.name))
- self.val_loss_metrics.append(Loss(l.name))
-
+ def _check_context(self, context):
# handle context
if isinstance(context, Context):
- self.context = [context]
+ context = [context]
elif isinstance(context, list) and all([isinstance(c, Context) for c
in context]):
- self.context = context
+ context = context
elif not context:
if num_gpus() > 0:
# only use 1 GPU by default
@@ -104,40 +100,41 @@ class Estimator(object):
warnings.warn("You have multiple GPUs, gpu(0) will be used
by default."
"To utilize all your GPUs, specify context
as a list of gpus, "
"e.g. context=[mx.gpu(0), mx.gpu(1)] ")
- self.context = [gpu(0)]
+ context = [gpu(0)]
else:
- self.context = [cpu()]
+ context = [cpu()]
else:
raise ValueError("context must be a Context or a list of Context, "
"refer to mxnet.Context:{}".format(context))
+ return context
+ def _initialize(self, initializer):
# initialize the network
- self.initializer = initializer
- if self.initializer:
+ if initializer:
if self._is_initialized():
# if already initialized, re-init with user specified
initializer
warnings.warn("Network already initialized, re-initializing
with %s. "
"You don't need to pass initializer if you
already "
- "initialized your net." %
type(self.initializer).__name__)
- self.net.initialize(init=self.initializer, ctx=self.context,
force_reinit=True)
+ "initialized your net." %
type(initializer).__name__)
+ self.net.initialize(init=initializer, ctx=self.context,
force_reinit=True)
else:
# initialize with user specified initializer
- self.net.initialize(init=self.initializer, ctx=self.context,
force_reinit=False)
+ self.net.initialize(init=initializer, ctx=self.context,
force_reinit=False)
else:
if not self._is_initialized():
self.net.initialize(ctx=self.context)
+ def _check_trainer(self, trainer):
# handle trainer
if not trainer:
warnings.warn("No trainer specified, default SGD optimizer "
"with learning rate 0.001 is used.")
- self.trainer = gluon.Trainer(self.net.collect_params(),
- 'sgd', {'learning_rate': 0.001})
+ trainer = gluon.Trainer(self.net.collect_params(),
+ 'sgd', {'learning_rate': 0.001})
elif not isinstance(trainer, gluon.Trainer):
raise ValueError("Trainer must be a Gluon Trainer instance, refer
to "
"gluon.Trainer:{}".format(trainer))
- else:
- self.trainer = trainer
+ return trainer
def _is_initialized(self):
param_dict = self.net.collect_params()
@@ -148,63 +145,70 @@ class Estimator(object):
return False
return True
- def _batch_fn(self, batch, ctx, is_iterator=False):
- if is_iterator:
- data = batch.data[0]
- label = batch.label[0]
- else:
- data = batch[0]
- label = batch[1]
+ def _get_data_and_label(self, batch, ctx):
+ data = batch[0]
+ label = batch[1]
data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0)
return data, label
+ def prepare_loss_and_metrics(self):
+ """
+ Based on loss functions and training metrics in estimator
+ Create metric wrappers to record loss values,
+ Create copies of train loss/metric objects to record validation values
+ """
+ if any(not hasattr(self, attribute) for attribute in
+ ['train_metrics', 'val_metrics']):
+ # Use default mx.metric.Accuracy() for
gluon.loss.SoftmaxCrossEntropyLoss()
+ if not self.train_metrics and any([isinstance(l,
gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]):
+ self.train_metrics = [Accuracy()]
+ self.val_metrics = []
+ for loss in self.loss:
+ self.train_metrics.append(Loss("Train " + ''.join([i for i in
loss.name if not i.isdigit()])))
+ self.val_metrics.append(Loss("Validation " + ''.join([i for i
in loss.name if not i.isdigit()])))
+ for metric in self.train_metrics:
+ val_metric = copy.deepcopy(metric)
+ metric.name = "Train " + metric.name
+ val_metric.name = "Validation " + val_metric.name
+ self.val_metrics.append(val_metric)
+ return self.train_metrics, self.val_metrics
+
def evaluate(self,
val_data,
- batch_fn=None):
+ val_metrics):
"""Evaluate model on validation data
Parameters
----------
val_data : DataLoader
validation data with data and labels
- batch_fn : function
- custom batch function to extract data and label
- from a data batch and load into contexts(devices)
+ val_metrics : EvalMetric or list of EvalMetrics
+ metrics to update validation result
"""
- for metric in self.val_metrics + self.val_loss_metrics:
+ for metric in val_metrics:
metric.reset()
for _, batch in enumerate(val_data):
- if not batch_fn:
- if isinstance(val_data, gluon.data.DataLoader):
- data, label = self._batch_fn(batch, self.context)
- else:
- raise ValueError("You are using a custom iteration, please
also provide "
- "batch_fn to extract data and label.
Alternatively, you "
- "can provide the data as
gluon.data.DataLoader.")
- else:
- data, label = batch_fn(batch, self.context)
+ if not isinstance(val_data, gluon.data.DataLoader):
+ raise ValueError("Estimator only support input as Gluon
DataLoader. Alternatively, you "
+ "can transform your DataIter or any NDArray
into Gluon DataLoader. "
+ "Refer to gluon.data.dataloader")
+ data, label = self._get_data_and_label(batch, self.context)
pred = [self.net(x) for x in data]
- losses = []
- for loss in self.loss:
- losses.append([loss(y_hat, y) for y_hat, y in zip(pred,
label)])
+ loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
# update metrics
- for metric in self.val_metrics:
- metric.update(label, pred)
- name, value = metric.get()
- self.train_stats['val_' + name] = value
- for loss, loss_metric, in zip(losses, self.val_loss_metrics):
- loss_metric.update(0, [l for l in loss])
- name, value = loss_metric.get()
- self.train_stats['val_' + name] = value
+ for metric in val_metrics:
+ if isinstance(metric, Loss):
+ metric.update(0, loss)
+ else:
+ metric.update(label, pred)
def fit(self, train_data,
val_data=None,
epochs=1,
- event_handlers=None,
- batch_fn=None):
+ event_handlers=None):
"""Trains the model on a given dataset for a specified
number of epochs. Also, the batch size is inferred from the
DataLoader's batch_size.
@@ -226,111 +230,72 @@ class Estimator(object):
custom batch function to extract data and label
from a data batch and load into contexts(devices)
"""
-
- self.max_epoch = epochs
- self.stop_training = False
- self.processed_samples = None
- self.batch_idx = 0
-
+ self.max_epochs = epochs
event_handlers = event_handlers or []
# provide default logging handler
- if not event_handlers or \
- not any(isinstance(handler, LoggingHandler) for handler in
event_handlers):
- event_handlers.append(LoggingHandler())
- warnings.warn("No Event Handler specified, default
`LoggingHandler()` "
- "is used with
verbose=LoggingHandler.LOG_VERBOSITY_PER_EPOCH. "
- "Please look at gluon.estimator.event_handler for
more detail.")
+ if not event_handlers:
+ train_metrics, val_metrics = self.prepare_loss_and_metrics()
+ event_handlers.append(MetricHandler(train_metrics=train_metrics))
+ if val_data:
+ event_handlers.append(ValidationHandler(val_data=val_data,
eval_fn=self.evaluate,
+
val_metrics=val_metrics))
+ event_handlers.append(LoggingHandler(train_metrics=train_metrics,
+ val_metrics=val_metrics))
+ warnings.warn("No Event Handler specified, default %s are used. "
+ "Please look at
gluon.contrib.estimator.event_handler for more detail." %
+ ", ".join([handler.__class__.__name__ for handler in
event_handlers]))
+
+ event_handlers.sort(key=lambda handler: getattr(handler, 'rank', 0),
reverse=True)
train_begin, epoch_begin, batch_begin, \
batch_end, epoch_end, train_end =
self._categorize_handlers(event_handlers)
- # passing estimator to event handlers so they can access estimator
information
- # when a event is triggered
- for handler in event_handlers:
- handler.estimator = self
-
+ # only pass a weak reference to all event handlers
+ estimator_ref = weakref.proxy(self)
# training begin
for handler in train_begin:
- handler.train_begin()
+ handler.train_begin(estimator_ref)
- for epoch in range(self.max_epoch):
+ for epoch in range(epochs):
# epoch begin
- self.current_epoch = epoch
- # Number of samples trained after every batch
- completed_samples = 0
-
for handler in epoch_begin:
- handler.epoch_begin()
-
- for metric in self.train_metrics + self.train_loss_metrics:
- metric.reset()
+ handler.epoch_begin(estimator_ref)
for i, batch in enumerate(train_data):
- if not batch_fn:
- if isinstance(train_data, gluon.data.DataLoader):
- data, label = self._batch_fn(batch, self.context)
- else:
- raise ValueError("You are using a custom iteration,
please also provide "
- "batch_fn to extract data and label.
Alternatively, you "
- "can provide the data as
gluon.data.DataLoader")
- else:
- data, label = batch_fn(batch, self.context)
+ if not isinstance(train_data, gluon.data.DataLoader):
+ raise ValueError("Estimator only support input as Gluon
DataLoader. Alternatively, you "
+ "can transform your DataIter or any
NDArray into Gluon DataLoader. "
+ "Refer to gluon.data.dataloader")
+ data, label = self._get_data_and_label(batch, self.context)
batch_size = batch[0].shape[0]
# batch begin
for handler in batch_begin:
- handler.batch_begin()
+ handler.batch_begin(estimator_ref, batch=batch)
with autograd.record():
pred = [self.net(x) for x in data]
- losses = []
- for loss in self.loss:
- losses.append([loss(y_hat, y) for y_hat, y in
zip(pred, label)])
+ loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred,
label)]
- for loss in losses:
- for l in loss:
- l.backward()
-
- # update train metrics
- for metric in self.train_metrics:
- metric.update(label, pred)
- # get metric name and current value and update train stats
- name, value = metric.get()
- self.train_stats['train_' + name] = value
-
- # update loss
- for loss, loss_metric, in zip(losses, self.train_loss_metrics):
- loss_metric.update(0, [l for l in loss])
- name, value = loss_metric.get()
- self.train_stats['train_' + name] = value
-
- completed_samples += batch_size
-
- self.batch_idx = i
- # record trained samples v.s. total samples if using Gluon
DataLoader
- if isinstance(train_data, gluon.data.DataLoader):
- self.processed_samples = "{}/{}".format(completed_samples,
-
len(train_data._dataset))
+ for l in loss:
+ l.backward()
self.trainer.step(batch_size)
# batch end
for handler in batch_end:
- handler.batch_end()
-
- if val_data:
- self.evaluate(val_data, batch_fn)
+ if handler.batch_end(estimator_ref, batch=batch,
+ pred=pred, label=label, loss=loss):
+ break
# epoch end
for handler in epoch_end:
- handler.epoch_end()
-
- if self.stop_training:
- break
+ if handler.epoch_end(estimator_ref):
+ break
# train end
for handler in train_end:
- handler.train_end()
+ handler.train_end(estimator_ref)
def _categorize_handlers(self, event_handlers):
"""
@@ -346,16 +311,16 @@ class Estimator(object):
epoch_end = []
train_end = []
for handler in event_handlers:
- if not handler.__class__.train_begin == EventHandler.train_begin:
+ if isinstance(handler, TrainBegin):
train_begin.append(handler)
- if not handler.__class__.epoch_begin == EventHandler.epoch_begin:
+ if isinstance(handler, EpochBegin):
epoch_begin.append(handler)
- if not handler.__class__.batch_begin == EventHandler.batch_begin:
+ if isinstance(handler, BatchBegin):
batch_begin.append(handler)
- if not handler.__class__.batch_end == EventHandler.batch_end:
+ if isinstance(handler, BatchEnd):
batch_end.append(handler)
- if not handler.__class__.epoch_end == EventHandler.epoch_end:
+ if isinstance(handler, EpochEnd):
epoch_end.append(handler)
- if not handler.__class__.train_end == EventHandler.train_end:
+ if isinstance(handler, TrainEnd):
train_end.append(handler)
return train_begin, epoch_begin, batch_begin, batch_end, epoch_end,
train_end
diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py
b/python/mxnet/gluon/contrib/estimator/event_handler.py
index 53c0bf5..220aa31 100644
--- a/python/mxnet/gluon/contrib/estimator/event_handler.py
+++ b/python/mxnet/gluon/contrib/estimator/event_handler.py
@@ -16,10 +16,9 @@
# under the License.
# coding: utf-8
-# pylint: disable=wildcard-import
+# pylint: disable=wildcard-import, unused-argument
"""Gluon EventHandlers for Estimators"""
-__all__ = ['EventHandler', 'LoggingHandler']
import logging
import os
import time
@@ -27,51 +26,130 @@ import warnings
import numpy as np
+from ....metric import EvalMetric, Loss
-class EventHandler(object):
- """Basic for event handlers
- :py:class:`EventHandler` can perform user defined functions at
- different stages of training: train begin, epoch begin, batch begin,
- batch end, epoch end, train end.
-
- Parameters
- ----------
- estimator : Estimator
- The :py:class:`Estimator` to get training statistics
- """
+class TrainBegin(object):
+ def train_begin(self, estimator, *args, **kwargs):
+ pass
- def __init__(self):
- self._estimator = None
- @property
- def estimator(self):
- return self._estimator
+class TrainEnd(object):
+ def train_end(self, estimator, *args, **kwargs):
+ pass
- @estimator.setter
- def estimator(self, estimator):
- self._estimator = estimator
- def train_begin(self):
+class EpochBegin(object):
+ def epoch_begin(self, estimator, *args, **kwargs):
pass
- def train_end(self):
- pass
- def batch_begin(self):
- pass
+class EpochEnd(object):
+ def epoch_end(self, estimator, *args, **kwargs):
+ return False
- def batch_end(self):
- pass
- def epoch_begin(self):
+class BatchBegin(object):
+ def batch_begin(self, estimator, *args, **kwargs):
pass
- def epoch_end(self):
- pass
+class BatchEnd(object):
+ def batch_end(self, estimator, *args, **kwargs):
+ return False
+
+
+class MetricHandler(EpochBegin, BatchEnd):
+ """Metric Handler that update metric values at batch end
+
+ :py:class:`MetricHandler` takes model predictions and true labels
+ and update the metrics, it also update metric wrapper for loss with loss
values
+ Validation loss and metrics will be handled by
:py:class:`ValidationHandler`
+
+ Parameters
+ ----------
+ train_metrics : List of EvalMetrics
+ training metrics to be updated at batch end
+ """
+
+ def __init__(self, train_metrics):
+ self.train_metrics = train_metrics or []
+ # order to be called among all callbacks
+ # metrics need to be calculated before other callbacks can access them
+ self.priority = -np.Inf
+
+ def epoch_begin(self, estimator, *args, **kwargs):
+ for metric in self.train_metrics:
+ metric.reset()
+
+ def batch_end(self, estimator, *args, **kwargs):
+ pred = kwargs['pred']
+ label = kwargs['label']
+ loss = kwargs['loss']
+ for metric in self.train_metrics:
+ if isinstance(metric, Loss):
+ # metric wrapper for loss values
+ metric.update(0, loss)
+ else:
+ metric.update(label, pred)
-class LoggingHandler(EventHandler):
+
+class ValidationHandler(BatchEnd, EpochEnd):
+ """"Validation Handler that evaluate model on validation dataset
+
+ :py:class:`ValidationHandler` takes validation dataset, an evaluation
function,
+ metrics to be evaluated, and how often to run the validation. You can
provide custom
+ evaluation function or use the one provided my :py:class:`Estimator`
+
+ Parameters
+ ----------
+ val_data : DataLoader
+ validation data set to run evaluation
+ eval_fn : function
+ a function defines how to run evaluation and
+ calculate loss and metrics
+ val_metrics : List of EvalMetrics
+ validation metrics to be updated
+ epoch_period : int, default 1
+ how often to run validation at epoch end, by default
+ validate every epoch
+ batch_period : int, default None
+ how often to run validation at batch end, by default
+ does not validate at batch end
+ """
+
+ def __init__(self,
+ val_data,
+ eval_fn,
+ val_metrics=None,
+ epoch_period=1,
+ batch_period=None):
+ self.val_data = val_data
+ self.eval_fn = eval_fn
+ self.epoch_period = epoch_period
+ self.batch_period = batch_period
+ self.val_metrics = val_metrics
+ self.num_batches = 0
+ self.num_epochs = 0
+ # order to be called among all callbacks
+ # validation metrics need to be calculated before other callbacks can
access them
+ self.priority = -np.Inf
+
+ def batch_end(self, estimator, *args, **kwargs):
+ if self.batch_period and self.num_batches % self.batch_period == 0:
+ self.eval_fn(val_data=self.val_data,
+ val_metrics=self.val_metrics)
+ self.num_batches += 1
+
+ def epoch_end(self, estimator, *args, **kwargs):
+ if self.num_epochs % self.epoch_period == 0:
+ self.eval_fn(val_data=self.val_data,
+ val_metrics=self.val_metrics)
+
+ self.num_epochs += 1
+
+
+class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin,
BatchEnd):
"""Basic Logging Handler that applies to every Gluon estimator by default.
:py:class:`LoggingHandler` logs hyper-parameters, training statistics,
@@ -79,22 +157,28 @@ class LoggingHandler(EventHandler):
Parameters
----------
- estimator : Estimator
- The :py:class:`Estimator` to get training statistics
file_name : str
file name to save the logs
- file_location: str
+ file_location : str
file location to save the logs
- verbose: int, default LOG_VERBOSITY_PER_EPOCH
+ verbose : int, default LOG_VERBOSITY_PER_EPOCH
Limit the granularity of metrics displayed during training process
verbose=LOG_VERBOSITY_PER_EPOCH: display metrics every epoch
verbose=LOG_VERBOSITY_PER_BATCH: display metrics every batch
+ train_metrics : list of EvalMetrics
+ training metrics to be logged, logged at batch end, epoch end, train
end
+ val_metrics : list of EvalMetrics
+ validation metrics to be logged, logged at epoch end, train end
"""
LOG_VERBOSITY_PER_EPOCH = 1
LOG_VERBOSITY_PER_BATCH = 2
- def __init__(self, file_name=None, file_location=None,
verbose=LOG_VERBOSITY_PER_EPOCH):
+ def __init__(self, file_name=None,
+ file_location=None,
+ verbose=LOG_VERBOSITY_PER_EPOCH,
+ train_metrics=None,
+ val_metrics=None):
super(LoggingHandler, self).__init__()
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
@@ -112,71 +196,83 @@ class LoggingHandler(EventHandler):
file_location = file_location or './'
file_handler = logging.FileHandler(os.path.join(file_location,
file_name))
self.logger.addHandler(file_handler)
-
- def train_begin(self):
+ self.train_metrics = train_metrics or []
+ self.val_metrics = val_metrics or []
+ self.batch_index = 0
+ self.current_epoch = 0
+ self.processed_samples = 0
+ # logging handler need to be called at last to make sure all states
are updated
+ # it will also shut down logging at train end
+ self.priority = np.Inf
+
+ def train_begin(self, estimator, *args, **kwargs):
self.train_start = time.time()
+ trainer = estimator.trainer
+ optimizer = trainer.optimizer.__class__.__name__
+ lr = trainer.learning_rate
self.logger.info("Training begin: using optimizer %s "
"with current learning rate %.4f ",
- self.estimator.trainer.optimizer.__class__.__name__,
- self.estimator.trainer.learning_rate)
- self.logger.info("Train for %d epochs.", self.estimator.max_epoch)
+ optimizer, lr)
+ self.logger.info("Train for %d epochs.", estimator.max_epochs)
- def train_end(self):
+ def train_end(self, estimator, *args, **kwargs):
train_time = time.time() - self.train_start
- epoch = self.estimator.current_epoch
- msg = 'Train finished using total %ds at epoch %d. ' % (train_time,
epoch)
+ msg = 'Train finished using total %ds with %d epochs.' % (train_time,
self.current_epoch)
# log every result in train stats including train/validation loss &
metrics
- for key in self.estimator.train_stats:
- msg += '%s : %.4f ' % (key, self.estimator.train_stats[key])
+ for metric in self.train_metrics + self.val_metrics:
+ name, value = metric.get()
+ msg += '%s : %.4f ' % (name, value)
self.logger.info(msg)
+ for handler in self.logger.handlers:
+ handler.close()
+ self.logger.removeHandler(handler)
+ logging.shutdown()
- def batch_begin(self):
+ def batch_begin(self, estimator, *args, **kwargs):
if self.verbose == self.LOG_VERBOSITY_PER_BATCH:
self.batch_start = time.time()
- def batch_end(self):
+ def batch_end(self, estimator, *args, **kwargs):
if self.verbose == self.LOG_VERBOSITY_PER_BATCH:
batch_time = time.time() - self.batch_start
- epoch = self.estimator.current_epoch
- batch = self.estimator.batch_idx
- msg = '[Epoch %d] [Batch %d] ' % (epoch, batch)
- if self.estimator.processed_samples:
- msg += '[Samples %s] ' % (self.estimator.processed_samples)
+ msg = '[Epoch %d] [Batch %d] ' % (self.current_epoch,
self.batch_index)
+ self.processed_samples += kwargs['batch'][0].shape[0]
+ msg += '[Samples %s] ' % (self.processed_samples)
msg += 'time/batch: %.3fs ' % batch_time
- for key in self.estimator.train_stats:
+ for metric in self.train_metrics:
# only log current training loss & metric after each batch
- if key.startswith('train_'):
- msg += key + ': ' + '%.4f ' %
self.estimator.train_stats[key]
+ name, value = metric.get()
+ msg += '%s : %.4f ' % (name, value)
self.logger.info(msg)
+ self.batch_index += 1
- def epoch_begin(self):
+ def epoch_begin(self, estimator, *args, **kwargs):
if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH:
self.epoch_start = time.time()
- def epoch_end(self):
+ def epoch_end(self, estimator, *args, **kwargs):
if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH:
epoch_time = time.time() - self.epoch_start
- epoch = self.estimator.current_epoch
- msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time)
- # log every result in train stats including train/validation loss
& metrics
- for key in self.estimator.train_stats:
- msg += '%s : %.4f ' % (key, self.estimator.train_stats[key])
+ msg = '\n[Epoch %d] finished in %.3fs: ' % (self.current_epoch,
epoch_time)
+ for monitor in self.train_metrics + self.val_metrics:
+ name, value = monitor.get()
+ msg += '%s : %.4f ' % (name, value)
self.logger.info(msg)
+ self.current_epoch += 1
+ self.batch_index = 0
-class CheckpointHandler(EventHandler):
+class CheckpointHandler(BatchEnd, EpochEnd):
"""Save the model after every epoch.
:py:class:`CheckpointHandler` save the network parameters every epoch
Parameters
----------
- estimator : Estimator
- The :py:class:`Estimator` to get training statistics
filepath : str
file name to save the parameters, it can contain directories,
for example: ./saved_model/resnet.params
- monitor: str
+ monitor: EvalMetric
the metrics to monitor
verbose: int, default 0
verbosity mode
@@ -191,18 +287,23 @@ class CheckpointHandler(EventHandler):
def __init__(self,
filepath,
- monitor='val_accuracy',
+ monitor=None,
verbose=0,
save_best_only=False,
mode='auto',
- period=1):
- super(CheckpointHandler, self).__init__()
+ epoch_period=1,
+ batch_period=None):
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
self.save_best_only = save_best_only
- self.period = period
- self.epochs_since_last_save = 0
+ if self.save_best_only and not isinstance(self.monitor, EvalMetric):
+ raise ValueError("To save best model only, please provide one of
the metric objects as monitor, "
+ "You can create these objects using
estimator.prepare_loss_and_metric()")
+ self.epoch_period = epoch_period
+ self.batch_period = batch_period
+ self.num_batches = 0
+ self.num_epochs = 0
self.logger = logging.getLogger(__name__)
if mode not in ['auto', 'min', 'max']:
@@ -219,55 +320,61 @@ class CheckpointHandler(EventHandler):
self.best = -np.Inf
else:
# use greater for accuracy and less otherwise
- if 'acc' in self.monitor:
+ if 'acc' in self.monitor.get()[0].lower():
self.monitor_op = np.greater
self.best = -np.Inf
else:
self.monitor_op = np.less
self.best = np.Inf
- def epoch_end(self, ):
- epoch = self.estimator.current_epoch
+ def batch_end(self, estimator, *args, **kwargs):
+ self._save_checkpoint(estimator.net, "Batch", self.num_batches)
+ self.num_batches += 1
+
+ def epoch_end(self, estimator, *args, **kwargs):
+ self._save_checkpoint(estimator.net, "Epoch", self.num_epochs)
+ self.num_epochs += 1
+
+ def _save_checkpoint(self, net, period_name, period_value):
# add extension for weights
if '.params' not in self.filepath:
self.filepath += '.params'
- self.epochs_since_last_save += 1
- if self.epochs_since_last_save >= self.period:
- self.epochs_since_last_save = 0
+ if self.num_epochs % self.epoch_period == 0:
if self.save_best_only:
+ monitor_name, monitor_value = self.monitor.get()
# check if monitor exists in train stats
- if self.monitor not in self.estimator.train_stats:
- warnings.warn(RuntimeWarning('Unable to find %s in
training statistics, make sure the monitor value'
- 'starts with `train_ `or
`val_` and contains loss/metric name, ',
- 'for example val_accuracy',
self.monitor))
- self.estimator.net.save_parameters(self.filepath)
+ if np.isnan(monitor_value):
+ warnings.warn(RuntimeWarning('%s is not updated, make sure
you pass one of the metric objects'
+ 'as monitor, you can use
estimator.prepare_loss_and_metrics to'
+ 'create all metric objects',
monitor_name))
+ net.save_parameters(self.filepath)
else:
- current = self.estimator.train_stats[self.monitor]
- if self.monitor_op(current, self.best):
+ if self.monitor_op(monitor_value, self.best):
if self.verbose > 0:
- self.logger.info('\n[Epoch %d] %s improved from
%0.5f to %0.5f,'
+ self.logger.info('\n[%s %d] %s improved from %0.5f
to %0.5f,'
' saving model to %s',
- epoch, self.monitor, self.best,
current, self.filepath)
- self.best = current
- self.estimator.net.save_parameters(self.filepath)
+ period_name, period_value,
monitor_name,
+ self.best, monitor_value,
self.filepath)
+ self.best = monitor_value
+ net.save_parameters(self.filepath)
else:
if self.verbose > 0:
- self.logger.info('\n[Epoch %d] %s did not improve
from %0.5f, skipping save model',
- epoch, self.monitor, self.best)
+ self.logger.info('\n[%s %d] %s did not improve
from %0.5f, skipping save model',
+ period_name, period_value,
monitor_name, self.best)
else:
if self.verbose > 0:
- logging.info('\nEpoch %d: saving model to %s', epoch,
self.filepath)
- self.estimator.net.save_parameters(self.filepath)
+ logging.info('\n%s %d: saving model to %s', period_name,
period_value, self.filepath)
+ net.save_parameters(self.filepath)
-class EarlyStoppingHandler(EventHandler):
+class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd):
"""Early stop training if monitored value is not improving
Parameters
----------
estimator : Estimator
The :py:class:`Estimator` to get training statistics
- monitor: str
+ monitor: EvalMetric
the metrics to monitor
min_delta: float, default 0
minimal change in monitored value to be considered as an improvement
@@ -281,19 +388,24 @@ class EarlyStoppingHandler(EventHandler):
"""
def __init__(self,
- monitor='val_accuracy',
+ monitor,
min_delta=0,
patience=0,
mode='auto',
baseline=None):
super(EarlyStoppingHandler, self).__init__()
+ if not isinstance(monitor, EvalMetric):
+ raise ValueError("Please provide one of the metric objects as
monitor, "
+ "You can create these objects using
estimator.prepare_loss_and_metric()")
self.monitor = monitor
self.baseline = baseline
self.patience = patience
self.min_delta = min_delta
self.wait = 0
self.stopped_epoch = 0
+ self.num_epochs = 0
+ self.stop_training = False
self.logger = logging.getLogger(__name__)
if mode not in ['auto', 'min', 'max']:
@@ -306,7 +418,7 @@ class EarlyStoppingHandler(EventHandler):
elif mode == 'max':
self.monitor_op = np.greater
else:
- if 'acc' in self.monitor:
+ if 'acc' in self.monitor.get()[0].lower():
self.monitor_op = np.greater
else:
self.monitor_op = np.less
@@ -316,7 +428,7 @@ class EarlyStoppingHandler(EventHandler):
else:
self.min_delta *= -1
- def train_begin(self):
+ def train_begin(self, estimator, *args, **kwargs):
self.wait = 0
self.stopped_epoch = 0
if self.baseline is not None:
@@ -324,23 +436,24 @@ class EarlyStoppingHandler(EventHandler):
else:
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
- def epoch_end(self):
- epoch = self.estimator.current_epoch
- if self.monitor not in self.estimator.train_stats:
- warnings.warn(RuntimeWarning('Unable to find %s in training
statistics, make sure the monitor value'
- 'starts with `train_ `or `val_` and
contains loss/metric name, ',
- 'for example val_accuracy',
self.monitor))
+ def epoch_end(self, estimator, *args, **kwargs):
+ monitor_name, monitor_value = self.monitor.get()
+ if np.isnan(monitor_value):
+ warnings.warn(RuntimeWarning('%s is not updated, make sure you
pass one of the metric objects'
+ 'as monitor, you can use
estimator.prepare_loss_and_metrics to'
+ 'create all metric objects',
monitor_name))
else:
- current = self.estimator.train_stats[self.monitor]
- if self.monitor_op(current - self.min_delta, self.best):
- self.best = current
+ if self.monitor_op(monitor_value - self.min_delta, self.best):
+ self.best = monitor_value
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
- self.stopped_epoch = epoch
- self.estimator.stop_training = True
+ self.stopped_epoch = self.num_epochs
+ self.stop_training = True
+ return self.stop_training
- def train_end(self):
+ def train_end(self, estimator, *args, **kwargs):
if self.stopped_epoch > 0:
- self.logger.info('Epoch %d: early stopping due to %s not
improving', self.stopped_epoch, self.monitor)
+ self.logger.info('Epoch %d: early stopping due to %s not
improving',
+ self.stopped_epoch, self.monitor.get()[0])
diff --git a/tests/nightly/Jenkinsfile b/tests/nightly/Jenkinsfile
index a65da2d..1be084c 100755
--- a/tests/nightly/Jenkinsfile
+++ b/tests/nightly/Jenkinsfile
@@ -137,19 +137,19 @@ core_logic: {
}
}
},
- 'estimator: RNN GPU': {
+ 'Gluon estimator: GPU': {
node(NODE_LINUX_GPU) {
- ws('workspace/estimator-test-rnn-gpu') {
+ ws('workspace/estimator-test-gpu') {
utils.unpack_and_init('gpu', mx_lib)
- utils.docker_run('ubuntu_nightly_gpu',
'nightly_estimator_test_rnn_gpu', true)
+ utils.docker_run('ubuntu_nightly_gpu', 'nightly_estimator_gpu', true)
}
}
},
- 'estimator: RNN CPU': {
+ 'Gluon estimator: CPU': {
node(NODE_LINUX_CPU) {
- ws('workspace/estimator-test-rnn-cpu') {
+ ws('workspace/estimator-test-cpu') {
utils.unpack_and_init('cpu', mx_lib)
- utils.docker_run('ubuntu_nightly_cpu',
'nightly_estimator_test_rnn_cpu', false)
+ utils.docker_run('ubuntu_nightly_cpu', 'nightly_estimator_cpu',
false)
}
}
}
diff --git a/tests/nightly/JenkinsfileForBinaries
b/tests/nightly/JenkinsfileForBinaries
index 53572c8..53e1c30 100755
--- a/tests/nightly/JenkinsfileForBinaries
+++ b/tests/nightly/JenkinsfileForBinaries
@@ -106,22 +106,6 @@ core_logic: {
utils.docker_run('ubuntu_nightly_gpu',
'nightly_tutorial_test_ubuntu_python3_gpu', true, '1500m')
}
}
- },
- 'estimator: CNN GPU': {
- node(NODE_LINUX_GPU) {
- ws('workspace/estimator-test-cnn-gpu') {
- utils.unpack_and_init('gpu', mx_lib)
- utils.docker_run('ubuntu_nightly_gpu',
'nightly_estimator_test_cnn_gpu', true)
- }
- }
- },
- 'estimator: CNN CPU': {
- node(NODE_LINUX_CPU) {
- ws('workspace/estimator-test-cnn-cpu') {
- utils.unpack_and_init('cpu', mx_lib)
- utils.docker_run('ubuntu_nightly_cpu',
'nightly_estimator_test_cnn_cpu', true)
- }
- }
}
}
}
diff --git a/tests/nightly/estimator/test_estimator_cnn.py
b/tests/nightly/estimator/test_estimator_cnn.py
index 7d0018b..c60dc54 100644
--- a/tests/nightly/estimator/test_estimator_cnn.py
+++ b/tests/nightly/estimator/test_estimator_cnn.py
@@ -137,7 +137,7 @@ def test_estimator_gpu():
val_data=test_data,
epochs=num_epochs)
- assert est.train_stats['train_'+acc.name] > 0.80
+ assert acc.get()[1] > 0.80
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='test gluon estimator')
diff --git a/tests/nightly/estimator/test_sentiment_rnn.py
b/tests/nightly/estimator/test_sentiment_rnn.py
index 5fd93c1..404bf83 100644
--- a/tests/nightly/estimator/test_sentiment_rnn.py
+++ b/tests/nightly/estimator/test_sentiment_rnn.py
@@ -183,7 +183,7 @@ def run(net, train_dataloader, test_dataloader, **kwargs):
# Begin training
est.fit(train_data=train_dataloader, val_data=test_dataloader,
epochs=num_epochs)
- return est
+ return acc
def test_estimator_cpu(**kwargs):
@@ -250,9 +250,9 @@ def test_estimator_gpu(**kwargs):
net.embedding.weight.set_data(glove_embedding.idx_to_vec)
net.embedding.collect_params().setattr('grad_req', 'null')
- est = run(net, train_dataloader, test_dataloader, **kwargs)
+ acc = run(net, train_dataloader, test_dataloader, **kwargs)
- assert est.train_stats['train_accuracy'] > 0.70
+ assert acc.get()[1] > 0.70
parser = argparse.ArgumentParser(description='test gluon estimator')
diff --git a/tests/python/unittest/test_gluon_estimator.py
b/tests/python/unittest/test_gluon_estimator.py
index 13fcd96..6f19f43 100644
--- a/tests/python/unittest/test_gluon_estimator.py
+++ b/tests/python/unittest/test_gluon_estimator.py
@@ -19,12 +19,11 @@
import sys
import unittest
-import warnings
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
-from mxnet.gluon.contrib.estimator import Estimator, EventHandler
+from mxnet.gluon.contrib.estimator import *
from nose.tools import assert_raises
@@ -222,6 +221,7 @@ def test_metric():
loss=loss,
trainer=trainer,
context=ctx)
+ est.prepare_loss_and_metrics()
assert isinstance(est.train_metrics[0], mx.metric.Accuracy)
@@ -252,7 +252,7 @@ def test_context():
metrics=metrics)
# input list of context
gpus = mx.context.num_gpus()
- ctx = [mx.gpu(i) for i in gpus] if gpus > 0 else [mx.cpu()]
+ ctx = [mx.gpu(i) for i in range(gpus)] if gpus > 0 else [mx.cpu()]
net = get_model()
est = Estimator(net=net,
loss=loss,
@@ -267,16 +267,12 @@ def test_context():
def test_categorize_handlers():
- class CustomHandler1(EventHandler):
- def __init__(self):
- super(CustomHandler1, self).__init__()
+ class CustomHandler1(TrainBegin):
def train_begin(self):
print("custom train begin")
- class CustomHandler2(EventHandler):
- def __init__(self):
- super(CustomHandler2, self).__init__()
+ class CustomHandler2(EpochBegin, BatchBegin, TrainEnd):
def epoch_begin(self):
print("custom epoch begin")
@@ -287,9 +283,7 @@ def test_categorize_handlers():
def train_end(self):
print("custom train end")
- class CustomHandler3(EventHandler):
- def __init__(self):
- super(CustomHandler3, self).__init__()
+ class CustomHandler3(EpochBegin, BatchBegin, BatchEnd, TrainEnd):
def epoch_begin(self):
print("custom epoch begin")
diff --git a/tests/python/unittest/test_gluon_event_handler.py
b/tests/python/unittest/test_gluon_event_handler.py
index dd2e60d..e151281 100644
--- a/tests/python/unittest/test_gluon_event_handler.py
+++ b/tests/python/unittest/test_gluon_event_handler.py
@@ -17,18 +17,21 @@
import os
import tempfile
+
import mxnet as mx
from mxnet import nd
from mxnet.gluon import nn, loss
from mxnet.gluon.contrib.estimator import estimator, event_handler
+from common import TemporaryDirectory
def _get_test_network():
net = nn.Sequential()
net.add(nn.Dense(128, activation='relu', in_units=100, flatten=False),
- nn.Dense(64, activation='relu', in_units=128),
- nn.Dense(10, activation='relu', in_units=64))
+ nn.Dense(64, activation='relu', in_units=128),
+ nn.Dense(10, activation='relu', in_units=64))
return net
+
def _get_test_data():
data = nd.ones((32, 100))
label = nd.random.randint(0, 10, (32, 1))
@@ -39,57 +42,60 @@ def _get_test_data():
def test_checkpoint_handler():
tmpdir = tempfile.mkdtemp()
file_path = os.path.join(tmpdir, "model.params")
- test_data = _get_test_data()
+ test_data = _get_test_data()
save_best_only = False
mode = 'auto'
net = _get_test_network()
ce_loss = loss.SoftmaxCrossEntropyLoss()
+ ce_loss_metric = mx.metric.Loss(ce_loss.name)
acc = mx.metric.Accuracy()
est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
checkpoint_handler = [event_handler.CheckpointHandler(file_path,
+ monitor=acc,
save_best_only=save_best_only,
mode=mode)]
est.fit(test_data, event_handlers=checkpoint_handler, epochs=1)
assert os.path.isfile(file_path)
os.remove(file_path)
+
def test_early_stopping():
test_data = _get_test_data()
mode = 'max'
- monitor = 'train_accuracy'
patience = 0
net = _get_test_network()
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
- early_stopping = [event_handler.EarlyStoppingHandler(monitor,
+ early_stopping = [event_handler.EarlyStoppingHandler(monitor=acc,
patience=patience,
mode=mode)]
est.fit(test_data, event_handlers=early_stopping, epochs=3)
mode = 'auto'
- monitor = 'train_accuracy'
patience = 2
- early_stopping = [event_handler.EarlyStoppingHandler(monitor,
+ early_stopping = [event_handler.EarlyStoppingHandler(monitor=acc,
patience=patience,
- mode=mode)]
+ mode=mode)]
est.fit(test_data, event_handlers=early_stopping, epochs=1)
+
def test_logging():
- tmpdir = tempfile.mkdtemp()
- test_data = _get_test_data()
- file_name = 'test_log'
- output_dir = os.path.join(tmpdir, file_name)
+ with TemporaryDirectory() as tmpdir:
+ test_data = _get_test_data()
+ file_name = 'test_log'
+ output_dir = os.path.join(tmpdir, file_name)
- net = _get_test_network()
- ce_loss = loss.SoftmaxCrossEntropyLoss()
- acc = mx.metric.Accuracy()
- est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
- logging_handler = [event_handler.LoggingHandler(file_name=file_name,
file_location=tmpdir)]
- est.fit(test_data, event_handlers=logging_handler, epochs=1)
- assert os.path.isfile(output_dir)
- os.remove(output_dir)
\ No newline at end of file
+ net = _get_test_network()
+ ce_loss = loss.SoftmaxCrossEntropyLoss()
+ ce_loss_metric = mx.metric.Loss(ce_loss.name)
+ acc = mx.metric.Accuracy()
+ est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+ logging_handler = [event_handler.LoggingHandler(file_name=file_name,
+ file_location=tmpdir,
train_metrics=[acc, ce_loss_metric])]
+ est.fit(test_data, event_handlers=logging_handler, epochs=1)
+ assert os.path.isfile(output_dir)
\ No newline at end of file