This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7ec2d35665 [TIR] Add support for conditional expressions in TVMScript 
(#18323)
7ec2d35665 is described below

commit 7ec2d356653254a2bf9aab7b9b66a25e42a30a53
Author: Siyuan Feng <[email protected]>
AuthorDate: Sat Sep 20 21:13:52 2025 +0800

    [TIR] Add support for conditional expressions in TVMScript (#18323)
    
    Add support for conditional expressions in TVMScript
    
    This PR adds support for conditional expressions in TVMScript parser,
    which allows developers to use Python-style conditional expressions
    
    ```python
    @T.prim_func
    def func(A: T.buffer((128, 128), "float32")):
        for i, j in T.grid(128, 128):
            A[i, j] = i if i < j else j
    
    @T.prim_func
    def expected(A: T.buffer((128, 128), "float32")):
        for i, j in T.grid(128, 128):
            A[i, j] = T.if_then_else(i < j, i, j)
    ```
---
 python/tvm/script/parser/core/evaluator.py         | 41 ++++++++++++++++++----
 .../python/tvmscript/test_tvmscript_parser_tir.py  | 14 ++++++++
 2 files changed, 49 insertions(+), 6 deletions(-)

diff --git a/python/tvm/script/parser/core/evaluator.py 
b/python/tvm/script/parser/core/evaluator.py
index 9d09df3d8e..9969dd80f5 100644
--- a/python/tvm/script/parser/core/evaluator.py
+++ b/python/tvm/script/parser/core/evaluator.py
@@ -19,6 +19,8 @@
 import ast
 from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, 
Union
 
+import tvm
+
 from . import dispatch, doc
 from .error import ParserError
 
@@ -173,18 +175,19 @@ class ExprEvaluator:
             isinstance(node, doc.Call)
             and hasattr(node.func, "attr")
             and node.func.attr not in ["reads", "writes", "match_buffer", 
"realize"]
-        ) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, 
doc.BoolOp)):
+        ) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, 
doc.BoolOp, doc.IfExp)):
             if isinstance(node, doc.BinOp):
                 args = [node.left, node.right]
             elif isinstance(node, doc.UnaryOp):
                 args = [node.operand]
             elif isinstance(node, doc.Compare):
                 args = [node.left, *node.comparators]
-            else:
-                if isinstance(node, doc.Call):
-                    args = node.args
-                elif isinstance(node, doc.BoolOp):
-                    args = node.values
+            elif isinstance(node, doc.IfExp):
+                args = [node.test, node.body, node.orelse]
+            elif isinstance(node, doc.Call):
+                args = node.args
+            elif isinstance(node, doc.BoolOp):
+                args = node.values
         for arg in args:
             if isinstance(arg, doc.Subscript) and isinstance(arg.slice, 
(doc.Slice, doc.Tuple)):
                 if isinstance(arg.slice, doc.Slice):
@@ -256,6 +259,8 @@ class ExprEvaluator:
                 value = self._eval_unary_op(fields)
             elif isinstance(node, doc.BinOp):
                 value = self._eval_bin_op(fields)
+            elif isinstance(node, doc.IfExp):
+                value = self._eval_if_exp(fields)
             elif isinstance(node, doc.Slice):
                 value = self._eval_slice(fields)
             else:
@@ -364,6 +369,30 @@ class ExprEvaluator:
             ],
         )
 
+    def _eval_if_exp(self, fields: Dict[str, Any]) -> Any:
+        """The doc AST if-else expression node evaluating method.
+
+        Parameters
+        ----------
+        fields : Dict[str, Any]
+            The dictionary of if-else expression information,
+            e.g., test, body, orelse.
+
+        Returns
+        -------
+        res : Any
+            The evaluation result.
+        """
+        test = self._eval_expr(fields["test"])
+        body = self._eval_expr(fields["body"])
+        orelse = self._eval_expr(fields["orelse"])
+        if isinstance(test, bool):
+            return body if test else orelse
+        elif isinstance(test, tvm.tir.PrimExpr) and test.dtype == "bool":
+            return tvm.tir.op.if_then_else(test, body, orelse)
+        else:
+            raise TypeError(f"Expected Python bool or TIR bool, but got 
{type(test)}")
+
     def _eval_slice(self, fields: Dict[str, Any]) -> slice:
         """The doc AST slice node evaluating method.
 
diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py 
b/tests/python/tvmscript/test_tvmscript_parser_tir.py
index fd196be72a..d28e4680ae 100644
--- a/tests/python/tvmscript/test_tvmscript_parser_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py
@@ -612,5 +612,19 @@ def test_alloc_inside_block():
     tvm.ir.assert_structural_equal(func, expected)
 
 
+def test_ifexp():
+    @T.prim_func(private=True)
+    def func(A: T.buffer((128, 128), "float32")):
+        for i, j in T.grid(128, 128):
+            A[i, j] = i if i < j else j
+
+    @T.prim_func(private=True)
+    def expected(A: T.buffer((128, 128), "float32")):
+        for i, j in T.grid(128, 128):
+            A[i, j] = T.if_then_else(i < j, i, j)
+
+    tvm.ir.assert_structural_equal(func, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to