This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 066bf777b8 [Relax][Frontend][TFLite] Add HASHTABLE_LOOKUP converter 
(#19654)
066bf777b8 is described below

commit 066bf777b841af6ad791b86747cb934ce8c8b09f
Author: YinHanke <[email protected]>
AuthorDate: Tue Jun 2 02:41:27 2026 +0800

    [Relax][Frontend][TFLite] Add HASHTABLE_LOOKUP converter (#19654)
    
    ## Summary
    
    Add Relax TFLite frontend support for `HASHTABLE_LOOKUP`.
    
    This PR adds a converter for `HASHTABLE_LOOKUP` in the Relax TFLite
    frontend. The implementation supports non-string value tensors and
    lowers the lookup through `bucketize`, `take`, and `where` so that
    missing keys return zero-filled values together with a `uint8` hits mask
    matching TFLite semantics for the supported cases.
    
    The PR also adds handcrafted TFLite frontend tests covering:
    - 1D float value tensors
    - 2D float value tensors
    - the current unsupported string-value case
    
    ## Testing
    
    Ran `tests/python/relax/test_frontend_tflite.py -k 'hashtable_lookup'`.
    
    Part of #19519
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 83 ++++++++++++++++++++
 tests/python/relax/test_frontend_tflite.py         | 89 ++++++++++++++++++++++
 2 files changed, 172 insertions(+)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 2a4455eb30..fc3d61713d 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -248,6 +248,7 @@ class OperatorConverter:
             "HASHTABLE": self.convert_hashtable,
             "HASHTABLE_FIND": self.convert_hashtable_find,
             "HASHTABLE_IMPORT": self.convert_hashtable_import,
+            "HASHTABLE_LOOKUP": self.convert_hashtable_lookup,
             "HASHTABLE_SIZE": self.convert_hashtable_size,
             "IF": self.convert_if,
             "L2_NORMALIZATION": self.convert_l2_normalization,
@@ -755,6 +756,88 @@ class OperatorConverter:
             "HASHTABLE_FIND requires TensorType.STRING support in Relax TFLite 
frontend"
         )
 
+    def convert_hashtable_lookup(self, op):
+        """Convert TFLite HASHTABLE_LOOKUP for non-string value tensors."""
+        from tflite.TensorType import TensorType
+
+        input_tensors = self.get_input_tensors(op)
+        output_tensors = self.get_output_tensors(op)
+        if len(input_tensors) != 3 or len(output_tensors) != 2:
+            raise tvm.error.OpNotImplemented(
+                "HASHTABLE_LOOKUP expects lookup, key, and value inputs with 
two outputs"
+            )
+
+        lookup_tensor, key_tensor, value_tensor = input_tensors
+        output_tensor, hits_tensor = output_tensors
+
+        if (
+            lookup_tensor.tensor.Type() != TensorType.INT32
+            or key_tensor.tensor.Type() != TensorType.INT32
+        ):
+            raise tvm.error.OpNotImplemented(
+                "HASHTABLE_LOOKUP requires int32 lookup and key tensors"
+            )
+        if self._is_tflite_string_type(value_tensor.tensor.Type()):
+            raise tvm.error.OpNotImplemented(
+                "HASHTABLE_LOOKUP with TensorType.STRING values is not 
supported"
+            )
+        if value_tensor.tensor.Type() != output_tensor.tensor.Type():
+            raise tvm.error.OpNotImplemented(
+                "HASHTABLE_LOOKUP output dtype must match the value tensor 
dtype"
+            )
+        if hits_tensor.tensor.Type() != TensorType.UINT8:
+            raise tvm.error.OpNotImplemented("HASHTABLE_LOOKUP hits output 
must be uint8")
+
+        lookup_shape = to_int_list(self.get_tensor_shape(lookup_tensor))
+        key_shape = to_int_list(self.get_tensor_shape(key_tensor))
+        value_shape = to_int_list(self.get_tensor_shape(value_tensor))
+        output_shape = to_int_list(self.get_tensor_shape(output_tensor))
+        hits_shape = to_int_list(self.get_tensor_shape(hits_tensor))
+
+        if len(lookup_shape) != 1 or len(key_shape) != 1 or len(value_shape) < 
1:
+            raise tvm.error.OpNotImplemented(
+                "HASHTABLE_LOOKUP requires rank-1 lookup/key and rank>=1 value 
tensors"
+            )
+        if key_shape[0] != value_shape[0]:
+            raise tvm.error.OpNotImplemented(
+                "HASHTABLE_LOOKUP requires key and value tensors to agree on 
row count"
+            )
+        if key_shape[0] == 0:
+            raise tvm.error.OpNotImplemented(
+                "HASHTABLE_LOOKUP requires a non-empty key/value table"
+            )
+        if output_shape != [lookup_shape[0]] + value_shape[1:]:
+            raise tvm.error.OpNotImplemented(
+                "HASHTABLE_LOOKUP output shape must match lookup count and 
value tail shape"
+            )
+        if hits_shape != [lookup_shape[0]]:
+            raise tvm.error.OpNotImplemented(
+                "HASHTABLE_LOOKUP hits output shape must match lookup count"
+            )
+
+        lookup = self.get_tensor_expr(lookup_tensor)
+        key = self.get_tensor_expr(key_tensor)
+        value = self.get_tensor_expr(value_tensor)
+
+        positions = relax.op.bucketize(lookup, key, out_int32=True, 
right=False)
+        candidate_keys = relax.op.take(key, positions, axis=0, mode="clip")
+        in_range = relax.op.less(positions, relax.const(key_shape[0], "int32"))
+        found = relax.op.logical_and(in_range, relax.op.equal(candidate_keys, 
lookup))
+
+        gathered_values = relax.op.take(value, positions, axis=0, mode="clip")
+        output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
+        zero_values = relax.op.zeros(output_shape, output_dtype)
+
+        if len(value_shape) > 1:
+            found_values = relax.op.expand_dims(found, axis=list(range(1, 
len(value_shape))))
+            found_values = relax.op.broadcast_to(found_values, output_shape)
+        else:
+            found_values = found
+
+        output = relax.op.where(found_values, gathered_values, zero_values)
+        hits = relax.op.astype(found, "uint8")
+        return relax.Tuple([output, hits])
+
     def convert_hashtable_size(self, op):
         """Convert HASHTABLE_SIZE for a statically imported TFLite 
hashtable."""
         input_tensors = self.get_input_tensors(op)
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 7c3e526d99..c34da605de 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -4053,6 +4053,18 @@ def _get_builtin_operator(builtin_name):
     return getattr(_tfl_builtin_operator, builtin_name)
 
 
+def _run_module(mod, *inputs):
+    tgt = tvm.target.Target("c")
+    ex = tvm.compile(mod, tgt)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    vm.set_input("main", *inputs)
+    vm.invoke_stateful("main")
+    outputs = vm.get_outputs("main")
+    if hasattr(outputs, "numpy"):
+        return outputs.numpy()
+    return tuple(output.numpy() for output in outputs)
+
+
 def _build_tflite_call_model(
     call_subgraph_index=1,
     callee_inputs=None,
@@ -5844,6 +5856,36 @@ def _build_tflite_hashtable_size_uninitialized_model():
     )
 
 
+def _build_tflite_hashtable_lookup_model(*, value_shape, value_type=None):
+    """Build a model containing one HASHTABLE_LOOKUP operator."""
+    builder = flatbuffers.Builder(1024)
+
+    value_type = _tfl_tensor_type.FLOAT32 if value_type is None else value_type
+
+    lookup_tensor = _build_tensor(builder, 0, [4], 
tensor_type=_tfl_tensor_type.INT32)
+    key_tensor = _build_tensor(builder, 1, [3], 
tensor_type=_tfl_tensor_type.INT32)
+    value_tensor = _build_tensor(builder, 2, value_shape, 
tensor_type=value_type)
+    output_tensor = _build_tensor(builder, 3, [4, *value_shape[1:]], 
tensor_type=value_type)
+    hits_tensor = _build_tensor(builder, 4, [4], 
tensor_type=_tfl_tensor_type.UINT8)
+
+    hashtable_lookup = _build_operator(builder, 0, [0, 1, 2], [3, 4])
+    main_subgraph = _build_subgraph(
+        builder,
+        tensors=[lookup_tensor, key_tensor, value_tensor, output_tensor, 
hits_tensor],
+        operators=[hashtable_lookup],
+        inputs=[0, 1, 2],
+        outputs=[3, 4],
+    )
+    operator_codes = [_build_operator_code(builder, 
_get_builtin_operator("HASHTABLE_LOOKUP"))]
+    buffers = [_build_buffer(builder) for _ in range(5)]
+    return _finish_tflite_model(
+        builder,
+        subgraph=main_subgraph,
+        operator_codes=operator_codes,
+        buffers=buffers,
+    )
+
+
 def test_resource_variable_call_once_init_read():
     """Test reading a resource variable initialized by a supported CALL_ONCE 
subgraph."""
     mod = _load_model_from_buffer(_build_tflite_resource_variable_model())
@@ -5908,6 +5950,53 @@ def test_hashtable_size_uninitialized_unsupported():
         
_load_model_from_buffer(_build_tflite_hashtable_size_uninitialized_model())
 
 
+def test_hashtable_lookup_1d_value():
+    mod = 
_load_model_from_buffer(_build_tflite_hashtable_lookup_model(value_shape=[3]))
+
+    output, hits = _run_module(
+        mod,
+        np.array([1234, -292, -11, 0], dtype=np.int32),
+        np.array([-11, 0, 1234], dtype=np.int32),
+        np.array([0.0, 0.1, 0.4], dtype=np.float32),
+    )
+
+    np.testing.assert_allclose(output, np.array([0.4, 0.0, 0.0, 0.1], 
dtype=np.float32))
+    np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8))
+
+
+def test_hashtable_lookup_2d_value():
+    mod = 
_load_model_from_buffer(_build_tflite_hashtable_lookup_model(value_shape=[3, 
2]))
+
+    output, hits = _run_module(
+        mod,
+        np.array([1234, -292, -11, 0], dtype=np.int32),
+        np.array([-11, 0, 1234], dtype=np.int32),
+        np.array([[0.0, 0.1], [1.0, 1.1], [2.0, 2.1]], dtype=np.float32),
+    )
+
+    np.testing.assert_allclose(
+        output,
+        np.array(
+            [
+                [2.0, 2.1],
+                [0.0, 0.0],
+                [0.0, 0.1],
+                [1.0, 1.1],
+            ],
+            dtype=np.float32,
+        ),
+    )
+    np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8))
+
+
+def test_hashtable_lookup_string_value_unsupported():
+    string_type = _get_string_tensor_type()
+    with pytest.raises(ValueError, match="unknown dtype `string`"):
+        _load_model_from_buffer(
+            _build_tflite_hashtable_lookup_model(value_shape=[3], 
value_type=string_type)
+        )
+
+
 def _get_stablehlo_builtin_operator(builtin_name):
     if not hasattr(_tfl_builtin_operator, builtin_name):
         pytest.skip(f"TFLite schema does not provide 
BuiltinOperator.{builtin_name}")

Reply via email to