This is an automated email from the ASF dual-hosted git repository.
marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 564e01a [MXNET-769] set MXNET_HOME as base for downloaded models
through base.data_dir() (#11636)
564e01a is described below
commit 564e01acdf460535d4ab7340db39b0d10028b453
Author: Pedro Larroy <[email protected]>
AuthorDate: Thu Aug 2 09:56:33 2018 +0200
[MXNET-769] set MXNET_HOME as base for downloaded models through
base.data_dir() (#11636)
* set MXNET_DATA_DIR as base for downloaded models through base.data_dir()
push joblib to save containers so is not required when running
* MXNET_DATA_DIR -> MXNET_HOME
---
ci/docker_cache.py | 2 +-
.../examples/scripts/get_cifar_data.sh | 4 +-
.../examples/scripts/get_mnist_data.sh | 4 +-
contrib/clojure-package/scripts/get_cifar_data.sh | 4 +-
contrib/clojure-package/scripts/get_mnist_data.sh | 4 +-
docs/faq/env_var.md | 4 ++
python/mxnet/base.py | 24 ++++++++++-
python/mxnet/contrib/text/embedding.py | 9 ++--
python/mxnet/gluon/contrib/data/text.py | 11 +++--
python/mxnet/gluon/data/vision/datasets.py | 18 ++++----
python/mxnet/gluon/model_zoo/model_store.py | 17 ++++----
python/mxnet/gluon/model_zoo/vision/__init__.py | 2 +-
python/mxnet/gluon/model_zoo/vision/alexnet.py | 5 ++-
python/mxnet/gluon/model_zoo/vision/densenet.py | 13 +++---
python/mxnet/gluon/model_zoo/vision/inception.py | 5 ++-
python/mxnet/gluon/model_zoo/vision/mobilenet.py | 9 ++--
python/mxnet/gluon/model_zoo/vision/resnet.py | 25 +++++------
python/mxnet/gluon/model_zoo/vision/squeezenet.py | 9 ++--
python/mxnet/gluon/model_zoo/vision/vgg.py | 21 ++++-----
.../get_mnist_data.sh => python/mxnet/util.py | 30 +++++--------
scala-package/core/scripts/get_cifar_data.sh | 4 +-
scala-package/core/scripts/get_mnist_data.sh | 4 +-
.../src/test/scala/org/apache/mxnet/TestUtil.scala | 2 +-
.../org/apache/mxnetexamples/gan/GanMnist.scala | 2 +-
.../imclassification/TrainMnist.scala | 2 +-
.../imageclassifier/ImageClassifierExample.scala | 6 +--
tests/python/unittest/test_base.py | 50 ++++++++++++++++++++++
27 files changed, 182 insertions(+), 108 deletions(-)
diff --git a/ci/docker_cache.py b/ci/docker_cache.py
index 6637ec3..7a6d110 100755
--- a/ci/docker_cache.py
+++ b/ci/docker_cache.py
@@ -31,7 +31,6 @@ import sys
import subprocess
import json
import build as build_util
-from joblib import Parallel, delayed
@@ -43,6 +42,7 @@ def build_save_containers(platforms, registry, load_cache) ->
int:
:param load_cache: Load cache before building
:return: 1 if error occurred, 0 otherwise
"""
+ from joblib import Parallel, delayed
if len(platforms) == 0:
return 0
diff --git a/contrib/clojure-package/examples/scripts/get_cifar_data.sh
b/contrib/clojure-package/examples/scripts/get_cifar_data.sh
index 372c7bb..12b3770 100755
--- a/contrib/clojure-package/examples/scripts/get_cifar_data.sh
+++ b/contrib/clojure-package/examples/scripts/get_cifar_data.sh
@@ -20,8 +20,8 @@
set -evx
-if [ ! -z "$MXNET_DATA_DIR" ]; then
- data_path="$MXNET_DATA_DIR"
+if [ ! -z "$MXNET_HOME" ]; then
+ data_path="$MXNET_HOME"
else
data_path="./data"
fi
diff --git a/contrib/clojure-package/examples/scripts/get_mnist_data.sh
b/contrib/clojure-package/examples/scripts/get_mnist_data.sh
index 6f32b85..703ece2 100755
--- a/contrib/clojure-package/examples/scripts/get_mnist_data.sh
+++ b/contrib/clojure-package/examples/scripts/get_mnist_data.sh
@@ -20,8 +20,8 @@
set -evx
-if [ ! -z "$MXNET_DATA_DIR" ]; then
- data_path="$MXNET_DATA_DIR"
+if [ ! -z "$MXNET_HOME" ]; then
+ data_path="$MXNET_HOME"
else
data_path="./data"
fi
diff --git a/contrib/clojure-package/scripts/get_cifar_data.sh
b/contrib/clojure-package/scripts/get_cifar_data.sh
index 372c7bb..12b3770 100755
--- a/contrib/clojure-package/scripts/get_cifar_data.sh
+++ b/contrib/clojure-package/scripts/get_cifar_data.sh
@@ -20,8 +20,8 @@
set -evx
-if [ ! -z "$MXNET_DATA_DIR" ]; then
- data_path="$MXNET_DATA_DIR"
+if [ ! -z "$MXNET_HOME" ]; then
+ data_path="$MXNET_HOME"
else
data_path="./data"
fi
diff --git a/contrib/clojure-package/scripts/get_mnist_data.sh
b/contrib/clojure-package/scripts/get_mnist_data.sh
index 6f32b85..703ece2 100755
--- a/contrib/clojure-package/scripts/get_mnist_data.sh
+++ b/contrib/clojure-package/scripts/get_mnist_data.sh
@@ -20,8 +20,8 @@
set -evx
-if [ ! -z "$MXNET_DATA_DIR" ]; then
- data_path="$MXNET_DATA_DIR"
+if [ ! -z "$MXNET_HOME" ]; then
+ data_path="$MXNET_HOME"
else
data_path="./data"
fi
diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md
index 881bc14..6e9a359 100644
--- a/docs/faq/env_var.md
+++ b/docs/faq/env_var.md
@@ -152,6 +152,10 @@ When USE_PROFILER is enabled in Makefile or CMake, the
following environments ca
- Values: String
```(default='https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/'```
- The repository url to be used for Gluon datasets and pre-trained models.
+* MXNET_HOME
+ - Data directory in the filesystem for storage, for example when downloading
gluon models.
+ - Default in *nix is .mxnet APPDATA/mxnet in windows.
+
Settings for Minimum Memory Usage
---------------------------------
- Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1```
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 4df794b..3d8ee01 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -22,11 +22,11 @@ from __future__ import absolute_import
import atexit
import ctypes
-import inspect
import os
import sys
import warnings
-
+import inspect
+import platform
import numpy as np
from . import libinfo
@@ -59,6 +59,26 @@ else:
py_str = lambda x: x
+def data_dir_default():
+ """
+
+ :return: default data directory depending on the platform and environment
variables
+ """
+ system = platform.system()
+ if system == 'Windows':
+ return os.path.join(os.environ.get('APPDATA'), 'mxnet')
+ else:
+ return os.path.join(os.path.expanduser("~"), '.mxnet')
+
+
+def data_dir():
+ """
+
+ :return: data directory in the filesystem for storage, for example when
downloading models
+ """
+ return os.getenv('MXNET_HOME', data_dir_default())
+
+
class _NullType(object):
"""Placeholder for arguments"""
def __repr__(self):
diff --git a/python/mxnet/contrib/text/embedding.py
b/python/mxnet/contrib/text/embedding.py
index 6598718..38defb4 100644
--- a/python/mxnet/contrib/text/embedding.py
+++ b/python/mxnet/contrib/text/embedding.py
@@ -34,6 +34,7 @@ from . import _constants as C
from . import vocab
from ... import ndarray as nd
from ... import registry
+from ... import base
def register(embedding_cls):
@@ -496,7 +497,7 @@ class GloVe(_TokenEmbedding):
----------
pretrained_file_name : str, default 'glove.840B.300d.txt'
The name of the pre-trained token embedding file.
- embedding_root : str, default os.path.join('~', '.mxnet', 'embeddings')
+ embedding_root : str, default $MXNET_HOME/embeddings
The root directory for storing embedding-related files.
init_unknown_vec : callback
The callback used to initialize the embedding vector for the unknown
token.
@@ -541,7 +542,7 @@ class GloVe(_TokenEmbedding):
return archive
def __init__(self, pretrained_file_name='glove.840B.300d.txt',
- embedding_root=os.path.join('~', '.mxnet', 'embeddings'),
+ embedding_root=os.path.join(base.data_dir(), 'embeddings'),
init_unknown_vec=nd.zeros, vocabulary=None, **kwargs):
GloVe._check_pretrained_file_names(pretrained_file_name)
@@ -600,7 +601,7 @@ class FastText(_TokenEmbedding):
----------
pretrained_file_name : str, default 'wiki.en.vec'
The name of the pre-trained token embedding file.
- embedding_root : str, default os.path.join('~', '.mxnet', 'embeddings')
+ embedding_root : str, default $MXNET_HOME/embeddings
The root directory for storing embedding-related files.
init_unknown_vec : callback
The callback used to initialize the embedding vector for the unknown
token.
@@ -642,7 +643,7 @@ class FastText(_TokenEmbedding):
return '.'.join(pretrained_file_name.split('.')[:-1])+'.zip'
def __init__(self, pretrained_file_name='wiki.simple.vec',
- embedding_root=os.path.join('~', '.mxnet', 'embeddings'),
+ embedding_root=os.path.join(base.data_dir(), 'embeddings'),
init_unknown_vec=nd.zeros, vocabulary=None, **kwargs):
FastText._check_pretrained_file_names(pretrained_file_name)
diff --git a/python/mxnet/gluon/contrib/data/text.py
b/python/mxnet/gluon/contrib/data/text.py
index 98fe6b6..9e78e3c 100644
--- a/python/mxnet/gluon/contrib/data/text.py
+++ b/python/mxnet/gluon/contrib/data/text.py
@@ -30,8 +30,7 @@ from . import _constants as C
from ...data import dataset
from ...utils import download, check_sha1, _get_repo_file_url
from ....contrib import text
-from .... import nd
-
+from .... import nd, base
class _LanguageModelDataset(dataset._DownloadedDataset): # pylint:
disable=abstract-method
def __init__(self, root, namespace, vocabulary):
@@ -116,7 +115,7 @@ class WikiText2(_WikiText):
Parameters
----------
- root : str, default '~/.mxnet/datasets/wikitext-2'
+ root : str, default $MXNET_HOME/datasets/wikitext-2
Path to temp folder for storing data.
segment : str, default 'train'
Dataset segment. Options are 'train', 'validation', 'test'.
@@ -127,7 +126,7 @@ class WikiText2(_WikiText):
The sequence length of each sample, regardless of the sentence
boundary.
"""
- def __init__(self, root=os.path.join('~', '.mxnet', 'datasets',
'wikitext-2'),
+ def __init__(self, root=os.path.join(base.data_dir(), 'datasets',
'wikitext-2'),
segment='train', vocab=None, seq_len=35):
self._archive_file = ('wikitext-2-v1.zip',
'3c914d17d80b1459be871a5039ac23e752a53cbe')
self._data_file = {'train': ('wiki.train.tokens',
@@ -154,7 +153,7 @@ class WikiText103(_WikiText):
Parameters
----------
- root : str, default '~/.mxnet/datasets/wikitext-103'
+ root : str, default $MXNET_HOME/datasets/wikitext-103
Path to temp folder for storing data.
segment : str, default 'train'
Dataset segment. Options are 'train', 'validation', 'test'.
@@ -164,7 +163,7 @@ class WikiText103(_WikiText):
seq_len : int, default 35
The sequence length of each sample, regardless of the sentence
boundary.
"""
- def __init__(self, root=os.path.join('~', '.mxnet', 'datasets',
'wikitext-103'),
+ def __init__(self, root=os.path.join(base.data_dir(), 'datasets',
'wikitext-103'),
segment='train', vocab=None, seq_len=35):
self._archive_file = ('wikitext-103-v1.zip',
'0aec09a7537b58d4bb65362fee27650eeaba625a')
self._data_file = {'train': ('wiki.train.tokens',
diff --git a/python/mxnet/gluon/data/vision/datasets.py
b/python/mxnet/gluon/data/vision/datasets.py
index 74a5aeb..2c98000 100644
--- a/python/mxnet/gluon/data/vision/datasets.py
+++ b/python/mxnet/gluon/data/vision/datasets.py
@@ -30,7 +30,7 @@ import numpy as np
from .. import dataset
from ...utils import download, check_sha1, _get_repo_file_url
-from .... import nd, image, recordio
+from .... import nd, image, recordio, base
class MNIST(dataset._DownloadedDataset):
@@ -40,7 +40,7 @@ class MNIST(dataset._DownloadedDataset):
Parameters
----------
- root : str, default '~/.mxnet/datasets/mnist'
+ root : str, default $MXNET_HOME/datasets/mnist
Path to temp folder for storing data.
train : bool, default True
Whether to load the training or testing set.
@@ -51,7 +51,7 @@ class MNIST(dataset._DownloadedDataset):
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
- def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'mnist'),
+ def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'mnist'),
train=True, transform=None):
self._train = train
self._train_data = ('train-images-idx3-ubyte.gz',
@@ -101,7 +101,7 @@ class FashionMNIST(MNIST):
Parameters
----------
- root : str, default '~/.mxnet/datasets/fashion-mnist'
+ root : str, default $MXNET_HOME/datasets/fashion-mnist'
Path to temp folder for storing data.
train : bool, default True
Whether to load the training or testing set.
@@ -112,7 +112,7 @@ class FashionMNIST(MNIST):
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
- def __init__(self, root=os.path.join('~', '.mxnet', 'datasets',
'fashion-mnist'),
+ def __init__(self, root=os.path.join(base.data_dir(), 'datasets',
'fashion-mnist'),
train=True, transform=None):
self._train = train
self._train_data = ('train-images-idx3-ubyte.gz',
@@ -134,7 +134,7 @@ class CIFAR10(dataset._DownloadedDataset):
Parameters
----------
- root : str, default '~/.mxnet/datasets/cifar10'
+ root : str, default $MXNET_HOME/datasets/cifar10
Path to temp folder for storing data.
train : bool, default True
Whether to load the training or testing set.
@@ -145,7 +145,7 @@ class CIFAR10(dataset._DownloadedDataset):
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
- def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'cifar10'),
+ def __init__(self, root=os.path.join(base.data_dir(), 'datasets',
'cifar10'),
train=True, transform=None):
self._train = train
self._archive_file = ('cifar-10-binary.tar.gz',
'fab780a1e191a7eda0f345501ccd62d20f7ed891')
@@ -197,7 +197,7 @@ class CIFAR100(CIFAR10):
Parameters
----------
- root : str, default '~/.mxnet/datasets/cifar100'
+ root : str, default $MXNET_HOME/datasets/cifar100
Path to temp folder for storing data.
fine_label : bool, default False
Whether to load the fine-grained (100 classes) or coarse-grained (20
super-classes) labels.
@@ -210,7 +210,7 @@ class CIFAR100(CIFAR10):
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
- def __init__(self, root=os.path.join('~', '.mxnet', 'datasets',
'cifar100'),
+ def __init__(self, root=os.path.join(base.data_dir(), 'datasets',
'cifar100'),
fine_label=False, train=True, transform=None):
self._train = train
self._archive_file = ('cifar-100-binary.tar.gz',
'a0bb982c76b83111308126cc779a992fa506b90b')
diff --git a/python/mxnet/gluon/model_zoo/model_store.py
b/python/mxnet/gluon/model_zoo/model_store.py
index 7eead68..11ac47b 100644
--- a/python/mxnet/gluon/model_zoo/model_store.py
+++ b/python/mxnet/gluon/model_zoo/model_store.py
@@ -21,8 +21,10 @@ from __future__ import print_function
__all__ = ['get_model_file', 'purge']
import os
import zipfile
+import logging
from ..utils import download, check_sha1
+from ... import base, util
_model_sha1 = {name: checksum for checksum, name in [
('44335d1f0046b328243b32a26a4fbd62d9057b45', 'alexnet'),
@@ -68,7 +70,7 @@ def short_hash(name):
raise ValueError('Pretrained model for {name} is not
available.'.format(name=name))
return _model_sha1[name][:8]
-def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')):
+def get_model_file(name, root=os.path.join(base.data_dir(), 'models')):
r"""Return location for the pretrained on local file system.
This function will download from online model zoo when model cannot be
found or has mismatch.
@@ -78,7 +80,7 @@ def get_model_file(name, root=os.path.join('~', '.mxnet',
'models')):
----------
name : str
Name of the model.
- root : str, default '~/.mxnet/models'
+ root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
Returns
@@ -95,12 +97,11 @@ def get_model_file(name, root=os.path.join('~', '.mxnet',
'models')):
if check_sha1(file_path, sha1_hash):
return file_path
else:
- print('Mismatch in the content of model file detected. Downloading
again.')
+ logging.warning('Mismatch in the content of model file detected.
Downloading again.')
else:
- print('Model file is not found. Downloading.')
+ logging.info('Model file not found. Downloading to %s.', file_path)
- if not os.path.exists(root):
- os.makedirs(root)
+ util.makedirs(root)
zip_file_path = os.path.join(root, file_name+'.zip')
repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url)
@@ -118,12 +119,12 @@ def get_model_file(name, root=os.path.join('~', '.mxnet',
'models')):
else:
raise ValueError('Downloaded file has different hash. Please try
again.')
-def purge(root=os.path.join('~', '.mxnet', 'models')):
+def purge(root=os.path.join(base.data_dir(), 'models')):
r"""Purge all pretrained model files in local file store.
Parameters
----------
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
root = os.path.expanduser(root)
diff --git a/python/mxnet/gluon/model_zoo/vision/__init__.py
b/python/mxnet/gluon/model_zoo/vision/__init__.py
index a6e5dc1..7d33ce4 100644
--- a/python/mxnet/gluon/model_zoo/vision/__init__.py
+++ b/python/mxnet/gluon/model_zoo/vision/__init__.py
@@ -101,7 +101,7 @@ def get_model(name, **kwargs):
Number of classes for the output layer.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
Returns
diff --git a/python/mxnet/gluon/model_zoo/vision/alexnet.py
b/python/mxnet/gluon/model_zoo/vision/alexnet.py
index fdb0062..daf4617 100644
--- a/python/mxnet/gluon/model_zoo/vision/alexnet.py
+++ b/python/mxnet/gluon/model_zoo/vision/alexnet.py
@@ -25,6 +25,7 @@ import os
from ....context import cpu
from ...block import HybridBlock
from ... import nn
+from .... import base
# Net
class AlexNet(HybridBlock):
@@ -68,7 +69,7 @@ class AlexNet(HybridBlock):
# Constructor
def alexnet(pretrained=False, ctx=cpu(),
- root=os.path.join('~', '.mxnet', 'models'), **kwargs):
+ root=os.path.join(base.data_dir(), 'models'), **kwargs):
r"""AlexNet model from the `"One weird trick..."
<https://arxiv.org/abs/1404.5997>`_ paper.
Parameters
@@ -77,7 +78,7 @@ def alexnet(pretrained=False, ctx=cpu(),
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
"""
net = AlexNet(**kwargs)
diff --git a/python/mxnet/gluon/model_zoo/vision/densenet.py
b/python/mxnet/gluon/model_zoo/vision/densenet.py
index b03f5ce..83febd3 100644
--- a/python/mxnet/gluon/model_zoo/vision/densenet.py
+++ b/python/mxnet/gluon/model_zoo/vision/densenet.py
@@ -26,6 +26,7 @@ from ....context import cpu
from ...block import HybridBlock
from ... import nn
from ...contrib.nn import HybridConcurrent, Identity
+from .... import base
# Helpers
def _make_dense_block(num_layers, bn_size, growth_rate, dropout, stage_index):
@@ -122,7 +123,7 @@ densenet_spec = {121: (64, 32, [6, 12, 24, 16]),
# Constructor
def get_densenet(num_layers, pretrained=False, ctx=cpu(),
- root=os.path.join('~', '.mxnet', 'models'), **kwargs):
+ root=os.path.join(base.data_dir(), 'models'), **kwargs):
r"""Densenet-BC model from the
`"Densely Connected Convolutional Networks"
<https://arxiv.org/pdf/1608.06993.pdf>`_ paper.
@@ -134,7 +135,7 @@ def get_densenet(num_layers, pretrained=False, ctx=cpu(),
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
"""
num_init_features, growth_rate, block_config = densenet_spec[num_layers]
@@ -154,7 +155,7 @@ def densenet121(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_densenet(121, **kwargs)
@@ -169,7 +170,7 @@ def densenet161(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_densenet(161, **kwargs)
@@ -184,7 +185,7 @@ def densenet169(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_densenet(169, **kwargs)
@@ -199,7 +200,7 @@ def densenet201(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_densenet(201, **kwargs)
diff --git a/python/mxnet/gluon/model_zoo/vision/inception.py
b/python/mxnet/gluon/model_zoo/vision/inception.py
index 7c54691..6bdc526 100644
--- a/python/mxnet/gluon/model_zoo/vision/inception.py
+++ b/python/mxnet/gluon/model_zoo/vision/inception.py
@@ -26,6 +26,7 @@ from ....context import cpu
from ...block import HybridBlock
from ... import nn
from ...contrib.nn import HybridConcurrent
+from .... import base
# Helpers
def _make_basic_conv(**kwargs):
@@ -199,7 +200,7 @@ class Inception3(HybridBlock):
# Constructor
def inception_v3(pretrained=False, ctx=cpu(),
- root=os.path.join('~', '.mxnet', 'models'), **kwargs):
+ root=os.path.join(base.data_dir(), 'models'), **kwargs):
r"""Inception v3 model from
`"Rethinking the Inception Architecture for Computer Vision"
<http://arxiv.org/abs/1512.00567>`_ paper.
@@ -210,7 +211,7 @@ def inception_v3(pretrained=False, ctx=cpu(),
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
"""
net = Inception3(**kwargs)
diff --git a/python/mxnet/gluon/model_zoo/vision/mobilenet.py
b/python/mxnet/gluon/model_zoo/vision/mobilenet.py
index 1a2c9b9..1a84e05 100644
--- a/python/mxnet/gluon/model_zoo/vision/mobilenet.py
+++ b/python/mxnet/gluon/model_zoo/vision/mobilenet.py
@@ -30,6 +30,7 @@ import os
from ... import nn
from ....context import cpu
from ...block import HybridBlock
+from .... import base
# Helpers
@@ -188,7 +189,7 @@ class MobileNetV2(nn.HybridBlock):
# Constructor
def get_mobilenet(multiplier, pretrained=False, ctx=cpu(),
- root=os.path.join('~', '.mxnet', 'models'), **kwargs):
+ root=os.path.join(base.data_dir(), 'models'), **kwargs):
r"""MobileNet model from the
`"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision
Applications"
<https://arxiv.org/abs/1704.04861>`_ paper.
@@ -203,7 +204,7 @@ def get_mobilenet(multiplier, pretrained=False, ctx=cpu(),
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
"""
net = MobileNet(multiplier, **kwargs)
@@ -219,7 +220,7 @@ def get_mobilenet(multiplier, pretrained=False, ctx=cpu(),
def get_mobilenet_v2(multiplier, pretrained=False, ctx=cpu(),
- root=os.path.join('~', '.mxnet', 'models'), **kwargs):
+ root=os.path.join(base.data_dir(), 'models'), **kwargs):
r"""MobileNetV2 model from the
`"Inverted Residuals and Linear Bottlenecks:
Mobile Networks for Classification, Detection and Segmentation"
@@ -235,7 +236,7 @@ def get_mobilenet_v2(multiplier, pretrained=False,
ctx=cpu(),
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
"""
net = MobileNetV2(multiplier, **kwargs)
diff --git a/python/mxnet/gluon/model_zoo/vision/resnet.py
b/python/mxnet/gluon/model_zoo/vision/resnet.py
index da279b8..48390de 100644
--- a/python/mxnet/gluon/model_zoo/vision/resnet.py
+++ b/python/mxnet/gluon/model_zoo/vision/resnet.py
@@ -32,6 +32,7 @@ import os
from ....context import cpu
from ...block import HybridBlock
from ... import nn
+from .... import base
# Helpers
def _conv3x3(channels, stride, in_channels):
@@ -356,7 +357,7 @@ resnet_block_versions = [{'basic_block': BasicBlockV1,
'bottle_neck': Bottleneck
# Constructor
def get_resnet(version, num_layers, pretrained=False, ctx=cpu(),
- root=os.path.join('~', '.mxnet', 'models'), **kwargs):
+ root=os.path.join(base.data_dir(), 'models'), **kwargs):
r"""ResNet V1 model from `"Deep Residual Learning for Image Recognition"
<http://arxiv.org/abs/1512.03385>`_ paper.
ResNet V2 model from `"Identity Mappings in Deep Residual Networks"
@@ -372,7 +373,7 @@ def get_resnet(version, num_layers, pretrained=False,
ctx=cpu(),
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
"""
assert num_layers in resnet_spec, \
@@ -400,7 +401,7 @@ def resnet18_v1(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_resnet(1, 18, **kwargs)
@@ -415,7 +416,7 @@ def resnet34_v1(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_resnet(1, 34, **kwargs)
@@ -430,7 +431,7 @@ def resnet50_v1(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_resnet(1, 50, **kwargs)
@@ -445,7 +446,7 @@ def resnet101_v1(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_resnet(1, 101, **kwargs)
@@ -460,7 +461,7 @@ def resnet152_v1(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_resnet(1, 152, **kwargs)
@@ -475,7 +476,7 @@ def resnet18_v2(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_resnet(2, 18, **kwargs)
@@ -490,7 +491,7 @@ def resnet34_v2(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_resnet(2, 34, **kwargs)
@@ -505,7 +506,7 @@ def resnet50_v2(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_resnet(2, 50, **kwargs)
@@ -520,7 +521,7 @@ def resnet101_v2(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_resnet(2, 101, **kwargs)
@@ -535,7 +536,7 @@ def resnet152_v2(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_resnet(2, 152, **kwargs)
diff --git a/python/mxnet/gluon/model_zoo/vision/squeezenet.py
b/python/mxnet/gluon/model_zoo/vision/squeezenet.py
index aaff4c3..b97d127 100644
--- a/python/mxnet/gluon/model_zoo/vision/squeezenet.py
+++ b/python/mxnet/gluon/model_zoo/vision/squeezenet.py
@@ -26,6 +26,7 @@ from ....context import cpu
from ...block import HybridBlock
from ... import nn
from ...contrib.nn import HybridConcurrent
+from .... import base
# Helpers
def _make_fire(squeeze_channels, expand1x1_channels, expand3x3_channels):
@@ -110,7 +111,7 @@ class SqueezeNet(HybridBlock):
# Constructor
def get_squeezenet(version, pretrained=False, ctx=cpu(),
- root=os.path.join('~', '.mxnet', 'models'), **kwargs):
+ root=os.path.join(base.data_dir(), 'models'), **kwargs):
r"""SqueezeNet model from the `"SqueezeNet: AlexNet-level accuracy with
50x fewer parameters
and <0.5MB model size" <https://arxiv.org/abs/1602.07360>`_ paper.
SqueezeNet 1.1 model from the `official SqueezeNet repo
@@ -126,7 +127,7 @@ def get_squeezenet(version, pretrained=False, ctx=cpu(),
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
"""
net = SqueezeNet(version, **kwargs)
@@ -145,7 +146,7 @@ def squeezenet1_0(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_squeezenet('1.0', **kwargs)
@@ -162,7 +163,7 @@ def squeezenet1_1(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_squeezenet('1.1', **kwargs)
diff --git a/python/mxnet/gluon/model_zoo/vision/vgg.py
b/python/mxnet/gluon/model_zoo/vision/vgg.py
index a3b1685..9a740e6 100644
--- a/python/mxnet/gluon/model_zoo/vision/vgg.py
+++ b/python/mxnet/gluon/model_zoo/vision/vgg.py
@@ -30,6 +30,7 @@ from ....context import cpu
from ....initializer import Xavier
from ...block import HybridBlock
from ... import nn
+from .... import base
class VGG(HybridBlock):
@@ -94,7 +95,7 @@ vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
# Constructors
def get_vgg(num_layers, pretrained=False, ctx=cpu(),
- root=os.path.join('~', '.mxnet', 'models'), **kwargs):
+ root=os.path.join(base.data_dir(), 'models'), **kwargs):
r"""VGG model from the `"Very Deep Convolutional Networks for Large-Scale
Image Recognition"
<https://arxiv.org/abs/1409.1556>`_ paper.
@@ -106,7 +107,7 @@ def get_vgg(num_layers, pretrained=False, ctx=cpu(),
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
"""
layers, filters = vgg_spec[num_layers]
@@ -128,7 +129,7 @@ def vgg11(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_vgg(11, **kwargs)
@@ -143,7 +144,7 @@ def vgg13(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_vgg(13, **kwargs)
@@ -158,7 +159,7 @@ def vgg16(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_vgg(16, **kwargs)
@@ -173,7 +174,7 @@ def vgg19(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
return get_vgg(19, **kwargs)
@@ -189,7 +190,7 @@ def vgg11_bn(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
kwargs['batch_norm'] = True
@@ -206,7 +207,7 @@ def vgg13_bn(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
kwargs['batch_norm'] = True
@@ -223,7 +224,7 @@ def vgg16_bn(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
kwargs['batch_norm'] = True
@@ -240,7 +241,7 @@ def vgg19_bn(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
- root : str, default '~/.mxnet/models'
+ root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
kwargs['batch_norm'] = True
diff --git a/scala-package/core/scripts/get_mnist_data.sh b/python/mxnet/util.py
old mode 100755
new mode 100644
similarity index 69%
copy from scala-package/core/scripts/get_mnist_data.sh
copy to python/mxnet/util.py
index 97e151b..57bc2bf
--- a/scala-package/core/scripts/get_mnist_data.sh
+++ b/python/mxnet/util.py
@@ -1,5 +1,3 @@
-#!/bin/bash
-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -16,23 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""general utility functions"""
+import os
+import sys
-set -e
-
-if [ ! -z "$MXNET_DATA_DIR" ]; then
- data_path="$MXNET_DATA_DIR"
-else
- data_path="./data"
-fi
-
-if [ ! -d "$data_path" ]; then
- mkdir -p "$data_path"
-fi
-mnist_data_path="$data_path/mnist.zip"
-if [ ! -f "$mnist_data_path" ]; then
- wget http://data.mxnet.io/mxnet/data/mnist.zip -P $data_path
- cd $data_path
- unzip -u mnist.zip
-fi
+def makedirs(d):
+ """Create directories recursively if they don't exist.
os.makedirs(exist_ok=True) is not
+ available in Python2"""
+ if sys.version_info[0] < 3:
+ from distutils.dir_util import mkpath
+ mkpath(d)
+ else:
+ os.makedirs(d, exist_ok=True)
diff --git a/scala-package/core/scripts/get_cifar_data.sh
b/scala-package/core/scripts/get_cifar_data.sh
index 9ec1c39..b061c18 100755
--- a/scala-package/core/scripts/get_cifar_data.sh
+++ b/scala-package/core/scripts/get_cifar_data.sh
@@ -20,8 +20,8 @@
set -e
-if [ ! -z "$MXNET_DATA_DIR" ]; then
- data_path="$MXNET_DATA_DIR"
+if [ ! -z "$MXNET_HOME" ]; then
+ data_path="$MXNET_HOME"
else
data_path="./data"
fi
diff --git a/scala-package/core/scripts/get_mnist_data.sh
b/scala-package/core/scripts/get_mnist_data.sh
index 97e151b..ded206f 100755
--- a/scala-package/core/scripts/get_mnist_data.sh
+++ b/scala-package/core/scripts/get_mnist_data.sh
@@ -20,8 +20,8 @@
set -e
-if [ ! -z "$MXNET_DATA_DIR" ]; then
- data_path="$MXNET_DATA_DIR"
+if [ ! -z "$MXNET_HOME" ]; then
+ data_path="$MXNET_HOME"
else
data_path="./data"
fi
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/TestUtil.scala
b/scala-package/core/src/test/scala/org/apache/mxnet/TestUtil.scala
index 1187757..4fc8ec9 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/TestUtil.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/TestUtil.scala
@@ -24,7 +24,7 @@ class TestUtil {
* @return Data direcotry path ()may be relative)
*/
def getDataDirectory: String = {
- var dataDir = System.getenv("MXNET_DATA_DIR")
+ var dataDir = System.getenv("MXNET_HOME")
if(dataDir == null) {
dataDir = "data"
} else {
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
index 6186989..70846ee 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
@@ -181,7 +181,7 @@ object GanMnist {
try {
parser.parseArgument(args.toList.asJava)
- val dataPath = if (anst.mnistDataPath == null)
System.getenv("MXNET_DATA_DIR")
+ val dataPath = if (anst.mnistDataPath == null)
System.getenv("MXNET_HOME")
else anst.mnistDataPath
assert(dataPath != null)
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
index b0ecc7d..bd0ce45 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
@@ -112,7 +112,7 @@ object TrainMnist {
try {
parser.parseArgument(args.toList.asJava)
- val dataPath = if (inst.dataDir == null) System.getenv("MXNET_DATA_DIR")
+ val dataPath = if (inst.dataDir == null) System.getenv("MXNET_HOME")
else inst.dataDir
val (dataShape, net) =
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
index e886b90..3bbd780 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
@@ -119,13 +119,13 @@ object ImageClassifierExample {
parser.parseArgument(args.toList.asJava)
- val modelPathPrefix = if (inst.modelPathPrefix == null)
System.getenv("MXNET_DATA_DIR")
+ val modelPathPrefix = if (inst.modelPathPrefix == null)
System.getenv("MXNET_HOME")
else inst.modelPathPrefix
- val inputImagePath = if (inst.inputImagePath == null)
System.getenv("MXNET_DATA_DIR")
+ val inputImagePath = if (inst.inputImagePath == null)
System.getenv("MXNET_HOME")
else inst.inputImagePath
- val inputImageDir = if (inst.inputImageDir == null)
System.getenv("MXNET_DATA_DIR")
+ val inputImageDir = if (inst.inputImageDir == null)
System.getenv("MXNET_HOME")
else inst.inputImageDir
val singleOutput = runInferenceOnSingleImage(modelPathPrefix,
inputImagePath, context)
diff --git a/tests/python/unittest/test_base.py
b/tests/python/unittest/test_base.py
new file mode 100644
index 0000000..3189729
--- /dev/null
+++ b/tests/python/unittest/test_base.py
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from mxnet.base import data_dir
+from nose.tools import *
+import os
+import unittest
+import logging
+import os.path as op
+import platform
+
+class MXNetDataDirTest(unittest.TestCase):
+ def setUp(self):
+ self.mxnet_data_dir = os.environ.get('MXNET_HOME')
+ if 'MXNET_HOME' in os.environ:
+ del os.environ['MXNET_HOME']
+
+ def tearDown(self):
+ if self.mxnet_data_dir:
+ os.environ['MXNET_HOME'] = self.mxnet_data_dir
+ else:
+ if 'MXNET_HOME' in os.environ:
+ del os.environ['MXNET_HOME']
+
+ def test_data_dir(self,):
+ prev_data_dir = data_dir()
+ system = platform.system()
+ if system != 'Windows':
+ self.assertEqual(data_dir(), op.join(op.expanduser('~'), '.mxnet'))
+ os.environ['MXNET_HOME'] = '/tmp/mxnet_data'
+ self.assertEqual(data_dir(), '/tmp/mxnet_data')
+ del os.environ['MXNET_HOME']
+ self.assertEqual(data_dir(), prev_data_dir)
+
+