This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9a99fc89a2 [Utils] Allow classmethod and staticmethod in
TVMDerivedObject (#14249)
9a99fc89a2 is described below
commit 9a99fc89a2970b9fca151a573de7a5e409b5d9ee
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Mar 12 21:33:57 2023 -0500
[Utils] Allow classmethod and staticmethod in TVMDerivedObject (#14249)
Instance methods that exist in the user-defined class but not in the
TVM base are forward using `__getattr__`. However, this is only
applied for attribute look of instances, and doesn't apply for
attribute lookup on the class object itself, such as when calling a
classmethod or staticmethod.
This commit exposes class methods and static methods in the wrapper
class, if they are defined in the user-defined subclass.
---
python/tvm/meta_schedule/utils.py | 3 +++
.../unittest/test_meta_schedule_post_order_apply.py | 21 +++++++++++++++++++++
2 files changed, 24 insertions(+)
diff --git a/python/tvm/meta_schedule/utils.py
b/python/tvm/meta_schedule/utils.py
index 401fdab08a..fb1ddd6585 100644
--- a/python/tvm/meta_schedule/utils.py
+++ b/python/tvm/meta_schedule/utils.py
@@ -128,6 +128,9 @@ def derived_object(cls: type) -> type:
TVMDerivedObject.__name__ = cls.__name__
TVMDerivedObject.__doc__ = cls.__doc__
TVMDerivedObject.__module__ = cls.__module__
+ for key, value in cls.__dict__.items():
+ if isinstance(value, (classmethod, staticmethod)):
+ setattr(TVMDerivedObject, key, value)
return TVMDerivedObject
diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py
b/tests/python/unittest/test_meta_schedule_post_order_apply.py
index c1d2dc3d07..716f829653 100644
--- a/tests/python/unittest/test_meta_schedule_post_order_apply.py
+++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py
@@ -404,5 +404,26 @@ def test_target_blocks_search_space():
assert len(schs) == 8
+def test_meta_schedule_derived_object():
+ @derived_object
+ class RemoveBlock(PyScheduleRule):
+ @classmethod
+ def class_construct(cls):
+ return cls()
+
+ @staticmethod
+ def static_construct():
+ return RemoveBlock()
+
+ inst_by_init = RemoveBlock()
+ assert isinstance(inst_by_init, RemoveBlock)
+
+ inst_by_classmethod = RemoveBlock.class_construct()
+ assert isinstance(inst_by_classmethod, RemoveBlock)
+
+ inst_by_staticmethod = RemoveBlock.static_construct()
+ assert isinstance(inst_by_staticmethod, RemoveBlock)
+
+
if __name__ == "__main__":
tvm.testing.main()