This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new daf9c202ac [Unity][IR][UX] Privacy annotation in Relax (#15140)
daf9c202ac is described below

commit daf9c202ac5ad0b0a64043b486894b139e64e4de
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Tue Jun 27 09:31:44 2023 -0400

    [Unity][IR][UX] Privacy annotation in Relax (#15140)
    
    This PR implements the privacy annotation proposal. Namely, the @R.function 
decorator now has an optional private attribute. If a function is marked as 
private, then it will not have a global symbol attached to it and thus will not 
be externally accessible. By default, functions are not private, so the parser 
does insert a global symbol for them.
---
 include/tvm/script/ir_builder/relax/frame.h        |   2 +
 include/tvm/script/ir_builder/relax/ir.h           |   3 +-
 python/tvm/relax/block_builder.py                  |  12 ++
 python/tvm/relax/frontend/torch/dynamo.py          |   3 +-
 python/tvm/relax/training/setup_trainer.py         |   5 +-
 python/tvm/script/ir_builder/relax/ir.py           |   9 +-
 python/tvm/script/parser/core/parser.py            |   2 +
 python/tvm/script/parser/relax/entry.py            |   6 +-
 python/tvm/script/parser/relax/parser.py           |  27 ++--
 src/relax/training/utils.cc                        |   4 +-
 src/relax/transform/gradient.cc                    |   8 +-
 src/relax/transform/lift_transform_params.cc       |   5 +-
 src/script/ir_builder/relax/frame.cc               |   5 +
 src/script/ir_builder/relax/ir.cc                  |   3 +-
 src/script/printer/relax/function.cc               |  69 ++++++++-
 tests/python/relax/test_dataflow_pattern.py        |  12 +-
 tests/python/relax/test_training_loss.py           |  10 ++
 tests/python/relax/test_training_optimizer.py      |   8 ++
 .../relax/test_training_optimizer_numeric.py       |   2 +-
 .../relax/test_transform_attach_global_symbol.py   |   6 +-
 .../test_transform_combine_parallel_matmul.py      |  23 +--
 .../relax/test_transform_dead_code_elimination.py  |   8 +-
 tests/python/relax/test_transform_fuse_ops.py      |  52 ++++---
 .../relax/test_transform_fuse_ops_by_pattern.py    |  29 ++--
 tests/python/relax/test_transform_fuse_tir.py      |   6 +-
 tests/python/relax/test_transform_lambda_lift.py   |  14 +-
 .../test_transform_merge_composite_functions.py    |  44 +++---
 tests/python/relax/test_transform_normalize.py     |  20 +--
 .../relax/test_transform_rewrite_cuda_graph.py     |   8 +-
 tests/python/relax/test_tvmscript_parser.py        |  27 +++-
 .../relax/test_tvmscript_parser_op_datatype.py     |   2 +-
 tests/python/relax/test_tvmscript_printer_relax.py | 159 ++++++++++++++++++++-
 tests/python/relax/test_utils.py                   |   8 +-
 33 files changed, 456 insertions(+), 145 deletions(-)

diff --git a/include/tvm/script/ir_builder/relax/frame.h 
b/include/tvm/script/ir_builder/relax/frame.h
index 9a8f835e81..1ad6813889 100644
--- a/include/tvm/script/ir_builder/relax/frame.h
+++ b/include/tvm/script/ir_builder/relax/frame.h
@@ -99,6 +99,8 @@ class FunctionFrameNode : public SeqExprFrameNode {
   Optional<tvm::relax::StructInfo> ret_struct_info;
   /*! \brief Whether the function is annotated as pure */
   Optional<Bool> is_pure;
+  /*! \brief Whether the function is annotated as private */
+  Optional<Bool> is_private;
   /*! \brief The function attributes. */
   Map<String, ObjectRef> attrs;
   /*! \brief The block builder to create Relax function. */
diff --git a/include/tvm/script/ir_builder/relax/ir.h 
b/include/tvm/script/ir_builder/relax/ir.h
index 1cf30b4919..d160ad090e 100644
--- a/include/tvm/script/ir_builder/relax/ir.h
+++ b/include/tvm/script/ir_builder/relax/ir.h
@@ -34,9 +34,10 @@ namespace relax {
 /*!
  * \brief Start a function frame.
  * \param is_pure Whether the function is annotated as pure.
+ * \param is_private Whether the function is annotated as private.
  * \return The created ir_builder Function frame.
  */
-TVM_DLL FunctionFrame Function(const Bool& is_pure);
+TVM_DLL FunctionFrame Function(const Bool& is_pure, const Bool& is_private);
 
 /*!
  * \brief Add a parameter to the last function frame.
diff --git a/python/tvm/relax/block_builder.py 
b/python/tvm/relax/block_builder.py
index 80edb31efb..502073edf2 100644
--- a/python/tvm/relax/block_builder.py
+++ b/python/tvm/relax/block_builder.py
@@ -199,6 +199,7 @@ class BlockBuilder(Object):
         name: str,
         params: Optional[Union[Var, Tuple, List[Var]]] = None,
         attrs: Optional[Dict[str, Object]] = None,
+        private: bool = False,
     ) -> FunctionScope:
         """Annotate a Relax function.
 
@@ -215,6 +216,12 @@ class BlockBuilder(Object):
         attrs : Dict[str, Object], optional
             The function attrs
 
+        private : bool, optional
+            Whether the function is annotated as private.
+            If the function is private, it will not have a global symbol 
attribute.
+            If it is not private and not an inner function, then it will have
+            a global symbol attribute (mapped to the function's name)
+
         Returns
         -------
         ret: FunctionScope
@@ -233,6 +240,11 @@ class BlockBuilder(Object):
                     )
         if attrs is None:
             attrs = {}
+        # The block builder does not permit nesting functions, per above 
comment,
+        # so no further check should be needed
+        if not private:
+            attrs["global_symbol"] = name
+
         return FunctionScope(self, name, params, attrs)
 
     def testing_scope(self, def_vars: List[tir.Var]) -> TestingScope:
diff --git a/python/tvm/relax/frontend/torch/dynamo.py 
b/python/tvm/relax/frontend/torch/dynamo.py
index 3015f77428..abdf7b8862 100644
--- a/python/tvm/relax/frontend/torch/dynamo.py
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -154,7 +154,8 @@ def dynamo_capture_subgraphs(model, *params, **kwargs) -> 
tvm.IRModule:
             keep_params_as_input=keep_params_as_input,
             unwrap_unit_return_tuple=True,
         )
-        mod[f"subgraph_{len(mod.get_global_vars())}"] = mod_["main"]
+        new_name = f"subgraph_{len(mod.get_global_vars())}"
+        mod[new_name] = mod_["main"].with_attr("global_symbol", new_name)
         return graph_module.forward
 
     dynamo.reset()
diff --git a/python/tvm/relax/training/setup_trainer.py 
b/python/tvm/relax/training/setup_trainer.py
index 81ecaf4ea5..2e20570869 100644
--- a/python/tvm/relax/training/setup_trainer.py
+++ b/python/tvm/relax/training/setup_trainer.py
@@ -198,7 +198,10 @@ class SetupTrainer:
 
         # Add optimizer function.
         self._optimizer.init(params)
-        mod[self.OPTIMIZER_FUNC] = self._optimizer.get_function()
+        # Need the global symbol to match the function's name
+        mod[self.OPTIMIZER_FUNC] = self._optimizer.get_function().with_attr(
+            "global_symbol", self.OPTIMIZER_FUNC
+        )
 
         # Module attrs
         mod = mod.with_attrs(
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index e54e4aa07b..b06d9547ac 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -165,19 +165,24 @@ py_str = str
 ############################### Function ################################
 
 
-def function(is_pure: bool = True) -> frame.FunctionFrame:
+def function(is_pure: bool = True, is_private: bool = False) -> 
frame.FunctionFrame:
     """Start a function frame.
     Parameters
     ----------
     is_pure: bool
         Whether the function is annotated as pure.
 
+    is_private : bool
+        Whether the function is annotated as private.
+
     Returns
     -------
     frame: FunctionFrame
         The constructed function frame.
     """
-    return _ffi_api.Function(is_pure)  # type: ignore[attr-defined] # pylint: 
disable=no-member
+    return _ffi_api.Function(  # type: ignore[attr-defined]  # pylint: 
disable=no-member
+        is_pure, is_private
+    )
 
 
 def arg(name: py_str, struct_info: StructInfo) -> Var:
diff --git a/python/tvm/script/parser/core/parser.py 
b/python/tvm/script/parser/core/parser.py
index 9275924466..69e262b1d3 100644
--- a/python/tvm/script/parser/core/parser.py
+++ b/python/tvm/script/parser/core/parser.py
@@ -240,6 +240,7 @@ class Parser(doc.NodeVisitor):
     dispatch_tokens: List[str]
     function_annotations: Optional[Dict[str, Dict[str, Any]]]
     var_table: VarTable
+    inside_function: bool  # whether we are within a function
 
     def __init__(
         self,
@@ -250,6 +251,7 @@ class Parser(doc.NodeVisitor):
         self.dispatch_tokens = ["default"]
         self.function_annotations = function_annotations
         self.var_table = VarTable()
+        self.inside_function = False
 
     def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
         """The main parse method for parser.
diff --git a/python/tvm/script/parser/relax/entry.py 
b/python/tvm/script/parser/relax/entry.py
index 2711e855dd..ff237a5600 100644
--- a/python/tvm/script/parser/relax/entry.py
+++ b/python/tvm/script/parser/relax/entry.py
@@ -45,9 +45,11 @@ FType = TypeVar("FType", bound=_Callable)
 # this formulation allows us to support having @R.function
 # appear as a decorator by itself or to have optional arguments
 # like @R.function(pure=False)
-def function(f: Optional[FType] = None, pure: bool = True) -> Union[Function, 
FType]:
+def function(
+    f: Optional[FType] = None, pure: bool = True, private: bool = False
+) -> Union[Function, FType]:
     # pylint: disable=unused-argument
-    # (pure isn't used here, but is used later in parsing)
+    # (pure and private aren't used here, but are used later in parsing)
 
     # need to inspect the stack first because is_defined_in_class expects the 
outer class
     # to be in a particular position in the stack
diff --git a/python/tvm/script/parser/relax/parser.py 
b/python/tvm/script/parser/relax/parser.py
index 427c56bcc8..863c249975 100644
--- a/python/tvm/script/parser/relax/parser.py
+++ b/python/tvm/script/parser/relax/parser.py
@@ -160,6 +160,9 @@ def collect_symbolic_var_from_params(self: Parser, node: 
doc.FunctionDef) -> Non
 
 @dispatch.register(token="relax", type_name="FunctionDef")
 def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
+    is_inner_function = self.inside_function
+    self.inside_function = True
+
     # reserve a var for local function
     func_val = self.var_table.get().get(node.name)
     if not func_val and is_recursive(node):
@@ -178,11 +181,14 @@ def visit_function_def(self: Parser, node: 
doc.FunctionDef) -> None:
         local_func_var = relax.Var(node.name, 
relax.FuncStructInfo(params_sinfo, ret_sinfo))
         self.var_table.add(node.name, local_func_var)
 
-    purity = find_purity_annotation(node)
+    purity = find_decorator_annotation(node, "pure")
+    # treat the function as private if we are inside another function
+    # or if it has a privacy annotation
+    privacy = is_inner_function or find_decorator_annotation(node, "private", 
default=False)
 
     with self.var_table.with_frame():
         with self.with_dispatch_token("relax"):
-            with R.function(is_pure=purity):
+            with R.function(is_pure=purity, is_private=privacy):
                 R.func_name(node.name)
                 collect_symbolic_var_from_params(self, node)
 
@@ -202,22 +208,22 @@ def visit_function_def(self: Parser, node: 
doc.FunctionDef) -> None:
                             self.report_error(stmt, "inline prim_func is 
disallowed in Relax IR")
 
                 self.visit_body(node.body)
+    self.inside_function = is_inner_function
 
 
-def find_purity_annotation(node: doc.FunctionDef) -> bool:
+def find_decorator_annotation(node: doc.FunctionDef, annotation: str, default: 
bool = True) -> bool:
     """
-    Check the value of `pure` in the function decorator.
-    Returns the annotated purity if present, otherwise defaulting to True.
-    This allows for specifying the purity in the function signature.
+    Check the value of given annotation (argument name) in the function 
decorator.
+    Returns the value of the annotation if present, otherwise giving the 
default value.
     """
-    # look for the pure argument in the function decorator
+    # look for the named argument in the function decorator
     for dec in node.decorator_list:
         if not isinstance(dec, doc.Call) or dec.func.attr != "function":
             continue
         for keyword in dec.keywords:
-            if keyword.arg == "pure":
+            if keyword.arg == annotation:
                 return keyword.value.value
-    return True
+    return default
 
 
 @dispatch.register(token="relax", type_name="tvm_declare_function")
@@ -238,7 +244,8 @@ def visit_tvm_declare_function(self: Parser, node: 
doc.FunctionDef) -> GlobalVar
             param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True)
             params.append(relax.Var(arg.arg, param_sinfo))
 
-    is_pure = find_purity_annotation(node)
+    is_pure = find_decorator_annotation(node, "pure")
+
     func_signature = relax.Function.create_empty(params, ret_sinfo, 
is_pure=is_pure)
     return I.decl_function(node.name, func_signature)
 
diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc
index 37582e3015..19faaad58b 100644
--- a/src/relax/training/utils.cc
+++ b/src/relax/training/utils.cc
@@ -48,7 +48,9 @@ class AppendLossMutator : private ExprMutator {
     Function new_loss_func = CopyWithNewVars(loss_function);
 
     AppendLossMutator mutator(mod, new_loss_func, num_backbone_outputs);
-    auto new_func_transformed = 
Downcast<Function>(mutator.VisitExpr(new_func));
+    auto new_func_transformed =
+        WithAttr(Downcast<Function>(mutator.VisitExpr(new_func)), 
tvm::attr::kGlobalSymbol,
+                 new_func_name.value_or(func_name + "_loss"));
 
     auto new_module = GetRef<IRModule>(mod.CopyOnWrite());
     auto new_var = GlobalVar(new_func_name.value_or(func_name + "_loss"));
diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc
index 7645ae8cb6..2cda7a972d 100644
--- a/src/relax/transform/gradient.cc
+++ b/src/relax/transform/gradient.cc
@@ -333,10 +333,14 @@ class GradientMutator : private ExprMutator {
     }
 
     GradientMutator mutator(mod, require_grads_value, target_index);
-    Function new_func_transformed = 
Downcast<Function>(mutator.VisitExpr(new_func));
+
+    // make the adjoint public
+    auto new_name = func_name + "_adjoint";
+    Function new_func_transformed = 
WithAttr(Downcast<Function>(mutator.VisitExpr(new_func)),
+                                             tvm::attr::kGlobalSymbol, 
new_name);
 
     IRModule new_module = GetRef<IRModule>(mod.CopyOnWrite());
-    new_module->Add(GlobalVar(func_name + "_adjoint"), new_func_transformed);
+    new_module->Add(GlobalVar(new_name), new_func_transformed);
     return new_module;
   }
 
diff --git a/src/relax/transform/lift_transform_params.cc 
b/src/relax/transform/lift_transform_params.cc
index f7c9a4189d..fb1f292776 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -251,7 +251,10 @@ class TransformParamsLifter : public ExprMutator {
     lift_plan_ = planner.Plan(func, num_input);
 
     // Step 2: Add the lifted function to the module
-    builder_->AddFunction(lift_plan_.f_transform_params, new_func_name);
+    // (The lifted function should be public so we add a global symbol to it)
+    auto lift_func =
+        WithAttr(lift_plan_.f_transform_params, tvm::attr::kGlobalSymbol, 
new_func_name);
+    builder_->AddFunction(lift_func, new_func_name);
 
     // Step 3: Update the current function.
 
diff --git a/src/script/ir_builder/relax/frame.cc 
b/src/script/ir_builder/relax/frame.cc
index 00bbd2a551..966af809c9 100644
--- a/src/script/ir_builder/relax/frame.cc
+++ b/src/script/ir_builder/relax/frame.cc
@@ -56,6 +56,11 @@ void FunctionFrameNode::ExitWithScope() {
                              "`return` to return an Expr";
   this->block_builder->BeginScope(params);
   Expr body = 
this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, 
output.value()));
+  // if the function is not private, add a global symbol to its attributes
+  if (!is_private.value_or(Bool(false))->value && name.defined() &&
+      !attrs.count(tvm::attr::kGlobalSymbol)) {
+    attrs.Set(tvm::attr::kGlobalSymbol, name.value());
+  }
   auto dict_attrs = attrs.empty() ? NullValue<DictAttrs>() : DictAttrs(attrs);
   this->block_builder->EndScope();
   tvm::relax::Function func(/*params=*/params,
diff --git a/src/script/ir_builder/relax/ir.cc 
b/src/script/ir_builder/relax/ir.cc
index 52d9f0cfe1..d66e8d0598 100644
--- a/src/script/ir_builder/relax/ir.cc
+++ b/src/script/ir_builder/relax/ir.cc
@@ -52,7 +52,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable)
 
 /////////////////////////////// Function ////////////////////////////////
 
-FunctionFrame Function(const Bool& is_pure) {
+FunctionFrame Function(const Bool& is_pure, const Bool& is_private) {
   ObjectPtr<FunctionFrameNode> n = make_object<FunctionFrameNode>();
   const IRBuilder& ir_builder = IRBuilder::Current();
   Optional<tvm::IRModule> mod = NullOpt;
@@ -61,6 +61,7 @@ FunctionFrame Function(const Bool& is_pure) {
   }
   n->block_builder = tvm::relax::BlockBuilder::Create(/*mod=*/mod);
   n->is_pure = is_pure;
+  n->is_private = is_private;
   return FunctionFrame(n);
 }
 
diff --git a/src/script/printer/relax/function.cc 
b/src/script/printer/relax/function.cc
index bd5d969563..bc5f12309f 100644
--- a/src/script/printer/relax/function.cc
+++ b/src/script/printer/relax/function.cc
@@ -22,13 +22,36 @@ namespace tvm {
 namespace script {
 namespace printer {
 
+bool AtTopLevelFunction(const IRDocsifier& d) {
+  // fewer than 2 frames: not in a function at all
+  if (d->frames.size() < 2) {
+    return false;
+  }
+  // if the first frame is a RelaxFrame, then this is not inside a module.
+  // 2 frames => we are at a function (more than 2 => nested function)
+  if (d->frames[0]->IsInstance<RelaxFrameNode>()) {
+    return d->frames.size() == 2;
+  }
+  // otherwise the first two frames pertain to an IR module,
+  // so 3 frames => we are at a top-level function (more than 3 => nested 
function)
+  return d->frames.size() == 3;
+}
+
 TVM_REGISTER_NODE_TYPE(RelaxFrameNode);
 
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<relax::Function>("", [](relax::Function n, ObjectPath n_p, 
IRDocsifier d) -> Doc {
       std::unordered_set<const tir::VarNode*> func_vars;
       With<RelaxFrame> f(d);
-      IdDoc func_name = d->Define(n, f(), FindFunctionName(d, 
n).value_or("main"));
+
+      IdDoc func_name("");
+      // if we are binding a local definition, then calling d->Define
+      // will result in a repeated definition and an incorrect displayed name
+      if (Optional<String> name = GetBindingName(d)) {
+        func_name = std::move(IdDoc(name.value()));
+      } else {
+        func_name = std::move(d->Define(n, f(), FindFunctionName(d, 
n).value_or("main")));
+      }
       (*f)->AddDispatchToken(d, "relax");
       (*f)->is_func = true;
       (*f)->func_vars = &func_vars;
@@ -52,17 +75,49 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
       (*f)->func_vars = nullptr;
       // Step 4. Print attributes
       if (n->attrs.defined() && !n->attrs->dict.empty()) {
-        (*f)->stmts.push_back(
-            ExprStmtDoc(Relax(d, "func_attr")  //
-                            ->Call({d->AsDoc<ExprDoc>(n->attrs, 
n_p->Attr("attrs"))})));
+        // If the function is a global function and has a global symbol,
+        // then don't print the global symbol (it will be implicit from not 
being private).
+        // For a function without an IR module whose global symbol
+        // doesn't match the function name, we should still print the global 
symbol attribute.
+        if (AtTopLevelFunction(d) && 
n->attrs->dict.count(tvm::attr::kGlobalSymbol) &&
+            Downcast<String>(n->attrs->dict.at(tvm::attr::kGlobalSymbol)) == 
func_name->name) {
+          Map<String, ObjectRef> new_attrs;
+          for (auto kv : n->attrs->dict) {
+            if (kv.first != tvm::attr::kGlobalSymbol) {
+              new_attrs.Set(kv.first, kv.second);
+            }
+          }
+          if (!new_attrs.empty()) {
+            (*f)->stmts.push_back(ExprStmtDoc(
+                Relax(d, "func_attr")  //
+                    ->Call({d->AsDoc<ExprDoc>(DictAttrs(new_attrs), 
n_p->Attr("attrs"))})));
+          }
+        } else {
+          (*f)->stmts.push_back(
+              ExprStmtDoc(Relax(d, "func_attr")  //
+                              ->Call({d->AsDoc<ExprDoc>(n->attrs, 
n_p->Attr("attrs"))})));
+        }
       }
       // Step 5. Prepare the decorator (include purity if it's impure)
       ExprDoc decorator = Relax(d, "function");
+      Array<ExprDoc, void> pos_args = {};
+      Array<String, void> dec_keys;
+      Array<ExprDoc, void> dec_values;
       if (!n->is_pure) {
-        Array<ExprDoc> pos_args = {};
-        decorator = std::move(decorator->Call(
-            pos_args, {"pure"}, {LiteralDoc::Boolean(false, 
Optional<ObjectPath>())}));
+        dec_keys.push_back("pure");
+        dec_values.push_back(LiteralDoc::Boolean(false, 
Optional<ObjectPath>()));
       }
+      // if the function is global or is not in a module and does not have a 
global symbol,
+      // indicate that it's private
+      if (AtTopLevelFunction(d) &&
+          (!n->attrs.defined() || 
!n->attrs->dict.count(tvm::attr::kGlobalSymbol))) {
+        dec_keys.push_back("private");
+        dec_values.push_back(LiteralDoc::Boolean(true, 
Optional<ObjectPath>()));
+      }
+      if (dec_keys.size()) {
+        decorator = std::move(decorator->Call(pos_args, dec_keys, dec_values));
+      }
+
       // Step 6. Print body
       Array<StmtDoc> body =
           PrintSeqExpr(Downcast<relax::SeqExpr>(n->body), n_p->Attr("body"), 
d, /*use_ret=*/true);
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index cbb19c6743..ea83807bf8 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -875,7 +875,7 @@ def test_rewrite_simple():
         return R.multiply(matchings[x], R.const(2, "float32"))
 
     rewritten = rewrite_call(pattern, rewriter, main)
-    tvm.ir.assert_structural_equal(rewritten, expected1)
+    tvm.ir.assert_structural_equal(rewritten, 
expected1.with_attr("global_symbol", "main"))
 
     add1 = is_op("relax.add")(x, x)
     pattern = is_op("relax.add")(add1, add1)
@@ -884,7 +884,7 @@ def test_rewrite_simple():
         return R.multiply(matchings[x], R.const(4, "float32"))
 
     rewritten = rewrite_call(pattern, rewriter, main)
-    tvm.ir.assert_structural_equal(rewritten, expected2)
+    tvm.ir.assert_structural_equal(rewritten, 
expected2.with_attr("global_symbol", "main"))
 
     # No rewriting, return the original call node as is
     def rewriter(orig, _):
@@ -959,7 +959,7 @@ def test_rewrite_attention():
         return R.nn.attention(matchings[Q], matchings[K], matchings[V])
 
     rewritten = rewrite_call(pattern, rewriter, main)
-    tvm.ir.assert_structural_equal(rewritten, expected)
+    tvm.ir.assert_structural_equal(rewritten, 
expected.with_attr("global_symbol", "main"))
 
 
 def test_attention_qkv():
@@ -1115,7 +1115,7 @@ def test_combine_matmul_twice():
             inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, 
matmul2, matmul3
         )
         rewritten = rewrite_bindings(ctx, rewriter, qkv_x2)
-        tvm.ir.assert_structural_equal(rewritten, expected)
+        tvm.ir.assert_structural_equal(rewritten, 
expected.with_attr("global_symbol", "qkv_x2"))
 
 
 def test_combine_matmul_emit_order():
@@ -1173,7 +1173,7 @@ def test_combine_matmul_emit_order():
             inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, 
matmul2, matmul3
         )
         rewritten = rewrite_bindings(ctx, rewriter, main)
-        tvm.ir.assert_structural_equal(rewritten, expected)
+        tvm.ir.assert_structural_equal(rewritten, 
expected.with_attr("global_symbol", "main"))
 
         # make sure it builds
         mod = tvm.IRModule()
@@ -1272,7 +1272,7 @@ def test_combine_transposed_matmul_twice():
 
         rewritten = rewrite_bindings(ctx, rewriter, main)
         print(rewritten.script())
-        tvm.ir.assert_structural_equal(rewritten, expected)
+        tvm.ir.assert_structural_equal(rewritten, 
expected.with_attr("global_symbol", "main"))
 
         # make sure it builds
         mod = tvm.IRModule()
diff --git a/tests/python/relax/test_training_loss.py 
b/tests/python/relax/test_training_loss.py
index 0a2418aad7..0d456ceb38 100644
--- a/tests/python/relax/test_training_loss.py
+++ b/tests/python/relax/test_training_loss.py
@@ -46,6 +46,7 @@ def test_l1_loss():
     def expected(
         predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), 
"float32")
     ) -> R.Tensor((), "float32"):
+        R.func_attr({"global_symbol": "l1_loss"})
         with R.dataflow():
             lv: R.Tensor((3, 5), "float32") = R.subtract(predictions, targets)
             lv1: R.Tensor((3, 5), "float32") = R.abs(lv)
@@ -70,6 +71,7 @@ def test_l1_loss_append():
         b: R.Tensor((2, 4), "float32"),
         targets: R.Tensor((2, 4), "float32"),
     ) -> R.Tensor((), "float32"):
+        R.func_attr({"global_symbol": "forward_loss"})
         with R.dataflow():
             lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="")
             out: R.Tensor((2, 4), "float32") = R.add(lv, b)
@@ -93,6 +95,7 @@ def test_mse_loss():
     def expected(
         predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), 
"float32")
     ) -> R.Tensor((), "float32"):
