This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9a99fc89a2 [Utils] Allow classmethod and staticmethod in 
TVMDerivedObject (#14249)
9a99fc89a2 is described below

commit 9a99fc89a2970b9fca151a573de7a5e409b5d9ee
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Mar 12 21:33:57 2023 -0500

    [Utils] Allow classmethod and staticmethod in TVMDerivedObject (#14249)
    
    Instance methods that exist in the user-defined class but not in the
    TVM base are forward using `__getattr__`.  However, this is only
    applied for attribute look of instances, and doesn't apply for
    attribute lookup on the class object itself, such as when calling a
    classmethod or staticmethod.
    
    This commit exposes class methods and static methods in the wrapper
    class, if they are defined in the user-defined subclass.
---
 python/tvm/meta_schedule/utils.py                   |  3 +++
 .../unittest/test_meta_schedule_post_order_apply.py | 21 +++++++++++++++++++++
 2 files changed, 24 insertions(+)

diff --git a/python/tvm/meta_schedule/utils.py 
b/python/tvm/meta_schedule/utils.py
index 401fdab08a..fb1ddd6585 100644
--- a/python/tvm/meta_schedule/utils.py
+++ b/python/tvm/meta_schedule/utils.py
@@ -128,6 +128,9 @@ def derived_object(cls: type) -> type:
     TVMDerivedObject.__name__ = cls.__name__
     TVMDerivedObject.__doc__ = cls.__doc__
     TVMDerivedObject.__module__ = cls.__module__
+    for key, value in cls.__dict__.items():
+        if isinstance(value, (classmethod, staticmethod)):
+            setattr(TVMDerivedObject, key, value)
     return TVMDerivedObject
 
 
diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py 
b/tests/python/unittest/test_meta_schedule_post_order_apply.py
index c1d2dc3d07..716f829653 100644
--- a/tests/python/unittest/test_meta_schedule_post_order_apply.py
+++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py
@@ -404,5 +404,26 @@ def test_target_blocks_search_space():
     assert len(schs) == 8
 
 
+def test_meta_schedule_derived_object():
+    @derived_object
+    class RemoveBlock(PyScheduleRule):
+        @classmethod
+        def class_construct(cls):
+            return cls()
+
+        @staticmethod
+        def static_construct():
+            return RemoveBlock()
+
+    inst_by_init = RemoveBlock()
+    assert isinstance(inst_by_init, RemoveBlock)
+
+    inst_by_classmethod = RemoveBlock.class_construct()
+    assert isinstance(inst_by_classmethod, RemoveBlock)
+
+    inst_by_staticmethod = RemoveBlock.static_construct()
+    assert isinstance(inst_by_staticmethod, RemoveBlock)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to