This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0d017c13a4 [Relax] Replaced call_pure_packed with tensor_to_shape 
operator (#18616)
0d017c13a4 is described below

commit 0d017c13a48d3d2ed8ae63b0329f5064204316b3
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Dec 30 16:02:29 2025 +0800

    [Relax] Replaced call_pure_packed with tensor_to_shape operator (#18616)
    
    ## Why
    
    Simplifying the code and addressing the purity issue mentioned in the
    TODO comment.
    
    ## How
    
    **Before**
    ```
    output_shape = bb.emit(
        call_pure_packed(
            "vm.builtin.tensor_to_shape", output_shape, 
sinfo_args=ShapeStructInfo(ndim=ndim)
        )
    )
    ```
    
    **After**
    ```
    output_shape = bb.emit(tensor_to_shape(output_shape))
    ```
    
    ---------
    
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
 python/tvm/relax/transform/legalize_ops/index.py            | 13 ++-----------
 .../test_transform_legalize_ops_index_linear_algebra.py     |  8 ++------
 2 files changed, 4 insertions(+), 17 deletions(-)

diff --git a/python/tvm/relax/transform/legalize_ops/index.py 
b/python/tvm/relax/transform/legalize_ops/index.py
index d99c1f4db6..75c17f7fa9 100644
--- a/python/tvm/relax/transform/legalize_ops/index.py
+++ b/python/tvm/relax/transform/legalize_ops/index.py
@@ -17,7 +17,7 @@
 # pylint: disable=invalid-name
 """Default legalization function for index operators."""
 from tvm import topi, tir, te
-from ...op import call_pure_packed
+from ...op import tensor_to_shape
 from ...block_builder import BlockBuilder
 from ...expr import Call, Expr
 from ...struct_info import ShapeStructInfo, PrimStructInfo
@@ -109,17 +109,8 @@ def _dynamic_strided_slice(bb: BlockBuilder, call: Call) 
-> Expr:
     )
 
     # 2. Convert tensor to shape and match cast with new symbolic vars
-    # Get shape length
     ndim = int(output_shape.struct_info.shape[0])
-    output_shape = bb.emit(
-        # TODO(@relax-team): Ideally, we should use the tensor_to_shape op 
here to
-        # address the issue with purity, but that introduces a staging issue:
-        # we need to apply DecomposeOpsForInference in that case
-        # and it's unclear when in the build it should happen
-        call_pure_packed(
-            "vm.builtin.tensor_to_shape", output_shape, 
sinfo_args=ShapeStructInfo(ndim=ndim)
-        )
-    )
+    output_shape = bb.emit(tensor_to_shape(output_shape))
     output_shape_vars = [tir.Var("s", "int64") for i in range(ndim)]
     bb.match_cast(output_shape, ShapeStructInfo(output_shape_vars))
 
diff --git 
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py 
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index efa7f4dfff..a6e53dab4d 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -669,9 +669,7 @@ def test_dynamic_strided_slice():
                 (x, begin, end, strides),
                 out_sinfo=R.Tensor((4,), dtype="int64"),
             )
-            gv1: R.Shape(ndim=4) = R.call_pure_packed(
-                "vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=4),)
-            )
+            gv1: R.Shape(ndim=4) = R.tensor_to_shape(gv)
             gv2: R.Shape([s, s_1, s_2, s_3]) = R.match_cast(
                 gv1, R.Shape([s, s_1, s_2, s_3])
             )
@@ -868,9 +866,7 @@ def test_dynamic_strided_slice_symbolic():
                 (x, begin, end, strides),
                 out_sinfo=R.Tensor((2,), dtype="int64"),
             )
-            gv1: R.Shape(ndim=2) = R.call_pure_packed(
-                "vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=2),)
-            )
+            gv1: R.Shape(ndim=2) = R.tensor_to_shape(gv)
             gv2: R.Shape([s, s_1]) = R.match_cast(gv1, R.Shape([s, s_1]))
             gv_1 = R.call_tir(
                 Expected.dynamic_strided_slice,

Reply via email to