This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c265cdae97 [TVMScript] Add `__name__` attr for parsed PrimFunc and 
IRModule (#14786)
c265cdae97 is described below

commit c265cdae97a9449f6fae2d26db79088aec11e8cc
Author: Siyuan Feng <[email protected]>
AuthorDate: Sun May 7 06:29:46 2023 +0800

    [TVMScript] Add `__name__` attr for parsed PrimFunc and IRModule (#14786)
    
    This PR adds `__name__` attr to indicate the func/mod name for parsed 
PrimFunc and IRModule.
---
 python/tvm/script/parser/ir/entry.py               |  4 +++-
 python/tvm/script/parser/tir/entry.py              |  4 +++-
 tests/python/unittest/test_tvmscript_parser_ir.py  |  1 +
 tests/python/unittest/test_tvmscript_parser_tir.py | 16 ++++++++++++++--
 4 files changed, 21 insertions(+), 4 deletions(-)

diff --git a/python/tvm/script/parser/ir/entry.py 
b/python/tvm/script/parser/ir/entry.py
index 94fc3d2e2c..5878a1ce55 100644
--- a/python/tvm/script/parser/ir/entry.py
+++ b/python/tvm/script/parser/ir/entry.py
@@ -40,7 +40,9 @@ def ir_module(mod: Type) -> IRModule:
     if not inspect.isclass(mod):
         raise TypeError(f"Expect a class, but got: {mod}")
 
-    return parse(mod, utils.inspect_class_capture(mod))
+    m = parse(mod, utils.inspect_class_capture(mod))
+    setattr(m, "__name__", mod.__name__)
+    return m
 
 
 setattr(ir_module, "dispatch_token", "ir")
diff --git a/python/tvm/script/parser/tir/entry.py 
b/python/tvm/script/parser/tir/entry.py
index 649f817411..d5bff7a856 100644
--- a/python/tvm/script/parser/tir/entry.py
+++ b/python/tvm/script/parser/tir/entry.py
@@ -42,7 +42,9 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
         raise TypeError(f"Expect a function, but got: {func}")
     if utils.is_defined_in_class(inspect.stack(), func):
         return func
-    return parse(func, utils.inspect_function_capture(func))
+    f = parse(func, utils.inspect_function_capture(func))
+    setattr(f, "__name__", func.__name__)
+    return f
 
 
 setattr(prim_func, "dispatch_token", "tir")
diff --git a/tests/python/unittest/test_tvmscript_parser_ir.py 
b/tests/python/unittest/test_tvmscript_parser_ir.py
index d3e758fbe1..d33594794f 100644
--- a/tests/python/unittest/test_tvmscript_parser_ir.py
+++ b/tests/python/unittest/test_tvmscript_parser_ir.py
@@ -29,6 +29,7 @@ def test_ir_base():
         pass
 
     assert isinstance(BlankIRModule, IRModule) and 
len(BlankIRModule.functions.items()) == 0
+    assert BlankIRModule.__name__ == "BlankIRModule"
 
 
 if __name__ == "__main__":
diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py 
b/tests/python/unittest/test_tvmscript_parser_tir.py
index 20be6d1498..31bf5cc101 100644
--- a/tests/python/unittest/test_tvmscript_parser_tir.py
+++ b/tests/python/unittest/test_tvmscript_parser_tir.py
@@ -16,8 +16,6 @@
 # under the License.
 """Unittests for tvm.script.parser.tir"""
 
-import pytest
-import inspect
 import tvm.testing
 from tvm.script.parser import tir as T
 from tvm import ir, tir
@@ -59,5 +57,19 @@ def test_tir_ptr_proxy():
     )
 
 
+def test_tir_func_name():
+    @T.prim_func
+    def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, [128, 128])
+        B = T.match_buffer(b, [128, 128])
+        C = T.match_buffer(c, [128, 128])
+        for i, j, k in T.grid(128, 128, 128):
+            with T.block("update"):
+                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+    assert matmul.__name__ == "matmul"
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to