This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b466ef5d86 [Relax][PyTorch] Enhance index_put support for 
multi-dimensional indices (#18488)
b466ef5d86 is described below

commit b466ef5d86235793dec8502a1892dfb459b5c914
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sun Nov 23 00:31:36 2025 +0800

    [Relax][PyTorch] Enhance index_put support for multi-dimensional indices 
(#18488)
    
    ## Related Issue
    
    close https://github.com/apache/tvm/issues/18438
    
    ## Why
    
    current implementation would be broken when handle multi-dim indices
    
    ## How
    
    - support multi-dimensional indices in index_put
    - add test case
---
 .../frontend/torch/base_fx_graph_translator.py     |  29 ++++-
 src/relax/op/tensor/manipulate.cc                  |  32 +++++-
 .../relax/test_frontend_from_exported_program.py   | 119 +++++++++++++++++++++
 3 files changed, 175 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index d2c888cdd1..5ca79344ba 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1677,7 +1677,34 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             raise TypeError("'accumulate' must be a boolean value, got 
{}".format(type(accumulate)))
 
         if isinstance(indices, (list, tuple)):
-            indices = relax.Tuple(indices)
+            # In PyTorch index_put, None means "select all elements" for that 
dimension
+            non_none_indices = [(i, idx) for i, idx in enumerate(indices) if 
idx is not None]
+
+            if len(non_none_indices) < len(indices):
+                data_shape = self.shape_of(tensor)
+                processed_indices = []
+
+                max_ndim = max((idx.struct_info.ndim for _, idx in 
non_none_indices), default=1)
+
+                for i, idx in enumerate(indices):
+                    if idx is None:
+                        # Replace None with arange for full dimension indexing
+                        arange_idx = self.block_builder.emit(
+                            relax.op.arange(
+                                relax.PrimValue(0), data_shape[i], 
relax.PrimValue(1), "int64"
+                            )
+                        )
+                        # Reshape to [dim_size, 1, 1, ...] for broadcasting
+                        arange_idx = self.block_builder.emit(
+                            relax.op.reshape(arange_idx, [data_shape[i]] + [1] 
* (max_ndim - 1))
+                        )
+                        processed_indices.append(arange_idx)
+                    else:
+                        processed_indices.append(idx)
+
+                indices = relax.Tuple(processed_indices)
+            else:
+                indices = relax.Tuple(indices)
         return self.block_builder.emit(relax.op.index_put(tensor, indices, 
values, accumulate))
 
     def _index_tensor(self, node: fx.Node) -> relax.Var:
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index 79c0687cad..78244a8bc5 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -2105,12 +2105,19 @@ StructInfo InferStructInfoIndexPut(const Call& call, 
const BlockBuilder& ctx) {
   }
 
   // Validate each index tensor
+  // Index tensors can be multi-dimensional for broadcasting
+  int max_index_ndim = -1;
   for (size_t i = 0; i < indices_tensors.size(); ++i) {
     const auto& tensor_sinfo = indices_tensors[i];
-    if (!tensor_sinfo->IsUnknownNdim() && tensor_sinfo->ndim != 1) {
-      ctx->ReportFatal(Diagnostic::Error(call)
-                       << "IndexPut requires each index tensor to be 1D. "
-                       << "However, index tensor " << i << " has ndim=" << 
tensor_sinfo->ndim);
+    if (!tensor_sinfo->IsUnknownNdim()) {
+      if (tensor_sinfo->ndim < 1) {
+        ctx->ReportFatal(Diagnostic::Error(call)
+                         << "IndexPut requires each index tensor to have at 
least 1 dimension. "
+                         << "However, index tensor " << i << " has ndim=" << 
tensor_sinfo->ndim);
+      }
+      if (max_index_ndim < tensor_sinfo->ndim) {
+        max_index_ndim = tensor_sinfo->ndim;
+      }
     }
     if (tensor_sinfo->IsUnknownDtype()) {
       LOG(WARNING) << "Data type of index tensor " << i
@@ -2122,6 +2129,23 @@ StructInfo InferStructInfoIndexPut(const Call& call, 
const BlockBuilder& ctx) {
     }
   }
 
+  // Validate that index tensor shapes are broadcastable
+  if (max_index_ndim > 1) {
+    for (size_t i = 0; i < indices_tensors.size(); ++i) {
+      const auto& tensor_sinfo = indices_tensors[i];
+      if (!tensor_sinfo->IsUnknownNdim() && tensor_sinfo->ndim > 1) {
+        // Check that multi-dimensional indices are broadcastable
+        const auto* shape = tensor_sinfo->shape.as<ShapeExprNode>();
+        if (shape) {
+          // Verify trailing dimensions can broadcast
+          // For now, we accept any multi-dimensional index and rely on 
runtime validation
+          LOG(INFO) << "IndexPut: index tensor " << i << " has ndim=" << 
tensor_sinfo->ndim
+                    << " for broadcasting";
+        }
+      }
+    }
+  }
+
   // Check that the number of index tensors matches data dimensions
   if (!data_sinfo->IsUnknownNdim() &&
       indices_tensors.size() != static_cast<size_t>(data_sinfo->ndim)) {
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 01efb6b936..c4851973ea 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -6576,12 +6576,131 @@ def test_index_put():
                 R.output(gv)
             return gv
 
+    # Test case 6: 2D input with multi-dimensional index (broadcasting)
+    # This tests the multi-dimensional index support with broadcasting
+    class IndexPutBroadcast1D(Module):
+        def forward(self, data, indices_1):
+            indices_0 = torch.arange(data.shape[0]).unsqueeze(1)
+            values = torch.ones(data.shape[0], len(indices_1), 
dtype=data.dtype)
+            return data.index_put_((indices_0, indices_1), values, 
accumulate=False)
+
+    example_args_broadcast1 = (
+        torch.randn(32, 64, dtype=torch.float32),
+        torch.randint(0, 64, (10,), dtype=torch.int64),
+    )
+
+    @I.ir_module
+    class ExpectedBroadcast1D:
+        @R.function
+        def main(
+            data: R.Tensor((32, 64), dtype="float32"),
+            indices_1: R.Tensor((10,), dtype="int64"),
+        ) -> R.Tuple(R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), 
dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((32,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(32), R.prim_value(1), 
dtype="int64"
+                )
+                lv1: R.Tensor((32, 1), dtype="int64") = R.expand_dims(lv, 
axis=[1])
+                lv2: R.Tensor((32, 10), dtype="float32") = R.full(
+                    R.shape([32, 10]), R.const(1.0, "float32"), dtype="float32"
+                )
+                lv3: R.Tensor((32, 64), dtype="float32") = R.index_put(
+                    data, R.tuple(lv1, indices_1), lv2, accumulate=False
+                )
+                gv: R.Tuple(
+                    R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), 
dtype="float32")
+                ) = (lv3, lv3)
+                R.output(gv)
+            return gv
+
+    # Test case 7: 2D input with multi-dimensional index (second position)
+    class IndexPutBroadcast2D(Module):
+        def forward(self, data, indices_0):
+            indices_1 = torch.arange(data.shape[1]).unsqueeze(1)
+            values = torch.ones(len(indices_0), data.shape[1], 
dtype=data.dtype)
+            return data.index_put_((indices_0, indices_1), values, 
accumulate=False)
+
+    example_args_broadcast2 = (
+        torch.randn(32, 64, dtype=torch.float32),
+        torch.randint(0, 32, (10,), dtype=torch.int64),
+    )
+
+    @I.ir_module
+    class ExpectedBroadcast2D:
+        @R.function
+        def main(
+            data: R.Tensor((32, 64), dtype="float32"),
+            indices_0: R.Tensor((10,), dtype="int64"),
+        ) -> R.Tuple(R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), 
dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((64,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(64), R.prim_value(1), 
dtype="int64"
+                )
+                lv1: R.Tensor((64, 1), dtype="int64") = R.expand_dims(lv, 
axis=[1])
+                lv2: R.Tensor((10, 64), dtype="float32") = R.full(
+                    R.shape([10, 64]), R.const(1.0, "float32"), dtype="float32"
+                )
+                lv3: R.Tensor((32, 64), dtype="float32") = R.index_put(
+                    data, R.tuple(indices_0, lv1), lv2, accumulate=False
+                )
+                gv: R.Tuple(
+                    R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), 
dtype="float32")
+                ) = (lv3, lv3)
+                R.output(gv)
+            return gv
+
+    # Test case 8: 3D input with mixed 1D and 2D indices
+    class IndexPutBroadcast3D(Module):
+        def forward(self, data, indices_1):
+            indices_0 = torch.arange(data.shape[0]).unsqueeze(1)
+            indices_2 = torch.arange(data.shape[2]).unsqueeze(1)
+            values = torch.ones(data.shape[0], len(indices_1), data.shape[2], 
dtype=data.dtype)
+            return data.index_put_((indices_0, indices_1, indices_2), values, 
accumulate=False)
+
+    example_args_broadcast3d = (
+        torch.randn(16, 32, 64, dtype=torch.float32),
+        torch.randint(0, 32, (10,), dtype=torch.int64),
+    )
+
+    @I.ir_module
+    class ExpectedBroadcast3D:
+        @R.function
+        def main(
+            data: R.Tensor((16, 32, 64), dtype="float32"),
+            indices_1: R.Tensor((10,), dtype="int64"),
+        ) -> R.Tuple(
+            R.Tensor((16, 32, 64), dtype="float32"), R.Tensor((16, 32, 64), 
dtype="float32")
+        ):
+            with R.dataflow():
+                lv: R.Tensor((16,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(16), R.prim_value(1), 
dtype="int64"
+                )
+                lv1: R.Tensor((16, 1), dtype="int64") = R.expand_dims(lv, 
axis=[1])
+                lv2: R.Tensor((64,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(64), R.prim_value(1), 
dtype="int64"
+                )
+                lv3: R.Tensor((64, 1), dtype="int64") = R.expand_dims(lv2, 
axis=[1])
+                lv4: R.Tensor((16, 10, 64), dtype="float32") = R.full(
+                    R.shape([16, 10, 64]), R.const(1.0, "float32"), 
dtype="float32"
+                )
+                lv5: R.Tensor((16, 32, 64), dtype="float32") = R.index_put(
+                    data, R.tuple(lv1, indices_1, lv3), lv4, accumulate=False
+                )
+                gv: R.Tuple(
+                    R.Tensor((16, 32, 64), dtype="float32"), R.Tensor((16, 32, 
64), dtype="float32")
+                ) = (lv5, lv5)
+                R.output(gv)
+            return gv
+
     # Run verification for each case
     verify_model(IndexPut1D(), example_args_1d, {}, Expected1D)
     verify_model(IndexPut2D(), example_args_2d, {}, Expected2D)
     verify_model(IndexPut3D(), example_args_3d, {}, Expected3D)
     verify_model(IndexPut4D(), example_args_4d, {}, Expected4D)
     verify_model(IndexPut5D(), example_args_5d, {}, Expected5D)
+    verify_model(IndexPutBroadcast1D(), example_args_broadcast1, {}, 
ExpectedBroadcast1D)
+    verify_model(IndexPutBroadcast2D(), example_args_broadcast2, {}, 
ExpectedBroadcast2D)
+    verify_model(IndexPutBroadcast3D(), example_args_broadcast3d, {}, 
ExpectedBroadcast3D)
 
 
 def test_flip():

Reply via email to