This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch fit-api
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/fit-api by this push:
new ed7f6e5 [MXNet-1340][Fit API]Update train stats (#14494)
ed7f6e5 is described below
commit ed7f6e56a4e372d5d460031186145065f5657893
Author: Lai Wei <[email protected]>
AuthorDate: Wed Apr 3 14:18:13 2019 -0700
[MXNet-1340][Fit API]Update train stats (#14494)
* add train history
* update history
* update test
* avoid calling empty methods
* remove train history object
* fix pylint
* add unit test
* fix test
* update categorize handlers
---
python/mxnet/gluon/estimator/estimator.py | 147 +++++++++-------
python/mxnet/gluon/estimator/event_handler.py | 102 +++++++-----
python/mxnet/gluon/trainer.py | 7 +
tests/python/unittest/test_gluon_estimator.py | 193 ++++++++++++++--------
tests/python/unittest/test_gluon_event_handler.py | 12 +-
5 files changed, 280 insertions(+), 181 deletions(-)
diff --git a/python/mxnet/gluon/estimator/estimator.py
b/python/mxnet/gluon/estimator/estimator.py
index e759fa7..c5da0c0 100644
--- a/python/mxnet/gluon/estimator/estimator.py
+++ b/python/mxnet/gluon/estimator/estimator.py
@@ -22,7 +22,7 @@
import copy
import warnings
-from .event_handler import LoggingHandler
+from .event_handler import EventHandler, LoggingHandler
from ... import gluon, autograd
from ...context import Context, cpu, gpu, num_gpus
from ...io import DataIter
@@ -39,27 +39,26 @@ class Estimator(object):
Parameters
----------
- loss : Loss or list of Loss
+ loss : gluon.loss.Loss or list of gluon.loss.Loss
Loss(objective functions) to calculate during training
metrics : EvalMetric or list of EvalMetric
Metrics for evaluating models
initializer : Initializer
initializer to initialize the network
- trainers : Trainer or list of Trainer
- Trainers to apply optimizers on network parameters
+ trainer : Trainer
+ Trainer to apply optimizer on network parameters
context : Context or list of Context
devices to run the training on
"""
def __init__(self, net,
- loss=None,
+ loss,
metrics=None,
initializer=None,
- trainers=None,
+ trainer=None,
context=None):
self.net = net
- self.stop_training = False
if isinstance(loss, gluon.loss.Loss):
self.loss = [loss]
@@ -86,27 +85,14 @@ class Estimator(object):
# store training statistics
self.train_stats = {}
- self.train_stats['epochs'] = []
- self.train_stats['learning_rate'] = []
- # current step of the epoch
- self.train_stats['step'] = ''
- for metric in self.train_metrics:
- # record a history of metrics over each epoch
- self.train_stats['train_' + metric.name] = []
- # only record the latest metric numbers after each batch
- self.train_stats['batch_' + metric.name] = 0.
- for metric in self.val_metrics:
- self.train_stats['val_' + metric.name] = []
+
+ # separate train and validation
self.train_loss_metrics = []
self.val_loss_metrics = []
# using the metric wrapper for loss to record loss value
for l in self.loss:
self.train_loss_metrics.append(Loss(l.name))
self.val_loss_metrics.append(Loss(l.name))
- self.train_stats['train_' + l.name] = []
- self.train_stats['val_' + l.name] = []
- # only record the latest loss numbers after each batch
- self.train_stats['batch_' + l.name] = 0.
# handle context
if isinstance(context, Context):
@@ -127,7 +113,6 @@ class Estimator(object):
raise ValueError("context must be a Context or a list of Context, "
"refer to mxnet.Context:{}".format(context))
-
# initialize the network
self.initializer = initializer
if self.initializer:
@@ -135,7 +120,7 @@ class Estimator(object):
# if already initialized, re-init with user specified
initializer
warnings.warn("Network already initialized, re-initializing
with %s. "
"You don't need to pass initializer if you
already "
- "initialized your net."%
type(self.initializer).__name__)
+ "initialized your net." %
type(self.initializer).__name__)
self.net.initialize(init=self.initializer, ctx=self.context,
force_reinit=True)
else:
# initialize with user specified initializer
@@ -144,16 +129,17 @@ class Estimator(object):
if not self._is_initialized():
self.net.initialize(ctx=self.context)
- # handle trainers
- if isinstance(trainers, gluon.Trainer):
- self.trainers = [trainers]
- elif not trainers:
+ # handle trainer
+ if not trainer:
warnings.warn("No trainer specified, default SGD optimizer "
"with learning rate 0.001 is used.")
- self.trainers = [gluon.Trainer(self.net.collect_params(),
- 'sgd', {'learning_rate': 0.001})]
+ self.trainer = gluon.Trainer(self.net.collect_params(),
+ 'sgd', {'learning_rate': 0.001})
+ elif not isinstance(trainer, gluon.Trainer):
+ raise ValueError("Trainer must be a Gluon Trainer instance, refer
to "
+ "gluon.Trainer:{}".format(trainer))
else:
- raise ValueError("Invalid trainer specified, please provide a
valid gluon.Trainer")
+ self.trainer = trainer
def _is_initialized(self):
param_dict = self.net.collect_params()
@@ -212,8 +198,12 @@ class Estimator(object):
# update metrics
for metric in self.val_metrics:
metric.update(label, pred)
+ name, value = metric.get()
+ self.train_stats['val_' + name] = value
for loss, loss_metric, in zip(losses, self.val_loss_metrics):
loss_metric.update(0, [l for l in loss])
+ name, value = loss_metric.get()
+ self.train_stats['val_' + name] = value
def fit(self, train_data,
val_data=None,
@@ -241,27 +231,38 @@ class Estimator(object):
from a data batch and load into contexts(devices)
"""
-
- self.epochs = epochs
+ self.max_epoch = epochs
if not batch_size:
- batch_size = 32 * len(self.context)
+ self.batch_size = 32 * len(self.context)
+ else:
+ self.batch_size = batch_size
+ self.stop_training = False
+ self.samples = None
+ self.batch_idx = 0
event_handlers = event_handlers or []
# provide default logging handler
if not event_handlers or \
not any(isinstance(handler, LoggingHandler) for handler in
event_handlers):
- event_handlers.append(LoggingHandler(self))
+ event_handlers.append(LoggingHandler())
- # training begin
+ train_begin, epoch_begin, batch_begin, \
+ batch_end, epoch_end, train_end =
self._categorize_handlers(event_handlers)
+
+ # passing estimator to event handlers so they can access estimator
information
+ # when a event is triggered
for handler in event_handlers:
+ handler.estimator = self
+
+ # training begin
+ for handler in train_begin:
handler.train_begin()
- for epoch in range(epochs):
+ for epoch in range(self.max_epoch):
# epoch begin
- self.train_stats['epochs'].append(epoch)
-
self.train_stats['learning_rate'].append(self.trainers[0].learning_rate)
+ self.current_epoch = epoch
- for handler in event_handlers:
+ for handler in epoch_begin:
handler.epoch_begin()
for metric in self.train_metrics + self.train_loss_metrics:
@@ -282,7 +283,7 @@ class Estimator(object):
data, label = batch_fn(batch, self.context)
# batch begin
- for handler in event_handlers:
+ for handler in batch_begin:
handler.batch_begin()
with autograd.record():
@@ -298,42 +299,64 @@ class Estimator(object):
# update train metrics
for metric in self.train_metrics:
metric.update(label, pred)
- self.train_stats['batch_' + metric.name] = metric.get()[1]
+ # get metric name and current value and update train stats
+ name, value = metric.get()
+ self.train_stats['train_' + name] = value
+
+ # update loss
for loss, loss_metric, in zip(losses, self.train_loss_metrics):
loss_metric.update(0, [l for l in loss])
- self.train_stats['batch_' + loss_metric.name] =
loss_metric.get()[1]
-
- try:
- completed_samples = len(train_data._dataset) if i ==
len(train_data._dataset) - 1 \
- else batch_size * (i + 1)
- # We need to check if this is the last batch in the
current epoch and select
- # the value to print appropriately
- self.train_stats['step'] =
"{}/{}".format(completed_samples, len(train_data._dataset))
- except AttributeError:
- self.train_stats['step'] = i
+ name, value = loss_metric.get()
+ self.train_stats['train_' + name] = value
- for trainer in self.trainers:
- trainer.step(batch_size)
+ self.batch_idx = i
+ # record trained samples v.s. total samples if using Gluon
DataLoader
+ if isinstance(train_data, gluon.data.DataLoader):
+ self.samples = "{}/{}".format(self.batch_size * (i + 1),
len(train_data._dataset))
+ self.trainer.step(self.batch_size)
# batch end
- for handler in event_handlers:
+ for handler in batch_end:
handler.batch_end()
if val_data:
self.evaluate(val_data, batch_fn)
- for metric in self.train_metrics + self.train_loss_metrics:
- self.train_stats['train_' +
metric.name].append(metric.get()[1])
- for metric in self.val_metrics + self.val_loss_metrics:
- self.train_stats['val_' + metric.name].append(metric.get()[1])
-
# epoch end
- for handler in event_handlers:
+ for handler in epoch_end:
handler.epoch_end()
if self.stop_training:
break
# train end
- for handler in event_handlers:
+ for handler in train_end:
handler.train_end()
+
+ def _categorize_handlers(self, event_handlers):
+ """
+ categorize handlers into 6 event lists to avoid calling empty methods
+ for example, only event handlers with train_begin method
+ implemented will be called at train begin
+ """
+
+ train_begin = []
+ epoch_begin = []
+ batch_begin = []
+ batch_end = []
+ epoch_end = []
+ train_end = []
+ for handler in event_handlers:
+ if not handler.__class__.train_begin == EventHandler.train_begin:
+ train_begin.append(handler)
+ if not handler.__class__.epoch_begin == EventHandler.epoch_begin:
+ epoch_begin.append(handler)
+ if not handler.__class__.batch_begin == EventHandler.batch_begin:
+ batch_begin.append(handler)
+ if not handler.__class__.batch_end == EventHandler.batch_end:
+ batch_end.append(handler)
+ if not handler.__class__.epoch_end == EventHandler.epoch_end:
+ epoch_end.append(handler)
+ if not handler.__class__.train_end == EventHandler.train_end:
+ train_end.append(handler)
+ return train_begin, epoch_begin, batch_begin, batch_end, epoch_end,
train_end
diff --git a/python/mxnet/gluon/estimator/event_handler.py
b/python/mxnet/gluon/estimator/event_handler.py
index c59644e..7810074 100644
--- a/python/mxnet/gluon/estimator/event_handler.py
+++ b/python/mxnet/gluon/estimator/event_handler.py
@@ -40,7 +40,16 @@ class EventHandler(object):
estimator : Estimator
The :py:class:`Estimator` to get training statistics
"""
- def __init__(self, estimator):
+
+ def __init__(self):
+ self._estimator = None
+
+ @property
+ def estimator(self):
+ return self._estimator
+
+ @estimator.setter
+ def estimator(self, estimator):
self._estimator = estimator
def train_begin(self):
@@ -78,8 +87,8 @@ class LoggingHandler(EventHandler):
file location to save the logs
"""
- def __init__(self, estimator, file_name=None, file_location=None, ):
- super(LoggingHandler, self).__init__(estimator)
+ def __init__(self, file_name=None, file_location=None):
+ super(LoggingHandler, self).__init__()
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
stream_handler = logging.StreamHandler()
@@ -92,22 +101,37 @@ class LoggingHandler(EventHandler):
self.logger.addHandler(file_handler)
def train_begin(self):
- pass
+ self.train_start = time.time()
+ self.logger.info("Training begin: using optimizer %s "
+ "with current learning rate %.4f ",
+ self.estimator.trainer.optimizer.__class__.__name__,
+ self.estimator.trainer.learning_rate)
+ self.logger.info("Train for %d epochs.", self.estimator.max_epoch)
def train_end(self):
- pass
+ train_time = time.time() - self.train_start
+ epoch = self.estimator.current_epoch
+ msg = 'Train finished using total %ds at epoch %d. ' % (train_time,
epoch)
+ # log every result in train stats including train/validation loss &
metrics
+ for key in self.estimator.train_stats:
+ msg += '%s : %.4f ' % (key, self.estimator.train_stats[key])
+ self.logger.info(msg)
def batch_begin(self):
self.batch_start = time.time()
def batch_end(self):
batch_time = time.time() - self.batch_start
- epoch = self._estimator.train_stats['epochs'][-1]
- step = self._estimator.train_stats['step']
- msg = '[Epoch %d] [Step %s] time/step: %.3fs ' % (epoch, step,
batch_time)
- for key in self._estimator.train_stats.keys():
- if key.startswith('batch_'):
- msg += key[6:] + ': ' + '%.4f ' %
self._estimator.train_stats[key]
+ epoch = self.estimator.current_epoch
+ batch = self.estimator.batch_idx
+ msg = '[Epoch %d] [Batch %d] ' % (epoch, batch)
+ if self.estimator.samples:
+ msg += '[Samples %s] ' % (self.estimator.samples)
+ msg += 'time/batch: %.3fs ' % batch_time
+ for key in self.estimator.train_stats:
+ # only log current training loss & metric after each batch
+ if key.startswith('train_'):
+ msg += key + ': ' + '%.4f ' % self.estimator.train_stats[key]
self.logger.info(msg)
def epoch_begin(self):
@@ -115,11 +139,11 @@ class LoggingHandler(EventHandler):
def epoch_end(self):
epoch_time = time.time() - self.epoch_start
- epoch = self._estimator.train_stats['epochs'][-1]
+ epoch = self.estimator.current_epoch
msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time)
- for key in self._estimator.train_stats.keys():
- if key.startswith('train_') or key.startswith('val_'):
- msg += key + ': ' + '%.4f ' %
self._estimator.train_stats[key][epoch]
+ # log every result in train stats including train/validation loss &
metrics
+ for key in self.estimator.train_stats:
+ msg += '%s : %.4f ' % (key, self.estimator.train_stats[key])
self.logger.info(msg)
@@ -148,14 +172,14 @@ class CheckpointHandler(EventHandler):
intervals between saving the network
"""
- def __init__(self, estimator,
+ def __init__(self,
filepath,
- monitor='val_loss',
+ monitor='val_accuracy',
verbose=0,
save_best_only=False,
mode='auto',
period=1):
- super(CheckpointHandler, self).__init__(estimator)
+ super(CheckpointHandler, self).__init__()
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
@@ -186,7 +210,7 @@ class CheckpointHandler(EventHandler):
self.best = np.Inf
def epoch_end(self, ):
- epoch = self._estimator.train_stats['epochs'][-1]
+ epoch = self.estimator.current_epoch
# add extension for weights
if '.params' not in self.filepath:
self.filepath += '.params'
@@ -194,20 +218,21 @@ class CheckpointHandler(EventHandler):
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save = 0
if self.save_best_only:
- # check if monitor exists in train_stats
- if self.monitor not in self._estimator.train_stats:
- warnings.warn(RuntimeWarning('Unable to find %s in
training statistics, make sure'
- 'you are passing one of the
metric names as monitor', self.monitor))
- self._estimator.net.save_parameters(self.filepath)
+ # check if monitor exists in train stats
+ if self.monitor not in self.estimator.train_stats:
+ warnings.warn(RuntimeWarning('Unable to find %s in
training statistics, make sure the monitor value'
+ 'starts with `train_ `or
`val_` and contains loss/metric name, ',
+ 'for example val_accuracy',
self.monitor))
+ self.estimator.net.save_parameters(self.filepath)
else:
- current = self._estimator.train_stats[self.monitor][-1]
+ current = self.estimator.train_stats[self.monitor]
if self.monitor_op(current, self.best):
if self.verbose > 0:
self.logger.info('\n[Epoch %d] %s improved from
%0.5f to %0.5f,'
' saving model to %s',
epoch, self.monitor, self.best,
current, self.filepath)
self.best = current
- self._estimator.net.save_parameters(self.filepath)
+ self.estimator.net.save_parameters(self.filepath)
else:
if self.verbose > 0:
self.logger.info('\n[Epoch %d] %s did not improve
from %0.5f, skipping save model',
@@ -215,7 +240,7 @@ class CheckpointHandler(EventHandler):
else:
if self.verbose > 0:
logging.info('\nEpoch %d: saving model to %s', epoch,
self.filepath)
- self._estimator.net.save_parameters(self.filepath)
+ self.estimator.net.save_parameters(self.filepath)
class EarlyStoppingHandler(EventHandler):
@@ -238,15 +263,14 @@ class EarlyStoppingHandler(EventHandler):
baseline value to compare the monitored value with
"""
- def __init__(self, estimator,
- monitor='val_loss',
+ def __init__(self,
+ monitor='val_accuracy',
min_delta=0,
patience=0,
mode='auto',
baseline=None):
- super(EarlyStoppingHandler, self).__init__(estimator)
+ super(EarlyStoppingHandler, self).__init__()
- self._estimator = estimator
self.monitor = monitor
self.baseline = baseline
self.patience = patience
@@ -284,15 +308,13 @@ class EarlyStoppingHandler(EventHandler):
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def epoch_end(self):
- epoch = self._estimator.train_stats['epochs'][-1]
- if self.monitor not in self._estimator.train_stats:
- warnings.warn(RuntimeWarning('Unable to find %s in training
statistics, make sure'
- 'you are passing one of the metric
names as monitor', self.monitor))
+ epoch = self.estimator.current_epoch
+ if self.monitor not in self.estimator.train_stats:
+ warnings.warn(RuntimeWarning('Unable to find %s in training
statistics, make sure the monitor value'
+ 'starts with `train_ `or `val_` and
contains loss/metric name, ',
+ 'for example val_accuracy',
self.monitor))
else:
- current = self._estimator.train_stats[self.monitor][-1]
- if current is None:
- return
-
+ current = self.estimator.train_stats[self.monitor]
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
@@ -300,7 +322,7 @@ class EarlyStoppingHandler(EventHandler):
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
- self._estimator.stop_training = True
+ self.estimator.stop_training = True
def train_end(self):
if self.stopped_epoch > 0:
diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index 8060f38..44e8954 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -259,6 +259,13 @@ class Trainer(object):
else:
return self._optimizer.learning_rate
+ @property
+ def optimizer(self):
+ if isinstance(self._optimizer, opt.Optimizer):
+ return self._optimizer
+ else:
+ raise UserWarning("Optimizer has not been initialized yet")
+
def set_learning_rate(self, lr):
"""Sets a new learning rate of the optimizer.
diff --git a/tests/python/unittest/test_gluon_estimator.py
b/tests/python/unittest/test_gluon_estimator.py
index 85e61ce..25a410e 100644
--- a/tests/python/unittest/test_gluon_estimator.py
+++ b/tests/python/unittest/test_gluon_estimator.py
@@ -17,14 +17,15 @@
''' Unit tests for Gluon Estimator '''
-import unittest
import sys
+import unittest
import warnings
-from nose.tools import assert_raises
+
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
-from mxnet.gluon.estimator import estimator
+from mxnet.gluon.estimator import Estimator, EventHandler
+from nose.tools import assert_raises
def get_model():
@@ -43,11 +44,11 @@ def test_fit():
acc = mx.metric.Accuracy()
net.initialize(ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate':
0.001})
- est = estimator.Estimator(net=net,
- loss=loss,
- metrics=acc,
- trainers=trainer,
- context=ctx)
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ trainer=trainer,
+ context=ctx)
in_data = mx.nd.random.uniform(shape=(10, 3))
out_data = mx.nd.random.uniform(shape=(10, 4))
# Input dataloader
@@ -80,11 +81,11 @@ def test_validation():
acc = mx.metric.Accuracy()
net.initialize(ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate':
0.001})
- est = estimator.Estimator(net=net,
- loss=loss,
- metrics=acc,
- trainers=trainer,
- context=ctx)
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ trainer=trainer,
+ context=ctx)
in_data = mx.nd.random.uniform(shape=(10, 3))
out_data = mx.nd.random.uniform(shape=(10, 4))
# Input dataloader
@@ -125,10 +126,10 @@ def test_initializer():
loss = gluon.loss.L2Loss()
acc = mx.metric.Accuracy()
# no initializer
- est = estimator.Estimator(net=net,
- loss=loss,
- metrics=acc,
- context=ctx)
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ context=ctx)
est.fit(train_data=train_data,
epochs=num_epochs,
batch_size=batch_size)
@@ -139,12 +140,12 @@ def test_initializer():
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate':
0.001})
# catch reinit warning
with warnings.catch_warnings(record=True) as w:
- est = estimator.Estimator(net=net,
- loss=loss,
- metrics=acc,
- initializer=mx.init.MSRAPrelu(),
- trainers=trainer,
- context=ctx)
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ initializer=mx.init.MSRAPrelu(),
+ trainer=trainer,
+ context=ctx)
assert 'Network already initialized' in str(w[-1].message)
est.fit(train_data=train_data,
epochs=num_epochs,
@@ -167,10 +168,10 @@ def test_trainer():
net.initialize(ctx=ctx)
# input no trainer
with warnings.catch_warnings(record=True) as w:
- est = estimator.Estimator(net=net,
- loss=loss,
- metrics=acc,
- context=ctx)
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ context=ctx)
assert 'No trainer specified' in str(w[-1].message)
est.fit(train_data=train_data,
epochs=num_epochs,
@@ -179,11 +180,11 @@ def test_trainer():
# input invalid trainer
trainer = 'sgd'
with assert_raises(ValueError):
- est = estimator.Estimator(net=net,
- loss=loss,
- metrics=acc,
- trainers=trainer,
- context=ctx)
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ trainer=trainer,
+ context=ctx)
def test_metric():
@@ -200,59 +201,54 @@ def test_metric():
net.initialize(ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate':
0.001})
# input no metric
- est = estimator.Estimator(net=net,
- loss=loss,
- trainers=trainer,
- context=ctx)
+ est = Estimator(net=net,
+ loss=loss,
+ trainer=trainer,
+ context=ctx)
est.fit(train_data=train_data,
epochs=num_epochs,
batch_size=batch_size)
# input list of metrics
metrics = [mx.metric.Accuracy(), mx.metric.Accuracy()]
- est = estimator.Estimator(net=net,
- loss=loss,
- metrics=metrics,
- trainers=trainer,
- context=ctx)
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=metrics,
+ trainer=trainer,
+ context=ctx)
est.fit(train_data=train_data,
epochs=num_epochs,
batch_size=batch_size)
# input invalid metric
with assert_raises(ValueError):
- est = estimator.Estimator(net=net,
- loss=loss,
- metrics='acc',
- trainers=trainer,
- context=ctx)
+ est = Estimator(net=net,
+ loss=loss,
+ metrics='acc',
+ trainer=trainer,
+ context=ctx)
# test default metric
loss = gluon.loss.SoftmaxCrossEntropyLoss()
- est = estimator.Estimator(net=net,
- loss=loss,
- trainers=trainer,
- context=ctx)
+ est = Estimator(net=net,
+ loss=loss,
+ trainer=trainer,
+ context=ctx)
assert isinstance(est.train_metrics[0], mx.metric.Accuracy)
def test_loss():
- ''' test with no loss, invalid loss '''
+ ''' test with invalid loss '''
net = get_model()
ctx = mx.cpu()
acc = mx.metric.Accuracy()
net.initialize(ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate':
0.001})
- # input no loss
- with assert_raises(ValueError):
- est = estimator.Estimator(net=net,
- trainers=trainer,
- metrics=acc,
- context=ctx)
# input invalid loss
with assert_raises(ValueError):
- est = estimator.Estimator(net=net,
- loss='mse',
- metrics=acc,
- trainers=trainer,
- context=ctx)
+ est = Estimator(net=net,
+ loss='mse',
+ metrics=acc,
+ trainer=trainer,
+ context=ctx)
+
def test_context():
''' test with no context, list of context, invalid context '''
@@ -260,18 +256,69 @@ def test_context():
loss = gluon.loss.L2Loss()
metrics = mx.metric.Accuracy()
# input no context
- est = estimator.Estimator(net=net,
- loss=loss,
- metrics=metrics)
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=metrics)
# input list of context
ctx = [mx.gpu(0), mx.gpu(1)]
- est = estimator.Estimator(net=net,
- loss=loss,
- metrics=metrics,
- context=ctx)
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=metrics,
+ context=ctx)
# input invalid context
with assert_raises(ValueError):
- est = estimator.Estimator(net=net,
- loss=loss,
- metrics=metrics,
- context='cpu')
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=metrics,
+ context='cpu')
+
+
+def test_categorize_handlers():
+ class CustomHandler1(EventHandler):
+ def __init__(self):
+ super(CustomHandler1, self).__init__()
+
+ def train_begin(self):
+ print("custom train begin")
+
+ class CustomHandler2(EventHandler):
+ def __init__(self):
+ super(CustomHandler2, self).__init__()
+
+ def epoch_begin(self):
+ print("custom epoch begin")
+
+ def batch_begin(self):
+ print("custom batch begin")
+
+ def train_end(self):
+ print("custom train end")
+
+ class CustomHandler3(EventHandler):
+ def __init__(self):
+ super(CustomHandler3, self).__init__()
+
+ def epoch_begin(self):
+ print("custom epoch begin")
+
+ def batch_begin(self):
+ print("custom batch begin")
+
+ def batch_end(self):
+ print("custom batch end")
+
+ def train_end(self):
+ print("custom train end")
+
+ net = nn.Sequential()
+ net.add(nn.Dense(10))
+ loss = gluon.loss.SoftmaxCrossEntropyLoss()
+ est = Estimator(net, loss=loss)
+ event_handlers = [CustomHandler1(), CustomHandler2(), CustomHandler3()]
+ train_begin, epoch_begin, batch_begin, \
+ batch_end, epoch_end, train_end = est._categorize_handlers(event_handlers)
+ assert len(train_begin) == 1
+ assert len(epoch_begin) == 2
+ assert len(batch_begin) == 2
+ assert len(batch_end) == 1
+ assert len(train_end) == 2
diff --git a/tests/python/unittest/test_gluon_event_handler.py
b/tests/python/unittest/test_gluon_event_handler.py
index a551594..ccbcb54 100644
--- a/tests/python/unittest/test_gluon_event_handler.py
+++ b/tests/python/unittest/test_gluon_event_handler.py
@@ -45,7 +45,7 @@ def test_checkpoint_handler():
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
- checkpoint_handler = [event_handler.CheckpointHandler(est, file_path,
+ checkpoint_handler = [event_handler.CheckpointHandler(file_path,
save_best_only=save_best_only,
mode=mode)]
est.fit(test_data, event_handlers=checkpoint_handler, epochs=1)
@@ -63,15 +63,15 @@ def test_early_stopping():
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
- early_stopping = [event_handler.EarlyStoppingHandler(est, monitor,
+ early_stopping = [event_handler.EarlyStoppingHandler(monitor,
patience=patience,
- mode=mode)]
- est.fit(test_data, event_handlers=early_stopping, epochs=1)
+ mode=mode)]
+ est.fit(test_data, event_handlers=early_stopping, epochs=3)
mode = 'auto'
monitor = 'train_accuracy'
patience = 2
- early_stopping = [event_handler.EarlyStoppingHandler(est, monitor,
+ early_stopping = [event_handler.EarlyStoppingHandler(monitor,
patience=patience,
mode=mode)]
est.fit(test_data, event_handlers=early_stopping, epochs=1)
@@ -86,7 +86,7 @@ def test_logging():
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
- logging_handler = [event_handler.LoggingHandler(est, file_name=file_name,
file_location=tmpdir)]
+ logging_handler = [event_handler.LoggingHandler(file_name=file_name,
file_location=tmpdir)]
est.fit(test_data, event_handlers=logging_handler, epochs=1)
assert os.path.isfile(output_dir)
os.remove(output_dir)
\ No newline at end of file