This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 16158e7cc3 [Unity][Relax][UX] Specify function purity in the
@R.function decorator (#15109)
16158e7cc3 is described below
commit 16158e7cc301f82f0034da0229125fd6778dc8cf
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Sat Jun 17 10:47:32 2023 -0400
[Unity][Relax][UX] Specify function purity in the @R.function decorator
(#15109)
* Set purity as an attribute in the @R.function decorator instead of using
R.is_pure() or R.is_impure()
* Remove accidental debug prints
* Need to override pylint unused argument warning in function decorator
* Parser argument no longer needed for find_purity_annotation
---
include/tvm/script/ir_builder/relax/ir.h | 9 +--
python/tvm/script/ir_builder/relax/ir.py | 28 ++------
python/tvm/script/parser/relax/entry.py | 33 +++++++--
python/tvm/script/parser/relax/parser.py | 24 +++----
src/script/ir_builder/relax/ir.cc | 13 +---
src/script/printer/relax/function.cc | 10 +--
src/script/printer/utils.h | 2 +-
.../relax/test_analysis_contains_impure_call.py | 9 +--
tests/python/relax/test_ast_printer.py | 3 +-
tests/python/relax/test_pipeline.py | 3 +-
tests/python/relax/test_relax_operators.py | 21 ++----
tests/python/relax/test_transform.py | 20 ++----
tests/python/relax/test_transform_lambda_lift.py | 13 ++--
.../test_transform_static_plan_block_memory.py | 6 +-
tests/python/relax/test_tvmscript_parser.py | 80 ++++++++++++++++++++--
tests/python/relax/test_tvmscript_printer_relax.py | 12 ++--
16 files changed, 159 insertions(+), 127 deletions(-)
diff --git a/include/tvm/script/ir_builder/relax/ir.h
b/include/tvm/script/ir_builder/relax/ir.h
index ca705d11dc..1cf30b4919 100644
--- a/include/tvm/script/ir_builder/relax/ir.h
+++ b/include/tvm/script/ir_builder/relax/ir.h
@@ -33,9 +33,10 @@ namespace relax {
/*!
* \brief Start a function frame.
+ * \param is_pure Whether the function is annotated as pure.
* \return The created ir_builder Function frame.
*/
-TVM_DLL FunctionFrame Function();
+TVM_DLL FunctionFrame Function(const Bool& is_pure);
/*!
* \brief Add a parameter to the last function frame.
@@ -57,12 +58,6 @@ TVM_DLL void FuncName(const String& name);
*/
TVM_DLL void FuncAttrs(Map<String, ObjectRef> attrs);
-/*!
- * \brief Specify the purity of the last function frame.
- * \param purity Whether the function is pure.
- */
-TVM_DLL void FuncIsPure(bool purity);
-
/*!
* \brief Specify the return struct info of the last function frame.
* \param ret_sinfo The return struct info.
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 73606a8924..e54e4aa07b 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -165,14 +165,19 @@ py_str = str
############################### Function ################################
-def function() -> frame.FunctionFrame:
+def function(is_pure: bool = True) -> frame.FunctionFrame:
"""Start a function frame.
+ Parameters
+ ----------
+ is_pure: bool
+ Whether the function is annotated as pure.
+
Returns
-------
frame: FunctionFrame
The constructed function frame.
"""
- return _ffi_api.Function() # type: ignore[attr-defined] # pylint:
disable=no-member
+ return _ffi_api.Function(is_pure) # type: ignore[attr-defined] # pylint:
disable=no-member
def arg(name: py_str, struct_info: StructInfo) -> Var:
@@ -213,23 +218,6 @@ def func_attr(attrs: Dict[py_str, tvm_Object]) -> None:
return _ffi_api.FuncAttrs(attrs) # type: ignore[attr-defined] # pylint:
disable=no-member
-def is_pure(purity: bool = True) -> None:
- """Specify the purity of the last function frame.
-
- Parameters
- ----------
- purity: bool
- The annotated purity.
- """
- return _ffi_api.FuncIsPure(purity) # type: ignore[attr-defined] # pylint:
disable=no-member
-
-
-def is_impure() -> None:
- """Specify that the last function frame is annotated as impure.
- (Syntactic sugar for R.is_pure(False))"""
- return _ffi_api.FuncIsPure(False) # type: ignore[attr-defined] # pylint:
disable=no-member
-
-
def func_ret_struct_info(ret_sinfo: StructInfo) -> None:
"""Specify the return struct info of the last function frame.
Parameters
@@ -634,8 +622,6 @@ __all__ = [
"image",
"invoke_closure",
"invoke_pure_closure",
- "is_impure",
- "is_pure",
"isfinite",
"isinf",
"isnan",
diff --git a/python/tvm/script/parser/relax/entry.py
b/python/tvm/script/parser/relax/entry.py
index 70e5173458..2711e855dd 100644
--- a/python/tvm/script/parser/relax/entry.py
+++ b/python/tvm/script/parser/relax/entry.py
@@ -42,13 +42,32 @@ FType = TypeVar("FType", bound=_Callable)
############################## R.function ##############################
-
-def function(f: FType) -> Union[Function, FType]:
- if not inspect.isfunction(f):
- raise TypeError(f"Expect a function, but got: {f}")
- if utils.is_defined_in_class(inspect.stack(), f):
- return f
- return parse(f, utils.inspect_function_capture(f))
+# this formulation allows us to support having @R.function
+# appear as a decorator by itself or to have optional arguments
+# like @R.function(pure=False)
+def function(f: Optional[FType] = None, pure: bool = True) -> Union[Function,
FType]:
+ # pylint: disable=unused-argument
+ # (pure isn't used here, but is used later in parsing)
+
+ # need to inspect the stack first because is_defined_in_class expects the
outer class
+ # to be in a particular position in the stack
+ orig_stack = inspect.stack()
+
+ def decorator_wrapper(f):
+ if not inspect.isfunction(f):
+ raise TypeError(f"Expect a function, but got: {f}")
+ if utils.is_defined_in_class(orig_stack, f):
+ return f
+ return parse(f, utils.inspect_function_capture(f))
+
+ if f is not None:
+ # if there are no optional args given, this will directly invoke the
wrapper
+ return decorator_wrapper(f)
+ else:
+ # if there is a optional arg given, it returns the wrapper function
+ # as a new decorator and applies it
+ setattr(decorator_wrapper, "dispatch_token", "relax")
+ return decorator_wrapper
setattr(function, "dispatch_token", "relax")
diff --git a/python/tvm/script/parser/relax/parser.py
b/python/tvm/script/parser/relax/parser.py
index 3dfde96714..427c56bcc8 100644
--- a/python/tvm/script/parser/relax/parser.py
+++ b/python/tvm/script/parser/relax/parser.py
@@ -178,9 +178,11 @@ def visit_function_def(self: Parser, node:
doc.FunctionDef) -> None:
local_func_var = relax.Var(node.name,
relax.FuncStructInfo(params_sinfo, ret_sinfo))
self.var_table.add(node.name, local_func_var)
+ purity = find_purity_annotation(node)
+
with self.var_table.with_frame():
with self.with_dispatch_token("relax"):
- with R.function():
+ with R.function(is_pure=purity):
R.func_name(node.name)
collect_symbolic_var_from_params(self, node)
@@ -204,20 +206,17 @@ def visit_function_def(self: Parser, node:
doc.FunctionDef) -> None:
def find_purity_annotation(node: doc.FunctionDef) -> bool:
"""
- Check if is_pure is specified in the function body.
+ Check the value of `pure` in the function decorator.
Returns the annotated purity if present, otherwise defaulting to True.
This allows for specifying the purity in the function signature.
"""
- for item in node.body:
- if (
- isinstance(item, doc.Expr)
- and isinstance(item.value, doc.Call)
- and isinstance(item.value.func, doc.Attribute)
- and item.value.func.attr == "is_pure"
- and len(item.value.args) == 1
- and isinstance(item.value.args[0], doc.Constant)
- ):
- return bool(item.value.args[0].value)
+ # look for the pure argument in the function decorator
+ for dec in node.decorator_list:
+ if not isinstance(dec, doc.Call) or dec.func.attr != "function":
+ continue
+ for keyword in dec.keywords:
+ if keyword.arg == "pure":
+ return keyword.value.value
return True
@@ -240,7 +239,6 @@ def visit_tvm_declare_function(self: Parser, node:
doc.FunctionDef) -> GlobalVar
params.append(relax.Var(arg.arg, param_sinfo))
is_pure = find_purity_annotation(node)
-
func_signature = relax.Function.create_empty(params, ret_sinfo,
is_pure=is_pure)
return I.decl_function(node.name, func_signature)
diff --git a/src/script/ir_builder/relax/ir.cc
b/src/script/ir_builder/relax/ir.cc
index 5c39bedd43..52d9f0cfe1 100644
--- a/src/script/ir_builder/relax/ir.cc
+++ b/src/script/ir_builder/relax/ir.cc
@@ -52,7 +52,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable)
/////////////////////////////// Function ////////////////////////////////
-FunctionFrame Function() {
+FunctionFrame Function(const Bool& is_pure) {
ObjectPtr<FunctionFrameNode> n = make_object<FunctionFrameNode>();
const IRBuilder& ir_builder = IRBuilder::Current();
Optional<tvm::IRModule> mod = NullOpt;
@@ -60,6 +60,7 @@ FunctionFrame Function() {
mod = tvm::IRModule(mod_frame.value()->functions);
}
n->block_builder = tvm::relax::BlockBuilder::Create(/*mod=*/mod);
+ n->is_pure = is_pure;
return FunctionFrame(n);
}
@@ -87,15 +88,6 @@ void FuncAttrs(Map<String, ObjectRef> attrs) {
frame->attrs = attrs;
}
-void FuncIsPure(bool purity) {
- FunctionFrame frame = FindFunctionFrame("R.is_pure");
- if (frame->is_pure.defined()) {
- LOG(FATAL) << "ValueError: Duplicate function purity annotations, previous
one is:\n"
- << frame->is_pure.value();
- }
- frame->is_pure = Bool(purity);
-}
-
void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo) {
FunctionFrame frame = FindFunctionFrame("R.func_ret_struct_info");
if (frame->ret_struct_info.defined()) {
@@ -132,7 +124,6 @@
TVM_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function)
TVM_REGISTER_GLOBAL("script.ir_builder.relax.Arg").set_body_typed(Arg);
TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncName").set_body_typed(FuncName);
TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncAttrs").set_body_typed(FuncAttrs);
-TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncIsPure").set_body_typed(FuncIsPure);
TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetStructInfo").set_body_typed(FuncRetStructInfo);
TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetValue").set_body_typed(FuncRetValue);
diff --git a/src/script/printer/relax/function.cc
b/src/script/printer/relax/function.cc
index 95169712d9..bd5d969563 100644
--- a/src/script/printer/relax/function.cc
+++ b/src/script/printer/relax/function.cc
@@ -56,16 +56,18 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
ExprStmtDoc(Relax(d, "func_attr") //
->Call({d->AsDoc<ExprDoc>(n->attrs,
n_p->Attr("attrs"))})));
}
- // Step 5. Print purity attributes (only include if it's impure)
+ // Step 5. Prepare the decorator (include purity if it's impure)
+ ExprDoc decorator = Relax(d, "function");
if (!n->is_pure) {
- (*f)->stmts.push_back(ExprStmtDoc(Relax(d, "is_impure")->Call({})));
+ Array<ExprDoc> pos_args = {};
+ decorator = std::move(decorator->Call(
+ pos_args, {"pure"}, {LiteralDoc::Boolean(false,
Optional<ObjectPath>())}));
}
// Step 6. Print body
Array<StmtDoc> body =
PrintSeqExpr(Downcast<relax::SeqExpr>(n->body), n_p->Attr("body"),
d, /*use_ret=*/true);
(*f)->stmts.insert((*f)->stmts.end(), body.begin(), body.end());
- return HeaderWrapper(
- d, FunctionDoc(func_name, params, {Relax(d, "function")}, ret_type,
(*f)->stmts));
+ return HeaderWrapper(d, FunctionDoc(func_name, params, {decorator},
ret_type, (*f)->stmts));
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h
index 2a2f469082..1a54385776 100644
--- a/src/script/printer/utils.h
+++ b/src/script/printer/utils.h
@@ -111,7 +111,7 @@ inline ExprDoc TIR(const IRDocsifier& d, const String&
attr) {
return IdDoc(d->cfg->tir_prefix)->Attr(attr);
}
-/*! \brief Creates the TIR common prefix, which is by default `T` */
+/*! \brief Creates the Relax common prefix, which is by default `R` */
inline ExprDoc Relax(const IRDocsifier& d, const String& attr) {
d->ir_usage.insert("relax");
return IdDoc(d->cfg->relax_prefix)->Attr(attr);
diff --git a/tests/python/relax/test_analysis_contains_impure_call.py
b/tests/python/relax/test_analysis_contains_impure_call.py
index bc7d663517..f82022fd94 100644
--- a/tests/python/relax/test_analysis_contains_impure_call.py
+++ b/tests/python/relax/test_analysis_contains_impure_call.py
@@ -37,9 +37,8 @@ def test_simple_pure_case():
def test_simple_impure_case():
@tvm.script.ir_module
class ImpureTest:
- @R.function
+ @R.function(pure=False)
def impure_func() -> R.Object:
- R.is_impure()
y = R.print(format="I am a message")
return y
@@ -52,9 +51,8 @@ def test_nested_function():
@R.function
def pure_with_impure_nested() -> R.Tensor((), "int32"):
# unused
- @R.function
+ @R.function(pure=False)
def impure_inner() -> R.Object:
- R.is_impure()
y = R.print(format="Another, worse, message")
return y
@@ -73,9 +71,8 @@ def test_ignoring_recursive_call():
# function has become pure
@tvm.script.ir_module
class RecursiveTest:
- @R.function
+ @R.function(pure=False)
def recursive_impure() -> R.Object:
- R.is_impure()
x = R.const(1, "int32")
y = R.add(x, x)
z = R.print(x, y, format="{} {}")
diff --git a/tests/python/relax/test_ast_printer.py
b/tests/python/relax/test_ast_printer.py
index e0ddab5c67..2a554f16e2 100644
--- a/tests/python/relax/test_ast_printer.py
+++ b/tests/python/relax/test_ast_printer.py
@@ -357,13 +357,12 @@ def test_struct_info():
def test_call_packed():
# test case from test_parser
- @R.function
+ @R.function(pure=False)
def f(
x: R.Tensor((32, "m"), "float32"),
y: R.Tensor(("m",), "float32"),
r: R.Tensor(dtype="int64"),
) -> R.Object:
- R.is_impure()
m = T.int64()
z: R.Tensor((32, m), "float32") = R.multiply(x, y)
w: R.Tensor = R.multiply(z, z)
diff --git a/tests/python/relax/test_pipeline.py
b/tests/python/relax/test_pipeline.py
index a9f0863214..8aa41490c6 100644
--- a/tests/python/relax/test_pipeline.py
+++ b/tests/python/relax/test_pipeline.py
@@ -66,14 +66,13 @@ def test_pipeline_with_kv_cache():
)
return kv_cache
- @R.function
+ @R.function(pure=False)
def main(
x: R.Tensor((1, 4), "float32"),
y: R.Tensor((1, 4), "float32"),
shape: R.Shape(["L", 4]),
kv_cache: R.Object,
):
- R.is_impure()
L = T.int64()
# computation of the current value
curr_value = R.add(x, y)
diff --git a/tests/python/relax/test_relax_operators.py
b/tests/python/relax/test_relax_operators.py
index e6c947e9ef..90608df4b6 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -58,9 +58,8 @@ def test_unique():
@tvm.script.ir_module
class PrintTest:
- @R.function
+ @R.function(pure=False)
def foo(x: R.Tensor((), "int32")):
- R.is_impure()
# results have to be bound, but we don't use them
# TODO: We should allow calls whose results are not bound for side
effects;
# it would be easy syntactic sugar to add.
@@ -90,40 +89,34 @@ def test_print():
@tvm.script.ir_module
class AssertOpTest:
- @R.function
+ @R.function(pure=False)
def passes(x: R.Tensor((), "int32")):
- R.is_impure()
p1 = R.assert_op(relax.const(True))
return x
- @R.function
+ @R.function(pure=False)
def pass_with_args(x: R.Tensor((), "int32")):
- R.is_impure()
p1 = R.assert_op(relax.const(True), x, format="You won't see me")
return x
- @R.function
+ @R.function(pure=False)
def simple_fail(x: R.Tensor((), "int32")):
- R.is_impure()
p1 = R.assert_op(relax.const(False))
return x
- @R.function
+ @R.function(pure=False)
def fail_with_message(x: R.Tensor((), "int32")):
- R.is_impure()
p1 = R.assert_op(relax.const(False), format="I failed...")
return x
- @R.function
+ @R.function(pure=False)
def fail_with_args(x: R.Tensor((), "int32")):
- R.is_impure()
# no format
p1 = R.assert_op(relax.const(False), [x, x])
return x
- @R.function
+ @R.function(pure=False)
def fail_with_formatted_message(x: R.Tensor((), "int32")):
- R.is_impure()
p1 = R.assert_op(relax.const(False), x, format="Number: {}")
return x
diff --git a/tests/python/relax/test_transform.py
b/tests/python/relax/test_transform.py
index 2476f6e1f3..102f80b2b0 100644
--- a/tests/python/relax/test_transform.py
+++ b/tests/python/relax/test_transform.py
@@ -145,9 +145,8 @@ def test_transform_remove_purity_checking():
res = R.invoke_pure_closure(closure, (x,), sinfo_args=R.Tensor((),
"int32"))
return res
- @R.function
+ @R.function(pure=False)
def impure_func() -> R.Object:
- R.is_impure()
y = R.print(format="I am impure!")
return y
@@ -165,13 +164,10 @@ def test_transform_remove_purity_checking():
w = nested(z)
return w
- @R.function
+ @R.function(pure=False)
def nested_impure_func() -> R.Tensor((), "int32"):
- R.is_impure()
-
- @R.function
+ @R.function(pure=False)
def nested() -> R.Object:
- R.is_impure()
x = R.print(format="Oops!")
return x
@@ -202,9 +198,8 @@ def test_transform_remove_purity_checking():
res = R.invoke_closure(closure, (x,), sinfo_args=R.Tensor((),
"int32"))
return res
- @R.function
+ @R.function(pure=False)
def impure_func() -> R.Object:
- R.is_impure()
y = R.print(format="I am impure!")
return y
@@ -223,13 +218,10 @@ def test_transform_remove_purity_checking():
w = nested(z)
return w
- @R.function
+ @R.function(pure=False)
def nested_impure_func() -> R.Tensor((), "int32"):
- R.is_impure()
-
- @R.function
+ @R.function(pure=False)
def nested() -> R.Object:
- R.is_impure()
x = R.print(format="Oops!")
return x
diff --git a/tests/python/relax/test_transform_lambda_lift.py
b/tests/python/relax/test_transform_lambda_lift.py
index 98f35a4b98..ddc274fee2 100644
--- a/tests/python/relax/test_transform_lambda_lift.py
+++ b/tests/python/relax/test_transform_lambda_lift.py
@@ -308,28 +308,23 @@ def test_no_local_func():
def test_impure_function():
@tvm.script.ir_module
class Expected:
- @R.function
+ @R.function(pure=False)
def lifted_func_0() -> R.Tuple:
- R.is_impure()
y = R.print(format="Wow!")
return y
- @R.function
+ @R.function(pure=False)
def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
- R.is_impure()
inner = Expected.lifted_func_0
gv1 = inner()
return x
@tvm.script.ir_module
class Before:
- @R.function
+ @R.function(pure=False)
def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
- R.is_impure()
-
- @R.function
+ @R.function(pure=False)
def inner() -> R.Tuple:
- R.is_impure()
y = R.print(format="Wow!")
return y
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py
b/tests/python/relax/test_transform_static_plan_block_memory.py
index ffc0a586e5..0f59278a8e 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -664,10 +664,9 @@ def test_call_func_other_than_primfunc():
def test_call_packed_external_func():
@I.ir_module
class Module:
- @R.function
+ @R.function(pure=False)
def main(x: R.Tensor((2, 3), "float32")):
# the extern func may or may not be pure, depends on what we're
calling
- R.is_impure()
alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
R.shape([2, 3]), dtype="float32", runtime_device_index=0
)
@@ -682,9 +681,8 @@ def test_call_packed_external_func():
@I.ir_module
class Expected:
- @R.function
+ @R.function(pure=False)
def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
- R.is_impure()
storage: R.Object = R.memory.alloc_storage(
R.shape([24]), R.prim_value(0), R.str("global"),
R.dtype("float32")
)
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index fef13d234e..7a8bcdee26 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1383,27 +1383,99 @@ def test_global_var_sinfo():
def test_assert_op():
@I.ir_module
class AssertOp:
- @R.function
+ @R.function(pure=False)
def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
- R.is_impure()
y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}")
return x
_check(AssertOp)
+def test_assert_outside_of_class():
+ @R.function(pure=False)
+ def func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}")
+ return x
+
+ # this just makes sure that the machinery regarding the pure attribute
parses
+ # in the case where the function is outside of a class too
+ _check(func)
+
+
+def test_impure_inner_function():
+ @R.function
+ def f(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ # we will not actually call it
+ @R.function(pure=False)
+ def g(y: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ z = R.assert_op(R.const(False, dtype="bool"), y, format="y: {}")
+ return y
+
+ return x
+
+ assert f.is_pure
+ # definition of g
+ assert not f.body.blocks[0].bindings[0].value.is_pure
+
+ # make sure we are not incorrectly passing state for inner functions
+ _check(f)
+
+
+def test_impure_inner_function_in_class():
+ @I.ir_module
+ class ImpureInner:
+ @R.function
+ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ # we will not actually call it
+ @R.function(pure=False)
+ def g(y: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ z = R.assert_op(R.const(False, dtype="bool"), y, format="y:
{}")
+ return y
+
+ return x
+
+ assert ImpureInner["main"].is_pure
+ # definition of g
+ assert not ImpureInner["main"].body.blocks[0].bindings[0].value.is_pure
+
+ # make sure we are not incorrectly passing state for inner functions
+ _check(ImpureInner)
+
+
def test_print():
@I.ir_module
class Print:
- @R.function
+ @R.function(pure=False)
def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
- R.is_impure()
y = R.print(x, format="x: {}")
return x
_check(Print)
+def test_parse_multiple_pure_and_impure_funcs():
+ @I.ir_module
+ class Mixture:
+ @R.function(pure=False)
+ def print(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ y = R.print(x, format="x: {}")
+ return x
+
+ @R.function(pure=False)
+ def assert_func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}")
+ return x
+
+ @R.function
+ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ return x
+
+ assert not Mixture["print"].is_pure
+ assert not Mixture["assert_func"].is_pure
+ assert Mixture["main"].is_pure
+ _check(Mixture)
+
+
def test_call_pure_packed():
@R.function
def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py
b/tests/python/relax/test_tvmscript_printer_relax.py
index e76fe1d902..7525c63be4 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -531,9 +531,8 @@ class Module:
def test_assert_op():
@I.ir_module
class AssertOpMod:
- @R.function
+ @R.function(pure=False)
def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
- R.is_impure()
y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}")
return x
@@ -545,9 +544,8 @@ def test_assert_op():
@I.ir_module
class Module:
- @R.function
+ @R.function(pure=False)
def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
- R.is_impure()
y: R.Tuple = R.assert_op(R.const(False, "bool"), x, format=R.str("x:
{}"))
return x
""",
@@ -557,9 +555,8 @@ class Module:
def test_print():
@I.ir_module
class PrintMod:
- @R.function
+ @R.function(pure=False)
def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
- R.is_impure()
y = R.print(x, format="x: {}")
return x
@@ -571,9 +568,8 @@ def test_print():
@I.ir_module
class Module:
- @R.function
+ @R.function(pure=False)
def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
- R.is_impure()
y: R.Tuple = R.print(x, format=R.str("x: {}"))
return x
""",