This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b466ef5d86 [Relax][PyTorch] Enhance index_put support for
multi-dimensional indices (#18488)
b466ef5d86 is described below
commit b466ef5d86235793dec8502a1892dfb459b5c914
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sun Nov 23 00:31:36 2025 +0800
[Relax][PyTorch] Enhance index_put support for multi-dimensional indices
(#18488)
## Related Issue
close https://github.com/apache/tvm/issues/18438
## Why
current implementation would be broken when handle multi-dim indices
## How
- support multi-dimensional indices in index_put
- add test case
---
.../frontend/torch/base_fx_graph_translator.py | 29 ++++-
src/relax/op/tensor/manipulate.cc | 32 +++++-
.../relax/test_frontend_from_exported_program.py | 119 +++++++++++++++++++++
3 files changed, 175 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index d2c888cdd1..5ca79344ba 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1677,7 +1677,34 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
raise TypeError("'accumulate' must be a boolean value, got
{}".format(type(accumulate)))
if isinstance(indices, (list, tuple)):
- indices = relax.Tuple(indices)
+ # In PyTorch index_put, None means "select all elements" for that
dimension
+ non_none_indices = [(i, idx) for i, idx in enumerate(indices) if
idx is not None]
+
+ if len(non_none_indices) < len(indices):
+ data_shape = self.shape_of(tensor)
+ processed_indices = []
+
+ max_ndim = max((idx.struct_info.ndim for _, idx in
non_none_indices), default=1)
+
+ for i, idx in enumerate(indices):
+ if idx is None:
+ # Replace None with arange for full dimension indexing
+ arange_idx = self.block_builder.emit(
+ relax.op.arange(
+ relax.PrimValue(0), data_shape[i],
relax.PrimValue(1), "int64"
+ )
+ )
+ # Reshape to [dim_size, 1, 1, ...] for broadcasting
+ arange_idx = self.block_builder.emit(
+ relax.op.reshape(arange_idx, [data_shape[i]] + [1]
* (max_ndim - 1))
+ )
+ processed_indices.append(arange_idx)
+ else:
+ processed_indices.append(idx)
+
+ indices = relax.Tuple(processed_indices)
+ else:
+ indices = relax.Tuple(indices)
return self.block_builder.emit(relax.op.index_put(tensor, indices,
values, accumulate))
def _index_tensor(self, node: fx.Node) -> relax.Var:
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index 79c0687cad..78244a8bc5 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -2105,12 +2105,19 @@ StructInfo InferStructInfoIndexPut(const Call& call,
const BlockBuilder& ctx) {
}
// Validate each index tensor
+ // Index tensors can be multi-dimensional for broadcasting
+ int max_index_ndim = -1;
for (size_t i = 0; i < indices_tensors.size(); ++i) {
const auto& tensor_sinfo = indices_tensors[i];
- if (!tensor_sinfo->IsUnknownNdim() && tensor_sinfo->ndim != 1) {
- ctx->ReportFatal(Diagnostic::Error(call)
- << "IndexPut requires each index tensor to be 1D. "
- << "However, index tensor " << i << " has ndim=" <<
tensor_sinfo->ndim);
+ if (!tensor_sinfo->IsUnknownNdim()) {
+ if (tensor_sinfo->ndim < 1) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "IndexPut requires each index tensor to have at
least 1 dimension. "
+ << "However, index tensor " << i << " has ndim=" <<
tensor_sinfo->ndim);
+ }
+ if (max_index_ndim < tensor_sinfo->ndim) {
+ max_index_ndim = tensor_sinfo->ndim;
+ }
}
if (tensor_sinfo->IsUnknownDtype()) {
LOG(WARNING) << "Data type of index tensor " << i
@@ -2122,6 +2129,23 @@ StructInfo InferStructInfoIndexPut(const Call& call,
const BlockBuilder& ctx) {
}
}
+ // Validate that index tensor shapes are broadcastable
+ if (max_index_ndim > 1) {
+ for (size_t i = 0; i < indices_tensors.size(); ++i) {
+ const auto& tensor_sinfo = indices_tensors[i];
+ if (!tensor_sinfo->IsUnknownNdim() && tensor_sinfo->ndim > 1) {
+ // Check that multi-dimensional indices are broadcastable
+ const auto* shape = tensor_sinfo->shape.as<ShapeExprNode>();
+ if (shape) {
+ // Verify trailing dimensions can broadcast
+ // For now, we accept any multi-dimensional index and rely on
runtime validation
+ LOG(INFO) << "IndexPut: index tensor " << i << " has ndim=" <<
tensor_sinfo->ndim
+ << " for broadcasting";
+ }
+ }
+ }
+ }
+
// Check that the number of index tensors matches data dimensions
if (!data_sinfo->IsUnknownNdim() &&
indices_tensors.size() != static_cast<size_t>(data_sinfo->ndim)) {
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 01efb6b936..c4851973ea 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -6576,12 +6576,131 @@ def test_index_put():
R.output(gv)
return gv
+ # Test case 6: 2D input with multi-dimensional index (broadcasting)
+ # This tests the multi-dimensional index support with broadcasting
+ class IndexPutBroadcast1D(Module):
+ def forward(self, data, indices_1):
+ indices_0 = torch.arange(data.shape[0]).unsqueeze(1)
+ values = torch.ones(data.shape[0], len(indices_1),
dtype=data.dtype)
+ return data.index_put_((indices_0, indices_1), values,
accumulate=False)
+
+ example_args_broadcast1 = (
+ torch.randn(32, 64, dtype=torch.float32),
+ torch.randint(0, 64, (10,), dtype=torch.int64),
+ )
+
+ @I.ir_module
+ class ExpectedBroadcast1D:
+ @R.function
+ def main(
+ data: R.Tensor((32, 64), dtype="float32"),
+ indices_1: R.Tensor((10,), dtype="int64"),
+ ) -> R.Tuple(R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64),
dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((32,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(32), R.prim_value(1),
dtype="int64"
+ )
+ lv1: R.Tensor((32, 1), dtype="int64") = R.expand_dims(lv,
axis=[1])
+ lv2: R.Tensor((32, 10), dtype="float32") = R.full(
+ R.shape([32, 10]), R.const(1.0, "float32"), dtype="float32"
+ )
+ lv3: R.Tensor((32, 64), dtype="float32") = R.index_put(
+ data, R.tuple(lv1, indices_1), lv2, accumulate=False
+ )
+ gv: R.Tuple(
+ R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64),
dtype="float32")
+ ) = (lv3, lv3)
+ R.output(gv)
+ return gv
+
+ # Test case 7: 2D input with multi-dimensional index (second position)
+ class IndexPutBroadcast2D(Module):
+ def forward(self, data, indices_0):
+ indices_1 = torch.arange(data.shape[1]).unsqueeze(1)
+ values = torch.ones(len(indices_0), data.shape[1],
dtype=data.dtype)
+ return data.index_put_((indices_0, indices_1), values,
accumulate=False)
+
+ example_args_broadcast2 = (
+ torch.randn(32, 64, dtype=torch.float32),
+ torch.randint(0, 32, (10,), dtype=torch.int64),
+ )
+
+ @I.ir_module
+ class ExpectedBroadcast2D:
+ @R.function
+ def main(
+ data: R.Tensor((32, 64), dtype="float32"),
+ indices_0: R.Tensor((10,), dtype="int64"),
+ ) -> R.Tuple(R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64),
dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((64,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(64), R.prim_value(1),
dtype="int64"
+ )
+ lv1: R.Tensor((64, 1), dtype="int64") = R.expand_dims(lv,
axis=[1])
+ lv2: R.Tensor((10, 64), dtype="float32") = R.full(
+ R.shape([10, 64]), R.const(1.0, "float32"), dtype="float32"
+ )
+ lv3: R.Tensor((32, 64), dtype="float32") = R.index_put(
+ data, R.tuple(indices_0, lv1), lv2, accumulate=False
+ )
+ gv: R.Tuple(
+ R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64),
dtype="float32")
+ ) = (lv3, lv3)
+ R.output(gv)
+ return gv
+
+ # Test case 8: 3D input with mixed 1D and 2D indices
+ class IndexPutBroadcast3D(Module):
+ def forward(self, data, indices_1):
+ indices_0 = torch.arange(data.shape[0]).unsqueeze(1)
+ indices_2 = torch.arange(data.shape[2]).unsqueeze(1)
+ values = torch.ones(data.shape[0], len(indices_1), data.shape[2],
dtype=data.dtype)
+ return data.index_put_((indices_0, indices_1, indices_2), values,
accumulate=False)
+
+ example_args_broadcast3d = (
+ torch.randn(16, 32, 64, dtype=torch.float32),
+ torch.randint(0, 32, (10,), dtype=torch.int64),
+ )
+
+ @I.ir_module
+ class ExpectedBroadcast3D:
+ @R.function
+ def main(
+ data: R.Tensor((16, 32, 64), dtype="float32"),
+ indices_1: R.Tensor((10,), dtype="int64"),
+ ) -> R.Tuple(
+ R.Tensor((16, 32, 64), dtype="float32"), R.Tensor((16, 32, 64),
dtype="float32")
+ ):
+ with R.dataflow():
+ lv: R.Tensor((16,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(16), R.prim_value(1),
dtype="int64"
+ )
+ lv1: R.Tensor((16, 1), dtype="int64") = R.expand_dims(lv,
axis=[1])
+ lv2: R.Tensor((64,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(64), R.prim_value(1),
dtype="int64"
+ )
+ lv3: R.Tensor((64, 1), dtype="int64") = R.expand_dims(lv2,
axis=[1])
+ lv4: R.Tensor((16, 10, 64), dtype="float32") = R.full(
+ R.shape([16, 10, 64]), R.const(1.0, "float32"),
dtype="float32"
+ )
+ lv5: R.Tensor((16, 32, 64), dtype="float32") = R.index_put(
+ data, R.tuple(lv1, indices_1, lv3), lv4, accumulate=False
+ )
+ gv: R.Tuple(
+ R.Tensor((16, 32, 64), dtype="float32"), R.Tensor((16, 32,
64), dtype="float32")
+ ) = (lv5, lv5)
+ R.output(gv)
+ return gv
+
# Run verification for each case
verify_model(IndexPut1D(), example_args_1d, {}, Expected1D)
verify_model(IndexPut2D(), example_args_2d, {}, Expected2D)
verify_model(IndexPut3D(), example_args_3d, {}, Expected3D)
verify_model(IndexPut4D(), example_args_4d, {}, Expected4D)
verify_model(IndexPut5D(), example_args_5d, {}, Expected5D)
+ verify_model(IndexPutBroadcast1D(), example_args_broadcast1, {},
ExpectedBroadcast1D)
+ verify_model(IndexPutBroadcast2D(), example_args_broadcast2, {},
ExpectedBroadcast2D)
+ verify_model(IndexPutBroadcast3D(), example_args_broadcast3d, {},
ExpectedBroadcast3D)
def test_flip():