This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fff3b4bf0d [Relax][Frontend][TFLite] Support StableHLO region-based 
ops and multi-subgraph models (#19587)
fff3b4bf0d is described below

commit fff3b4bf0d82cded1c397b07706daf265c441ed0
Author: HoYi <[email protected]>
AuthorDate: Thu May 21 12:37:39 2026 +0800

    [Relax][Frontend][TFLite] Support StableHLO region-based ops and 
multi-subgraph models (#19587)
    
    ## Summary
    
    This PR adds Relax TFLite frontend support for 10 additional StableHLO
    builtin
    operators from #19519 item I, building on the 29 ops merged in PR
    #19536.
    
    The first 5 ops are direct single-subgraph converters: `CBRT`,
    `REMAINDER`,
    `DYNAMIC_UPDATE_SLICE`, `DOT_GENERAL`, and `CONVOLUTION`. The remaining
    5 ops
    are region/subgraph-based: `REDUCE`, `REDUCE_WINDOW`, `SORT`, `SCATTER`,
    and
    `COMPOSITE`. To support these, the TFLite frontend is extended to accept
    multi-subgraph models while still converting only `Subgraphs(0)` into
    the
    Relax main function. Region subgraphs are consumed by their parent op
    converters as needed.
    
    Relates to #19519.
    
    ## Changes
    
    1. **Single-subgraph ops**
       - `CBRT` — sign-preserving composite expression:
         `where(x < 0, -power(-x, 1/3), power(x, 1/3))`. Float dtype only.
    - `REMAINDER` — truncating remainder via `x - y * trunc(x / y)`,
    matching
         StableHLO semantics (sign follows dividend). Float dtype only.
    - `DYNAMIC_UPDATE_SLICE` — static start indices + static shapes only,
    lowered
    to `R.scatter_nd` with a coordinate grid generated via `np.indices`.
         Runtime starts and out-of-bounds ranges raise `OpNotImplemented`.
       - `DOT_GENERAL` — canonical 2D matmul subset: no batching dims,
    `lhs_contracting=[1]`, `rhs_contracting=[0]`, lowered to `R.matmul`.
    - `CONVOLUTION` — canonical 2D NHWC/HWIO subset with
    `BatchGroupCount=1`,
    `FeatureGroupCount=1`, lowered to `R.nn.conv2d`. Non-canonical dimension
         numbers and grouped/depthwise conv raise `OpNotImplemented`.
    
    2. **Multi-subgraph infrastructure**
    - Lift `from_tflite()` assertion from `model.SubgraphsLength() == 1` to
    `model.SubgraphsLength() >= 1`. Only `Subgraphs(0)` is converted into
    the
         Relax main function.
       - Limit `_input_type()` to `Subgraphs(0)` inputs, preventing region
         parameters from leaking as Relax main function parameters.
    - Add `_get_stablehlo_simple_body_op` helper for validating and
    extracting
         the single operator from a region body subgraph.
    - Extend test helper `_finish_tflite_model` with `extra_subgraphs`
    parameter
         for constructing multi-subgraph TFLite flatbuffers.
    
    3. **Region/subgraph ops**
    - `REDUCE` — single-op reducer body subgraph. Supports `ADD` → `R.sum`,
         `MAXIMUM` → `R.max`, `MINIMUM` → `R.min`, `MULTIPLY` → `R.prod`.
         Init value must match the reducer identity element.
       - `SORT` — single-op comparator body subgraph. `LT` → ascending sort,
         `GT` → descending sort via `R.sort`. `IsStable` is not mapped.
    - `REDUCE_WINDOW` — NHWC 4D 2D-pooling subset with `MAXIMUM` reducer and
    identity init, lowered to `R.nn.max_pool2d`. BaseDilations must be all
    1.
       - `SCATTER` — single-op update computation body subgraph. Supports
         `ADD`/`MAXIMUM`/`MINIMUM`/`MULTIPLY` → `R.scatter_nd` with the
         corresponding reduction mode. Only canonical point-update semantics
         (no window dims).
       - `COMPOSITE` — inlines a decomposition subgraph through a recursive
    `OperatorConverter` with an isolated `ExprTable`, so decomposition
    tensor
    bindings cannot overwrite main graph bindings. Only supports composites
         without `CompositeAttributes`.
    
    4. **Not included**
    - `STABLEHLO_RESHAPE`, `STABLEHLO_TRANSPOSE`, and `STABLEHLO_SLICE` are
         left to another contributor.
    - `WHILE`, `CUSTOM_CALL`, and `RNG_BIT_GENERATOR` are deferred to
    follow-up
         PRs.
    
    5. **Bug fix**
    - Fixed `DYNAMIC_UPDATE_SLICE` scatter_nd indices layout: `np.indices`
         returns `(rank, *update_shape)` but `scatter_nd` expects
    `(*update_shape, rank)`. Added `np.moveaxis` to transpose the coordinate
         axis from first to last position.
    
    ## Testing
    
    All tests use manually-built minimal TFLite flatbuffers with
    `tvm.ir.assert_structural_equal`. Region/subgraph tests construct the
    smallest
    valid body/comparator/update subgraphs. BuiltinOptions2 ops construct
    their
    options via the FlatBuffers schema API.
    
    ```bash
    python -m pytest tests/python/relax/test_frontend_tflite.py -k stablehlo -q
    ```
    
    ## Result
    
    - 39 StableHLO operators registered in the Relax TFLite frontend (29
    from
      PR #19536 + 10 from this PR).
    - 77 StableHLO test cases covering all registered ops, including
      structural-equal tests and unsupported/error-path checks:
    
      - `REMAINDER` truncating semantics
      - `DYNAMIC_UPDATE_SLICE` with dynamic starts and out-of-bounds starts
      - `DOT_GENERAL` with non-canonical contracting dimensions
    - `CONVOLUTION` with non-canonical dimension numbers and
    `FeatureGroupCount > 1`
      - `REDUCE` with unsupported reducer and non-identity init value
      - `SORT` with unsupported comparator and stable sort
      - `REDUCE_WINDOW` with unsupported reducer and base dilation
      - `SCATTER` with unsupported reducer and update window dims
      - `COMPOSITE` with composite attributes and scope isolation
      - Multi-subgraph model with unused subgraphs
    - All 77 StableHLO tests pass.
    
    ## References
    
    - Issue #19519 item I: StableHLO operators in TFLite
    - PR #19536: First batch of 29 StableHLO ops
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   |  631 ++++++++++-
 tests/python/relax/test_frontend_tflite.py         | 1168 +++++++++++++++++++-
 2 files changed, 1776 insertions(+), 23 deletions(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 145e953394..28b125eec0 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -244,15 +244,20 @@ class OperatorConverter:
             "STABLEHLO_ADD": functools.partial(self._convert_stablehlo_binary, 
relax_op=_op.add),
             "STABLEHLO_AND": self._convert_stablehlo_and,
             "STABLEHLO_BROADCAST_IN_DIM": 
self._convert_stablehlo_broadcast_in_dim,
+            "STABLEHLO_CBRT": self._convert_stablehlo_cbrt,
             "STABLEHLO_CLAMP": self._convert_stablehlo_clamp,
             "STABLEHLO_COMPARE": self._convert_stablehlo_compare,
+            "STABLEHLO_COMPOSITE": self._convert_stablehlo_composite,
             "STABLEHLO_CONCATENATE": self._convert_stablehlo_concatenate,
+            "STABLEHLO_CONVOLUTION": self._convert_stablehlo_convolution,
             "STABLEHLO_CONVERT": self._convert_stablehlo_convert,
             "STABLEHLO_COSINE": 
functools.partial(self._convert_stablehlo_unary, relax_op=_op.cos),
             "STABLEHLO_DIVIDE": functools.partial(
                 self._convert_stablehlo_binary, relax_op=_op.divide
             ),
+            "STABLEHLO_DOT_GENERAL": self._convert_stablehlo_dot_general,
             "STABLEHLO_DYNAMIC_SLICE": self._convert_stablehlo_dynamic_slice,
+            "STABLEHLO_DYNAMIC_UPDATE_SLICE": 
self._convert_stablehlo_dynamic_update_slice,
             "STABLEHLO_EXPONENTIAL": functools.partial(
                 self._convert_stablehlo_unary, relax_op=_op.exp
             ),
@@ -280,13 +285,18 @@ class OperatorConverter:
             "STABLEHLO_POWER": functools.partial(
                 self._convert_stablehlo_binary, relax_op=_op.power
             ),
+            "STABLEHLO_REDUCE": self._convert_stablehlo_reduce,
+            "STABLEHLO_REDUCE_WINDOW": self._convert_stablehlo_reduce_window,
+            "STABLEHLO_REMAINDER": self._convert_stablehlo_remainder,
             "STABLEHLO_RSQRT": 
functools.partial(self._convert_stablehlo_unary, relax_op=_op.rsqrt),
+            "STABLEHLO_SCATTER": self._convert_stablehlo_scatter,
             "STABLEHLO_SELECT": functools.partial(
                 self._convert_stablehlo_ternary, relax_op=_op.where
             ),
             "STABLEHLO_SHIFT_LEFT": functools.partial(
                 self._convert_stablehlo_binary, relax_op=_op.left_shift
             ),
+            "STABLEHLO_SORT": self._convert_stablehlo_sort,
             "STABLEHLO_SUBTRACT": functools.partial(
                 self._convert_stablehlo_binary, relax_op=_op.subtract
             ),
@@ -1483,6 +1493,413 @@ class OperatorConverter:
         result.Init(op_options.Bytes, op_options.Pos)
         return result
 
+    def _get_static_tensor_shape(self, tensor, op_name):
+        """Return a statically-known TFLite tensor shape as Python ints."""
+        try:
+            return [int(dim) for dim in self.get_tensor_shape(tensor)]
+        except (TypeError, ValueError) as err:
+            raise tvm.error.OpNotImplemented(
+                f"{op_name} requires statically-known tensor shapes"
+            ) from err
+
+    def _get_stablehlo_i64_vector(self, vector, default):
+        """Convert an optional StableHLO int64 vector field to a Python int 
list."""
+        if vector is None or isinstance(vector, int):
+            return list(default)
+        return [int(v) for v in vector]
+
+    def _ensure_stablehlo_float_dtype(self, expr, op_name):
+        """Return expr dtype if the StableHLO subset supports it."""
+        dtype = expr.struct_info.dtype
+        if not dtype.startswith("float"):
+            raise tvm.error.OpNotImplemented(f"{op_name} with dtype {dtype} is 
not supported")
+        return dtype
+
+    def _convert_stablehlo_cbrt(self, op):
+        """Convert STABLEHLO_CBRT to a sign-preserving Relax expression."""
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 1, "input tensors length should be 1"
+        assert len(self.get_output_tensors(op)) == 1
+
+        data = self.get_tensor_expr(input_tensors[0])
+        dtype = self._ensure_stablehlo_float_dtype(data, "STABLEHLO_CBRT")
+        zero = relax.const(0, dtype)
+        exponent = relax.const(1.0 / 3.0, dtype)
+
+        is_negative = self.bb.normalize(relax.op.less(data, zero))
+        negative_base = self.bb.normalize(relax.op.negative(data))
+        negative_root = self.bb.normalize(relax.op.power(negative_base, 
exponent))
+        negative_result = self.bb.normalize(relax.op.negative(negative_root))
+        positive_result = self.bb.normalize(relax.op.power(data, exponent))
+        return self.bb.normalize(relax.op.where(is_negative, negative_result, 
positive_result))
+
+    def _convert_stablehlo_remainder(self, op):
+        """Convert STABLEHLO_REMAINDER to truncating remainder for float 
tensors."""
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+        assert len(self.get_output_tensors(op)) == 1
+
+        lhs = self.get_tensor_expr(input_tensors[0])
+        rhs = self.get_tensor_expr(input_tensors[1])
+        self._ensure_stablehlo_float_dtype(lhs, "STABLEHLO_REMAINDER")
+        self._ensure_stablehlo_float_dtype(rhs, "STABLEHLO_REMAINDER")
+
+        quotient = self.bb.normalize(relax.op.divide(lhs, rhs))
+        truncated = self.bb.normalize(relax.op.trunc(quotient))
+        product = self.bb.normalize(relax.op.multiply(rhs, truncated))
+        return self.bb.normalize(relax.op.subtract(lhs, product))
+
+    def _get_stablehlo_simple_body_op(self, body_subgraph_index, 
parent_op_name, input_count):
+        """Return the single operator from a simple StableHLO body subgraph."""
+        if body_subgraph_index <= 0 or body_subgraph_index >= 
self.model.SubgraphsLength():
+            raise tvm.error.OpNotImplemented(
+                f"{parent_op_name} requires a valid non-main body subgraph"
+            )
+
+        body_subgraph = self.model.Subgraphs(body_subgraph_index)
+        if (
+            body_subgraph.InputsLength() != input_count
+            or body_subgraph.OutputsLength() != 1
+            or body_subgraph.OperatorsLength() != 1
+        ):
+            raise tvm.error.OpNotImplemented(
+                f"{parent_op_name} only supports single-op body subgraphs"
+            )
+
+        return body_subgraph.Operators(0)
+
+    def _check_stablehlo_reduce_init(
+        self, init_tensor, reducer_name, parent_op_name="STABLEHLO_REDUCE"
+    ):
+        """Validate that the StableHLO reduce init value matches the Relax 
identity."""
+        if self.has_expr(init_tensor.tensor_idx):
+            raise tvm.error.OpNotImplemented(
+                f"{parent_op_name} with dynamic init values is not supported"
+            )
+
+        init_value = np.asarray(self.get_tensor_value(init_tensor))
+        if init_value.shape not in [(), (1,)]:
+            raise tvm.error.OpNotImplemented(f"{parent_op_name} requires 
scalar init values")
+
+        dtype = init_value.dtype
+        scalar = init_value.item()
+        if reducer_name == "STABLEHLO_ADD":
+            is_identity = bool(np.isclose(scalar, 0))
+        elif reducer_name == "STABLEHLO_MULTIPLY":
+            is_identity = bool(np.isclose(scalar, 1))
+        elif reducer_name == "STABLEHLO_MAXIMUM":
+            if np.issubdtype(dtype, np.floating):
+                is_identity = bool(np.isneginf(scalar))
+            elif np.issubdtype(dtype, np.integer):
+                is_identity = scalar == np.iinfo(dtype).min
+            else:
+                is_identity = False
+        elif reducer_name == "STABLEHLO_MINIMUM":
+            if np.issubdtype(dtype, np.floating):
+                is_identity = bool(np.isposinf(scalar))
+            elif np.issubdtype(dtype, np.integer):
+                is_identity = scalar == np.iinfo(dtype).max
+            else:
+                is_identity = False
+        else:
+            raise tvm.error.OpNotImplemented(
+                f"{parent_op_name} reducer {reducer_name} is not supported"
+            )
+
+        if not is_identity:
+            raise tvm.error.OpNotImplemented(
+                f"{parent_op_name} init value must match the reducer identity"
+            )
+
+    def _convert_stablehlo_reduce(self, op):
+        """Convert the single-input STABLEHLO_REDUCE subset to Relax 
reductions."""
+        from tflite.StablehloReduceOptions import StablehloReduceOptions
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+        assert len(self.get_output_tensors(op)) == 1
+
+        opts = self._get_stablehlo_options(op, StablehloReduceOptions)
+        dimensions = self._get_stablehlo_i64_vector(opts.DimensionsAsNumpy(), 
[])
+        body_op = self._get_stablehlo_simple_body_op(
+            int(opts.BodySubgraphIndex()), "STABLEHLO_REDUCE", 2
+        )
+        reducer_name = self.get_op_code_str(body_op)
+
+        reducers = {
+            "STABLEHLO_ADD": relax.op.sum,
+            "STABLEHLO_MAXIMUM": relax.op.max,
+            "STABLEHLO_MINIMUM": relax.op.min,
+            "STABLEHLO_MULTIPLY": relax.op.prod,
+        }
+        if reducer_name not in reducers:
+            raise tvm.error.OpNotImplemented(
+                f"STABLEHLO_REDUCE reducer {reducer_name} is not supported"
+            )
+
+        self._check_stablehlo_reduce_init(input_tensors[1], reducer_name)
+        data = self.get_tensor_expr(input_tensors[0])
+        return self.bb.normalize(reducers[reducer_name](data, axis=dimensions, 
keepdims=False))
+
+    def _convert_stablehlo_reduce_window(self, op):
+        """Convert the NHWC 2D max-pool STABLEHLO_REDUCE_WINDOW subset."""
+        from tflite.StablehloReduceWindowOptions import 
StablehloReduceWindowOptions
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+        assert len(self.get_output_tensors(op)) == 1
+
+        opts = self._get_stablehlo_options(op, StablehloReduceWindowOptions)
+        body_op = self._get_stablehlo_simple_body_op(
+            int(opts.BodySubgraphIndex()), "STABLEHLO_REDUCE_WINDOW", 2
+        )
+        reducer_name = self.get_op_code_str(body_op)
+        if reducer_name != "STABLEHLO_MAXIMUM":
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_REDUCE_WINDOW only supports MAXIMUM reducer windows"
+            )
+        self._check_stablehlo_reduce_init(input_tensors[1], reducer_name, 
"STABLEHLO_REDUCE_WINDOW")
+
+        data_shape = self._get_static_tensor_shape(input_tensors[0], 
"STABLEHLO_REDUCE_WINDOW")
+        if len(data_shape) != 4:
+            raise tvm.error.OpNotImplemented("STABLEHLO_REDUCE_WINDOW only 
supports 4D input")
+
+        window_dimensions = 
self._get_stablehlo_i64_vector(opts.WindowDimensionsAsNumpy(), [])
+        window_strides = self._get_stablehlo_i64_vector(
+            opts.WindowStridesAsNumpy(), [1] * len(window_dimensions)
+        )
+        base_dilations = self._get_stablehlo_i64_vector(
+            opts.BaseDilationsAsNumpy(), [1] * len(window_dimensions)
+        )
+        window_dilations = self._get_stablehlo_i64_vector(
+            opts.WindowDilationsAsNumpy(), [1] * len(window_dimensions)
+        )
+        padding = self._get_stablehlo_i64_vector(
+            opts.PaddingAsNumpy(), [0] * (2 * len(window_dimensions))
+        )
+
+        if (
+            len(window_dimensions) != 4
+            or len(window_strides) != 4
+            or len(base_dilations) != 4
+            or len(window_dilations) != 4
+            or len(padding) != 8
+        ):
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_REDUCE_WINDOW only supports rank-4 window 
attributes"
+            )
+        if window_dimensions[0] != 1 or window_dimensions[3] != 1:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_REDUCE_WINDOW only supports pooling over spatial 
dimensions"
+            )
+        if window_strides[0] != 1 or window_strides[3] != 1:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_REDUCE_WINDOW only supports unit batch/channel 
strides"
+            )
+        if base_dilations != [1, 1, 1, 1]:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_REDUCE_WINDOW with base dilation is not supported"
+            )
+        if padding[0] != 0 or padding[1] != 0 or padding[6] != 0 or padding[7] 
!= 0:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_REDUCE_WINDOW only supports spatial padding"
+            )
+
+        data = self.get_tensor_expr(input_tensors[0])
+        return self.bb.normalize(
+            relax.op.nn.max_pool2d(
+                data,
+                pool_size=[window_dimensions[1], window_dimensions[2]],
+                strides=[window_strides[1], window_strides[2]],
+                padding=[padding[2], padding[4], padding[3], padding[5]],
+                dilation=[window_dilations[1], window_dilations[2]],
+                layout="NHWC",
+                out_layout="NHWC",
+            )
+        )
+
+    def _convert_stablehlo_scatter(self, op):
+        """Convert the canonical point-update STABLEHLO_SCATTER subset."""
+        from tflite.StablehloScatterOptions import StablehloScatterOptions
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 3, "input tensors length should be 3"
+        assert len(self.get_output_tensors(op)) == 1
+
+        opts = self._get_stablehlo_options(op, StablehloScatterOptions)
+        operand_shape = self._get_static_tensor_shape(input_tensors[0], 
"STABLEHLO_SCATTER")
+        indices_shape = self._get_static_tensor_shape(input_tensors[1], 
"STABLEHLO_SCATTER")
+        updates_shape = self._get_static_tensor_shape(input_tensors[2], 
"STABLEHLO_SCATTER")
+        operand_rank = len(operand_shape)
+        indices_rank = len(indices_shape)
+
+        update_window_dims = 
self._get_stablehlo_i64_vector(opts.UpdateWindowDimsAsNumpy(), [])
+        inserted_window_dims = 
self._get_stablehlo_i64_vector(opts.InsertedWindowDimsAsNumpy(), [])
+        scatter_dims_to_operand_dims = self._get_stablehlo_i64_vector(
+            opts.ScatterDimsToOperandDimsAsNumpy(), []
+        )
+        index_vector_dim = int(opts.IndexVectorDim())
+
+        if indices_rank == 0 or index_vector_dim != indices_rank - 1:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_SCATTER only supports trailing index-vector 
dimensions"
+            )
+        if update_window_dims:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_SCATTER only supports point updates without update 
windows"
+            )
+        if inserted_window_dims != list(range(operand_rank)):
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_SCATTER only supports point updates for every 
operand dimension"
+            )
+        if scatter_dims_to_operand_dims != list(range(operand_rank)):
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_SCATTER only supports canonical scatter-to-operand 
dimensions"
+            )
+        if indices_shape[-1] != operand_rank or updates_shape != 
indices_shape[:-1]:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_SCATTER requires point update shapes to match 
scatter indices"
+            )
+
+        body_op = self._get_stablehlo_simple_body_op(
+            int(opts.UpdateComputationSubgraphIndex()), "STABLEHLO_SCATTER", 2
+        )
+        reducer_name = self.get_op_code_str(body_op)
+        reductions = {
+            "STABLEHLO_ADD": "add",
+            "STABLEHLO_MAXIMUM": "max",
+            "STABLEHLO_MINIMUM": "min",
+            "STABLEHLO_MULTIPLY": "mul",
+        }
+        if reducer_name not in reductions:
+            raise tvm.error.OpNotImplemented(
+                f"STABLEHLO_SCATTER reducer {reducer_name} is not supported"
+            )
+
+        operand = self.get_tensor_expr(input_tensors[0])
+        indices = self.get_tensor_expr(input_tensors[1])
+        updates = self.get_tensor_expr(input_tensors[2])
+        return self.bb.normalize(
+            relax.op.scatter_nd(operand, indices, updates, 
reductions[reducer_name])
+        )
+
+    def _convert_stablehlo_composite(self, op):
+        """Convert STABLEHLO_COMPOSITE by inlining a simple decomposition 
subgraph."""
+        from tflite.StableHLOCompositeOptions import StableHLOCompositeOptions
+
+        input_tensors = self.get_input_tensors(op)
+        output_tensors = self.get_output_tensors(op)
+        if len(output_tensors) != 1:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_COMPOSITE only supports single-output 
decompositions"
+            )
+
+        opts = self._get_stablehlo_options(op, StableHLOCompositeOptions)
+        composite_name = opts.Name()
+        composite_name = (
+            composite_name.decode("utf-8") if composite_name is not None else 
"<unnamed>"
+        )
+        if opts.CompositeAttributesLength() != 0:
+            raise tvm.error.OpNotImplemented(
+                f"STABLEHLO_COMPOSITE {composite_name} with composite 
attributes is not supported"
+            )
+
+        decomposition_subgraph_index = int(opts.DecompositionSubgraphIndex())
+        if (
+            decomposition_subgraph_index <= 0
+            or decomposition_subgraph_index >= self.model.SubgraphsLength()
+        ):
+            raise tvm.error.OpNotImplemented(
+                f"STABLEHLO_COMPOSITE {composite_name} requires a valid 
decomposition subgraph"
+            )
+        decomposition_subgraph = 
self.model.Subgraphs(decomposition_subgraph_index)
+        if decomposition_subgraph.InputsLength() != len(input_tensors):
+            raise tvm.error.OpNotImplemented(
+                f"STABLEHLO_COMPOSITE {composite_name} decomposition input 
count mismatch"
+            )
+        if decomposition_subgraph.OutputsLength() != 1:
+            raise tvm.error.OpNotImplemented(
+                f"STABLEHLO_COMPOSITE {composite_name} only supports 
single-output decompositions"
+            )
+
+        decomposition_exp_tab = ExprTable()
+        decomposition_converter = OperatorConverter(
+            self.model, decomposition_subgraph, decomposition_exp_tab, self.bb
+        )
+        for decomposition_input_idx, composite_input in zip(
+            decomposition_subgraph.InputsAsNumpy(), input_tensors
+        ):
+            decomposition_input_name = get_tensor_name(
+                decomposition_subgraph, int(decomposition_input_idx)
+            )
+            decomposition_exp_tab.set_expr(
+                decomposition_input_name,
+                self.get_tensor_expr(composite_input),
+                force_override=True,
+            )
+
+        decomposition_converter.check_unsupported_ops()
+        decomposition_converter.convert_op_to_relax()
+        decomposition_output_idx = int(decomposition_subgraph.Outputs(0))
+        decomposition_output_tensor = decomposition_converter.get_tensors(
+            [decomposition_output_idx]
+        )[0]
+        for const_expr, value in decomposition_exp_tab.params.values():
+            param_name = f"_param_{self.exp_tab.const_ctr}"
+            self.exp_tab.const_ctr += 1
+            self.exp_tab.params[param_name] = (const_expr, value)
+        return 
decomposition_converter.get_tensor_expr(decomposition_output_tensor)
+
+    def _convert_stablehlo_sort(self, op):
+        """Convert the single-input STABLEHLO_SORT subset to Relax sort."""
+        from tflite.StablehloCompareOptions import StablehloCompareOptions
+        from tflite.StablehloComparisonDirection import 
StablehloComparisonDirection
+        from tflite.StablehloComparisonType import StablehloComparisonType
+        from tflite.StablehloSortOptions import StablehloSortOptions
+
+        input_tensors = self.get_input_tensors(op)
+        output_tensors = self.get_output_tensors(op)
+        if len(input_tensors) != 1 or len(output_tensors) != 1:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_SORT only supports single-input single-output sort"
+            )
+
+        opts = self._get_stablehlo_options(op, StablehloSortOptions)
+        if opts.IsStable():
+            raise tvm.error.OpNotImplemented("STABLEHLO_SORT stable sort is 
not supported")
+
+        body_op = self._get_stablehlo_simple_body_op(
+            int(opts.ComparatorSubgraphIndex()), "STABLEHLO_SORT", 2
+        )
+        comparator_name = self.get_op_code_str(body_op)
+        if comparator_name != "STABLEHLO_COMPARE":
+            raise tvm.error.OpNotImplemented(
+                f"STABLEHLO_SORT comparator {comparator_name} is not supported"
+            )
+
+        compare_opts = self._get_stablehlo_options(body_op, 
StablehloCompareOptions)
+        if (
+            compare_opts.CompareType()
+            == 
StablehloComparisonType.STABLEHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER
+        ):
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_SORT with TOTALORDER comparator is not supported"
+            )
+
+        direction = compare_opts.ComparisonDirection()
+        _DIR = StablehloComparisonDirection
+        if direction == _DIR.STABLEHLO_COMPARISON_DIRECTION_LT:
+            descending = False
+        elif direction == _DIR.STABLEHLO_COMPARISON_DIRECTION_GT:
+            descending = True
+        else:
+            raise tvm.error.OpNotImplemented("STABLEHLO_SORT only supports LT 
or GT comparators")
+
+        data = self.get_tensor_expr(input_tensors[0])
+        return self.bb.normalize(
+            relax.op.sort(data, axis=int(opts.Dimension()), 
descending=descending)
+        )
+
     def _convert_stablehlo_convert(self, op):
         """Convert STABLEHLO_CONVERT to Relax (astype).
 
@@ -1719,6 +2136,189 @@ class OperatorConverter:
 
         return self.bb.normalize(relax.op.dynamic_strided_slice(operand, 
begin, end, strides))
 
+    def _convert_stablehlo_dynamic_update_slice(self, op):
+        """Convert STABLEHLO_DYNAMIC_UPDATE_SLICE to Relax for static 
starts."""
+        input_tensors = self.get_input_tensors(op)
+        # operand + update + N start-index scalars
+        assert len(input_tensors) >= 3, "input tensors length should be >= 3"
+        assert len(self.get_output_tensors(op)) == 1
+
+        operand_tensor = input_tensors[0]
+        update_tensor = input_tensors[1]
+        start_tensors = input_tensors[2:]
+
+        op_name = "STABLEHLO_DYNAMIC_UPDATE_SLICE"
+        operand_shape = self._get_static_tensor_shape(operand_tensor, op_name)
+        update_shape = self._get_static_tensor_shape(update_tensor, op_name)
+        rank = len(operand_shape)
+        if len(update_shape) != rank or len(start_tensors) != rank:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_DYNAMIC_UPDATE_SLICE requires operand, update, "
+                "and start-index ranks to match"
+            )
+
+        if any(self.has_expr(t.tensor_idx) for t in start_tensors):
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_DYNAMIC_UPDATE_SLICE with dynamic start indices is 
not supported"
+            )
+
+        start_vals = [int(np.asarray(self.get_tensor_value(t)).item()) for t 
in start_tensors]
+        for start, size, dim in zip(start_vals, update_shape, operand_shape):
+            if start < 0 or start + size > dim:
+                raise tvm.error.OpNotImplemented(
+                    "STABLEHLO_DYNAMIC_UPDATE_SLICE with out-of-bounds update "
+                    "indices is not supported"
+                )
+
+        update_indices = np.indices(update_shape, dtype=np.int64)
+        for axis, start in enumerate(start_vals):
+            update_indices[axis] += start
+        update_indices = np.moveaxis(update_indices, 0, -1)
+
+        operand = self.get_tensor_expr(operand_tensor)
+        update = self.get_tensor_expr(update_tensor)
+        indices = self.bb.normalize(relax.const(update_indices, dtype="int64"))
+        return self.bb.normalize(relax.op.scatter_nd(operand, indices, update, 
"update"))
+
+    def _convert_stablehlo_dot_general(self, op):
+        """Convert the canonical 2D STABLEHLO_DOT_GENERAL subset to Relax 
matmul."""
+        from tflite.StablehloDotGeneralOptions import 
StablehloDotGeneralOptions
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+        assert len(self.get_output_tensors(op)) == 1
+
+        opts = self._get_stablehlo_options(op, StablehloDotGeneralOptions)
+        lhs_batch_dims = 
self._get_stablehlo_i64_vector(opts.LhsBatchingDimensionsAsNumpy(), [])
+        rhs_batch_dims = 
self._get_stablehlo_i64_vector(opts.RhsBatchingDimensionsAsNumpy(), [])
+        lhs_contract_dims = self._get_stablehlo_i64_vector(
+            opts.LhsContractingDimensionsAsNumpy(), []
+        )
+        rhs_contract_dims = self._get_stablehlo_i64_vector(
+            opts.RhsContractingDimensionsAsNumpy(), []
+        )
+
+        lhs_shape = self._get_static_tensor_shape(input_tensors[0], 
"STABLEHLO_DOT_GENERAL")
+        rhs_shape = self._get_static_tensor_shape(input_tensors[1], 
"STABLEHLO_DOT_GENERAL")
+        if len(lhs_shape) != 2 or len(rhs_shape) != 2:
+            raise tvm.error.OpNotImplemented("STABLEHLO_DOT_GENERAL only 
supports 2D matmul")
+        if lhs_batch_dims or rhs_batch_dims:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_DOT_GENERAL with batching dimensions is not 
supported"
+            )
+        if lhs_contract_dims != [1] or rhs_contract_dims != [0]:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_DOT_GENERAL only supports canonical contracting 
dimensions"
+            )
+
+        lhs = self.get_tensor_expr(input_tensors[0])
+        rhs = self.get_tensor_expr(input_tensors[1])
+        return self.bb.normalize(relax.op.matmul(lhs, rhs))
+
+    def _convert_stablehlo_convolution(self, op):
+        """Convert the canonical 2D NHWC/HWIO STABLEHLO_CONVOLUTION subset."""
+        from tflite.StablehloConvolutionOptions import 
StablehloConvolutionOptions
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+        assert len(self.get_output_tensors(op)) == 1
+
+        opts = self._get_stablehlo_options(op, StablehloConvolutionOptions)
+        input_spatial_dims = self._get_stablehlo_i64_vector(
+            opts.InputSpatialDimensionsAsNumpy(), []
+        )
+        kernel_spatial_dims = self._get_stablehlo_i64_vector(
+            opts.KernelSpatialDimensionsAsNumpy(), []
+        )
+        output_spatial_dims = self._get_stablehlo_i64_vector(
+            opts.OutputSpatialDimensionsAsNumpy(), []
+        )
+        if input_spatial_dims != [1, 2]:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_CONVOLUTION only supports NHWC input layout"
+            )
+        if kernel_spatial_dims != [0, 1]:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_CONVOLUTION only supports HWIO kernel layout"
+            )
+        if output_spatial_dims != [1, 2]:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_CONVOLUTION only supports NHWC output layout"
+            )
+
+        if (
+            int(opts.InputBatchDimension()) != 0
+            or int(opts.InputFeatureDimension()) != 3
+            or int(opts.KernelInputFeatureDimension()) != 2
+            or int(opts.KernelOutputFeatureDimension()) != 3
+            or int(opts.OutputBatchDimension()) != 0
+            or int(opts.OutputFeatureDimension()) != 3
+        ):
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_CONVOLUTION only supports canonical NHWC/HWIO 
dimension numbers"
+            )
+        if int(opts.BatchGroupCount()) != 1:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_CONVOLUTION with batch_group_count > 1 is not 
supported"
+            )
+        if int(opts.FeatureGroupCount()) != 1:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_CONVOLUTION with feature_group_count > 1 is not 
supported"
+            )
+
+        data_shape = self._get_static_tensor_shape(input_tensors[0], 
"STABLEHLO_CONVOLUTION")
+        kernel_shape = self._get_static_tensor_shape(input_tensors[1], 
"STABLEHLO_CONVOLUTION")
+        if len(data_shape) != 4 or len(kernel_shape) != 4:
+            raise tvm.error.OpNotImplemented("STABLEHLO_CONVOLUTION only 
supports 2D convolution")
+        if data_shape[3] != kernel_shape[2]:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_CONVOLUTION input channels must match kernel input 
channels"
+            )
+
+        window_strides = 
self._get_stablehlo_i64_vector(opts.WindowStridesAsNumpy(), [1, 1])
+        padding = self._get_stablehlo_i64_vector(opts.PaddingAsNumpy(), [0, 0, 
0, 0])
+        lhs_dilation = 
self._get_stablehlo_i64_vector(opts.LhsDilationAsNumpy(), [1, 1])
+        rhs_dilation = 
self._get_stablehlo_i64_vector(opts.RhsDilationAsNumpy(), [1, 1])
+        window_reversal = opts.WindowReversalAsNumpy()
+        window_reversal = (
+            [False, False] if window_reversal is None else [bool(v) for v in 
window_reversal]
+        )
+
+        if len(window_strides) != 2 or len(rhs_dilation) != 2:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_CONVOLUTION only supports two spatial dimensions"
+            )
+        if lhs_dilation != [1, 1]:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_CONVOLUTION with lhs dilation is not supported"
+            )
+        if any(window_reversal):
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_CONVOLUTION with window reversal is not supported"
+            )
+        if len(padding) != 4:
+            raise tvm.error.OpNotImplemented(
+                "STABLEHLO_CONVOLUTION only supports 2D low/high padding"
+            )
+
+        # StableHLO stores padding as [low_h, high_h, low_w, high_w].
+        relax_padding = [padding[0], padding[2], padding[1], padding[3]]
+        data = self.get_tensor_expr(input_tensors[0])
+        kernel = self.get_tensor_expr(input_tensors[1])
+        self._ensure_stablehlo_float_dtype(data, "STABLEHLO_CONVOLUTION")
+        self._ensure_stablehlo_float_dtype(kernel, "STABLEHLO_CONVOLUTION")
+        return self.bb.normalize(
+            relax.op.nn.conv2d(
+                data,
+                kernel,
+                strides=window_strides,
+                padding=relax_padding,
+                dilation=rhs_dilation,
+                data_layout="NHWC",
+                kernel_layout="HWIO",
+            )
+        )
+
     def _convert_stablehlo_gather(self, op):
         """Convert STABLEHLO_GATHER to Relax (take-equivalent subset only).
 
@@ -5528,19 +6128,18 @@ def _input_type(model):
     assert subgraph_count > 0
     shape_dict = {}
     dtype_dict = {}
-    for subgraph_index in range(subgraph_count):
-        subgraph = model.Subgraphs(subgraph_index)
-        inputs_count = subgraph.InputsLength()
-        # TFLite subgraphs can validly have zero inputs (e.g. constant-only 
RANGE models).
-        for input_index in range(inputs_count):
-            input_ = subgraph.Inputs(input_index)
-            assert subgraph.TensorsLength() > input_
-            tensor = subgraph.Tensors(input_)
-            input_shape = tuple(tensor.ShapeAsNumpy())
-            tensor_type = tensor.Type()
-            input_name = get_tensor_name(subgraph, input_)
-            shape_dict[input_name] = input_shape
-            dtype_dict[input_name] = _decode_type(tensor_type)
+    subgraph = model.Subgraphs(0)
+    inputs_count = subgraph.InputsLength()
+    # TFLite subgraphs can validly have zero inputs (e.g. constant-only RANGE 
models).
+    for input_index in range(inputs_count):
+        input_ = subgraph.Inputs(input_index)
+        assert subgraph.TensorsLength() > input_
+        tensor = subgraph.Tensors(input_)
+        input_shape = tuple(tensor.ShapeAsNumpy())
+        tensor_type = tensor.Type()
+        input_name = get_tensor_name(subgraph, input_)
+        shape_dict[input_name] = input_shape
+        dtype_dict[input_name] = _decode_type(tensor_type)
 
     return shape_dict, dtype_dict
 
