This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ff6ce9c2b3 Enable Shared Function in LiftTransformParam Pass (#16717)
ff6ce9c2b3 is described below
commit ff6ce9c2b32c4175a30a23f8d19c8f6191615a23
Author: Xiyou Zhou <[email protected]>
AuthorDate: Tue Mar 19 08:45:03 2024 -0700
Enable Shared Function in LiftTransformParam Pass (#16717)
* [WIP] LiftTransformParams for multiple functions
* pass test
* [In-Progress] Define desired behavior for shared LiftTransformParams
Currently, the `relax.transform.LiftTransformParams` pass produces a
separate `transform_params` function for every function in the
`IRModule`. In most cases, the functions in an `IRModule` all accept
the same set of model weights (e.g. `"prefill"` and `"decode"` in a
transformer model). However, the lifted `*_transform_params`
functions may be different for each inference function.
The goal is to introduce a new optional parameter `shared_transform`
for `LiftTransformParams`. If set, a single parameter transformation
function should be generated for the entire `IRModule`, rather than
one parameter transformation function for each original function.
Because the shared parameter transformation function must be
compatible with all existing functions, it should only contain
parameter transformation steps that are common across all input
functions.
* [TIR] Implemented shared lift transform params
* Comments & skip test.
* Linting.
* Avoid c++20 feature to pass CI.
* Remove unused code.
* Fix interface as suggested.
* Fix docs.
* Fix interface as suggested.
* Move code for readability.
---------
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Eric Lunderberg <[email protected]>
---
include/tvm/relax/transform.h | 14 +-
python/tvm/relax/transform/transform.py | 22 +-
src/relax/transform/lift_transform_params.cc | 638 ++++++++++-----
.../relax/test_transform_lift_transform_params.py | 880 +++++++++++++++++++++
4 files changed, 1358 insertions(+), 196 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index f3544d8613..82cbf3d12d 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -265,9 +265,21 @@ TVM_DLL Pass RealizeVDevice();
* Users are expected to invoke the `transform_params` function in runtime and
pass the transformed
* parameters to the original function as input.
*
+ * \param shared_transform Indicates how the parameter transformation function
will be produced.
+ * - `False` (default): A separate parameter transformation function will
be produced for each
+ * function with the `"num_input"` attribute.
+ *
+ * - `True`: A single parameter transformation function will be produced,
containing the
+ * preprocessing steps common across all functions with the `"num_input"`
attribute.
+ *
+ * - List[str]: A single parameter transformation function will be
produced, containing the
+ * preprocessing steps common across each function whose name is in the
list. Passing a list of
+ * all functions with the `"num_input"` attribute or an empty list is
equivalent to passing
+ * `True`.
+ *
* \return The Pass.
*/
-TVM_DLL Pass LiftTransformParams();
+TVM_DLL Pass LiftTransformParams(Variant<Bool, Array<String>> shared_transform
= Bool(false));
/*!
* \brief Update virtual device.
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 9ef5133b71..ef10f5791d 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -855,7 +855,7 @@ def MergeCompositeFunctions() -> tvm.ir.transform.Pass:
return _ffi_api.MergeCompositeFunctions() # type: ignore
-def LiftTransformParams() -> tvm.ir.transform.Pass:
+def LiftTransformParams(shared_transform: Union[bool, List[str]] = False) ->
tvm.ir.transform.Pass:
"""Lift transformation of the parameters of a function.
When some inputs of the function is marked as 'parameters' (the model
weights), this pass
@@ -867,12 +867,30 @@ def LiftTransformParams() -> tvm.ir.transform.Pass:
Users are expected to invoke the `transform_params` function in runtime
and pass the transformed
parameters to the original function as input.
+ Parameters
+ ----------
+ shared_transform: Union[bool, List[str]]
+
+ Indicates how the parameter transformation function will be produced
+
+ - `False` (default): A separate parameter transformation function will
be
+ produced for each function with the `"num_input"` attribute.
+
+ - `True`: A single parameter transformation function will be produced,
+ containing the preprocessing steps common across all functions with
+ the `"num_input"` attribute.
+
+ - List[str]: A single parameter transformation function will be
produced,
+ containing the preprocessing steps common across each function whose
+ name is in the list. Passing a list of all functions with the
`"num_input"`
+ attribute or an empty list is equivalent to passing `True`.
+
Returns
-------
ret : tvm.transform.Pass
The registered pass for lifting transformation of parameters.
"""
- return _ffi_api.LiftTransformParams() # type: ignore
+ return _ffi_api.LiftTransformParams(shared_transform) # type: ignore
def BundleModelParams(param_tuple_name: Optional[str] = None) ->
tvm.ir.transform.Pass:
diff --git a/src/relax/transform/lift_transform_params.cc
b/src/relax/transform/lift_transform_params.cc
index cdf1abc38e..abf21189e4 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -29,6 +29,7 @@
#include <tvm/runtime/logging.h>
#include <iostream>
+#include <optional>
#include <tuple>
#include <vector>
@@ -42,14 +43,8 @@ constexpr const char* kLiftTransformConsumeParams =
"relax.lift_transform_params
TVM_REGISTER_PASS_CONFIG_OPTION(kLiftTransformConsumeParams, Bool);
namespace {
-
-struct CollectInfo {
- /* \brief The analyzed function */
- Function orig_func;
-
- /* \brief The number of parameters unknown until runtime */
- size_t num_runtime_params;
-
+struct BaseCollectInfo {
+ public:
/*! \brief Bindings that can be lifted out into a pre-processing
*
* - All bindings in `computable_at_compile_time` are suitable for
@@ -74,6 +69,104 @@ struct CollectInfo {
std::unordered_set<Variant<relax::Var, tir::Var>, ObjectPtrHash,
ObjectPtrEqual>
required_at_runtime;
+ protected:
+ Array<Var> GetCompileTimeOutputsHelper(const Array<Var>& params) const {
+ // The output of the compile-time function is in the following order:
+ // 1) Any parameter that is required at runtime in the original order,
followed by,
+ // 2) Any binding that is computable at compile-time and required at
runtime in the original
+ // order.
+ Array<Var> output;
+ for (const auto& param : params) {
+ if (required_at_runtime.count(param)) {
+ output.push_back(param);
+ }
+ }
+ for (const auto& binding : computable_at_compile_time) {
+ if (requires_compile_time_param.count(binding->var) &&
+ required_at_runtime.count(binding->var)) {
+ output.push_back(binding->var);
+ }
+ }
+
+ return output;
+ }
+
+ Function MakeCompileTimeFunctionHelper(const Array<Var> params, const
Array<Binding>& bindings,
+ const Array<tir::Var>&
output_symbolic_vars,
+ const Array<Var>& outputs) const {
+ Array<Binding> output_var_binding;
+ Array<Expr> output_exprs;
+ if (output_symbolic_vars.size()) {
+ output_exprs.push_back(
+ ShapeExpr(output_symbolic_vars.Map([](tir::Var var) -> PrimExpr {
return var; })));
+ }
+
+ for (const auto& var : outputs) {
+ Var out_var(var->name_hint() + "_output", GetStructInfo(var));
+ output_var_binding.push_back(VarBinding(out_var, var));
+ output_exprs.push_back(out_var);
+ }
+
+ Var tuple_var("output_tuple",
TupleStructInfo(output_exprs.Map(GetStructInfo)));
+ output_var_binding.push_back(VarBinding(tuple_var, Tuple(output_exprs)));
+
+ SeqExpr body(
+ {
+ DataflowBlock(bindings),
+ DataflowBlock(output_var_binding),
+ },
+ tuple_var);
+ Function func(params, body, GetStructInfo(tuple_var));
+ func = WithAttr(func, attr::kNumInput, Integer(0));
+ func = CopyWithNewVars(func);
+ func = Downcast<Function>(CanonicalizeBindings(func));
+ return func;
+ }
+};
+
+struct GlobalCollectInfo : public BaseCollectInfo {
+ // The original functions
+ Array<Function> orig_functions;
+ // The parameters of the compile-time function.
+ Array<Var> params;
+ // The cross-function mapping between variables.
+ Map<relax::Var, Expr> var_remap;
+ // The cross-function between between TIR variables.
+ Map<tir::Var, PrimExpr> tir_var_remap;
+ Array<tir::Var> GetPropagatedSymbolicVariables() const {
+ auto vars_from_original_params =
+
DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo)));
+ auto vars_from_transformed_params =
+ [&]() -> std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> {
+ auto tir_vars =
+
DefinableTIRVarsInStructInfo(TupleStructInfo(GetCompileTimeOutputs().Map(GetStructInfo)));
+ return {tir_vars.begin(), tir_vars.end()};
+ }();
+
+ Array<tir::Var> output;
+ for (const auto& tir_var : vars_from_original_params) {
+ if (required_at_runtime.count(tir_var) &&
!vars_from_transformed_params.count(tir_var)) {
+ output.push_back(tir_var);
+ }
+ }
+ return output;
+ }
+
+ Function MakeCompileTimeFunc() {
+ return MakeCompileTimeFunctionHelper(params, computable_at_compile_time,
+ GetPropagatedSymbolicVariables(),
GetCompileTimeOutputs());
+ }
+ Array<Var> GetCompileTimeOutputs() const { return
GetCompileTimeOutputsHelper(params); }
+};
+struct LocalCollectInfo : public BaseCollectInfo {
+ /* \brief The analyzed function */
+ Function orig_func;
+
+ /* \brief The number of parameters unknown until runtime */
+ size_t num_runtime_params;
+
+ GlobalCollectInfo* global_info = nullptr;
+
Array<Var> GetCompileTimeInputs() const {
return Array<Var>(orig_func->params.begin() + num_runtime_params,
orig_func->params.end());
}
@@ -111,65 +204,13 @@ struct CollectInfo {
}
Array<Var> GetCompileTimeOutputs() const {
- Array<Var> params;
-
- // Any value that is available at compile-time, but is also
- // required at runtime, must be passed through the compile-time
- // function.
- for (size_t i = num_runtime_params; i < orig_func->params.size(); i++) {
- Var var = orig_func->params[i];
- if (required_at_runtime.count(var)) {
- params.push_back(var);
- }
- }
-
- // Any variable that is computed at compile-time, but is required
- // at runtime, must be provided as a parameter.
- for (const auto& binding : computable_at_compile_time) {
- if (requires_compile_time_param.count(binding->var) &&
- required_at_runtime.count(binding->var)) {
- params.push_back(binding->var);
- }
- }
-
- return params;
+ return GetCompileTimeOutputsHelper(GetCompileTimeInputs());
}
Function MakeCompileTimeFunction() const {
- auto compile_time_params = GetCompileTimeInputs();
-
- Array<Binding> output_var_binding;
- Array<Expr> output_exprs;
-
- // Any symbolic variables that are inferrable from compile-time
- // parameters, but are not inferrable from run-time parameters,
- // must be propagated to the output.
- if (auto propagated_tir_vars = GetPropagatedSymbolicVariables();
propagated_tir_vars.size()) {
- output_exprs.push_back(
- ShapeExpr(propagated_tir_vars.Map([](tir::Var var) -> PrimExpr {
return var; })));
- }
-
- for (const auto& var : GetCompileTimeOutputs()) {
- Var out_var(var->name_hint() + "_output", GetStructInfo(var));
- output_var_binding.push_back(VarBinding(out_var, var));
- output_exprs.push_back(out_var);
- }
-
- Var tuple_var("output_tuple",
TupleStructInfo(output_exprs.Map(GetStructInfo)));
- output_var_binding.push_back(VarBinding(tuple_var, Tuple(output_exprs)));
-
- SeqExpr body(
- {
- DataflowBlock(computable_at_compile_time),
- DataflowBlock(output_var_binding),
- },
- tuple_var);
-
- Function func(compile_time_params, body, GetStructInfo(tuple_var));
- func = WithAttr(func, attr::kNumInput, Integer(0));
- func = CopyWithNewVars(func);
- func = Downcast<Function>(CanonicalizeBindings(func));
- return func;
+ ICHECK(!global_info); // This function is only called for local lifting
+ return MakeCompileTimeFunctionHelper(GetCompileTimeInputs(),
computable_at_compile_time,
+ GetPropagatedSymbolicVariables(),
GetCompileTimeOutputs());
}
Function MakeRuntimeFunction() const {
@@ -181,13 +222,64 @@ struct CollectInfo {
// serve as the parameter. This trivial binding will later be
// removed with CanonicalizeBindings.
Array<Var> params = GetRuntimeInputs();
- if (auto propagated_tir_vars = GetPropagatedSymbolicVariables();
propagated_tir_vars.size()) {
+ auto propagated_tir_vars = [&]() {
+ Array<tir::Var> local_tir_vars = GetPropagatedSymbolicVariables();
+ if (!global_info) {
+ return local_tir_vars;
+ }
+ // When global lifting is enabled, the compile-time outputs are the
global outputs, but the
+ // variables in the global outputs to the local variables.
+ Map<tir::Var, tir::Var> reverse_map;
+ for (const auto& var : local_tir_vars) {
+ if (auto it = global_info->tir_var_remap.find(var);
+ it != global_info->tir_var_remap.end()) {
+ reverse_map.Set(Downcast<tir::Var>((*it).second), var);
+ }
+ }
+ Array<tir::Var> global_tir_vars =
global_info->GetPropagatedSymbolicVariables();
+ global_tir_vars = global_tir_vars.Map([&](const tir::Var& var) {
+ if (auto it = reverse_map.find(var); it != reverse_map.end()) {
+ return Downcast<tir::Var>((*it).second);
+ } else {
+ // This is the case when the some of the outputs of the shared
transform is not used in
+ // this function.
+ return var;
+ }
+ });
+ return global_tir_vars;
+ }();
+ if (propagated_tir_vars.size()) {
ShapeStructInfo shape_sinfo(
propagated_tir_vars.Map([](tir::Var var) -> PrimExpr { return var;
}));
Var shape_expr("vars_from_compile_time_params", shape_sinfo);
params.push_back(shape_expr);
}
- for (const auto& var : GetCompileTimeOutputs()) {
+ Array<Var> compile_time_outputs = [&]() {
+ Array<Var> local_outputs = GetCompileTimeOutputs();
+ if (!global_info) {
+ return local_outputs;
+ }
+ // When global lifting is enabled, the compile-time outputs are the
global outputs, but the
+ // variables in the global outputs to the local variables.
+ Map<Var, Var> reverse_map;
+ for (const auto& var : local_outputs) {
+ if (auto it = global_info->var_remap.find(var); it !=
global_info->var_remap.end()) {
+ reverse_map.Set(Downcast<Var>((*it).second), var);
+ }
+ }
+ Array<Var> global_outputs = global_info->GetCompileTimeOutputs();
+ global_outputs = global_outputs.Map([&](const Var& var) {
+ if (auto it = reverse_map.find(var); it != reverse_map.end()) {
+ return Downcast<Var>((*it).second);
+ } else {
+ // This is the case when the some of the outputs of the shared
transform is not used in
+ // this function.
+ return var;
+ }
+ });
+ return global_outputs;
+ }();
+ for (const auto& var : compile_time_outputs) {
Var param_var(var->name_hint(), GetStructInfo(var));
bindings.push_back(VarBinding(var, param_var));
params.push_back(param_var);
@@ -231,86 +323,111 @@ struct CollectInfo {
body = SeqExpr({DataflowBlock(bindings)}, body);
Function func(params, body, orig_func->ret_struct_info,
orig_func->is_pure, orig_func->attrs);
- func = WithoutAttr(func, tvm::attr::kGlobalSymbol);
func = CopyWithNewVars(func);
+ func = Downcast<Function>(CanonicalizeBindings(func));
return func;
}
+};
- Function MakePartitionedFunction() const {
- Array<Binding> inner_func_bindings;
- Var compile_time_func = [&]() {
- auto func = MakeCompileTimeFunction();
- Var var("transform_params", GetStructInfo(func));
- inner_func_bindings.push_back(VarBinding(var, std::move(func)));
- return var;
- }();
- Var runtime_func = [&]() {
- auto func = MakeRuntimeFunction();
- Var var("runtime", GetStructInfo(func));
- inner_func_bindings.push_back(VarBinding(var, std::move(func)));
- return var;
- }();
+class BaseLiftableBindingCollector : public ExprVisitor {
+ protected:
+ void VisitBindingBlock_(const DataflowBlockNode* block) final {
+ bool cache = is_in_dataflow_block_;
+ is_in_dataflow_block_ = true;
+ ExprVisitor::VisitBindingBlock_(block);
+ is_in_dataflow_block_ = cache;
+ }
- Array<Binding> calling_scope;
+ bool CanLiftBinding(const Binding& binding) const {
+ auto value = GetBoundValue(binding);
- Call compile_time_preprocess(
- compile_time_func, GetCompileTimeInputs().Map([](const Var& var) ->
Expr { return var; }));
+ // Cond 1. Do not lift bindings outside dataflow blocks.
+ if (!is_in_dataflow_block_) {
+ return false;
+ }
- // Use a fresh variable in case it is passed through unmodified in
- // the compile-time function.
- Array<Var> compile_time_outputs;
- if (auto propagated_tir_vars = GetPropagatedSymbolicVariables();
propagated_tir_vars.size()) {
- ShapeStructInfo shape_sinfo(
- propagated_tir_vars.Map([](tir::Var var) -> PrimExpr { return var;
}));
- Var shape_expr("vars_from_compile_time_params", shape_sinfo);
- compile_time_outputs.push_back(shape_expr);
- }
- for (const auto& relax_var : GetCompileTimeOutputs()) {
- compile_time_outputs.push_back(
- Var(relax_var->name_hint(), GetStructInfo(relax_var),
relax_var->span));
- }
- {
- Var tuple_output("compile_time_output",
-
TupleStructInfo(compile_time_outputs.Map(GetStructInfo)));
- calling_scope.push_back(VarBinding(tuple_output,
compile_time_preprocess));
- for (size_t i = 0; i < compile_time_outputs.size(); i++) {
- calling_scope.push_back(VarBinding(compile_time_outputs[i],
TupleGetItem(tuple_output, i)));
+ // Cond 2. Do not lift regarding the "builtin.stop_lift_params" op.
+ if (const auto* call = value.as<CallNode>()) {
+ static const Op& stop_lift_params_op =
Op::Get("relax.builtin.stop_lift_params");
+ if (call->op.same_as(stop_lift_params_op)) {
+ return false;
}
}
- Array<Expr> runtime_args = GetRuntimeInputs().Map([](const Var& var) ->
Expr { return var; });
- for (const auto& var : compile_time_outputs) {
- runtime_args.push_back(var);
+ // Cond 3. Do not lift when involving Vars that are not liftable.
+ for (const auto& var : FreeVars(value)) {
+ if (!liftable_vars_.count(var)) {
+ return false;
+ }
}
- Call runtime_execution(runtime_func, runtime_args);
- Var output_var("output", orig_func->ret_struct_info);
- calling_scope.push_back(VarBinding(output_var, runtime_execution));
+ // Cond 4. Do not lift when its struct info contains symbolic variables
that do not appear in
+ // params.
+ for (const auto& var : TIRVarsInStructInfo(GetStructInfo(binding->var))) {
+ if (!liftable_vars_.count(var)) {
+ return false;
+ }
+ }
- SeqExpr body(
- {
- BindingBlock(inner_func_bindings),
- DataflowBlock(calling_scope),
- },
- output_var);
+ // Cond 5. Do not lift declarations of external functions
+ if (value.as<relax::ExternFuncNode>()) {
+ return false;
+ }
- Function func = orig_func;
- func.CopyOnWrite()->body = body;
- func = Downcast<Function>(CanonicalizeBindings(func));
- return func;
+ return true;
}
+
+ std::unordered_set<Variant<Var, tir::Var>, ObjectPtrHash, ObjectPtrEqual>
liftable_vars_;
+ bool is_in_dataflow_block_{false};
};
-class LiftableBindingCollector : ExprVisitor {
+class LocalLiftableBindingCollector : public BaseLiftableBindingCollector {
public:
- static CollectInfo Collect(const Function& func) {
- LiftableBindingCollector visitor;
+ static LocalCollectInfo Collect(const Function& func, GlobalCollectInfo*
global_info) {
+ LocalLiftableBindingCollector visitor(global_info);
visitor(func);
visitor.info_.orig_func = func;
+
+ auto set_union =
+ [&](std::unordered_set<Variant<relax::Var, tir::Var>, ObjectPtrHash,
ObjectPtrEqual>&
+ target_set,
+ const std::unordered_set<Variant<relax::Var, tir::Var>,
ObjectPtrHash, ObjectPtrEqual>&
+ source_set,
+ const Map<relax::Var, Expr>& var_remap, const Map<tir::Var,
PrimExpr>& tir_var_remap) {
+ // In-place update the set in global info by unioning with the local
set, variable
+ // mappings are applied.
+ for (const auto& relax_or_tir_var : source_set) {
+ if (relax_or_tir_var->IsInstance<relax::VarNode>()) {
+ if (auto it = var_remap.find(Downcast<Var>(relax_or_tir_var));
+ it != var_remap.end()) {
+ target_set.insert(Downcast<relax::Var>((*it).second));
+ } else {
+ target_set.insert(Downcast<relax::Var>(relax_or_tir_var));
+ }
+ } else {
+ if (auto it =
tir_var_remap.find(Downcast<tir::Var>(relax_or_tir_var));
+ it != tir_var_remap.end()) {
+ target_set.insert(Downcast<tir::Var>((*it).second));
+ } else {
+ target_set.insert(Downcast<tir::Var>(relax_or_tir_var));
+ }
+ }
+ }
+ };
+
+ if (global_info) {
+ set_union(global_info->requires_compile_time_param,
visitor.info_.requires_compile_time_param,
+ global_info->var_remap, global_info->tir_var_remap);
+ set_union(global_info->required_at_runtime,
visitor.info_.required_at_runtime,
+ global_info->var_remap, global_info->tir_var_remap);
+ }
return visitor.info_;
}
private:
+ explicit LocalLiftableBindingCollector(GlobalCollectInfo* global_info) {
+ info_.global_info = global_info;
+ }
void VisitExpr_(const FunctionNode* func) override {
size_t num_runtime_params = func->params.size();
if (auto opt = func->attrs.GetAttr<Integer>(attr::kNumInput)) {
@@ -329,17 +446,13 @@ class LiftableBindingCollector : ExprVisitor {
ExprVisitor::VisitExpr_(func);
}
- void VisitBindingBlock_(const DataflowBlockNode* block) final {
- bool cache = is_in_dataflow_block_;
- is_in_dataflow_block_ = true;
- ExprVisitor::VisitBindingBlock_(block);
- is_in_dataflow_block_ = cache;
- }
-
void VisitBinding(const Binding& binding) override {
auto bound_value = GetBoundValue(binding);
- if (CanLiftBinding(binding)) {
+ if (CanLiftBinding(binding) &&
+ (!info_.global_info ||
info_.global_info->var_remap.count(binding->var))) {
+ // The binding is liftable and can be shared with other functions (if
global lifting is
+ // enabled)
info_.computable_at_compile_time.push_back(binding);
liftable_vars_.insert(binding->var);
@@ -388,63 +501,156 @@ class LiftableBindingCollector : ExprVisitor {
}
}
- bool CanLiftBinding(const Binding& binding) const {
- auto value = GetBoundValue(binding);
-
- // Cond 1. Do not lift bindings outside dataflow blocks.
- if (!is_in_dataflow_block_) {
- return false;
- }
+ LocalCollectInfo info_;
+};
- // Cond 2. Do not lift regarding the "builtin.stop_lift_params" op.
- if (const auto* call = value.as<CallNode>()) {
- static const Op& stop_lift_params_op =
Op::Get("relax.builtin.stop_lift_params");
- if (call->op.same_as(stop_lift_params_op)) {
- return false;
+/*! \brief Visitor to find the correspondence between parameters in multiple
functions. */
+class ParamRemapper : private ExprFunctor<void(const Expr&, const Expr&)> {
+ public:
+ static std::pair<Map<Var, Expr>, Map<tir::Var, PrimExpr>> GetParamMapping(
+ const Array<Function>& functions) {
+ ParamRemapper mapper;
+ if (functions.size()) {
+ auto num_inputs_0 =
functions[0]->GetAttr<Integer>(attr::kNumInput).value()->value;
+ int num_params = static_cast<int>(functions[0]->params.size()) -
num_inputs_0;
+ for (int i = 0; i < static_cast<int>(functions.size()); i++) {
+ auto num_inputs_i =
functions[i]->GetAttr<Integer>(attr::kNumInput).value()->value;
+ CHECK_EQ(num_params, static_cast<int>(functions[i]->params.size()) -
num_inputs_i)
+ << "The number of parameters should be the same for all target
functions";
+
+ for (int j = 0; j < num_params; j++) {
+ // Map the parameters to the first function
+ int index_i = j + num_inputs_i;
+ int index_0 = j + num_inputs_0;
+ mapper.VisitExpr(functions[i]->params[index_i],
functions[0]->params[index_0]);
+ StructuralEqual eq;
+ eq(functions[i]->params[index_i]->struct_info_,
+ functions[0]->params[index_0]->struct_info_);
+ }
}
}
+ return {mapper.var_remap_, mapper.tir_var_remap_};
+ }
- // Cond 3. Do not lift when involving Vars that are not liftable.
- for (const auto& var : FreeVars(value)) {
- if (!liftable_vars_.count(var)) {
- return false;
+ private:
+ void VisitExpr_(const VarNode* lhs_var, const Expr& rhs_expr) final {
+ auto rhs_var = Downcast<Var>(rhs_expr);
+ if (auto it = var_remap_.find(GetRef<Var>(lhs_var)); it !=
var_remap_.end()) {
+ CHECK((*it).second.same_as(rhs_var));
+ } else {
+ var_remap_.Set(GetRef<Var>(lhs_var), rhs_var);
+ }
+ CHECK(structural_equal.Equal(lhs_var->struct_info_, rhs_var->struct_info_,
+ /*map_free_vars=*/true))
+ << "The struct info of the parameters should be the same for all
target functions";
+ auto lhs_tir_vars =
DefinableTIRVarsInStructInfo(GetStructInfo(GetRef<Var>(lhs_var)));
+ auto rhs_tir_vars = DefinableTIRVarsInStructInfo(GetStructInfo(rhs_expr));
+ ICHECK_EQ(lhs_tir_vars.size(), rhs_tir_vars.size());
+ for (size_t i = 0; i < lhs_tir_vars.size(); i++) {
+ if (auto it = tir_var_remap_.find(lhs_tir_vars[i]); it !=
tir_var_remap_.end()) {
+ CHECK((*it).second.same_as(rhs_tir_vars[i]));
+ } else {
+ tir_var_remap_.Set(lhs_tir_vars[i], rhs_tir_vars[i]);
}
}
+ }
- // Cond 4. Do not lift when its struct info contains symbolic variables
that do not appear in
- // params.
- for (const auto& var : TIRVarsInStructInfo(GetStructInfo(binding->var))) {
- if (!liftable_vars_.count(var)) {
- return false;
+ SEqualHandlerDefault structural_equal{/*assert_mode=*/false,
/*first_mismatch=*/nullptr,
+ /*defer_fail=*/false};
+ Map<Var, Expr> var_remap_;
+ Map<tir::Var, PrimExpr> tir_var_remap_;
+};
+
+class GlobalLiftableBindingCollector : public BaseLiftableBindingCollector {
+ public:
+ static GlobalCollectInfo Collect(const Array<Function>& functions,
+ const Map<Var, Expr>& var_remap,
+ const Map<tir::Var, PrimExpr>&
tir_var_remap) {
+ GlobalLiftableBindingCollector collector(var_remap, tir_var_remap);
+ ICHECK(functions.size());
+ for (const auto& func : functions) {
+ int num_inputs = func->GetAttr<Integer>(attr::kNumInput).value()->value;
+ for (int i = num_inputs; i < static_cast<int>(func->params.size()); i++)
{
+ collector.liftable_vars_.insert(func->params[i]);
}
+ collector(func);
}
+ Array<Var> params(functions[0]->params.begin() +
+
functions[0]->GetAttr<Integer>(attr::kNumInput).value()->value,
+ functions[0]->params.end());
+ // todo(@tvm-team): use c++20 designated initializers when windows CI
supports it
+ GlobalCollectInfo info = GlobalCollectInfo();
+ info.orig_functions = functions;
+ info.params = std::move(params);
+ info.var_remap = var_remap;
+ info.tir_var_remap = tir_var_remap;
+ // Find shared bindings among transform_params. Re-compute var_remap based
on the shared
+ // bindings as collector.var_remap_ may contain invalid mappings.
+ for (const auto& unified_binding : collector.unified_bindings_) {
+ const auto& original_bindings =
collector.original_bindings_[GetBoundValue(unified_binding)];
+ // Note: it is possible that one or more functions have common
subexpressions such as:
+ //
+ // func1:
+ // w1_t = w.transpose
+ // w2_t = w.transpose
+ //
+ // func2:
+ // w1_t = w.transpose
+ // w2_t = w.transpose
+ //
+ // In this case, original_bindings.size() != functions.size() but we
should still consider
+ // w and w.transpose as a shared binding.
- // Cond 5. Do not lift declarations of external functions
- if (value.as<relax::ExternFuncNode>()) {
- return false;
+ if (original_bindings.size() == functions.size()) {
+ info.computable_at_compile_time.push_back(unified_binding);
+ for (const auto& original_binding : original_bindings) {
+ info.var_remap.Set(original_binding->var, unified_binding->var);
+ }
+ }
}
-
- return true;
+ return info;
}
- CollectInfo info_;
- std::unordered_set<Variant<Var, tir::Var>, ObjectPtrHash, ObjectPtrEqual>
liftable_vars_;
- bool is_in_dataflow_block_{false};
-};
-
-class PreprocessPartitioner : public ExprMutator {
- public:
- using ExprMutator::VisitExpr_;
- Expr VisitExpr_(const FunctionNode* op) override {
- auto func = GetRef<Function>(op);
- if (func->attrs.GetAttr<Integer>(attr::kNumInput)) {
- auto info = LiftableBindingCollector::Collect(func);
- return info.MakePartitionedFunction();
- } else {
- return func;
+ private:
+ GlobalLiftableBindingCollector(const Map<Var, Expr>& var_remap,
+ const Map<tir::Var, PrimExpr> tir_var_remap)
+ : var_remap_(var_remap), tir_var_remap_(tir_var_remap) {}
+ void VisitBinding(const Binding& binding) override {
+ CHECK(!binding->IsInstance<MatchCastNode>()) << "MatchCast is not
supported in global lifting";
+ if (CanLiftBinding(binding)) {
+ liftable_vars_.insert(binding->var);
+ auto bound_value = GetBoundValue(binding);
+ auto new_value = Bind(bound_value, var_remap_, tir_var_remap_);
+ if (auto it = original_bindings_.find(new_value); it !=
original_bindings_.end()) {
+ it->second.push_back(binding);
+ } else {
+ unified_bindings_.push_back(binding);
+ original_bindings_[new_value].push_back(binding);
+ }
+ var_remap_.Set(binding->var, original_bindings_[new_value].front()->var);
}
}
-};
+
+ // The cross-function mapping between variables. This is initialized with
the mapping from the
+ // function parameters, and is updated with the mapping between binding
variables asthe collector
+ // visits the bindings.
+ Map<Var, Expr> var_remap_;
+ // The cross-function between between TIR variables.
+ Map<tir::Var, PrimExpr> tir_var_remap_;
+ std::vector<Binding> unified_bindings_;
+ // The mapping between the unified bindings and the original bindings in
different functions.
+ // The unified binding is the binding with all variables replaced by the
unified variables as
+ // defined in var_remap_.
+ std::unordered_map<Expr, std::vector<Binding>, StructuralHash,
StructuralEqual>
+ original_bindings_;
+}; // namespace
+
+GlobalCollectInfo MakeGlobalLiftPlan(const IRModule& mod,
+ const std::vector<Function>&
target_functions) {
+ ParamRemapper remapper;
+ auto [var_remap, tir_var_remap] =
ParamRemapper::GetParamMapping(target_functions);
+ return GlobalLiftableBindingCollector::Collect(target_functions, var_remap,
tir_var_remap);
+}
// Adapted from https://stackoverflow.com/a/2072890
inline bool ends_with(const std::string& value, const std::string& ending) {
@@ -494,21 +700,76 @@ class ConsumeBundledParams : public ExprMutator {
std::unordered_map<int, Expr> param_remap_;
};
+std::vector<std::pair<GlobalVar, Function>> GetTargetFunctions(
+ const IRModule& mod, const Variant<Bool, Array<String>>& shared_transform)
{
+ std::vector<std::pair<GlobalVar, Function>> target_functions;
+ if (shared_transform.as<Array<String>>().value_or(Array<String>{}).size()) {
+ for (const auto& name : shared_transform.as<Array<String>>().value()) {
+ auto gvar = mod->GetGlobalVar(name);
+ target_functions.push_back({gvar,
Downcast<Function>(mod->Lookup(gvar))});
+ }
+ } else {
+ // Get all the functions that have the `num_input` attribute.
+ for (const auto& [gvar, func] : mod->functions) {
+ if (func->IsInstance<FunctionNode>()) {
+ auto opt_num_input = func->GetAttr<Integer>(attr::kNumInput);
+ if (opt_num_input) {
+ target_functions.emplace_back(gvar, Downcast<Function>(func));
+ }
+ }
+ }
+ std::sort(target_functions.begin(), target_functions.end(),
+ [](const auto& lhs, const auto& rhs) {
+ return lhs.first->name_hint < rhs.first->name_hint;
+ });
+ }
+ return target_functions;
+}
+
} // namespace
namespace transform {
-Pass PartitionTransformParams() {
+Pass PartitionTransformParams(Variant<Bool, Array<String>> shared_transform) {
auto pass_func = [=](IRModule mod, PassContext pc) {
- PreprocessPartitioner mutator;
-
IRModule updates;
- for (const auto& [gvar, func] : mod->functions) {
- if (auto opt = func.as<relax::Function>()) {
- auto new_func = Downcast<Function>(mutator(opt.value()));
- if (!new_func.same_as(func)) {
- updates->Add(gvar, new_func);
- }
+ std::optional<GlobalCollectInfo> global_collect_info;
+
+ CHECK(shared_transform.defined()) << "shared_transform is not defined";
+ CHECK((shared_transform.as<Bool>() ||
shared_transform.as<Array<String>>()))
+ << "shared_transform should be a boolean or an array of function
names";
+
+ auto target_functions = GetTargetFunctions(mod, shared_transform);
+
+ if (shared_transform.as<Bool>().value_or(Bool(true))) {
+ std::vector<Function> functions;
+ for (const auto& [_, func] : target_functions) {
+ functions.push_back(func);
+ }
+ global_collect_info = MakeGlobalLiftPlan(mod, functions);
+ }
+
+ std::unordered_map<GlobalVar, LocalCollectInfo, ObjectPtrHash,
ObjectPtrEqual>
+ local_collect_info;
+ for (const auto& [gvar, func] : target_functions) {
+ auto info = LocalLiftableBindingCollector::Collect(
+ func, global_collect_info.has_value() ? &global_collect_info.value()
: nullptr);
+ local_collect_info[gvar] = info;
+ }
+
+ for (const auto& [gvar, info] : local_collect_info) {
+ auto new_runtime_func = info.MakeRuntimeFunction();
+ updates->Add(gvar, new_runtime_func);
+ }
+
+ if (global_collect_info.has_value()) {
+ auto global_transform =
global_collect_info.value().MakeCompileTimeFunc();
+ updates->Add(GlobalVar("transform_params"), global_transform);
+ } else {
+ for (const auto& [gvar, info] : local_collect_info) {
+ // transform_params is emitted for each function if global lifting is
not enabled
+ updates->Add(GlobalVar(gvar->name_hint + "_transform_params"),
+ info.MakeCompileTimeFunction());
}
}
@@ -521,7 +782,7 @@ Pass PartitionTransformParams() {
return tvm::transform::CreateModulePass(pass_func, 1,
"PartitionTransformParams", {});
}
-Pass LiftTransformParams() {
+Pass LiftTransformParams(Variant<Bool, Array<String>> shared_transform) {
// A post-proc utility as as the third step in LiftTransformParams
//
// 1. PartitionTransformParams: Partition each function into a
@@ -533,7 +794,6 @@ Pass LiftTransformParams() {
// 3. Post-proc: Expose the compile-time and run-time functions for
// external use, replacing the end-to-end functions.
auto post_proc_func = [=](IRModule mod, PassContext pc) {
- std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> to_remove;
std::unordered_map<GlobalVar, Function, ObjectPtrHash, ObjectPtrEqual>
to_add;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<Function>()) {
@@ -547,20 +807,12 @@ Pass LiftTransformParams() {
func = Downcast<Function>(ConsumeBundledParams()(func));
}
to_add[gvar] = func;
- } else if (ends_with(func_name, "_runtime")) {
- std::string name(func_name.begin(), func_name.end() -
sizeof("_runtime") + 1);
- to_remove.insert(mod->GetGlobalVar(name));
- to_remove.insert(gvar);
- to_add[GlobalVar(name)] = WithAttr(func, tvm::attr::kGlobalSymbol,
String(name));
}
}
}
- if (to_remove.size() || to_add.size()) {
+ if (to_add.size()) {
auto write_ptr = mod.CopyOnWrite();
- for (const auto& gvar : to_remove) {
- write_ptr->Remove(gvar);
- }
for (const auto& [gvar, func] : to_add) {
write_ptr->Add(gvar, func);
}
@@ -573,7 +825,7 @@ Pass LiftTransformParams() {
return tvm::transform::Sequential(
{
- PartitionTransformParams(),
+ PartitionTransformParams(shared_transform),
LambdaLift(),
post_proc,
},
diff --git a/tests/python/relax/test_transform_lift_transform_params.py
b/tests/python/relax/test_transform_lift_transform_params.py
index 80de52ca66..508664f1ef 100644
--- a/tests/python/relax/test_transform_lift_transform_params.py
+++ b/tests/python/relax/test_transform_lift_transform_params.py
@@ -482,6 +482,886 @@ def test_multiple_functions():
tvm.ir.assert_structural_equal(after, Expected)
+def test_share_identical_transform_across_multiple_functions():
+ """Like test_multiple_functions, but producing a single transform_params
+
+ `func1` and `func2` contain the same values `w1_t` and `w2_t`.
+ When `shared_transform=True`, all eligible publicly-exposed
+ functions must be usable with the same shared transform.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ w2_t = R.permute_dims(w2)
+ y2 = R.matmul(x, w2_t)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ w2_t = R.permute_dims(w2)
+ y2 = R.matmul(x, w2_t)
+ output = R.multiply(y1, y2)
+ R.output(output)
+ return output
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def transform_params(
+ params: R.Tuple(
+ R.Tensor((256, 256), dtype="float32"),
+ R.Tensor((256, 256), dtype="float32"),
+ )
+ ):
+ R.func_attr({"num_input": 0})
+ with R.dataflow():
+ w1 = params[0]
+ w1_t = R.permute_dims(w1)
+ w2 = params[1]
+ w2_t = R.permute_dims(w2)
+ output = (w1_t, w2_t)
+ R.output(output)
+ return output
+
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w1_t: R.Tensor((256, 256), "float32"),
+ w2_t: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ y1 = R.matmul(x, w1_t)
+ y2 = R.matmul(x, w2_t)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float32"),
+ w1_t: R.Tensor((256, 256), "float32"),
+ w2_t: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ y1 = R.matmul(x, w1_t)
+ y2 = R.matmul(x, w2_t)
+ output = R.multiply(y1, y2)
+ R.output(output)
+ return output
+
+ after = relax.transform.LiftTransformParams(shared_transform=True)(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_incompatible_weights_in_shared_transform_raises_error():
+ """Model weights must have matched shape for shared_transform
+
+ Here, `func1` accepts one model weight, but `func2` accepts two.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ output = y1
+ R.output(output)
+ return output
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ w2_t = R.permute_dims(w2)
+ y2 = R.matmul(x, w2_t)
+ output = R.multiply(y1, y2)
+ R.output(output)
+ return output
+
+ with pytest.raises(tvm.TVMError):
+ relax.transform.LiftTransformParams(shared_transform=True)(Before)
+
+
+def test_incompatible_shape_in_shared_transform_raises_error():
+ """Model weights must have matched shape for shared_transform
+
+ Here, `func1` accepts `w1` and `w2` with shape `[256,256]`, but `func2`
+ requires shape `[128, 256]`.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ w2_t = R.permute_dims(w2)
+ y2 = R.matmul(x, w2_t)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((128, 256), "float32"),
+ w2: R.Tensor((128, 256), "float32"),
+ ) -> R.Tensor((256, 128), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ w2_t = R.permute_dims(w2)
+ y2 = R.matmul(x, w2_t)
+ output = R.multiply(y1, y2)
+ R.output(output)
+ return output
+
+ with pytest.raises(tvm.TVMError):
+ relax.transform.LiftTransformParams(shared_transform=True)(Before)
+
+
+def test_incompatible_dtype_in_shared_transform_raises_error():
+ """Model weights must have matched dtype for shared_transform
+
+ Here, `func1` accepts `w1` and `w2` with "float32" dtype, but
+ `func2` requires "float16".
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ w2_t = R.permute_dims(w2)
+ y2 = R.matmul(x, w2_t)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float16"),
+ w1: R.Tensor((128, 256), "float16"),
+ w2: R.Tensor((128, 256), "float16"),
+ ) -> R.Tensor((256, 128), "float16"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ w2_t = R.permute_dims(w2)
+ y2 = R.matmul(x, w2_t)
+ output = R.multiply(y1, y2)
+ R.output(output)
+ return output
+
+ with pytest.raises(tvm.TVMError):
+ relax.transform.LiftTransformParams(shared_transform=True)(Before)
+
+
+def
test_share_transform_across_multiple_functions_has_intersection_of_transforms():
+ """Like test_multiple_functions, but producing a single transform_params
+
+ In `func1`, both `w1_t` and `w2_t` could be lifted out. In
+ `func2`, only `w1_t` could be lifted out of the function.
+ Therefore, the shared `transform_params` can pre-compute `w1_t`,
+ but must preserve `w2`.
+
+ When `shared_transform=True`, all eligible publicly-exposed
+ functions must be usable with the same shared transform.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ w2_t = R.permute_dims(w2)
+ y2 = R.matmul(x, w2_t)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ y2 = Before.fused_permute_dims_matmul(x, w2)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function(private=True)
+ def fused_permute_dims_matmul(
+ x: R.Tensor((256, 256), "float32"),
+ weight: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ with R.dataflow():
+ weight_t = R.permute_dims(weight)
+ y = R.matmul(x, weight_t)
+ R.output(y)
+ return y
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def transform_params(
+ params: R.Tuple(
+ R.Tensor((256, 256), dtype="float32"),
+ R.Tensor((256, 256), dtype="float32"),
+ )
+ ):
+ R.func_attr({"num_input": 0})
+ with R.dataflow():
+ w1 = params[0]
+ w1_t = R.permute_dims(w1)
+ w2 = params[1]
+ output = (w2, w1_t)
+ R.output(output)
+ return output
+
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ w1_t: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ y1 = R.matmul(x, w1_t)
+ w2_t = R.permute_dims(w2)
+ y2 = R.matmul(x, w2_t)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ w1_t: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ y1 = R.matmul(x, w1_t)
+ y2 = Expected.fused_permute_dims_matmul(x, w2)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function(private=True)
+ def fused_permute_dims_matmul(
+ x: R.Tensor((256, 256), "float32"),
+ weight: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ with R.dataflow():
+ weight_t = R.permute_dims(weight)
+ y = R.matmul(x, weight_t)
+ R.output(y)
+ return y
+
+ after = relax.transform.LiftTransformParams(shared_transform=True)(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_share_transforms_with_different_binding_order():
+ """Like test_share_transform_across_multiple_functions, but the
+ lifted bindings are in different order for each function.
+
+ Both `func1` and `func2` compute the same value for `w1_t` and
+ `w2_t`. However, the bindings occur in different orders. The
+ shared `transform_params` can pre-compute both `w1_t` and `w2_t`,
+ even though they occur in different orders.
+
+ For consistency in testing and pre-computing weights, the order of
+ `transform_params` should be deterministic. When lifting from a
+ single function, the bindings in `transform_params` may be
+ determined from the order in that function. When lifting from
+ multiple functions, the order should be deterministic. Since
+ `IRModule::functions` has unspecified order, the order in this
+ test assumes that public functions are visited in alphabetical
+ order by name.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w2_t = R.permute_dims(w2)
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ y2 = R.matmul(x, w2_t)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ w2_t = R.permute_dims(w2)
+ y2 = R.matmul(x, w2_t)
+ output = R.multiply(y1, y2)
+ R.output(output)
+ return output
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def transform_params(
+ params: R.Tuple(
+ R.Tensor((256, 256), dtype="float32"),
+ R.Tensor((256, 256), dtype="float32"),
+ )
+ ):
+ R.func_attr({"num_input": 0})
+ with R.dataflow():
+ w2 = params[1]
+ w2_t = R.permute_dims(w2)
+ w1 = params[0]
+ w1_t = R.permute_dims(w1)
+
+ output = (w2_t, w1_t)
+ R.output(output)
+ return output
+
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w2_t: R.Tensor((256, 256), "float32"),
+ w1_t: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ y1 = R.matmul(x, w1_t)
+ y2 = R.matmul(x, w2_t)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float32"),
+ w2_t: R.Tensor((256, 256), "float32"),
+ w1_t: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ y1 = R.matmul(x, w1_t)
+ y2 = R.matmul(x, w2_t)
+ output = R.multiply(y1, y2)
+ R.output(output)
+ return output
+
+ after = relax.transform.LiftTransformParams(shared_transform=True)(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_share_transforms_resulting_in_identical_functions():
+ """Functions in the public interface must be preserved
+
+ When lifting functions, the resulting functions may be identical.
+ Even though the `relax.BlockBuilder` de-duplicates identical
+ functions, functions that are part of the IRModule's public
+ interface must be preserved.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w2_t = R.permute_dims(w2)
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ y2 = R.matmul(x, w2_t)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ w2_t = R.permute_dims(w2)
+ y2 = R.matmul(x, w2_t)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def transform_params(
+ params: R.Tuple(
+ R.Tensor((256, 256), dtype="float32"),
+ R.Tensor((256, 256), dtype="float32"),
+ )
+ ):
+ R.func_attr({"num_input": 0})
+ with R.dataflow():
+ w2 = params[1]
+ w2_t = R.permute_dims(w2)
+ w1 = params[0]
+ w1_t = R.permute_dims(w1)
+ output = (w2_t, w1_t)
+ R.output(output)
+ return output
+
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w2_t: R.Tensor((256, 256), "float32"),
+ w1_t: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ y1 = R.matmul(x, w1_t)
+ y2 = R.matmul(x, w2_t)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float32"),
+ w2_t: R.Tensor((256, 256), "float32"),
+ w1_t: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ y1 = R.matmul(x, w1_t)
+ y2 = R.matmul(x, w2_t)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ after = relax.transform.LiftTransformParams(shared_transform=True)(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_share_transform_across_specified_functions():
+ """Like test_multiple_functions, but producing a single transform_params
+
+ In `func1`, both `w1_t` and `w2_t` could be lifted out. In
+ `func2`, only `w1_t` could be lifted out of the function.
+ Therefore, the shared `transform_params` can pre-compute `w1_t`,
+ but must preserve `w2`.
+
+ If `func3` were included in the `transform_params`, the same logic
+ would prevent `w1_t` from being computed in the shared
+ `transform_params`. However, the
+ `shared_transform=['func1','func2']` argument means that `func3`
+ does not have any parameter transformations lifted out.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ w2_t = R.permute_dims(w2)
+ y2 = R.matmul(x, w2_t)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ y2 = Before.fused_permute_dims_matmul(x, w2)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function
+ def func3(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ y1 = Before.fused_permute_dims_matmul(x, w1)
+ y2 = Before.fused_permute_dims_matmul(x, w2)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function(private=True)
+ def fused_permute_dims_matmul(
+ x: R.Tensor((256, 256), "float32"),
+ weight: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ with R.dataflow():
+ weight_t = R.permute_dims(weight)
+ y = R.matmul(x, weight_t)
+ R.output(y)
+ return y
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def transform_params(
+ params: R.Tuple(
+ R.Tensor((256, 256), dtype="float32"),
+ R.Tensor((256, 256), dtype="float32"),
+ )
+ ):
+ R.func_attr({"num_input": 0})
+ with R.dataflow():
+ w1 = params[0]
+ w1_t = R.permute_dims(w1)
+ w2 = params[1]
+ output = (w2, w1_t)
+ R.output(output)
+ return output
+
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ w1_t: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ y1 = R.matmul(x, w1_t)
+ w2_t = R.permute_dims(w2)
+ y2 = R.matmul(x, w2_t)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ w1_t: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ y1 = R.matmul(x, w1_t)
+ y2 = Expected.fused_permute_dims_matmul(x, w2)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function
+ def func3(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ y1 = Expected.fused_permute_dims_matmul(x, w1)
+ y2 = Expected.fused_permute_dims_matmul(x, w2)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function(private=True)
+ def fused_permute_dims_matmul(
+ x: R.Tensor((256, 256), "float32"),
+ weight: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ with R.dataflow():
+ weight_t = R.permute_dims(weight)
+ y = R.matmul(x, weight_t)
+ R.output(y)
+ return y
+
+ after = relax.transform.LiftTransformParams(shared_transform=["func1",
"func2"])(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_share_transform_with_unused_parameter():
+ """Like test_share_transform_across_specified_functions, but not
+ all functions use every model weight.
+
+ In `func1`, both `w1_t` and `w2_t` could be lifted out. In
+ `func2`, only `w1_t` could be lifted out of the function.
+ Normally, the `w2` parameter would need to be preserved, as `w2_t`
+ is only generated in one of the functions. However, `func2`
+ doesn't use `w2` at all, and so `w2_t` can still be pre-computed.
+
+ For example, a `embed_vocab` function would only use the embedding
+ weights. It could accept the full set of model weights for
+ consistency, but any transformations performed on unused weights
+ in other functions can still be lifted out.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ w2_t = R.permute_dims(w2)
+ y2 = R.matmul(x, w2_t)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ R.output(y1)
+ return y1
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def transform_params(
+ params: R.Tuple(
+ R.Tensor((256, 256), dtype="float32"),
+ R.Tensor((256, 256), dtype="float32"),
+ )
+ ):
+ R.func_attr({"num_input": 0})
+ with R.dataflow():
+ w1 = params[0]
+ w1_t = R.permute_dims(w1)
+ w2 = params[1]
+ output = (w2, w1_t)
+ R.output(output)
+ return output
+
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ w1_t: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ y1 = R.matmul(x, w1_t)
+ w2_t = R.permute_dims(w2)
+ y2 = R.matmul(x, w2_t)
+ output = R.add(y1, y2)
+ R.output(output)
+ return output
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ w1_t: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ y1 = R.matmul(x, w1_t)
+ R.output(y1)
+ return y1
+
+ after = relax.transform.LiftTransformParams(shared_transform=True)(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
[email protected]
+def test_share_transform_with_no_shared_preprocessing():
+ """Like test_share_transform_with_unused_parameter, but each
+ function uses a single model weight.
+
+ In `func1`, `w2_t` can be lifted out and `w1` is unused. In
+ `func2`, `w1_t` can be lifted out, and `w2` is unused. In their
+ shared `transform_params`, both `w1_t` and `w2_t` can be computed.
+
+ For consistency in testing and pre-computing weights, the order of
+ `transform_params` should be deterministic. When lifting from a
+ single function, the bindings in `transform_params` may be
+ determined from the order in that function. When lifting from
+ multiple functions, the order should be deterministic. Since
+ `IRModule::functions` has unspecified order, the order in this
+ test assumes that public functions are visited in alphabetical
+ order by name.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w2_t = R.permute_dims(w2)
+ y2 = R.matmul(x, w2_t)
+ R.output(y2)
+ return y2
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ w2: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1)
+ y1 = R.matmul(x, w1_t)
+ R.output(y1)
+ return y1
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def transform_params(
+ params: R.Tuple(
+ R.Tensor((256, 256), dtype="float32"),
+ R.Tensor((256, 256), dtype="float32"),
+ )
+ ):
+ R.func_attr({"num_input": 0})
+ with R.dataflow():
+ w1 = params[0]
+ w1_t = R.permute_dims(w1)
+ w2 = params[1]
+ w2_t = R.permute_dims(w2)
+ output = (w2_t, w1_t)
+ R.output(output)
+ return output
+
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w2_t: R.Tensor((256, 256), "float32"),
+ w1_t: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ y2 = R.matmul(x, w2_t)
+ R.output(y2)
+ return y2
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float32"),
+ w2_t: R.Tensor((256, 256), "float32"),
+ w1_t: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ y1 = R.matmul(x, w1_t)
+ R.output(y1)
+ return y1
+
+ after = relax.transform.LiftTransformParams(shared_transform=True)(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
def test_stop_lifting():
@tvm.script.ir_module
class Before: