This is an automated email from the ASF dual-hosted git repository. vikramkoka pushed a commit to branch aip99-llamaindex in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 57fc8336074017168f180dc448f210ccc2753140 Author: Vikram Koka <[email protected]> AuthorDate: Mon May 18 16:17:37 2026 +0100 Add LlamaIndex operators to common.ai provider - Adds LlamaIndexHook to bridge Airflow connections to LlamaIndex's Settings singleton. Reuses the pydanticai connection type, supports separate embedding and LLM connections. - Adds EmbeddingOperator to chunk documents and produce embedding vectors via LlamaIndex's SentenceSplitter. Input is list[dict(text, metadata)] (same shape as DocumentLoaderOperator output), output includes chunks with vectors ready for downstream vector store ingest operators (pgvector, Pinecone, Weaviate). - Adds RetrievalOperator to load a persisted LlamaIndex index and perform similarity search. Output is scored chunks ready for synthesis via LLMOperator. Design notes All LlamaIndex imports are lazy (inside execute() / method bodies), so modules parse without llama-index installed. The hook currently hardcodes OpenAI embedding/LLM providers; a follow-up PR will refactor to use BaseAIHook for provider-agnostic model resolution when it lands. What's included ┌─────────────────────────────────────────┬──────────────────────────────────────────┐ │ File │ Purpose │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ hooks/llamaindex.py │ Hook (~110 lines) │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ operators/llamaindex_embedding.py │ EmbeddingOperator (~110 lines) │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ operators/llamaindex_retrieval.py │ RetrievalOperator (~90 lines) │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ tests/.../test_llamaindex.py │ 12 hook tests │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ tests/.../test_llamaindex_embedding.py │ 10 operator tests │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ tests/.../test_llamaindex_retrieval.py │ 8 operator tests │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ docs/hooks/llamaindex.rst │ Hook docs │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ docs/operators/llamaindex_embedding.rst │ EmbeddingOperator docs │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ docs/operators/llamaindex_retrieval.rst │ RetrievalOperator docs │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ provider.yaml │ Integration, hook, operator registration │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ docs/index.rst │ LlamaIndex Hook in Guides toctree │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ docs/operators/index.rst │ Chooser table rows │ └─────────────────────────────────────────┴──────────────────────────────────────────┘ Test plan - uv run --project providers/common/ai pytest providers/common/ai/tests/unit/common/ai/hooks/test_llamaindex.py -xvs (12 tests) - uv run --project providers/common/ai pytest providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py -xvs (18 tests) - Hook: init defaults, separate embed_conn_id, connection kwargs extraction, embedding model, LLM, Settings configuration - EmbeddingOperator: output shape, chunking, index persistence, vector inclusion/omission, splitter params - RetrievalOperator: output shape, chunk keys, top_k forwarding, multiple results, storage context --- Was generative AI tooling used to co-author this PR? - Yes — Claude Code (Opus 4.6) Generated-by: Claude Code (Opus 4.6) following https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#gen-ai-assisted-contributions --- providers/common/ai/docs/hooks/llamaindex.rst | 87 +++++++++ providers/common/ai/docs/index.rst | 1 + providers/common/ai/docs/operators/index.rst | 8 +- .../ai/docs/operators/llamaindex_embedding.rst | 135 ++++++++++++++ .../ai/docs/operators/llamaindex_retrieval.rst | 108 +++++++++++ providers/common/ai/provider.yaml | 11 ++ .../providers/common/ai/hooks/llamaindex.py | 110 +++++++++++ .../common/ai/operators/llamaindex_embedding.py | 109 +++++++++++ .../common/ai/operators/llamaindex_retrieval.py | 96 ++++++++++ .../tests/unit/common/ai/hooks/test_llamaindex.py | 196 ++++++++++++++++++++ .../ai/operators/test_llamaindex_embedding.py | 202 +++++++++++++++++++++ .../ai/operators/test_llamaindex_retrieval.py | 199 ++++++++++++++++++++ 12 files changed, 1261 insertions(+), 1 deletion(-) diff --git a/providers/common/ai/docs/hooks/llamaindex.rst b/providers/common/ai/docs/hooks/llamaindex.rst new file mode 100644 index 00000000000..ff942a2f65b --- /dev/null +++ b/providers/common/ai/docs/hooks/llamaindex.rst @@ -0,0 +1,87 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/hook:llamaindex: + +``LlamaIndexHook`` +================== + +Use :class:`~airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook` to +bridge Airflow connections to LlamaIndex's ``Settings`` singleton. The hook +reuses the ``pydanticai`` connection type, so users configure a single +connection for both pydantic-ai operators and LlamaIndex operators. + +.. seealso:: + :ref:`Connection configuration <howto/connection:pydanticai>` + +What It Does +------------ + +The hook resolves API keys and base URLs from Airflow connections and uses +them to configure LlamaIndex's embedding models, LLMs, and global settings. +This eliminates manual ``Settings.embed_model = ...`` boilerplate in every +task that uses LlamaIndex. + +Configuration +------------- + +``LlamaIndexHook`` reuses the ``pydanticai`` connection type. Set the API key +in the **Password** field and optionally a custom endpoint in the **Host** +field. + +Separate Embedding and LLM Connections +-------------------------------------- + +RAG pipelines often use different providers for embeddings and chat. The hook +supports an optional ``embed_conn_id`` parameter that defaults to the main +``llm_conn_id``: + +.. code-block:: python + + from airflow.providers.common.ai.hooks.llamaindex import LlamaIndexHook + + hook = LlamaIndexHook( + llm_conn_id="openai_default", + embed_conn_id="embedding_provider", + embed_model="text-embedding-3-large", + llm_model="gpt-4o", + ) + hook.configure_settings() + +Parameters +---------- + +.. list-table:: + :header-rows: 1 + :widths: 25 15 60 + + * - Parameter + - Default + - Description + * - ``llm_conn_id`` + - ``pydanticai_default`` + - Airflow connection ID for the LLM/embedding provider. + * - ``embed_conn_id`` + - Same as ``llm_conn_id`` + - Separate connection for embeddings (optional). + * - ``embed_model`` + - ``text-embedding-3-small`` + - Embedding model name. + * - ``llm_model`` + - ``None`` + - LLM model name. Required for ``get_llm()`` and ``configure_settings()`` + LLM setup. diff --git a/providers/common/ai/docs/index.rst b/providers/common/ai/docs/index.rst index e96ba4cfd27..a842fdf358f 100644 --- a/providers/common/ai/docs/index.rst +++ b/providers/common/ai/docs/index.rst @@ -37,6 +37,7 @@ Connection types <connections/pydantic_ai> MCP connection <connections/mcp> Hooks <hooks/pydantic_ai> + LlamaIndex Hook <hooks/llamaindex> Toolsets <toolsets> Operators <operators/index> HITL Review <hitl_review> diff --git a/providers/common/ai/docs/operators/index.rst b/providers/common/ai/docs/operators/index.rst index 89ba5d15e6c..64f61e84c28 100644 --- a/providers/common/ai/docs/operators/index.rst +++ b/providers/common/ai/docs/operators/index.rst @@ -21,7 +21,7 @@ Common AI Operators Choosing the right operator --------------------------- -The common-ai provider ships five operators (and matching ``@task`` decorators). Use this table +The common-ai provider ships several operators (and matching ``@task`` decorators). Use this table to pick the one that fits your use case: .. list-table:: @@ -46,6 +46,12 @@ to pick the one that fits your use case: * - Multi-turn reasoning with tools (DB queries, API calls, etc.) - :class:`~airflow.providers.common.ai.operators.agent.AgentOperator` - ``@task.agent`` + * - Chunk documents and produce embedding vectors + - :class:`~airflow.providers.common.ai.operators.llamaindex_embedding.EmbeddingOperator` + - — + * - Retrieve relevant chunks from a vector index + - :class:`~airflow.providers.common.ai.operators.llamaindex_retrieval.RetrievalOperator` + - — **LLMOperator / @task.llm** — stateless, single-turn calls. Use this for classification, summarization, extraction, or any prompt that produces one response. Supports structured output diff --git a/providers/common/ai/docs/operators/llamaindex_embedding.rst b/providers/common/ai/docs/operators/llamaindex_embedding.rst new file mode 100644 index 00000000000..2a32d056dc2 --- /dev/null +++ b/providers/common/ai/docs/operators/llamaindex_embedding.rst @@ -0,0 +1,135 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/operator:llamaindex_embedding: + +``EmbeddingOperator`` +===================== + +Use :class:`~airflow.providers.common.ai.operators.llamaindex_embedding.EmbeddingOperator` +to chunk documents and produce embedding vectors using LlamaIndex. This operator +bridges document loading (Airflow provider hooks returning text) and vector +storage (pgvector, Pinecone, Weaviate ingest operators). + +Basic Usage +----------- + +Provide a list of documents with ``text`` and ``metadata`` keys. The operator +chunks the documents, embeds them, and returns the results: + +.. code-block:: python + + from airflow.providers.common.ai.operators.llamaindex_embedding import EmbeddingOperator + + embed = EmbeddingOperator( + task_id="embed_docs", + documents=[ + {"text": "Airflow is a workflow orchestration platform.", "metadata": {"source": "docs"}}, + {"text": "LlamaIndex is a data framework for LLM applications.", "metadata": {"source": "docs"}}, + ], + llm_conn_id="openai_default", + ) + +Connection Configuration +------------------------ + +The operator uses :class:`~airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook` +internally. Configure your embedding API credentials via the ``pydanticai`` +connection type. + +.. seealso:: + :ref:`Connection configuration <howto/connection:pydanticai>` + +Chunking Parameters +------------------- + +Control how documents are split into chunks before embedding: + +.. code-block:: python + + embed = EmbeddingOperator( + task_id="embed_docs", + documents=documents, + llm_conn_id="openai_default", + chunk_size=256, + chunk_overlap=25, + ) + +Index Persistence +----------------- + +Set ``persist_dir`` to save the LlamaIndex index for later retrieval via +:class:`~airflow.providers.common.ai.operators.llamaindex_retrieval.RetrievalOperator`: + +.. code-block:: python + + embed = EmbeddingOperator( + task_id="embed_docs", + documents=documents, + llm_conn_id="openai_default", + persist_dir="/opt/airflow/data/my_index", + ) + +Output Shape +------------ + +The operator returns a dict: + +.. code-block:: python + + { + "document_count": 2, + "chunk_count": 5, + "persist_dir": "/opt/airflow/data/my_index", + "chunks": [ + {"text": "chunk text", "metadata": {"source": "docs"}, "vector": [0.1, ...]}, + ... + ], + } + +Each chunk includes ``text``, ``metadata``, and optionally ``vector`` (the +embedding array). The ``chunks`` list is ready for downstream consumption by +vector store ingest operators. + +Parameters +---------- + +.. list-table:: + :header-rows: 1 + :widths: 25 15 60 + + * - Parameter + - Default + - Description + * - ``documents`` + - (required) + - List of dicts with ``text`` and ``metadata`` keys. + * - ``llm_conn_id`` + - ``pydanticai_default`` + - Airflow connection ID for the embedding API. + * - ``embed_model`` + - ``text-embedding-3-small`` + - Embedding model name. + * - ``chunk_size`` + - ``512`` + - Chunk size for the sentence splitter. + * - ``chunk_overlap`` + - ``50`` + - Overlap between chunks. + * - ``persist_dir`` + - ``None`` + - Directory path to persist the index. diff --git a/providers/common/ai/docs/operators/llamaindex_retrieval.rst b/providers/common/ai/docs/operators/llamaindex_retrieval.rst new file mode 100644 index 00000000000..ff238744cf1 --- /dev/null +++ b/providers/common/ai/docs/operators/llamaindex_retrieval.rst @@ -0,0 +1,108 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/operator:llamaindex_retrieval: + +``RetrievalOperator`` +===================== + +Use :class:`~airflow.providers.common.ai.operators.llamaindex_retrieval.RetrievalOperator` +to retrieve relevant document chunks from a persisted LlamaIndex index. The +operator performs similarity search against the provided query and returns +results ready for downstream synthesis via ``LLMOperator``. + +Basic Usage +----------- + +Provide a query string and the path to a previously persisted index: + +.. code-block:: python + + from airflow.providers.common.ai.operators.llamaindex_retrieval import RetrievalOperator + + retrieve = RetrievalOperator( + task_id="retrieve_context", + query="What are Airflow's key features?", + index_persist_dir="/opt/airflow/data/my_index", + llm_conn_id="openai_default", + ) + +Query Templating +---------------- + +The ``query`` field supports Jinja templating, so it can be set dynamically +from upstream task output or Dag run configuration: + +.. code-block:: python + + retrieve = RetrievalOperator( + task_id="retrieve_context", + query="{{ dag_run.conf['question'] }}", + index_persist_dir="/opt/airflow/data/my_index", + llm_conn_id="openai_default", + top_k=10, + ) + +Output Shape +------------ + +The operator returns a dict: + +.. code-block:: python + + { + "question": "What are Airflow's key features?", + "chunks": [ + { + "text": "Airflow provides ...", + "score": 0.95, + "metadata": {"source": "overview.txt"}, + "source": "node-abc123", + }, + ... + ], + } + +Each chunk includes ``text``, ``score`` (similarity), ``metadata``, and +``source`` (the LlamaIndex node ID). This output pairs naturally with +``LLMOperator`` for RAG synthesis using Jinja templates. + +Parameters +---------- + +.. list-table:: + :header-rows: 1 + :widths: 25 15 60 + + * - Parameter + - Default + - Description + * - ``query`` + - (required) + - The query string to search for. Supports Jinja templating. + * - ``index_persist_dir`` + - (required) + - Path to the persisted LlamaIndex index directory. + * - ``llm_conn_id`` + - ``pydanticai_default`` + - Airflow connection ID for the embedding API. + * - ``embed_model`` + - ``text-embedding-3-small`` + - Embedding model name. + * - ``top_k`` + - ``5`` + - Number of top results to retrieve. diff --git a/providers/common/ai/provider.yaml b/providers/common/ai/provider.yaml index 2a13392ea99..5a148718345 100644 --- a/providers/common/ai/provider.yaml +++ b/providers/common/ai/provider.yaml @@ -48,6 +48,12 @@ integrations: - integration-name: MCP Server external-doc-url: https://modelcontextprotocol.io/ tags: [ai] + - integration-name: LlamaIndex + external-doc-url: https://docs.llamaindex.ai/ + how-to-guide: + - /docs/apache-airflow-providers-common-ai/operators/llamaindex_embedding.rst + - /docs/apache-airflow-providers-common-ai/operators/llamaindex_retrieval.rst + tags: [ai] hooks: - integration-name: Pydantic AI @@ -56,6 +62,9 @@ hooks: - integration-name: MCP Server python-modules: - airflow.providers.common.ai.hooks.mcp + - integration-name: LlamaIndex + python-modules: + - airflow.providers.common.ai.hooks.llamaindex plugins: - name: hitl_review @@ -323,6 +332,8 @@ operators: - airflow.providers.common.ai.operators.llm_branch - airflow.providers.common.ai.operators.llm_sql - airflow.providers.common.ai.operators.llm_schema_compare + - airflow.providers.common.ai.operators.llamaindex_embedding + - airflow.providers.common.ai.operators.llamaindex_retrieval task-decorators: - class-name: airflow.providers.common.ai.decorators.agent.agent_task diff --git a/providers/common/ai/src/airflow/providers/common/ai/hooks/llamaindex.py b/providers/common/ai/src/airflow/providers/common/ai/hooks/llamaindex.py new file mode 100644 index 00000000000..7c3272c4cdb --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/hooks/llamaindex.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for LlamaIndex integration with Airflow connections.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from airflow.providers.common.compat.sdk import BaseHook + +if TYPE_CHECKING: + from llama_index.core.base.embeddings.base import BaseEmbedding + from llama_index.core.llms.llm import LLM + + +class LlamaIndexHook(BaseHook): + """ + Bridge Airflow connections to LlamaIndex's Settings singleton. + + Reuses the ``pydanticai`` connection type so users configure a single + connection for both pydantic-ai operators and LlamaIndex operators. + + :param llm_conn_id: Airflow connection ID for the LLM/embedding provider. + :param embed_conn_id: Separate connection for embeddings. Defaults to + ``llm_conn_id`` when not provided. + :param embed_model: Embedding model name (e.g. ``text-embedding-3-small``). + :param llm_model: LLM model name (e.g. ``gpt-4o``). Only needed when + configuring ``Settings.llm``. + """ + + conn_name_attr = "llm_conn_id" + default_conn_name = "pydanticai_default" + conn_type = "pydanticai" + hook_name = "LlamaIndex" + + def __init__( + self, + llm_conn_id: str = "pydanticai_default", + embed_conn_id: str | None = None, + embed_model: str = "text-embedding-3-small", + llm_model: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.llm_conn_id = llm_conn_id + self.embed_conn_id = embed_conn_id or llm_conn_id + self.embed_model = embed_model + self.llm_model = llm_model + + def _resolve_connection_kwargs(self, conn_id: str) -> dict[str, Any]: + """Extract API key and base URL from an Airflow connection.""" + conn = self.get_connection(conn_id) + kwargs: dict[str, Any] = {} + if conn.password: + kwargs["api_key"] = conn.password + if conn.host: + kwargs["api_base"] = conn.host + return kwargs + + def get_embedding_model(self) -> BaseEmbedding: + """ + Return a LlamaIndex embedding model configured from the Airflow connection. + + Uses ``embed_conn_id`` (falls back to ``llm_conn_id``) for credentials. + """ + from llama_index.embeddings.openai import OpenAIEmbedding + + conn_kwargs = self._resolve_connection_kwargs(self.embed_conn_id) + return OpenAIEmbedding(model=self.embed_model, **conn_kwargs) + + def get_llm(self) -> LLM: + """ + Return a LlamaIndex LLM configured from the Airflow connection. + + Requires ``llm_model`` to be set on the hook. + """ + if not self.llm_model: + raise ValueError("llm_model must be set to use get_llm()") + + from llama_index.llms.openai import OpenAI + + conn_kwargs = self._resolve_connection_kwargs(self.llm_conn_id) + return OpenAI(model=self.llm_model, **conn_kwargs) + + def configure_settings(self) -> None: + """ + Configure LlamaIndex's global Settings with models from Airflow connections. + + Sets ``Settings.embed_model`` always, and ``Settings.llm`` when + ``llm_model`` is provided. + """ + from llama_index.core import Settings + + Settings.embed_model = self.get_embedding_model() + if self.llm_model: + Settings.llm = self.get_llm() diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py new file mode 100644 index 00000000000..acbbc46dabb --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operator for document chunking and embedding via LlamaIndex.""" + +from __future__ import annotations + +import os +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from airflow.providers.common.compat.sdk import BaseOperator + +if TYPE_CHECKING: + from airflow.sdk import Context + + +class EmbeddingOperator(BaseOperator): + """ + Chunk documents and produce embedding vectors using LlamaIndex. + + Bridges document loading (Airflow provider hooks returning text) and + vector storage (pgvector, Pinecone, Weaviate ingest operators). Input + is ``list[dict]`` with ``text`` and ``metadata`` keys; output includes + the embedding vectors ready for downstream storage. + + :param documents: List of dicts with ``text`` and ``metadata`` keys, + typically from ``DocumentLoaderOperator`` or a ``@task``. + :param llm_conn_id: Airflow connection ID for the embedding API. + :param embed_model: Embedding model name (default: ``text-embedding-3-small``). + :param chunk_size: Chunk size for the sentence splitter (default: 512). + :param chunk_overlap: Overlap between chunks (default: 50). + :param persist_dir: Optional directory path to persist the LlamaIndex + index for later retrieval. + """ + + template_fields: Sequence[str] = ("documents", "llm_conn_id", "persist_dir") + + def __init__( + self, + *, + documents: list[dict[str, Any]], + llm_conn_id: str = "pydanticai_default", + embed_model: str = "text-embedding-3-small", + chunk_size: int = 512, + chunk_overlap: int = 50, + persist_dir: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.documents = documents + self.llm_conn_id = llm_conn_id + self.embed_model = embed_model + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.persist_dir = persist_dir + + def execute(self, context: Context) -> dict[str, Any]: + from llama_index.core import Document, StorageContext, VectorStoreIndex + from llama_index.core.node_parser import SentenceSplitter + + from airflow.providers.common.ai.hooks.llamaindex import LlamaIndexHook + + hook = LlamaIndexHook(llm_conn_id=self.llm_conn_id, embed_model=self.embed_model) + hook.configure_settings() + + llama_docs = [Document(text=doc["text"], metadata=doc.get("metadata", {})) for doc in self.documents] + + splitter = SentenceSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) + nodes = splitter.get_nodes_from_documents(llama_docs) + self.log.info("Split %d documents into %d chunks", len(llama_docs), len(nodes)) + + storage_context = StorageContext.from_defaults() + VectorStoreIndex(nodes, storage_context=storage_context, show_progress=False) + + if self.persist_dir: + os.makedirs(self.persist_dir, exist_ok=True) + storage_context.persist(persist_dir=self.persist_dir) + self.log.info("Index persisted to %s", self.persist_dir) + + chunks = [] + for node in nodes: + chunk: dict[str, Any] = { + "text": node.text, + "metadata": node.metadata, + } + if node.embedding: + chunk["vector"] = node.embedding + chunks.append(chunk) + + return { + "document_count": len(llama_docs), + "chunk_count": len(nodes), + "persist_dir": self.persist_dir, + "chunks": chunks, + } diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py new file mode 100644 index 00000000000..6089f7a4c62 --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operator for semantic retrieval via a persisted LlamaIndex index.""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from airflow.providers.common.compat.sdk import BaseOperator + +if TYPE_CHECKING: + from airflow.sdk import Context + + +class RetrievalOperator(BaseOperator): + """ + Retrieve relevant document chunks from a persisted LlamaIndex index. + + Loads a previously persisted vector store index and performs similarity + search against the provided query. The output is a list of chunks with + text, score, metadata, and source information ready for downstream + synthesis via ``LLMOperator``. + + :param query: The query string to search for. Supports Jinja templating. + :param index_persist_dir: Path to the persisted LlamaIndex index directory. + :param llm_conn_id: Airflow connection ID for the embedding API + (needed to embed the query vector). + :param embed_model: Embedding model name (default: ``text-embedding-3-small``). + :param top_k: Number of top results to retrieve (default: 5). + """ + + template_fields: Sequence[str] = ("query", "index_persist_dir", "llm_conn_id") + + def __init__( + self, + *, + query: str, + index_persist_dir: str, + llm_conn_id: str = "pydanticai_default", + embed_model: str = "text-embedding-3-small", + top_k: int = 5, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.query = query + self.index_persist_dir = index_persist_dir + self.llm_conn_id = llm_conn_id + self.embed_model = embed_model + self.top_k = top_k + + def execute(self, context: Context) -> dict[str, Any]: + from llama_index.core import StorageContext, load_index_from_storage + + from airflow.providers.common.ai.hooks.llamaindex import LlamaIndexHook + + hook = LlamaIndexHook(llm_conn_id=self.llm_conn_id, embed_model=self.embed_model) + hook.configure_settings() + + storage_context = StorageContext.from_defaults(persist_dir=self.index_persist_dir) + index = load_index_from_storage(storage_context) + + retriever = index.as_retriever(similarity_top_k=self.top_k) + results = retriever.retrieve(self.query) + self.log.info("Retrieved %d chunks for query: %s", len(results), self.query[:100]) + + chunks = [] + for node_with_score in results: + node = node_with_score.node + chunks.append( + { + "text": node.get_content(), + "score": node_with_score.score, + "metadata": node.metadata, + "source": node.node_id, + } + ) + + return { + "question": self.query, + "chunks": chunks, + } diff --git a/providers/common/ai/tests/unit/common/ai/hooks/test_llamaindex.py b/providers/common/ai/tests/unit/common/ai/hooks/test_llamaindex.py new file mode 100644 index 00000000000..3b119e3e5b4 --- /dev/null +++ b/providers/common/ai/tests/unit/common/ai/hooks/test_llamaindex.py @@ -0,0 +1,196 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.providers.common.ai.hooks.llamaindex import LlamaIndexHook + + +class TestLlamaIndexHookInit: + def test_default_params(self): + hook = LlamaIndexHook() + assert hook.llm_conn_id == "pydanticai_default" + assert hook.embed_conn_id == "pydanticai_default" + assert hook.embed_model == "text-embedding-3-small" + assert hook.llm_model is None + + def test_separate_embed_conn_id(self): + hook = LlamaIndexHook(llm_conn_id="llm_conn", embed_conn_id="embed_conn") + assert hook.llm_conn_id == "llm_conn" + assert hook.embed_conn_id == "embed_conn" + + def test_embed_conn_defaults_to_llm_conn(self): + hook = LlamaIndexHook(llm_conn_id="my_conn") + assert hook.embed_conn_id == "my_conn" + + +class TestResolveConnectionKwargs: + @patch.object(LlamaIndexHook, "get_connection") + def test_extracts_password_as_api_key(self, mock_get_conn): + mock_conn = MagicMock() + mock_conn.password = "sk-test-key" + mock_conn.host = "" + mock_get_conn.return_value = mock_conn + + hook = LlamaIndexHook() + result = hook._resolve_connection_kwargs("test_conn") + + assert result == {"api_key": "sk-test-key"} + + @patch.object(LlamaIndexHook, "get_connection") + def test_extracts_host_as_api_base(self, mock_get_conn): + mock_conn = MagicMock() + mock_conn.password = "" + mock_conn.host = "https://custom.api.com" + mock_get_conn.return_value = mock_conn + + hook = LlamaIndexHook() + result = hook._resolve_connection_kwargs("test_conn") + + assert result == {"api_base": "https://custom.api.com"} + + @patch.object(LlamaIndexHook, "get_connection") + def test_both_password_and_host(self, mock_get_conn): + mock_conn = MagicMock() + mock_conn.password = "sk-key" + mock_conn.host = "https://api.example.com" + mock_get_conn.return_value = mock_conn + + hook = LlamaIndexHook() + result = hook._resolve_connection_kwargs("test_conn") + + assert result == {"api_key": "sk-key", "api_base": "https://api.example.com"} + + @patch.object(LlamaIndexHook, "get_connection") + def test_empty_fields_return_empty_dict(self, mock_get_conn): + mock_conn = MagicMock() + mock_conn.password = "" + mock_conn.host = "" + mock_get_conn.return_value = mock_conn + + hook = LlamaIndexHook() + result = hook._resolve_connection_kwargs("test_conn") + + assert result == {} + + +def _make_mock_openai_embedding_module(): + mock_module = MagicMock() + mock_cls = MagicMock() + mock_module.OpenAIEmbedding = mock_cls + return mock_module, mock_cls + + +def _make_mock_openai_llm_module(): + mock_module = MagicMock() + mock_cls = MagicMock() + mock_module.OpenAI = mock_cls + return mock_module, mock_cls + + +class TestGetEmbeddingModel: + @patch.object(LlamaIndexHook, "get_connection") + def test_returns_openai_embedding(self, mock_get_conn): + mock_conn = MagicMock() + mock_conn.password = "sk-test" + mock_conn.host = "" + mock_get_conn.return_value = mock_conn + + mock_embed_module, mock_embed_cls = _make_mock_openai_embedding_module() + + hook = LlamaIndexHook(embed_model="text-embedding-3-large") + with patch.dict("sys.modules", {"llama_index.embeddings.openai": mock_embed_module}): + result = hook.get_embedding_model() + + mock_embed_cls.assert_called_once_with(model="text-embedding-3-large", api_key="sk-test") + assert result == mock_embed_cls.return_value + + +class TestGetLLM: + def test_raises_without_llm_model(self): + hook = LlamaIndexHook() + with pytest.raises(ValueError, match="llm_model must be set"): + hook.get_llm() + + @patch.object(LlamaIndexHook, "get_connection") + def test_returns_openai_llm(self, mock_get_conn): + mock_conn = MagicMock() + mock_conn.password = "sk-test" + mock_conn.host = "" + mock_get_conn.return_value = mock_conn + + mock_llm_module, mock_llm_cls = _make_mock_openai_llm_module() + + hook = LlamaIndexHook(llm_model="gpt-4o") + with patch.dict("sys.modules", {"llama_index.llms.openai": mock_llm_module}): + result = hook.get_llm() + + mock_llm_cls.assert_called_once_with(model="gpt-4o", api_key="sk-test") + assert result == mock_llm_cls.return_value + + +class TestConfigureSettings: + @patch.object(LlamaIndexHook, "get_connection") + def test_sets_embed_model(self, mock_get_conn): + mock_conn = MagicMock() + mock_conn.password = "sk-test" + mock_conn.host = "" + mock_get_conn.return_value = mock_conn + + mock_embed_module, mock_embed_cls = _make_mock_openai_embedding_module() + mock_settings_module = MagicMock() + + hook = LlamaIndexHook() + with patch.dict( + "sys.modules", + { + "llama_index.embeddings.openai": mock_embed_module, + "llama_index": MagicMock(), + "llama_index.core": mock_settings_module, + }, + ): + hook.configure_settings() + + assert mock_settings_module.Settings.embed_model == mock_embed_cls.return_value + + @patch.object(LlamaIndexHook, "get_connection") + def test_sets_llm_when_model_provided(self, mock_get_conn): + mock_conn = MagicMock() + mock_conn.password = "sk-test" + mock_conn.host = "" + mock_get_conn.return_value = mock_conn + + mock_embed_module, _ = _make_mock_openai_embedding_module() + mock_llm_module, mock_llm_cls = _make_mock_openai_llm_module() + mock_settings_module = MagicMock() + + hook = LlamaIndexHook(llm_model="gpt-4o") + with patch.dict( + "sys.modules", + { + "llama_index.embeddings.openai": mock_embed_module, + "llama_index.llms.openai": mock_llm_module, + "llama_index": MagicMock(), + "llama_index.core": mock_settings_module, + }, + ): + hook.configure_settings() + + assert mock_settings_module.Settings.llm == mock_llm_cls.return_value diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py new file mode 100644 index 00000000000..ee3e8c51a56 --- /dev/null +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py @@ -0,0 +1,202 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from airflow.providers.common.ai.operators.llamaindex_embedding import EmbeddingOperator + + +def _make_mock_node(text="chunk text", metadata=None, embedding=None): + node = MagicMock() + node.text = text + node.metadata = metadata or {} + node.embedding = embedding + return node + + +def _make_mock_llamaindex_modules(nodes=None): + """Create mock llama_index modules for sys.modules injection.""" + if nodes is None: + nodes = [_make_mock_node()] + + mock_core = MagicMock() + mock_core.Document = MagicMock(side_effect=lambda text, metadata: MagicMock(text=text, metadata=metadata)) + mock_core.StorageContext.from_defaults.return_value = MagicMock() + mock_core.VectorStoreIndex = MagicMock() + + mock_node_parser = MagicMock() + mock_splitter = MagicMock() + mock_splitter.get_nodes_from_documents.return_value = nodes + mock_node_parser.SentenceSplitter.return_value = mock_splitter + + return ( + { + "llama_index": MagicMock(), + "llama_index.core": mock_core, + "llama_index.core.node_parser": mock_node_parser, + "llama_index.embeddings": MagicMock(), + "llama_index.embeddings.openai": MagicMock(), + }, + mock_core, + mock_splitter, + ) + + +class TestEmbeddingOperator: + def test_template_fields(self): + expected = {"documents", "llm_conn_id", "persist_dir"} + assert set(EmbeddingOperator.template_fields) == expected + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_execute_returns_expected_shape(self, mock_hook_cls): + docs = [{"text": "Hello world", "metadata": {"source": "test"}}] + nodes = [_make_mock_node(text="Hello world", metadata={"source": "test"})] + mock_modules, mock_core, mock_splitter = _make_mock_llamaindex_modules(nodes) + + op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + assert "document_count" in result + assert "chunk_count" in result + assert "persist_dir" in result + assert "chunks" in result + assert result["document_count"] == 1 + assert result["chunk_count"] == 1 + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_chunking_node_count(self, mock_hook_cls): + docs = [{"text": "A long document " * 100, "metadata": {}}] + nodes = [_make_mock_node(text=f"chunk {i}") for i in range(5)] + mock_modules, mock_core, mock_splitter = _make_mock_llamaindex_modules(nodes) + + op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + assert result["chunk_count"] == 5 + assert len(result["chunks"]) == 5 + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_persist_dir_creates_and_persists(self, mock_hook_cls, tmp_path): + docs = [{"text": "test", "metadata": {}}] + persist_dir = str(tmp_path / "index_storage") + mock_modules, mock_core, _ = _make_mock_llamaindex_modules() + mock_storage_ctx = mock_core.StorageContext.from_defaults.return_value + + op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn", persist_dir=persist_dir) + + with patch.dict("sys.modules", mock_modules): + op.execute(context=MagicMock()) + + mock_storage_ctx.persist.assert_called_once_with(persist_dir=persist_dir) + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_no_persist_when_none(self, mock_hook_cls): + docs = [{"text": "test", "metadata": {}}] + mock_modules, mock_core, _ = _make_mock_llamaindex_modules() + mock_storage_ctx = mock_core.StorageContext.from_defaults.return_value + + op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") + + with patch.dict("sys.modules", mock_modules): + op.execute(context=MagicMock()) + + mock_storage_ctx.persist.assert_not_called() + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_chunks_have_text_and_metadata(self, mock_hook_cls): + docs = [{"text": "test", "metadata": {"src": "a"}}] + nodes = [_make_mock_node(text="chunk1", metadata={"src": "a"})] + mock_modules, _, _ = _make_mock_llamaindex_modules(nodes) + + op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + chunk = result["chunks"][0] + assert "text" in chunk + assert "metadata" in chunk + assert chunk["text"] == "chunk1" + assert chunk["metadata"] == {"src": "a"} + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_chunks_include_vector_when_present(self, mock_hook_cls): + docs = [{"text": "test", "metadata": {}}] + nodes = [_make_mock_node(text="chunk1", embedding=[0.1, 0.2, 0.3])] + mock_modules, _, _ = _make_mock_llamaindex_modules(nodes) + + op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + assert result["chunks"][0]["vector"] == [0.1, 0.2, 0.3] + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_chunks_omit_vector_when_not_present(self, mock_hook_cls): + docs = [{"text": "test", "metadata": {}}] + nodes = [_make_mock_node(text="chunk1", embedding=None)] + mock_modules, _, _ = _make_mock_llamaindex_modules(nodes) + + op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + assert "vector" not in result["chunks"][0] + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_hook_configured_with_params(self, mock_hook_cls): + docs = [{"text": "test", "metadata": {}}] + mock_modules, _, _ = _make_mock_llamaindex_modules() + + op = EmbeddingOperator( + task_id="test", + documents=docs, + llm_conn_id="custom_conn", + embed_model="text-embedding-ada-002", + ) + + with patch.dict("sys.modules", mock_modules): + op.execute(context=MagicMock()) + + mock_hook_cls.assert_called_once_with(llm_conn_id="custom_conn", embed_model="text-embedding-ada-002") + mock_hook_cls.return_value.configure_settings.assert_called_once() + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_splitter_params_forwarded(self, mock_hook_cls): + docs = [{"text": "test", "metadata": {}}] + mock_modules, _, _ = _make_mock_llamaindex_modules() + mock_node_parser = mock_modules["llama_index.core.node_parser"] + + op = EmbeddingOperator( + task_id="test", + documents=docs, + llm_conn_id="my_conn", + chunk_size=256, + chunk_overlap=25, + ) + + with patch.dict("sys.modules", mock_modules): + op.execute(context=MagicMock()) + + mock_node_parser.SentenceSplitter.assert_called_once_with(chunk_size=256, chunk_overlap=25) diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py new file mode 100644 index 00000000000..0c85e86c214 --- /dev/null +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from airflow.providers.common.ai.operators.llamaindex_retrieval import RetrievalOperator + + +def _make_mock_node_with_score(text="chunk text", score=0.9, metadata=None, node_id="node-1"): + node = MagicMock() + node.get_content.return_value = text + node.metadata = metadata or {} + node.node_id = node_id + + node_with_score = MagicMock() + node_with_score.node = node + node_with_score.score = score + return node_with_score + + +def _make_mock_llamaindex_modules(retrieval_results=None): + """Create mock llama_index modules for sys.modules injection.""" + if retrieval_results is None: + retrieval_results = [_make_mock_node_with_score()] + + mock_core = MagicMock() + mock_index = MagicMock() + mock_retriever = MagicMock() + mock_retriever.retrieve.return_value = retrieval_results + mock_index.as_retriever.return_value = mock_retriever + mock_core.load_index_from_storage.return_value = mock_index + + return ( + { + "llama_index": MagicMock(), + "llama_index.core": mock_core, + "llama_index.embeddings": MagicMock(), + "llama_index.embeddings.openai": MagicMock(), + }, + mock_core, + mock_index, + mock_retriever, + ) + + +class TestRetrievalOperator: + def test_template_fields(self): + expected = {"query", "index_persist_dir", "llm_conn_id"} + assert set(RetrievalOperator.template_fields) == expected + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_execute_returns_expected_shape(self, mock_hook_cls): + results = [_make_mock_node_with_score(text="relevant chunk", score=0.95)] + mock_modules, mock_core, _, _ = _make_mock_llamaindex_modules(results) + + op = RetrievalOperator( + task_id="test", + query="What is Airflow?", + index_persist_dir="/tmp/index", + llm_conn_id="my_conn", + ) + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + assert "question" in result + assert "chunks" in result + assert result["question"] == "What is Airflow?" + assert len(result["chunks"]) == 1 + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_chunks_have_required_keys(self, mock_hook_cls): + results = [ + _make_mock_node_with_score( + text="chunk text", score=0.8, metadata={"file": "doc.txt"}, node_id="abc-123" + ) + ] + mock_modules, _, _, _ = _make_mock_llamaindex_modules(results) + + op = RetrievalOperator( + task_id="test", + query="test query", + index_persist_dir="/tmp/index", + llm_conn_id="my_conn", + ) + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + chunk = result["chunks"][0] + assert chunk["text"] == "chunk text" + assert chunk["score"] == 0.8 + assert chunk["metadata"] == {"file": "doc.txt"} + assert chunk["source"] == "abc-123" + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_top_k_forwarded_to_retriever(self, mock_hook_cls): + mock_modules, _, mock_index, _ = _make_mock_llamaindex_modules([]) + + op = RetrievalOperator( + task_id="test", + query="test", + index_persist_dir="/tmp/index", + llm_conn_id="my_conn", + top_k=10, + ) + + with patch.dict("sys.modules", mock_modules): + op.execute(context=MagicMock()) + + mock_index.as_retriever.assert_called_once_with(similarity_top_k=10) + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_query_value_in_output(self, mock_hook_cls): + mock_modules, _, _, _ = _make_mock_llamaindex_modules([]) + + op = RetrievalOperator( + task_id="test", + query="How does Airflow scheduling work?", + index_persist_dir="/tmp/index", + llm_conn_id="my_conn", + ) + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + assert result["question"] == "How does Airflow scheduling work?" + assert result["chunks"] == [] + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_multiple_results_returned(self, mock_hook_cls): + results = [ + _make_mock_node_with_score(text=f"chunk {i}", score=0.9 - i * 0.1, node_id=f"node-{i}") + for i in range(3) + ] + mock_modules, _, _, _ = _make_mock_llamaindex_modules(results) + + op = RetrievalOperator( + task_id="test", + query="test", + index_persist_dir="/tmp/index", + llm_conn_id="my_conn", + ) + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + assert len(result["chunks"]) == 3 + assert result["chunks"][0]["text"] == "chunk 0" + assert result["chunks"][2]["text"] == "chunk 2" + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_hook_configured_with_params(self, mock_hook_cls): + mock_modules, _, _, _ = _make_mock_llamaindex_modules([]) + + op = RetrievalOperator( + task_id="test", + query="test", + index_persist_dir="/tmp/index", + llm_conn_id="custom_conn", + embed_model="text-embedding-ada-002", + ) + + with patch.dict("sys.modules", mock_modules): + op.execute(context=MagicMock()) + + mock_hook_cls.assert_called_once_with(llm_conn_id="custom_conn", embed_model="text-embedding-ada-002") + mock_hook_cls.return_value.configure_settings.assert_called_once() + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_persist_dir_passed_to_storage_context(self, mock_hook_cls): + mock_modules, mock_core, _, _ = _make_mock_llamaindex_modules([]) + + op = RetrievalOperator( + task_id="test", + query="test", + index_persist_dir="/data/my_index", + llm_conn_id="my_conn", + ) + + with patch.dict("sys.modules", mock_modules): + op.execute(context=MagicMock()) + + mock_core.StorageContext.from_defaults.assert_called_once_with(persist_dir="/data/my_index")
