This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new c5b7afc721 [Unity] Implemented BundleModelParams transform (#15657)
c5b7afc721 is described below

commit c5b7afc721a1539c6d04d75e21ef0922d6b2a3dc
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Sep 6 12:22:51 2023 -0700

    [Unity] Implemented BundleModelParams transform (#15657)
    
    * [Unity] Implemented BundleModelParams transform
    
    Implemented `relax.transform.BundleModelParams`, which groups
    parameters into user-provided runtime parameters, and a tuple of
    compile-time model weights.  This functionality was previously part of
    `LiftTransformParams`, but is being separated to allow for composible
    functions which each mutate the parameters.
    
    * [Unity] Keep parameters separate in LiftTransformParams
    
    Because parameters may be mutated at multiple points when preparing a
    model (e.g. first by explicit quantizing, and then by lifted
    transformation), each step that alters the parameters should retain
    the same general form.
    
    Prior to this commit, the `LiftTransformParams` pass extracted an
    independent `func_transform_params` function that could be applied to
    the weights, removed the `"num_input"` attribute, and bundled the
    transformed model parameters into a single tuple parameter.
    
    This commit updates `LiftTransformParams` to only perform the first
    step, generating the independent `func_transform_params` function,
    while the remaining steps are performed by `BundleModelParams`.
---
 python/tvm/relax/transform/transform.py            |  17 +++
 src/relax/transform/bundle_model_params.cc         | 119 +++++++++++++++++++++
 src/relax/transform/lift_transform_params.cc       | 106 ++++++++++--------
 .../relax/test_transform_bundle_model_params.py    | 104 ++++++++++++++++++
 .../relax/test_transform_lift_transform_params.py  |  90 ++++++++--------
 5 files changed, 350 insertions(+), 86 deletions(-)

diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index aff73167e4..5c9a2ae554 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -720,6 +720,23 @@ def LiftTransformParams() -> tvm.ir.transform.Pass:
     return _ffi_api.LiftTransformParams()  # type: ignore
 
 
+def BundleModelParams() -> tvm.ir.transform.Pass:
+    """Bundle several model parameters into a single tuple paramters
+
+    For each function, if the function has the attribute "num_input",
+    separate between run-time parameters and compile-time weights.
+    Run-time parameters (e.g. activations) are the first `num_input`
+    parameters, and the remainder are compile-time weights.
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered pass for lifting transformation of parameters.
+
+    """
+    return _ffi_api.BundleModelParams()  # type: ignore
+
+
 def LegalizeOps(
     customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None, 
enable_warning: bool = False
 ):
diff --git a/src/relax/transform/bundle_model_params.cc 
b/src/relax/transform/bundle_model_params.cc
new file mode 100644
index 0000000000..8f6e7a1291
--- /dev/null
+++ b/src/relax/transform/bundle_model_params.cc
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/bundle_model_params.cc
+ * \brief Lift local functions into global functions.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/runtime/logging.h>
+
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+static const auto kAttrNumInput = "num_input";
+
+class ModelParamBundler : public ExprMutator {
+ public:
+  ModelParamBundler() {}
+
+  Expr VisitExpr_(const FunctionNode* op) override {
+    Function func = GetRef<Function>(op);
+    auto opt_num_input = func->attrs.GetAttr<Integer>(kAttrNumInput);
+    if (!opt_num_input) return func;
+    auto signed_num_input = opt_num_input.value()->value;
+
+    ICHECK_GE(signed_num_input, 0);
+    ICHECK_LE(signed_num_input, func->params.size())
+        << "Function was declared to have " << signed_num_input << " runtime 
inputs, "
+        << "but only has " << func->params.size() << " parameters total.";
+    size_t num_input = signed_num_input;
+
+    Array<Var> params;
+    for (size_t i = 0; i < num_input; i++) {
+      params.push_back(func->params[i]);
+    }
+
+    Array<StructInfo> param_tuple;
+    for (size_t i = num_input; i < func->params.size(); i++) {
+      param_tuple.push_back(GetStructInfo(func->params[i]));
+    }
+
+    Var var_param_tuple("model_params", TupleStructInfo(param_tuple));
+    params.push_back(var_param_tuple);
+
+    for (size_t i = num_input; i < func->params.size(); i++) {
+      var_to_expr_.Set(func->params[i], TupleGetItem(var_param_tuple, i - 
num_input));
+    }
+
+    func = WithoutAttr(func, kAttrNumInput);
+    func.CopyOnWrite()->params = params;
+
+    return ExprMutator::VisitExpr_(func.get());
+  }
+
+  Expr VisitExpr_(const VarNode* op) override {
+    auto var = GetRef<Var>(op);
+    if (auto it = var_to_expr_.find(var); it != var_to_expr_.end()) {
+      return (*it).second;
+    } else {
+      return ExprMutator::VisitExpr_(op);
+    }
+  }
+
+ private:
+  Map<Var, Expr> var_to_expr_;
+};
+
+namespace transform {
+Pass BundleModelParams() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule mod,
+                                                                            
PassContext pc) {
+    IRModule updates;
+
+    ModelParamBundler mutator;
+
+    for (const auto& [gvar, func] : mod->functions) {
+      if (auto opt = func.as<relax::Function>()) {
+        auto new_func = Downcast<relax::Function>(mutator(opt.value()));
+        if (!new_func.same_as(func)) {
+          updates->Add(gvar, new_func);
+        }
+      }
+    }
+
+    if (updates->functions.size()) {
+      mod.CopyOnWrite()->Update(updates);
+    }
+    return mod;
+  };
+  return CreateModulePass(pass_func, 1, "BundleModelParams", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.BundleModelParams").set_body_typed(BundleModelParams);
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/lift_transform_params.cc 
b/src/relax/transform/lift_transform_params.cc
index fb1f292776..afa1e191f4 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -218,70 +218,60 @@ class LiftTransformParamsPlanner : public ExprVisitor {
  *\brief The rewriter that lifts the transform params of a function and 
updates the original
  * function.
  */
-class TransformParamsLifter : public ExprMutator {
+class TransformParamsLifter : ExprMutator {
  public:
   explicit TransformParamsLifter(const IRModule& module) : ExprMutator(module) 
{}
 
-  IRModule Lift() {
-    auto mod = builder_->GetContextIRModule();
-    for (const auto& [gv, base_func] : mod->functions) {
-      // Skip non-Relax functions.
-      const auto* func_ = base_func.as<FunctionNode>();
-      if (func_ == nullptr) {
-        continue;
-      }
-      // Skip functions that do not have the `num_input` attribute.
-      Optional<Integer> opt_num_input = 
func_->attrs.GetAttr<Integer>(attr_num_input_);
-      if (!opt_num_input.defined()) {
-        continue;
-      }
-      Function func = RewriteFunc(GetRef<Function>(func_), 
opt_num_input.value()->value,
-                                  gv->name_hint + "_transform_params");
-      builder_->UpdateFunction(gv, func);
-    }
-
-    return builder_->GetContextIRModule();
+  Function VisitFunction(GlobalVar gvar, Function func) {
+    current_gvar_ = gvar;
+    auto out = Downcast<Function>(VisitExpr(std::move(func)));
+    current_gvar_ = NullOpt;
+    return out;
   }
 
+  Map<GlobalVar, Function> GetTransformParamFunctions() const { return 
transform_param_funcs_; }
+
  private:
-  Function RewriteFunc(const Function& func, int num_input, String 
new_func_name) {
+  Expr VisitExpr_(const FunctionNode* op) override {
+    auto func = GetRef<Function>(op);
+    Optional<Integer> opt_num_input = 
func->attrs.GetAttr<Integer>(attr_num_input_);
+    if (!opt_num_input) {
+      return func;
+    }
+    auto signed_num_input = opt_num_input.value()->value;
+    ICHECK_GE(signed_num_input, 0);
+    ICHECK_LE(signed_num_input, func->params.size());
+    size_t num_input = signed_num_input;
+
     LiftTransformParamsPlanner planner;
 
     // Step 1: Create the plan of lifting transform params
     lift_plan_ = planner.Plan(func, num_input);
 
-    // Step 2: Add the lifted function to the module
-    // (The lifted function should be public so we add a global symbol to it)
-    auto lift_func =
-        WithAttr(lift_plan_.f_transform_params, tvm::attr::kGlobalSymbol, 
new_func_name);
-    builder_->AddFunction(lift_func, new_func_name);
+    // Step 2: Stash the lifted function to add to the module
+    transform_param_funcs_.Set(current_gvar_.value(), 
lift_plan_.f_transform_params);
 
     // Step 3: Update the current function.
 
     // Step 3.1: Update the function signature
-    Var params("params", lift_plan_.f_transform_params->ret_struct_info);
-    Array<Var> new_params;
-    for (int i = 0; i < num_input; ++i) {
-      new_params.push_back(func->params[i]);
+    Array<StructInfo> param_fields =
+        
Downcast<TupleStructInfo>(lift_plan_.f_transform_params->ret_struct_info)->fields;
+    Array<Var> new_params(func->params.begin(), func->params.begin() + 
num_input);
+    for (size_t i = 0; i < param_fields.size(); i++) {
+      std::stringstream name;
+      name << "transformed_param_" << i;
+      Var param(name.str(), param_fields[i]);
+      new_params.push_back(param);
     }
-    new_params.push_back(params);
 
     // Step 3.2: Update the function body
     for (const auto& [var, index] : lift_plan_.output_to_index) {
-      param_remap_[var] = TupleGetItem(params, index);
+      ICHECK_LT(num_input + index, new_params.size());
+      param_remap_[var] = new_params[num_input + index];
     }
     auto new_body = VisitWithNewScope(func->body, new_params);
 
-    // Step 3.3: Remove function attributes that are not needed
-    auto new_attrs = func->attrs;
-    auto* new_attrs_node = new_attrs.CopyOnWrite();
-    new_attrs_node->dict.erase(attr_num_input_);
-    if (new_attrs->dict.empty()) {
-      new_attrs = NullValue<DictAttrs>();
-    }
-
-    Function new_func(new_params, new_body, func->ret_struct_info, 
func->is_pure, new_attrs);
-    return new_func;
+    return Function(new_params, new_body, func->ret_struct_info, 
func->is_pure, func->attrs);
   }
 
   void VisitBinding_(const VarBindingNode* binding) final {
@@ -315,12 +305,40 @@ class TransformParamsLifter : public ExprMutator {
   std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_remap_;
   // The plan of lifting the transform params
   LiftTransformParamsInfoPlan lift_plan_;
+
+  Map<GlobalVar, Function> transform_param_funcs_;
+  Optional<GlobalVar> current_gvar_;
 };
 
 namespace transform {
 Pass LiftTransformParams() {
-  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
-      [=](IRModule m, PassContext pc) { return 
TransformParamsLifter(m).Lift(); };
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule mod,
+                                                                            
PassContext pc) {
+    TransformParamsLifter mutator(mod);
+
+    IRModule updates;
+    for (const auto& [gvar, func] : mod->functions) {
+      if (auto opt = func.as<relax::Function>()) {
+        auto new_func = mutator.VisitFunction(gvar, opt.value());
+        if (!new_func.same_as(func)) {
+          updates->Add(gvar, new_func);
+        }
+      }
+    }
+    for (const auto& [gvar, transform_func] : 
mutator.GetTransformParamFunctions()) {
+      String name = gvar->name_hint + "_transform_params";
+      GlobalVar new_gvar(name);
+      new_gvar->struct_info_ = transform_func->struct_info_;
+
+      updates->Add(new_gvar, WithAttr(transform_func, 
tvm::attr::kGlobalSymbol, name));
+    }
+
+    if (updates->functions.size()) {
+      mod.CopyOnWrite()->Update(updates);
+    }
+
+    return mod;
+  };
   return CreateModulePass(pass_func, 1, "LiftTransformParams", {});
 }
 
diff --git a/tests/python/relax/test_transform_bundle_model_params.py 
b/tests/python/relax/test_transform_bundle_model_params.py
new file mode 100644
index 0000000000..8b0a15e647
--- /dev/null
+++ b/tests/python/relax/test_transform_bundle_model_params.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+
+from tvm import relax
+from tvm.script import relax as R, tir as T
+from tvm.script import ir as I
+import tvm.topi.testing
+
+
+def test_basic():
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(
+            a: R.Tensor([16], "float32"),
+            b: R.Tensor([16], "float32"),
+            c: R.Tensor([16], "float32"),
+        ) -> R.Tensor([16], "float32"):
+            R.func_attr({"num_input": 1})
+            expr = a
+            expr = R.add(expr, b)
+            expr = R.add(expr, c)
+            return expr
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            a: R.Tensor([16], "float32"),
+            params: R.Tuple(R.Tensor([16], "float32"), R.Tensor([16], 
"float32")),
+        ) -> R.Tensor([16], "float32"):
+            expr = a
+            b = params[0]
+            expr = R.add(expr, b)
+            c = params[1]
+            expr = R.add(expr, c)
+            return expr
+
+    mod = Before
+    after = relax.transform.BundleModelParams()(mod)
+    tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_no_model_params():
+    """If all parameters are inputs, model params should be an empty tuple
+
+    This ensures that a caller does not need to check whether the
+    model has compile-time inputs, and can instead provide the output
+    of a lifted parameter transformation in all cases, even if that
+    transformation returns an empty tuple.
+    """
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(
+            a: R.Tensor([16], "float32"),
+            b: R.Tensor([16], "float32"),
+            c: R.Tensor([16], "float32"),
+        ) -> R.Tensor([16], "float32"):
+            R.func_attr({"num_input": 3})
+            expr = a
+            expr = R.add(expr, b)
+            expr = R.add(expr, c)
+            return expr
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            a: R.Tensor([16], "float32"),
+            b: R.Tensor([16], "float32"),
+            c: R.Tensor([16], "float32"),
+            params: R.Tuple(),
+        ) -> R.Tensor([16], "float32"):
+            expr = a
+            expr = R.add(expr, b)
+            expr = R.add(expr, c)
+            return expr
+
+    mod = Before
+    after = relax.transform.BundleModelParams()(mod)
+    tvm.ir.assert_structural_equal(after, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_transform_lift_transform_params.py 
b/tests/python/relax/test_transform_lift_transform_params.py
index 2a045e9acb..c23efe655b 100644
--- a/tests/python/relax/test_transform_lift_transform_params.py
+++ b/tests/python/relax/test_transform_lift_transform_params.py
@@ -62,15 +62,15 @@ def test_basic():
         @R.function
         def main(
             x: R.Tensor((1, 3, 224, 224), dtype="float32"),
-            params: R.Tuple(
-                R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 
3), dtype="float32")
-            ),
+            param0: R.Tensor((16, 16, 3, 3), dtype="float32"),
+            param1: R.Tensor((16, 3, 3, 3), dtype="float32"),
         ) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
+            R.func_attr({"num_input": 1})
             with R.dataflow():
-                lv: R.Tensor((16, 3, 3, 3), dtype="float32") = params[1]
+                param1 = param1
                 conv1: R.Tensor((1, 16, 224, 224), dtype="float32") = 
R.nn.conv2d(
                     x,
-                    lv,
+                    param1,
                     strides=[1, 1],
                     padding=[1, 1, 1, 1],
                     dilation=[1, 1],
@@ -80,10 +80,10 @@ def test_basic():
                     out_layout="NCHW",
                     out_dtype="void",
                 )
-                lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
+                param0 = param0
                 conv2: R.Tensor((1, 16, 224, 224), dtype="float32") = 
R.nn.conv2d(
                     conv1,
-                    lv1,
+                    param0,
                     strides=[1, 1],
                     padding=[1, 1, 1, 1],
                     dilation=[1, 1],
@@ -161,12 +161,12 @@ def test_tuple():
         @R.function
         def main(
             x: R.Tensor((1, 16, 224, 224), dtype="float32"),
-            params: R.Tuple(
-                R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 
3, 3), dtype="float32")
-            ),
+            param0: R.Tensor((16, 16, 3, 3), dtype="float32"),
+            param1: R.Tensor((16, 16, 3, 3), dtype="float32"),
         ) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
+            R.func_attr({"num_input": 1})
             with R.dataflow():
-                lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
+                lv: R.Tensor((16, 16, 3, 3), dtype="float32") = param1
                 conv1: R.Tensor((1, 16, 224, 224), dtype="float32") = 
R.nn.conv2d(
                     x,
                     lv,
@@ -179,7 +179,7 @@ def test_tuple():
                     out_layout="NCHW",
                     out_dtype="void",
                 )
-                lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
+                lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = param0
                 conv2: R.Tensor((1, 16, 224, 224), dtype="float32") = 
R.nn.conv2d(
                     conv1,
                     lv1,
@@ -271,18 +271,17 @@ def test_condition():
         @R.function
         def main(
             x: R.Tensor((1, 16, 224, 224), "float32"),
-            params: R.Tuple(
-                R.Tensor((16, 16, 3, 3), dtype="float32"),
-                R.Tensor((16, 16, 3, 3), dtype="float32"),
-                R.Tensor((), dtype="bool"),
-            ),
+            param0: R.Tensor((16, 16, 3, 3), dtype="float32"),
+            param1: R.Tensor((16, 16, 3, 3), dtype="float32"),
+            param2: R.Tensor((), dtype="bool"),
         ) -> R.Tensor((1, 16, 224, 224), "float32"):
-            gv: R.Tensor((), dtype="bool") = params[2]
+            R.func_attr({"num_input": 1})
+            gv: R.Tensor((), dtype="bool") = param2
             if gv:
-                gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
+                gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = param0
                 w: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
             else:
-                gv2: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
+                gv2: R.Tensor((16, 16, 3, 3), dtype="float32") = param1
                 w: R.Tensor((16, 16, 3, 3), dtype="float32") = gv2
             with R.dataflow():
                 conv1 = R.nn.conv2d(x, w, padding=(1, 1), data_layout="NCHW", 
kernel_layout="OIHW")
@@ -337,10 +336,11 @@ def test_multiple_functions():
         @R.function
         def func1(
             x: R.Tensor((256, 256), dtype="float32"),
-            params: R.Tuple(R.Tensor((256, 256), dtype="float32")),
+            param0: R.Tensor((256, 256), dtype="float32"),
         ) -> R.Tensor((256, 256), dtype="float32"):
+            R.func_attr({"num_input": 1})
             with R.dataflow():
-                lv: R.Tensor((256, 256), dtype="float32") = params[0]
+                lv: R.Tensor((256, 256), dtype="float32") = param0
                 y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, lv, 
out_dtype="void")
                 R.output(y)
             return y
@@ -359,10 +359,11 @@ def test_multiple_functions():
         @R.function
         def func2(
             x: R.Tensor((256, 256), dtype="float32"),
-            params: R.Tuple(R.Tensor((256, 128), dtype="float32")),
+            param0: R.Tensor((256, 128), dtype="float32"),
         ) -> R.Tensor((256, 128), dtype="float32"):
+            R.func_attr({"num_input": 1})
             with R.dataflow():
-                lv1: R.Tensor((256, 128), dtype="float32") = params[0]
+                lv1: R.Tensor((256, 128), dtype="float32") = param0
                 y: R.Tensor((256, 128), dtype="float32") = R.matmul(x, lv1, 
out_dtype="void")
                 R.output(y)
             return y
@@ -415,10 +416,11 @@ def test_stop_lifting():
         @R.function
         def func1(
             x: R.Tensor((256, 256), dtype="float32"),
-            params: R.Tuple(R.Tensor((256, 256), dtype="float32")),
+            param0: R.Tensor((256, 256), dtype="float32"),
         ) -> R.Tensor((256, 256), dtype="float32"):
+            R.func_attr({"num_input": 1})
             with R.dataflow():
-                lv: R.Tensor((256, 256), dtype="float32") = params[0]
+                lv: R.Tensor((256, 256), dtype="float32") = param0
                 w1_add: R.Tensor((256, 256), dtype="float32") = R.add(lv, 
R.const(1, "float32"))
                 y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, w1_add, 
out_dtype="void")
                 R.output(y)
@@ -440,9 +442,9 @@ def test_stop_lifting():
     tvm.ir.assert_structural_equal(after, Expected)
 
 
-def test_symbolic_var():
+def test_symbolic_var_1():
     @tvm.script.ir_module
-    class Before1:
+    class Before:
         @R.function
         def main(shape: R.Shape(["n"])):
             R.func_attr({"num_input": 1})
@@ -452,7 +454,7 @@ def test_symbolic_var():
             return shape
 
     @I.ir_module
-    class Expected1:
+    class Expected:
         @R.function
         def main_transform_params(params: R.Tuple) -> R.Tuple:
             with R.dataflow():
@@ -461,15 +463,22 @@ def test_symbolic_var():
             return gv
 
         @R.function
-        def main(shape: R.Shape(["n"]), params: R.Tuple) -> R.Shape(["n"]):
+        def main(shape: R.Shape(["n"])) -> R.Shape(["n"]):
+            R.func_attr({"num_input": 1})
             n = T.int64()
             with R.dataflow():
                 zeros: R.Tensor((n, n), dtype="float32") = R.zeros(R.shape([n, 
n]), dtype="float32")
                 R.output()
             return shape
 
+    mod = Before
+    after = relax.transform.LiftTransformParams()(mod)
+    tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_symbolic_var_2():
     @I.ir_module
-    class Before2:
+    class Before:
         @T.prim_func
         def zeros(var_T_full: T.handle):
             T.func_attr({"tir.noalias": T.bool(True)})
@@ -484,9 +493,9 @@ def test_symbolic_var():
 
         @R.function
         def main(shape: R.Shape(["n"])) -> R.Shape(["n"]):
-            n = T.int64()
             R.func_attr({"num_input": 1})
-            cls = Before2
+            n = T.int64()
+            cls = Before
             with R.dataflow():
                 zeros = R.call_tir(
                     cls.zeros, R.tuple(), out_sinfo=R.Tensor((n, n), 
dtype="float32")
@@ -495,7 +504,7 @@ def test_symbolic_var():
             return shape
 
     @I.ir_module
-    class Expected2:
+    class Expected:
         @T.prim_func
         def zeros(var_T_full: T.handle):
             T.func_attr({"tir.noalias": T.bool(True)})
@@ -517,9 +526,10 @@ def test_symbolic_var():
             return gv
 
         @R.function
-        def main(shape: R.Shape(["n"]), params: R.Tuple) -> R.Shape(["n"]):
+        def main(shape: R.Shape(["n"])) -> R.Shape(["n"]):
+            R.func_attr({"num_input": 1})
             n = T.int64()
-            cls = Expected2
+            cls = Expected
             with R.dataflow():
                 zeros = R.call_tir(
                     cls.zeros, R.tuple(), out_sinfo=R.Tensor((n, n), 
dtype="float32")
@@ -527,13 +537,9 @@ def test_symbolic_var():
                 R.output()
             return shape
 
-    mod = Before1
-    after = relax.transform.LiftTransformParams()(mod)
-    tvm.ir.assert_structural_equal(after, Expected1)
-
-    mod = Before2
+    mod = Before
     after = relax.transform.LiftTransformParams()(mod)
-    tvm.ir.assert_structural_equal(after, Expected2)
+    tvm.ir.assert_structural_equal(after, Expected)
 
 
 if __name__ == "__main__":

Reply via email to