@@ -5652,8 +6251,10 @@ def from_tflite(
     if dtype_dict is not None:
         _dtype_dict.update(dtype_dict)
 
-    # keep the same as tflite
-    assert model.SubgraphsLength() == 1, "only support one subgraph (main 
subgraph)"
+    # Only Subgraphs(0) is converted into Relax main. Additional subgraphs are
+    # region bodies referenced by specific TFLite ops and are consumed by those
+    # op converters as needed.
+    assert model.SubgraphsLength() >= 1, "TFLite model must contain at least 
one subgraph"
     subgraph = model.Subgraphs(0)
 
     # model inputs / outputs
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index bb2fb0bfa7..031c1553d8 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -3677,6 +3677,9 @@ _tfl_dilate_options = 
_get_tflite_schema_module("DilateOptions")
 # ── StableHLO BuiltinOptions2 schema modules ────────────────────────────
 _tfl_stablehlo_concat_opts = 
_get_tflite_schema_module("StablehloConcatenateOptions")
 _tfl_stablehlo_bcast_opts = 
_get_tflite_schema_module("StablehloBroadcastInDimOptions")
+_tfl_stablehlo_composite_opts = 
_get_tflite_schema_module("StableHLOCompositeOptions")
+_tfl_stablehlo_conv_opts = 
_get_tflite_schema_module("StablehloConvolutionOptions")
+_tfl_stablehlo_dot_opts = 
_get_tflite_schema_module("StablehloDotGeneralOptions")
 _tfl_stablehlo_iota_opts = _get_tflite_schema_module("StablehloIotaOptions")
 _tfl_stablehlo_compare_opts = 
