marcoabreu closed pull request #11636: [MXNET-769] set MXNET_HOME as base for 
downloaded models through base.data_dir()
URL: https://github.com/apache/incubator-mxnet/pull/11636
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/ci/docker_cache.py b/ci/docker_cache.py
index 6637ec37716..7a6d1106d38 100755
--- a/ci/docker_cache.py
+++ b/ci/docker_cache.py
@@ -31,7 +31,6 @@
 import subprocess
 import json
 import build as build_util
-from joblib import Parallel, delayed
 
 
 
@@ -43,6 +42,7 @@ def build_save_containers(platforms, registry, load_cache) -> 
int:
     :param load_cache: Load cache before building
     :return: 1 if error occurred, 0 otherwise
     """
+    from joblib import Parallel, delayed
     if len(platforms) == 0:
         return 0
 
diff --git a/contrib/clojure-package/examples/scripts/get_cifar_data.sh 
b/contrib/clojure-package/examples/scripts/get_cifar_data.sh
index 372c7bb5781..12b3770c270 100755
--- a/contrib/clojure-package/examples/scripts/get_cifar_data.sh
+++ b/contrib/clojure-package/examples/scripts/get_cifar_data.sh
@@ -20,8 +20,8 @@
 
 set -evx
 
-if [ ! -z "$MXNET_DATA_DIR" ]; then
-  data_path="$MXNET_DATA_DIR"
+if [ ! -z "$MXNET_HOME" ]; then
+  data_path="$MXNET_HOME"
 else
   data_path="./data"
 fi
diff --git a/contrib/clojure-package/examples/scripts/get_mnist_data.sh 
b/contrib/clojure-package/examples/scripts/get_mnist_data.sh
index 6f32b85f480..703ece207a1 100755
--- a/contrib/clojure-package/examples/scripts/get_mnist_data.sh
+++ b/contrib/clojure-package/examples/scripts/get_mnist_data.sh
@@ -20,8 +20,8 @@
 
 set -evx
 
-if [ ! -z "$MXNET_DATA_DIR" ]; then
-  data_path="$MXNET_DATA_DIR"
+if [ ! -z "$MXNET_HOME" ]; then
+  data_path="$MXNET_HOME"
 else
   data_path="./data"
 fi
diff --git a/contrib/clojure-package/scripts/get_cifar_data.sh 
b/contrib/clojure-package/scripts/get_cifar_data.sh
index 372c7bb5781..12b3770c270 100755
--- a/contrib/clojure-package/scripts/get_cifar_data.sh
+++ b/contrib/clojure-package/scripts/get_cifar_data.sh
@@ -20,8 +20,8 @@
 
 set -evx
 
-if [ ! -z "$MXNET_DATA_DIR" ]; then
-  data_path="$MXNET_DATA_DIR"
+if [ ! -z "$MXNET_HOME" ]; then
+  data_path="$MXNET_HOME"
 else
   data_path="./data"
 fi
diff --git a/contrib/clojure-package/scripts/get_mnist_data.sh 
b/contrib/clojure-package/scripts/get_mnist_data.sh
index 6f32b85f480..703ece207a1 100755
--- a/contrib/clojure-package/scripts/get_mnist_data.sh
+++ b/contrib/clojure-package/scripts/get_mnist_data.sh
@@ -20,8 +20,8 @@
 
 set -evx
 
-if [ ! -z "$MXNET_DATA_DIR" ]; then
-  data_path="$MXNET_DATA_DIR"
+if [ ! -z "$MXNET_HOME" ]; then
+  data_path="$MXNET_HOME"
 else
   data_path="./data"
 fi
diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md
index 881bc14fdc8..6e9a3594168 100644
--- a/docs/faq/env_var.md
+++ b/docs/faq/env_var.md
@@ -152,6 +152,10 @@ When USE_PROFILER is enabled in Makefile or CMake, the 
following environments ca
   - Values: String 
```(default='https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/'```
   - The repository url to be used for Gluon datasets and pre-trained models.
 
+* MXNET_HOME
+  - Data directory in the filesystem for storage, for example when downloading 
gluon models.
+  - Default in *nix is .mxnet APPDATA/mxnet in windows.
+
 Settings for Minimum Memory Usage
 ---------------------------------
 - Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1```
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 4df794bdfe3..3d8ee019175 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -22,11 +22,11 @@
 
 import atexit
 import ctypes
-import inspect
 import os
 import sys
 import warnings
-
+import inspect
+import platform
 import numpy as np
 
 from . import libinfo
@@ -59,6 +59,26 @@
     py_str = lambda x: x
 
 
+def data_dir_default():
+    """
+
+    :return: default data directory depending on the platform and environment 
variables
+    """
+    system = platform.system()
+    if system == 'Windows':
+        return os.path.join(os.environ.get('APPDATA'), 'mxnet')
+    else:
+        return os.path.join(os.path.expanduser("~"), '.mxnet')
+
+
+def data_dir():
+    """
+
+    :return: data directory in the filesystem for storage, for example when 
downloading models
+    """
+    return os.getenv('MXNET_HOME', data_dir_default())
+
+
 class _NullType(object):
     """Placeholder for arguments"""
     def __repr__(self):
diff --git a/python/mxnet/contrib/text/embedding.py 
b/python/mxnet/contrib/text/embedding.py
index 6598718e6b0..38defb4b90b 100644
--- a/python/mxnet/contrib/text/embedding.py
+++ b/python/mxnet/contrib/text/embedding.py
@@ -34,6 +34,7 @@
 from . import vocab
 from ... import ndarray as nd
 from ... import registry
+from ... import base
 
 
 def register(embedding_cls):
@@ -496,7 +497,7 @@ class GloVe(_TokenEmbedding):
     ----------
     pretrained_file_name : str, default 'glove.840B.300d.txt'
         The name of the pre-trained token embedding file.
-    embedding_root : str, default os.path.join('~', '.mxnet', 'embeddings')
+    embedding_root : str, default $MXNET_HOME/embeddings
         The root directory for storing embedding-related files.
     init_unknown_vec : callback
         The callback used to initialize the embedding vector for the unknown 
token.
@@ -541,7 +542,7 @@ def _get_download_file_name(cls, pretrained_file_name):
         return archive
 
     def __init__(self, pretrained_file_name='glove.840B.300d.txt',
-                 embedding_root=os.path.join('~', '.mxnet', 'embeddings'),
+                 embedding_root=os.path.join(base.data_dir(), 'embeddings'),
                  init_unknown_vec=nd.zeros, vocabulary=None, **kwargs):
         GloVe._check_pretrained_file_names(pretrained_file_name)
 
@@ -600,7 +601,7 @@ class FastText(_TokenEmbedding):
     ----------
     pretrained_file_name : str, default 'wiki.en.vec'
         The name of the pre-trained token embedding file.
-    embedding_root : str, default os.path.join('~', '.mxnet', 'embeddings')
+    embedding_root : str, default $MXNET_HOME/embeddings
         The root directory for storing embedding-related files.
     init_unknown_vec : callback
         The callback used to initialize the embedding vector for the unknown 
token.
@@ -642,7 +643,7 @@ def _get_download_file_name(cls, pretrained_file_name):
         return '.'.join(pretrained_file_name.split('.')[:-1])+'.zip'
 
     def __init__(self, pretrained_file_name='wiki.simple.vec',
-                 embedding_root=os.path.join('~', '.mxnet', 'embeddings'),
+                 embedding_root=os.path.join(base.data_dir(), 'embeddings'),
                  init_unknown_vec=nd.zeros, vocabulary=None, **kwargs):
         FastText._check_pretrained_file_names(pretrained_file_name)
 
diff --git a/python/mxnet/gluon/contrib/data/text.py 
b/python/mxnet/gluon/contrib/data/text.py
index 98fe6b657f2..9e78e3c2e23 100644
--- a/python/mxnet/gluon/contrib/data/text.py
+++ b/python/mxnet/gluon/contrib/data/text.py
@@ -30,8 +30,7 @@
 from ...data import dataset
 from ...utils import download, check_sha1, _get_repo_file_url
 from ....contrib import text
-from .... import nd
-
+from .... import nd, base
 
 class _LanguageModelDataset(dataset._DownloadedDataset): # pylint: 
disable=abstract-method
     def __init__(self, root, namespace, vocabulary):
@@ -116,7 +115,7 @@ class WikiText2(_WikiText):
 
     Parameters
     ----------
-    root : str, default '~/.mxnet/datasets/wikitext-2'
+    root : str, default $MXNET_HOME/datasets/wikitext-2
         Path to temp folder for storing data.
     segment : str, default 'train'
         Dataset segment. Options are 'train', 'validation', 'test'.
@@ -127,7 +126,7 @@ class WikiText2(_WikiText):
         The sequence length of each sample, regardless of the sentence 
boundary.
 
     """
