This is an automated email from the ASF dual-hosted git repository.

vikramkoka pushed a commit to branch aip99-langchain
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/aip99-langchain by this push:
     new a34521eab7a Add LangChain hook to common.ai provider
a34521eab7a is described below

commit a34521eab7ab6554aaa179372bc320ad4de1f0c9
Author: Vikram Koka <[email protected]>
AuthorDate: Tue May 19 16:40:56 2026 +0100

    Add LangChain hook to common.ai provider
    
    - Adds LangChainHook to bridge Airflow connections to LangChain model 
constructors (ChatOpenAI, OpenAIEmbeddings),
      using constructor injection for credentials
      - Reuses the existing pydanticai connection type so users configure one 
connection for PydanticAI, LlamaIndex, and
      LangChain
      - Follows the same pattern as LlamaIndexHook: 
_resolve_connection_kwargs() extracts api_key and base_url from the
      Airflow connection and passes them directly to LangChain constructors
      - Adds langchain optional dependency extra (langchain>=1.0.0, 
langchain-openai>=0.3.0)
    
      What's included
    
      - hooks/langchain.py — LangChainHook(BaseHook) with get_chat_model() and 
get_embedding_model()
      - tests/unit/common/ai/hooks/test_langchain.py — full test coverage 
(init, connection resolution, chat model,
      embedding model)
      - docs/hooks/langchain.rst — hook documentation with usage examples
      - provider.yaml — LangChain integration and hook registration
      - pyproject.toml — langchain optional dependency extra
    
      Design decisions
    
      - BaseHook, not BaseAIHook — BaseAIHook is still in development. Will 
migrate in a follow-up PR once it ships.
      - Constructor injection — credentials passed as api_key=/base_url= kwargs 
to LangChain constructors. No environment
      variable mutation. Matches the LlamaIndexHook pattern.
      - Shared connection type — reuses pydanticai connection type rather than 
introducing a new one. One connection works
      across all three frameworks.
      - No @task.langchain yet — consistent with LlamaIndex (no 
@task.llamaindex). Deferred to the BaseAIHook migration PR.
---
 providers/common/ai/docs/hooks/langchain.rst       | 111 ++++++++++++
 providers/common/ai/provider.yaml                  |   6 +
 providers/common/ai/pyproject.toml                 |   4 +
 .../airflow/providers/common/ai/hooks/langchain.py |  97 +++++++++++
 .../tests/unit/common/ai/hooks/test_langchain.py   | 189 +++++++++++++++++++++
 5 files changed, 407 insertions(+)

