Lunderberg commented on a change in pull request #8451:
URL: https://github.com/apache/tvm/pull/8451#discussion_r668205945



##########
File path: python/tvm/testing.py
##########
@@ -1160,6 +1162,43 @@ def wraps(func):
     return wraps(func)
 
 
+class _DeepCopyAllowedClasses(dict):
+    def __init__(self, *allowed_class_list):
+        self.allowed_class_list = allowed_class_list
+        super().__init__()
+
+    def get(self, key, *args, **kwargs):
+        obj = ctypes.cast(key, ctypes.py_object).value
+        cls = type(obj)
+        if (
+            cls in copy._deepcopy_dispatch
+            or issubclass(cls, type)
+            or getattr(obj, "__deepcopy__", None)
+            or copyreg.dispatch_table.get(cls)
+            or cls.__reduce__ is not object.__reduce__
+            or cls.__reduce_ex__ is not object.__reduce_ex__
+            or cls in self.allowed_class_list
+        ):
+            return super().get(key, *args, **kwargs)
+
+        rfc_url = (
+            
"https://github.com/apache/tvm-rfcs/blob/main/rfcs/0007-parametrized-unit-tests.md";
+        )
+        raise TypeError(
+            (
+                "Cannot copy fixture of type {}.  TVM fixture caching "
+                "is limited to objects that explicitly provide the ability "
+                "to be copied (e.g. through __deepcopy__, __getstate__, or 
__setstate__),"
+                "and forbids the use of the default `object.__reduce__` and "
+                "`object.__reduce_ex__`.  For third-party classes that are 
known to be "
+                "safe to use with copy.deepcopy, please add the class to "

Review comment:
       Good call, the "known to be" is unnecessarily wordy.  Edited as 
suggested.

##########
File path: python/tvm/testing.py
##########
@@ -1199,18 +1238,15 @@ def wrapper(*args, **kwargs):
             except KeyError:
                 cached_value = cache[cache_key] = func(*args, **kwargs)
 
-            try:
-                yield copy.deepcopy(cached_value)
-            except TypeError as e:
-                rfc_url = (
-                    "https://github.com/apache/tvm-rfcs/blob/main/rfcs/";
-                    "0007-parametrized-unit-tests.md#unresolved-questions"
-                )
-                message = (
-                    "TVM caching of fixtures can only be used on serializable 
data types, not {}.\n"
-                    "Please see {} for details/discussion."
-                ).format(type(cached_value), rfc_url)
-                raise TypeError(message) from e
+            yield copy.deepcopy(
+                cached_value,
+                _DeepCopyAllowedClasses(
+                    # *args should be a list of classes that are known
+                    # to be safe to copy using copy.deepcopy, but do

Review comment:
       Agreed as above.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to