This is an automated email from the ASF dual-hosted git repository. vikramkoka pushed a commit to branch aip99-langchain in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 84ac8a0aeef5a0ce28fe3e9330b40ca707c73496 Author: Vikram Koka <[email protected]> AuthorDate: Tue May 19 16:34:15 2026 +0100 Add LangChain hook to common.ai provider - Adds LangChainHook to bridge Airflow connections to LangChain model constructors (ChatOpenAI, OpenAIEmbeddings), using constructor injection for credentials - Reuses the existing pydanticai connection type so users configure one connection for PydanticAI, LlamaIndex, and LangChain - Follows the same pattern as LlamaIndexHook: _resolve_connection_kwargs() extracts api_key and base_url from the Airflow connection and passes them directly to LangChain constructors - Adds langchain optional dependency extra (langchain>=1.0.0, langchain-openai>=0.3.0) What's included - hooks/langchain.py — LangChainHook(BaseHook) with get_chat_model() and get_embedding_model() - tests/unit/common/ai/hooks/test_langchain.py — full test coverage (init, connection resolution, chat model, embedding model) - docs/hooks/langchain.rst — hook documentation with usage examples - provider.yaml — LangChain integration and hook registration - pyproject.toml — langchain optional dependency extra Design decisions - BaseHook, not BaseAIHook — BaseAIHook is still in development. Will migrate in a follow-up PR once it ships. - Constructor injection — credentials passed as api_key=/base_url= kwargs to LangChain constructors. No environment variable mutation. Matches the LlamaIndexHook pattern. - Shared connection type — reuses pydanticai connection type rather than introducing a new one. One connection works across all three frameworks. - No @task.langchain yet — consistent with LlamaIndex (no @task.llamaindex). Deferred to the BaseAIHook migration PR. --- .../example_dags/example_langchain_tool_agent.py | 546 +++++++++++++++++++++ 1 file changed, 546 insertions(+) diff --git a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_langchain_tool_agent.py b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_langchain_tool_agent.py new file mode 100644 index 00000000000..e00335d0ff6 --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_langchain_tool_agent.py @@ -0,0 +1,546 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +ReAct tool-calling agent with LangChain -- research and report pipeline. + +Demonstrates the "agent as a task" pattern using a LangChain ReAct agent +that autonomously decides which tools to call, composed with common.ai's +:class:`~airflow.providers.common.ai.operators.llm.LLMOperator` for report +formatting and AIP-90 HITL operators for human review. + +Unlike RAG examples (fixed pipeline: retrieve then synthesize), this +agent's tool-call sequence is determined by the LLM at runtime. The agent +might call zero tools or ten tools depending on the question. This is the +canonical "agent as a task" pattern: Airflow handles scheduling, retry, +connections, and the surrounding workflow; the LangChain agent handles +internal reasoning. + +``example_langchain_tool_agent`` (manual trigger): + +.. code-block:: text + + prompt_review (HITLEntryOperator) + -> prepare_tools (@task) + -> run_research_agent (@task) + -> format_report (LLMOperator) + -> report_approval (ApprovalOperator) + +**What this makes visible that running an agent alone hides:** + +* The question goes through human review before the agent runs. +* The agent's raw findings are a visible XCom value between tasks. +* Report formatting is a separate, independently retryable LLM call. +* The formatted report requires human approval before delivery. + +**Contrast with AIP-99's AgentOperator:** + +AIP-99's ``AgentOperator`` uses PydanticAI for agent execution. This +example uses LangChain's ``create_agent`` with LangChain-native ``@tool`` +definitions. Users with existing LangChain tools (700+ integrations) +can use them directly without rewriting as PydanticAI tools. + +Before running: + +1. Install LangChain packages:: + + pip install langchain langchain-openai langchain-text-splitters \\ + langchain-community faiss-cpu + +2. Create an LLM connection named ``pydanticai_default`` (or the value of + ``LLM_CONN_ID`` below) for your chosen model provider. + +3. Optionally place a knowledge base directory at ``DOCS_PATH`` and a + survey CSV at ``SURVEY_CSV_PATH``. If ``DOCS_PATH`` is empty, sample + documents about Apache Airflow are created automatically. +""" + +from __future__ import annotations + +import datetime +import os + +from airflow.providers.common.ai.operators.llm import LLMOperator +from airflow.providers.common.compat.sdk import dag, task +from airflow.providers.standard.operators.hitl import ApprovalOperator, HITLEntryOperator +from airflow.sdk import Param + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +LLM_CONN_ID = "pydanticai_default" +LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o") + +EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-small") + +DOCS_PATH = os.environ.get("DOCS_PATH", "/opt/airflow/data/rag_documents") + +SURVEY_CSV_PATH = os.environ.get( + "SURVEY_CSV_PATH", + "/opt/airflow/data/airflow-user-survey-2025.csv", +) + +INDEX_PERSIST_DIR = os.environ.get("INDEX_PERSIST_DIR", "/opt/airflow/data/langchain_faiss_index") + +DEFAULT_QUESTION = ( + "What percentage of Airflow users are on Kubernetes? " + "Also check what the documentation says about the KubernetesExecutor." +) + +SAMPLE_DOCUMENTS = { + "apache_airflow_overview.txt": ( + "Apache Airflow is an open-source platform for programmatically authoring, " + "scheduling, and monitoring workflows. Originally created at Airbnb in 2014, " + "it graduated from the Apache Incubator in 2019. Airflow uses directed acyclic " + "graphs (DAGs) to define workflows as Python code, making pipelines versionable, " + "testable, and collaborative. The scheduler executes tasks on workers following " + "the defined dependencies. Airflow is widely used for ETL/ELT pipelines, ML model " + "training orchestration, and data warehouse management. As of Airflow 3.0, workers " + "communicate exclusively through the Execution API and never access the metadata " + "database directly, strengthening security and enabling horizontal scaling." + ), + "kubernetes_executor.txt": ( + "The KubernetesExecutor runs each Airflow task as a separate Kubernetes pod. " + "This provides strong isolation between tasks, dynamic resource allocation, and " + "the ability to use different Docker images per task. When a task is scheduled, " + "the executor creates a pod spec, submits it to the Kubernetes API, and monitors " + "the pod until completion. Resource requests and limits can be set per task via " + "executor_config. The KubernetesExecutor is recommended for heterogeneous " + "workloads where tasks have different resource requirements or dependencies. " + "It scales to zero when no tasks are running, reducing infrastructure costs. " + "In Airflow 3.0, pod specs are submitted via the Execution API." + ), + "operators_and_hooks.txt": ( + "Operators are the building blocks of Airflow tasks. Each operator defines a " + "single unit of work: BashOperator runs shell commands, PythonOperator executes " + "Python callables, and provider-specific operators interact with external systems " + "(S3, BigQuery, Spark, etc.). Hooks are the connection layer between operators " + "and external services. A hook manages authentication and provides methods to " + "interact with a specific service. For example, S3Hook provides methods to read " + "and write S3 objects, while PostgresHook connects to PostgreSQL databases." + ), + "connections_and_variables.txt": ( + "Connections store credentials and endpoint information for external services. " + "Each connection has a type (e.g., postgres, aws, http), login, password, host, " + "port, schema, and an extras JSON field for additional parameters. In Airflow 3.0, " + "workers access connections through the Execution API using short-lived JWT tokens " + "scoped to the running task instance. Variables are key-value pairs for storing " + "configuration that may change between environments." + ), + "ai_operators.txt": ( + "Airflow's common.ai provider (AIP-99) adds first-class AI/LLM support. " + "LLMOperator sends a prompt to any supported LLM and returns text or structured " + "output via Pydantic models. AgentOperator runs multi-turn reasoning with tools " + "(SQL, HTTP, MCP servers). LLMBranchOperator uses an LLM to choose downstream " + "task branches. All operators support human-in-the-loop review, durable execution " + "for long-running agents, usage limits for cost control, and connect to 20+ model " + "providers through Airflow connections." + ), +} + +REPORT_SYSTEM_PROMPT = ( + "You are a technical report writer. Format the research findings into a " + "clear, well-structured report with sections and bullet points. Cite " + "sources when available. Be concise but thorough." +) + + +# --------------------------------------------------------------------------- +# Helper: build or load the knowledge base FAISS index +# --------------------------------------------------------------------------- + + +def _ensure_knowledge_base(hook) -> str: + """Build a FAISS index from sample docs if it does not already exist. + + Returns the persist directory path. + """ + from langchain_community.vectorstores import FAISS + from langchain_core.documents import Document + from langchain_text_splitters import RecursiveCharacterTextSplitter + + if os.path.exists(os.path.join(INDEX_PERSIST_DIR, "index.faiss")): + return INDEX_PERSIST_DIR + + os.makedirs(DOCS_PATH, exist_ok=True) + for filename, content in SAMPLE_DOCUMENTS.items(): + filepath = os.path.join(DOCS_PATH, filename) + if not os.path.exists(filepath): + with open(filepath, "w", encoding="utf-8") as f: + f.write(content) + + docs = [] + for filename in sorted(os.listdir(DOCS_PATH)): + if not filename.endswith((".txt", ".md")): + continue + with open(os.path.join(DOCS_PATH, filename), encoding="utf-8") as f: + docs.append(Document(page_content=f.read(), metadata={"source": filename})) + + splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100) + chunks = splitter.split_documents(docs) + + embeddings = hook.get_embedding_model() + vectorstore = FAISS.from_documents(chunks, embeddings) + + os.makedirs(INDEX_PERSIST_DIR, exist_ok=True) + vectorstore.save_local(INDEX_PERSIST_DIR) + print(f"Built FAISS index: {len(chunks)} chunks in {INDEX_PERSIST_DIR}") + return INDEX_PERSIST_DIR + + +# --------------------------------------------------------------------------- +# Tool definitions (LangChain @tool decorator) +# --------------------------------------------------------------------------- + + +def _build_tools(hook, index_dir: str, survey_csv_path: str) -> list: + """Construct the agent's tool set.""" + from langchain.tools import tool + + # -- Tool 1: Knowledge base search (vector retrieval) ------------------ + + @tool + def search_knowledge_base(query: str) -> str: + """Search the internal knowledge base for relevant documentation. + + Use this for questions about Airflow features, architecture, + operators, executors, connections, or best practices. + """ + from langchain_community.vectorstores import FAISS + + embeddings = hook.get_embedding_model() + vectorstore = FAISS.load_local( + index_dir, embeddings, allow_dangerous_deserialization=True + ) + results = vectorstore.similarity_search(query, k=3) + + if not results: + return "No relevant documents found in the knowledge base." + + formatted = [] + for i, doc in enumerate(results, 1): + source = doc.metadata.get("source", "unknown") + formatted.append(f"[{i}] Source: {source}\n{doc.page_content}") + return "\n\n".join(formatted) + + # -- Tool 2: Survey data query ---------------------------------------- + + @tool + def query_survey_data(question: str) -> str: + """Query the Airflow user survey dataset to answer questions about + Airflow adoption, usage patterns, executor choices, deployment + methods, cloud providers, and user demographics. + + Pass a natural language question. The tool converts it to SQL + and executes it against the survey data. + """ + import csv + + if not os.path.exists(survey_csv_path): + return ( + "Survey data not available. The CSV file was not found at " + f"{survey_csv_path}. Continuing with other tools." + ) + + with open(survey_csv_path, encoding="utf-8") as f: + reader = csv.DictReader(f) + rows = list(reader) + + if not rows: + return "Survey data is empty." + + columns = list(rows[0].keys()) + total = len(rows) + summary_parts = [f"Survey has {total} responses with columns:"] + summary_parts.append(", ".join(columns[:15])) + if len(columns) > 15: + summary_parts.append(f"... and {len(columns) - 15} more columns") + + q_lower = question.lower() + + if "kubernetes" in q_lower or "k8s" in q_lower: + k8s_col = next( + (c for c in columns if "kubernetes" in c.lower()), + None, + ) + if k8s_col: + k8s_users = sum(1 for r in rows if r.get(k8s_col, "").strip()) + pct = round(100 * k8s_users / total, 1) if total else 0 + return ( + f"KubernetesExecutor usage: {k8s_users} of {total} " + f"respondents ({pct}%) indicated they use KubernetesExecutor." + ) + + if "celery" in q_lower: + celery_col = next( + (c for c in columns if "celery" in c.lower()), + None, + ) + if celery_col: + celery_users = sum(1 for r in rows if r.get(celery_col, "").strip()) + pct = round(100 * celery_users / total, 1) if total else 0 + return ( + f"CeleryExecutor usage: {celery_users} of {total} " + f"respondents ({pct}%) indicated they use CeleryExecutor." + ) + + if "version" in q_lower: + version_col = next( + (c for c in columns if "version" in c.lower() and "airflow" in c.lower()), + None, + ) + if version_col: + from collections import Counter + + counts = Counter(r.get(version_col, "unknown") for r in rows) + top5 = counts.most_common(5) + lines = [f" {v}: {c} ({round(100*c/total,1)}%)" for v, c in top5] + return f"Airflow version distribution (top 5):\n" + "\n".join(lines) + + return ( + f"Survey dataset has {total} responses across {len(columns)} columns. " + "Available topics: executor usage (Kubernetes, Celery, Local), " + "Airflow versions, deployment methods, cloud providers, company " + "size, industries, AI tool usage. Ask a more specific question." + ) + + # -- Tool 3: Web search (simulated) ------------------------------------ + + @tool + def search_web(query: str) -> str: + """Search the web for current information, news, or context. + + Use this for questions that need up-to-date external information + not available in the knowledge base or survey data. + """ + responses = { + "kubernetes airflow": ( + "Recent blog posts indicate KubernetesExecutor adoption has grown " + "significantly since Airflow 2.0, with many large-scale deployments " + "migrating from CeleryExecutor. Key advantages cited: pod-level " + "isolation, dynamic scaling, and per-task resource configuration. " + "Source: Airflow blog, Astronomer blog (2025-2026)." + ), + "airflow 3": ( + "Airflow 3.0 shipped in early 2026 with major architectural changes: " + "Execution API (workers never access metadata DB directly), multi-team " + "isolation, improved UI, and the common.ai provider for AI/LLM support. " + "Source: airflow.apache.org release notes." + ), + "airflow adoption": ( + "The 2025 Airflow User Survey showed continued growth: 2,000+ responses, " + "40% of respondents at companies with 1,000+ employees, 35% using cloud-managed " + "Airflow (Astronomer, MWAA, Cloud Composer). Source: Airflow blog." + ), + } + + for keyword, response in responses.items(): + if any(w in query.lower() for w in keyword.split()): + return response + + return ( + f"Web search for '{query}' returned general results. " + "For this demo, web search is simulated with canned responses. " + "In production, use Tavily, Serper, or another search API " + "configured via an Airflow connection." + ) + + # -- Tool 4: Calculator ------------------------------------------------ + + @tool + def calculate(expression: str) -> str: + """Evaluate a mathematical expression. + + Use for computing percentages, averages, growth rates, or any + numerical calculation. Supports basic Python math operations. + + Examples: "100 * 0.35", "1234 / 5678 * 100", "round(3.14159, 2)" + """ + allowed_names = { + "abs": abs, "round": round, "min": min, "max": max, + "sum": sum, "len": len, "pow": pow, + } + try: + result = eval(expression, {"__builtins__": {}}, allowed_names) # noqa: S307 + return str(result) + except Exception as e: + return f"Calculation error: {e}. Check the expression syntax." + + return [search_knowledge_base, query_survey_data, search_web, calculate] + + +# --------------------------------------------------------------------------- +# DAG: ReAct tool-calling agent with human review +# --------------------------------------------------------------------------- + + +# [START example_langchain_tool_agent] +@dag +def example_langchain_tool_agent(): + """ + Research agent with LangChain tools and human review. + + Task graph:: + + prompt_review (HITLEntryOperator) + -> prepare_tools (@task) + -> run_research_agent (@task) + -> format_report (LLMOperator) + -> report_approval (ApprovalOperator) + + The agent uses LangChain's ``create_agent`` with a ReAct reasoning + loop. It autonomously decides which tools to call -- knowledge base + search, survey data query, web search, or calculator -- based on the + user's question. The number and sequence of tool calls is determined + by the LLM at runtime. + + The surrounding Airflow DAG provides what the agent cannot: + human review of the question (HITLEntryOperator), formatted report + generation (LLMOperator), and human approval of the final output + (ApprovalOperator). + """ + + prompt_review = HITLEntryOperator( + task_id="prompt_review", + subject="Review the research question", + params={ + "question": Param( + DEFAULT_QUESTION, + type="string", + description="The research question for the agent to investigate", + ), + }, + response_timeout=datetime.timedelta(hours=1), + ) + + @task + def prepare_tools(hitl_response: dict) -> dict: + """Build the FAISS knowledge base index and resolve tool config.""" + from airflow.providers.common.ai.hooks.langchain import LangChainHook + + hook = LangChainHook( + llm_conn_id=LLM_CONN_ID, + llm_model=LLM_MODEL, + embed_model=EMBEDDING_MODEL, + ) + index_dir = _ensure_knowledge_base(hook) + + question = hitl_response["params_input"]["question"] + return { + "question": question, + "index_dir": index_dir, + "survey_csv_path": SURVEY_CSV_PATH, + } + + @task + def run_research_agent(config: dict) -> dict: + """Run a LangChain ReAct agent that autonomously researches the question. + + The agent decides which tools to call and in what order. The number + of tool calls depends on the complexity of the question. All + reasoning steps, tool calls, and observations are logged. + """ + from langchain.agents import create_agent + + from airflow.providers.common.ai.hooks.langchain import LangChainHook + + hook = LangChainHook( + llm_conn_id=LLM_CONN_ID, + llm_model=LLM_MODEL, + embed_model=EMBEDDING_MODEL, + ) + model = hook.get_chat_model() + tools = _build_tools(hook, config["index_dir"], config["survey_csv_path"]) + + agent = create_agent( + model, + tools=tools, + system_prompt=( + "You are a thorough research assistant for Apache Airflow. " + "You have access to tools for searching a knowledge base, " + "querying survey data, searching the web, and doing math. " + "Use the appropriate tools to fully answer the question. " + "Combine information from multiple sources when relevant. " + "Always cite which tool provided each piece of information." + ), + ) + + question = config["question"] + print(f"Research question: {question}") + print("Agent starting research...") + + tool_calls_log = [] + final_answer = "" + + for step in agent.stream( + {"messages": [{"role": "user", "content": question}]}, + stream_mode="values", + ): + msg = step["messages"][-1] + if hasattr(msg, "tool_calls") and msg.tool_calls: + for tc in msg.tool_calls: + tool_calls_log.append({ + "tool": tc["name"], + "args": str(tc.get("args", {}))[:200], + }) + print(f" Tool call: {tc['name']}({tc.get('args', {})})") + elif hasattr(msg, "content") and msg.content: + final_answer = msg.content + + print(f"Agent completed. Tool calls made: {len(tool_calls_log)}") + + return { + "question": question, + "findings": final_answer, + "tool_calls": tool_calls_log, + "tool_call_count": len(tool_calls_log), + } + + tools_config = prepare_tools(prompt_review.output) + research_result = run_research_agent(tools_config) + + format_report = LLMOperator( + task_id="format_report", + llm_conn_id=LLM_CONN_ID, + system_prompt=REPORT_SYSTEM_PROMPT, + prompt="""\ +Format the following research findings into a clear report. + +{% set result = ti.xcom_pull(task_ids='run_research_agent') -%} +Question: {{ result['question'] }} + +Raw findings: +{{ result['findings'] }} + +Tools used: {{ result['tool_call_count'] }} calls +{% for tc in result['tool_calls'] -%} + - {{ tc['tool'] }}: {{ tc['args'][:100] }} +{% endfor -%}""", + ) + research_result >> format_report + + report_approval = ApprovalOperator( # noqa: F841 + task_id="report_approval", + subject="Review the research report", + body=format_report.output, + response_timeout=datetime.timedelta(hours=1), + ) + + +# [END example_langchain_tool_agent] + +example_langchain_tool_agent()
