This is an automated email from the ASF dual-hosted git repository.

sslyu pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0fe3ed4424 [Unity] Mark result of LazyTransformParams as impure 
function (#15697)
0fe3ed4424 is described below

commit 0fe3ed4424907d748676fccbb0efe536d140e6bb
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Sep 20 17:10:17 2023 -0500

    [Unity] Mark result of LazyTransformParams as impure function (#15697)
    
    * [Unity] Mark result of LazyTransformParams as impure function
    
    * Updated unit tests for correct "relax.force_pure" attributes
---
 python/tvm/relax/transform/lazy_transform_params.py        | 8 +++++++-
 tests/python/relax/test_transform_lazy_transform_params.py | 7 ++-----
 2 files changed, 9 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/transform/lazy_transform_params.py 
b/python/tvm/relax/transform/lazy_transform_params.py
index 90e56c8dbb..2bdc784f32 100644
--- a/python/tvm/relax/transform/lazy_transform_params.py
+++ b/python/tvm/relax/transform/lazy_transform_params.py
@@ -158,7 +158,13 @@ class LazyTransformParamsMutator(PyExprMutator):
                 if not isinstance(sinfo, relax.TensorStructInfo):
                     params.append(relax.Var("symbolic_var_holder", sinfo))
 
-        return relax.Function(params, new_body, relax.ObjectStructInfo(), 
attrs=func.attrs)
+        return relax.Function(
+            params,
+            new_body,
+            relax.ObjectStructInfo(),
+            attrs=func.attrs,
+            is_pure=False,
+        ).without_attr("relax.force_pure")
 
     def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> relax.Expr:
         # rewrite get item
diff --git a/tests/python/relax/test_transform_lazy_transform_params.py 
b/tests/python/relax/test_transform_lazy_transform_params.py
index 94f2181daf..5eac747d18 100644
--- a/tests/python/relax/test_transform_lazy_transform_params.py
+++ b/tests/python/relax/test_transform_lazy_transform_params.py
@@ -74,9 +74,8 @@ def test_lazy_transform_params():
                     T.writes(out[o, i, h, w])
                     out[o, i, h, w] = w1[i, o, h, w]
 
-        @R.function
+        @R.function(pure=False)
         def main_transform_params() -> R.Tuple:
-            R.func_attr({"relax.force_pure": True})
             cls = Expected
             lv: R.Object = R.call_packed("get_item", R.prim_value(1), 
sinfo_args=(R.Object,))
             gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
@@ -145,10 +144,8 @@ def test_lazy_transform_params_with_symbolic_vars():
 
     @I.ir_module
     class Expected:
-        @R.function
+        @R.function(pure=False)
         def main_transform_params(slice_shape_expr: R.Shape(["slice_index"])):
-            # we expect ToNonDataflow and RemovePurityTracking to be invoked 
first
-            R.func_attr({"relax.force_pure": True})
             cls = Expected
 
             slice_index = T.int64()

Reply via email to