+        R.func_attr({"global_symbol": "mse_loss"})
         with R.dataflow():
             lv: R.Tensor((3, 5), "float32") = R.subtract(predictions, targets)
             lv1: R.Tensor((3, 5), "float32") = R.multiply(lv, lv)
@@ -117,6 +120,7 @@ def test_mse_loss_append():
         b: R.Tensor((2, 4), "float32"),
         targets: R.Tensor((2, 4), "float32"),
     ) -> R.Tensor((), "float32"):
+        R.func_attr({"global_symbol": "forward_loss"})
         with R.dataflow():
             lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="")
             out: R.Tensor((2, 4), "float32") = R.add(lv, b)
@@ -143,6 +147,7 @@ def test_cross_entropy_loss():
         targets: R.Tensor((3,), "int64"),
         weights: R.Tensor((5,), "float32"),
     ) -> R.Tensor((), "float32"):
+        R.func_attr({"global_symbol": "cross_entropy_loss"})
         with R.dataflow():
             lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, 
axis=-1)
             gv: R.Tensor((), "float32") = R.nn.nll_loss(
@@ -165,6 +170,7 @@ def test_cross_entropy_loss_without_weights():
     def expected(
         predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3,), 
"int64")
     ) -> R.Tensor((), "float32"):
+        R.func_attr({"global_symbol": "cross_entropy_loss"})
         with R.dataflow():
             lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, 
