This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 36e473f58b [TIR] Support sequence comparisons in TVMScript (#18341)
36e473f58b is described below

commit 36e473f58bda75e03f745d30da09033d0ab880f9
Author: Siyuan Feng <[email protected]>
AuthorDate: Fri Sep 26 03:12:31 2025 +0800

    [TIR] Support sequence comparisons in TVMScript (#18341)
    
    Implement proper parsing and evaluation of chained comparison operators
    (e.g., `0 < i < 128`) in TVMScript. The sequence comparisons are now
    correctly expanded to their logical equivalents (e.g., `(0 < i and i < 
128)`).
    
    Changes:
    - Updated expression evaluator to handle sequence comparisons correctly
    - Added test case to verify sequence comparison functionality
---
 python/tvm/script/parser/core/evaluator.py          | 16 ++++++++++++----
 tests/python/tvmscript/test_tvmscript_parser_tir.py | 20 ++++++++++++++++++++
 2 files changed, 32 insertions(+), 4 deletions(-)

diff --git a/python/tvm/script/parser/core/evaluator.py 
b/python/tvm/script/parser/core/evaluator.py
index 9969dd80f5..7668fa99e6 100644
--- a/python/tvm/script/parser/core/evaluator.py
+++ b/python/tvm/script/parser/core/evaluator.py
@@ -324,10 +324,18 @@ class ExprEvaluator:
         res : Any
             The evaluation result.
         """
-        value = self._eval_expr(fields["left"])
-        for op, rhs in zip(fields["ops"], fields["comparators"]):
-            value = _eval_op(op, values=[value, self._eval_expr(rhs)])
-        return value
+        values = [self._eval_expr(fields["left"])]
+        values.extend([self._eval_expr(rhs) for rhs in fields["comparators"]])
+        result = None
+        assert len(fields["ops"]) == len(values) - 1
+
+        for index, op in enumerate(fields["ops"]):
+            sub_result = _eval_op(op, values=[values[index], values[index + 
1]])
+            if result is None:
+                result = sub_result
+            else:
+                result = _eval_op(doc.And(), values=[result, sub_result])
+        return result
 
     def _eval_unary_op(self, fields: Dict[str, Any]) -> Any:
         """The doc AST unary operation node evaluating method.
diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py 
b/tests/python/tvmscript/test_tvmscript_parser_tir.py
index d28e4680ae..f1569be5b1 100644
--- a/tests/python/tvmscript/test_tvmscript_parser_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py
@@ -626,5 +626,25 @@ def test_ifexp():
     tvm.ir.assert_structural_equal(func, expected)
 
 
+def test_sequence_compare():
+    @T.prim_func(private=True)
+    def tir_func(A: T.Buffer((128, 128), "float32")):
+        for i, j in T.grid(128, 128):
+            if 0 < i < 128 and 0 < j < 128:
+                A[i, j] = 1
+            else:
+                A[i, j] = 0
+
+    @T.prim_func(private=True)
+    def expected(A: T.buffer((128, 128), "float32")):
+        for i, j in T.grid(128, 128):
+            if (0 < i and i < 128) and (0 < j and j < 128):
+                A[i, j] = 1
+            else:
+                A[i, j] = 0
+
+    tvm.ir.assert_structural_equal(tir_func, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to