_get_tflite_schema_module("StablehloCompareOptions")
 _tfl_stablehlo_comp_dir = 
_get_tflite_schema_module("StablehloComparisonDirection")
@@ -3684,6 +3687,10 @@ _tfl_stablehlo_comp_type = 
_get_tflite_schema_module("StablehloComparisonType")
 _tfl_stablehlo_pad_opts = _get_tflite_schema_module("StablehloPadOptions")
 _tfl_stablehlo_dyn_slice_opts = 
_get_tflite_schema_module("StablehloDynamicSliceOptions")
 _tfl_stablehlo_gather_opts = 
_get_tflite_schema_module("StablehloGatherOptions")
+_tfl_stablehlo_reduce_opts = 
_get_tflite_schema_module("StablehloReduceOptions")
+_tfl_stablehlo_reduce_window_opts = 
_get_tflite_schema_module("StablehloReduceWindowOptions")
+_tfl_stablehlo_scatter_opts = 
_get_tflite_schema_module("StablehloScatterOptions")
+_tfl_stablehlo_sort_opts = _get_tflite_schema_module("StablehloSortOptions")
 _tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata")
 _tfl_fully_connected_options = 
_get_tflite_schema_module("FullyConnectedOptions")
 _tfl_int32_vector = _get_tflite_schema_module("Int32Vector")
