This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new c5b7afc721 [Unity] Implemented BundleModelParams transform (#15657)
c5b7afc721 is described below
commit c5b7afc721a1539c6d04d75e21ef0922d6b2a3dc
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Sep 6 12:22:51 2023 -0700
[Unity] Implemented BundleModelParams transform (#15657)
* [Unity] Implemented BundleModelParams transform
Implemented `relax.transform.BundleModelParams`, which groups
parameters into user-provided runtime parameters, and a tuple of
compile-time model weights. This functionality was previously part of
`LiftTransformParams`, but is being separated to allow for composible
functions which each mutate the parameters.
* [Unity] Keep parameters separate in LiftTransformParams
Because parameters may be mutated at multiple points when preparing a
model (e.g. first by explicit quantizing, and then by lifted
transformation), each step that alters the parameters should retain
the same general form.
Prior to this commit, the `LiftTransformParams` pass extracted an
independent `func_transform_params` function that could be applied to
the weights, removed the `"num_input"` attribute, and bundled the
transformed model parameters into a single tuple parameter.
This commit updates `LiftTransformParams` to only perform the first
step, generating the independent `func_transform_params` function,
while the remaining steps are performed by `BundleModelParams`.
---
python/tvm/relax/transform/transform.py | 17 +++
src/relax/transform/bundle_model_params.cc | 119 +++++++++++++++++++++
src/relax/transform/lift_transform_params.cc | 106 ++++++++++--------
.../relax/test_transform_bundle_model_params.py | 104 ++++++++++++++++++
.../relax/test_transform_lift_transform_params.py | 90 ++++++++--------
5 files changed, 350 insertions(+), 86 deletions(-)
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index aff73167e4..5c9a2ae554 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -720,6 +720,23 @@ def LiftTransformParams() -> tvm.ir.transform.Pass:
return _ffi_api.LiftTransformParams() # type: ignore
+def BundleModelParams() -> tvm.ir.transform.Pass:
+ """Bundle several model parameters into a single tuple paramters
+
+ For each function, if the function has the attribute "num_input",
+ separate between run-time parameters and compile-time weights.
+ Run-time parameters (e.g. activations) are the first `num_input`
+ parameters, and the remainder are compile-time weights.
+
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The registered pass for lifting transformation of parameters.
+
+ """
+ return _ffi_api.BundleModelParams() # type: ignore
+
+
def LegalizeOps(
customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None,
enable_warning: bool = False
):
diff --git a/src/relax/transform/bundle_model_params.cc
b/src/relax/transform/bundle_model_params.cc
new file mode 100644
index 0000000000..8f6e7a1291
--- /dev/null
+++ b/src/relax/transform/bundle_model_params.cc
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/bundle_model_params.cc
+ * \brief Lift local functions into global functions.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/runtime/logging.h>
+
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+static const auto kAttrNumInput = "num_input";
+
+class ModelParamBundler : public ExprMutator {
+ public:
+ ModelParamBundler() {}
+
+ Expr VisitExpr_(const FunctionNode* op) override {
+ Function func = GetRef<Function>(op);
+ auto opt_num_input = func->attrs.GetAttr<Integer>(kAttrNumInput);
+ if (!opt_num_input) return func;
+ auto signed_num_input = opt_num_input.value()->value;
+
+ ICHECK_GE(signed_num_input, 0);
+ ICHECK_LE(signed_num_input, func->params.size())
+ << "Function was declared to have " << signed_num_input << " runtime
inputs, "
+ << "but only has " << func->params.size() << " parameters total.";
+ size_t num_input = signed_num_input;
+
+ Array<Var> params;
+ for (size_t i = 0; i < num_input; i++) {
+ params.push_back(func->params[i]);
+ }
+
+ Array<StructInfo> param_tuple;
+ for (size_t i = num_input; i < func->params.size(); i++) {
+ param_tuple.push_back(GetStructInfo(func->params[i]));
+ }
+
+ Var var_param_tuple("model_params", TupleStructInfo(param_tuple));
+ params.push_back(var_param_tuple);
+
+ for (size_t i = num_input; i < func->params.size(); i++) {
+ var_to_expr_.Set(func->params[i], TupleGetItem(var_param_tuple, i -
num_input));
+ }
+
+ func = WithoutAttr(func, kAttrNumInput);
+ func.CopyOnWrite()->params = params;
+
+ return ExprMutator::VisitExpr_(func.get());
+ }
+
+ Expr VisitExpr_(const VarNode* op) override {
+ auto var = GetRef<Var>(op);
+ if (auto it = var_to_expr_.find(var); it != var_to_expr_.end()) {
+ return (*it).second;
+ } else {
+ return ExprMutator::VisitExpr_(op);
+ }
+ }
+
+ private:
+ Map<Var, Expr> var_to_expr_;
+};
+
+namespace transform {
+Pass BundleModelParams() {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod,
+
PassContext pc) {
+ IRModule updates;
+
+ ModelParamBundler mutator;
+
+ for (const auto& [gvar, func] : mod->functions) {
+ if (auto opt = func.as<relax::Function>()) {
+ auto new_func = Downcast<relax::Function>(mutator(opt.value()));
+ if (!new_func.same_as(func)) {
+ updates->Add(gvar, new_func);
+ }
+ }
+ }
+
+ if (updates->functions.size()) {
+ mod.CopyOnWrite()->Update(updates);
+ }
+ return mod;
+ };
+ return CreateModulePass(pass_func, 1, "BundleModelParams", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.BundleModelParams").set_body_typed(BundleModelParams);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/transform/lift_transform_params.cc
b/src/relax/transform/lift_transform_params.cc
index fb1f292776..afa1e191f4 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -218,70 +218,60 @@ class LiftTransformParamsPlanner : public ExprVisitor {
*\brief The rewriter that lifts the transform params of a function and
updates the original
* function.
*/
-class TransformParamsLifter : public ExprMutator {
+class TransformParamsLifter : ExprMutator {
public:
explicit TransformParamsLifter(const IRModule& module) : ExprMutator(module)
{}
- IRModule Lift() {
- auto mod = builder_->GetContextIRModule();
- for (const auto& [gv, base_func] : mod->functions) {
- // Skip non-Relax functions.
- const auto* func_ = base_func.as<FunctionNode>();
- if (func_ == nullptr) {
- continue;
- }
- // Skip functions that do not have the `num_input` attribute.
- Optional<Integer> opt_num_input =
func_->attrs.GetAttr<Integer>(attr_num_input_);
- if (!opt_num_input.defined()) {
- continue;
- }
- Function func = RewriteFunc(GetRef<Function>(func_),
opt_num_input.value()->value,
- gv->name_hint + "_transform_params");
- builder_->UpdateFunction(gv, func);
- }
-
- return builder_->GetContextIRModule();
+ Function VisitFunction(GlobalVar gvar, Function func) {
+ current_gvar_ = gvar;
+ auto out = Downcast<Function>(VisitExpr(std::move(func)));
+ current_gvar_ = NullOpt;
+ return out;
}
+ Map<GlobalVar, Function> GetTransformParamFunctions() const { return
transform_param_funcs_; }
+
private:
- Function RewriteFunc(const Function& func, int num_input, String
new_func_name) {
+ Expr VisitExpr_(const FunctionNode* op) override {
+ auto func = GetRef<Function>(op);
+ Optional<Integer> opt_num_input =
func->attrs.GetAttr<Integer>(attr_num_input_);
+ if (!opt_num_input) {
+ return func;
+ }
+ auto signed_num_input = opt_num_input.value()->value;
+ ICHECK_GE(signed_num_input, 0);
+ ICHECK_LE(signed_num_input, func->params.size());
+ size_t num_input = signed_num_input;
+
LiftTransformParamsPlanner planner;
// Step 1: Create the plan of lifting transform params
lift_plan_ = planner.Plan(func, num_input);
- // Step 2: Add the lifted function to the module
- // (The lifted function should be public so we add a global symbol to it)
- auto lift_func =
- WithAttr(lift_plan_.f_transform_params, tvm::attr::kGlobalSymbol,
new_func_name);
- builder_->AddFunction(lift_func, new_func_name);
+ // Step 2: Stash the lifted function to add to the module
+ transform_param_funcs_.Set(current_gvar_.value(),
lift_plan_.f_transform_params);
// Step 3: Update the current function.
// Step 3.1: Update the function signature
- Var params("params", lift_plan_.f_transform_params->ret_struct_info);
- Array<Var> new_params;
- for (int i = 0; i < num_input; ++i) {
- new_params.push_back(func->params[i]);
+ Array<StructInfo> param_fields =
+
Downcast<TupleStructInfo>(lift_plan_.f_transform_params->ret_struct_info)->fields;
+ Array<Var> new_params(func->params.begin(), func->params.begin() +
num_input);
+ for (size_t i = 0; i < param_fields.size(); i++) {
+ std::stringstream name;
+ name << "transformed_param_" << i;
+ Var param(name.str(), param_fields[i]);
+ new_params.push_back(param);
}
- new_params.push_back(params);
// Step 3.2: Update the function body
for (const auto& [var, index] : lift_plan_.output_to_index) {
- param_remap_[var] = TupleGetItem(params, index);
+ ICHECK_LT(num_input + index, new_params.size());
+ param_remap_[var] = new_params[num_input + index];
}
auto new_body = VisitWithNewScope(func->body, new_params);
- // Step 3.3: Remove function attributes that are not needed
- auto new_attrs = func->attrs;
- auto* new_attrs_node = new_attrs.CopyOnWrite();
- new_attrs_node->dict.erase(attr_num_input_);
- if (new_attrs->dict.empty()) {
- new_attrs = NullValue<DictAttrs>();
- }
-
- Function new_func(new_params, new_body, func->ret_struct_info,
func->is_pure, new_attrs);
- return new_func;
+ return Function(new_params, new_body, func->ret_struct_info,
func->is_pure, func->attrs);
}
void VisitBinding_(const VarBindingNode* binding) final {
@@ -315,12 +305,40 @@ class TransformParamsLifter : public ExprMutator {
std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_remap_;
// The plan of lifting the transform params
LiftTransformParamsInfoPlan lift_plan_;
+
+ Map<GlobalVar, Function> transform_param_funcs_;
+ Optional<GlobalVar> current_gvar_;
};
namespace transform {
Pass LiftTransformParams() {
- runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
- [=](IRModule m, PassContext pc) { return
TransformParamsLifter(m).Lift(); };
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod,
+
PassContext pc) {
+ TransformParamsLifter mutator(mod);
+
+ IRModule updates;
+ for (const auto& [gvar, func] : mod->functions) {
+ if (auto opt = func.as<relax::Function>()) {
+ auto new_func = mutator.VisitFunction(gvar, opt.value());
+ if (!new_func.same_as(func)) {
+ updates->Add(gvar, new_func);
+ }
+ }
+ }
+ for (const auto& [gvar, transform_func] :
mutator.GetTransformParamFunctions()) {
+ String name = gvar->name_hint + "_transform_params";
+ GlobalVar new_gvar(name);
+ new_gvar->struct_info_ = transform_func->struct_info_;
+
+ updates->Add(new_gvar, WithAttr(transform_func,
tvm::attr::kGlobalSymbol, name));
+ }
+
+ if (updates->functions.size()) {
+ mod.CopyOnWrite()->Update(updates);
+ }
+
+ return mod;
+ };
return CreateModulePass(pass_func, 1, "LiftTransformParams", {});
}
diff --git a/tests/python/relax/test_transform_bundle_model_params.py
b/tests/python/relax/test_transform_bundle_model_params.py
new file mode 100644
index 0000000000..8b0a15e647
--- /dev/null
+++ b/tests/python/relax/test_transform_bundle_model_params.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+
+from tvm import relax
+from tvm.script import relax as R, tir as T
+from tvm.script import ir as I
+import tvm.topi.testing
+
+
+def test_basic():
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(
+ a: R.Tensor([16], "float32"),
+ b: R.Tensor([16], "float32"),
+ c: R.Tensor([16], "float32"),
+ ) -> R.Tensor([16], "float32"):
+ R.func_attr({"num_input": 1})
+ expr = a
+ expr = R.add(expr, b)
+ expr = R.add(expr, c)
+ return expr
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ a: R.Tensor([16], "float32"),
+ params: R.Tuple(R.Tensor([16], "float32"), R.Tensor([16],
"float32")),
+ ) -> R.Tensor([16], "float32"):
+ expr = a
+ b = params[0]
+ expr = R.add(expr, b)
+ c = params[1]
+ expr = R.add(expr, c)
+ return expr
+
+ mod = Before
+ after = relax.transform.BundleModelParams()(mod)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_no_model_params():
+ """If all parameters are inputs, model params should be an empty tuple
+
+ This ensures that a caller does not need to check whether the
+ model has compile-time inputs, and can instead provide the output
+ of a lifted parameter transformation in all cases, even if that
+ transformation returns an empty tuple.
+ """
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(
+ a: R.Tensor([16], "float32"),
+ b: R.Tensor([16], "float32"),
+ c: R.Tensor([16], "float32"),
+ ) -> R.Tensor([16], "float32"):
+ R.func_attr({"num_input": 3})
+ expr = a
+ expr = R.add(expr, b)
+ expr = R.add(expr, c)
+ return expr
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ a: R.Tensor([16], "float32"),
+ b: R.Tensor([16], "float32"),
+ c: R.Tensor([16], "float32"),
+ params: R.Tuple(),
+ ) -> R.Tensor([16], "float32"):
+ expr = a
+ expr = R.add(expr, b)
+ expr = R.add(expr, c)
+ return expr
+
+ mod = Before
+ after = relax.transform.BundleModelParams()(mod)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_transform_lift_transform_params.py
b/tests/python/relax/test_transform_lift_transform_params.py
index 2a045e9acb..c23efe655b 100644
--- a/tests/python/relax/test_transform_lift_transform_params.py
+++ b/tests/python/relax/test_transform_lift_transform_params.py
@@ -62,15 +62,15 @@ def test_basic():
@R.function
def main(
x: R.Tensor((1, 3, 224, 224), dtype="float32"),
- params: R.Tuple(
- R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3,
3), dtype="float32")
- ),
+ param0: R.Tensor((16, 16, 3, 3), dtype="float32"),
+ param1: R.Tensor((16, 3, 3, 3), dtype="float32"),
) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
+ R.func_attr({"num_input": 1})
with R.dataflow():
- lv: R.Tensor((16, 3, 3, 3), dtype="float32") = params[1]
+ param1 = param1
conv1: R.Tensor((1, 16, 224, 224), dtype="float32") =
R.nn.conv2d(
x,
- lv,
+ param1,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
@@ -80,10 +80,10 @@ def test_basic():
out_layout="NCHW",
out_dtype="void",
)
- lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
+ param0 = param0
conv2: R.Tensor((1, 16, 224, 224), dtype="float32") =
R.nn.conv2d(
conv1,
- lv1,
+ param0,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
@@ -161,12 +161,12 @@ def test_tuple():
@R.function
def main(
x: R.Tensor((1, 16, 224, 224), dtype="float32"),
- params: R.Tuple(
- R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16,
3, 3), dtype="float32")
- ),
+ param0: R.Tensor((16, 16, 3, 3), dtype="float32"),
+ param1: R.Tensor((16, 16, 3, 3), dtype="float32"),
) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
+ R.func_attr({"num_input": 1})
with R.dataflow():
- lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
+ lv: R.Tensor((16, 16, 3, 3), dtype="float32") = param1
conv1: R.Tensor((1, 16, 224, 224), dtype="float32") =
R.nn.conv2d(
x,
lv,
@@ -179,7 +179,7 @@ def test_tuple():
out_layout="NCHW",
out_dtype="void",
)
- lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
+ lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = param0
conv2: R.Tensor((1, 16, 224, 224), dtype="float32") =
R.nn.conv2d(
conv1,
lv1,
@@ -271,18 +271,17 @@ def test_condition():
@R.function
def main(
x: R.Tensor((1, 16, 224, 224), "float32"),
- params: R.Tuple(
- R.Tensor((16, 16, 3, 3), dtype="float32"),
- R.Tensor((16, 16, 3, 3), dtype="float32"),
- R.Tensor((), dtype="bool"),
- ),
+ param0: R.Tensor((16, 16, 3, 3), dtype="float32"),
+ param1: R.Tensor((16, 16, 3, 3), dtype="float32"),
+ param2: R.Tensor((), dtype="bool"),
) -> R.Tensor((1, 16, 224, 224), "float32"):
- gv: R.Tensor((), dtype="bool") = params[2]
+ R.func_attr({"num_input": 1})
+ gv: R.Tensor((), dtype="bool") = param2
if gv:
- gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
+ gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = param0
w: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
else:
- gv2: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
+ gv2: R.Tensor((16, 16, 3, 3), dtype="float32") = param1
w: R.Tensor((16, 16, 3, 3), dtype="float32") = gv2
with R.dataflow():
conv1 = R.nn.conv2d(x, w, padding=(1, 1), data_layout="NCHW",
kernel_layout="OIHW")
@@ -337,10 +336,11 @@ def test_multiple_functions():
@R.function
def func1(
x: R.Tensor((256, 256), dtype="float32"),
- params: R.Tuple(R.Tensor((256, 256), dtype="float32")),
+ param0: R.Tensor((256, 256), dtype="float32"),
) -> R.Tensor((256, 256), dtype="float32"):
+ R.func_attr({"num_input": 1})
with R.dataflow():
- lv: R.Tensor((256, 256), dtype="float32") = params[0]
+ lv: R.Tensor((256, 256), dtype="float32") = param0
y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, lv,
out_dtype="void")
R.output(y)
return y
@@ -359,10 +359,11 @@ def test_multiple_functions():
@R.function
def func2(
x: R.Tensor((256, 256), dtype="float32"),
- params: R.Tuple(R.Tensor((256, 128), dtype="float32")),
+ param0: R.Tensor((256, 128), dtype="float32"),
) -> R.Tensor((256, 128), dtype="float32"):
+ R.func_attr({"num_input": 1})
with R.dataflow():
- lv1: R.Tensor((256, 128), dtype="float32") = params[0]
+ lv1: R.Tensor((256, 128), dtype="float32") = param0
y: R.Tensor((256, 128), dtype="float32") = R.matmul(x, lv1,
out_dtype="void")
R.output(y)
return y
@@ -415,10 +416,11 @@ def test_stop_lifting():
@R.function
def func1(
x: R.Tensor((256, 256), dtype="float32"),
- params: R.Tuple(R.Tensor((256, 256), dtype="float32")),
+ param0: R.Tensor((256, 256), dtype="float32"),
) -> R.Tensor((256, 256), dtype="float32"):
+ R.func_attr({"num_input": 1})
with R.dataflow():
- lv: R.Tensor((256, 256), dtype="float32") = params[0]
+ lv: R.Tensor((256, 256), dtype="float32") = param0
w1_add: R.Tensor((256, 256), dtype="float32") = R.add(lv,
R.const(1, "float32"))
y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, w1_add,
out_dtype="void")
R.output(y)
@@ -440,9 +442,9 @@ def test_stop_lifting():
tvm.ir.assert_structural_equal(after, Expected)
-def test_symbolic_var():
+def test_symbolic_var_1():
@tvm.script.ir_module
- class Before1:
+ class Before:
@R.function
def main(shape: R.Shape(["n"])):
R.func_attr({"num_input": 1})
@@ -452,7 +454,7 @@ def test_symbolic_var():
return shape
@I.ir_module
- class Expected1:
+ class Expected:
@R.function
def main_transform_params(params: R.Tuple) -> R.Tuple:
with R.dataflow():
@@ -461,15 +463,22 @@ def test_symbolic_var():
return gv
@R.function
- def main(shape: R.Shape(["n"]), params: R.Tuple) -> R.Shape(["n"]):
+ def main(shape: R.Shape(["n"])) -> R.Shape(["n"]):
+ R.func_attr({"num_input": 1})
n = T.int64()
with R.dataflow():
zeros: R.Tensor((n, n), dtype="float32") = R.zeros(R.shape([n,
n]), dtype="float32")
R.output()
return shape
+ mod = Before
+ after = relax.transform.LiftTransformParams()(mod)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_symbolic_var_2():
@I.ir_module
- class Before2:
+ class Before:
@T.prim_func
def zeros(var_T_full: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
@@ -484,9 +493,9 @@ def test_symbolic_var():
@R.function
def main(shape: R.Shape(["n"])) -> R.Shape(["n"]):
- n = T.int64()
R.func_attr({"num_input": 1})
- cls = Before2
+ n = T.int64()
+ cls = Before
with R.dataflow():
zeros = R.call_tir(
cls.zeros, R.tuple(), out_sinfo=R.Tensor((n, n),
dtype="float32")
@@ -495,7 +504,7 @@ def test_symbolic_var():
return shape
@I.ir_module
- class Expected2:
+ class Expected:
@T.prim_func
def zeros(var_T_full: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
@@ -517,9 +526,10 @@ def test_symbolic_var():
return gv
@R.function
- def main(shape: R.Shape(["n"]), params: R.Tuple) -> R.Shape(["n"]):
+ def main(shape: R.Shape(["n"])) -> R.Shape(["n"]):
+ R.func_attr({"num_input": 1})
n = T.int64()
- cls = Expected2
+ cls = Expected
with R.dataflow():
zeros = R.call_tir(
cls.zeros, R.tuple(), out_sinfo=R.Tensor((n, n),
dtype="float32")
@@ -527,13 +537,9 @@ def test_symbolic_var():
R.output()
return shape
- mod = Before1
- after = relax.transform.LiftTransformParams()(mod)
- tvm.ir.assert_structural_equal(after, Expected1)
-
- mod = Before2
+ mod = Before
after = relax.transform.LiftTransformParams()(mod)
- tvm.ir.assert_structural_equal(after, Expected2)
+ tvm.ir.assert_structural_equal(after, Expected)
if __name__ == "__main__":