This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 16158e7cc3 [Unity][Relax][UX] Specify function purity in the 
@R.function decorator (#15109)
16158e7cc3 is described below

commit 16158e7cc301f82f0034da0229125fd6778dc8cf
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Sat Jun 17 10:47:32 2023 -0400

    [Unity][Relax][UX] Specify function purity in the @R.function decorator 
(#15109)
    
    * Set purity as an attribute in the @R.function decorator instead of using 
R.is_pure() or R.is_impure()
    
    * Remove accidental debug prints
    
    * Need to override pylint unused argument warning in function decorator
    
    * Parser argument no longer needed for find_purity_annotation
---
 include/tvm/script/ir_builder/relax/ir.h           |  9 +--
 python/tvm/script/ir_builder/relax/ir.py           | 28 ++------
 python/tvm/script/parser/relax/entry.py            | 33 +++++++--
 python/tvm/script/parser/relax/parser.py           | 24 +++----
 src/script/ir_builder/relax/ir.cc                  | 13 +---
 src/script/printer/relax/function.cc               | 10 +--
 src/script/printer/utils.h                         |  2 +-
 .../relax/test_analysis_contains_impure_call.py    |  9 +--
 tests/python/relax/test_ast_printer.py             |  3 +-
 tests/python/relax/test_pipeline.py                |  3 +-
 tests/python/relax/test_relax_operators.py         | 21 ++----
 tests/python/relax/test_transform.py               | 20 ++----
 tests/python/relax/test_transform_lambda_lift.py   | 13 ++--
 .../test_transform_static_plan_block_memory.py     |  6 +-
 tests/python/relax/test_tvmscript_parser.py        | 80 ++++++++++++++++++++--
 tests/python/relax/test_tvmscript_printer_relax.py | 12 ++--
 16 files changed, 159 insertions(+), 127 deletions(-)

diff --git a/include/tvm/script/ir_builder/relax/ir.h 
b/include/tvm/script/ir_builder/relax/ir.h
index ca705d11dc..1cf30b4919 100644
--- a/include/tvm/script/ir_builder/relax/ir.h
+++ b/include/tvm/script/ir_builder/relax/ir.h
@@ -33,9 +33,10 @@ namespace relax {
 
 /*!
  * \brief Start a function frame.
+ * \param is_pure Whether the function is annotated as pure.
  * \return The created ir_builder Function frame.
  */
-TVM_DLL FunctionFrame Function();
+TVM_DLL FunctionFrame Function(const Bool& is_pure);
 
 /*!
  * \brief Add a parameter to the last function frame.
@@ -57,12 +58,6 @@ TVM_DLL void FuncName(const String& name);
  */
 TVM_DLL void FuncAttrs(Map<String, ObjectRef> attrs);
 
-/*!
- * \brief Specify the purity of the last function frame.
- * \param purity Whether the function is pure.
- */
-TVM_DLL void FuncIsPure(bool purity);
-
 /*!
  * \brief Specify the return struct info of the last function frame.
  * \param ret_sinfo The return struct info.
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 73606a8924..e54e4aa07b 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -165,14 +165,19 @@ py_str = str
 ############################### Function ################################
 
 
-def function() -> frame.FunctionFrame:
+def function(is_pure: bool = True) -> frame.FunctionFrame:
     """Start a function frame.
+    Parameters
+    ----------
+    is_pure: bool
+        Whether the function is annotated as pure.
+
     Returns
     -------
     frame: FunctionFrame
         The constructed function frame.
     """
-    return _ffi_api.Function()  # type: ignore[attr-defined] # pylint: 
disable=no-member
+    return _ffi_api.Function(is_pure)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def arg(name: py_str, struct_info: StructInfo) -> Var:
@@ -213,23 +218,6 @@ def func_attr(attrs: Dict[py_str, tvm_Object]) -> None:
     return _ffi_api.FuncAttrs(attrs)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
-def is_pure(purity: bool = True) -> None:
-    """Specify the purity of the last function frame.
-
-    Parameters
-    ----------
-    purity: bool
-        The annotated purity.
-    """
-    return _ffi_api.FuncIsPure(purity)  # type: ignore[attr-defined] # pylint: 
disable=no-member
-
-
-def is_impure() -> None:
-    """Specify that the last function frame is annotated as impure.
-    (Syntactic sugar for R.is_pure(False))"""
-    return _ffi_api.FuncIsPure(False)  # type: ignore[attr-defined] # pylint: 
disable=no-member
-
-
 def func_ret_struct_info(ret_sinfo: StructInfo) -> None:
     """Specify the return struct info of the last function frame.
     Parameters
@@ -634,8 +622,6 @@ __all__ = [
     "image",
     "invoke_closure",
     "invoke_pure_closure",
-    "is_impure",
-    "is_pure",
     "isfinite",
     "isinf",
     "isnan",
diff --git a/python/tvm/script/parser/relax/entry.py 
b/python/tvm/script/parser/relax/entry.py
index 70e5173458..2711e855dd 100644
--- a/python/tvm/script/parser/relax/entry.py
+++ b/python/tvm/script/parser/relax/entry.py
@@ -42,13 +42,32 @@ FType = TypeVar("FType", bound=_Callable)
 
 ############################## R.function ##############################
 
-
-def function(f: FType) -> Union[Function, FType]:
-    if not inspect.isfunction(f):
-        raise TypeError(f"Expect a function, but got: {f}")
-    if utils.is_defined_in_class(inspect.stack(), f):
-        return f
-    return parse(f, utils.inspect_function_capture(f))
+# this formulation allows us to support having @R.function
+# appear as a decorator by itself or to have optional arguments
+# like @R.function(pure=False)
+def function(f: Optional[FType] = None, pure: bool = True) -> Union[Function, 
FType]:
+    # pylint: disable=unused-argument
+    # (pure isn't used here, but is used later in parsing)
+
+    # need to inspect the stack first because is_defined_in_class expects the 
outer class
+    # to be in a particular position in the stack
+    orig_stack = inspect.stack()
+
+    def decorator_wrapper(f):
+        if not inspect.isfunction(f):
+            raise TypeError(f"Expect a function, but got: {f}")
+        if utils.is_defined_in_class(orig_stack, f):
+            return f
+        return parse(f, utils.inspect_function_capture(f))
+
+    if f is not None:
+        # if there are no optional args given, this will directly invoke the 
wrapper
+        return decorator_wrapper(f)
+    else:
+        # if there is a optional arg given, it returns the wrapper function
+        # as a new decorator and applies it
+        setattr(decorator_wrapper, "dispatch_token", "relax")
+        return decorator_wrapper
 
 
 setattr(function, "dispatch_token", "relax")
diff --git a/python/tvm/script/parser/relax/parser.py 
b/python/tvm/script/parser/relax/parser.py
index 3dfde96714..427c56bcc8 100644
--- a/python/tvm/script/parser/relax/parser.py
+++ b/python/tvm/script/parser/relax/parser.py
@@ -178,9 +178,11 @@ def visit_function_def(self: Parser, node: 
doc.FunctionDef) -> None:
         local_func_var = relax.Var(node.name, 
relax.FuncStructInfo(params_sinfo, ret_sinfo))
         self.var_table.add(node.name, local_func_var)
 
+    purity = find_purity_annotation(node)
+
     with self.var_table.with_frame():
         with self.with_dispatch_token("relax"):
-            with R.function():
+            with R.function(is_pure=purity):
                 R.func_name(node.name)
                 collect_symbolic_var_from_params(self, node)
 
@@ -204,20 +206,17 @@ def visit_function_def(self: Parser, node: 
doc.FunctionDef) -> None:
 
 def find_purity_annotation(node: doc.FunctionDef) -> bool:
     """
-    Check if is_pure is specified in the function body.
+    Check the value of `pure` in the function decorator.
     Returns the annotated purity if present, otherwise defaulting to True.
     This allows for specifying the purity in the function signature.
     """
-    for item in node.body:
-        if (
-            isinstance(item, doc.Expr)
-            and isinstance(item.value, doc.Call)
-            and isinstance(item.value.func, doc.Attribute)
-            and item.value.func.attr == "is_pure"
-            and len(item.value.args) == 1
-            and isinstance(item.value.args[0], doc.Constant)
-        ):
-            return bool(item.value.args[0].value)
+    # look for the pure argument in the function decorator
+    for dec in node.decorator_list:
+        if not isinstance(dec, doc.Call) or dec.func.attr != "function":
+            continue
+        for keyword in dec.keywords:
+            if keyword.arg == "pure":
+                return keyword.value.value
     return True
 
 
@@ -240,7 +239,6 @@ def visit_tvm_declare_function(self: Parser, node: 
doc.FunctionDef) -> GlobalVar
             params.append(relax.Var(arg.arg, param_sinfo))
 
     is_pure = find_purity_annotation(node)
-
     func_signature = relax.Function.create_empty(params, ret_sinfo, 
is_pure=is_pure)
     return I.decl_function(node.name, func_signature)
 
diff --git a/src/script/ir_builder/relax/ir.cc 
b/src/script/ir_builder/relax/ir.cc
index 5c39bedd43..52d9f0cfe1 100644
--- a/src/script/ir_builder/relax/ir.cc
+++ b/src/script/ir_builder/relax/ir.cc
@@ -52,7 +52,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable)
 
 /////////////////////////////// Function ////////////////////////////////
 
-FunctionFrame Function() {
+FunctionFrame Function(const Bool& is_pure) {
   ObjectPtr<FunctionFrameNode> n = make_object<FunctionFrameNode>();
   const IRBuilder& ir_builder = IRBuilder::Current();
   Optional<tvm::IRModule> mod = NullOpt;
@@ -60,6 +60,7 @@ FunctionFrame Function() {
     mod = tvm::IRModule(mod_frame.value()->functions);
   }
   n->block_builder = tvm::relax::BlockBuilder::Create(/*mod=*/mod);
+  n->is_pure = is_pure;
   return FunctionFrame(n);
 }
 
@@ -87,15 +88,6 @@ void FuncAttrs(Map<String, ObjectRef> attrs) {
   frame->attrs = attrs;
 }
 
-void FuncIsPure(bool purity) {
-  FunctionFrame frame = FindFunctionFrame("R.is_pure");
-  if (frame->is_pure.defined()) {
-    LOG(FATAL) << "ValueError: Duplicate function purity annotations, previous 
one is:\n"
-               << frame->is_pure.value();
-  }
-  frame->is_pure = Bool(purity);
-}
-
 void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo) {
   FunctionFrame frame = FindFunctionFrame("R.func_ret_struct_info");
   if (frame->ret_struct_info.defined()) {
@@ -132,7 +124,6 @@ 
TVM_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function)
 TVM_REGISTER_GLOBAL("script.ir_builder.relax.Arg").set_body_typed(Arg);
 
TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncName").set_body_typed(FuncName);
 
TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncAttrs").set_body_typed(FuncAttrs);
-TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncIsPure").set_body_typed(FuncIsPure);
 
TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetStructInfo").set_body_typed(FuncRetStructInfo);
 
TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetValue").set_body_typed(FuncRetValue);
 
diff --git a/src/script/printer/relax/function.cc 
b/src/script/printer/relax/function.cc
index 95169712d9..bd5d969563 100644
--- a/src/script/printer/relax/function.cc
+++ b/src/script/printer/relax/function.cc
@@ -56,16 +56,18 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
             ExprStmtDoc(Relax(d, "func_attr")  //
                             ->Call({d->AsDoc<ExprDoc>(n->attrs, 
n_p->Attr("attrs"))})));
       }