@@ -3721,6 +3728,20 @@ def _tflite_int32_vector(builder, start_vector_fn, 
values):
     return builder.EndVector()
 
 
+def _tflite_int64_vector(builder, start_vector_fn, values):
+    start_vector_fn(builder, len(values))
+    for value in reversed(values):
+        builder.PrependInt64(value)
+    return builder.EndVector()
+
+
+def _tflite_bool_vector(builder, start_vector_fn, values):
+    start_vector_fn(builder, len(values))
+    for value in reversed(values):
+        builder.PrependBool(value)
+    return builder.EndVector()
+
+
 def _tflite_offset_vector(builder, start_vector_fn, offsets):
     start_vector_fn(builder, len(offsets))
     for offset in reversed(offsets):
@@ -3834,12 +3855,15 @@ def _build_subgraph(builder, *, tensors, operators, 
inputs, outputs):
     return _tfl_subgraph.SubGraphEnd(builder)
 
 
-def _finish_tflite_model(builder, *, subgraph, operator_codes, buffers):
+def _finish_tflite_model(builder, *, subgraph, operator_codes, buffers, 
extra_subgraphs=None):
+    all_subgraphs = [subgraph] + (extra_subgraphs or [])
     buffers_vec = _tflite_offset_vector(builder, 
_tfl_model.ModelStartBuffersVector, buffers)
     opcodes_vec = _tflite_offset_vector(
         builder, _tfl_model.ModelStartOperatorCodesVector, operator_codes
     )
