This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 1f4573984d [Unity][Transform] Add LiftTransformParams pass (#14069)
1f4573984d is described below
commit 1f4573984d7a4b15ced592e9deab3bda3f213ba1
Author: Wuwei Lin <[email protected]>
AuthorDate: Tue Feb 21 12:21:56 2023 -0800
[Unity][Transform] Add LiftTransformParams pass (#14069)
This PR added a pass `LiftTransformParams`. It allows to compile the
end-to-end model without weights provided. The idea is annotate the
input parameters that are weights, and identify and lift the
transformations to weights, and compile it to a separate function
`transform_params` that can be executed in runtime. Users can run
`transform_params` with weights to get the weights for the optimized
model as a prep step before the deployment. In this way, we perform the
same optimizations and defer the weight transformations to the user
side, while the overhead of the deferred weight transformation can be
ignored as it only need to be run once.
This pass is integrated with the default `vm.build`. It is optional and
only necessary when the parameters are kept as inputs when importing the
model from the frontend.
---
include/tvm/relax/transform.h | 16 ++
python/tvm/relax/transform/transform.py | 21 +-
src/relax/transform/lift_transform_params.cc | 297 +++++++++++++++++++++
.../relax/test_transform_lift_transform_params.py | 295 ++++++++++++++++++++
4 files changed, 628 insertions(+), 1 deletion(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 1934a9f9f2..7d9f3d64b0 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -174,6 +174,22 @@ TVM_DLL Pass Normalize();
*/
TVM_DLL Pass LegalizeOps(Optional<Map<String, PackedFunc>> cmap);
+/*
+ * \brief Lift transformation of the parameters of a function.
+ *
+ * When some inputs of the function is marked as 'parameters' (the model
weights), this pass
+ * identifies the transformation of the parameters and lifts them to a
separate function called
+ * `transform_params`. `transform_params` takes a tuple of the original
parameters as input and
+ * returns a tuple of the transformed parameters. The original function will
be rewritten to accept
+ * a tuple of transformed parameters as input.
+ *
+ * Users are expected to invoke the `transform_params` function in runtime and
pass the transformed
+ * parameters to the original function as input.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass LiftTransformParams();
+
} // namespace transform
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 12ed27f73a..590059739c 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -285,11 +285,30 @@ def MergeCompositeFunctions() -> tvm.ir.transform.Pass:
-------
ret : tvm.transform.Pass
The registered pass for merging composite functions.
-
"""
return _ffi_api.MergeCompositeFunctions() # type: ignore
+def LiftTransformParams() -> tvm.ir.transform.Pass:
+ """Lift transformation of the parameters of a function.
+
+ When some inputs of the function is marked as 'parameters' (the model
weights), this pass
+ identifies the transformation of the parameters and lifts them to a
separate function called
+ `transform_params`. `transform_params` takes a tuple of the original
parameters as input and
+ returns a tuple of the transformed parameters. The original function will
be rewritten to accept
+ a tuple of transformed parameters as input.
+
+ Users are expected to invoke the `transform_params` function in runtime
and pass the transformed
+ parameters to the original function as input.
+
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The registered pass for lifting transformation of parameters.
+ """
+ return _ffi_api.LiftTransformParams() # type: ignore
+
+
def LegalizeOps(customize_legalize_map: Optional[Dict[str, LegalizeFunc]] =
None):
"""Legalize high-level operator calls in Relax functions to call_tir
with corresponding low-level TIR PrimFuncs.
diff --git a/src/relax/transform/lift_transform_params.cc
b/src/relax/transform/lift_transform_params.cc
new file mode 100644
index 0000000000..97ed8b24a0
--- /dev/null
+++ b/src/relax/transform/lift_transform_params.cc
@@ -0,0 +1,297 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/lambda_lift.cc
+ * \brief Lift local functions into global functions.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/runtime/logging.h>
+
+#include <iostream>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Plan of lifting transform params */
+struct LiftTransformParamsInfoPlan {
+ Function f_transform_params; // the lifted function that transforms the
parameters
+ std::unordered_map<Var, int, ObjectPtrHash, ObjectPtrEqual>
+ output_to_index; // the index of the original bindings in the output
tuple
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>
+ lifted_bindings; // the bindings of the original function that are
lifted
+};
+
+/*! \brief Builder of the function that transforms the parameters. */
+class TransformParamsFuncBuilder : public ExprMutator {
+ public:
+ TransformParamsFuncBuilder() { builder_->BeginDataflowBlock(); }
+
+ /*! \brief Add a input parameter. */
+ void AddInput(const Var& var) { inputs_.push_back(var); }
+
+ /*! \brief Add a binding to lift. */
+ void AddBinding(const VarBinding& binding) { bindings_.push_back(binding); }
+
+ /*! \brief Mark a variable as the output of the function. */
+ void MarkOutput(const Var& output) { outputs_.insert(output); }
+
+ /*!
+ * \brief Build the function that transforms the parameters
+ * \return The created function, and a map from the variable in the original
function to the index
+ * of the element of the output tuple
+ */
+ std::pair<Function, std::unordered_map<Var, int, ObjectPtrHash,
ObjectPtrEqual>> Build() {
+ Array<StructInfo> input_sinfo;
+ Array<Expr> output_vars;
+ std::unordered_map<Var, int, ObjectPtrHash, ObjectPtrEqual>
output_to_index;
+
+ for (const auto& input : inputs_) {
+ input_sinfo.push_back(Downcast<StructInfo>(input->struct_info_.value()));
+ }
+ Var params("params", TupleStructInfo(input_sinfo));
+
+ // Helper to add a variable to the output tuple
+ // original_var: the binding variable in the original function
+ // output_var: the variable, which is a binding in the transform_params
function, that is added
+ // to the output tuple
+ auto f_add_output = [&](const Var& original_var, const Var& output_var) ->
void {
+ output_to_index[original_var] = output_vars.size();
+ output_vars.push_back(output_var);
+ };
+
+ // Create mapping from the original input variables to the TupleGetItem
from the packed
+ // parameter tuple Add the parameters that are marked as the output of the
function to the
+ // output tuple
+ for (const auto& input : inputs_) {
+ input_remap_.emplace(input.get(), TupleGetItem(params,
input_remap_.size()));
+ if (outputs_.count(input)) {
+ auto output_var = builder_->Emit(input_remap_.at(input.get()));
+ f_add_output(input, output_var);
+ }
+ }
+
+ // Re-emit the bindings that are lifted. Update the output tuple if the
binding is marked as the
+ // output.
+ for (const auto& binding : bindings_) {
+ if (outputs_.count(binding->var)) {
+ auto output_var = builder_->Emit(VisitExpr(binding->value));
+ var_remap_[binding->var->vid] = output_var;
+ f_add_output(binding->var, output_var);
+ } else {
+ VisitBinding(binding);
+ }
+ }
+
+ // Create the function.
+ Expr transformed_params = builder_->EmitOutput(Tuple(output_vars));
+ BindingBlock block = builder_->EndBlock();
+ Expr body = builder_->Normalize(SeqExpr({block}, transformed_params));
+ Function f_transform_params =
+ Function(/*params=*/{params}, /*body=*/body,
/*ret_struct_info=*/NullOpt);
+ return {f_transform_params, output_to_index};
+ }
+
+ Expr VisitExpr_(const VarNode* var) final {
+ if (auto it = input_remap_.find(var); it != input_remap_.end()) {
+ return builder_->Emit((*it).second);
+ } else {
+ return ExprMutator::VisitExpr_(var);
+ }
+ }
+
+ // The input parameters of the function.
+ Array<Var> inputs_;
+ // Remap from the original input variable to TupleGetItem from the packed
parameter tuple, which
+ // is the input of the lifted function.
+ std::unordered_map<const VarNode*, Expr> input_remap_;
+ // The bindings that are lifted.
+ Array<VarBinding> bindings_;
+ // The variables that are marked as the output of the function.
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> outputs_;
+};
+
+/*!
+ * \brief Visitor that creates the plan of lifting transform params.
+ *
+ * Starting from the parameters of the function (they are the initial set of
lifted bindings), we
+ * will visit the body of the function to find the bindings that can be
lifted. A binding can be
+ * lifted if all the variables that it depends on are also lifted.
+ *
+ * When a binding cannot be lifted, all the variables that 1) it depends on,
and 2) have been
+ * lifted, will be marked as the boundary variable and will be in the output
of the lifted function.
+ */
+class LiftTransformParamsPlanner : public ExprVisitor {
+ public:
+ LiftTransformParamsInfoPlan Plan(const Function& function, int num_inputs) {
+ for (int i = num_inputs; i < static_cast<int>(function->params.size());
++i) {
+ builder_.AddInput(function->params[i]);
+ lifted_bindings_.emplace(function->params[i]);
+ }
+ VisitExpr(function->body);
+
+ const auto& [f_transform_params, output_to_index] = builder_.Build();
+ return {f_transform_params, output_to_index, std::move(lifted_bindings_)};
+ }
+
+ private:
+ void VisitBindingBlock_(const DataflowBlockNode* block) final {
+ is_in_dataflow_block_ = true;
+ ExprVisitor::VisitBindingBlock_(block);
+ is_in_dataflow_block_ = false;
+ }
+
+ void VisitBinding_(const VarBindingNode* binding) final {
+ std::vector<const VarNode*> producers;
+ bool can_lift = true;
+ if (!is_in_dataflow_block_) {
+ can_lift = false;
+ }
+
+ PostOrderVisit(binding->value, [&](const ObjectRef& obj) {
+ if (const VarNode* var = obj.as<VarNode>()) {
+ producers.push_back(var);
+ if (!lifted_bindings_.count(GetRef<Var>(var))) {
+ can_lift = false;
+ }
+ }
+ });
+ if (can_lift) {
+ lifted_bindings_.insert(binding->var);
+ builder_.AddBinding(GetRef<VarBinding>(binding));
+ } else {
+ for (const VarNode* producer : producers) {
+ if (lifted_bindings_.count(GetRef<Var>(producer))) {
+ builder_.MarkOutput(GetRef<Var>(producer));
+ }
+ }
+ }
+ }
+
+ // The bindings that are lifted
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> lifted_bindings_;
+ // The builder of the function that transforms the parameters
+ TransformParamsFuncBuilder builder_;
+ // Whether we are in a dataflow block
+ bool is_in_dataflow_block_{false};
+};
+
+/*!
+ *\brief The rewriter that lifts the transform params of a function and
updates the original
+ * function.
+ */
+class TransformParamsLifter : public ExprMutator {
+ public:
+ explicit TransformParamsLifter(const IRModule& module) : ExprMutator(module)
{}
+
+ IRModule Lift() {
+ auto mod = builder_->GetContextIRModule();
+ GlobalVar gv_main = mod->GetGlobalVar("main");
+ Function func = Downcast<Function>(mod->Lookup(gv_main));
+ func = RewriteFunc(func);
+ builder_->UpdateFunction(gv_main, func);
+ return builder_->GetContextIRModule();
+ }
+
+ private:
+ Function RewriteFunc(const Function& func) {
+ const std::string attr_num_input = "num_input";
+ auto opt_num_input = func->attrs.GetAttr<Integer>(attr_num_input);
+ if (!opt_num_input.defined()) {
+ return func;
+ }
+ LiftTransformParamsPlanner planner;
+ int64_t params_begin = opt_num_input.value()->value;
+
+ // Step 1: Create the plan of lifting transform params
+ lift_plan_ = planner.Plan(func, params_begin);
+
+ // Step 2: Add the lifted function to the module
+ builder_->AddFunction(lift_plan_.f_transform_params, "transform_params");
+
+ // Step 3: Update the current function.
+
+ // Step 3.1: Update the function signature
+ Var params("params", lift_plan_.f_transform_params->ret_struct_info);
+ Array<Var> new_params;
+ for (int i = 0; i < params_begin; ++i) {
+ new_params.push_back(func->params[i]);
+ }
+ new_params.push_back(params);
+
+ // Step 3.2: Update the function body
+ for (const auto& [var, index] : lift_plan_.output_to_index) {
+ param_remap_[var] = TupleGetItem(params, index);
+ }
+ auto new_body = VisitExpr(func->body);
+
+ // Step 3.3: Remove function attributes that are not needed
+ auto new_attrs = func->attrs;
+ auto* new_attrs_node = new_attrs.CopyOnWrite();
+ new_attrs_node->dict.erase(attr_num_input);
+ if (new_attrs->dict.empty()) {
+ new_attrs = NullValue<DictAttrs>();
+ }
+
+ Function new_func(new_params, new_body, func->ret_struct_info, new_attrs);
+ return new_func;
+ }
+
+ void VisitBinding_(const VarBindingNode* binding) final {
+ if (lift_plan_.lifted_bindings.count(binding->var)) {
+ return;
+ }
+ ExprMutator::VisitBinding_(binding);
+ }
+
+ Expr VisitExpr_(const VarNode* var) final {
+ auto it = param_remap_.find(GetRef<Var>(var));
+ if (it != param_remap_.end()) {
+ return builder_->Emit(it->second);
+ }
+ return ExprMutator::VisitExpr_(var);
+ }
+
+ Expr VisitExpr_(const DataflowVarNode* var) final {
+ return VisitExpr_(static_cast<const VarNode*>(var));
+ }
+
+ // Remap the original parameters to TupleGetItem from the packed tuple of
transformed parameters.
+ std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_remap_;
+ // The plan of lifting the transform params
+ LiftTransformParamsInfoPlan lift_plan_;
+};
+
+namespace transform {
+Pass LiftTransformParams() {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+ [=](IRModule m, PassContext pc) { return
TransformParamsLifter(m).Lift(); };
+ return CreateModulePass(pass_func, 1, "LiftTransformParams", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.LiftTransformParams").set_body_typed(LiftTransformParams);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_transform_lift_transform_params.py
b/tests/python/relax/test_transform_lift_transform_params.py
new file mode 100644
index 0000000000..a1f67d41da
--- /dev/null
+++ b/tests/python/relax/test_transform_lift_transform_params.py
@@ -0,0 +1,295 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.script import relax as R, tir as T
+import numpy as np
+import tvm.topi.testing
+
+
+def test_basic():
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def transform_layout_IOHW_to_OIHW(
+ w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3,
3), "float32")
+ ) -> None:
+ for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
+ with T.block("layout_transform"):
+ o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ out[o, i, h, w] = w1[i, o, h, w]
+
+ @R.function
+ def main(
+ x: R.Tensor((1, 3, 224, 224), "float32"),
+ w1: R.Tensor((3, 16, 3, 3), "float32"),
+ w2: R.Tensor((16, 16, 3, 3), "float32"),
+ ) -> R.Tensor((1, 16, 224, 224), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ w1_transformed = R.call_tir(
+ transform_layout_IOHW_to_OIHW, w1, R.Tensor((16, 3, 3, 3),
"float32")
+ )
+ conv1 = R.nn.conv2d(
+ x, w1_transformed, padding=(1, 1), data_layout="NCHW",
kernel_layout="OIHW"
+ )
+ conv2 = R.nn.conv2d(
+ conv1, w2, padding=(1, 1), data_layout="NCHW",
kernel_layout="OIHW"
+ )
+ R.output(conv2)
+ return conv2
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3, 224, 224), dtype="float32"),
+ params: R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3,
3), dtype="float32")
+ ),
+ ) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((16, 3, 3, 3), dtype="float32") = params[1]
+ conv1: R.Tensor((1, 16, 224, 224), dtype="float32") =
R.nn.conv2d(
+ x,
+ lv,
+ strides=[1, 1],
+ padding=[1, 1, 1, 1],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="void",
+ )
+ lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
+ conv2: R.Tensor((1, 16, 224, 224), dtype="float32") =
R.nn.conv2d(
+ conv1,
+ lv1,
+ strides=[1, 1],
+ padding=[1, 1, 1, 1],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="void",
+ )
+ R.output(conv2)
+ return conv2
+
+ @T.prim_func
+ def transform_layout_IOHW_to_OIHW(
+ w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3,
3), "float32")
+ ):
+ for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
+ with T.block("layout_transform"):
+ o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(w1[i, o, h, w])
+ T.writes(out[o, i, h, w])
+ out[o, i, h, w] = w1[i, o, h, w]
+
+ @R.function
+ def transform_params(
+ params: R.Tuple(
+ R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3,
3), dtype="float32")
+ )
+ ) -> R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3),
dtype="float32")
+ ):
+ with R.dataflow():
+ lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
+ lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
+ lv2 = R.call_tir(
+ transform_layout_IOHW_to_OIHW,
+ (lv1,),
+ out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
+ )
+ gv: R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"),
+ R.Tensor((16, 3, 3, 3), dtype="float32"),
+ ) = (lv, lv2)
+ R.output(gv)
+ return gv
+
+ mod = Before
+ after = relax.transform.LiftTransformParams()(mod)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_tuple():
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor((1, 16, 224, 224), "float32"), w1: R.Tensor((16, 16,
3, 3), "float32")
+ ) -> R.Tensor((1, 16, 224, 224), "float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ l0 = (w1,)
+ l1 = (l0,)
+ l2 = l1[0]
+ l3 = l2[0]
+ conv1 = R.nn.conv2d(x, l3, padding=(1, 1), data_layout="NCHW",
kernel_layout="OIHW")
+ conv2 = R.nn.conv2d(
+ conv1, w1, padding=(1, 1), data_layout="NCHW",
kernel_layout="OIHW"
+ )
+ R.output(conv2)
+ return conv2
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((1, 16, 224, 224), dtype="float32"),
+ params: R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16,
3, 3), dtype="float32")
+ ),
+ ) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
+ conv1: R.Tensor((1, 16, 224, 224), dtype="float32") =
R.nn.conv2d(
+ x,
+ lv,
+ strides=[1, 1],
+ padding=[1, 1, 1, 1],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="void",
+ )
+ lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
+ conv2: R.Tensor((1, 16, 224, 224), dtype="float32") =
R.nn.conv2d(
+ conv1,
+ lv1,
+ strides=[1, 1],
+ padding=[1, 1, 1, 1],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="void",
+ )
+ R.output(conv2)
+ return conv2
+
+ @R.function
+ def transform_params(
+ params: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32"))
+ ) -> R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3,
3), dtype="float32")
+ ):
+ with R.dataflow():
+ lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
+ lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
+ l0: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32")) = (lv1,)
+ l1: R.Tuple(R.Tuple(R.Tensor((16, 16, 3, 3),
dtype="float32"))) = (l0,)
+ l2: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32")) = l1[0]
+ lv2: R.Tensor((16, 16, 3, 3), dtype="float32") = l2[0]
+ gv: R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"),
+ R.Tensor((16, 16, 3, 3), dtype="float32"),
+ ) = (lv, lv2)
+ R.output(gv)
+ return gv
+
+ mod = Before
+ after = relax.transform.LiftTransformParams()(mod)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_condition():
+ """Test case that the conditional statement can't be lifted"""
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor((1, 16, 224, 224), "float32"),
+ w1: R.Tensor((16, 16, 3, 3), "float32"),
+ w2: R.Tensor((16, 16, 3, 3), "float32"),
+ cond: R.Tensor((), "bool"),
+ ) -> R.Tensor((1, 16, 224, 224), "float32"):
+ R.func_attr({"num_input": 1})
+ if cond:
+ w = w1
+ else:
+ w = w2
+ with R.dataflow():
+ conv1 = R.nn.conv2d(x, w, padding=(1, 1), data_layout="NCHW",
kernel_layout="OIHW")
+ R.output(conv1)
+ return conv1
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def transform_params(
+ params: R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"),
+ R.Tensor((16, 16, 3, 3), dtype="float32"),
+ R.Tensor((), dtype="bool"),
+ )
+ ) -> R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"),
+ R.Tensor((16, 16, 3, 3), dtype="float32"),
+ R.Tensor((), dtype="bool"),
+ ):
+ with R.dataflow():
+ lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
+ lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
+ lv2: R.Tensor((), dtype="bool") = params[2]
+ gv: R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"),
+ R.Tensor((16, 16, 3, 3), dtype="float32"),
+ R.Tensor((), dtype="bool"),
+ ) = (lv, lv1, lv2)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ x: R.Tensor((1, 16, 224, 224), "float32"),
+ params: R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"),
+ R.Tensor((16, 16, 3, 3), dtype="float32"),
+ R.Tensor((), dtype="bool"),
+ ),
+ ) -> R.Tensor((1, 16, 224, 224), "float32"):
+ gv: R.Tensor((), dtype="bool") = params[2]
+ if gv:
+ gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
+ w: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
+ else:
+ gv2: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
+ w: R.Tensor((16, 16, 3, 3), dtype="float32") = gv2
+ with R.dataflow():
+ conv1 = R.nn.conv2d(x, w, padding=(1, 1), data_layout="NCHW",
kernel_layout="OIHW")
+ R.output(conv1)
+ return conv1
+
+ mod = Before
+ after = relax.transform.LiftTransformParams()(mod)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()