-      // Step 5. Print purity attributes (only include if it's impure)
+      // Step 5. Prepare the decorator (include purity if it's impure)
+      ExprDoc decorator = Relax(d, "function");
       if (!n->is_pure) {
-        (*f)->stmts.push_back(ExprStmtDoc(Relax(d, "is_impure")->Call({})));
+        Array<ExprDoc> pos_args = {};
+        decorator = std::move(decorator->Call(
+            pos_args, {"pure"}, {LiteralDoc::Boolean(false, 
Optional<ObjectPath>())}));
       }
       // Step 6. Print body
       Array<StmtDoc> body =
           PrintSeqExpr(Downcast<relax::SeqExpr>(n->body), n_p->Attr("body"), 
d, /*use_ret=*/true);
       (*f)->stmts.insert((*f)->stmts.end(), body.begin(), body.end());
-      return HeaderWrapper(
-          d, FunctionDoc(func_name, params, {Relax(d, "function")}, ret_type, 
(*f)->stmts));
+      return HeaderWrapper(d, FunctionDoc(func_name, params, {decorator}, 
ret_type, (*f)->stmts));
     });
 
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h
index 2a2f469082..1a54385776 100644
--- a/src/script/printer/utils.h
+++ b/src/script/printer/utils.h
@@ -111,7 +111,7 @@ inline ExprDoc TIR(const IRDocsifier& d, const String& 
attr) {
   return IdDoc(d->cfg->tir_prefix)->Attr(attr);
 }
 
