roywei commented on a change in pull request #14685: [Fit API] improve event
handlers
URL: https://github.com/apache/incubator-mxnet/pull/14685#discussion_r277094415
##########
File path: python/mxnet/gluon/contrib/estimator/estimator.py
##########
@@ -226,111 +230,72 @@ def fit(self, train_data,
custom batch function to extract data and label
from a data batch and load into contexts(devices)
"""
-
- self.max_epoch = epochs
- self.stop_training = False
- self.processed_samples = None
- self.batch_idx = 0
-
+ self.max_epochs = epochs
event_handlers = event_handlers or []
# provide default logging handler
- if not event_handlers or \
- not any(isinstance(handler, LoggingHandler) for handler in
event_handlers):
- event_handlers.append(LoggingHandler())
- warnings.warn("No Event Handler specified, default
`LoggingHandler()` "
- "is used with
verbose=LoggingHandler.LOG_VERBOSITY_PER_EPOCH. "
- "Please look at gluon.estimator.event_handler for
more detail.")
+ if not event_handlers:
+ train_metrics, val_metrics = self.prepare_loss_and_metrics()
+ event_handlers.append(MetricHandler(train_metrics=train_metrics))
+ if val_data:
+ event_handlers.append(ValidationHandler(val_data=val_data,
eval_fn=self.evaluate,
+
val_metrics=val_metrics))
+ event_handlers.append(LoggingHandler(train_metrics=train_metrics,
+ val_metrics=val_metrics))
+ warnings.warn("No Event Handler specified, default %s are used. "
+ "Please look at
gluon.contrib.estimator.event_handler for more detail." %
+ ", ".join([handler.__class__.__name__ for handler in
event_handlers]))
+
+ event_handlers.sort(key=lambda handler: getattr(handler, 'rank', 0),
reverse=True)
train_begin, epoch_begin, batch_begin, \
batch_end, epoch_end, train_end =
self._categorize_handlers(event_handlers)
- # passing estimator to event handlers so they can access estimator
information
- # when a event is triggered
- for handler in event_handlers:
- handler.estimator = self
-
+ # only pass a weak reference to all event handlers
+ estimator_ref = weakref.proxy(self)
# training begin
for handler in train_begin:
- handler.train_begin()
+ handler.train_begin(estimator_ref)
- for epoch in range(self.max_epoch):
+ for epoch in range(epochs):
# epoch begin
- self.current_epoch = epoch
- # Number of samples trained after every batch
- completed_samples = 0
-
for handler in epoch_begin:
- handler.epoch_begin()
-
- for metric in self.train_metrics + self.train_loss_metrics:
- metric.reset()
+ handler.epoch_begin(estimator_ref)
for i, batch in enumerate(train_data):
- if not batch_fn:
- if isinstance(train_data, gluon.data.DataLoader):
- data, label = self._batch_fn(batch, self.context)
- else:
- raise ValueError("You are using a custom iteration,
please also provide "
- "batch_fn to extract data and label.
Alternatively, you "
- "can provide the data as
gluon.data.DataLoader")
- else:
- data, label = batch_fn(batch, self.context)
+ if not isinstance(train_data, gluon.data.DataLoader):
+ raise ValueError("Estimator only support input as Gluon
DataLoader. Alternatively, you "
+ "can transform your DataIter or any
NDArray into Gluon DataLoader. "
+ "Refer to gluon.data.dataloader")
+ data, label = self._get_data_and_label(batch, self.context)
batch_size = batch[0].shape[0]
# batch begin
for handler in batch_begin:
- handler.batch_begin()
+ handler.batch_begin(estimator_ref, batch=batch)
with autograd.record():
pred = [self.net(x) for x in data]
- losses = []
- for loss in self.loss:
- losses.append([loss(y_hat, y) for y_hat, y in
zip(pred, label)])
+ loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred,
label)]
Review comment:
as above https://issues.apache.org/jira/browse/MXNET-1395
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services