This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4a688ddcbc [Relax][Frontend][TFLite] Add EMBEDDING_LOOKUP_SPARSE 
converter (#19652)
4a688ddcbc is described below

commit 4a688ddcbcb6c51d52b5458ee00b70925785155f
Author: YinHanke <[email protected]>
AuthorDate: Wed Jun 3 01:50:30 2026 +0800

    [Relax][Frontend][TFLite] Add EMBEDDING_LOOKUP_SPARSE converter (#19652)
    
    ## Summary
    
    Add Relax TFLite frontend support for `EMBEDDING_LOOKUP_SPARSE`.
    
    This PR adds a converter for `EMBEDDING_LOOKUP_SPARSE` in the Relax
    TFLite frontend. The implementation supports the `SUM`, `MEAN`, and
    `SQRTN` combiners and handles higher-rank sparse indices. The sparse
    aggregation is lowered through `scatter_nd` to match TFLite operator
    semantics for the supported cases.
    
    The PR also adds handcrafted TFLite frontend tests covering:
    - `SUM`
    - `MEAN`
    - `SQRTN`
    - a 3D indices case
    
    ## Testing
    
    Ran `tests/python/relax/test_frontend_tflite.py -k
    'embedding_lookup_sparse'`.
    
    Part of #19519
    
    ---------
    
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 118 ++++++++++++
 tests/python/relax/test_frontend_tflite.py         | 213 +++++++++++++++++++++
 2 files changed, 331 insertions(+)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index bf90895cfc..67d57e5866 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -224,6 +224,7 @@ class OperatorConverter:
             "DIV": functools.partial(self._convert_elemwise, 
relax_op=_op.divide),
             "ELU": self.convert_elu,
             "EMBEDDING_LOOKUP": self.convert_embedding_lookup,
+            "EMBEDDING_LOOKUP_SPARSE": self.convert_embedding_lookup_sparse,
             "EQUAL": functools.partial(
                 self._convert_elemwise, relax_op=_op.equal, comparison_op=True
             ),
@@ -6339,6 +6340,123 @@ class OperatorConverter:
             indices = self.get_tensor_expr(indices_tensor)
         return relax.op.take(params, indices, axis=0)
 
+    def convert_embedding_lookup_sparse(self, op):
+        """Convert TFLite EMBEDDING_LOOKUP_SPARSE."""
+        from tflite.CombinerType import CombinerType
+        from tflite.EmbeddingLookupSparseOptions import 
EmbeddingLookupSparseOptions
+        from tflite.TensorType import TensorType
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 5, "EMBEDDING_LOOKUP_SPARSE should have 5 
input tensors"
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "EMBEDDING_LOOKUP_SPARSE should have 
1 output tensor"
+
+        ids_tensor, indices_tensor, dense_shape_tensor, weights_tensor, 
params_tensor = (
+            input_tensors
+        )
+        output_tensor = output_tensors[0]
+
+        for tensor in input_tensors:
+            assert not tensor.qnn_params, "Quantized input is not expected."
+
+        assert ids_tensor.tensor.Type() == TensorType.INT32
+        assert indices_tensor.tensor.Type() == TensorType.INT32
+        assert dense_shape_tensor.tensor.Type() == TensorType.INT32
+        assert weights_tensor.tensor.Type() == TensorType.FLOAT32
+        assert params_tensor.tensor.Type() == TensorType.FLOAT32
+        assert output_tensor.tensor.Type() == TensorType.FLOAT32
+
+        ids_shape = to_int_list(self.get_tensor_shape(ids_tensor))
+        indices_shape = to_int_list(self.get_tensor_shape(indices_tensor))
+        dense_shape_shape = 
to_int_list(self.get_tensor_shape(dense_shape_tensor))
+        weights_shape = to_int_list(self.get_tensor_shape(weights_tensor))
+        params_shape = to_int_list(self.get_tensor_shape(params_tensor))
+
+        assert len(ids_shape) == 1, "EMBEDDING_LOOKUP_SPARSE ids must be rank 
1"
+        assert len(indices_shape) == 2, "EMBEDDING_LOOKUP_SPARSE indices must 
be rank 2"
+        assert len(dense_shape_shape) == 1, "EMBEDDING_LOOKUP_SPARSE 
dense_shape must be rank 1"
+        assert len(weights_shape) == 1, "EMBEDDING_LOOKUP_SPARSE weights must 
be rank 1"
+        assert len(params_shape) >= 2, "EMBEDDING_LOOKUP_SPARSE params must be 
rank >= 2"
+        assert indices_shape[0] == ids_shape[0], (
+            "EMBEDDING_LOOKUP_SPARSE ids and indices must agree on lookup 
count"
+        )
+        assert weights_shape[0] == ids_shape[0], (
+            "EMBEDDING_LOOKUP_SPARSE ids and weights must agree on lookup 
count"
+        )
+
+        if self.has_expr(dense_shape_tensor.tensor_idx):
+            raise tvm.error.OpNotImplemented(
+                "TFLite EMBEDDING_LOOKUP_SPARSE with runtime dense_shape is 
not supported."
+            )
+
+        dense_shape = to_int_list(self.get_tensor_value(dense_shape_tensor))
+        lookup_rank = indices_shape[1]
+        assert len(dense_shape) == lookup_rank, (
+            "EMBEDDING_LOOKUP_SPARSE dense_shape length must match indices 
width"
+        )
+        assert lookup_rank >= 1, "EMBEDDING_LOOKUP_SPARSE indices width must 
be positive"
+        if not self.has_expr(ids_tensor.tensor_idx):
+            ids_value = self.get_tensor_value(ids_tensor)
+            if np.any(ids_value < 0):
+                raise tvm.error.OpNotImplemented(
+                    "TFLite EMBEDDING_LOOKUP_SPARSE with negative ids is not 
supported."
+                )
+
+        params = self.get_tensor_expr(params_tensor)
+        ids = self.get_tensor_expr(ids_tensor)
+        weights = self.get_tensor_expr(weights_tensor)
+        indices = self.get_tensor_expr(indices_tensor)
+
+        ids = relax.op.astype(ids, "int32")
+        lookup = relax.op.take(params, ids, axis=0)
+
+        embedding_tail_shape = params_shape[1:]
+        output_prefix_shape = dense_shape[:-1]
+        output_shape = output_prefix_shape + embedding_tail_shape
+
+        # Aggregation buckets are defined by every sparse index dimension 
except the last one.
+        bucket_indices = relax.op.strided_slice(indices, axes=[1], begin=[0], 
end=[lookup_rank - 1])
+
+        weight_expand_shape = [ids_shape[0]] + [1] * len(embedding_tail_shape)
+        weighted_lookup = relax.op.multiply(lookup, relax.op.reshape(weights, 
weight_expand_shape))
+
+        value_base = relax.const(np.zeros(output_shape, dtype=np.float32), 
"float32")
+        summed_lookup = relax.op.scatter_nd(value_base, bucket_indices, 
weighted_lookup, "add")
+
+        op_options = op.BuiltinOptions()
+        sparse_options = EmbeddingLookupSparseOptions()
+        sparse_options.Init(op_options.Bytes, op_options.Pos)
+        combiner = sparse_options.Combiner()
+        if combiner == CombinerType.SUM:
+            return summed_lookup
+
+        count_shape = output_prefix_shape
+        count_base = relax.const(np.zeros(count_shape, dtype=np.float32), 
"float32")
+        bucket_count_updates = relax.const(np.ones(ids_shape, 
dtype=np.float32), "float32")
+        bucket_counts = relax.op.scatter_nd(count_base, bucket_indices, 
bucket_count_updates, "add")
+        if combiner == CombinerType.MEAN:
+            denominator_updates = weights
+        elif combiner == CombinerType.SQRTN:
+            denominator_updates = relax.op.multiply(weights, weights)
+        else:
+            raise tvm.error.OpNotImplemented(
+                f"Unsupported TFLite EMBEDDING_LOOKUP_SPARSE combiner value 
{combiner}"
+            )
+
+        denominator = relax.op.scatter_nd(count_base, bucket_indices, 
denominator_updates, "add")
+        if combiner == CombinerType.SQRTN:
+            denominator = relax.op.sqrt(denominator)
+
+        broadcast_shape = count_shape + [1] * len(embedding_tail_shape)
+        denominator = relax.op.reshape(denominator, broadcast_shape)
+        denominator = relax.op.broadcast_to(denominator, output_shape)
+        normalized = relax.op.divide(summed_lookup, denominator)
+        bucket_counts = relax.op.reshape(bucket_counts, broadcast_shape)
+        bucket_counts = relax.op.broadcast_to(bucket_counts, output_shape)
+        return relax.op.where(
+            relax.op.greater(bucket_counts, relax.const(0.0, "float32")), 
normalized, value_base
+        )
+
     def convert_batch_matmul(self, op):
         """batch_matmul implementation."""
 
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index e4866d7096..e4483b9d41 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -4039,6 +4039,17 @@ def _build_hashtable_options(
     return hashtable_options.HashtableOptionsEnd(builder)
 
 
+def _build_embedding_lookup_sparse_options(builder, combiner):
+    try:
+        sparse_options = 
_get_tflite_schema_module("EmbeddingLookupSparseOptions")
+    except ModuleNotFoundError:
+        pytest.skip("TFLite schema does not provide 
EmbeddingLookupSparseOptions")
+
+    sparse_options.EmbeddingLookupSparseOptionsStart(builder)
+    sparse_options.EmbeddingLookupSparseOptionsAddCombiner(builder, combiner)
+    return sparse_options.EmbeddingLookupSparseOptionsEnd(builder)
+
+
 def _load_model_from_buffer(model_bytes):
     if hasattr(tflite.Model, "Model"):
         tflite_model = tflite.Model.Model.GetRootAsModel(model_bytes, 0)
@@ -4067,6 +4078,10 @@ def _run_module(mod, *inputs):
     return tuple(output.numpy() for output in outputs)
 
 
+def _run_no_input_module(mod):
+    return _run_module(mod)
+
+
 def _build_tflite_call_model(
     call_subgraph_index=1,
     callee_inputs=None,
@@ -5858,6 +5873,88 @@ def _build_tflite_hashtable_size_uninitialized_model():
     )
 
 
+def _build_tflite_embedding_lookup_sparse_model(
+    combiner, indices_data, dense_shape_data, weights_data=None
+):
+    builder = flatbuffers.Builder(4096)
+
+    ids_data = np.array([1, 3, 0], dtype=np.int32)
+    indices_data = np.array(indices_data, dtype=np.int32)
+    dense_shape_data = np.array(dense_shape_data, dtype=np.int32)
+    weights_data = (
+        np.array([1.0, 2.0, 4.0], dtype=np.float32)
+        if weights_data is None
+        else np.array(weights_data, dtype=np.float32)
+    )
+    params_data = np.array(
+        [
+            [[0.00, 0.01], [0.10, 0.11], [0.20, 0.21]],
+            [[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]],
+            [[2.00, 2.01], [2.10, 2.11], [2.20, 2.21]],
+            [[3.00, 3.01], [3.10, 3.11], [3.20, 3.21]],
+        ],
+        dtype=np.float32,
+    )
+
+    output_shape = dense_shape_data[:-1].tolist() + list(params_data.shape[1:])
+    sparse_options = _build_embedding_lookup_sparse_options(builder, combiner)
+
+    ids_tensor = _build_tensor(builder, 0, list(ids_data.shape), 
tensor_type=_tfl_tensor_type.INT32)
+    indices_tensor = _build_tensor(
+        builder, 1, list(indices_data.shape), 
tensor_type=_tfl_tensor_type.INT32
+    )
+    dense_shape_tensor = _build_tensor(
+        builder, 2, list(dense_shape_data.shape), 
tensor_type=_tfl_tensor_type.INT32
+    )
+    weights_tensor = _build_tensor(
+        builder, 3, list(weights_data.shape), 
tensor_type=_tfl_tensor_type.FLOAT32
+    )
+    params_tensor = _build_tensor(
+        builder, 4, list(params_data.shape), 
tensor_type=_tfl_tensor_type.FLOAT32
+    )
+    output_tensor = _build_tensor(builder, 5, output_shape, 
tensor_type=_tfl_tensor_type.FLOAT32)
+
+    sparse_op = _build_operator(
+        builder,
+        0,
+        [0, 1, 2, 3, 4],
+        [5],
+        
builtin_options_type=_get_builtin_options_type("EmbeddingLookupSparseOptions"),
+        builtin_options=sparse_options,
+    )
+    subgraph = _build_subgraph(
+        builder,
+        tensors=[
+            ids_tensor,
+            indices_tensor,
+            dense_shape_tensor,
+            weights_tensor,
+            params_tensor,
+            output_tensor,
+        ],
+        operators=[sparse_op],
+        inputs=[],
+        outputs=[5],
+    )
+    operator_codes = [
+        _build_operator_code(builder, 
_get_builtin_operator("EMBEDDING_LOOKUP_SPARSE"))
+    ]
+    buffers = [
+        _build_buffer(builder, ids_data.tobytes()),
+        _build_buffer(builder, indices_data.tobytes()),
+        _build_buffer(builder, dense_shape_data.tobytes()),
+        _build_buffer(builder, weights_data.tobytes()),
+        _build_buffer(builder, params_data.tobytes()),
+        _build_buffer(builder),
+    ]
+    return _finish_tflite_model(
+        builder,
+        subgraph=subgraph,
+        operator_codes=operator_codes,
+        buffers=buffers,
+    )
+
+
 def _build_tflite_hashtable_lookup_model(*, value_shape, value_type=None):
     """Build a model containing one HASHTABLE_LOOKUP operator."""
     builder = flatbuffers.Builder(1024)
@@ -5952,6 +6049,122 @@ def test_hashtable_size_uninitialized_unsupported():
         
_load_model_from_buffer(_build_tflite_hashtable_size_uninitialized_model())
 
 
+def test_embedding_lookup_sparse_sum():
+    from tflite.CombinerType import CombinerType
+
+    mod = _load_model_from_buffer(
+        _build_tflite_embedding_lookup_sparse_model(
+            CombinerType.SUM,
+            indices_data=[[0, 0], [2, 0], [2, 1]],
+            dense_shape_data=[3, 2],
+        )
+    )
+
+    out = _run_no_input_module(mod)
+    expected = np.array(
+        [
+            [[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]],
+            [[0.00, 0.00], [0.00, 0.00], [0.00, 0.00]],
+            [[6.00, 6.06], [6.60, 6.66], [7.20, 7.26]],
+        ],
+        dtype=np.float32,
+    )
+    np.testing.assert_allclose(out, expected, rtol=1e-5, atol=1e-5)
+
+
+def test_embedding_lookup_sparse_mean():
+    from tflite.CombinerType import CombinerType
+
+    mod = _load_model_from_buffer(
+        _build_tflite_embedding_lookup_sparse_model(
+            CombinerType.MEAN,
+            indices_data=[[0, 0], [2, 0], [2, 1]],
+            dense_shape_data=[3, 2],
+        )
+    )
+
+    out = _run_no_input_module(mod)
+    expected = np.array(
+        [
+            [[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]],
+            [[0.00, 0.00], [0.00, 0.00], [0.00, 0.00]],
+            [[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]],
+        ],
+        dtype=np.float32,
+    )
+    np.testing.assert_allclose(out, expected, rtol=1e-5, atol=1e-5)
+
+
+def test_embedding_lookup_sparse_mean_negative_weights():
+    from tflite.CombinerType import CombinerType
+
+    mod = _load_model_from_buffer(
+        _build_tflite_embedding_lookup_sparse_model(
+            CombinerType.MEAN,
+            indices_data=[[0, 0], [0, 1], [2, 0]],
+            dense_shape_data=[3, 2],
+            weights_data=[1.0, -2.0, 0.0],
+        )
+    )
+
+    (output,) = (_run_no_input_module(mod),)
+    expected = np.array(
+        [
+            [[5.0, 5.01], [5.1, 5.11], [5.2, 5.21]],
+            [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],
+            [[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan]],
+        ],
+        dtype=np.float32,
+    )
+    np.testing.assert_allclose(output, expected, rtol=1e-5, atol=1e-5, 
equal_nan=True)
+
+
+def test_embedding_lookup_sparse_sqrtn():
+    from tflite.CombinerType import CombinerType
+
+    mod = _load_model_from_buffer(
+        _build_tflite_embedding_lookup_sparse_model(
+            CombinerType.SQRTN,
+            indices_data=[[0, 0], [2, 0], [2, 1]],
+            dense_shape_data=[3, 2],
+        )
+    )
+
+    out = _run_no_input_module(mod)
+    scale = np.sqrt(20.0).astype("float32")
+    expected = np.array(
+        [
+            [[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]],
+            [[0.00, 0.00], [0.00, 0.00], [0.00, 0.00]],
+            [
+                [6.00 / scale, 6.06 / scale],
+                [6.60 / scale, 6.66 / scale],
+                [7.20 / scale, 7.26 / scale],
+            ],
+        ],
+        dtype=np.float32,
+    )
+    np.testing.assert_allclose(out, expected, rtol=1e-5, atol=1e-5)
+
+
+def test_embedding_lookup_sparse_indices_3d():
+    from tflite.CombinerType import CombinerType
+
+    mod = _load_model_from_buffer(
+        _build_tflite_embedding_lookup_sparse_model(
+            CombinerType.SUM,
+            indices_data=[[0, 0, 0], [2, 0, 0], [2, 0, 1]],
+            dense_shape_data=[3, 2, 2],
+        )
+    )
+
+    out = _run_no_input_module(mod)
+    expected = np.zeros((3, 2, 3, 2), dtype=np.float32)
+    expected[0, 0] = np.array([[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]], 
dtype=np.float32)
+    expected[2, 0] = np.array([[6.00, 6.06], [6.60, 6.66], [7.20, 7.26]], 
dtype=np.float32)
+    np.testing.assert_allclose(out, expected, rtol=1e-5, atol=1e-5)
+
+
 def test_hashtable_lookup_1d_value():
     mod = 
_load_model_from_buffer(_build_tflite_hashtable_lookup_model(value_shape=[3]))
 

Reply via email to