This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 14e41c681a [Relax][ONNX] Support dynamic repeats for Tile (#18878)
14e41c681a is described below
commit 14e41c681ac7e65af7e1118f86c55af7f2834043
Author: YinHanke <[email protected]>
AuthorDate: Fri Mar 6 12:38:32 2026 +0800
[Relax][ONNX] Support dynamic repeats for Tile (#18878)
## Summary
Support dynamic `repeats` for ONNX Tile in the Relax frontend.
## Changes
- add a dynamic Tile conversion path for ONNX when `repeats` is a graph
input
- expose `topi.dyn_tile` to the Python/packed TOPI interface
- add frontend tests for dynamic `repeats`
## Validation
- `tests/python/relax/test_frontend_onnx.py -k test_tile_dynamic_repeats
-q`
- local end-to-end repro matches ONNX Runtime
## Issue
Fixes #18752
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 49 +++++++++++++++++++++++--
python/tvm/topi/transform.py | 22 +++++++++++
src/topi/transform.cc | 5 +++
tests/python/relax/test_frontend_onnx.py | 31 ++++++++++++++++
4 files changed, 104 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index b3c2d06eab..3dc575ae77 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1974,14 +1974,57 @@ class Pad(OnnxOpConverter):
class Tile(OnnxOpConverter):
"""Converts an onnx Tile node into an equivalent Relax expression."""
+ @staticmethod
+ def _tensor_length(expr):
+ shape = expr.struct_info.shape
+ if not isinstance(shape, relax.ShapeExpr):
+ return None
+
+ length = shape.values[0]
+ if not isinstance(length, tir.IntImm):
+ return None
+ return length.value
+
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
reps = get_constant(inputs[1], params)
if isinstance(reps, relax.Constant):
reps = reps.data.numpy().tolist()
- else:
- raise ValueError("Dynamic reps for Tile are supported yet.")
- return bb.emit_te(topi.tile, inputs[0], reps)
+ return bb.emit_te(topi.tile, inputs[0], reps)
+
+ data = inputs[0]
+ data_ndim = data.struct_info.ndim
+ reps_len = cls._tensor_length(reps)
+ if data_ndim == -1 or reps_len is None:
+ raise ValueError("Dynamic Tile requires known input rank and
repeats length.")
+
+ if reps.struct_info.dtype != "int64":
+ reps = bb.normalize(relax.op.astype(reps, "int64"))
+
+ data_shape = bb.normalize(relax.op.shape_of(data))
+ data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape))
+ output_shape_tensor = reps
+
+ if data_ndim > reps_len:
+ reps_prefix = relax.const(_np.ones((data_ndim - reps_len,),
dtype="int64"), "int64")
+ output_shape_tensor = bb.normalize(
+ relax.op.concat([reps_prefix, output_shape_tensor], axis=0)
+ )
+ elif reps_len > data_ndim:
+ data_prefix = relax.const(_np.ones((reps_len - data_ndim,),
dtype="int64"), "int64")
+ data_shape_tensor = bb.normalize(
+ relax.op.concat([data_prefix, data_shape_tensor], axis=0)
+ )
+
+ output_shape_tensor = bb.normalize(
+ relax.op.multiply(output_shape_tensor, data_shape_tensor)
+ )
+ output_shape =
bb.normalize(relax.op.tensor_to_shape(output_shape_tensor))
+ output_shape_vars = [
+ tir.Var(f"tile_dim_{i}", "int64") for i in range(max(data_ndim,
reps_len))
+ ]
+ bb.match_cast(output_shape, relax.ShapeStructInfo(output_shape_vars))
+ return bb.emit_te(topi.dyn_tile, data, output_shape_vars, reps_len)
class Expand(OnnxOpConverter):
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index bc187e3f26..a3e7366446 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -657,6 +657,28 @@ def tile(a, reps):
return cpp.tile(a, reps)
+def dyn_tile(a, new_shape, rdim):
+ """Repeats the whole array multiple times with dynamic output shape.
+
+ Parameters
+ ----------
+ a : tvm.te.Tensor
+ The tensor to be tiled.
+
+ new_shape : tuple of PrimExpr
+ The output shape after tiling.
+
+ rdim : int
+ The rank of the repeats input.
+
+ Returns
+ -------
+ ret : tvm.te.Tensor
+ """
+
+ return cpp.dyn_tile(a, new_shape, rdim)
+
+
def layout_transform(array, src_layout, dst_layout, schedule_rule="None"):
"""Transform the layout according to src_layout and dst_layout
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index 5e2ffd4cbd..09f9a9be5e 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -156,6 +156,11 @@ TVM_FFI_STATIC_INIT_BLOCK() {
[](ffi::PackedArgs args, ffi::Any* rv) {
*rv = tile(args[0].cast<te::Tensor>(),
args[1].cast<ffi::Array<Integer>>());
})
+ .def_packed("topi.dyn_tile",
+ [](ffi::PackedArgs args, ffi::Any* rv) {
+ *rv = dyn_tile(args[0].cast<te::Tensor>(),
args[1].cast<ffi::Array<PrimExpr>>(),
+ args[2].cast<int>());
+ })
.def_packed("topi.gather",
[](ffi::PackedArgs args, ffi::Any* rv) {
*rv = gather(args[0].cast<te::Tensor>(),
args[1].cast<int>(),
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 12f9f0f353..ecbc6c9e8a 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -2700,6 +2700,37 @@ def test_tile(dynamic):
verify_tile(x.shape, repeats, z_array.shape)
[email protected]("dynamic_input", [True, False])
[email protected](
+ "in_shape,repeats",
+ [
+ ((2, 3), np.array([2, 2], dtype=np.int64)),
+ ((2, 3, 4), np.array([2, 2, 1], dtype=np.int64)),
+ ((2, 3, 4, 5), np.array([1, 2, 1, 2], dtype=np.int64)),
+ ],
+)
+def test_tile_dynamic_repeats(dynamic_input, in_shape, repeats):
+ x = np.random.rand(*in_shape).astype(np.float32)
+ out_shape = np.tile(x, repeats).shape
+
+ input_shape = ["?" for _ in in_shape] if dynamic_input else list(x.shape)
+ output_shape = ["?" for _ in out_shape] if dynamic_input else
list(out_shape)
+
+ node = helper.make_node("Tile", inputs=["input", "repeats"],
outputs=["out"])
+ graph = helper.make_graph(
+ [node],
+ "tile_dynamic_repeats_test",
+ inputs=[
+ helper.make_tensor_value_info("input", TensorProto.FLOAT,
input_shape),
+ helper.make_tensor_value_info("repeats", TensorProto.INT64,
[len(repeats)]),
+ ],
+ outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT,
output_shape)],
+ )
+ model = helper.make_model(graph, producer_name="tile_dynamic_repeats_test")
+
+ check_correctness(model, inputs={"input": x, "repeats": repeats}, opset=13)
+
+
def _generate_roi_cases():
# Base case when with_roi is False
roi_list = [