axis=-1)
             gv: R.Tensor((), "float32") = R.nn.nll_loss(
@@ -195,6 +201,7 @@ def test_cross_entropy_loss_append():
         targets: R.Tensor((2,), "int64"),
         weights: R.Tensor((4,), "float32"),
     ) -> R.Tensor((), "float32"):
+        R.func_attr({"global_symbol": "forward_loss"})
         with R.dataflow():
             lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="")
             out: R.Tensor((2, 4), "float32") = R.add(lv, b)
@@ -224,6 +231,7 @@ def test_categorical_cross_entropy_loss():
         targets: R.Tensor((3, 5), "int64"),
         weights: R.Tensor((5,), "float32"),
     ) -> R.Tensor((), "float32"):
+        R.func_attr({"global_symbol": "categorical_cross_entropy_loss"})
         with R.dataflow():
             lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, 
axis=-1)
             lv: R.Tensor((), "float32") = -lv * targets.astype("float32")
@@ -245,6 +253,7 @@ def test_categorical_cross_entropy_loss_without_weights():
     def expected(
         predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), 
"int64")
     ) -> R.Tensor((), "float32"):
+        R.func_attr({"global_symbol": "categorical_cross_entropy_loss"})
         with R.dataflow():
             lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, 
axis=-1)
             gv: R.Tensor((), "float32") = R.mean(-lv * 
targets.astype("float32"))
@@ -270,6 +279,7 @@ def test_categorical_cross_entropy_loss_with_ignore_index():
         targets: R.Tensor((3, 5), "int64"),
         weights: R.Tensor((5,), "float32"),
     ) -> R.Tensor((), "float32"):
+        R.func_attr({"global_symbol": "categorical_cross_entropy_loss"})
         with R.dataflow():
             lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, 
axis=-1)
             targets = relax.op.reshape(
diff --git a/tests/python/relax/test_training_optimizer.py 
b/tests/python/relax/test_training_optimizer.py
index b2246087c6..514422da8d 100644
--- a/tests/python/relax/test_training_optimizer.py
+++ b/tests/python/relax/test_training_optimizer.py
@@ -67,6 +67,7 @@ def test_sgd_simple():
         R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
         R.Tuple(R.Tensor((), "int64")),
     ):
