hypercubestart commented on a change in pull request #5881:
URL: https://github.com/apache/incubator-tvm/pull/5881#discussion_r475823330
##########
File path: tests/python/relay/test_type_infer.py
##########
@@ -362,6 +365,147 @@ def test_let_polymorphism():
int32 = relay.TensorType((), "int32")
tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32,
relay.TupleType([])]))
+def test_mutual_recursion():
+ # f(x) = if x > 0 then g(x - 1) else 0
+ # g(y) = if y > 0 then f(y - 1) else 0
+ tensortype = relay.TensorType((), 'float32')
+
+ x = relay.Var("x")
+ y = relay.Var("y")
+
+ zero = relay.Constant(tvm.nd.array(np.array(0, dtype='float32')))
+ one = relay.Constant(tvm.nd.array(np.array(1, dtype='float32')))
+
+ f_gv = relay.GlobalVar('f')
+ g_gv = relay.GlobalVar('g')
+
+ def body(var, call_func):
+ subtract_one = relay.op.subtract(var, one)
+ cond = relay.If(relay.op.greater(var, zero),
+ relay.Call(call_func, [subtract_one]),
+ zero)
+ func = relay.Function([var], cond)
+ return func
+
+ f = body(x, g_gv)
+ g = body(y, f_gv)
+
+ mod = tvm.IRModule()
+ # p = Prelude(mod)
+ mod.add_unchecked(f_gv, f)
+ mod.add_unchecked(g_gv, g)
+ mod = transform.InferTypeAll()(mod)
+
+ expected = relay.FuncType([tensortype], tensortype)
+ tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected)
+ tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected)
+
+def test_mutual_recursion_adt():
+ # f[A](x: A) = match x {
+ # Cons(a, Nil) => a
+ # Cons(_, b) => g(b)
+ # }
+ # g[B](y: B) = match y {
+ # Cons(a, Nil) => a
+ # Cons(_, b) => f(b)
+ # }
+ p = Prelude()
+ l = p.l
+
+ A = relay.TypeVar("A")
+ B = relay.TypeVar("B")
+
+ x = relay.Var("x")
+ y = relay.Var("y")
+
+ f_gv = relay.GlobalVar('f')
+ g_gv = relay.GlobalVar('g')
+
+ def body(var, call_func, type_param):
+ a = relay.Var("a", type_param)
+ b = relay.Var("b")
+ body = relay.Match(
+ var,
+ [
+ relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternVar(a), relay.PatternConstructor(p.nil)]), a),
+ relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternWildcard(), relay.PatternVar(b)]), relay.Call(call_func, [b]))
+ ],
+ complete=False
+ )
+ func = relay.Function([var], body, type_params=[type_param])
+ return func
+
+ f = body(x, g_gv, A)
+ g = body(y, f_gv, B)
+
+ mod = p.mod
+ mod.add_unchecked(f_gv, f)
+ mod.add_unchecked(g_gv, g)
+ mod = transform.InferTypeAll()(mod)
+
+ tv = relay.TypeVar("test")
+ expected = relay.FuncType([l(tv)], tv, [tv])
+ tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected)
+ tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected)
+
+def test_mutual_recursion_peano():
+ # even and odd function for peano function
+ # even(x) = match x {
+ # z => true
+ # s(a: nat) => odd(a)
+ # }
+ # odd(x) = match x {
+ # z => false
+ # s(a: nat) => even(a)
+ # }
+ p = Prelude()
+ add_nat_definitions(p)
+ z = p.z
Review comment:
parser fails to handle a module with mutually recursive functions, we
would have to make changes to the parser which feels out of the scope of this
PR.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]