-    def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 
'wikitext-2'),
+    def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 
'wikitext-2'),
                  segment='train', vocab=None, seq_len=35):
         self._archive_file = ('wikitext-2-v1.zip', 
'3c914d17d80b1459be871a5039ac23e752a53cbe')
         self._data_file = {'train': ('wiki.train.tokens',
@@ -154,7 +153,7 @@ class WikiText103(_WikiText):
 
     Parameters
     ----------
-    root : str, default '~/.mxnet/datasets/wikitext-103'
+    root : str, default $MXNET_HOME/datasets/wikitext-103
         Path to temp folder for storing data.
     segment : str, default 'train'
         Dataset segment. Options are 'train', 'validation', 'test'.
@@ -164,7 +163,7 @@ class WikiText103(_WikiText):
     seq_len : int, default 35
         The sequence length of each sample, regardless of the sentence 
boundary.
     """
-    def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 
'wikitext-103'),
+    def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 
'wikitext-103'),
                  segment='train', vocab=None, seq_len=35):
         self._archive_file = ('wikitext-103-v1.zip', 
'0aec09a7537b58d4bb65362fee27650eeaba625a')
         self._data_file = {'train': ('wiki.train.tokens',
diff --git a/python/mxnet/gluon/data/vision/datasets.py 
b/python/mxnet/gluon/data/vision/datasets.py
index 74a5aebf17b..2c98000389a 100644
--- a/python/mxnet/gluon/data/vision/datasets.py
+++ b/python/mxnet/gluon/data/vision/datasets.py
@@ -30,7 +30,7 @@
 
 from .. import dataset
 from ...utils import download, check_sha1, _get_repo_file_url
-from .... import nd, image, recordio
+from .... import nd, image, recordio, base
 
 
 class MNIST(dataset._DownloadedDataset):
@@ -40,7 +40,7 @@ class MNIST(dataset._DownloadedDataset):
 
     Parameters
     ----------
-    root : str, default '~/.mxnet/datasets/mnist'
+    root : str, default $MXNET_HOME/datasets/mnist
         Path to temp folder for storing data.
     train : bool, default True
         Whether to load the training or testing set.
@@ -51,7 +51,7 @@ class MNIST(dataset._DownloadedDataset):
         transform=lambda data, label: (data.astype(np.float32)/255, label)
 
     """
-    def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'mnist'),
+    def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'mnist'),
                  train=True, transform=None):
         self._train = train
         self._train_data = ('train-images-idx3-ubyte.gz',
@@ -101,7 +101,7 @@ class FashionMNIST(MNIST):
 
     Parameters
     ----------
-    root : str, default '~/.mxnet/datasets/fashion-mnist'
+    root : str, default $MXNET_HOME/datasets/fashion-mnist'
         Path to temp folder for storing data.
     train : bool, default True
         Whether to load the training or testing set.
@@ -112,7 +112,7 @@ class FashionMNIST(MNIST):
         transform=lambda data, label: (data.astype(np.float32)/255, label)
 
     """
-    def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 
'fashion-mnist'),
+    def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 
'fashion-mnist'),
                  train=True, transform=None):
         self._train = train
         self._train_data = ('train-images-idx3-ubyte.gz',
@@ -134,7 +134,7 @@ class CIFAR10(dataset._DownloadedDataset):
 
     Parameters
     ----------
-    root : str, default '~/.mxnet/datasets/cifar10'
+    root : str, default $MXNET_HOME/datasets/cifar10
         Path to temp folder for storing data.
     train : bool, default True
         Whether to load the training or testing set.
@@ -145,7 +145,7 @@ class CIFAR10(dataset._DownloadedDataset):
         transform=lambda data, label: (data.astype(np.float32)/255, label)
 
     """
-    def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'cifar10'),
+    def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 
'cifar10'),
                  train=True, transform=None):
         self._train = train
         self._archive_file = ('cifar-10-binary.tar.gz', 
'fab780a1e191a7eda0f345501ccd62d20f7ed891')
@@ -197,7 +197,7 @@ class CIFAR100(CIFAR10):
 
     Parameters
     ----------