-/*! \brief Creates the TIR common prefix, which is by default `T` */
+/*! \brief Creates the Relax common prefix, which is by default `R` */
 inline ExprDoc Relax(const IRDocsifier& d, const String& attr) {
   d->ir_usage.insert("relax");
   return IdDoc(d->cfg->relax_prefix)->Attr(attr);
diff --git a/tests/python/relax/test_analysis_contains_impure_call.py 
b/tests/python/relax/test_analysis_contains_impure_call.py
index bc7d663517..f82022fd94 100644
--- a/tests/python/relax/test_analysis_contains_impure_call.py
+++ b/tests/python/relax/test_analysis_contains_impure_call.py
@@ -37,9 +37,8 @@ def test_simple_pure_case():
 def test_simple_impure_case():
     @tvm.script.ir_module
     class ImpureTest:
-        @R.function
+        @R.function(pure=False)
         def impure_func() -> R.Object:
-            R.is_impure()
             y = R.print(format="I am a message")
             return y
 
@@ -52,9 +51,8 @@ def test_nested_function():
         @R.function
         def pure_with_impure_nested() -> R.Tensor((), "int32"):
             # unused
-            @R.function
+            @R.function(pure=False)
             def impure_inner() -> R.Object:
-                R.is_impure()
                 y = R.print(format="Another, worse, message")
                 return y
 
@@ -73,9 +71,8 @@ def test_ignoring_recursive_call():
     # function has become pure
     @tvm.script.ir_module
     class RecursiveTest:
-        @R.function
+        @R.function(pure=False)
         def recursive_impure() -> R.Object:
-            R.is_impure()
             x = R.const(1, "int32")
             y = R.add(x, x)
             z = R.print(x, y, format="{} {}")
diff --git a/tests/python/relax/test_ast_printer.py 
b/tests/python/relax/test_ast_printer.py
index e0ddab5c67..2a554f16e2 100644
--- a/tests/python/relax/test_ast_printer.py
+++ b/tests/python/relax/test_ast_printer.py
@@ -357,13 +357,12 @@ def test_struct_info():
 
 def test_call_packed():
     # test case from test_parser
-    @R.function
+    @R.function(pure=False)
     def f(
         x: R.Tensor((32, "m"), "float32"),
         y: R.Tensor(("m",), "float32"),
         r: R.Tensor(dtype="int64"),
     ) -> R.Object:
-        R.is_impure()
         m = T.int64()
         z: R.Tensor((32, m), "float32") = R.multiply(x, y)
         w: R.Tensor = R.multiply(z, z)
diff --git a/tests/python/relax/test_pipeline.py 
b/tests/python/relax/test_pipeline.py
index a9f0863214..8aa41490c6 100644
--- a/tests/python/relax/test_pipeline.py
+++ b/tests/python/relax/test_pipeline.py
@@ -66,14 +66,13 @@ def test_pipeline_with_kv_cache():
             )
             return kv_cache
 
