hypercubestart commented on a change in pull request #5881:
URL: https://github.com/apache/incubator-tvm/pull/5881#discussion_r475823330
##########
File path: tests/python/relay/test_type_infer.py
##########
@@ -362,6 +365,147 @@ def test_let_polymorphism():
int32 = relay.TensorType((), "int32")
tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32,
relay.TupleType([])]))
+def test_mutual_recursion():
+ # f(x) = if x > 0 then g(x - 1) else 0
+ # g(y) = if y > 0 then f(y - 1) else 0
+ tensortype = relay.TensorType((), 'float32')
+
+ x = relay.Var("x")
+ y = relay.Var("y")
+
+ zero = relay.Constant(tvm.nd.array(np.array(0, dtype='float32')))
+ one = relay.Constant(tvm.nd.array(np.array(1, dtype='float32')))
+
+ f_gv = relay.GlobalVar('f')
+ g_gv = relay.GlobalVar('g')
+
+ def body(var, call_func):
+ subtract_one = relay.op.subtract(var, one)
+ cond = relay.If(relay.op.greater(var, zero),
+ relay.Call(call_func, [subtract_one]),
+ zero)
+ func = relay.Function([var], cond)
+ return func
+
+ f = body(x, g_gv)
+ g = body(y, f_gv)
+
+ mod = tvm.IRModule()
+ # p = Prelude(mod)
+ mod.add_unchecked(f_gv, f)
+ mod.add_unchecked(g_gv, g)
+ mod = transform.InferTypeAll()(mod)
+
+ expected = relay.FuncType([tensortype], tensortype)
+ tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected)
+ tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected)
+
+def test_mutual_recursion_adt():
+ # f[A](x: A) = match x {
+ # Cons(a, Nil) => a
+ # Cons(_, b) => g(b)
+ # }
+ # g[B](y: B) = match y {
+ # Cons(a, Nil) => a
+ # Cons(_, b) => f(b)
+ # }
+ p = Prelude()
+ l = p.l
+
+ A = relay.TypeVar("A")
+ B = relay.TypeVar("B")
+
+ x = relay.Var("x")
+ y = relay.Var("y")
+
+ f_gv = relay.GlobalVar('f')
+ g_gv = relay.GlobalVar('g')
+
+ def body(var, call_func, type_param):
+ a = relay.Var("a", type_param)
+ b = relay.Var("b")
+ body = relay.Match(
+ var,
+ [
+ relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternVar(a), relay.PatternConstructor(p.nil)]), a),
+ relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternWildcard(), relay.PatternVar(b)]), relay.Call(call_func, [b]))
+ ],
+ complete=False
+ )
+ func = relay.Function([var], body, type_params=[type_param])
+ return func
+
+ f = body(x, g_gv, A)
+ g = body(y, f_gv, B)
+
+ mod = p.mod
+ mod.add_unchecked(f_gv, f)
+ mod.add_unchecked(g_gv, g)
+ mod = transform.InferTypeAll()(mod)
+
+ tv = relay.TypeVar("test")
+ expected = relay.FuncType([l(tv)], tv, [tv])
+ tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected)
+ tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected)
+
+def test_mutual_recursion_peano():
+ # even and odd function for peano function
+ # even(x) = match x {
+ # z => true
+ # s(a: nat) => odd(a)
+ # }
+ # odd(x) = match x {
+ # z => false
+ # s(a: nat) => even(a)
+ # }
+ p = Prelude()
+ add_nat_definitions(p)
+ z = p.z
Review comment:
the parser isn't set up to handle mutual recursive functions, we would
have to make changes to the parser which feels out of the scope of this PR
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]