This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 987ac35b66 [Unity] Relax Recursive function (#14092)
987ac35b66 is described below

commit 987ac35b66e685dba5f2f38bdfa633fbd9ee8c79
Author: Yong Wu <[email protected]>
AuthorDate: Wed Feb 22 16:00:06 2023 -0800

    [Unity] Relax Recursive function (#14092)
    
    This PR adds TVMScript local recursive function support. It also update 
lambda lifting pass. Removed CalledGlobalVars, it was not used anymore. It also 
updates well-form pass to allow un-defined vars for recursive call
---
 include/tvm/relax/analysis.h                     |  9 ---
 include/tvm/script/ir_builder/relax/ir.h         |  7 ++
 python/tvm/script/ir_builder/relax/ir.py         | 17 ++++-
 python/tvm/script/parser/relax/parser.py         | 62 +++++++++++++++--
 src/relax/analysis/analysis.cc                   | 20 ------
 src/relax/analysis/well_formed.cc                | 11 +++-
 src/relax/transform/lambda_lift.cc               | 84 ++++++++++++++++++------
 src/script/ir_builder/relax/ir.cc                |  9 +++
 tests/python/relax/test_analysis_well_formed.py  | 27 ++++++++
 tests/python/relax/test_transform_lambda_lift.py | 34 +++++-----
 tests/python/relax/test_utils.py                 |  6 +-
 11 files changed, 213 insertions(+), 73 deletions(-)

diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index b9866577e9..39ecfd9e13 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -296,15 +296,6 @@ TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
  */
 TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
 
-/*!
- * \brief Get all global variables used in calls in expression expr.
- *
- * \param expr the expression.
- *
- * \return List of all global variables called in expr.
- */
-TVM_DLL tvm::Array<GlobalVar> CalledGlobalVars(const Expr& expr);
-
 /*!
  * \brief Get all global variables from expression expr.
  *
diff --git a/include/tvm/script/ir_builder/relax/ir.h 
b/include/tvm/script/ir_builder/relax/ir.h
index 72aab6684e..42aa591a95 100644
--- a/include/tvm/script/ir_builder/relax/ir.h
+++ b/include/tvm/script/ir_builder/relax/ir.h
@@ -110,6 +110,13 @@ TVM_DLL tvm::relax::Var Emit(
 TVM_DLL tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value,
                                       const tvm::relax::StructInfo& 
struct_info);
 
+/*!
+ * \brief Emit a binding to the last binding block frame.
+ * \param binding The binding to be emitted.
+ * \return The left side var of the emitted binding.
+ */
+TVM_DLL tvm::relax::Var EmitVarBinding(const tvm::relax::VarBinding& binding);
+
 ///////////////////////////// If Then Else /////////////////////////////
 
 /*!
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 43918ce7ec..63efea135c 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 import tvm
 from tvm import DataType, relax
 from tvm.ir import PrimExpr
-from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, const
+from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, VarBinding, 
const
 
 ############################### Operators ###############################
 from tvm.relax.op import (
@@ -342,6 +342,20 @@ def emit_match_cast(value: Expr, struct_info: StructInfo) 
-> Var:
     return _ffi_api.EmitMatchCast(value, struct_info)  # type: ignore
 
 
+def emit_var_binding(value: VarBinding) -> Var:
+    """Emit a binding to the last binding block frame.
+    Parameters
+    ----------
+    value: VarBinding
+        The binding to be emitted.
+    Returns
+    -------
+    var: Var
+        The left side var of the emitted binding.
+    """
+    return _ffi_api.EmitVarBinding(value)  # type: ignore
+
+
 ############################# If Then Else #############################
 
 
@@ -497,6 +511,7 @@ __all__ = [
     "divide",
     "dtype",
     "emit",
+    "emit_var_binding",
     "emit_match_cast",
     "equal",
     "ewise_fma",
diff --git a/python/tvm/script/parser/relax/parser.py 
b/python/tvm/script/parser/relax/parser.py
index e5e5bb2743..e1af1c1df3 100644
--- a/python/tvm/script/parser/relax/parser.py
+++ b/python/tvm/script/parser/relax/parser.py
@@ -96,8 +96,7 @@ def eval_struct_info_proxy(self: Parser, node: doc.expr) -> 
StructInfoProxy:
             annotation = annotation()
         if isinstance(annotation, StructInfoProxy):
             return annotation
-        else:
-            raise TypeError(f"Expected StructInfoProxy but got 
{type(annotation)}.")
+        raise TypeError(f"Expected StructInfoProxy but got 
{type(annotation)}.")
     except Exception as err:
         self.report_error(node, str(err))
         raise err
@@ -112,6 +111,38 @@ def eval_struct_info(self: Parser, node: doc.expr, 
eval_str: bool = False) -> St
         raise err
 
 
+def is_called(node: Any, func_name: str) -> bool:
+    # Check if it calls into a func
+    if isinstance(node, doc.Call):
+        # Recursive call was found
+        if isinstance(node.func, doc.Name) and node.func.id == func_name:
+            return True
+    elif isinstance(node, (list, tuple)):
+        for stmt in node:
+            if is_called(stmt, func_name):
+                return True
+    elif isinstance(node, (doc.AnnAssign, doc.Assign, doc.Return, doc.Expr)):
+        return is_called(node.value, func_name)
+    elif isinstance(node, doc.With):
+        return is_called(node.body, func_name)
+    elif isinstance(node, doc.If):
+        smts = []
+        if node.body is not None:
+            smts = smts + list(node.body)
+        if node.orelse is not None:
+            smts = smts + list(node.orelse)
+        return is_called(smts, func_name)
+    return False
+
+
+def is_recursive(node: doc.FunctionDef) -> bool:
+    # Check if it is a recursive function
+    for stmt in node.body:
+        if is_called(stmt, node.name):
+            return True
+    return False
+
+
 def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> 
None:
     # Collect symbolic vars from parameters
     symbolic_vars = set()
@@ -128,6 +159,24 @@ def collect_symbolic_var_from_params(self: Parser, node: 
doc.FunctionDef) -> Non
 
 @dispatch.register(token="relax", type_name="FunctionDef")
 def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
+    # reserve a var for local function
+    func_val = self.var_table.get().get(node.name)
+    if not func_val and is_recursive(node):
+        collect_symbolic_var_from_params(self, node)
+        if node.returns is None:
+            ret_sinfo = relax.TupleStructInfo([])
+        else:
+            ret_sinfo = eval_struct_info(self, node.returns, eval_str=True)
+        params_sinfo = []
+        for arg in node.args.args:
+            if arg.annotation is None:
+                self.report_error(arg, "Type annotation is required for 
function parameters.")
+            param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True)
+            params_sinfo.append(param_sinfo)
+        # created a var for the local function, the same var could be used for 
recursive call
+        local_func_var = relax.Var(node.name, 
relax.FuncStructInfo(params_sinfo, ret_sinfo))
+        self.var_table.add(node.name, local_func_var)
+
     with self.var_table.with_frame():
         with self.with_dispatch_token("relax"):
             with R.function():
@@ -164,12 +213,10 @@ def visit_tvm_declare_function(self: Parser, node: 
doc.FunctionDef) -> None:
         else:
             ret_sinfo = eval_struct_info(self, node.returns, eval_str=True)
         params = []
-        params_sinfo = []
         for arg in node.args.args:
             if arg.annotation is None:
                 self.report_error(arg, "Type annotation is required for 
function parameters.")
             param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True)
-            params_sinfo.append(param_sinfo)
             params.append(relax.Var(arg.arg, param_sinfo))
 
     func_signature = relax.Function.create_empty(params, ret_sinfo)
@@ -188,7 +235,12 @@ def post_token_switch(self: Parser, node: doc.Expr) -> 
None:
     ir_builder = IRBuilder.current()
     result = ir_builder.get()
     ir_builder.__exit__(None, None, None)
-    var = R.emit(result)
+    # reuse var if it is reserved
+    reserved_var = self.var_table.get().get(node.name)
+    if reserved_var:
+        var = R.emit_var_binding(relax.VarBinding(reserved_var, result))
+    else:
+        var = R.emit(result)
     IRBuilder.name(node.name, var)
     self.var_table.add(node.name, var, allow_shadowing=False)
 
diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc
index 33197308fa..4132039a5e 100644
--- a/src/relax/analysis/analysis.cc
+++ b/src/relax/analysis/analysis.cc
@@ -87,15 +87,6 @@ class VarVisitor : protected ExprVisitor {
     return ret;
   }
 
-  Array<GlobalVar> CalledGlobalVars(const Expr& expr) {
-    this->VisitExpr(expr);
-    Array<GlobalVar> ret;
-    for (const auto& v : called_global_vars_.data) {
-      ret.push_back(v);
-    }
-    return ret;
-  }
-
   void MarkBounded(const Var& v) {
     bound_vars_.Insert(v);
     vars_.Insert(v);
@@ -123,10 +114,6 @@ class VarVisitor : protected ExprVisitor {
     for (Expr arg : call_node->args) {
       VisitExpr(arg);
     }
-
-    if (const GlobalVarNode* global_var_node = 
call_node->op.as<GlobalVarNode>()) {
-      called_global_vars_.Insert(GetRef<GlobalVar>(global_var_node));
-    }
   }
 
   void VisitBinding_(const VarBindingNode* binding) final {
@@ -144,7 +131,6 @@ class VarVisitor : protected ExprVisitor {
   InsertionSet<Var> vars_;
   InsertionSet<Var> bound_vars_;
   InsertionSet<GlobalVar> global_vars_;
-  InsertionSet<GlobalVar> called_global_vars_;
 };
 
 tvm::Array<Var> FreeVars(const Expr& expr) { return VarVisitor().Free(expr); }
@@ -155,10 +141,6 @@ tvm::Array<Var> AllVars(const Expr& expr) { return 
VarVisitor().All(expr); }
 
 tvm::Array<GlobalVar> AllGlobalVars(const Expr& expr) { return 
VarVisitor().AllGlobalVars(expr); }
 
-tvm::Array<GlobalVar> CalledGlobalVars(const Expr& expr) {
-  return VarVisitor().CalledGlobalVars(expr);
-}
-
 TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars);
 
 TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars);
@@ -167,7 +149,5 @@ 
TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars);
 
 
TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars);
 
-TVM_REGISTER_GLOBAL("relax.analysis.called_global_vars").set_body_typed(CalledGlobalVars);
-
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/analysis/well_formed.cc 
b/src/relax/analysis/well_formed.cc
index 05ad0954bb..25b9155d77 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -177,7 +177,7 @@ class WellFormedChecker : public relax::ExprVisitor,
 
   void VisitExpr_(const VarNode* op) final {
     Var var = GetRef<Var>(op);
-    if (var_set_.count(var) == 0) {
+    if (var_set_.count(var) == 0 && recur_vars_.count(var) == 0) {
       Malformed(Diagnostic::Error(var) << "Var " << op->name_hint() << " is 
not defined.");
     }
     CheckStructInfo(op);
@@ -316,12 +316,20 @@ class WellFormedChecker : public relax::ExprVisitor,
   }
 
   void VisitBinding_(const VarBindingNode* binding) final {
+    bool is_lambda = false;
+    if (binding->value->IsInstance<FunctionNode>()) {
+      is_lambda = true;
+      recur_vars_.insert(binding->var);
+    }
     if (binding->value->IsInstance<tir::PrimFuncNode>()) {
       Malformed(Diagnostic::Error(binding->value) << "Inline PrimFunc is 
disallowed in Relax IR.");
     } else {
       this->VisitExpr(binding->value);
     }
     this->VisitVarDef(binding->var);
+    if (is_lambda) {
+      recur_vars_.erase(binding->var);
+    }
   }
 
   void VisitBinding_(const MatchCastNode* binding) final {
@@ -451,6 +459,7 @@ class WellFormedChecker : public relax::ExprVisitor,
   VisitMode mode_ = VisitMode::kDefault;
   // set of context variables.
   std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> var_set_;
+  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> recur_vars_;
   std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual> 
dataflow_var_set_;
   std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> 
symbolic_var_set_;
   std::unordered_map<Var, Function, ObjectPtrHash, ObjectPtrEqual> 
param_var_func_map_;
diff --git a/src/relax/transform/lambda_lift.cc 
b/src/relax/transform/lambda_lift.cc
index f08499036b..7492082310 100644
--- a/src/relax/transform/lambda_lift.cc
+++ b/src/relax/transform/lambda_lift.cc
@@ -46,35 +46,72 @@ class LambdaLifter : public ExprMutator {
 
   using ExprMutator::VisitExpr_;
 
+  void VisitBinding_(const VarBindingNode* binding) final {
+    bool is_lambda = false;
+    if (binding->value->IsInstance<FunctionNode>()) {
+      is_lambda = true;
+      recur_vars_.push_back(binding->var);
+    }
+    Expr new_value = this->VisitExpr(binding->value);
+    if (new_value->struct_info_.defined() &&
+        !new_value->struct_info_.same_as(binding->var->struct_info_)) {
+      binding->var->struct_info_ = GetStructInfo(new_value);
+      binding->var->checked_type_ = new_value->checked_type_;
+    }
+    if (new_value.same_as(binding->value)) {
+      builder_->EmitNormalized(GetRef<VarBinding>(binding));
+    } else {
+      builder_->EmitNormalized(VarBinding(binding->var, new_value));
+    }
+    if (is_lambda) {
+      recur_vars_.pop_back();
+    }
+  }
+
   Expr VisitExpr_(const CallNode* call_node) final {
     auto call = Downcast<Call>(ExprMutator::VisitExpr_(call_node));
-    if (auto const* var = call_node->op.as<VarNode>()) {
-      bool has_closure = HasClosure(GetRef<Var>(var));
-      auto val = builder_->LookupBinding(GetRef<Var>(var));
+    if (const auto* var_node = call_node->op.as<VarNode>()) {
+      auto var = GetRef<Var>(var_node);
+      bool has_closure = HasClosure(var);
+      auto val = builder_->LookupBinding(var);
+      if (const auto* fsinfo_node = 
GetStructInfo(var).as<FuncStructInfoNode>()) {
+        auto fsinfo = GetRef<FuncStructInfo>(fsinfo_node);
+        if (!GetStructInfo(call).same_as(fsinfo)) {
+          call->struct_info_ = fsinfo->ret;
+          call->checked_type_ = GetStaticType(fsinfo->ret);
+        }
+      }
       // Call "relax.invoke_closure" to invoke closure
-      if (has_closure && val.as<CallNode>()) {
-        Var clo_arg = GetRef<Var>(var);
+      Var clo_arg = var;
+      if (has_closure && val->IsInstance<CallNode>()) {
         if (this->var_remap_.find(var->vid) != this->var_remap_.end()) {
           clo_arg = this->var_remap_.at(var->vid);
         }
         return Call(invoke_closure_op_, {clo_arg, Tuple(call_node->args)}, {},
                     {GetStructInfo(GetRef<Expr>(call_node))});
       }
-    }
-    if (auto global_var_node = call_node->op.as<GlobalVarNode>()) {
-      String rec_name = global_var_node->name_hint;
-      auto global_var = GetRef<GlobalVar>(global_var_node);
-      auto it = lambda_map_.find(global_var);
+      auto it = lambda_map_.find(var);
       if (it != lambda_map_.end()) {
         // flatten nested call, e.g. call(y)(x) -> call(x, y))
         Array<relay::Expr> new_args;
+        Array<StructInfo> params;
         for (const auto arg : call->args) {
           new_args.push_back(arg);
+          params.push_back(StructInfoFromType(arg->checked_type()));
         }
         if (const auto* nest_call = it->second.as<CallNode>()) {
+          // Update the StructInfo accordingly
           for (const auto arg : nest_call->args) {
             new_args.push_back(arg);
+            params.push_back(StructInfoFromType(arg->checked_type()));
           }
+          StructInfo new_func_sinfo;
+          if (const auto* fsinfo = 
GetStructInfo(nest_call->op).as<FuncStructInfoNode>()) {
+            auto func_sinfo = GetRef<FuncStructInfo>(fsinfo);
+            new_func_sinfo = FuncStructInfo(params, func_sinfo->ret);
+          }
+          nest_call->op->struct_info_ = new_func_sinfo;
+          nest_call->op->checked_type_ = GetStaticType(new_func_sinfo);
           return Call(nest_call->op, new_args, call_node->attrs, 
call_node->sinfo_args);
         }
         return Call(it->second, call->args, call_node->attrs, 
call_node->sinfo_args);
@@ -89,11 +126,19 @@ class LambdaLifter : public ExprMutator {
     // TODO(@yongwww): consider appending inner func name into the lifted func 
name
     String lift_func_name = "lifted_func_" + std::to_string(lift_func_num_++);
     auto global = GlobalVar(lift_func_name);
-    Array<Var> captured_vars = FreeVars(func);
-    recur_vars_ = CalledGlobalVars(func);
-    auto all_global_vars = AllGlobalVars(func);
+    Array<Var> free_vars = FreeVars(func);
+    Array<Var> captured_vars;
 
     Array<Var> typed_captured_vars;
+    bool recursive = false;
+    for (const auto& var : free_vars) {
+      if (!recur_vars_.empty() && var == recur_vars_.back()) {
+        recursive = true;
+      } else {
+        captured_vars.push_back(var);
+      }
+    }
+
     Map<Var, Expr> rebinding_map;
     for (auto free_var : captured_vars) {
       Var var = Var(free_var->name_hint(), GetStructInfo(free_var), 
free_var->span);
@@ -102,12 +147,14 @@ class LambdaLifter : public ExprMutator {
     }
 
     // recursive call
-    if (!recur_vars_.empty()) {
+    if (recursive) {
       if (!captured_vars.empty()) {
         Array<Expr> fvs;
         for (auto fv : captured_vars) {
           fvs.push_back(fv);
         }
+        // it is required by block_blocker, will be updated later
+        UpdateStructInfo(global, GetStructInfo(recur_vars_.back()));
         lambda_map_.emplace(recur_vars_.back(), Call(global, fvs));
       } else {
         if (recur_vars_.size() > 0) {
@@ -162,18 +209,17 @@ class LambdaLifter : public ExprMutator {
                              /*attrs=*/new_func->attrs,
                              /*span=*/func->span);
 
-      Array<Type> param_types;
       for (Var param : closure_params) {
         CHECK(param->checked_type_.defined())
             << "relax.Function requires params to contain checked_type_";
-        param_types.push_back(param->checked_type_);
       }
     }
 
     ICHECK(lifted_func.defined());
 
     // Add the lifted function to the module.
