This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 987ac35b66 [Unity] Relax Recursive function (#14092)
987ac35b66 is described below
commit 987ac35b66e685dba5f2f38bdfa633fbd9ee8c79
Author: Yong Wu <[email protected]>
AuthorDate: Wed Feb 22 16:00:06 2023 -0800
[Unity] Relax Recursive function (#14092)
This PR adds TVMScript local recursive function support. It also update
lambda lifting pass. Removed CalledGlobalVars, it was not used anymore. It also
updates well-form pass to allow un-defined vars for recursive call
---
include/tvm/relax/analysis.h | 9 ---
include/tvm/script/ir_builder/relax/ir.h | 7 ++
python/tvm/script/ir_builder/relax/ir.py | 17 ++++-
python/tvm/script/parser/relax/parser.py | 62 +++++++++++++++--
src/relax/analysis/analysis.cc | 20 ------
src/relax/analysis/well_formed.cc | 11 +++-
src/relax/transform/lambda_lift.cc | 84 ++++++++++++++++++------
src/script/ir_builder/relax/ir.cc | 9 +++
tests/python/relax/test_analysis_well_formed.py | 27 ++++++++
tests/python/relax/test_transform_lambda_lift.py | 34 +++++-----
tests/python/relax/test_utils.py | 6 +-
11 files changed, 213 insertions(+), 73 deletions(-)
diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index b9866577e9..39ecfd9e13 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -296,15 +296,6 @@ TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
*/
TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
-/*!
- * \brief Get all global variables used in calls in expression expr.
- *
- * \param expr the expression.
- *
- * \return List of all global variables called in expr.
- */
-TVM_DLL tvm::Array<GlobalVar> CalledGlobalVars(const Expr& expr);
-
/*!
* \brief Get all global variables from expression expr.
*
diff --git a/include/tvm/script/ir_builder/relax/ir.h
b/include/tvm/script/ir_builder/relax/ir.h
index 72aab6684e..42aa591a95 100644
--- a/include/tvm/script/ir_builder/relax/ir.h
+++ b/include/tvm/script/ir_builder/relax/ir.h
@@ -110,6 +110,13 @@ TVM_DLL tvm::relax::Var Emit(
TVM_DLL tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value,
const tvm::relax::StructInfo&
struct_info);
+/*!
+ * \brief Emit a binding to the last binding block frame.
+ * \param binding The binding to be emitted.
+ * \return The left side var of the emitted binding.
+ */
+TVM_DLL tvm::relax::Var EmitVarBinding(const tvm::relax::VarBinding& binding);
+
///////////////////////////// If Then Else /////////////////////////////
/*!
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 43918ce7ec..63efea135c 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import tvm
from tvm import DataType, relax
from tvm.ir import PrimExpr
-from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, const
+from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, VarBinding,
const
############################### Operators ###############################
from tvm.relax.op import (
@@ -342,6 +342,20 @@ def emit_match_cast(value: Expr, struct_info: StructInfo)
-> Var:
return _ffi_api.EmitMatchCast(value, struct_info) # type: ignore
+def emit_var_binding(value: VarBinding) -> Var:
+ """Emit a binding to the last binding block frame.
+ Parameters
+ ----------
+ value: VarBinding
+ The binding to be emitted.
+ Returns
+ -------
+ var: Var
+ The left side var of the emitted binding.
+ """
+ return _ffi_api.EmitVarBinding(value) # type: ignore
+
+
############################# If Then Else #############################
@@ -497,6 +511,7 @@ __all__ = [
"divide",
"dtype",
"emit",
+ "emit_var_binding",
"emit_match_cast",
"equal",
"ewise_fma",
diff --git a/python/tvm/script/parser/relax/parser.py
b/python/tvm/script/parser/relax/parser.py
index e5e5bb2743..e1af1c1df3 100644
--- a/python/tvm/script/parser/relax/parser.py
+++ b/python/tvm/script/parser/relax/parser.py
@@ -96,8 +96,7 @@ def eval_struct_info_proxy(self: Parser, node: doc.expr) ->
StructInfoProxy:
annotation = annotation()
if isinstance(annotation, StructInfoProxy):
return annotation
- else:
- raise TypeError(f"Expected StructInfoProxy but got
{type(annotation)}.")
+ raise TypeError(f"Expected StructInfoProxy but got
{type(annotation)}.")
except Exception as err:
self.report_error(node, str(err))
raise err
@@ -112,6 +111,38 @@ def eval_struct_info(self: Parser, node: doc.expr,
eval_str: bool = False) -> St
raise err
+def is_called(node: Any, func_name: str) -> bool:
+ # Check if it calls into a func
+ if isinstance(node, doc.Call):
+ # Recursive call was found
+ if isinstance(node.func, doc.Name) and node.func.id == func_name:
+ return True
+ elif isinstance(node, (list, tuple)):
+ for stmt in node:
+ if is_called(stmt, func_name):
+ return True
+ elif isinstance(node, (doc.AnnAssign, doc.Assign, doc.Return, doc.Expr)):
+ return is_called(node.value, func_name)
+ elif isinstance(node, doc.With):
+ return is_called(node.body, func_name)
+ elif isinstance(node, doc.If):
+ smts = []
+ if node.body is not None:
+ smts = smts + list(node.body)
+ if node.orelse is not None:
+ smts = smts + list(node.orelse)
+ return is_called(smts, func_name)
+ return False
+
+
+def is_recursive(node: doc.FunctionDef) -> bool:
+ # Check if it is a recursive function
+ for stmt in node.body:
+ if is_called(stmt, node.name):
+ return True
+ return False
+
+
def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) ->
None:
# Collect symbolic vars from parameters
symbolic_vars = set()
@@ -128,6 +159,24 @@ def collect_symbolic_var_from_params(self: Parser, node:
doc.FunctionDef) -> Non
@dispatch.register(token="relax", type_name="FunctionDef")
def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
+ # reserve a var for local function
+ func_val = self.var_table.get().get(node.name)
+ if not func_val and is_recursive(node):
+ collect_symbolic_var_from_params(self, node)
+ if node.returns is None:
+ ret_sinfo = relax.TupleStructInfo([])
+ else:
+ ret_sinfo = eval_struct_info(self, node.returns, eval_str=True)
+ params_sinfo = []
+ for arg in node.args.args:
+ if arg.annotation is None:
+ self.report_error(arg, "Type annotation is required for
function parameters.")
+ param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True)
+ params_sinfo.append(param_sinfo)
+ # created a var for the local function, the same var could be used for
recursive call
+ local_func_var = relax.Var(node.name,
relax.FuncStructInfo(params_sinfo, ret_sinfo))
+ self.var_table.add(node.name, local_func_var)
+
with self.var_table.with_frame():
with self.with_dispatch_token("relax"):
with R.function():
@@ -164,12 +213,10 @@ def visit_tvm_declare_function(self: Parser, node:
doc.FunctionDef) -> None:
else:
ret_sinfo = eval_struct_info(self, node.returns, eval_str=True)
params = []
- params_sinfo = []
for arg in node.args.args:
if arg.annotation is None:
self.report_error(arg, "Type annotation is required for
function parameters.")
param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True)
- params_sinfo.append(param_sinfo)
params.append(relax.Var(arg.arg, param_sinfo))
func_signature = relax.Function.create_empty(params, ret_sinfo)
@@ -188,7 +235,12 @@ def post_token_switch(self: Parser, node: doc.Expr) ->
None:
ir_builder = IRBuilder.current()
result = ir_builder.get()
ir_builder.__exit__(None, None, None)
- var = R.emit(result)
+ # reuse var if it is reserved
+ reserved_var = self.var_table.get().get(node.name)
+ if reserved_var:
+ var = R.emit_var_binding(relax.VarBinding(reserved_var, result))
+ else:
+ var = R.emit(result)
IRBuilder.name(node.name, var)
self.var_table.add(node.name, var, allow_shadowing=False)
diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc
index 33197308fa..4132039a5e 100644
--- a/src/relax/analysis/analysis.cc
+++ b/src/relax/analysis/analysis.cc
@@ -87,15 +87,6 @@ class VarVisitor : protected ExprVisitor {
return ret;
}
- Array<GlobalVar> CalledGlobalVars(const Expr& expr) {
- this->VisitExpr(expr);
- Array<GlobalVar> ret;
- for (const auto& v : called_global_vars_.data) {
- ret.push_back(v);
- }
- return ret;
- }
-
void MarkBounded(const Var& v) {
bound_vars_.Insert(v);
vars_.Insert(v);
@@ -123,10 +114,6 @@ class VarVisitor : protected ExprVisitor {
for (Expr arg : call_node->args) {
VisitExpr(arg);
}
-
- if (const GlobalVarNode* global_var_node =
call_node->op.as<GlobalVarNode>()) {
- called_global_vars_.Insert(GetRef<GlobalVar>(global_var_node));
- }
}
void VisitBinding_(const VarBindingNode* binding) final {
@@ -144,7 +131,6 @@ class VarVisitor : protected ExprVisitor {
InsertionSet<Var> vars_;
InsertionSet<Var> bound_vars_;
InsertionSet<GlobalVar> global_vars_;
- InsertionSet<GlobalVar> called_global_vars_;
};
tvm::Array<Var> FreeVars(const Expr& expr) { return VarVisitor().Free(expr); }
@@ -155,10 +141,6 @@ tvm::Array<Var> AllVars(const Expr& expr) { return
VarVisitor().All(expr); }
tvm::Array<GlobalVar> AllGlobalVars(const Expr& expr) { return
VarVisitor().AllGlobalVars(expr); }
-tvm::Array<GlobalVar> CalledGlobalVars(const Expr& expr) {
- return VarVisitor().CalledGlobalVars(expr);
-}
-
TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars);
TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars);
@@ -167,7 +149,5 @@
TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars);
TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars);
-TVM_REGISTER_GLOBAL("relax.analysis.called_global_vars").set_body_typed(CalledGlobalVars);
-
} // namespace relax
} // namespace tvm
diff --git a/src/relax/analysis/well_formed.cc
b/src/relax/analysis/well_formed.cc
index 05ad0954bb..25b9155d77 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -177,7 +177,7 @@ class WellFormedChecker : public relax::ExprVisitor,
void VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
- if (var_set_.count(var) == 0) {
+ if (var_set_.count(var) == 0 && recur_vars_.count(var) == 0) {
Malformed(Diagnostic::Error(var) << "Var " << op->name_hint() << " is
not defined.");
}
CheckStructInfo(op);
@@ -316,12 +316,20 @@ class WellFormedChecker : public relax::ExprVisitor,
}
void VisitBinding_(const VarBindingNode* binding) final {
+ bool is_lambda = false;
+ if (binding->value->IsInstance<FunctionNode>()) {
+ is_lambda = true;
+ recur_vars_.insert(binding->var);
+ }
if (binding->value->IsInstance<tir::PrimFuncNode>()) {
Malformed(Diagnostic::Error(binding->value) << "Inline PrimFunc is
disallowed in Relax IR.");
} else {
this->VisitExpr(binding->value);
}
this->VisitVarDef(binding->var);
+ if (is_lambda) {
+ recur_vars_.erase(binding->var);
+ }
}
void VisitBinding_(const MatchCastNode* binding) final {
@@ -451,6 +459,7 @@ class WellFormedChecker : public relax::ExprVisitor,
VisitMode mode_ = VisitMode::kDefault;
// set of context variables.
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> var_set_;
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> recur_vars_;
std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual>
dataflow_var_set_;
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual>
symbolic_var_set_;
std::unordered_map<Var, Function, ObjectPtrHash, ObjectPtrEqual>
param_var_func_map_;
diff --git a/src/relax/transform/lambda_lift.cc
b/src/relax/transform/lambda_lift.cc
index f08499036b..7492082310 100644
--- a/src/relax/transform/lambda_lift.cc
+++ b/src/relax/transform/lambda_lift.cc
@@ -46,35 +46,72 @@ class LambdaLifter : public ExprMutator {
using ExprMutator::VisitExpr_;
+ void VisitBinding_(const VarBindingNode* binding) final {
+ bool is_lambda = false;
+ if (binding->value->IsInstance<FunctionNode>()) {
+ is_lambda = true;
+ recur_vars_.push_back(binding->var);
+ }
+ Expr new_value = this->VisitExpr(binding->value);
+ if (new_value->struct_info_.defined() &&
+ !new_value->struct_info_.same_as(binding->var->struct_info_)) {
+ binding->var->struct_info_ = GetStructInfo(new_value);
+ binding->var->checked_type_ = new_value->checked_type_;
+ }
+ if (new_value.same_as(binding->value)) {
+ builder_->EmitNormalized(GetRef<VarBinding>(binding));
+ } else {
+ builder_->EmitNormalized(VarBinding(binding->var, new_value));
+ }
+ if (is_lambda) {
+ recur_vars_.pop_back();
+ }
+ }
+
Expr VisitExpr_(const CallNode* call_node) final {
auto call = Downcast<Call>(ExprMutator::VisitExpr_(call_node));
- if (auto const* var = call_node->op.as<VarNode>()) {
- bool has_closure = HasClosure(GetRef<Var>(var));
- auto val = builder_->LookupBinding(GetRef<Var>(var));
+ if (const auto* var_node = call_node->op.as<VarNode>()) {
+ auto var = GetRef<Var>(var_node);
+ bool has_closure = HasClosure(var);
+ auto val = builder_->LookupBinding(var);
+ if (const auto* fsinfo_node =
GetStructInfo(var).as<FuncStructInfoNode>()) {
+ auto fsinfo = GetRef<FuncStructInfo>(fsinfo_node);
+ if (!GetStructInfo(call).same_as(fsinfo)) {
+ call->struct_info_ = fsinfo->ret;
+ call->checked_type_ = GetStaticType(fsinfo->ret);
+ }
+ }
// Call "relax.invoke_closure" to invoke closure
- if (has_closure && val.as<CallNode>()) {
- Var clo_arg = GetRef<Var>(var);
+ Var clo_arg = var;
+ if (has_closure && val->IsInstance<CallNode>()) {
if (this->var_remap_.find(var->vid) != this->var_remap_.end()) {
clo_arg = this->var_remap_.at(var->vid);
}
return Call(invoke_closure_op_, {clo_arg, Tuple(call_node->args)}, {},
{GetStructInfo(GetRef<Expr>(call_node))});
}
- }
- if (auto global_var_node = call_node->op.as<GlobalVarNode>()) {
- String rec_name = global_var_node->name_hint;
- auto global_var = GetRef<GlobalVar>(global_var_node);
- auto it = lambda_map_.find(global_var);
+ auto it = lambda_map_.find(var);
if (it != lambda_map_.end()) {
// flatten nested call, e.g. call(y)(x) -> call(x, y))
Array<relay::Expr> new_args;
+ Array<StructInfo> params;
for (const auto arg : call->args) {
new_args.push_back(arg);
+ params.push_back(StructInfoFromType(arg->checked_type()));
}
if (const auto* nest_call = it->second.as<CallNode>()) {
+ // Update the StructInfo accordingly
for (const auto arg : nest_call->args) {
new_args.push_back(arg);
+ params.push_back(StructInfoFromType(arg->checked_type()));
}
+ StructInfo new_func_sinfo;
+ if (const auto* fsinfo =
GetStructInfo(nest_call->op).as<FuncStructInfoNode>()) {
+ auto func_sinfo = GetRef<FuncStructInfo>(fsinfo);
+ new_func_sinfo = FuncStructInfo(params, func_sinfo->ret);
+ }
+ nest_call->op->struct_info_ = new_func_sinfo;
+ nest_call->op->checked_type_ = GetStaticType(new_func_sinfo);
return Call(nest_call->op, new_args, call_node->attrs,
call_node->sinfo_args);
}
return Call(it->second, call->args, call_node->attrs,
call_node->sinfo_args);
@@ -89,11 +126,19 @@ class LambdaLifter : public ExprMutator {
// TODO(@yongwww): consider appending inner func name into the lifted func
name
String lift_func_name = "lifted_func_" + std::to_string(lift_func_num_++);
auto global = GlobalVar(lift_func_name);
- Array<Var> captured_vars = FreeVars(func);
- recur_vars_ = CalledGlobalVars(func);
- auto all_global_vars = AllGlobalVars(func);
+ Array<Var> free_vars = FreeVars(func);
+ Array<Var> captured_vars;
Array<Var> typed_captured_vars;
+ bool recursive = false;
+ for (const auto& var : free_vars) {
+ if (!recur_vars_.empty() && var == recur_vars_.back()) {
+ recursive = true;
+ } else {
+ captured_vars.push_back(var);
+ }
+ }
+
Map<Var, Expr> rebinding_map;
for (auto free_var : captured_vars) {
Var var = Var(free_var->name_hint(), GetStructInfo(free_var),
free_var->span);
@@ -102,12 +147,14 @@ class LambdaLifter : public ExprMutator {
}
// recursive call
- if (!recur_vars_.empty()) {
+ if (recursive) {
if (!captured_vars.empty()) {
Array<Expr> fvs;
for (auto fv : captured_vars) {
fvs.push_back(fv);
}
+ // it is required by block_blocker, will be updated later
+ UpdateStructInfo(global, GetStructInfo(recur_vars_.back()));
lambda_map_.emplace(recur_vars_.back(), Call(global, fvs));
} else {
if (recur_vars_.size() > 0) {
@@ -162,18 +209,17 @@ class LambdaLifter : public ExprMutator {
/*attrs=*/new_func->attrs,
/*span=*/func->span);
- Array<Type> param_types;
for (Var param : closure_params) {
CHECK(param->checked_type_.defined())
<< "relax.Function requires params to contain checked_type_";
- param_types.push_back(param->checked_type_);
}
}
ICHECK(lifted_func.defined());
// Add the lifted function to the module.
- UpdateStructInfo(global, GetStructInfo(lifted_func));
+ global->struct_info_ = GetStructInfo(lifted_func);
+ global->checked_type_ = lifted_func->checked_type_;
builder_->UpdateFunction(global, lifted_func);
if (!is_closure) {
@@ -242,8 +288,8 @@ class LambdaLifter : public ExprMutator {
}
private:
- std::unordered_map<GlobalVar, Expr, ObjectPtrHash, ObjectPtrEqual>
lambda_map_;
- Array<GlobalVar> recur_vars_;
+ std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> lambda_map_;
+ Array<Var> recur_vars_;
IRModule mod_;
size_t lift_func_num_ = 0;
/*! \brief Cache ops that would be used later to reduce lookup overhead. */
diff --git a/src/script/ir_builder/relax/ir.cc
b/src/script/ir_builder/relax/ir.cc
index ece645243c..ddfb1ddfa3 100644
--- a/src/script/ir_builder/relax/ir.cc
+++ b/src/script/ir_builder/relax/ir.cc
@@ -203,8 +203,17 @@ tvm::relax::Var EmitMatchCast(const tvm::relax::Expr&
value,
return var;
}
+tvm::relax::Var EmitVarBinding(const tvm::relax::VarBinding& binding) {
+ BlockFrame block_frame = CheckBlockFrameExistAndUnended();
+ const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder();
+ block_builder->EmitNormalized(binding);
+ block_frame->emitted_vars.push_back(binding->var);
+ return binding->var;
+}
+
TVM_REGISTER_GLOBAL("script.ir_builder.relax.Emit").set_body_typed(Emit);
TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchCast").set_body_typed(EmitMatchCast);
+TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitVarBinding").set_body_typed(EmitVarBinding);
///////////////////////////// If Then Else /////////////////////////////
diff --git a/tests/python/relax/test_analysis_well_formed.py
b/tests/python/relax/test_analysis_well_formed.py
index 67da772741..ee5814eb7b 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -173,6 +173,33 @@ def test_seq_expr():
assert not rx.analysis.well_formed(mod, check_struct_info=False)
+def test_recursive():
+ scalar_struct_info = rx.TensorStructInfo(shape=[], dtype="int32")
+ gv0 = rx.Var("gv0", scalar_struct_info)
+ f = rx.Var("f", rx.FuncStructInfo([scalar_struct_info],
scalar_struct_info))
+ ipt = rx.Var("ipt", scalar_struct_info)
+ x0 = rx.Var("x0", scalar_struct_info)
+ x1 = rx.Var("x1", scalar_struct_info)
+ x2 = rx.Var("x2", scalar_struct_info)
+ y = rx.Var("y", scalar_struct_info)
+ inner_block = rx.BindingBlock(
+ [rx.VarBinding(x0, rx.const(2, "int32")), rx.VarBinding(y, rx.Call(f,
[x0]))]
+ )
+ inner_func = rx.Function([ipt], rx.SeqExpr([inner_block], y),
scalar_struct_info)
+ outer_block = rx.BindingBlock(
+ [
+ rx.VarBinding(f, inner_func),
+ rx.VarBinding(x1, rx.const(1, "int32")),
+ rx.VarBinding(x2, rx.op.add(x1, rx.Call(f, [x1]))),
+ rx.VarBinding(gv0, x2),
+ ]
+ )
+ func = rx.Function([], rx.SeqExpr([outer_block], gv0), scalar_struct_info)
+ mod = tvm.IRModule.from_expr(func)
+ normalized = rx.transform.Normalize()(mod)
+ assert rx.analysis.well_formed(normalized)
+
+
def test_if():
# Error: Var defined in true/false branch is invisible in the outer scope
# except the return Var, i.e the var in the last stmt
diff --git a/tests/python/relax/test_transform_lambda_lift.py
b/tests/python/relax/test_transform_lambda_lift.py
index c9bbc0fb91..5a137f22cb 100644
--- a/tests/python/relax/test_transform_lambda_lift.py
+++ b/tests/python/relax/test_transform_lambda_lift.py
@@ -114,7 +114,9 @@ def test_closure():
x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
) -> R.Tensor((2, 3), "float32"):
@R.function
- def outer_func(c1: R.Tensor((2, 3), "float32")):
+ def outer_func(
+ c1: R.Tensor((2, 3), "float32")
+ ) -> R.Callable((R.Tensor((2, 3), "float32"),), R.Tensor((2, 3),
"float32")):
@R.function
def inner_func(x1: R.Tensor((2, 3), "float32")) ->
R.Tensor((2, 3), "float32"):
s: R.Tensor((2, 3), "float32") = R.add(x1, c1)
@@ -133,7 +135,6 @@ def test_closure():
_check_save_roundtrip(after)
[email protected](reason="Need fix after parser switch over")
def test_recursive():
# the expected IRModule
@tvm.script.ir_module
@@ -149,18 +150,19 @@ def test_recursive():
if cond:
new_i: R.Tensor((), "int32") = R.add(i, c)
new_s: R.Tensor((2, 3), "float32") = R.add(s, x)
- r = lifted_func_0(new_i, new_s, x)
+ new_r = lifted_func_0(new_i, new_s, x)
+ r = new_r
else:
r = s
return r
@R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor:
+ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
dtype="float32"):
while_loop = R.make_closure(lifted_func_0, (x,))
- gv = R.invoke_closure(
+ gv: R.Tensor((2, 3), dtype="float32") = R.invoke_closure(
while_loop,
- (relax.const(0), x),
- sinfo_args=(R.Tensor(ndim=2, dtype="float32")),
+ (R.const(0), x),
+ sinfo_args=(R.Tensor((2, 3), dtype="float32")),
)
return gv
@@ -185,11 +187,14 @@ def test_recursive():
r: R.Tensor((2, 3), "float32") = s
return r
- gv: R.Tensor((2, 3), "float32") = while_loop(relax.const(0), x)
+ gv: R.Tensor((2, 3), "float32") = while_loop(R.const(0), x)
return gv
before = Before
expected = Expected
+ # check well-formness of recursive call
+ assert relax.analysis.well_formed(before)
+
# Perform Lambda Lifting
after = transform.LambdaLift()(before)
assert len(after.functions) == 2
@@ -198,7 +203,6 @@ def test_recursive():
_check_save_roundtrip(after)
[email protected](reason="Need fix after parser switch over")
def test_multi_func():
# expected IRModule
@tvm.script.ir_module
@@ -207,29 +211,29 @@ def test_multi_func():
def glob_func_1(
x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
) -> R.Tensor(None, "float32", ndim=2):
- inner = lifted_func_1
- gv1 = inner(x1, y1)
+ inner = lifted_func_0
+ gv1: R.Tensor((10, 5), "float32") = inner(x1, y1)
return gv1
@R.function
def glob_func_2(
x11: R.Tensor((10, 5), "float32"), y11: R.Tensor((10, 5),
"float32")
) -> R.Tensor(None, "float32", ndim=2):
- inner1 = lifted_func_0
- gv11 = inner1(x11, y11)
+ inner = lifted_func_1
+ gv11: R.Tensor((10, 5), "float32") = inner(x11, y11)
return gv11
@R.function
def lifted_func_0(
x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32")
- ) -> R.Tensor(None, "float32", ndim=2):
+ ) -> R.Tensor((10, 5), "float32"):
s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
return s
@R.function
def lifted_func_1(
x21: R.Tensor((10, 5), "float32"), y21: R.Tensor((10, 5),
"float32")
- ) -> R.Tensor(None, "float32", ndim=2):
+ ) -> R.Tensor((10, 5), "float32"):
s1: R.Tensor((10, 5), "float32") = R.add(x21, y21)
return s1
diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py
index fbeb57564f..15122dab37 100644
--- a/tests/python/relax/test_utils.py
+++ b/tests/python/relax/test_utils.py
@@ -69,7 +69,7 @@ def test_copy_with_new_vars_on_ir_module_nested_function():
@R.function
def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
@R.function
- def inner(x: R.Tensor((3,), "float32")):
+ def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,),
dtype="float32"):
gv = R.add(x, x)
return gv
@@ -81,7 +81,7 @@ def test_copy_with_new_vars_on_ir_module_nested_function():
@R.function
def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
@R.function
- def inner(x: R.Tensor((3,), "float32")):
+ def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,),
dtype="float32"):
gv = R.add(x, x)
return gv
@@ -91,7 +91,7 @@ def test_copy_with_new_vars_on_ir_module_nested_function():
@R.function
def func_copied(x: R.Tensor((3,), "float32"), y: R.Tensor((3,),
"float32")):
@R.function
- def inner(x: R.Tensor((3,), "float32")):
+ def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,),
dtype="float32"):
gv = R.add(x, x)
return gv