This is an automated email from the ASF dual-hosted git repository. haibin pushed a commit to branch v1.6.x in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 69f4f3161364f290be053ecbd48931a40bd7ab68 Author: Haibin Lin <[email protected]> AuthorDate: Thu Jan 23 19:32:21 2020 -0800 [BUGFIX] fix model zoo parallel download (#17372) * use temp file * fix dependency * Update model_store.py * Update test_gluon_model_zoo.py * remove NamedTempFile --- python/mxnet/gluon/model_zoo/model_store.py | 22 +++++++++++++------- python/mxnet/gluon/utils.py | 30 +++++++++++++++++++-------- tests/python/unittest/test_gluon_model_zoo.py | 16 ++++++++++++++ 3 files changed, 52 insertions(+), 16 deletions(-) diff --git a/python/mxnet/gluon/model_zoo/model_store.py b/python/mxnet/gluon/model_zoo/model_store.py index 11ac47b..6da7dd1 100644 --- a/python/mxnet/gluon/model_zoo/model_store.py +++ b/python/mxnet/gluon/model_zoo/model_store.py @@ -22,8 +22,11 @@ __all__ = ['get_model_file', 'purge'] import os import zipfile import logging +import tempfile +import uuid +import shutil -from ..utils import download, check_sha1 +from ..utils import download, check_sha1, replace_file from ... import base, util _model_sha1 = {name: checksum for checksum, name in [ @@ -103,16 +106,21 @@ def get_model_file(name, root=os.path.join(base.data_dir(), 'models')): util.makedirs(root) - zip_file_path = os.path.join(root, file_name+'.zip') repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url) if repo_url[-1] != '/': repo_url = repo_url + '/' + + random_uuid = str(uuid.uuid4()) + temp_zip_file_path = os.path.join(root, file_name+'.zip'+random_uuid) download(_url_format.format(repo_url=repo_url, file_name=file_name), - path=zip_file_path, - overwrite=True) - with zipfile.ZipFile(zip_file_path) as zf: - zf.extractall(root) - os.remove(zip_file_path) + path=temp_zip_file_path, overwrite=True) + with zipfile.ZipFile(temp_zip_file_path) as zf: + temp_dir = tempfile.mkdtemp(dir=root) + zf.extractall(temp_dir) + temp_file_path = os.path.join(temp_dir, file_name+'.params') + replace_file(temp_file_path, file_path) + shutil.rmtree(temp_dir) + os.remove(temp_zip_file_path) if check_sha1(file_path, sha1_hash): return file_path diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 81a8dba..63e11ea 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -21,7 +21,7 @@ from __future__ import absolute_import __all__ = ['split_data', 'split_and_load', 'clip_global_norm', - 'check_sha1', 'download'] + 'check_sha1', 'download', 'replace_file'] import os import sys @@ -35,7 +35,7 @@ import requests import numpy as np from .. import ndarray -from ..util import is_np_shape, is_np_array +from ..util import is_np_shape, is_np_array, makedirs from .. import numpy as _mx_np # pylint: disable=reimported @@ -209,8 +209,14 @@ def check_sha1(filename, sha1_hash): if not sys.platform.startswith('win32'): # refer to https://github.com/untitaker/python-atomicwrites - def _replace_atomic(src, dst): - """Implement atomic os.replace with linux and OSX. Internal use only""" + def replace_file(src, dst): + """Implement atomic os.replace with linux and OSX. + + Parameters + ---------- + src : source file path + dst : destination file path + """ try: os.rename(src, dst) except OSError: @@ -252,11 +258,17 @@ else: finally: raise OSError(msg) - def _replace_atomic(src, dst): + def replace_file(src, dst): """Implement atomic os.replace with windows. + refer to https://docs.microsoft.com/en-us/windows/desktop/api/winbase/nf-winbase-movefileexw The function fails when one of the process(copy, flush, delete) fails. - Internal use only""" + + Parameters + ---------- + src : source file path + dst : destination file path + """ _handle_errors(ctypes.windll.kernel32.MoveFileExW( _str_to_unicode(src), _str_to_unicode(dst), _windows_default_flags | _MOVEFILE_REPLACE_EXISTING @@ -264,7 +276,7 @@ else: def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): - """Download an given URL + """Download a given URL Parameters ---------- @@ -310,7 +322,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) if not os.path.exists(dirname): - os.makedirs(dirname) + makedirs(dirname) while retries + 1 > 0: # Disable pyling too broad Exception # pylint: disable=W0703 @@ -330,7 +342,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ # delete the temporary file if not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): # atmoic operation in the same file system - _replace_atomic('{}.{}'.format(fname, random_uuid), fname) + replace_file('{}.{}'.format(fname, random_uuid), fname) else: try: os.remove('{}.{}'.format(fname, random_uuid)) diff --git a/tests/python/unittest/test_gluon_model_zoo.py b/tests/python/unittest/test_gluon_model_zoo.py index a646684..d53dd40 100644 --- a/tests/python/unittest/test_gluon_model_zoo.py +++ b/tests/python/unittest/test_gluon_model_zoo.py @@ -20,6 +20,7 @@ import mxnet as mx from mxnet.gluon.model_zoo.vision import get_model import sys from common import setup_module, with_seed, teardown +import multiprocessing def eprint(*args, **kwargs): @@ -49,6 +50,21 @@ def test_models(): model.collect_params().initialize() model(mx.nd.random.uniform(shape=data_shape)).wait_to_read() +def parallel_download(model_name): + model = get_model(model_name, pretrained=True, root='./parallel_download') + print(type(model)) + +@with_seed() +def test_parallel_download(): + processes = [] + name = 'mobilenetv2_0.25' + for _ in range(10): + p = multiprocessing.Process(target=parallel_download, args=(name,)) + processes.append(p) + for p in processes: + p.start() + for p in processes: + p.join() if __name__ == '__main__': import nose
