This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new ab43ba8ae4 [Unity][Pass] LambdaLift pass (#14012)
ab43ba8ae4 is described below

commit ab43ba8ae453c09ad19142d013530faaf3bff58c
Author: Yong Wu <[email protected]>
AuthorDate: Thu Feb 16 12:09:47 2023 -0800

    [Unity][Pass] LambdaLift pass (#14012)
---
 include/tvm/relax/analysis.h                     |  57 +++++
 python/tvm/relax/transform/transform.py          |  10 +
 src/relax/analysis/analysis.cc                   | 173 +++++++++++++
 src/relax/transform/lambda_lift.cc               | 266 ++++++++++++++++++++
 src/relax/utils.cc                               |  45 ++++
 tests/python/relax/test_transform_lambda_lift.py | 304 +++++++++++++++++++++++
 6 files changed, 855 insertions(+)

diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index a55fe6797d..ff576d4ebb 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -260,6 +260,63 @@ TVM_DLL bool IsBaseOf(const StructInfo& base, const 
StructInfo& derived,
 TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
                                  arith::Analyzer* ana = nullptr);
 
+//-----------------------------------
+// General IR analysis
+//-----------------------------------
+/*!
+ * \brief Get all bound variables from expression expr.
+ *
+ * Bound variables are all variables that are declared in the expr.
+ * They only have meaning inside that expr, and can only be used in it.
+ *
+ * \param expr the expression.
+ *
+ * \return List of bound vars, in the PostDFS order in the expression.
+ */
+TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
+
+/*!
+ * \brief Get free type parameters from expression expr.
+ *
+ * Free variables are variables that are not bound by a
+ * varbinding or a function parameter in the context.
+ *
+ * \param expr the expression.
+ *
+ * \return List of free vars, in the PostDFS order in the expression.
+ */
+TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
+
+/*!
+ * \brief Get all variables from expression expr.
+ *
+ * \param expr the expression.
+ *
+ * \return List of all vars, in the PostDFS order in the expression.
+ */
+TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
+
+/*!
+ * \brief Get all global variables used in calls in expression expr.
+ *
+ * \param expr the expression.
+ *
+ * \return List of all global variables called in expr.
+ */
+TVM_DLL tvm::Array<GlobalVar> CalledGlobalVars(const Expr& expr);
+
+/*!
+ * \brief Get all global variables from expression expr.
+ *
+ * AllVars is a superset of BoundVars and FreeVars.
+ * The union of BoundVars and FreeVars is Allvars.
+ *
+ * \param expr the expression.
+ *
+ * \return List of all global variables, in the PostDFS order in the 
expression.
+ */
+TVM_DLL tvm::Array<GlobalVar> AllGlobalVars(const Expr& expr);
+
 /*!
  * \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax 
FuseOps.
  *
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 0f973db290..1a525431dd 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -47,6 +47,16 @@ def ToNonDataflow() -> tvm.ir.transform.Pass:
     return _ffi_api.ToNonDataflow()  # type: ignore
 
 
+def LambdaLift():
+    """A pass that lifts local functions into global.
+
+    Returns
+    -------
+    ret : tvm.ir.transform.Pass
+    """
+    return _ffi_api.LambdaLift()
+
+
 def CallTIRRewrite() -> tvm.ir.transform.Pass:
     """Perform explicit tensor allocation for call_tir.
 
diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc
new file mode 100644
index 0000000000..33197308fa
--- /dev/null
+++ b/src/relax/analysis/analysis.cc
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file analysis.cc
+ *
+ * \brief Analysis functions for Relax.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/tir/expr_functor.h>
+
+namespace tvm {
+namespace relax {
+
+template <typename T>
+struct InsertionSet {
+  std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual> set;
+  std::vector<T> data;
+  void Insert(const T& t) {
+    if (set.count(t) == 0) {
+      set.insert(t);
+      data.push_back(t);
+    }
+  }
+};
+
+class VarVisitor : protected ExprVisitor {
+ public:
+  Array<Var> Free(const Expr& expr) {
+    this->VisitExpr(expr);
+    Array<Var> ret;
+    for (const auto& v : vars_.data) {
+      if (bound_vars_.set.count(v) == 0) {
+        ret.push_back(v);
+      }
+    }
+    return ret;
+  }
+
+  Array<Var> Collect() {
+    Array<Var> ret;
+    for (const auto& v : bound_vars_.data) {
+      ret.push_back(v);
+    }
+    return ret;
+  }
+
+  Array<Var> Bound(const Expr& expr) {
+    this->VisitExpr(expr);
+    return Collect();
+  }
+
+  Array<Var> All(const Expr& expr) {
+    this->VisitExpr(expr);
+    Array<Var> ret;
+    for (const auto& v : vars_.data) {
+      ret.push_back(v);
+    }
+    return ret;
+  }
+
+  Array<GlobalVar> AllGlobalVars(const Expr& expr) {
+    this->VisitExpr(expr);
+    Array<GlobalVar> ret;
+    for (const auto& v : global_vars_.data) {
+      ret.push_back(v);
+    }
+    return ret;
+  }
+
+  Array<GlobalVar> CalledGlobalVars(const Expr& expr) {
+    this->VisitExpr(expr);
+    Array<GlobalVar> ret;
+    for (const auto& v : called_global_vars_.data) {
+      ret.push_back(v);
+    }
+    return ret;
+  }
+
+  void MarkBounded(const Var& v) {
+    bound_vars_.Insert(v);
+    vars_.Insert(v);
+  }
+
+  void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); }
+
+  void VisitExpr_(const FunctionNode* op) final {
+    for (const auto& param : op->params) {
+      MarkBounded(param);
+    }
+    VisitExpr(op->body);
+  }
+
+  void VisitExpr_(const GlobalVarNode* op) final { 
global_vars_.Insert(GetRef<GlobalVar>(op)); }
+
+  void VisitExpr_(const CallNode* call_node) final {
+    VisitSpan(call_node->span);
+    VisitExpr(call_node->op);
+
+    for (StructInfo sinfo_arg : call_node->sinfo_args) {
+      VisitExprDepStructInfoField(sinfo_arg);
+    }
+
+    for (Expr arg : call_node->args) {
+      VisitExpr(arg);
+    }
+
+    if (const GlobalVarNode* global_var_node = 
call_node->op.as<GlobalVarNode>()) {
+      called_global_vars_.Insert(GetRef<GlobalVar>(global_var_node));
+    }
+  }
+
+  void VisitBinding_(const VarBindingNode* binding) final {
+    MarkBounded(binding->var);
+    VisitExpr(binding->value);
+    VisitVarDef(binding->var);
+  }
+
+  void VisitBinding_(const MatchCastNode* binding) final {
+    MarkBounded(binding->var);
+    ExprVisitor::VisitBinding_(binding);
+  }
+
+ private:
+  InsertionSet<Var> vars_;
+  InsertionSet<Var> bound_vars_;
+  InsertionSet<GlobalVar> global_vars_;
+  InsertionSet<GlobalVar> called_global_vars_;
+};
+
+tvm::Array<Var> FreeVars(const Expr& expr) { return VarVisitor().Free(expr); }
+
+tvm::Array<Var> BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); 
}
+
+tvm::Array<Var> AllVars(const Expr& expr) { return VarVisitor().All(expr); }
+
+tvm::Array<GlobalVar> AllGlobalVars(const Expr& expr) { return 
VarVisitor().AllGlobalVars(expr); }
+
+tvm::Array<GlobalVar> CalledGlobalVars(const Expr& expr) {
+  return VarVisitor().CalledGlobalVars(expr);
+}
+
+TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars);
+
+TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars);
+
+TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars);
+
+TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars);
+
+TVM_REGISTER_GLOBAL("relax.analysis.called_global_vars").set_body_typed(CalledGlobalVars);
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/lambda_lift.cc 
b/src/relax/transform/lambda_lift.cc
new file mode 100644
index 0000000000..f08499036b
--- /dev/null
+++ b/src/relax/transform/lambda_lift.cc
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/lambda_lift.cc
+ * \brief Lift local functions into global functions.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/runtime/logging.h>
+
+#include <iostream>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+/* The goal of this class is to lift out any nested functions into top-level
+ * functions.
+ *
+ * We will lift a function out into a global which takes the set of the free
+ * vars and then return the new created function.
+ */
+class LambdaLifter : public ExprMutator {
+ public:
+  explicit LambdaLifter(const IRModule& module) : ExprMutator(module) { mod_ = 
module; }
+
+  using ExprMutator::VisitExpr_;
+
+  Expr VisitExpr_(const CallNode* call_node) final {
+    auto call = Downcast<Call>(ExprMutator::VisitExpr_(call_node));
+    if (auto const* var = call_node->op.as<VarNode>()) {
+      bool has_closure = HasClosure(GetRef<Var>(var));
+      auto val = builder_->LookupBinding(GetRef<Var>(var));
+      // Call "relax.invoke_closure" to invoke closure
+      if (has_closure && val.as<CallNode>()) {
+        Var clo_arg = GetRef<Var>(var);
+        if (this->var_remap_.find(var->vid) != this->var_remap_.end()) {
+          clo_arg = this->var_remap_.at(var->vid);
+        }
+        return Call(invoke_closure_op_, {clo_arg, Tuple(call_node->args)}, {},
+                    {GetStructInfo(GetRef<Expr>(call_node))});
+      }
+    }
+    if (auto global_var_node = call_node->op.as<GlobalVarNode>()) {
+      String rec_name = global_var_node->name_hint;
+      auto global_var = GetRef<GlobalVar>(global_var_node);
+      auto it = lambda_map_.find(global_var);
+      if (it != lambda_map_.end()) {
+        // flatten nested call, e.g. call(y)(x) -> call(x, y))
+        Array<relay::Expr> new_args;
+        for (const auto arg : call->args) {
+          new_args.push_back(arg);
+        }
+        if (const auto* nest_call = it->second.as<CallNode>()) {
+          for (const auto arg : nest_call->args) {
+            new_args.push_back(arg);
+          }
+          return Call(nest_call->op, new_args, call_node->attrs, 
call_node->sinfo_args);
+        }
+        return Call(it->second, call->args, call_node->attrs, 
call_node->sinfo_args);
+      }
+    }
+    return std::move(call);
+  }
+
+  Expr VisitExpr_(const FunctionNode* func_node) final {
+    auto func = GetRef<Function>(func_node);
+
+    // TODO(@yongwww): consider appending inner func name into the lifted func 
name
+    String lift_func_name = "lifted_func_" + std::to_string(lift_func_num_++);
+    auto global = GlobalVar(lift_func_name);
+    Array<Var> captured_vars = FreeVars(func);
+    recur_vars_ = CalledGlobalVars(func);
+    auto all_global_vars = AllGlobalVars(func);
+
+    Array<Var> typed_captured_vars;
+    Map<Var, Expr> rebinding_map;
+    for (auto free_var : captured_vars) {
+      Var var = Var(free_var->name_hint(), GetStructInfo(free_var), 
free_var->span);
+      typed_captured_vars.push_back(var);
+      rebinding_map.Set(free_var, var);
+    }
+
+    // recursive call
+    if (!recur_vars_.empty()) {
+      if (!captured_vars.empty()) {
+        Array<Expr> fvs;
+        for (auto fv : captured_vars) {
+          fvs.push_back(fv);
+        }
+        lambda_map_.emplace(recur_vars_.back(), Call(global, fvs));
+      } else {
+        if (recur_vars_.size() > 0) {
+          lambda_map_.emplace(recur_vars_.back(), global);
+        }
+      }
+    }
+
+    tvm::Array<Var> params;
+    bool all_params_unchanged = true;
+    for (Var param : func_node->params) {
+      Var new_param = this->VisitVarDef(param);
+      params.push_back(new_param);
+      all_params_unchanged &= param.same_as(new_param);
+    }
+
+    Expr body = this->VisitWithNewScope(func_node->body);
+    Expr visited_func;
+
+    if (all_params_unchanged && body.same_as(func_node->body)) {
+      visited_func = GetRef<Expr>(func_node);
+    } else if (const auto& body_sinfo = 
MatchStructInfo<ObjectStructInfo>(body)) {
+      visited_func = Function(params, body, body_sinfo.value(), 
func_node->attrs);
+    } else {
+      visited_func = Function(params, body, func_node->ret_struct_info, 
func_node->attrs);
+    }
+    auto new_func = Downcast<Function>(visited_func);
+
+    Function lifted_func;
+    bool is_closure = IsClosure(captured_vars);
+    if (!is_closure) {
+      lifted_func = Function(
+          /*params=*/new_func->params,
+          /*body=*/new_func->body,
+          /*ret_struct_info=*/new_func->ret_struct_info,
+          /*attrs=*/new_func->attrs,
+          /*span=*/new_func->span);
+    } else {
+      // Flatten the Closure
+      std::vector<Var> closure_params;
+      closure_params.reserve(func->params.size() + typed_captured_vars.size());
+      for (size_t i = 0; i < func->params.size(); ++i) {
+        closure_params.emplace_back(func->params[i]);
+      }
+      for (size_t i = 0; i < typed_captured_vars.size(); ++i) {
+        closure_params.emplace_back(typed_captured_vars[i]);
+      }
+
+      lifted_func = Function(/*params=*/closure_params,
+                             /*body=*/Bind(new_func->body, rebinding_map),
+                             /*ret_struct_info=*/new_func->ret_struct_info,
+                             /*attrs=*/new_func->attrs,
+                             /*span=*/func->span);
+
+      Array<Type> param_types;
+      for (Var param : closure_params) {
+        CHECK(param->checked_type_.defined())
+            << "relax.Function requires params to contain checked_type_";
+        param_types.push_back(param->checked_type_);
+      }
+    }
+
+    ICHECK(lifted_func.defined());
+
+    // Add the lifted function to the module.
+    UpdateStructInfo(global, GetStructInfo(lifted_func));
+    builder_->UpdateFunction(global, lifted_func);
+
+    if (!is_closure) {
+      return std::move(global);
+    } else {
+      // If we need to allocate a closure,
+      // we pass the variables in its environment here.
+      Array<Expr> fvs;
+      for (auto fv : captured_vars) {
+        fvs.push_back(fv);
+      }
+      // Call make_closure intrinsic
+      return Call(make_closure_op_, {global, Tuple(fvs)}, {}, {});
+    }
+  }
+
+  bool HasClosure(const Var& var) {
+    auto val = builder_->LookupBinding(var);
+    if (const auto* value = val.as<GlobalVarNode>()) {
+      IRModule ctx_mod = builder_->GetContextIRModule();
+      ICHECK(ctx_mod->functions.size() > 0);
+      BaseFunc func = ctx_mod->Lookup(GetRef<GlobalVar>(value));
+      if (const auto* func_node = func.as<FunctionNode>()) {
+        if (const auto* call_node = func_node->body.as<CallNode>()) {
+          if (call_node->op == make_closure_op_) {
+            return true;
+          }
+        } else if (const auto* seq_expr_node = 
func_node->body.as<SeqExprNode>()) {
+          // the return var points to a make_closure intrinsic
+          if (const auto* var = seq_expr_node->body.as<VarNode>()) {
+            return HasClosure(GetRef<Var>(var));
+          }
+        }
+      }
+    } else if (const auto* func_node = val.as<FunctionNode>()) {
+      if (const auto* call_node = func_node->body.as<CallNode>()) {
+        if (call_node->op == make_closure_op_) {
+          return true;
+        }
+      }
+    } else if (const auto* call_node = val.as<relax::CallNode>()) {
+      // recursive call
+      auto op = call_node->op;
+      if (make_closure_op_ == op) {
+        return true;
+      }
+      if (const auto* lv = op.as<VarNode>()) {
+        return HasClosure(GetRef<Var>(lv));
+      }
+    }
+    return false;
+  }
+
+  bool IsClosure(const Array<Var>& captured_vars) { return 
captured_vars.size() > 0; }
+
+  IRModule Lift() {
+    auto glob_funcs = mod_->functions;
+    for (auto pair : glob_funcs) {
+      if (auto* n = pair.second.as<FunctionNode>()) {
+        auto func = GetRef<Function>(n);
+        func = Function(func->params, VisitExpr(func->body), 
func->ret_struct_info, func->attrs);
+        builder_->UpdateFunction(pair.first, func);
+      }
+    }
+    return builder_->GetContextIRModule();
+  }
+
+ private:
+  std::unordered_map<GlobalVar, Expr, ObjectPtrHash, ObjectPtrEqual> 
lambda_map_;
+  Array<GlobalVar> recur_vars_;
+  IRModule mod_;
+  size_t lift_func_num_ = 0;
+  /*! \brief Cache ops that would be used later to reduce lookup overhead. */
+  const Op& make_closure_op_ = Op::Get("relax.make_closure");
+  const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
+};
+
+namespace transform {
+
+Pass LambdaLift() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule m, PassContext pc) { return relax::LambdaLifter(m).Lift(); 
};
+  return CreateModulePass(pass_func, 1, "LambdaLift", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.LambdaLift").set_body_typed(LambdaLift);
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index 5846f8116d..24414f250c 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -22,6 +22,51 @@
 namespace tvm {
 namespace relax {
 
+/*! \brief Helper to implement bind params.*/
+class ExprBinder : public ExprMutator {
+ public:
+  explicit ExprBinder(const tvm::Map<Var, Expr>& args_map) : 
args_map_(args_map) {}
+
+  Expr VisitExpr_(const VarNode* op) final {
+    auto id = GetRef<Var>(op);
+    auto it = args_map_.find(id);
+    if (it != args_map_.end()) {
+      return (*it).second;
+    } else {
+      return ExprMutator::VisitExpr_(op);
+    }
+  }
+
+ private:
+  const tvm::Map<Var, Expr>& args_map_;
+};
+
+/*!
+ * \brief Bind params on expr
+ * \param expr The expr where to bind params
+ * \param args_map The map from param var to the expr it binds to
+ * \return The result expr after bind params
+ */
+Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
+  if (const FunctionNode* func = expr.as<FunctionNode>()) {
+    Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
+    Array<Var> new_params;
+    for (size_t i = 0; i < func->params.size(); ++i) {
+      if (!args_map.count(func->params[i])) {
+        new_params.push_back(func->params[i]);
+      }
+    }
+    if (new_body.same_as(func->body) && new_params.size() == 
func->params.size()) {
+      return expr;
+    }
+    // The checked_type_ of the new function is deduced from the function body
+    // TODO(@relax-team): Should infer the shape from the body as well
+    return Function(new_params, new_body, NullOpt, func->attrs);
+  } else {
+    return ExprBinder(args_map).VisitExpr(expr);
+  }
+}
+
 bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool 
permit_unknown_dtype) {
   const DynTensorTypeNode* tt = ty.as<DynTensorTypeNode>();
   if (!tt) {
diff --git a/tests/python/relax/test_transform_lambda_lift.py 
b/tests/python/relax/test_transform_lambda_lift.py
new file mode 100644
index 0000000000..fbdb1fbdce
--- /dev/null
+++ b/tests/python/relax/test_transform_lambda_lift.py
@@ -0,0 +1,304 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import tvm
+import tvm.testing
+from tvm import relax
+import tvm.script
+from tvm.script import relax as R, tir as T
+from tvm.relax import transform
+from tvm.ir.base import assert_structural_equal
+
+
+def _check_equal(x, y):
+    tvm.ir.assert_structural_equal(x, y)
+    tvm.ir.assert_structural_equal(y, x)
+
+    xhash = tvm.ir.structural_hash(x, map_free_vars=True)
+    yhash = tvm.ir.structural_hash(y, map_free_vars=True)
+    assert xhash == yhash
+
+
+def _check_save_roundtrip(x):
+    y = tvm.ir.load_json(tvm.ir.save_json(x))
+    _check_equal(x, y)
+
+
+def test_basic():
+    # the target IRModule
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def lifted_func_0(
+            x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32")
+        ) -> R.Tensor((10, 5), "float32"):
+            s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
+            return s
+
+        @R.function
+        def main(
+            x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
+        ) -> R.Tensor((10, 5), "float32"):
+            inner = lifted_func_0
+            gv1: R.Tensor((10, 5), "float32") = inner(x1, y1)
+            return gv1
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(
+            x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
+        ) -> R.Tensor((10, 5), "float32"):
+            @R.function
+            def inner(
+                x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), 
"float32")
+            ) -> R.Tensor((10, 5), "float32"):
+                s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
+                return s
+
+            gv1: R.Tensor((10, 5), "float32") = inner(x1, y1)
+            return gv1
+
+    before = Before
+    expected = Expected
+    # Perform Lambda Lifting
+    after = transform.LambdaLift()(before)
+    assert len(after.functions) == 2
+    assert_structural_equal(after, expected, map_free_vars=True)
+    _check_save_roundtrip(after)
+
+
+def test_closure():
+    # the expected IRModule
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
+        ) -> R.Tensor((2, 3), "float32"):
+            outer_func = lifted_func_0
+            in_call = outer_func(x)
+            res = R.invoke_closure(in_call, (y,), sinfo_args=(R.Tensor((2, 3), 
dtype="float32")))
+            return res
+
+        @R.function
+        def lifted_func_1(x1: R.Tensor((2, 3), "float32"), c1: R.Tensor((2, 
3), "float32")):
+            r_1: R.Tensor((2, 3), "float32") = R.add(x1, c1)
+            return r_1
+
+        @R.function
+        def lifted_func_0(y: R.Tensor((2, 3), "float32")) -> R.Object:
+            inner_func = R.make_closure(lifted_func_1, (y,))
+            return inner_func
+
+    # IRModule to perform Lambda Lifting
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
+        ) -> R.Tensor((2, 3), "float32"):
+            @R.function
+            def outer_func(c1: R.Tensor((2, 3), "float32")):
+                @R.function
+                def inner_func(x1: R.Tensor((2, 3), "float32")) -> 
R.Tensor((2, 3), "float32"):
+                    s: R.Tensor((2, 3), "float32") = R.add(x1, c1)
+                    return s
+
+                return inner_func
+
+            in_call = outer_func(x)
+            res = in_call(y)
+            return res
+
+    before = Before
+    after = transform.LambdaLift()(before)
+    expected = Expected
+    assert_structural_equal(after, expected, map_free_vars=True)
+    _check_save_roundtrip(after)
+
+
[email protected](reason="Need fix after parser switch over")
+def test_recursive():
+    # the expected IRModule
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def lifted_func_0(
+            i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32"), x: 
R.Tensor((2, 3), "float32")
+        ) -> R.Tensor((2, 3), "float32"):
+            cond: R.Tensor((), "bool") = R.call_packed(
+                "test.vm.less", i, R.const(10), sinfo_args=(R.Tensor((), 
dtype="bool"))
+            )
+            c: R.Tensor((), "int32") = R.const(1, dtype="int32")
+            if cond:
+                new_i: R.Tensor((), "int32") = R.add(i, c)
+                new_s: R.Tensor((2, 3), "float32") = R.add(s, x)
+                r = lifted_func_0(new_i, new_s, x)
+            else:
+                r = s
+            return r
+
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor:
+            while_loop = R.make_closure(lifted_func_0, (x,))
+            gv = R.invoke_closure(
+                while_loop,
+                (relax.const(0), x),
+                sinfo_args=(R.Tensor(ndim=2, dtype="float32")),
+            )
+            return gv
+
+    # the IRModule to apply lambda lifting
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor:
+            @R.function
+            def while_loop(
+                i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32")
+            ) -> R.Tensor((2, 3), "float32"):
+                cond: R.Tensor((), "bool") = R.call_packed(
+                    "test.vm.less", i, R.const(10), sinfo_args=(R.Tensor((), 
dtype="bool"))
+                )
+                c: R.Tensor((), "int32") = R.const(1, dtype="int32")
+                if cond:
+                    new_i: R.Tensor((), "int32") = R.add(i, c)
+                    new_s: R.Tensor((2, 3), "float32") = R.add(s, x)
+                    r: R.Tensor((2, 3), "float32") = while_loop(new_i, new_s)
+                else:
+                    r: R.Tensor((2, 3), "float32") = s
+                return r
+
+            gv: R.Tensor((2, 3), "float32") = while_loop(relax.const(0), x)
+            return gv
+
+    before = Before
+    expected = Expected
+    # Perform Lamda Lifting
+    after = transform.LambdaLift()(before)
+    assert len(after.functions) == 2
+
+    assert_structural_equal(after, expected, map_free_vars=True)
+    _check_save_roundtrip(after)
+
+
[email protected](reason="Need fix after parser switch over")
+def test_multi_func():
+    # expected IRModule
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def glob_func_1(
+            x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
+        ) -> R.Tensor(None, "float32", ndim=2):
+            inner = lifted_func_1
+            gv1 = inner(x1, y1)
+            return gv1
+
+        @R.function
+        def glob_func_2(
+            x11: R.Tensor((10, 5), "float32"), y11: R.Tensor((10, 5), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=2):
+            inner1 = lifted_func_0
+            gv11 = inner1(x11, y11)
+            return gv11
+
+        @R.function
+        def lifted_func_0(
+            x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32")
+        ) -> R.Tensor(None, "float32", ndim=2):
+            s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
+            return s
+
+        @R.function
+        def lifted_func_1(
+            x21: R.Tensor((10, 5), "float32"), y21: R.Tensor((10, 5), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=2):
+            s1: R.Tensor((10, 5), "float32") = R.add(x21, y21)
+            return s1
+
+    # the IRModule to apply lambda lifting
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def glob_func_1(
+            x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
+        ) -> R.Tensor((10, 5), "float32"):
+            @R.function
+            def inner(
+                x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), 
"float32")
+            ) -> R.Tensor((10, 5), "float32"):
+                s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
+                return s
+
+            gv1: R.Tensor((10, 5), "float32") = inner(x1, y1)
+            return gv1
+
+        @R.function
+        def glob_func_2(
+            x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
+        ) -> R.Tensor((10, 5), "float32"):
+            @R.function
+            def inner(
+                x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), 
"float32")
+            ) -> R.Tensor((10, 5), "float32"):
+                s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
+                return s
+
+            gv1: R.Tensor((10, 5), "float32") = inner(x1, y1)
+            return gv1
+
+    before = Before
+    expected = Expected
+    # Perform Lamda Lifting
+    after = transform.LambdaLift()(before)
+    assert len(after.functions) == 4
+    assert_structural_equal(after, expected, map_free_vars=True)
+    _check_save_roundtrip(after)
+
+
+def test_no_local_func():
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def sub(
+            A: T.Buffer[(16, 16), "float32"],
+            B: T.Buffer[(16, 16), "float32"],
+            C: T.Buffer[(16, 16), "float32"],
+        ) -> None:
+            for i, j in T.grid(16, 16):
+                with T.block("sub"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    C[vi, vj] = A[vi, vj] - B[vi, vj]
+
+        @R.function
+        def before(c0: R.Tensor((16, 16), "float32"), x: 
R.Tensor(dtype="float32", ndim=2)):
+            s = R.call_tir(sub, (c0, x), R.Tensor((16, 16), dtype="float32"))
+            return s
+
+    before = Before
+    # Perform lambda lifting
+    after = transform.LambdaLift()(before)
+    # No local functions are lifted
+    assert_structural_equal(after, before, map_free_vars=True)
+    _check_save_roundtrip(after)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to