This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 18b6e05  Deprecate dataset transform= argument in gluon data API 
(#17852)
18b6e05 is described below

commit 18b6e0552ae89d86ece4501eb1e08d3ff65cefe4
Author: Joshua Z. Zhang <[email protected]>
AuthorDate: Mon May 11 16:34:19 2020 -0700

    Deprecate dataset transform= argument in gluon data API (#17852)
---
 .../tutorials/packages/gluon/blocks/save_load_params.md    |  2 +-
 example/distributed_training/README.md                     |  9 ++++-----
 example/distributed_training/cifar10_dist.py               |  4 ++--
 example/gluon/data.py                                      |  8 ++++----
 example/gluon/dc_gan/dcgan.py                              |  8 ++++----
 example/gluon/mnist/mnist.py                               |  4 ++--
 example/gluon/sn_gan/data.py                               |  2 +-
 example/restricted-boltzmann-machine/binary_rbm_gluon.py   |  6 +++---
 python/mxnet/gluon/data/dataset.py                         |  4 ++++
 python/mxnet/gluon/data/vision/datasets.py                 | 14 ++++++++++++++
 10 files changed, 39 insertions(+), 22 deletions(-)

diff --git 
a/docs/python_docs/python/tutorials/packages/gluon/blocks/save_load_params.md 
b/docs/python_docs/python/tutorials/packages/gluon/blocks/save_load_params.md
index ee72095..0515303 100644
--- 
a/docs/python_docs/python/tutorials/packages/gluon/blocks/save_load_params.md
+++ 
b/docs/python_docs/python/tutorials/packages/gluon/blocks/save_load_params.md
@@ -180,7 +180,7 @@ def verify_loaded_model(net):
         return data.astype(np.float32)/255, label.astype(np.float32)
 
     # Load ten random images from the test dataset
-    sample_data = 
mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, 
transform=transform),
+    sample_data = 
mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False).transform(transform),
                                   10, shuffle=True)
 
     for data, label in sample_data:
diff --git a/example/distributed_training/README.md 
b/example/distributed_training/README.md
index af25b9e..7025b07 100644
--- a/example/distributed_training/README.md
+++ b/example/distributed_training/README.md
@@ -117,7 +117,7 @@ We can then create a `DataLoader` using the `SplitSampler` 
like shown below:
 
 ```python
 # Load the training data
-train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(train=True, 
transform=transform),
+train_data = 
gluon.data.DataLoader(gluon.data.vision.CIFAR10(train=True).transform(transform),
                                       batch_size,
                                       sampler=SplitSampler(50000, 
store.num_workers, store.rank))
 ```
@@ -141,7 +141,7 @@ def train_batch(batch, ctx, net, trainer):
     # Split and load data into multiple GPUs
     data = batch[0]
     data = gluon.utils.split_and_load(data, ctx)
-    
+
     # Split and load label into multiple GPUs
     label = batch[1]
     label = gluon.utils.split_and_load(label, ctx)
@@ -204,7 +204,7 @@ python ~/mxnet/tools/launch.py -n 2 -s 2 -H hosts \
 Let's take a look at the `hosts` file.
 
 ```
-~/dist$ cat hosts 
+~/dist$ cat hosts
 d1
 d2
 ```
@@ -232,7 +232,7 @@ Last login: Wed Jan 31 18:06:45 2018 from 72.21.198.67
 Note that no authentication information was provided to login to the host. 
This can be done using multiple methods. One easy way is to specify the ssh 
certificates in `~/.ssh/config`. Example:
 
 ```
-~$ cat ~/.ssh/config 
+~$ cat ~/.ssh/config
 Host d1
     HostName ec2-34-201-108-233.compute-1.amazonaws.com
     port 22
@@ -269,4 +269,3 @@ Epoch 4: Test_acc 0.687900
 ```
 
 Note that the output from all hosts are merged and printed to the console.
