This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new daf9c202ac [Unity][IR][UX] Privacy annotation in Relax (#15140)
daf9c202ac is described below
commit daf9c202ac5ad0b0a64043b486894b139e64e4de
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Tue Jun 27 09:31:44 2023 -0400
[Unity][IR][UX] Privacy annotation in Relax (#15140)
This PR implements the privacy annotation proposal. Namely, the @R.function
decorator now has an optional private attribute. If a function is marked as
private, then it will not have a global symbol attached to it and thus will not
be externally accessible. By default, functions are not private, so the parser
does insert a global symbol for them.
---
include/tvm/script/ir_builder/relax/frame.h | 2 +
include/tvm/script/ir_builder/relax/ir.h | 3 +-
python/tvm/relax/block_builder.py | 12 ++
python/tvm/relax/frontend/torch/dynamo.py | 3 +-
python/tvm/relax/training/setup_trainer.py | 5 +-
python/tvm/script/ir_builder/relax/ir.py | 9 +-
python/tvm/script/parser/core/parser.py | 2 +
python/tvm/script/parser/relax/entry.py | 6 +-
python/tvm/script/parser/relax/parser.py | 27 ++--
src/relax/training/utils.cc | 4 +-
src/relax/transform/gradient.cc | 8 +-
src/relax/transform/lift_transform_params.cc | 5 +-
src/script/ir_builder/relax/frame.cc | 5 +
src/script/ir_builder/relax/ir.cc | 3 +-
src/script/printer/relax/function.cc | 69 ++++++++-
tests/python/relax/test_dataflow_pattern.py | 12 +-
tests/python/relax/test_training_loss.py | 10 ++
tests/python/relax/test_training_optimizer.py | 8 ++
.../relax/test_training_optimizer_numeric.py | 2 +-
.../relax/test_transform_attach_global_symbol.py | 6 +-
.../test_transform_combine_parallel_matmul.py | 23 +--
.../relax/test_transform_dead_code_elimination.py | 8 +-
tests/python/relax/test_transform_fuse_ops.py | 52 ++++---
.../relax/test_transform_fuse_ops_by_pattern.py | 29 ++--
tests/python/relax/test_transform_fuse_tir.py | 6 +-
tests/python/relax/test_transform_lambda_lift.py | 14 +-
.../test_transform_merge_composite_functions.py | 44 +++---
tests/python/relax/test_transform_normalize.py | 20 +--
.../relax/test_transform_rewrite_cuda_graph.py | 8 +-
tests/python/relax/test_tvmscript_parser.py | 27 +++-
.../relax/test_tvmscript_parser_op_datatype.py | 2 +-
tests/python/relax/test_tvmscript_printer_relax.py | 159 ++++++++++++++++++++-
tests/python/relax/test_utils.py | 8 +-
33 files changed, 456 insertions(+), 145 deletions(-)
diff --git a/include/tvm/script/ir_builder/relax/frame.h
b/include/tvm/script/ir_builder/relax/frame.h
index 9a8f835e81..1ad6813889 100644
--- a/include/tvm/script/ir_builder/relax/frame.h
+++ b/include/tvm/script/ir_builder/relax/frame.h
@@ -99,6 +99,8 @@ class FunctionFrameNode : public SeqExprFrameNode {
Optional<tvm::relax::StructInfo> ret_struct_info;
/*! \brief Whether the function is annotated as pure */
Optional<Bool> is_pure;
+ /*! \brief Whether the function is annotated as private */
+ Optional<Bool> is_private;
/*! \brief The function attributes. */
Map<String, ObjectRef> attrs;
/*! \brief The block builder to create Relax function. */
diff --git a/include/tvm/script/ir_builder/relax/ir.h
b/include/tvm/script/ir_builder/relax/ir.h
index 1cf30b4919..d160ad090e 100644
--- a/include/tvm/script/ir_builder/relax/ir.h
+++ b/include/tvm/script/ir_builder/relax/ir.h
@@ -34,9 +34,10 @@ namespace relax {
/*!
* \brief Start a function frame.
* \param is_pure Whether the function is annotated as pure.
+ * \param is_private Whether the function is annotated as private.
* \return The created ir_builder Function frame.
*/
-TVM_DLL FunctionFrame Function(const Bool& is_pure);
+TVM_DLL FunctionFrame Function(const Bool& is_pure, const Bool& is_private);
/*!
* \brief Add a parameter to the last function frame.
diff --git a/python/tvm/relax/block_builder.py
b/python/tvm/relax/block_builder.py
index 80edb31efb..502073edf2 100644
--- a/python/tvm/relax/block_builder.py
+++ b/python/tvm/relax/block_builder.py
@@ -199,6 +199,7 @@ class BlockBuilder(Object):
name: str,
params: Optional[Union[Var, Tuple, List[Var]]] = None,
attrs: Optional[Dict[str, Object]] = None,
+ private: bool = False,
) -> FunctionScope:
"""Annotate a Relax function.
@@ -215,6 +216,12 @@ class BlockBuilder(Object):
attrs : Dict[str, Object], optional
The function attrs
+ private : bool, optional
+ Whether the function is annotated as private.
+ If the function is private, it will not have a global symbol
attribute.
+ If it is not private and not an inner function, then it will have
+ a global symbol attribute (mapped to the function's name)
+
Returns
-------
ret: FunctionScope
@@ -233,6 +240,11 @@ class BlockBuilder(Object):
)
if attrs is None:
attrs = {}
+ # The block builder does not permit nesting functions, per above
comment,
+ # so no further check should be needed
+ if not private:
+ attrs["global_symbol"] = name
+
return FunctionScope(self, name, params, attrs)
def testing_scope(self, def_vars: List[tir.Var]) -> TestingScope:
diff --git a/python/tvm/relax/frontend/torch/dynamo.py
b/python/tvm/relax/frontend/torch/dynamo.py
index 3015f77428..abdf7b8862 100644
--- a/python/tvm/relax/frontend/torch/dynamo.py
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -154,7 +154,8 @@ def dynamo_capture_subgraphs(model, *params, **kwargs) ->
tvm.IRModule:
keep_params_as_input=keep_params_as_input,
unwrap_unit_return_tuple=True,
)
- mod[f"subgraph_{len(mod.get_global_vars())}"] = mod_["main"]
+ new_name = f"subgraph_{len(mod.get_global_vars())}"
+ mod[new_name] = mod_["main"].with_attr("global_symbol", new_name)
return graph_module.forward
dynamo.reset()
diff --git a/python/tvm/relax/training/setup_trainer.py
b/python/tvm/relax/training/setup_trainer.py
index 81ecaf4ea5..2e20570869 100644
--- a/python/tvm/relax/training/setup_trainer.py
+++ b/python/tvm/relax/training/setup_trainer.py
@@ -198,7 +198,10 @@ class SetupTrainer:
# Add optimizer function.
self._optimizer.init(params)
- mod[self.OPTIMIZER_FUNC] = self._optimizer.get_function()
+ # Need the global symbol to match the function's name
+ mod[self.OPTIMIZER_FUNC] = self._optimizer.get_function().with_attr(
+ "global_symbol", self.OPTIMIZER_FUNC
+ )
# Module attrs
mod = mod.with_attrs(
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index e54e4aa07b..b06d9547ac 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -165,19 +165,24 @@ py_str = str
############################### Function ################################
-def function(is_pure: bool = True) -> frame.FunctionFrame:
+def function(is_pure: bool = True, is_private: bool = False) ->
frame.FunctionFrame:
"""Start a function frame.
Parameters
----------
is_pure: bool
Whether the function is annotated as pure.
+ is_private : bool
+ Whether the function is annotated as private.
+
Returns
-------
frame: FunctionFrame
The constructed function frame.
"""
- return _ffi_api.Function(is_pure) # type: ignore[attr-defined] # pylint:
disable=no-member
+ return _ffi_api.Function( # type: ignore[attr-defined] # pylint:
disable=no-member
+ is_pure, is_private
+ )
def arg(name: py_str, struct_info: StructInfo) -> Var:
diff --git a/python/tvm/script/parser/core/parser.py
b/python/tvm/script/parser/core/parser.py
index 9275924466..69e262b1d3 100644
--- a/python/tvm/script/parser/core/parser.py
+++ b/python/tvm/script/parser/core/parser.py
@@ -240,6 +240,7 @@ class Parser(doc.NodeVisitor):
dispatch_tokens: List[str]
function_annotations: Optional[Dict[str, Dict[str, Any]]]
var_table: VarTable
+ inside_function: bool # whether we are within a function
def __init__(
self,
@@ -250,6 +251,7 @@ class Parser(doc.NodeVisitor):
self.dispatch_tokens = ["default"]
self.function_annotations = function_annotations
self.var_table = VarTable()
+ self.inside_function = False
def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
"""The main parse method for parser.
diff --git a/python/tvm/script/parser/relax/entry.py
b/python/tvm/script/parser/relax/entry.py
index 2711e855dd..ff237a5600 100644
--- a/python/tvm/script/parser/relax/entry.py
+++ b/python/tvm/script/parser/relax/entry.py
@@ -45,9 +45,11 @@ FType = TypeVar("FType", bound=_Callable)
# this formulation allows us to support having @R.function
# appear as a decorator by itself or to have optional arguments
# like @R.function(pure=False)
-def function(f: Optional[FType] = None, pure: bool = True) -> Union[Function,
FType]:
+def function(
+ f: Optional[FType] = None, pure: bool = True, private: bool = False
+) -> Union[Function, FType]:
# pylint: disable=unused-argument
- # (pure isn't used here, but is used later in parsing)
+ # (pure and private aren't used here, but are used later in parsing)
# need to inspect the stack first because is_defined_in_class expects the
outer class
# to be in a particular position in the stack
diff --git a/python/tvm/script/parser/relax/parser.py
b/python/tvm/script/parser/relax/parser.py
index 427c56bcc8..863c249975 100644
--- a/python/tvm/script/parser/relax/parser.py
+++ b/python/tvm/script/parser/relax/parser.py
@@ -160,6 +160,9 @@ def collect_symbolic_var_from_params(self: Parser, node:
doc.FunctionDef) -> Non
@dispatch.register(token="relax", type_name="FunctionDef")
def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
+ is_inner_function = self.inside_function
+ self.inside_function = True
+
# reserve a var for local function
func_val = self.var_table.get().get(node.name)
if not func_val and is_recursive(node):
@@ -178,11 +181,14 @@ def visit_function_def(self: Parser, node:
doc.FunctionDef) -> None:
local_func_var = relax.Var(node.name,
relax.FuncStructInfo(params_sinfo, ret_sinfo))
self.var_table.add(node.name, local_func_var)
- purity = find_purity_annotation(node)
+ purity = find_decorator_annotation(node, "pure")
+ # treat the function as private if we are inside another function
+ # or if it has a privacy annotation
+ privacy = is_inner_function or find_decorator_annotation(node, "private",
default=False)
with self.var_table.with_frame():
with self.with_dispatch_token("relax"):
- with R.function(is_pure=purity):
+ with R.function(is_pure=purity, is_private=privacy):
R.func_name(node.name)
collect_symbolic_var_from_params(self, node)
@@ -202,22 +208,22 @@ def visit_function_def(self: Parser, node:
doc.FunctionDef) -> None:
self.report_error(stmt, "inline prim_func is
disallowed in Relax IR")
self.visit_body(node.body)
+ self.inside_function = is_inner_function
-def find_purity_annotation(node: doc.FunctionDef) -> bool:
+def find_decorator_annotation(node: doc.FunctionDef, annotation: str, default:
bool = True) -> bool:
"""
- Check the value of `pure` in the function decorator.
- Returns the annotated purity if present, otherwise defaulting to True.
- This allows for specifying the purity in the function signature.
+ Check the value of given annotation (argument name) in the function
decorator.
+ Returns the value of the annotation if present, otherwise giving the
default value.
"""
- # look for the pure argument in the function decorator
+ # look for the named argument in the function decorator
for dec in node.decorator_list:
if not isinstance(dec, doc.Call) or dec.func.attr != "function":
continue
for keyword in dec.keywords:
- if keyword.arg == "pure":
+ if keyword.arg == annotation:
return keyword.value.value
- return True
+ return default
@dispatch.register(token="relax", type_name="tvm_declare_function")
@@ -238,7 +244,8 @@ def visit_tvm_declare_function(self: Parser, node:
doc.FunctionDef) -> GlobalVar
param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True)
params.append(relax.Var(arg.arg, param_sinfo))
- is_pure = find_purity_annotation(node)
+ is_pure = find_decorator_annotation(node, "pure")
+
func_signature = relax.Function.create_empty(params, ret_sinfo,
is_pure=is_pure)
return I.decl_function(node.name, func_signature)
diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc
index 37582e3015..19faaad58b 100644
--- a/src/relax/training/utils.cc
+++ b/src/relax/training/utils.cc
@@ -48,7 +48,9 @@ class AppendLossMutator : private ExprMutator {
Function new_loss_func = CopyWithNewVars(loss_function);
AppendLossMutator mutator(mod, new_loss_func, num_backbone_outputs);
- auto new_func_transformed =
Downcast<Function>(mutator.VisitExpr(new_func));
+ auto new_func_transformed =
+ WithAttr(Downcast<Function>(mutator.VisitExpr(new_func)),
tvm::attr::kGlobalSymbol,
+ new_func_name.value_or(func_name + "_loss"));
auto new_module = GetRef<IRModule>(mod.CopyOnWrite());
auto new_var = GlobalVar(new_func_name.value_or(func_name + "_loss"));
diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc
index 7645ae8cb6..2cda7a972d 100644
--- a/src/relax/transform/gradient.cc
+++ b/src/relax/transform/gradient.cc
@@ -333,10 +333,14 @@ class GradientMutator : private ExprMutator {
}
GradientMutator mutator(mod, require_grads_value, target_index);
- Function new_func_transformed =
Downcast<Function>(mutator.VisitExpr(new_func));
+
+ // make the adjoint public
+ auto new_name = func_name + "_adjoint";
+ Function new_func_transformed =
WithAttr(Downcast<Function>(mutator.VisitExpr(new_func)),
+ tvm::attr::kGlobalSymbol,
new_name);
IRModule new_module = GetRef<IRModule>(mod.CopyOnWrite());
- new_module->Add(GlobalVar(func_name + "_adjoint"), new_func_transformed);
+ new_module->Add(GlobalVar(new_name), new_func_transformed);
return new_module;
}
diff --git a/src/relax/transform/lift_transform_params.cc
b/src/relax/transform/lift_transform_params.cc
index f7c9a4189d..fb1f292776 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -251,7 +251,10 @@ class TransformParamsLifter : public ExprMutator {
lift_plan_ = planner.Plan(func, num_input);
// Step 2: Add the lifted function to the module
- builder_->AddFunction(lift_plan_.f_transform_params, new_func_name);
+ // (The lifted function should be public so we add a global symbol to it)
+ auto lift_func =
+ WithAttr(lift_plan_.f_transform_params, tvm::attr::kGlobalSymbol,
new_func_name);
+ builder_->AddFunction(lift_func, new_func_name);
// Step 3: Update the current function.
diff --git a/src/script/ir_builder/relax/frame.cc
b/src/script/ir_builder/relax/frame.cc
index 00bbd2a551..966af809c9 100644
--- a/src/script/ir_builder/relax/frame.cc
+++ b/src/script/ir_builder/relax/frame.cc
@@ -56,6 +56,11 @@ void FunctionFrameNode::ExitWithScope() {
"`return` to return an Expr";
this->block_builder->BeginScope(params);
Expr body =
this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks,
output.value()));
+ // if the function is not private, add a global symbol to its attributes
+ if (!is_private.value_or(Bool(false))->value && name.defined() &&
+ !attrs.count(tvm::attr::kGlobalSymbol)) {
+ attrs.Set(tvm::attr::kGlobalSymbol, name.value());
+ }
auto dict_attrs = attrs.empty() ? NullValue<DictAttrs>() : DictAttrs(attrs);
this->block_builder->EndScope();
tvm::relax::Function func(/*params=*/params,
diff --git a/src/script/ir_builder/relax/ir.cc
b/src/script/ir_builder/relax/ir.cc
index 52d9f0cfe1..d66e8d0598 100644
--- a/src/script/ir_builder/relax/ir.cc
+++ b/src/script/ir_builder/relax/ir.cc
@@ -52,7 +52,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable)
/////////////////////////////// Function ////////////////////////////////
-FunctionFrame Function(const Bool& is_pure) {
+FunctionFrame Function(const Bool& is_pure, const Bool& is_private) {
ObjectPtr<FunctionFrameNode> n = make_object<FunctionFrameNode>();
const IRBuilder& ir_builder = IRBuilder::Current();
Optional<tvm::IRModule> mod = NullOpt;
@@ -61,6 +61,7 @@ FunctionFrame Function(const Bool& is_pure) {
}
n->block_builder = tvm::relax::BlockBuilder::Create(/*mod=*/mod);
n->is_pure = is_pure;
+ n->is_private = is_private;
return FunctionFrame(n);
}
diff --git a/src/script/printer/relax/function.cc
b/src/script/printer/relax/function.cc
index bd5d969563..bc5f12309f 100644
--- a/src/script/printer/relax/function.cc
+++ b/src/script/printer/relax/function.cc
@@ -22,13 +22,36 @@ namespace tvm {
namespace script {
namespace printer {
+bool AtTopLevelFunction(const IRDocsifier& d) {
+ // fewer than 2 frames: not in a function at all
+ if (d->frames.size() < 2) {
+ return false;
+ }
+ // if the first frame is a RelaxFrame, then this is not inside a module.
+ // 2 frames => we are at a function (more than 2 => nested function)
+ if (d->frames[0]->IsInstance<RelaxFrameNode>()) {
+ return d->frames.size() == 2;
+ }
+ // otherwise the first two frames pertain to an IR module,
+ // so 3 frames => we are at a top-level function (more than 3 => nested
function)
+ return d->frames.size() == 3;
+}
+
TVM_REGISTER_NODE_TYPE(RelaxFrameNode);
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::Function>("", [](relax::Function n, ObjectPath n_p,
IRDocsifier d) -> Doc {
std::unordered_set<const tir::VarNode*> func_vars;
With<RelaxFrame> f(d);
- IdDoc func_name = d->Define(n, f(), FindFunctionName(d,
n).value_or("main"));
+
+ IdDoc func_name("");
+ // if we are binding a local definition, then calling d->Define
+ // will result in a repeated definition and an incorrect displayed name
+ if (Optional<String> name = GetBindingName(d)) {
+ func_name = std::move(IdDoc(name.value()));
+ } else {
+ func_name = std::move(d->Define(n, f(), FindFunctionName(d,
n).value_or("main")));
+ }
(*f)->AddDispatchToken(d, "relax");
(*f)->is_func = true;
(*f)->func_vars = &func_vars;
@@ -52,17 +75,49 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
(*f)->func_vars = nullptr;
// Step 4. Print attributes
if (n->attrs.defined() && !n->attrs->dict.empty()) {
- (*f)->stmts.push_back(
- ExprStmtDoc(Relax(d, "func_attr") //
- ->Call({d->AsDoc<ExprDoc>(n->attrs,
n_p->Attr("attrs"))})));
+ // If the function is a global function and has a global symbol,
+ // then don't print the global symbol (it will be implicit from not
being private).
+ // For a function without an IR module whose global symbol
+ // doesn't match the function name, we should still print the global
symbol attribute.
+ if (AtTopLevelFunction(d) &&
n->attrs->dict.count(tvm::attr::kGlobalSymbol) &&
+ Downcast<String>(n->attrs->dict.at(tvm::attr::kGlobalSymbol)) ==
func_name->name) {
+ Map<String, ObjectRef> new_attrs;
+ for (auto kv : n->attrs->dict) {
+ if (kv.first != tvm::attr::kGlobalSymbol) {
+ new_attrs.Set(kv.first, kv.second);
+ }
+ }
+ if (!new_attrs.empty()) {
+ (*f)->stmts.push_back(ExprStmtDoc(
+ Relax(d, "func_attr") //
+ ->Call({d->AsDoc<ExprDoc>(DictAttrs(new_attrs),
n_p->Attr("attrs"))})));
+ }
+ } else {
+ (*f)->stmts.push_back(
+ ExprStmtDoc(Relax(d, "func_attr") //
+ ->Call({d->AsDoc<ExprDoc>(n->attrs,
n_p->Attr("attrs"))})));
+ }
}
// Step 5. Prepare the decorator (include purity if it's impure)
ExprDoc decorator = Relax(d, "function");
+ Array<ExprDoc, void> pos_args = {};
+ Array<String, void> dec_keys;
+ Array<ExprDoc, void> dec_values;
if (!n->is_pure) {
- Array<ExprDoc> pos_args = {};
- decorator = std::move(decorator->Call(
- pos_args, {"pure"}, {LiteralDoc::Boolean(false,
Optional<ObjectPath>())}));
+ dec_keys.push_back("pure");
+ dec_values.push_back(LiteralDoc::Boolean(false,
Optional<ObjectPath>()));
}
+ // if the function is global or is not in a module and does not have a
global symbol,
+ // indicate that it's private
+ if (AtTopLevelFunction(d) &&
+ (!n->attrs.defined() ||
!n->attrs->dict.count(tvm::attr::kGlobalSymbol))) {
+ dec_keys.push_back("private");
+ dec_values.push_back(LiteralDoc::Boolean(true,
Optional<ObjectPath>()));
+ }
+ if (dec_keys.size()) {
+ decorator = std::move(decorator->Call(pos_args, dec_keys, dec_values));
+ }
+
// Step 6. Print body
Array<StmtDoc> body =
PrintSeqExpr(Downcast<relax::SeqExpr>(n->body), n_p->Attr("body"),
d, /*use_ret=*/true);
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index cbb19c6743..ea83807bf8 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -875,7 +875,7 @@ def test_rewrite_simple():
return R.multiply(matchings[x], R.const(2, "float32"))
rewritten = rewrite_call(pattern, rewriter, main)
- tvm.ir.assert_structural_equal(rewritten, expected1)
+ tvm.ir.assert_structural_equal(rewritten,
expected1.with_attr("global_symbol", "main"))
add1 = is_op("relax.add")(x, x)
pattern = is_op("relax.add")(add1, add1)
@@ -884,7 +884,7 @@ def test_rewrite_simple():
return R.multiply(matchings[x], R.const(4, "float32"))
rewritten = rewrite_call(pattern, rewriter, main)
- tvm.ir.assert_structural_equal(rewritten, expected2)
+ tvm.ir.assert_structural_equal(rewritten,
expected2.with_attr("global_symbol", "main"))
# No rewriting, return the original call node as is
def rewriter(orig, _):
@@ -959,7 +959,7 @@ def test_rewrite_attention():
return R.nn.attention(matchings[Q], matchings[K], matchings[V])
rewritten = rewrite_call(pattern, rewriter, main)
- tvm.ir.assert_structural_equal(rewritten, expected)
+ tvm.ir.assert_structural_equal(rewritten,
expected.with_attr("global_symbol", "main"))
def test_attention_qkv():
@@ -1115,7 +1115,7 @@ def test_combine_matmul_twice():
inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1,
matmul2, matmul3
)
rewritten = rewrite_bindings(ctx, rewriter, qkv_x2)
- tvm.ir.assert_structural_equal(rewritten, expected)
+ tvm.ir.assert_structural_equal(rewritten,
expected.with_attr("global_symbol", "qkv_x2"))
def test_combine_matmul_emit_order():
@@ -1173,7 +1173,7 @@ def test_combine_matmul_emit_order():
inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1,
matmul2, matmul3
)
rewritten = rewrite_bindings(ctx, rewriter, main)
- tvm.ir.assert_structural_equal(rewritten, expected)
+ tvm.ir.assert_structural_equal(rewritten,
expected.with_attr("global_symbol", "main"))
# make sure it builds
mod = tvm.IRModule()
@@ -1272,7 +1272,7 @@ def test_combine_transposed_matmul_twice():
rewritten = rewrite_bindings(ctx, rewriter, main)
print(rewritten.script())
- tvm.ir.assert_structural_equal(rewritten, expected)
+ tvm.ir.assert_structural_equal(rewritten,
expected.with_attr("global_symbol", "main"))
# make sure it builds
mod = tvm.IRModule()
diff --git a/tests/python/relax/test_training_loss.py
b/tests/python/relax/test_training_loss.py
index 0a2418aad7..0d456ceb38 100644
--- a/tests/python/relax/test_training_loss.py
+++ b/tests/python/relax/test_training_loss.py
@@ -46,6 +46,7 @@ def test_l1_loss():
def expected(
predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5),
"float32")
) -> R.Tensor((), "float32"):
+ R.func_attr({"global_symbol": "l1_loss"})
with R.dataflow():
lv: R.Tensor((3, 5), "float32") = R.subtract(predictions, targets)
lv1: R.Tensor((3, 5), "float32") = R.abs(lv)
@@ -70,6 +71,7 @@ def test_l1_loss_append():
b: R.Tensor((2, 4), "float32"),
targets: R.Tensor((2, 4), "float32"),
) -> R.Tensor((), "float32"):
+ R.func_attr({"global_symbol": "forward_loss"})
with R.dataflow():
lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="")
out: R.Tensor((2, 4), "float32") = R.add(lv, b)
@@ -93,6 +95,7 @@ def test_mse_loss():
def expected(
predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5),
"float32")
) -> R.Tensor((), "float32"):
+ R.func_attr({"global_symbol": "mse_loss"})
with R.dataflow():
lv: R.Tensor((3, 5), "float32") = R.subtract(predictions, targets)
lv1: R.Tensor((3, 5), "float32") = R.multiply(lv, lv)
@@ -117,6 +120,7 @@ def test_mse_loss_append():
b: R.Tensor((2, 4), "float32"),
targets: R.Tensor((2, 4), "float32"),
) -> R.Tensor((), "float32"):
+ R.func_attr({"global_symbol": "forward_loss"})
with R.dataflow():
lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="")
out: R.Tensor((2, 4), "float32") = R.add(lv, b)
@@ -143,6 +147,7 @@ def test_cross_entropy_loss():
targets: R.Tensor((3,), "int64"),
weights: R.Tensor((5,), "float32"),
) -> R.Tensor((), "float32"):
+ R.func_attr({"global_symbol": "cross_entropy_loss"})
with R.dataflow():
lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions,
axis=-1)
gv: R.Tensor((), "float32") = R.nn.nll_loss(
@@ -165,6 +170,7 @@ def test_cross_entropy_loss_without_weights():
def expected(
predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3,),
"int64")
) -> R.Tensor((), "float32"):
+ R.func_attr({"global_symbol": "cross_entropy_loss"})
with R.dataflow():
lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions,
axis=-1)
gv: R.Tensor((), "float32") = R.nn.nll_loss(
@@ -195,6 +201,7 @@ def test_cross_entropy_loss_append():
targets: R.Tensor((2,), "int64"),
weights: R.Tensor((4,), "float32"),
) -> R.Tensor((), "float32"):
+ R.func_attr({"global_symbol": "forward_loss"})
with R.dataflow():
lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="")
out: R.Tensor((2, 4), "float32") = R.add(lv, b)
@@ -224,6 +231,7 @@ def test_categorical_cross_entropy_loss():
targets: R.Tensor((3, 5), "int64"),
weights: R.Tensor((5,), "float32"),
) -> R.Tensor((), "float32"):
+ R.func_attr({"global_symbol": "categorical_cross_entropy_loss"})
with R.dataflow():
lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions,
axis=-1)
lv: R.Tensor((), "float32") = -lv * targets.astype("float32")
@@ -245,6 +253,7 @@ def test_categorical_cross_entropy_loss_without_weights():
def expected(
predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5),
"int64")
) -> R.Tensor((), "float32"):
+ R.func_attr({"global_symbol": "categorical_cross_entropy_loss"})
with R.dataflow():
lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions,
axis=-1)
gv: R.Tensor((), "float32") = R.mean(-lv *
targets.astype("float32"))
@@ -270,6 +279,7 @@ def test_categorical_cross_entropy_loss_with_ignore_index():
targets: R.Tensor((3, 5), "int64"),
weights: R.Tensor((5,), "float32"),
) -> R.Tensor((), "float32"):
+ R.func_attr({"global_symbol": "categorical_cross_entropy_loss"})
with R.dataflow():
lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions,
axis=-1)
targets = relax.op.reshape(
diff --git a/tests/python/relax/test_training_optimizer.py
b/tests/python/relax/test_training_optimizer.py
index b2246087c6..514422da8d 100644
--- a/tests/python/relax/test_training_optimizer.py
+++ b/tests/python/relax/test_training_optimizer.py
@@ -67,6 +67,7 @@ def test_sgd_simple():
R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
R.Tuple(R.Tensor((), "int64")),
):
+ R.func_attr({"global_symbol": "SGD"})
# block 0
with R.dataflow():
num_steps: R.Tensor((), "int64") = optim_states[0]
@@ -104,6 +105,7 @@ def test_sgd_complex():
R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
R.Tuple(R.Tensor((), "int64")),
):
+ R.func_attr({"global_symbol": "SGD"})
with R.dataflow():
num_steps: R.Tensor((), "int64") = optim_states[0]
num_steps_new: R.Tensor((), "int64") = R.add(num_steps, R.const(1,
"int64"))
@@ -146,6 +148,7 @@ def test_momentum_sgd_simple():
R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
R.Tuple(R.Tensor((), "int64"), R.Tensor((3, 3), "float32"),
R.Tensor((3,), "float32")),
):
+ R.func_attr({"global_symbol": "MomentumSGD"})
# block 0
with R.dataflow():
num_steps: R.Tensor((), "int64") = optim_states[0]
@@ -195,6 +198,7 @@ def test_momentum_sgd_complex():
R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
R.Tuple(R.Tensor((), "int64"), R.Tensor((3, 3), "float32"),
R.Tensor((3,), "float32")),
):
+ R.func_attr({"global_symbol": "MomentumSGD"})
# block 0
with R.dataflow():
num_steps: R.Tensor((), "int64") = optim_states[0]
@@ -250,6 +254,7 @@ def test_momentum_sgd_nesterov():
R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
R.Tuple(R.Tensor((), "int64"), R.Tensor((3, 3), "float32"),
R.Tensor((3,), "float32")),
):
+ R.func_attr({"global_symbol": "MomentumSGD"})
# block 0
with R.dataflow():
num_steps: R.Tensor((), "int64") = optim_states[0]
@@ -321,6 +326,7 @@ def test_adam_simple():
R.Tensor((3,), "float32"),
),
):
+ R.func_attr({"global_symbol": "Adam"})
# block 0
with R.dataflow():
num_steps: R.Tensor((), "int64") = optim_states[0]
@@ -418,6 +424,7 @@ def test_adam_complex():
R.Tensor((3,), "float32"),
),
):
+ R.func_attr({"global_symbol": "Adam"})
# block 0
with R.dataflow():
num_steps: R.Tensor((), "int64") = optim_states[0]
@@ -519,6 +526,7 @@ def test_adam_float64():
R.Tensor((3,), "float64"),
),
):
+ R.func_attr({"global_symbol": "Adam"})
# block 0
with R.dataflow():
num_steps: R.Tensor((), "int64") = optim_states[0]
diff --git a/tests/python/relax/test_training_optimizer_numeric.py
b/tests/python/relax/test_training_optimizer_numeric.py
index 3b300e8261..23db8987f1 100644
--- a/tests/python/relax/test_training_optimizer_numeric.py
+++ b/tests/python/relax/test_training_optimizer_numeric.py
@@ -69,7 +69,7 @@ def _test_optimizer(target, dev, np_func, opt_type, *args,
**kwargs):
x = relax.Var("x", R.Tensor((3, 3), "float32"))
y = relax.Var("y", R.Tensor((3,), "float32"))
opt = opt_type(*args, **kwargs).init([x, y])
- mod = IRModule.from_expr(opt.get_function())
+ mod = IRModule.from_expr(opt.get_function().with_attr("global_symbol",
"main"))
tvm_func = _legalize_and_build(mod, target, dev)["main"]
param_arr = [np.random.rand(3, 3).astype(np.float32),
np.random.rand(3).astype(np.float32)]
diff --git a/tests/python/relax/test_transform_attach_global_symbol.py
b/tests/python/relax/test_transform_attach_global_symbol.py
index 035e21609d..680df96947 100644
--- a/tests/python/relax/test_transform_attach_global_symbol.py
+++ b/tests/python/relax/test_transform_attach_global_symbol.py
@@ -43,7 +43,7 @@ def test_basic():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
- @R.function
+ @R.function(private=True)
def main(
x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"),
"float32")
) -> R.Tensor:
@@ -74,7 +74,6 @@ def test_basic():
def main(
x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"),
"float32")
) -> R.Tensor:
- R.func_attr({"global_symbol": "main"})
m, n, k = T.int64(), T.int64(), T.int64()
gv0 = R.call_tir(Expected.tir_matmul, (x, w), R.Tensor((m, k),
dtype="float32"))
return gv0
@@ -94,7 +93,7 @@ def test_system_lib_prefix():
def tir_zeros(x: T.Buffer((2), "float32")) -> None:
x[0] = T.float32(0)
- @R.function
+ @R.function(private=True)
def main() -> R.Tensor:
gv0 = R.call_tir(Before.tir_zeros, (), R.Tensor((2,),
dtype="float32"))
return gv0
@@ -110,7 +109,6 @@ def test_system_lib_prefix():
@R.function
def main() -> R.Tensor:
- R.func_attr({"global_symbol": "main"})
gv0 = R.call_tir(Expected.tir_zeros, (), R.Tensor((2,),
dtype="float32"))
return gv0
diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py
b/tests/python/relax/test_transform_combine_parallel_matmul.py
index 719daaf449..97211f0dd0 100644
--- a/tests/python/relax/test_transform_combine_parallel_matmul.py
+++ b/tests/python/relax/test_transform_combine_parallel_matmul.py
@@ -97,7 +97,7 @@ def test_simple():
R.output(lv3)
return lv3
- tvm.ir.assert_structural_equal(mod["main"], expected1)
+ tvm.ir.assert_structural_equal(mod["main"],
expected1.with_attr("global_symbol", "main"))
# Test a batched LHS case, slicing is done on the axis 2
mod = get_parallel_matmul(3, lhs_shape=(2, 1024, 640))
@@ -121,7 +121,7 @@ def test_simple():
R.output(lv3)
return lv3
- tvm.ir.assert_structural_equal(mod["main"], expected2)
+ tvm.ir.assert_structural_equal(mod["main"],
expected2.with_attr("global_symbol", "main"))
def test_bias():
@@ -151,7 +151,7 @@ def test_bias():
R.output(lv6)
return lv6
- tvm.ir.assert_structural_equal(mod["main"], expected1)
+ tvm.ir.assert_structural_equal(mod["main"],
expected1.with_attr("global_symbol", "main"))
mod = get_parallel_matmul(3, with_bias=[True, False, True])
mod = CombineParallelMatmul()(mod)
@@ -178,7 +178,7 @@ def test_bias():
R.output(lv5)
return lv5
- tvm.ir.assert_structural_equal(mod["main"], expected2)
+ tvm.ir.assert_structural_equal(mod["main"],
expected2.with_attr("global_symbol", "main"))
def test_activation():
@@ -204,7 +204,7 @@ def test_activation():
R.output(lv6)
return lv6
- tvm.ir.assert_structural_equal(mod["main"], expected1)
+ tvm.ir.assert_structural_equal(mod["main"],
expected1.with_attr("global_symbol", "main"))
mod = get_parallel_matmul(3, activation=["gelu", "relu", "relu"])
mod = CombineParallelMatmul()(mod)
@@ -230,7 +230,7 @@ def test_activation():
R.output(lv6)
return lv6
- tvm.ir.assert_structural_equal(mod["main"], expected2)
+ tvm.ir.assert_structural_equal(mod["main"],
expected2.with_attr("global_symbol", "main"))
mod = get_parallel_matmul(3, activation=["relu", None, None])
mod = CombineParallelMatmul()(mod)
@@ -255,7 +255,7 @@ def test_activation():
R.output(lv4)
return lv4
- tvm.ir.assert_structural_equal(mod["main"], expected3)
+ tvm.ir.assert_structural_equal(mod["main"],
expected3.with_attr("global_symbol", "main"))
def test_bias_activation():
@@ -286,7 +286,7 @@ def test_bias_activation():
R.output(lv9)
return lv9
- tvm.ir.assert_structural_equal(mod["main"], expected1)
+ tvm.ir.assert_structural_equal(mod["main"],
expected1.with_attr("global_symbol", "main"))
mod = get_parallel_matmul(3, with_bias=[True, True, True],
activation=["relu", None, "relu"])
mod = CombineParallelMatmul()(mod)
@@ -316,7 +316,7 @@ def test_bias_activation():
R.output(lv8)
return lv8
- tvm.ir.assert_structural_equal(mod["main"], expected2)
+ tvm.ir.assert_structural_equal(mod["main"],
expected2.with_attr("global_symbol", "main"))
mod = get_parallel_matmul(3, with_bias=[True, False, True],
activation=["relu", None, "relu"])
mod = CombineParallelMatmul()(mod)
@@ -345,7 +345,7 @@ def test_bias_activation():
R.output(lv7)
return lv7
- tvm.ir.assert_structural_equal(mod["main"], expected3)
+ tvm.ir.assert_structural_equal(mod["main"],
expected3.with_attr("global_symbol", "main"))
def test_rhs_batched():
@@ -378,6 +378,7 @@ def test_rhs_batched():
w2: R.Tensor((2, 640, 640), dtype="float32"),
w3: R.Tensor((3, 4, 640, 640), dtype="float32"),
) -> R.Tensor:
+ R.func_attr({"global_symbol": "main"})
with R.dataflow():
lv = R.concat((w0, w2), axis=2)
lv1 = R.matmul(x, lv, out_dtype="float32")
@@ -458,6 +459,7 @@ def test_multiple_combine():
b0: R.Tensor((640,), dtype="float32"),
b1: R.Tensor((640,), dtype="float32"),
) -> R.Tensor:
+ R.func_attr({"global_symbol": "main"})
with R.dataflow():
lv = R.concat((w0, w1, w2), axis=1)
lv1 = R.matmul(x1, lv, out_dtype="float32")
@@ -515,6 +517,7 @@ def test_check():
w3: R.Tensor((640, 640), dtype="float32"),
w4: R.Tensor((640, 640), dtype="float32"),
) -> R.Tensor:
+ R.func_attr({"global_symbol": "main"})
with R.dataflow():
lv = R.concat((w0, w1, w2), axis=1)
lv1 = R.matmul(x1, lv, out_dtype="float32")
diff --git a/tests/python/relax/test_transform_dead_code_elimination.py
b/tests/python/relax/test_transform_dead_code_elimination.py
index 9c6e0e0567..12a3de6acb 100644
--- a/tests/python/relax/test_transform_dead_code_elimination.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -168,7 +168,7 @@ def test_unused_relax_func():
vi, vj = T.axis.remap("SS", [i, j])
z[vi, vj] = x[vi, vj] + y[vi, vj]
- @R.function
+ @R.function(private=True)
def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16,
16), "float32")):
gv0 = R.add(x, w)
return gv0
@@ -202,7 +202,7 @@ def test_unused_relax_func_custom_entry_func():
vi, vj = T.axis.remap("SS", [i, j])
z[vi, vj] = x[vi, vj] + y[vi, vj]
- @R.function
+ @R.function(private=True)
def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16,
16), "float32")):
gv0 = R.add(x, w)
return gv0
@@ -239,7 +239,7 @@ def test_unused_relax_func_symbolic_shape():
vi, vj = T.axis.remap("SS", [i, j])
z[vi, vj] = x[vi, vj] + y[vi, vj]
- @R.function
+ @R.function(private=True)
def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n",
"k"), "float32")):
gv0 = R.add(x, w)
return gv0
@@ -310,7 +310,7 @@ def test_multiple_unused_funcs():
vi, vj = T.axis.remap("SS", [i, j])
z[vi, vj] = x[vi, vj] + y[vi, vj]
- @R.function
+ @R.function(private=True)
def unused_func2(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16,
16), "float32")):
gv0 = R.add(x, w)
return gv0
diff --git a/tests/python/relax/test_transform_fuse_ops.py
b/tests/python/relax/test_transform_fuse_ops.py
index 14c3dbe713..b51f651025 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -48,7 +48,7 @@ def test_fuse_simple():
x = relax.Var("x", R.Tensor([10, 20], "float32"))
p0 = relax.Var("p0", R.Tensor((), "float32"))
- with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive":
1}):
+ with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive":
1}, private=True):
with bb.dataflow():
lv0 = bb.emit_te(topi.add, x, p0)
lv1 = bb.emit_te(topi.exp, lv0)
@@ -100,7 +100,9 @@ def test_conv2d_fuse():
x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
w = relax.Var("w", R.Tensor((16, 16, 3, 3), dtype))
p0 = relax.Var("p0", R.Tensor((), dtype))
- with bb.function("fused_conv2d_add1_add2", [x, w, p0],
attrs={"Primitive": 1}):
+ with bb.function(
+ "fused_conv2d_add1_add2", [x, w, p0], attrs={"Primitive": 1},
private=True
+ ):
with bb.dataflow():
lv0 = bb.emit_te(
topi.nn.conv2d,
@@ -119,7 +121,7 @@ def test_conv2d_fuse():
x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
w = relax.Var("w", R.Tensor((16, 16, 1, 1), dtype))
y = relax.Var("y", R.Tensor((1, 16, 64, 64), dtype))
- with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive":
1}):
+ with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive":
1}, private=True):
with bb.dataflow():
lv0 = bb.emit_te(
topi.nn.conv2d,
@@ -196,7 +198,9 @@ def test_concatenate():
x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
w = relax.Var("w", R.Tensor((1, 16, 32, 32), "float32"))
p0 = relax.Var("p0", R.Tensor((), "float32"))
- with bb.function("fused_upsampling_concatenate_add", [w, x, p0],
attrs={"Primitive": 1}):
+ with bb.function(
+ "fused_upsampling_concatenate_add", [w, x, p0],
attrs={"Primitive": 1}, private=True
+ ):
with bb.dataflow():
lv0 = bb.emit_te(topi.nn.upsampling, w, scale_h=2.0,
scale_w=2.0)
lv1 = bb.emit_te(topi.concatenate, (lv0, x), axis=1)
@@ -287,7 +291,10 @@ def test_fuse_tuple_get_elemwise():
# Grouped function
dense = relax.Var("dense", R.Tensor((1, 3 * dim), "float32"))
with bb.function(
- "fused_split_sigmoid_tanh_exp_multiply_add", [dense],
attrs={"Primitive": 1}
+ "fused_split_sigmoid_tanh_exp_multiply_add",
+ [dense],
+ attrs={"Primitive": 1},
+ private=True,
):
with bb.dataflow():
lv0 = bb.emit_te(topi.split, dense, indices_or_sections=3,
axis=1)
@@ -340,7 +347,7 @@ def test_tuple_get_root():
# Grouped function
x = relax.Var("x", R.Tensor((1, 3 * dim), "float32"))
- with bb.function("fused_split", [x], attrs={"Primitive": 1}):
+ with bb.function("fused_split", [x], attrs={"Primitive": 1},
private=True):
with bb.dataflow():
lv0 = bb.emit_te(topi.split, x, indices_or_sections=3, axis=1)
gv = bb.emit_output(relax.TupleGetItem(lv0, 0))
@@ -398,6 +405,7 @@ def test_tuple_intermediate():
"fused_squeeze_add_squeeze1_add_add_add_concatenate_squeeze2_add1",
[x, p0, p1, p2, p3, p4],
attrs={"Primitive": 1},
+ private=True,
):
with bb.dataflow():
lv0 = bb.emit_te(topi.squeeze, x)
@@ -500,6 +508,7 @@ def test_tuple_consecutive():
"fused_add_add_add_concatenate_add1_add_add_add_concatenate_add1_add_add_add_concatenate_add1_concatenate1",
[x, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11],
attrs={"Primitive": 1},
+ private=True,
):
with bb.dataflow():
lv0 = bb.emit_te(topi.add, x, p0)
@@ -523,7 +532,7 @@ def test_tuple_consecutive():
# Grouped function 2
concat = relax.Var("concat", R.Tensor((1, 144, 64, 64), "float32"))
p0 = relax.Var("p0", R.Tensor((), "float32"))
- with bb.function("fused_pool2d_add2", [concat, p0],
attrs={"Primitive": 1}):
+ with bb.function("fused_pool2d_add2", [concat, p0],
attrs={"Primitive": 1}, private=True):
with bb.dataflow():
lv0 = bb.emit_te(
topi.nn.pool2d,
@@ -609,7 +618,7 @@ def test_inception_like():
# Grouped function 1
x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
w = relax.Var("w", R.Tensor((16, 16, 3, 3), "float32"))
- with bb.function("fused_conv2d_relu", [x, w], attrs={"Primitive": 1}):
+ with bb.function("fused_conv2d_relu", [x, w], attrs={"Primitive": 1},
private=True):
with bb.dataflow():
lv0 = bb.emit_te(
topi.nn.conv2d,
@@ -626,7 +635,7 @@ def test_inception_like():
# Grouped function 2
x = relax.Var("x", R.Tensor((1, 32, 64, 64), "float32"))
w = relax.Var("w", R.Tensor((16, 32, 3, 3), "float32"))
- with bb.function("fused_conv2d1_relu", [x, w], attrs={"Primitive": 1}):
+ with bb.function("fused_conv2d1_relu", [x, w], attrs={"Primitive": 1},
private=True):
with bb.dataflow():
lv0 = bb.emit_te(
topi.nn.conv2d,
@@ -689,7 +698,10 @@ def test_fuse_parallel_injective():
x = relax.Var("x", R.Tensor((10, 20), "int32"))
p0 = relax.Var("p0", R.Tensor((), "int32"))
with bb.function(
- "fused_add_squeeze_transpose_transpose1_left_shift", [x, p0],
attrs={"Primitive": 1}
+ "fused_add_squeeze_transpose_transpose1_left_shift",
+ [x, p0],
+ attrs={"Primitive": 1},
+ private=True,
):
with bb.dataflow():
lv0 = bb.emit_te(topi.add, x, p0)
@@ -734,7 +746,7 @@ def test_softmax():
# Grouped function
x = relax.Var("x", R.Tensor((16, 16), "float32"))
- with bb.function("fused_softmax_cast", [x], attrs={"Primitive": 1}):
+ with bb.function("fused_softmax_cast", [x], attrs={"Primitive": 1},
private=True):
with bb.dataflow():
lv0 = bb.emit_te(topi.nn.softmax, x)
gv = bb.emit_output(bb.call_te(topi.cast, lv0,
dtype="float16"))
@@ -781,7 +793,7 @@ def test_multiple_relax_functions():
x = relax.Var("x", R.Tensor([10, 20], "float32"))
p0 = relax.Var("p0", R.Tensor((), "float32"))
- with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive":
1}):
+ with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive":
1}, private=True):
with bb.dataflow():
lv0 = bb.emit_te(topi.add, x, p0)
lv1 = bb.emit_te(topi.exp, lv0)
@@ -791,7 +803,7 @@ def test_multiple_relax_functions():
x = relax.Var("x", R.Tensor([20, 10], "float32"))
p0 = relax.Var("p0", R.Tensor((), "float32"))
- with bb.function("fused_add1_exp1_squeeze1", [x, p0],
attrs={"Primitive": 1}):
+ with bb.function("fused_add1_exp1_squeeze1", [x, p0],
attrs={"Primitive": 1}, private=True):
with bb.dataflow():
lv0 = bb.emit_te(topi.add, x, p0)
lv1 = bb.emit_te(topi.exp, lv0)
@@ -938,7 +950,7 @@ def test_layer_norm_silu():
T.writes(B[v_i0, v_i1, v_i2, v_i3])
B[v_i0, v_i1, v_i2, v_i3] = T.max(A[v_i0, v_i1, v_i2,
v_i3], T.float32(0))
- @R.function
+ @R.function(private=True)
def fused_layer_norm_relu(x: R.Tensor((1, 512, 64, 64),
dtype="float32"), mean: R.Tensor((64, 64), dtype="float32"), var: R.Tensor((64,
64), dtype="float32")) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
R.func_attr({"Primitive": 1})
cls = Expected
@@ -1080,7 +1092,7 @@ def test_multiple_paths():
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
- @R.function
+ @R.function(private=True)
def fused_conv2d_add_add2(inp_0: R.Tensor((2, 320, 64, 64),
dtype="float32"), w1: R.Tensor((320, 320, 3, 3), dtype="float32"), lv28:
R.Tensor((1, 320, 1, 1), dtype="float32"), lv35: R.Tensor((2, 320, 1, 1),
dtype="float32")) -> R.Tensor((2, 320, 64, 64), dtype="float32"):
R.func_attr({"Primitive": 1})
cls = Expected
@@ -1091,7 +1103,7 @@ def test_multiple_paths():
R.output(gv)
return gv
- @R.function
+ @R.function(private=True)
def fused_matmul_add1(inp_1: R.Tensor((2, 1280), dtype="float32"),
lv31: R.Tensor((1280, 320), dtype="float32"), b2: R.Tensor((320,),
dtype="float32")) -> R.Tensor((2, 320), dtype="float32"):
cls = Expected
R.func_attr({"Primitive": 1})
@@ -1225,7 +1237,7 @@ def test_dead_group():
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
- @R.function
+ @R.function(private=True)
def fused_matmul1_add1(inp_1: R.Tensor((1, 128), dtype="float32"),
lv4: R.Tensor((128, 10), dtype="float32"), linear2_bias: R.Tensor((10,),
dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"Primitive": 1})
cls = Expected
@@ -1267,7 +1279,7 @@ def test_symbolic_shape_aware_fuse():
@I.ir_module
class Expected:
- @R.function
+ @R.function(private=True)
def fused_add_exp_squeeze(
x: R.Tensor(["n", "m"], "float32"), p0: R.Tensor([], "float32")
) -> R.Tensor(["n", "m"], dtype="float32"):
@@ -1305,7 +1317,7 @@ def test_symbolic_shape_aware_fuse_2():
@I.ir_module
class Expected:
- @R.function
+ @R.function(private=True)
def fused_full_trilu_broadcast_to(
s: R.Shape(["n"]),
) -> R.Tensor([1, 1, "n", "n"], "float32"):
@@ -1353,7 +1365,7 @@ def test_shape_expr_arg():
@I.ir_module
class Expected:
- @R.function
+ @R.function(private=True)
def fused_full_trilu_broadcast_to(
s: R.Shape(["n"]),
) -> R.Tensor([1, 1, "n", "n"], "float32"):
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index 5fb2b3332c..592132516b 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -60,7 +60,7 @@ class Conv2dReLU_composite_annotated:
R.output(gv)
return gv
- @R.function
+ @R.function(private=True)
def fused_relax_nn_conv2d_relax_nn_relu_dnnl(
data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -124,7 +124,7 @@ class Conv2dReLUx2Partitioned:
R.output(gv)
return gv
- @R.function
+ @R.function(private=True)
def fused_relax_nn_conv2d_relax_nn_relu(
data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -138,7 +138,7 @@ class Conv2dReLUx2Partitioned:
R.output(gv1)
return gv1
- @R.function
+ @R.function(private=True)
def fused_relax_nn_conv2d_relax_nn_relu1(
conv1: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -174,7 +174,7 @@ class Conv2dReLUx2Partitioned_only_conv2d:
R.output(conv2d)
return conv2d
- @R.function
+ @R.function(private=True)
def fused_relax_nn_conv2d(
data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -187,7 +187,7 @@ class Conv2dReLUx2Partitioned_only_conv2d:
R.output(gv)
return gv
- @R.function
+ @R.function(private=True)
def fused_relax_nn_conv2d1(
conv11: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -236,7 +236,7 @@ class Conv2dConv2dReLUPartitioned:
R.output(gv)
return gv
- @R.function
+ @R.function(private=True)
def fused_relax_nn_conv2d_relax_nn_relu(
conv1: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -250,7 +250,7 @@ class Conv2dConv2dReLUPartitioned:
R.output(gv1)
return gv1
- @R.function
+ @R.function(private=True)
def fused_relax_nn_conv2d(
data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -303,7 +303,7 @@ class BranchTupleOutputPartitioned:
R.output(out)
return out
- @R.function
+ @R.function(private=True)
def fused_relax_nn_conv2d_relax_nn_relu(
data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -377,7 +377,7 @@ class Conv2dx2:
@tvm.script.ir_module
class Conv2dx2_partitioned:
- @R.function
+ @R.function(private=True)
def fused_relax_nn_conv2d_cutlass(
data: R.Tensor((16, 32, 32, 16), dtype="float16"),
weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
@@ -565,7 +565,7 @@ def test_ignore_call_tir():
T.writes(out[i, j, k, l])
out[i, j, k, l] = T.max(data[i, j, k, l], T.float32(0))
- @R.function
+ @R.function(private=True)
def fused_relax_nn_conv2d(
data: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -617,7 +617,7 @@ def test_unused():
@I.ir_module
class Conv2dReLU_partitioned:
- @R.function
+ @R.function(private=True)
def fused_relax_nn_conv2d(
data: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -685,7 +685,7 @@ def test_bind_constants():
@I.ir_module
class Conv2dWithConstantWeight_partitioned:
- @R.function
+ @R.function(private=True)
def fused_relax_nn_conv2d(
data: R.Tensor((1, 64, 56, 56), dtype="float32"),
param_0: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -721,6 +721,7 @@ def test_bind_constants():
def test_split():
@R.function
def func(inp: R.Tensor((16, 32), "float32")):
+ R.func_attr({"global_symbol": "main"})
with R.dataflow():
tup = R.split(inp, [16], axis=1)
out = R.add(tup[0], tup[1])
@@ -729,7 +730,7 @@ def test_split():
@tvm.script.ir_module
class Expected1:
- @R.function
+ @R.function(private=True)
def fused_relax_split(
inp: R.Tensor((16, 32), dtype="float32")
) -> R.Tuple(R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16),
dtype="float32")):
@@ -756,7 +757,7 @@ def test_split():
@I.ir_module
class Expected2:
- @R.function
+ @R.function(private=True)
def fused_relax_split_relax_add(
inp: R.Tensor((16, 32), dtype="float32")
) -> R.Tensor((16, 16), dtype="float32"):
diff --git a/tests/python/relax/test_transform_fuse_tir.py
b/tests/python/relax/test_transform_fuse_tir.py
index 00dc714654..f59e3f2e9e 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -32,7 +32,7 @@ def test_simple():
x = relax.Var("x", R.Tensor([10, 20], "float32"))
p0 = relax.Var("p0", R.Tensor([], "float32"))
- with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive":
True}):
+ with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive":
True}, private=True):
with bb.dataflow():
lv0 = bb.emit_te(topi.add, x, p0)
lv1 = bb.emit_te(topi.exp, lv0)
@@ -565,7 +565,7 @@ def test_multiple_relax_functions():
x = relax.Var("x", R.Tensor([10, 20], "float32"))
p0 = relax.Var("p0", R.Tensor((), "float32"))
- with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive":
1}):
+ with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive":
1}, private=True):
with bb.dataflow():
lv0 = bb.emit_te(topi.add, x, p0)
lv1 = bb.emit_te(topi.exp, lv0)
@@ -575,7 +575,7 @@ def test_multiple_relax_functions():
x = relax.Var("x", R.Tensor([20, 10], "float32"))
p0 = relax.Var("p0", R.Tensor((), "float32"))
- with bb.function("fused_add1_exp1_squeeze1", [x, p0],
attrs={"Primitive": 1}):
+ with bb.function("fused_add1_exp1_squeeze1", [x, p0],
attrs={"Primitive": 1}, private=True):
with bb.dataflow():
lv0 = bb.emit_te(topi.add, x, p0)
lv1 = bb.emit_te(topi.exp, lv0)
diff --git a/tests/python/relax/test_transform_lambda_lift.py
b/tests/python/relax/test_transform_lambda_lift.py
index ddc274fee2..d672484171 100644
--- a/tests/python/relax/test_transform_lambda_lift.py
+++ b/tests/python/relax/test_transform_lambda_lift.py
@@ -42,7 +42,7 @@ def test_basic():
# the target IRModule
@tvm.script.ir_module
class Expected:
- @R.function
+ @R.function(private=True)
def lifted_func_0(
x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32")
) -> R.Tensor((10, 5), "float32"):
@@ -97,12 +97,12 @@ def test_closure():
)
return res
- @R.function
+ @R.function(private=True)
def lifted_func_1(x1: R.Tensor((2, 3), "float32"), c1: R.Tensor((2,
3), "float32")):
r_1: R.Tensor((2, 3), "float32") = R.add(x1, c1)
return r_1
- @R.function
+ @R.function(private=True)
def lifted_func_0(y: R.Tensor((2, 3), "float32")) -> R.Object:
inner_func = R.make_closure(Expected.lifted_func_1, (y,))
return inner_func
@@ -140,7 +140,7 @@ def test_recursive():
# the expected IRModule
@tvm.script.ir_module
class Expected:
- @R.function
+ @R.function(private=True)
def lifted_func_0(
i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32"), x:
R.Tensor((2, 3), "float32")
) -> R.Tensor((2, 3), "float32"):
@@ -224,14 +224,14 @@ def test_multi_func():
gv11: R.Tensor((10, 5), "float32") = inner(x11, y11)
return gv11
- @R.function
+ @R.function(private=True)
def lifted_func_0(
x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32")
) -> R.Tensor((10, 5), "float32"):
s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
return s
- @R.function
+ @R.function(private=True)
def lifted_func_1(
x21: R.Tensor((10, 5), "float32"), y21: R.Tensor((10, 5),
"float32")
) -> R.Tensor((10, 5), "float32"):
@@ -308,7 +308,7 @@ def test_no_local_func():
def test_impure_function():
@tvm.script.ir_module
class Expected:
- @R.function(pure=False)
+ @R.function(pure=False, private=True)
def lifted_func_0() -> R.Tuple:
y = R.print(format="Wow!")
return y
diff --git a/tests/python/relax/test_transform_merge_composite_functions.py
b/tests/python/relax/test_transform_merge_composite_functions.py
index 61df388c78..d552266131 100644
--- a/tests/python/relax/test_transform_merge_composite_functions.py
+++ b/tests/python/relax/test_transform_merge_composite_functions.py
@@ -41,7 +41,7 @@ class Conv2dReLUx2:
R.output(gv)
return gv
- @R.function
+ @R.function(private=True)
def fused_relax_nn_conv2d_relax_nn_relu(
data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -57,7 +57,7 @@ class Conv2dReLUx2:
R.output(gv1)
return gv1
- @R.function
+ @R.function(private=True)
def fused_relax_nn_conv2d_relax_nn_relu1(
conv1: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -164,7 +164,7 @@ class Diamond:
R.output(gv2)
return gv2
- @R.function
+ @R.function(private=True)
def fused_relax_nn_gelu(
lv: R.Tensor((1, 64, 54, 54), dtype="float32")
) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
@@ -174,7 +174,7 @@ class Diamond:
R.output(gv)
return gv
- @R.function
+ @R.function(private=True)
def fused_relax_nn_relu(
lv1: R.Tensor((1, 64, 54, 54), dtype="float32")
) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
@@ -184,7 +184,7 @@ class Diamond:
R.output(gv1)
return gv1
- @R.function
+ @R.function(private=True)
def fused_relax_add(
lv5: R.Tensor((1, 64, 54, 54), dtype="float32"),
gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"),
@@ -195,7 +195,7 @@ class Diamond:
R.output(gv3)
return gv3
- @R.function
+ @R.function(private=True)
def fused_relax_nn_conv2d(
data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -332,7 +332,7 @@ class Diamond_cyclic_dep:
R.output(gv2)
return gv2
- @R.function
+ @R.function(private=True)
def fused_relax_nn_gelu(
lv: R.Tensor((1, 64, 54, 54), dtype="float32")
) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
@@ -342,7 +342,7 @@ class Diamond_cyclic_dep:
R.output(gv)
return gv
- @R.function
+ @R.function(private=True)
def fused_relax_nn_relu(
lv1: R.Tensor((1, 64, 54, 54), dtype="float32")
) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
@@ -352,7 +352,7 @@ class Diamond_cyclic_dep:
R.output(gv1)
return gv1
- @R.function
+ @R.function(private=True)
def fused_relax_add(
lv5: R.Tensor((1, 64, 54, 54), dtype="float32"),
gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"),
@@ -363,7 +363,7 @@ class Diamond_cyclic_dep:
R.output(gv3)
return gv3
- @R.function
+ @R.function(private=True)
def fused_relax_nn_conv2d(
data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -509,7 +509,7 @@ class MultipleProducers:
R.output(gv1)
return gv1
- @R.function
+ @R.function(private=True)
def fused_relax_nn_relu(
x11: R.Tensor((10,), dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
@@ -519,7 +519,7 @@ class MultipleProducers:
R.output(gv2)
return gv2
- @R.function
+ @R.function(private=True)
def fused_relax_nn_gelu(
x21: R.Tensor((10,), dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
@@ -529,7 +529,7 @@ class MultipleProducers:
R.output(gv3)
return gv3
- @R.function
+ @R.function(private=True)
def fused_relax_add(
lv: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,),
dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
@@ -627,7 +627,7 @@ class MultipleProducersCyclic:
R.output(gv1)
return gv1
- @R.function
+ @R.function(private=True)
def fused_relax_nn_relu(
x11: R.Tensor((10,), dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
@@ -637,7 +637,7 @@ class MultipleProducersCyclic:
R.output(gv2)
return gv2
- @R.function
+ @R.function(private=True)
def fused_relax_nn_gelu(
x21: R.Tensor((10,), dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
@@ -647,7 +647,7 @@ class MultipleProducersCyclic:
R.output(gv3)
return gv3
- @R.function
+ @R.function(private=True)
def fused_relax_add(
lv: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,),
dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
@@ -759,7 +759,7 @@ class MergeCompilerRegionsExample:
R.output(gv1)
return gv1
- @R.function
+ @R.function(private=True)
def fused_relax_nn_relu(
add2: R.Tensor((10,), dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
@@ -769,7 +769,7 @@ class MergeCompilerRegionsExample:
R.output(gv)
return gv
- @R.function
+ @R.function(private=True)
def fused_relax_add(
x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,),
dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
@@ -779,7 +779,7 @@ class MergeCompilerRegionsExample:
R.output(gv2)
return gv2
- @R.function
+ @R.function(private=True)
def fused_relax_nn_gelu(
x31: R.Tensor((10,), dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
@@ -924,7 +924,7 @@ class ModuleWithNonComposite:
R.output(conv)
return conv
- @R.function
+ @R.function(private=True)
def fused_relax_nn_conv2d(
data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
@@ -1071,7 +1071,7 @@ def test_reshape():
# Verify that the non-CallNode input (shape in reshape) can be handled
properly.
@I.ir_module
class Module:
- @R.function
+ @R.function(private=True)
def fused_relax_matmul(
lv: R.Tensor((1, 784), dtype="float32"), lv1: R.Tensor((784, 512),
dtype="float32")
) -> R.Tensor((1, 512), dtype="float32"):
@@ -1081,7 +1081,7 @@ def test_reshape():
R.output(gv)
return gv
- @R.function
+ @R.function(private=True)
def fused_relax_reshape(
inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"), param_0:
R.Shape([1, 784])
) -> R.Tensor((1, 784), dtype="float32"):
diff --git a/tests/python/relax/test_transform_normalize.py
b/tests/python/relax/test_transform_normalize.py
index 874e83c7f9..a6feb0b8ab 100644
--- a/tests/python/relax/test_transform_normalize.py
+++ b/tests/python/relax/test_transform_normalize.py
@@ -44,7 +44,7 @@ def test_normalize_function():
after_mod = relax.transform.Normalize()(before_mod)
- @R.function
+ @R.function(private=True)
def expected(x: R.Tensor(("m", "n"), "float16")) ->
R.Tensor(dtype="float16", ndim=2):
gv = R.add(x, x)
gv1 = R.add(x, x)
@@ -86,7 +86,7 @@ def test_normalize_if():
before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)
- @R.function
+ @R.function(private=True)
def expected(
cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")
) -> R.Tensor(dtype="float32", ndim=1):
@@ -151,7 +151,7 @@ def test_normalize_seq_body():
before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)
- @R.function
+ @R.function(private=True)
def expected(
x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32")
) -> R.Tensor(ndim=0, dtype="int32"):
@@ -175,7 +175,7 @@ def test_normalize_func_body():
before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)
- @R.function
+ @R.function(private=True)
def expected(
x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32")
) -> R.Tensor(ndim=0, dtype="int32"):
@@ -207,7 +207,7 @@ def test_normalize_if_branches():
before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)
- @R.function
+ @R.function(private=True)
def expected(
cond: R.Tensor((), dtype="bool"),
x: R.Tensor((), dtype="int32"),
@@ -257,7 +257,7 @@ def test_normalize_if_condition():
before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)
- @R.function
+ @R.function(private=True)
def expected(
cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")
) -> R.Tensor(dtype="float32", ndim=1):
@@ -341,7 +341,7 @@ def test_normalize_combine_nearby_blocks():
after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
- @R.function
+ @R.function(private=True)
def expected(x: R.Tensor((), "int32")):
with R.dataflow():
v0 = x
@@ -383,7 +383,7 @@ def test_normalize_nested_seq():
)
after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
- @R.function
+ @R.function(private=True)
def expected():
x = relax.const(1)
z = relax.const(2)
@@ -434,7 +434,7 @@ def test_normalize_nested_seq_dataflow():
)
after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
- @R.function
+ @R.function(private=True)
def expected():
x = relax.const(1)
q = relax.const(2)
@@ -507,7 +507,7 @@ def test_normalize_deeply_nested_seq():
)
after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
- @R.function
+ @R.function(private=True)
def expected():
x = relax.const(1)
u = relax.const(2)
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py
b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index 747af6f296..4f25feb032 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -82,7 +82,7 @@ def test_rewrite_cuda_graph():
T.writes(compute[i0, i1])
compute[i0, i1] = T.exp(rxplaceholder[i0, i1],
dtype="float32")
- @R.function
+ @R.function(private=True)
def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object, R.Object):
R.func_attr({"relax.force_pure": True})
storage: R.Object = R.memory.alloc_storage(R.shape([32]),
R.prim_value(0), R.str("global"), R.dtype("float32"))
@@ -91,7 +91,7 @@ def test_rewrite_cuda_graph():
gv: R.Tuple(R.Object, R.Object, R.Object) = (storage, storage1,
storage2)
return gv
- @R.function
+ @R.function(private=True)
def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"),
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object, storage2:
R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")):
R.func_attr({"relax.force_pure": True})
cls = Expected
@@ -193,7 +193,7 @@ def test_tuple():
T.writes(compute[i0, i1])
compute[i0, i1] = T.exp(rxplaceholder[i0, i1])
- @R.function
+ @R.function(private=True)
def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
R.func_attr({"relax.force_pure": True})
storage: R.Object = R.memory.alloc_storage(R.shape([32]),
R.prim_value(0), R.str("global"), R.dtype("float32"))
@@ -201,7 +201,7 @@ def test_tuple():
gv: R.Tuple(R.Object, R.Object) = (storage, storage1)
return gv
- @R.function
+ @R.function(private=True)
def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"),
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) ->
R.Tuple(R.Tensor((2, 4), dtype="float32")):
R.func_attr({"relax.force_pure": True})
cls = Expected
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index 7a8bcdee26..9305cdbcb1 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -203,7 +203,7 @@ def test_simple_module():
x = relax.Var("x", R.Tensor((128, 128), "float32"))
bb = relax.BlockBuilder()
- with bb.function("foo", (x,)):
+ with bb.function("foo", (x,), {"global_symbol": "foo"}):
out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func")
bb.emit_func_output(out)
@@ -232,7 +232,7 @@ def test_emit_te_primfunc_attrs():
x = relax.Var("x", R.Tensor((128, 128), "float32"))
bb = relax.BlockBuilder()
- with bb.function("foo", (x,)):
+ with bb.function("foo", (x,), {"global_symbol": "foo"}):
out = bb.emit_te(
lambda x: x + 1,
x,
@@ -254,7 +254,7 @@ def test_emit_te():
bb = relax.BlockBuilder()
x = relax.Var("x", relax.TensorStructInfo([10, 20], "float32"))
- with bb.function("main", [x]):
+ with bb.function("main", [x], {"global_symbol": "main"}):
lv1 = bb.emit_te(topi.add, x, x)
out = bb.emit_te(topi.multiply, lv1, lv1)
bb.emit_func_output(out)
@@ -294,7 +294,7 @@ def test_module_with_attr_and_global_info():
x = relax.Var("x", R.Tensor((128, 128), "float32"))
bb = relax.BlockBuilder()
- with bb.function("foo", (x,)):
+ with bb.function("foo", (x,), {"global_symbol": "foo"}):
out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func")
bb.emit_func_output(out)
mod = bb.get()
@@ -834,7 +834,7 @@ def test_call_dps_packed_empty_shape():
def test_call_tir_empty_tuple_arg():
bb = relax.BlockBuilder()
dummy_param = relax.Var("dummy_param", R.Tensor(()))
- with bb.function("foo", [dummy_param]):
+ with bb.function("foo", [dummy_param], {"global_symbol": "foo"}):
output = bb.emit_te(topi.full, shape=(16, 32), dtype="float32",
fill_value=1.0)
bb.emit_func_output(output)
@@ -1493,5 +1493,22 @@ def test_call_pure_packed():
_check(foo, bb.get()["foo"])
+def test_private_function():
+ @I.ir_module
+ class Addition:
+ @R.function(private=True)
+ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ y = R.add(x, x)
+ return y
+
+ x = relax.Var("x", R.Tensor((), "int32"))
+ bb = relax.BlockBuilder()
+ with bb.function("main", (x), private=True):
+ y = bb.emit(R.add(x, x))
+ bb.emit_func_output(y)
+
+ _check(Addition, bb.get())
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_datatype.py
b/tests/python/relax/test_tvmscript_parser_op_datatype.py
index ec71e868d4..85c5faa866 100644
--- a/tests/python/relax/test_tvmscript_parser_op_datatype.py
+++ b/tests/python/relax/test_tvmscript_parser_op_datatype.py
@@ -47,7 +47,7 @@ def test_astype():
gv = bb.emit(relax.op.astype(x, "float16"))
bb.emit_func_output(gv)
- _check(expected, bb.get()["main"])
+ _check(expected.with_attr("global_symbol", "main"), bb.get()["main"])
if __name__ == "__main__":
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py
b/tests/python/relax/test_tvmscript_printer_relax.py
index 7525c63be4..c376943173 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -27,12 +27,15 @@ def _assert_print(obj, expected):
if not isinstance(obj, str):
obj = obj.script(verbose_expr=True)
obj = obj.strip()
- assert obj == expected.strip(), "\n" + obj
+ # compare line by line in case there is trailing whitespace in the _middle_
+ for obj_line, expected_line in zip(obj.splitlines(),
expected.strip().splitlines()):
+ assert obj_line.strip() == expected_line.strip(), "\n" + obj
def test_function():
@R.function
def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore
+ R.func_attr({"some_attr": 1})
return a
_assert_print(
@@ -41,19 +44,39 @@ def test_function():
# from tvm.script import relax as R
@R.function
+def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):
+ R.func_attr({"some_attr": 1})
+ return a""",
+ )
+
+
+def test_lone_private_function():
+ @R.function(private=True)
+ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore
+ R.func_attr({"some_attr": 1})
+ return a
+
+ # name prints as main because without a global symbol, the printer cannot
assume a name
+ _assert_print(
+ func,
+ """
+# from tvm.script import relax as R
+
[email protected](private=True)
def main(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):
+ R.func_attr({"some_attr": 1})
return a""",
)
def test_extern_func():
@R.function
- def relax_func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type:
ignore
+ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore
return a
obj = IRModule(
{
- "func": relax_func,
+ "func": func,
"my_ext": relax.ExternFunc("my_ext"),
}
)
@@ -73,6 +96,40 @@ class Module:
)
+def test_nested_function():
+ @I.ir_module
+ class NestedFunction:
+ @R.function
+ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ @R.function
+ def nested(y: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ return y
+
+ z = nested(x)
+ return z
+
+ _assert_print(
+ NestedFunction,
+ """
+# from tvm.script import ir as I
+# from tvm.script import relax as R
+
[email protected]_module
+class Module:
+ @R.function
+ def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+ # from tvm.script import relax as R
+
+ @R.function
+ def nested(y: R.Tensor((), dtype="int32")) -> R.Tensor((),
dtype="int32"):
+ return y
+
+ z: R.Tensor((), dtype="int32") = nested(x)
+ return z
+""",
+ )
+
+
def test_object_struct_info():
obj = relax.ObjectStructInfo()
_assert_print(
@@ -576,5 +633,101 @@ class Module:
)
+def test_private_function():
+ @I.ir_module
+ class AddMod:
+ @R.function(private=True)
+ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ y: R.Tensor((), dtype="int32") = R.add(x, x)
+ return y
+
+ _assert_print(
+ AddMod,
+ """
+# from tvm.script import ir as I
+# from tvm.script import relax as R
+
[email protected]_module
+class Module:
+ @R.function(private=True)
+ def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+ y: R.Tensor((), dtype="int32") = R.add(x, x)
+ return y
+""",
+ )
+
+
+def test_directly_construct_private_funcs():
+ # public
+ @R.function
+ def foo(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ y: R.Tensor((), dtype="int32") = R.add(x, x)
+ return y
+
+ # private
+ @R.function(private=True)
+ def bar(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ y: R.Tensor((), dtype="int32") = R.multiply(x, x)
+ return y
+
+ # public but there's another attribute
+ @R.function
+ def baz(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ R.func_attr({"relax.force_pure": True})
+ y: R.Tuple = R.print(format="Hi there!")
+ z: R.Tensor((), dtype="int32") = R.add(x, x)
+ return z
+
+ # private with an attribute
+ @R.function(private=True)
+ def quux(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ R.func_attr({"relax.force_pure": True})
+ y: R.Tuple = R.print(format="Lol")
+ z: R.Tensor((), dtype="int32") = R.multiply(x, x)
+ return z
+
+ obj = IRModule(
+ {
+ "foo": foo,
+ "bar": bar,
+ "baz": baz,
+ "quux": quux,
+ }
+ )
+ _assert_print(
+ obj,
+ """
+# from tvm.script import ir as I
+# from tvm.script import relax as R
+
[email protected]_module
+class Module:
+ @R.function(private=True)
+ def bar(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+ y: R.Tensor((), dtype="int32") = R.multiply(x, x)
+ return y
+
+ @R.function
+ def baz(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+ R.func_attr({"relax.force_pure": 1})
+ y: R.Tuple = R.print(format=R.str("Hi there!"))
+ z: R.Tensor((), dtype="int32") = R.add(x, x)
+ return z
+
+ @R.function
+ def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+ y: R.Tensor((), dtype="int32") = R.add(x, x)
+ return y
+
+ @R.function(private=True)
+ def quux(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+ R.func_attr({"relax.force_pure": 1})
+ y: R.Tuple = R.print(format=R.str("Lol"))
+ z: R.Tensor((), dtype="int32") = R.multiply(x, x)
+ return z
+""",
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py
index c55876a3ba..f0c4ae0bd2 100644
--- a/tests/python/relax/test_utils.py
+++ b/tests/python/relax/test_utils.py
@@ -71,7 +71,9 @@ def test_copy_with_new_vars_on_ir_module():
gv = R.add(x, y)
return gv
- Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"])
+ Actual["func_copied"] =
relax.utils.copy_with_new_vars(Actual["func"]).with_attr(
+ "global_symbol", "func_copied"
+ )
# Assertion will fail if the f_copied contains the same VarNode that's
used in
# the original function, due to var mapping during structural equal.
@@ -113,7 +115,9 @@ def test_copy_with_new_vars_on_ir_module_nested_function():
gv = R.add(x, y)
return gv
- Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"])
+ Actual["func_copied"] =
relax.utils.copy_with_new_vars(Actual["func"]).with_attr(
+ "global_symbol", "func_copied"
+ )
assert_structural_equal(Actual, Expected)