This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 23a0ea8d8b [Relax][Frontend][TFLite] Support 
STABLEHLO_RNG_BIT_GENERATOR (#19651)
23a0ea8d8b is described below

commit 23a0ea8d8bd8c1d2538408f37dcdd13f55940684
Author: HoYi <[email protected]>
AuthorDate: Tue Jun 2 02:50:40 2026 +0800

    [Relax][Frontend][TFLite] Support STABLEHLO_RNG_BIT_GENERATOR (#19651)
    
    ## Summary
    
    This PR adds Relax TFLite frontend support for the TFLite builtin
    `STABLEHLO_RNG_BIT_GENERATOR` operator.
    
    Unlike most StableHLO builtins, the TFLite runtime
    (`tensorflow/lite/kernels/rng_bit_generator.cc`) implements this op as a
    real,
    deterministic counter-based PRNG, so the importer must reproduce it
    bit-exactly
    rather than map it to an existing op:
    
    - one uint64 1-D `initial_state` input, two outputs — uint64
    `output_state` and
      the random-bit `output` (int32 / int64 / uint32 / uint64);
    - `algorithm` in `{DEFAULT, PHILOX, THREEFRY}`, where `DEFAULT` resolves
    to
      `PHILOX`;
    - Random123 Threefry2x32 (20 rounds) and Philox4x32 (10 rounds) with the
    fixed
      constants from `rng_util.cc`;
    - state-length constraints: `THREEFRY` requires `u64[2]`,
    `PHILOX`/`DEFAULT`
      require `u64[2]` or `u64[3]`.
    
    ## Design
    
    TVM/Relax has no matching RNG primitive, so the converter generates a
    TIR kernel
    that mirrors the runtime and emits it through `relax.call_tir` with two
    outputs.
    The kernel:
    
    - reinterprets the uint64 state as uint32 words and advances a 64-bit
    block
      counter (`final counter = initial_state[1] + num_blocks`);
    - runs the selected algorithm per block with all round state
    materialized into
    local buffers, which keeps the generated IR linear instead of an
    exponentially
      nested expression tree;
    - packs the produced uint32 words back into the output dtype, and writes
    the
    updated state (key unchanged, counter advanced, Philox `u64[3]` tail
    passed
      through) — the only state behaviour the runtime relies on.
    
    The kernel is an `s_tir` PrimFunc wrapped in a single opaque structured
    block so
    it remains a well-formed block-structured function for the Relax
    pipeline
    (e.g. `HasReshapePattern`). `get_tensor_type_str` and the input
    `_decode_type`
    map are extended with uint32/uint64 so the uint64 state imports
    correctly.
    
    Unsupported inputs raise a precise `OpNotImplemented` (non-uint64 /
    non-1-D
    state, mismatched output-state shape, unsupported output dtype, unknown
    algorithm, per-algorithm state-length violations).
    
    ## Operator Support
    
    | Operator | TFLite options | Relax lowering | Supported subset |
    |---|---|---|---|
    | `STABLEHLO_RNG_BIT_GENERATOR` |
    `StablehloRngBitGeneratorOptions.Algorithm()` from `BuiltinOptions2` |
    `call_tir` to a generated bit-exact TIR kernel | THREEFRY (`u64[2]`) and
    PHILOX/DEFAULT (`u64[2]`/`u64[3]`); int32/int64/uint32/uint64 output |
    
    ## Tests
    
    Tests build minimal RNG flatbuffers, compile, and execute them,
    comparing the
    output and updated state against the verbatim expected vectors from the
    TFLite
    runtime kernel test (`rng_bit_generator_test.cc`).
    
    | Test | Coverage |
    |---|---|
    | `test_stablehlo_rng_bit_generator_threefry` | THREEFRY bit-exact, all
    4 output dtypes |
    | `test_stablehlo_rng_bit_generator_philox` | PHILOX bit-exact, all 4
    output dtypes |
    | `test_stablehlo_rng_bit_generator_default_matches_philox` | DEFAULT
    resolves to PHILOX |
    | `test_stablehlo_rng_bit_generator_deterministic` | run-to-run
    bit-identical output |
    | `test_stablehlo_rng_bit_generator_unsupported_output_dtype` | output
    dtype guard |
    | `test_stablehlo_rng_bit_generator_threefry_invalid_state_unsupported`
    | THREEFRY `u64[2]` state guard |
    | `test_stablehlo_rng_bit_generator_non_uint64_state_unsupported` |
    uint64 state guard |
    
    Local validation:
    
    ```bash
    python -m ruff check \
      python/tvm/relax/frontend/tflite/tflite_frontend.py \
      tests/python/relax/test_frontend_tflite.py
    
    python -m pytest \
      tests/python/relax/test_frontend_tflite.py \
      -k rng_bit_generator -q
    
    python -m pytest \
      tests/python/relax/test_frontend_tflite.py \
      -k stablehlo -q
    ```
    
    Result:
    
    ```text
    ruff check: All checks passed
    rng_bit_generator tests: 13 passed
    stablehlo tests: 96 passed
    ```
    
    ## References
    
    - Issue #19519 item I: remaining StableHLO operators in TFLite
    - `tensorflow/lite/kernels/rng_bit_generator.cc`, `rng_util.cc`,
      `rng_bit_generator_test.cc`
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 231 ++++++++++++++++++++
 tests/python/relax/test_frontend_tflite.py         | 236 +++++++++++++++++++++
 2 files changed, 467 insertions(+)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index fc3d61713d..bf90895cfc 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -376,6 +376,7 @@ class OperatorConverter:
             "STABLEHLO_REDUCE": self._convert_stablehlo_reduce,
             "STABLEHLO_REDUCE_WINDOW": self._convert_stablehlo_reduce_window,
             "STABLEHLO_REMAINDER": self._convert_stablehlo_remainder,
+            "STABLEHLO_RNG_BIT_GENERATOR": 
self._convert_stablehlo_rng_bit_generator,
             "STABLEHLO_RSQRT": 
functools.partial(self._convert_stablehlo_unary, relax_op=_op.rsqrt),
             "STABLEHLO_SCATTER": self._convert_stablehlo_scatter,
             "STABLEHLO_SELECT": functools.partial(
@@ -1001,6 +1002,8 @@ class OperatorConverter:
             TensorType.FLOAT32: np.float32,
             TensorType.INT32: np.int32,
             TensorType.INT64: np.int64,
+            TensorType.UINT32: np.uint32,
+            TensorType.UINT64: np.uint64,
             TensorType.BOOL: np.bool_,
         }[tensor_wrapper.tensor.Type()]
 
@@ -1041,6 +1044,10 @@ class OperatorConverter:
             return "int32"
         if tensor_type == TensorType.INT64:
             return "int64"
+        if tensor_type == TensorType.UINT32:
+            return "uint32"
+        if tensor_type == TensorType.UINT64:
+            return "uint64"
         if tensor_type == TensorType.BOOL:
             return "bool"
         raise NotImplementedError(f"Tensor type {tensor_type!s} is currently 
not supported")
@@ -2289,6 +2296,72 @@ class OperatorConverter:
         target = call_target_name or "<empty>"
         raise tvm.error.OpNotImplemented(f"STABLEHLO_CUSTOM_CALL target 
{target} is not supported")
 
+    def _convert_stablehlo_rng_bit_generator(self, op):
+        """Convert STABLEHLO_RNG_BIT_GENERATOR to a bit-exact call_tir 
kernel."""
+        from tflite.RngAlgorithm import RngAlgorithm
+        from tflite.StablehloRngBitGeneratorOptions import 
StablehloRngBitGeneratorOptions
+
+        op_name = "STABLEHLO_RNG_BIT_GENERATOR"
+        input_tensors = self.get_input_tensors(op)
+        output_tensors = self.get_output_tensors(op)
+        if len(input_tensors) != 1 or len(output_tensors) != 2:
+            raise tvm.error.OpNotImplemented(f"{op_name} expects one input and 
two outputs")
+
+        opts = self._get_stablehlo_options(op, StablehloRngBitGeneratorOptions)
+        algorithm_enum = opts.Algorithm()
+        # DEFAULT resolves to PHILOX in the TFLite runtime kernel.
+        if algorithm_enum == RngAlgorithm.THREEFRY:
+            algorithm = "threefry"
+        elif algorithm_enum in (RngAlgorithm.PHILOX, RngAlgorithm.DEFAULT):
+            algorithm = "philox"
+        else:
+            raise tvm.error.OpNotImplemented(
+                f"{op_name} algorithm {algorithm_enum} is not supported"
+            )
+
+        state_tensor = input_tensors[0]
+        if self.get_tensor_type_str(state_tensor.tensor.Type()) != "uint64":
+            raise tvm.error.OpNotImplemented(f"{op_name} requires a uint64 
initial state")
+        state_shape = self._get_static_tensor_shape(state_tensor, op_name)
+        if len(state_shape) != 1:
+            raise tvm.error.OpNotImplemented(f"{op_name} requires a 1-D 
initial state")
+        state_len = int(state_shape[0])
+        # State-length constraints mirror the TFLite runtime kernel.
+        if algorithm == "threefry" and state_len != 2:
+            raise tvm.error.OpNotImplemented(f"{op_name} THREEFRY requires a 
u64[2] state")
+        if algorithm == "philox" and state_len not in (2, 3):
+            raise tvm.error.OpNotImplemented(f"{op_name} PHILOX requires a 
u64[2] or u64[3] state")
+
+        out_state_tensor, out_tensor = output_tensors
+        if self.get_tensor_type_str(out_state_tensor.tensor.Type()) != 
"uint64":
+            raise tvm.error.OpNotImplemented(f"{op_name} output state must be 
uint64")
+        out_state_shape = self._get_static_tensor_shape(out_state_tensor, 
op_name)
+        if list(out_state_shape) != list(state_shape):
+            raise tvm.error.OpNotImplemented(
+                f"{op_name} output state shape must match the initial state"
+            )
+        out_dtype = self.get_tensor_type_str(out_tensor.tensor.Type())
+        if out_dtype not in ("int32", "int64", "uint32", "uint64"):
+            raise tvm.error.OpNotImplemented(f"{op_name} output dtype 
{out_dtype} is not supported")
+        out_shape = tuple(self._get_static_tensor_shape(out_tensor, op_name))
+
+        prim_func = _build_stablehlo_rng_bit_generator_primfunc(
+            algorithm, state_len, out_dtype, out_shape
+        )
+        module_builder = self.conversion_state["module_builder"]
+        func_name = 
f"tflite_stablehlo_rng_{algorithm}_{out_state_tensor.tensor_idx}"
+        gv = module_builder.add_func(prim_func, func_name)
+        state_expr = self.get_tensor_expr(state_tensor)
+        call = relax.call_tir(
+            gv,
+            [state_expr],
+            [
+                relax.TensorStructInfo(tuple(state_shape), "uint64"),
+                relax.TensorStructInfo(out_shape, out_dtype),
+            ],
+        )
+        return self.bb.normalize(call)
+
     def _convert_stablehlo_while(self, op):
         """Convert STABLEHLO_WHILE to a recursive Relax private function."""
         from tflite.StablehloWhileOptions import StablehloWhileOptions
@@ -7430,6 +7503,162 @@ class OperatorConverter:
         )
 
 
+# Constants for the Random123 counter-based PRNGs used by 
STABLEHLO_RNG_BIT_GENERATOR,
+# matching tensorflow/lite/kernels/rng_util.cc.
+_STABLEHLO_RNG_THREEFRY_PARITY = 0x1BD11BDA
+_STABLEHLO_RNG_PHILOX_MUL_A = 0xD2511F53
+_STABLEHLO_RNG_PHILOX_MUL_B = 0xCD9E8D57
+_STABLEHLO_RNG_PHILOX_WEYL_A = 0x9E3779B9
+_STABLEHLO_RNG_PHILOX_WEYL_B = 0xBB67AE85
+
+
+def _build_stablehlo_rng_bit_generator_primfunc(algorithm, state_len, 
out_dtype, out_shape):
+    """Build a bit-exact TIR kernel for STABLEHLO_RNG_BIT_GENERATOR.
+
+    Mirrors the TFLite runtime kernel 
(tensorflow/lite/kernels/rng_bit_generator.cc),
+    implementing the Random123 Threefry2x32 (20 rounds) and Philox4x32 (10 
rounds)
+    counter-based PRNGs. The kernel reinterprets the uint64 state as uint32 
words,
+    advances a 64-bit block counter, and packs the generated words into the 
output
+    tensor. The updated state keeps the key unchanged and only advances the 
counter,
+    which is the only behaviour the runtime relies on.
+    """
+    from tvm.script.parser import tirx as T
+
+    total = 1
+    for dim in out_shape:
+        total *= int(dim)
+    is_64bit = out_dtype in ("int64", "uint64")
+    block_words = 2 if algorithm == "threefry" else 4
+    out_word_count = total * (2 if is_64bit else 1)
+    num_blocks = (out_word_count + block_words - 1) // block_words
+    writes_per_block = block_words // (2 if is_64bit else 1)
+    parity = _STABLEHLO_RNG_THREEFRY_PARITY
+    mul_a, mul_b = _STABLEHLO_RNG_PHILOX_MUL_A, _STABLEHLO_RNG_PHILOX_MUL_B
+    weyl_a, weyl_b = _STABLEHLO_RNG_PHILOX_WEYL_A, _STABLEHLO_RNG_PHILOX_WEYL_B
+
+    def _u32(value):
+        return T.Cast("uint32", value)
+
+    def _u64(value):
+        return T.Cast("uint64", value)
+
+    def _store_value(words, write_index):
+        # Pack the generated uint32 words into one output element, 
reinterpreting
+        # the bit pattern into the (possibly signed) output dtype.
+        if is_64bit:
+            low = _u64(words[2 * write_index])
+            high = _u64(words[2 * write_index + 1])
+            return T.reinterpret(out_dtype, low | (high << T.uint64(32)))
+        return T.reinterpret(out_dtype, words[write_index])
+
+    if algorithm == "threefry":
+
+        @T.prim_func(private=True, s_tir=True)
+        def kernel(
+            initial_state: T.Buffer((state_len,), "uint64"),
+            output_state: T.Buffer((state_len,), "uint64"),
+            output: T.Buffer(out_shape, out_dtype),
+        ):
+            # A single opaque structured block keeps the imperative kernel as a
+            # well-formed block-structured PrimFunc, as required by the Relax
+            # pipeline (e.g. HasReshapePattern).
+            with T.sblock("rng_bit_generator"):
+                state_key = initial_state[0]
+                state_counter = initial_state[1]
+                key_0 = _u32(state_key & T.uint64(0xFFFFFFFF))
+                key_1 = _u32(state_key >> T.uint64(32))
+                output_state[0] = state_key
+                output_state[1] = state_counter + T.uint64(num_blocks)
+                out_flat = T.decl_buffer((total,), out_dtype, data=output.data)
+                keys = T.decl_buffer((3,), "uint32", scope="local")
+                rotations = T.decl_buffer((8,), "uint32", scope="local")
+                ctr = T.decl_buffer((2,), "uint32", scope="local")
+                keys[0] = key_0
+                keys[1] = key_1
+                keys[2] = key_0 ^ key_1 ^ T.uint32(parity)
+                rotations[0] = T.uint32(13)
+                rotations[1] = T.uint32(15)
+                rotations[2] = T.uint32(26)
+                rotations[3] = T.uint32(6)
+                rotations[4] = T.uint32(17)
+                rotations[5] = T.uint32(29)
+                rotations[6] = T.uint32(16)
+                rotations[7] = T.uint32(24)
+                for block in T.serial(num_blocks):
+                    counter = state_counter + _u64(block)
+                    ctr[0] = _u32(counter & T.uint64(0xFFFFFFFF)) + key_0
+                    ctr[1] = _u32(counter >> T.uint64(32)) + key_1
+                    for group in T.serial(5):
+                        for step in T.serial(4):
+                            rot = rotations[(group * 4 + step) % 8]
+                            ctr[0] = ctr[0] + ctr[1]
+                            ctr[1] = (ctr[1] << rot) | (ctr[1] >> 
(T.uint32(32) - rot))
+                            ctr[1] = ctr[1] ^ ctr[0]
+                        ctr[0] = ctr[0] + keys[(group + 1) % 3]
+                        ctr[1] = ctr[1] + keys[(group + 2) % 3] + _u32(group + 
1)
+                    for write_index in T.serial(writes_per_block):
+                        element = block * writes_per_block + write_index
+                        if element < total:
+                            out_flat[element] = _store_value(ctr, write_index)
+
+        return kernel
+
+    @T.prim_func(private=True, s_tir=True)
+    def kernel(
+        initial_state: T.Buffer((state_len,), "uint64"),
+        output_state: T.Buffer((state_len,), "uint64"),
+        output: T.Buffer(out_shape, out_dtype),
+    ):
+        with T.sblock("rng_bit_generator"):
+            state_key = initial_state[0]
+            state_counter = initial_state[1]
+            key_0 = _u32(state_key & T.uint64(0xFFFFFFFF))
+            key_1 = _u32(state_key >> T.uint64(32))
+            output_state[0] = state_key
+            output_state[1] = state_counter + T.uint64(num_blocks)
+            out_flat = T.decl_buffer((total,), out_dtype, data=output.data)
+            ctr = T.decl_buffer((4,), "uint32", scope="local")
+            keys = T.decl_buffer((2,), "uint32", scope="local")
+            high_ctr = T.decl_buffer((2,), "uint32", scope="local")
+            if state_len == 3:
+                # PHILOX u64[3]: the third state word feeds the high counter 
and
+                # is passed through to the output state unchanged.
+                high_state = initial_state[2]
+                output_state[2] = high_state
+                high_ctr[0] = _u32(high_state & T.uint64(0xFFFFFFFF))
+                high_ctr[1] = _u32(high_state >> T.uint64(32))
+            else:
+                high_ctr[0] = key_0
+                high_ctr[1] = key_1
+            for block in T.serial(num_blocks):
+                counter = state_counter + _u64(block)
+                ctr[0] = _u32(counter & T.uint64(0xFFFFFFFF))
+                ctr[1] = _u32(counter >> T.uint64(32))
+                ctr[2] = high_ctr[0]
+                ctr[3] = high_ctr[1]
+                keys[0] = key_0
+                keys[1] = key_1
+                for _round in T.serial(10):
+                    prod_0 = T.uint64(mul_a) * _u64(ctr[0])
+                    prod_1 = T.uint64(mul_b) * _u64(ctr[2])
+                    new_0 = _u32(prod_1 >> T.uint64(32)) ^ ctr[1] ^ keys[0]
+                    new_1 = _u32(prod_1 & T.uint64(0xFFFFFFFF))
+                    new_2 = _u32(prod_0 >> T.uint64(32)) ^ ctr[3] ^ keys[1]
+                    new_3 = _u32(prod_0 & T.uint64(0xFFFFFFFF))
+                    ctr[0] = new_0
+                    ctr[1] = new_1
+                    ctr[2] = new_2
+                    ctr[3] = new_3
+                    keys[0] = keys[0] + T.uint32(weyl_a)
+                    keys[1] = keys[1] + T.uint32(weyl_b)
+                for write_index in T.serial(writes_per_block):
+                    element = block * writes_per_block + write_index
+                    if element < total:
+                        out_flat[element] = _store_value(ctr, write_index)
+
+    return kernel
+
+
 # pylint: disable=no-else-return
 def prepare_dense_matrix_from_sparse(sparse_tensor, sparse_tensor_value, 
sparse_tensor_type):
     """Prepare sparse indices and dense matrix from TFLite sparse 
parameters."""
@@ -7676,6 +7905,8 @@ def _decode_type(n):
         7: "int16",
         8: "complex64",
         9: "int8",
+        12: "uint64",
+        15: "uint32",
     }
     return _tflite_m[n]
 
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index c34da605de..e4866d7096 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -3697,6 +3697,7 @@ _tfl_stablehlo_reduce_window_opts = 
_get_tflite_schema_module("StablehloReduceWi
 _tfl_stablehlo_scatter_opts = 
_get_tflite_schema_module("StablehloScatterOptions")
 _tfl_stablehlo_sort_opts = _get_tflite_schema_module("StablehloSortOptions")
 _tfl_stablehlo_while_opts = _get_tflite_schema_module("StablehloWhileOptions")
+_tfl_stablehlo_rng_opts = 
_get_tflite_schema_module("StablehloRngBitGeneratorOptions")
 _tfl_call_options = _get_tflite_schema_module("CallOptions")
 _tfl_call_once_options = _get_tflite_schema_module("CallOnceOptions")
 _tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata")
@@ -3721,6 +3722,7 @@ _tfl_fc_weights_format = 
_get_tflite_schema_enum("FullyConnectedOptionsWeightsFo
 _tfl_padding = _get_tflite_schema_enum("Padding")
 _tfl_sparse_index_vector = _get_tflite_schema_enum("SparseIndexVector")
 _tfl_tensor_type = _get_tflite_schema_enum("TensorType")
+_tfl_rng_algorithm = _get_tflite_schema_enum("RngAlgorithm")
 
 _tfl_lstm_options = _get_tflite_schema_module("LSTMOptions")
 _tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions")
@@ -7015,6 +7017,240 @@ def 
test_stablehlo_options_missing_payload_unsupported():
         _load_model_from_buffer(buf)
 
 
+def _build_stablehlo_rng_model(algorithm, state_len, out_shape, 
out_tensor_type, const_state=None):
+    """Build a STABLEHLO_RNG_BIT_GENERATOR model.
+
+    When ``const_state`` is provided, the uint64 initial state is embedded as a
+    constant tensor (no graph input); otherwise it is a graph input.
+    """
+    builder = flatbuffers.Builder(1024)
+
+    _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsStart(builder)
+    
_tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsAddAlgorithm(builder, 
algorithm)
+    rng_opts = 
_tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsEnd(builder)
+
+    rng_builtin = 
_get_stablehlo_builtin_operator("STABLEHLO_RNG_BIT_GENERATOR")
+    rng_code = _build_operator_code(builder, rng_builtin)
+
+    main_tensors = [
+        _build_tensor(builder, 0, [state_len], 
tensor_type=_tfl_tensor_type.UINT64),
+        _build_tensor(builder, 1, [state_len], 
tensor_type=_tfl_tensor_type.UINT64),
+        _build_tensor(builder, 2, list(out_shape), 
tensor_type=out_tensor_type),
+    ]
+    rng_op = _build_operator(
+        builder,
+        0,
+        [0],
+        [1, 2],
+        
builtin_options2_type=_tfl_builtin_options2.StablehloRngBitGeneratorOptions,
+        builtin_options2=rng_opts,
+    )
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=main_tensors,
+        operators=[rng_op],
+        inputs=[] if const_state is not None else [0],
+        outputs=[1, 2],
+    )
+
+    state_data = None
+    if const_state is not None:
+        state_data = np.array(const_state, dtype="uint64").tobytes()
+    buffers = [
+        _build_buffer(builder, data=state_data),
+        _build_buffer(builder),
+        _build_buffer(builder),
+    ]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        operator_codes=[rng_code],
+        buffers=buffers,
+    )
+
+
+def _run_stablehlo_rng_model(algorithm, state_len, out_shape, out_tensor_type, 
init_state):
+    """Import, compile, and execute an RNG model, returning (output_state, 
output)."""
+    buf = _build_stablehlo_rng_model(algorithm, state_len, out_shape, 
out_tensor_type)
+    mod = _load_model_from_buffer(buf)
+    ex = tvm.compile(mod, tvm.target.Target("llvm"))
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    result = vm["main"](tvm.runtime.tensor(np.array(init_state, 
dtype="uint64")))
+    return result[0].numpy(), result[1].numpy()
+
+
+# Expected vectors are taken verbatim from the TFLite runtime kernel test
+# (tensorflow/lite/kernels/rng_bit_generator_test.cc), guaranteeing bit-exact 
parity.
+_RNG_THREEFRY_EXPECTED = {
+    "int32": [43444564, -2144348869, -315321645, -549236733, 1672743891, 
-54463903],
+    "uint32": [43444564, 2150618427, 3979645651, 3745730563, 1672743891, 
4240503393],
+    "int64": [
+        -9209908263526143660,
+        -2358953802017238317,
+        -233920680524772397,
+        2658481902456610144,
+        -2022031683723149139,
+        -2324041912354448873,
+    ],
+    "uint64": [
+        9236835810183407956,
+        16087790271692313299,
+        18212823393184779219,
+        2658481902456610144,
+        16424712389986402477,
+        16122702161355102743,
+    ],
+}
+_RNG_THREEFRY_STATE = {"int32": [1, 5], "uint32": [1, 5], "int64": [1, 8], 
"uint64": [1, 8]}
+_RNG_PHILOX_EXPECTED = {
+    "int32": [-263854262, 1366700262, 495645701, -1243243882, 89414891, 
1917262711],
+    "uint32": [4031113034, 1366700262, 495645701, 3051723414, 89414891, 
1917262711],
+    "int64": [
+        5869932932755744586,
+        -5339691813646437371,
+        8234580641674714347,
+        2641225993340350124,
+        1962472297844690804,
+        -3580856229565614135,
+    ],
+    "uint64": [
+        5869932932755744586,
+        13107052260063114245,
+        8234580641674714347,
+        2641225993340350124,
+        1962472297844690804,
+        14865887844143937481,
+    ],
+}
+_RNG_PHILOX_STATE = {
+    "int32": [1, 4, 3],
+    "uint32": [1, 4, 3],
+    "int64": [1, 5, 3],
+    "uint64": [1, 5, 3],
+}
+
+
[email protected](
+    "out_dtype,out_tensor_type",
+    [
+        ("int32", _tfl_tensor_type.INT32),
+        ("uint32", _tfl_tensor_type.UINT32),
+        ("int64", _tfl_tensor_type.INT64),
+        ("uint64", _tfl_tensor_type.UINT64),
+    ],
+)
+def test_stablehlo_rng_bit_generator_threefry(out_dtype, out_tensor_type):
+    """TFLite STABLEHLO_RNG_BIT_GENERATOR THREEFRY matches the runtime kernel 
bit-exactly."""
+    state, output = _run_stablehlo_rng_model(
+        _tfl_rng_algorithm.THREEFRY, 2, [2, 3], out_tensor_type, [1, 2]
+    )
+    assert output.flatten().tolist() == _RNG_THREEFRY_EXPECTED[out_dtype]
+    assert state.tolist() == _RNG_THREEFRY_STATE[out_dtype]
+
+
[email protected](
+    "out_dtype,out_tensor_type",
+    [
+        ("int32", _tfl_tensor_type.INT32),
+        ("uint32", _tfl_tensor_type.UINT32),
+        ("int64", _tfl_tensor_type.INT64),
+        ("uint64", _tfl_tensor_type.UINT64),
+    ],
+)
+def test_stablehlo_rng_bit_generator_philox(out_dtype, out_tensor_type):
+    """TFLite STABLEHLO_RNG_BIT_GENERATOR PHILOX matches the runtime kernel 
bit-exactly."""
+    state, output = _run_stablehlo_rng_model(
+        _tfl_rng_algorithm.PHILOX, 3, [2, 3], out_tensor_type, [1, 2, 3]
+    )
+    assert output.flatten().tolist() == _RNG_PHILOX_EXPECTED[out_dtype]
+    assert state.tolist() == _RNG_PHILOX_STATE[out_dtype]
+
+
+def test_stablehlo_rng_bit_generator_default_matches_philox():
+    """TFLite STABLEHLO_RNG_BIT_GENERATOR DEFAULT resolves to the PHILOX 
algorithm."""
+    state, output = _run_stablehlo_rng_model(
+        _tfl_rng_algorithm.DEFAULT, 3, [2, 3], _tfl_tensor_type.INT32, [1, 2, 
3]
+    )
+    assert output.flatten().tolist() == _RNG_PHILOX_EXPECTED["int32"]
+    assert state.tolist() == _RNG_PHILOX_STATE["int32"]
+
+
+def test_stablehlo_rng_bit_generator_deterministic():
+    """Re-running the imported RNG kernel yields identical bit-exact output."""
+    buf = _build_stablehlo_rng_model(_tfl_rng_algorithm.PHILOX, 3, [3, 3], 
_tfl_tensor_type.INT32)
+    mod = _load_model_from_buffer(buf)
+    ex = tvm.compile(mod, tvm.target.Target("llvm"))
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    init = tvm.runtime.tensor(np.array([7, 8, 9], dtype="uint64"))
+    first = vm["main"](init)
+    second = vm["main"](init)
+    np.testing.assert_equal(first[1].numpy(), second[1].numpy())
+    np.testing.assert_equal(first[0].numpy(), second[0].numpy())
+
+
+def test_stablehlo_rng_bit_generator_constant_state():
+    """A constant uint64 initial state imports and stays bit-exact (no graph 
input)."""
+    buf = _build_stablehlo_rng_model(
+        _tfl_rng_algorithm.THREEFRY, 2, [2, 3], _tfl_tensor_type.INT32, 
const_state=[1, 2]
+    )
+    mod = _load_model_from_buffer(buf)
+    assert len(mod["main"].params) == 0
+    ex = tvm.compile(mod, tvm.target.Target("llvm"))
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    result = vm["main"]()
+    assert result[1].numpy().flatten().tolist() == 
_RNG_THREEFRY_EXPECTED["int32"]
+    assert result[0].numpy().tolist() == _RNG_THREEFRY_STATE["int32"]
+
+
+def test_stablehlo_rng_bit_generator_unsupported_output_dtype():
+    """TFLite STABLEHLO_RNG_BIT_GENERATOR rejects non-integer output dtypes."""
+    buf = _build_stablehlo_rng_model(_tfl_rng_algorithm.PHILOX, 3, [2, 3], 
_tfl_tensor_type.FLOAT32)
+    with pytest.raises(tvm.error.OpNotImplemented, match="output dtype float32 
is not supported"):
+        _load_model_from_buffer(buf)
+
+
+def test_stablehlo_rng_bit_generator_threefry_invalid_state_unsupported():
+    """TFLite STABLEHLO_RNG_BIT_GENERATOR rejects a u64[3] state for 
THREEFRY."""
+    buf = _build_stablehlo_rng_model(_tfl_rng_algorithm.THREEFRY, 3, [2, 3], 
_tfl_tensor_type.INT32)
+    with pytest.raises(tvm.error.OpNotImplemented, match="THREEFRY requires a 
u64.2. state"):
+        _load_model_from_buffer(buf)
+
+
+def test_stablehlo_rng_bit_generator_non_uint64_state_unsupported():
+    """TFLite STABLEHLO_RNG_BIT_GENERATOR rejects a non-uint64 initial 
state."""
+    builder = flatbuffers.Builder(1024)
+    _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsStart(builder)
+    _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsAddAlgorithm(
+        builder, _tfl_rng_algorithm.PHILOX
+    )
+    rng_opts = 
_tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsEnd(builder)
+    rng_code = _build_operator_code(
+        builder, _get_stablehlo_builtin_operator("STABLEHLO_RNG_BIT_GENERATOR")
+    )
+    tensors = [
+        _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT64),
+        _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT64),
+        _build_tensor(builder, 2, [2, 3], tensor_type=_tfl_tensor_type.INT32),
+    ]
+    rng_op = _build_operator(
+        builder,
+        0,
+        [0],
+        [1, 2],
+        
builtin_options2_type=_tfl_builtin_options2.StablehloRngBitGeneratorOptions,
+        builtin_options2=rng_opts,
+    )
+    subgraph = _build_subgraph(
+        builder, tensors=tensors, operators=[rng_op], inputs=[0], outputs=[1, 
2]
+    )
+    buffers = [_build_buffer(builder) for _ in range(3)]
+    buf = _finish_tflite_model(
+        builder, subgraph=subgraph, operator_codes=[rng_code], buffers=buffers
+    )
+    with pytest.raises(tvm.error.OpNotImplemented, match="requires a uint64 
initial state"):
+        _load_model_from_buffer(buf)
+
+
 def test_stablehlo_while():
     """TFLite STABLEHLO_WHILE lowers to a recursive Relax private function."""
     mod = _load_model_from_buffer(_build_stablehlo_while_model())

Reply via email to