This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4c45b828be [Relax] Unit-test for structural equal of recursive
function (#16796)
4c45b828be is described below
commit 4c45b828be94d7e13fb6f8f87cbdacb4c462bb93
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Mar 28 07:33:19 2024 -0500
[Relax] Unit-test for structural equal of recursive function (#16796)
A follow-up PR to https://github.com/apache/tvm/pull/16756, adding an
explicit unit test for `tvm.ir.assert_structural_equal` of two
distinct recursive functions.
---
tests/python/relax/test_utils.py | 65 ++++++++++++++++++++++++++++++++++++++++
1 file changed, 65 insertions(+)
diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py
index 9abc53484b..41b0e714d1 100644
--- a/tests/python/relax/test_utils.py
+++ b/tests/python/relax/test_utils.py
@@ -206,5 +206,70 @@ def test_structural_equal_with_recursive_lambda_function():
tvm.ir.assert_structural_equal(func_1, func_2)
+def test_structural_equal_with_distinct_recursive_lambda_function():
+ """A recursive lambda function may be checked for structural equality
+
+ Like `test_structural_equal_with_recursive_lambda_function`, but
+ comparing between two distinct functions.
+ """
+
+ @R.function(private=True)
+ def func_a(n: R.Prim("int64")):
+ @R.function
+ def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"):
+ i = T.int64()
+ if R.prim_value(i == 0):
+ output = R.prim_value(T.int64(0))
+ # ^
+ # The first mismatch is here ^
+ else:
+ remainder_relax = recursive_lambda(R.prim_value(i - 1))
+ remainder_tir = T.int64()
+ _ = R.match_cast(remainder_relax, R.Prim(value=remainder_tir))
+ output = R.prim_value(i + remainder_tir)
+ return output
+
+ return recursive_lambda(n)
+
+ @R.function(private=True)
+ def func_b(n: R.Prim("int64")):
+ @R.function
+ def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"):
+ i = T.int64()
+ if R.prim_value(i == 0):
+ output = R.prim_value(T.int64(1))
+ # ^
+ # The first mismatch is here ^
+ else:
+ remainder_relax = recursive_lambda(R.prim_value(i - 1))
+ remainder_tir = T.int64()
+ _ = R.match_cast(remainder_relax, R.Prim(value=remainder_tir))
+ output = R.prim_value(i * remainder_tir)
+ return output
+
+ return recursive_lambda(n)
+
+ # The path to the first mismatch, which should appear within the
+ # error message.
+ mismatch_path = [
+ "<root>",
+ "body",
+ "blocks[0]",
+ "bindings[0]",
+ "value",
+ "body",
+ "blocks[0]",
+ "bindings[0]",
+ "value",
+ "true_branch",
+ "body",
+ "value",
+ "value",
+ ]
+
+ with pytest.raises(ValueError, match=re.escape(".".join(mismatch_path))):
+ tvm.ir.assert_structural_equal(func_a, func_b)
+
+
if __name__ == "__main__":
pytest.main([__file__])