-    root : str, default '~/.mxnet/datasets/cifar100'
+    root : str, default $MXNET_HOME/datasets/cifar100
         Path to temp folder for storing data.
     fine_label : bool, default False
         Whether to load the fine-grained (100 classes) or coarse-grained (20 
super-classes) labels.
@@ -210,7 +210,7 @@ class CIFAR100(CIFAR10):
         transform=lambda data, label: (data.astype(np.float32)/255, label)
 
     """
-    def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 
'cifar100'),
+    def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 
'cifar100'),
                  fine_label=False, train=True, transform=None):
         self._train = train
         self._archive_file = ('cifar-100-binary.tar.gz', 
'a0bb982c76b83111308126cc779a992fa506b90b')
diff --git a/python/mxnet/gluon/model_zoo/model_store.py 
b/python/mxnet/gluon/model_zoo/model_store.py
index 7eead68f0db..11ac47bae90 100644
--- a/python/mxnet/gluon/model_zoo/model_store.py
+++ b/python/mxnet/gluon/model_zoo/model_store.py
@@ -21,8 +21,10 @@
 __all__ = ['get_model_file', 'purge']
 import os
 import zipfile
+import logging
 
 from ..utils import download, check_sha1
+from ... import base, util
 
 _model_sha1 = {name: checksum for checksum, name in [
     ('44335d1f0046b328243b32a26a4fbd62d9057b45', 'alexnet'),
@@ -68,7 +70,7 @@ def short_hash(name):
         raise ValueError('Pretrained model for {name} is not 
available.'.format(name=name))
     return _model_sha1[name][:8]
 
-def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')):
+def get_model_file(name, root=os.path.join(base.data_dir(), 'models')):
     r"""Return location for the pretrained on local file system.
 
     This function will download from online model zoo when model cannot be 
