This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 69f4f3161364f290be053ecbd48931a40bd7ab68
Author: Haibin Lin <[email protected]>
AuthorDate: Thu Jan 23 19:32:21 2020 -0800

    [BUGFIX] fix model zoo parallel download (#17372)
    
    * use temp file
    
    * fix dependency
    
    * Update model_store.py
    
    * Update test_gluon_model_zoo.py
    
    * remove NamedTempFile
---
 python/mxnet/gluon/model_zoo/model_store.py   | 22 +++++++++++++-------
 python/mxnet/gluon/utils.py                   | 30 +++++++++++++++++++--------
 tests/python/unittest/test_gluon_model_zoo.py | 16 ++++++++++++++
 3 files changed, 52 insertions(+), 16 deletions(-)

diff --git a/python/mxnet/gluon/model_zoo/model_store.py 
b/python/mxnet/gluon/model_zoo/model_store.py
index 11ac47b..6da7dd1 100644
--- a/python/mxnet/gluon/model_zoo/model_store.py
+++ b/python/mxnet/gluon/model_zoo/model_store.py
@@ -22,8 +22,11 @@ __all__ = ['get_model_file', 'purge']
 import os
 import zipfile
 import logging
+import tempfile
+import uuid
+import shutil
 
-from ..utils import download, check_sha1
+from ..utils import download, check_sha1, replace_file
 from ... import base, util
 
 _model_sha1 = {name: checksum for checksum, name in [
@@ -103,16 +106,21 @@ def get_model_file(name, 
root=os.path.join(base.data_dir(), 'models')):
 
     util.makedirs(root)
 
-    zip_file_path = os.path.join(root, file_name+'.zip')
     repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url)
     if repo_url[-1] != '/':
         repo_url = repo_url + '/'
+
+    random_uuid = str(uuid.uuid4())
+    temp_zip_file_path = os.path.join(root, file_name+'.zip'+random_uuid)
     download(_url_format.format(repo_url=repo_url, file_name=file_name),
-             path=zip_file_path,
-             overwrite=True)
-    with zipfile.ZipFile(zip_file_path) as zf:
-        zf.extractall(root)
-    os.remove(zip_file_path)
+             path=temp_zip_file_path, overwrite=True)
+    with zipfile.ZipFile(temp_zip_file_path) as zf:
+        temp_dir = tempfile.mkdtemp(dir=root)
+        zf.extractall(temp_dir)
+        temp_file_path = os.path.join(temp_dir, file_name+'.params')
+        replace_file(temp_file_path, file_path)
+        shutil.rmtree(temp_dir)
+    os.remove(temp_zip_file_path)
 
     if check_sha1(file_path, sha1_hash):
         return file_path
diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index 81a8dba..63e11ea 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -21,7 +21,7 @@
 from __future__ import absolute_import
 
 __all__ = ['split_data', 'split_and_load', 'clip_global_norm',
-           'check_sha1', 'download']
+           'check_sha1', 'download', 'replace_file']
 
 import os
 import sys
@@ -35,7 +35,7 @@ import requests
 import numpy as np
 
 from .. import ndarray
-from ..util import is_np_shape, is_np_array
+from ..util import is_np_shape, is_np_array, makedirs
 from .. import numpy as _mx_np  # pylint: disable=reimported
 
 
@@ -209,8 +209,14 @@ def check_sha1(filename, sha1_hash):
 
 if not sys.platform.startswith('win32'):
     # refer to https://github.com/untitaker/python-atomicwrites
-    def _replace_atomic(src, dst):
-        """Implement atomic os.replace with linux and OSX. Internal use only"""
+    def replace_file(src, dst):
+        """Implement atomic os.replace with linux and OSX.
+
+        Parameters
+        ----------
+        src : source file path
+        dst : destination file path
+        """
         try:
             os.rename(src, dst)
         except OSError:
@@ -252,11 +258,17 @@ else:
             finally:
                 raise OSError(msg)
 
-    def _replace_atomic(src, dst):
+    def replace_file(src, dst):
         """Implement atomic os.replace with windows.
+
         refer to 
https://docs.microsoft.com/en-us/windows/desktop/api/winbase/nf-winbase-movefileexw
         The function fails when one of the process(copy, flush, delete) fails.
-        Internal use only"""
+
+        Parameters
+        ----------
+        src : source file path
+        dst : destination file path
+        """
         _handle_errors(ctypes.windll.kernel32.MoveFileExW(
             _str_to_unicode(src), _str_to_unicode(dst),
             _windows_default_flags | _MOVEFILE_REPLACE_EXISTING
@@ -264,7 +276,7 @@ else:
 
 
 def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, 
verify_ssl=True):
-    """Download an given URL
+    """Download a given URL
 
     Parameters
     ----------
@@ -310,7 +322,7 @@ def download(url, path=None, overwrite=False, 
sha1_hash=None, retries=5, verify_
     if overwrite or not os.path.exists(fname) or (sha1_hash and not 
check_sha1(fname, sha1_hash)):
         dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
         if not os.path.exists(dirname):
-            os.makedirs(dirname)
+            makedirs(dirname)
         while retries + 1 > 0:
             # Disable pyling too broad Exception
             # pylint: disable=W0703
@@ -330,7 +342,7 @@ def download(url, path=None, overwrite=False, 
sha1_hash=None, retries=5, verify_
                 # delete the temporary file
                 if not os.path.exists(fname) or (sha1_hash and not 
check_sha1(fname, sha1_hash)):
                     # atmoic operation in the same file system
-                    _replace_atomic('{}.{}'.format(fname, random_uuid), fname)
+                    replace_file('{}.{}'.format(fname, random_uuid), fname)
                 else:
                     try:
                         os.remove('{}.{}'.format(fname, random_uuid))
diff --git a/tests/python/unittest/test_gluon_model_zoo.py 
b/tests/python/unittest/test_gluon_model_zoo.py
index a646684..d53dd40 100644
--- a/tests/python/unittest/test_gluon_model_zoo.py
+++ b/tests/python/unittest/test_gluon_model_zoo.py
@@ -20,6 +20,7 @@ import mxnet as mx
 from mxnet.gluon.model_zoo.vision import get_model
 import sys
 from common import setup_module, with_seed, teardown
+import multiprocessing
 
 
 def eprint(*args, **kwargs):
@@ -49,6 +50,21 @@ def test_models():
             model.collect_params().initialize()
         model(mx.nd.random.uniform(shape=data_shape)).wait_to_read()
 
+def parallel_download(model_name):
+    model = get_model(model_name, pretrained=True, root='./parallel_download')
+    print(type(model))
+
+@with_seed()
+def test_parallel_download():
+    processes = []
+    name = 'mobilenetv2_0.25'
+    for _ in range(10):
+        p = multiprocessing.Process(target=parallel_download, args=(name,))
+        processes.append(p)
+    for p in processes:
+        p.start()
+    for p in processes:
+        p.join()
 
 if __name__ == '__main__':
     import nose

Reply via email to