This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ab43ba8ae4 [Unity][Pass] LambdaLift pass (#14012)
ab43ba8ae4 is described below
commit ab43ba8ae453c09ad19142d013530faaf3bff58c
Author: Yong Wu <[email protected]>
AuthorDate: Thu Feb 16 12:09:47 2023 -0800
[Unity][Pass] LambdaLift pass (#14012)
---
include/tvm/relax/analysis.h | 57 +++++
python/tvm/relax/transform/transform.py | 10 +
src/relax/analysis/analysis.cc | 173 +++++++++++++
src/relax/transform/lambda_lift.cc | 266 ++++++++++++++++++++
src/relax/utils.cc | 45 ++++
tests/python/relax/test_transform_lambda_lift.py | 304 +++++++++++++++++++++++
6 files changed, 855 insertions(+)
diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index a55fe6797d..ff576d4ebb 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -260,6 +260,63 @@ TVM_DLL bool IsBaseOf(const StructInfo& base, const
StructInfo& derived,
TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
arith::Analyzer* ana = nullptr);
+//-----------------------------------
+// General IR analysis
+//-----------------------------------
+/*!
+ * \brief Get all bound variables from expression expr.
+ *
+ * Bound variables are all variables that are declared in the expr.
+ * They only have meaning inside that expr, and can only be used in it.
+ *
+ * \param expr the expression.
+ *
+ * \return List of bound vars, in the PostDFS order in the expression.
+ */
+TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
+
+/*!
+ * \brief Get free type parameters from expression expr.
+ *
+ * Free variables are variables that are not bound by a
+ * varbinding or a function parameter in the context.
+ *
+ * \param expr the expression.
+ *
+ * \return List of free vars, in the PostDFS order in the expression.
+ */
+TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
+
+/*!
+ * \brief Get all variables from expression expr.
+ *
+ * \param expr the expression.
+ *
+ * \return List of all vars, in the PostDFS order in the expression.
+ */
+TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
+
+/*!
+ * \brief Get all global variables used in calls in expression expr.
+ *
+ * \param expr the expression.
+ *
+ * \return List of all global variables called in expr.
+ */
+TVM_DLL tvm::Array<GlobalVar> CalledGlobalVars(const Expr& expr);
+
+/*!
+ * \brief Get all global variables from expression expr.
+ *
+ * AllVars is a superset of BoundVars and FreeVars.
+ * The union of BoundVars and FreeVars is Allvars.
+ *
+ * \param expr the expression.
+ *
+ * \return List of all global variables, in the PostDFS order in the
expression.
+ */
+TVM_DLL tvm::Array<GlobalVar> AllGlobalVars(const Expr& expr);
+
/*!
* \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax
FuseOps.
*
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 0f973db290..1a525431dd 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -47,6 +47,16 @@ def ToNonDataflow() -> tvm.ir.transform.Pass:
return _ffi_api.ToNonDataflow() # type: ignore
+def LambdaLift():
+ """A pass that lifts local functions into global.
+
+ Returns
+ -------
+ ret : tvm.ir.transform.Pass
+ """
+ return _ffi_api.LambdaLift()
+
+
def CallTIRRewrite() -> tvm.ir.transform.Pass:
"""Perform explicit tensor allocation for call_tir.
diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc
new file mode 100644
index 0000000000..33197308fa
--- /dev/null
+++ b/src/relax/analysis/analysis.cc
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file analysis.cc
+ *
+ * \brief Analysis functions for Relax.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/tir/expr_functor.h>
+
+namespace tvm {
+namespace relax {
+
+template <typename T>
+struct InsertionSet {
+ std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual> set;
+ std::vector<T> data;
+ void Insert(const T& t) {
+ if (set.count(t) == 0) {
+ set.insert(t);
+ data.push_back(t);
+ }
+ }
+};
+
+class VarVisitor : protected ExprVisitor {
+ public:
+ Array<Var> Free(const Expr& expr) {
+ this->VisitExpr(expr);
+ Array<Var> ret;
+ for (const auto& v : vars_.data) {
+ if (bound_vars_.set.count(v) == 0) {
+ ret.push_back(v);
+ }
+ }
+ return ret;
+ }
+
+ Array<Var> Collect() {
+ Array<Var> ret;
+ for (const auto& v : bound_vars_.data) {
+ ret.push_back(v);
+ }
+ return ret;
+ }
+
+ Array<Var> Bound(const Expr& expr) {
+ this->VisitExpr(expr);
+ return Collect();
+ }
+
+ Array<Var> All(const Expr& expr) {
+ this->VisitExpr(expr);
+ Array<Var> ret;
+ for (const auto& v : vars_.data) {
+ ret.push_back(v);
+ }
+ return ret;
+ }
+
+ Array<GlobalVar> AllGlobalVars(const Expr& expr) {
+ this->VisitExpr(expr);
+ Array<GlobalVar> ret;
+ for (const auto& v : global_vars_.data) {
+ ret.push_back(v);
+ }
+ return ret;
+ }
+
+ Array<GlobalVar> CalledGlobalVars(const Expr& expr) {
+ this->VisitExpr(expr);
+ Array<GlobalVar> ret;
+ for (const auto& v : called_global_vars_.data) {
+ ret.push_back(v);
+ }
+ return ret;
+ }
+
+ void MarkBounded(const Var& v) {
+ bound_vars_.Insert(v);
+ vars_.Insert(v);
+ }
+
+ void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); }
+
+ void VisitExpr_(const FunctionNode* op) final {
+ for (const auto& param : op->params) {
+ MarkBounded(param);
+ }
+ VisitExpr(op->body);
+ }
+
+ void VisitExpr_(const GlobalVarNode* op) final {
global_vars_.Insert(GetRef<GlobalVar>(op)); }
+
+ void VisitExpr_(const CallNode* call_node) final {
+ VisitSpan(call_node->span);
+ VisitExpr(call_node->op);
+
+ for (StructInfo sinfo_arg : call_node->sinfo_args) {
+ VisitExprDepStructInfoField(sinfo_arg);
+ }
+
+ for (Expr arg : call_node->args) {
+ VisitExpr(arg);
+ }
+
+ if (const GlobalVarNode* global_var_node =
call_node->op.as<GlobalVarNode>()) {
+ called_global_vars_.Insert(GetRef<GlobalVar>(global_var_node));
+ }
+ }
+
+ void VisitBinding_(const VarBindingNode* binding) final {
+ MarkBounded(binding->var);
+ VisitExpr(binding->value);
+ VisitVarDef(binding->var);
+ }
+
+ void VisitBinding_(const MatchCastNode* binding) final {
+ MarkBounded(binding->var);
+ ExprVisitor::VisitBinding_(binding);
+ }
+
+ private:
+ InsertionSet<Var> vars_;
+ InsertionSet<Var> bound_vars_;
+ InsertionSet<GlobalVar> global_vars_;
+ InsertionSet<GlobalVar> called_global_vars_;
+};
+
+tvm::Array<Var> FreeVars(const Expr& expr) { return VarVisitor().Free(expr); }
+
+tvm::Array<Var> BoundVars(const Expr& expr) { return VarVisitor().Bound(expr);
}
+
+tvm::Array<Var> AllVars(const Expr& expr) { return VarVisitor().All(expr); }
+
+tvm::Array<GlobalVar> AllGlobalVars(const Expr& expr) { return
VarVisitor().AllGlobalVars(expr); }
+
+tvm::Array<GlobalVar> CalledGlobalVars(const Expr& expr) {
+ return VarVisitor().CalledGlobalVars(expr);
+}
+
+TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars);
+
+TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars);
+
+TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars);
+
+TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars);
+
+TVM_REGISTER_GLOBAL("relax.analysis.called_global_vars").set_body_typed(CalledGlobalVars);
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/transform/lambda_lift.cc
b/src/relax/transform/lambda_lift.cc
new file mode 100644
index 0000000000..f08499036b
--- /dev/null
+++ b/src/relax/transform/lambda_lift.cc
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/lambda_lift.cc
+ * \brief Lift local functions into global functions.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/runtime/logging.h>
+
+#include <iostream>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+/* The goal of this class is to lift out any nested functions into top-level
+ * functions.
+ *
+ * We will lift a function out into a global which takes the set of the free
+ * vars and then return the new created function.
+ */
+class LambdaLifter : public ExprMutator {
+ public:
+ explicit LambdaLifter(const IRModule& module) : ExprMutator(module) { mod_ =
module; }
+
+ using ExprMutator::VisitExpr_;
+
+ Expr VisitExpr_(const CallNode* call_node) final {
+ auto call = Downcast<Call>(ExprMutator::VisitExpr_(call_node));
+ if (auto const* var = call_node->op.as<VarNode>()) {
+ bool has_closure = HasClosure(GetRef<Var>(var));
+ auto val = builder_->LookupBinding(GetRef<Var>(var));
+ // Call "relax.invoke_closure" to invoke closure
+ if (has_closure && val.as<CallNode>()) {
+ Var clo_arg = GetRef<Var>(var);
+ if (this->var_remap_.find(var->vid) != this->var_remap_.end()) {
+ clo_arg = this->var_remap_.at(var->vid);
+ }
+ return Call(invoke_closure_op_, {clo_arg, Tuple(call_node->args)}, {},
+ {GetStructInfo(GetRef<Expr>(call_node))});
+ }
+ }
+ if (auto global_var_node = call_node->op.as<GlobalVarNode>()) {
+ String rec_name = global_var_node->name_hint;
+ auto global_var = GetRef<GlobalVar>(global_var_node);
+ auto it = lambda_map_.find(global_var);
+ if (it != lambda_map_.end()) {
+ // flatten nested call, e.g. call(y)(x) -> call(x, y))
+ Array<relay::Expr> new_args;
+ for (const auto arg : call->args) {
+ new_args.push_back(arg);
+ }
+ if (const auto* nest_call = it->second.as<CallNode>()) {
+ for (const auto arg : nest_call->args) {
+ new_args.push_back(arg);
+ }
+ return Call(nest_call->op, new_args, call_node->attrs,
call_node->sinfo_args);
+ }
+ return Call(it->second, call->args, call_node->attrs,
call_node->sinfo_args);
+ }
+ }
+ return std::move(call);
+ }
+
+ Expr VisitExpr_(const FunctionNode* func_node) final {
+ auto func = GetRef<Function>(func_node);
+
+ // TODO(@yongwww): consider appending inner func name into the lifted func
name
+ String lift_func_name = "lifted_func_" + std::to_string(lift_func_num_++);
+ auto global = GlobalVar(lift_func_name);
+ Array<Var> captured_vars = FreeVars(func);
+ recur_vars_ = CalledGlobalVars(func);
+ auto all_global_vars = AllGlobalVars(func);
+
+ Array<Var> typed_captured_vars;
+ Map<Var, Expr> rebinding_map;
+ for (auto free_var : captured_vars) {
+ Var var = Var(free_var->name_hint(), GetStructInfo(free_var),
free_var->span);
+ typed_captured_vars.push_back(var);
+ rebinding_map.Set(free_var, var);
+ }
+
+ // recursive call
+ if (!recur_vars_.empty()) {
+ if (!captured_vars.empty()) {
+ Array<Expr> fvs;
+ for (auto fv : captured_vars) {
+ fvs.push_back(fv);
+ }
+ lambda_map_.emplace(recur_vars_.back(), Call(global, fvs));
+ } else {
+ if (recur_vars_.size() > 0) {
+ lambda_map_.emplace(recur_vars_.back(), global);
+ }
+ }
+ }
+
+ tvm::Array<Var> params;
+ bool all_params_unchanged = true;
+ for (Var param : func_node->params) {
+ Var new_param = this->VisitVarDef(param);
+ params.push_back(new_param);
+ all_params_unchanged &= param.same_as(new_param);
+ }
+
+ Expr body = this->VisitWithNewScope(func_node->body);
+ Expr visited_func;
+
+ if (all_params_unchanged && body.same_as(func_node->body)) {
+ visited_func = GetRef<Expr>(func_node);
+ } else if (const auto& body_sinfo =
MatchStructInfo<ObjectStructInfo>(body)) {
+ visited_func = Function(params, body, body_sinfo.value(),
func_node->attrs);
+ } else {
+ visited_func = Function(params, body, func_node->ret_struct_info,
func_node->attrs);
+ }
+ auto new_func = Downcast<Function>(visited_func);
+
+ Function lifted_func;
+ bool is_closure = IsClosure(captured_vars);
+ if (!is_closure) {
+ lifted_func = Function(
+ /*params=*/new_func->params,
+ /*body=*/new_func->body,
+ /*ret_struct_info=*/new_func->ret_struct_info,
+ /*attrs=*/new_func->attrs,
+ /*span=*/new_func->span);
+ } else {
+ // Flatten the Closure
+ std::vector<Var> closure_params;
+ closure_params.reserve(func->params.size() + typed_captured_vars.size());
+ for (size_t i = 0; i < func->params.size(); ++i) {
+ closure_params.emplace_back(func->params[i]);
+ }
+ for (size_t i = 0; i < typed_captured_vars.size(); ++i) {
+ closure_params.emplace_back(typed_captured_vars[i]);
+ }
+
+ lifted_func = Function(/*params=*/closure_params,
+ /*body=*/Bind(new_func->body, rebinding_map),
+ /*ret_struct_info=*/new_func->ret_struct_info,
+ /*attrs=*/new_func->attrs,
+ /*span=*/func->span);
+
+ Array<Type> param_types;
+ for (Var param : closure_params) {
+ CHECK(param->checked_type_.defined())
+ << "relax.Function requires params to contain checked_type_";
+ param_types.push_back(param->checked_type_);
+ }
+ }
+
+ ICHECK(lifted_func.defined());
+
+ // Add the lifted function to the module.
+ UpdateStructInfo(global, GetStructInfo(lifted_func));
+ builder_->UpdateFunction(global, lifted_func);
+
+ if (!is_closure) {
+ return std::move(global);
+ } else {
+ // If we need to allocate a closure,
+ // we pass the variables in its environment here.
+ Array<Expr> fvs;
+ for (auto fv : captured_vars) {
+ fvs.push_back(fv);
+ }
+ // Call make_closure intrinsic
+ return Call(make_closure_op_, {global, Tuple(fvs)}, {}, {});
+ }
+ }
+
+ bool HasClosure(const Var& var) {
+ auto val = builder_->LookupBinding(var);
+ if (const auto* value = val.as<GlobalVarNode>()) {
+ IRModule ctx_mod = builder_->GetContextIRModule();
+ ICHECK(ctx_mod->functions.size() > 0);
+ BaseFunc func = ctx_mod->Lookup(GetRef<GlobalVar>(value));
+ if (const auto* func_node = func.as<FunctionNode>()) {
+ if (const auto* call_node = func_node->body.as<CallNode>()) {
+ if (call_node->op == make_closure_op_) {
+ return true;
+ }
+ } else if (const auto* seq_expr_node =
func_node->body.as<SeqExprNode>()) {
+ // the return var points to a make_closure intrinsic
+ if (const auto* var = seq_expr_node->body.as<VarNode>()) {
+ return HasClosure(GetRef<Var>(var));
+ }
+ }
+ }
+ } else if (const auto* func_node = val.as<FunctionNode>()) {
+ if (const auto* call_node = func_node->body.as<CallNode>()) {
+ if (call_node->op == make_closure_op_) {
+ return true;
+ }
+ }
+ } else if (const auto* call_node = val.as<relax::CallNode>()) {
+ // recursive call
+ auto op = call_node->op;
+ if (make_closure_op_ == op) {
+ return true;
+ }
+ if (const auto* lv = op.as<VarNode>()) {
+ return HasClosure(GetRef<Var>(lv));
+ }
+ }
+ return false;
+ }
+
+ bool IsClosure(const Array<Var>& captured_vars) { return
captured_vars.size() > 0; }
+
+ IRModule Lift() {
+ auto glob_funcs = mod_->functions;
+ for (auto pair : glob_funcs) {
+ if (auto* n = pair.second.as<FunctionNode>()) {
+ auto func = GetRef<Function>(n);
+ func = Function(func->params, VisitExpr(func->body),
func->ret_struct_info, func->attrs);
+ builder_->UpdateFunction(pair.first, func);
+ }
+ }
+ return builder_->GetContextIRModule();
+ }
+
+ private:
+ std::unordered_map<GlobalVar, Expr, ObjectPtrHash, ObjectPtrEqual>
lambda_map_;
+ Array<GlobalVar> recur_vars_;
+ IRModule mod_;
+ size_t lift_func_num_ = 0;
+ /*! \brief Cache ops that would be used later to reduce lookup overhead. */
+ const Op& make_closure_op_ = Op::Get("relax.make_closure");
+ const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
+};
+
+namespace transform {
+
+Pass LambdaLift() {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+ [=](IRModule m, PassContext pc) { return relax::LambdaLifter(m).Lift();
};
+ return CreateModulePass(pass_func, 1, "LambdaLift", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.LambdaLift").set_body_typed(LambdaLift);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index 5846f8116d..24414f250c 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -22,6 +22,51 @@
namespace tvm {
namespace relax {
+/*! \brief Helper to implement bind params.*/
+class ExprBinder : public ExprMutator {
+ public:
+ explicit ExprBinder(const tvm::Map<Var, Expr>& args_map) :
args_map_(args_map) {}
+
+ Expr VisitExpr_(const VarNode* op) final {
+ auto id = GetRef<Var>(op);
+ auto it = args_map_.find(id);
+ if (it != args_map_.end()) {
+ return (*it).second;
+ } else {
+ return ExprMutator::VisitExpr_(op);
+ }
+ }
+
+ private:
+ const tvm::Map<Var, Expr>& args_map_;
+};
+
+/*!
+ * \brief Bind params on expr
+ * \param expr The expr where to bind params
+ * \param args_map The map from param var to the expr it binds to
+ * \return The result expr after bind params
+ */
+Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
+ if (const FunctionNode* func = expr.as<FunctionNode>()) {
+ Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
+ Array<Var> new_params;
+ for (size_t i = 0; i < func->params.size(); ++i) {
+ if (!args_map.count(func->params[i])) {
+ new_params.push_back(func->params[i]);
+ }
+ }
+ if (new_body.same_as(func->body) && new_params.size() ==
func->params.size()) {
+ return expr;
+ }
+ // The checked_type_ of the new function is deduced from the function body
+ // TODO(@relax-team): Should infer the shape from the body as well
+ return Function(new_params, new_body, NullOpt, func->attrs);
+ } else {
+ return ExprBinder(args_map).VisitExpr(expr);
+ }
+}
+
bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool
permit_unknown_dtype) {
const DynTensorTypeNode* tt = ty.as<DynTensorTypeNode>();
if (!tt) {
diff --git a/tests/python/relax/test_transform_lambda_lift.py
b/tests/python/relax/test_transform_lambda_lift.py
new file mode 100644
index 0000000000..fbdb1fbdce
--- /dev/null
+++ b/tests/python/relax/test_transform_lambda_lift.py
@@ -0,0 +1,304 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import tvm
+import tvm.testing
+from tvm import relax
+import tvm.script
+from tvm.script import relax as R, tir as T
+from tvm.relax import transform
+from tvm.ir.base import assert_structural_equal
+
+
+def _check_equal(x, y):
+ tvm.ir.assert_structural_equal(x, y)
+ tvm.ir.assert_structural_equal(y, x)
+
+ xhash = tvm.ir.structural_hash(x, map_free_vars=True)
+ yhash = tvm.ir.structural_hash(y, map_free_vars=True)
+ assert xhash == yhash
+
+
+def _check_save_roundtrip(x):
+ y = tvm.ir.load_json(tvm.ir.save_json(x))
+ _check_equal(x, y)
+
+
+def test_basic():
+ # the target IRModule
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def lifted_func_0(
+ x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32")
+ ) -> R.Tensor((10, 5), "float32"):
+ s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
+ return s
+
+ @R.function
+ def main(
+ x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
+ ) -> R.Tensor((10, 5), "float32"):
+ inner = lifted_func_0
+ gv1: R.Tensor((10, 5), "float32") = inner(x1, y1)
+ return gv1
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(
+ x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
+ ) -> R.Tensor((10, 5), "float32"):
+ @R.function
+ def inner(
+ x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5),
"float32")
+ ) -> R.Tensor((10, 5), "float32"):
+ s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
+ return s
+
+ gv1: R.Tensor((10, 5), "float32") = inner(x1, y1)
+ return gv1
+
+ before = Before
+ expected = Expected
+ # Perform Lambda Lifting
+ after = transform.LambdaLift()(before)
+ assert len(after.functions) == 2
+ assert_structural_equal(after, expected, map_free_vars=True)
+ _check_save_roundtrip(after)
+
+
+def test_closure():
+ # the expected IRModule
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
+ ) -> R.Tensor((2, 3), "float32"):
+ outer_func = lifted_func_0
+ in_call = outer_func(x)
+ res = R.invoke_closure(in_call, (y,), sinfo_args=(R.Tensor((2, 3),
dtype="float32")))
+ return res
+
+ @R.function
+ def lifted_func_1(x1: R.Tensor((2, 3), "float32"), c1: R.Tensor((2,
3), "float32")):
+ r_1: R.Tensor((2, 3), "float32") = R.add(x1, c1)
+ return r_1
+
+ @R.function
+ def lifted_func_0(y: R.Tensor((2, 3), "float32")) -> R.Object:
+ inner_func = R.make_closure(lifted_func_1, (y,))
+ return inner_func
+
+ # IRModule to perform Lambda Lifting
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
+ ) -> R.Tensor((2, 3), "float32"):
+ @R.function
+ def outer_func(c1: R.Tensor((2, 3), "float32")):
+ @R.function
+ def inner_func(x1: R.Tensor((2, 3), "float32")) ->
R.Tensor((2, 3), "float32"):
+ s: R.Tensor((2, 3), "float32") = R.add(x1, c1)
+ return s
+
+ return inner_func
+
+ in_call = outer_func(x)
+ res = in_call(y)
+ return res
+
+ before = Before
+ after = transform.LambdaLift()(before)
+ expected = Expected
+ assert_structural_equal(after, expected, map_free_vars=True)
+ _check_save_roundtrip(after)
+
+
[email protected](reason="Need fix after parser switch over")
+def test_recursive():
+ # the expected IRModule
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def lifted_func_0(
+ i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32"), x:
R.Tensor((2, 3), "float32")
+ ) -> R.Tensor((2, 3), "float32"):
+ cond: R.Tensor((), "bool") = R.call_packed(
+ "test.vm.less", i, R.const(10), sinfo_args=(R.Tensor((),
dtype="bool"))
+ )
+ c: R.Tensor((), "int32") = R.const(1, dtype="int32")
+ if cond:
+ new_i: R.Tensor((), "int32") = R.add(i, c)
+ new_s: R.Tensor((2, 3), "float32") = R.add(s, x)
+ r = lifted_func_0(new_i, new_s, x)
+ else:
+ r = s
+ return r
+
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor:
+ while_loop = R.make_closure(lifted_func_0, (x,))
+ gv = R.invoke_closure(
+ while_loop,
+ (relax.const(0), x),
+ sinfo_args=(R.Tensor(ndim=2, dtype="float32")),
+ )
+ return gv
+
+ # the IRModule to apply lambda lifting
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor:
+ @R.function
+ def while_loop(
+ i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32")
+ ) -> R.Tensor((2, 3), "float32"):
+ cond: R.Tensor((), "bool") = R.call_packed(
+ "test.vm.less", i, R.const(10), sinfo_args=(R.Tensor((),
dtype="bool"))
+ )
+ c: R.Tensor((), "int32") = R.const(1, dtype="int32")
+ if cond:
+ new_i: R.Tensor((), "int32") = R.add(i, c)
+ new_s: R.Tensor((2, 3), "float32") = R.add(s, x)
+ r: R.Tensor((2, 3), "float32") = while_loop(new_i, new_s)
+ else:
+ r: R.Tensor((2, 3), "float32") = s
+ return r
+
+ gv: R.Tensor((2, 3), "float32") = while_loop(relax.const(0), x)
+ return gv
+
+ before = Before
+ expected = Expected
+ # Perform Lamda Lifting
+ after = transform.LambdaLift()(before)
+ assert len(after.functions) == 2
+
+ assert_structural_equal(after, expected, map_free_vars=True)
+ _check_save_roundtrip(after)
+
+
[email protected](reason="Need fix after parser switch over")
+def test_multi_func():
+ # expected IRModule
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def glob_func_1(
+ x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
+ ) -> R.Tensor(None, "float32", ndim=2):
+ inner = lifted_func_1
+ gv1 = inner(x1, y1)
+ return gv1
+
+ @R.function
+ def glob_func_2(
+ x11: R.Tensor((10, 5), "float32"), y11: R.Tensor((10, 5),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=2):
+ inner1 = lifted_func_0
+ gv11 = inner1(x11, y11)
+ return gv11
+
+ @R.function
+ def lifted_func_0(
+ x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32")
+ ) -> R.Tensor(None, "float32", ndim=2):
+ s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
+ return s
+
+ @R.function
+ def lifted_func_1(
+ x21: R.Tensor((10, 5), "float32"), y21: R.Tensor((10, 5),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=2):
+ s1: R.Tensor((10, 5), "float32") = R.add(x21, y21)
+ return s1
+
+ # the IRModule to apply lambda lifting
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def glob_func_1(
+ x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
+ ) -> R.Tensor((10, 5), "float32"):
+ @R.function
+ def inner(
+ x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5),
"float32")
+ ) -> R.Tensor((10, 5), "float32"):
+ s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
+ return s
+
+ gv1: R.Tensor((10, 5), "float32") = inner(x1, y1)
+ return gv1
+
+ @R.function
+ def glob_func_2(
+ x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
+ ) -> R.Tensor((10, 5), "float32"):
+ @R.function
+ def inner(
+ x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5),
"float32")
+ ) -> R.Tensor((10, 5), "float32"):
+ s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
+ return s
+
+ gv1: R.Tensor((10, 5), "float32") = inner(x1, y1)
+ return gv1
+
+ before = Before
+ expected = Expected
+ # Perform Lamda Lifting
+ after = transform.LambdaLift()(before)
+ assert len(after.functions) == 4
+ assert_structural_equal(after, expected, map_free_vars=True)
+ _check_save_roundtrip(after)
+
+
+def test_no_local_func():
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def sub(
+ A: T.Buffer[(16, 16), "float32"],
+ B: T.Buffer[(16, 16), "float32"],
+ C: T.Buffer[(16, 16), "float32"],
+ ) -> None:
+ for i, j in T.grid(16, 16):
+ with T.block("sub"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = A[vi, vj] - B[vi, vj]
+
+ @R.function
+ def before(c0: R.Tensor((16, 16), "float32"), x:
R.Tensor(dtype="float32", ndim=2)):
+ s = R.call_tir(sub, (c0, x), R.Tensor((16, 16), dtype="float32"))
+ return s
+
+ before = Before
+ # Perform lambda lifting
+ after = transform.LambdaLift()(before)
+ # No local functions are lifted
+ assert_structural_equal(after, before, map_free_vars=True)
+ _check_save_roundtrip(after)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()