This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new e1d8c66  Enable the reporting of cross-entropy or nll loss value when 
training CNN network using the models defined by example/image-classification 
(#9805)
e1d8c66 is described below

commit e1d8c66c1af09841a6e0b2f74d40b09fc23051e7
Author: Shufan <33112206+juliusshu...@users.noreply.github.com>
AuthorDate: Fri Feb 23 02:37:56 2018 +0800

    Enable the reporting of cross-entropy or nll loss value when training CNN 
network using the models defined by example/image-classification (#9805)
    
    * Enable the reporting of cross-entropy or nll loss value during training
    
    * Set the default value of loss as a '' to avoid a Python runtime issue 
when loss argument is not set
---
 example/image-classification/common/fit.py | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)

diff --git a/example/image-classification/common/fit.py 
b/example/image-classification/common/fit.py
index d9f96d0..0e0cd52 100755
--- a/example/image-classification/common/fit.py
+++ b/example/image-classification/common/fit.py
@@ -117,6 +117,8 @@ def add_fit_args(parser):
                        help='load the model on an epoch using the 
model-load-prefix')
     train.add_argument('--top-k', type=int, default=0,
                        help='report the top-k accuracy. 0 means no report.')
+    train.add_argument('--loss', type=str, default='',
+                       help='show the cross-entropy or nll loss. ce strands 
for cross-entropy, nll-loss stands for likelihood loss')
     train.add_argument('--test-io', type=int, default=0,
                        help='1 means test reading speed without training')
     train.add_argument('--dtype', type=str, default='float32',
@@ -260,6 +262,23 @@ def fit(args, network, data_loader, **kwargs):
         eval_metrics.append(mx.metric.create(
             'top_k_accuracy', top_k=args.top_k))
 
+    supported_loss = ['ce', 'nll_loss']
+    if len(args.loss) > 0:
+        # ce or nll loss is only applicable to softmax output
+        loss_type_list = args.loss.split(',')
+        if 'softmax_output' in network.list_outputs():
+            for loss_type in loss_type_list:
+                loss_type = loss_type.strip()
+                if loss_type == 'nll':
+                    loss_type = 'nll_loss'
+                if loss_type not in supported_loss:
+                    logging.warning(loss_type + ' is not an valid loss type, 
only cross-entropy or ' \
+                                    'negative likelihood loss is supported!')
+                else:
+                    eval_metrics.append(mx.metric.create(loss_type))
+        else:
+            logging.warning("The output is not softmax_output, loss argument 
will be skipped!")
+
     # callbacks that run after each batch
     batch_end_callbacks = [mx.callback.Speedometer(
         args.batch_size, args.disp_batches)]

-- 
To stop receiving notification emails like this one, please contact
zhash...@apache.org.

Reply via email to