-    UpdateStructInfo(global, GetStructInfo(lifted_func));
+    global->struct_info_ = GetStructInfo(lifted_func);
+    global->checked_type_ = lifted_func->checked_type_;
     builder_->UpdateFunction(global, lifted_func);
 
     if (!is_closure) {
@@ -242,8 +288,8 @@ class LambdaLifter : public ExprMutator {
   }
 
  private:
-  std::unordered_map<GlobalVar, Expr, ObjectPtrHash, ObjectPtrEqual> 
lambda_map_;
-  Array<GlobalVar> recur_vars_;
+  std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> lambda_map_;
+  Array<Var> recur_vars_;
   IRModule mod_;
   size_t lift_func_num_ = 0;
   /*! \brief Cache ops that would be used later to reduce lookup overhead. */
diff --git a/src/script/ir_builder/relax/ir.cc 
b/src/script/ir_builder/relax/ir.cc
index ece645243c..ddfb1ddfa3 100644
--- a/src/script/ir_builder/relax/ir.cc
+++ b/src/script/ir_builder/relax/ir.cc
@@ -203,8 +203,17 @@ tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& 
value,
   return var;
 }
 
+tvm::relax::Var EmitVarBinding(const tvm::relax::VarBinding& binding) {
+  BlockFrame block_frame = CheckBlockFrameExistAndUnended();
+  const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder();
+  block_builder->EmitNormalized(binding);
+  block_frame->emitted_vars.push_back(binding->var);
+  return binding->var;
+}
+
 TVM_REGISTER_GLOBAL("script.ir_builder.relax.Emit").set_body_typed(Emit);
 
TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchCast").set_body_typed(EmitMatchCast);
+TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitVarBinding").set_body_typed(EmitVarBinding);
 
 ///////////////////////////// If Then Else /////////////////////////////
 
diff --git a/tests/python/relax/test_analysis_well_formed.py 
b/tests/python/relax/test_analysis_well_formed.py
index 67da772741..ee5814eb7b 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -173,6 +173,33 @@ def test_seq_expr():
     assert not rx.analysis.well_formed(mod, check_struct_info=False)
 
 
+def test_recursive():
+    scalar_struct_info = rx.TensorStructInfo(shape=[], dtype="int32")
+    gv0 = rx.Var("gv0", scalar_struct_info)
+    f = rx.Var("f", rx.FuncStructInfo([scalar_struct_info], 
scalar_struct_info))
+    ipt = rx.Var("ipt", scalar_struct_info)
+    x0 = rx.Var("x0", scalar_struct_info)
+    x1 = rx.Var("x1", scalar_struct_info)
+    x2 = rx.Var("x2", scalar_struct_info)
+    y = rx.Var("y", scalar_struct_info)
+    inner_block = rx.BindingBlock(
+        [rx.VarBinding(x0, rx.const(2, "int32")), rx.VarBinding(y, rx.Call(f, 
[x0]))]
+    )
+    inner_func = rx.Function([ipt], rx.SeqExpr([inner_block], y), 
scalar_struct_info)
+    outer_block = rx.BindingBlock(
+        [
+            rx.VarBinding(f, inner_func),
+            rx.VarBinding(x1, rx.const(1, "int32")),
+            rx.VarBinding(x2, rx.op.add(x1, rx.Call(f, [x1]))),
+            rx.VarBinding(gv0, x2),
+        ]
+    )
+    func = rx.Function([], rx.SeqExpr([outer_block], gv0), scalar_struct_info)
+    mod = tvm.IRModule.from_expr(func)
+    normalized = rx.transform.Normalize()(mod)
+    assert rx.analysis.well_formed(normalized)
+
+
 def test_if():
     # Error: Var defined in true/false branch is invisible in the outer scope
     # except the return Var, i.e the var in the last stmt