diff --git a/providers/common/ai/docs/hooks/langchain.rst 
b/providers/common/ai/docs/hooks/langchain.rst
new file mode 100644
index 00000000000..758c8b36b81
--- /dev/null
+++ b/providers/common/ai/docs/hooks/langchain.rst
@@ -0,0 +1,111 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+.. _howto/hook:langchain:
+
+``LangChainHook``
+=================
+
+Use :class:`~airflow.providers.common.ai.hooks.langchain.LangChainHook`
+to bridge Airflow connections to LangChain model constructors.  The hook
+extracts credentials from an Airflow connection and returns configured
+LangChain model objects (``ChatOpenAI``, ``OpenAIEmbeddings``).
+
+The hook reuses the ``pydanticai`` connection type, so users configure a
+single connection for PydanticAI operators, LlamaIndex operators, and
+LangChain tasks.
+
+.. seealso::
+    :ref:`Connection configuration <howto/connection:pydanticai>`
+
+Basic Usage
+-----------
+
+Use the hook in a ``@task`` function to get a configured chat model:
+
+.. code-block:: python
+
+    from airflow.providers.common.ai.hooks.langchain import LangChainHook
+
+    @task
+    def run_chain(query: str) -> str:
+        hook = LangChainHook(llm_conn_id="pydanticai_default", 
llm_model="gpt-4o")
+        llm = hook.get_chat_model()
+
+        from langchain_core.prompts import ChatPromptTemplate
+        from langchain_core.output_parsers import StrOutputParser
+
+        prompt = ChatPromptTemplate.from_template("Summarize: {query}")
+        chain = prompt | llm | StrOutputParser()
+        return chain.invoke({"query": query})
+
+Embedding Models
+----------------
+
+Use :meth:`~LangChainHook.get_embedding_model` for embeddings.  A
+separate ``embed_conn_id`` can be used when embedding and chat models
+use different API keys:
+
+.. code-block:: python
+
+    hook = LangChainHook(
+        llm_conn_id="chat_conn",
+        embed_conn_id="embed_conn",
+        embed_model="text-embedding-3-large",
+        llm_model="gpt-4o",
+    )
+    embeddings = hook.get_embedding_model()
+    chat_model = hook.get_chat_model()
+
+Connection Configuration
+------------------------
+
+The hook reads credentials from the Airflow connection:
+
+- **password** -- API key (passed as ``api_key`` to model constructors)
+- **host** -- Base URL (passed as ``base_url``; optional, for custom
+  endpoints or Ollama)
+
+Parameters
+----------
+
+.. list-table::
+   :header-rows: 1
+   :widths: 25 15 60
+
+   * - Parameter
+     - Default
+     - Description
+   * - ``llm_conn_id``
+     - ``pydanticai_default``
+     - Airflow connection ID for the LLM provider.
+   * - ``embed_conn_id``
+     - ``None`` (falls back to ``llm_conn_id``)
+     - Separate connection for embeddings.
+   * - ``embed_model``
+     - ``text-embedding-3-small``
+     - Embedding model name.
+   * - ``llm_model``
+     - ``None``
+     - Chat model name.  Required for ``get_chat_model()``.
+
+Dependencies
+------------
+
+Install the ``langchain`` extra to use this hook::
+
+    pip install apache-airflow-providers-common-ai[langchain]
diff --git a/providers/common/ai/provider.yaml 
b/providers/common/ai/provider.yaml
index 2a13392ea99..cc716d54a47 100644
--- a/providers/common/ai/provider.yaml
+++ b/providers/common/ai/provider.yaml
@@ -48,6 +48,9 @@ integrations:
   - integration-name: MCP Server
     external-doc-url: https://modelcontextprotocol.io/
     tags: [ai]
+  - integration-name: LangChain
+    external-doc-url: https://python.langchain.com/
+    tags: [ai]
 
 hooks:
   - integration-name: Pydantic AI
@@ -56,6 +59,9 @@ hooks:
   - integration-name: MCP Server
     python-modules:
       - airflow.providers.common.ai.hooks.mcp
+  - integration-name: LangChain
+    python-modules:
+      - airflow.providers.common.ai.hooks.langchain
 
 plugins:
   - name: hitl_review