-    subgraphs_vec = _tflite_offset_vector(builder, 
_tfl_model.ModelStartSubgraphsVector, [subgraph])
+    subgraphs_vec = _tflite_offset_vector(
+        builder, _tfl_model.ModelStartSubgraphsVector, all_subgraphs
+    )
 
     _tfl_model.ModelStart(builder)
     _tfl_model.ModelAddBuffers(builder, buffers_vec)
@@ -3896,6 +3920,453 @@ def _build_stablehlo_model(*, builtin_name, 
input_count):
     )
 
 
+def _build_stablehlo_model_with_unused_subgraph():
+    """Build a StableHLO model with an unused extra subgraph."""
+    builder = flatbuffers.Builder(1024)
+    builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_ADD")
+
+    main_tensors = [_build_tensor(builder, buffer_idx, [2, 2]) for buffer_idx 
in range(3)]
+    main_op = _build_operator(builder, 0, [0, 1], [2])
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=[main_op],
+        inputs=[0, 1],
+        outputs=[2],
+    )
+
+    # Give the unused subgraph a conflicting input tensor name and different
+    # shape. from_tflite should infer the main function input shape only from
+    # Subgraphs(0).
+    extra_tensors = [_build_tensor(builder, buffer_idx, [4, 4]) for buffer_idx 
in range(3, 6)]
+    extra_op = _build_operator(builder, 0, [0, 1], [2])
+    extra_subgraph = _build_subgraph(
+        builder,
+        tensors=extra_tensors,
+        operators=[extra_op],
+        inputs=[0, 1],
+        outputs=[2],
+    )
+
+    operator_codes = [_build_operator_code(builder, builtin_op)]
+    buffers = [_build_buffer(builder) for _ in range(6)]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[extra_subgraph],
+        operator_codes=operator_codes,
+        buffers=buffers,
+    )
+
+
+def _build_stablehlo_reduce_model(reducer_name, init_value):
+    """Build a single-input STABLEHLO_REDUCE model with a binary reducer 
body."""
+    builder = flatbuffers.Builder(1024)
+
+    dimensions_vec = _tflite_int64_vector(
+        builder,
+        _tfl_stablehlo_reduce_opts.StablehloReduceOptionsStartDimensionsVector,
+        [1],
+    )
+    _tfl_stablehlo_reduce_opts.StablehloReduceOptionsStart(builder)
+    _tfl_stablehlo_reduce_opts.StablehloReduceOptionsAddDimensions(builder, 
dimensions_vec)
+    
_tfl_stablehlo_reduce_opts.StablehloReduceOptionsAddBodySubgraphIndex(builder, 
1)
+    reduce_opts = _tfl_stablehlo_reduce_opts.StablehloReduceOptionsEnd(builder)
+
+    reduce_builtin = _get_stablehlo_builtin_operator("STABLEHLO_REDUCE")
+    reducer_builtin = _get_stablehlo_builtin_operator(reducer_name)
+    reduce_code = _build_operator_code(builder, reduce_builtin)
+    reducer_code = _build_operator_code(builder, reducer_builtin)
+
+    main_tensors = [
+        _build_tensor(builder, 0, [2, 3]),
+        _build_tensor(builder, 1, []),
+        _build_tensor(builder, 2, [2]),
+    ]
+    reduce_op = _build_operator(
+        builder,
+        0,
+        [0, 1],
+        [2],
+        builtin_options2_type=_tfl_builtin_options2.StablehloReduceOptions,
+        builtin_options2=reduce_opts,
+    )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=[reduce_op],
+        inputs=[0],
+        outputs=[2],
+    )
+
+    body_tensors = [_build_tensor(builder, buffer_idx, []) for buffer_idx in 
range(3, 6)]
+    reducer_op = _build_operator(builder, 1, [0, 1], [2])
+    body_subgraph = _build_subgraph(
+        builder,
+        tensors=body_tensors,
+        operators=[reducer_op],
+        inputs=[0, 1],
+        outputs=[2],
+    )
+
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder, np.array(init_value, 
dtype=np.float32).tobytes()),
+        _build_buffer(builder),
+        _build_buffer(builder),
+        _build_buffer(builder),
+        _build_buffer(builder),
+    ]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[body_subgraph],
+        operator_codes=[reduce_code, reducer_code],
+        buffers=buffers,
+    )
+
+
+def _build_stablehlo_sort_model(comparison_direction, is_stable=False):
+    """Build a single-input STABLEHLO_SORT model with a compare body."""
+    builder = flatbuffers.Builder(1024)
+
+    _tfl_stablehlo_sort_opts.StablehloSortOptionsStart(builder)
+    _tfl_stablehlo_sort_opts.StablehloSortOptionsAddDimension(builder, 1)
+    _tfl_stablehlo_sort_opts.StablehloSortOptionsAddIsStable(builder, 
is_stable)
+    
_tfl_stablehlo_sort_opts.StablehloSortOptionsAddComparatorSubgraphIndex(builder,
 1)
+    sort_opts = _tfl_stablehlo_sort_opts.StablehloSortOptionsEnd(builder)
+
+    _tfl_stablehlo_compare_opts.StablehloCompareOptionsStart(builder)
+    _tfl_stablehlo_compare_opts.StablehloCompareOptionsAddComparisonDirection(
+        builder, comparison_direction
+    )
+    compare_opts = 
_tfl_stablehlo_compare_opts.StablehloCompareOptionsEnd(builder)
+
+    sort_builtin = _get_stablehlo_builtin_operator("STABLEHLO_SORT")
+    compare_builtin = _get_stablehlo_builtin_operator("STABLEHLO_COMPARE")
+    sort_code = _build_operator_code(builder, sort_builtin)
+    compare_code = _build_operator_code(builder, compare_builtin)
+
+    main_tensors = [
+        _build_tensor(builder, 0, [2, 3]),
+        _build_tensor(builder, 1, [2, 3]),
+    ]
+    sort_op = _build_operator(
+        builder,
+        0,
+        [0],
+        [1],
+        builtin_options2_type=_tfl_builtin_options2.StablehloSortOptions,
+        builtin_options2=sort_opts,
+    )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=[sort_op],
+        inputs=[0],
+        outputs=[1],
+    )
+
+    body_tensors = [
+        _build_tensor(builder, 2, []),
+        _build_tensor(builder, 3, []),
+        _build_tensor(builder, 4, [], tensor_type=_tfl_tensor_type.BOOL),
+    ]
+    compare_op = _build_operator(
+        builder,
+        1,
+        [0, 1],
+        [2],
+        builtin_options2_type=_tfl_builtin_options2.StablehloCompareOptions,
+        builtin_options2=compare_opts,
+    )
+    body_subgraph = _build_subgraph(
+        builder,
+        tensors=body_tensors,
+        operators=[compare_op],
+        inputs=[0, 1],
+        outputs=[2],
+    )
+
+    buffers = [_build_buffer(builder) for _ in range(5)]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[body_subgraph],
+        operator_codes=[sort_code, compare_code],
+        buffers=buffers,
+    )
+
+
+def _build_stablehlo_reduce_window_model(
+    reducer_name="STABLEHLO_MAXIMUM",
+    init_value=-np.inf,
+    base_dilations=None,
+):
+    """Build an NHWC 2D STABLEHLO_REDUCE_WINDOW model."""
+    builder = flatbuffers.Builder(1024)
+    if base_dilations is None:
+        base_dilations = [1, 1, 1, 1]
+
+    window_dimensions_vec = _tflite_int64_vector(
+        builder,
+        
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartWindowDimensionsVector,
+        [1, 2, 2, 1],
+    )
+    window_strides_vec = _tflite_int64_vector(
+        builder,
+        
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartWindowStridesVector,
+        [1, 2, 2, 1],
+    )
+    base_dilations_vec = _tflite_int64_vector(
+        builder,
+        
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartBaseDilationsVector,
+        base_dilations,
+    )
+    window_dilations_vec = _tflite_int64_vector(
+        builder,
+        
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartWindowDilationsVector,
+        [1, 1, 1, 1],
+    )
+    padding_vec = _tflite_int64_vector(
+        builder,
+        
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartPaddingVector,
+        [0, 0, 0, 0, 0, 0, 0, 0],
+    )
+
+    
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStart(builder)
+    
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddWindowDimensions(
+        builder, window_dimensions_vec
+    )
+    
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddWindowStrides(
+        builder, window_strides_vec
+    )
+    
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddBaseDilations(
+        builder, base_dilations_vec
+    )
+    
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddWindowDilations(
+        builder, window_dilations_vec
+    )
+    
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddPadding(builder,
 padding_vec)
