This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch fit-api
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/fit-api by this push:
     new ed7f6e5  [MXNet-1340][Fit API]Update train stats (#14494)
ed7f6e5 is described below

commit ed7f6e56a4e372d5d460031186145065f5657893
Author: Lai Wei <[email protected]>
AuthorDate: Wed Apr 3 14:18:13 2019 -0700

    [MXNet-1340][Fit API]Update train stats (#14494)
    
    * add train history
    
    * update history
    
    * update test
    
    * avoid calling empty methods
    
    * remove train history object
    
    * fix pylint
    
    * add unit test
    
    * fix test
    
    * update categorize handlers
---
 python/mxnet/gluon/estimator/estimator.py         | 147 +++++++++-------
 python/mxnet/gluon/estimator/event_handler.py     | 102 +++++++-----
 python/mxnet/gluon/trainer.py                     |   7 +
 tests/python/unittest/test_gluon_estimator.py     | 193 ++++++++++++++--------
 tests/python/unittest/test_gluon_event_handler.py |  12 +-
 5 files changed, 280 insertions(+), 181 deletions(-)

diff --git a/python/mxnet/gluon/estimator/estimator.py 
b/python/mxnet/gluon/estimator/estimator.py
index e759fa7..c5da0c0 100644
--- a/python/mxnet/gluon/estimator/estimator.py
+++ b/python/mxnet/gluon/estimator/estimator.py
@@ -22,7 +22,7 @@
 import copy
 import warnings
 
-from .event_handler import LoggingHandler
+from .event_handler import EventHandler, LoggingHandler
 from ... import gluon, autograd
 from ...context import Context, cpu, gpu, num_gpus
 from ...io import DataIter
@@ -39,27 +39,26 @@ class Estimator(object):
 
     Parameters
     ----------
-    loss : Loss or list of Loss
+    loss : gluon.loss.Loss or list of gluon.loss.Loss
         Loss(objective functions) to calculate during training
     metrics : EvalMetric or list of EvalMetric
         Metrics for evaluating models
     initializer : Initializer
         initializer to initialize the network
-    trainers : Trainer or list of Trainer
-        Trainers to apply optimizers on network parameters
+    trainer : Trainer
+        Trainer to apply optimizer on network parameters
     context : Context or list of Context
         devices to run the training on
     """
 
     def __init__(self, net,
-                 loss=None,
+                 loss,
                  metrics=None,
                  initializer=None,
-                 trainers=None,
+                 trainer=None,
                  context=None):
 
         self.net = net
-        self.stop_training = False
 
         if isinstance(loss, gluon.loss.Loss):
             self.loss = [loss]
@@ -86,27 +85,14 @@ class Estimator(object):
 
         # store training statistics
         self.train_stats = {}
-        self.train_stats['epochs'] = []
-        self.train_stats['learning_rate'] = []
-        # current step of the epoch
-        self.train_stats['step'] = ''
-        for metric in self.train_metrics:
-            # record a history of metrics over each epoch
-            self.train_stats['train_' + metric.name] = []
-            # only record the latest metric numbers after each batch
-            self.train_stats['batch_' + metric.name] = 0.
-        for metric in self.val_metrics:
-            self.train_stats['val_' + metric.name] = []
+
+        # separate train and validation
         self.train_loss_metrics = []
         self.val_loss_metrics = []
         # using the metric wrapper for loss to record loss value
         for l in self.loss:
             self.train_loss_metrics.append(Loss(l.name))
             self.val_loss_metrics.append(Loss(l.name))
-            self.train_stats['train_' + l.name] = []
-            self.train_stats['val_' + l.name] = []
-            # only record the latest loss numbers after each batch
-            self.train_stats['batch_' + l.name] = 0.
 
         # handle context
         if isinstance(context, Context):
@@ -127,7 +113,6 @@ class Estimator(object):
             raise ValueError("context must be a Context or a list of Context, "
                              "refer to mxnet.Context:{}".format(context))
 
-
         # initialize the network
         self.initializer = initializer
         if self.initializer:
@@ -135,7 +120,7 @@ class Estimator(object):
                 # if already initialized, re-init with user specified 
initializer
                 warnings.warn("Network already initialized, re-initializing 
with %s. "
                               "You don't need to pass initializer if you 
already "
-                              "initialized your net."% 
type(self.initializer).__name__)
+                              "initialized your net." % 
type(self.initializer).__name__)
                 self.net.initialize(init=self.initializer, ctx=self.context, 
force_reinit=True)
             else:
                 # initialize with user specified initializer
@@ -144,16 +129,17 @@ class Estimator(object):
             if not self._is_initialized():
                 self.net.initialize(ctx=self.context)
 
-        # handle trainers
-        if isinstance(trainers, gluon.Trainer):
-            self.trainers = [trainers]
-        elif not trainers:
+        # handle trainer
+        if not trainer:
             warnings.warn("No trainer specified, default SGD optimizer "
                           "with learning rate 0.001 is used.")
-            self.trainers = [gluon.Trainer(self.net.collect_params(),
-                                           'sgd', {'learning_rate': 0.001})]
+            self.trainer = gluon.Trainer(self.net.collect_params(),
+                                         'sgd', {'learning_rate': 0.001})
+        elif not isinstance(trainer, gluon.Trainer):
+            raise ValueError("Trainer must be a Gluon Trainer instance, refer 
to "
+                             "gluon.Trainer:{}".format(trainer))
         else:
-            raise ValueError("Invalid trainer specified, please provide a 
valid gluon.Trainer")
+            self.trainer = trainer
 
     def _is_initialized(self):
         param_dict = self.net.collect_params()
@@ -212,8 +198,12 @@ class Estimator(object):
             # update metrics
             for metric in self.val_metrics:
                 metric.update(label, pred)
+                name, value = metric.get()
+                self.train_stats['val_' + name] = value
             for loss, loss_metric, in zip(losses, self.val_loss_metrics):
                 loss_metric.update(0, [l for l in loss])
+                name, value = loss_metric.get()
+                self.train_stats['val_' + name] = value
 
     def fit(self, train_data,
             val_data=None,
@@ -241,27 +231,38 @@ class Estimator(object):
             from a data batch and load into contexts(devices)
         """
 
-
-        self.epochs = epochs
+        self.max_epoch = epochs
         if not batch_size:
-            batch_size = 32 * len(self.context)
+            self.batch_size = 32 * len(self.context)
+        else:
+            self.batch_size = batch_size
+        self.stop_training = False
+        self.samples = None
+        self.batch_idx = 0
 
         event_handlers = event_handlers or []
         # provide default logging handler
         if not event_handlers or \
                 not any(isinstance(handler, LoggingHandler) for handler in 
event_handlers):
-            event_handlers.append(LoggingHandler(self))
+            event_handlers.append(LoggingHandler())
 
-        # training begin
+        train_begin, epoch_begin, batch_begin, \
+        batch_end, epoch_end, train_end = 
self._categorize_handlers(event_handlers)
+
+        # passing estimator to event handlers so they can access estimator 
information
+        # when a event is triggered
         for handler in event_handlers:
+            handler.estimator = self
+
+        # training begin
+        for handler in train_begin:
             handler.train_begin()
 
-        for epoch in range(epochs):
+        for epoch in range(self.max_epoch):
             # epoch begin
-            self.train_stats['epochs'].append(epoch)
-            
self.train_stats['learning_rate'].append(self.trainers[0].learning_rate)
+            self.current_epoch = epoch
 
-            for handler in event_handlers:
+            for handler in epoch_begin:
                 handler.epoch_begin()
 
             for metric in self.train_metrics + self.train_loss_metrics:
@@ -282,7 +283,7 @@ class Estimator(object):
                     data, label = batch_fn(batch, self.context)
 
                 # batch begin
-                for handler in event_handlers:
+                for handler in batch_begin:
                     handler.batch_begin()
 
                 with autograd.record():
@@ -298,42 +299,64 @@ class Estimator(object):
                 # update train metrics
                 for metric in self.train_metrics:
                     metric.update(label, pred)
-                    self.train_stats['batch_' + metric.name] = metric.get()[1]
+                    # get metric name and current value and update train stats
+                    name, value = metric.get()
+                    self.train_stats['train_' + name] = value
+
+                # update loss
                 for loss, loss_metric, in zip(losses, self.train_loss_metrics):
                     loss_metric.update(0, [l for l in loss])
-                    self.train_stats['batch_' + loss_metric.name] = 
loss_metric.get()[1]
-
-                try:
-                    completed_samples = len(train_data._dataset) if i == 
len(train_data._dataset) - 1 \
-                                        else batch_size * (i + 1)
-                    # We need to check if this is the last batch in the 
current epoch and select
-                    # the value to print appropriately
-                    self.train_stats['step'] = 
"{}/{}".format(completed_samples, len(train_data._dataset))
-                except AttributeError:
-                    self.train_stats['step'] = i
+                    name, value = loss_metric.get()
+                    self.train_stats['train_' + name] = value
 
-                for trainer in self.trainers:
-                    trainer.step(batch_size)
+                self.batch_idx = i
+                # record trained samples v.s. total samples if using Gluon 
DataLoader
+                if isinstance(train_data, gluon.data.DataLoader):
+                    self.samples = "{}/{}".format(self.batch_size * (i + 1), 
len(train_data._dataset))
 
+                self.trainer.step(self.batch_size)
                 # batch end
-                for handler in event_handlers:
+                for handler in batch_end:
                     handler.batch_end()
 
             if val_data:
                 self.evaluate(val_data, batch_fn)
 
-            for metric in self.train_metrics + self.train_loss_metrics:
-                self.train_stats['train_' + 
metric.name].append(metric.get()[1])
-            for metric in self.val_metrics + self.val_loss_metrics:
-                self.train_stats['val_' + metric.name].append(metric.get()[1])
-
             # epoch end
-            for handler in event_handlers:
+            for handler in epoch_end:
                 handler.epoch_end()
 
             if self.stop_training:
                 break
 
         # train end
-        for handler in event_handlers:
+        for handler in train_end:
             handler.train_end()
+
+    def _categorize_handlers(self, event_handlers):
+        """
+        categorize handlers into 6 event lists to avoid calling empty methods
+        for example, only event handlers with train_begin method
+        implemented will be called at train begin
+        """
+
+        train_begin = []
+        epoch_begin = []
+        batch_begin = []
+        batch_end = []
+        epoch_end = []
+        train_end = []
+        for handler in event_handlers:
+            if not handler.__class__.train_begin == EventHandler.train_begin:
+                train_begin.append(handler)
+            if not handler.__class__.epoch_begin == EventHandler.epoch_begin:
+                epoch_begin.append(handler)
+            if not handler.__class__.batch_begin == EventHandler.batch_begin:
+                batch_begin.append(handler)
+            if not handler.__class__.batch_end == EventHandler.batch_end:
+                batch_end.append(handler)
+            if not handler.__class__.epoch_end == EventHandler.epoch_end:
+                epoch_end.append(handler)
+            if not handler.__class__.train_end == EventHandler.train_end:
+                train_end.append(handler)
+        return train_begin, epoch_begin, batch_begin, batch_end, epoch_end, 
train_end
diff --git a/python/mxnet/gluon/estimator/event_handler.py 
b/python/mxnet/gluon/estimator/event_handler.py
index c59644e..7810074 100644
--- a/python/mxnet/gluon/estimator/event_handler.py
+++ b/python/mxnet/gluon/estimator/event_handler.py
@@ -40,7 +40,16 @@ class EventHandler(object):
         estimator : Estimator
             The :py:class:`Estimator` to get training statistics
         """
-    def __init__(self, estimator):
+
+    def __init__(self):
+        self._estimator = None
+
+    @property
+    def estimator(self):
+        return self._estimator
+
+    @estimator.setter
+    def estimator(self, estimator):
         self._estimator = estimator
 
     def train_begin(self):
@@ -78,8 +87,8 @@ class LoggingHandler(EventHandler):
         file location to save the logs
     """
 
-    def __init__(self, estimator, file_name=None, file_location=None, ):
-        super(LoggingHandler, self).__init__(estimator)
+    def __init__(self, file_name=None, file_location=None):
+        super(LoggingHandler, self).__init__()
         self.logger = logging.getLogger(__name__)
         self.logger.setLevel(logging.INFO)
         stream_handler = logging.StreamHandler()
@@ -92,22 +101,37 @@ class LoggingHandler(EventHandler):
             self.logger.addHandler(file_handler)
 
     def train_begin(self):
-        pass
+        self.train_start = time.time()
+        self.logger.info("Training begin: using optimizer %s "
+                         "with current learning rate %.4f ",
+                         self.estimator.trainer.optimizer.__class__.__name__,
+                         self.estimator.trainer.learning_rate)
+        self.logger.info("Train for %d epochs.", self.estimator.max_epoch)
 
     def train_end(self):
-        pass
+        train_time = time.time() - self.train_start
+        epoch = self.estimator.current_epoch
+        msg = 'Train finished using total %ds at epoch %d. ' % (train_time, 
epoch)
+        # log every result in train stats including train/validation loss & 
metrics
+        for key in self.estimator.train_stats:
+            msg += '%s : %.4f ' % (key, self.estimator.train_stats[key])
+        self.logger.info(msg)
 
     def batch_begin(self):
         self.batch_start = time.time()
 
     def batch_end(self):
         batch_time = time.time() - self.batch_start
-        epoch = self._estimator.train_stats['epochs'][-1]
-        step = self._estimator.train_stats['step']
-        msg = '[Epoch %d] [Step %s] time/step: %.3fs ' % (epoch, step, 
batch_time)
-        for key in self._estimator.train_stats.keys():
-            if key.startswith('batch_'):
-                msg += key[6:] + ': ' + '%.4f ' % 
self._estimator.train_stats[key]
+        epoch = self.estimator.current_epoch
+        batch = self.estimator.batch_idx
+        msg = '[Epoch %d] [Batch %d] ' % (epoch, batch)
+        if self.estimator.samples:
+            msg += '[Samples %s] ' % (self.estimator.samples)
+        msg += 'time/batch: %.3fs ' % batch_time
+        for key in self.estimator.train_stats:
+            # only log current training loss & metric after each batch
+            if key.startswith('train_'):
+                msg += key + ': ' + '%.4f ' % self.estimator.train_stats[key]
         self.logger.info(msg)
 
     def epoch_begin(self):
@@ -115,11 +139,11 @@ class LoggingHandler(EventHandler):
 
     def epoch_end(self):
         epoch_time = time.time() - self.epoch_start
-        epoch = self._estimator.train_stats['epochs'][-1]
+        epoch = self.estimator.current_epoch
         msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time)
-        for key in self._estimator.train_stats.keys():
-            if key.startswith('train_') or key.startswith('val_'):
-                msg += key + ': ' + '%.4f ' % 
self._estimator.train_stats[key][epoch]
+        # log every result in train stats including train/validation loss & 
metrics
+        for key in self.estimator.train_stats:
+            msg += '%s : %.4f ' % (key, self.estimator.train_stats[key])
         self.logger.info(msg)
 
 
@@ -148,14 +172,14 @@ class CheckpointHandler(EventHandler):
         intervals between saving the network
     """
 
-    def __init__(self, estimator,
+    def __init__(self,
                  filepath,
-                 monitor='val_loss',
+                 monitor='val_accuracy',
                  verbose=0,
                  save_best_only=False,
                  mode='auto',
                  period=1):
-        super(CheckpointHandler, self).__init__(estimator)
+        super(CheckpointHandler, self).__init__()
         self.monitor = monitor
         self.verbose = verbose
         self.filepath = filepath
@@ -186,7 +210,7 @@ class CheckpointHandler(EventHandler):
                 self.best = np.Inf
 
     def epoch_end(self, ):
-        epoch = self._estimator.train_stats['epochs'][-1]
+        epoch = self.estimator.current_epoch
         # add extension for weights
         if '.params' not in self.filepath:
             self.filepath += '.params'
@@ -194,20 +218,21 @@ class CheckpointHandler(EventHandler):
         if self.epochs_since_last_save >= self.period:
             self.epochs_since_last_save = 0
             if self.save_best_only:
-                # check if monitor exists in train_stats
-                if self.monitor not in self._estimator.train_stats:
-                    warnings.warn(RuntimeWarning('Unable to find %s in 
training statistics, make sure'
-                                                 'you are passing one of the 
metric names as monitor', self.monitor))
-                    self._estimator.net.save_parameters(self.filepath)
+                # check if monitor exists in train stats
+                if self.monitor not in self.estimator.train_stats:
+                    warnings.warn(RuntimeWarning('Unable to find %s in 
training statistics, make sure the monitor value'
+                                                 'starts with `train_ `or 
`val_` and contains loss/metric name, ',
+                                                 'for example val_accuracy', 
self.monitor))
+                    self.estimator.net.save_parameters(self.filepath)
                 else:
-                    current = self._estimator.train_stats[self.monitor][-1]
+                    current = self.estimator.train_stats[self.monitor]
                     if self.monitor_op(current, self.best):
                         if self.verbose > 0:
                             self.logger.info('\n[Epoch %d] %s improved from 
%0.5f to %0.5f,'
                                              ' saving model to %s',
                                              epoch, self.monitor, self.best, 
current, self.filepath)
                         self.best = current
-                        self._estimator.net.save_parameters(self.filepath)
+                        self.estimator.net.save_parameters(self.filepath)
                     else:
                         if self.verbose > 0:
                             self.logger.info('\n[Epoch %d] %s did not improve 
from %0.5f, skipping save model',
@@ -215,7 +240,7 @@ class CheckpointHandler(EventHandler):
             else:
                 if self.verbose > 0:
                     logging.info('\nEpoch %d: saving model to %s', epoch, 
self.filepath)
-                self._estimator.net.save_parameters(self.filepath)
+                self.estimator.net.save_parameters(self.filepath)
 
 
 class EarlyStoppingHandler(EventHandler):
@@ -238,15 +263,14 @@ class EarlyStoppingHandler(EventHandler):
         baseline value to compare the monitored value with
     """
 
-    def __init__(self, estimator,
-                 monitor='val_loss',
+    def __init__(self,
+                 monitor='val_accuracy',
                  min_delta=0,
                  patience=0,
                  mode='auto',
                  baseline=None):
-        super(EarlyStoppingHandler, self).__init__(estimator)
+        super(EarlyStoppingHandler, self).__init__()
 
-        self._estimator = estimator
         self.monitor = monitor
         self.baseline = baseline
         self.patience = patience
@@ -284,15 +308,13 @@ class EarlyStoppingHandler(EventHandler):
             self.best = np.Inf if self.monitor_op == np.less else -np.Inf
 
     def epoch_end(self):
-        epoch = self._estimator.train_stats['epochs'][-1]
-        if self.monitor not in self._estimator.train_stats:
-            warnings.warn(RuntimeWarning('Unable to find %s in training 
statistics, make sure'
-                                         'you are passing one of the metric 
names as monitor', self.monitor))
+        epoch = self.estimator.current_epoch
+        if self.monitor not in self.estimator.train_stats:
+            warnings.warn(RuntimeWarning('Unable to find %s in training 
statistics, make sure the monitor value'
+                                         'starts with `train_ `or `val_` and 
contains loss/metric name, ',
+                                         'for example val_accuracy', 
self.monitor))
         else:
-            current = self._estimator.train_stats[self.monitor][-1]
-            if current is None:
-                return
-
+            current = self.estimator.train_stats[self.monitor]
             if self.monitor_op(current - self.min_delta, self.best):
                 self.best = current
                 self.wait = 0
@@ -300,7 +322,7 @@ class EarlyStoppingHandler(EventHandler):
                 self.wait += 1
                 if self.wait >= self.patience:
                     self.stopped_epoch = epoch
-                    self._estimator.stop_training = True
+                    self.estimator.stop_training = True
 
     def train_end(self):
         if self.stopped_epoch > 0:
diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index 8060f38..44e8954 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -259,6 +259,13 @@ class Trainer(object):
         else:
             return self._optimizer.learning_rate
 
+    @property
+    def optimizer(self):
+        if isinstance(self._optimizer, opt.Optimizer):
+            return self._optimizer
+        else:
+            raise UserWarning("Optimizer has not been initialized yet")
+
     def set_learning_rate(self, lr):
         """Sets a new learning rate of the optimizer.
 
diff --git a/tests/python/unittest/test_gluon_estimator.py 
b/tests/python/unittest/test_gluon_estimator.py
index 85e61ce..25a410e 100644
--- a/tests/python/unittest/test_gluon_estimator.py
+++ b/tests/python/unittest/test_gluon_estimator.py
@@ -17,14 +17,15 @@
 
 ''' Unit tests for Gluon Estimator '''
 
-import unittest
 import sys
+import unittest
 import warnings
-from nose.tools import assert_raises
+
 import mxnet as mx
 from mxnet import gluon
 from mxnet.gluon import nn
-from mxnet.gluon.estimator import estimator
+from mxnet.gluon.estimator import Estimator, EventHandler
+from nose.tools import assert_raises
 
 
 def get_model():
@@ -43,11 +44,11 @@ def test_fit():
     acc = mx.metric.Accuracy()
     net.initialize(ctx=ctx)
     trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 
0.001})
-    est = estimator.Estimator(net=net,
-                              loss=loss,
-                              metrics=acc,
-                              trainers=trainer,
-                              context=ctx)
+    est = Estimator(net=net,
+                    loss=loss,
+                    metrics=acc,
+                    trainer=trainer,
+                    context=ctx)
     in_data = mx.nd.random.uniform(shape=(10, 3))
     out_data = mx.nd.random.uniform(shape=(10, 4))
     # Input dataloader
@@ -80,11 +81,11 @@ def test_validation():
     acc = mx.metric.Accuracy()
     net.initialize(ctx=ctx)
     trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 
0.001})
-    est = estimator.Estimator(net=net,
-                              loss=loss,
-                              metrics=acc,
-                              trainers=trainer,
-                              context=ctx)
+    est = Estimator(net=net,
+                    loss=loss,
+                    metrics=acc,
+                    trainer=trainer,
+                    context=ctx)
     in_data = mx.nd.random.uniform(shape=(10, 3))
     out_data = mx.nd.random.uniform(shape=(10, 4))
     # Input dataloader
@@ -125,10 +126,10 @@ def test_initializer():
     loss = gluon.loss.L2Loss()
     acc = mx.metric.Accuracy()
     # no initializer
-    est = estimator.Estimator(net=net,
-                              loss=loss,
-                              metrics=acc,
-                              context=ctx)
+    est = Estimator(net=net,
+                    loss=loss,
+                    metrics=acc,
+                    context=ctx)
     est.fit(train_data=train_data,
             epochs=num_epochs,
             batch_size=batch_size)
@@ -139,12 +140,12 @@ def test_initializer():
     trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 
0.001})
     # catch reinit warning
     with warnings.catch_warnings(record=True) as w:
-        est = estimator.Estimator(net=net,
-                                  loss=loss,
-                                  metrics=acc,
-                                  initializer=mx.init.MSRAPrelu(),
-                                  trainers=trainer,
-                                  context=ctx)
+        est = Estimator(net=net,
+                        loss=loss,
+                        metrics=acc,
+                        initializer=mx.init.MSRAPrelu(),
+                        trainer=trainer,
+                        context=ctx)
         assert 'Network already initialized' in str(w[-1].message)
     est.fit(train_data=train_data,
             epochs=num_epochs,
@@ -167,10 +168,10 @@ def test_trainer():
     net.initialize(ctx=ctx)
     # input no trainer
     with warnings.catch_warnings(record=True) as w:
-        est = estimator.Estimator(net=net,
-                                  loss=loss,
-                                  metrics=acc,
-                                  context=ctx)
+        est = Estimator(net=net,
+                        loss=loss,
+                        metrics=acc,
+                        context=ctx)
         assert 'No trainer specified' in str(w[-1].message)
     est.fit(train_data=train_data,
             epochs=num_epochs,
@@ -179,11 +180,11 @@ def test_trainer():
     # input invalid trainer
     trainer = 'sgd'
     with assert_raises(ValueError):
-        est = estimator.Estimator(net=net,
-                                  loss=loss,
-                                  metrics=acc,
-                                  trainers=trainer,
-                                  context=ctx)
+        est = Estimator(net=net,
+                        loss=loss,
+                        metrics=acc,
+                        trainer=trainer,
+                        context=ctx)
 
 
 def test_metric():
@@ -200,59 +201,54 @@ def test_metric():
     net.initialize(ctx=ctx)
     trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 
0.001})
     # input no metric
-    est = estimator.Estimator(net=net,
-                              loss=loss,
-                              trainers=trainer,
-                              context=ctx)
+    est = Estimator(net=net,
+                    loss=loss,
+                    trainer=trainer,
+                    context=ctx)
     est.fit(train_data=train_data,
             epochs=num_epochs,
             batch_size=batch_size)
     # input list of metrics
     metrics = [mx.metric.Accuracy(), mx.metric.Accuracy()]
-    est = estimator.Estimator(net=net,
-                              loss=loss,
-                              metrics=metrics,
-                              trainers=trainer,
-                              context=ctx)
+    est = Estimator(net=net,
+                    loss=loss,
+                    metrics=metrics,
+                    trainer=trainer,
+                    context=ctx)
     est.fit(train_data=train_data,
             epochs=num_epochs,
             batch_size=batch_size)
     # input invalid metric
     with assert_raises(ValueError):
-        est = estimator.Estimator(net=net,
-                                  loss=loss,
-                                  metrics='acc',
-                                  trainers=trainer,
-                                  context=ctx)
+        est = Estimator(net=net,
+                        loss=loss,
+                        metrics='acc',
+                        trainer=trainer,
+                        context=ctx)
     # test default metric
     loss = gluon.loss.SoftmaxCrossEntropyLoss()
-    est = estimator.Estimator(net=net,
-                              loss=loss,
-                              trainers=trainer,
-                              context=ctx)
+    est = Estimator(net=net,
+                    loss=loss,
+                    trainer=trainer,
+                    context=ctx)
     assert isinstance(est.train_metrics[0], mx.metric.Accuracy)
 
 
 def test_loss():
-    ''' test with no loss, invalid loss '''
+    ''' test with invalid loss '''
     net = get_model()
     ctx = mx.cpu()
     acc = mx.metric.Accuracy()
     net.initialize(ctx=ctx)
     trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 
0.001})
-    # input no loss
-    with assert_raises(ValueError):
-        est = estimator.Estimator(net=net,
-                                  trainers=trainer,
-                                  metrics=acc,
-                                  context=ctx)
     # input invalid loss
     with assert_raises(ValueError):
-        est = estimator.Estimator(net=net,
-                                  loss='mse',
-                                  metrics=acc,
-                                  trainers=trainer,
-                                  context=ctx)
+        est = Estimator(net=net,
+                        loss='mse',
+                        metrics=acc,
+                        trainer=trainer,
+                        context=ctx)
+
 
 def test_context():
     ''' test with no context, list of context, invalid context '''
@@ -260,18 +256,69 @@ def test_context():
     loss = gluon.loss.L2Loss()
     metrics = mx.metric.Accuracy()
     # input no context
-    est = estimator.Estimator(net=net,
-                              loss=loss,
-                              metrics=metrics)
+    est = Estimator(net=net,
+                    loss=loss,
+                    metrics=metrics)
     # input list of context
     ctx = [mx.gpu(0), mx.gpu(1)]
-    est = estimator.Estimator(net=net,
-                              loss=loss,
-                              metrics=metrics,
-                              context=ctx)
+    est = Estimator(net=net,
+                    loss=loss,
+                    metrics=metrics,
+                    context=ctx)
     # input invalid context
     with assert_raises(ValueError):
-        est = estimator.Estimator(net=net,
-                                  loss=loss,
-                                  metrics=metrics,
-                                  context='cpu')
+        est = Estimator(net=net,
+                        loss=loss,
+                        metrics=metrics,
+                        context='cpu')
+
+
+def test_categorize_handlers():
+    class CustomHandler1(EventHandler):
+        def __init__(self):
+            super(CustomHandler1, self).__init__()
+
+        def train_begin(self):
+            print("custom train begin")
+
+    class CustomHandler2(EventHandler):
+        def __init__(self):
+            super(CustomHandler2, self).__init__()
+
+        def epoch_begin(self):
+            print("custom epoch begin")
+
+        def batch_begin(self):
+            print("custom batch begin")
+
+        def train_end(self):
+            print("custom train end")
+
+    class CustomHandler3(EventHandler):
+        def __init__(self):
+            super(CustomHandler3, self).__init__()
+
+        def epoch_begin(self):
+            print("custom epoch begin")
+
+        def batch_begin(self):
+            print("custom batch begin")
+
+        def batch_end(self):
+            print("custom batch end")
+
+        def train_end(self):
+            print("custom train end")
+
+    net = nn.Sequential()
+    net.add(nn.Dense(10))
+    loss = gluon.loss.SoftmaxCrossEntropyLoss()
+    est = Estimator(net, loss=loss)
+    event_handlers = [CustomHandler1(), CustomHandler2(), CustomHandler3()]
+    train_begin, epoch_begin, batch_begin, \
+    batch_end, epoch_end, train_end = est._categorize_handlers(event_handlers)
+    assert len(train_begin) == 1
+    assert len(epoch_begin) == 2
+    assert len(batch_begin) == 2
+    assert len(batch_end) == 1
+    assert len(train_end) == 2
diff --git a/tests/python/unittest/test_gluon_event_handler.py 
b/tests/python/unittest/test_gluon_event_handler.py
index a551594..ccbcb54 100644
--- a/tests/python/unittest/test_gluon_event_handler.py
+++ b/tests/python/unittest/test_gluon_event_handler.py
@@ -45,7 +45,7 @@ def test_checkpoint_handler():
     ce_loss = loss.SoftmaxCrossEntropyLoss()
     acc = mx.metric.Accuracy()
     est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
-    checkpoint_handler = [event_handler.CheckpointHandler(est, file_path,
+    checkpoint_handler = [event_handler.CheckpointHandler(file_path,
                                                           
save_best_only=save_best_only,
                                                           mode=mode)]
     est.fit(test_data, event_handlers=checkpoint_handler, epochs=1)
@@ -63,15 +63,15 @@ def test_early_stopping():
     ce_loss = loss.SoftmaxCrossEntropyLoss()
     acc = mx.metric.Accuracy()
     est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
-    early_stopping = [event_handler.EarlyStoppingHandler(est, monitor,
+    early_stopping = [event_handler.EarlyStoppingHandler(monitor,
                                                          patience=patience,
-                                                          mode=mode)]
-    est.fit(test_data, event_handlers=early_stopping, epochs=1)
+                                                         mode=mode)]
+    est.fit(test_data, event_handlers=early_stopping, epochs=3)
 
     mode = 'auto'
     monitor = 'train_accuracy'
     patience = 2
-    early_stopping = [event_handler.EarlyStoppingHandler(est, monitor,
+    early_stopping = [event_handler.EarlyStoppingHandler(monitor,
                                                          patience=patience,
                                                           mode=mode)]
     est.fit(test_data, event_handlers=early_stopping, epochs=1)
@@ -86,7 +86,7 @@ def test_logging():
     ce_loss = loss.SoftmaxCrossEntropyLoss()
     acc = mx.metric.Accuracy()
     est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
-    logging_handler = [event_handler.LoggingHandler(est, file_name=file_name, 
file_location=tmpdir)]
+    logging_handler = [event_handler.LoggingHandler(file_name=file_name, 
file_location=tmpdir)]
     est.fit(test_data, event_handlers=logging_handler, epochs=1)
     assert os.path.isfile(output_dir)
     os.remove(output_dir)
\ No newline at end of file

Reply via email to