This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 99488d992d [Relax][Frontend][TFLite] Support STABLEHLO_WHILE (#19646)
99488d992d is described below

commit 99488d992de65ac9e6299548c673c3ca95ef98c2
Author: HoYi <[email protected]>
AuthorDate: Sun May 31 13:45:09 2026 +0800

    [Relax][Frontend][TFLite] Support STABLEHLO_WHILE (#19646)
    
    ## Summary
    
    This PR adds Relax TFLite frontend support for the TFLite builtin
    `STABLEHLO_WHILE` operator.
    
    `STABLEHLO_WHILE` uses StableHLO `BuiltinOptions2` to reference its
    condition
    and body region subgraphs. Its loop semantics otherwise match the
    existing
    TFLite `WHILE` importer path: loop-carried tensors are passed to the
    cond/body
    subgraphs, the cond subgraph returns a scalar bool, and the body
    subgraph
    returns the updated loop state.
    
    ## Design
    
    ### Shared While Lowering
    
    The native TFLite `WHILE` converter is refactored through a shared
    `_convert_while_like` helper. Native `WHILE` and `STABLEHLO_WHILE` now
    share the
    same validation and lowering path after their options are parsed:
    
    - native `WHILE` reads `WhileOptions` from `BuiltinOptions`
    - `STABLEHLO_WHILE` reads `StablehloWhileOptions` from `BuiltinOptions2`
    
    Both paths lower the referenced cond/body subgraphs to private Relax
    functions
    and emit a recursive private Relax function for the loop.
    
    ### Boundary Validation
    
    `STABLEHLO_WHILE` reuses the same guard-first checks as native `WHILE`:
    
    - loop input count must match op output count
    - cond subgraph input metadata must match loop-carried tensors
    - cond subgraph must have exactly one output
    - cond output must be a scalar bool tensor
    - body subgraph input and output metadata must match loop-carried
    tensors
    - referenced cond/body subgraph indices must be valid non-main subgraphs
    
    The recursive loop-function cache key now includes the generated
    function
    prefix. This prevents native `WHILE` and `STABLEHLO_WHILE` from
    accidentally
    sharing a cached loop wrapper if they reference the same cond/body
    subgraph
    indices.
    
    ## Operator Support
    
    | Operator | TFLite options | Relax lowering | Supported subset |
    |---|---|---|---|
    | `STABLEHLO_WHILE` | `StablehloWhileOptions.CondSubgraphIndex()`,
    `BodySubgraphIndex()` from `BuiltinOptions2` | recursive private Relax
    function | tensor loop-carried state, scalar bool cond output, matching
    cond/body interfaces |
    
    ## Tests
    
    The tests manually build a minimal StableHLO while TFLite flatbuffer and
    compare
    the imported Relax IR with `tvm.ir.assert_structural_equal`. Unsupported
    patterns use `pytest.raises`.
    
    | Test | Coverage |
    |---|---|
    | `test_stablehlo_while` | basic `STABLEHLO_WHILE` recursive private
    function lowering |
    | `test_stablehlo_while_non_bool_condition_unsupported` | cond output
    scalar bool guard |
    | `test_stablehlo_while_invalid_index_unsupported` | invalid cond/body
    subgraph index guard |
    | `test_stablehlo_while_output_count_mismatch_unsupported` | body output
    arity guard |
    | `test_stablehlo_while_input_metadata_mismatch_unsupported` | cond
    subgraph input metadata guard |
    | `test_stablehlo_while_output_metadata_mismatch_unsupported` | body
    subgraph output metadata guard |
    
    Local validation:
    
    ```bash
    python -m py_compile \
      python/tvm/relax/frontend/tflite/tflite_frontend.py \
      tests/python/relax/test_frontend_tflite.py
    
    python -m ruff check \
      python/tvm/relax/frontend/tflite/tflite_frontend.py \
      tests/python/relax/test_frontend_tflite.py
    
    python -m pytest \
      tests/python/relax/test_frontend_tflite.py \
      -k stablehlo_while -q
    
    python -m pytest \
      tests/python/relax/test_frontend_tflite.py \
      -k stablehlo -q
    ```
    
    Result:
    
    ```text
    py_compile: passed
    ruff check: All checks passed
    stablehlo_while tests: 6 passed
    stablehlo tests: 84 passed
    ```
    
    ## References
    
    - Issue #19519 item I: remaining StableHLO operators in TFLite
    - PR #19587: StableHLO region-based ops and multi-subgraph model support
    - PR #19616: TFLite control-flow / multi-subgraph support
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   |  64 ++++--
 tests/python/relax/test_frontend_tflite.py         | 219 +++++++++++++++++++++
 2 files changed, 264 insertions(+), 19 deletions(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 7046e43bbe..45cd41ce5b 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -387,6 +387,7 @@ class OperatorConverter:
                 self._convert_stablehlo_binary, relax_op=_op.subtract
             ),
             "STABLEHLO_TANH": functools.partial(self._convert_stablehlo_unary, 
relax_op=_op.tanh),
+            "STABLEHLO_WHILE": self._convert_stablehlo_while,
             "SQUEEZE": self.convert_squeeze,
             "STRIDED_SLICE": self.convert_strided_slice,
             "SUB": functools.partial(self._convert_elemwise, 
relax_op=_op.subtract),
@@ -2161,6 +2162,19 @@ class OperatorConverter:
             relax.op.sort(data, axis=int(opts.Dimension()), 
descending=descending)
         )
 
+    def _convert_stablehlo_while(self, op):
+        """Convert STABLEHLO_WHILE to a recursive Relax private function."""
+        from tflite.StablehloWhileOptions import StablehloWhileOptions
+
+        opts = self._get_stablehlo_options(op, StablehloWhileOptions)
+        return self._convert_while_like(
+            op,
+            "STABLEHLO_WHILE",
+            int(opts.CondSubgraphIndex()),
+            int(opts.BodySubgraphIndex()),
+            "tflite_stablehlo_while",
+        )
+
     def _get_builtin_options(self, op, options_cls):
         """Parse BuiltinOptions for a TFLite builtin operator."""
         from tflite.BuiltinOptions import BuiltinOptions
@@ -2402,14 +2416,15 @@ class OperatorConverter:
         cond_func,
         body_func,
         body_subgraph,
+        function_prefix="tflite_while",
     ):
         """Lower a TFLite WHILE op into a recursive private Relax function."""
