This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 29f7f43 gluon dataset download refactor (#10008) 29f7f43 is described below commit 29f7f43d9aca7a9abd67a0b7baea655b7aa0ed48 Author: Sheng Zha <s...@users.noreply.github.com> AuthorDate: Tue Mar 6 14:47:09 2018 -0500 gluon dataset download refactor (#10008) --- python/mxnet/gluon/contrib/data/text.py | 21 ++++++++++++++------- python/mxnet/gluon/data/dataset.py | 23 ++++++----------------- python/mxnet/gluon/data/vision/datasets.py | 20 ++++++++++++-------- python/mxnet/gluon/utils.py | 30 +++++++++++++++++++++++++++--- 4 files changed, 59 insertions(+), 35 deletions(-) diff --git a/python/mxnet/gluon/contrib/data/text.py b/python/mxnet/gluon/contrib/data/text.py index 1f52954..8f69dbd 100644 --- a/python/mxnet/gluon/contrib/data/text.py +++ b/python/mxnet/gluon/contrib/data/text.py @@ -28,16 +28,17 @@ import numpy as np from . import _constants as C from ...data import dataset -from ...utils import download, check_sha1 +from ...utils import download, check_sha1, _get_repo_file_url from ....contrib import text from .... import nd class _LanguageModelDataset(dataset._DownloadedDataset): # pylint: disable=abstract-method - def __init__(self, repo_dir, root, vocabulary): + def __init__(self, root, namespace, vocabulary): self._vocab = vocabulary self._counter = None - super(_LanguageModelDataset, self).__init__(repo_dir, root, None) + self._namespace = namespace + super(_LanguageModelDataset, self).__init__(root, None) @property def vocabulary(self): @@ -76,7 +77,7 @@ class _WikiText(_LanguageModelDataset): data_file_name, data_hash = self._data_file[self._segment] path = os.path.join(self._root, data_file_name) if not os.path.exists(path) or not check_sha1(path, data_hash): - downloaded_file_path = download(self._get_url(archive_file_name), + downloaded_file_path = download(_get_repo_file_url(self._namespace, archive_file_name), path=self._root, sha1_hash=archive_hash) @@ -89,11 +90,17 @@ class _WikiText(_LanguageModelDataset): open(dest, "wb") as target: shutil.copyfileobj(source, target) - data, label = self._read_batch(os.path.join(self._root, data_file_name)) + data, label = self._read_batch(path) self._data = nd.array(data, dtype=data.dtype).reshape((-1, self._seq_len)) self._label = nd.array(label, dtype=label.dtype).reshape((-1, self._seq_len)) + def __getitem__(self, idx): + return self._data[idx], self._label[idx] + + def __len__(self): + return len(self._label) + class WikiText2(_WikiText): """WikiText-2 word-level dataset for language modeling, from Salesforce research. @@ -130,7 +137,7 @@ class WikiText2(_WikiText): 'c7b8ce0aa086fb34dab808c5c49224211eb2b172')} self._segment = segment self._seq_len = seq_len - super(WikiText2, self).__init__('wikitext-2', root, vocab) + super(WikiText2, self).__init__(root, 'wikitext-2', vocab) class WikiText103(_WikiText): @@ -167,4 +174,4 @@ class WikiText103(_WikiText): '8a5befc548865cec54ed4273cf87dbbad60d1e47')} self._segment = segment self._seq_len = seq_len - super(WikiText103, self).__init__('wikitext-103', root, vocab) + super(WikiText103, self).__init__(root, 'wikitext-103', vocab) diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py index fe1e813..bf5fa0a 100644 --- a/python/mxnet/gluon/data/dataset.py +++ b/python/mxnet/gluon/data/dataset.py @@ -182,24 +182,18 @@ class RecordFileDataset(Dataset): def __len__(self): return len(self._record.keys) -apache_repo_url = 'https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/' class _DownloadedDataset(Dataset): """Base class for MNIST, cifar10, etc.""" - def __init__(self, repo_dir, root, transform): - self._root = os.path.expanduser(root) - self._repo_dir = repo_dir + def __init__(self, root, transform): + super(_DownloadedDataset, self).__init__() self._transform = transform self._data = None self._label = None - - repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url) - if repo_url[-1] != '/': - repo_url = repo_url+'/' - self._base_url = repo_url - - if not os.path.isdir(self._root): - os.makedirs(self._root) + root = os.path.expanduser(root) + self._root = root + if not os.path.isdir(root): + os.makedirs(root) self._get_data() def __getitem__(self, idx): @@ -212,8 +206,3 @@ class _DownloadedDataset(Dataset): def _get_data(self): raise NotImplementedError - - def _get_url(self, filename): - return '{base_url}gluon/dataset/{repo_dir}/{filename}'.format(base_url=self._base_url, - repo_dir=self._repo_dir, - filename=filename) diff --git a/python/mxnet/gluon/data/vision/datasets.py b/python/mxnet/gluon/data/vision/datasets.py index 4568da6..d199cad 100644 --- a/python/mxnet/gluon/data/vision/datasets.py +++ b/python/mxnet/gluon/data/vision/datasets.py @@ -29,7 +29,7 @@ import warnings import numpy as np from .. import dataset -from ...utils import download, check_sha1 +from ...utils import download, check_sha1, _get_repo_file_url from .... import nd, image, recordio @@ -62,7 +62,8 @@ class MNIST(dataset._DownloadedDataset): 'c3a25af1f52dad7f726cce8cacb138654b760d48') self._test_label = ('t10k-labels-idx1-ubyte.gz', '763e7fa3757d93b0cdec073cef058b2004252c17') - super(MNIST, self).__init__('mnist', root, transform) + self._namespace = 'mnist' + super(MNIST, self).__init__(root, transform) def _get_data(self): if self._train: @@ -70,10 +71,10 @@ class MNIST(dataset._DownloadedDataset): else: data, label = self._test_data, self._test_label - data_file = download(self._get_url(data[0]), + data_file = download(_get_repo_file_url(self._namespace, data[0]), path=self._root, sha1_hash=data[1]) - label_file = download(self._get_url(label[0]), + label_file = download(_get_repo_file_url(self._namespace, label[0]), path=self._root, sha1_hash=label[1]) @@ -121,7 +122,8 @@ class FashionMNIST(MNIST): '626ed6a7c06dd17c0eec72fa3be1740f146a2863') self._test_label = ('t10k-labels-idx1-ubyte.gz', '17f9ab60e7257a1620f4ad76bbbaf857c3920701') - super(MNIST, self).__init__('fashion-mnist', root, transform) # pylint: disable=bad-super-call + self._namespace = 'fashion-mnist' + super(MNIST, self).__init__(root, transform) # pylint: disable=bad-super-call class CIFAR10(dataset._DownloadedDataset): @@ -152,7 +154,8 @@ class CIFAR10(dataset._DownloadedDataset): ('data_batch_4.bin', 'aab85764eb3584312d3c7f65fd2fd016e36a258e'), ('data_batch_5.bin', '26e2849e66a845b7f1e4614ae70f4889ae604628')] self._test_data = [('test_batch.bin', '67eb016db431130d61cd03c7ad570b013799c88c')] - super(CIFAR10, self).__init__('cifar10', root, transform) + self._namespace = 'cifar10' + super(CIFAR10, self).__init__(root, transform) def _read_batch(self, filename): with open(filename, 'rb') as fin: @@ -165,7 +168,7 @@ class CIFAR10(dataset._DownloadedDataset): if any(not os.path.exists(path) or not check_sha1(path, sha1) for path, sha1 in ((os.path.join(self._root, name), sha1) for name, sha1 in self._train_data + self._test_data)): - filename = download(self._get_url(self._archive_file[0]), + filename = download(_get_repo_file_url(self._namespace, self._archive_file[0]), path=self._root, sha1_hash=self._archive_file[1]) @@ -212,7 +215,8 @@ class CIFAR100(CIFAR10): self._train_data = [('train.bin', 'e207cd2e05b73b1393c74c7f5e7bea451d63e08e')] self._test_data = [('test.bin', '8fb6623e830365ff53cf14adec797474f5478006')] self._fine_label = fine_label - super(CIFAR10, self).__init__('cifar100', root, transform) # pylint: disable=bad-super-call + self._namespace = 'cifar100' + super(CIFAR10, self).__init__(root, transform) # pylint: disable=bad-super-call def _read_batch(self, filename): with open(filename, 'rb') as fin: diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 88effc9..9ebfe4c 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -192,10 +192,12 @@ def download(url, path=None, overwrite=False, sha1_hash=None): """ if path is None: fname = url.split('/')[-1] - elif os.path.isdir(path): - fname = os.path.join(path, url.split('/')[-1]) else: - fname = path + path = os.path.expanduser(path) + if os.path.isdir(path): + fname = os.path.join(path, url.split('/')[-1]) + else: + fname = path if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) @@ -218,3 +220,25 @@ def download(url, path=None, overwrite=False, sha1_hash=None): 'the default repo.'.format(fname)) return fname + +def _get_repo_url(): + """Return the base URL for Gluon dataset and model repository.""" + default_repo = 'https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/' + repo_url = os.environ.get('MXNET_GLUON_REPO', default_repo) + if repo_url[-1] != '/': + repo_url = repo_url+'/' + return repo_url + +def _get_repo_file_url(namespace, filename): + """Return the URL for hoste file in Gluon repository. + + Parameters + ---------- + namespace : str + Namespace of the file. + filename : str + Name of the file + """ + return '{base_url}gluon/dataset/{namespace}/{filename}'.format(base_url=_get_repo_url(), + namespace=namespace, + filename=filename) -- To stop receiving notification emails like this one, please contact j...@apache.org.