diff --git a/tests/python/relax/test_transform_lambda_lift.py 
b/tests/python/relax/test_transform_lambda_lift.py
index c9bbc0fb91..5a137f22cb 100644
--- a/tests/python/relax/test_transform_lambda_lift.py
+++ b/tests/python/relax/test_transform_lambda_lift.py
@@ -114,7 +114,9 @@ def test_closure():
             x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
         ) -> R.Tensor((2, 3), "float32"):
             @R.function
-            def outer_func(c1: R.Tensor((2, 3), "float32")):
+            def outer_func(
+                c1: R.Tensor((2, 3), "float32")
+            ) -> R.Callable((R.Tensor((2, 3), "float32"),), R.Tensor((2, 3), 
"float32")):
                 @R.function
                 def inner_func(x1: R.Tensor((2, 3), "float32")) -> 
R.Tensor((2, 3), "float32"):
                     s: R.Tensor((2, 3), "float32") = R.add(x1, c1)
@@ -133,7 +135,6 @@ def test_closure():
     _check_save_roundtrip(after)
 
 
[email protected](reason="Need fix after parser switch over")
 def test_recursive():
     # the expected IRModule
     @tvm.script.ir_module
@@ -149,18 +150,19 @@ def test_recursive():
             if cond:
                 new_i: R.Tensor((), "int32") = R.add(i, c)
                 new_s: R.Tensor((2, 3), "float32") = R.add(s, x)
