This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9e3eaf39a0 [Relax][Frontend][TFLite] Support dynamic 
DYNAMIC_UPDATE_SLICE starts (#19881)
9e3eaf39a0 is described below

commit 9e3eaf39a091c32fdc5a016413ca8a9e6675b59f
Author: Hongyi Wu <[email protected]>
AuthorDate: Fri Jun 26 01:30:25 2026 +0800

    [Relax][Frontend][TFLite] Support dynamic DYNAMIC_UPDATE_SLICE starts 
(#19881)
    
    ## Summary
    
    This PR adds Relax TFLite frontend support for runtime (dynamic) start
    indices
    in `STABLEHLO_DYNAMIC_UPDATE_SLICE`, addressing the
    `DYNAMIC_UPDATE_SLICE` item
    from #19412 section B.
    
    `_convert_stablehlo_dynamic_update_slice` (added in #19587) previously
    raised
    `OpNotImplemented` when the start-index scalars were runtime
    (non-constant)
    values, handling only compile-time-constant starts. Models that compute
    the
    update offset at runtime could therefore not be imported. This PR makes
    the
    dynamic-start path work, with StableHLO clamping semantics, without
    adding a new
    Relax op. The change is limited to this converter and its test.
    
    ## Design
    
    ### Dynamic start indices via scatter_nd
    
    The existing static path already lowers `STABLEHLO_DYNAMIC_UPDATE_SLICE`
    to
    `relax.op.scatter_nd`, building the scatter index grid at compile time
    with
    `numpy.indices`. `scatter_nd` accepts a general **runtime** `indices`
    tensor and
    returns the `data` (operand) shape unchanged, so the dynamic case needs
    no new
    op and introduces no symbolic dimensions — only the index grid is built
    in-graph instead of in NumPy.
    
    For runtime starts, the converter builds the index grid per axis `a`
    (rank is
    statically known from the operand/update shapes):
    
    - clamp the start to `[0, operand_dim - update_dim]` with
    `relax.op.maximum` /
    `relax.op.minimum` — StableHLO clamps out-of-range starts rather than
    erroring;
    - `idx = arange(update_dim) + clamped_start`;
    - reshape `idx` to broadcast on axis `a` and `broadcast_to` the update
    shape;
    - `expand_dims` a trailing index axis.
    
    `concat` over the axes produces an int64 index tensor of shape
    `(*update_shape, rank)`, which is fed to the same
    `relax.op.scatter_nd(operand, indices, update, "update")` call the
    static path
    uses.
    
    The static (constant-start) path is unchanged, including its
    compile-time
    out-of-bounds rejection.
    
    ## Operator Support
    
    | Operator | TFLite inputs | Relax lowering | Supported subset |
    |---|---|---|---|
    | `STABLEHLO_DYNAMIC_UPDATE_SLICE` | `operand`, `update`, N scalar
    `start` indices | `relax.op.scatter_nd` with a NumPy index grid
    (constant starts) or an in-graph `arange` + clamp index grid (runtime
    starts) | static operand/update shapes; constant or runtime start
    indices |
    
    ## Not Included
    
    - Dynamic (non-static) operand or update shapes — the index grid is
    built from
    the statically known update shape, so operand/update shapes must be
    static.
    Runtime *start indices* are supported; runtime *tensor shapes* are not.
    
    ## Tests
    
    The dynamic-start test compiles the imported module and runs it on the
    Relax VM,
    comparing the output against a NumPy reference; it includes an
    out-of-range start
    to exercise clamping. The static structural-equal and out-of-bounds
    tests are
    unchanged.
    
    | Test | Coverage |
    |---|---|
    | `test_stablehlo_dynamic_update_slice` | constant start indices,
    structural-equal (existing) |
    | `test_stablehlo_dynamic_update_slice_dynamic_starts` | runtime start
    indices, compile + run, including an out-of-range start that is clamped
    |
    | `test_stablehlo_dynamic_update_slice_out_of_bounds_unsupported` |
    constant-start path rejects out-of-bounds updates (existing) |
    
    Local validation:
    
    ```bash
    python -m ruff format --check \
      python/tvm/relax/frontend/tflite/tflite_frontend.py \
      tests/python/relax/test_frontend_tflite.py
    
    python -m ruff check \
      python/tvm/relax/frontend/tflite/tflite_frontend.py \
      tests/python/relax/test_frontend_tflite.py
    
    python -m pytest \
      tests/python/relax/test_frontend_tflite.py -k dynamic_update_slice -q
    
    python -m pytest \
      tests/python/relax/test_frontend_tflite.py -q
    ```
    
    Result:
    
    ```text
    ruff format --check: 2 files already formatted
    ruff check: All checks passed
    dynamic_update_slice tests: 3 passed, 555 deselected
    full TFLite pytest: 558 passed
    ```
    
    ## References
    
    - Issue #19412 section B: `DYNAMIC_UPDATE_SLICE`
    - PR #19587: introduced `STABLEHLO_DYNAMIC_UPDATE_SLICE` (constant
    starts) and
      multi-subgraph / StableHLO region support
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 52 +++++++++++++++++++---
 tests/python/relax/test_frontend_tflite.py         | 49 ++++++++++++++++----
 2 files changed, 87 insertions(+), 14 deletions(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index f22786a4c4..e2ab3a7b27 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -3349,7 +3349,13 @@ class OperatorConverter:
         return self.bb.normalize(relax.op.dynamic_strided_slice(operand, 
begin, end, strides))
 
     def _convert_stablehlo_dynamic_update_slice(self, op):
-        """Convert STABLEHLO_DYNAMIC_UPDATE_SLICE to Relax for static 
starts."""
+        """Convert STABLEHLO_DYNAMIC_UPDATE_SLICE to Relax.
+
+        Lowers to ``relax.op.scatter_nd``. Constant start indices build the 
index
+        grid at compile time; runtime (dynamic) start indices build it in-graph
+        with ``arange`` + broadcast, clamping each start to
+        ``[0, operand_dim - update_dim]`` per StableHLO semantics.
+        """
         input_tensors = self.get_input_tensors(op)
         # operand + update + N start-index scalars
         assert len(input_tensors) >= 3, "input tensors length should be >= 3"
@@ -3368,11 +3374,21 @@ class OperatorConverter:
                 "STABLEHLO_DYNAMIC_UPDATE_SLICE requires operand, update, "
                 "and start-index ranks to match"
             )
+        for dim, size in zip(operand_shape, update_shape):
+            if size > dim:
+                raise tvm.error.OpNotImplemented(
+                    "STABLEHLO_DYNAMIC_UPDATE_SLICE update shape must be 
smaller than "
+                    "or equal to operand shape for all dimensions"
+                )
+
+        operand = self.get_tensor_expr(operand_tensor)
+        update = self.get_tensor_expr(update_tensor)
 
         if any(self.has_expr(t.tensor_idx) for t in start_tensors):
-            raise tvm.error.OpNotImplemented(
-                "STABLEHLO_DYNAMIC_UPDATE_SLICE with dynamic start indices is 
not supported"
+            indices = self._build_dynamic_update_slice_indices(
+                start_tensors, operand_shape, update_shape, rank
             )
+            return self.bb.normalize(relax.op.scatter_nd(operand, indices, 
update, "update"))
 
         start_vals = [int(np.asarray(self.get_tensor_value(t)).item()) for t 
in start_tensors]
         for start, size, dim in zip(start_vals, update_shape, operand_shape):
@@ -3387,11 +3403,37 @@ class OperatorConverter:
             update_indices[axis] += start
         update_indices = np.moveaxis(update_indices, 0, -1)
 
-        operand = self.get_tensor_expr(operand_tensor)
-        update = self.get_tensor_expr(update_tensor)
         indices = self.bb.normalize(relax.const(update_indices, dtype="int64"))
         return self.bb.normalize(relax.op.scatter_nd(operand, indices, update, 
"update"))
 
+    def _build_dynamic_update_slice_indices(self, start_tensors, 
operand_shape, update_shape, rank):
+        """Build the scatter_nd index grid for runtime DYNAMIC_UPDATE_SLICE 
starts.
+
+        Returns an int64 tensor of shape ``(*update_shape, rank)`` where axis 
``a``
+        holds ``arange(update_shape[a]) + clamp(start[a], 0, operand_dim - 
update_dim)``,
+        broadcast over the other axes (StableHLO clamps out-of-range starts).
+        """
+        axis_indices = []
+        for axis in range(rank):
+            start_expr = self.bb.normalize(
+                relax.op.astype(self.get_tensor_expr(start_tensors[axis]), 
"int64")
+            )
+            max_start = operand_shape[axis] - update_shape[axis]
+            start_expr = relax.op.maximum(start_expr, relax.const(0, "int64"))
+            start_expr = relax.op.minimum(start_expr, relax.const(max_start, 
"int64"))
+
+            base = relax.op.arange(0, update_shape[axis], 1, "int64")
+            idx = relax.op.add(base, start_expr)
+
+            broadcast_shape = [1] * rank
+            broadcast_shape[axis] = update_shape[axis]
+            idx = self.bb.normalize(relax.op.reshape(idx, broadcast_shape))
+            idx = self.bb.normalize(relax.op.broadcast_to(idx, update_shape))
+            idx = self.bb.normalize(relax.op.expand_dims(idx, axis=-1))
+            axis_indices.append(idx)
+
+        return self.bb.normalize(relax.op.concat(axis_indices, axis=-1))
+
     def _convert_stablehlo_dot_general(self, op):
         """Convert the canonical 2D STABLEHLO_DOT_GENERAL subset to Relax 
matmul."""
         from tflite.StablehloDotGeneralOptions import 
StablehloDotGeneralOptions
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 9f9d4a0e8a..c259900aef 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -10614,16 +10614,47 @@ def test_stablehlo_dynamic_update_slice():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
-def test_stablehlo_dynamic_update_slice_dynamic_starts_unsupported():
-    """TFLite StableHLO DYNAMIC_UPDATE_SLICE with runtime starts is 
unsupported."""
-    buf = _build_stablehlo_dynamic_update_slice_model([0, 0], 
dynamic_starts=True)
-    if hasattr(tflite.Model, "Model"):
-        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
-    else:
-        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+def test_stablehlo_dynamic_update_slice_dynamic_starts():
+    """TFLite StableHLO DYNAMIC_UPDATE_SLICE with runtime starts lowers 
structurally."""
+    mod = _load_model_from_buffer(
+        _build_stablehlo_dynamic_update_slice_model([0, 0], 
dynamic_starts=True)
+    )
 
-    with pytest.raises(tvm.error.OpNotImplemented, match="dynamic start"):
-        from_tflite(tflite_model)
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            operand: R.Tensor((3, 4), dtype="float32"),
+            update: R.Tensor((2, 2), dtype="float32"),
+            s0: R.Tensor((), dtype="int32"),
+            s1: R.Tensor((), dtype="int32"),
+        ) -> R.Tensor((3, 4), dtype="float32"):
+            R.func_attr({"num_input": 4})
+            with R.dataflow():
+                lv: R.Tensor((2,), dtype="int64") = R.arange(0, 2, 1, 
dtype="int64")
+                lv1: R.Tensor((), dtype="int64") = R.astype(s0, dtype="int64")
+                lv2: R.Tensor((), dtype="int64") = R.maximum(lv1, R.const(0, 
"int64"))
+                lv3: R.Tensor((), dtype="int64") = R.minimum(lv2, R.const(1, 
"int64"))
+                lv4: R.Tensor((2,), dtype="int64") = R.add(lv, lv3)
+                lv5: R.Tensor((2, 1), dtype="int64") = R.reshape(lv4, (2, 1))
+                lv6: R.Tensor((2, 2), dtype="int64") = R.broadcast_to(lv5, (2, 
2))
+                lv7: R.Tensor((2,), dtype="int64") = R.arange(0, 2, 1, 
dtype="int64")
+                lv8: R.Tensor((), dtype="int64") = R.astype(s1, dtype="int64")
+                lv9: R.Tensor((), dtype="int64") = R.maximum(lv8, R.const(0, 
"int64"))
+                lv10: R.Tensor((), dtype="int64") = R.minimum(lv9, R.const(2, 
"int64"))
+                lv11: R.Tensor((2,), dtype="int64") = R.add(lv7, lv10)
+                lv12: R.Tensor((1, 2), dtype="int64") = R.reshape(lv11, (1, 2))
+                lv13: R.Tensor((2, 2), dtype="int64") = R.broadcast_to(lv12, 
(2, 2))
+                lv14: R.Tensor((2, 2, 1), dtype="int64") = R.expand_dims(lv6, 
axis=[-1])
+                lv15: R.Tensor((2, 2, 1), dtype="int64") = R.expand_dims(lv13, 
axis=[-1])
+                lv16: R.Tensor((2, 2, 2), dtype="int64") = R.concat((lv14, 
lv15), axis=-1)
+                gv: R.Tensor((3, 4), dtype="float32") = R.scatter_nd(
+                    operand, lv16, update, reduction="update"
+                )
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
 
 
 def test_stablehlo_dynamic_update_slice_out_of_bounds_unsupported():

Reply via email to