This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b654852b15 [Bugfix] Allow import of TVM when current directory is 
read-only (#17142)
b654852b15 is described below

commit b654852b155d667a0c86adc8ff92d5eb7ca2c44b
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Jul 15 12:54:05 2024 -0500

    [Bugfix] Allow import of TVM when current directory is read-only (#17142)
    
    * [Bugfix] Allow import of TVM when current directory is read-only
    
    Prior to this commit, TVM could only be imported if the current
    directory had write privileges.  This was due to the use of
    `tvm.contrib.pickle_memoize` to cache the winograd transformation
    matrices.
    
    This commit makes multiple related fixes, to ensure that (1) TVM can
    be imported regardless of directory permissions, (2) the working
    directory is not left in a cluttered state, and (3) cache files are
    generated in an expected location to be reused later.
    
    * The cache directory is only generated when required, just prior to
      saving a cache.
    
    * The cache directory defaults to `$HOME/.cache/tvm/pkl_memoize`,
      rather than `.pkl_memorize_py3` in the working directory.
    
    * The cache directory respects `XDG_CACHE_HOME`, using
      `$XDG_CACHE_HOME/tvm/pkl_memoize` if set.
    
    * lint fix
---
 python/tvm/contrib/pickle_memoize.py          |  58 ++++++++----
 tests/python/contrib/pickle_memoize_script.py |  48 ++++++++++
 tests/python/contrib/test_memoize.py          | 126 ++++++++++++++++++++++++++
 3 files changed, 214 insertions(+), 18 deletions(-)

diff --git a/python/tvm/contrib/pickle_memoize.py 
b/python/tvm/contrib/pickle_memoize.py
index 6d2ffbac06..4f3aff8fb5 100644
--- a/python/tvm/contrib/pickle_memoize.py
+++ b/python/tvm/contrib/pickle_memoize.py
@@ -15,10 +15,13 @@
 # specific language governing permissions and limitations
 # under the License.
 """Memoize result of function via pickle, used for cache testcases."""
+
 # pylint: disable=broad-except,superfluous-parens
+import atexit
 import os
+import pathlib
 import sys
-import atexit
+
 from decorator import decorate
 from .._ffi.base import string_types
 
@@ -28,6 +31,17 @@ except ImportError:
     import pickle
 
 
+def _get_global_cache_dir() -> pathlib.Path:
+    if "XDG_CACHE_HOME" in os.environ:
+        cache_home = pathlib.Path(os.environ.get("XDG_CACHE_HOME"))
+    else:
+        cache_home = pathlib.Path.home().joinpath(".cache")
+    return cache_home.joinpath("tvm", f"pkl_memoize_py{sys.version_info[0]}")
+
+
+GLOBAL_CACHE_DIR = _get_global_cache_dir()
+
+
 class Cache(object):
     """A cache object for result cache.
 
@@ -42,28 +56,36 @@ class Cache(object):
     cache_by_key = {}
 
     def __init__(self, key, save_at_exit):
-        cache_dir = f".pkl_memoize_py{sys.version_info[0]}"
-        try:
-            os.mkdir(cache_dir)
-        except FileExistsError:
-            pass
-        else:
-            self.cache = {}
-        self.path = os.path.join(cache_dir, key)
-        if os.path.exists(self.path):
-            try:
-                self.cache = pickle.load(open(self.path, "rb"))
-            except Exception:
-                self.cache = {}
-        else:
-            self.cache = {}
+        self._cache = None
+
+        self.path = GLOBAL_CACHE_DIR.joinpath(key)
         self.dirty = False
         self.save_at_exit = save_at_exit
 
+    @property
+    def cache(self):
+        """Return the cache, initializing on first use."""
+
+        if self._cache is not None:
+            return self._cache
+
+        if self.path.exists():
+            with self.path.open("rb") as cache_file:
+                try:
+                    cache = pickle.load(cache_file)
+                except pickle.UnpicklingError:
+                    cache = {}
+        else:
+            cache = {}
+
+        self._cache = cache
+        return self._cache
+
     def save(self):
         if self.dirty:
-            print(f"Save memoize result to {self.path}")
-            with open(self.path, "wb") as out_file:
+            self.path.parent.mkdir(parents=True, exist_ok=True)
+
+            with self.path.open("wb") as out_file:
                 pickle.dump(self.cache, out_file, pickle.HIGHEST_PROTOCOL)
 
 
diff --git a/tests/python/contrib/pickle_memoize_script.py 
b/tests/python/contrib/pickle_memoize_script.py
new file mode 100755
index 0000000000..f0d73e3910
--- /dev/null
+++ b/tests/python/contrib/pickle_memoize_script.py
@@ -0,0 +1,48 @@
+#!/usr/bin/env python3
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+
+import tvm
+
+
[email protected]_memoize.memoize("test_memoize_save_data", 
save_at_exit=True)
+def get_data_saved():
+    return 42
+
+
[email protected]_memoize.memoize("test_memoize_transient_data", 
save_at_exit=False)
+def get_data_transient():
+    return 42
+
+
+def main():
+    assert len(sys.argv) == 3, "Expect arguments SCRIPT NUM_SAVED 
NUM_TRANSIENT"
+
+    num_iter_saved = int(sys.argv[1])
+    num_iter_transient = int(sys.argv[2])
+
+    for _ in range(num_iter_saved):
+        get_data_saved()
+    for _ in range(num_iter_transient):
+        get_data_transient()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tests/python/contrib/test_memoize.py 
b/tests/python/contrib/test_memoize.py
new file mode 100644
index 0000000000..6881940e50
--- /dev/null
+++ b/tests/python/contrib/test_memoize.py
@@ -0,0 +1,126 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Tests for tvm.contrib.pickle_memoize"""
+
+import os
+import pathlib
+import tempfile
+import subprocess
+import sys
+
+import tvm.testing
+
+TEST_SCRIPT_FILE = 
pathlib.Path(__file__).with_name("pickle_memoize_script.py").resolve()
+
+
+def test_cache_dir_not_in_current_working_dir():
+    with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir:
+        temp_dir = pathlib.Path(temp_dir)
+        subprocess.check_call([TEST_SCRIPT_FILE, "1", "1"], cwd=temp_dir)
+
+        new_files = list(temp_dir.iterdir())
+        assert (
+            not new_files
+        ), "Use of tvm.contrib.pickle_memorize may not write to current 
directory."
+
+
+def test_current_directory_is_not_required_to_be_writable():
+    """TVM may be imported without directory permissions
+
+    This is a regression test.  In previous implementations, the
+    `tvm.contrib.pickle_memoize.memoize` function would write to the
+    current directory when importing TVM.  Import of a Python module
+    should not write to any directory.
+
+    """
+
+    with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir:
+        temp_dir = pathlib.Path(temp_dir)
+
+        # User may read/cd into the temp dir, nobody may write to temp
+        # dir.
+        temp_dir.chmod(0o500)
+        subprocess.check_call([sys.executable, "-c", "import tvm"], 
cwd=temp_dir)
+
+
+def test_cache_dir_defaults_to_home_config_cache():
+    with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir:
+        temp_dir = pathlib.Path(temp_dir)
+
+        subprocess.check_call([TEST_SCRIPT_FILE, "1", "0"], cwd=temp_dir)
+
+        new_files = list(temp_dir.iterdir())
+        assert (
+            not new_files
+        ), "Use of tvm.contrib.pickle_memorize may not write to current 
directory."
+
+        cache_dir = pathlib.Path.home().joinpath(".cache", "tvm", 
"pkl_memoize_py3")
+        assert cache_dir.exists()
+        cache_files = list(cache_dir.iterdir())
+        assert len(cache_files) >= 1
+
+
+def test_cache_dir_respects_xdg_cache_home():
+    with tempfile.TemporaryDirectory(
+        prefix="tvm_"
+    ) as temp_working_dir, tempfile.TemporaryDirectory(prefix="tvm_") as 
temp_cache_dir:
+        temp_cache_dir = pathlib.Path(temp_cache_dir)
+        temp_working_dir = pathlib.Path(temp_working_dir)
+
+        subprocess.check_call(
+            [TEST_SCRIPT_FILE, "1", "0"],
+            cwd=temp_working_dir,
+            env={
+                **os.environ,
+                "XDG_CACHE_HOME": temp_cache_dir.as_posix(),
+            },
+        )
+
+        new_files = list(temp_working_dir.iterdir())
+        assert (
+            not new_files
+        ), "Use of tvm.contrib.pickle_memorize may not write to current 
directory."
+
+        cache_dir = temp_cache_dir.joinpath("tvm", "pkl_memoize_py3")
+        assert cache_dir.exists()
+        cache_files = list(cache_dir.iterdir())
+        assert len(cache_files) == 1
+
+
+def test_cache_dir_only_created_when_used():
+    with tempfile.TemporaryDirectory(
+        prefix="tvm_"
+    ) as temp_working_dir, tempfile.TemporaryDirectory(prefix="tvm_") as 
temp_cache_dir:
+        temp_cache_dir = pathlib.Path(temp_cache_dir)
+        temp_working_dir = pathlib.Path(temp_working_dir)
+
+        subprocess.check_call(
+            [TEST_SCRIPT_FILE, "0", "1"],
+            cwd=temp_working_dir,
+            env={
+                **os.environ,
+                "XDG_CACHE_HOME": temp_cache_dir.as_posix(),
+            },
+        )
+
+        cache_dir = temp_cache_dir.joinpath("tvm", "pkl_memoize_py3")
+        assert not cache_dir.exists()
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to