-                r = lifted_func_0(new_i, new_s, x)
+                new_r = lifted_func_0(new_i, new_s, x)
+                r = new_r
             else:
                 r = s
             return r
 
         @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor:
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
dtype="float32"):
             while_loop = R.make_closure(lifted_func_0, (x,))
-            gv = R.invoke_closure(
+            gv: R.Tensor((2, 3), dtype="float32") = R.invoke_closure(
                 while_loop,
-                (relax.const(0), x),
-                sinfo_args=(R.Tensor(ndim=2, dtype="float32")),
+                (R.const(0), x),
+                sinfo_args=(R.Tensor((2, 3), dtype="float32")),
             )
             return gv
 
@@ -185,11 +187,14 @@ def test_recursive():
                     r: R.Tensor((2, 3), "float32") = s
                 return r
 
-            gv: R.Tensor((2, 3), "float32") = while_loop(relax.const(0), x)
+            gv: R.Tensor((2, 3), "float32") = while_loop(R.const(0), x)
             return gv
 
     before = Before
     expected = Expected
+    # check well-formness of recursive call
+    assert relax.analysis.well_formed(before)
+
     # Perform Lambda Lifting
     after = transform.LambdaLift()(before)
     assert len(after.functions) == 2
@@ -198,7 +203,6 @@ def test_recursive():
     _check_save_roundtrip(after)
 
 
[email protected](reason="Need fix after parser switch over")
 def test_multi_func():
     # expected IRModule
     @tvm.script.ir_module
@@ -207,29 +211,29 @@ def test_multi_func():
         def glob_func_1(
             x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
         ) -> R.Tensor(None, "float32", ndim=2):
-            inner = lifted_func_1
-            gv1 = inner(x1, y1)
+            inner = lifted_func_0
+            gv1: R.Tensor((10, 5), "float32") = inner(x1, y1)
             return gv1
 
         @R.function
         def glob_func_2(
             x11: R.Tensor((10, 5), "float32"), y11: R.Tensor((10, 5), 
"float32")
         ) -> R.Tensor(None, "float32", ndim=2):
-            inner1 = lifted_func_0
-            gv11 = inner1(x11, y11)
+            inner = lifted_func_1
+            gv11: R.Tensor((10, 5), "float32") = inner(x11, y11)
             return gv11
 
         @R.function
         def lifted_func_0(
             x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32")
-        ) -> R.Tensor(None, "float32", ndim=2):
+        ) -> R.Tensor((10, 5), "float32"):
             s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
             return s
 
         @R.function
         def lifted_func_1(
             x21: R.Tensor((10, 5), "float32"), y21: R.Tensor((10, 5), 
"float32")
-        ) -> R.Tensor(None, "float32", ndim=2):
+        ) -> R.Tensor((10, 5), "float32"):
             s1: R.Tensor((10, 5), "float32") = R.add(x21, y21)
             return s1
 
diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py
index fbeb57564f..15122dab37 100644
--- a/tests/python/relax/test_utils.py
+++ b/tests/python/relax/test_utils.py
@@ -69,7 +69,7 @@ def test_copy_with_new_vars_on_ir_module_nested_function():
         @R.function
         def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
             @R.function
-            def inner(x: R.Tensor((3,), "float32")):
+            def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), 
dtype="float32"):
                 gv = R.add(x, x)
                 return gv
 
@@ -81,7 +81,7 @@ def test_copy_with_new_vars_on_ir_module_nested_function():
         @R.function
         def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
             @R.function
-            def inner(x: R.Tensor((3,), "float32")):
+            def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), 
dtype="float32"):
                 gv = R.add(x, x)
                 return gv
 
@@ -91,7 +91,7 @@ def test_copy_with_new_vars_on_ir_module_nested_function():
         @R.function
         def func_copied(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), 
"float32")):
             @R.function
-            def inner(x: R.Tensor((3,), "float32")):
+            def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), 
dtype="float32"):
                 gv = R.add(x, x)
                 return gv
 

Reply via email to