This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a4be2ed904 [TVMScript] Support inlined function call as a sugar
(#11324)
a4be2ed904 is described below
commit a4be2ed9046a97fa826da9beba64c791e2c36ccf
Author: Masahiro Masuda <[email protected]>
AuthorDate: Wed May 18 17:56:10 2022 +0900
[TVMScript] Support inlined function call as a sugar (#11324)
* [TVMScript] Support function call to help construct AST
* add test
* update test
* more comment
* fix for avoiding Buffer.vload(...) case
* update parse error msg
* wrap func call with try / catch, emit error msg
* silence pylint
---
python/tvm/script/parser.py | 44 ++++++++++--
.../python/unittest/test_tvmscript_syntax_sugar.py | 81 ++++++++++++++++++++++
2 files changed, 121 insertions(+), 4 deletions(-)
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index fe71b06432..daeb018ea9 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -20,7 +20,8 @@ We use [synr](https://synr.readthedocs.io) to get an AST that
is stable over
different python versions. Synr also provides an error handling context that we
use for error reporting.
"""
-# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return
+# pylint: disable=invalid-name, inconsistent-return-statements,
no-else-return, broad-except
+import types
import json
import operator
import inspect
@@ -543,7 +544,7 @@ class TVMScriptParser(Transformer):
AST abstract grammar:
Assign(expr* targets, expr value, string? type_comment)
- By now 3 patterns of Assign is supported:
+ By now 5 patterns of Assign is supported:
1. special stmts with return value
1.1 Buffer = T.match_buffer()/T.buffer_decl()
1.2 Var = T.var()
@@ -552,6 +553,9 @@ class TVMScriptParser(Transformer):
3. (Store) Var[PrimExpr] = PrimExpr
4. with scope handlers with concise scoping and var def
4.1 var = T.allocate()
+ 5. A call to a pure python function, consuming and producing
TVMScript values.
+ The outputs are inlined into the following body (no variable is
created).
+ x, y = f(...)
"""
if isinstance(node.rhs, ast.Call):
@@ -577,6 +581,35 @@ class TVMScriptParser(Transformer):
arg_list = self.parse_arg_list(func, node.rhs)
func.handle(node, self.context, arg_list,
node.rhs.func_name.span)
return self.parse_body(node)
+ elif isinstance(func, types.FunctionType):
+ # Pattern 5
+ args = [self.transform(arg) for arg in node.rhs.params]
+ try:
+ out = func(*args)
+ except Exception as e:
+ self.report_error(
+ "Error occured when invoking the function "
+ + func.__name__
+ + ": \n"
+ + str(e),
+ node.rhs.span,
+ )
+
+ if len(node.lhs) == 1 and not isinstance(out, list):
+ out = [out]
+
+ assert len(out) == len(node.lhs)
+
+ for var, value in zip(node.lhs, out):
+ self.context.update_symbol(var.id.name, value, node)
+
+ body = self.parse_body(node)
+
+ for var, value in zip(node.lhs, out):
+ self.context.remove_symbol(var.id.name)
+
+ return body
+
if isinstance(node.rhs, (ast.Call, ast.Constant)):
# Pattern 4 of let binding
value = self.transform(node.rhs)
@@ -606,7 +639,7 @@ class TVMScriptParser(Transformer):
return tvm.tir.LetStmt(var, value, body,
span=tvm_span_from_synr(node.span))
self.report_error(
- """Assignments should be either
+ """Assignments should be one of:
1. A "special statement" with return value
1.1 Buffer = T.match_buffer()/T.buffer_decl()
1.2 Var = T.var()
@@ -614,7 +647,10 @@ class TVMScriptParser(Transformer):
2. A store into a buffer: Buffer[PrimExpr, PrimExpr, ...,
PrimExpr] = PrimExpr
3. A store into a variable: Var[PrimExpr] = PrimExpr
4. A with scope handler with concise scoping and var def
- 4.1 var = T.allocate()""",
+ 4.1 var = T.allocate()
+ 5. The right-hand side being a call to a pure python function,
consuming and
+ producing TVMScript values.
+ x, y = f(...)""",
node.span,
)
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py
b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index a0964ea4d7..b3fe5674a8 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -265,5 +265,86 @@ def test_letstmt_bind_with_constant():
assert_structural_equal(constant_binds, constant_binds_wrapped)
+def test_func_call():
+ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
+ thread_id = (i % 8) * 4 + (j % 8) // 2
+ return thread_id, (j // 8) * 4 + (i // 8) * 2 + (j % 2)
+
+ @T.prim_func
+ def mma_sync_m16n16k16_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16,
scope="warp")
+ B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16,
scope="warp")
+ C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16,
scope="warp")
+
+ with T.block("root"):
+ T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
+ T.writes(C[0:32, 0:8])
+ for i, j, k in T.grid(16, 16, 16):
+ with T.block("C"):
+ i, j, k = T.axis.remap("SSR", [i, j, k])
+ thread_id_C, local_id_C =
shared_16x16_to_ldmatrix_32x8_layout(i, j)
+ thread_id_A, local_id_A =
shared_16x16_to_ldmatrix_32x8_layout(i, k)
+ thread_id_B, local_id_B =
shared_16x16_to_ldmatrix_32x8_layout(k, j)
+
+ T.reads(
+ C[thread_id_C, local_id_C],
+ A[thread_id_A, local_id_A],
+ B[thread_id_B, local_id_B],
+ )
+ T.writes(C[thread_id_C, local_id_C])
+
+ C[thread_id_C, local_id_C] += (
+ A[thread_id_A, local_id_A] * B[thread_id_B, local_id_B]
+ )
+
+ @T.prim_func
+ def mma_sync_m16n16k16_desc_manual(a: T.handle, b: T.handle, c: T.handle)
-> None:
+ A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16,
scope="warp")
+ B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16,
scope="warp")
+ C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16,
scope="warp")
+
+ with T.block("root"):
+ T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
+ T.writes(C[0:32, 0:8])
+ for i, j, k in T.grid(16, 16, 16):
+ with T.block("C"):
+ i, j, k = T.axis.remap("SSR", [i, j, k])
+ T.reads(
+ C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j
% 2],
+ A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k
% 2],
+ B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j
% 2],
+ )
+ T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2
+ j % 2])
+ C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2]
= (
+ C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j
% 2]
+ + A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 +
k % 2]
+ * B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 +
j % 2]
+ )
+
+ assert_structural_equal(mma_sync_m16n16k16_desc,
mma_sync_m16n16k16_desc_manual)
+
+ # The following is an example of an error message from calling an invalid
function
+
+ # error: Error occured when invoking the function sqrt:
+ # loop of ufunc does not support argument 0 of type Var which has no
callable sqrt method
+ # --> test_tvmscript_syntax_sugar.py:334:19
+ # |
+ # 334 | ind = sqrt(i)
+ # | ^^^^^^^
+ # note: run with `TVM_BACKTRACE=1` environment variable to display a
backtrace.
+
+ # Uncomment to see the error above.
+ # def sqrt(x):
+ # import numpy as np
+ # return np.sqrt(x)
+
+ # @T.prim_func
+ # def loop(a: T.handle) -> None:
+ # A = T.match_buffer(a, (128,))
+ # for i in T.serial(128):
+ # ind = sqrt(i)
+ # A[i] = A[ind]
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))