This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 81bf988675 [Unity][Transform] LiftTransformParams handling multiple
functions (#14192)
81bf988675 is described below
commit 81bf98867566d81a6677dbcca9ad5a638a3fbe88
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Mar 4 13:26:02 2023 -0500
[Unity][Transform] LiftTransformParams handling multiple functions (#14192)
Previously, the LiftTransformParams pass only works on function
`"main"`. This is a bit restrictive as in our recent practice on stable
diffusion, there are cases where multiple Relax functions inside an
IRModule all need to be transformed.
Therefore, this PR enhances the LiftTransformParams pass, so that it
will now transform **all** functions **with attribute `num_input`**. For
functions without this attribute, the pass will simply skip them.
---
src/relax/transform/lift_transform_params.cc | 37 +++++---
.../relax/test_transform_lift_transform_params.py | 105 ++++++++++++++++++++-
2 files changed, 124 insertions(+), 18 deletions(-)
diff --git a/src/relax/transform/lift_transform_params.cc
b/src/relax/transform/lift_transform_params.cc
index 97ed8b24a0..401a03dbe2 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -207,35 +207,41 @@ class TransformParamsLifter : public ExprMutator {
IRModule Lift() {
auto mod = builder_->GetContextIRModule();
- GlobalVar gv_main = mod->GetGlobalVar("main");
- Function func = Downcast<Function>(mod->Lookup(gv_main));
- func = RewriteFunc(func);
- builder_->UpdateFunction(gv_main, func);
+ for (const auto& [gv, base_func] : mod->functions) {
+ // Skip non-Relax functions.
+ const auto* func_ = base_func.as<FunctionNode>();
+ if (func_ == nullptr) {
+ continue;
+ }
+ // Skip functions that do not have the `num_input` attribute.
+ Optional<Integer> opt_num_input =
func_->attrs.GetAttr<Integer>(attr_num_input_);
+ if (!opt_num_input.defined()) {
+ continue;
+ }
+ Function func = RewriteFunc(GetRef<Function>(func_),
opt_num_input.value()->value,
+ gv->name_hint + "_transform_params");
+ builder_->UpdateFunction(gv, func);
+ }
+
return builder_->GetContextIRModule();
}
private:
- Function RewriteFunc(const Function& func) {
- const std::string attr_num_input = "num_input";
- auto opt_num_input = func->attrs.GetAttr<Integer>(attr_num_input);
- if (!opt_num_input.defined()) {
- return func;
- }
+ Function RewriteFunc(const Function& func, int num_input, String
new_func_name) {
LiftTransformParamsPlanner planner;
- int64_t params_begin = opt_num_input.value()->value;
// Step 1: Create the plan of lifting transform params
- lift_plan_ = planner.Plan(func, params_begin);
+ lift_plan_ = planner.Plan(func, num_input);
// Step 2: Add the lifted function to the module
- builder_->AddFunction(lift_plan_.f_transform_params, "transform_params");
+ builder_->AddFunction(lift_plan_.f_transform_params, new_func_name);
// Step 3: Update the current function.
// Step 3.1: Update the function signature
Var params("params", lift_plan_.f_transform_params->ret_struct_info);
Array<Var> new_params;
- for (int i = 0; i < params_begin; ++i) {
+ for (int i = 0; i < num_input; ++i) {
new_params.push_back(func->params[i]);
}
new_params.push_back(params);
@@ -249,7 +255,7 @@ class TransformParamsLifter : public ExprMutator {
// Step 3.3: Remove function attributes that are not needed
auto new_attrs = func->attrs;
auto* new_attrs_node = new_attrs.CopyOnWrite();
- new_attrs_node->dict.erase(attr_num_input);
+ new_attrs_node->dict.erase(attr_num_input_);
if (new_attrs->dict.empty()) {
new_attrs = NullValue<DictAttrs>();
}
@@ -277,6 +283,7 @@ class TransformParamsLifter : public ExprMutator {
return VisitExpr_(static_cast<const VarNode*>(var));
}
+ const char* attr_num_input_ = "num_input";
// Remap the original parameters to TupleGetItem from the packed tuple of
transformed parameters.
std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_remap_;
// The plan of lifting the transform params
diff --git a/tests/python/relax/test_transform_lift_transform_params.py
b/tests/python/relax/test_transform_lift_transform_params.py
index a1f67d41da..8c2eae684a 100644
--- a/tests/python/relax/test_transform_lift_transform_params.py
+++ b/tests/python/relax/test_transform_lift_transform_params.py
@@ -106,7 +106,7 @@ def test_basic():
out[o, i, h, w] = w1[i, o, h, w]
@R.function
- def transform_params(
+ def main_transform_params(
params: R.Tuple(
R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3,
3), dtype="float32")
)
@@ -193,7 +193,7 @@ def test_tuple():
return conv2
@R.function
- def transform_params(
+ def main_transform_params(
params: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32"))
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3,
3), dtype="float32")
@@ -242,7 +242,7 @@ def test_condition():
@tvm.script.ir_module
class Expected:
@R.function
- def transform_params(
+ def main_transform_params(
params: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((16, 16, 3, 3), dtype="float32"),
@@ -291,5 +291,104 @@ def test_condition():
tvm.ir.assert_structural_equal(after, Expected)
+def test_multiple_functions():
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1, [1, 0])
+ y = R.matmul(x, w1_t)
+ R.output(y)
+ return y
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((128, 256), "float32"),
+ ) -> R.Tensor((256, 128), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_t = R.permute_dims(w1, [1, 0])
+ y = R.matmul(x, w1_t)
+ R.output(y)
+ return y
+
+ @R.function
+ def func3(
+ x: R.Tensor((256, 256), "float32"),
+ w1: R.Tensor((256, 256), "float32"),
+ ) -> R.Tensor((256, 256), "float32"):
+ with R.dataflow():
+ w1_t = R.permute_dims(w1, [1, 0])
+ y = R.matmul(x, w1_t)
+ R.output(y)
+ return y
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def func1(
+ x: R.Tensor((256, 256), dtype="float32"),
+ params: R.Tuple(R.Tensor((256, 256), dtype="float32")),
+ ) -> R.Tensor((256, 256), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((256, 256), dtype="float32") = params[0]
+ y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, lv,
out_dtype="void")
+ R.output(y)
+ return y
+
+ @R.function
+ def func1_transform_params(
+ params: R.Tuple(R.Tensor((256, 256), dtype="float32"))
+ ) -> R.Tuple(R.Tensor((256, 256), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((256, 256), dtype="float32") = params[0]
+ lv1: R.Tensor((256, 256), dtype="float32") =
R.permute_dims(lv, axes=[1, 0])
+ gv: R.Tuple(R.Tensor((256, 256), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def func2(
+ x: R.Tensor((256, 256), dtype="float32"),
+ params: R.Tuple(R.Tensor((256, 128), dtype="float32")),
+ ) -> R.Tensor((256, 128), dtype="float32"):
+ with R.dataflow():
+ lv1: R.Tensor((256, 128), dtype="float32") = params[0]
+ y: R.Tensor((256, 128), dtype="float32") = R.matmul(x, lv1,
out_dtype="void")
+ R.output(y)
+ return y
+
+ @R.function
+ def func2_transform_params(
+ params: R.Tuple(R.Tensor((128, 256), dtype="float32"))
+ ) -> R.Tuple(R.Tensor((256, 128), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((128, 256), dtype="float32") = params[0]
+ lv1: R.Tensor((256, 128), dtype="float32") =
R.permute_dims(lv, axes=[1, 0])
+ gv: R.Tuple(R.Tensor((256, 128), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def func3(
+ x: R.Tensor((256, 256), dtype="float32"), w1: R.Tensor((256, 256),
dtype="float32")
+ ) -> R.Tensor((256, 256), dtype="float32"):
+ with R.dataflow():
+ w1_t: R.Tensor((256, 256), dtype="float32") =
R.permute_dims(w1, axes=[1, 0])
+ y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, w1_t,
out_dtype="void")
+ R.output(y)
+ return y
+
+ mod = Before
+ after = relax.transform.LiftTransformParams()(mod)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()