szha closed pull request #9633: Gluon image-classification example improvement
URL: https://github.com/apache/incubator-mxnet/pull/9633
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/gluon/data.py b/example/gluon/data.py
index dc8f12e81f6..c996c9af9ed 100644
--- a/example/gluon/data.py
+++ b/example/gluon/data.py
@@ -19,8 +19,14 @@
 """ data iterator for mnist """
 import os
 import random
+import logging
+logging.basicConfig(level=logging.INFO)
+
 import mxnet as mx
 from mxnet.test_utils import get_cifar10
+from mxnet.gluon.data.vision import ImageFolderDataset
+from mxnet.gluon.data import DataLoader
+from mxnet.contrib.io import DataLoaderIter
 
 def get_cifar10_iterator(batch_size, data_shape, resize=-1, num_parts=1, 
part_index=0):
     get_cifar10()
@@ -49,50 +55,38 @@ def get_cifar10_iterator(batch_size, data_shape, resize=-1, 
num_parts=1, part_in
 
     return train, val
 
-
-def get_imagenet_iterator(train_data, val_data, batch_size, data_shape, 
resize=-1, num_parts=1, part_index=0):
-    train = mx.io.ImageRecordIter(
-        path_imgrec             = train_data,
-        data_shape              = data_shape,
-        mean_r                  = 123.68,
-        mean_g                  = 116.779,
-        mean_b                  = 103.939,
-        std_r                   = 58.395,
-        std_g                   = 57.12,
-        std_b                   = 57.375,
-        preprocess_threads      = 32,
-        shuffle                 = True,
-        batch_size              = batch_size,
-        rand_crop               = True,
-        resize                  = resize,
-        random_mirror           = True,
-        max_random_h            = 36,
-        max_random_s            = 50,
-        max_random_l            = 50,
-        max_random_rotate_angle = 10,
-        max_random_shear_ratio  = 0.1,
-        max_random_aspect_ratio = 0.25,
-        fill_value              = 127,
-        min_random_scale        = 0.533,
-        num_parts               = num_parts,
-        part_index              = part_index)
-
-    val = mx.io.ImageRecordIter(
-        path_imgrec        = val_data,
-        data_shape         = data_shape,
-        mean_r             = 123.68,
-        mean_g             = 116.779,
-        mean_b             = 103.939,
-        std_r              = 58.395,
-        std_g              = 57.12,
-        std_b              = 57.375,
-        preprocess_threads = 32,
-        batch_size         = batch_size,
-        resize             = resize,
-        num_parts          = num_parts,
-        part_index         = part_index)
-
-    return train, val
+def get_imagenet_transforms(data_shape=224, dtype='float32'):
+    def train_transform(image, label):
+        image, _ = mx.image.random_size_crop(image, (data_shape, data_shape), 
0.08, (3/4., 4/3.))
+        image = mx.nd.image.random_flip_left_right(image)
+        image = mx.nd.image.to_tensor(image)
+        image = mx.nd.image.normalize(image, mean=(0.485, 0.456, 0.406), 
std=(0.229, 0.224, 0.225))
+        return mx.nd.cast(image, dtype), label
+
+    def val_transform(image, label):
+        image = mx.image.resize_short(image, data_shape + 32)
+        image, _ = mx.image.center_crop(image, (data_shape, data_shape))
+        image = mx.nd.image.to_tensor(image)
+        image = mx.nd.image.normalize(image, mean=(0.485, 0.456, 0.406), 
std=(0.229, 0.224, 0.225))
+        return mx.nd.cast(image, dtype), label
+    return train_transform, val_transform
+
+def get_imagenet_iterator(root, batch_size, num_workers, data_shape=224, 
dtype='float32'):
+    """Dataset loader with preprocessing."""
+    train_dir = os.path.join(root, 'train')
+    train_transform, val_transform = get_imagenet_transforms(data_shape, dtype)
+    logging.info("Loading image folder %s, this may take a bit long...", 
train_dir)
+    train_dataset = ImageFolderDataset(train_dir, transform=train_transform)
+    train_data = DataLoader(train_dataset, batch_size, shuffle=True,
+                            last_batch='discard', num_workers=num_workers)
+    val_dir = os.path.join(root, 'val')
+    if not os.path.isdir(os.path.join(os.path.expanduser(root, 'val', 
'n01440764'))):
+        user_warning = 'Make sure validation images are stored in one subdir 
per category, a helper script is available at https://git.io/vNQv1'
+        raise ValueError(user_warning)
+    logging.info("Loading image folder %s, this may take a bit long...", 
val_dir)
+    val_dataset = ImageFolderDataset(val_dir, transform=val_transform)
+    val_data = DataLoader(val_dataset, batch_size, last_batch='keep', 
num_workers=num_workers)
+    return DataLoaderIter(train_data, dtype), DataLoaderIter(val_data, dtype)
 
 
 class DummyIter(mx.io.DataIter):
diff --git a/example/gluon/image_classification.py 
b/example/gluon/image_classification.py
index 529b977a790..9acfda51d17 100644
--- a/example/gluon/image_classification.py
+++ b/example/gluon/image_classification.py
@@ -17,9 +17,8 @@
 
 from __future__ import division
 
-import argparse, time
+import argparse, time, os
 import logging
-logging.basicConfig(level=logging.INFO)
 
 import mxnet as mx
 from mxnet import gluon
@@ -27,26 +26,40 @@
 from mxnet.gluon.model_zoo import vision as models
 from mxnet import autograd as ag
 from mxnet.test_utils import get_mnist_iterator
+from mxnet.metric import Accuracy, TopKAccuracy, CompositeEvalMetric
+import numpy as np
 
 from data import *
 
+# logging
+logging.basicConfig(level=logging.INFO)
+fh = logging.FileHandler('image-classification.log')
+logger = logging.getLogger()
+logger.addHandler(fh)
+formatter = logging.Formatter('%(message)s')
+fh.setFormatter(formatter)
+fh.setLevel(logging.DEBUG)
+logging.debug('\n%s', '-' * 100)
+formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
+fh.setFormatter(formatter)
+
 # CLI
 parser = argparse.ArgumentParser(description='Train a model for image 
classification.')
 parser.add_argument('--dataset', type=str, default='cifar10',
-                    help='dataset to use. options are mnist, cifar10, and 
dummy.')
-parser.add_argument('--train-data', type=str, default='',
-                    help='training record file to use, required for imagenet.')
-parser.add_argument('--val-data', type=str, default='',
-                    help='validation record file to use, required for 
imagenet.')
+                    help='dataset to use. options are mnist, cifar10, imagenet 
and dummy.')
+parser.add_argument('--data-dir', type=str, default='',
+                    help='training directory of imagenet images, contains 
train/val subdirs.')
 parser.add_argument('--batch-size', type=int, default=32,
                     help='training batch size per device (CPU/GPU).')
-parser.add_argument('--num-gpus', type=int, default=0,
-                    help='number of gpus to use.')
-parser.add_argument('--epochs', type=int, default=3,
+parser.add_argument('--num-worker', '-j', dest='num_workers', default=4, 
type=int,
+                    help='number of workers of dataloader.')
+parser.add_argument('--gpus', type=str, default='',
+                    help='ordinates of gpus to use, can be "0,1,2" or empty 
for cpu only.')
+parser.add_argument('--epochs', type=int, default=120,
                     help='number of training epochs.')
-parser.add_argument('--lr', type=float, default=0.01,
-                    help='learning rate. default is 0.01.')
-parser.add_argument('-momentum', type=float, default=0.9,
+parser.add_argument('--lr', type=float, default=0.1,
+                    help='learning rate. default is 0.1.')
+parser.add_argument('--momentum', type=float, default=0.9,
                     help='momentum value for optimizer, default is 0.9.')
 parser.add_argument('--wd', type=float, default=0.0001,
                     help='weight decay rate. default is 0.0001.')
@@ -62,39 +75,64 @@
                     help='enable batch normalization or not in vgg. default is 
false.')
 parser.add_argument('--use-pretrained', action='store_true',
                     help='enable using pretrained model from gluon.')
+parser.add_argument('--prefix', default='', type=str,
+                    help='path to checkpoint prefix, default is current 
working dir')
+parser.add_argument('--start-epoch', default=0, type=int,
+                    help='starting epoch, 0 for fresh training, > 0 to resume')
+parser.add_argument('--resume', type=str, default='',
+                    help='path to saved weight where you want resume')
+parser.add_argument('--lr-factor', default=0.1, type=float,
+                    help='learning rate decay ratio')
+parser.add_argument('--lr-steps', default='30,60,90', type=str,
+                    help='list of learning rate decay epochs as in str')
+parser.add_argument('--dtype', default='float32', type=str,
+                    help='data type, float32 or float16 if applicable')
+parser.add_argument('--save-frequency', default=10, type=int,
+                    help='epoch frequence to save model, best model will 
always be saved')
 parser.add_argument('--kvstore', type=str, default='device',
                     help='kvstore to use for trainer/module.')
-parser.add_argument('--log-interval', type=int, default=50, help='Number of 
batches to wait before logging.')
+parser.add_argument('--log-interval', type=int, default=50,
+                    help='Number of batches to wait before logging.')
 parser.add_argument('--profile', action='store_true',
                     help='Option to turn on memory profiling for front-end, '\
                          'and prints out the memory usage by python function 
at the end.')
 opt = parser.parse_args()
 
-logging.info(opt)
-
+# global variables
+logger.info('Starting new image-classification task:, %s',opt)
 mx.random.seed(opt.seed)
-
+model_name = opt.model
 dataset_classes = {'mnist': 10, 'cifar10': 10, 'imagenet': 1000, 'dummy': 1000}
-
 batch_size, dataset, classes = opt.batch_size, opt.dataset, 
dataset_classes[opt.dataset]
-
-num_gpus = opt.num_gpus
-
+context = [mx.gpu(int(i)) for i in opt.gpus.split(',')] if opt.gpus.strip() 
else [mx.cpu()]
+num_gpus = len(context)
 batch_size *= max(1, num_gpus)
-context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
+lr_steps = [int(x) for x in opt.lr_steps.split(',') if x.strip()]
+metric = CompositeEvalMetric([Accuracy(), TopKAccuracy(5)])
 
-model_name = opt.model
+def get_model(model, ctx, opt):
+    """Model initialization."""
+    kwargs = {'ctx': ctx, 'pretrained': opt.use_pretrained, 'classes': classes}
+    if model.startswith('resnet'):
+        kwargs['thumbnail'] = opt.use_thumbnail
+    elif model.startswith('vgg'):
+        kwargs['batch_norm'] = opt.batch_norm
 
-kwargs = {'ctx': context, 'pretrained': opt.use_pretrained, 'classes': classes}
-if model_name.startswith('resnet'):
-    kwargs['thumbnail'] = opt.use_thumbnail
-elif model_name.startswith('vgg'):
-    kwargs['batch_norm'] = opt.batch_norm
+    net = models.get_model(model, **kwargs)
+    if opt.resume:
+        net.load_params(opt.resume)
+    elif not opt.use_pretrained:
+        if model in ['alexnet']:
+            net.initialize(mx.init.Normal())
+        else:
+            net.initialize(mx.init.Xavier(magnitude=2))
+    net.cast(opt.dtype)
+    return net
 
-net = models.get_model(opt.model, **kwargs)
+net = get_model(opt.model, context, opt)
 
 def get_data_iters(dataset, batch_size, num_workers=1, rank=0):
-    # get dataset iterators
+    """get dataset iterators"""
     if dataset == 'mnist':
         train_data, val_data = get_mnist_iterator(batch_size, (1, 28, 28),
                                                   num_parts=num_workers, 
part_index=rank)
@@ -102,14 +140,12 @@ def get_data_iters(dataset, batch_size, num_workers=1, 
rank=0):
         train_data, val_data = get_cifar10_iterator(batch_size, (3, 32, 32),
                                                     num_parts=num_workers, 
part_index=rank)
     elif dataset == 'imagenet':
+        if not opt.data_dir:
+            raise ValueError('Dir containing raw images in train/val is 
required for imagenet, plz specify "--data-dir"')
         if model_name == 'inceptionv3':
-            train_data, val_data = get_imagenet_iterator(opt.train_data, 
opt.val_data,
-                                                         batch_size, (3, 299, 
299),
-                                                         
num_parts=num_workers, part_index=rank)
+            train_data, val_data = get_imagenet_iterator(opt.data_dir, 
batch_size, opt.num_workers, 299, opt.dtype)
         else:
-            train_data, val_data = get_imagenet_iterator(opt.train_data, 
opt.val_data,
-                                                         batch_size, (3, 224, 
224),
-                                                         
num_parts=num_workers, part_index=rank)
+            train_data, val_data = get_imagenet_iterator(opt.data_dir, 
batch_size, opt.num_workers, 224, opt.dtype)
     elif dataset == 'dummy':
         if model_name == 'inceptionv3':
             train_data, val_data = dummy_iterator(batch_size, (3, 299, 299))
@@ -118,7 +154,7 @@ def get_data_iters(dataset, batch_size, num_workers=1, 
rank=0):
     return train_data, val_data
 
 def test(ctx, val_data):
-    metric = mx.metric.Accuracy()
+    metric.reset()
     val_data.reset()
     for batch in val_data:
         data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, 
batch_axis=0)
@@ -129,27 +165,45 @@ def test(ctx, val_data):
         metric.update(label, outputs)
     return metric.get()
 
+def update_learning_rate(lr, trainer, epoch, ratio, steps):
+    """Set the learning rate to the initial value decayed by ratio every N 
epochs."""
+    new_lr = lr * (ratio ** int(np.sum(np.array(steps) < epoch)))
+    trainer.set_learning_rate(new_lr)
+    return trainer
+
+def save_checkpoint(epoch, top1, best_acc):
+    if opt.save_frequency and (epoch + 1) % opt.save_frequency == 0:
+        fname = os.path.join(opt.prefix, '%s_%d_acc_%.4f.params' % (opt.model, 
epoch, top1))
+        net.save_params(fname)
+        logger.info('[Epoch %d] Saving checkpoint to %s with Accuracy: %.4f', 
epoch, fname, top1)
+    if top1 > best_acc[0]:
+        best_acc[0] = top1
+        fname = os.path.join(opt.prefix, '%s_best.params' % (opt.model))
+        net.save_params(fname)
+        logger.info('[Epoch %d] Saving checkpoint to %s with Accuracy: %.4f', 
epoch, fname, top1)
 
-def train(epochs, ctx):
+def train(opt, ctx):
     if isinstance(ctx, mx.Context):
         ctx = [ctx]
-    net.initialize(mx.init.Xavier(magnitude=2), ctx=ctx)
     kv = mx.kv.create(opt.kvstore)
     train_data, val_data = get_data_iters(dataset, batch_size, kv.num_workers, 
kv.rank)
+    net.collect_params().reset_ctx(ctx)
     trainer = gluon.Trainer(net.collect_params(), 'sgd',
-                            {'learning_rate': opt.lr, 'wd': opt.wd, 
'momentum': opt.momentum},
+                            {'learning_rate': opt.lr, 'wd': opt.wd, 
'momentum': opt.momentum,
+                             'multi_precision': True},
                             kvstore = kv)
-    metric = mx.metric.Accuracy()
     loss = gluon.loss.SoftmaxCrossEntropyLoss()
 
-    for epoch in range(epochs):
+    best_acc = [0]
+    for epoch in range(opt.start_epoch, opt.epochs):
+        trainer = update_learning_rate(opt.lr, trainer, epoch, opt.lr_factor, 
lr_steps)
         tic = time.time()
         train_data.reset()
         metric.reset()
         btic = time.time()
         for i, batch in enumerate(train_data):
-            data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, 
batch_axis=0)
-            label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, 
batch_axis=0)
+            data = gluon.utils.split_and_load(batch.data[0].astype(opt.dtype), 
ctx_list=ctx, batch_axis=0)
+            label = 
gluon.utils.split_and_load(batch.label[0].astype(opt.dtype), ctx_list=ctx, 
batch_axis=0)
             outputs = []
             Ls = []
             with ag.record():
@@ -160,23 +214,23 @@ def train(epochs, ctx):
                     # on all GPUs for better speed on multiple GPUs.
                     Ls.append(L)
                     outputs.append(z)
-                for L in Ls:
-                    L.backward()
+                ag.backward(Ls)
             trainer.step(batch.data[0].shape[0])
             metric.update(label, outputs)
             if opt.log_interval and not (i+1)%opt.log_interval:
                 name, acc = metric.get()
-                logging.info('Epoch[%d] Batch [%d]\tSpeed: %f 
samples/sec\t%s=%f'%(
-                               epoch, i, batch_size/(time.time()-btic), name, 
acc))
+                logger.info('Epoch[%d] Batch [%d]\tSpeed: %f 
samples/sec\t%s=%f, %s=%f'%(
+                               epoch, i, batch_size/(time.time()-btic), 
name[0], acc[0], name[1], acc[1]))
             btic = time.time()
 
         name, acc = metric.get()
-        logging.info('[Epoch %d] training: %s=%f'%(epoch, name, acc))
-        logging.info('[Epoch %d] time cost: %f'%(epoch, time.time()-tic))
+        logger.info('[Epoch %d] training: %s=%f, %s=%f'%(epoch, name[0], 
acc[0], name[1], acc[1]))
+        logger.info('[Epoch %d] time cost: %f'%(epoch, time.time()-tic))
         name, val_acc = test(ctx, val_data)
-        logging.info('[Epoch %d] validation: %s=%f'%(epoch, name, val_acc))
+        logger.info('[Epoch %d] validation: %s=%f, %s=%f'%(epoch, name[0], 
val_acc[0], name[1], val_acc[1]))
 
-    net.save_params('image-classifier-%s-%d.params'%(opt.model, epochs))
+        # save model if meet requirements
+        save_checkpoint(epoch, val_acc[0], best_acc)
 
 def main():
     if opt.mode == 'symbolic':
@@ -193,13 +247,13 @@ def main():
                 batch_end_callback = mx.callback.Speedometer(batch_size, 
max(1, opt.log_interval)),
                 epoch_end_callback = 
mx.callback.do_checkpoint('image-classifier-%s'% opt.model),
                 optimizer = 'sgd',
-                optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 
'momentum': opt.momentum},
+                optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 
'momentum': opt.momentum, 'multi_precision': True},
                 initializer = mx.init.Xavier(magnitude=2))
         mod.save_params('image-classifier-%s-%d-final.params'%(opt.model, 
opt.epochs))
     else:
         if opt.mode == 'hybrid':
             net.hybridize()
-        train(opt.epochs, context)
+        train(opt, context)
 
 if __name__ == '__main__':
     if opt.profile:
diff --git a/python/mxnet/contrib/__init__.py b/python/mxnet/contrib/__init__.py
index 21c77719b70..36ee21305bf 100644
--- a/python/mxnet/contrib/__init__.py
+++ b/python/mxnet/contrib/__init__.py
@@ -28,3 +28,5 @@
 from . import tensorboard
 
 from . import text
+
+from . import io
diff --git a/python/mxnet/contrib/io.py b/python/mxnet/contrib/io.py
new file mode 100644
index 00000000000..6020b3ee1c9
--- /dev/null
+++ b/python/mxnet/contrib/io.py
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Contrib data iterators for common data formats."""
+from __future__ import absolute_import
+from ..io import DataIter, DataDesc
+from .. import ndarray as nd
+
+
+class DataLoaderIter(DataIter):
+    """Returns an iterator for ``mx.gluon.data.Dataloader`` so gluon dataloader
+    can be used in symbolic module.
+
+    Parameters
+    ----------
+    loader : mxnet.gluon.data.Dataloader
+        Gluon dataloader instance
+    data_name : str, optional
+        The data name.
+    label_name : str, optional
+        The label name.
+    dtype : str, optional
+        The dtype specifier, can be float32 or float16
+
+    Example usage:
+    ----------
+    >>> import mxnet as mx
+    >>> from mxnet.gluon.data.vision import MNIST
+    >>> from mxnet.gluon.data import DataLoader
+    >>> train_dataset = MNIST(train=True)
+    >>> train_data = mx.gluon.data.DataLoader(train_dataset, 32, shuffle=True, 
num_workers=4)
+    >>> dataiter = mx.io.DataloaderIter(train_data)
+    >>> for batch in dataiter:
+    ...     batch.data[0].shape
+    ...
+    (32L, 28L, 28L, 1L)
+    """
+    def __init__(self, loader, data_name='data', label_name='softmax_label', 
dtype='float32'):
+        super(DataLoaderIter, self).__init__()
+        self._loader = loader
+        self._iter = iter(self._loader)
+        data, label = next(self._iter)
+        self.batch_size = data.shape[0]
+        self.dtype = dtype
+        self.provide_data = [DataDesc(data_name, data.shape, dtype)]
+        self.provide_label = [DataDesc(label_name, label.shape, dtype)]
+        self._current_batch = None
+        self.reset()
+
+    def reset(self):
+        self._iter = iter(self._loader)
+
+    def iter_next(self):
+        try:
+            self._current_batch = next(self._iter)
+        except StopIteration:
+            self._current_batch = None
+        return self._current_batch is not None
+
+    def getdata(self):
+        if self.getpad():
+            dshape = self._current_batch[0].shape
+            ret = nd.empty(shape=([self.batch_size] + list(dshape[1:])))
+            ret[:dshape[0]] = self._current_batch[0].astype(self.dtype)
+            return [ret]
+        return [self._current_batch[0].astype(self.dtype)]
+
+    def getlabel(self):
+        if self.getpad():
+            lshape = self._current_batch[1].shape
+            ret = nd.empty(shape=([self.batch_size] + list(lshape[1:])))
+            ret[:lshape[0]] = self._current_batch[1].astype(self.dtype)
+            return [ret]
+        return [self._current_batch[1].astype(self.dtype)]
+
+    def getpad(self):
+        return self.batch_size - self._current_batch[0].shape[0]
+
+    def getindex(self):
+        return None
diff --git a/tests/python/unittest/test_contrib_io.py 
b/tests/python/unittest/test_contrib_io.py
new file mode 100644
index 00000000000..dbae69fe729
--- /dev/null
+++ b/tests/python/unittest/test_contrib_io.py
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet.ndarray as nd
+from mxnet.gluon.data.vision.datasets import *
+from mxnet.gluon.data.dataloader import *
+from mxnet.contrib.io import *
+from mxnet.test_utils import *
+
+def test_contrib_DataLoaderIter():
+    def test_mnist_batches(batch_size, expected, last_batch='discard'):
+        dataset = MNIST(train=False)
+        dataloader = DataLoader(dataset, batch_size, last_batch=last_batch)
+        test_iter = DataLoaderIter(dataloader)
+        batch = next(test_iter)
+        assert batch.data[0].shape == (batch_size, 28, 28, 1)
+        assert batch.label[0].shape == (batch_size,)
+        count = 0
+        test_iter.reset()
+        for batch in test_iter:
+            count += 1
+        assert count == expected, "expected {} batches, given 
{}".format(expected, count)
+
+    num_examples = 10000
+    test_mnist_batches(50, num_examples // 50, 'discard')
+    test_mnist_batches(31, num_examples // 31, 'discard')
+    test_mnist_batches(31, num_examples // 31, 'rollover')
+    test_mnist_batches(31, num_examples // 31 + 1, 'keep')
+
+
+if __name__ == "__main__":
+    test_contrib_DataLoaderIter()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to