This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4c45b828be [Relax] Unit-test for structural equal of recursive 
function (#16796)
4c45b828be is described below

commit 4c45b828be94d7e13fb6f8f87cbdacb4c462bb93
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Mar 28 07:33:19 2024 -0500

    [Relax] Unit-test for structural equal of recursive function (#16796)
    
    A follow-up PR to https://github.com/apache/tvm/pull/16756, adding an
    explicit unit test for `tvm.ir.assert_structural_equal` of two
    distinct recursive functions.
---
 tests/python/relax/test_utils.py | 65 ++++++++++++++++++++++++++++++++++++++++
 1 file changed, 65 insertions(+)

diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py
index 9abc53484b..41b0e714d1 100644
--- a/tests/python/relax/test_utils.py
+++ b/tests/python/relax/test_utils.py
@@ -206,5 +206,70 @@ def test_structural_equal_with_recursive_lambda_function():
     tvm.ir.assert_structural_equal(func_1, func_2)
 
 
+def test_structural_equal_with_distinct_recursive_lambda_function():
+    """A recursive lambda function may be checked for structural equality
+
+    Like `test_structural_equal_with_recursive_lambda_function`, but
+    comparing between two distinct functions.
+    """
+
+    @R.function(private=True)
+    def func_a(n: R.Prim("int64")):
+        @R.function
+        def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"):
+            i = T.int64()
+            if R.prim_value(i == 0):
+                output = R.prim_value(T.int64(0))
+                #                             ^
+                # The first mismatch is here  ^
+            else:
+                remainder_relax = recursive_lambda(R.prim_value(i - 1))
+                remainder_tir = T.int64()
+                _ = R.match_cast(remainder_relax, R.Prim(value=remainder_tir))
+                output = R.prim_value(i + remainder_tir)
+            return output
+
+        return recursive_lambda(n)
+
+    @R.function(private=True)
+    def func_b(n: R.Prim("int64")):
+        @R.function
+        def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"):
+            i = T.int64()
+            if R.prim_value(i == 0):
+                output = R.prim_value(T.int64(1))
+                #                             ^
+                # The first mismatch is here  ^
+            else:
+                remainder_relax = recursive_lambda(R.prim_value(i - 1))
+                remainder_tir = T.int64()
+                _ = R.match_cast(remainder_relax, R.Prim(value=remainder_tir))
+                output = R.prim_value(i * remainder_tir)
+            return output
+
+        return recursive_lambda(n)
+
+    # The path to the first mismatch, which should appear within the
+    # error message.
+    mismatch_path = [
+        "<root>",
+        "body",
+        "blocks[0]",
+        "bindings[0]",
+        "value",
+        "body",
+        "blocks[0]",
+        "bindings[0]",
+        "value",
+        "true_branch",
+        "body",
+        "value",
+        "value",
+    ]
+
+    with pytest.raises(ValueError, match=re.escape(".".join(mismatch_path))):
+        tvm.ir.assert_structural_equal(func_a, func_b)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to