This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7ec2d35665 [TIR] Add support for conditional expressions in TVMScript
(#18323)
7ec2d35665 is described below
commit 7ec2d356653254a2bf9aab7b9b66a25e42a30a53
Author: Siyuan Feng <[email protected]>
AuthorDate: Sat Sep 20 21:13:52 2025 +0800
[TIR] Add support for conditional expressions in TVMScript (#18323)
Add support for conditional expressions in TVMScript
This PR adds support for conditional expressions in TVMScript parser,
which allows developers to use Python-style conditional expressions
```python
@T.prim_func
def func(A: T.buffer((128, 128), "float32")):
for i, j in T.grid(128, 128):
A[i, j] = i if i < j else j
@T.prim_func
def expected(A: T.buffer((128, 128), "float32")):
for i, j in T.grid(128, 128):
A[i, j] = T.if_then_else(i < j, i, j)
```
---
python/tvm/script/parser/core/evaluator.py | 41 ++++++++++++++++++----
.../python/tvmscript/test_tvmscript_parser_tir.py | 14 ++++++++
2 files changed, 49 insertions(+), 6 deletions(-)
diff --git a/python/tvm/script/parser/core/evaluator.py
b/python/tvm/script/parser/core/evaluator.py
index 9d09df3d8e..9969dd80f5 100644
--- a/python/tvm/script/parser/core/evaluator.py
+++ b/python/tvm/script/parser/core/evaluator.py
@@ -19,6 +19,8 @@
import ast
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type,
Union
+import tvm
+
from . import dispatch, doc
from .error import ParserError
@@ -173,18 +175,19 @@ class ExprEvaluator:
isinstance(node, doc.Call)
and hasattr(node.func, "attr")
and node.func.attr not in ["reads", "writes", "match_buffer",
"realize"]
- ) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare,
doc.BoolOp)):
+ ) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare,
doc.BoolOp, doc.IfExp)):
if isinstance(node, doc.BinOp):
args = [node.left, node.right]
elif isinstance(node, doc.UnaryOp):
args = [node.operand]
elif isinstance(node, doc.Compare):
args = [node.left, *node.comparators]
- else:
- if isinstance(node, doc.Call):
- args = node.args
- elif isinstance(node, doc.BoolOp):
- args = node.values
+ elif isinstance(node, doc.IfExp):
+ args = [node.test, node.body, node.orelse]
+ elif isinstance(node, doc.Call):
+ args = node.args
+ elif isinstance(node, doc.BoolOp):
+ args = node.values
for arg in args:
if isinstance(arg, doc.Subscript) and isinstance(arg.slice,
(doc.Slice, doc.Tuple)):
if isinstance(arg.slice, doc.Slice):
@@ -256,6 +259,8 @@ class ExprEvaluator:
value = self._eval_unary_op(fields)
elif isinstance(node, doc.BinOp):
value = self._eval_bin_op(fields)
+ elif isinstance(node, doc.IfExp):
+ value = self._eval_if_exp(fields)
elif isinstance(node, doc.Slice):
value = self._eval_slice(fields)
else:
@@ -364,6 +369,30 @@ class ExprEvaluator:
],
)
+ def _eval_if_exp(self, fields: Dict[str, Any]) -> Any:
+ """The doc AST if-else expression node evaluating method.
+
+ Parameters
+ ----------
+ fields : Dict[str, Any]
+ The dictionary of if-else expression information,
+ e.g., test, body, orelse.
+
+ Returns
+ -------
+ res : Any
+ The evaluation result.
+ """
+ test = self._eval_expr(fields["test"])
+ body = self._eval_expr(fields["body"])
+ orelse = self._eval_expr(fields["orelse"])
+ if isinstance(test, bool):
+ return body if test else orelse
+ elif isinstance(test, tvm.tir.PrimExpr) and test.dtype == "bool":
+ return tvm.tir.op.if_then_else(test, body, orelse)
+ else:
+ raise TypeError(f"Expected Python bool or TIR bool, but got
{type(test)}")
+
def _eval_slice(self, fields: Dict[str, Any]) -> slice:
"""The doc AST slice node evaluating method.
diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py
b/tests/python/tvmscript/test_tvmscript_parser_tir.py
index fd196be72a..d28e4680ae 100644
--- a/tests/python/tvmscript/test_tvmscript_parser_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py
@@ -612,5 +612,19 @@ def test_alloc_inside_block():
tvm.ir.assert_structural_equal(func, expected)
+def test_ifexp():
+ @T.prim_func(private=True)
+ def func(A: T.buffer((128, 128), "float32")):
+ for i, j in T.grid(128, 128):
+ A[i, j] = i if i < j else j
+
+ @T.prim_func(private=True)
+ def expected(A: T.buffer((128, 128), "float32")):
+ for i, j in T.grid(128, 128):
+ A[i, j] = T.if_then_else(i < j, i, j)
+
+ tvm.ir.assert_structural_equal(func, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()