found or has mismatch.
@@ -78,7 +80,7 @@ def get_model_file(name, root=os.path.join('~', '.mxnet', 
'models')):
     ----------
     name : str
         Name of the model.
-    root : str, default '~/.mxnet/models'
+    root : str, default $MXNET_HOME/models
         Location for keeping the model parameters.
 
     Returns
@@ -95,12 +97,11 @@ def get_model_file(name, root=os.path.join('~', '.mxnet', 
'models')):
         if check_sha1(file_path, sha1_hash):
             return file_path
         else:
-            print('Mismatch in the content of model file detected. Downloading 
again.')
+            logging.warning('Mismatch in the content of model file detected. 
Downloading again.')
     else:
-        print('Model file is not found. Downloading.')
+        logging.info('Model file not found. Downloading to %s.', file_path)
 
-    if not os.path.exists(root):
-        os.makedirs(root)
+    util.makedirs(root)
 
     zip_file_path = os.path.join(root, file_name+'.zip')
     repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url)
@@ -118,12 +119,12 @@ def get_model_file(name, root=os.path.join('~', '.mxnet', 
'models')):
     else:
         raise ValueError('Downloaded file has different hash. Please try 
again.')
 
-def purge(root=os.path.join('~', '.mxnet', 'models')):
+def purge(root=os.path.join(base.data_dir(), 'models')):
     r"""Purge all pretrained model files in local file store.
 
     Parameters
     ----------
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     root = os.path.expanduser(root)
diff --git a/python/mxnet/gluon/model_zoo/vision/__init__.py 
b/python/mxnet/gluon/model_zoo/vision/__init__.py
index a6e5dc137d4..7d33ce409b2 100644
--- a/python/mxnet/gluon/model_zoo/vision/__init__.py
+++ b/python/mxnet/gluon/model_zoo/vision/__init__.py
@@ -101,7 +101,7 @@ def get_model(name, **kwargs):
         Number of classes for the output layer.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
 
     Returns
diff --git a/python/mxnet/gluon/model_zoo/vision/alexnet.py 
b/python/mxnet/gluon/model_zoo/vision/alexnet.py
index fdb006258c2..daf4617cd12 100644
--- a/python/mxnet/gluon/model_zoo/vision/alexnet.py
+++ b/python/mxnet/gluon/model_zoo/vision/alexnet.py
@@ -25,6 +25,7 @@
 from ....context import cpu
 from ...block import HybridBlock
 from ... import nn
+from .... import base
 
 # Net
 class AlexNet(HybridBlock):
@@ -68,7 +69,7 @@ def hybrid_forward(self, F, x):
 
 # Constructor
 def alexnet(pretrained=False, ctx=cpu(),
-            root=os.path.join('~', '.mxnet', 'models'), **kwargs):
+            root=os.path.join(base.data_dir(), 'models'), **kwargs):
     r"""AlexNet model from the `"One weird trick..." 
<https://arxiv.org/abs/1404.5997>`_ paper.
 
     Parameters
@@ -77,7 +78,7 @@ def alexnet(pretrained=False, ctx=cpu(),
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default $MXNET_HOME/models
         Location for keeping the model parameters.
     """
     net = AlexNet(**kwargs)
diff --git a/python/mxnet/gluon/model_zoo/vision/densenet.py 
b/python/mxnet/gluon/model_zoo/vision/densenet.py
index b03f5ce8d52..83febd3658c 100644
--- a/python/mxnet/gluon/model_zoo/vision/densenet.py
+++ b/python/mxnet/gluon/model_zoo/vision/densenet.py
@@ -26,6 +26,7 @@
 from ...block import HybridBlock
 from ... import nn
 from ...contrib.nn import HybridConcurrent, Identity
+from .... import base
 
 # Helpers
 def _make_dense_block(num_layers, bn_size, growth_rate, dropout, stage_index):
@@ -122,7 +123,7 @@ def hybrid_forward(self, F, x):
 
 # Constructor
 def get_densenet(num_layers, pretrained=False, ctx=cpu(),
-                 root=os.path.join('~', '.mxnet', 'models'), **kwargs):
+                 root=os.path.join(base.data_dir(), 'models'), **kwargs):
     r"""Densenet-BC model from the
     `"Densely Connected Convolutional Networks" 
<https://arxiv.org/pdf/1608.06993.pdf>`_ paper.
 
@@ -134,7 +135,7 @@ def get_densenet(num_layers, pretrained=False, ctx=cpu(),
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default $MXNET_HOME/models
         Location for keeping the model parameters.
     """
     num_init_features, growth_rate, block_config = densenet_spec[num_layers]
@@ -154,7 +155,7 @@ def densenet121(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_densenet(121, **kwargs)
@@ -169,7 +170,7 @@ def densenet161(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_densenet(161, **kwargs)
@@ -184,7 +185,7 @@ def densenet169(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_densenet(169, **kwargs)
@@ -199,7 +200,7 @@ def densenet201(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_densenet(201, **kwargs)
diff --git a/python/mxnet/gluon/model_zoo/vision/inception.py 
b/python/mxnet/gluon/model_zoo/vision/inception.py
index 7c54691f1b5..6bdc526a6a1 100644
--- a/python/mxnet/gluon/model_zoo/vision/inception.py
+++ b/python/mxnet/gluon/model_zoo/vision/inception.py
@@ -26,6 +26,7 @@
 from ...block import HybridBlock
 from ... import nn
 from ...contrib.nn import HybridConcurrent
+from .... import base
 
 # Helpers
 def _make_basic_conv(**kwargs):
@@ -199,7 +200,7 @@ def hybrid_forward(self, F, x):
 
 # Constructor
 def inception_v3(pretrained=False, ctx=cpu(),
-                 root=os.path.join('~', '.mxnet', 'models'), **kwargs):
+                 root=os.path.join(base.data_dir(), 'models'), **kwargs):
     r"""Inception v3 model from
     `"Rethinking the Inception Architecture for Computer Vision"
     <http://arxiv.org/abs/1512.00567>`_ paper.
@@ -210,7 +211,7 @@ def inception_v3(pretrained=False, ctx=cpu(),
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default $MXNET_HOME/models
         Location for keeping the model parameters.
     """
     net = Inception3(**kwargs)
diff --git a/python/mxnet/gluon/model_zoo/vision/mobilenet.py 
b/python/mxnet/gluon/model_zoo/vision/mobilenet.py
index 1a2c9b94619..1a84e05af20 100644
--- a/python/mxnet/gluon/model_zoo/vision/mobilenet.py
+++ b/python/mxnet/gluon/model_zoo/vision/mobilenet.py
@@ -30,6 +30,7 @@
 from ... import nn
 from ....context import cpu
 from ...block import HybridBlock
+from .... import base
 
 
 # Helpers
@@ -188,7 +189,7 @@ def hybrid_forward(self, F, x):
 
 # Constructor
 def get_mobilenet(multiplier, pretrained=False, ctx=cpu(),
-                  root=os.path.join('~', '.mxnet', 'models'), **kwargs):
+                  root=os.path.join(base.data_dir(), 'models'), **kwargs):
     r"""MobileNet model from the
     `"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision 
Applications"
     <https://arxiv.org/abs/1704.04861>`_ paper.
@@ -203,7 +204,7 @@ def get_mobilenet(multiplier, pretrained=False, ctx=cpu(),
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default $MXNET_HOME/models
         Location for keeping the model parameters.
     """
     net = MobileNet(multiplier, **kwargs)
@@ -219,7 +220,7 @@ def get_mobilenet(multiplier, pretrained=False, ctx=cpu(),
 
 
 def get_mobilenet_v2(multiplier, pretrained=False, ctx=cpu(),
-                     root=os.path.join('~', '.mxnet', 'models'), **kwargs):
+                     root=os.path.join(base.data_dir(), 'models'), **kwargs):
     r"""MobileNetV2 model from the
     `"Inverted Residuals and Linear Bottlenecks:
       Mobile Networks for Classification, Detection and Segmentation"
@@ -235,7 +236,7 @@ def get_mobilenet_v2(multiplier, pretrained=False, 
ctx=cpu(),
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default $MXNET_HOME/models
         Location for keeping the model parameters.
     """
     net = MobileNetV2(multiplier, **kwargs)
diff --git a/python/mxnet/gluon/model_zoo/vision/resnet.py 
b/python/mxnet/gluon/model_zoo/vision/resnet.py
index da279b89583..48390decb11 100644
--- a/python/mxnet/gluon/model_zoo/vision/resnet.py
+++ b/python/mxnet/gluon/model_zoo/vision/resnet.py
@@ -32,6 +32,7 @@
 from ....context import cpu
 from ...block import HybridBlock
 from ... import nn
+from .... import base
 
 # Helpers
 def _conv3x3(channels, stride, in_channels):
@@ -356,7 +357,7 @@ def hybrid_forward(self, F, x):
 
 # Constructor
 def get_resnet(version, num_layers, pretrained=False, ctx=cpu(),
-               root=os.path.join('~', '.mxnet', 'models'), **kwargs):
+               root=os.path.join(base.data_dir(), 'models'), **kwargs):
     r"""ResNet V1 model from `"Deep Residual Learning for Image Recognition"
     <http://arxiv.org/abs/1512.03385>`_ paper.
     ResNet V2 model from `"Identity Mappings in Deep Residual Networks"
@@ -372,7 +373,7 @@ def get_resnet(version, num_layers, pretrained=False, 
ctx=cpu(),
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default $MXNET_HOME/models
         Location for keeping the model parameters.
     """
     assert num_layers in resnet_spec, \
@@ -400,7 +401,7 @@ def resnet18_v1(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_resnet(1, 18, **kwargs)
@@ -415,7 +416,7 @@ def resnet34_v1(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_resnet(1, 34, **kwargs)
@@ -430,7 +431,7 @@ def resnet50_v1(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_resnet(1, 50, **kwargs)
@@ -445,7 +446,7 @@ def resnet101_v1(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_resnet(1, 101, **kwargs)
@@ -460,7 +461,7 @@ def resnet152_v1(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_resnet(1, 152, **kwargs)
@@ -475,7 +476,7 @@ def resnet18_v2(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_resnet(2, 18, **kwargs)
@@ -490,7 +491,7 @@ def resnet34_v2(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_resnet(2, 34, **kwargs)
@@ -505,7 +506,7 @@ def resnet50_v2(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_resnet(2, 50, **kwargs)
@@ -520,7 +521,7 @@ def resnet101_v2(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_resnet(2, 101, **kwargs)
@@ -535,7 +536,7 @@ def resnet152_v2(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_resnet(2, 152, **kwargs)
diff --git a/python/mxnet/gluon/model_zoo/vision/squeezenet.py 
b/python/mxnet/gluon/model_zoo/vision/squeezenet.py
index aaff4c36dfa..b97d1274a6f 100644
--- a/python/mxnet/gluon/model_zoo/vision/squeezenet.py
+++ b/python/mxnet/gluon/model_zoo/vision/squeezenet.py
@@ -26,6 +26,7 @@
 from ...block import HybridBlock
 from ... import nn
 from ...contrib.nn import HybridConcurrent
+from .... import base
 
 # Helpers
 def _make_fire(squeeze_channels, expand1x1_channels, expand3x3_channels):
@@ -110,7 +111,7 @@ def hybrid_forward(self, F, x):
 
 # Constructor
 def get_squeezenet(version, pretrained=False, ctx=cpu(),
-                   root=os.path.join('~', '.mxnet', 'models'), **kwargs):
+                   root=os.path.join(base.data_dir(), 'models'), **kwargs):
     r"""SqueezeNet model from the `"SqueezeNet: AlexNet-level accuracy with 
50x fewer parameters
     and <0.5MB model size" <https://arxiv.org/abs/1602.07360>`_ paper.
     SqueezeNet 1.1 model from the `official SqueezeNet repo
@@ -126,7 +127,7 @@ def get_squeezenet(version, pretrained=False, ctx=cpu(),
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default $MXNET_HOME/models
         Location for keeping the model parameters.
     """
     net = SqueezeNet(version, **kwargs)
@@ -145,7 +146,7 @@ def squeezenet1_0(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_squeezenet('1.0', **kwargs)
@@ -162,7 +163,7 @@ def squeezenet1_1(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_squeezenet('1.1', **kwargs)
diff --git a/python/mxnet/gluon/model_zoo/vision/vgg.py 
b/python/mxnet/gluon/model_zoo/vision/vgg.py
index a3b1685b413..9a740e63318 100644
--- a/python/mxnet/gluon/model_zoo/vision/vgg.py
+++ b/python/mxnet/gluon/model_zoo/vision/vgg.py
@@ -30,6 +30,7 @@
 from ....initializer import Xavier
 from ...block import HybridBlock
 from ... import nn
+from .... import base
 
 
 class VGG(HybridBlock):
@@ -94,7 +95,7 @@ def hybrid_forward(self, F, x):
 
 # Constructors
 def get_vgg(num_layers, pretrained=False, ctx=cpu(),
-            root=os.path.join('~', '.mxnet', 'models'), **kwargs):
+            root=os.path.join(base.data_dir(), 'models'), **kwargs):
     r"""VGG model from the `"Very Deep Convolutional Networks for Large-Scale 
Image Recognition"
     <https://arxiv.org/abs/1409.1556>`_ paper.
 
@@ -106,7 +107,7 @@ def get_vgg(num_layers, pretrained=False, ctx=cpu(),
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default $MXNET_HOME/models
         Location for keeping the model parameters.
     """
     layers, filters = vgg_spec[num_layers]
@@ -128,7 +129,7 @@ def vgg11(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_vgg(11, **kwargs)
@@ -143,7 +144,7 @@ def vgg13(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_vgg(13, **kwargs)
@@ -158,7 +159,7 @@ def vgg16(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_vgg(16, **kwargs)
@@ -173,7 +174,7 @@ def vgg19(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     return get_vgg(19, **kwargs)
@@ -189,7 +190,7 @@ def vgg11_bn(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     kwargs['batch_norm'] = True
@@ -206,7 +207,7 @@ def vgg13_bn(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     kwargs['batch_norm'] = True
@@ -223,7 +224,7 @@ def vgg16_bn(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     kwargs['batch_norm'] = True
@@ -240,7 +241,7 @@ def vgg19_bn(**kwargs):
         Whether to load the pretrained weights for model.
     ctx : Context, default CPU
         The context in which to load the pretrained weights.
-    root : str, default '~/.mxnet/models'
+    root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
     kwargs['batch_norm'] = True
diff --git a/python/mxnet/util.py b/python/mxnet/util.py
new file mode 100644
index 00000000000..57bc2bf7638
--- /dev/null
+++ b/python/mxnet/util.py
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""general utility functions"""
+
+import os
+import sys
+
+
+def makedirs(d):
+    """Create directories recursively if they don't exist. 
os.makedirs(exist_ok=True) is not
+    available in Python2"""
+    if sys.version_info[0] < 3:
+        from distutils.dir_util import mkpath
+        mkpath(d)
+    else:
+        os.makedirs(d, exist_ok=True)
diff --git a/scala-package/core/scripts/get_cifar_data.sh 
b/scala-package/core/scripts/get_cifar_data.sh
index 9ec1c39a4f9..b061c1895e4 100755
--- a/scala-package/core/scripts/get_cifar_data.sh
+++ b/scala-package/core/scripts/get_cifar_data.sh
@@ -20,8 +20,8 @@
 
 set -e
 
-if [ ! -z "$MXNET_DATA_DIR" ]; then
-  data_path="$MXNET_DATA_DIR"
+if [ ! -z "$MXNET_HOME" ]; then
+  data_path="$MXNET_HOME"
 else
   data_path="./data"
 fi
diff --git a/scala-package/core/scripts/get_mnist_data.sh 
b/scala-package/core/scripts/get_mnist_data.sh
index 97e151bf833..ded206fbb13 100755
--- a/scala-package/core/scripts/get_mnist_data.sh
+++ b/scala-package/core/scripts/get_mnist_data.sh
@@ -20,8 +20,8 @@
 
 set -e
 
-if [ ! -z "$MXNET_DATA_DIR" ]; then
-  data_path="$MXNET_DATA_DIR"
+if [ ! -z "$MXNET_HOME" ]; then
+  data_path="$MXNET_HOME"
 else
   data_path="./data"
 fi
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/TestUtil.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/TestUtil.scala
index 1187757a033..4fc8ec9826c 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/TestUtil.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/TestUtil.scala
@@ -24,7 +24,7 @@ class TestUtil {
     * @return Data direcotry path ()may be relative)
     */
   def getDataDirectory: String = {
-    var dataDir = System.getenv("MXNET_DATA_DIR")
+    var dataDir = System.getenv("MXNET_HOME")
     if(dataDir == null) {
       dataDir = "data"
     } else {
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
index 6186989b74f..70846eebfb8 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
@@ -181,7 +181,7 @@ object GanMnist {
     try {
       parser.parseArgument(args.toList.asJava)
 
-      val dataPath = if (anst.mnistDataPath == null) 
System.getenv("MXNET_DATA_DIR")
+      val dataPath = if (anst.mnistDataPath == null) 
System.getenv("MXNET_HOME")
       else anst.mnistDataPath
 
       assert(dataPath != null)
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
index b0ecc7d29cc..bd0ce45ffe5 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
@@ -112,7 +112,7 @@ object TrainMnist {
     try {
       parser.parseArgument(args.toList.asJava)
 
-      val dataPath = if (inst.dataDir == null) System.getenv("MXNET_DATA_DIR")
+      val dataPath = if (inst.dataDir == null) System.getenv("MXNET_HOME")
         else inst.dataDir
 
       val (dataShape, net) =
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
index e886b908ba2..3bbd780d39b 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
@@ -119,13 +119,13 @@ object ImageClassifierExample {
       parser.parseArgument(args.toList.asJava)
 
 
-      val modelPathPrefix = if (inst.modelPathPrefix == null) 
System.getenv("MXNET_DATA_DIR")
+      val modelPathPrefix = if (inst.modelPathPrefix == null) 
System.getenv("MXNET_HOME")
       else inst.modelPathPrefix
 
-      val inputImagePath = if (inst.inputImagePath == null) 
System.getenv("MXNET_DATA_DIR")
+      val inputImagePath = if (inst.inputImagePath == null) 
System.getenv("MXNET_HOME")
       else inst.inputImagePath
 
-      val inputImageDir = if (inst.inputImageDir == null) 
System.getenv("MXNET_DATA_DIR")
+      val inputImageDir = if (inst.inputImageDir == null) 
System.getenv("MXNET_HOME")
       else inst.inputImageDir
 
       val singleOutput = runInferenceOnSingleImage(modelPathPrefix, 
inputImagePath, context)
diff --git a/tests/python/unittest/test_base.py 
b/tests/python/unittest/test_base.py
new file mode 100644
index 00000000000..3189729e1d1
--- /dev/null
+++ b/tests/python/unittest/test_base.py
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from mxnet.base import data_dir
+from nose.tools import *
+import os
+import unittest
+import logging
+import os.path as op
+import platform
+
+class MXNetDataDirTest(unittest.TestCase):
+    def setUp(self):
+        self.mxnet_data_dir = os.environ.get('MXNET_HOME')
+        if 'MXNET_HOME' in os.environ:
+            del os.environ['MXNET_HOME']
+
+    def tearDown(self):
+        if self.mxnet_data_dir:
+            os.environ['MXNET_HOME'] = self.mxnet_data_dir
+        else:
+            if 'MXNET_HOME' in os.environ:
+                del os.environ['MXNET_HOME']
+
+    def test_data_dir(self,):
+        prev_data_dir = data_dir()
+        system = platform.system()
+        if system != 'Windows':
+            self.assertEqual(data_dir(), op.join(op.expanduser('~'), '.mxnet'))
+        os.environ['MXNET_HOME'] = '/tmp/mxnet_data'
+        self.assertEqual(data_dir(), '/tmp/mxnet_data')
+        del os.environ['MXNET_HOME']
+        self.assertEqual(data_dir(), prev_data_dir)
+
+


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to