This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch fit-api
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/fit-api by this push:
     new 7e10355  [Fit API] improve event handlers (#14685)
7e10355 is described below

commit 7e10355a2adb349d1b5ff1f3c46f367a43ee1706
Author: Lai Wei <[email protected]>
AuthorDate: Thu Apr 18 22:27:52 2019 -0700

    [Fit API] improve event handlers (#14685)
    
    * improve event handlers
    
    * update tests
    
    * passing weakref of estimator
    
    * fix unit test
    
    * fix test
    
    * fix pylint
    
    * fix test
    
    * fix pylint
    
    * move default metric logic
    
    * combine nightly tests
---
 ci/docker/runtime_functions.sh                     |  18 +-
 python/mxnet/gluon/contrib/estimator/estimator.py  | 271 ++++++++---------
 .../mxnet/gluon/contrib/estimator/event_handler.py | 337 ++++++++++++++-------
 tests/nightly/Jenkinsfile                          |  12 +-
 tests/nightly/JenkinsfileForBinaries               |  16 -
 tests/nightly/estimator/test_estimator_cnn.py      |   2 +-
 tests/nightly/estimator/test_sentiment_rnn.py      |   6 +-
 tests/python/unittest/test_gluon_estimator.py      |  18 +-
 tests/python/unittest/test_gluon_event_handler.py  |  46 +--
 9 files changed, 388 insertions(+), 338 deletions(-)

diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 59ff221..b194ebb 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -1296,31 +1296,19 @@ nightly_scala_demo_test_cpu() {
     bash bin/run_im.sh
 }
 
-nightly_estimator_cnn_gpu() {
+nightly_estimator_gpu() {
     set -ex
     cd /work/mxnet/tests/nightly/estimator
     export PYTHONPATH=/work/mxnet/python/
     python test_estimator_cnn.py --type gpu
-}
-
-nightly_estimator_cnn_cpu() {
-    set -ex
-    cd /work/mxnet/tests/nightly/estimator
-    export PYTHONPATH=/work/mxnet/python/
-    python test_estimator_cnn.py --type cpu
-}
-
-nightly_estimator_rnn_gpu() {
-    set -ex
-    cd /work/mxnet/tests/nightly/estimator
-    export PYTHONPATH=/work/mxnet/python/
     python test_sentiment_rnn.py --type gpu
 }
 
-nightly_estimator_rnn_cpu() {
+nightly_estimator_cpu() {
     set -ex
     cd /work/mxnet/tests/nightly/estimator
     export PYTHONPATH=/work/mxnet/python/
+    python test_estimator_cnn.py --type cpu
     python test_sentiment_rnn.py --type cpu
 }
 
diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py 
b/python/mxnet/gluon/contrib/estimator/estimator.py
index f7c97c4..78672d2 100644
--- a/python/mxnet/gluon/contrib/estimator/estimator.py
+++ b/python/mxnet/gluon/contrib/estimator/estimator.py
@@ -16,12 +16,15 @@
 # under the License.
 
 # coding: utf-8
-# pylint: disable=wildcard-import
+# pylint: disable=wildcard-import, unused-variable
 """Gluon Estimator"""
 
 import copy
 import warnings
-from .event_handler import EventHandler, LoggingHandler
+import weakref
+
+from .event_handler import MetricHandler, ValidationHandler, LoggingHandler
+from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, 
EpochEnd, TrainEnd
 from .... import gluon, autograd
 from ....context import Context, cpu, gpu, num_gpus
 from ....metric import EvalMetric, Loss, Accuracy
@@ -46,7 +49,7 @@ class Estimator(object):
     trainer : Trainer
         Trainer to apply optimizer on network parameters
     context : Context or list of Context
-        devices to run the training on
+        device(s) to run the training on
     """
 
     def __init__(self, net,
@@ -57,46 +60,39 @@ class Estimator(object):
                  context=None):
 
         self.net = net
+        self.loss = self._check_loss(loss)
+        self.train_metrics = self._check_metrics(metrics)
+
+        self.context = self._check_context(context)
+        self._initialize(initializer)
+        self.trainer = self._check_trainer(trainer)
 
+    def _check_loss(self, loss):
         if isinstance(loss, gluon.loss.Loss):
-            self.loss = [loss]
-        elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) 
for l in loss]):
-            self.loss = loss
+            loss = [loss]
+        elif isinstance(loss, list) or all([isinstance(l, gluon.loss.Loss) for 
l in loss]):
+            loss = loss
         else:
             raise ValueError("loss must be a Loss or a list of Loss, "
                              "refer to gluon.loss.Loss:{}".format(loss))
+        return loss
 
+    def _check_metrics(self, metrics):
         if isinstance(metrics, EvalMetric):
-            self.train_metrics = [metrics]
+            metrics = [metrics]
         else:
-            self.train_metrics = metrics or []
-            if not all([isinstance(metric, EvalMetric) for metric in 
self.train_metrics]):
+            metrics = metrics or []
+            if not all([isinstance(metric, EvalMetric) for metric in metrics]):
                 raise ValueError("metrics must be a Metric or a list of 
Metric, "
                                  "refer to 
mxnet.metric.EvalMetric:{}".format(metrics))
+        return metrics
 
-        # Use default mx.metric.Accuracy() for 
gluon.loss.SoftmaxCrossEntropyLoss()
-        if not self.train_metrics and any([isinstance(l, 
gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]):
-            self.train_metrics = [Accuracy()]
-
-        # Use same metrics for validation
-        self.val_metrics = copy.deepcopy(self.train_metrics)
-
-        # store training statistics
-        self.train_stats = {}
-
-        # separate train and validation
-        self.train_loss_metrics = []
-        self.val_loss_metrics = []
-        # using the metric wrapper for loss to record loss value
-        for l in self.loss:
-            self.train_loss_metrics.append(Loss(l.name))
-            self.val_loss_metrics.append(Loss(l.name))
-
+    def _check_context(self, context):
         # handle context
         if isinstance(context, Context):
-            self.context = [context]
+            context = [context]
         elif isinstance(context, list) and all([isinstance(c, Context) for c 
in context]):
-            self.context = context
+            context = context
         elif not context:
             if num_gpus() > 0:
                 # only use 1 GPU by default
@@ -104,40 +100,41 @@ class Estimator(object):
                     warnings.warn("You have multiple GPUs, gpu(0) will be used 
by default."
                                   "To utilize all your GPUs, specify context 
as a list of gpus, "
                                   "e.g. context=[mx.gpu(0), mx.gpu(1)] ")
-                self.context = [gpu(0)]
+                context = [gpu(0)]
             else:
-                self.context = [cpu()]
+                context = [cpu()]
         else:
             raise ValueError("context must be a Context or a list of Context, "
                              "refer to mxnet.Context:{}".format(context))
+        return context
 
+    def _initialize(self, initializer):
         # initialize the network
-        self.initializer = initializer
-        if self.initializer:
+        if initializer:
             if self._is_initialized():
                 # if already initialized, re-init with user specified 
initializer
                 warnings.warn("Network already initialized, re-initializing 
with %s. "
                               "You don't need to pass initializer if you 
already "
-                              "initialized your net." % 
type(self.initializer).__name__)
-                self.net.initialize(init=self.initializer, ctx=self.context, 
force_reinit=True)
+                              "initialized your net." % 
type(initializer).__name__)
+                self.net.initialize(init=initializer, ctx=self.context, 
force_reinit=True)
             else:
                 # initialize with user specified initializer
-                self.net.initialize(init=self.initializer, ctx=self.context, 
force_reinit=False)
+                self.net.initialize(init=initializer, ctx=self.context, 
force_reinit=False)
         else:
             if not self._is_initialized():
                 self.net.initialize(ctx=self.context)
 
+    def _check_trainer(self, trainer):
         # handle trainer
         if not trainer:
             warnings.warn("No trainer specified, default SGD optimizer "
                           "with learning rate 0.001 is used.")
-            self.trainer = gluon.Trainer(self.net.collect_params(),
-                                         'sgd', {'learning_rate': 0.001})
+            trainer = gluon.Trainer(self.net.collect_params(),
+                                    'sgd', {'learning_rate': 0.001})
         elif not isinstance(trainer, gluon.Trainer):
             raise ValueError("Trainer must be a Gluon Trainer instance, refer 
to "
                              "gluon.Trainer:{}".format(trainer))
-        else:
-            self.trainer = trainer
+        return trainer
 
     def _is_initialized(self):
         param_dict = self.net.collect_params()
@@ -148,63 +145,70 @@ class Estimator(object):
                 return False
         return True
 
-    def _batch_fn(self, batch, ctx, is_iterator=False):
-        if is_iterator:
-            data = batch.data[0]
-            label = batch.label[0]
-        else:
-            data = batch[0]
-            label = batch[1]
+    def _get_data_and_label(self, batch, ctx):
+        data = batch[0]
+        label = batch[1]
         data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=0)
         label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0)
         return data, label
 
+    def prepare_loss_and_metrics(self):
+        """
+        Based on loss functions and training metrics in estimator
+        Create metric wrappers to record loss values,
+        Create copies of train loss/metric objects to record validation values
+        """
+        if any(not hasattr(self, attribute) for attribute in
+               ['train_metrics', 'val_metrics']):
+            # Use default mx.metric.Accuracy() for 
gluon.loss.SoftmaxCrossEntropyLoss()
+            if not self.train_metrics and any([isinstance(l, 
gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]):
+                self.train_metrics = [Accuracy()]
+            self.val_metrics = []
+            for loss in self.loss:
+                self.train_metrics.append(Loss("Train " + ''.join([i for i in 
loss.name if not i.isdigit()])))
+                self.val_metrics.append(Loss("Validation " + ''.join([i for i 
in loss.name if not i.isdigit()])))
+            for metric in self.train_metrics:
+                val_metric = copy.deepcopy(metric)
+                metric.name = "Train " + metric.name
+                val_metric.name = "Validation " + val_metric.name
+                self.val_metrics.append(val_metric)
+        return self.train_metrics, self.val_metrics
+
     def evaluate(self,
                  val_data,
-                 batch_fn=None):
+                 val_metrics):
         """Evaluate model on validation data
 
          Parameters
          ----------
          val_data : DataLoader
              validation data with data and labels
-         batch_fn : function
-             custom batch function to extract data and label
-             from a data batch and load into contexts(devices)
+         val_metrics : EvalMetric or list of EvalMetrics
+             metrics to update validation result
          """
 
-        for metric in self.val_metrics + self.val_loss_metrics:
+        for metric in val_metrics:
             metric.reset()
 
         for _, batch in enumerate(val_data):
-            if not batch_fn:
-                if isinstance(val_data, gluon.data.DataLoader):
-                    data, label = self._batch_fn(batch, self.context)
-                else:
-                    raise ValueError("You are using a custom iteration, please 
also provide "
-                                     "batch_fn to extract data and label. 
Alternatively, you "
-                                     "can provide the data as 
gluon.data.DataLoader.")
-            else:
-                data, label = batch_fn(batch, self.context)
+            if not isinstance(val_data, gluon.data.DataLoader):
+                raise ValueError("Estimator only support input as Gluon 
DataLoader. Alternatively, you "
+                                 "can transform your DataIter or any NDArray 
into Gluon DataLoader. "
+                                 "Refer to gluon.data.dataloader")
+            data, label = self._get_data_and_label(batch, self.context)
             pred = [self.net(x) for x in data]
-            losses = []
-            for loss in self.loss:
-                losses.append([loss(y_hat, y) for y_hat, y in zip(pred, 
label)])
+            loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
             # update metrics
-            for metric in self.val_metrics:
-                metric.update(label, pred)
-                name, value = metric.get()
-                self.train_stats['val_' + name] = value
-            for loss, loss_metric, in zip(losses, self.val_loss_metrics):
-                loss_metric.update(0, [l for l in loss])
-                name, value = loss_metric.get()
-                self.train_stats['val_' + name] = value
+            for metric in val_metrics:
+                if isinstance(metric, Loss):
+                    metric.update(0, loss)
+                else:
+                    metric.update(label, pred)
 
     def fit(self, train_data,
             val_data=None,
             epochs=1,
-            event_handlers=None,
-            batch_fn=None):
+            event_handlers=None):
         """Trains the model on a given dataset for a specified
         number of epochs. Also, the batch size is inferred from the
         DataLoader's batch_size.
@@ -226,111 +230,72 @@ class Estimator(object):
             custom batch function to extract data and label
             from a data batch and load into contexts(devices)
         """
-
-        self.max_epoch = epochs
-        self.stop_training = False
-        self.processed_samples = None
-        self.batch_idx = 0
-
+        self.max_epochs = epochs
         event_handlers = event_handlers or []
         # provide default logging handler
-        if not event_handlers or \
-                not any(isinstance(handler, LoggingHandler) for handler in 
event_handlers):
-            event_handlers.append(LoggingHandler())
-            warnings.warn("No Event Handler specified, default 
`LoggingHandler()` "
-                          "is used with 
verbose=LoggingHandler.LOG_VERBOSITY_PER_EPOCH. "
-                          "Please look at gluon.estimator.event_handler for 
more detail.")
+        if not event_handlers:
+            train_metrics, val_metrics = self.prepare_loss_and_metrics()
+            event_handlers.append(MetricHandler(train_metrics=train_metrics))
+            if val_data:
+                event_handlers.append(ValidationHandler(val_data=val_data, 
eval_fn=self.evaluate,
+                                                        
val_metrics=val_metrics))
+            event_handlers.append(LoggingHandler(train_metrics=train_metrics,
+                                                 val_metrics=val_metrics))
+            warnings.warn("No Event Handler specified, default %s are used. "
+                          "Please look at 
gluon.contrib.estimator.event_handler for more detail." %
+                          ", ".join([handler.__class__.__name__ for handler in 
event_handlers]))
+
+        event_handlers.sort(key=lambda handler: getattr(handler, 'rank', 0), 
reverse=True)
 
         train_begin, epoch_begin, batch_begin, \
         batch_end, epoch_end, train_end = 
self._categorize_handlers(event_handlers)
 
-        # passing estimator to event handlers so they can access estimator 
information
-        # when a event is triggered
-        for handler in event_handlers:
-            handler.estimator = self
-
+        # only pass a weak reference to all event handlers
+        estimator_ref = weakref.proxy(self)
         # training begin
         for handler in train_begin:
-            handler.train_begin()
+            handler.train_begin(estimator_ref)
 
-        for epoch in range(self.max_epoch):
+        for epoch in range(epochs):
             # epoch begin
-            self.current_epoch = epoch
-            # Number of samples trained after every batch
-            completed_samples = 0
-
             for handler in epoch_begin:
-                handler.epoch_begin()
-
-            for metric in self.train_metrics + self.train_loss_metrics:
-                metric.reset()
+                handler.epoch_begin(estimator_ref)
 
             for i, batch in enumerate(train_data):
-                if not batch_fn:
-                    if isinstance(train_data, gluon.data.DataLoader):
-                        data, label = self._batch_fn(batch, self.context)
-                    else:
-                        raise ValueError("You are using a custom iteration, 
please also provide "
-                                         "batch_fn to extract data and label. 
Alternatively, you "
-                                         "can provide the data as 
gluon.data.DataLoader")
-                else:
-                    data, label = batch_fn(batch, self.context)
+                if not isinstance(train_data, gluon.data.DataLoader):
+                    raise ValueError("Estimator only support input as Gluon 
DataLoader. Alternatively, you "
+                                     "can transform your DataIter or any 
NDArray into Gluon DataLoader. "
+                                     "Refer to gluon.data.dataloader")
+                data, label = self._get_data_and_label(batch, self.context)
 
                 batch_size = batch[0].shape[0]
 
                 # batch begin
                 for handler in batch_begin:
-                    handler.batch_begin()
+                    handler.batch_begin(estimator_ref, batch=batch)
 
                 with autograd.record():
                     pred = [self.net(x) for x in data]
-                    losses = []
-                    for loss in self.loss:
-                        losses.append([loss(y_hat, y) for y_hat, y in 
zip(pred, label)])
+                    loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, 
label)]
 
-                for loss in losses:
-                    for l in loss:
-                        l.backward()
-
-                # update train metrics
-                for metric in self.train_metrics:
-                    metric.update(label, pred)
-                    # get metric name and current value and update train stats
-                    name, value = metric.get()
-                    self.train_stats['train_' + name] = value
-
-                # update loss
-                for loss, loss_metric, in zip(losses, self.train_loss_metrics):
-                    loss_metric.update(0, [l for l in loss])
-                    name, value = loss_metric.get()
-                    self.train_stats['train_' + name] = value
-
-                completed_samples += batch_size
-
-                self.batch_idx = i
-                # record trained samples v.s. total samples if using Gluon 
DataLoader
-                if isinstance(train_data, gluon.data.DataLoader):
-                    self.processed_samples = "{}/{}".format(completed_samples,
-                                                            
len(train_data._dataset))
+                for l in loss:
+                    l.backward()
 
                 self.trainer.step(batch_size)
                 # batch end
                 for handler in batch_end:
-                    handler.batch_end()
-
-            if val_data:
-                self.evaluate(val_data, batch_fn)
+                    if handler.batch_end(estimator_ref, batch=batch,
+                                         pred=pred, label=label, loss=loss):
+                        break
 
             # epoch end
             for handler in epoch_end:
-                handler.epoch_end()
-
-            if self.stop_training:
-                break
+                if handler.epoch_end(estimator_ref):
+                    break
 
         # train end
         for handler in train_end:
-            handler.train_end()
+            handler.train_end(estimator_ref)
 
     def _categorize_handlers(self, event_handlers):
         """
@@ -346,16 +311,16 @@ class Estimator(object):
         epoch_end = []
         train_end = []
         for handler in event_handlers:
-            if not handler.__class__.train_begin == EventHandler.train_begin:
+            if isinstance(handler, TrainBegin):
                 train_begin.append(handler)
-            if not handler.__class__.epoch_begin == EventHandler.epoch_begin:
+            if isinstance(handler, EpochBegin):
                 epoch_begin.append(handler)
-            if not handler.__class__.batch_begin == EventHandler.batch_begin:
+            if isinstance(handler, BatchBegin):
                 batch_begin.append(handler)
-            if not handler.__class__.batch_end == EventHandler.batch_end:
+            if isinstance(handler, BatchEnd):
                 batch_end.append(handler)
-            if not handler.__class__.epoch_end == EventHandler.epoch_end:
+            if isinstance(handler, EpochEnd):
                 epoch_end.append(handler)
-            if not handler.__class__.train_end == EventHandler.train_end:
+            if isinstance(handler, TrainEnd):
                 train_end.append(handler)
         return train_begin, epoch_begin, batch_begin, batch_end, epoch_end, 
train_end
diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py 
b/python/mxnet/gluon/contrib/estimator/event_handler.py
index 53c0bf5..220aa31 100644
--- a/python/mxnet/gluon/contrib/estimator/event_handler.py
+++ b/python/mxnet/gluon/contrib/estimator/event_handler.py
@@ -16,10 +16,9 @@
 # under the License.
 
 # coding: utf-8
-# pylint: disable=wildcard-import
+# pylint: disable=wildcard-import, unused-argument
 """Gluon EventHandlers for Estimators"""
 
-__all__ = ['EventHandler', 'LoggingHandler']
 import logging
 import os
 import time
@@ -27,51 +26,130 @@ import warnings
 
 import numpy as np
 
+from ....metric import EvalMetric, Loss
 
-class EventHandler(object):
-    """Basic for event handlers
 
-        :py:class:`EventHandler` can perform user defined functions at
-        different stages of training: train begin, epoch begin, batch begin,
-        batch end, epoch end, train end.
-
-        Parameters
-        ----------
-        estimator : Estimator
-            The :py:class:`Estimator` to get training statistics
-        """
+class TrainBegin(object):
+    def train_begin(self, estimator, *args, **kwargs):
+        pass
 
-    def __init__(self):
-        self._estimator = None
 
-    @property
-    def estimator(self):
-        return self._estimator
+class TrainEnd(object):
+    def train_end(self, estimator, *args, **kwargs):
+        pass
 
-    @estimator.setter
-    def estimator(self, estimator):
-        self._estimator = estimator
 
-    def train_begin(self):
+class EpochBegin(object):
+    def epoch_begin(self, estimator, *args, **kwargs):
         pass
 
-    def train_end(self):
-        pass
 
-    def batch_begin(self):
-        pass
+class EpochEnd(object):
+    def epoch_end(self, estimator, *args, **kwargs):
+        return False
 
-    def batch_end(self):
-        pass
 
-    def epoch_begin(self):
+class BatchBegin(object):
+    def batch_begin(self, estimator, *args, **kwargs):
         pass
 
-    def epoch_end(self):
-        pass
 
+class BatchEnd(object):
+    def batch_end(self, estimator, *args, **kwargs):
+        return False
+
+
+class MetricHandler(EpochBegin, BatchEnd):
+    """Metric Handler that update metric values at batch end
+
+    :py:class:`MetricHandler` takes model predictions and true labels
+    and update the metrics, it also update metric wrapper for loss with loss 
values
+    Validation loss and metrics will be handled by 
:py:class:`ValidationHandler`
+
+    Parameters
+    ----------
+    train_metrics : List of EvalMetrics
+        training metrics to be updated at batch end
+    """
+
+    def __init__(self, train_metrics):
+        self.train_metrics = train_metrics or []
+        # order to be called among all callbacks
+        # metrics need to be calculated before other callbacks can access them
+        self.priority = -np.Inf
+
+    def epoch_begin(self, estimator, *args, **kwargs):
+        for metric in self.train_metrics:
+            metric.reset()
+
+    def batch_end(self, estimator, *args, **kwargs):
+        pred = kwargs['pred']
+        label = kwargs['label']
+        loss = kwargs['loss']
+        for metric in self.train_metrics:
+            if isinstance(metric, Loss):
+                # metric wrapper for loss values
+                metric.update(0, loss)
+            else:
+                metric.update(label, pred)
 
-class LoggingHandler(EventHandler):
+
+class ValidationHandler(BatchEnd, EpochEnd):
+    """"Validation Handler that evaluate model on validation dataset
+
+    :py:class:`ValidationHandler` takes validation dataset, an evaluation 
function,
+    metrics to be evaluated, and how often to run the validation. You can 
provide custom
+    evaluation function or use the one provided my :py:class:`Estimator`
+
+    Parameters
+    ----------
+    val_data : DataLoader
+        validation data set to run evaluation
+    eval_fn : function
+        a function defines how to run evaluation and
+        calculate loss and metrics
+    val_metrics : List of EvalMetrics
+        validation metrics to be updated
+    epoch_period : int, default 1
+        how often to run validation at epoch end, by default
+        validate every epoch
+    batch_period : int, default None
+        how often to run validation at batch end, by default
+        does not validate at batch end
+    """
+
+    def __init__(self,
+                 val_data,
+                 eval_fn,
+                 val_metrics=None,
+                 epoch_period=1,
+                 batch_period=None):
+        self.val_data = val_data
+        self.eval_fn = eval_fn
+        self.epoch_period = epoch_period
+        self.batch_period = batch_period
+        self.val_metrics = val_metrics
+        self.num_batches = 0
+        self.num_epochs = 0
+        # order to be called among all callbacks
+        # validation metrics need to be calculated before other callbacks can 
access them
+        self.priority = -np.Inf
+
+    def batch_end(self, estimator, *args, **kwargs):
+        if self.batch_period and self.num_batches % self.batch_period == 0:
+            self.eval_fn(val_data=self.val_data,
+                         val_metrics=self.val_metrics)
+        self.num_batches += 1
+
+    def epoch_end(self, estimator, *args, **kwargs):
+        if self.num_epochs % self.epoch_period == 0:
+            self.eval_fn(val_data=self.val_data,
+                         val_metrics=self.val_metrics)
+
+        self.num_epochs += 1
+
+
+class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, 
BatchEnd):
     """Basic Logging Handler that applies to every Gluon estimator by default.
 
     :py:class:`LoggingHandler` logs hyper-parameters, training statistics,
@@ -79,22 +157,28 @@ class LoggingHandler(EventHandler):
 
     Parameters
     ----------
-    estimator : Estimator
-        The :py:class:`Estimator` to get training statistics
     file_name : str
         file name to save the logs
-    file_location: str
+    file_location : str
         file location to save the logs
-    verbose: int, default LOG_VERBOSITY_PER_EPOCH
+    verbose : int, default LOG_VERBOSITY_PER_EPOCH
         Limit the granularity of metrics displayed during training process
         verbose=LOG_VERBOSITY_PER_EPOCH: display metrics every epoch
         verbose=LOG_VERBOSITY_PER_BATCH: display metrics every batch
+    train_metrics : list of EvalMetrics
+        training metrics to be logged, logged at batch end, epoch end, train 
end
+    val_metrics : list of EvalMetrics
+        validation metrics to be logged, logged at epoch end, train end
     """
 
     LOG_VERBOSITY_PER_EPOCH = 1
     LOG_VERBOSITY_PER_BATCH = 2
 
-    def __init__(self, file_name=None, file_location=None, 
verbose=LOG_VERBOSITY_PER_EPOCH):
+    def __init__(self, file_name=None,
+                 file_location=None,
+                 verbose=LOG_VERBOSITY_PER_EPOCH,
+                 train_metrics=None,
+                 val_metrics=None):
         super(LoggingHandler, self).__init__()
         self.logger = logging.getLogger(__name__)
         self.logger.setLevel(logging.INFO)
@@ -112,71 +196,83 @@ class LoggingHandler(EventHandler):
             file_location = file_location or './'
             file_handler = logging.FileHandler(os.path.join(file_location, 
file_name))
             self.logger.addHandler(file_handler)
-
-    def train_begin(self):
+        self.train_metrics = train_metrics or []
+        self.val_metrics = val_metrics or []
+        self.batch_index = 0
+        self.current_epoch = 0
+        self.processed_samples = 0
+        # logging handler need to be called at last to make sure all states 
are updated
+        # it will also shut down logging at train end
+        self.priority = np.Inf
+
+    def train_begin(self, estimator, *args, **kwargs):
         self.train_start = time.time()
+        trainer = estimator.trainer
+        optimizer = trainer.optimizer.__class__.__name__
+        lr = trainer.learning_rate
         self.logger.info("Training begin: using optimizer %s "
                          "with current learning rate %.4f ",
-                         self.estimator.trainer.optimizer.__class__.__name__,
-                         self.estimator.trainer.learning_rate)
-        self.logger.info("Train for %d epochs.", self.estimator.max_epoch)
+                         optimizer, lr)
+        self.logger.info("Train for %d epochs.", estimator.max_epochs)
 
-    def train_end(self):
+    def train_end(self, estimator, *args, **kwargs):
         train_time = time.time() - self.train_start
-        epoch = self.estimator.current_epoch
-        msg = 'Train finished using total %ds at epoch %d. ' % (train_time, 
epoch)
+        msg = 'Train finished using total %ds with %d epochs.' % (train_time, 
self.current_epoch)
         # log every result in train stats including train/validation loss & 
metrics
-        for key in self.estimator.train_stats:
-            msg += '%s : %.4f ' % (key, self.estimator.train_stats[key])
+        for metric in self.train_metrics + self.val_metrics:
+            name, value = metric.get()
+            msg += '%s : %.4f ' % (name, value)
         self.logger.info(msg)
+        for handler in self.logger.handlers:
+            handler.close()
+            self.logger.removeHandler(handler)
+        logging.shutdown()
 
-    def batch_begin(self):
+    def batch_begin(self, estimator, *args, **kwargs):
         if self.verbose == self.LOG_VERBOSITY_PER_BATCH:
             self.batch_start = time.time()
 
-    def batch_end(self):
+    def batch_end(self, estimator, *args, **kwargs):
         if self.verbose == self.LOG_VERBOSITY_PER_BATCH:
             batch_time = time.time() - self.batch_start
-            epoch = self.estimator.current_epoch
-            batch = self.estimator.batch_idx
-            msg = '[Epoch %d] [Batch %d] ' % (epoch, batch)
-            if self.estimator.processed_samples:
-                msg += '[Samples %s] ' % (self.estimator.processed_samples)
+            msg = '[Epoch %d] [Batch %d] ' % (self.current_epoch, 
self.batch_index)
+            self.processed_samples += kwargs['batch'][0].shape[0]
+            msg += '[Samples %s] ' % (self.processed_samples)
             msg += 'time/batch: %.3fs ' % batch_time
-            for key in self.estimator.train_stats:
+            for metric in self.train_metrics:
                 # only log current training loss & metric after each batch
-                if key.startswith('train_'):
-                    msg += key + ': ' + '%.4f ' % 
self.estimator.train_stats[key]
+                name, value = metric.get()
+                msg += '%s : %.4f ' % (name, value)
             self.logger.info(msg)
+            self.batch_index += 1
 
-    def epoch_begin(self):
+    def epoch_begin(self, estimator, *args, **kwargs):
         if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH:
             self.epoch_start = time.time()
 
-    def epoch_end(self):
+    def epoch_end(self, estimator, *args, **kwargs):
         if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH:
             epoch_time = time.time() - self.epoch_start
-            epoch = self.estimator.current_epoch
-            msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time)
-            # log every result in train stats including train/validation loss 
& metrics
-            for key in self.estimator.train_stats:
-                msg += '%s : %.4f ' % (key, self.estimator.train_stats[key])
+            msg = '\n[Epoch %d] finished in %.3fs: ' % (self.current_epoch, 
epoch_time)
+            for monitor in self.train_metrics + self.val_metrics:
+                name, value = monitor.get()
+                msg += '%s : %.4f ' % (name, value)
             self.logger.info(msg)
+            self.current_epoch += 1
+            self.batch_index = 0
 
 
-class CheckpointHandler(EventHandler):
+class CheckpointHandler(BatchEnd, EpochEnd):
     """Save the model after every epoch.
 
     :py:class:`CheckpointHandler` save the network parameters every epoch
 
     Parameters
     ----------
-    estimator : Estimator
-        The :py:class:`Estimator` to get training statistics
     filepath : str
         file name to save the parameters, it can contain directories,
         for example: ./saved_model/resnet.params
-    monitor: str
+    monitor: EvalMetric
         the metrics to monitor
     verbose: int, default 0
         verbosity mode
@@ -191,18 +287,23 @@ class CheckpointHandler(EventHandler):
 
     def __init__(self,
                  filepath,
-                 monitor='val_accuracy',
+                 monitor=None,
                  verbose=0,
                  save_best_only=False,
                  mode='auto',
-                 period=1):
-        super(CheckpointHandler, self).__init__()
+                 epoch_period=1,
+                 batch_period=None):
         self.monitor = monitor
         self.verbose = verbose
         self.filepath = filepath
         self.save_best_only = save_best_only
-        self.period = period
-        self.epochs_since_last_save = 0
+        if self.save_best_only and not isinstance(self.monitor, EvalMetric):
+            raise ValueError("To save best model only, please provide one of 
the metric objects as monitor, "
+                             "You can create these objects using 
estimator.prepare_loss_and_metric()")
+        self.epoch_period = epoch_period
+        self.batch_period = batch_period
+        self.num_batches = 0
+        self.num_epochs = 0
         self.logger = logging.getLogger(__name__)
 
         if mode not in ['auto', 'min', 'max']:
@@ -219,55 +320,61 @@ class CheckpointHandler(EventHandler):
             self.best = -np.Inf
         else:
             # use greater for accuracy and less otherwise
-            if 'acc' in self.monitor:
+            if 'acc' in self.monitor.get()[0].lower():
                 self.monitor_op = np.greater
                 self.best = -np.Inf
             else:
                 self.monitor_op = np.less
                 self.best = np.Inf
 
-    def epoch_end(self, ):
-        epoch = self.estimator.current_epoch
+    def batch_end(self, estimator, *args, **kwargs):
+        self._save_checkpoint(estimator.net, "Batch", self.num_batches)
+        self.num_batches += 1
+
+    def epoch_end(self, estimator, *args, **kwargs):
+        self._save_checkpoint(estimator.net, "Epoch", self.num_epochs)
+        self.num_epochs += 1
+
+    def _save_checkpoint(self, net, period_name, period_value):
         # add extension for weights
         if '.params' not in self.filepath:
             self.filepath += '.params'
-        self.epochs_since_last_save += 1
-        if self.epochs_since_last_save >= self.period:
-            self.epochs_since_last_save = 0
+        if self.num_epochs % self.epoch_period == 0:
             if self.save_best_only:
+                monitor_name, monitor_value = self.monitor.get()
                 # check if monitor exists in train stats
-                if self.monitor not in self.estimator.train_stats:
-                    warnings.warn(RuntimeWarning('Unable to find %s in 
training statistics, make sure the monitor value'
-                                                 'starts with `train_ `or 
`val_` and contains loss/metric name, ',
-                                                 'for example val_accuracy', 
self.monitor))
-                    self.estimator.net.save_parameters(self.filepath)
+                if np.isnan(monitor_value):
+                    warnings.warn(RuntimeWarning('%s is not updated, make sure 
you pass one of the metric objects'
+                                                 'as monitor, you can use 
estimator.prepare_loss_and_metrics to'
+                                                 'create all metric objects', 
monitor_name))
+                    net.save_parameters(self.filepath)
                 else:
-                    current = self.estimator.train_stats[self.monitor]
-                    if self.monitor_op(current, self.best):
+                    if self.monitor_op(monitor_value, self.best):
                         if self.verbose > 0:
-                            self.logger.info('\n[Epoch %d] %s improved from 
%0.5f to %0.5f,'
+                            self.logger.info('\n[%s %d] %s improved from %0.5f 
to %0.5f,'
                                              ' saving model to %s',
-                                             epoch, self.monitor, self.best, 
current, self.filepath)
-                        self.best = current
-                        self.estimator.net.save_parameters(self.filepath)
+                                             period_name, period_value, 
monitor_name,
+                                             self.best, monitor_value, 
self.filepath)
+                        self.best = monitor_value
+                        net.save_parameters(self.filepath)
                     else:
                         if self.verbose > 0:
-                            self.logger.info('\n[Epoch %d] %s did not improve 
from %0.5f, skipping save model',
-                                             epoch, self.monitor, self.best)
+                            self.logger.info('\n[%s %d] %s did not improve 
from %0.5f, skipping save model',
+                                             period_name, period_value, 
monitor_name, self.best)
             else:
                 if self.verbose > 0:
-                    logging.info('\nEpoch %d: saving model to %s', epoch, 
self.filepath)
-                self.estimator.net.save_parameters(self.filepath)
+                    logging.info('\n%s %d: saving model to %s', period_name, 
period_value, self.filepath)
+                net.save_parameters(self.filepath)
 
 
-class EarlyStoppingHandler(EventHandler):
+class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd):
     """Early stop training if monitored value is not improving
 
     Parameters
     ----------
     estimator : Estimator
         The :py:class:`Estimator` to get training statistics
-    monitor: str
+    monitor: EvalMetric
         the metrics to monitor
     min_delta: float, default 0
         minimal change in monitored value to be considered as an improvement
@@ -281,19 +388,24 @@ class EarlyStoppingHandler(EventHandler):
     """
 
     def __init__(self,
-                 monitor='val_accuracy',
+                 monitor,
                  min_delta=0,
                  patience=0,
                  mode='auto',
                  baseline=None):
         super(EarlyStoppingHandler, self).__init__()
 
+        if not isinstance(monitor, EvalMetric):
+            raise ValueError("Please provide one of the metric objects as 
monitor, "
+                             "You can create these objects using 
estimator.prepare_loss_and_metric()")
         self.monitor = monitor
         self.baseline = baseline
         self.patience = patience
         self.min_delta = min_delta
         self.wait = 0
         self.stopped_epoch = 0
+        self.num_epochs = 0
+        self.stop_training = False
         self.logger = logging.getLogger(__name__)
 
         if mode not in ['auto', 'min', 'max']:
@@ -306,7 +418,7 @@ class EarlyStoppingHandler(EventHandler):
         elif mode == 'max':
             self.monitor_op = np.greater
         else:
-            if 'acc' in self.monitor:
+            if 'acc' in self.monitor.get()[0].lower():
                 self.monitor_op = np.greater
             else:
                 self.monitor_op = np.less
@@ -316,7 +428,7 @@ class EarlyStoppingHandler(EventHandler):
         else:
             self.min_delta *= -1
 
-    def train_begin(self):
+    def train_begin(self, estimator, *args, **kwargs):
         self.wait = 0
         self.stopped_epoch = 0
         if self.baseline is not None:
@@ -324,23 +436,24 @@ class EarlyStoppingHandler(EventHandler):
         else:
             self.best = np.Inf if self.monitor_op == np.less else -np.Inf
 
-    def epoch_end(self):
-        epoch = self.estimator.current_epoch
-        if self.monitor not in self.estimator.train_stats:
-            warnings.warn(RuntimeWarning('Unable to find %s in training 
statistics, make sure the monitor value'
-                                         'starts with `train_ `or `val_` and 
contains loss/metric name, ',
-                                         'for example val_accuracy', 
self.monitor))
+    def epoch_end(self, estimator, *args, **kwargs):
+        monitor_name, monitor_value = self.monitor.get()
+        if np.isnan(monitor_value):
+            warnings.warn(RuntimeWarning('%s is not updated, make sure you 
pass one of the metric objects'
+                                         'as monitor, you can use 
estimator.prepare_loss_and_metrics to'
+                                         'create all metric objects', 
monitor_name))
         else:
-            current = self.estimator.train_stats[self.monitor]
-            if self.monitor_op(current - self.min_delta, self.best):
-                self.best = current
+            if self.monitor_op(monitor_value - self.min_delta, self.best):
+                self.best = monitor_value
                 self.wait = 0
             else:
                 self.wait += 1
                 if self.wait >= self.patience:
-                    self.stopped_epoch = epoch
-                    self.estimator.stop_training = True
+                    self.stopped_epoch = self.num_epochs
+                    self.stop_training = True
+        return self.stop_training
 
-    def train_end(self):
+    def train_end(self, estimator, *args, **kwargs):
         if self.stopped_epoch > 0:
-            self.logger.info('Epoch %d: early stopping due to %s not 
improving', self.stopped_epoch, self.monitor)
+            self.logger.info('Epoch %d: early stopping due to %s not 
improving',
+                             self.stopped_epoch, self.monitor.get()[0])
diff --git a/tests/nightly/Jenkinsfile b/tests/nightly/Jenkinsfile
index a65da2d..1be084c 100755
--- a/tests/nightly/Jenkinsfile
+++ b/tests/nightly/Jenkinsfile
@@ -137,19 +137,19 @@ core_logic: {
         }
       }
     },
-    'estimator: RNN GPU': {
+    'Gluon estimator: GPU': {
       node(NODE_LINUX_GPU) {
-        ws('workspace/estimator-test-rnn-gpu') {
+        ws('workspace/estimator-test-gpu') {
           utils.unpack_and_init('gpu', mx_lib)
-          utils.docker_run('ubuntu_nightly_gpu', 
'nightly_estimator_test_rnn_gpu', true)
+          utils.docker_run('ubuntu_nightly_gpu', 'nightly_estimator_gpu', true)
         }
       }
     },
-    'estimator: RNN CPU': {
+    'Gluon estimator: CPU': {
       node(NODE_LINUX_CPU) {
-        ws('workspace/estimator-test-rnn-cpu') {
+        ws('workspace/estimator-test-cpu') {
           utils.unpack_and_init('cpu', mx_lib)
-          utils.docker_run('ubuntu_nightly_cpu', 
'nightly_estimator_test_rnn_cpu', false)
+          utils.docker_run('ubuntu_nightly_cpu', 'nightly_estimator_cpu', 
false)
         }
       }
     }
diff --git a/tests/nightly/JenkinsfileForBinaries 
b/tests/nightly/JenkinsfileForBinaries
index 53572c8..53e1c30 100755
--- a/tests/nightly/JenkinsfileForBinaries
+++ b/tests/nightly/JenkinsfileForBinaries
@@ -106,22 +106,6 @@ core_logic: {
           utils.docker_run('ubuntu_nightly_gpu', 
'nightly_tutorial_test_ubuntu_python3_gpu', true, '1500m')
         }
       }
-    },
-    'estimator: CNN GPU': {
-      node(NODE_LINUX_GPU) {
-        ws('workspace/estimator-test-cnn-gpu') {
-          utils.unpack_and_init('gpu', mx_lib)
-          utils.docker_run('ubuntu_nightly_gpu', 
'nightly_estimator_test_cnn_gpu', true)
-        }
-      }
-    },
-    'estimator: CNN CPU': {
-      node(NODE_LINUX_CPU) {
-        ws('workspace/estimator-test-cnn-cpu') {
-          utils.unpack_and_init('cpu', mx_lib)
-          utils.docker_run('ubuntu_nightly_cpu', 
'nightly_estimator_test_cnn_cpu', true)
-        }
-      }
     }
   }
 }
diff --git a/tests/nightly/estimator/test_estimator_cnn.py 
b/tests/nightly/estimator/test_estimator_cnn.py
index 7d0018b..c60dc54 100644
--- a/tests/nightly/estimator/test_estimator_cnn.py
+++ b/tests/nightly/estimator/test_estimator_cnn.py
@@ -137,7 +137,7 @@ def test_estimator_gpu():
             val_data=test_data,
             epochs=num_epochs)
 
-    assert est.train_stats['train_'+acc.name] > 0.80
+    assert acc.get()[1] > 0.80
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='test gluon estimator')
diff --git a/tests/nightly/estimator/test_sentiment_rnn.py 
b/tests/nightly/estimator/test_sentiment_rnn.py
index 5fd93c1..404bf83 100644
--- a/tests/nightly/estimator/test_sentiment_rnn.py
+++ b/tests/nightly/estimator/test_sentiment_rnn.py
@@ -183,7 +183,7 @@ def run(net, train_dataloader, test_dataloader, **kwargs):
     # Begin training
     est.fit(train_data=train_dataloader, val_data=test_dataloader,
             epochs=num_epochs)
-    return est
+    return acc
 
 
 def test_estimator_cpu(**kwargs):
@@ -250,9 +250,9 @@ def test_estimator_gpu(**kwargs):
     net.embedding.weight.set_data(glove_embedding.idx_to_vec)
     net.embedding.collect_params().setattr('grad_req', 'null')
 
-    est = run(net, train_dataloader, test_dataloader, **kwargs)
+    acc = run(net, train_dataloader, test_dataloader, **kwargs)
 
-    assert est.train_stats['train_accuracy'] > 0.70
+    assert acc.get()[1] > 0.70
 
 
 parser = argparse.ArgumentParser(description='test gluon estimator')
diff --git a/tests/python/unittest/test_gluon_estimator.py 
b/tests/python/unittest/test_gluon_estimator.py
index 13fcd96..6f19f43 100644
--- a/tests/python/unittest/test_gluon_estimator.py
+++ b/tests/python/unittest/test_gluon_estimator.py
@@ -19,12 +19,11 @@
 
 import sys
 import unittest
-import warnings
 
 import mxnet as mx
 from mxnet import gluon
 from mxnet.gluon import nn
-from mxnet.gluon.contrib.estimator import Estimator, EventHandler
+from mxnet.gluon.contrib.estimator import *
 from nose.tools import assert_raises
 
 
@@ -222,6 +221,7 @@ def test_metric():
                     loss=loss,
                     trainer=trainer,
                     context=ctx)
+    est.prepare_loss_and_metrics()
     assert isinstance(est.train_metrics[0], mx.metric.Accuracy)
 
 
@@ -252,7 +252,7 @@ def test_context():
                     metrics=metrics)
     # input list of context
     gpus = mx.context.num_gpus()
-    ctx = [mx.gpu(i) for i in gpus] if gpus > 0 else [mx.cpu()]
+    ctx = [mx.gpu(i) for i in range(gpus)] if gpus > 0 else [mx.cpu()]
     net = get_model()
     est = Estimator(net=net,
                     loss=loss,
@@ -267,16 +267,12 @@ def test_context():
 
 
 def test_categorize_handlers():
-    class CustomHandler1(EventHandler):
-        def __init__(self):
-            super(CustomHandler1, self).__init__()
+    class CustomHandler1(TrainBegin):
 
         def train_begin(self):
             print("custom train begin")
 
-    class CustomHandler2(EventHandler):
-        def __init__(self):
-            super(CustomHandler2, self).__init__()
+    class CustomHandler2(EpochBegin, BatchBegin, TrainEnd):
 
         def epoch_begin(self):
             print("custom epoch begin")
@@ -287,9 +283,7 @@ def test_categorize_handlers():
         def train_end(self):
             print("custom train end")
 
-    class CustomHandler3(EventHandler):
-        def __init__(self):
-            super(CustomHandler3, self).__init__()
+    class CustomHandler3(EpochBegin, BatchBegin, BatchEnd, TrainEnd):
 
         def epoch_begin(self):
             print("custom epoch begin")
diff --git a/tests/python/unittest/test_gluon_event_handler.py 
b/tests/python/unittest/test_gluon_event_handler.py
index dd2e60d..e151281 100644
--- a/tests/python/unittest/test_gluon_event_handler.py
+++ b/tests/python/unittest/test_gluon_event_handler.py
@@ -17,18 +17,21 @@
 
 import os
 import tempfile
+
 import mxnet as mx
 from mxnet import nd
 from mxnet.gluon import nn, loss
 from mxnet.gluon.contrib.estimator import estimator, event_handler
+from common import TemporaryDirectory
 
 def _get_test_network():
     net = nn.Sequential()
     net.add(nn.Dense(128, activation='relu', in_units=100, flatten=False),
-              nn.Dense(64, activation='relu', in_units=128),
-              nn.Dense(10, activation='relu', in_units=64))
+            nn.Dense(64, activation='relu', in_units=128),
+            nn.Dense(10, activation='relu', in_units=64))
     return net
 
+
 def _get_test_data():
     data = nd.ones((32, 100))
     label = nd.random.randint(0, 10, (32, 1))
@@ -39,57 +42,60 @@ def _get_test_data():
 def test_checkpoint_handler():
     tmpdir = tempfile.mkdtemp()
     file_path = os.path.join(tmpdir, "model.params")
-    test_data  = _get_test_data()
+    test_data = _get_test_data()
 
     save_best_only = False
     mode = 'auto'
 
     net = _get_test_network()
     ce_loss = loss.SoftmaxCrossEntropyLoss()
+    ce_loss_metric = mx.metric.Loss(ce_loss.name)
     acc = mx.metric.Accuracy()
     est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
     checkpoint_handler = [event_handler.CheckpointHandler(file_path,
+                                                          monitor=acc,
                                                           
save_best_only=save_best_only,
                                                           mode=mode)]
     est.fit(test_data, event_handlers=checkpoint_handler, epochs=1)
     assert os.path.isfile(file_path)
     os.remove(file_path)
 
+
 def test_early_stopping():
     test_data = _get_test_data()
 
     mode = 'max'
-    monitor = 'train_accuracy'
     patience = 0
 
     net = _get_test_network()
     ce_loss = loss.SoftmaxCrossEntropyLoss()
     acc = mx.metric.Accuracy()
     est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
-    early_stopping = [event_handler.EarlyStoppingHandler(monitor,
+    early_stopping = [event_handler.EarlyStoppingHandler(monitor=acc,
                                                          patience=patience,
                                                          mode=mode)]
     est.fit(test_data, event_handlers=early_stopping, epochs=3)
 
     mode = 'auto'
-    monitor = 'train_accuracy'
     patience = 2
-    early_stopping = [event_handler.EarlyStoppingHandler(monitor,
+    early_stopping = [event_handler.EarlyStoppingHandler(monitor=acc,
                                                          patience=patience,
-                                                          mode=mode)]
+                                                         mode=mode)]
     est.fit(test_data, event_handlers=early_stopping, epochs=1)
 
+
 def test_logging():
-    tmpdir = tempfile.mkdtemp()
-    test_data = _get_test_data()
-    file_name = 'test_log'
-    output_dir = os.path.join(tmpdir, file_name)
+    with TemporaryDirectory() as tmpdir:
+        test_data = _get_test_data()
+        file_name = 'test_log'
+        output_dir = os.path.join(tmpdir, file_name)
 
-    net = _get_test_network()
-    ce_loss = loss.SoftmaxCrossEntropyLoss()
-    acc = mx.metric.Accuracy()
-    est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
-    logging_handler = [event_handler.LoggingHandler(file_name=file_name, 
file_location=tmpdir)]
-    est.fit(test_data, event_handlers=logging_handler, epochs=1)
-    assert os.path.isfile(output_dir)
-    os.remove(output_dir)
\ No newline at end of file
+        net = _get_test_network()
+        ce_loss = loss.SoftmaxCrossEntropyLoss()
+        ce_loss_metric = mx.metric.Loss(ce_loss.name)
+        acc = mx.metric.Accuracy()
+        est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+        logging_handler = [event_handler.LoggingHandler(file_name=file_name,
+                                                        file_location=tmpdir, 
train_metrics=[acc, ce_loss_metric])]
+        est.fit(test_data, event_handlers=logging_handler, epochs=1)
+        assert os.path.isfile(output_dir)
\ No newline at end of file

Reply via email to