This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b91d4e55b3 [TVMScript] Produce empty DictAttrs when R.func_attrs is
absent (#16844)
b91d4e55b3 is described below
commit b91d4e55b3f66a10508b4b492378173be75ba1a5
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Apr 5 07:21:59 2024 -0500
[TVMScript] Produce empty DictAttrs when R.func_attrs is absent (#16844)
A follow-up to https://github.com/apache/tvm/pull/16745. For Relax
functions produced in TVMScript, when `R.func_attrs` was not present,
the default was set to `None` instead of an empty dictionary.
---
src/relax/ir/expr.cc | 4 ++++
src/script/ir_builder/relax/frame.cc | 3 +--
src/tir/ir/function.cc | 4 ++++
tests/python/relax/test_tvmscript_parser.py | 22 ++++++++++++++++++++++
4 files changed, 31 insertions(+), 2 deletions(-)
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index b709039e8c..1b5551e509 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -493,6 +493,10 @@ TVM_REGISTER_NODE_TYPE(FunctionNode);
Function::Function(Array<Var> params, Expr body, Optional<StructInfo>
ret_struct_info, bool is_pure,
DictAttrs attrs, Span span) {
+ if (!attrs.defined()) {
+ attrs = DictAttrs();
+ }
+
// Set the function type.
// For function, we take a conservative approach and require the function
type
// to be known at construction time.
diff --git a/src/script/ir_builder/relax/frame.cc
b/src/script/ir_builder/relax/frame.cc
index b95db57a88..792331dda4 100644
--- a/src/script/ir_builder/relax/frame.cc
+++ b/src/script/ir_builder/relax/frame.cc
@@ -61,13 +61,12 @@ void FunctionFrameNode::ExitWithScope() {
!attrs.count(tvm::attr::kGlobalSymbol)) {
attrs.Set(tvm::attr::kGlobalSymbol, name.value());
}
- auto dict_attrs = attrs.empty() ? NullValue<DictAttrs>() : DictAttrs(attrs);
this->block_builder->EndScope();
tvm::relax::Function func(/*params=*/params,
/*body=*/body,
/*ret_struct_info=*/ret_struct_info,
/*is_pure=*/is_pure.value_or(Bool(true))->value,
- /*attrs=*/dict_attrs);
+ /*attrs=*/DictAttrs(attrs));
// Step 2: Update IRModule.
if (builder->frames.empty()) {
// Case 0. No outer frame, return function directly
diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc
index 8a3d2d6947..14dd0eadb6 100644
--- a/src/tir/ir/function.cc
+++ b/src/tir/ir/function.cc
@@ -70,6 +70,10 @@ relax::StructInfo InferStructInfo(const PrimFunc& prim_func)
{
// Get the function type of a PrimFunc
PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type,
Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span
span) {
+ if (!attrs.defined()) {
+ attrs = DictAttrs();
+ }
+
// Assume void-return type for now
// TODO(tvm-team) consider type deduction from body.
if (!ret_type.defined()) {
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index c8db26c81b..e692768a12 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -2271,5 +2271,27 @@ def test_define_relax_function_using_global_var():
tvm.ir.assert_structural_equal(DefinedAllAtOnce, MainDefinedLater)
+def test_function_attributes_are_defined():
+ """func.attrs defaults to an empty DictAttrs"""
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(x: R.Tensor, shape: R.Shape(["m", "n"])):
+ output = Module.subroutine(x, shape)
+ return output
+
+ @R.function
+ def subroutine(x: R.Tensor, _: R.Shape(["m", "n"])) -> R.Tensor(["m",
"n"]):
+ q = x
+ m, n = T.int64(), T.int64()
+ z = R.match_cast(q, R.Tensor((m, n)))
+ w = z
+ return w
+
+ for gvar, func in Module.functions.items():
+ assert func.attrs is not None
+
+
if __name__ == "__main__":
tvm.testing.main()