MarisaKirisame commented on a change in pull request #5881:
URL: https://github.com/apache/incubator-tvm/pull/5881#discussion_r475539746
##########
File path: src/relay/analysis/type_solver.h
##########
@@ -65,6 +65,9 @@ class TypeSolver {
public:
TypeSolver(const GlobalVar& current_func, const IRModule& _mod,
ErrorReporter* err_reporter);
~TypeSolver();
+
+ void SetCurrentFunc(GlobalVar current_func) { this->current_func =
current_func; }
Review comment:
```suggestion
void SetCurrentFunc(const GlobalVar& current_func) { this->current_func =
current_func; }
```
##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -86,6 +86,31 @@ struct ResolvedTypeInfo {
Array<Type> type_args = Array<Type>(ObjectPtr<Object>(nullptr));
};
+// helper class to dedup typevars of a type
+// - types do not have to be already typechecked
+//
+// This is used to Dedup GlobalVar type to avoid
+// incorrect type resolving across different usages
+class DeDupType : public TypeMutator, public ExprMutator, public
PatternMutator {
Review comment:
move to common file
##########
File path: tests/python/relay/test_type_infer.py
##########
@@ -362,6 +365,147 @@ def test_let_polymorphism():
int32 = relay.TensorType((), "int32")
tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32,
relay.TupleType([])]))
+def test_mutual_recursion():
+ # f(x) = if x > 0 then g(x - 1) else 0
+ # g(y) = if y > 0 then f(y - 1) else 0
+ tensortype = relay.TensorType((), 'float32')
+
+ x = relay.Var("x")
+ y = relay.Var("y")
+
+ zero = relay.Constant(tvm.nd.array(np.array(0, dtype='float32')))
+ one = relay.Constant(tvm.nd.array(np.array(1, dtype='float32')))
+
+ f_gv = relay.GlobalVar('f')
+ g_gv = relay.GlobalVar('g')
+
+ def body(var, call_func):
+ subtract_one = relay.op.subtract(var, one)
+ cond = relay.If(relay.op.greater(var, zero),
+ relay.Call(call_func, [subtract_one]),
+ zero)
+ func = relay.Function([var], cond)
+ return func
+
+ f = body(x, g_gv)
+ g = body(y, f_gv)
+
+ mod = tvm.IRModule()
+ # p = Prelude(mod)
+ mod.add_unchecked(f_gv, f)
+ mod.add_unchecked(g_gv, g)
+ mod = transform.InferTypeAll()(mod)
+
+ expected = relay.FuncType([tensortype], tensortype)
+ tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected)
+ tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected)
+
+def test_mutual_recursion_adt():
+ # f[A](x: A) = match x {
+ # Cons(a, Nil) => a
+ # Cons(_, b) => g(b)
+ # }
+ # g[B](y: B) = match y {
+ # Cons(a, Nil) => a
+ # Cons(_, b) => f(b)
+ # }
+ p = Prelude()
+ l = p.l
+
+ A = relay.TypeVar("A")
+ B = relay.TypeVar("B")
+
+ x = relay.Var("x")
+ y = relay.Var("y")
+
+ f_gv = relay.GlobalVar('f')
+ g_gv = relay.GlobalVar('g')
+
+ def body(var, call_func, type_param):
+ a = relay.Var("a", type_param)
+ b = relay.Var("b")
+ body = relay.Match(
+ var,
+ [
+ relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternVar(a), relay.PatternConstructor(p.nil)]), a),
+ relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternWildcard(), relay.PatternVar(b)]), relay.Call(call_func, [b]))
+ ],
+ complete=False
+ )
+ func = relay.Function([var], body, type_params=[type_param])
+ return func
+
+ f = body(x, g_gv, A)
+ g = body(y, f_gv, B)
+
+ mod = p.mod
+ mod.add_unchecked(f_gv, f)
+ mod.add_unchecked(g_gv, g)
+ mod = transform.InferTypeAll()(mod)
+
+ tv = relay.TypeVar("test")
+ expected = relay.FuncType([l(tv)], tv, [tv])
+ tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected)
+ tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected)
+
+def test_mutual_recursion_peano():
+ # even and odd function for peano function
+ # even(x) = match x {
+ # z => true
+ # s(a: nat) => odd(a)
+ # }
+ # odd(x) = match x {
+ # z => false
+ # s(a: nat) => even(a)
+ # }
+ p = Prelude()
+ add_nat_definitions(p)
+ z = p.z
Review comment:
can you try using the parser?
##########
File path: src/relay/op/type_relations.cc
##########
@@ -126,6 +126,7 @@ bool BroadcastCompRel(const Array<Type>& types, int
num_inputs, const Attrs& att
return true;
}
}
+ reporter->Assign(types[0], types[1]);
Review comment:
why?
##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -109,6 +134,44 @@ class TypeInferencer : private ExprFunctor<Type(const
Expr&)>,
// inference the type of expr.
Expr Infer(Expr expr);
+ void SetCurrentFunc(GlobalVar current_func) {
+ this->current_func_ = current_func;
+ this->solver_.SetCurrentFunc(current_func);
+ }
+
+ void Solve();
+ Expr ResolveType(Expr expr);
+
+ // Lazily get type for expr
+ // expression, we will populate it now, and return the result.
+ Type GetType(const Expr& expr) {
+ auto it = type_map_.find(expr);
+ if (it != type_map_.end() && it->second.checked_type.defined()) {
+ if (expr.as<GlobalVarNode>() != nullptr) {
+ // if we don't dedup GlobalVarNode, two functions that use the same
GlobalVar
+ // may resolve to the same type incorrectly
+ return DeDupType().VisitType(it->second.checked_type);
Review comment:
what if you always dedup?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]