+        R.func_attr({"global_symbol": "SGD"})
         # block 0
         with R.dataflow():
             num_steps: R.Tensor((), "int64") = optim_states[0]
@@ -104,6 +105,7 @@ def test_sgd_complex():
         R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
         R.Tuple(R.Tensor((), "int64")),
     ):
+        R.func_attr({"global_symbol": "SGD"})
         with R.dataflow():
             num_steps: R.Tensor((), "int64") = optim_states[0]
             num_steps_new: R.Tensor((), "int64") = R.add(num_steps, R.const(1, 
"int64"))
@@ -146,6 +148,7 @@ def test_momentum_sgd_simple():
         R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
         R.Tuple(R.Tensor((), "int64"), R.Tensor((3, 3), "float32"), 
R.Tensor((3,), "float32")),
     ):
+        R.func_attr({"global_symbol": "MomentumSGD"})
         # block 0
         with R.dataflow():
             num_steps: R.Tensor((), "int64") = optim_states[0]
@@ -195,6 +198,7 @@ def test_momentum_sgd_complex():
         R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
         R.Tuple(R.Tensor((), "int64"), R.Tensor((3, 3), "float32"), 
R.Tensor((3,), "float32")),
     ):
+        R.func_attr({"global_symbol": "MomentumSGD"})
         # block 0
         with R.dataflow():
             num_steps: R.Tensor((), "int64") = optim_states[0]
@@ -250,6 +254,7 @@ def test_momentum_sgd_nesterov():
         R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
         R.Tuple(R.Tensor((), "int64"), R.Tensor((3, 3), "float32"), 
R.Tensor((3,), "float32")),
     ):
+        R.func_attr({"global_symbol": "MomentumSGD"})
         # block 0
         with R.dataflow():
             num_steps: R.Tensor((), "int64") = optim_states[0]
@@ -321,6 +326,7 @@ def test_adam_simple():
             R.Tensor((3,), "float32"),
         ),
     ):
+        R.func_attr({"global_symbol": "Adam"})
         # block 0
         with R.dataflow():
             num_steps: R.Tensor((), "int64") = optim_states[0]
@@ -418,6 +424,7 @@ def test_adam_complex():
             R.Tensor((3,), "float32"),
         ),
     ):
+        R.func_attr({"global_symbol": "Adam"})
         # block 0
         with R.dataflow():
             num_steps: R.Tensor((), "int64") = optim_states[0]
@@ -519,6 +526,7 @@ def test_adam_float64():
             R.Tensor((3,), "float64"),
         ),
     ):
+        R.func_attr({"global_symbol": "Adam"})
         # block 0
         with R.dataflow():
             num_steps: R.Tensor((), "int64") = optim_states[0]
diff --git a/tests/python/relax/test_training_optimizer_numeric.py 
b/tests/python/relax/test_training_optimizer_numeric.py
index 3b300e8261..23db8987f1 100644
--- a/tests/python/relax/test_training_optimizer_numeric.py
+++ b/tests/python/relax/test_training_optimizer_numeric.py
@@ -69,7 +69,7 @@ def _test_optimizer(target, dev, np_func, opt_type, *args, 
**kwargs):
     x = relax.Var("x", R.Tensor((3, 3), "float32"))
     y = relax.Var("y", R.Tensor((3,), "float32"))
     opt = opt_type(*args, **kwargs).init([x, y])
-    mod = IRModule.from_expr(opt.get_function())
+    mod = IRModule.from_expr(opt.get_function().with_attr("global_symbol", 
"main"))
     tvm_func = _legalize_and_build(mod, target, dev)["main"]
 
     param_arr = [np.random.rand(3, 3).astype(np.float32), 
np.random.rand(3).astype(np.float32)]
diff --git a/tests/python/relax/test_transform_attach_global_symbol.py 
b/tests/python/relax/test_transform_attach_global_symbol.py
index 035e21609d..680df96947 100644
--- a/tests/python/relax/test_transform_attach_global_symbol.py
+++ b/tests/python/relax/test_transform_attach_global_symbol.py
@@ -43,7 +43,7 @@ def test_basic():
                         C[vi, vj] = T.float32(0)
                     C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
 
-        @R.function
+        @R.function(private=True)
         def main(
             x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), 
"float32")
         ) -> R.Tensor:
@@ -74,7 +74,6 @@ def test_basic():
         def main(
             x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), 
"float32")
         ) -> R.Tensor:
-            R.func_attr({"global_symbol": "main"})
             m, n, k = T.int64(), T.int64(), T.int64()
             gv0 = R.call_tir(Expected.tir_matmul, (x, w), R.Tensor((m, k), 
dtype="float32"))
             return gv0
@@ -94,7 +93,7 @@ def test_system_lib_prefix():
         def tir_zeros(x: T.Buffer((2), "float32")) -> None:
             x[0] = T.float32(0)
 
-        @R.function
+        @R.function(private=True)
         def main() -> R.Tensor:
             gv0 = R.call_tir(Before.tir_zeros, (), R.Tensor((2,), 
dtype="float32"))
             return gv0
@@ -110,7 +109,6 @@ def test_system_lib_prefix():
 
         @R.function
         def main() -> R.Tensor:
-            R.func_attr({"global_symbol": "main"})
             gv0 = R.call_tir(Expected.tir_zeros, (), R.Tensor((2,), 
dtype="float32"))
             return gv0
 
diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py 
b/tests/python/relax/test_transform_combine_parallel_matmul.py
index 719daaf449..97211f0dd0 100644
--- a/tests/python/relax/test_transform_combine_parallel_matmul.py
+++ b/tests/python/relax/test_transform_combine_parallel_matmul.py
@@ -97,7 +97,7 @@ def test_simple():
             R.output(lv3)
         return lv3
 
-    tvm.ir.assert_structural_equal(mod["main"], expected1)
+    tvm.ir.assert_structural_equal(mod["main"], 
expected1.with_attr("global_symbol", "main"))
 
     # Test a batched LHS case, slicing is done on the axis 2
     mod = get_parallel_matmul(3, lhs_shape=(2, 1024, 640))
@@ -121,7 +121,7 @@ def test_simple():
             R.output(lv3)
         return lv3
 
-    tvm.ir.assert_structural_equal(mod["main"], expected2)
+    tvm.ir.assert_structural_equal(mod["main"], 
expected2.with_attr("global_symbol", "main"))
 
 
 def test_bias():
@@ -151,7 +151,7 @@ def test_bias():
             R.output(lv6)
         return lv6
 
-    tvm.ir.assert_structural_equal(mod["main"], expected1)
+    tvm.ir.assert_structural_equal(mod["main"], 
expected1.with_attr("global_symbol", "main"))
 
     mod = get_parallel_matmul(3, with_bias=[True, False, True])
     mod = CombineParallelMatmul()(mod)
@@ -178,7 +178,7 @@ def test_bias():
             R.output(lv5)
         return lv5
 
-    tvm.ir.assert_structural_equal(mod["main"], expected2)
+    tvm.ir.assert_structural_equal(mod["main"], 
expected2.with_attr("global_symbol", "main"))
 
 
 def test_activation():
@@ -204,7 +204,7 @@ def test_activation():
             R.output(lv6)
         return lv6
 
-    tvm.ir.assert_structural_equal(mod["main"], expected1)
+    tvm.ir.assert_structural_equal(mod["main"], 
expected1.with_attr("global_symbol", "main"))
 
     mod = get_parallel_matmul(3, activation=["gelu", "relu", "relu"])
     mod = CombineParallelMatmul()(mod)
@@ -230,7 +230,7 @@ def test_activation():
             R.output(lv6)
         return lv6
 
-    tvm.ir.assert_structural_equal(mod["main"], expected2)
+    tvm.ir.assert_structural_equal(mod["main"], 
expected2.with_attr("global_symbol", "main"))
 
     mod = get_parallel_matmul(3, activation=["relu", None, None])
     mod = CombineParallelMatmul()(mod)
@@ -255,7 +255,7 @@ def test_activation():
             R.output(lv4)
         return lv4
 
-    tvm.ir.assert_structural_equal(mod["main"], expected3)
+    tvm.ir.assert_structural_equal(mod["main"], 
expected3.with_attr("global_symbol", "main"))
 
 
 def test_bias_activation():
@@ -286,7 +286,7 @@ def test_bias_activation():
             R.output(lv9)
         return lv9
 
-    tvm.ir.assert_structural_equal(mod["main"], expected1)
+    tvm.ir.assert_structural_equal(mod["main"], 
expected1.with_attr("global_symbol", "main"))
 
     mod = get_parallel_matmul(3, with_bias=[True, True, True], 
activation=["relu", None, "relu"])
     mod = CombineParallelMatmul()(mod)
@@ -316,7 +316,7 @@ def test_bias_activation():
             R.output(lv8)
         return lv8
 
-    tvm.ir.assert_structural_equal(mod["main"], expected2)
+    tvm.ir.assert_structural_equal(mod["main"], 
expected2.with_attr("global_symbol", "main"))
 
     mod = get_parallel_matmul(3, with_bias=[True, False, True], 
activation=["relu", None, "relu"])
     mod = CombineParallelMatmul()(mod)
@@ -345,7 +345,7 @@ def test_bias_activation():
             R.output(lv7)
         return lv7
 
-    tvm.ir.assert_structural_equal(mod["main"], expected3)
+    tvm.ir.assert_structural_equal(mod["main"], 
expected3.with_attr("global_symbol", "main"))
 
 
 def test_rhs_batched():
@@ -378,6 +378,7 @@ def test_rhs_batched():
         w2: R.Tensor((2, 640, 640), dtype="float32"),
         w3: R.Tensor((3, 4, 640, 640), dtype="float32"),
     ) -> R.Tensor:
+        R.func_attr({"global_symbol": "main"})
         with R.dataflow():
             lv = R.concat((w0, w2), axis=2)
             lv1 = R.matmul(x, lv, out_dtype="float32")
@@ -458,6 +459,7 @@ def test_multiple_combine():
         b0: R.Tensor((640,), dtype="float32"),
         b1: R.Tensor((640,), dtype="float32"),
     ) -> R.Tensor:
+        R.func_attr({"global_symbol": "main"})
         with R.dataflow():
             lv = R.concat((w0, w1, w2), axis=1)
             lv1 = R.matmul(x1, lv, out_dtype="float32")
@@ -515,6 +517,7 @@ def test_check():
         w3: R.Tensor((640, 640), dtype="float32"),
         w4: R.Tensor((640, 640), dtype="float32"),
     ) -> R.Tensor:
+        R.func_attr({"global_symbol": "main"})
         with R.dataflow():
             lv = R.concat((w0, w1, w2), axis=1)
             lv1 = R.matmul(x1, lv, out_dtype="float32")
diff --git a/tests/python/relax/test_transform_dead_code_elimination.py 
b/tests/python/relax/test_transform_dead_code_elimination.py
index 9c6e0e0567..12a3de6acb 100644
--- a/tests/python/relax/test_transform_dead_code_elimination.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -168,7 +168,7 @@ def test_unused_relax_func():
                     vi, vj = T.axis.remap("SS", [i, j])
                     z[vi, vj] = x[vi, vj] + y[vi, vj]
 
-        @R.function
+        @R.function(private=True)
         def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 
16), "float32")):
             gv0 = R.add(x, w)
             return gv0
@@ -202,7 +202,7 @@ def test_unused_relax_func_custom_entry_func():
                     vi, vj = T.axis.remap("SS", [i, j])
                     z[vi, vj] = x[vi, vj] + y[vi, vj]
 
-        @R.function
+        @R.function(private=True)
         def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 
16), "float32")):
             gv0 = R.add(x, w)
             return gv0
@@ -239,7 +239,7 @@ def test_unused_relax_func_symbolic_shape():
                     vi, vj = T.axis.remap("SS", [i, j])
                     z[vi, vj] = x[vi, vj] + y[vi, vj]
 
-        @R.function
+        @R.function(private=True)
         def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", 
"k"), "float32")):
             gv0 = R.add(x, w)
             return gv0
@@ -310,7 +310,7 @@ def test_multiple_unused_funcs():
                     vi, vj = T.axis.remap("SS", [i, j])
                     z[vi, vj] = x[vi, vj] + y[vi, vj]
 
-        @R.function
+        @R.function(private=True)
         def unused_func2(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 
16), "float32")):
             gv0 = R.add(x, w)
             return gv0
diff --git a/tests/python/relax/test_transform_fuse_ops.py 
b/tests/python/relax/test_transform_fuse_ops.py
index 14c3dbe713..b51f651025 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -48,7 +48,7 @@ def test_fuse_simple():
         x = relax.Var("x", R.Tensor([10, 20], "float32"))
         p0 = relax.Var("p0", R.Tensor((), "float32"))
 
-        with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 
1}):
+        with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 
1}, private=True):
             with bb.dataflow():
                 lv0 = bb.emit_te(topi.add, x, p0)
                 lv1 = bb.emit_te(topi.exp, lv0)
@@ -100,7 +100,9 @@ def test_conv2d_fuse():
         x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
         w = relax.Var("w", R.Tensor((16, 16, 3, 3), dtype))
         p0 = relax.Var("p0", R.Tensor((), dtype))
-        with bb.function("fused_conv2d_add1_add2", [x, w, p0], 
attrs={"Primitive": 1}):
+        with bb.function(
+            "fused_conv2d_add1_add2", [x, w, p0], attrs={"Primitive": 1}, 
private=True
+        ):
             with bb.dataflow():
                 lv0 = bb.emit_te(
                     topi.nn.conv2d,
@@ -119,7 +121,7 @@ def test_conv2d_fuse():
         x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
         w = relax.Var("w", R.Tensor((16, 16, 1, 1), dtype))
         y = relax.Var("y", R.Tensor((1, 16, 64, 64), dtype))
-        with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": 
1}):
+        with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": 
1}, private=True):
             with bb.dataflow():
                 lv0 = bb.emit_te(
                     topi.nn.conv2d,
@@ -196,7 +198,9 @@ def test_concatenate():
         x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
         w = relax.Var("w", R.Tensor((1, 16, 32, 32), "float32"))
         p0 = relax.Var("p0", R.Tensor((), "float32"))
-        with bb.function("fused_upsampling_concatenate_add", [w, x, p0], 
attrs={"Primitive": 1}):
+        with bb.function(
+            "fused_upsampling_concatenate_add", [w, x, p0], 
attrs={"Primitive": 1}, private=True
+        ):
             with bb.dataflow():
                 lv0 = bb.emit_te(topi.nn.upsampling, w, scale_h=2.0, 
scale_w=2.0)
                 lv1 = bb.emit_te(topi.concatenate, (lv0, x), axis=1)
@@ -287,7 +291,10 @@ def test_fuse_tuple_get_elemwise():
         # Grouped function
         dense = relax.Var("dense", R.Tensor((1, 3 * dim), "float32"))
         with bb.function(
-            "fused_split_sigmoid_tanh_exp_multiply_add", [dense], 
attrs={"Primitive": 1}
+            "fused_split_sigmoid_tanh_exp_multiply_add",
+            [dense],
+            attrs={"Primitive": 1},
+            private=True,
         ):
             with bb.dataflow():
                 lv0 = bb.emit_te(topi.split, dense, indices_or_sections=3, 
axis=1)
@@ -340,7 +347,7 @@ def test_tuple_get_root():
 
         # Grouped function
         x = relax.Var("x", R.Tensor((1, 3 * dim), "float32"))
-        with bb.function("fused_split", [x], attrs={"Primitive": 1}):
+        with bb.function("fused_split", [x], attrs={"Primitive": 1}, 
private=True):
             with bb.dataflow():
                 lv0 = bb.emit_te(topi.split, x, indices_or_sections=3, axis=1)
                 gv = bb.emit_output(relax.TupleGetItem(lv0, 0))
@@ -398,6 +405,7 @@ def test_tuple_intermediate():
             "fused_squeeze_add_squeeze1_add_add_add_concatenate_squeeze2_add1",
             [x, p0, p1, p2, p3, p4],
             attrs={"Primitive": 1},
+            private=True,
         ):
             with bb.dataflow():
                 lv0 = bb.emit_te(topi.squeeze, x)
@@ -500,6 +508,7 @@ def test_tuple_consecutive():
             
"fused_add_add_add_concatenate_add1_add_add_add_concatenate_add1_add_add_add_concatenate_add1_concatenate1",
             [x, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11],
             attrs={"Primitive": 1},
+            private=True,
         ):
             with bb.dataflow():
                 lv0 = bb.emit_te(topi.add, x, p0)
@@ -523,7 +532,7 @@ def test_tuple_consecutive():
         # Grouped function 2
         concat = relax.Var("concat", R.Tensor((1, 144, 64, 64), "float32"))
         p0 = relax.Var("p0", R.Tensor((), "float32"))
-        with bb.function("fused_pool2d_add2", [concat, p0], 
attrs={"Primitive": 1}):
+        with bb.function("fused_pool2d_add2", [concat, p0], 
attrs={"Primitive": 1}, private=True):
             with bb.dataflow():
                 lv0 = bb.emit_te(
                     topi.nn.pool2d,
@@ -609,7 +618,7 @@ def test_inception_like():
         # Grouped function 1
         x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
         w = relax.Var("w", R.Tensor((16, 16, 3, 3), "float32"))
-        with bb.function("fused_conv2d_relu", [x, w], attrs={"Primitive": 1}):
+        with bb.function("fused_conv2d_relu", [x, w], attrs={"Primitive": 1}, 
private=True):
             with bb.dataflow():
                 lv0 = bb.emit_te(
                     topi.nn.conv2d,
@@ -626,7 +635,7 @@ def test_inception_like():
         # Grouped function 2
         x = relax.Var("x", R.Tensor((1, 32, 64, 64), "float32"))
         w = relax.Var("w", R.Tensor((16, 32, 3, 3), "float32"))
-        with bb.function("fused_conv2d1_relu", [x, w], attrs={"Primitive": 1}):
+        with bb.function("fused_conv2d1_relu", [x, w], attrs={"Primitive": 1}, 
private=True):
             with bb.dataflow():
                 lv0 = bb.emit_te(
                     topi.nn.conv2d,
@@ -689,7 +698,10 @@ def test_fuse_parallel_injective():
         x = relax.Var("x", R.Tensor((10, 20), "int32"))
         p0 = relax.Var("p0", R.Tensor((), "int32"))
         with bb.function(
-            "fused_add_squeeze_transpose_transpose1_left_shift", [x, p0], 
attrs={"Primitive": 1}
+            "fused_add_squeeze_transpose_transpose1_left_shift",
+            [x, p0],
+            attrs={"Primitive": 1},
+            private=True,
         ):
             with bb.dataflow():
                 lv0 = bb.emit_te(topi.add, x, p0)
@@ -734,7 +746,7 @@ def test_softmax():
 
         # Grouped function
         x = relax.Var("x", R.Tensor((16, 16), "float32"))
-        with bb.function("fused_softmax_cast", [x], attrs={"Primitive": 1}):
+        with bb.function("fused_softmax_cast", [x], attrs={"Primitive": 1}, 
private=True):
             with bb.dataflow():
                 lv0 = bb.emit_te(topi.nn.softmax, x)
                 gv = bb.emit_output(bb.call_te(topi.cast, lv0, 
dtype="float16"))
@@ -781,7 +793,7 @@ def test_multiple_relax_functions():
 
         x = relax.Var("x", R.Tensor([10, 20], "float32"))
         p0 = relax.Var("p0", R.Tensor((), "float32"))
-        with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 
1}):
+        with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 
1}, private=True):
             with bb.dataflow():
                 lv0 = bb.emit_te(topi.add, x, p0)
                 lv1 = bb.emit_te(topi.exp, lv0)
@@ -791,7 +803,7 @@ def test_multiple_relax_functions():
 
         x = relax.Var("x", R.Tensor([20, 10], "float32"))
         p0 = relax.Var("p0", R.Tensor((), "float32"))
-        with bb.function("fused_add1_exp1_squeeze1", [x, p0], 
attrs={"Primitive": 1}):
+        with bb.function("fused_add1_exp1_squeeze1", [x, p0], 
attrs={"Primitive": 1}, private=True):
             with bb.dataflow():
                 lv0 = bb.emit_te(topi.add, x, p0)
                 lv1 = bb.emit_te(topi.exp, lv0)