-
diff --git a/example/distributed_training/cifar10_dist.py 
b/example/distributed_training/cifar10_dist.py
index d3ba515..b668457 100644
--- a/example/distributed_training/cifar10_dist.py
+++ b/example/distributed_training/cifar10_dist.py
@@ -86,11 +86,11 @@ class SplitSampler(gluon.data.sampler.Sampler):
 
 
 # Load the training data
-train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(train=True, 
transform=transform), batch_size,
+train_data = 
gluon.data.DataLoader(gluon.data.vision.CIFAR10(train=True).transform(transform),
 batch_size,
                                    sampler=SplitSampler(50000, 
store.num_workers, store.rank))
 
 # Load the test data
-test_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(train=False, 
transform=transform),
+test_data = 
gluon.data.DataLoader(gluon.data.vision.CIFAR10(train=False).transform(transform),
                                   batch_size, shuffle=False)
 
 # Use ResNet from model zoo
diff --git a/example/gluon/data.py b/example/gluon/data.py
index f855c90..7d0f882 100644
--- a/example/gluon/data.py
+++ b/example/gluon/data.py
@@ -78,7 +78,7 @@ def get_imagenet_iterator(root, batch_size, num_workers, 
data_shape=224, dtype='
     train_dir = os.path.join(root, 'train')
     train_transform, val_transform = get_imagenet_transforms(data_shape, dtype)
     logging.info("Loading image folder %s, this may take a bit long...", 
train_dir)
-    train_dataset = ImageFolderDataset(train_dir, transform=train_transform)
+    train_dataset = 
ImageFolderDataset(train_dir).transform_first(train_transform)
     train_data = DataLoader(train_dataset, batch_size, shuffle=True,
                             last_batch='discard', num_workers=num_workers)
     val_dir = os.path.join(root, 'val')
@@ -86,7 +86,7 @@ def get_imagenet_iterator(root, batch_size, num_workers, 
data_shape=224, dtype='
         user_warning = 'Make sure validation images are stored in one subdir 
per category, a helper script is available at https://git.io/vNQv1'
         raise ValueError(user_warning)
     logging.info("Loading image folder %s, this may take a bit long...", 
val_dir)
-    val_dataset = ImageFolderDataset(val_dir, transform=val_transform)
+    val_dataset = ImageFolderDataset(val_dir).transform(val_transform)
     val_data = DataLoader(val_dataset, batch_size, last_batch='keep', 
num_workers=num_workers)
     return DataLoaderIter(train_data, dtype), DataLoaderIter(val_data, dtype)
 
@@ -118,8 +118,8 @@ def get_caltech101_iterator(batch_size, num_workers, dtype):
         return transposed, label
 
     training_path, testing_path = get_caltech101_data()
-    dataset_train = ImageFolderDataset(root=training_path, transform=transform)
-    dataset_test = ImageFolderDataset(root=testing_path, transform=transform)
+    dataset_train = ImageFolderDataset(root=training_path).transform(transform)
+    dataset_test = ImageFolderDataset(root=testing_path).transform(transform)
 
     train_data = DataLoader(dataset_train, batch_size, shuffle=True, 
num_workers=num_workers)
     test_data = DataLoader(dataset_test, batch_size, shuffle=False, 
num_workers=num_workers)
diff --git a/example/gluon/dc_gan/dcgan.py b/example/gluon/dc_gan/dcgan.py
index 93af13a..6e03aae 100644
--- a/example/gluon/dc_gan/dcgan.py
+++ b/example/gluon/dc_gan/dcgan.py
@@ -143,20 +143,20 @@ def get_dataset(dataset_name):
     # mnist
     if dataset == "mnist":
         train_data = gluon.data.DataLoader(
-            gluon.data.vision.MNIST('./data', train=True, 
transform=transformer),
+            gluon.data.vision.MNIST('./data', 
train=True).transform(transformer),
             batch_size, shuffle=True, last_batch='discard')
 
         val_data = gluon.data.DataLoader(
-            gluon.data.vision.MNIST('./data', train=False, 
transform=transformer),
+            gluon.data.vision.MNIST('./data', 
train=False).transform(transformer),
             batch_size, shuffle=False)
     # cifar10
     elif dataset == "cifar10":
         train_data = gluon.data.DataLoader(
-            gluon.data.vision.CIFAR10('./data', train=True, 
transform=transformer),
+            gluon.data.vision.CIFAR10('./data', 
train=True).transform(transformer),
             batch_size, shuffle=True, last_batch='discard')
 
         val_data = gluon.data.DataLoader(
-            gluon.data.vision.CIFAR10('./data', train=False, 
transform=transformer),
+            gluon.data.vision.CIFAR10('./data', 
train=False).transform(transformer),
             batch_size, shuffle=False)
 
     return train_data, val_data
diff --git a/example/gluon/mnist/mnist.py b/example/gluon/mnist/mnist.py
index 6aea3ab..5acaf14 100644
--- a/example/gluon/mnist/mnist.py
+++ b/example/gluon/mnist/mnist.py
@@ -60,11 +60,11 @@ def transformer(data, label):
     return data, label
 
 train_data = gluon.data.DataLoader(
-    gluon.data.vision.MNIST('./data', train=True, transform=transformer),
+    gluon.data.vision.MNIST('./data', train=True).transform(transformer),
     batch_size=opt.batch_size, shuffle=True, last_batch='discard')
 
 val_data = gluon.data.DataLoader(
-    gluon.data.vision.MNIST('./data', train=False, transform=transformer),
+    gluon.data.vision.MNIST('./data', train=False).transform(transformer),
     batch_size=opt.batch_size, shuffle=False)
 
 # train
diff --git a/example/gluon/sn_gan/data.py b/example/gluon/sn_gan/data.py
index 782f74f..754aa2c 100644
--- a/example/gluon/sn_gan/data.py
+++ b/example/gluon/sn_gan/data.py
@@ -38,5 +38,5 @@ def transformer(data, label):
 def get_training_data(batch_size):
     """ helper function to get dataloader"""
     return gluon.data.DataLoader(
-        CIFAR10(train=True, transform=transformer),
+        CIFAR10(train=True).transform(transformer),
         batch_size=batch_size, shuffle=True, last_batch='discard')
diff --git a/example/restricted-boltzmann-machine/binary_rbm_gluon.py 
b/example/restricted-boltzmann-machine/binary_rbm_gluon.py
index cdce2e6..994b8ea 100644
--- a/example/restricted-boltzmann-machine/binary_rbm_gluon.py
+++ b/example/restricted-boltzmann-machine/binary_rbm_gluon.py
@@ -62,8 +62,8 @@ ctx = mx.gpu(args.device_id) if args.cuda else mx.cpu()
 def data_transform(data, label):
     return data.astype(np.float32) / 255, label.astype(np.float32)
 
-mnist_train_dataset = mx.gluon.data.vision.MNIST(train=True, 
transform=data_transform)
-mnist_test_dataset = mx.gluon.data.vision.MNIST(train=False, 
transform=data_transform)
+mnist_train_dataset = 
mx.gluon.data.vision.MNIST(train=True).transform(data_transform)
+mnist_test_dataset = 
mx.gluon.data.vision.MNIST(train=False).transform(data_transform)
 img_height = mnist_train_dataset[0][0].shape[0]
 img_width = mnist_train_dataset[0][0].shape[1]
 num_visible = img_width * img_height
@@ -139,4 +139,4 @@ plt.axis('off')
 plt.axvline(showcase_num_samples_w * img_width, color='y')
 plt.show(s)
 
-print("Done")
\ No newline at end of file
+print("Done")
diff --git a/python/mxnet/gluon/data/dataset.py 
b/python/mxnet/gluon/data/dataset.py
index c70e792..9a03b0f 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -415,6 +415,10 @@ class _DownloadedDataset(Dataset):
     """Base class for MNIST, cifar10, etc."""
     def __init__(self, root, transform):
         super(_DownloadedDataset, self).__init__()
+        if transform is not None:
+            raise DeprecationWarning(
+                'Directly apply transform to dataset is deprecated. '
+                'Please use dataset.transform() or dataset.transform_first() 
instead...')
         self._transform = transform
         self._data = None
         self._label = None
diff --git a/python/mxnet/gluon/data/vision/datasets.py 
b/python/mxnet/gluon/data/vision/datasets.py
index 9912a13..c88648c 100644
--- a/python/mxnet/gluon/data/vision/datasets.py
+++ b/python/mxnet/gluon/data/vision/datasets.py
@@ -48,6 +48,7 @@ class MNIST(dataset._DownloadedDataset):
     train : bool, default True
         Whether to load the training or testing set.
     transform : function, default None
+        DEPRECATED FUNCTION ARGUMENTS.
         A user defined callback that transforms each sample. For example::
 
             transform=lambda data, label: (data.astype(np.float32)/255, label)
@@ -111,6 +112,7 @@ class FashionMNIST(MNIST):
     train : bool, default True
         Whether to load the training or testing set.
     transform : function, default None
+        DEPRECATED FUNCTION ARGUMENTS.
         A user defined callback that transforms each sample. For example::
 
             transform=lambda data, label: (data.astype(np.float32)/255, label)
@@ -143,6 +145,7 @@ class CIFAR10(dataset._DownloadedDataset):
     train : bool, default True
         Whether to load the training or testing set.
     transform : function, default None
+        DEPRECATED FUNCTION ARGUMENTS.
         A user defined callback that transforms each sample. For example::
 
             transform=lambda data, label: (data.astype(np.float32)/255, label)
@@ -208,6 +211,7 @@ class CIFAR100(CIFAR10):
     train : bool, default True
         Whether to load the training or testing set.
     transform : function, default None
+        DEPRECATED FUNCTION ARGUMENTS.
         A user defined callback that transforms each sample. For example::
 
             transform=lambda data, label: (data.astype(np.float32)/255, label)
@@ -244,6 +248,7 @@ class ImageRecordDataset(dataset.RecordFileDataset):
         If 0, always convert images to greyscale. \
         If 1, always convert images to colored (RGB).
     transform : function, default None
+        DEPRECATED FUNCTION ARGUMENTS.
         A user defined callback that transforms each sample. For example::
 
             transform=lambda data, label: (data.astype(np.float32)/255, label)
@@ -251,6 +256,10 @@ class ImageRecordDataset(dataset.RecordFileDataset):
     """
     def __init__(self, filename, flag=1, transform=None):
         super(ImageRecordDataset, self).__init__(filename)
+        if transform is not None:
+            raise DeprecationWarning(
+                'Directly apply transform to dataset is deprecated. '
+                'Please use dataset.transform() or dataset.transform_first() 
instead...')
         self._flag = flag
         self._transform = transform
 
@@ -287,6 +296,7 @@ class ImageFolderDataset(dataset.Dataset):
         If 0, always convert loaded images to greyscale (1 channel).
         If 1, always convert loaded images to colored (3 channels).
     transform : callable, default None
+        DEPRECATED FUNCTION ARGUMENTS.
         A function that takes data and label and transforms them::
 
             transform = lambda data, label: (data.astype(np.float32)/255, 
label)
@@ -301,6 +311,10 @@ class ImageFolderDataset(dataset.Dataset):
     def __init__(self, root, flag=1, transform=None):
         self._root = os.path.expanduser(root)
         self._flag = flag
+        if transform is not None:
+            raise DeprecationWarning(
+                'Directly apply transform to dataset is deprecated. '
+                'Please use dataset.transform() or dataset.transform_first() 
instead...')
         self._transform = transform
         self._exts = ['.jpg', '.jpeg', '.png']
         self._list_images(self._root)

Reply via email to