This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ec0026e0bc [Relax][PyTorch] Fix index_put with broadcast indices 
(#18533)
ec0026e0bc is described below

commit ec0026e0bc8b7904b29e167e39b252c7e2794d4a
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Dec 2 04:21:32 2025 +0800

    [Relax][PyTorch] Fix index_put with broadcast indices (#18533)
    
    ## Related Issue
    
    closes https://github.com/apache/tvm/issues/18355
    
    ## Why
    
    Converting PyTorch operations like M[:, rows, cols] = x failed because:
    1. The TOPI index_put implementation called len() on TVM Tensor objects
    (unsupported)
    2. Index tensors with different shapes (e.g., (2,) and (10,)) couldn't
    broadcast together
    
    ## How
    
    - Added broadcasting support following NumPy rules to handle
    multi-dimensional index tensors
    - add tests for batched indexing pattern M[:, rows, cols] = x
---
 .../frontend/torch/base_fx_graph_translator.py     |  3 +-
 python/tvm/relax/op/manipulate.py                  |  2 +-
 python/tvm/topi/index_put.py                       | 68 ++++++++++++++++++----
 .../relax/test_frontend_from_exported_program.py   | 49 ++++++++++++++++
 4 files changed, 108 insertions(+), 14 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index e9a9cdd939..7ebb95c136 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1812,8 +1812,9 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
                             )
                         )
                         # Reshape to [dim_size, 1, 1, ...] for broadcasting
+                        # Add an extra dimension so it broadcasts with other 
indices
                         arange_idx = self.block_builder.emit(
-                            relax.op.reshape(arange_idx, [data_shape[i]] + [1] 
* (max_ndim - 1))
+                            relax.op.reshape(arange_idx, [data_shape[i]] + [1] 
* max_ndim)
                         )
                         processed_indices.append(arange_idx)
                     else:
