This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fa66213249 [Relax][Frontend][TFLite] Support control-flow 
multi-subgraph operators (#19616)
fa66213249 is described below

commit fa66213249a361656fea055d80291cf0e6b2ff1a
Author: HoYi <[email protected]>
AuthorDate: Wed May 27 11:09:45 2026 +0800

    [Relax][Frontend][TFLite] Support control-flow multi-subgraph operators 
(#19616)
    
    ## Summary
    
    This PR adds Relax TFLite frontend support for the TFLite builtin
    control-flow / multi-subgraph operator family from #19519 item F:
    `CALL`, `IF`, `WHILE`, and `CALL_ONCE`.
    
    It builds on the multi-subgraph import infrastructure merged in PR
    #19587.
    The frontend already accepts TFLite models with extra subgraphs while
    converting
    only `Subgraphs(0)` into the Relax `main` function. This PR uses those
    extra
    subgraphs as callable or control-flow regions for the TFLite
    control-flow
    operators.
    
    The supported subset is intentionally pure tensor and guard-first:
    
    - `CALL` lowers a referenced TFLite subgraph to a private Relax function
    and
      emits a direct call.
    - `IF` lowers the then/else subgraphs to private Relax functions and
    emits a
      private wrapper function containing Relax `If`.
    - `WHILE` lowers the cond/body subgraphs to private Relax functions and
    emits a
      recursive private Relax function for the loop.
    - `CALL_ONCE` supports the empty-init no-op subset and explicitly
    rejects
      non-empty or resource-like init patterns.
    
    This PR does not model resource variable side effects. Those cases
    remain
    explicitly guarded instead of being imported with incorrect pure
    functional
    semantics.
    
    ## Design
    
    ### Shared Subgraph Lowering
    
    The frontend now keeps shared conversion state across the main graph and
    referenced subgraphs:
    
    - `lowered_subgraphs`
    - `lowered_if_functions`
    - `lowered_while_functions`
    - `lowering_stack`
    - `module_builder`
    
    Referenced pure tensor subgraphs are lowered through a recursive
    `OperatorConverter` using an isolated `ExprTable`, so subgraph tensor
    bindings
    cannot overwrite bindings from the main graph. Lowered subgraphs are
    cached by
    subgraph index and reused when the same region is referenced more than
    once.
    Generated private functions are registered through the shared parent
    `module_builder`, so nested cases such as `main CALL -> subgraph A ->
    CALL
    subgraph B` keep all private functions in the final IRModule.
    
    Recursive ordinary `CALL` subgraphs are guarded with `OpNotImplemented`.
    `WHILE` uses a dedicated recursive wrapper function instead, because
    recursion
    is part of the intended Relax representation for the loop itself.
    
    ### Boundary Validation
    
    The control-flow converters validate subgraph boundaries before
    lowering:
    
    - referenced subgraph indices must be valid
    - op input/output arity must match the referenced subgraph interface
    - branch and loop tensor shape/dtype metadata must match the surrounding
    op
    - `IF` and `WHILE` conditions must be scalar bool tensors
    - `WHILE` loop-carried input/output tensors must have matching metadata
    
    The shared `_check_subgraph_interface` helper is used by `CALL`, `IF`,
    and
    `WHILE` to keep arity and metadata checks consistent across the
    control-flow
    operators. `_require_scalar_bool_tensor` accepts both frontend
    `TensorWrapper`
    objects and raw TFLite tensors so caller and referenced-subgraph
    condition
    checks use the same path.
    
    These checks keep the first implementation conservative and make
    unsupported
    cases fail with targeted `OpNotImplemented` diagnostics.
    
    ### Tuple Outputs
    
    TFLite `CALL`, `IF`, and `WHILE` may produce multiple output tensors.
    The
    frontend maps those cases to Relax tuple returns:
    
    ```text
    single output  -> tensor expression
    multi output   -> Tuple(...)
    op outputs     -> TupleGetItem(...)
    ```
    
    This keeps the single-output IR simple while covering multi-output
    calls,
    multi-output branches, and multi-variable loop state.
    
    ## Operator Support
    
    | Operator | TFLite options | Relax lowering | Supported subset |
    |---|---|---|---|
    | `CALL` | `CallOptions.Subgraph()` | private Relax function call | pure
    tensor subgraphs, single or multiple outputs |
    | `IF` | `IfOptions.ThenSubgraphIndex()`, `ElseSubgraphIndex()` |
    private wrapper function containing Relax `If` | scalar bool condition,
    matching branch I/O metadata |
    | `WHILE` | `WhileOptions.CondSubgraphIndex()`, `BodySubgraphIndex()` |
    recursive private Relax function | scalar bool cond output, tensor
    loop-carried state |
    | `CALL_ONCE` | `CallOnceOptions.InitSubgraphIndex()` | no-op for empty
    init subgraph | empty init subgraph only |
    
    ## Not Included
    
    - Full `CALL_ONCE` resource/variable initialization semantics.
    - Resource, variant, hashtable, or variable tensor support.
    - TensorFlow-generated `tf.cond` / `tf.while_loop` smoke tests.
    - Dynamic-shape loop-state refinements beyond the current static
    metadata
      checks.
    
    ## Tests
    
    The tests manually build minimal TFLite flatbuffers and compare the
    imported
    Relax IR with `tvm.ir.assert_structural_equal`. Unsupported-boundary
    tests use
    `pytest.raises`.
    
    | Test | Coverage |
    |---|---|
    | `test_call_subgraph` | basic `CALL` to a pure tensor subgraph |
    | `test_call_subgraph_multi_output` | `CALL` tuple return and output
    binding |
    | `test_call_subgraph_nested_call` | nested `CALL` private function
    registration |
    | `test_call_subgraph_invalid_index_unsupported` | invalid `CALL`
    subgraph index |
    | `test_call_subgraph_io_mismatch_unsupported` | `CALL` arity mismatch |
    | `test_call_subgraph_output_metadata_mismatch_unsupported` | `CALL`
    output metadata guard |
    | `test_if_subgraphs` | basic `IF` branch selection |
    | `test_if_subgraphs_multi_output` | `IF` tuple branch returns |
    | `test_if_subgraphs_non_bool_condition_unsupported` | `IF` condition
    dtype guard |
    | `test_if_subgraphs_invalid_index_unsupported` | invalid then/else
    subgraph index |
    | `test_if_subgraphs_output_count_mismatch_unsupported` | branch output
    count guard |
    | `test_if_subgraphs_input_metadata_mismatch_unsupported` | branch input
    metadata guard |
    | `test_if_subgraphs_output_metadata_mismatch_unsupported` | branch
    output metadata guard |
    | `test_while_subgraphs` | basic recursive `WHILE` lowering |
    | `test_while_subgraphs_repeated_cond_body_pair` | shared cond/body loop
    function cache |
    | `test_while_subgraphs_two_loop_vars` | multi-variable loop state tuple
    path |
    | `test_while_subgraphs_non_bool_condition_unsupported` | `WHILE` cond
    output dtype guard |
    | `test_while_subgraphs_invalid_index_unsupported` | invalid cond/body
    subgraph index |
    | `test_while_subgraphs_zero_loop_vars_unsupported` | zero-loop-var
    guard |
    | `test_while_subgraphs_loop_state_metadata_mismatch_unsupported` | loop
    state metadata guard |
    | `test_while_subgraphs_output_count_mismatch_unsupported` | body output
    count guard |
    | `test_while_subgraphs_input_metadata_mismatch_unsupported` | cond/body
    input metadata guard |
    | `test_while_subgraphs_output_metadata_mismatch_unsupported` |
    cond/body output metadata guard |
    | `test_call_once_empty_init_subgraph` | empty `CALL_ONCE` no-op subset
    |
    | `test_call_once_non_empty_init_subgraph_unsupported` | non-empty init
    subgraph guard |
    | `test_call_once_inputs_outputs_unsupported` | `CALL_ONCE` op I/O guard
    |
    | `test_call_once_init_subgraph_io_unsupported` | init subgraph I/O
    guard |
    | `test_call_once_invalid_index_unsupported` | invalid init subgraph
    index |
    
    Local validation:
    
    ```bash
    python -m ruff format --check \
      python/tvm/relax/frontend/tflite/tflite_frontend.py \
      tests/python/relax/test_frontend_tflite.py
    
    python -m ruff check \
      python/tvm/relax/frontend/tflite/tflite_frontend.py \
      tests/python/relax/test_frontend_tflite.py
    
    python -m pytest \
      tests/python/relax/test_frontend_tflite.py \
      -k "call_subgraph or if_subgraphs or while_subgraphs or call_once" -q
    
    python -m pytest \
      tests/python/relax/test_frontend_tflite.py -q
    ```
    
    Result:
    
    ```text
    ruff format --check: 2 files already formatted
    ruff check: All checks passed
    28 passed, 434 deselected
    462 passed
    ```
    
    ## References
    
    - Issue #19519 item F: TFLite control-flow / multi-subgraph operators
    - PR #19587: StableHLO region-based ops and multi-subgraph model support
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   |  434 ++++++-
 tests/python/relax/test_frontend_tflite.py         | 1352 ++++++++++++++++++++
 2 files changed, 1782 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 979bbbb867..f395c95b6d 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -154,7 +154,7 @@ class OperatorConverter:
         }
     )
 
-    def __init__(self, model, subgraph, exp_tab, ctx):
+    def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
         from tflite.ActivationFunctionType import ActivationFunctionType
         from tflite.BuiltinOperator import BuiltinOperator
         from tflite.BuiltinOptions import BuiltinOptions
@@ -168,6 +168,17 @@ class OperatorConverter:
         self.prefetched_nodes = {}
         self.allow_custom_ops = False
         self.bb = ctx
