This is an automated email from the ASF dual-hosted git repository. zhasheng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new e1d8c66 Enable the reporting of cross-entropy or nll loss value when training CNN network using the models defined by example/image-classification (#9805) e1d8c66 is described below commit e1d8c66c1af09841a6e0b2f74d40b09fc23051e7 Author: Shufan <33112206+juliusshu...@users.noreply.github.com> AuthorDate: Fri Feb 23 02:37:56 2018 +0800 Enable the reporting of cross-entropy or nll loss value when training CNN network using the models defined by example/image-classification (#9805) * Enable the reporting of cross-entropy or nll loss value during training * Set the default value of loss as a '' to avoid a Python runtime issue when loss argument is not set --- example/image-classification/common/fit.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py index d9f96d0..0e0cd52 100755 --- a/example/image-classification/common/fit.py +++ b/example/image-classification/common/fit.py @@ -117,6 +117,8 @@ def add_fit_args(parser): help='load the model on an epoch using the model-load-prefix') train.add_argument('--top-k', type=int, default=0, help='report the top-k accuracy. 0 means no report.') + train.add_argument('--loss', type=str, default='', + help='show the cross-entropy or nll loss. ce strands for cross-entropy, nll-loss stands for likelihood loss') train.add_argument('--test-io', type=int, default=0, help='1 means test reading speed without training') train.add_argument('--dtype', type=str, default='float32', @@ -260,6 +262,23 @@ def fit(args, network, data_loader, **kwargs): eval_metrics.append(mx.metric.create( 'top_k_accuracy', top_k=args.top_k)) + supported_loss = ['ce', 'nll_loss'] + if len(args.loss) > 0: + # ce or nll loss is only applicable to softmax output + loss_type_list = args.loss.split(',') + if 'softmax_output' in network.list_outputs(): + for loss_type in loss_type_list: + loss_type = loss_type.strip() + if loss_type == 'nll': + loss_type = 'nll_loss' + if loss_type not in supported_loss: + logging.warning(loss_type + ' is not an valid loss type, only cross-entropy or ' \ + 'negative likelihood loss is supported!') + else: + eval_metrics.append(mx.metric.create(loss_type)) + else: + logging.warning("The output is not softmax_output, loss argument will be skipped!") + # callbacks that run after each batch batch_end_callbacks = [mx.callback.Speedometer( args.batch_size, args.disp_batches)] -- To stop receiving notification emails like this one, please contact zhash...@apache.org.