This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7f6457f1f47 Add Azure OpenAI & Bedrock support in ``common.ai``
provider (#62816)
7f6457f1f47 is described below
commit 7f6457f1f4796b1bcaa0cd85388b6fa896ca4dbb
Author: Gökhan Çetin <[email protected]>
AuthorDate: Wed Mar 11 22:58:10 2026 +0300
Add Azure OpenAI & Bedrock support in ``common.ai`` provider (#62816)
---
.../ai/docs/operators/llm_schema_compare.rst | 2 +-
providers/common/ai/docs/toolsets.rst | 4 +-
providers/common/ai/provider.yaml | 170 +++++++++
.../providers/common/ai/decorators/agent.py | 2 +-
.../common/ai/example_dags/example_agent.py | 10 +-
.../ai/example_dags/example_llm_schema_compare.py | 10 +-
.../common/ai/example_dags/example_mcp.py | 6 +-
.../providers/common/ai/get_provider_info.py | 128 +++++++
.../providers/common/ai/hooks/pydantic_ai.py | 373 +++++++++++++++---
.../airflow/providers/common/ai/operators/agent.py | 5 +-
.../airflow/providers/common/ai/operators/llm.py | 14 +-
.../tests/unit/common/ai/decorators/test_agent.py | 10 +-
.../ai/tests/unit/common/ai/decorators/test_llm.py | 4 +-
.../unit/common/ai/decorators/test_llm_branch.py | 4 +-
.../ai/decorators/test_llm_schema_compare.py | 8 +-
.../unit/common/ai/decorators/test_llm_sql.py | 4 +-
.../tests/unit/common/ai/hooks/test_pydantic_ai.py | 418 +++++++++++++++++++--
.../tests/unit/common/ai/operators/test_agent.py | 36 +-
.../ai/tests/unit/common/ai/operators/test_llm.py | 24 +-
.../unit/common/ai/operators/test_llm_branch.py | 12 +-
.../tests/unit/common/ai/operators/test_llm_sql.py | 36 +-
21 files changed, 1109 insertions(+), 171 deletions(-)
diff --git a/providers/common/ai/docs/operators/llm_schema_compare.rst
b/providers/common/ai/docs/operators/llm_schema_compare.rst
index f6767e04dbe..ebdf3f48994 100644
--- a/providers/common/ai/docs/operators/llm_schema_compare.rst
+++ b/providers/common/ai/docs/operators/llm_schema_compare.rst
@@ -91,7 +91,7 @@ on top, concatenate them:
LLMSchemaCompareOperator(
task_id="compare_with_custom_rules",
prompt="Compare schemas and flag breaking changes",
- llm_conn_id="pydantic_ai_default",
+ llm_conn_id="pydanticai_default",
db_conn_ids=["postgres_source", "snowflake_target"],
table_names=["customers"],
system_prompt=DEFAULT_SYSTEM_PROMPT
diff --git a/providers/common/ai/docs/toolsets.rst
b/providers/common/ai/docs/toolsets.rst
index de708ac9c4c..8d295ac361c 100644
--- a/providers/common/ai/docs/toolsets.rst
+++ b/providers/common/ai/docs/toolsets.rst
@@ -256,7 +256,7 @@ Using Multiple MCP Servers
AgentOperator(
task_id="multi_mcp",
prompt="Get the weather in London and run a calculation",
- llm_conn_id="pydantic_ai_default",
+ llm_conn_id="pydanticai_default",
toolsets=[
MCPToolset(mcp_conn_id="weather_mcp", tool_prefix="weather"),
MCPToolset(mcp_conn_id="code_runner_mcp", tool_prefix="code"),
@@ -276,7 +276,7 @@ server instances directly — no Airflow connection needed:
AgentOperator(
task_id="direct_mcp",
prompt="What tools are available?",
- llm_conn_id="pydantic_ai_default",
+ llm_conn_id="pydanticai_default",
toolsets=[
MCPServerStreamableHTTP("http://localhost:3001/mcp"),
MCPServerStdio("uvx", args=["mcp-run-python"]),
diff --git a/providers/common/ai/provider.yaml
b/providers/common/ai/provider.yaml
index 4a406a66470..43a98af32a3 100644
--- a/providers/common/ai/provider.yaml
+++ b/providers/common/ai/provider.yaml
@@ -79,6 +79,176 @@ connection-types:
type:
- string
- 'null'
+ - hook-class-name:
airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIAzureHook
+ connection-type: pydanticai-azure
+ ui-field-behaviour:
+ hidden-fields:
+ - schema
+ - port
+ - login
+ relabeling:
+ password: API Key
+ host: Azure Endpoint
+ placeholders:
+ host: "https://<resource>.openai.azure.com"
+ conn-fields:
+ model:
+ label: Model
+ description: "Azure model identifier (e.g. azure:gpt-4o)"
+ schema:
+ type:
+ - string
+ - 'null'
+ api_version:
+ label: API Version
+ description: "Azure OpenAI API version (e.g. 2024-07-01-preview).
Falls back to OPENAI_API_VERSION."
+ schema:
+ type:
+ - string
+ - 'null'
+ - hook-class-name:
airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIBedrockHook
+ connection-type: pydanticai-bedrock
+ ui-field-behaviour:
+ hidden-fields:
+ - schema
+ - port
+ - login
+ - host
+ - password
+ relabeling: {}
+ placeholders: {}
+ conn-fields:
+ model:
+ label: Model
+ description: "Bedrock model identifier (e.g.
bedrock:us.anthropic.claude-opus-4-5)"
+ schema:
+ type:
+ - string
+ - 'null'
+ region_name:
+ label: AWS Region
+ description: "AWS region (e.g. us-east-1). Falls back to
AWS_DEFAULT_REGION env var."
+ schema:
+ type:
+ - string
+ - 'null'
+ aws_access_key_id:
+ label: AWS Access Key ID
+ description: "IAM access key. Leave empty to use instance role /
environment credential chain."
+ schema:
+ type:
+ - string
+ - 'null'
+ aws_secret_access_key:
+ label: AWS Secret Access Key
+ description: "IAM secret key."
+ schema:
+ type:
+ - string
+ - 'null'
+ aws_session_token:
+ label: AWS Session Token
+ description: "Temporary session token (optional)."
+ schema:
+ type:
+ - string
+ - 'null'
+ profile_name:
+ label: AWS Profile Name
+ description: "Named AWS credentials profile (optional)."
+ schema:
+ type:
+ - string
+ - 'null'
+ api_key:
+ label: Bearer Token
+ description: "AWS bearer token (alt. to IAM key/secret). Falls back to
AWS_BEARER_TOKEN_BEDROCK."
+ schema:
+ type:
+ - string
+ - 'null'
+ base_url:
+ label: Custom Endpoint URL
+ description: "Override the Bedrock runtime endpoint URL (optional)."
+ schema:
+ type:
+ - string
+ - 'null'
+ aws_read_timeout:
+ label: Read Timeout (s)
+ description: "boto3 read timeout in seconds (float, optional)."
+ schema:
+ type:
+ - number
+ - 'null'
+ aws_connect_timeout:
+ label: Connect Timeout (s)
+ description: "boto3 connect timeout in seconds (float, optional)."
+ schema:
+ type:
+ - number
+ - 'null'
+ - hook-class-name:
airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIVertexHook
+ connection-type: pydanticai-vertex
+ ui-field-behaviour:
+ hidden-fields:
+ - schema
+ - port
+ - login
+ - host
+ - password
+ relabeling: {}
+ placeholders: {}
+ conn-fields:
+ model:
+ label: Model
+ description: "Google model identifier (e.g.
google-vertex:gemini-2.0-flash)"
+ schema:
+ type:
+ - string
+ - 'null'
+ project:
+ label: GCP Project
+ description: "Google Cloud project ID. Falls back to
GOOGLE_CLOUD_PROJECT env var."
+ schema:
+ type:
+ - string
+ - 'null'
+ location:
+ label: Location / Region
+ description: "Vertex AI region (e.g. us-central1). Falls back to
GOOGLE_CLOUD_LOCATION env var."
+ schema:
+ type:
+ - string
+ - 'null'
+ vertexai:
+ label: Force Vertex AI Mode
+ description: "Force Vertex AI mode. Auto-detected when
project/location/credentials are set."
+ schema:
+ type:
+ - boolean
+ - 'null'
+ api_key:
+ label: API Key
+ description: "Google API key for Gen Language API or Vertex AI. Falls
back to GOOGLE_API_KEY."
+ schema:
+ type:
+ - string
+ - 'null'
+ service_account_info:
+ label: Service Account Info
+ description: "Service account key as inline dict (JSON with type,
project_id, private_key, etc.)."
+ schema:
+ type:
+ - object
+ - 'null'
+ base_url:
+ label: Custom Endpoint URL
+ description: "Override the Google API base URL (optional)."
+ schema:
+ type:
+ - string
+ - 'null'
- hook-class-name: airflow.providers.common.ai.hooks.mcp.MCPHook
connection-type: mcp
ui-field-behaviour:
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py
b/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py
index 40c55f630c0..379c8cb65fa 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py
@@ -107,7 +107,7 @@ def agent_task(
Usage::
@task.agent(
- llm_conn_id="pydantic_ai_default",
+ llm_conn_id="pydanticai_default",
system_prompt="You are a data analyst.",
toolsets=[SQLToolset(db_conn_id="postgres_default")],
)
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
index 3da0ef0ffaa..d36461e5a0c 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
@@ -36,7 +36,7 @@ def example_agent_operator_sql():
AgentOperator(
task_id="analyst",
prompt="What are the top 5 customers by order count?",
- llm_conn_id="pydantic_ai_default",
+ llm_conn_id="pydanticai_default",
system_prompt=(
"You are a SQL analyst. Use the available tools to explore "
"the schema and answer the question with data."
@@ -71,7 +71,7 @@ def example_agent_operator_hook():
AgentOperator(
task_id="api_explorer",
prompt="What endpoints are available and what does /status return?",
- llm_conn_id="pydantic_ai_default",
+ llm_conn_id="pydanticai_default",
system_prompt="You are an API explorer. Use the tools to discover and
call endpoints.",
toolsets=[
HookToolset(
@@ -97,7 +97,7 @@ example_agent_operator_hook()
@dag
def example_agent_decorator():
@task.agent(
- llm_conn_id="pydantic_ai_default",
+ llm_conn_id="pydanticai_default",
system_prompt="You are a data analyst. Use tools to answer questions.",
toolsets=[
SQLToolset(
@@ -133,7 +133,7 @@ def example_agent_structured_output():
row_count: int
@task.agent(
- llm_conn_id="pydantic_ai_default",
+ llm_conn_id="pydanticai_default",
system_prompt="You are a data analyst. Return structured results.",
output_type=Analysis,
toolsets=[SQLToolset(db_conn_id="postgres_default")],
@@ -158,7 +158,7 @@ example_agent_structured_output()
@dag
def example_agent_chain():
@task.agent(
- llm_conn_id="pydantic_ai_default",
+ llm_conn_id="pydanticai_default",
system_prompt="You are a SQL analyst.",
toolsets=[SQLToolset(db_conn_id="postgres_default",
allowed_tables=["orders"])],
)
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_schema_compare.py
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_schema_compare.py
index 0e6d306f7bc..88d3891ead1 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_schema_compare.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_schema_compare.py
@@ -29,7 +29,7 @@ def example_llm_schema_compare_basic():
LLMSchemaCompareOperator(
task_id="detect_schema_drift",
prompt="Identify schema mismatches that would break data loading
between systems",
- llm_conn_id="pydantic_ai_default",
+ llm_conn_id="pydanticai_default",
db_conn_ids=["postgres_default", "snowflake_default"],
table_names=["customers"],
)
@@ -49,7 +49,7 @@ def example_llm_schema_compare_full_context():
"Compare schemas and generate a migration plan. "
"Flag any differences that would break nightly ETL loads."
),
- llm_conn_id="pydantic_ai_default",
+ llm_conn_id="pydanticai_default",
db_conn_ids=["postgres_source", "snowflake_target"],
table_names=["customers", "orders"],
context_strategy="full",
@@ -74,7 +74,7 @@ def example_llm_schema_compare_with_object_storage():
LLMSchemaCompareOperator(
task_id="compare_s3_vs_db",
prompt="Compare S3 Parquet schema against the Postgres table and flag
breaking changes",
- llm_conn_id="pydantic_ai_default",
+ llm_conn_id="pydanticai_default",
db_conn_ids=["postgres_default"],
table_names=["customers"],
data_sources=[s3_source],
@@ -90,7 +90,7 @@ example_llm_schema_compare_with_object_storage()
@dag
def example_llm_schema_compare_decorator():
@task.llm_schema_compare(
- llm_conn_id="pydantic_ai_default",
+ llm_conn_id="pydanticai_default",
db_conn_ids=["postgres_source", "snowflake_target"],
table_names=["customers"],
)
@@ -109,7 +109,7 @@ example_llm_schema_compare_decorator()
@dag
def example_llm_schema_compare_conditional():
@task.llm_schema_compare(
- llm_conn_id="pydantic_ai_default",
+ llm_conn_id="pydanticai_default",
db_conn_ids=["postgres_source", "snowflake_target"],
table_names=["customers"],
context_strategy="full",
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_mcp.py
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_mcp.py
index adc3dc85b94..320fe8428ce 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_mcp.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_mcp.py
@@ -34,7 +34,7 @@ def example_mcp_toolset():
AgentOperator(
task_id="mcp_agent",
prompt="What tools are available? Run the hello tool.",
- llm_conn_id="pydantic_ai_default",
+ llm_conn_id="pydanticai_default",
system_prompt="You are a helpful assistant with access to MCP tools.",
toolsets=[
MCPToolset(mcp_conn_id="my_mcp_server"),
@@ -59,7 +59,7 @@ def example_mcp_multiple_servers():
AgentOperator(
task_id="multi_mcp_agent",
prompt="Get the weather in London and run a Python calculation: 2**10",
- llm_conn_id="pydantic_ai_default",
+ llm_conn_id="pydanticai_default",
system_prompt="You have access to weather and code execution tools.",
toolsets=[
MCPToolset(mcp_conn_id="weather_mcp", tool_prefix="weather"),
@@ -84,7 +84,7 @@ example_mcp_multiple_servers()
# AgentOperator(
# task_id="direct_mcp",
# prompt="What tools are available?",
-# llm_conn_id="pydantic_ai_default",
+# llm_conn_id="pydanticai_default",
# toolsets=[
# MCPServerStreamableHTTP("http://localhost:3001/mcp"),
# MCPServerStdio("uvx", args=["mcp-run-python"]),
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
index 0ace3488ef1..90ae393d64d 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
@@ -80,6 +80,134 @@ def get_provider_info():
}
},
},
+ {
+ "hook-class-name":
"airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIAzureHook",
+ "connection-type": "pydanticai-azure",
+ "ui-field-behaviour": {
+ "hidden-fields": ["schema", "port", "login"],
+ "relabeling": {"password": "API Key", "host": "Azure
Endpoint"},
+ "placeholders": {"host":
"https://<resource>.openai.azure.com"},
+ },
+ "conn-fields": {
+ "model": {
+ "label": "Model",
+ "description": "Azure model identifier (e.g.
azure:gpt-4o)",
+ "schema": {"type": ["string", "null"]},
+ },
+ "api_version": {
+ "label": "API Version",
+ "description": "Azure OpenAI API version (e.g.
2024-07-01-preview). Falls back to OPENAI_API_VERSION.",
+ "schema": {"type": ["string", "null"]},
+ },
+ },
+ },
+ {
+ "hook-class-name":
"airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIBedrockHook",
+ "connection-type": "pydanticai-bedrock",
+ "ui-field-behaviour": {
+ "hidden-fields": ["schema", "port", "login", "host",
"password"],
+ "relabeling": {},
+ "placeholders": {},
+ },
+ "conn-fields": {
+ "model": {
+ "label": "Model",
+ "description": "Bedrock model identifier (e.g.
bedrock:us.anthropic.claude-opus-4-5)",
+ "schema": {"type": ["string", "null"]},
+ },
+ "region_name": {
+ "label": "AWS Region",
+ "description": "AWS region (e.g. us-east-1). Falls
back to AWS_DEFAULT_REGION env var.",
+ "schema": {"type": ["string", "null"]},
+ },
+ "aws_access_key_id": {
+ "label": "AWS Access Key ID",
+ "description": "IAM access key. Leave empty to use
instance role / environment credential chain.",
+ "schema": {"type": ["string", "null"]},
+ },
+ "aws_secret_access_key": {
+ "label": "AWS Secret Access Key",
+ "description": "IAM secret key.",
+ "schema": {"type": ["string", "null"]},
+ },
+ "aws_session_token": {
+ "label": "AWS Session Token",
+ "description": "Temporary session token (optional).",
+ "schema": {"type": ["string", "null"]},
+ },
+ "profile_name": {
+ "label": "AWS Profile Name",
+ "description": "Named AWS credentials profile
(optional).",
+ "schema": {"type": ["string", "null"]},
+ },
+ "api_key": {
+ "label": "Bearer Token",
+ "description": "AWS bearer token (alt. to IAM
key/secret). Falls back to AWS_BEARER_TOKEN_BEDROCK.",
+ "schema": {"type": ["string", "null"]},
+ },
+ "base_url": {
+ "label": "Custom Endpoint URL",
+ "description": "Override the Bedrock runtime endpoint
URL (optional).",
+ "schema": {"type": ["string", "null"]},
+ },
+ "aws_read_timeout": {
+ "label": "Read Timeout (s)",
+ "description": "boto3 read timeout in seconds (float,
optional).",
+ "schema": {"type": ["number", "null"]},
+ },
+ "aws_connect_timeout": {
+ "label": "Connect Timeout (s)",
+ "description": "boto3 connect timeout in seconds
(float, optional).",
+ "schema": {"type": ["number", "null"]},
+ },
+ },
+ },
+ {
+ "hook-class-name":
"airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIVertexHook",
+ "connection-type": "pydanticai-vertex",
+ "ui-field-behaviour": {
+ "hidden-fields": ["schema", "port", "login", "host",
"password"],
+ "relabeling": {},
+ "placeholders": {},
+ },
+ "conn-fields": {
+ "model": {
+ "label": "Model",
+ "description": "Google model identifier (e.g.
google-vertex:gemini-2.0-flash)",
+ "schema": {"type": ["string", "null"]},
+ },
+ "project": {
+ "label": "GCP Project",
+ "description": "Google Cloud project ID. Falls back to
GOOGLE_CLOUD_PROJECT env var.",
+ "schema": {"type": ["string", "null"]},
+ },
+ "location": {
+ "label": "Location / Region",
+ "description": "Vertex AI region (e.g. us-central1).
Falls back to GOOGLE_CLOUD_LOCATION env var.",
+ "schema": {"type": ["string", "null"]},
+ },
+ "vertexai": {
+ "label": "Force Vertex AI Mode",
+ "description": "Force Vertex AI mode. Auto-detected
when project/location/credentials are set.",
+ "schema": {"type": ["boolean", "null"]},
+ },
+ "api_key": {
+ "label": "API Key",
+ "description": "Google API key for Gen Language API or
Vertex AI. Falls back to GOOGLE_API_KEY.",
+ "schema": {"type": ["string", "null"]},
+ },
+ "service_account_info": {
+ "label": "Service Account Info",
+ "description": "Service account key as inline dict
(JSON with type, project_id, private_key, etc.).",
+ "schema": {"type": ["object", "null"]},
+ },
+ "base_url": {
+ "label": "Custom Endpoint URL",
+ "description": "Override the Google API base URL
(optional).",
+ "schema": {"type": ["string", "null"]},
+ },
+ },
+ },
{
"hook-class-name":
"airflow.providers.common.ai.hooks.mcp.MCPHook",
"connection-type": "mcp",
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py
b/providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py
index b05f0421e9a..eb4db2d1ec5 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py
@@ -19,33 +19,32 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, TypeVar, overload
from pydantic_ai import Agent
-from pydantic_ai.models import Model, infer_model
-from pydantic_ai.providers import Provider, infer_provider,
infer_provider_class
+from pydantic_ai.models import infer_model
+from pydantic_ai.providers import infer_provider, infer_provider_class
from airflow.providers.common.compat.sdk import BaseHook
OutputT = TypeVar("OutputT")
if TYPE_CHECKING:
- from pydantic_ai.models import KnownModelName
+ from pydantic_ai.models import KnownModelName, Model
class PydanticAIHook(BaseHook):
"""
Hook for LLM access via pydantic-ai.
- Manages connection credentials and model creation. Uses pydantic-ai's
- model inference to support any provider (OpenAI, Anthropic, Google,
- Bedrock, Ollama, vLLM, etc.).
+ Covers providers that use a standard ``api_key`` + optional ``base_url``
+ (OpenAI, Anthropic, Groq, Mistral, DeepSeek, Ollama, vLLM, …).
- Connection fields:
- - **Model** (conn-field): Model in ``provider:model`` format (e.g.
``"anthropic:claude-sonnet-4-20250514"``)
- - **password**: API key (OpenAI, Anthropic, Groq, Mistral, etc.)
- - **host**: Base URL (optional — for custom endpoints like Ollama,
vLLM, Azure)
+ For cloud providers with non-standard auth use the dedicated subclasses:
+ :class:`PydanticAIAzureHook`, :class:`PydanticAIBedrockHook`,
+ :class:`PydanticAIVertexHook`.
- Cloud providers (Bedrock, Vertex) that use native auth chains should leave
- password empty and configure environment-based auth (``AWS_PROFILE``,
- ``GOOGLE_APPLICATION_CREDENTIALS``).
+ Connection fields:
+ - **password**: API key
+ - **host**: Base URL (optional, e.g. ``https://api.openai.com/v1``)
+ - **extra** JSON: ``{"model": "openai:gpt-5.3"}``
:param llm_conn_id: Airflow connection ID for the LLM provider.
:param model_id: Model identifier in ``provider:model`` format (e.g.
``"openai:gpt-5.3"``).
@@ -59,12 +58,16 @@ class PydanticAIHook(BaseHook):
def __init__(
self,
- llm_conn_id: str = default_conn_name,
+ llm_conn_id: str | None = None,
model_id: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
- self.llm_conn_id = llm_conn_id
+ # Resolve at runtime so each subclass uses its own default_conn_name.
+ # A bare `llm_conn_id: str = default_conn_name` would bind the *base*
+ # class value for all subclasses because Python evaluates default
+ # argument values at class-definition time.
+ self.llm_conn_id = llm_conn_id if llm_conn_id is not None else
self.default_conn_name
self.model_id = model_id
self._model: Model | None = None
@@ -75,62 +78,97 @@ class PydanticAIHook(BaseHook):
"hidden_fields": ["schema", "port", "login"],
"relabeling": {"password": "API Key"},
"placeholders": {
- "host": "https://api.openai.com/v1 (optional, for custom
endpoints)",
+ "host": "https://api.openai.com/v1 (optional, for custom
endpoints / Ollama)",
+ "extra": '{"model": "openai:gpt-5.3"}',
},
}
+ # ------------------------------------------------------------------
+ # Core connection / agent API
+ # ------------------------------------------------------------------
+
+ def _get_provider_kwargs(
+ self,
+ api_key: str | None,
+ base_url: str | None,
+ extra: dict[str, Any],
+ ) -> dict[str, Any]:
+ """
+ Return the kwargs to pass to the provider constructor.
+
+ Subclasses override this method to map their connection fields to the
+ parameters expected by their specific provider class. The base
+ implementation handles the common ``api_key`` / ``base_url`` pattern
+ used by OpenAI, Anthropic, Groq, Mistral, Ollama, and most other
+ providers.
+
+ :param api_key: Value of ``conn.password``.
+ :param base_url: Value of ``conn.host``.
+ :param extra: Deserialized ``conn.extra`` JSON.
+ :return: Kwargs forwarded to ``provider_cls(**kwargs)``. Empty dict
+ signals that no explicit credentials are available and the hook
+ should fall back to environment-variable–based auth.
+ """
+ kwargs: dict[str, Any] = {}
+ if api_key:
+ kwargs["api_key"] = api_key
+ if base_url:
+ kwargs["base_url"] = base_url
+ return kwargs
+
def get_conn(self) -> Model:
"""
- Return a configured pydantic-ai Model.
+ Return a configured pydantic-ai ``Model``.
- Reads API key from connection password, base_url from connection host,
- and model from (in priority order):
+ Resolution order:
- 1. ``model_id`` parameter on the hook
- 2. ``extra["model"]`` on the connection (set by the "Model" conn-field
in the UI)
+ 1. **Explicit credentials** — when :meth:`_get_provider_kwargs` returns
+ a non-empty dict the provider class is instantiated with those
kwargs
+ and wrapped in a ``provider_factory``.
+ 2. **Default resolution** — delegates to pydantic-ai ``infer_model``
+ which reads standard env vars (``OPENAI_API_KEY``, ``AWS_PROFILE``,
…).
- The result is cached for the lifetime of this hook instance.
+ The resolved model is cached for the lifetime of this hook instance.
"""
if self._model is not None:
return self._model
conn = self.get_connection(self.llm_conn_id)
- model_name: str | KnownModelName = self.model_id or
conn.extra_dejson.get("model", "")
+
+ extra: dict[str, Any] = conn.extra_dejson
+ model_name: str | KnownModelName = self.model_id or extra.get("model",
"")
if not model_name:
raise ValueError(
"No model specified. Set model_id on the hook or the Model
field on the connection."
)
- api_key = conn.password
- base_url = conn.host or None
- if not api_key and not base_url:
- # No credentials to inject — use default provider resolution
- # (picks up env vars like OPENAI_API_KEY, AWS_PROFILE, etc.)
- self._model = infer_model(model_name)
+ api_key: str | None = conn.password or None
+ base_url: str | None = conn.host or None
+
+ provider_kwargs = self._get_provider_kwargs(api_key, base_url, extra)
+ if provider_kwargs:
+ _kwargs = provider_kwargs # capture for closure
+ self.log.info(
+ "Using explicit credentials for provider with model '%s': %s",
+ model_name,
+ list(provider_kwargs),
+ )
+
+ def _provider_factory(pname: str) -> Any:
+ try:
+ return infer_provider_class(pname)(**_kwargs)
+ except TypeError:
+ self.log.warning(
+ "Provider '%s' rejected kwargs %s; falling back to
env-var auth",
+ pname,
+ list(_kwargs),
+ )
+ return infer_provider(pname)
+
+ self._model = infer_model(model_name,
provider_factory=_provider_factory)
return self._model
- def _provider_factory(provider_name: str) -> Provider[Any]:
- """
- Create a provider with credentials from the Airflow connection.
-
- Falls back to default provider resolution if the provider's
constructor
- doesn't accept api_key/base_url (e.g. Google Vertex, Bedrock).
- """
- provider_cls = infer_provider_class(provider_name)
- kwargs: dict[str, Any] = {}
- if api_key:
- kwargs["api_key"] = api_key
- if base_url:
- kwargs["base_url"] = base_url
- try:
- return provider_cls(**kwargs)
- except TypeError:
- # Provider doesn't accept these kwargs (e.g. Google Vertex/GLA
- # use ADC, Bedrock uses boto session). Fall back to default
- # provider resolution which reads credentials from the
environment.
- return infer_provider(provider_name)
-
- self._model = infer_model(model_name,
provider_factory=_provider_factory)
+ self._model = infer_model(model_name)
return self._model
@overload
@@ -157,13 +195,238 @@ class PydanticAIHook(BaseHook):
"""
Test connection by resolving the model.
- Validates that the model string is valid, the provider package is
- installed, and the provider class can be instantiated. Does NOT make an
- LLM API call — that would be expensive, flaky, and fail for reasons
- unrelated to connectivity (quotas, billing, rate limits).
+ Validates that the model string is valid and the provider class can be
+ instantiated with the supplied credentials. Does NOT make an LLM API
+ call — that would be expensive and fail for reasons unrelated to
+ connectivity (quotas, billing, rate limits).
"""
try:
self.get_conn()
return True, "Model resolved successfully."
except Exception as e:
return False, str(e)
+
+
+class PydanticAIAzureHook(PydanticAIHook):
+ """
+ Hook for Azure OpenAI via pydantic-ai.
+
+ Connection fields:
+ - **password**: Azure API key
+ - **host**: Azure endpoint (e.g.
``https://<resource>.openai.azure.com``)
+ - **extra** JSON::
+
+ {"model": "azure:gpt-4o", "api_version": "2024-07-01-preview"}
+
+ :param llm_conn_id: Airflow connection ID.
+ :param model_id: Model identifier, e.g. ``"azure:gpt-4o"``.
+ """
+
+ conn_type = "pydanticai-azure"
+ default_conn_name = "pydanticai_azure_default"
+ hook_name = "Pydantic AI (Azure OpenAI)"
+
+ @staticmethod
+ def get_ui_field_behaviour() -> dict[str, Any]:
+ """Return custom field behaviour for the Airflow connection form."""
+ return {
+ "hidden_fields": ["schema", "port", "login"],
+ "relabeling": {"password": "API Key", "host": "Azure Endpoint"},
+ "placeholders": {
+ "host": "https://<resource>.openai.azure.com",
+ "extra": '{"model": "azure:gpt-4o", "api_version":
"2024-07-01-preview"}',
+ },
+ }
+
+ def _get_provider_kwargs(
+ self,
+ api_key: str | None,
+ base_url: str | None,
+ extra: dict[str, Any],
+ ) -> dict[str, Any]:
+ kwargs: dict[str, Any] = {}
+ if api_key:
+ kwargs["api_key"] = api_key
+ if base_url:
+ kwargs["azure_endpoint"] = base_url
+ if extra.get("api_version"):
+ kwargs["api_version"] = extra["api_version"]
+ return kwargs
+
+
+class PydanticAIBedrockHook(PydanticAIHook):
+ """
+ Hook for AWS Bedrock via pydantic-ai.
+
+ Credentials are resolved in order:
+
+ 1. IAM keys from ``extra`` (``aws_access_key_id`` +
``aws_secret_access_key``,
+ optionally ``aws_session_token``).
+ 2. Bearer token in ``extra`` (``api_key``, maps to env
``AWS_BEARER_TOKEN_BEDROCK``).
+ 3. Environment-variable / instance-role chain (``AWS_PROFILE``, IAM role,
…)
+ when no explicit keys are provided.
+
+ Connection fields:
+ - **extra** JSON::
+
+ {
+ "model": "bedrock:us.anthropic.claude-opus-4-5",
+ "region_name": "us-east-1",
+ "aws_access_key_id": "AKIA...",
+ "aws_secret_access_key": "...",
+ "aws_session_token": "...",
+ "profile_name": "my-aws-profile",
+ "api_key": "bearer-token",
+ "base_url": "https://custom-bedrock-endpoint",
+ "aws_read_timeout": 60.0,
+ "aws_connect_timeout": 10.0
+ }
+
+ Leave ``aws_access_key_id`` / ``aws_secret_access_key`` and
``api_key``
+ empty to use the default AWS credential chain.
+
+ :param llm_conn_id: Airflow connection ID.
+ :param model_id: Model identifier, e.g.
``"bedrock:us.anthropic.claude-opus-4-5"``.
+ """
+
+ conn_type = "pydanticai-bedrock"
+ default_conn_name = "pydanticai_bedrock_default"
+ hook_name = "Pydantic AI (AWS Bedrock)"
+
+ @staticmethod
+ def get_ui_field_behaviour() -> dict[str, Any]:
+ """Return custom field behaviour for the Airflow connection form."""
+ return {
+ "hidden_fields": ["schema", "port", "login", "host", "password"],
+ "relabeling": {},
+ "placeholders": {
+ "extra": (
+ '{"model": "bedrock:us.anthropic.claude-opus-4-5", '
+ '"region_name": "us-east-1"}'
+ " — leave aws_access_key_id empty for IAM role / env-var
auth"
+ ),
+ },
+ }
+
+ def _get_provider_kwargs(
+ self,
+ api_key: str | None,
+ base_url: str | None,
+ extra: dict[str, Any],
+ ) -> dict[str, Any]:
+ """
+ Return kwargs for ``BedrockProvider``.
+
+ .. note::
+ The ``api_key`` and ``base_url`` parameters (sourced from
+ ``conn.password`` and ``conn.host``) are intentionally ignored
here.
+ Bedrock connections hide those fields in the UI; all config is
+ stored in ``extra`` instead. The ``api_key`` and ``base_url``
+ keys below refer to *extra* fields, not the method parameters.
+ """
+ _str_keys = (
+ "aws_access_key_id",
+ "aws_secret_access_key",
+ "aws_session_token",
+ "region_name",
+ "profile_name",
+ # Bearer-token auth (alternative to IAM key/secret).
+ # Maps to AWS_BEARER_TOKEN_BEDROCK env var.
+ "api_key",
+ # Custom Bedrock runtime endpoint.
+ "base_url",
+ )
+ kwargs: dict[str, Any] = {k: extra[k] for k in _str_keys if
extra.get(k)}
+ # BedrockProvider expects float for timeout values; JSON integers must
be coerced.
+ for _timeout_key in ("aws_read_timeout", "aws_connect_timeout"):
+ if extra.get(_timeout_key):
+ kwargs[_timeout_key] = float(extra[_timeout_key])
+ return kwargs
+
+
+class PydanticAIVertexHook(PydanticAIHook):
+ """
+ Hook for Google Vertex AI (or Generative Language API) via pydantic-ai.
+
+ Credentials are resolved in order:
+
+ 1. ``service_account_info`` (JSON object) in ``extra``
+ — loaded into a ``google.auth.credentials.Credentials``
+ object and passed as ``credentials`` to ``GoogleProvider``.
+ 2. ``api_key`` in ``extra`` — for Generative Language API (non-Vertex) or
+ Vertex API-key auth.
+ 3. Application Default Credentials (``GOOGLE_APPLICATION_CREDENTIALS``,
+ ``gcloud auth application-default login``, Workload Identity, …) when
+ no explicit credentials are provided.
+
+ Connection fields:
+ - **extra** JSON::
+
+ {
+ "model": "google-vertex:gemini-2.0-flash",
+ "project": "my-gcp-project",
+ "location": "us-central1",
+ "service_account_info": {...},
+ "vertexai": true,
+ }
+
+ Use ``"service_account_info"`` to embed the service-account JSON
directly
+ (as an object, not a string path).
+
+ Set ``"vertexai": true`` to force Vertex AI mode when only ``api_key``
is
+ provided. Omit ``vertexai`` for the Generative Language API (GLA).
+
+ :param llm_conn_id: Airflow connection ID.
+ :param model_id: Model identifier, e.g.
``"google-vertex:gemini-2.0-flash"``.
+ """
+
+ conn_type = "pydanticai-vertex"
+ default_conn_name = "pydanticai_vertex_default"
+ hook_name = "Pydantic AI (Google Vertex AI)"
+
+ @staticmethod
+ def get_ui_field_behaviour() -> dict[str, Any]:
+ """Return custom field behaviour for the Airflow connection form."""
+ return {
+ "hidden_fields": ["schema", "port", "login", "host", "password"],
+ "relabeling": {},
+ "placeholders": {
+ "extra": (
+ '{"model": "google-vertex:gemini-2.0-flash", '
+ '"project": "my-project", "location": "us-central1",
"vertexai": true}'
+ " — add service_account_info (object) for SA auth;"
+ " omit both to use Application Default Credentials"
+ ),
+ },
+ }
+
+ def _get_provider_kwargs(
+ self,
+ api_key: str | None,
+ base_url: str | None,
+ extra: dict[str, Any],
+ ) -> dict[str, Any]:
+ sa_info = extra.get("service_account_info")
+ kwargs: dict[str, Any] = {}
+
+ # Direct GoogleProvider scalar kwargs.
+ for _key in ("api_key", "project", "location", "base_url"):
+ if extra.get(_key):
+ kwargs[_key] = extra[_key]
+
+ # Optional vertexai bool flag (force Vertex AI mode for API-key auth).
+ _vertexai = extra.get("vertexai")
+ if _vertexai is not None:
+ kwargs["vertexai"] = bool(_vertexai)
+
+ # Service-account credentials — loaded lazily to avoid importing
+ # google-auth on non-Vertex code paths (optional heavy dependency).
+ if sa_info:
+ from google.oauth2 import service_account # lazy: optional dep
+
+ kwargs["credentials"] =
service_account.Credentials.from_service_account_info(
+ sa_info,
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+
+ return kwargs
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
index d9c028b4b57..da8ecafc1c5 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
@@ -173,7 +173,10 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
@cached_property
def llm_hook(self) -> PydanticAIHook:
"""Return PydanticAIHook for the configured LLM connection."""
- return PydanticAIHook(llm_conn_id=self.llm_conn_id,
model_id=self.model_id)
+ hook_params = {
+ "model_id": self.model_id,
+ }
+ return PydanticAIHook.get_hook(self.llm_conn_id,
hook_params=hook_params)
def _build_agent(self) -> Agent[None, Any]:
"""Build and return a pydantic-ai Agent from the operator's config."""
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
index b5a2c80bd11..73000eb50a3 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
@@ -101,8 +101,18 @@ class LLMOperator(BaseOperator, LLMApprovalMixin):
@cached_property
def llm_hook(self) -> PydanticAIHook:
- """Return PydanticAIHook for the configured LLM connection."""
- return PydanticAIHook(llm_conn_id=self.llm_conn_id,
model_id=self.model_id)
+ """
+ Return the correct PydanticAIHook subclass for the configured
connection.
+
+ Delegates to :meth:`~PydanticAIHook.get_hook` which looks up
+ the connection's ``conn_type`` and instantiates the matching subclass
+ (e.g.
:class:`~airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIAzureHook`
+ for ``pydanticai-azure`` connections).
+ """
+ hook_params = {
+ "model_id": self.model_id,
+ }
+ return PydanticAIHook.get_hook(self.llm_conn_id,
hook_params=hook_params)
def execute(self, context: Context) -> Any:
agent: Agent[None, Any] = self.llm_hook.create_agent(
diff --git a/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
index 52ed82c53cd..f5766325140 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
@@ -46,7 +46,7 @@ class TestAgentDecoratedOperator:
"""The callable's return value becomes the agent prompt."""
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("The top
customer is Acme Corp.")
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
def my_prompt():
return "Who is our top customer?"
@@ -78,7 +78,7 @@ class TestAgentDecoratedOperator:
"""op_kwargs are resolved by the callable to build the prompt."""
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("done")
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
def my_prompt(topic):
return f"Analyze {topic}"
@@ -99,7 +99,7 @@ class TestAgentDecoratedOperator:
"""Toolsets passed to the decorator are forwarded to the agent."""
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("result")
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
mock_toolset = MagicMock()
@@ -111,7 +111,7 @@ class TestAgentDecoratedOperator:
)
op.execute(context={})
- create_call = mock_hook_cls.return_value.create_agent.call_args
+ create_call =
mock_hook_cls.get_hook.return_value.create_agent.call_args
passed_toolsets = create_call[1]["toolsets"]
assert len(passed_toolsets) == 1
assert isinstance(passed_toolsets[0], LoggingToolset)
@@ -126,7 +126,7 @@ class TestAgentDecoratedOperator:
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value =
_make_mock_run_result(Summary(text="Great results"))
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
op = _AgentDecoratedOperator(
task_id="test",
diff --git a/providers/common/ai/tests/unit/common/ai/decorators/test_llm.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm.py
index ff3240529ff..21f663e177a 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_llm.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_llm.py
@@ -44,7 +44,7 @@ class TestLLMDecoratedOperator:
"""The callable's return value becomes the LLM prompt."""
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("This is a
summary.")
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
def my_prompt():
return "Summarize this text"
@@ -76,7 +76,7 @@ class TestLLMDecoratedOperator:
"""op_kwargs are resolved by the callable to build the prompt."""
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("done")
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
def my_prompt(topic):
return f"Summarize {topic}"
diff --git
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
index 8daf2799b5f..6620e505db1 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
@@ -49,7 +49,7 @@ class TestLLMBranchDecoratedOperator:
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value =
_make_mock_run_result(downstream_enum.positive)
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
mock_do_branch.return_value = "positive"
def my_prompt():
@@ -92,7 +92,7 @@ class TestLLMBranchDecoratedOperator:
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value =
_make_mock_run_result(downstream_enum.task_a)
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
def my_prompt(ticket_type):
return f"Route this {ticket_type} ticket"
diff --git
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_schema_compare.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_schema_compare.py
index 7722ea4e450..c271c94be00 100644
---
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_schema_compare.py
+++
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_schema_compare.py
@@ -61,7 +61,9 @@ class TestLLMSchemaCompareDecoratedOperator:
@patch.object(LLMSchemaCompareOperator, "_build_schema_context",
return_value="mocked schema")
def test_execute_calls_callable_and_uses_result_as_prompt(self,
mock_build_ctx, mock_hook_cls):
"""The user's callable return value becomes the LLM prompt."""
- mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent(_make_compare_result())
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
_make_mock_agent(
+ _make_compare_result()
+ )
def my_prompt_fn():
return "Compare schemas and flag breaking changes"
@@ -99,7 +101,9 @@ class TestLLMSchemaCompareDecoratedOperator:
@patch.object(LLMSchemaCompareOperator, "_build_schema_context",
return_value="mocked schema")
def test_execute_merges_op_kwargs_into_callable(self, mock_build_ctx,
mock_hook_cls):
"""op_kwargs are resolved by the callable to build the prompt."""
- mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent(_make_compare_result())
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
_make_mock_agent(
+ _make_compare_result()
+ )
def my_prompt_fn(target_env):
return f"Compare schemas for {target_env} environment"
diff --git
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py
index a1b7e3e3f36..0f3f943926a 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py
@@ -44,7 +44,7 @@ class TestLLMSQLDecoratedOperator:
"""The user's callable return value becomes the LLM prompt."""
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("SELECT 1")
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
def my_prompt_fn():
return "Get all users"
@@ -76,7 +76,7 @@ class TestLLMSQLDecoratedOperator:
"""op_kwargs are resolved by the callable to build the prompt."""
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("SELECT 1")
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
def my_prompt_fn(table_name):
return f"Get all rows from {table_name}"
diff --git a/providers/common/ai/tests/unit/common/ai/hooks/test_pydantic_ai.py
b/providers/common/ai/tests/unit/common/ai/hooks/test_pydantic_ai.py
index 5a2ecb2b1fa..993824e1c14 100644
--- a/providers/common/ai/tests/unit/common/ai/hooks/test_pydantic_ai.py
+++ b/providers/common/ai/tests/unit/common/ai/hooks/test_pydantic_ai.py
@@ -16,13 +16,20 @@
# under the License.
from __future__ import annotations
+import json
+import sys
from unittest.mock import MagicMock, patch
import pytest
from pydantic_ai.models import Model
from airflow.models.connection import Connection
-from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
+from airflow.providers.common.ai.hooks.pydantic_ai import (
+ PydanticAIAzureHook,
+ PydanticAIBedrockHook,
+ PydanticAIHook,
+ PydanticAIVertexHook,
+)
class TestPydanticAIHookInit:
@@ -36,6 +43,19 @@ class TestPydanticAIHookInit:
assert hook.llm_conn_id == "my_llm"
assert hook.model_id == "openai:gpt-5.3"
+ def test_azure_hook_uses_own_default_conn_name(self):
+ """Subclass default_conn_name is used, not the base class value."""
+ hook = PydanticAIAzureHook()
+ assert hook.llm_conn_id == "pydanticai_azure_default"
+
+ def test_bedrock_hook_uses_own_default_conn_name(self):
+ hook = PydanticAIBedrockHook()
+ assert hook.llm_conn_id == "pydanticai_bedrock_default"
+
+ def test_vertex_hook_uses_own_default_conn_name(self):
+ hook = PydanticAIVertexHook()
+ assert hook.llm_conn_id == "pydanticai_vertex_default"
+
class TestPydanticAIHookGetConn:
@patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
@@ -175,36 +195,6 @@ class TestPydanticAIHookGetConn:
assert first is second
mock_infer_model.assert_called_once()
- @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_provider",
autospec=True)
-
@patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_provider_class",
autospec=True)
- @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
- def test_provider_factory_falls_back_on_unsupported_kwargs(
- self, mock_infer_model, mock_infer_provider_class, mock_infer_provider
- ):
- """If a provider rejects api_key/base_url, fall back to default
resolution."""
- mock_infer_model.return_value = MagicMock(spec=Model)
- mock_fallback_provider = MagicMock()
- mock_infer_provider.return_value = mock_fallback_provider
- # Simulate a provider that doesn't accept api_key/base_url
- mock_infer_provider_class.return_value =
MagicMock(side_effect=TypeError("unexpected keyword"))
-
- hook = PydanticAIHook(llm_conn_id="test_conn",
model_id="google:gemini-2.0-flash")
- conn = Connection(
- conn_id="test_conn",
- conn_type="pydanticai",
- password="some-key",
- )
- with patch.object(hook, "get_connection", return_value=conn):
- hook.get_conn()
-
- factory = mock_infer_model.call_args[1]["provider_factory"]
- result = factory("google-gla")
-
- # Should have tried provider_cls first, then fallen back to
infer_provider
-
mock_infer_provider_class.return_value.assert_called_once_with(api_key="some-key")
- mock_infer_provider.assert_called_with("google-gla")
- assert result is mock_fallback_provider
-
class TestPydanticAIHookCreateAgent:
@patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
@@ -296,3 +286,369 @@ class TestPydanticAIHookTestConnection:
assert success is False
assert "No model specified" in message
+
+
+# ---------------------------------------------------------------------------
+# Subclass hook tests
+# ---------------------------------------------------------------------------
+
+
+class TestPydanticAIAzureHook:
+ """Tests for PydanticAIAzureHook."""
+
+ def test_conn_type(self):
+ assert PydanticAIAzureHook.conn_type == "pydanticai-azure"
+
+ def test_hook_name(self):
+ assert "Azure" in PydanticAIAzureHook.hook_name
+
+ def test_ui_field_behaviour_relabels_host(self):
+ behaviour = PydanticAIAzureHook.get_ui_field_behaviour()
+ assert behaviour["relabeling"].get("host") == "Azure Endpoint"
+
+ def test_get_provider_kwargs_maps_azure_endpoint(self):
+ hook = PydanticAIAzureHook.__new__(PydanticAIAzureHook)
+ result = hook._get_provider_kwargs(
+ "my-key",
+ "https://myresource.openai.azure.com",
+ {"model": "azure:gpt-4o", "api_version": "2024-07-01-preview"},
+ )
+ assert result["azure_endpoint"] ==
"https://myresource.openai.azure.com"
+ assert result["api_key"] == "my-key"
+ assert result["api_version"] == "2024-07-01-preview"
+ assert "base_url" not in result
+
+ def test_get_provider_kwargs_omits_none_api_key(self):
+ hook = PydanticAIAzureHook.__new__(PydanticAIAzureHook)
+ result = hook._get_provider_kwargs(
+ None,
+ "https://myresource.openai.azure.com",
+ {"model": "azure:gpt-4o", "api_version": "2024-07-01-preview"},
+ )
+ assert "api_key" not in result
+ assert result["azure_endpoint"] ==
"https://myresource.openai.azure.com"
+
+ def test_get_provider_kwargs_omits_azure_endpoint_when_no_host(self):
+ hook = PydanticAIAzureHook.__new__(PydanticAIAzureHook)
+ result = hook._get_provider_kwargs(
+ "my-key",
+ None,
+ {"model": "azure:gpt-4o", "api_version": "2024-07-01-preview"},
+ )
+ assert "azure_endpoint" not in result
+ assert result["api_key"] == "my-key"
+
+ def test_get_provider_kwargs_empty_without_api_version(self):
+ hook = PydanticAIAzureHook.__new__(PydanticAIAzureHook)
+ result = hook._get_provider_kwargs(
+ "my-key",
+ "https://myresource.openai.azure.com",
+ {"model": "azure:gpt-4o"},
+ )
+ # api_version should not appear if not in extra
+ assert "api_version" not in result
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+
@patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_provider_class",
autospec=True)
+ def test_get_conn_uses_azure_endpoint(self, mock_infer_provider_class,
mock_infer_model):
+ mock_infer_model.return_value = MagicMock(spec=Model)
+ mock_provider_cls = MagicMock(return_value=MagicMock())
+ mock_infer_provider_class.return_value = mock_provider_cls
+
+ hook = PydanticAIAzureHook(llm_conn_id="azure_test")
+ conn = Connection(
+ conn_id="azure_test",
+ conn_type="pydanticai-azure",
+ password="azure-key",
+ host="https://myresource.openai.azure.com",
+ extra=json.dumps({"model": "azure:gpt-4o", "api_version":
"2024-07-01-preview"}),
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.get_conn()
+
+ factory = mock_infer_model.call_args[1]["provider_factory"]
+ factory("azure")
+ mock_provider_cls.assert_called_with(
+ api_key="azure-key",
+ azure_endpoint="https://myresource.openai.azure.com",
+ api_version="2024-07-01-preview",
+ )
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+ def test_get_conn_falls_back_to_env_auth_when_no_kwargs(self,
mock_infer_model):
+ """No host + no password → env-var auth path (empty
_get_provider_kwargs)."""
+ mock_infer_model.return_value = MagicMock(spec=Model)
+ hook = PydanticAIAzureHook(llm_conn_id="azure_test")
+ conn = Connection(
+ conn_id="azure_test",
+ conn_type="pydanticai-azure",
+ extra=json.dumps({"model": "azure:gpt-4o"}),
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.get_conn()
+
+ mock_infer_model.assert_called_once_with("azure:gpt-4o")
+
+
+class TestPydanticAIBedrockHook:
+ """Tests for PydanticAIBedrockHook."""
+
+ def test_conn_type(self):
+ assert PydanticAIBedrockHook.conn_type == "pydanticai-bedrock"
+
+ def test_hook_name(self):
+ assert "Bedrock" in PydanticAIBedrockHook.hook_name
+
+ def test_ui_hides_host_and_password(self):
+ behaviour = PydanticAIBedrockHook.get_ui_field_behaviour()
+ assert "host" in behaviour["hidden_fields"]
+ assert "password" in behaviour["hidden_fields"]
+
+ def test_get_provider_kwargs_maps_bedrock_fields(self):
+ hook = PydanticAIBedrockHook.__new__(PydanticAIBedrockHook)
+ result = hook._get_provider_kwargs(
+ None,
+ None,
+ {
+ "model": "bedrock:us.anthropic.claude-opus-4-5",
+ "region_name": "us-east-1",
+ "aws_access_key_id": "AKIA123",
+ "aws_secret_access_key": "secret",
+ },
+ )
+ assert result["region_name"] == "us-east-1"
+ assert result["aws_access_key_id"] == "AKIA123"
+ assert result["aws_secret_access_key"] == "secret"
+ assert "model" not in result
+ assert "api_key" not in result
+
+ def test_get_provider_kwargs_returns_empty_for_env_auth(self):
+ """When no keys are in extra, return {} so env-auth path is taken."""
+ hook = PydanticAIBedrockHook.__new__(PydanticAIBedrockHook)
+ result = hook._get_provider_kwargs(None, None, {"model":
"bedrock:us.anthropic.claude-opus-4-5"})
+ assert result == {}
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+ def test_get_conn_falls_back_to_env_auth(self, mock_infer_model):
+ mock_infer_model.return_value = MagicMock(spec=Model)
+ hook = PydanticAIBedrockHook(llm_conn_id="bedrock_test")
+ conn = Connection(
+ conn_id="bedrock_test",
+ conn_type="pydanticai-bedrock",
+ extra=json.dumps({"model":
"bedrock:us.anthropic.claude-opus-4-5"}),
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.get_conn()
+
+
mock_infer_model.assert_called_once_with("bedrock:us.anthropic.claude-opus-4-5")
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+
@patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_provider_class",
autospec=True)
+ def test_get_conn_uses_explicit_keys(self, mock_infer_provider_class,
mock_infer_model):
+ mock_infer_model.return_value = MagicMock(spec=Model)
+ mock_provider_cls = MagicMock(return_value=MagicMock())
+ mock_infer_provider_class.return_value = mock_provider_cls
+
+ hook = PydanticAIBedrockHook(llm_conn_id="bedrock_test")
+ conn = Connection(
+ conn_id="bedrock_test",
+ conn_type="pydanticai-bedrock",
+ extra=json.dumps(
+ {
+ "model": "bedrock:us.anthropic.claude-opus-4-5",
+ "region_name": "eu-west-1",
+ "aws_access_key_id": "AKIA123",
+ "aws_secret_access_key": "secret",
+ }
+ ),
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.get_conn()
+
+ factory = mock_infer_model.call_args[1]["provider_factory"]
+ factory("bedrock")
+ mock_provider_cls.assert_called_with(
+ region_name="eu-west-1",
+ aws_access_key_id="AKIA123",
+ aws_secret_access_key="secret",
+ )
+
+ def test_get_provider_kwargs_bearer_token(self):
+ """api_key in extra maps to BedrockProvider's bearer-token param."""
+ hook = PydanticAIBedrockHook.__new__(PydanticAIBedrockHook)
+ result = hook._get_provider_kwargs(
+ None,
+ None,
+ {
+ "model": "bedrock:us.anthropic.claude-opus-4-5",
+ "api_key": "bearer-token-value",
+ "region_name": "us-east-1",
+ },
+ )
+ assert result["api_key"] == "bearer-token-value"
+ assert result["region_name"] == "us-east-1"
+ assert "aws_access_key_id" not in result
+
+ def test_get_provider_kwargs_base_url(self):
+ """base_url in extra is forwarded to BedrockProvider."""
+ hook = PydanticAIBedrockHook.__new__(PydanticAIBedrockHook)
+ result = hook._get_provider_kwargs(
+ None,
+ None,
+ {
+ "model": "bedrock:us.anthropic.claude-opus-4-5",
+ "base_url": "https://custom-bedrock.example.com",
+ },
+ )
+ assert result["base_url"] == "https://custom-bedrock.example.com"
+
+ def test_get_provider_kwargs_float_timeouts(self):
+ """Timeout values are coerced to float (JSON delivers them as int)."""
+ hook = PydanticAIBedrockHook.__new__(PydanticAIBedrockHook)
+ result = hook._get_provider_kwargs(
+ None,
+ None,
+ {
+ "model": "bedrock:us.anthropic.claude-opus-4-5",
+ "aws_read_timeout": 60, # int from JSON
+ "aws_connect_timeout": 10.5, # float already
+ },
+ )
+ assert result["aws_read_timeout"] == 60.0
+ assert isinstance(result["aws_read_timeout"], float)
+ assert result["aws_connect_timeout"] == 10.5
+ assert isinstance(result["aws_connect_timeout"], float)
+
+
+class TestPydanticAIVertexHook:
+ """Tests for PydanticAIVertexHook."""
+
+ def test_conn_type(self):
+ assert PydanticAIVertexHook.conn_type == "pydanticai-vertex"
+
+ def test_hook_name(self):
+ assert "Vertex" in PydanticAIVertexHook.hook_name
+
+ def test_ui_hides_host_and_password(self):
+ behaviour = PydanticAIVertexHook.get_ui_field_behaviour()
+ assert "host" in behaviour["hidden_fields"]
+ assert "password" in behaviour["hidden_fields"]
+
+ def test_get_provider_kwargs_maps_vertex_fields(self):
+ """project and location are passed directly; api_key absent when not
in extra."""
+ hook = PydanticAIVertexHook.__new__(PydanticAIVertexHook)
+ result = hook._get_provider_kwargs(
+ None,
+ None,
+ {
+ "model": "google-vertex:gemini-2.0-flash",
+ "project": "my-project",
+ "location": "us-central1",
+ },
+ )
+ assert result["project"] == "my-project"
+ assert result["location"] == "us-central1"
+ assert "model" not in result
+ assert "api_key" not in result
+ assert "project_id" not in result
+
+ def test_get_provider_kwargs_api_key_gla_mode(self):
+ """api_key in extra is forwarded for Generative Language API mode."""
+ hook = PydanticAIVertexHook.__new__(PydanticAIVertexHook)
+ result = hook._get_provider_kwargs(
+ None,
+ None,
+ {"model": "google-gla:gemini-2.0-flash", "api_key": "gla-key"},
+ )
+ assert result["api_key"] == "gla-key"
+
+ def test_get_provider_kwargs_vertexai_flag(self):
+ """vertexai bool is forwarded and coerced to bool."""
+ hook = PydanticAIVertexHook.__new__(PydanticAIVertexHook)
+ result = hook._get_provider_kwargs(
+ None,
+ None,
+ {"model": "google-vertex:gemini-2.0-flash", "api_key": "key",
"vertexai": True},
+ )
+ assert result["vertexai"] is True
+
+ def test_get_provider_kwargs_service_account_info_loads_credentials(self):
+ """service_account_info dict is loaded into a Credentials object."""
+ mock_sa = MagicMock()
+ mock_creds = MagicMock()
+ mock_sa.Credentials.from_service_account_info.return_value = mock_creds
+
+ mock_google_oauth2 = MagicMock()
+ mock_google_oauth2.service_account = mock_sa
+
+ sa_info_dict = {"type": "service_account", "project_id": "my-project",
"private_key": "..."}
+ hook = PydanticAIVertexHook.__new__(PydanticAIVertexHook)
+ with patch.dict(
+ sys.modules,
+ {
+ "google": MagicMock(),
+ "google.oauth2": mock_google_oauth2,
+ "google.oauth2.service_account": mock_sa,
+ },
+ ):
+ result = hook._get_provider_kwargs(
+ None,
+ None,
+ {
+ "model": "google-vertex:gemini-2.0-flash",
+ "service_account_info": sa_info_dict,
+ },
+ )
+
+ mock_sa.Credentials.from_service_account_info.assert_called_once_with(
+ sa_info_dict,
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ assert result["credentials"] is mock_creds
+ assert "service_account_info" not in result
+
+ def test_get_provider_kwargs_returns_empty_for_adc(self):
+ """When no keys are in extra, return {} so ADC path is taken."""
+ hook = PydanticAIVertexHook.__new__(PydanticAIVertexHook)
+ result = hook._get_provider_kwargs(None, None, {"model":
"google-vertex:gemini-2.0-flash"})
+ assert result == {}
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+ def test_get_conn_falls_back_to_adc(self, mock_infer_model):
+ mock_infer_model.return_value = MagicMock(spec=Model)
+ hook = PydanticAIVertexHook(llm_conn_id="vertex_test")
+ conn = Connection(
+ conn_id="vertex_test",
+ conn_type="pydanticai-vertex",
+ extra=json.dumps({"model": "google-vertex:gemini-2.0-flash"}),
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.get_conn()
+
+
mock_infer_model.assert_called_once_with("google-vertex:gemini-2.0-flash")
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+
@patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_provider_class",
autospec=True)
+ def test_get_conn_uses_explicit_project(self, mock_infer_provider_class,
mock_infer_model):
+ mock_infer_model.return_value = MagicMock(spec=Model)
+ mock_provider_cls = MagicMock(return_value=MagicMock())
+ mock_infer_provider_class.return_value = mock_provider_cls
+
+ hook = PydanticAIVertexHook(llm_conn_id="vertex_test")
+ conn = Connection(
+ conn_id="vertex_test",
+ conn_type="pydanticai-vertex",
+ extra=json.dumps(
+ {
+ "model": "google-vertex:gemini-2.0-flash",
+ "project": "my-project",
+ "location": "europe-west4",
+ }
+ ),
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.get_conn()
+
+ factory = mock_infer_model.call_args[1]["provider_factory"]
+ factory("google-vertex")
+ mock_provider_cls.assert_called_with(project="my-project",
location="europe-west4")
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
index 2125b3301f8..13410b35c55 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
@@ -82,7 +82,7 @@ class TestAgentOperatorExecute:
@patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
def test_execute_creates_agent_from_hook(self, mock_hook_cls):
mock_agent = _make_mock_agent("The answer is 42.")
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
op = AgentOperator(
task_id="test",
@@ -93,8 +93,8 @@ class TestAgentOperatorExecute:
result = op.execute(context=MagicMock())
assert result == "The answer is 42."
- mock_hook_cls.assert_called_once_with(llm_conn_id="my_llm",
model_id=None)
- mock_hook_cls.return_value.create_agent.assert_called_once_with(
+ mock_hook_cls.get_hook.assert_called_once_with("my_llm",
hook_params={"model_id": None})
+
mock_hook_cls.get_hook.return_value.create_agent.assert_called_once_with(
output_type=str, instructions="You are helpful."
)
mock_agent.run_sync.assert_called_once_with("What is the answer?")
@@ -102,7 +102,7 @@ class TestAgentOperatorExecute:
@patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
def test_execute_passes_toolsets_in_agent_kwargs(self, mock_hook_cls):
"""Toolsets are passed through to the agent constructor."""
- mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("done")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
_make_mock_agent("done")
mock_toolset = MagicMock()
op = AgentOperator(
@@ -113,7 +113,7 @@ class TestAgentOperatorExecute:
)
op.execute(context=MagicMock())
- create_call = mock_hook_cls.return_value.create_agent.call_args
+ create_call =
mock_hook_cls.get_hook.return_value.create_agent.call_args
passed_toolsets = create_call[1]["toolsets"]
assert len(passed_toolsets) == 1
assert isinstance(passed_toolsets[0], LoggingToolset)
@@ -122,7 +122,7 @@ class TestAgentOperatorExecute:
@patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
def test_enable_tool_logging_false_skips_wrapping(self, mock_hook_cls):
"""enable_tool_logging=False passes toolsets through unwrapped."""
- mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("done")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
_make_mock_agent("done")
mock_toolset = MagicMock()
op = AgentOperator(
@@ -134,13 +134,13 @@ class TestAgentOperatorExecute:
)
op.execute(context=MagicMock())
- create_call = mock_hook_cls.return_value.create_agent.call_args
+ create_call =
mock_hook_cls.get_hook.return_value.create_agent.call_args
assert create_call[1]["toolsets"] == [mock_toolset]
@patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
def test_execute_passes_agent_params(self, mock_hook_cls):
"""agent_params are unpacked into create_agent."""
- mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("ok")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
_make_mock_agent("ok")
op = AgentOperator(
task_id="test",
@@ -150,7 +150,7 @@ class TestAgentOperatorExecute:
)
op.execute(context=MagicMock())
- create_call = mock_hook_cls.return_value.create_agent.call_args
+ create_call =
mock_hook_cls.get_hook.return_value.create_agent.call_args
assert create_call[1]["retries"] == 3
assert create_call[1]["model_settings"] == {"temperature": 0}
@@ -162,7 +162,7 @@ class TestAgentOperatorExecute:
text: str
score: float
- mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent(
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
_make_mock_agent(
Summary(text="Great", score=0.95)
)
@@ -179,7 +179,7 @@ class TestAgentOperatorExecute:
@patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
def test_execute_with_model_id(self, mock_hook_cls):
"""model_id is passed to PydanticAIHook."""
- mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("ok")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
_make_mock_agent("ok")
op = AgentOperator(
task_id="test",
@@ -189,7 +189,7 @@ class TestAgentOperatorExecute:
)
op.execute(context=MagicMock())
- mock_hook_cls.assert_called_once_with(llm_conn_id="my_llm",
model_id="openai:gpt-5")
+ mock_hook_cls.get_hook.assert_called_once_with("my_llm",
hook_params={"model_id": "openai:gpt-5"})
@pytest.mark.skipif(
not AIRFLOW_V_3_1_PLUS, reason="Human in the loop is only compatible
with Airflow >= 3.1.0"
@@ -203,7 +203,7 @@ class TestAgentOperatorExecute:
mock_result.all_messages.return_value = msg_history
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = mock_result
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
mock_run_hitl.return_value = "Approved output"
op = AgentOperator(
@@ -234,7 +234,7 @@ class TestAgentOperatorExecute:
mock_result = _make_mock_run_result(Summary(text="Approved summary",
score=0.9))
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = mock_result
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
# run_hitl_review returns JSON string (as stored in
session.current_output)
mock_run_hitl.return_value = '{"text": "Approved summary", "score":
0.9}'
@@ -261,7 +261,7 @@ class TestAgentOperatorExecute:
mock_result = _make_mock_run_result("Initial output")
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = mock_result
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
mock_run_hitl.return_value = "Approved output"
op = AgentOperator(
@@ -289,7 +289,7 @@ class TestAgentOperatorExecute:
mock_result = _make_mock_run_result("Initial output")
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = mock_result
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
mock_run_hitl.side_effect = HITLMaxIterationsError("Task exceeded max
iterations.")
op = AgentOperator(
@@ -361,7 +361,7 @@ class TestAgentOperatorRegenerateWithFeedback:
mock_result.all_messages.return_value = msg_history + [MagicMock()]
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = mock_result
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
op = AgentOperator(
task_id="test",
@@ -391,7 +391,7 @@ class TestAgentOperatorRegenerateWithFeedback:
mock_result.all_messages.return_value = []
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = mock_result
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
op = AgentOperator(
task_id="test",
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
index d596bb39b87..3a686d014f0 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
@@ -53,15 +53,17 @@ class TestLLMOperator:
"""Default output_type=str returns the LLM string directly."""
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("Paris is the
capital of France.")
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
op = LLMOperator(task_id="test", prompt="What is the capital of
France?", llm_conn_id="my_llm")
result = op.execute(context=MagicMock())
assert result == "Paris is the capital of France."
mock_agent.run_sync.assert_called_once_with("What is the capital of
France?")
-
mock_hook_cls.return_value.create_agent.assert_called_once_with(output_type=str,
instructions="")
- mock_hook_cls.assert_called_once_with(llm_conn_id="my_llm",
model_id=None)
+
mock_hook_cls.get_hook.return_value.create_agent.assert_called_once_with(
+ output_type=str, instructions=""
+ )
+ mock_hook_cls.get_hook.assert_called_once_with("my_llm",
hook_params={"model_id": None})
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
def test_execute_structured_output_with_all_params(self, mock_hook_cls):
@@ -72,7 +74,7 @@ class TestLLMOperator:
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value =
_make_mock_run_result(Entities(names=["Alice", "Bob"]))
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
op = LLMOperator(
task_id="test",
@@ -86,8 +88,8 @@ class TestLLMOperator:
result = op.execute(context=MagicMock())
assert result == {"names": ["Alice", "Bob"]}
- mock_hook_cls.assert_called_once_with(llm_conn_id="my_llm",
model_id="openai:gpt-5")
- mock_hook_cls.return_value.create_agent.assert_called_once_with(
+ mock_hook_cls.get_hook.assert_called_once_with("my_llm",
hook_params={"model_id": "openai:gpt-5"})
+
mock_hook_cls.get_hook.return_value.create_agent.assert_called_once_with(
output_type=Entities,
instructions="You are an extractor.",
retries=3,
@@ -126,7 +128,7 @@ class TestLLMOperatorApproval:
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("LLM
response")
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
op = LLMOperator(
task_id="approval_test",
@@ -152,7 +154,7 @@ class TestLLMOperatorApproval:
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("draft
output")
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
op = LLMOperator(
task_id="mod_test",
@@ -178,7 +180,7 @@ class TestLLMOperatorApproval:
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("output")
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
timeout = timedelta(hours=1)
op = LLMOperator(
@@ -207,7 +209,7 @@ class TestLLMOperatorApproval:
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value =
_make_mock_run_result(Summary(text="hello"))
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
op = LLMOperator(
task_id="struct_test",
@@ -228,7 +230,7 @@ class TestLLMOperatorApproval:
"""When require_approval=False, execute() returns output directly."""
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("plain
output")
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
op = LLMOperator(task_id="no_approval", prompt="p",
llm_conn_id="my_llm", require_approval=False)
result = op.execute(context={})
diff --git
a/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py
index ffd475f71b8..78e6c231080 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py
@@ -64,7 +64,7 @@ class TestLLMBranchOperator:
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value =
_make_mock_run_result(downstream_enum.task_a)
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
mock_do_branch.return_value = "task_a"
op = LLMBranchOperator(
@@ -93,7 +93,7 @@ class TestLLMBranchOperator:
mock_agent.run_sync.return_value = _make_mock_run_result(
[downstream_enum.task_a, downstream_enum.task_c]
)
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
mock_do_branch.return_value = ["task_a", "task_c"]
op = LLMBranchOperator(
@@ -118,7 +118,7 @@ class TestLLMBranchOperator:
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value =
_make_mock_run_result(downstream_enum.task_a)
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
op = LLMBranchOperator(
task_id="test",
@@ -130,7 +130,7 @@ class TestLLMBranchOperator:
op.execute(MagicMock())
- call_kwargs = mock_hook_cls.return_value.create_agent.call_args
+ call_kwargs =
mock_hook_cls.get_hook.return_value.create_agent.call_args
assert call_kwargs.kwargs["instructions"] == "Route tickets to the
right team."
@patch.object(LLMBranchOperator, "do_branch")
@@ -143,7 +143,7 @@ class TestLLMBranchOperator:
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value =
_make_mock_run_result(downstream_enum.billing)
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
op = LLMBranchOperator(
task_id="test",
@@ -154,7 +154,7 @@ class TestLLMBranchOperator:
op.execute(MagicMock())
- output_type =
mock_hook_cls.return_value.create_agent.call_args.kwargs["output_type"]
+ output_type =
mock_hook_cls.get_hook.return_value.create_agent.call_args.kwargs["output_type"]
assert {m.value for m in output_type} == {"billing", "auth", "general"}
def test_execute_raises_on_no_downstream_tasks(self):
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
index 57bdcb088f0..4ff3158a356 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
@@ -100,7 +100,7 @@ class TestLLMSQLQueryOperator:
def test_execute_with_schema_context(self, mock_hook_cls):
"""Operator uses schema_context and returns generated SQL."""
mock_agent = _make_mock_agent("SELECT id, name FROM users WHERE active
= true")
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
op = LLMSQLQueryOperator(
task_id="test",
@@ -116,7 +116,7 @@ class TestLLMSQLQueryOperator:
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
def test_execute_validation_blocks_unsafe_sql(self, mock_hook_cls):
"""Validation catches unsafe SQL generated by the LLM."""
- mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("DROP TABLE users")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
_make_mock_agent("DROP TABLE users")
op = LLMSQLQueryOperator(task_id="test", prompt="Delete everything",
llm_conn_id="my_llm")
@@ -126,7 +126,7 @@ class TestLLMSQLQueryOperator:
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
def test_execute_validation_disabled(self, mock_hook_cls):
"""When validate_sql=False, unsafe SQL is returned without checks."""
- mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("DROP TABLE users")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
_make_mock_agent("DROP TABLE users")
op = LLMSQLQueryOperator(task_id="test", prompt="Drop it",
llm_conn_id="my_llm", validate_sql=False)
result = op.execute(context=MagicMock())
@@ -136,7 +136,7 @@ class TestLLMSQLQueryOperator:
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
def test_execute_passes_agent_params(self, mock_hook_cls):
"""agent_params inherited from LLMOperator are unpacked into
create_agent."""
- mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("SELECT 1")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
_make_mock_agent("SELECT 1")
op = LLMSQLQueryOperator(
task_id="test",
@@ -146,14 +146,14 @@ class TestLLMSQLQueryOperator:
)
op.execute(context=MagicMock())
- create_agent_call = mock_hook_cls.return_value.create_agent.call_args
+ create_agent_call =
mock_hook_cls.get_hook.return_value.create_agent.call_args
assert create_agent_call[1]["retries"] == 3
assert create_agent_call[1]["model_settings"] == {"temperature": 0}
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
def test_system_prompt_appended_to_sql_instructions(self, mock_hook_cls):
"""User-provided system_prompt is appended to built-in SQL safety
prompt."""
- mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("SELECT 1")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
_make_mock_agent("SELECT 1")
op = LLMSQLQueryOperator(
task_id="test",
@@ -163,7 +163,7 @@ class TestLLMSQLQueryOperator:
)
op.execute(context=MagicMock())
- instructions =
mock_hook_cls.return_value.create_agent.call_args[1]["instructions"]
+ instructions =
mock_hook_cls.get_hook.return_value.create_agent.call_args[1]["instructions"]
assert "Always use LEFT JOINs." in instructions
# Built-in SQL safety prompt should still be present
assert "Generate only SELECT queries" in instructions
@@ -175,7 +175,7 @@ class TestLLMSQLQueryOperatorSchemaIntrospection:
def test_introspect_schemas_via_db_hook(self, mock_hook_cls):
"""db_conn_id + table_names triggers schema introspection."""
mock_agent = _make_mock_agent("SELECT id FROM users")
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
mock_db_hook = MagicMock(spec=["get_table_schema", "dialect_name"])
mock_db_hook.get_table_schema.return_value = [
@@ -199,7 +199,7 @@ class TestLLMSQLQueryOperatorSchemaIntrospection:
mock_db_hook.get_table_schema.assert_called_once_with("users")
# Verify the system prompt contains the schema info
- instructions =
mock_hook_cls.return_value.create_agent.call_args[1]["instructions"]
+ instructions =
mock_hook_cls.get_hook.return_value.create_agent.call_args[1]["instructions"]
assert "users" in instructions
assert "id INTEGER" in instructions
@@ -362,7 +362,7 @@ class TestLLMSQLQueryOperatorSchemaIntrospection:
mock_engine.get_schema.return_value = "event: TEXT\nts: TIMESTAMP"
mock_agent = _make_mock_agent("SELECT u.id, e.event FROM users u JOIN
events e ON u.id = e.user_id")
- mock_hook_cls.return_value.create_agent.return_value = mock_agent
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
ds_config = DataSourceConfig(
conn_id="aws_default",
@@ -390,7 +390,7 @@ class TestLLMSQLQueryOperatorSchemaIntrospection:
result = op.execute(context=MagicMock())
assert "SELECT" in result
- instructions =
mock_hook_cls.return_value.create_agent.call_args[1]["instructions"]
+ instructions =
mock_hook_cls.get_hook.return_value.create_agent.call_args[1]["instructions"]
assert "users" in instructions
assert "events" in instructions
assert "event: TEXT\nts: TIMESTAMP" in instructions
@@ -464,7 +464,7 @@ class TestLLMSQLQueryOperatorApproval:
def test_execute_with_approval_defers(self, mock_hook_cls, mock_upsert,
mock_trigger_cls):
"""When require_approval=True, execute() defers after generating and
validating SQL."""
- mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent(
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
_make_mock_agent(
"SELECT id FROM users WHERE active"
)
@@ -491,7 +491,7 @@ class TestLLMSQLQueryOperatorApproval:
self, mock_hook_cls, mock_upsert, mock_trigger_cls
):
"""SQL validation runs before defer_for_approval; unsafe SQL is
blocked."""
- mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("DROP TABLE users")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
_make_mock_agent("DROP TABLE users")
op = LLMSQLQueryOperator(
task_id="sql_unsafe",
@@ -512,7 +512,7 @@ class TestLLMSQLQueryOperatorApproval:
def test_execute_with_approval_and_modifications(self, mock_hook_cls,
mock_upsert, mock_trigger_cls):
"""allow_modifications=True passes editable params."""
- mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("SELECT 1")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
_make_mock_agent("SELECT 1")
op = LLMSQLQueryOperator(
task_id="sql_mod",
@@ -535,7 +535,7 @@ class TestLLMSQLQueryOperatorApproval:
def test_execute_with_approval_and_timeout(self, mock_hook_cls,
mock_upsert, mock_trigger_cls):
"""approval_timeout is propagated to the trigger."""
- mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("SELECT 1")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
_make_mock_agent("SELECT 1")
timeout = timedelta(minutes=30)
op = LLMSQLQueryOperator(
@@ -555,7 +555,7 @@ class TestLLMSQLQueryOperatorApproval:
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
def test_execute_without_approval_returns_sql(self, mock_hook_cls):
"""When require_approval=False, execute() returns the SQL directly."""
- mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("SELECT 1")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
_make_mock_agent("SELECT 1")
op = LLMSQLQueryOperator(
task_id="no_approval",
@@ -573,7 +573,9 @@ class TestLLMSQLQueryOperatorApproval:
def test_execute_strips_code_fences_before_deferring(self, mock_hook_cls,
mock_upsert, mock_trigger_cls):
"""Markdown code fences are stripped from LLM output before
deferring."""
- mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("```sql\nSELECT 1\n```")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
_make_mock_agent(
+ "```sql\nSELECT 1\n```"
+ )
op = LLMSQLQueryOperator(
task_id="strip_test",