This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 29f7f43  gluon dataset download refactor (#10008)
29f7f43 is described below

commit 29f7f43d9aca7a9abd67a0b7baea655b7aa0ed48
Author: Sheng Zha <s...@users.noreply.github.com>
AuthorDate: Tue Mar 6 14:47:09 2018 -0500

    gluon dataset download refactor (#10008)
---
 python/mxnet/gluon/contrib/data/text.py    | 21 ++++++++++++++-------
 python/mxnet/gluon/data/dataset.py         | 23 ++++++-----------------
 python/mxnet/gluon/data/vision/datasets.py | 20 ++++++++++++--------
 python/mxnet/gluon/utils.py                | 30 +++++++++++++++++++++++++++---
 4 files changed, 59 insertions(+), 35 deletions(-)

diff --git a/python/mxnet/gluon/contrib/data/text.py 
b/python/mxnet/gluon/contrib/data/text.py
index 1f52954..8f69dbd 100644
--- a/python/mxnet/gluon/contrib/data/text.py
+++ b/python/mxnet/gluon/contrib/data/text.py
@@ -28,16 +28,17 @@ import numpy as np
 
 from . import _constants as C
 from ...data import dataset
-from ...utils import download, check_sha1
+from ...utils import download, check_sha1, _get_repo_file_url
 from ....contrib import text
 from .... import nd
 
 
 class _LanguageModelDataset(dataset._DownloadedDataset): # pylint: 
disable=abstract-method
-    def __init__(self, repo_dir, root, vocabulary):
+    def __init__(self, root, namespace, vocabulary):
         self._vocab = vocabulary
         self._counter = None
-        super(_LanguageModelDataset, self).__init__(repo_dir, root, None)
+        self._namespace = namespace
+        super(_LanguageModelDataset, self).__init__(root, None)
 
     @property
     def vocabulary(self):
@@ -76,7 +77,7 @@ class _WikiText(_LanguageModelDataset):
         data_file_name, data_hash = self._data_file[self._segment]
         path = os.path.join(self._root, data_file_name)
         if not os.path.exists(path) or not check_sha1(path, data_hash):
-            downloaded_file_path = download(self._get_url(archive_file_name),
+            downloaded_file_path = 
download(_get_repo_file_url(self._namespace, archive_file_name),
                                             path=self._root,
                                             sha1_hash=archive_hash)
 
@@ -89,11 +90,17 @@ class _WikiText(_LanguageModelDataset):
                              open(dest, "wb") as target:
                             shutil.copyfileobj(source, target)
 
-        data, label = self._read_batch(os.path.join(self._root, 
data_file_name))
+        data, label = self._read_batch(path)
 
         self._data = nd.array(data, dtype=data.dtype).reshape((-1, 
self._seq_len))
         self._label = nd.array(label, dtype=label.dtype).reshape((-1, 
self._seq_len))
 
+    def __getitem__(self, idx):
+        return self._data[idx], self._label[idx]
+
+    def __len__(self):
+        return len(self._label)
+
 
 class WikiText2(_WikiText):
     """WikiText-2 word-level dataset for language modeling, from Salesforce 
research.
@@ -130,7 +137,7 @@ class WikiText2(_WikiText):
                                     
'c7b8ce0aa086fb34dab808c5c49224211eb2b172')}
         self._segment = segment
         self._seq_len = seq_len
-        super(WikiText2, self).__init__('wikitext-2', root, vocab)
+        super(WikiText2, self).__init__(root, 'wikitext-2', vocab)
 
 
 class WikiText103(_WikiText):
@@ -167,4 +174,4 @@ class WikiText103(_WikiText):
                                     
'8a5befc548865cec54ed4273cf87dbbad60d1e47')}
         self._segment = segment
         self._seq_len = seq_len
-        super(WikiText103, self).__init__('wikitext-103', root, vocab)
+        super(WikiText103, self).__init__(root, 'wikitext-103', vocab)
diff --git a/python/mxnet/gluon/data/dataset.py 
b/python/mxnet/gluon/data/dataset.py
index fe1e813..bf5fa0a 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -182,24 +182,18 @@ class RecordFileDataset(Dataset):
     def __len__(self):
         return len(self._record.keys)
 
-apache_repo_url = 'https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/'
 
 class _DownloadedDataset(Dataset):
     """Base class for MNIST, cifar10, etc."""
-    def __init__(self, repo_dir, root, transform):
-        self._root = os.path.expanduser(root)
-        self._repo_dir = repo_dir
+    def __init__(self, root, transform):
+        super(_DownloadedDataset, self).__init__()
         self._transform = transform
         self._data = None
         self._label = None
-
-        repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url)
-        if repo_url[-1] != '/':
-            repo_url = repo_url+'/'
-        self._base_url = repo_url
-
-        if not os.path.isdir(self._root):
-            os.makedirs(self._root)
+        root = os.path.expanduser(root)
+        self._root = root
+        if not os.path.isdir(root):
+            os.makedirs(root)
         self._get_data()
 
     def __getitem__(self, idx):
@@ -212,8 +206,3 @@ class _DownloadedDataset(Dataset):
 
     def _get_data(self):
         raise NotImplementedError
-
-    def _get_url(self, filename):
-        return 
'{base_url}gluon/dataset/{repo_dir}/{filename}'.format(base_url=self._base_url,
-                                                                      
repo_dir=self._repo_dir,
-                                                                      
filename=filename)
diff --git a/python/mxnet/gluon/data/vision/datasets.py 
b/python/mxnet/gluon/data/vision/datasets.py
index 4568da6..d199cad 100644
--- a/python/mxnet/gluon/data/vision/datasets.py
+++ b/python/mxnet/gluon/data/vision/datasets.py
@@ -29,7 +29,7 @@ import warnings
 import numpy as np
 
 from .. import dataset
-from ...utils import download, check_sha1
+from ...utils import download, check_sha1, _get_repo_file_url
 from .... import nd, image, recordio
 
 
@@ -62,7 +62,8 @@ class MNIST(dataset._DownloadedDataset):
                            'c3a25af1f52dad7f726cce8cacb138654b760d48')
         self._test_label = ('t10k-labels-idx1-ubyte.gz',
                             '763e7fa3757d93b0cdec073cef058b2004252c17')
-        super(MNIST, self).__init__('mnist', root, transform)
+        self._namespace = 'mnist'
+        super(MNIST, self).__init__(root, transform)
 
     def _get_data(self):
         if self._train:
@@ -70,10 +71,10 @@ class MNIST(dataset._DownloadedDataset):
         else:
             data, label = self._test_data, self._test_label
 
-        data_file = download(self._get_url(data[0]),
+        data_file = download(_get_repo_file_url(self._namespace, data[0]),
                              path=self._root,
                              sha1_hash=data[1])
-        label_file = download(self._get_url(label[0]),
+        label_file = download(_get_repo_file_url(self._namespace, label[0]),
                               path=self._root,
                               sha1_hash=label[1])
 
@@ -121,7 +122,8 @@ class FashionMNIST(MNIST):
                            '626ed6a7c06dd17c0eec72fa3be1740f146a2863')
         self._test_label = ('t10k-labels-idx1-ubyte.gz',
                             '17f9ab60e7257a1620f4ad76bbbaf857c3920701')
-        super(MNIST, self).__init__('fashion-mnist', root, transform) # 
pylint: disable=bad-super-call
+        self._namespace = 'fashion-mnist'
+        super(MNIST, self).__init__(root, transform) # pylint: 
disable=bad-super-call
 
 
 class CIFAR10(dataset._DownloadedDataset):
@@ -152,7 +154,8 @@ class CIFAR10(dataset._DownloadedDataset):
                             ('data_batch_4.bin', 
'aab85764eb3584312d3c7f65fd2fd016e36a258e'),
                             ('data_batch_5.bin', 
'26e2849e66a845b7f1e4614ae70f4889ae604628')]
         self._test_data = [('test_batch.bin', 
'67eb016db431130d61cd03c7ad570b013799c88c')]
-        super(CIFAR10, self).__init__('cifar10', root, transform)
+        self._namespace = 'cifar10'
+        super(CIFAR10, self).__init__(root, transform)
 
     def _read_batch(self, filename):
         with open(filename, 'rb') as fin:
@@ -165,7 +168,7 @@ class CIFAR10(dataset._DownloadedDataset):
         if any(not os.path.exists(path) or not check_sha1(path, sha1)
                for path, sha1 in ((os.path.join(self._root, name), sha1)
                                   for name, sha1 in self._train_data + 
self._test_data)):
-            filename = download(self._get_url(self._archive_file[0]),
+            filename = download(_get_repo_file_url(self._namespace, 
self._archive_file[0]),
                                 path=self._root,
                                 sha1_hash=self._archive_file[1])
 
@@ -212,7 +215,8 @@ class CIFAR100(CIFAR10):
         self._train_data = [('train.bin', 
'e207cd2e05b73b1393c74c7f5e7bea451d63e08e')]
         self._test_data = [('test.bin', 
'8fb6623e830365ff53cf14adec797474f5478006')]
         self._fine_label = fine_label
-        super(CIFAR10, self).__init__('cifar100', root, transform) # pylint: 
disable=bad-super-call
+        self._namespace = 'cifar100'
+        super(CIFAR10, self).__init__(root, transform) # pylint: 
disable=bad-super-call
 
     def _read_batch(self, filename):
         with open(filename, 'rb') as fin:
diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index 88effc9..9ebfe4c 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -192,10 +192,12 @@ def download(url, path=None, overwrite=False, 
sha1_hash=None):
     """
     if path is None:
         fname = url.split('/')[-1]
-    elif os.path.isdir(path):
-        fname = os.path.join(path, url.split('/')[-1])
     else:
-        fname = path
+        path = os.path.expanduser(path)
+        if os.path.isdir(path):
+            fname = os.path.join(path, url.split('/')[-1])
+        else:
+            fname = path
 
     if overwrite or not os.path.exists(fname) or (sha1_hash and not 
check_sha1(fname, sha1_hash)):
         dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
@@ -218,3 +220,25 @@ def download(url, path=None, overwrite=False, 
sha1_hash=None):
                               'the default repo.'.format(fname))
 
     return fname
+
+def _get_repo_url():
+    """Return the base URL for Gluon dataset and model repository."""
+    default_repo = 
'https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/'
+    repo_url = os.environ.get('MXNET_GLUON_REPO', default_repo)
+    if repo_url[-1] != '/':
+        repo_url = repo_url+'/'
+    return repo_url
+
+def _get_repo_file_url(namespace, filename):
+    """Return the URL for hoste file in Gluon repository.
+
+    Parameters
+    ----------
+    namespace : str
+        Namespace of the file.
+    filename : str
+        Name of the file
+    """
+    return 
'{base_url}gluon/dataset/{namespace}/{filename}'.format(base_url=_get_repo_url(),
+                                                                   
namespace=namespace,
+                                                                   
filename=filename)

-- 
To stop receiving notification emails like this one, please contact
j...@apache.org.

Reply via email to