Aharrypotter opened a new pull request, #19616:
URL: https://github.com/apache/tvm/pull/19616
## Summary
This PR adds Relax TFLite frontend support for the TFLite builtin
control-flow / multi-subgraph operator family from #19519 item F:
`CALL`, `IF`, `WHILE`, and `CALL_ONCE`.
It builds on the multi-subgraph import infrastructure merged in PR #19587.
The frontend already accepts TFLite models with extra subgraphs while
converting
only `Subgraphs(0)` into the Relax `main` function. This PR uses those extra
subgraphs as callable or control-flow regions for the TFLite control-flow
operators.
The supported subset is intentionally pure tensor and guard-first:
- `CALL` lowers a referenced TFLite subgraph to a private Relax function and
emits a direct call.
- `IF` lowers the then/else subgraphs to private Relax functions and emits a
private wrapper function containing Relax `If`.
- `WHILE` lowers the cond/body subgraphs to private Relax functions and
emits a
recursive private Relax function for the loop.
- `CALL_ONCE` supports the empty-init no-op subset and explicitly rejects
non-empty or resource-like init patterns.
This PR does not model resource variable side effects. Those cases remain
explicitly guarded instead of being imported with incorrect pure functional
semantics.
## Design
### Shared Subgraph Lowering
The frontend now keeps shared conversion state across the main graph and
referenced subgraphs:
- `lowered_subgraphs`
- `lowered_if_functions`
- `lowered_while_functions`
- `lowering_stack`
Referenced pure tensor subgraphs are lowered through a recursive
`OperatorConverter` using an isolated `ExprTable`, so subgraph tensor
bindings
cannot overwrite bindings from the main graph. Lowered subgraphs are cached
by
subgraph index and reused when the same region is referenced more than once.
Recursive ordinary `CALL` subgraphs are guarded with `OpNotImplemented`.
`WHILE` uses a dedicated recursive wrapper function instead, because
recursion
is part of the intended Relax representation for the loop itself.
### Boundary Validation
The control-flow converters validate subgraph boundaries before lowering:
- referenced subgraph indices must be valid
- op input/output arity must match the referenced subgraph interface
- branch and loop tensor shape/dtype metadata must match the surrounding op
- `IF` and `WHILE` conditions must be scalar bool tensors
- `WHILE` loop-carried input/output tensors must have matching metadata
The shared `_check_subgraph_interface` helper is used by `CALL`, `IF`, and
`WHILE` to keep arity and metadata checks consistent across the control-flow
operators. `_require_scalar_bool_tensor` accepts both frontend
`TensorWrapper`
objects and raw TFLite tensors so caller and referenced-subgraph condition
checks use the same path.
These checks keep the first implementation conservative and make unsupported
cases fail with targeted `OpNotImplemented` diagnostics.
### Tuple Outputs
TFLite `CALL`, `IF`, and `WHILE` may produce multiple output tensors. The
frontend maps those cases to Relax tuple returns:
```text
single output -> tensor expression
multi output -> Tuple(...)
op outputs -> TupleGetItem(...)
```
This keeps the single-output IR simple while covering multi-output calls,
multi-output branches, and multi-variable loop state.
## Operator Support
| Operator | TFLite options | Relax lowering | Supported subset |
|---|---|---|---|
| `CALL` | `CallOptions.Subgraph()` | private Relax function call | pure
tensor subgraphs, single or multiple outputs |
| `IF` | `IfOptions.ThenSubgraphIndex()`, `ElseSubgraphIndex()` | private
wrapper function containing Relax `If` | scalar bool condition, matching branch
I/O metadata |
| `WHILE` | `WhileOptions.CondSubgraphIndex()`, `BodySubgraphIndex()` |
recursive private Relax function | scalar bool cond output, tensor loop-carried
state |
| `CALL_ONCE` | `CallOnceOptions.InitSubgraphIndex()` | no-op for empty init
subgraph | empty init subgraph only |
## Not Included
- Full `CALL_ONCE` resource/variable initialization semantics.
- Resource, variant, hashtable, or variable tensor support.
- TensorFlow-generated `tf.cond` / `tf.while_loop` smoke tests.
- Dynamic-shape loop-state refinements beyond the current static metadata
checks.
## Tests
The tests manually build minimal TFLite flatbuffers and compare the imported
Relax IR with `tvm.ir.assert_structural_equal`. Unsupported-boundary tests
use
`pytest.raises`.
| Test | Coverage |
|---|---|
| `test_call_subgraph` | basic `CALL` to a pure tensor subgraph |
| `test_call_subgraph_multi_output` | `CALL` tuple return and output binding
|
| `test_call_subgraph_invalid_index_unsupported` | invalid `CALL` subgraph
index |
| `test_call_subgraph_io_mismatch_unsupported` | `CALL` arity mismatch |
| `test_call_subgraph_output_metadata_mismatch_unsupported` | `CALL` output
metadata guard |
| `test_if_subgraphs` | basic `IF` branch selection |
| `test_if_subgraphs_multi_output` | `IF` tuple branch returns |
| `test_if_subgraphs_non_bool_condition_unsupported` | `IF` condition dtype
guard |
| `test_if_subgraphs_invalid_index_unsupported` | invalid then/else subgraph
index |
| `test_if_subgraphs_output_count_mismatch_unsupported` | branch output
count guard |
| `test_if_subgraphs_input_metadata_mismatch_unsupported` | branch input
metadata guard |
| `test_if_subgraphs_output_metadata_mismatch_unsupported` | branch output
metadata guard |
| `test_while_subgraphs` | basic recursive `WHILE` lowering |
| `test_while_subgraphs_repeated_cond_body_pair` | shared cond/body loop
function cache |
| `test_while_subgraphs_two_loop_vars` | multi-variable loop state tuple
path |
| `test_while_subgraphs_non_bool_condition_unsupported` | `WHILE` cond
output dtype guard |
| `test_while_subgraphs_invalid_index_unsupported` | invalid cond/body
subgraph index |
| `test_while_subgraphs_zero_loop_vars_unsupported` | zero-loop-var guard |
| `test_while_subgraphs_loop_state_metadata_mismatch_unsupported` | loop
state metadata guard |
| `test_while_subgraphs_output_count_mismatch_unsupported` | body output
count guard |
| `test_while_subgraphs_input_metadata_mismatch_unsupported` | cond/body
input metadata guard |
| `test_while_subgraphs_output_metadata_mismatch_unsupported` | cond/body
output metadata guard |
| `test_call_once_empty_init_subgraph` | empty `CALL_ONCE` no-op subset |
| `test_call_once_non_empty_init_subgraph_unsupported` | non-empty init
subgraph guard |
| `test_call_once_inputs_outputs_unsupported` | `CALL_ONCE` op I/O guard |
| `test_call_once_init_subgraph_io_unsupported` | init subgraph I/O guard |
| `test_call_once_invalid_index_unsupported` | invalid init subgraph index |
Local validation:
```bash
python -m ruff format --check \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m ruff check \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m pytest \
tests/python/relax/test_frontend_tflite.py \
-k "call_subgraph or if_subgraphs or while_subgraphs or call_once" -q
python -m pytest \
tests/python/relax/test_frontend_tflite.py -q
```
Result:
```text
ruff format --check: 2 files already formatted
ruff check: All checks passed
27 passed, 434 deselected
461 passed
```
## References
- Issue #19519 item F: TFLite control-flow / multi-subgraph operators
- PR #19587: StableHLO region-based ops and multi-subgraph model support
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]