@@ -938,7 +950,7 @@ def test_layer_norm_silu():
                     T.writes(B[v_i0, v_i1, v_i2, v_i3])
                     B[v_i0, v_i1, v_i2, v_i3] = T.max(A[v_i0, v_i1, v_i2, 
v_i3], T.float32(0))
 
-        @R.function
+        @R.function(private=True)
         def fused_layer_norm_relu(x: R.Tensor((1, 512, 64, 64), 
dtype="float32"), mean: R.Tensor((64, 64), dtype="float32"), var: R.Tensor((64, 
64), dtype="float32")) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
             R.func_attr({"Primitive": 1})
             cls = Expected
@@ -1080,7 +1092,7 @@ def test_multiple_paths():
                     T.writes(T_transpose[v_ax0, v_ax1])
                     T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
 
-        @R.function
+        @R.function(private=True)
         def fused_conv2d_add_add2(inp_0: R.Tensor((2, 320, 64, 64), 
dtype="float32"), w1: R.Tensor((320, 320, 3, 3), dtype="float32"), lv28: 
R.Tensor((1, 320, 1, 1), dtype="float32"), lv35: R.Tensor((2, 320, 1, 1), 
dtype="float32")) -> R.Tensor((2, 320, 64, 64), dtype="float32"):
             R.func_attr({"Primitive": 1})
             cls = Expected
@@ -1091,7 +1103,7 @@ def test_multiple_paths():
                 R.output(gv)
             return gv
 
-        @R.function
+        @R.function(private=True)
         def fused_matmul_add1(inp_1: R.Tensor((2, 1280), dtype="float32"), 
lv31: R.Tensor((1280, 320), dtype="float32"), b2: R.Tensor((320,), 
dtype="float32")) -> R.Tensor((2, 320), dtype="float32"):
             cls = Expected
             R.func_attr({"Primitive": 1})
@@ -1225,7 +1237,7 @@ def test_dead_group():
                     T.writes(T_transpose[v_ax0, v_ax1])
                     T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
 
