This is an automated email from the ASF dual-hosted git repository.
kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new eee0b99826e Fix `LlamaIndexEmbeddingOperator` returning `vector=None`
for every chunk (#68491)
eee0b99826e is described below
commit eee0b99826e8857e0a339245454478217d6bcca3
Author: Kaxil Naik <[email protected]>
AuthorDate: Mon Jun 15 23:15:01 2026 +0100
Fix `LlamaIndexEmbeddingOperator` returning `vector=None` for every chunk
(#68491)
VectorStoreIndex attaches embeddings to model_copy() copies of the
nodes it is given, never the originals, so reading node.embedding
after index construction always returned None. Embed the original
nodes explicitly with the same content VectorStoreIndex embeds
(MetadataMode.EMBED) and only build the index when persist_dir is
set; embed_nodes() inside the index skips pre-embedded nodes, so
persisting does not re-call the embedding API.
---
.../ai/docs/operators/llamaindex_embedding.rst | 12 +-
.../common/ai/operators/llamaindex_embedding.py | 47 +++++---
.../ai/operators/test_llamaindex_embedding.py | 130 ++++++++++++++++++---
3 files changed, 151 insertions(+), 38 deletions(-)
diff --git a/providers/common/ai/docs/operators/llamaindex_embedding.rst
b/providers/common/ai/docs/operators/llamaindex_embedding.rst
index 99125ac74bd..894045b684c 100644
--- a/providers/common/ai/docs/operators/llamaindex_embedding.rst
+++ b/providers/common/ai/docs/operators/llamaindex_embedding.rst
@@ -25,10 +25,10 @@ LlamaIndex. Designed to feed the output of
:class:`~airflow.providers.common.ai.operators.document_loader.DocumentLoaderOperator`
into vector storage (pgvector, Pinecone, Weaviate, ...).
-The operator passes the embedding model **directly** to
-``VectorStoreIndex(..., embed_model=...)`` -- it does not mutate
-LlamaIndex's global ``Settings`` singleton, so concurrent tasks in the same
-worker process don't race on shared model state.
+The operator calls the embedding model **directly** (and passes it to
+``VectorStoreIndex(..., embed_model=...)`` when persisting) -- it does not
+mutate LlamaIndex's global ``Settings`` singleton, so concurrent tasks in the
+same worker process don't race on shared model state.
Basic usage
-----------
@@ -117,3 +117,7 @@ Returns a dict with::
...
],
}
+
+``vector`` is computed over the chunk's metadata-enriched content
+(LlamaIndex's ``MetadataMode.EMBED``, the same content ``VectorStoreIndex``
+embeds), while ``text`` is the raw chunk text without metadata.
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py
index d85e6921002..34cd441ef05 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py
@@ -44,10 +44,10 @@ class LlamaIndexEmbeddingOperator(BaseOperator):
``list[dict]`` with ``text`` and ``metadata`` keys; output includes the
embedding vectors ready for downstream storage ingest.
- The operator passes the embedding model **directly** to
- ``VectorStoreIndex(..., embed_model=...)`` -- it does not mutate
- LlamaIndex's global ``Settings`` singleton, so concurrent tasks in the
- same worker don't race on shared state.
+ The operator calls the embedding model **directly** (and passes it to
+ ``VectorStoreIndex(..., embed_model=...)`` when persisting) -- it does
+ not mutate LlamaIndex's global ``Settings`` singleton, so concurrent
+ tasks in the same worker don't race on shared state.
:param documents: List of dicts with ``text`` and ``metadata`` keys,
typically from ``DocumentLoaderOperator`` or a ``@task``. Templated,
@@ -114,6 +114,7 @@ class LlamaIndexEmbeddingOperator(BaseOperator):
try:
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
+ from llama_index.core.schema import MetadataMode
except ImportError as e:
raise AirflowOptionalProviderFeatureException(e)
@@ -125,19 +126,32 @@ class LlamaIndexEmbeddingOperator(BaseOperator):
nodes = splitter.get_nodes_from_documents(llama_docs)
self.log.info("Split %d documents into %d chunks", len(llama_docs),
len(nodes))
- # ``VectorStoreIndex(...)`` populates each node's ``.embedding`` as a
- # side effect of building the index; capture the index so the
- # variable isn't discarded.
- index = VectorStoreIndex(nodes, embed_model=embed_model,
show_progress=False)
+ # ``VectorStoreIndex(...)`` never sets ``.embedding`` on the nodes it
+ # is given -- ``_get_node_with_embedding()`` attaches embeddings to
+ # ``model_copy()`` copies, so reading ``node.embedding`` afterwards
+ # always returns ``None`` (apache/airflow#68416). Embed the original
+ # nodes explicitly.
+ # ``MetadataMode.EMBED`` matches what ``embed_nodes()`` inside the
+ # index embeds (includes metadata, respects
+ # ``excluded_embed_metadata_keys``).
+ texts = [node.get_content(metadata_mode=MetadataMode.EMBED) for node
in nodes]
+ vectors = embed_model.get_text_embedding_batch(texts,
show_progress=False)
+ for node, vector in zip(nodes, vectors, strict=True):
+ node.embedding = vector
if self.persist_dir:
+ # The index is only needed for persistence. ``embed_nodes()``
+ # inside ``VectorStoreIndex`` skips nodes whose ``.embedding`` is
+ # already set, so this reuses the vectors above instead of
+ # re-calling the embedding API.
+ index = VectorStoreIndex(nodes, embed_model=embed_model,
show_progress=False)
self._persist(index, self.persist_dir)
# ``SentenceSplitter`` always returns ``TextNode`` instances, but the
# base ``get_nodes_from_documents`` signature is typed as
# ``list[BaseNode]`` (which has no ``.text``). Cast so mypy doesn't
- # flag the ``.text`` access; ``node.embedding`` is populated by
- # ``VectorStoreIndex`` for every node above.
+ # flag the ``.text`` access; ``node.embedding`` is populated by the
+ # pre-embed step above for every node.
text_nodes = cast("list[TextNode]", nodes)
chunks = [
{
@@ -164,9 +178,9 @@ class LlamaIndexEmbeddingOperator(BaseOperator):
* ``None`` or ``str`` -- build an ``OpenAIEmbedding`` via
``LlamaIndexHook`` (the framework's documented ``default``
behaviour).
- * Has ``get_text_embedding`` / ``_get_query_embedding`` -- treat as
- a pre-built ``BaseEmbedding`` (duck-typed to avoid forcing a
- ``llama_index`` import here).
+ * Has ``get_text_embedding_batch`` / ``_get_query_embedding`` --
+ treat as a pre-built ``BaseEmbedding`` (duck-typed to avoid
+ forcing a ``llama_index`` import here).
* Anything else -- ``TypeError`` with a clear pointer.
"""
if self.embed_model is None or isinstance(self.embed_model, str):
@@ -179,10 +193,11 @@ class LlamaIndexEmbeddingOperator(BaseOperator):
).get_embedding_model()
# ``BaseEmbedding`` always exposes these two methods (see
- # ``llama_index.core.base.embeddings.base``). Duck-typing avoids
- # importing ``llama_index`` here and also catches the case where an
+ # ``llama_index.core.base.embeddings.base``); ``execute`` calls
+ # ``get_text_embedding_batch``. Duck-typing avoids importing
+ # ``llama_index`` here and also catches the case where an
# unresolved ``XComArg`` slips through.
- if hasattr(self.embed_model, "get_text_embedding") and hasattr(
+ if hasattr(self.embed_model, "get_text_embedding_batch") and hasattr(
self.embed_model, "_get_query_embedding"
):
return self.embed_model
diff --git
a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py
b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py
index 43b44f87c9f..d3e6b2acfca 100644
---
a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py
+++
b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py
@@ -20,6 +20,11 @@ from unittest.mock import MagicMock, patch
import pytest
+pytest.importorskip("llama_index.core")
+
+from llama_index.core import MockEmbedding
+from llama_index.core.schema import MetadataMode
+
from airflow.providers.common.ai.operators.llamaindex_embedding import
LlamaIndexEmbeddingOperator
@@ -38,17 +43,23 @@ def _li(monkeypatch):
return {"VectorStoreIndex": VectorStoreIndex, "SentenceSplitter":
SentenceSplitter}
-def _node(text: str = "chunk text", metadata: dict | None = None, vector=None):
+def _node(text: str = "chunk text", metadata: dict | None = None):
node = MagicMock()
node.text = text
node.metadata = metadata or {}
- node.embedding = vector
+ node.embedding = None
+ node.get_content.return_value = text
return node
-def _byo_embedding():
- """Return a duck-typed ``BaseEmbedding`` stand-in (has the two methods the
operator checks)."""
- return MagicMock(name="MyBaseEmbedding", spec=["get_text_embedding",
"_get_query_embedding"])
+def _byo_embedding(vectors: list[list[float]] | None = None):
+ """Return a duck-typed ``BaseEmbedding`` stand-in (has the methods the
operator checks and calls)."""
+ embedding = MagicMock(
+ name="MyBaseEmbedding",
+ spec=["get_text_embedding_batch", "_get_query_embedding"],
+ )
+ embedding.get_text_embedding_batch.return_value = [[0.0]] if vectors is
None else vectors
+ return embedding
class TestEmbeddingOperatorInit:
@@ -72,8 +83,9 @@ class TestEmbeddingOperatorExecute:
def test_string_embed_model_goes_through_hook(self, mock_get_embed, _li):
# `embed_model` as a string -> hook builds OpenAIEmbedding.
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [
- _node(text="chunk a", vector=[0.1, 0.2]),
+ _node(text="chunk a"),
]
+ mock_get_embed.return_value = _byo_embedding(vectors=[[0.1, 0.2]])
op = LlamaIndexEmbeddingOperator(
task_id="test",
@@ -93,6 +105,7 @@ class TestEmbeddingOperatorExecute:
def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls,
_li):
# ``embed_conn_id`` overrides ``llm_conn_id`` for the embedding API.
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value =
[_node()]
+ mock_hook_cls.return_value.get_embedding_model.return_value =
_byo_embedding()
op = LlamaIndexEmbeddingOperator(
task_id="test",
@@ -110,8 +123,9 @@ class TestEmbeddingOperatorExecute:
)
def test_byo_embed_model_bypasses_hook(self, _li):
- # `embed_model` is a non-string instance -> hook is bypassed.
- byo = _byo_embedding()
+ # `embed_model` is a non-string instance -> hook is bypassed and the
+ # user's instance does the embedding.
+ byo = _byo_embedding(vectors=[[0.5]])
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value =
[_node()]
op = LlamaIndexEmbeddingOperator(
@@ -119,12 +133,10 @@ class TestEmbeddingOperatorExecute:
documents=[{"text": "doc"}],
embed_model=byo,
)
- op.execute(context=MagicMock())
+ result = op.execute(context=MagicMock())
- # VectorStoreIndex called with the user's instance, not anything else.
- _li["VectorStoreIndex"].assert_called_once()
- kwargs = _li["VectorStoreIndex"].call_args.kwargs
- assert kwargs["embed_model"] is byo
+ byo.get_text_embedding_batch.assert_called_once()
+ assert result["chunks"][0]["vector"] == [0.5]
def test_invalid_embed_model_raises_typeerror(self, _li):
# An object that's neither None/str nor duck-types as BaseEmbedding
@@ -140,17 +152,16 @@ class TestEmbeddingOperatorExecute:
with pytest.raises(TypeError, match="embed_model must be"):
op.execute(context=MagicMock())
-
@patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model")
- def test_chunks_carry_text_metadata_vector(self, mock_get_embed, _li):
+ def test_chunks_carry_text_metadata_vector(self, _li):
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [
- _node(text="x", metadata={"k": "v"}, vector=[1.0, 2.0]),
- _node(text="y", metadata={"k": "v2"}, vector=[3.0, 4.0]),
+ _node(text="x", metadata={"k": "v"}),
+ _node(text="y", metadata={"k": "v2"}),
]
op = LlamaIndexEmbeddingOperator(
task_id="test",
documents=[{"text": "doc"}],
- embed_model="text-embedding-3-small",
+ embed_model=_byo_embedding(vectors=[[1.0, 2.0], [3.0, 4.0]]),
)
result = op.execute(context=MagicMock())
@@ -159,8 +170,84 @@ class TestEmbeddingOperatorExecute:
{"text": "y", "metadata": {"k": "v2"}, "vector": [3.0, 4.0]},
]
+ def test_nodes_embedded_with_embed_metadata_mode(self, _li):
+ # llama-index's own ``embed_nodes()`` embeds
+ # ``node.get_content(metadata_mode=MetadataMode.EMBED)`` (includes
+ # metadata, respects ``excluded_embed_metadata_keys``). The pre-embed
+ # step must match, or the vectors silently change semantics.
+ node = _node(text="chunk a")
+
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value =
[node]
+ byo = _byo_embedding()
+
+ op = LlamaIndexEmbeddingOperator(
+ task_id="test",
+ documents=[{"text": "doc"}],
+ embed_model=byo,
+ )
+ op.execute(context=MagicMock())
+
+
node.get_content.assert_called_once_with(metadata_mode=MetadataMode.EMBED)
+ byo.get_text_embedding_batch.assert_called_once()
+ assert byo.get_text_embedding_batch.call_args.args[0] == ["chunk a"]
+
+ def test_index_only_built_when_persisting(self, _li):
+ # Without ``persist_dir`` the index would be built and immediately
+ # discarded; the vectors come from the pre-embed step.
+
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value =
[_node()]
+
+ op = LlamaIndexEmbeddingOperator(
+ task_id="test",
+ documents=[{"text": "doc"}],
+ embed_model=_byo_embedding(),
+ )
+ op.execute(context=MagicMock())
+
+ _li["VectorStoreIndex"].assert_not_called()
+
+ def test_vectors_populated_with_real_llama_index(self):
+ # Regression test for #68416: ``VectorStoreIndex`` attaches embeddings
+ # to node *copies* (``model_copy()`` in ``_get_node_with_embedding``),
+ # so reading ``node.embedding`` after index construction returns
+ # ``None``. Run the real llama-index code path with its offline
+ # ``MockEmbedding`` -- no mocks on the operator's internals.
+ op = LlamaIndexEmbeddingOperator(
+ task_id="test",
+ documents=[{"text": "hello world", "metadata": {"src": "a"}}],
+ embed_model=MockEmbedding(embed_dim=8),
+ )
+ result = op.execute(context=MagicMock())
+
+ assert result["chunk_count"] >= 1
+ assert all(chunk["vector"] is not None for chunk in result["chunks"])
+ assert all(len(chunk["vector"]) == 8 for chunk in result["chunks"])
+
class TestEmbeddingOperatorPersist:
+ def test_persist_path_embeds_each_chunk_once_with_real_llama_index(self,
tmp_path):
+ # ``embed_nodes()`` inside ``VectorStoreIndex`` must skip the
+ # pre-embedded nodes -- if it re-embeds, every chunk pays the
+ # embedding API twice. Runs the real llama-index persist path.
+ embedded_texts: list[str] = []
+
+ class CountingMockEmbedding(MockEmbedding):
+ def _get_text_embedding(self, text: str) -> list[float]:
+ embedded_texts.append(text)
+ return super()._get_text_embedding(text)
+
+ persist_dir = tmp_path / "idx"
+ op = LlamaIndexEmbeddingOperator(
+ task_id="test",
+ documents=[{"text": "hello world", "metadata": {"src": "a"}}],
+ embed_model=CountingMockEmbedding(embed_dim=8),
+ persist_dir=str(persist_dir),
+ )
+ result = op.execute(context=MagicMock())
+
+ assert result["chunk_count"] == 1
+ assert len(embedded_texts) == result["chunk_count"]
+ assert all(chunk["vector"] is not None for chunk in result["chunks"])
+ assert any(persist_dir.iterdir())
+
@patch("os.makedirs")
@patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model")
def test_local_persist_dir_calls_makedirs_and_storage_persist(
@@ -168,6 +255,7 @@ class TestEmbeddingOperatorPersist:
):
node = _node()
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value =
[node]
+ mock_get_embed.return_value = _byo_embedding(vectors=[[0.1]])
index = _li["VectorStoreIndex"].return_value
op = LlamaIndexEmbeddingOperator(
@@ -180,6 +268,11 @@ class TestEmbeddingOperatorPersist:
mock_makedirs.assert_called_once_with(str(tmp_path / "idx"),
exist_ok=True)
index.storage_context.persist.assert_called_once_with(persist_dir=str(tmp_path
/ "idx"))
+ # Nodes are already embedded when handed to the index (the
+ # no-double-embed behavior itself is pinned by
+ # ``test_persist_path_embeds_each_chunk_once_with_real_llama_index``).
+ nodes_arg = _li["VectorStoreIndex"].call_args.args[0]
+ assert nodes_arg[0].embedding == [0.1]
@patch("airflow.sdk.ObjectStoragePath")
@patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model")
@@ -195,6 +288,7 @@ class TestEmbeddingOperatorPersist:
node = _node()
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value =
[node]
+ mock_get_embed.return_value = _byo_embedding()
index = _li["VectorStoreIndex"].return_value
op = LlamaIndexEmbeddingOperator(