This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a4be2ed904 [TVMScript] Support inlined function call as a sugar 
(#11324)
a4be2ed904 is described below

commit a4be2ed9046a97fa826da9beba64c791e2c36ccf
Author: Masahiro Masuda <[email protected]>
AuthorDate: Wed May 18 17:56:10 2022 +0900

    [TVMScript] Support inlined function call as a sugar (#11324)
    
    * [TVMScript] Support function call to help construct AST
    
    * add test
    
    * update test
    
    * more comment
    
    * fix for avoiding Buffer.vload(...) case
    
    * update parse error msg
    
    * wrap func call with try / catch, emit error msg
    
    * silence pylint
---
 python/tvm/script/parser.py                        | 44 ++++++++++--
 .../python/unittest/test_tvmscript_syntax_sugar.py | 81 ++++++++++++++++++++++
 2 files changed, 121 insertions(+), 4 deletions(-)

diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index fe71b06432..daeb018ea9 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -20,7 +20,8 @@ We use [synr](https://synr.readthedocs.io) to get an AST that 
is stable over
 different python versions. Synr also provides an error handling context that we
 use for error reporting.
 """
-# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return
+# pylint: disable=invalid-name, inconsistent-return-statements, 
no-else-return, broad-except
+import types
 import json
 import operator
 import inspect
@@ -543,7 +544,7 @@ class TVMScriptParser(Transformer):
         AST abstract grammar:
             Assign(expr* targets, expr value, string? type_comment)
 
-        By now 3 patterns of Assign is supported:
+        By now 5 patterns of Assign is supported:
             1. special stmts with return value
                 1.1 Buffer = T.match_buffer()/T.buffer_decl()
                 1.2 Var = T.var()
@@ -552,6 +553,9 @@ class TVMScriptParser(Transformer):
             3. (Store)       Var[PrimExpr] = PrimExpr
             4. with scope handlers with concise scoping and var def
                 4.1 var = T.allocate()
+            5. A call to a pure python function, consuming and producing 
TVMScript values.
+               The outputs are inlined into the following body (no variable is 
created).
+               x, y = f(...)
         """
 
         if isinstance(node.rhs, ast.Call):
@@ -577,6 +581,35 @@ class TVMScriptParser(Transformer):
                 arg_list = self.parse_arg_list(func, node.rhs)
                 func.handle(node, self.context, arg_list, 
node.rhs.func_name.span)
                 return self.parse_body(node)
+            elif isinstance(func, types.FunctionType):
+                # Pattern 5
+                args = [self.transform(arg) for arg in node.rhs.params]
+                try:
+                    out = func(*args)
+                except Exception as e:
+                    self.report_error(
+                        "Error occured when invoking the function "
+                        + func.__name__
+                        + ": \n"
+                        + str(e),
+                        node.rhs.span,
+                    )
+
+                if len(node.lhs) == 1 and not isinstance(out, list):
+                    out = [out]
+
+                assert len(out) == len(node.lhs)
+
+                for var, value in zip(node.lhs, out):
+                    self.context.update_symbol(var.id.name, value, node)
+
+                body = self.parse_body(node)
+
+                for var, value in zip(node.lhs, out):
+                    self.context.remove_symbol(var.id.name)
+
+                return body
+
         if isinstance(node.rhs, (ast.Call, ast.Constant)):
             # Pattern 4 of let binding
             value = self.transform(node.rhs)
@@ -606,7 +639,7 @@ class TVMScriptParser(Transformer):
             return tvm.tir.LetStmt(var, value, body, 
span=tvm_span_from_synr(node.span))
 
         self.report_error(
-            """Assignments should be either
+            """Assignments should be one of:
             1. A "special statement" with return value
                 1.1 Buffer = T.match_buffer()/T.buffer_decl()
                 1.2 Var = T.var()
@@ -614,7 +647,10 @@ class TVMScriptParser(Transformer):
             2. A store into a buffer: Buffer[PrimExpr, PrimExpr, ..., 
PrimExpr] = PrimExpr
             3. A store into a variable: Var[PrimExpr] = PrimExpr
             4. A with scope handler with concise scoping and var def
-                4.1 var = T.allocate()""",
+                4.1 var = T.allocate()
+            5. The right-hand side being a call to a pure python function, 
consuming and
+               producing TVMScript values.
+               x, y = f(...)""",
             node.span,
         )
 
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py 
b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index a0964ea4d7..b3fe5674a8 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -265,5 +265,86 @@ def test_letstmt_bind_with_constant():
     assert_structural_equal(constant_binds, constant_binds_wrapped)
 
 
+def test_func_call():
+    def shared_16x16_to_ldmatrix_32x8_layout(i, j):
+        thread_id = (i % 8) * 4 + (j % 8) // 2
+        return thread_id, (j // 8) * 4 + (i // 8) * 2 + (j % 2)
+
+    @T.prim_func
+    def mma_sync_m16n16k16_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, 
scope="warp")
+        B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, 
scope="warp")
+        C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, 
scope="warp")
+
+        with T.block("root"):
+            T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
+            T.writes(C[0:32, 0:8])
+            for i, j, k in T.grid(16, 16, 16):
+                with T.block("C"):
+                    i, j, k = T.axis.remap("SSR", [i, j, k])
+                    thread_id_C, local_id_C = 
shared_16x16_to_ldmatrix_32x8_layout(i, j)
+                    thread_id_A, local_id_A = 
shared_16x16_to_ldmatrix_32x8_layout(i, k)
+                    thread_id_B, local_id_B = 
shared_16x16_to_ldmatrix_32x8_layout(k, j)
+
+                    T.reads(
+                        C[thread_id_C, local_id_C],
+                        A[thread_id_A, local_id_A],
+                        B[thread_id_B, local_id_B],
+                    )
+                    T.writes(C[thread_id_C, local_id_C])
+
+                    C[thread_id_C, local_id_C] += (
+                        A[thread_id_A, local_id_A] * B[thread_id_B, local_id_B]
+                    )
+
+    @T.prim_func
+    def mma_sync_m16n16k16_desc_manual(a: T.handle, b: T.handle, c: T.handle) 
-> None:
+        A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, 
scope="warp")
+        B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, 
scope="warp")
+        C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, 
scope="warp")
+
+        with T.block("root"):
+            T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
+            T.writes(C[0:32, 0:8])
+            for i, j, k in T.grid(16, 16, 16):
+                with T.block("C"):
+                    i, j, k = T.axis.remap("SSR", [i, j, k])
+                    T.reads(
+                        C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j 
% 2],
+                        A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k 
% 2],
+                        B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j 
% 2],
+                    )
+                    T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 
+ j % 2])
+                    C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] 
= (
+                        C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j 
% 2]
+                        + A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + 
k % 2]
+                        * B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + 
j % 2]
+                    )
+
+    assert_structural_equal(mma_sync_m16n16k16_desc, 
mma_sync_m16n16k16_desc_manual)
+
+    # The following is an example of an error message from calling an invalid 
function
+
+    # error: Error occured when invoking the function sqrt:
+    # loop of ufunc does not support argument 0 of type Var which has no 
callable sqrt method
+    #  --> test_tvmscript_syntax_sugar.py:334:19
+    #      |
+    #  334 |              ind = sqrt(i)
+    #      |                    ^^^^^^^
+    # note: run with `TVM_BACKTRACE=1` environment variable to display a 
backtrace.
+
+    # Uncomment to see the error above.
+    # def sqrt(x):
+    #     import numpy as np
+    #     return np.sqrt(x)
+
+    # @T.prim_func
+    # def loop(a: T.handle) -> None:
+    #     A = T.match_buffer(a, (128,))
+    #     for i in T.serial(128):
+    #         ind = sqrt(i)
+    #         A[i] = A[ind]
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))

Reply via email to