-        @R.function
+        @R.function(pure=False)
         def main(
             x: R.Tensor((1, 4), "float32"),
             y: R.Tensor((1, 4), "float32"),
             shape: R.Shape(["L", 4]),
             kv_cache: R.Object,
         ):
-            R.is_impure()
             L = T.int64()
             # computation of the current value
             curr_value = R.add(x, y)
diff --git a/tests/python/relax/test_relax_operators.py 
b/tests/python/relax/test_relax_operators.py
index e6c947e9ef..90608df4b6 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -58,9 +58,8 @@ def test_unique():
 
 @tvm.script.ir_module
 class PrintTest:
-    @R.function
+    @R.function(pure=False)
     def foo(x: R.Tensor((), "int32")):
-        R.is_impure()
         # results have to be bound, but we don't use them
         # TODO: We should allow calls whose results are not bound for side 
effects;
         #       it would be easy syntactic sugar to add.
@@ -90,40 +89,34 @@ def test_print():
 
 @tvm.script.ir_module
 class AssertOpTest:
-    @R.function
+    @R.function(pure=False)
     def passes(x: R.Tensor((), "int32")):
-        R.is_impure()
         p1 = R.assert_op(relax.const(True))
         return x
 
-    @R.function
+    @R.function(pure=False)
     def pass_with_args(x: R.Tensor((), "int32")):
