This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 99488d992d [Relax][Frontend][TFLite] Support STABLEHLO_WHILE (#19646)
99488d992d is described below
commit 99488d992de65ac9e6299548c673c3ca95ef98c2
Author: HoYi <[email protected]>
AuthorDate: Sun May 31 13:45:09 2026 +0800
[Relax][Frontend][TFLite] Support STABLEHLO_WHILE (#19646)
## Summary
This PR adds Relax TFLite frontend support for the TFLite builtin
`STABLEHLO_WHILE` operator.
`STABLEHLO_WHILE` uses StableHLO `BuiltinOptions2` to reference its
condition
and body region subgraphs. Its loop semantics otherwise match the
existing
TFLite `WHILE` importer path: loop-carried tensors are passed to the
cond/body
subgraphs, the cond subgraph returns a scalar bool, and the body
subgraph
returns the updated loop state.
## Design
### Shared While Lowering
The native TFLite `WHILE` converter is refactored through a shared
`_convert_while_like` helper. Native `WHILE` and `STABLEHLO_WHILE` now
share the
same validation and lowering path after their options are parsed:
- native `WHILE` reads `WhileOptions` from `BuiltinOptions`
- `STABLEHLO_WHILE` reads `StablehloWhileOptions` from `BuiltinOptions2`
Both paths lower the referenced cond/body subgraphs to private Relax
functions
and emit a recursive private Relax function for the loop.
### Boundary Validation
`STABLEHLO_WHILE` reuses the same guard-first checks as native `WHILE`:
- loop input count must match op output count
- cond subgraph input metadata must match loop-carried tensors
- cond subgraph must have exactly one output
- cond output must be a scalar bool tensor
- body subgraph input and output metadata must match loop-carried
tensors
- referenced cond/body subgraph indices must be valid non-main subgraphs
The recursive loop-function cache key now includes the generated
function
prefix. This prevents native `WHILE` and `STABLEHLO_WHILE` from
accidentally
sharing a cached loop wrapper if they reference the same cond/body
subgraph
indices.
## Operator Support
| Operator | TFLite options | Relax lowering | Supported subset |
|---|---|---|---|
| `STABLEHLO_WHILE` | `StablehloWhileOptions.CondSubgraphIndex()`,
`BodySubgraphIndex()` from `BuiltinOptions2` | recursive private Relax
function | tensor loop-carried state, scalar bool cond output, matching
cond/body interfaces |
## Tests
The tests manually build a minimal StableHLO while TFLite flatbuffer and
compare
the imported Relax IR with `tvm.ir.assert_structural_equal`. Unsupported
patterns use `pytest.raises`.
| Test | Coverage |
|---|---|
| `test_stablehlo_while` | basic `STABLEHLO_WHILE` recursive private
function lowering |
| `test_stablehlo_while_non_bool_condition_unsupported` | cond output
scalar bool guard |
| `test_stablehlo_while_invalid_index_unsupported` | invalid cond/body
subgraph index guard |
| `test_stablehlo_while_output_count_mismatch_unsupported` | body output
arity guard |
| `test_stablehlo_while_input_metadata_mismatch_unsupported` | cond
subgraph input metadata guard |
| `test_stablehlo_while_output_metadata_mismatch_unsupported` | body
subgraph output metadata guard |
Local validation:
```bash
python -m py_compile \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m ruff check \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m pytest \
tests/python/relax/test_frontend_tflite.py \
-k stablehlo_while -q
python -m pytest \
tests/python/relax/test_frontend_tflite.py \
-k stablehlo -q
```
Result:
```text
py_compile: passed
ruff check: All checks passed
stablehlo_while tests: 6 passed
stablehlo tests: 84 passed
```
## References
- Issue #19519 item I: remaining StableHLO operators in TFLite
- PR #19587: StableHLO region-based ops and multi-subgraph model support
- PR #19616: TFLite control-flow / multi-subgraph support
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 64 ++++--
tests/python/relax/test_frontend_tflite.py | 219 +++++++++++++++++++++
2 files changed, 264 insertions(+), 19 deletions(-)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 7046e43bbe..45cd41ce5b 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -387,6 +387,7 @@ class OperatorConverter:
self._convert_stablehlo_binary, relax_op=_op.subtract
),
"STABLEHLO_TANH": functools.partial(self._convert_stablehlo_unary,
relax_op=_op.tanh),
+ "STABLEHLO_WHILE": self._convert_stablehlo_while,
"SQUEEZE": self.convert_squeeze,
"STRIDED_SLICE": self.convert_strided_slice,
"SUB": functools.partial(self._convert_elemwise,
relax_op=_op.subtract),
@@ -2161,6 +2162,19 @@ class OperatorConverter:
relax.op.sort(data, axis=int(opts.Dimension()),
descending=descending)
)
+ def _convert_stablehlo_while(self, op):
+ """Convert STABLEHLO_WHILE to a recursive Relax private function."""
+ from tflite.StablehloWhileOptions import StablehloWhileOptions
+
+ opts = self._get_stablehlo_options(op, StablehloWhileOptions)
+ return self._convert_while_like(
+ op,
+ "STABLEHLO_WHILE",
+ int(opts.CondSubgraphIndex()),
+ int(opts.BodySubgraphIndex()),
+ "tflite_stablehlo_while",
+ )
+
def _get_builtin_options(self, op, options_cls):
"""Parse BuiltinOptions for a TFLite builtin operator."""
from tflite.BuiltinOptions import BuiltinOptions
@@ -2402,14 +2416,15 @@ class OperatorConverter:
cond_func,
body_func,
body_subgraph,
+ function_prefix="tflite_while",
):
"""Lower a TFLite WHILE op into a recursive private Relax function."""
- cache_key = (cond_subgraph_index, body_subgraph_index, loop_var_count)
+ cache_key = (function_prefix, cond_subgraph_index,
body_subgraph_index, loop_var_count)
lowered_while_functions =
self.conversion_state["lowered_while_functions"]
if cache_key in lowered_while_functions:
return lowered_while_functions[cache_key]
- loop_name =
f"tflite_while_subgraph_{cond_subgraph_index}_{body_subgraph_index}"
+ loop_name =
f"{function_prefix}_subgraph_{cond_subgraph_index}_{body_subgraph_index}"
params, _ = self._get_subgraph_params(body_subgraph)
dummy_body = self._make_tuple_or_single(params)
module_builder = self.conversion_state["module_builder"]
@@ -2489,47 +2504,44 @@ class OperatorConverter:
args = [self.get_tensor_expr(tensor) for tensor in input_tensors]
return relax.Call(if_func, args)
- def convert_while(self, op):
- """Convert TFLite WHILE to a recursive Relax private function."""
- from tflite.WhileOptions import WhileOptions
-
- opts = self._get_builtin_options(op, WhileOptions)
- cond_subgraph_index = int(opts.CondSubgraphIndex())
- body_subgraph_index = int(opts.BodySubgraphIndex())
+ def _convert_while_like(
+ self, op, op_name, cond_subgraph_index, body_subgraph_index,
function_prefix
+ ):
+ """Convert a TFLite while-like operator with referenced cond/body
subgraphs."""
input_tensors = self.get_input_tensors(op)
output_tensors = self.get_output_tensors(op)
loop_var_count = len(input_tensors)
if loop_var_count == 0:
- raise tvm.error.OpNotImplemented("WHILE requires loop-carried
inputs")
+ raise tvm.error.OpNotImplemented(f"{op_name} requires loop-carried
inputs")
if len(output_tensors) != loop_var_count:
- raise tvm.error.OpNotImplemented("WHILE output count must match
input count")
+ raise tvm.error.OpNotImplemented(f"{op_name} output count must
match input count")
cond_subgraph = self._check_subgraph_interface(
cond_subgraph_index,
- "WHILE",
+ op_name,
input_tensors=input_tensors,
output_count=1,
)
body_subgraph = self._check_subgraph_interface(
body_subgraph_index,
- "WHILE",
+ op_name,
input_tensors=input_tensors,
output_tensors=input_tensors,
)
for input_tensor, output_tensor in zip(input_tensors, output_tensors):
- self._check_tensor_metadata_match(input_tensor, output_tensor,
"WHILE", "loop state")
+ self._check_tensor_metadata_match(input_tensor, output_tensor,
op_name, "loop state")
cond_output = cond_subgraph.Tensors(int(cond_subgraph.Outputs(0)))
- self._require_scalar_bool_tensor(cond_output, "WHILE")
+ self._require_scalar_bool_tensor(cond_output, op_name)
cond_func = self._lower_subgraph_to_function(
cond_subgraph_index,
- f"tflite_while_cond_subgraph_{cond_subgraph_index}",
- op_name="WHILE",
+ f"{function_prefix}_cond_subgraph_{cond_subgraph_index}",
+ op_name=op_name,
)
body_func = self._lower_subgraph_to_function(
body_subgraph_index,
- f"tflite_while_body_subgraph_{body_subgraph_index}",
- op_name="WHILE",
+ f"{function_prefix}_body_subgraph_{body_subgraph_index}",
+ op_name=op_name,
)
loop_gv = self._lower_while_to_function(
@@ -2539,11 +2551,25 @@ class OperatorConverter:
cond_func,
body_func,
body_subgraph,
+ function_prefix=function_prefix,
)
args = [self.get_tensor_expr(tensor) for tensor in input_tensors]
return relax.Call(loop_gv, args)
+ def convert_while(self, op):
+ """Convert TFLite WHILE to a recursive Relax private function."""
+ from tflite.WhileOptions import WhileOptions
+
+ opts = self._get_builtin_options(op, WhileOptions)
+ return self._convert_while_like(
+ op,
+ "WHILE",
+ int(opts.CondSubgraphIndex()),
+ int(opts.BodySubgraphIndex()),
+ "tflite_while",
+ )
+
def convert_call_once(self, op):
"""Convert TFLite CALL_ONCE for no-op and resource-variable
initialization subsets."""
from tflite.CallOnceOptions import CallOnceOptions
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 05a6c1e5e5..cc3a84e2fd 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -3695,6 +3695,7 @@ _tfl_stablehlo_reduce_opts =
_get_tflite_schema_module("StablehloReduceOptions")
_tfl_stablehlo_reduce_window_opts =
_get_tflite_schema_module("StablehloReduceWindowOptions")
_tfl_stablehlo_scatter_opts =
_get_tflite_schema_module("StablehloScatterOptions")
_tfl_stablehlo_sort_opts = _get_tflite_schema_module("StablehloSortOptions")
+_tfl_stablehlo_while_opts = _get_tflite_schema_module("StablehloWhileOptions")
_tfl_call_options = _get_tflite_schema_module("CallOptions")
_tfl_call_once_options = _get_tflite_schema_module("CallOnceOptions")
_tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata")
@@ -3946,6 +3947,17 @@ def _build_while_options(builder, cond_subgraph_index,
body_subgraph_index):
return _tfl_while_options.WhileOptionsEnd(builder)
+def _build_stablehlo_while_options(builder, cond_subgraph_index,
body_subgraph_index):
+ _tfl_stablehlo_while_opts.StablehloWhileOptionsStart(builder)
+ _tfl_stablehlo_while_opts.StablehloWhileOptionsAddCondSubgraphIndex(
+ builder, cond_subgraph_index
+ )
+ _tfl_stablehlo_while_opts.StablehloWhileOptionsAddBodySubgraphIndex(
+ builder, body_subgraph_index
+ )
+ return _tfl_stablehlo_while_opts.StablehloWhileOptionsEnd(builder)
+
+
def _build_call_once_options(builder, init_subgraph_index):
_tfl_call_once_options.CallOnceOptionsStart(builder)
_tfl_call_once_options.CallOnceOptionsAddInitSubgraphIndex(builder,
init_subgraph_index)
@@ -6296,6 +6308,107 @@ def
_build_stablehlo_scatter_model(reducer_name="STABLEHLO_ADD", update_window_d
)
+def _build_stablehlo_while_model(
+ cond_subgraph_index=1,
+ body_subgraph_index=2,
+ cond_output_type=_tfl_tensor_type.BOOL,
+ cond_input_type=_tfl_tensor_type.INT32,
+ body_outputs=None,
+ body_input_type=_tfl_tensor_type.INT32,
+ body_output_type=_tfl_tensor_type.INT32,
+ main_output_type=_tfl_tensor_type.INT32,
+):
+ """Build a STABLEHLO_WHILE model incrementing an int32 scalar until i < 3
is false."""
+ builder = flatbuffers.Builder(1024)
+
+ body_outputs = [2] if body_outputs is None else body_outputs
+ while_options = _build_stablehlo_while_options(
+ builder, cond_subgraph_index, body_subgraph_index
+ )
+ _tfl_stablehlo_compare_opts.StablehloCompareOptionsStart(builder)
+ _tfl_stablehlo_compare_opts.StablehloCompareOptionsAddComparisonDirection(
+ builder,
+
_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LT,
+ )
+ compare_opts =
_tfl_stablehlo_compare_opts.StablehloCompareOptionsEnd(builder)
+ one = np.array(1, dtype=np.int32)
+ three = np.array(3, dtype=np.int32)
+
+ main_tensors = [
+ _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32),
+ _build_tensor(builder, 3, [], tensor_type=main_output_type),
+ ]
+ main_while = _build_operator(
+ builder,
+ 0,
+ [0],
+ [1],
+ builtin_options2_type=_tfl_builtin_options2.StablehloWhileOptions,
+ builtin_options2=while_options,
+ )
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=main_tensors,
+ operators=[main_while],
+ inputs=[0],
+ outputs=[1],
+ )
+
+ cond_tensors = [
+ _build_tensor(builder, 0, [], tensor_type=cond_input_type),
+ _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32),
+ _build_tensor(builder, 3, [], tensor_type=cond_output_type),
+ ]
+ cond_compare = _build_operator(
+ builder,
+ 1,
+ [0, 1],
+ [2],
+ builtin_options2_type=_tfl_builtin_options2.StablehloCompareOptions,
+ builtin_options2=compare_opts,
+ )
+ cond_subgraph = _build_subgraph(
+ builder,
+ tensors=cond_tensors,
+ operators=[cond_compare],
+ inputs=[0],
+ outputs=[2],
+ )
+
+ body_tensors = [
+ _build_tensor(builder, 0, [], tensor_type=body_input_type),
+ _build_tensor(builder, 2, [], tensor_type=_tfl_tensor_type.INT32),
+ _build_tensor(builder, 3, [], tensor_type=body_output_type),
+ ]
+ body_add = _build_operator(builder, 2, [0, 1], [2])
+ body_subgraph = _build_subgraph(
+ builder,
+ tensors=body_tensors,
+ operators=[body_add],
+ inputs=[0],
+ outputs=body_outputs,
+ )
+
+ operator_codes = [
+ _build_operator_code(builder,
_get_stablehlo_builtin_operator("STABLEHLO_WHILE")),
+ _build_operator_code(builder,
_get_stablehlo_builtin_operator("STABLEHLO_COMPARE")),
+ _build_operator_code(builder,
_get_stablehlo_builtin_operator("STABLEHLO_ADD")),
+ ]
+ buffers = [
+ _build_buffer(builder),
+ _build_buffer(builder, three.tobytes()),
+ _build_buffer(builder, one.tobytes()),
+ _build_buffer(builder),
+ ]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ extra_subgraphs=[cond_subgraph, body_subgraph],
+ operator_codes=operator_codes,
+ buffers=buffers,
+ )
+
+
def _build_stablehlo_composite_model(with_attributes=False,
use_main_input_after_composite=False):
"""Build a STABLEHLO_COMPOSITE model that decomposes to
STABLEHLO_NEGATE."""
builder = flatbuffers.Builder(1024)
@@ -6699,6 +6812,112 @@ def test_stablehlo_scatter_update_window_unsupported():
from_tflite(tflite_model)
+def test_stablehlo_while():
+ """TFLite STABLEHLO_WHILE lowers to a recursive Relax private function."""
+ mod = _load_model_from_buffer(_build_stablehlo_while_model())
+
+ @I.ir_module
+ class Expected:
+ @R.function(private=True)
+ def tflite_stablehlo_while_cond_subgraph_1(
+ tvmgen_tensor_0: R.Tensor((), dtype="int32"),
+ ) -> R.Tensor((), dtype="bool"):
+ with R.dataflow():
+ gv: R.Tensor((), dtype="bool") = R.less(tvmgen_tensor_0,
R.const(3, "int32"))
+ R.output(gv)
+ return gv
+
+ @R.function(private=True)
+ def tflite_stablehlo_while_body_subgraph_2(
+ tvmgen_tensor_0: R.Tensor((), dtype="int32"),
+ ) -> R.Tensor((), dtype="int32"):
+ with R.dataflow():
+ gv: R.Tensor((), dtype="int32") = R.add(tvmgen_tensor_0,
R.const(1, "int32"))
+ R.output(gv)
+ return gv
+
+ @R.function(private=True)
+ def tflite_stablehlo_while_subgraph_1_2(
+ tvmgen_tensor_0: R.Tensor((), dtype="int32"),
+ ) -> R.Tensor((), dtype="int32"):
+ cls = Expected
+ while_cond: R.Tensor((), dtype="bool") =
cls.tflite_stablehlo_while_cond_subgraph_1(
+ tvmgen_tensor_0
+ )
+ if while_cond:
+ gv: R.Tensor((), dtype="int32") =
cls.tflite_stablehlo_while_body_subgraph_2(
+ tvmgen_tensor_0
+ )
+ gv1: R.Tensor((), dtype="int32") =
cls.tflite_stablehlo_while_subgraph_1_2(gv)
+ cond_result: R.Tensor((), dtype="int32") = gv1
+ else:
+ cond_result: R.Tensor((), dtype="int32") = tvmgen_tensor_0
+ return cond_result
+
+ @R.function
+ def main(
+ tvmgen_tensor_0: R.Tensor((), dtype="int32"),
+ ) -> R.Tensor((), dtype="int32"):
+ R.func_attr({"num_input": 1})
+ cls = Expected
+ with R.dataflow():
+ gv: R.Tensor((), dtype="int32") =
cls.tflite_stablehlo_while_subgraph_1_2(
+ tvmgen_tensor_0
+ )
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_while_non_bool_condition_unsupported():
+ """STABLEHLO_WHILE rejects cond subgraphs that do not return scalar
bool."""
+ with pytest.raises(
+ tvm.error.OpNotImplemented, match="STABLEHLO_WHILE requires a scalar
bool condition"
+ ):
+ _load_model_from_buffer(
+
_build_stablehlo_while_model(cond_output_type=_tfl_tensor_type.INT32)
+ )
+
+
+def test_stablehlo_while_invalid_index_unsupported():
+ """STABLEHLO_WHILE rejects invalid cond/body subgraph indices before
lowering."""
+ with pytest.raises(
+ tvm.error.OpNotImplemented, match="STABLEHLO_WHILE requires a valid
subgraph index"
+ ):
+
_load_model_from_buffer(_build_stablehlo_while_model(cond_subgraph_index=3))
+
+
+def test_stablehlo_while_output_count_mismatch_unsupported():
+ """STABLEHLO_WHILE rejects body subgraphs whose output arity does not
match loop vars."""
+ with pytest.raises(
+ tvm.error.OpNotImplemented, match="STABLEHLO_WHILE subgraph output
count mismatch"
+ ):
+ _load_model_from_buffer(_build_stablehlo_while_model(body_outputs=[]))
+
+
+def test_stablehlo_while_input_metadata_mismatch_unsupported():
+ """STABLEHLO_WHILE rejects cond subgraph inputs whose metadata does not
match loop vars."""
+ with pytest.raises(
+ tvm.error.OpNotImplemented,
+ match="STABLEHLO_WHILE subgraph input tensor metadata mismatch",
+ ):
+ _load_model_from_buffer(
+
_build_stablehlo_while_model(cond_input_type=_tfl_tensor_type.FLOAT32)
+ )
+
+
+def test_stablehlo_while_output_metadata_mismatch_unsupported():
+ """STABLEHLO_WHILE rejects body outputs whose metadata does not match loop
vars."""
+ with pytest.raises(
+ tvm.error.OpNotImplemented,
+ match="STABLEHLO_WHILE subgraph output tensor metadata mismatch",
+ ):
+ _load_model_from_buffer(
+
_build_stablehlo_while_model(body_output_type=_tfl_tensor_type.FLOAT32)
+ )
+
+
def test_stablehlo_composite():
"""TFLite StableHLO COMPOSITE inlines a simple decomposition subgraph."""
mod = _load_model_from_buffer(_build_stablehlo_composite_model())