-        @R.function
+        @R.function(private=True)
         def fused_matmul1_add1(inp_1: R.Tensor((1, 128), dtype="float32"), 
lv4: R.Tensor((128, 10), dtype="float32"), linear2_bias: R.Tensor((10,), 
dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
             R.func_attr({"Primitive": 1})
             cls = Expected
@@ -1267,7 +1279,7 @@ def test_symbolic_shape_aware_fuse():
 
     @I.ir_module
     class Expected:
-        @R.function
+        @R.function(private=True)
         def fused_add_exp_squeeze(
             x: R.Tensor(["n", "m"], "float32"), p0: R.Tensor([], "float32")
         ) -> R.Tensor(["n", "m"], dtype="float32"):
@@ -1305,7 +1317,7 @@ def test_symbolic_shape_aware_fuse_2():
 
     @I.ir_module
     class Expected:
-        @R.function
+        @R.function(private=True)
         def fused_full_trilu_broadcast_to(
             s: R.Shape(["n"]),
         ) -> R.Tensor([1, 1, "n", "n"], "float32"):
@@ -1353,7 +1365,7 @@ def test_shape_expr_arg():
 
     @I.ir_module
     class Expected:
-        @R.function
+        @R.function(private=True)
         def fused_full_trilu_broadcast_to(
             s: R.Shape(["n"]),
         ) -> R.Tensor([1, 1, "n", "n"], "float32"):
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py 
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index 5fb2b3332c..592132516b 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -60,7 +60,7 @@ class Conv2dReLU_composite_annotated:
             R.output(gv)
         return gv
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_conv2d_relax_nn_relu_dnnl(
         data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -124,7 +124,7 @@ class Conv2dReLUx2Partitioned:
             R.output(gv)
         return gv
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_conv2d_relax_nn_relu(
         data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -138,7 +138,7 @@ class Conv2dReLUx2Partitioned:
             R.output(gv1)
         return gv1
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_conv2d_relax_nn_relu1(
         conv1: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -174,7 +174,7 @@ class Conv2dReLUx2Partitioned_only_conv2d:
             R.output(conv2d)
         return conv2d
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_conv2d(
         data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -187,7 +187,7 @@ class Conv2dReLUx2Partitioned_only_conv2d:
             R.output(gv)
         return gv
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_conv2d1(
         conv11: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -236,7 +236,7 @@ class Conv2dConv2dReLUPartitioned:
             R.output(gv)
         return gv
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_conv2d_relax_nn_relu(
         conv1: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -250,7 +250,7 @@ class Conv2dConv2dReLUPartitioned:
             R.output(gv1)
         return gv1
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_conv2d(
         data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -303,7 +303,7 @@ class BranchTupleOutputPartitioned:
             R.output(out)
         return out
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_conv2d_relax_nn_relu(
         data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -377,7 +377,7 @@ class Conv2dx2:
 
 @tvm.script.ir_module
 class Conv2dx2_partitioned:
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_conv2d_cutlass(
         data: R.Tensor((16, 32, 32, 16), dtype="float16"),
         weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
@@ -565,7 +565,7 @@ def test_ignore_call_tir():
                     T.writes(out[i, j, k, l])
                     out[i, j, k, l] = T.max(data[i, j, k, l], T.float32(0))
 
-        @R.function
+        @R.function(private=True)
         def fused_relax_nn_conv2d(
             data: R.Tensor((1, 64, 56, 56), dtype="float32"),
             weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -617,7 +617,7 @@ def test_unused():
 
     @I.ir_module
     class Conv2dReLU_partitioned:
-        @R.function
+        @R.function(private=True)
         def fused_relax_nn_conv2d(
             data: R.Tensor((1, 64, 56, 56), dtype="float32"),
             weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -685,7 +685,7 @@ def test_bind_constants():
 
     @I.ir_module
     class Conv2dWithConstantWeight_partitioned:
-        @R.function
+        @R.function(private=True)
         def fused_relax_nn_conv2d(
             data: R.Tensor((1, 64, 56, 56), dtype="float32"),
             param_0: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -721,6 +721,7 @@ def test_bind_constants():
 def test_split():
     @R.function
     def func(inp: R.Tensor((16, 32), "float32")):
+        R.func_attr({"global_symbol": "main"})
         with R.dataflow():
             tup = R.split(inp, [16], axis=1)
             out = R.add(tup[0], tup[1])
@@ -729,7 +730,7 @@ def test_split():
 
     @tvm.script.ir_module
     class Expected1:
-        @R.function
+        @R.function(private=True)
         def fused_relax_split(
             inp: R.Tensor((16, 32), dtype="float32")
         ) -> R.Tuple(R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), 
dtype="float32")):
@@ -756,7 +757,7 @@ def test_split():
 
     @I.ir_module
     class Expected2:
-        @R.function
+        @R.function(private=True)
         def fused_relax_split_relax_add(
             inp: R.Tensor((16, 32), dtype="float32")
         ) -> R.Tensor((16, 16), dtype="float32"):
diff --git a/tests/python/relax/test_transform_fuse_tir.py 
b/tests/python/relax/test_transform_fuse_tir.py
index 00dc714654..f59e3f2e9e 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -32,7 +32,7 @@ def test_simple():
         x = relax.Var("x", R.Tensor([10, 20], "float32"))
         p0 = relax.Var("p0", R.Tensor([], "float32"))
 
-        with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 
True}):
+        with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 
True}, private=True):
             with bb.dataflow():
                 lv0 = bb.emit_te(topi.add, x, p0)
                 lv1 = bb.emit_te(topi.exp, lv0)
@@ -565,7 +565,7 @@ def test_multiple_relax_functions():
 
         x = relax.Var("x", R.Tensor([10, 20], "float32"))
         p0 = relax.Var("p0", R.Tensor((), "float32"))
-        with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 
1}):
+        with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 
1}, private=True):
             with bb.dataflow():
                 lv0 = bb.emit_te(topi.add, x, p0)
                 lv1 = bb.emit_te(topi.exp, lv0)
@@ -575,7 +575,7 @@ def test_multiple_relax_functions():
 
         x = relax.Var("x", R.Tensor([20, 10], "float32"))
         p0 = relax.Var("p0", R.Tensor((), "float32"))
-        with bb.function("fused_add1_exp1_squeeze1", [x, p0], 
attrs={"Primitive": 1}):
+        with bb.function("fused_add1_exp1_squeeze1", [x, p0], 
attrs={"Primitive": 1}, private=True):
             with bb.dataflow():
                 lv0 = bb.emit_te(topi.add, x, p0)
                 lv1 = bb.emit_te(topi.exp, lv0)
diff --git a/tests/python/relax/test_transform_lambda_lift.py 
b/tests/python/relax/test_transform_lambda_lift.py
index ddc274fee2..d672484171 100644
--- a/tests/python/relax/test_transform_lambda_lift.py
+++ b/tests/python/relax/test_transform_lambda_lift.py
@@ -42,7 +42,7 @@ def test_basic():
     # the target IRModule
     @tvm.script.ir_module
     class Expected:
-        @R.function
+        @R.function(private=True)
         def lifted_func_0(
             x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32")
         ) -> R.Tensor((10, 5), "float32"):
@@ -97,12 +97,12 @@ def test_closure():
             )
             return res
 
-        @R.function
+        @R.function(private=True)
         def lifted_func_1(x1: R.Tensor((2, 3), "float32"), c1: R.Tensor((2, 
3), "float32")):
             r_1: R.Tensor((2, 3), "float32") = R.add(x1, c1)
             return r_1
 
-        @R.function
+        @R.function(private=True)
         def lifted_func_0(y: R.Tensor((2, 3), "float32")) -> R.Object:
             inner_func = R.make_closure(Expected.lifted_func_1, (y,))
             return inner_func
@@ -140,7 +140,7 @@ def test_recursive():
     # the expected IRModule
     @tvm.script.ir_module
     class Expected:
-        @R.function
+        @R.function(private=True)
         def lifted_func_0(
             i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32"), x: 
R.Tensor((2, 3), "float32")
         ) -> R.Tensor((2, 3), "float32"):
@@ -224,14 +224,14 @@ def test_multi_func():
             gv11: R.Tensor((10, 5), "float32") = inner(x11, y11)
             return gv11
 
-        @R.function
+        @R.function(private=True)
         def lifted_func_0(
             x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32")
         ) -> R.Tensor((10, 5), "float32"):
             s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
             return s
 
-        @R.function
+        @R.function(private=True)
         def lifted_func_1(
             x21: R.Tensor((10, 5), "float32"), y21: R.Tensor((10, 5), 
"float32")
         ) -> R.Tensor((10, 5), "float32"):
@@ -308,7 +308,7 @@ def test_no_local_func():
 def test_impure_function():
     @tvm.script.ir_module
     class Expected:
-        @R.function(pure=False)
+        @R.function(pure=False, private=True)
         def lifted_func_0() -> R.Tuple:
             y = R.print(format="Wow!")
             return y
diff --git a/tests/python/relax/test_transform_merge_composite_functions.py 
b/tests/python/relax/test_transform_merge_composite_functions.py
index 61df388c78..d552266131 100644
--- a/tests/python/relax/test_transform_merge_composite_functions.py
+++ b/tests/python/relax/test_transform_merge_composite_functions.py
@@ -41,7 +41,7 @@ class Conv2dReLUx2:
             R.output(gv)
         return gv
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_conv2d_relax_nn_relu(
         data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -57,7 +57,7 @@ class Conv2dReLUx2:
             R.output(gv1)
         return gv1
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_conv2d_relax_nn_relu1(
         conv1: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -164,7 +164,7 @@ class Diamond:
             R.output(gv2)
         return gv2
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_gelu(
         lv: R.Tensor((1, 64, 54, 54), dtype="float32")
     ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
@@ -174,7 +174,7 @@ class Diamond:
             R.output(gv)
         return gv
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_relu(
         lv1: R.Tensor((1, 64, 54, 54), dtype="float32")
     ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
@@ -184,7 +184,7 @@ class Diamond:
             R.output(gv1)
         return gv1
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_add(
         lv5: R.Tensor((1, 64, 54, 54), dtype="float32"),
         gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"),
@@ -195,7 +195,7 @@ class Diamond:
             R.output(gv3)
         return gv3
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_conv2d(
         data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -332,7 +332,7 @@ class Diamond_cyclic_dep:
             R.output(gv2)
         return gv2
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_gelu(
         lv: R.Tensor((1, 64, 54, 54), dtype="float32")
     ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
@@ -342,7 +342,7 @@ class Diamond_cyclic_dep:
             R.output(gv)
         return gv
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_relu(
         lv1: R.Tensor((1, 64, 54, 54), dtype="float32")
     ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
@@ -352,7 +352,7 @@ class Diamond_cyclic_dep:
             R.output(gv1)
         return gv1
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_add(
         lv5: R.Tensor((1, 64, 54, 54), dtype="float32"),
         gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"),
@@ -363,7 +363,7 @@ class Diamond_cyclic_dep:
             R.output(gv3)
         return gv3
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_conv2d(
         data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -509,7 +509,7 @@ class MultipleProducers:
             R.output(gv1)
         return gv1
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_relu(
         x11: R.Tensor((10,), dtype="float32")
     ) -> R.Tensor((10,), dtype="float32"):
@@ -519,7 +519,7 @@ class MultipleProducers:
             R.output(gv2)
         return gv2
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_gelu(
         x21: R.Tensor((10,), dtype="float32")
     ) -> R.Tensor((10,), dtype="float32"):
@@ -529,7 +529,7 @@ class MultipleProducers:
             R.output(gv3)
         return gv3
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_add(
         lv: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), 
dtype="float32")
     ) -> R.Tensor((10,), dtype="float32"):
@@ -627,7 +627,7 @@ class MultipleProducersCyclic:
             R.output(gv1)
         return gv1
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_relu(
         x11: R.Tensor((10,), dtype="float32")
     ) -> R.Tensor((10,), dtype="float32"):
@@ -637,7 +637,7 @@ class MultipleProducersCyclic:
             R.output(gv2)
         return gv2
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_gelu(
         x21: R.Tensor((10,), dtype="float32")
     ) -> R.Tensor((10,), dtype="float32"):
@@ -647,7 +647,7 @@ class MultipleProducersCyclic:
             R.output(gv3)
         return gv3
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_add(
         lv: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), 
dtype="float32")
     ) -> R.Tensor((10,), dtype="float32"):
@@ -759,7 +759,7 @@ class MergeCompilerRegionsExample:
             R.output(gv1)
         return gv1
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_relu(
         add2: R.Tensor((10,), dtype="float32")
     ) -> R.Tensor((10,), dtype="float32"):
@@ -769,7 +769,7 @@ class MergeCompilerRegionsExample:
             R.output(gv)
         return gv
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_add(
         x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,), 
dtype="float32")
     ) -> R.Tensor((10,), dtype="float32"):
@@ -779,7 +779,7 @@ class MergeCompilerRegionsExample:
             R.output(gv2)
         return gv2
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_gelu(
         x31: R.Tensor((10,), dtype="float32")
     ) -> R.Tensor((10,), dtype="float32"):
@@ -924,7 +924,7 @@ class ModuleWithNonComposite:
             R.output(conv)
         return conv
 
-    @R.function
+    @R.function(private=True)
     def fused_relax_nn_conv2d(
         data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -1071,7 +1071,7 @@ def test_reshape():
     # Verify that the non-CallNode input (shape in reshape) can be handled 
properly.
     @I.ir_module
     class Module:
-        @R.function
+        @R.function(private=True)
         def fused_relax_matmul(
             lv: R.Tensor((1, 784), dtype="float32"), lv1: R.Tensor((784, 512), 
dtype="float32")
         ) -> R.Tensor((1, 512), dtype="float32"):
@@ -1081,7 +1081,7 @@ def test_reshape():
                 R.output(gv)
             return gv
 
-        @R.function
+        @R.function(private=True)
         def fused_relax_reshape(
             inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"), param_0: 
R.Shape([1, 784])
         ) -> R.Tensor((1, 784), dtype="float32"):
diff --git a/tests/python/relax/test_transform_normalize.py 
b/tests/python/relax/test_transform_normalize.py
index 874e83c7f9..a6feb0b8ab 100644
--- a/tests/python/relax/test_transform_normalize.py
+++ b/tests/python/relax/test_transform_normalize.py
@@ -44,7 +44,7 @@ def test_normalize_function():
 
     after_mod = relax.transform.Normalize()(before_mod)
 
-    @R.function
+    @R.function(private=True)
     def expected(x: R.Tensor(("m", "n"), "float16")) -> 
R.Tensor(dtype="float16", ndim=2):
         gv = R.add(x, x)
         gv1 = R.add(x, x)
@@ -86,7 +86,7 @@ def test_normalize_if():
     before_mod = tvm.IRModule.from_expr(f)
     after_mod = relax.transform.Normalize()(before_mod)
 
-    @R.function
+    @R.function(private=True)
     def expected(
         cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")
     ) -> R.Tensor(dtype="float32", ndim=1):
@@ -151,7 +151,7 @@ def test_normalize_seq_body():
     before_mod = tvm.IRModule.from_expr(f)
     after_mod = relax.transform.Normalize()(before_mod)
 
-    @R.function
+    @R.function(private=True)
     def expected(
         x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32")
     ) -> R.Tensor(ndim=0, dtype="int32"):
@@ -175,7 +175,7 @@ def test_normalize_func_body():
     before_mod = tvm.IRModule.from_expr(f)
     after_mod = relax.transform.Normalize()(before_mod)
 
-    @R.function
+    @R.function(private=True)
     def expected(
         x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32")
     ) -> R.Tensor(ndim=0, dtype="int32"):
@@ -207,7 +207,7 @@ def test_normalize_if_branches():
     before_mod = tvm.IRModule.from_expr(f)
     after_mod = relax.transform.Normalize()(before_mod)
 
-    @R.function
+    @R.function(private=True)
     def expected(
         cond: R.Tensor((), dtype="bool"),
         x: R.Tensor((), dtype="int32"),
@@ -257,7 +257,7 @@ def test_normalize_if_condition():
     before_mod = tvm.IRModule.from_expr(f)
     after_mod = relax.transform.Normalize()(before_mod)
 
-    @R.function
+    @R.function(private=True)
     def expected(
         cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")
     ) -> R.Tensor(dtype="float32", ndim=1):
@@ -341,7 +341,7 @@ def test_normalize_combine_nearby_blocks():
 
     after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
 
-    @R.function
+    @R.function(private=True)
     def expected(x: R.Tensor((), "int32")):
         with R.dataflow():
             v0 = x
@@ -383,7 +383,7 @@ def test_normalize_nested_seq():
     )
     after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
 
-    @R.function
+    @R.function(private=True)
     def expected():
         x = relax.const(1)
         z = relax.const(2)
@@ -434,7 +434,7 @@ def test_normalize_nested_seq_dataflow():
     )
     after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
 
-    @R.function
+    @R.function(private=True)
     def expected():
         x = relax.const(1)
         q = relax.const(2)
@@ -507,7 +507,7 @@ def test_normalize_deeply_nested_seq():
     )
     after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
 
-    @R.function
+    @R.function(private=True)
     def expected():
         x = relax.const(1)
         u = relax.const(2)
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py 
b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index 747af6f296..4f25feb032 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -82,7 +82,7 @@ def test_rewrite_cuda_graph():
                         T.writes(compute[i0, i1])
                         compute[i0, i1] = T.exp(rxplaceholder[i0, i1], 
dtype="float32")
 
-        @R.function
+        @R.function(private=True)
         def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object, R.Object):
             R.func_attr({"relax.force_pure": True})
             storage: R.Object = R.memory.alloc_storage(R.shape([32]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
@@ -91,7 +91,7 @@ def test_rewrite_cuda_graph():
             gv: R.Tuple(R.Object, R.Object, R.Object) = (storage, storage1, 
storage2)
             return gv
 
-        @R.function
+        @R.function(private=True)
         def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), 
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object, storage2: 
R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")):
             R.func_attr({"relax.force_pure": True})
             cls = Expected
@@ -193,7 +193,7 @@ def test_tuple():
                         T.writes(compute[i0, i1])
                         compute[i0, i1] = T.exp(rxplaceholder[i0, i1])
 
-        @R.function
+        @R.function(private=True)
         def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
             R.func_attr({"relax.force_pure": True})
             storage: R.Object = R.memory.alloc_storage(R.shape([32]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
@@ -201,7 +201,7 @@ def test_tuple():
             gv: R.Tuple(R.Object, R.Object) = (storage, storage1)
             return gv
 
-        @R.function
+        @R.function(private=True)
         def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), 
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) -> 
R.Tuple(R.Tensor((2, 4), dtype="float32")):
             R.func_attr({"relax.force_pure": True})
             cls = Expected
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index 7a8bcdee26..9305cdbcb1 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -203,7 +203,7 @@ def test_simple_module():
 
     x = relax.Var("x", R.Tensor((128, 128), "float32"))
     bb = relax.BlockBuilder()
-    with bb.function("foo", (x,)):
+    with bb.function("foo", (x,), {"global_symbol": "foo"}):
         out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func")
         bb.emit_func_output(out)
 
@@ -232,7 +232,7 @@ def test_emit_te_primfunc_attrs():
 
     x = relax.Var("x", R.Tensor((128, 128), "float32"))
     bb = relax.BlockBuilder()
-    with bb.function("foo", (x,)):
+    with bb.function("foo", (x,), {"global_symbol": "foo"}):
         out = bb.emit_te(
             lambda x: x + 1,
             x,
@@ -254,7 +254,7 @@ def test_emit_te():
 
     bb = relax.BlockBuilder()
     x = relax.Var("x", relax.TensorStructInfo([10, 20], "float32"))
-    with bb.function("main", [x]):
+    with bb.function("main", [x], {"global_symbol": "main"}):
         lv1 = bb.emit_te(topi.add, x, x)
         out = bb.emit_te(topi.multiply, lv1, lv1)
         bb.emit_func_output(out)
@@ -294,7 +294,7 @@ def test_module_with_attr_and_global_info():
 
     x = relax.Var("x", R.Tensor((128, 128), "float32"))
     bb = relax.BlockBuilder()
-    with bb.function("foo", (x,)):
+    with bb.function("foo", (x,), {"global_symbol": "foo"}):
         out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func")
         bb.emit_func_output(out)
     mod = bb.get()
@@ -834,7 +834,7 @@ def test_call_dps_packed_empty_shape():
 def test_call_tir_empty_tuple_arg():
     bb = relax.BlockBuilder()
     dummy_param = relax.Var("dummy_param", R.Tensor(()))
-    with bb.function("foo", [dummy_param]):
+    with bb.function("foo", [dummy_param], {"global_symbol": "foo"}):
         output = bb.emit_te(topi.full, shape=(16, 32), dtype="float32", 
fill_value=1.0)
         bb.emit_func_output(output)
 
@@ -1493,5 +1493,22 @@ def test_call_pure_packed():
     _check(foo, bb.get()["foo"])
 
 
+def test_private_function():
+    @I.ir_module
+    class Addition:
+        @R.function(private=True)
+        def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+            y = R.add(x, x)
+            return y
+
+    x = relax.Var("x", R.Tensor((), "int32"))
+    bb = relax.BlockBuilder()
+    with bb.function("main", (x), private=True):
+        y = bb.emit(R.add(x, x))
+        bb.emit_func_output(y)
+
+    _check(Addition, bb.get())
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_datatype.py 
b/tests/python/relax/test_tvmscript_parser_op_datatype.py
index ec71e868d4..85c5faa866 100644
--- a/tests/python/relax/test_tvmscript_parser_op_datatype.py
+++ b/tests/python/relax/test_tvmscript_parser_op_datatype.py
@@ -47,7 +47,7 @@ def test_astype():
         gv = bb.emit(relax.op.astype(x, "float16"))
         bb.emit_func_output(gv)
 
-    _check(expected, bb.get()["main"])
+    _check(expected.with_attr("global_symbol", "main"), bb.get()["main"])
 
 
 if __name__ == "__main__":
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py 
b/tests/python/relax/test_tvmscript_printer_relax.py
index 7525c63be4..c376943173 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -27,12 +27,15 @@ def _assert_print(obj, expected):
     if not isinstance(obj, str):
         obj = obj.script(verbose_expr=True)
     obj = obj.strip()
-    assert obj == expected.strip(), "\n" + obj
+    # compare line by line in case there is trailing whitespace in the _middle_
+    for obj_line, expected_line in zip(obj.splitlines(), 
expected.strip().splitlines()):
+        assert obj_line.strip() == expected_line.strip(), "\n" + obj
 
 
 def test_function():
     @R.function
     def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):  # type: ignore
+        R.func_attr({"some_attr": 1})
         return a
 
     _assert_print(
@@ -41,19 +44,39 @@ def test_function():
 # from tvm.script import relax as R
 
 @R.function
+def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):
+    R.func_attr({"some_attr": 1})
+    return a""",
+    )
+
+
+def test_lone_private_function():
+    @R.function(private=True)
+    def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):  # type: ignore
+        R.func_attr({"some_attr": 1})
+        return a
+
+    # name prints as main because without a global symbol, the printer cannot 
assume a name
+    _assert_print(
+        func,
+        """
+# from tvm.script import relax as R
+
[email protected](private=True)
 def main(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):
+    R.func_attr({"some_attr": 1})
     return a""",
     )
 
 
 def test_extern_func():
     @R.function
-    def relax_func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):  # type: 
ignore
+    def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):  # type: ignore
         return a
 
     obj = IRModule(
         {
-            "func": relax_func,
+            "func": func,
             "my_ext": relax.ExternFunc("my_ext"),
         }
     )
@@ -73,6 +96,40 @@ class Module:
     )
 
 
+def test_nested_function():
+    @I.ir_module
+    class NestedFunction:
+        @R.function
+        def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+            @R.function
+            def nested(y: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+                return y
+
+            z = nested(x)
+            return z
+
+    _assert_print(
+        NestedFunction,
+        """
+# from tvm.script import ir as I
+# from tvm.script import relax as R
+
[email protected]_module
+class Module:
+    @R.function
+    def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+        # from tvm.script import relax as R
+
+        @R.function
+        def nested(y: R.Tensor((), dtype="int32")) -> R.Tensor((), 
dtype="int32"):
+            return y
+
+        z: R.Tensor((), dtype="int32") = nested(x)
+        return z
+""",
+    )
+
+
 def test_object_struct_info():
     obj = relax.ObjectStructInfo()
     _assert_print(
@@ -576,5 +633,101 @@ class Module:
     )
 
 
+def test_private_function():
+    @I.ir_module
+    class AddMod:
+        @R.function(private=True)
+        def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+            y: R.Tensor((), dtype="int32") = R.add(x, x)
+            return y
+
+    _assert_print(
+        AddMod,
+        """
+# from tvm.script import ir as I
+# from tvm.script import relax as R
+
[email protected]_module
+class Module:
+    @R.function(private=True)
+    def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+        y: R.Tensor((), dtype="int32") = R.add(x, x)
+        return y
+""",
+    )
+
+
+def test_directly_construct_private_funcs():
+    # public
+    @R.function
+    def foo(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+        y: R.Tensor((), dtype="int32") = R.add(x, x)
+        return y
+
+    # private
+    @R.function(private=True)
+    def bar(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+        y: R.Tensor((), dtype="int32") = R.multiply(x, x)
+        return y
+
+    # public but there's another attribute
+    @R.function
+    def baz(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+        R.func_attr({"relax.force_pure": True})
+        y: R.Tuple = R.print(format="Hi there!")
+        z: R.Tensor((), dtype="int32") = R.add(x, x)
+        return z
+
+    # private with an attribute
+    @R.function(private=True)
+    def quux(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+        R.func_attr({"relax.force_pure": True})
+        y: R.Tuple = R.print(format="Lol")
+        z: R.Tensor((), dtype="int32") = R.multiply(x, x)
+        return z
+
+    obj = IRModule(
+        {
+            "foo": foo,
+            "bar": bar,
+            "baz": baz,
+            "quux": quux,
+        }
+    )
+    _assert_print(
+        obj,
+        """
+# from tvm.script import ir as I
+# from tvm.script import relax as R
+
[email protected]_module
+class Module:
+    @R.function(private=True)
+    def bar(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+        y: R.Tensor((), dtype="int32") = R.multiply(x, x)
+        return y
+
+    @R.function
+    def baz(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+        R.func_attr({"relax.force_pure": 1})
+        y: R.Tuple = R.print(format=R.str("Hi there!"))
+        z: R.Tensor((), dtype="int32") = R.add(x, x)
+        return z
+
+    @R.function
+    def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+        y: R.Tensor((), dtype="int32") = R.add(x, x)
+        return y
+
+    @R.function(private=True)
+    def quux(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+        R.func_attr({"relax.force_pure": 1})
+        y: R.Tuple = R.print(format=R.str("Lol"))
+        z: R.Tensor((), dtype="int32") = R.multiply(x, x)
+        return z
+""",
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py
index c55876a3ba..f0c4ae0bd2 100644
--- a/tests/python/relax/test_utils.py
+++ b/tests/python/relax/test_utils.py
@@ -71,7 +71,9 @@ def test_copy_with_new_vars_on_ir_module():
             gv = R.add(x, y)
             return gv
 
-    Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"])
+    Actual["func_copied"] = 
relax.utils.copy_with_new_vars(Actual["func"]).with_attr(
+        "global_symbol", "func_copied"
+    )
 
     # Assertion will fail if the f_copied contains the same VarNode that's 
used in
     # the original function, due to var mapping during structural equal.
@@ -113,7 +115,9 @@ def test_copy_with_new_vars_on_ir_module_nested_function():
             gv = R.add(x, y)
             return gv
 
-    Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"])
+    Actual["func_copied"] = 
relax.utils.copy_with_new_vars(Actual["func"]).with_attr(
+        "global_symbol", "func_copied"
+    )
 
     assert_structural_equal(Actual, Expected)
 

Reply via email to