-        R.is_impure()
         p1 = R.assert_op(relax.const(True), x, format="You won't see me")
         return x
 
-    @R.function
+    @R.function(pure=False)
     def simple_fail(x: R.Tensor((), "int32")):
-        R.is_impure()
         p1 = R.assert_op(relax.const(False))
         return x
 
-    @R.function
+    @R.function(pure=False)
     def fail_with_message(x: R.Tensor((), "int32")):
-        R.is_impure()
         p1 = R.assert_op(relax.const(False), format="I failed...")
         return x
 
-    @R.function
+    @R.function(pure=False)
     def fail_with_args(x: R.Tensor((), "int32")):
-        R.is_impure()
         # no format
         p1 = R.assert_op(relax.const(False), [x, x])
         return x
 
-    @R.function
+    @R.function(pure=False)
     def fail_with_formatted_message(x: R.Tensor((), "int32")):
-        R.is_impure()
         p1 = R.assert_op(relax.const(False), x, format="Number: {}")
         return x
 
diff --git a/tests/python/relax/test_transform.py 
b/tests/python/relax/test_transform.py
index 2476f6e1f3..102f80b2b0 100644
--- a/tests/python/relax/test_transform.py
+++ b/tests/python/relax/test_transform.py
@@ -145,9 +145,8 @@ def test_transform_remove_purity_checking():
             res = R.invoke_pure_closure(closure, (x,), sinfo_args=R.Tensor((), 
"int32"))
             return res
 
-        @R.function
+        @R.function(pure=False)
         def impure_func() -> R.Object:
-            R.is_impure()
             y = R.print(format="I am impure!")
             return y
 
@@ -165,13 +164,10 @@ def test_transform_remove_purity_checking():
             w = nested(z)
             return w
 
-        @R.function
+        @R.function(pure=False)
         def nested_impure_func() -> R.Tensor((), "int32"):
-            R.is_impure()
-
-            @R.function
+            @R.function(pure=False)
             def nested() -> R.Object:
-                R.is_impure()
                 x = R.print(format="Oops!")
                 return x
 
@@ -202,9 +198,8 @@ def test_transform_remove_purity_checking():
             res = R.invoke_closure(closure, (x,), sinfo_args=R.Tensor((), 
"int32"))
             return res
 
-        @R.function
+        @R.function(pure=False)
         def impure_func() -> R.Object:
-            R.is_impure()
             y = R.print(format="I am impure!")
             return y
 
@@ -223,13 +218,10 @@ def test_transform_remove_purity_checking():
             w = nested(z)
             return w
 
-        @R.function
+        @R.function(pure=False)
         def nested_impure_func() -> R.Tensor((), "int32"):
-            R.is_impure()
-
-            @R.function
+            @R.function(pure=False)
             def nested() -> R.Object:
-                R.is_impure()
                 x = R.print(format="Oops!")
                 return x
 
diff --git a/tests/python/relax/test_transform_lambda_lift.py 
b/tests/python/relax/test_transform_lambda_lift.py
index 98f35a4b98..ddc274fee2 100644
--- a/tests/python/relax/test_transform_lambda_lift.py
+++ b/tests/python/relax/test_transform_lambda_lift.py
@@ -308,28 +308,23 @@ def test_no_local_func():
 def test_impure_function():
     @tvm.script.ir_module
     class Expected:
-        @R.function
+        @R.function(pure=False)
         def lifted_func_0() -> R.Tuple:
-            R.is_impure()
             y = R.print(format="Wow!")
             return y
 
-        @R.function
+        @R.function(pure=False)
         def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
-            R.is_impure()
             inner = Expected.lifted_func_0
             gv1 = inner()
             return x
 
     @tvm.script.ir_module
     class Before:
-        @R.function
+        @R.function(pure=False)
         def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
-            R.is_impure()
-
-            @R.function
+            @R.function(pure=False)
             def inner() -> R.Tuple:
-                R.is_impure()
                 y = R.print(format="Wow!")
                 return y
 
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index ffc0a586e5..0f59278a8e 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -664,10 +664,9 @@ def test_call_func_other_than_primfunc():
 def test_call_packed_external_func():
     @I.ir_module
     class Module:
-        @R.function
+        @R.function(pure=False)
         def main(x: R.Tensor((2, 3), "float32")):
             # the extern func may or may not be pure, depends on what we're 
calling
-            R.is_impure()
             alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
                 R.shape([2, 3]), dtype="float32", runtime_device_index=0
             )
@@ -682,9 +681,8 @@ def test_call_packed_external_func():
 
     @I.ir_module
     class Expected:
-        @R.function
+        @R.function(pure=False)
         def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), 
dtype="float32"):
-            R.is_impure()
             storage: R.Object = R.memory.alloc_storage(
                 R.shape([24]), R.prim_value(0), R.str("global"), 
R.dtype("float32")
             )
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index fef13d234e..7a8bcdee26 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1383,27 +1383,99 @@ def test_global_var_sinfo():
 def test_assert_op():
     @I.ir_module
     class AssertOp:
-        @R.function
+        @R.function(pure=False)
         def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
-            R.is_impure()
             y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}")
             return x
 
     _check(AssertOp)
 
 
+def test_assert_outside_of_class():
+    @R.function(pure=False)
+    def func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+        y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}")
+        return x
+
+    # this just makes sure that the machinery regarding the pure attribute 
parses
+    # in the case where the function is outside of a class too
+    _check(func)
+
+
+def test_impure_inner_function():
+    @R.function
+    def f(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+        # we will not actually call it
+        @R.function(pure=False)
+        def g(y: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+            z = R.assert_op(R.const(False, dtype="bool"), y, format="y: {}")
+            return y
+
+        return x
+
+    assert f.is_pure
+    # definition of g
+    assert not f.body.blocks[0].bindings[0].value.is_pure
+
+    # make sure we are not incorrectly passing state for inner functions
+    _check(f)
+
+
+def test_impure_inner_function_in_class():
+    @I.ir_module
+    class ImpureInner:
+        @R.function
+        def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+            # we will not actually call it
+            @R.function(pure=False)
+            def g(y: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+                z = R.assert_op(R.const(False, dtype="bool"), y, format="y: 
{}")
+                return y
+
+            return x
+
+    assert ImpureInner["main"].is_pure
+    # definition of g
+    assert not ImpureInner["main"].body.blocks[0].bindings[0].value.is_pure
+
+    # make sure we are not incorrectly passing state for inner functions
+    _check(ImpureInner)
+
+
 def test_print():
     @I.ir_module
     class Print:
-        @R.function
+        @R.function(pure=False)
         def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
-            R.is_impure()
             y = R.print(x, format="x: {}")
             return x
 
     _check(Print)
 
 
+def test_parse_multiple_pure_and_impure_funcs():
+    @I.ir_module
+    class Mixture:
+        @R.function(pure=False)
+        def print(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+            y = R.print(x, format="x: {}")
+            return x
+
+        @R.function(pure=False)
+        def assert_func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+            y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}")
+            return x
+
+        @R.function
+        def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+            return x
+
+    assert not Mixture["print"].is_pure
+    assert not Mixture["assert_func"].is_pure
+    assert Mixture["main"].is_pure
+    _check(Mixture)
+
+
 def test_call_pure_packed():
     @R.function
     def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py 
b/tests/python/relax/test_tvmscript_printer_relax.py
index e76fe1d902..7525c63be4 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -531,9 +531,8 @@ class Module:
 def test_assert_op():
     @I.ir_module
     class AssertOpMod:
-        @R.function
+        @R.function(pure=False)
         def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
-            R.is_impure()
             y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}")
             return x
 
@@ -545,9 +544,8 @@ def test_assert_op():
 
 @I.ir_module
 class Module:
-    @R.function
+    @R.function(pure=False)
     def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
-        R.is_impure()
         y: R.Tuple = R.assert_op(R.const(False, "bool"), x, format=R.str("x: 
{}"))
         return x
 """,
@@ -557,9 +555,8 @@ class Module:
 def test_print():
     @I.ir_module
     class PrintMod:
-        @R.function
+        @R.function(pure=False)
         def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
-            R.is_impure()
             y = R.print(x, format="x: {}")
             return x
 
@@ -571,9 +568,8 @@ def test_print():
 
 @I.ir_module
 class Module:
-    @R.function
+    @R.function(pure=False)
     def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
-        R.is_impure()
         y: R.Tuple = R.print(x, format=R.str("x: {}"))
         return x
 """,

Reply via email to