+    
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddBodySubgraphIndex(builder,
 1)
+    reduce_window_opts = 
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsEnd(builder)
+
+    reduce_window_builtin = 
_get_stablehlo_builtin_operator("STABLEHLO_REDUCE_WINDOW")
+    reducer_builtin = _get_stablehlo_builtin_operator(reducer_name)
+    reduce_window_code = _build_operator_code(builder, reduce_window_builtin)
+    reducer_code = _build_operator_code(builder, reducer_builtin)
+
+    main_tensors = [
+        _build_tensor(builder, 0, [1, 4, 4, 1]),
+        _build_tensor(builder, 1, []),
+        _build_tensor(builder, 2, [1, 2, 2, 1]),
+    ]
+    reduce_window_op = _build_operator(
+        builder,
+        0,
+        [0, 1],
+        [2],
+        
builtin_options2_type=_tfl_builtin_options2.StablehloReduceWindowOptions,
+        builtin_options2=reduce_window_opts,
+    )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=[reduce_window_op],
+        inputs=[0],
+        outputs=[2],
+    )
+
+    body_tensors = [_build_tensor(builder, buffer_idx, []) for buffer_idx in 
range(3, 6)]
+    reducer_op = _build_operator(builder, 1, [0, 1], [2])
+    body_subgraph = _build_subgraph(
+        builder,
+        tensors=body_tensors,
+        operators=[reducer_op],
+        inputs=[0, 1],
+        outputs=[2],
+    )
+
+    buffers = [
+        _build_buffer(builder),
+        _build_buffer(builder, np.array(init_value, 
dtype=np.float32).tobytes()),
+        _build_buffer(builder),
+        _build_buffer(builder),
+        _build_buffer(builder),
+        _build_buffer(builder),
+    ]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[body_subgraph],
+        operator_codes=[reduce_window_code, reducer_code],
+        buffers=buffers,
+    )
+
+
+def _build_stablehlo_scatter_model(reducer_name="STABLEHLO_ADD", 
update_window_dims=None):
+    """Build a canonical point-update STABLEHLO_SCATTER model."""
+    builder = flatbuffers.Builder(1024)
+    if update_window_dims is None:
+        update_window_dims = []
+
+    update_window_dims_vec = _tflite_int64_vector(
+        builder,
+        
_tfl_stablehlo_scatter_opts.StablehloScatterOptionsStartUpdateWindowDimsVector,
+        update_window_dims,
+    )
+    inserted_window_dims_vec = _tflite_int64_vector(
+        builder,
+        
_tfl_stablehlo_scatter_opts.StablehloScatterOptionsStartInsertedWindowDimsVector,
+        [0],
+    )
+    scatter_dims_vec = _tflite_int64_vector(
+        builder,
+        
_tfl_stablehlo_scatter_opts.StablehloScatterOptionsStartScatterDimsToOperandDimsVector,
+        [0],
+    )
+
+    _tfl_stablehlo_scatter_opts.StablehloScatterOptionsStart(builder)
+    _tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddUpdateWindowDims(
+        builder, update_window_dims_vec
+    )
+    _tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddInsertedWindowDims(
+        builder, inserted_window_dims_vec
+    )
+    
_tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddScatterDimsToOperandDims(
+        builder, scatter_dims_vec
+    )
+    
_tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddIndexVectorDim(builder, 1)
+    
_tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddUpdateComputationSubgraphIndex(builder,
 1)
+    scatter_opts = 
_tfl_stablehlo_scatter_opts.StablehloScatterOptionsEnd(builder)
+
+    scatter_builtin = _get_stablehlo_builtin_operator("STABLEHLO_SCATTER")
+    reducer_builtin = _get_stablehlo_builtin_operator(reducer_name)
+    scatter_code = _build_operator_code(builder, scatter_builtin)
+    reducer_code = _build_operator_code(builder, reducer_builtin)
+
+    main_tensors = [
+        _build_tensor(builder, 0, [4]),
+        _build_tensor(builder, 1, [2, 1], tensor_type=_tfl_tensor_type.INT32),
+        _build_tensor(builder, 2, [2]),
+        _build_tensor(builder, 3, [4]),
+    ]
+    scatter_op = _build_operator(
+        builder,
+        0,
+        [0, 1, 2],
+        [3],
+        builtin_options2_type=_tfl_builtin_options2.StablehloScatterOptions,
+        builtin_options2=scatter_opts,
+    )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=[scatter_op],
+        inputs=[0, 1, 2],
+        outputs=[3],
+    )
+
+    body_tensors = [_build_tensor(builder, buffer_idx, []) for buffer_idx in 
range(4, 7)]
+    reducer_op = _build_operator(builder, 1, [0, 1], [2])
+    body_subgraph = _build_subgraph(
+        builder,
+        tensors=body_tensors,
+        operators=[reducer_op],
+        inputs=[0, 1],
+        outputs=[2],
+    )
+
+    buffers = [_build_buffer(builder) for _ in range(7)]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[body_subgraph],
+        operator_codes=[scatter_code, reducer_code],
+        buffers=buffers,
+    )
+
+
+def _build_stablehlo_composite_model(with_attributes=False, 
use_main_input_after_composite=False):
+    """Build a STABLEHLO_COMPOSITE model that decomposes to 
STABLEHLO_NEGATE."""
+    builder = flatbuffers.Builder(1024)
+
+    name = builder.CreateString("test.negate")
+    attributes = None
+    if with_attributes:
+        
_tfl_stablehlo_composite_opts.StableHLOCompositeOptionsStartCompositeAttributesVector(
+            builder, 1
+        )
+        builder.PrependUint8(1)
+        attributes = builder.EndVector()
+
+    _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsStart(builder)
+    _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsAddName(builder, 
name)
+    _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsAddVersion(builder, 
1)
+    
_tfl_stablehlo_composite_opts.StableHLOCompositeOptionsAddDecompositionSubgraphIndex(builder,
 1)
+    if attributes is not None:
+        
_tfl_stablehlo_composite_opts.StableHLOCompositeOptionsAddCompositeAttributes(
+            builder, attributes
+        )
+    composite_opts = 
_tfl_stablehlo_composite_opts.StableHLOCompositeOptionsEnd(builder)
+
+    composite_builtin = _get_stablehlo_builtin_operator("STABLEHLO_COMPOSITE")
+    negate_builtin = _get_stablehlo_builtin_operator("STABLEHLO_NEGATE")
+    add_builtin = _get_stablehlo_builtin_operator("STABLEHLO_ADD")
+    composite_code = _build_operator_code(builder, composite_builtin)
+    negate_code = _build_operator_code(builder, negate_builtin)
+    add_code = _build_operator_code(builder, add_builtin)
+
+    main_tensors = [
+        _build_tensor(builder, 0, [2, 2]),
+        _build_tensor(builder, 1, [2, 2]),
+        _build_tensor(builder, 2, [2, 2]),
+    ]
+    composite_op = _build_operator(
+        builder,
+        0,
+        [0],
+        [1],
+        builtin_options2_type=_tfl_builtin_options2.StableHLOCompositeOptions,
+        builtin_options2=composite_opts,
+    )
+    main_ops = [composite_op]
+    main_outputs = [1]
+    if use_main_input_after_composite:
+        main_ops.append(_build_operator(builder, 2, [0, 1], [2]))
+        main_outputs = [2]
+
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=main_ops,
+        inputs=[0],
+        outputs=main_outputs,
+    )
+
+    decomposition_tensors = [
+        _build_tensor(builder, 2, [2, 2]),
+        _build_tensor(builder, 3, [2, 2]),
+    ]
+    negate_op = _build_operator(builder, 1, [0], [1])
+    decomposition_subgraph = _build_subgraph(
+        builder,
+        tensors=decomposition_tensors,
+        operators=[negate_op],
+        inputs=[0],
+        outputs=[1],
+    )
+
+    buffers = [_build_buffer(builder) for _ in range(4)]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        extra_subgraphs=[decomposition_subgraph],
+        operator_codes=[composite_code, negate_code, add_code],
+        buffers=buffers,
+    )
+
+
 def _build_stablehlo_typed_binary_model(*, builtin_name, tensor_type):
     """Build a minimal TFLite StableHLO binary model with the requested tensor 
type."""
     builder = flatbuffers.Builder(1024)
@@ -3972,19 +4443,302 @@ def test_stablehlo_binary(builtin_name, relax_op):
     @I.ir_module
     class Expected:
         @R.function