-        cache_key = (cond_subgraph_index, body_subgraph_index, loop_var_count)
+        cache_key = (function_prefix, cond_subgraph_index, 
body_subgraph_index, loop_var_count)
         lowered_while_functions = 
self.conversion_state["lowered_while_functions"]
         if cache_key in lowered_while_functions:
             return lowered_while_functions[cache_key]
 
-        loop_name = 
f"tflite_while_subgraph_{cond_subgraph_index}_{body_subgraph_index}"
+        loop_name = 
f"{function_prefix}_subgraph_{cond_subgraph_index}_{body_subgraph_index}"
         params, _ = self._get_subgraph_params(body_subgraph)
         dummy_body = self._make_tuple_or_single(params)
         module_builder = self.conversion_state["module_builder"]
@@ -2489,47 +2504,44 @@ class OperatorConverter:
         args = [self.get_tensor_expr(tensor) for tensor in input_tensors]
         return relax.Call(if_func, args)
 
-    def convert_while(self, op):
-        """Convert TFLite WHILE to a recursive Relax private function."""
-        from tflite.WhileOptions import WhileOptions
-
-        opts = self._get_builtin_options(op, WhileOptions)
-        cond_subgraph_index = int(opts.CondSubgraphIndex())
-        body_subgraph_index = int(opts.BodySubgraphIndex())
+    def _convert_while_like(
+        self, op, op_name, cond_subgraph_index, body_subgraph_index, 
function_prefix
+    ):
+        """Convert a TFLite while-like operator with referenced cond/body 
subgraphs."""
         input_tensors = self.get_input_tensors(op)
         output_tensors = self.get_output_tensors(op)
         loop_var_count = len(input_tensors)
         if loop_var_count == 0:
-            raise tvm.error.OpNotImplemented("WHILE requires loop-carried 
inputs")
+            raise tvm.error.OpNotImplemented(f"{op_name} requires loop-carried 
inputs")
         if len(output_tensors) != loop_var_count:
-            raise tvm.error.OpNotImplemented("WHILE output count must match 
input count")
+            raise tvm.error.OpNotImplemented(f"{op_name} output count must 
match input count")
 
         cond_subgraph = self._check_subgraph_interface(
             cond_subgraph_index,
-            "WHILE",
+            op_name,
             input_tensors=input_tensors,
             output_count=1,
         )
         body_subgraph = self._check_subgraph_interface(
             body_subgraph_index,
-            "WHILE",
+            op_name,
             input_tensors=input_tensors,
             output_tensors=input_tensors,
         )
         for input_tensor, output_tensor in zip(input_tensors, output_tensors):
-            self._check_tensor_metadata_match(input_tensor, output_tensor, 
"WHILE", "loop state")
+            self._check_tensor_metadata_match(input_tensor, output_tensor, 
op_name, "loop state")
         cond_output = cond_subgraph.Tensors(int(cond_subgraph.Outputs(0)))
-        self._require_scalar_bool_tensor(cond_output, "WHILE")
+        self._require_scalar_bool_tensor(cond_output, op_name)
 
         cond_func = self._lower_subgraph_to_function(
             cond_subgraph_index,
-            f"tflite_while_cond_subgraph_{cond_subgraph_index}",
-            op_name="WHILE",
+            f"{function_prefix}_cond_subgraph_{cond_subgraph_index}",
+            op_name=op_name,
         )
         body_func = self._lower_subgraph_to_function(
             body_subgraph_index,
-            f"tflite_while_body_subgraph_{body_subgraph_index}",
-            op_name="WHILE",
+            f"{function_prefix}_body_subgraph_{body_subgraph_index}",
+            op_name=op_name,
         )
 
         loop_gv = self._lower_while_to_function(
@@ -2539,11 +2551,25 @@ class OperatorConverter:
             cond_func,
             body_func,
             body_subgraph,
+            function_prefix=function_prefix,
         )
 
         args = [self.get_tensor_expr(tensor) for tensor in input_tensors]
         return relax.Call(loop_gv, args)
 
+    def convert_while(self, op):
+        """Convert TFLite WHILE to a recursive Relax private function."""
+        from tflite.WhileOptions import WhileOptions
+
+        opts = self._get_builtin_options(op, WhileOptions)
+        return self._convert_while_like(
+            op,
+            "WHILE",
+            int(opts.CondSubgraphIndex()),
+            int(opts.BodySubgraphIndex()),
+            "tflite_while",
+        )
+
     def convert_call_once(self, op):
         """Convert TFLite CALL_ONCE for no-op and resource-variable 
initialization subsets."""
         from tflite.CallOnceOptions import CallOnceOptions
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 05a6c1e5e5..cc3a84e2fd 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -3695,6 +3695,7 @@ _tfl_stablehlo_reduce_opts = 
_get_tflite_schema_module("StablehloReduceOptions")
 _tfl_stablehlo_reduce_window_opts = 
_get_tflite_schema_module("StablehloReduceWindowOptions")
 _tfl_stablehlo_scatter_opts = 
_get_tflite_schema_module("StablehloScatterOptions")
 _tfl_stablehlo_sort_opts = _get_tflite_schema_module("StablehloSortOptions")
+_tfl_stablehlo_while_opts = _get_tflite_schema_module("StablehloWhileOptions")
 _tfl_call_options = _get_tflite_schema_module("CallOptions")
 _tfl_call_once_options = _get_tflite_schema_module("CallOnceOptions")
 _tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata")
@@ -3946,6 +3947,17 @@ def _build_while_options(builder, cond_subgraph_index, 
body_subgraph_index):
     return _tfl_while_options.WhileOptionsEnd(builder)
 
 
+def _build_stablehlo_while_options(builder, cond_subgraph_index, 
body_subgraph_index):
+    _tfl_stablehlo_while_opts.StablehloWhileOptionsStart(builder)
+    _tfl_stablehlo_while_opts.StablehloWhileOptionsAddCondSubgraphIndex(
+        builder, cond_subgraph_index
+    )
+    _tfl_stablehlo_while_opts.StablehloWhileOptionsAddBodySubgraphIndex(
+        builder, body_subgraph_index
+    )
+    return _tfl_stablehlo_while_opts.StablehloWhileOptionsEnd(builder)
+
+
 def _build_call_once_options(builder, init_subgraph_index):
     _tfl_call_once_options.CallOnceOptionsStart(builder)
     _tfl_call_once_options.CallOnceOptionsAddInitSubgraphIndex(builder, 
init_subgraph_index)
@@ -6296,6 +6308,107 @@ def 
_build_stablehlo_scatter_model(reducer_name="STABLEHLO_ADD", update_window_d
     )
 
 
+def _build_stablehlo_while_model(
+    cond_subgraph_index=1,
+    body_subgraph_index=2,
+    cond_output_type=_tfl_tensor_type.BOOL,
+    cond_input_type=_tfl_tensor_type.INT32,
+    body_outputs=None,
+    body_input_type=_tfl_tensor_type.INT32,
+    body_output_type=_tfl_tensor_type.INT32,
+    main_output_type=_tfl_tensor_type.INT32,
+):
+    """Build a STABLEHLO_WHILE model incrementing an int32 scalar until i < 3 
is false."""
+    builder = flatbuffers.Builder(1024)
+
+    body_outputs = [2] if body_outputs is None else body_outputs
+    while_options = _build_stablehlo_while_options(
+        builder, cond_subgraph_index, body_subgraph_index
+    )
+    _tfl_stablehlo_compare_opts.StablehloCompareOptionsStart(builder)
+    _tfl_stablehlo_compare_opts.StablehloCompareOptionsAddComparisonDirection(
+        builder,
+        
_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LT,
+    )
+    compare_opts = 
_tfl_stablehlo_compare_opts.StablehloCompareOptionsEnd(builder)
+    one = np.array(1, dtype=np.int32)
+    three = np.array(3, dtype=np.int32)
+
+    main_tensors = [
+        _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 3, [], tensor_type=main_output_type),
+    ]
+    main_while = _build_operator(
+        builder,
+        0,
+        [0],
+        [1],
+        builtin_options2_type=_tfl_builtin_options2.StablehloWhileOptions,
+        builtin_options2=while_options,
+    )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=[main_while],
+        inputs=[0],
+        outputs=[1],
+    )
+
+    cond_tensors = [
+        _build_tensor(builder, 0, [], tensor_type=cond_input_type),
+        _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 3, [], tensor_type=cond_output_type),
+    ]
+    cond_compare = _build_operator(
+        builder,
+        1,
+        [0, 1],
+        [2],
+        builtin_options2_type=_tfl_builtin_options2.StablehloCompareOptions,
+        builtin_options2=compare_opts,
+    )
+    cond_subgraph = _build_subgraph(
+        builder,
+        tensors=cond_tensors,
+        operators=[cond_compare],
+        inputs=[0],
+        outputs=[2],
+    )
+
+    body_tensors = [
+        _build_tensor(builder, 0, [], tensor_type=body_input_type),
+        _build_tensor(builder, 2, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 3, [], tensor_type=body_output_type),
+    ]
+    body_add = _build_operator(builder, 2, [0, 1], [2])
+    body_subgraph = _build_subgraph(
+        builder,
+        tensors=body_tensors,
+        operators=[body_add],
+        inputs=[0],
+        outputs=body_outputs,
+    )
+
+    operator_codes = [
+        _build_operator_code(builder, 
_get_stablehlo_builtin_operator("STABLEHLO_WHILE")),
+        _build_operator_code(builder, 
_get_stablehlo_builtin_operator("STABLEHLO_COMPARE")),
+        _build_operator_code(builder, 
_get_stablehlo_builtin_operator("STABLEHLO_ADD")),
+    ]
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder, three.tobytes()),
+        _build_buffer(builder, one.tobytes()),
+        _build_buffer(builder),
+    ]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[cond_subgraph, body_subgraph],
+        operator_codes=operator_codes,
+        buffers=buffers,
+    )
+
+
 def _build_stablehlo_composite_model(with_attributes=False, 
use_main_input_after_composite=False):
     """Build a STABLEHLO_COMPOSITE model that decomposes to 
STABLEHLO_NEGATE."""
     builder = flatbuffers.Builder(1024)