diff --git a/python/tvm/relax/op/manipulate.py 
b/python/tvm/relax/op/manipulate.py
index bb134f1148..ee486b0ab6 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -642,7 +642,7 @@ def index_put(
             [0.0, 3.0, 0.0],
         ]
     """
-    if not isinstance(indices, (list, tuple)):
+    if isinstance(indices, (list, tuple)):
         indices = RxTuple(indices)
     return _ffi_api.index_put(data, indices, values, accumulate)  # type: 
ignore
 
diff --git a/python/tvm/topi/index_put.py b/python/tvm/topi/index_put.py
index f51c6718ab..52406d402c 100644
--- a/python/tvm/topi/index_put.py
+++ b/python/tvm/topi/index_put.py
@@ -1,6 +1,6 @@
 # Licensed to the Apache Software Foundation (ASF) under one
-# or more contrir_builderutor license agreements.  See the NOTICE file
-# distrir_builderuted with this work for additional information
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
 # regarding copyright ownership.  The ASF licenses this file
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
@@ -9,7 +9,7 @@
 #   http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing,
-# software distrir_builderuted under the License is distrir_builderuted on an
+# software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
@@ -29,7 +29,8 @@ def index_put(data, indices, values, accumulate=False):
         The source array to be modified.
 
     indices : Tuple[tvm.te.Tensor]
-        Tuple of 1D index tensors (one for each dimension) specifying 
positions.
+        Tuple of index tensors (can be multi-dimensional) specifying positions.
+        Index tensors are broadcast together following NumPy broadcasting 
rules.
 
     values : tvm.te.Tensor
         The values to place at the specified indices.
@@ -60,11 +61,28 @@ def index_put(data, indices, values, accumulate=False):
     for dim in shape:
         full_range *= dim
 
-    # Check all indices have same length
-    index_len = len(indices[0])
-    for idx in indices[1:]:
-        if not utils.equal_const_int(len(idx), index_len):
-            raise ValueError("All index tensors must have same length")
+    index_shapes = [idx.shape for idx in indices]
+    broadcast_ndim = max(len(s) for s in index_shapes)
+    broadcast_shape = []
+
+    for i in range(broadcast_ndim):
+        max_dim = 1
+        for idx_shape in index_shapes:
+            # Right-align shapes
+            dim_idx = len(idx_shape) - broadcast_ndim + i
+            if dim_idx >= 0:
+                dim_size = idx_shape[dim_idx]
+                if not utils.equal_const_int(dim_size, 1):
+                    if utils.equal_const_int(max_dim, 1):
+                        max_dim = dim_size
+                    elif not utils.equal_const_int(dim_size, max_dim):
+                        raise ValueError(f"Cannot broadcast index shapes: 
{index_shapes}")
+        broadcast_shape.append(max_dim)
+
+    # Compute total number of elements after broadcasting
+    index_len = 1
+    for dim in broadcast_shape:
+        index_len *= dim
 
     def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func):
         ir_builder = tir.ir_builder.create()
@@ -78,12 +96,38 @@ def index_put(data, indices, values, accumulate=False):
             out[i] = data[i]
 
         with ir_builder.for_range(0, index_len, "k", kind="parallel") as k:
-            # Calculate multi-dimensional index
+            # Decompose k into multi-dimensional broadcast index
+            k_temp = k
+            broadcast_indices = []
+            for i in range(broadcast_ndim - 1, -1, -1):
+                broadcast_indices.insert(0, k_temp % broadcast_shape[i])
+                k_temp = k_temp // broadcast_shape[i]
+
             flat_index = 0
             stride = 1
             for dim in range(len(shape) - 1, -1, -1):
-                # Get index and shift to positive if needed
-                idx_val = indices[dim][k]
+                # Get the index for this dimension using broadcasting
+                idx_shape = index_shapes[dim]
+                idx_ndim = len(idx_shape)
+
+                # Compute the linear index into this index tensor
+                idx_offset = 0
+                idx_stride = 1
+                for i in range(broadcast_ndim - 1, -1, -1):
+                    # Right-align the index shape with broadcast shape
+                    dim_idx = idx_ndim - broadcast_ndim + i
+                    if dim_idx >= 0:
+                        dim_size = idx_shape[dim_idx]
+                        # Use broadcasting: if size is 1, use index 0
+                        # otherwise use broadcast_indices[i]
+                        if utils.equal_const_int(dim_size, 1):
+                            idx_in_dim = 0
+                        else:
+                            idx_in_dim = broadcast_indices[i]
+                        idx_offset += idx_in_dim * idx_stride
+                        idx_stride *= dim_size
+
+                idx_val = indices[dim][idx_offset]
                 shifted_idx = idx_val + (idx_val < 0) * shape[dim]
                 flat_index += shifted_idx * stride
                 stride *= shape[dim]
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 0658dbfaf3..010bd026a8 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7133,6 +7133,54 @@ def test_index_put():
                 R.output(gv)
             return gv
 
+    # Test case 9: batched indexing with slice (e.g., M[:, rows, cols] = x)
+    class IndexPutBatchedWithNone(Module):
+        def forward(self, x):
+            B = x.size(0)
+            M = torch.zeros(B, 11, 11)
+            rows = torch.arange(10)
+            cols = rows + 1
+            M[:, rows, cols] = x  # Batched index assignment
+            return M
+
+    example_args_batched_none = (torch.randn(2, 10, dtype=torch.float32),)
+
+    @I.ir_module
+    class ExpectedBatchedWithNone:
+        @R.function
+        def main(
+            x: R.Tensor((2, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((2, 11, 11), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 11, 11), dtype="float32") = R.full(
+                    R.shape([2, 11, 11]), R.const(0.0, "float32"), 
dtype="float32"
+                )
+                lv1: R.Tensor((10,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(10), R.prim_value(1), 
dtype="int64"
+                )
+                lv2: R.Tensor((10,), dtype="int64") = R.add(lv1, R.const(1, 
"int64"))
+                lv3: R.Tensor((2, 11, 11), dtype="float32") = R.strided_slice(
+                    lv,
+                    (R.prim_value(0),),
+                    (R.prim_value(0),),
+                    (R.prim_value(9223372036854775807),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv4: R.Tensor((2,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(2), R.prim_value(1), 
dtype="int64"
+                )
+                lv5: R.Tensor((2, 1), dtype="int64") = R.reshape(lv4, 
R.shape([2, 1]))
+                lv6: R.Tensor((2, 11, 11), dtype="float32") = R.index_put(
+                    lv3, (lv5, lv1, lv2), x, accumulate=False
+                )
+                lv7: R.Tensor((2, 11, 11), dtype="float32") = R.slice_scatter(
+                    lv, lv6, R.prim_value(0), R.prim_value(2), 
R.prim_value(1), axis=0
+                )
+                gv: R.Tuple(R.Tensor((2, 11, 11), dtype="float32")) = (lv7,)
+                R.output(gv)
+            return gv
+
     # Run verification for each case
     verify_model(IndexPut1D(), example_args_1d, {}, Expected1D)
     verify_model(IndexPut2D(), example_args_2d, {}, Expected2D)
@@ -7142,6 +7190,7 @@ def test_index_put():
     verify_model(IndexPutBroadcast1D(), example_args_broadcast1, {}, 
ExpectedBroadcast1D)
     verify_model(IndexPutBroadcast2D(), example_args_broadcast2, {}, 
ExpectedBroadcast2D)
     verify_model(IndexPutBroadcast3D(), example_args_broadcast3d, {}, 
ExpectedBroadcast3D)
+    verify_model(IndexPutBatchedWithNone(), example_args_batched_none, {}, 
ExpectedBatchedWithNone)
 
 
 def test_flip():

Reply via email to