-        def main(
-            x: R.Tensor((2, 2), dtype="float32"),
-            y: R.Tensor((2, 2), dtype="float32"),
-        ) -> R.Tensor((2, 2), dtype="float32"):
-            R.func_attr({"num_input": 2})
+        def main(
+            x: R.Tensor((2, 2), dtype="float32"),
+            y: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tensor((2, 2), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = relax_op(x, y)
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_model_with_unused_subgraph():
+    """TFLite StableHLO import ignores unused non-main subgraphs."""
+    mod = 
_load_model_from_buffer(_build_stablehlo_model_with_unused_subgraph())
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 2), dtype="float32"),
+            y: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tensor((2, 2), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = R.add(x, y)
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
[email protected](
+    "reducer_name, init_value, relax_op",
+    [
+        ("STABLEHLO_ADD", 0.0, R.sum),
+        ("STABLEHLO_MAXIMUM", -np.inf, R.max),
+        ("STABLEHLO_MINIMUM", np.inf, R.min),
+        ("STABLEHLO_MULTIPLY", 1.0, R.prod),
+    ],
+)
+def test_stablehlo_reduce(reducer_name, init_value, relax_op):
+    """TFLite StableHLO REDUCE with simple binary reducer body subgraphs."""
+    mod = _load_model_from_buffer(_build_stablehlo_reduce_model(reducer_name, 
init_value))
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2,), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((2,), dtype="float32") = relax_op(x, axis=[1], 
keepdims=False)
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_reduce_unsupported_reducer():
+    """TFLite StableHLO REDUCE rejects unsupported body reducer ops."""
+    buf = _build_stablehlo_reduce_model("STABLEHLO_SUBTRACT", 0.0)
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="reducer"):
+        from_tflite(tflite_model)
+
+
+def test_stablehlo_reduce_non_identity_init_unsupported():
+    """TFLite StableHLO REDUCE rejects init values that Relax reductions 
cannot express."""
+    buf = _build_stablehlo_reduce_model("STABLEHLO_ADD", 1.0)
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="init value"):
+        from_tflite(tflite_model)
+
+
[email protected](
+    "comparison_direction, descending",
+    [
+        (
+            
_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LT,
+            False,
+        ),
+        (
+            
_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_GT,
+            True,
+        ),
+    ],
+)
+def test_stablehlo_sort(comparison_direction, descending):
+    """TFLite StableHLO SORT with LT/GT scalar compare body subgraphs."""
+    mod = 
_load_model_from_buffer(_build_stablehlo_sort_model(comparison_direction))
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((2, 3), dtype="float32") = R.sort(x, axis=1, 
descending=descending)
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_sort_unsupported_comparator():
+    """TFLite StableHLO SORT rejects non-ordering comparators."""
+    _DIR = _tfl_stablehlo_comp_dir.StablehloComparisonDirection
+    buf = _build_stablehlo_sort_model(_DIR.STABLEHLO_COMPARISON_DIRECTION_EQ)
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="LT or GT"):
+        from_tflite(tflite_model)
+
+
+def test_stablehlo_sort_stable_unsupported():
+    """TFLite StableHLO SORT rejects stable sort until Relax exposes that 
contract."""
+    _DIR = _tfl_stablehlo_comp_dir.StablehloComparisonDirection
+    buf = _build_stablehlo_sort_model(_DIR.STABLEHLO_COMPARISON_DIRECTION_LT, 
is_stable=True)
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="stable sort"):
+        from_tflite(tflite_model)
+
+
+def test_stablehlo_reduce_window_max_pool2d():
+    """TFLite StableHLO REDUCE_WINDOW max reducer lowers to NHWC max_pool2d."""
+    mod = _load_model_from_buffer(_build_stablehlo_reduce_window_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((1, 4, 4, 1), dtype="float32"),
+        ) -> R.Tensor((1, 2, 2, 1), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((1, 2, 2, 1), dtype="float32") = R.nn.max_pool2d(
+                    x,
+                    pool_size=[2, 2],
+                    strides=[2, 2],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    ceil_mode=False,
+                    layout="NHWC",
+                    out_layout="NHWC",
+                )
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_reduce_window_unsupported_reducer():
+    """TFLite StableHLO REDUCE_WINDOW rejects non-max reducers in the pool 
subset."""
+    buf = _build_stablehlo_reduce_window_model(reducer_name="STABLEHLO_ADD", 
init_value=0.0)
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="MAXIMUM"):
+        from_tflite(tflite_model)
+
+
+def test_stablehlo_reduce_window_base_dilation_unsupported():
+    """TFLite StableHLO REDUCE_WINDOW rejects base dilation in the pool 
subset."""
+    buf = _build_stablehlo_reduce_window_model(base_dilations=[1, 2, 1, 1])
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="base dilation"):
+        from_tflite(tflite_model)
+
+
[email protected](
+    "reducer_name, reduction",
+    [
+        ("STABLEHLO_ADD", "add"),
+        ("STABLEHLO_MAXIMUM", "max"),
+        ("STABLEHLO_MINIMUM", "min"),
+        ("STABLEHLO_MULTIPLY", "mul"),
+    ],
+)
+def test_stablehlo_scatter(reducer_name, reduction):
+    """TFLite StableHLO SCATTER point updates lower to Relax scatter_nd."""
+    mod = _load_model_from_buffer(_build_stablehlo_scatter_model(reducer_name))
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            operand: R.Tensor((4,), dtype="float32"),
+            indices: R.Tensor((2, 1), dtype="int32"),
+            updates: R.Tensor((2,), dtype="float32"),
+        ) -> R.Tensor((4,), dtype="float32"):
+            R.func_attr({"num_input": 3})
+            with R.dataflow():
+                gv: R.Tensor((4,), dtype="float32") = R.scatter_nd(
+                    operand, indices, updates, reduction=reduction
+                )
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_scatter_unsupported_reducer():
+    """TFLite StableHLO SCATTER rejects unsupported update computation ops."""
+    buf = _build_stablehlo_scatter_model(reducer_name="STABLEHLO_SUBTRACT")
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="reducer"):
+        from_tflite(tflite_model)
+
+
+def test_stablehlo_scatter_update_window_unsupported():
+    """TFLite StableHLO SCATTER rejects slice update windows in the point 
subset."""
+    buf = _build_stablehlo_scatter_model(update_window_dims=[0])
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="point updates"):
+        from_tflite(tflite_model)
+
+
+def test_stablehlo_composite():
+    """TFLite StableHLO COMPOSITE inlines a simple decomposition subgraph."""
+    mod = _load_model_from_buffer(_build_stablehlo_composite_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = R.negative(x)
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_composite_does_not_overwrite_main_bindings():
+    """TFLite StableHLO COMPOSITE decomposition tensor names are scoped 
locally."""
+    mod = _load_model_from_buffer(
+        _build_stablehlo_composite_model(use_main_input_after_composite=True)
+    )
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
             with R.dataflow():
-                gv: R.Tensor((2, 2), dtype="float32") = relax_op(x, y)
+                lv: R.Tensor((2, 2), dtype="float32") = R.negative(x)
+                gv: R.Tensor((2, 2), dtype="float32") = R.add(x, lv)
                 R.output(gv)
             return gv
 
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_stablehlo_composite_attributes_unsupported():
+    """TFLite StableHLO COMPOSITE rejects attributes until they are parsed."""
+    buf = _build_stablehlo_composite_model(with_attributes=True)
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="composite 
attributes"):
+        from_tflite(tflite_model)
+
+
 @pytest.mark.parametrize(
     "builtin_name, relax_op, dtype, tensor_type",
     [
@@ -4987,6 +5741,404 @@ def 
test_stablehlo_dynamic_slice_out_of_bounds_unsupported():
         from_tflite(tflite_model)
 
 
+def test_stablehlo_cbrt():
+    """TFLite StableHLO CBRT uses a sign-preserving composite expression."""
+    mod = _load_model_from_buffer(
+        _build_stablehlo_model(builtin_name="STABLEHLO_CBRT", input_count=1)
+    )
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tensor((2, 2), dtype="float32") = R.negative(x)
+                lv1: R.Tensor((2, 2), dtype="float32") = R.power(lv, 
R.const(1.0 / 3.0, "float32"))
+                lv2: R.Tensor((2, 2), dtype="bool") = R.less(x, R.const(0, 
"float32"))
+                lv3: R.Tensor((2, 2), dtype="float32") = R.negative(lv1)
+                lv4: R.Tensor((2, 2), dtype="float32") = R.power(x, 
R.const(1.0 / 3.0, "float32"))
+                gv: R.Tensor((2, 2), dtype="float32") = R.where(lv2, lv3, lv4)
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_remainder():
+    """TFLite StableHLO REMAINDER uses truncating remainder semantics."""
+    mod = _load_model_from_buffer(
+        _build_stablehlo_model(builtin_name="STABLEHLO_REMAINDER", 
input_count=2)
+    )
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 2), dtype="float32"),
+            y: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tensor((2, 2), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                lv: R.Tensor((2, 2), dtype="float32") = R.divide(x, y)
+                lv1: R.Tensor((2, 2), dtype="float32") = R.trunc(lv)
+                lv2: R.Tensor((2, 2), dtype="float32") = R.multiply(y, lv1)
+                gv: R.Tensor((2, 2), dtype="float32") = R.subtract(x, lv2)
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def _build_stablehlo_dynamic_update_slice_model(start_vals, 
dynamic_starts=False):
+    """Build a minimal STABLEHLO_DYNAMIC_UPDATE_SLICE model."""
+    builder = flatbuffers.Builder(1024)
+    builtin_op = 
_get_stablehlo_builtin_operator("STABLEHLO_DYNAMIC_UPDATE_SLICE")
+    op_code = _build_operator_code(builder, builtin_op)
+
+    t_operand = _build_tensor(builder, 0, [3, 4])
+    t_update = _build_tensor(builder, 1, [2, 2])
+    start_tensors = [
+        _build_tensor(builder, 2 + i, [], tensor_type=_tfl_tensor_type.INT32)
+        for i in range(len(start_vals))
+    ]
+    out_idx = 2 + len(start_vals)
+    t_out = _build_tensor(builder, out_idx, [3, 4])
+    tensors = [t_operand, t_update, *start_tensors, t_out]
+
+    op_inputs = [0, 1, *range(2, out_idx)]
+    op = _build_operator(builder, 0, op_inputs, [out_idx])
+    subgraph_inputs = op_inputs if dynamic_starts else [0, 1]
+    subgraph = _build_subgraph(
+        builder,
+        tensors=tensors,
+        operators=[op],
+        inputs=subgraph_inputs,
+        outputs=[out_idx],
+    )
+    if dynamic_starts:
+        buffers = [_build_buffer(builder) for _ in range(out_idx + 1)]
+    else:
+        start_buffers = [
+            _build_buffer(builder, np.array([start], dtype=np.int32).tobytes())
+            for start in start_vals
+        ]
+        buffers = [
+            _build_buffer(builder),
+            _build_buffer(builder),
+            *start_buffers,
+            _build_buffer(builder),
+        ]
+
+    return _finish_tflite_model(
+        builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+    )
+
+
+def test_stablehlo_dynamic_update_slice():
+    """TFLite StableHLO DYNAMIC_UPDATE_SLICE with static starts."""
+    mod = 
_load_model_from_buffer(_build_stablehlo_dynamic_update_slice_model([1, 1]))
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            operand: R.Tensor((3, 4), dtype="float32"),
+            update: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tensor((3, 4), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                gv: R.Tensor((3, 4), dtype="float32") = R.scatter_nd(
+                    operand,
+                    R.const([[[1, 1], [1, 2]], [[2, 1], [2, 2]]], 
dtype="int64"),
+                    update,
+                    reduction="update",
+                )
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_dynamic_update_slice_dynamic_starts_unsupported():
+    """TFLite StableHLO DYNAMIC_UPDATE_SLICE with runtime starts is 
unsupported."""
+    buf = _build_stablehlo_dynamic_update_slice_model([0, 0], 
dynamic_starts=True)
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="dynamic start"):
+        from_tflite(tflite_model)
+
+
+def test_stablehlo_dynamic_update_slice_out_of_bounds_unsupported():
+    """TFLite StableHLO DYNAMIC_UPDATE_SLICE rejects out-of-bounds updates."""
+    buf = _build_stablehlo_dynamic_update_slice_model([2, 3])
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="out-of-bounds"):
+        from_tflite(tflite_model)
+
+
+def _build_stablehlo_dot_general_model(lhs_contract, rhs_contract, 
lhs_batch=None, rhs_batch=None):
+    """Build a minimal STABLEHLO_DOT_GENERAL model."""
+    builder = flatbuffers.Builder(1024)
+    lhs_batch = [] if lhs_batch is None else lhs_batch
+    rhs_batch = [] if rhs_batch is None else rhs_batch
+
+    lhs_batch_vec = _tflite_int64_vector(
+        builder,
+        
_tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartLhsBatchingDimensionsVector,
+        lhs_batch,
+    )
+    rhs_batch_vec = _tflite_int64_vector(
+        builder,
+        
_tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartRhsBatchingDimensionsVector,
+        rhs_batch,
+    )
+    lhs_contract_vec = _tflite_int64_vector(
+        builder,
+        
_tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartLhsContractingDimensionsVector,
+        lhs_contract,
+    )
+    rhs_contract_vec = _tflite_int64_vector(
+        builder,
+        
_tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartRhsContractingDimensionsVector,
+        rhs_contract,
+    )
+
+    _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStart(builder)
+    _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddLhsBatchingDimensions(
+        builder, lhs_batch_vec
+    )
+    _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddRhsBatchingDimensions(
+        builder, rhs_batch_vec
+    )
+    
_tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddLhsContractingDimensions(
+        builder, lhs_contract_vec
+    )
+    
_tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddRhsContractingDimensions(
+        builder, rhs_contract_vec
+    )
+    dot_opts = _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsEnd(builder)
+
+    builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DOT_GENERAL")
+    op_code = _build_operator_code(builder, builtin_op)
+    t_lhs = _build_tensor(builder, 0, [2, 3])
+    t_rhs = _build_tensor(builder, 1, [3, 4])
+    t_out = _build_tensor(builder, 2, [2, 4])
+    op = _build_operator(
+        builder,
+        0,
+        [0, 1],
+        [2],
+        builtin_options2_type=_tfl_builtin_options2.StablehloDotGeneralOptions,
+        builtin_options2=dot_opts,
+    )
+    subgraph = _build_subgraph(
+        builder,
+        tensors=[t_lhs, t_rhs, t_out],
+        operators=[op],
+        inputs=[0, 1],
+        outputs=[2],
+    )
+    buffers = [_build_buffer(builder) for _ in range(3)]
+    return _finish_tflite_model(
+        builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+    )
+
+
+def test_stablehlo_dot_general():
+    """TFLite StableHLO DOT_GENERAL canonical 2D matmul."""
+    mod = _load_model_from_buffer(_build_stablehlo_dot_general_model([1], [0]))
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            lhs: R.Tensor((2, 3), dtype="float32"),
+            rhs: R.Tensor((3, 4), dtype="float32"),
+        ) -> R.Tensor((2, 4), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                gv: R.Tensor((2, 4), dtype="float32") = R.matmul(lhs, rhs, 
out_dtype="void")
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_dot_general_noncanonical_unsupported():
+    """TFLite StableHLO DOT_GENERAL rejects non-canonical contracting dims."""
+    buf = _build_stablehlo_dot_general_model([0], [0])
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="contracting"):
+        from_tflite(tflite_model)
+
+
+def _build_stablehlo_convolution_model(feature_group_count=1, 
input_batch_dimension=0):
+    """Build a minimal STABLEHLO_CONVOLUTION model."""
+    builder = flatbuffers.Builder(1024)
+
+    window_strides_vec = _tflite_int64_vector(
+        builder,
+        
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartWindowStridesVector,
+        [1, 1],
+    )
+    padding_vec = _tflite_int64_vector(
+        builder,
+        _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartPaddingVector,
+        [0, 0, 0, 0],
+    )
+    lhs_dilation_vec = _tflite_int64_vector(
+        builder, 
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartLhsDilationVector, [1, 
1]
+    )
+    rhs_dilation_vec = _tflite_int64_vector(
+        builder, 
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartRhsDilationVector, [1, 
1]
+    )
+    window_reversal_vec = _tflite_bool_vector(
+        builder,
+        
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartWindowReversalVector,
+        [False, False],
+    )
+    input_spatial_vec = _tflite_int64_vector(
+        builder,
+        
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartInputSpatialDimensionsVector,
+        [1, 2],
+    )
+    kernel_spatial_vec = _tflite_int64_vector(
+        builder,
+        
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartKernelSpatialDimensionsVector,
+        [0, 1],
+    )
+    output_spatial_vec = _tflite_int64_vector(
+        builder,
+        
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartOutputSpatialDimensionsVector,
+        [1, 2],
+    )
+
+    _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStart(builder)
+    _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddWindowStrides(
+        builder, window_strides_vec
+    )
+    _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddPadding(builder, 
padding_vec)
+    
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddLhsDilation(builder, 
lhs_dilation_vec)
+    
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddRhsDilation(builder, 
rhs_dilation_vec)
+    _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddWindowReversal(
+        builder, window_reversal_vec
+    )
+    _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputBatchDimension(
+        builder, input_batch_dimension
+    )
+    
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputFeatureDimension(builder,
 3)