@@ -6699,6 +6812,112 @@ def test_stablehlo_scatter_update_window_unsupported():
         from_tflite(tflite_model)
 
 
+def test_stablehlo_while():
+    """TFLite STABLEHLO_WHILE lowers to a recursive Relax private function."""
+    mod = _load_model_from_buffer(_build_stablehlo_while_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function(private=True)
+        def tflite_stablehlo_while_cond_subgraph_1(
+            tvmgen_tensor_0: R.Tensor((), dtype="int32"),
+        ) -> R.Tensor((), dtype="bool"):
+            with R.dataflow():
+                gv: R.Tensor((), dtype="bool") = R.less(tvmgen_tensor_0, 
R.const(3, "int32"))
+                R.output(gv)
+            return gv
+
+        @R.function(private=True)
+        def tflite_stablehlo_while_body_subgraph_2(
+            tvmgen_tensor_0: R.Tensor((), dtype="int32"),
+        ) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                gv: R.Tensor((), dtype="int32") = R.add(tvmgen_tensor_0, 
R.const(1, "int32"))
+                R.output(gv)
+            return gv
+
+        @R.function(private=True)
+        def tflite_stablehlo_while_subgraph_1_2(
+            tvmgen_tensor_0: R.Tensor((), dtype="int32"),
+        ) -> R.Tensor((), dtype="int32"):
+            cls = Expected
+            while_cond: R.Tensor((), dtype="bool") = 
cls.tflite_stablehlo_while_cond_subgraph_1(
+                tvmgen_tensor_0
+            )
+            if while_cond:
+                gv: R.Tensor((), dtype="int32") = 
cls.tflite_stablehlo_while_body_subgraph_2(
+                    tvmgen_tensor_0
+                )
+                gv1: R.Tensor((), dtype="int32") = 
cls.tflite_stablehlo_while_subgraph_1_2(gv)
+                cond_result: R.Tensor((), dtype="int32") = gv1
+            else:
+                cond_result: R.Tensor((), dtype="int32") = tvmgen_tensor_0
+            return cond_result
+
+        @R.function
+        def main(
+            tvmgen_tensor_0: R.Tensor((), dtype="int32"),
+        ) -> R.Tensor((), dtype="int32"):
+            R.func_attr({"num_input": 1})
+            cls = Expected
+            with R.dataflow():
+                gv: R.Tensor((), dtype="int32") = 
cls.tflite_stablehlo_while_subgraph_1_2(
+                    tvmgen_tensor_0
+                )
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_while_non_bool_condition_unsupported():
+    """STABLEHLO_WHILE rejects cond subgraphs that do not return scalar 
bool."""
+    with pytest.raises(
+        tvm.error.OpNotImplemented, match="STABLEHLO_WHILE requires a scalar 
bool condition"
+    ):
+        _load_model_from_buffer(
+            
_build_stablehlo_while_model(cond_output_type=_tfl_tensor_type.INT32)
+        )
+
+
+def test_stablehlo_while_invalid_index_unsupported():
+    """STABLEHLO_WHILE rejects invalid cond/body subgraph indices before 
lowering."""
+    with pytest.raises(
+        tvm.error.OpNotImplemented, match="STABLEHLO_WHILE requires a valid 
subgraph index"
+    ):
+        
_load_model_from_buffer(_build_stablehlo_while_model(cond_subgraph_index=3))
+
+
+def test_stablehlo_while_output_count_mismatch_unsupported():
+    """STABLEHLO_WHILE rejects body subgraphs whose output arity does not 
match loop vars."""
+    with pytest.raises(
+        tvm.error.OpNotImplemented, match="STABLEHLO_WHILE subgraph output 
count mismatch"
+    ):
+        _load_model_from_buffer(_build_stablehlo_while_model(body_outputs=[]))
+
+
+def test_stablehlo_while_input_metadata_mismatch_unsupported():
+    """STABLEHLO_WHILE rejects cond subgraph inputs whose metadata does not 
match loop vars."""
+    with pytest.raises(
+        tvm.error.OpNotImplemented,
+        match="STABLEHLO_WHILE subgraph input tensor metadata mismatch",
+    ):
+        _load_model_from_buffer(
+            
_build_stablehlo_while_model(cond_input_type=_tfl_tensor_type.FLOAT32)
+        )
+
+
+def test_stablehlo_while_output_metadata_mismatch_unsupported():
+    """STABLEHLO_WHILE rejects body outputs whose metadata does not match loop 
vars."""
+    with pytest.raises(
+        tvm.error.OpNotImplemented,
+        match="STABLEHLO_WHILE subgraph output tensor metadata mismatch",
+    ):
+        _load_model_from_buffer(
+            
_build_stablehlo_while_model(body_output_type=_tfl_tensor_type.FLOAT32)
+        )
+
+
 def test_stablehlo_composite():
     """TFLite StableHLO COMPOSITE inlines a simple decomposition subgraph."""
     mod = _load_model_from_buffer(_build_stablehlo_composite_model())

Reply via email to