diff --git a/providers/common/ai/pyproject.toml 
b/providers/common/ai/pyproject.toml
index 57ba93f7461..f44fe09fbd4 100644
--- a/providers/common/ai/pyproject.toml
+++ b/providers/common/ai/pyproject.toml
@@ -95,6 +95,10 @@ dependencies = [
 "common.sql" = [
     "apache-airflow-providers-common-sql"
 ]
+"langchain" = [
+    "langchain>=1.0.0",
+    "langchain-openai>=0.3.0",
+]
 
 [dependency-groups]
 dev = [
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/hooks/langchain.py 
b/providers/common/ai/src/airflow/providers/common/ai/hooks/langchain.py
new file mode 100644
index 00000000000..a92db1228f8
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/hooks/langchain.py
@@ -0,0 +1,97 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Hook for LangChain integration with Airflow connections."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.common.compat.sdk import BaseHook
+
+if TYPE_CHECKING:
+    from langchain_core.embeddings import Embeddings
+    from langchain_core.language_models.chat_models import BaseChatModel
+
+
+class LangChainHook(BaseHook):
+    """
+    Bridge Airflow connections to LangChain model constructors.
+
+    Reuses the ``pydanticai`` connection type so users configure a single
+    connection for both pydantic-ai operators and LangChain tasks.
+
+    :param llm_conn_id: Airflow connection ID for the LLM/embedding provider.
+    :param embed_conn_id: Separate connection for embeddings. Defaults to
+        ``llm_conn_id`` when not provided.
+    :param embed_model: Embedding model name (e.g. ``text-embedding-3-small``).
+    :param llm_model: Chat model name (e.g. ``gpt-4o``). Only needed when
+        using :meth:`get_chat_model`.
+    """
+
+    conn_name_attr = "llm_conn_id"
+    default_conn_name = "pydanticai_default"
+    conn_type = "pydanticai"
+    hook_name = "LangChain"
+
+    def __init__(
+        self,
+        llm_conn_id: str = "pydanticai_default",
+        embed_conn_id: str | None = None,
+        embed_model: str = "text-embedding-3-small",
+        llm_model: str | None = None,
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.llm_conn_id = llm_conn_id
+        self.embed_conn_id = embed_conn_id or llm_conn_id
+        self.embed_model = embed_model
+        self.llm_model = llm_model
+
+    def _resolve_connection_kwargs(self, conn_id: str) -> dict[str, Any]:
+        """Extract API key and base URL from an Airflow connection."""
+        conn = self.get_connection(conn_id)
+        kwargs: dict[str, Any] = {}
+        if conn.password:
+            kwargs["api_key"] = conn.password
+        if conn.host:
+            kwargs["base_url"] = conn.host
+        return kwargs
+
+    def get_chat_model(self) -> BaseChatModel:
+        """
+        Return a LangChain chat model configured from the Airflow connection.
+
+        Requires ``llm_model`` to be set on the hook.
+        """
+        if not self.llm_model:
+            raise ValueError("llm_model must be set to use get_chat_model()")
+
+        from langchain_openai import ChatOpenAI
+
+        conn_kwargs = self._resolve_connection_kwargs(self.llm_conn_id)
+        return ChatOpenAI(model=self.llm_model, **conn_kwargs)
+
+    def get_embedding_model(self) -> Embeddings:
+        """
+        Return a LangChain embedding model configured from the Airflow 
connection.
+
+        Uses ``embed_conn_id`` (falls back to ``llm_conn_id``) for credentials.
+        """
+        from langchain_openai import OpenAIEmbeddings
+
+        conn_kwargs = self._resolve_connection_kwargs(self.embed_conn_id)
+        return OpenAIEmbeddings(model=self.embed_model, **conn_kwargs)
diff --git a/providers/common/ai/tests/unit/common/ai/hooks/test_langchain.py 
b/providers/common/ai/tests/unit/common/ai/hooks/test_langchain.py
new file mode 100644
index 00000000000..8ebd9b3f1b5
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/hooks/test_langchain.py
@@ -0,0 +1,189 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.hooks.langchain import LangChainHook
+
+
+class TestLangChainHookInit:
+    def test_default_params(self):
+        hook = LangChainHook()
+        assert hook.llm_conn_id == "pydanticai_default"
+        assert hook.embed_conn_id == "pydanticai_default"
+        assert hook.embed_model == "text-embedding-3-small"
+        assert hook.llm_model is None
+
+    def test_separate_embed_conn_id(self):
+        hook = LangChainHook(llm_conn_id="llm_conn", 
embed_conn_id="embed_conn")
+        assert hook.llm_conn_id == "llm_conn"
+        assert hook.embed_conn_id == "embed_conn"
+
+    def test_embed_conn_defaults_to_llm_conn(self):
+        hook = LangChainHook(llm_conn_id="my_conn")
+        assert hook.embed_conn_id == "my_conn"
+
+    def test_conn_type_is_pydanticai(self):
+        assert LangChainHook.conn_type == "pydanticai"
+        assert LangChainHook.default_conn_name == "pydanticai_default"
+
+
+class TestResolveConnectionKwargs:
+    @patch.object(LangChainHook, "get_connection")
+    def test_extracts_password_as_api_key(self, mock_get_conn):
+        mock_conn = MagicMock()
+        mock_conn.password = "sk-test-key"
+        mock_conn.host = ""
+        mock_get_conn.return_value = mock_conn
+
+        hook = LangChainHook()
+        result = hook._resolve_connection_kwargs("test_conn")
+
+        assert result == {"api_key": "sk-test-key"}
+
+    @patch.object(LangChainHook, "get_connection")
+    def test_extracts_host_as_base_url(self, mock_get_conn):
+        mock_conn = MagicMock()
+        mock_conn.password = ""
+        mock_conn.host = "https://custom.api.com";
+        mock_get_conn.return_value = mock_conn
+
+        hook = LangChainHook()
+        result = hook._resolve_connection_kwargs("test_conn")
+
+        assert result == {"base_url": "https://custom.api.com"}
+
+    @patch.object(LangChainHook, "get_connection")
+    def test_both_password_and_host(self, mock_get_conn):
+        mock_conn = MagicMock()
+        mock_conn.password = "sk-key"
+        mock_conn.host = "https://api.example.com";
+        mock_get_conn.return_value = mock_conn
+
+        hook = LangChainHook()
+        result = hook._resolve_connection_kwargs("test_conn")
+
+        assert result == {"api_key": "sk-key", "base_url": 
"https://api.example.com"}
+
+    @patch.object(LangChainHook, "get_connection")
+    def test_empty_fields_return_empty_dict(self, mock_get_conn):
+        mock_conn = MagicMock()
+        mock_conn.password = ""
+        mock_conn.host = ""
+        mock_get_conn.return_value = mock_conn
+
+        hook = LangChainHook()
+        result = hook._resolve_connection_kwargs("test_conn")
+
+        assert result == {}
+
+
+def _make_mock_chat_openai_module():
+    mock_module = MagicMock()
+    mock_cls = MagicMock()
+    mock_module.ChatOpenAI = mock_cls
+    return mock_module, mock_cls
+
+
+def _make_mock_openai_embeddings_module():
+    mock_module = MagicMock()
+    mock_cls = MagicMock()
+    mock_module.OpenAIEmbeddings = mock_cls
+    return mock_module, mock_cls
+
+
+class TestGetChatModel:
+    def test_raises_without_llm_model(self):
+        hook = LangChainHook()
+        with pytest.raises(ValueError, match="llm_model must be set"):
+            hook.get_chat_model()
+
+    @patch.object(LangChainHook, "get_connection")
+    def test_returns_chat_openai(self, mock_get_conn):
+        mock_conn = MagicMock()
+        mock_conn.password = "sk-test"
+        mock_conn.host = ""
+        mock_get_conn.return_value = mock_conn
+
+        mock_module, mock_cls = _make_mock_chat_openai_module()
+
+        hook = LangChainHook(llm_model="gpt-4o")
+        with patch.dict("sys.modules", {"langchain_openai": mock_module}):
+            result = hook.get_chat_model()
+
+        mock_cls.assert_called_once_with(model="gpt-4o", api_key="sk-test")
+        assert result == mock_cls.return_value
+
+    @patch.object(LangChainHook, "get_connection")
+    def test_passes_base_url(self, mock_get_conn):
+        mock_conn = MagicMock()
+        mock_conn.password = "sk-test"
+        mock_conn.host = "https://custom.api.com";
+        mock_get_conn.return_value = mock_conn
+
+        mock_module, mock_cls = _make_mock_chat_openai_module()
+
+        hook = LangChainHook(llm_model="gpt-4o")
+        with patch.dict("sys.modules", {"langchain_openai": mock_module}):
+            hook.get_chat_model()
+
+        mock_cls.assert_called_once_with(
+            model="gpt-4o", api_key="sk-test", 
base_url="https://custom.api.com";
+        )
+
+
+class TestGetEmbeddingModel:
+    @patch.object(LangChainHook, "get_connection")
+    def test_returns_openai_embeddings(self, mock_get_conn):
+        mock_conn = MagicMock()
+        mock_conn.password = "sk-test"
+        mock_conn.host = ""
+        mock_get_conn.return_value = mock_conn
+
+        mock_module, mock_cls = _make_mock_openai_embeddings_module()
+
+        hook = LangChainHook(embed_model="text-embedding-3-large")
+        with patch.dict("sys.modules", {"langchain_openai": mock_module}):
+            result = hook.get_embedding_model()
+
+        mock_cls.assert_called_once_with(model="text-embedding-3-large", 
api_key="sk-test")
+        assert result == mock_cls.return_value
+
+    @patch.object(LangChainHook, "get_connection")
+    def test_uses_embed_conn_id(self, mock_get_conn):
+        mock_conn_llm = MagicMock()
+        mock_conn_llm.password = "sk-llm"
+        mock_conn_llm.host = ""
+
+        mock_conn_embed = MagicMock()
+        mock_conn_embed.password = "sk-embed"
+        mock_conn_embed.host = ""
+
+        mock_get_conn.side_effect = lambda conn_id: (
+            mock_conn_embed if conn_id == "embed_conn" else mock_conn_llm
+        )
+
+        mock_module, mock_cls = _make_mock_openai_embeddings_module()
+
+        hook = LangChainHook(llm_conn_id="llm_conn", 
embed_conn_id="embed_conn")
+        with patch.dict("sys.modules", {"langchain_openai": mock_module}):
+            hook.get_embedding_model()
+
+        mock_cls.assert_called_once_with(model="text-embedding-3-small", 
api_key="sk-embed")

Reply via email to