This is an automated email from the ASF dual-hosted git repository.
sslyu pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0fe3ed4424 [Unity] Mark result of LazyTransformParams as impure
function (#15697)
0fe3ed4424 is described below
commit 0fe3ed4424907d748676fccbb0efe536d140e6bb
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Sep 20 17:10:17 2023 -0500
[Unity] Mark result of LazyTransformParams as impure function (#15697)
* [Unity] Mark result of LazyTransformParams as impure function
* Updated unit tests for correct "relax.force_pure" attributes
---
python/tvm/relax/transform/lazy_transform_params.py | 8 +++++++-
tests/python/relax/test_transform_lazy_transform_params.py | 7 ++-----
2 files changed, 9 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relax/transform/lazy_transform_params.py
b/python/tvm/relax/transform/lazy_transform_params.py
index 90e56c8dbb..2bdc784f32 100644
--- a/python/tvm/relax/transform/lazy_transform_params.py
+++ b/python/tvm/relax/transform/lazy_transform_params.py
@@ -158,7 +158,13 @@ class LazyTransformParamsMutator(PyExprMutator):
if not isinstance(sinfo, relax.TensorStructInfo):
params.append(relax.Var("symbolic_var_holder", sinfo))
- return relax.Function(params, new_body, relax.ObjectStructInfo(),
attrs=func.attrs)
+ return relax.Function(
+ params,
+ new_body,
+ relax.ObjectStructInfo(),
+ attrs=func.attrs,
+ is_pure=False,
+ ).without_attr("relax.force_pure")
def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> relax.Expr:
# rewrite get item
diff --git a/tests/python/relax/test_transform_lazy_transform_params.py
b/tests/python/relax/test_transform_lazy_transform_params.py
index 94f2181daf..5eac747d18 100644
--- a/tests/python/relax/test_transform_lazy_transform_params.py
+++ b/tests/python/relax/test_transform_lazy_transform_params.py
@@ -74,9 +74,8 @@ def test_lazy_transform_params():
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]
- @R.function
+ @R.function(pure=False)
def main_transform_params() -> R.Tuple:
- R.func_attr({"relax.force_pure": True})
cls = Expected
lv: R.Object = R.call_packed("get_item", R.prim_value(1),
sinfo_args=(R.Object,))
gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
@@ -145,10 +144,8 @@ def test_lazy_transform_params_with_symbolic_vars():
@I.ir_module
class Expected:
- @R.function
+ @R.function(pure=False)
def main_transform_params(slice_shape_expr: R.Shape(["slice_index"])):
- # we expect ToNonDataflow and RemovePurityTracking to be invoked
first
- R.func_attr({"relax.force_pure": True})
cls = Expected
slice_index = T.int64()