Lunderberg commented on code in PR #12863:
URL: https://github.com/apache/tvm/pull/12863#discussion_r989392512
##########
tests/python/unittest/test_tir_transform_simplify.py:
##########
@@ -547,5 +561,129 @@ def before(A: T.Buffer[16, "float32"]):
expected = before
+class TestRemoveTransitivelyProvableCondition(BaseBeforeAfter):
+ """Remove comparisons that may be proven using multiple others
+
+ For example, the `0 < i` and `i <= j` conditions can be used to prove
+ that `0 < j`.
+ """
+
+ i, j, k = [tvm.tir.Var(name, "int32") for name in "ijk"]
+ zero = tvm.tir.IntImm("int32", 0)
+
+ test_case = tvm.testing.parameter(
+ (tvm.tir.all(zero < i, i <= j), zero < j, True),
+ # Transitive comparisons from LT
+ (tvm.tir.all(i < j, j < k), i < k, True),
+ (tvm.tir.all(i < j, j == k), i < k, True),
+ (tvm.tir.all(i < j, j <= k), i < k, True),
+ (tvm.tir.all(i < j, j > k), i < k, False),
+ (tvm.tir.all(i < j, j >= k), i < k, False),
+ (tvm.tir.all(i < j, j != k), i < k, False),
+ # Transitive comparisons from LE
+ (tvm.tir.all(i <= j, j < k), i < k, True),
+ (tvm.tir.all(i <= j, j == k), i == k, False),
+ (tvm.tir.all(i <= j, j == k), i <= k, True),
+ (tvm.tir.all(i <= j, j <= k), i <= k, True),
+ (tvm.tir.all(i <= j, j <= k), i < k, False),
+ (tvm.tir.all(i <= j, j > k), i < k, False),
+ (tvm.tir.all(i <= j, j >= k), i < k, False),
+ (tvm.tir.all(i <= j, j != k), i < k, False),
+ # Transitive comparisons from GT
+ (tvm.tir.all(i > j, j > k), i > k, True),
+ (tvm.tir.all(i > j, j == k), i > k, True),
+ (tvm.tir.all(i > j, j >= k), i > k, True),
+ (tvm.tir.all(i > j, j < k), i > k, False),
+ (tvm.tir.all(i > j, j <= k), i > k, False),
+ (tvm.tir.all(i > j, j != k), i > k, False),
+ # Transitive comparisons from GE
+ (tvm.tir.all(i >= j, j > k), i > k, True),
+ (tvm.tir.all(i >= j, j == k), i == k, False),
+ (tvm.tir.all(i >= j, j == k), i >= k, True),
+ (tvm.tir.all(i >= j, j >= k), i >= k, True),
+ (tvm.tir.all(i >= j, j >= k), i > k, False),
+ (tvm.tir.all(i >= j, j < k), i > k, False),
+ (tvm.tir.all(i >= j, j <= k), i > k, False),
+ (tvm.tir.all(i >= j, j != k), i > k, False),
+ # GT or LT may be used to prove NE
+ (tvm.tir.all(i == j, j != k), i != k, True),
+ (tvm.tir.all(i == j, j < k), i != k, True),
+ (tvm.tir.all(i == j, j > k), i != k, True),
+ (tvm.tir.all(i == j, j != k), i < k, False),
+ (tvm.tir.all(i == j, j != k), i > k, False),
+ # Because these are integers, x<y is equivalent to x <= y-1,
+ # and may be used in equivalent simplifications.
+ (tvm.tir.all(i < j, j < k), i < k, True),
Review Comment:
Thank you for the catch, and yes, this was intended to be a `i <= j-1` test.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]