+        if conversion_state is None:
+            conversion_state = {
+                "lowered_subgraphs": {},
+                "lowered_if_functions": {},
+                "lowered_while_functions": {},
+                "lowering_stack": [],
+                "module_builder": ctx,
+            }
+        else:
+            conversion_state.setdefault("module_builder", ctx)
+        self.conversion_state = conversion_state
 
         # Add more operators
         self.convert_map = {
@@ -183,6 +194,8 @@ class OperatorConverter:
             "BITCAST": self.convert_bitcast,
             "BROADCAST_TO": self.convert_broadcast_to,
             "BROADCAST_ARGS": self.convert_broadcast_args,
+            "CALL": self.convert_call,
+            "CALL_ONCE": self.convert_call_once,
             "CAST": self.convert_cast,
             "CEIL": functools.partial(self._convert_unary_elemwise, 
relax_op=_op.ceil),
             "CONCATENATION": self.convert_concatenation,
@@ -221,6 +234,7 @@ class OperatorConverter:
             ),
             "GELU": self.convert_gelu,
             "HARD_SWISH": self.convert_hard_swish,
+            "IF": self.convert_if,
             "L2_NORMALIZATION": self.convert_l2_normalization,
             "L2_POOL_2D": functools.partial(self.convert_pool2d, 
pool_type="l2"),
             "LEAKY_RELU": self.convert_leaky_relu,
@@ -375,6 +389,7 @@ class OperatorConverter:
             ),
             # "UNIDIRECTIONAL_SEQUENCE_LSTM": 
self.convert_unidirectional_sequence_lstm,
             "WHERE": self.convert_select,
+            "WHILE": self.convert_while,
             "ZEROS_LIKE": self.convert_zeros_like,
             "NON_MAX_SUPPRESSION_V4": self.convert_nms_v4,
             "NON_MAX_SUPPRESSION_V5": self.convert_nms_v5,
@@ -562,7 +577,7 @@ class OperatorConverter:
     def get_tensors(self, tensors_idx_list):
         """Get tensor wrapper list from given TFLite tensor index list"""
         return_list = list()
-        for tensor_idx in tensors_idx_list:
+        for tensor_idx in self._indices_or_empty(tensors_idx_list):
             if tensor_idx < 0:
                 return_list.append(TensorWrapper(tensor_idx, 0, 0))
                 continue
@@ -1888,6 +1903,417 @@ class OperatorConverter:
             relax.op.sort(data, axis=int(opts.Dimension()), 
descending=descending)
         )
 
