This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 14e41c681a [Relax][ONNX] Support dynamic repeats for Tile (#18878)
14e41c681a is described below

commit 14e41c681ac7e65af7e1118f86c55af7f2834043
Author: YinHanke <[email protected]>
AuthorDate: Fri Mar 6 12:38:32 2026 +0800

    [Relax][ONNX] Support dynamic repeats for Tile (#18878)
    
    ## Summary
    
    Support dynamic `repeats` for ONNX Tile in the Relax frontend.
    
    ## Changes
    
    - add a dynamic Tile conversion path for ONNX when `repeats` is a graph
    input
    - expose `topi.dyn_tile` to the Python/packed TOPI interface
    - add frontend tests for dynamic `repeats`
    
    ## Validation
    
    - `tests/python/relax/test_frontend_onnx.py -k test_tile_dynamic_repeats
    -q`
    - local end-to-end repro matches ONNX Runtime
    
    ## Issue
    Fixes #18752
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 49 +++++++++++++++++++++++--
 python/tvm/topi/transform.py                    | 22 +++++++++++
 src/topi/transform.cc                           |  5 +++
 tests/python/relax/test_frontend_onnx.py        | 31 ++++++++++++++++
 4 files changed, 104 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index b3c2d06eab..3dc575ae77 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1974,14 +1974,57 @@ class Pad(OnnxOpConverter):
 class Tile(OnnxOpConverter):
     """Converts an onnx Tile node into an equivalent Relax expression."""
 
+    @staticmethod
+    def _tensor_length(expr):
+        shape = expr.struct_info.shape
+        if not isinstance(shape, relax.ShapeExpr):
+            return None
+
+        length = shape.values[0]
+        if not isinstance(length, tir.IntImm):
+            return None
+        return length.value
+
     @classmethod
     def _impl_v13(cls, bb, inputs, attr, params):
         reps = get_constant(inputs[1], params)
         if isinstance(reps, relax.Constant):
             reps = reps.data.numpy().tolist()
-        else:
-            raise ValueError("Dynamic reps for Tile are supported yet.")
-        return bb.emit_te(topi.tile, inputs[0], reps)
+            return bb.emit_te(topi.tile, inputs[0], reps)
+
+        data = inputs[0]
+        data_ndim = data.struct_info.ndim
+        reps_len = cls._tensor_length(reps)
+        if data_ndim == -1 or reps_len is None:
+            raise ValueError("Dynamic Tile requires known input rank and 
repeats length.")
+
+        if reps.struct_info.dtype != "int64":
+            reps = bb.normalize(relax.op.astype(reps, "int64"))
+
+        data_shape = bb.normalize(relax.op.shape_of(data))
+        data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape))
+        output_shape_tensor = reps
+
+        if data_ndim > reps_len:
+            reps_prefix = relax.const(_np.ones((data_ndim - reps_len,), 
dtype="int64"), "int64")
+            output_shape_tensor = bb.normalize(
+                relax.op.concat([reps_prefix, output_shape_tensor], axis=0)
+            )
+        elif reps_len > data_ndim:
+            data_prefix = relax.const(_np.ones((reps_len - data_ndim,), 
dtype="int64"), "int64")
+            data_shape_tensor = bb.normalize(
+                relax.op.concat([data_prefix, data_shape_tensor], axis=0)
+            )
+
+        output_shape_tensor = bb.normalize(
+            relax.op.multiply(output_shape_tensor, data_shape_tensor)
+        )
+        output_shape = 
bb.normalize(relax.op.tensor_to_shape(output_shape_tensor))
+        output_shape_vars = [
+            tir.Var(f"tile_dim_{i}", "int64") for i in range(max(data_ndim, 
reps_len))
+        ]
+        bb.match_cast(output_shape, relax.ShapeStructInfo(output_shape_vars))
+        return bb.emit_te(topi.dyn_tile, data, output_shape_vars, reps_len)
 
 
 class Expand(OnnxOpConverter):
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index bc187e3f26..a3e7366446 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -657,6 +657,28 @@ def tile(a, reps):
     return cpp.tile(a, reps)
 
 
+def dyn_tile(a, new_shape, rdim):
+    """Repeats the whole array multiple times with dynamic output shape.
+
+    Parameters
+    ----------
+    a : tvm.te.Tensor
+        The tensor to be tiled.
+
+    new_shape : tuple of PrimExpr
+        The output shape after tiling.
+
+    rdim : int
+        The rank of the repeats input.
+
+    Returns
+    -------
+    ret : tvm.te.Tensor
+    """
+
+    return cpp.dyn_tile(a, new_shape, rdim)
+
+
 def layout_transform(array, src_layout, dst_layout, schedule_rule="None"):
     """Transform the layout according to src_layout and dst_layout
 
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index 5e2ffd4cbd..09f9a9be5e 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -156,6 +156,11 @@ TVM_FFI_STATIC_INIT_BLOCK() {
                   [](ffi::PackedArgs args, ffi::Any* rv) {
                     *rv = tile(args[0].cast<te::Tensor>(), 
args[1].cast<ffi::Array<Integer>>());
                   })
+      .def_packed("topi.dyn_tile",
+                  [](ffi::PackedArgs args, ffi::Any* rv) {
+                    *rv = dyn_tile(args[0].cast<te::Tensor>(), 
args[1].cast<ffi::Array<PrimExpr>>(),
+                                   args[2].cast<int>());
+                  })
       .def_packed("topi.gather",
                   [](ffi::PackedArgs args, ffi::Any* rv) {
                     *rv = gather(args[0].cast<te::Tensor>(), 
args[1].cast<int>(),
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 12f9f0f353..ecbc6c9e8a 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -2700,6 +2700,37 @@ def test_tile(dynamic):
     verify_tile(x.shape, repeats, z_array.shape)
 
 
[email protected]("dynamic_input", [True, False])
[email protected](
+    "in_shape,repeats",
+    [
+        ((2, 3), np.array([2, 2], dtype=np.int64)),
+        ((2, 3, 4), np.array([2, 2, 1], dtype=np.int64)),
+        ((2, 3, 4, 5), np.array([1, 2, 1, 2], dtype=np.int64)),
+    ],
+)
+def test_tile_dynamic_repeats(dynamic_input, in_shape, repeats):
+    x = np.random.rand(*in_shape).astype(np.float32)
+    out_shape = np.tile(x, repeats).shape
+
+    input_shape = ["?" for _ in in_shape] if dynamic_input else list(x.shape)
+    output_shape = ["?" for _ in out_shape] if dynamic_input else 
list(out_shape)
+
+    node = helper.make_node("Tile", inputs=["input", "repeats"], 
outputs=["out"])
+    graph = helper.make_graph(
+        [node],
+        "tile_dynamic_repeats_test",
+        inputs=[
+            helper.make_tensor_value_info("input", TensorProto.FLOAT, 
input_shape),
+            helper.make_tensor_value_info("repeats", TensorProto.INT64, 
[len(repeats)]),
+        ],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
output_shape)],
+    )
+    model = helper.make_model(graph, producer_name="tile_dynamic_repeats_test")
+
+    check_correctness(model, inputs={"input": x, "repeats": repeats}, opset=13)
+
+
 def _generate_roi_cases():
     # Base case when with_roi is False
     roi_list = [

Reply via email to