+    
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputSpatialDimensions(
+        builder, input_spatial_vec
+    )
+    
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelInputFeatureDimension(builder,
 2)
+    
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelOutputFeatureDimension(builder,
 3)
+    
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelSpatialDimensions(
+        builder, kernel_spatial_vec
+    )
+    
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputBatchDimension(builder,
 0)
+    
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputFeatureDimension(builder,
 3)
+    
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputSpatialDimensions(
+        builder, output_spatial_vec
+    )
+    _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddFeatureGroupCount(
+        builder, feature_group_count
+    )
+    
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddBatchGroupCount(builder, 
1)
+    conv_opts = 
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsEnd(builder)
+
+    builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_CONVOLUTION")
+    op_code = _build_operator_code(builder, builtin_op)
+    t_data = _build_tensor(builder, 0, [1, 5, 5, 2])
+    t_kernel = _build_tensor(builder, 1, [3, 3, 2, 4])
+    t_out = _build_tensor(builder, 2, [1, 3, 3, 4])
+    op = _build_operator(
+        builder,
+        0,
+        [0, 1],
+        [2],
+        
builtin_options2_type=_tfl_builtin_options2.StablehloConvolutionOptions,
+        builtin_options2=conv_opts,
+    )
+    subgraph = _build_subgraph(
+        builder,
+        tensors=[t_data, t_kernel, t_out],
+        operators=[op],
+        inputs=[0, 1],
+        outputs=[2],
+    )
+    buffers = [_build_buffer(builder) for _ in range(3)]
+    return _finish_tflite_model(
+        builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+    )
+
+
+def test_stablehlo_convolution():
+    """TFLite StableHLO CONVOLUTION canonical NHWC/HWIO 2D convolution."""
+    mod = _load_model_from_buffer(_build_stablehlo_convolution_model())
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            data: R.Tensor((1, 5, 5, 2), dtype="float32"),
+            kernel: R.Tensor((3, 3, 2, 4), dtype="float32"),
+        ) -> R.Tensor((1, 3, 3, 4), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                gv: R.Tensor((1, 3, 3, 4), dtype="float32") = R.nn.conv2d(
+                    data,
+                    kernel,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="HWIO",
+                    out_layout="NHWC",
+                    out_dtype="void",
+                )
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_convolution_feature_group_unsupported():
+    """TFLite StableHLO CONVOLUTION rejects grouped convolution in the first 
subset."""
+    buf = _build_stablehlo_convolution_model(feature_group_count=2)
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+    with pytest.raises(tvm.error.OpNotImplemented, 
match="feature_group_count"):
+        from_tflite(tflite_model)
+
+
+def test_stablehlo_convolution_dimension_numbers_unsupported():
+    """TFLite StableHLO CONVOLUTION rejects non-canonical dimension numbers."""
+    buf = _build_stablehlo_convolution_model(input_batch_dimension=1)
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+    with pytest.raises(tvm.error.OpNotImplemented, match="dimension numbers"):
+        from_tflite(tflite_model)
+
+
 def _build_csr_sparsity(
     builder,
     *,

Reply via email to