+    def _get_builtin_options(self, op, options_cls):
+        """Parse BuiltinOptions for a TFLite builtin operator."""
+        from tflite.BuiltinOptions import BuiltinOptions
+
+        op_options = op.BuiltinOptions()
+        if op_options is None:
+            raise tvm.error.OpNotImplemented(f"{options_cls.__name__} is 
required")
+
+        options_type = getattr(BuiltinOptions, options_cls.__name__, None)
+        if options_type is not None and op.BuiltinOptionsType() != 
options_type:
+            raise tvm.error.OpNotImplemented(
+                f"Unexpected BuiltinOptions type: expected "
+                f"{options_cls.__name__}, got {op.BuiltinOptionsType()}"
+            )
+        result = options_cls()
+        result.Init(op_options.Bytes, op_options.Pos)
+        return result
+
+    def _get_subgraph(self, subgraph_index, op_name, allow_main=False):
+        """Return a validated TFLite subgraph by index."""
+        if subgraph_index < 0 or subgraph_index >= 
self.model.SubgraphsLength():
+            raise tvm.error.OpNotImplemented(f"{op_name} requires a valid 
subgraph index")
+        if not allow_main and subgraph_index == 0:
+            raise tvm.error.OpNotImplemented(f"{op_name} cannot target the 
main subgraph")
+        return self.model.Subgraphs(subgraph_index)
+
+    def _make_tuple_or_single(self, exprs):
+        """Return a single expression or Relax tuple for a list of 
expressions."""
+        if len(exprs) == 1:
+            return exprs[0]
+        return relax.Tuple(exprs)
+
+    def _indices_or_empty(self, indices):
+        """Return a TFLite index vector, using an empty list for absent 
vectors."""
+        return indices if indices is not None else []
+
+    def _check_subgraph_io(self, subgraph_index, op_name, input_count=None, 
output_count=None):
+        """Validate a referenced subgraph's input and output counts."""
+        subgraph = self._get_subgraph(subgraph_index, op_name)
+        if input_count is not None and subgraph.InputsLength() != input_count:
+            raise tvm.error.OpNotImplemented(f"{op_name} subgraph input count 
mismatch")
+        if output_count is not None and subgraph.OutputsLength() != 
output_count:
+            raise tvm.error.OpNotImplemented(f"{op_name} subgraph output count 
mismatch")
+        return subgraph
+
+    def _check_subgraph_interface(
+        self,
+        subgraph_index,
+        op_name,
+        input_tensors=None,
+        output_tensors=None,
+        input_count=None,
+        output_count=None,
+    ):
+        """Validate a referenced subgraph's arity and tensor metadata."""
+        if input_tensors is not None:
+            input_count = len(input_tensors)
+        if output_tensors is not None:
+            output_count = len(output_tensors)
+
+        subgraph = self._check_subgraph_io(
+            subgraph_index, op_name, input_count=input_count, 
output_count=output_count
+        )
+        if input_tensors is not None:
+            self._check_subgraph_tensor_metadata(
+                subgraph,
+                op_name,
+                "subgraph input",
+                subgraph.InputsAsNumpy(),
+                input_tensors,
+            )
+        if output_tensors is not None:
+            self._check_subgraph_tensor_metadata(
+                subgraph,
+                op_name,
+                "subgraph output",
+                subgraph.OutputsAsNumpy(),
+                output_tensors,
+            )
+        return subgraph
+
+    def _get_tensor_metadata(self, tensor):
+        """Return static shape and dtype metadata for a TFLite tensor."""
+        if isinstance(tensor, TensorWrapper):
+            tensor = tensor.tensor
+        shape = tuple(tensor.ShapeAsNumpy()) if tensor.ShapeLength() > 0 else 
()
+        dtype = self.get_tensor_type_str(tensor.Type())
+        return shape, dtype
+
+    def _check_tensor_metadata_match(self, actual, expected, op_name, 
tensor_role):
+        """Validate that two TFLite tensors have matching static metadata."""
+        if self._get_tensor_metadata(actual) != 
self._get_tensor_metadata(expected):
+            raise tvm.error.OpNotImplemented(f"{op_name} {tensor_role} tensor 
metadata mismatch")
+
+    def _check_subgraph_tensor_metadata(
+        self, subgraph, op_name, tensor_role, subgraph_indices, 
expected_tensors
+    ):
+        """Validate referenced subgraph tensor metadata against caller 
tensors."""
+        for subgraph_index, expected_tensor in zip(
+            self._indices_or_empty(subgraph_indices), expected_tensors
+        ):
+            self._check_tensor_metadata_match(
+                subgraph.Tensors(int(subgraph_index)),
+                expected_tensor,
+                op_name,
+                tensor_role,
+            )
+
+    def _require_scalar_bool_tensor(self, tensor, op_name):
+        """Validate that a TFLite tensor is a scalar bool tensor."""
+        if isinstance(tensor, TensorWrapper):
+            tensor = tensor.tensor
+        dtype = self.get_tensor_type_str(tensor.Type())
+        if dtype != "bool" or tensor.ShapeLength() != 0:
+            raise tvm.error.OpNotImplemented(f"{op_name} requires a scalar 
bool condition")
+
+    def _get_subgraph_params(self, subgraph):
+        """Create Relax parameters for a TFLite subgraph."""
+        params = []
+        exp_tab = ExprTable()
+        for input_index in self._indices_or_empty(subgraph.InputsAsNumpy()):
+            tensor = subgraph.Tensors(int(input_index))
+            input_name = get_tensor_name(subgraph, int(input_index))
+            shape = tuple(tensor.ShapeAsNumpy()) if tensor.ShapeLength() > 0 
else []
+            dtype = self.get_tensor_type_str(tensor.Type())
+            param = relax.Var(input_name, relax.TensorStructInfo(shape=shape, 
dtype=dtype))
+            exp_tab.set_expr(input_name, param)
+            params.append(param)
+        return params, exp_tab
+
+    def _get_tensor_param(self, tensor_wrapper):
+        """Create a Relax parameter from TFLite tensor metadata."""
+        name = get_tensor_name(self.subgraph, tensor_wrapper.tensor_idx)
+        shape = (
+            tuple(tensor_wrapper.tensor.ShapeAsNumpy())
+            if tensor_wrapper.tensor.ShapeLength() > 0
+            else []
+        )
+        dtype = self.get_tensor_type_str(tensor_wrapper.tensor.Type())
+        return relax.Var(name, relax.TensorStructInfo(shape=shape, 
dtype=dtype))
+
+    def _lower_subgraph_to_function(self, subgraph_index, function_name_hint, 
op_name="CALL"):
+        """Lower a TFLite subgraph into a private Relax function."""
+        lowered_subgraphs = self.conversion_state["lowered_subgraphs"]
+        if subgraph_index in lowered_subgraphs:
+            return lowered_subgraphs[subgraph_index]
+
+        lowering_stack = self.conversion_state["lowering_stack"]
+        if subgraph_index in lowering_stack:
+            raise tvm.error.OpNotImplemented(
+                f"Recursive TFLite {op_name} subgraphs are not supported"
+            )
+
+        subgraph = self._get_subgraph(subgraph_index, op_name)
+        lowering_stack.append(subgraph_index)
+        try:
+            params, subgraph_exp_tab = self._get_subgraph_params(subgraph)
+            subgraph_bb = relax.BlockBuilder()
+            with subgraph_bb.function(function_name_hint, params=params, 
private=True):
+                with subgraph_bb.dataflow():
+                    subgraph_converter = type(self)(
+                        self.model,
+                        subgraph,
+                        subgraph_exp_tab,
+                        subgraph_bb,
+                        self.conversion_state,
+                    )
+                    subgraph_converter.check_unsupported_ops()
+                    subgraph_converter.convert_op_to_relax()
+                    output_tensors = 
subgraph_converter.get_tensors(subgraph.OutputsAsNumpy())
+                    outputs = [
+                        subgraph_converter.get_tensor_expr(tensor) for tensor 
in output_tensors
+                    ]
+                    output = 
subgraph_bb.emit_output(self._make_tuple_or_single(outputs))
+                subgraph_bb.emit_func_output(output)
+
+            subgraph_mod = subgraph_bb.get()
+            module_builder = self.conversion_state["module_builder"]
+            gv = module_builder.add_func(subgraph_mod[function_name_hint], 
function_name_hint)
+            lowered_subgraphs[subgraph_index] = gv
+            return gv
+        finally:
+            lowering_stack.pop()
+
+    def _bind_call_outputs(self, call, output_count):
+        """Return per-output expressions from a single or tuple-valued call."""
+        if output_count == 1:
+            return [call]
+        return [call[index] for index in range(output_count)]
+
+    def _lower_if_to_function(
+        self,
+        then_subgraph_index,
+        else_subgraph_index,
+        input_tensors,
+        branch_input_count,
+        output_count,
+    ):
+        """Lower a TFLite IF op into a private Relax function."""
+        cache_key = (then_subgraph_index, else_subgraph_index, 
branch_input_count, output_count)
+        lowered_if_functions = self.conversion_state["lowered_if_functions"]
+        if cache_key in lowered_if_functions:
+            return lowered_if_functions[cache_key]
+
+        then_func = self._lower_subgraph_to_function(
+            then_subgraph_index,
+            f"tflite_if_then_subgraph_{then_subgraph_index}",
+            op_name="IF",
+        )
+        else_func = self._lower_subgraph_to_function(
+            else_subgraph_index,
+            f"tflite_if_else_subgraph_{else_subgraph_index}",
+            op_name="IF",
+        )
+        if_name = 
f"tflite_if_subgraph_{then_subgraph_index}_{else_subgraph_index}"
+        params = [self._get_tensor_param(tensor) for tensor in input_tensors]
+        cond = params[0]
+        branch_args = params[1:]
+
+        if_bb = relax.BlockBuilder()
+        with if_bb.function(if_name, params=params, private=True):
+            result = relax.If(
+                cond,
+                relax.Call(then_func, branch_args),
+                relax.Call(else_func, branch_args),
+            )
+            if_bb.emit_func_output(result)
+        if_func = if_bb.get()[if_name]
+        module_builder = self.conversion_state["module_builder"]
+        gv = module_builder.add_func(if_func, if_name)
+        lowered_if_functions[cache_key] = gv
+        return gv
+
+    def _lower_while_to_function(
+        self,
+        cond_subgraph_index,
+        body_subgraph_index,
+        loop_var_count,
+        cond_func,
+        body_func,
+        body_subgraph,
+    ):
+        """Lower a TFLite WHILE op into a recursive private Relax function."""
+        cache_key = (cond_subgraph_index, body_subgraph_index, loop_var_count)
+        lowered_while_functions = 
self.conversion_state["lowered_while_functions"]
+        if cache_key in lowered_while_functions:
+            return lowered_while_functions[cache_key]
+
+        loop_name = 
f"tflite_while_subgraph_{cond_subgraph_index}_{body_subgraph_index}"
+        params, _ = self._get_subgraph_params(body_subgraph)
+        dummy_body = self._make_tuple_or_single(params)
+        module_builder = self.conversion_state["module_builder"]
+        loop_gv = module_builder.add_func(relax.Function(params, dummy_body), 
loop_name)
+        lowered_while_functions[cache_key] = loop_gv
+
+        loop_bb = relax.BlockBuilder()
+        with loop_bb.function(loop_name, params=params, private=True):
+            cond = loop_bb.emit(relax.Call(cond_func, params), "while_cond")
+            next_state = relax.Call(body_func, params)
+            next_args = self._bind_call_outputs(next_state, loop_var_count)
+            true_branch = relax.Call(loop_gv, next_args)
+            false_branch = self._make_tuple_or_single(params)
+            result = relax.If(cond, true_branch, false_branch)
+            loop_bb.emit_func_output(result)
+        loop_func = loop_bb.get()[loop_name]
+        module_builder.update_func(loop_gv, loop_func)
+        return loop_gv
+
+    def convert_call(self, op):
+        """Convert TFLite CALL to a Relax private function call."""
+        from tflite.CallOptions import CallOptions
+
+        opts = self._get_builtin_options(op, CallOptions)
+        subgraph_index = int(opts.Subgraph())
+        input_tensors = self.get_input_tensors(op)
+        output_tensors = self.get_output_tensors(op)
+        self._check_subgraph_interface(
+            subgraph_index,
+            "CALL",
+            input_tensors=input_tensors,
+            output_tensors=output_tensors,
+        )
+
+        callee = self._lower_subgraph_to_function(
+            subgraph_index, f"tflite_call_subgraph_{subgraph_index}", 
op_name="CALL"
+        )
+        args = [self.get_tensor_expr(tensor) for tensor in input_tensors]
+        return relax.Call(callee, args)
+
+    def convert_if(self, op):
+        """Convert TFLite IF to Relax If with private branch functions."""
+        from tflite.IfOptions import IfOptions
+
+        opts = self._get_builtin_options(op, IfOptions)
+        then_subgraph_index = int(opts.ThenSubgraphIndex())
+        else_subgraph_index = int(opts.ElseSubgraphIndex())
+        input_tensors = self.get_input_tensors(op)
+        output_tensors = self.get_output_tensors(op)
+        if len(input_tensors) < 1:
+            raise tvm.error.OpNotImplemented("IF requires a condition input")
+
+        self._require_scalar_bool_tensor(input_tensors[0], "IF")
+        branch_input_count = len(input_tensors) - 1
+        output_count = len(output_tensors)
+        branch_input_tensors = input_tensors[1:]
+        self._check_subgraph_interface(
+            then_subgraph_index,
+            "IF",
+            input_tensors=branch_input_tensors,
+            output_tensors=output_tensors,
+        )
+        self._check_subgraph_interface(
+            else_subgraph_index,
+            "IF",
+            input_tensors=branch_input_tensors,
+            output_tensors=output_tensors,
+        )
+
+        if_func = self._lower_if_to_function(
+            then_subgraph_index,
+            else_subgraph_index,
+            input_tensors,
+            branch_input_count,
+            output_count,
+        )
+        args = [self.get_tensor_expr(tensor) for tensor in input_tensors]
+        return relax.Call(if_func, args)
+
+    def convert_while(self, op):
+        """Convert TFLite WHILE to a recursive Relax private function."""
+        from tflite.WhileOptions import WhileOptions
+
+        opts = self._get_builtin_options(op, WhileOptions)
+        cond_subgraph_index = int(opts.CondSubgraphIndex())
+        body_subgraph_index = int(opts.BodySubgraphIndex())
+        input_tensors = self.get_input_tensors(op)
+        output_tensors = self.get_output_tensors(op)
+        loop_var_count = len(input_tensors)
+        if loop_var_count == 0:
+            raise tvm.error.OpNotImplemented("WHILE requires loop-carried 
inputs")
+        if len(output_tensors) != loop_var_count:
+            raise tvm.error.OpNotImplemented("WHILE output count must match 
input count")
+
+        cond_subgraph = self._check_subgraph_interface(
+            cond_subgraph_index,
+            "WHILE",
+            input_tensors=input_tensors,
+            output_count=1,
+        )
+        body_subgraph = self._check_subgraph_interface(
+            body_subgraph_index,
+            "WHILE",
+            input_tensors=input_tensors,
+            output_tensors=input_tensors,
+        )
+        for input_tensor, output_tensor in zip(input_tensors, output_tensors):
+            self._check_tensor_metadata_match(input_tensor, output_tensor, 
"WHILE", "loop state")
+        cond_output = cond_subgraph.Tensors(int(cond_subgraph.Outputs(0)))
+        self._require_scalar_bool_tensor(cond_output, "WHILE")
+
+        cond_func = self._lower_subgraph_to_function(
+            cond_subgraph_index,
+            f"tflite_while_cond_subgraph_{cond_subgraph_index}",
+            op_name="WHILE",
+        )
+        body_func = self._lower_subgraph_to_function(
+            body_subgraph_index,
+            f"tflite_while_body_subgraph_{body_subgraph_index}",
+            op_name="WHILE",
+        )
+
+        loop_gv = self._lower_while_to_function(
+            cond_subgraph_index,
+            body_subgraph_index,
+            loop_var_count,
+            cond_func,
+            body_func,
+            body_subgraph,
+        )
+
+        args = [self.get_tensor_expr(tensor) for tensor in input_tensors]
+        return relax.Call(loop_gv, args)
+
+    def convert_call_once(self, op):
+        """Convert the no-op subset of TFLite CALL_ONCE.
+
+        Non-empty CALL_ONCE init subgraphs are used for resource initialization
+        side effects in TFLite.  The Relax TFLite frontend does not yet support
+        TFLite resource variable operators, so only the empty no-op form is 
safe
+        to import.
+        """
+        from tflite.CallOnceOptions import CallOnceOptions
+
+        opts = self._get_builtin_options(op, CallOnceOptions)
+        init_subgraph_index = int(opts.InitSubgraphIndex())
+        input_tensors = self.get_input_tensors(op)
+        output_tensors = self.get_output_tensors(op)
+        if len(input_tensors) != 0 or len(output_tensors) != 0:
+            raise tvm.error.OpNotImplemented("CALL_ONCE with inputs or outputs 
is not supported")
+
+        init_subgraph = self._get_subgraph(init_subgraph_index, "CALL_ONCE")
+        if init_subgraph.InputsLength() != 0 or init_subgraph.OutputsLength() 
!= 0:
+            raise tvm.error.OpNotImplemented(
+                "CALL_ONCE with non-empty init subgraph I/O is not supported"
+            )
+        if init_subgraph.OperatorsLength() != 0:
+            raise tvm.error.OpNotImplemented(
+                "CALL_ONCE with non-empty init subgraphs is not supported"
+            )
+        return None
+
     def _convert_stablehlo_convert(self, op):
         """Convert STABLEHLO_CONVERT to Relax (astype).
 
@@ -6201,8 +6627,8 @@ def from_tflite(
         _dtype_dict.update(dtype_dict)
 
     # Only Subgraphs(0) is converted into Relax main. Additional subgraphs are
-    # region bodies referenced by specific TFLite ops and are consumed by those
-    # op converters as needed.
+    # region/control-flow bodies referenced by specific TFLite ops and are
+    # consumed by those op converters as needed.
     assert model.SubgraphsLength() >= 1, "TFLite model must contain at least 
one subgraph"
     subgraph = model.Subgraphs(0)
 
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index d03de3b6a9..be762d5cb4 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -3695,8 +3695,11 @@ _tfl_stablehlo_reduce_opts = 
_get_tflite_schema_module("StablehloReduceOptions")
 _tfl_stablehlo_reduce_window_opts = 
_get_tflite_schema_module("StablehloReduceWindowOptions")
 _tfl_stablehlo_scatter_opts = 
_get_tflite_schema_module("StablehloScatterOptions")
 _tfl_stablehlo_sort_opts = _get_tflite_schema_module("StablehloSortOptions")
+_tfl_call_options = _get_tflite_schema_module("CallOptions")
+_tfl_call_once_options = _get_tflite_schema_module("CallOnceOptions")
 _tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata")
 _tfl_fully_connected_options = 
_get_tflite_schema_module("FullyConnectedOptions")
+_tfl_if_options = _get_tflite_schema_module("IfOptions")
 _tfl_int32_vector = _get_tflite_schema_module("Int32Vector")
 _tfl_model = _get_tflite_schema_module("Model")
 _tfl_operator = _get_tflite_schema_module("Operator")
@@ -3705,6 +3708,7 @@ _tfl_quantization_parameters = 
_get_tflite_schema_module("QuantizationParameters
 _tfl_sparsity_parameters = _get_tflite_schema_module("SparsityParameters")
 _tfl_subgraph = _get_tflite_schema_module("SubGraph")
 _tfl_tensor = _get_tflite_schema_module("Tensor")
+_tfl_while_options = _get_tflite_schema_module("WhileOptions")
 
 _tfl_builtin_operator = _get_tflite_schema_enum("BuiltinOperator")
 _tfl_builtin_options = _get_tflite_schema_enum("BuiltinOptions")
@@ -3909,6 +3913,32 @@ def _finish_tflite_model(builder, *, subgraph, 
operator_codes, buffers, extra_su
     return bytes(builder.Output())
 
 
+def _build_call_options(builder, subgraph_index):
+    _tfl_call_options.CallOptionsStart(builder)
+    _tfl_call_options.CallOptionsAddSubgraph(builder, subgraph_index)
+    return _tfl_call_options.CallOptionsEnd(builder)
+
+
+def _build_if_options(builder, then_subgraph_index, else_subgraph_index):
+    _tfl_if_options.IfOptionsStart(builder)
+    _tfl_if_options.IfOptionsAddThenSubgraphIndex(builder, then_subgraph_index)
+    _tfl_if_options.IfOptionsAddElseSubgraphIndex(builder, else_subgraph_index)
+    return _tfl_if_options.IfOptionsEnd(builder)
+
+
+def _build_while_options(builder, cond_subgraph_index, body_subgraph_index):
+    _tfl_while_options.WhileOptionsStart(builder)
+    _tfl_while_options.WhileOptionsAddCondSubgraphIndex(builder, 
cond_subgraph_index)
+    _tfl_while_options.WhileOptionsAddBodySubgraphIndex(builder, 
body_subgraph_index)
+    return _tfl_while_options.WhileOptionsEnd(builder)
+
+
+def _build_call_once_options(builder, init_subgraph_index):
+    _tfl_call_once_options.CallOnceOptionsStart(builder)
+    _tfl_call_once_options.CallOnceOptionsAddInitSubgraphIndex(builder, 
init_subgraph_index)
+    return _tfl_call_once_options.CallOnceOptionsEnd(builder)
+
+
 def _load_model_from_buffer(model_bytes):
     if hasattr(tflite.Model, "Model"):
         tflite_model = tflite.Model.Model.GetRootAsModel(model_bytes, 0)
@@ -3919,6 +3949,1328 @@ def _load_model_from_buffer(model_bytes):
     return mod
 
 
+def _get_builtin_operator(builtin_name):
+    if not hasattr(_tfl_builtin_operator, builtin_name):
+        pytest.skip(f"TFLite schema does not provide 
BuiltinOperator.{builtin_name}")
+    return getattr(_tfl_builtin_operator, builtin_name)
+
+
+def _build_tflite_call_model(
+    call_subgraph_index=1,
+    callee_inputs=None,
+    callee_outputs=None,
+    callee_output_shape=None,
+    callee_output_type=None,
+):
+    """Build a TFLite model where main CALLs a subgraph computing x + 1."""
+    builder = flatbuffers.Builder(1024)
+
+    callee_inputs = [0] if callee_inputs is None else callee_inputs
+    callee_outputs = [2] if callee_outputs is None else callee_outputs
+    callee_output_shape = [2, 2] if callee_output_shape is None else 
callee_output_shape
+    callee_output_type = (
+        _tfl_tensor_type.FLOAT32 if callee_output_type is None else 
callee_output_type
+    )
+    call_options = _build_call_options(builder, call_subgraph_index)
+    one = np.array(1.0, dtype=np.float32)
+
+    main_tensors = [
+        _build_tensor(builder, 0, [2, 2]),
+        _build_tensor(builder, 2, [2, 2]),
+    ]
+    main_call = _build_operator(
+        builder,
+        0,
+        [0],
+        [1],
+        builtin_options_type=_tfl_builtin_options.CallOptions,
+        builtin_options=call_options,
+    )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=[main_call],
+        inputs=[0],
+        outputs=[1],
+    )
+
+    callee_tensors = [
+        _build_tensor(builder, 0, [2, 2]),
+        _build_tensor(builder, 1, []),
+        _build_tensor(builder, 2, callee_output_shape, 
tensor_type=callee_output_type),
+    ]
+    callee_add = _build_operator(builder, 1, [0, 1], [2])
+    callee_subgraph = _build_subgraph(
+        builder,
+        tensors=callee_tensors,
+        operators=[callee_add],
+        inputs=callee_inputs,
+        outputs=callee_outputs,
+    )
+
+    operator_codes = [
+        _build_operator_code(builder, _get_builtin_operator("CALL")),
+        _build_operator_code(builder, _get_builtin_operator("ADD")),
+    ]
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder, one.tobytes()),
+        _build_buffer(builder),
+    ]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[callee_subgraph],
+        operator_codes=operator_codes,
+        buffers=buffers,
+    )
+
+
+def test_call_subgraph():
+    """Test TFLite CALL conversion to a private Relax function."""
+    mod = _load_model_from_buffer(_build_tflite_call_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function(private=True)
+        def tflite_call_subgraph_1(
+            tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tensor((2, 2), dtype="float32"):
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = R.add(
+                    tvmgen_tensor_0, R.const(1.0, "float32")
+                )
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tensor((2, 2), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            cls = Expected
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = 
cls.tflite_call_subgraph_1(tvmgen_tensor_0)
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def _build_tflite_multi_output_call_model():
+    """Build a TFLite model where CALL returns x + 1 and x - 1."""
+    builder = flatbuffers.Builder(1024)
+
+    call_options = _build_call_options(builder, 1)
+    one = np.array(1.0, dtype=np.float32)
+
+    main_tensors = [
+        _build_tensor(builder, 0, [2, 2]),
+        _build_tensor(builder, 2, [2, 2]),
+        _build_tensor(builder, 3, [2, 2]),
+    ]
+    main_call = _build_operator(
+        builder,
+        0,
+        [0],
+        [1, 2],
+        builtin_options_type=_tfl_builtin_options.CallOptions,
+        builtin_options=call_options,
+    )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=[main_call],
+        inputs=[0],
+        outputs=[1, 2],
+    )
+
+    callee_tensors = [
+        _build_tensor(builder, 0, [2, 2]),
+        _build_tensor(builder, 1, []),
+        _build_tensor(builder, 2, [2, 2]),
+        _build_tensor(builder, 3, [2, 2]),
+    ]
+    callee_add = _build_operator(builder, 1, [0, 1], [2])
+    callee_sub = _build_operator(builder, 2, [0, 1], [3])
+    callee_subgraph = _build_subgraph(
+        builder,
+        tensors=callee_tensors,
+        operators=[callee_add, callee_sub],
+        inputs=[0],
+        outputs=[2, 3],
+    )
+
+    operator_codes = [
+        _build_operator_code(builder, _get_builtin_operator("CALL")),
+        _build_operator_code(builder, _get_builtin_operator("ADD")),
+        _build_operator_code(builder, _get_builtin_operator("SUB")),
+    ]
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder, one.tobytes()),
+        _build_buffer(builder),
+        _build_buffer(builder),
+    ]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[callee_subgraph],
+        operator_codes=operator_codes,
+        buffers=buffers,
+    )
+
+
+def test_call_subgraph_multi_output():
+    """Test CALL tuple returns are split and rebound to TFLite output 
tensors."""
+    mod = _load_model_from_buffer(_build_tflite_multi_output_call_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function(private=True)
+        def tflite_call_subgraph_1(
+            tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), 
dtype="float32")):
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = R.add(
+                    tvmgen_tensor_0, R.const(1.0, "float32")
+                )
+                gv1: R.Tensor((2, 2), dtype="float32") = R.subtract(
+                    tvmgen_tensor_0, R.const(1.0, "float32")
+                )
+                gv2: R.Tuple(
+                    R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), 
dtype="float32")
+                ) = (gv, gv1)
+                R.output(gv2)
+            return gv2
+
+        @R.function
+        def main(
+            tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), 
dtype="float32")):
+            R.func_attr({"num_input": 1})
+            cls = Expected
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), 
dtype="float32")
+                ) = cls.tflite_call_subgraph_1(tvmgen_tensor_0)
+                lv1: R.Tensor((2, 2), dtype="float32") = lv[0]
+                lv2: R.Tensor((2, 2), dtype="float32") = lv[1]
+                gv: R.Tuple(
+                    R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), 
dtype="float32")
+                ) = (lv1, lv2)
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def _build_tflite_nested_call_model():
+    """Build a TFLite model where main CALLs subgraph A, which CALLs subgraph 
B."""
+    builder = flatbuffers.Builder(1024)
+
+    main_call_options = _build_call_options(builder, 1)
+    nested_call_options = _build_call_options(builder, 2)
+    one = np.array(1.0, dtype=np.float32)
+
+    main_tensors = [
+        _build_tensor(builder, 0, [2, 2]),
+        _build_tensor(builder, 3, [2, 2]),
+    ]
+    main_call = _build_operator(
+        builder,
+        0,
+        [0],
+        [1],
+        builtin_options_type=_tfl_builtin_options.CallOptions,
+        builtin_options=main_call_options,
+    )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=[main_call],
+        inputs=[0],
+        outputs=[1],
+    )
+
+    caller_tensors = [
+        _build_tensor(builder, 0, [2, 2]),
+        _build_tensor(builder, 3, [2, 2]),
+    ]
+    nested_call = _build_operator(
+        builder,
+        0,
+        [0],
+        [1],
+        builtin_options_type=_tfl_builtin_options.CallOptions,
+        builtin_options=nested_call_options,
+    )
+    caller_subgraph = _build_subgraph(
+        builder,
+        tensors=caller_tensors,
+        operators=[nested_call],
+        inputs=[0],
+        outputs=[1],
+    )
+
+    callee_tensors = [
+        _build_tensor(builder, 0, [2, 2]),
+        _build_tensor(builder, 1, []),
+        _build_tensor(builder, 3, [2, 2]),
+    ]
+    callee_add = _build_operator(builder, 1, [0, 1], [2])
+    callee_subgraph = _build_subgraph(
+        builder,
+        tensors=callee_tensors,
+        operators=[callee_add],
+        inputs=[0],
+        outputs=[2],
+    )
+
+    operator_codes = [
+        _build_operator_code(builder, _get_builtin_operator("CALL")),
+        _build_operator_code(builder, _get_builtin_operator("ADD")),
+    ]
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder, one.tobytes()),
+        _build_buffer(builder),
+        _build_buffer(builder),
+    ]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[caller_subgraph, callee_subgraph],
+        operator_codes=operator_codes,
+        buffers=buffers,
+    )
+
+
+def test_call_subgraph_nested_call():
+    """Test nested CALL subgraphs register all generated private functions."""
+    mod = _load_model_from_buffer(_build_tflite_nested_call_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function(private=True)
+        def tflite_call_subgraph_2(
+            tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tensor((2, 2), dtype="float32"):
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = R.add(
+                    tvmgen_tensor_0, R.const(1.0, "float32")
+                )
+                R.output(gv)
+            return gv
+
+        @R.function(private=True)
+        def tflite_call_subgraph_1(
+            tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tensor((2, 2), dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = 
cls.tflite_call_subgraph_2(tvmgen_tensor_0)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tensor((2, 2), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            cls = Expected
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = 
cls.tflite_call_subgraph_1(tvmgen_tensor_0)
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_call_subgraph_invalid_index_unsupported():
+    """Test CALL rejects invalid subgraph indices before lowering."""
+    with pytest.raises(tvm.error.OpNotImplemented, match="CALL requires a 
valid subgraph index"):
+        
_load_model_from_buffer(_build_tflite_call_model(call_subgraph_index=2))
+
+
+def test_call_subgraph_io_mismatch_unsupported():
+    """Test CALL rejects callees whose input arity does not match the call 
site."""
+    with pytest.raises(tvm.error.OpNotImplemented, match="CALL subgraph input 
count mismatch"):
+        _load_model_from_buffer(_build_tflite_call_model(callee_inputs=[]))
+
+
+def test_call_subgraph_output_metadata_mismatch_unsupported():
+    """Test CALL rejects callees whose output metadata does not match the call 
site."""
+    with pytest.raises(
+        tvm.error.OpNotImplemented, match="CALL subgraph output tensor 
metadata mismatch"
+    ):
+        
_load_model_from_buffer(_build_tflite_call_model(callee_output_shape=[2]))
+
+
+def _build_tflite_if_model(
+    condition_type=_tfl_tensor_type.BOOL,
+    then_subgraph_index=1,
+    else_subgraph_index=2,
+    then_outputs=None,
+    else_outputs=None,
+    else_input_shape=None,
+    else_input_type=None,
+    else_output_shape=None,
+    else_output_type=None,
+):
+    """Build a TFLite model where IF selects x + 1 or x - 1."""
+    builder = flatbuffers.Builder(1024)
+
+    then_outputs = [2] if then_outputs is None else then_outputs
+    else_outputs = [2] if else_outputs is None else else_outputs
+    else_input_shape = [2, 2] if else_input_shape is None else else_input_shape
+    else_input_type = _tfl_tensor_type.FLOAT32 if else_input_type is None else 
else_input_type
+    else_output_shape = [2, 2] if else_output_shape is None else 
else_output_shape
+    else_output_type = _tfl_tensor_type.FLOAT32 if else_output_type is None 
else else_output_type
+    if_options = _build_if_options(builder, then_subgraph_index, 
else_subgraph_index)
+    one = np.array(1.0, dtype=np.float32)
+
+    main_tensors = [
+        _build_tensor(builder, 0, [], tensor_type=condition_type),
+        _build_tensor(builder, 1, [2, 2]),
+        _build_tensor(builder, 3, [2, 2]),
+    ]
+    main_if = _build_operator(
+        builder,
+        0,
+        [0, 1],
+        [2],
+        builtin_options_type=_tfl_builtin_options.IfOptions,
+        builtin_options=if_options,
+    )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=[main_if],
+        inputs=[0, 1],
+        outputs=[2],
+    )
+
+    then_tensors = [
+        _build_tensor(builder, 1, [2, 2]),
+        _build_tensor(builder, 2, []),
+        _build_tensor(builder, 3, [2, 2]),
+    ]
+    then_add = _build_operator(builder, 1, [0, 1], [2])
+    then_subgraph = _build_subgraph(
+        builder,
+        tensors=then_tensors,
+        operators=[then_add],
+        inputs=[0],
+        outputs=then_outputs,
+    )
+
+    else_tensors = [
+        _build_tensor(builder, 1, else_input_shape, 
tensor_type=else_input_type),
+        _build_tensor(builder, 2, []),
+        _build_tensor(builder, 3, else_output_shape, 
tensor_type=else_output_type),
+    ]
+    else_sub = _build_operator(builder, 2, [0, 1], [2])
+    else_subgraph = _build_subgraph(
+        builder,
+        tensors=else_tensors,
+        operators=[else_sub],
+        inputs=[0],
+        outputs=else_outputs,
+    )
+
+    operator_codes = [
+        _build_operator_code(builder, _get_builtin_operator("IF")),
+        _build_operator_code(builder, _get_builtin_operator("ADD")),
+        _build_operator_code(builder, _get_builtin_operator("SUB")),
+    ]
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder),
+        _build_buffer(builder, one.tobytes()),
+        _build_buffer(builder),
+    ]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[then_subgraph, else_subgraph],
+        operator_codes=operator_codes,
+        buffers=buffers,
+    )
+
+
+def test_if_subgraphs():
+    """Test TFLite IF conversion to Relax If."""
+    mod = _load_model_from_buffer(_build_tflite_if_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function(private=True)
+        def tflite_if_then_subgraph_1(
+            tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tensor((2, 2), dtype="float32"):
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = R.add(
+                    tvmgen_tensor_0, R.const(1.0, "float32")
+                )
+                R.output(gv)
+            return gv
+
+        @R.function(private=True)
+        def tflite_if_else_subgraph_2(
+            tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tensor((2, 2), dtype="float32"):
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = R.subtract(
+                    tvmgen_tensor_0, R.const(1.0, "float32")
+                )
+                R.output(gv)
+            return gv
+
+        @R.function(private=True)
+        def tflite_if_subgraph_1_2(
+            tvmgen_tensor_0: R.Tensor((), dtype="bool"),
+            tvmgen_tensor_1: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tensor((2, 2), dtype="float32"):
+            cls = Expected
+            if tvmgen_tensor_0:
+                gv: R.Tensor((2, 2), dtype="float32") = 
cls.tflite_if_then_subgraph_1(
+                    tvmgen_tensor_1
+                )
+                cond_result: R.Tensor((2, 2), dtype="float32") = gv
+            else:
+                gv1: R.Tensor((2, 2), dtype="float32") = 
cls.tflite_if_else_subgraph_2(
+                    tvmgen_tensor_1
+                )
+                cond_result: R.Tensor((2, 2), dtype="float32") = gv1
+            return cond_result
+
+        @R.function
+        def main(
+            tvmgen_tensor_0: R.Tensor((), dtype="bool"),
+            tvmgen_tensor_1: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tensor((2, 2), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            cls = Expected
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = 
cls.tflite_if_subgraph_1_2(
+                    tvmgen_tensor_0, tvmgen_tensor_1
+                )
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def _build_tflite_multi_output_if_model():
+    """Build a TFLite model where IF returns two tensor outputs."""
+    builder = flatbuffers.Builder(1024)
+
+    if_options = _build_if_options(builder, 1, 2)
+    one = np.array(1.0, dtype=np.float32)
+
+    main_tensors = [
+        _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.BOOL),
+        _build_tensor(builder, 1, [2, 2]),
+        _build_tensor(builder, 4, [2, 2]),
+        _build_tensor(builder, 5, [2, 2]),
+    ]
+    main_if = _build_operator(
+        builder,
+        0,
+        [0, 1],
+        [2, 3],
+        builtin_options_type=_tfl_builtin_options.IfOptions,
+        builtin_options=if_options,
+    )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=[main_if],
+        inputs=[0, 1],
+        outputs=[2, 3],
+    )
+
+    then_tensors = [
+        _build_tensor(builder, 1, [2, 2]),
+        _build_tensor(builder, 2, []),
+        _build_tensor(builder, 3, [2, 2]),
+        _build_tensor(builder, 4, [2, 2]),
+    ]
+    then_add = _build_operator(builder, 1, [0, 1], [2])
+    then_sub = _build_operator(builder, 2, [0, 1], [3])
+    then_subgraph = _build_subgraph(
+        builder,
+        tensors=then_tensors,
+        operators=[then_add, then_sub],
+        inputs=[0],
+        outputs=[2, 3],
+    )
+
+    else_tensors = [
+        _build_tensor(builder, 1, [2, 2]),
+        _build_tensor(builder, 2, []),
+        _build_tensor(builder, 3, [2, 2]),
+        _build_tensor(builder, 4, [2, 2]),
+    ]
+    else_sub = _build_operator(builder, 2, [0, 1], [2])
+    else_add = _build_operator(builder, 1, [0, 1], [3])
+    else_subgraph = _build_subgraph(
+        builder,
+        tensors=else_tensors,
+        operators=[else_sub, else_add],
+        inputs=[0],
+        outputs=[2, 3],
+    )
+
+    operator_codes = [
+        _build_operator_code(builder, _get_builtin_operator("IF")),
+        _build_operator_code(builder, _get_builtin_operator("ADD")),
+        _build_operator_code(builder, _get_builtin_operator("SUB")),
+    ]
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder),
+        _build_buffer(builder, one.tobytes()),
+        _build_buffer(builder),
+        _build_buffer(builder),
+        _build_buffer(builder),
+    ]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[then_subgraph, else_subgraph],
+        operator_codes=operator_codes,
+        buffers=buffers,
+    )
+
+
+def test_if_subgraphs_multi_output():
+    """Test IF tuple returns are preserved through the private wrapper 
function."""
+    mod = _load_model_from_buffer(_build_tflite_multi_output_if_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function(private=True)
+        def tflite_if_then_subgraph_1(
+            tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), 
dtype="float32")):
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = R.add(
+                    tvmgen_tensor_0, R.const(1.0, "float32")
+                )
+                gv1: R.Tensor((2, 2), dtype="float32") = R.subtract(
+                    tvmgen_tensor_0, R.const(1.0, "float32")
+                )
+                gv2: R.Tuple(
+                    R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), 
dtype="float32")
+                ) = (gv, gv1)
+                R.output(gv2)
+            return gv2
+
+        @R.function(private=True)
+        def tflite_if_else_subgraph_2(
+            tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), 
dtype="float32")):
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = R.subtract(
+                    tvmgen_tensor_0, R.const(1.0, "float32")
+                )
+                gv1: R.Tensor((2, 2), dtype="float32") = R.add(
+                    tvmgen_tensor_0, R.const(1.0, "float32")
+                )
+                gv2: R.Tuple(
+                    R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), 
dtype="float32")
+                ) = (gv, gv1)
+                R.output(gv2)
+            return gv2
+
+        @R.function(private=True)
+        def tflite_if_subgraph_1_2(
+            tvmgen_tensor_0: R.Tensor((), dtype="bool"),
+            tvmgen_tensor_1: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), 
dtype="float32")):
+            cls = Expected
+            if tvmgen_tensor_0:
+                gv: R.Tuple(
+                    R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), 
dtype="float32")
+                ) = cls.tflite_if_then_subgraph_1(tvmgen_tensor_1)
+                cond_result: R.Tuple(
+                    R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), 
dtype="float32")
+                ) = gv
+            else:
+                gv1: R.Tuple(
+                    R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), 
dtype="float32")
+                ) = cls.tflite_if_else_subgraph_2(tvmgen_tensor_1)
+                cond_result: R.Tuple(
+                    R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), 
dtype="float32")
+                ) = gv1
+            return cond_result
+
+        @R.function
+        def main(
+            tvmgen_tensor_0: R.Tensor((), dtype="bool"),
+            tvmgen_tensor_1: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), 
dtype="float32")):
+            R.func_attr({"num_input": 2})
+            cls = Expected
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), 
dtype="float32")
+                ) = cls.tflite_if_subgraph_1_2(tvmgen_tensor_0, 
tvmgen_tensor_1)
+                lv1: R.Tensor((2, 2), dtype="float32") = lv[0]
+                lv2: R.Tensor((2, 2), dtype="float32") = lv[1]
+                gv: R.Tuple(
+                    R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), 
dtype="float32")
+                ) = (lv1, lv2)
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_if_subgraphs_non_bool_condition_unsupported():
+    """Test IF rejects non-bool condition tensors."""
+    with pytest.raises(tvm.error.OpNotImplemented, match="IF requires a scalar 
bool condition"):
+        
_load_model_from_buffer(_build_tflite_if_model(condition_type=_tfl_tensor_type.INT32))
+
+
+def test_if_subgraphs_invalid_index_unsupported():
+    """Test IF rejects invalid branch subgraph indices before lowering."""
+    with pytest.raises(tvm.error.OpNotImplemented, match="IF requires a valid 
subgraph index"):
+        _load_model_from_buffer(_build_tflite_if_model(then_subgraph_index=3))
+
+
+def test_if_subgraphs_output_count_mismatch_unsupported():
+    """Test IF rejects branches whose output arity does not match the call 
site."""
+    with pytest.raises(tvm.error.OpNotImplemented, match="IF subgraph output 
count mismatch"):
+        _load_model_from_buffer(_build_tflite_if_model(else_outputs=[]))
+
+
+def test_if_subgraphs_input_metadata_mismatch_unsupported():
+    """Test IF rejects branches whose input metadata does not match the call 
site."""
+    with pytest.raises(
+        tvm.error.OpNotImplemented, match="IF subgraph input tensor metadata 
mismatch"
+    ):
+        _load_model_from_buffer(_build_tflite_if_model(else_input_shape=[2]))
+
+
+def test_if_subgraphs_output_metadata_mismatch_unsupported():
+    """Test IF rejects branches whose output metadata does not match the call 
site."""
+    with pytest.raises(
+        tvm.error.OpNotImplemented, match="IF subgraph output tensor metadata 
mismatch"
+    ):
+        _load_model_from_buffer(_build_tflite_if_model(else_output_shape=[2]))
+
+
+def _build_tflite_while_model(
+    cond_subgraph_index=1,
+    body_subgraph_index=2,
+    cond_output_type=_tfl_tensor_type.BOOL,
+    cond_input_type=_tfl_tensor_type.INT32,
+    body_outputs=None,
+    body_input_type=_tfl_tensor_type.INT32,
+    body_output_type=_tfl_tensor_type.INT32,
+    main_output_type=_tfl_tensor_type.INT32,
+):
+    """Build a TFLite WHILE model incrementing an int32 scalar until i < 3 is 
false."""
+    builder = flatbuffers.Builder(1024)
+
+    body_outputs = [2] if body_outputs is None else body_outputs
+    while_options = _build_while_options(builder, cond_subgraph_index, 
body_subgraph_index)
+    one = np.array(1, dtype=np.int32)
+    three = np.array(3, dtype=np.int32)
+
+    main_tensors = [
+        _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 3, [], tensor_type=main_output_type),
+    ]
+    main_while = _build_operator(
+        builder,
+        0,
+        [0],
+        [1],
+        builtin_options_type=_tfl_builtin_options.WhileOptions,
+        builtin_options=while_options,
+    )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=[main_while],
+        inputs=[0],
+        outputs=[1],
+    )
+
+    cond_tensors = [
+        _build_tensor(builder, 0, [], tensor_type=cond_input_type),
+        _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 3, [], tensor_type=cond_output_type),
+    ]
+    cond_less = _build_operator(builder, 1, [0, 1], [2])
+    cond_subgraph = _build_subgraph(
+        builder,
+        tensors=cond_tensors,
+        operators=[cond_less],
+        inputs=[0],
+        outputs=[2],
+    )
+
+    body_tensors = [
+        _build_tensor(builder, 0, [], tensor_type=body_input_type),
+        _build_tensor(builder, 2, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 3, [], tensor_type=body_output_type),
+    ]
+    body_add = _build_operator(builder, 2, [0, 1], [2])
+    body_subgraph = _build_subgraph(
+        builder,
+        tensors=body_tensors,
+        operators=[body_add],
+        inputs=[0],
+        outputs=body_outputs,
+    )
+
+    operator_codes = [
+        _build_operator_code(builder, _get_builtin_operator("WHILE")),
+        _build_operator_code(builder, _get_builtin_operator("LESS")),
+        _build_operator_code(builder, _get_builtin_operator("ADD")),
+    ]
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder, three.tobytes()),
+        _build_buffer(builder, one.tobytes()),
+        _build_buffer(builder),
+    ]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[cond_subgraph, body_subgraph],
+        operator_codes=operator_codes,
+        buffers=buffers,
+    )
+
+
+def _build_tflite_repeated_while_model():
+    """Build a TFLite model where two WHILE ops share the same cond/body 
subgraphs."""
+    builder = flatbuffers.Builder(1024)
+
+    while_options = _build_while_options(builder, 1, 2)
+    one = np.array(1, dtype=np.int32)
+    three = np.array(3, dtype=np.int32)
+
+    main_tensors = [
+        _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 3, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 4, [], tensor_type=_tfl_tensor_type.INT32),
+    ]
+    main_while_0 = _build_operator(
+        builder,
+        0,
+        [0],
+        [1],
+        builtin_options_type=_tfl_builtin_options.WhileOptions,
+        builtin_options=while_options,
+    )
+    main_while_1 = _build_operator(
+        builder,
+        0,
+        [1],
+        [2],
+        builtin_options_type=_tfl_builtin_options.WhileOptions,
+        builtin_options=while_options,
+    )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=[main_while_0, main_while_1],
+        inputs=[0],
+        outputs=[2],
+    )
+
+    cond_tensors = [
+        _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 3, [], tensor_type=_tfl_tensor_type.BOOL),
+    ]
+    cond_less = _build_operator(builder, 1, [0, 1], [2])
+    cond_subgraph = _build_subgraph(
+        builder,
+        tensors=cond_tensors,
+        operators=[cond_less],
+        inputs=[0],
+        outputs=[2],
+    )
+
+    body_tensors = [
+        _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 2, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 3, [], tensor_type=_tfl_tensor_type.INT32),
+    ]
+    body_add = _build_operator(builder, 2, [0, 1], [2])
+    body_subgraph = _build_subgraph(
+        builder,
+        tensors=body_tensors,
+        operators=[body_add],
+        inputs=[0],
+        outputs=[2],
+    )
+
+    operator_codes = [
+        _build_operator_code(builder, _get_builtin_operator("WHILE")),
+        _build_operator_code(builder, _get_builtin_operator("LESS")),
+        _build_operator_code(builder, _get_builtin_operator("ADD")),
+    ]
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder, three.tobytes()),
+        _build_buffer(builder, one.tobytes()),
+        _build_buffer(builder),
+        _build_buffer(builder),
+    ]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[cond_subgraph, body_subgraph],
+        operator_codes=operator_codes,
+        buffers=buffers,
+    )
+
+
+def _build_tflite_zero_var_while_model():
+    """Build a TFLite WHILE model with no loop-carried tensors."""
+    builder = flatbuffers.Builder(1024)
+
+    while_options = _build_while_options(builder, 1, 2)
+    main_while = _build_operator(
+        builder,
+        0,
+        [],
+        [],
+        builtin_options_type=_tfl_builtin_options.WhileOptions,
+        builtin_options=while_options,
+    )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=[],
+        operators=[main_while],
+        inputs=[],
+        outputs=[],
+    )
+    cond_subgraph = _build_subgraph(builder, tensors=[], operators=[], 
inputs=[], outputs=[])
+    body_subgraph = _build_subgraph(builder, tensors=[], operators=[], 
inputs=[], outputs=[])
+
+    operator_codes = [_build_operator_code(builder, 
_get_builtin_operator("WHILE"))]
+    buffers = [_build_buffer(builder)]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[cond_subgraph, body_subgraph],
+        operator_codes=operator_codes,
+        buffers=buffers,
+    )
+
+
+def test_while_subgraphs():
+    """Test TFLite WHILE conversion to a recursive Relax private function."""
+    mod = _load_model_from_buffer(_build_tflite_while_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function(private=True)
+        def tflite_while_cond_subgraph_1(
+            tvmgen_tensor_0: R.Tensor((), dtype="int32"),
+        ) -> R.Tensor((), dtype="bool"):
+            with R.dataflow():
+                gv: R.Tensor((), dtype="bool") = R.less(tvmgen_tensor_0, 
R.const(3, "int32"))
+                R.output(gv)
+            return gv
+
+        @R.function(private=True)
+        def tflite_while_body_subgraph_2(
+            tvmgen_tensor_0: R.Tensor((), dtype="int32"),
+        ) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                gv: R.Tensor((), dtype="int32") = R.add(tvmgen_tensor_0, 
R.const(1, "int32"))
+                R.output(gv)
+            return gv
+
+        @R.function(private=True)
+        def tflite_while_subgraph_1_2(
+            tvmgen_tensor_0: R.Tensor((), dtype="int32"),
+        ) -> R.Tensor((), dtype="int32"):
+            cls = Expected
+            while_cond: R.Tensor((), dtype="bool") = 
cls.tflite_while_cond_subgraph_1(
+                tvmgen_tensor_0
+            )
+            if while_cond:
+                gv: R.Tensor((), dtype="int32") = 
cls.tflite_while_body_subgraph_2(tvmgen_tensor_0)
+                gv1: R.Tensor((), dtype="int32") = 
cls.tflite_while_subgraph_1_2(gv)
+                cond_result: R.Tensor((), dtype="int32") = gv1
+            else:
+                cond_result: R.Tensor((), dtype="int32") = tvmgen_tensor_0
+            return cond_result
+
+        @R.function
+        def main(
+            tvmgen_tensor_0: R.Tensor((), dtype="int32"),
+        ) -> R.Tensor((), dtype="int32"):
+            R.func_attr({"num_input": 1})
+            cls = Expected
+            with R.dataflow():
+                gv: R.Tensor((), dtype="int32") = 
cls.tflite_while_subgraph_1_2(tvmgen_tensor_0)
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_while_subgraphs_repeated_cond_body_pair():
+    """Test repeated WHILE ops reuse the same recursive private function."""
+    mod = _load_model_from_buffer(_build_tflite_repeated_while_model())
+    names = [gv.name_hint for gv in mod.get_global_vars()]
+    assert names.count("tflite_while_subgraph_1_2") == 1
+
+
+def _build_tflite_two_var_while_model():
+    """Build a TFLite WHILE model with two int32 loop-carried scalar 
tensors."""
+    builder = flatbuffers.Builder(1024)
+
+    while_options = _build_while_options(builder, 1, 2)
+    one = np.array(1, dtype=np.int32)
+    three = np.array(3, dtype=np.int32)
+
+    main_tensors = [
+        _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 4, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 5, [], tensor_type=_tfl_tensor_type.INT32),
+    ]
+    main_while = _build_operator(
+        builder,
+        0,
+        [0, 1],
+        [2, 3],
+        builtin_options_type=_tfl_builtin_options.WhileOptions,
+        builtin_options=while_options,
+    )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=[main_while],
+        inputs=[0, 1],
+        outputs=[2, 3],
+    )
+
+    cond_tensors = [
+        _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 2, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 4, [], tensor_type=_tfl_tensor_type.BOOL),
+    ]
+    cond_less = _build_operator(builder, 1, [0, 2], [3])
+    cond_subgraph = _build_subgraph(
+        builder,
+        tensors=cond_tensors,
+        operators=[cond_less],
+        inputs=[0, 1],
+        outputs=[3],
+    )
+
+    body_tensors = [
+        _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 3, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 4, [], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 5, [], tensor_type=_tfl_tensor_type.INT32),
+    ]
+    body_add_i = _build_operator(builder, 2, [0, 2], [3])
+    body_add_acc = _build_operator(builder, 2, [1, 0], [4])
+    body_subgraph = _build_subgraph(
+        builder,
+        tensors=body_tensors,
+        operators=[body_add_i, body_add_acc],
+        inputs=[0, 1],
+        outputs=[3, 4],
+    )
+
+    operator_codes = [
+        _build_operator_code(builder, _get_builtin_operator("WHILE")),
+        _build_operator_code(builder, _get_builtin_operator("LESS")),
+        _build_operator_code(builder, _get_builtin_operator("ADD")),
+    ]
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder),
+        _build_buffer(builder, three.tobytes()),
+        _build_buffer(builder, one.tobytes()),
+        _build_buffer(builder),
+        _build_buffer(builder),
+    ]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[cond_subgraph, body_subgraph],
+        operator_codes=operator_codes,
+        buffers=buffers,
+    )
+
+
+def test_while_subgraphs_two_loop_vars():
+    """Test WHILE tuple loop state with two loop-carried variables."""
+    mod = _load_model_from_buffer(_build_tflite_two_var_while_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function(private=True)
+        def tflite_while_cond_subgraph_1(
+            tvmgen_tensor_0: R.Tensor((), dtype="int32"),
+            tvmgen_tensor_1: R.Tensor((), dtype="int32"),
+        ) -> R.Tensor((), dtype="bool"):
+            with R.dataflow():
+                gv: R.Tensor((), dtype="bool") = R.less(tvmgen_tensor_0, 
R.const(3, "int32"))
+                R.output(gv)
+            return gv
+
+        @R.function(private=True)
+        def tflite_while_body_subgraph_2(
+            tvmgen_tensor_0: R.Tensor((), dtype="int32"),
+            tvmgen_tensor_1: R.Tensor((), dtype="int32"),
+        ) -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")):
+            with R.dataflow():
+                gv: R.Tensor((), dtype="int32") = R.add(tvmgen_tensor_0, 
R.const(1, "int32"))
+                gv1: R.Tensor((), dtype="int32") = R.add(tvmgen_tensor_1, 
tvmgen_tensor_0)
+                gv2: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), 
dtype="int32")) = (
+                    gv,
+                    gv1,
+                )
+                R.output(gv2)
+            return gv2
+
+        @R.function(private=True)
+        def tflite_while_subgraph_1_2(
+            tvmgen_tensor_0: R.Tensor((), dtype="int32"),
+            tvmgen_tensor_1: R.Tensor((), dtype="int32"),
+        ) -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")):
+            cls = Expected
+            while_cond: R.Tensor((), dtype="bool") = 
cls.tflite_while_cond_subgraph_1(
+                tvmgen_tensor_0, tvmgen_tensor_1
+            )
+            if while_cond:
+                gv: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), 
dtype="int32")) = (
+                    cls.tflite_while_body_subgraph_2(tvmgen_tensor_0, 
tvmgen_tensor_1)
+                )
+                gv1: R.Tensor((), dtype="int32") = gv[0]
+                gv2: R.Tensor((), dtype="int32") = gv[1]
+                gv3: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), 
dtype="int32")) = (
+                    cls.tflite_while_subgraph_1_2(gv1, gv2)
+                )
+                cond_result: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), 
dtype="int32")) = gv3
+            else:
+                cond_result: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), 
dtype="int32")) = (
+                    tvmgen_tensor_0,
+                    tvmgen_tensor_1,
+                )
+            return cond_result
+
+        @R.function
+        def main(
+            tvmgen_tensor_0: R.Tensor((), dtype="int32"),
+            tvmgen_tensor_1: R.Tensor((), dtype="int32"),
+        ) -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")):
+            R.func_attr({"num_input": 2})
+            cls = Expected
+            with R.dataflow():
+                lv: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), 
dtype="int32")) = (
+                    cls.tflite_while_subgraph_1_2(tvmgen_tensor_0, 
tvmgen_tensor_1)
+                )
+                lv1: R.Tensor((), dtype="int32") = lv[0]
+                lv2: R.Tensor((), dtype="int32") = lv[1]
+                gv: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), 
dtype="int32")) = (
+                    lv1,
+                    lv2,
+                )
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_while_subgraphs_non_bool_condition_unsupported():
+    """Test WHILE rejects cond subgraphs that do not return scalar bool."""
+    with pytest.raises(tvm.error.OpNotImplemented, match="WHILE requires a 
scalar bool condition"):
+        
_load_model_from_buffer(_build_tflite_while_model(cond_output_type=_tfl_tensor_type.INT32))
+
+
+def test_while_subgraphs_invalid_index_unsupported():
+    """Test WHILE rejects invalid cond/body subgraph indices before 
lowering."""
+    with pytest.raises(tvm.error.OpNotImplemented, match="WHILE requires a 
valid subgraph index"):
+        
_load_model_from_buffer(_build_tflite_while_model(cond_subgraph_index=3))
+
+
+def test_while_subgraphs_zero_loop_vars_unsupported():
+    """Test WHILE rejects operators without loop-carried tensors."""
+    with pytest.raises(tvm.error.OpNotImplemented, match="WHILE requires 
loop-carried inputs"):
+        _load_model_from_buffer(_build_tflite_zero_var_while_model())
+
+
+def test_while_subgraphs_loop_state_metadata_mismatch_unsupported():
+    """Test WHILE rejects loop outputs whose metadata does not match loop 
inputs."""
+    with pytest.raises(
+        tvm.error.OpNotImplemented, match="WHILE loop state tensor metadata 
mismatch"
+    ):
+        _load_model_from_buffer(
+            
_build_tflite_while_model(main_output_type=_tfl_tensor_type.FLOAT32)
+        )
+
+
+def test_while_subgraphs_output_count_mismatch_unsupported():
+    """Test WHILE rejects body subgraphs whose output arity does not match 
loop vars."""
+    with pytest.raises(tvm.error.OpNotImplemented, match="WHILE subgraph 
output count mismatch"):
+        _load_model_from_buffer(_build_tflite_while_model(body_outputs=[]))
+
+
+def test_while_subgraphs_input_metadata_mismatch_unsupported():
+    """Test WHILE rejects cond subgraph inputs whose metadata does not match 
loop vars."""
+    with pytest.raises(
+        tvm.error.OpNotImplemented, match="WHILE subgraph input tensor 
metadata mismatch"
+    ):
+        
_load_model_from_buffer(_build_tflite_while_model(cond_input_type=_tfl_tensor_type.FLOAT32))
+
+
+def test_while_subgraphs_output_metadata_mismatch_unsupported():
+    """Test WHILE rejects body outputs whose metadata does not match loop 
vars."""
+    with pytest.raises(
+        tvm.error.OpNotImplemented, match="WHILE subgraph output tensor 
metadata mismatch"
+    ):
+        _load_model_from_buffer(
+            
_build_tflite_while_model(body_output_type=_tfl_tensor_type.FLOAT32)
+        )
+
+
+def _build_tflite_call_once_model(
+    init_has_op=False,
+    init_subgraph_index=1,
+    call_once_inputs=None,
+    call_once_outputs=None,
+    init_inputs=None,
+    init_outputs=None,
+):
+    """Build a TFLite model with CALL_ONCE and one pass-through output."""
+    builder = flatbuffers.Builder(1024)
+
+    call_once_inputs = [] if call_once_inputs is None else call_once_inputs
+    call_once_outputs = [] if call_once_outputs is None else call_once_outputs
+    init_inputs = [] if init_inputs is None else init_inputs
+    init_outputs = [] if init_outputs is None else init_outputs
+
+    call_once_options = _build_call_once_options(builder, init_subgraph_index)
+    main_tensors = [_build_tensor(builder, 0, [2, 2])]
+    main_call_once = _build_operator(
+        builder,
+        0,
+        call_once_inputs,
+        call_once_outputs,
+        builtin_options_type=_tfl_builtin_options.CallOnceOptions,
+        builtin_options=call_once_options,
+    )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=[main_call_once],
+        inputs=[0],
+        outputs=[0],
+    )
+
+    if init_has_op:
+        one = np.array(1.0, dtype=np.float32)
+        init_tensors = [
+            _build_tensor(builder, 0, [2, 2]),
+            _build_tensor(builder, 1, []),
+            _build_tensor(builder, 2, [2, 2]),
+        ]
+        init_op = _build_operator(builder, 1, [0, 1], [2])
+        buffers = [
+            _build_buffer(builder),
+            _build_buffer(builder, one.tobytes()),
+            _build_buffer(builder),
+        ]
+    else:
+        init_tensors = (
+            [_build_tensor(builder, 0, [2, 2])]
+            if len(init_inputs) != 0 or len(init_outputs) != 0
+            else []
+        )
+        init_op = None
+        buffers = [_build_buffer(builder)]
+
+    init_subgraph = _build_subgraph(
+        builder,
+        tensors=init_tensors,
+        operators=[] if init_op is None else [init_op],
+        inputs=init_inputs,
+        outputs=init_outputs,
+    )
+
+    operator_codes = [_build_operator_code(builder, 
_get_builtin_operator("CALL_ONCE"))]
+    if init_has_op:
+        operator_codes.append(_build_operator_code(builder, 
_get_builtin_operator("ADD")))
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[init_subgraph],
+        operator_codes=operator_codes,
+        buffers=buffers,
+    )
+
+
+def test_call_once_empty_init_subgraph():
+    """Test the no-op CALL_ONCE subset."""
+    mod = _load_model_from_buffer(_build_tflite_call_once_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tensor((2, 2), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = tvmgen_tensor_0
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_call_once_non_empty_init_subgraph_unsupported():
+    """Test CALL_ONCE rejects init subgraphs with side-effect-like bodies."""
+    with pytest.raises(tvm.error.OpNotImplemented, match="CALL_ONCE"):
+        
_load_model_from_buffer(_build_tflite_call_once_model(init_has_op=True))
+
+
+def test_call_once_inputs_outputs_unsupported():
+    """Test CALL_ONCE rejects operator inputs and outputs."""
+    with pytest.raises(tvm.error.OpNotImplemented, match="CALL_ONCE with 
inputs or outputs"):
+        _load_model_from_buffer(
+            _build_tflite_call_once_model(call_once_inputs=[0], 
call_once_outputs=[0])
+        )
+
+
+def test_call_once_init_subgraph_io_unsupported():
+    """Test CALL_ONCE rejects init subgraphs with inputs or outputs."""
+    with pytest.raises(
+        tvm.error.OpNotImplemented, match="CALL_ONCE with non-empty init 
subgraph I/O"
+    ):
+        _load_model_from_buffer(_build_tflite_call_once_model(init_inputs=[0], 
init_outputs=[0]))
+
+
+def test_call_once_invalid_index_unsupported():
+    """Test CALL_ONCE rejects invalid init subgraph indices before lowering."""
+    with pytest.raises(
+        tvm.error.OpNotImplemented, match="CALL_ONCE requires a valid subgraph 
index"
+    ):
+        
_load_model_from_buffer(_build_tflite_call_once_model(init_subgraph_index=2))
+
+
 def _get_stablehlo_builtin_operator(builtin_name):
     if not hasattr(_tfl_builtin_operator, builtin_name):
         pytest.skip(f"TFLite schema does not provide 
BuiltinOperator.{builtin_name}")

Reply via email to