This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 36e473f58b [TIR] Support sequence comparisons in TVMScript (#18341)
36e473f58b is described below
commit 36e473f58bda75e03f745d30da09033d0ab880f9
Author: Siyuan Feng <[email protected]>
AuthorDate: Fri Sep 26 03:12:31 2025 +0800
[TIR] Support sequence comparisons in TVMScript (#18341)
Implement proper parsing and evaluation of chained comparison operators
(e.g., `0 < i < 128`) in TVMScript. The sequence comparisons are now
correctly expanded to their logical equivalents (e.g., `(0 < i and i <
128)`).
Changes:
- Updated expression evaluator to handle sequence comparisons correctly
- Added test case to verify sequence comparison functionality
---
python/tvm/script/parser/core/evaluator.py | 16 ++++++++++++----
tests/python/tvmscript/test_tvmscript_parser_tir.py | 20 ++++++++++++++++++++
2 files changed, 32 insertions(+), 4 deletions(-)
diff --git a/python/tvm/script/parser/core/evaluator.py
b/python/tvm/script/parser/core/evaluator.py
index 9969dd80f5..7668fa99e6 100644
--- a/python/tvm/script/parser/core/evaluator.py
+++ b/python/tvm/script/parser/core/evaluator.py
@@ -324,10 +324,18 @@ class ExprEvaluator:
res : Any
The evaluation result.
"""
- value = self._eval_expr(fields["left"])
- for op, rhs in zip(fields["ops"], fields["comparators"]):
- value = _eval_op(op, values=[value, self._eval_expr(rhs)])
- return value
+ values = [self._eval_expr(fields["left"])]
+ values.extend([self._eval_expr(rhs) for rhs in fields["comparators"]])
+ result = None
+ assert len(fields["ops"]) == len(values) - 1
+
+ for index, op in enumerate(fields["ops"]):
+ sub_result = _eval_op(op, values=[values[index], values[index +
1]])
+ if result is None:
+ result = sub_result
+ else:
+ result = _eval_op(doc.And(), values=[result, sub_result])
+ return result
def _eval_unary_op(self, fields: Dict[str, Any]) -> Any:
"""The doc AST unary operation node evaluating method.
diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py
b/tests/python/tvmscript/test_tvmscript_parser_tir.py
index d28e4680ae..f1569be5b1 100644
--- a/tests/python/tvmscript/test_tvmscript_parser_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py
@@ -626,5 +626,25 @@ def test_ifexp():
tvm.ir.assert_structural_equal(func, expected)
+def test_sequence_compare():
+ @T.prim_func(private=True)
+ def tir_func(A: T.Buffer((128, 128), "float32")):
+ for i, j in T.grid(128, 128):
+ if 0 < i < 128 and 0 < j < 128:
+ A[i, j] = 1
+ else:
+ A[i, j] = 0
+
+ @T.prim_func(private=True)
+ def expected(A: T.buffer((128, 128), "float32")):
+ for i, j in T.grid(128, 128):
+ if (0 < i and i < 128) and (0 < j and j < 128):
+ A[i, j] = 1
+ else:
+ A[i, j] = 0
+
+ tvm.ir.assert_structural_equal(tir_func, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()