This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7f6457f1f47 Add Azure OpenAI  & Bedrock support in ``common.ai`` 
provider (#62816)
7f6457f1f47 is described below

commit 7f6457f1f4796b1bcaa0cd85388b6fa896ca4dbb
Author: Gökhan Çetin <[email protected]>
AuthorDate: Wed Mar 11 22:58:10 2026 +0300

    Add Azure OpenAI  & Bedrock support in ``common.ai`` provider (#62816)
---
 .../ai/docs/operators/llm_schema_compare.rst       |   2 +-
 providers/common/ai/docs/toolsets.rst              |   4 +-
 providers/common/ai/provider.yaml                  | 170 +++++++++
 .../providers/common/ai/decorators/agent.py        |   2 +-
 .../common/ai/example_dags/example_agent.py        |  10 +-
 .../ai/example_dags/example_llm_schema_compare.py  |  10 +-
 .../common/ai/example_dags/example_mcp.py          |   6 +-
 .../providers/common/ai/get_provider_info.py       | 128 +++++++
 .../providers/common/ai/hooks/pydantic_ai.py       | 373 +++++++++++++++---
 .../airflow/providers/common/ai/operators/agent.py |   5 +-
 .../airflow/providers/common/ai/operators/llm.py   |  14 +-
 .../tests/unit/common/ai/decorators/test_agent.py  |  10 +-
 .../ai/tests/unit/common/ai/decorators/test_llm.py |   4 +-
 .../unit/common/ai/decorators/test_llm_branch.py   |   4 +-
 .../ai/decorators/test_llm_schema_compare.py       |   8 +-
 .../unit/common/ai/decorators/test_llm_sql.py      |   4 +-
 .../tests/unit/common/ai/hooks/test_pydantic_ai.py | 418 +++++++++++++++++++--
 .../tests/unit/common/ai/operators/test_agent.py   |  36 +-
 .../ai/tests/unit/common/ai/operators/test_llm.py  |  24 +-
 .../unit/common/ai/operators/test_llm_branch.py    |  12 +-
 .../tests/unit/common/ai/operators/test_llm_sql.py |  36 +-
 21 files changed, 1109 insertions(+), 171 deletions(-)

diff --git a/providers/common/ai/docs/operators/llm_schema_compare.rst 
b/providers/common/ai/docs/operators/llm_schema_compare.rst
index f6767e04dbe..ebdf3f48994 100644
--- a/providers/common/ai/docs/operators/llm_schema_compare.rst
+++ b/providers/common/ai/docs/operators/llm_schema_compare.rst
@@ -91,7 +91,7 @@ on top, concatenate them:
     LLMSchemaCompareOperator(
         task_id="compare_with_custom_rules",
         prompt="Compare schemas and flag breaking changes",
-        llm_conn_id="pydantic_ai_default",
+        llm_conn_id="pydanticai_default",
         db_conn_ids=["postgres_source", "snowflake_target"],
         table_names=["customers"],
         system_prompt=DEFAULT_SYSTEM_PROMPT
diff --git a/providers/common/ai/docs/toolsets.rst 
b/providers/common/ai/docs/toolsets.rst
index de708ac9c4c..8d295ac361c 100644
--- a/providers/common/ai/docs/toolsets.rst
+++ b/providers/common/ai/docs/toolsets.rst
@@ -256,7 +256,7 @@ Using Multiple MCP Servers
     AgentOperator(
         task_id="multi_mcp",
         prompt="Get the weather in London and run a calculation",
-        llm_conn_id="pydantic_ai_default",
+        llm_conn_id="pydanticai_default",
         toolsets=[
             MCPToolset(mcp_conn_id="weather_mcp", tool_prefix="weather"),
             MCPToolset(mcp_conn_id="code_runner_mcp", tool_prefix="code"),
@@ -276,7 +276,7 @@ server instances directly — no Airflow connection needed:
     AgentOperator(
         task_id="direct_mcp",
         prompt="What tools are available?",
-        llm_conn_id="pydantic_ai_default",
+        llm_conn_id="pydanticai_default",
         toolsets=[
             MCPServerStreamableHTTP("http://localhost:3001/mcp";),
             MCPServerStdio("uvx", args=["mcp-run-python"]),
diff --git a/providers/common/ai/provider.yaml 
b/providers/common/ai/provider.yaml
index 4a406a66470..43a98af32a3 100644
--- a/providers/common/ai/provider.yaml
+++ b/providers/common/ai/provider.yaml
@@ -79,6 +79,176 @@ connection-types:
           type:
             - string
             - 'null'
+  - hook-class-name: 
airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIAzureHook
+    connection-type: pydanticai-azure
+    ui-field-behaviour:
+      hidden-fields:
+        - schema
+        - port
+        - login
+      relabeling:
+        password: API Key
+        host: Azure Endpoint
+      placeholders:
+        host: "https://<resource>.openai.azure.com"
+    conn-fields:
+      model:
+        label: Model
+        description: "Azure model identifier (e.g. azure:gpt-4o)"
+        schema:
+          type:
+            - string
+            - 'null'
+      api_version:
+        label: API Version
+        description: "Azure OpenAI API version (e.g. 2024-07-01-preview). 
Falls back to OPENAI_API_VERSION."
+        schema:
+          type:
+            - string
+            - 'null'
+  - hook-class-name: 
airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIBedrockHook
+    connection-type: pydanticai-bedrock
+    ui-field-behaviour:
+      hidden-fields:
+        - schema
+        - port
+        - login
+        - host
+        - password
+      relabeling: {}
+      placeholders: {}
+    conn-fields:
+      model:
+        label: Model
+        description: "Bedrock model identifier (e.g. 
bedrock:us.anthropic.claude-opus-4-5)"
+        schema:
+          type:
+            - string
+            - 'null'
+      region_name:
+        label: AWS Region
+        description: "AWS region (e.g. us-east-1). Falls back to 
AWS_DEFAULT_REGION env var."
+        schema:
+          type:
+            - string
+            - 'null'
+      aws_access_key_id:
+        label: AWS Access Key ID
+        description: "IAM access key. Leave empty to use instance role / 
environment credential chain."
+        schema:
+          type:
+            - string
+            - 'null'
+      aws_secret_access_key:
+        label: AWS Secret Access Key
+        description: "IAM secret key."
+        schema:
+          type:
+            - string
+            - 'null'
+      aws_session_token:
+        label: AWS Session Token
+        description: "Temporary session token (optional)."
+        schema:
+          type:
+            - string
+            - 'null'
+      profile_name:
+        label: AWS Profile Name
+        description: "Named AWS credentials profile (optional)."
+        schema:
+          type:
+            - string
+            - 'null'
+      api_key:
+        label: Bearer Token
+        description: "AWS bearer token (alt. to IAM key/secret). Falls back to 
AWS_BEARER_TOKEN_BEDROCK."
+        schema:
+          type:
+            - string
+            - 'null'
+      base_url:
+        label: Custom Endpoint URL
+        description: "Override the Bedrock runtime endpoint URL (optional)."
+        schema:
+          type:
+            - string
+            - 'null'
+      aws_read_timeout:
+        label: Read Timeout (s)
+        description: "boto3 read timeout in seconds (float, optional)."
+        schema:
+          type:
+            - number
+            - 'null'
+      aws_connect_timeout:
+        label: Connect Timeout (s)
+        description: "boto3 connect timeout in seconds (float, optional)."
+        schema:
+          type:
+            - number
+            - 'null'
+  - hook-class-name: 
airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIVertexHook
+    connection-type: pydanticai-vertex
+    ui-field-behaviour:
+      hidden-fields:
+        - schema
+        - port
+        - login
+        - host
+        - password
+      relabeling: {}
+      placeholders: {}
+    conn-fields:
+      model:
+        label: Model
+        description: "Google model identifier (e.g. 
google-vertex:gemini-2.0-flash)"
+        schema:
+          type:
+            - string
+            - 'null'
+      project:
+        label: GCP Project
+        description: "Google Cloud project ID. Falls back to 
GOOGLE_CLOUD_PROJECT env var."
+        schema:
+          type:
+            - string
+            - 'null'
+      location:
+        label: Location / Region
+        description: "Vertex AI region (e.g. us-central1). Falls back to 
GOOGLE_CLOUD_LOCATION env var."
+        schema:
+          type:
+            - string
+            - 'null'
+      vertexai:
+        label: Force Vertex AI Mode
+        description: "Force Vertex AI mode. Auto-detected when 
project/location/credentials are set."
+        schema:
+          type:
+            - boolean
+            - 'null'
+      api_key:
+        label: API Key
+        description: "Google API key for Gen Language API or Vertex AI. Falls 
back to GOOGLE_API_KEY."
+        schema:
+          type:
+            - string
+            - 'null'
+      service_account_info:
+        label: Service Account Info
+        description: "Service account key as inline dict (JSON with type, 
project_id, private_key, etc.)."
+        schema:
+          type:
+            - object
+            - 'null'
+      base_url:
+        label: Custom Endpoint URL
+        description: "Override the Google API base URL (optional)."
+        schema:
+          type:
+            - string
+            - 'null'
   - hook-class-name: airflow.providers.common.ai.hooks.mcp.MCPHook
     connection-type: mcp
     ui-field-behaviour:
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py 
b/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py
index 40c55f630c0..379c8cb65fa 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py
@@ -107,7 +107,7 @@ def agent_task(
     Usage::
 
         @task.agent(
-            llm_conn_id="pydantic_ai_default",
+            llm_conn_id="pydanticai_default",
             system_prompt="You are a data analyst.",
             toolsets=[SQLToolset(db_conn_id="postgres_default")],
         )
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
index 3da0ef0ffaa..d36461e5a0c 100644
--- 
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
+++ 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
@@ -36,7 +36,7 @@ def example_agent_operator_sql():
     AgentOperator(
         task_id="analyst",
         prompt="What are the top 5 customers by order count?",
-        llm_conn_id="pydantic_ai_default",
+        llm_conn_id="pydanticai_default",
         system_prompt=(
             "You are a SQL analyst. Use the available tools to explore "
             "the schema and answer the question with data."
@@ -71,7 +71,7 @@ def example_agent_operator_hook():
     AgentOperator(
         task_id="api_explorer",
         prompt="What endpoints are available and what does /status return?",
-        llm_conn_id="pydantic_ai_default",
+        llm_conn_id="pydanticai_default",
         system_prompt="You are an API explorer. Use the tools to discover and 
call endpoints.",
         toolsets=[
             HookToolset(
@@ -97,7 +97,7 @@ example_agent_operator_hook()
 @dag
 def example_agent_decorator():
     @task.agent(
-        llm_conn_id="pydantic_ai_default",
+        llm_conn_id="pydanticai_default",
         system_prompt="You are a data analyst. Use tools to answer questions.",
         toolsets=[
             SQLToolset(
@@ -133,7 +133,7 @@ def example_agent_structured_output():
         row_count: int
 
     @task.agent(
-        llm_conn_id="pydantic_ai_default",
+        llm_conn_id="pydanticai_default",
         system_prompt="You are a data analyst. Return structured results.",
         output_type=Analysis,
         toolsets=[SQLToolset(db_conn_id="postgres_default")],
@@ -158,7 +158,7 @@ example_agent_structured_output()
 @dag
 def example_agent_chain():
     @task.agent(
-        llm_conn_id="pydantic_ai_default",
+        llm_conn_id="pydanticai_default",
         system_prompt="You are a SQL analyst.",
         toolsets=[SQLToolset(db_conn_id="postgres_default", 
allowed_tables=["orders"])],
     )
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_schema_compare.py
 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_schema_compare.py
index 0e6d306f7bc..88d3891ead1 100644
--- 
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_schema_compare.py
+++ 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_schema_compare.py
@@ -29,7 +29,7 @@ def example_llm_schema_compare_basic():
     LLMSchemaCompareOperator(
         task_id="detect_schema_drift",
         prompt="Identify schema mismatches that would break data loading 
between systems",
-        llm_conn_id="pydantic_ai_default",
+        llm_conn_id="pydanticai_default",
         db_conn_ids=["postgres_default", "snowflake_default"],
         table_names=["customers"],
     )
@@ -49,7 +49,7 @@ def example_llm_schema_compare_full_context():
             "Compare schemas and generate a migration plan. "
             "Flag any differences that would break nightly ETL loads."
         ),
-        llm_conn_id="pydantic_ai_default",
+        llm_conn_id="pydanticai_default",
         db_conn_ids=["postgres_source", "snowflake_target"],
         table_names=["customers", "orders"],
         context_strategy="full",
@@ -74,7 +74,7 @@ def example_llm_schema_compare_with_object_storage():
     LLMSchemaCompareOperator(
         task_id="compare_s3_vs_db",
         prompt="Compare S3 Parquet schema against the Postgres table and flag 
breaking changes",
-        llm_conn_id="pydantic_ai_default",
+        llm_conn_id="pydanticai_default",
         db_conn_ids=["postgres_default"],
         table_names=["customers"],
         data_sources=[s3_source],
@@ -90,7 +90,7 @@ example_llm_schema_compare_with_object_storage()
 @dag
 def example_llm_schema_compare_decorator():
     @task.llm_schema_compare(
-        llm_conn_id="pydantic_ai_default",
+        llm_conn_id="pydanticai_default",
         db_conn_ids=["postgres_source", "snowflake_target"],
         table_names=["customers"],
     )
@@ -109,7 +109,7 @@ example_llm_schema_compare_decorator()
 @dag
 def example_llm_schema_compare_conditional():
     @task.llm_schema_compare(
-        llm_conn_id="pydantic_ai_default",
+        llm_conn_id="pydanticai_default",
         db_conn_ids=["postgres_source", "snowflake_target"],
         table_names=["customers"],
         context_strategy="full",
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_mcp.py
 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_mcp.py
index adc3dc85b94..320fe8428ce 100644
--- 
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_mcp.py
+++ 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_mcp.py
@@ -34,7 +34,7 @@ def example_mcp_toolset():
     AgentOperator(
         task_id="mcp_agent",
         prompt="What tools are available? Run the hello tool.",
-        llm_conn_id="pydantic_ai_default",
+        llm_conn_id="pydanticai_default",
         system_prompt="You are a helpful assistant with access to MCP tools.",
         toolsets=[
             MCPToolset(mcp_conn_id="my_mcp_server"),
@@ -59,7 +59,7 @@ def example_mcp_multiple_servers():
     AgentOperator(
         task_id="multi_mcp_agent",
         prompt="Get the weather in London and run a Python calculation: 2**10",
-        llm_conn_id="pydantic_ai_default",
+        llm_conn_id="pydanticai_default",
         system_prompt="You have access to weather and code execution tools.",
         toolsets=[
             MCPToolset(mcp_conn_id="weather_mcp", tool_prefix="weather"),
@@ -84,7 +84,7 @@ example_mcp_multiple_servers()
 #   AgentOperator(
 #       task_id="direct_mcp",
 #       prompt="What tools are available?",
-#       llm_conn_id="pydantic_ai_default",
+#       llm_conn_id="pydanticai_default",
 #       toolsets=[
 #           MCPServerStreamableHTTP("http://localhost:3001/mcp";),
 #           MCPServerStdio("uvx", args=["mcp-run-python"]),
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py 
b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
index 0ace3488ef1..90ae393d64d 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
@@ -80,6 +80,134 @@ def get_provider_info():
                     }
                 },
             },
+            {
+                "hook-class-name": 
"airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIAzureHook",
+                "connection-type": "pydanticai-azure",
+                "ui-field-behaviour": {
+                    "hidden-fields": ["schema", "port", "login"],
+                    "relabeling": {"password": "API Key", "host": "Azure 
Endpoint"},
+                    "placeholders": {"host": 
"https://<resource>.openai.azure.com"},
+                },
+                "conn-fields": {
+                    "model": {
+                        "label": "Model",
+                        "description": "Azure model identifier (e.g. 
azure:gpt-4o)",
+                        "schema": {"type": ["string", "null"]},
+                    },
+                    "api_version": {
+                        "label": "API Version",
+                        "description": "Azure OpenAI API version (e.g. 
2024-07-01-preview). Falls back to OPENAI_API_VERSION.",
+                        "schema": {"type": ["string", "null"]},
+                    },
+                },
+            },
+            {
+                "hook-class-name": 
"airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIBedrockHook",
+                "connection-type": "pydanticai-bedrock",
+                "ui-field-behaviour": {
+                    "hidden-fields": ["schema", "port", "login", "host", 
"password"],
+                    "relabeling": {},
+                    "placeholders": {},
+                },
+                "conn-fields": {
+                    "model": {
+                        "label": "Model",
+                        "description": "Bedrock model identifier (e.g. 
bedrock:us.anthropic.claude-opus-4-5)",
+                        "schema": {"type": ["string", "null"]},
+                    },
+                    "region_name": {
+                        "label": "AWS Region",
+                        "description": "AWS region (e.g. us-east-1). Falls 
back to AWS_DEFAULT_REGION env var.",
+                        "schema": {"type": ["string", "null"]},
+                    },
+                    "aws_access_key_id": {
+                        "label": "AWS Access Key ID",
+                        "description": "IAM access key. Leave empty to use 
instance role / environment credential chain.",
+                        "schema": {"type": ["string", "null"]},
+                    },
+                    "aws_secret_access_key": {
+                        "label": "AWS Secret Access Key",
+                        "description": "IAM secret key.",
+                        "schema": {"type": ["string", "null"]},
+                    },
+                    "aws_session_token": {
+                        "label": "AWS Session Token",
+                        "description": "Temporary session token (optional).",
+                        "schema": {"type": ["string", "null"]},
+                    },
+                    "profile_name": {
+                        "label": "AWS Profile Name",
+                        "description": "Named AWS credentials profile 
(optional).",
+                        "schema": {"type": ["string", "null"]},
+                    },
+                    "api_key": {
+                        "label": "Bearer Token",
+                        "description": "AWS bearer token (alt. to IAM 
key/secret). Falls back to AWS_BEARER_TOKEN_BEDROCK.",
+                        "schema": {"type": ["string", "null"]},
+                    },
+                    "base_url": {
+                        "label": "Custom Endpoint URL",
+                        "description": "Override the Bedrock runtime endpoint 
URL (optional).",
+                        "schema": {"type": ["string", "null"]},
+                    },
+                    "aws_read_timeout": {
+                        "label": "Read Timeout (s)",
+                        "description": "boto3 read timeout in seconds (float, 
optional).",
+                        "schema": {"type": ["number", "null"]},
+                    },
+                    "aws_connect_timeout": {
+                        "label": "Connect Timeout (s)",
+                        "description": "boto3 connect timeout in seconds 
(float, optional).",
+                        "schema": {"type": ["number", "null"]},
+                    },
+                },
+            },
+            {
+                "hook-class-name": 
"airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIVertexHook",
+                "connection-type": "pydanticai-vertex",
+                "ui-field-behaviour": {
+                    "hidden-fields": ["schema", "port", "login", "host", 
"password"],
+                    "relabeling": {},
+                    "placeholders": {},
+                },
+                "conn-fields": {
+                    "model": {
+                        "label": "Model",
+                        "description": "Google model identifier (e.g. 
google-vertex:gemini-2.0-flash)",
+                        "schema": {"type": ["string", "null"]},
+                    },
+                    "project": {
+                        "label": "GCP Project",
+                        "description": "Google Cloud project ID. Falls back to 
GOOGLE_CLOUD_PROJECT env var.",
+                        "schema": {"type": ["string", "null"]},
+                    },
+                    "location": {
+                        "label": "Location / Region",
+                        "description": "Vertex AI region (e.g. us-central1). 
Falls back to GOOGLE_CLOUD_LOCATION env var.",
+                        "schema": {"type": ["string", "null"]},
+                    },
+                    "vertexai": {
+                        "label": "Force Vertex AI Mode",
+                        "description": "Force Vertex AI mode. Auto-detected 
when project/location/credentials are set.",
+                        "schema": {"type": ["boolean", "null"]},
+                    },
+                    "api_key": {
+                        "label": "API Key",
+                        "description": "Google API key for Gen Language API or 
Vertex AI. Falls back to GOOGLE_API_KEY.",
+                        "schema": {"type": ["string", "null"]},
+                    },
+                    "service_account_info": {
+                        "label": "Service Account Info",
+                        "description": "Service account key as inline dict 
(JSON with type, project_id, private_key, etc.).",
+                        "schema": {"type": ["object", "null"]},
+                    },
+                    "base_url": {
+                        "label": "Custom Endpoint URL",
+                        "description": "Override the Google API base URL 
(optional).",
+                        "schema": {"type": ["string", "null"]},
+                    },
+                },
+            },
             {
                 "hook-class-name": 
"airflow.providers.common.ai.hooks.mcp.MCPHook",
                 "connection-type": "mcp",
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py 
b/providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py
index b05f0421e9a..eb4db2d1ec5 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py
@@ -19,33 +19,32 @@ from __future__ import annotations
 from typing import TYPE_CHECKING, Any, TypeVar, overload
 
 from pydantic_ai import Agent
-from pydantic_ai.models import Model, infer_model
-from pydantic_ai.providers import Provider, infer_provider, 
infer_provider_class
+from pydantic_ai.models import infer_model
+from pydantic_ai.providers import infer_provider, infer_provider_class
 
 from airflow.providers.common.compat.sdk import BaseHook
 
 OutputT = TypeVar("OutputT")
 
 if TYPE_CHECKING:
-    from pydantic_ai.models import KnownModelName
+    from pydantic_ai.models import KnownModelName, Model
 
 
 class PydanticAIHook(BaseHook):
     """
     Hook for LLM access via pydantic-ai.
 
-    Manages connection credentials and model creation. Uses pydantic-ai's
-    model inference to support any provider (OpenAI, Anthropic, Google,
-    Bedrock, Ollama, vLLM, etc.).
+    Covers providers that use a standard ``api_key`` + optional ``base_url``
+    (OpenAI, Anthropic, Groq, Mistral, DeepSeek, Ollama, vLLM, …).
 
-    Connection fields:
-        - **Model** (conn-field): Model in ``provider:model`` format (e.g. 
``"anthropic:claude-sonnet-4-20250514"``)
-        - **password**: API key (OpenAI, Anthropic, Groq, Mistral, etc.)
-        - **host**: Base URL (optional — for custom endpoints like Ollama, 
vLLM, Azure)
+    For cloud providers with non-standard auth use the dedicated subclasses:
+    :class:`PydanticAIAzureHook`, :class:`PydanticAIBedrockHook`,
+    :class:`PydanticAIVertexHook`.
 
-    Cloud providers (Bedrock, Vertex) that use native auth chains should leave
-    password empty and configure environment-based auth (``AWS_PROFILE``,
-    ``GOOGLE_APPLICATION_CREDENTIALS``).
+    Connection fields:
+        - **password**: API key
+        - **host**: Base URL (optional, e.g. ``https://api.openai.com/v1``)
+        - **extra** JSON: ``{"model": "openai:gpt-5.3"}``
 
     :param llm_conn_id: Airflow connection ID for the LLM provider.
     :param model_id: Model identifier in ``provider:model`` format (e.g. 
``"openai:gpt-5.3"``).
@@ -59,12 +58,16 @@ class PydanticAIHook(BaseHook):
 
     def __init__(
         self,
-        llm_conn_id: str = default_conn_name,
+        llm_conn_id: str | None = None,
         model_id: str | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
-        self.llm_conn_id = llm_conn_id
+        # Resolve at runtime so each subclass uses its own default_conn_name.
+        # A bare `llm_conn_id: str = default_conn_name` would bind the *base*
+        # class value for all subclasses because Python evaluates default
+        # argument values at class-definition time.
+        self.llm_conn_id = llm_conn_id if llm_conn_id is not None else 
self.default_conn_name
         self.model_id = model_id
         self._model: Model | None = None
 
@@ -75,62 +78,97 @@ class PydanticAIHook(BaseHook):
             "hidden_fields": ["schema", "port", "login"],
             "relabeling": {"password": "API Key"},
             "placeholders": {
-                "host": "https://api.openai.com/v1 (optional, for custom 
endpoints)",
+                "host": "https://api.openai.com/v1  (optional, for custom 
endpoints / Ollama)",
+                "extra": '{"model": "openai:gpt-5.3"}',
             },
         }
 
+    # ------------------------------------------------------------------
+    # Core connection / agent API
+    # ------------------------------------------------------------------
+
+    def _get_provider_kwargs(
+        self,
+        api_key: str | None,
+        base_url: str | None,
+        extra: dict[str, Any],
+    ) -> dict[str, Any]:
+        """
+        Return the kwargs to pass to the provider constructor.
+
+        Subclasses override this method to map their connection fields to the
+        parameters expected by their specific provider class.  The base
+        implementation handles the common ``api_key`` / ``base_url`` pattern
+        used by OpenAI, Anthropic, Groq, Mistral, Ollama, and most other
+        providers.
+
+        :param api_key: Value of ``conn.password``.
+        :param base_url: Value of ``conn.host``.
+        :param extra: Deserialized ``conn.extra`` JSON.
+        :return: Kwargs forwarded to ``provider_cls(**kwargs)``.  Empty dict
+            signals that no explicit credentials are available and the hook
+            should fall back to environment-variable–based auth.
+        """
+        kwargs: dict[str, Any] = {}
+        if api_key:
+            kwargs["api_key"] = api_key
+        if base_url:
+            kwargs["base_url"] = base_url
+        return kwargs
+
     def get_conn(self) -> Model:
         """
-        Return a configured pydantic-ai Model.
+        Return a configured pydantic-ai ``Model``.
 
-        Reads API key from connection password, base_url from connection host,
-        and model from (in priority order):
+        Resolution order:
 
-        1. ``model_id`` parameter on the hook
-        2. ``extra["model"]`` on the connection (set by the "Model" conn-field 
in the UI)
+        1. **Explicit credentials** — when :meth:`_get_provider_kwargs` returns
+           a non-empty dict the provider class is instantiated with those 
kwargs
+           and wrapped in a ``provider_factory``.
+        2. **Default resolution** — delegates to pydantic-ai ``infer_model``
+           which reads standard env vars (``OPENAI_API_KEY``, ``AWS_PROFILE``, 
…).
 
-        The result is cached for the lifetime of this hook instance.
+        The resolved model is cached for the lifetime of this hook instance.
         """
         if self._model is not None:
             return self._model
 
         conn = self.get_connection(self.llm_conn_id)
-        model_name: str | KnownModelName = self.model_id or 
conn.extra_dejson.get("model", "")
+
+        extra: dict[str, Any] = conn.extra_dejson
+        model_name: str | KnownModelName = self.model_id or extra.get("model", 
"")
         if not model_name:
             raise ValueError(
                 "No model specified. Set model_id on the hook or the Model 
field on the connection."
             )
-        api_key = conn.password
-        base_url = conn.host or None
 
-        if not api_key and not base_url:
-            # No credentials to inject — use default provider resolution
-            # (picks up env vars like OPENAI_API_KEY, AWS_PROFILE, etc.)
-            self._model = infer_model(model_name)
+        api_key: str | None = conn.password or None
+        base_url: str | None = conn.host or None
+
+        provider_kwargs = self._get_provider_kwargs(api_key, base_url, extra)
+        if provider_kwargs:
+            _kwargs = provider_kwargs  # capture for closure
+            self.log.info(
+                "Using explicit credentials for provider with model '%s': %s",
+                model_name,
+                list(provider_kwargs),
+            )
+
+            def _provider_factory(pname: str) -> Any:
+                try:
+                    return infer_provider_class(pname)(**_kwargs)
+                except TypeError:
+                    self.log.warning(
+                        "Provider '%s' rejected kwargs %s; falling back to 
env-var auth",
+                        pname,
+                        list(_kwargs),
+                    )
+                    return infer_provider(pname)
+
+            self._model = infer_model(model_name, 
provider_factory=_provider_factory)
             return self._model
 
-        def _provider_factory(provider_name: str) -> Provider[Any]:
-            """
-            Create a provider with credentials from the Airflow connection.
-
-            Falls back to default provider resolution if the provider's 
constructor
-            doesn't accept api_key/base_url (e.g. Google Vertex, Bedrock).
-            """
-            provider_cls = infer_provider_class(provider_name)
-            kwargs: dict[str, Any] = {}
-            if api_key:
-                kwargs["api_key"] = api_key
-            if base_url:
-                kwargs["base_url"] = base_url
-            try:
-                return provider_cls(**kwargs)
-            except TypeError:
-                # Provider doesn't accept these kwargs (e.g. Google Vertex/GLA
-                # use ADC, Bedrock uses boto session). Fall back to default
-                # provider resolution which reads credentials from the 
environment.
-                return infer_provider(provider_name)
-
-        self._model = infer_model(model_name, 
provider_factory=_provider_factory)
+        self._model = infer_model(model_name)
         return self._model
 
     @overload
@@ -157,13 +195,238 @@ class PydanticAIHook(BaseHook):
         """
         Test connection by resolving the model.
 
-        Validates that the model string is valid, the provider package is
-        installed, and the provider class can be instantiated. Does NOT make an
-        LLM API call — that would be expensive, flaky, and fail for reasons
-        unrelated to connectivity (quotas, billing, rate limits).
+        Validates that the model string is valid and the provider class can be
+        instantiated with the supplied credentials.  Does NOT make an LLM API
+        call — that would be expensive and fail for reasons unrelated to
+        connectivity (quotas, billing, rate limits).
         """
         try:
             self.get_conn()
             return True, "Model resolved successfully."
         except Exception as e:
             return False, str(e)
+
+
+class PydanticAIAzureHook(PydanticAIHook):
+    """
+    Hook for Azure OpenAI via pydantic-ai.
+
+    Connection fields:
+        - **password**: Azure API key
+        - **host**: Azure endpoint (e.g. 
``https://<resource>.openai.azure.com``)
+        - **extra** JSON::
+
+            {"model": "azure:gpt-4o", "api_version": "2024-07-01-preview"}
+
+    :param llm_conn_id: Airflow connection ID.
+    :param model_id: Model identifier, e.g. ``"azure:gpt-4o"``.
+    """
+
+    conn_type = "pydanticai-azure"
+    default_conn_name = "pydanticai_azure_default"
+    hook_name = "Pydantic AI (Azure OpenAI)"
+
+    @staticmethod
+    def get_ui_field_behaviour() -> dict[str, Any]:
+        """Return custom field behaviour for the Airflow connection form."""
+        return {
+            "hidden_fields": ["schema", "port", "login"],
+            "relabeling": {"password": "API Key", "host": "Azure Endpoint"},
+            "placeholders": {
+                "host": "https://<resource>.openai.azure.com",
+                "extra": '{"model": "azure:gpt-4o", "api_version": 
"2024-07-01-preview"}',
+            },
+        }
+
+    def _get_provider_kwargs(
+        self,
+        api_key: str | None,
+        base_url: str | None,
+        extra: dict[str, Any],
+    ) -> dict[str, Any]:
+        kwargs: dict[str, Any] = {}
+        if api_key:
+            kwargs["api_key"] = api_key
+        if base_url:
+            kwargs["azure_endpoint"] = base_url
+        if extra.get("api_version"):
+            kwargs["api_version"] = extra["api_version"]
+        return kwargs
+
+
+class PydanticAIBedrockHook(PydanticAIHook):
+    """
+    Hook for AWS Bedrock via pydantic-ai.
+
+    Credentials are resolved in order:
+
+    1. IAM keys from ``extra`` (``aws_access_key_id`` + 
``aws_secret_access_key``,
+       optionally ``aws_session_token``).
+    2. Bearer token in ``extra`` (``api_key``, maps to env 
``AWS_BEARER_TOKEN_BEDROCK``).
+    3. Environment-variable / instance-role chain (``AWS_PROFILE``, IAM role, 
…)
+       when no explicit keys are provided.
+
+    Connection fields:
+        - **extra** JSON::
+
+            {
+              "model": "bedrock:us.anthropic.claude-opus-4-5",
+              "region_name": "us-east-1",
+              "aws_access_key_id": "AKIA...",
+              "aws_secret_access_key": "...",
+              "aws_session_token": "...",
+              "profile_name": "my-aws-profile",
+              "api_key": "bearer-token",
+              "base_url": "https://custom-bedrock-endpoint";,
+              "aws_read_timeout": 60.0,
+              "aws_connect_timeout": 10.0
+            }
+
+          Leave ``aws_access_key_id`` / ``aws_secret_access_key`` and 
``api_key``
+          empty to use the default AWS credential chain.
+
+    :param llm_conn_id: Airflow connection ID.
+    :param model_id: Model identifier, e.g. 
``"bedrock:us.anthropic.claude-opus-4-5"``.
+    """
+
+    conn_type = "pydanticai-bedrock"
+    default_conn_name = "pydanticai_bedrock_default"
+    hook_name = "Pydantic AI (AWS Bedrock)"
+
+    @staticmethod
+    def get_ui_field_behaviour() -> dict[str, Any]:
+        """Return custom field behaviour for the Airflow connection form."""
+        return {
+            "hidden_fields": ["schema", "port", "login", "host", "password"],
+            "relabeling": {},
+            "placeholders": {
+                "extra": (
+                    '{"model": "bedrock:us.anthropic.claude-opus-4-5", '
+                    '"region_name": "us-east-1"}'
+                    "  — leave aws_access_key_id empty for IAM role / env-var 
auth"
+                ),
+            },
+        }
+
+    def _get_provider_kwargs(
+        self,
+        api_key: str | None,
+        base_url: str | None,
+        extra: dict[str, Any],
+    ) -> dict[str, Any]:
+        """
+        Return kwargs for ``BedrockProvider``.
+
+        .. note::
+            The ``api_key`` and ``base_url`` parameters (sourced from
+            ``conn.password`` and ``conn.host``) are intentionally ignored 
here.
+            Bedrock connections hide those fields in the UI; all config is
+            stored in ``extra`` instead.  The ``api_key`` and ``base_url``
+            keys below refer to *extra* fields, not the method parameters.
+        """
+        _str_keys = (
+            "aws_access_key_id",
+            "aws_secret_access_key",
+            "aws_session_token",
+            "region_name",
+            "profile_name",
+            # Bearer-token auth (alternative to IAM key/secret).
+            # Maps to AWS_BEARER_TOKEN_BEDROCK env var.
+            "api_key",
+            # Custom Bedrock runtime endpoint.
+            "base_url",
+        )
+        kwargs: dict[str, Any] = {k: extra[k] for k in _str_keys if 
extra.get(k)}
+        # BedrockProvider expects float for timeout values; JSON integers must 
be coerced.
+        for _timeout_key in ("aws_read_timeout", "aws_connect_timeout"):
+            if extra.get(_timeout_key):
+                kwargs[_timeout_key] = float(extra[_timeout_key])
+        return kwargs
+
+
+class PydanticAIVertexHook(PydanticAIHook):
+    """
+    Hook for Google Vertex AI (or Generative Language API) via pydantic-ai.
+
+    Credentials are resolved in order:
+
+    1. ``service_account_info`` (JSON object) in ``extra``
+       — loaded into a ``google.auth.credentials.Credentials``
+       object and passed as ``credentials`` to ``GoogleProvider``.
+    2. ``api_key`` in ``extra`` — for Generative Language API (non-Vertex) or
+       Vertex API-key auth.
+    3. Application Default Credentials (``GOOGLE_APPLICATION_CREDENTIALS``,
+       ``gcloud auth application-default login``, Workload Identity, …) when
+       no explicit credentials are provided.
+
+    Connection fields:
+        - **extra** JSON::
+
+            {
+                "model": "google-vertex:gemini-2.0-flash",
+                "project": "my-gcp-project",
+                "location": "us-central1",
+                "service_account_info": {...},
+                "vertexai": true,
+            }
+
+        Use ``"service_account_info"`` to embed the service-account JSON 
directly
+        (as an object, not a string path).
+
+        Set ``"vertexai": true`` to force Vertex AI mode when only ``api_key`` 
is
+        provided.  Omit ``vertexai`` for the Generative Language API (GLA).
+
+    :param llm_conn_id: Airflow connection ID.
+    :param model_id: Model identifier, e.g. 
``"google-vertex:gemini-2.0-flash"``.
+    """
+
+    conn_type = "pydanticai-vertex"
+    default_conn_name = "pydanticai_vertex_default"
+    hook_name = "Pydantic AI (Google Vertex AI)"
+
+    @staticmethod
+    def get_ui_field_behaviour() -> dict[str, Any]:
+        """Return custom field behaviour for the Airflow connection form."""
+        return {
+            "hidden_fields": ["schema", "port", "login", "host", "password"],
+            "relabeling": {},
+            "placeholders": {
+                "extra": (
+                    '{"model": "google-vertex:gemini-2.0-flash", '
+                    '"project": "my-project", "location": "us-central1", 
"vertexai": true}'
+                    "  — add service_account_info (object) for SA auth;"
+                    " omit both to use Application Default Credentials"
+                ),
+            },
+        }
+
+    def _get_provider_kwargs(
+        self,
+        api_key: str | None,
+        base_url: str | None,
+        extra: dict[str, Any],
+    ) -> dict[str, Any]:
+        sa_info = extra.get("service_account_info")
+        kwargs: dict[str, Any] = {}
+
+        # Direct GoogleProvider scalar kwargs.
+        for _key in ("api_key", "project", "location", "base_url"):
+            if extra.get(_key):
+                kwargs[_key] = extra[_key]
+
+        # Optional vertexai bool flag (force Vertex AI mode for API-key auth).
+        _vertexai = extra.get("vertexai")
+        if _vertexai is not None:
+            kwargs["vertexai"] = bool(_vertexai)
+
+        # Service-account credentials — loaded lazily to avoid importing
+        # google-auth on non-Vertex code paths (optional heavy dependency).
+        if sa_info:
+            from google.oauth2 import service_account  # lazy: optional dep
+
+            kwargs["credentials"] = 
service_account.Credentials.from_service_account_info(
+                sa_info,
+                scopes=["https://www.googleapis.com/auth/cloud-platform";],
+            )
+
+        return kwargs
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py 
b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
index d9c028b4b57..da8ecafc1c5 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
@@ -173,7 +173,10 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
     @cached_property
     def llm_hook(self) -> PydanticAIHook:
         """Return PydanticAIHook for the configured LLM connection."""
-        return PydanticAIHook(llm_conn_id=self.llm_conn_id, 
model_id=self.model_id)
+        hook_params = {
+            "model_id": self.model_id,
+        }
+        return PydanticAIHook.get_hook(self.llm_conn_id, 
hook_params=hook_params)
 
     def _build_agent(self) -> Agent[None, Any]:
         """Build and return a pydantic-ai Agent from the operator's config."""
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py 
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
index b5a2c80bd11..73000eb50a3 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
@@ -101,8 +101,18 @@ class LLMOperator(BaseOperator, LLMApprovalMixin):
 
     @cached_property
     def llm_hook(self) -> PydanticAIHook:
-        """Return PydanticAIHook for the configured LLM connection."""
-        return PydanticAIHook(llm_conn_id=self.llm_conn_id, 
model_id=self.model_id)
+        """
+        Return the correct PydanticAIHook subclass for the configured 
connection.
+
+        Delegates to :meth:`~PydanticAIHook.get_hook` which looks up
+        the connection's ``conn_type`` and instantiates the matching subclass
+        (e.g. 
:class:`~airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIAzureHook`
+        for ``pydanticai-azure`` connections).
+        """
+        hook_params = {
+            "model_id": self.model_id,
+        }
+        return PydanticAIHook.get_hook(self.llm_conn_id, 
hook_params=hook_params)
 
     def execute(self, context: Context) -> Any:
         agent: Agent[None, Any] = self.llm_hook.create_agent(
diff --git a/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py 
b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
index 52ed82c53cd..f5766325140 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
@@ -46,7 +46,7 @@ class TestAgentDecoratedOperator:
         """The callable's return value becomes the agent prompt."""
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = _make_mock_run_result("The top 
customer is Acme Corp.")
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         def my_prompt():
             return "Who is our top customer?"
@@ -78,7 +78,7 @@ class TestAgentDecoratedOperator:
         """op_kwargs are resolved by the callable to build the prompt."""
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = _make_mock_run_result("done")
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         def my_prompt(topic):
             return f"Analyze {topic}"
@@ -99,7 +99,7 @@ class TestAgentDecoratedOperator:
         """Toolsets passed to the decorator are forwarded to the agent."""
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = _make_mock_run_result("result")
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         mock_toolset = MagicMock()
 
@@ -111,7 +111,7 @@ class TestAgentDecoratedOperator:
         )
         op.execute(context={})
 
-        create_call = mock_hook_cls.return_value.create_agent.call_args
+        create_call = 
mock_hook_cls.get_hook.return_value.create_agent.call_args
         passed_toolsets = create_call[1]["toolsets"]
         assert len(passed_toolsets) == 1
         assert isinstance(passed_toolsets[0], LoggingToolset)
@@ -126,7 +126,7 @@ class TestAgentDecoratedOperator:
 
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = 
_make_mock_run_result(Summary(text="Great results"))
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         op = _AgentDecoratedOperator(
             task_id="test",
diff --git a/providers/common/ai/tests/unit/common/ai/decorators/test_llm.py 
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm.py
index ff3240529ff..21f663e177a 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_llm.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_llm.py
@@ -44,7 +44,7 @@ class TestLLMDecoratedOperator:
         """The callable's return value becomes the LLM prompt."""
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = _make_mock_run_result("This is a 
summary.")
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         def my_prompt():
             return "Summarize this text"
@@ -76,7 +76,7 @@ class TestLLMDecoratedOperator:
         """op_kwargs are resolved by the callable to build the prompt."""
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = _make_mock_run_result("done")
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         def my_prompt(topic):
             return f"Summarize {topic}"
diff --git 
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py 
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
index 8daf2799b5f..6620e505db1 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
@@ -49,7 +49,7 @@ class TestLLMBranchDecoratedOperator:
 
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = 
_make_mock_run_result(downstream_enum.positive)
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
         mock_do_branch.return_value = "positive"
 
         def my_prompt():
@@ -92,7 +92,7 @@ class TestLLMBranchDecoratedOperator:
 
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = 
_make_mock_run_result(downstream_enum.task_a)
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         def my_prompt(ticket_type):
             return f"Route this {ticket_type} ticket"
diff --git 
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_schema_compare.py
 
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_schema_compare.py
index 7722ea4e450..c271c94be00 100644
--- 
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_schema_compare.py
+++ 
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_schema_compare.py
@@ -61,7 +61,9 @@ class TestLLMSchemaCompareDecoratedOperator:
     @patch.object(LLMSchemaCompareOperator, "_build_schema_context", 
return_value="mocked schema")
     def test_execute_calls_callable_and_uses_result_as_prompt(self, 
mock_build_ctx, mock_hook_cls):
         """The user's callable return value becomes the LLM prompt."""
-        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent(_make_compare_result())
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
_make_mock_agent(
+            _make_compare_result()
+        )
 
         def my_prompt_fn():
             return "Compare schemas and flag breaking changes"
@@ -99,7 +101,9 @@ class TestLLMSchemaCompareDecoratedOperator:
     @patch.object(LLMSchemaCompareOperator, "_build_schema_context", 
return_value="mocked schema")
     def test_execute_merges_op_kwargs_into_callable(self, mock_build_ctx, 
mock_hook_cls):
         """op_kwargs are resolved by the callable to build the prompt."""
-        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent(_make_compare_result())
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
_make_mock_agent(
+            _make_compare_result()
+        )
 
         def my_prompt_fn(target_env):
             return f"Compare schemas for {target_env} environment"
diff --git 
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py 
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py
index a1b7e3e3f36..0f3f943926a 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py
@@ -44,7 +44,7 @@ class TestLLMSQLDecoratedOperator:
         """The user's callable return value becomes the LLM prompt."""
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = _make_mock_run_result("SELECT 1")
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         def my_prompt_fn():
             return "Get all users"
@@ -76,7 +76,7 @@ class TestLLMSQLDecoratedOperator:
         """op_kwargs are resolved by the callable to build the prompt."""
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = _make_mock_run_result("SELECT 1")
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         def my_prompt_fn(table_name):
             return f"Get all rows from {table_name}"
diff --git a/providers/common/ai/tests/unit/common/ai/hooks/test_pydantic_ai.py 
b/providers/common/ai/tests/unit/common/ai/hooks/test_pydantic_ai.py
index 5a2ecb2b1fa..993824e1c14 100644
--- a/providers/common/ai/tests/unit/common/ai/hooks/test_pydantic_ai.py
+++ b/providers/common/ai/tests/unit/common/ai/hooks/test_pydantic_ai.py
@@ -16,13 +16,20 @@
 # under the License.
 from __future__ import annotations
 
+import json
+import sys
 from unittest.mock import MagicMock, patch
 
 import pytest
 from pydantic_ai.models import Model
 
 from airflow.models.connection import Connection
-from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
+from airflow.providers.common.ai.hooks.pydantic_ai import (
+    PydanticAIAzureHook,
+    PydanticAIBedrockHook,
+    PydanticAIHook,
+    PydanticAIVertexHook,
+)
 
 
 class TestPydanticAIHookInit:
@@ -36,6 +43,19 @@ class TestPydanticAIHookInit:
         assert hook.llm_conn_id == "my_llm"
         assert hook.model_id == "openai:gpt-5.3"
 
+    def test_azure_hook_uses_own_default_conn_name(self):
+        """Subclass default_conn_name is used, not the base class value."""
+        hook = PydanticAIAzureHook()
+        assert hook.llm_conn_id == "pydanticai_azure_default"
+
+    def test_bedrock_hook_uses_own_default_conn_name(self):
+        hook = PydanticAIBedrockHook()
+        assert hook.llm_conn_id == "pydanticai_bedrock_default"
+
+    def test_vertex_hook_uses_own_default_conn_name(self):
+        hook = PydanticAIVertexHook()
+        assert hook.llm_conn_id == "pydanticai_vertex_default"
+
 
 class TestPydanticAIHookGetConn:
     @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model", 
autospec=True)
@@ -175,36 +195,6 @@ class TestPydanticAIHookGetConn:
         assert first is second
         mock_infer_model.assert_called_once()
 
-    @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_provider", 
autospec=True)
-    
@patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_provider_class", 
autospec=True)
-    @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model", 
autospec=True)
-    def test_provider_factory_falls_back_on_unsupported_kwargs(
-        self, mock_infer_model, mock_infer_provider_class, mock_infer_provider
-    ):
-        """If a provider rejects api_key/base_url, fall back to default 
resolution."""
-        mock_infer_model.return_value = MagicMock(spec=Model)
-        mock_fallback_provider = MagicMock()
-        mock_infer_provider.return_value = mock_fallback_provider
-        # Simulate a provider that doesn't accept api_key/base_url
-        mock_infer_provider_class.return_value = 
MagicMock(side_effect=TypeError("unexpected keyword"))
-
-        hook = PydanticAIHook(llm_conn_id="test_conn", 
model_id="google:gemini-2.0-flash")
-        conn = Connection(
-            conn_id="test_conn",
-            conn_type="pydanticai",
-            password="some-key",
-        )
-        with patch.object(hook, "get_connection", return_value=conn):
-            hook.get_conn()
-
-        factory = mock_infer_model.call_args[1]["provider_factory"]
-        result = factory("google-gla")
-
-        # Should have tried provider_cls first, then fallen back to 
infer_provider
-        
mock_infer_provider_class.return_value.assert_called_once_with(api_key="some-key")
-        mock_infer_provider.assert_called_with("google-gla")
-        assert result is mock_fallback_provider
-
 
 class TestPydanticAIHookCreateAgent:
     @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model", 
autospec=True)
@@ -296,3 +286,369 @@ class TestPydanticAIHookTestConnection:
 
         assert success is False
         assert "No model specified" in message
+
+
+# ---------------------------------------------------------------------------
+# Subclass hook tests
+# ---------------------------------------------------------------------------
+
+
+class TestPydanticAIAzureHook:
+    """Tests for PydanticAIAzureHook."""
+
+    def test_conn_type(self):
+        assert PydanticAIAzureHook.conn_type == "pydanticai-azure"
+
+    def test_hook_name(self):
+        assert "Azure" in PydanticAIAzureHook.hook_name
+
+    def test_ui_field_behaviour_relabels_host(self):
+        behaviour = PydanticAIAzureHook.get_ui_field_behaviour()
+        assert behaviour["relabeling"].get("host") == "Azure Endpoint"
+
+    def test_get_provider_kwargs_maps_azure_endpoint(self):
+        hook = PydanticAIAzureHook.__new__(PydanticAIAzureHook)
+        result = hook._get_provider_kwargs(
+            "my-key",
+            "https://myresource.openai.azure.com";,
+            {"model": "azure:gpt-4o", "api_version": "2024-07-01-preview"},
+        )
+        assert result["azure_endpoint"] == 
"https://myresource.openai.azure.com";
+        assert result["api_key"] == "my-key"
+        assert result["api_version"] == "2024-07-01-preview"
+        assert "base_url" not in result
+
+    def test_get_provider_kwargs_omits_none_api_key(self):
+        hook = PydanticAIAzureHook.__new__(PydanticAIAzureHook)
+        result = hook._get_provider_kwargs(
+            None,
+            "https://myresource.openai.azure.com";,
+            {"model": "azure:gpt-4o", "api_version": "2024-07-01-preview"},
+        )
+        assert "api_key" not in result
+        assert result["azure_endpoint"] == 
"https://myresource.openai.azure.com";
+
+    def test_get_provider_kwargs_omits_azure_endpoint_when_no_host(self):
+        hook = PydanticAIAzureHook.__new__(PydanticAIAzureHook)
+        result = hook._get_provider_kwargs(
+            "my-key",
+            None,
+            {"model": "azure:gpt-4o", "api_version": "2024-07-01-preview"},
+        )
+        assert "azure_endpoint" not in result
+        assert result["api_key"] == "my-key"
+
+    def test_get_provider_kwargs_empty_without_api_version(self):
+        hook = PydanticAIAzureHook.__new__(PydanticAIAzureHook)
+        result = hook._get_provider_kwargs(
+            "my-key",
+            "https://myresource.openai.azure.com";,
+            {"model": "azure:gpt-4o"},
+        )
+        # api_version should not appear if not in extra
+        assert "api_version" not in result
+
+    @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model", 
autospec=True)
+    
@patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_provider_class", 
autospec=True)
+    def test_get_conn_uses_azure_endpoint(self, mock_infer_provider_class, 
mock_infer_model):
+        mock_infer_model.return_value = MagicMock(spec=Model)
+        mock_provider_cls = MagicMock(return_value=MagicMock())
+        mock_infer_provider_class.return_value = mock_provider_cls
+
+        hook = PydanticAIAzureHook(llm_conn_id="azure_test")
+        conn = Connection(
+            conn_id="azure_test",
+            conn_type="pydanticai-azure",
+            password="azure-key",
+            host="https://myresource.openai.azure.com";,
+            extra=json.dumps({"model": "azure:gpt-4o", "api_version": 
"2024-07-01-preview"}),
+        )
+        with patch.object(hook, "get_connection", return_value=conn):
+            hook.get_conn()
+
+        factory = mock_infer_model.call_args[1]["provider_factory"]
+        factory("azure")
+        mock_provider_cls.assert_called_with(
+            api_key="azure-key",
+            azure_endpoint="https://myresource.openai.azure.com";,
+            api_version="2024-07-01-preview",
+        )
+
+    @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model", 
autospec=True)
+    def test_get_conn_falls_back_to_env_auth_when_no_kwargs(self, 
mock_infer_model):
+        """No host + no password → env-var auth path (empty 
_get_provider_kwargs)."""
+        mock_infer_model.return_value = MagicMock(spec=Model)
+        hook = PydanticAIAzureHook(llm_conn_id="azure_test")
+        conn = Connection(
+            conn_id="azure_test",
+            conn_type="pydanticai-azure",
+            extra=json.dumps({"model": "azure:gpt-4o"}),
+        )
+        with patch.object(hook, "get_connection", return_value=conn):
+            hook.get_conn()
+
+        mock_infer_model.assert_called_once_with("azure:gpt-4o")
+
+
+class TestPydanticAIBedrockHook:
+    """Tests for PydanticAIBedrockHook."""
+
+    def test_conn_type(self):
+        assert PydanticAIBedrockHook.conn_type == "pydanticai-bedrock"
+
+    def test_hook_name(self):
+        assert "Bedrock" in PydanticAIBedrockHook.hook_name
+
+    def test_ui_hides_host_and_password(self):
+        behaviour = PydanticAIBedrockHook.get_ui_field_behaviour()
+        assert "host" in behaviour["hidden_fields"]
+        assert "password" in behaviour["hidden_fields"]
+
+    def test_get_provider_kwargs_maps_bedrock_fields(self):
+        hook = PydanticAIBedrockHook.__new__(PydanticAIBedrockHook)
+        result = hook._get_provider_kwargs(
+            None,
+            None,
+            {
+                "model": "bedrock:us.anthropic.claude-opus-4-5",
+                "region_name": "us-east-1",
+                "aws_access_key_id": "AKIA123",
+                "aws_secret_access_key": "secret",
+            },
+        )
+        assert result["region_name"] == "us-east-1"
+        assert result["aws_access_key_id"] == "AKIA123"
+        assert result["aws_secret_access_key"] == "secret"
+        assert "model" not in result
+        assert "api_key" not in result
+
+    def test_get_provider_kwargs_returns_empty_for_env_auth(self):
+        """When no keys are in extra, return {} so env-auth path is taken."""
+        hook = PydanticAIBedrockHook.__new__(PydanticAIBedrockHook)
+        result = hook._get_provider_kwargs(None, None, {"model": 
"bedrock:us.anthropic.claude-opus-4-5"})
+        assert result == {}
+
+    @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model", 
autospec=True)
+    def test_get_conn_falls_back_to_env_auth(self, mock_infer_model):
+        mock_infer_model.return_value = MagicMock(spec=Model)
+        hook = PydanticAIBedrockHook(llm_conn_id="bedrock_test")
+        conn = Connection(
+            conn_id="bedrock_test",
+            conn_type="pydanticai-bedrock",
+            extra=json.dumps({"model": 
"bedrock:us.anthropic.claude-opus-4-5"}),
+        )
+        with patch.object(hook, "get_connection", return_value=conn):
+            hook.get_conn()
+
+        
mock_infer_model.assert_called_once_with("bedrock:us.anthropic.claude-opus-4-5")
+
+    @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model", 
autospec=True)
+    
@patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_provider_class", 
autospec=True)
+    def test_get_conn_uses_explicit_keys(self, mock_infer_provider_class, 
mock_infer_model):
+        mock_infer_model.return_value = MagicMock(spec=Model)
+        mock_provider_cls = MagicMock(return_value=MagicMock())
+        mock_infer_provider_class.return_value = mock_provider_cls
+
+        hook = PydanticAIBedrockHook(llm_conn_id="bedrock_test")
+        conn = Connection(
+            conn_id="bedrock_test",
+            conn_type="pydanticai-bedrock",
+            extra=json.dumps(
+                {
+                    "model": "bedrock:us.anthropic.claude-opus-4-5",
+                    "region_name": "eu-west-1",
+                    "aws_access_key_id": "AKIA123",
+                    "aws_secret_access_key": "secret",
+                }
+            ),
+        )
+        with patch.object(hook, "get_connection", return_value=conn):
+            hook.get_conn()
+
+        factory = mock_infer_model.call_args[1]["provider_factory"]
+        factory("bedrock")
+        mock_provider_cls.assert_called_with(
+            region_name="eu-west-1",
+            aws_access_key_id="AKIA123",
+            aws_secret_access_key="secret",
+        )
+
+    def test_get_provider_kwargs_bearer_token(self):
+        """api_key in extra maps to BedrockProvider's bearer-token param."""
+        hook = PydanticAIBedrockHook.__new__(PydanticAIBedrockHook)
+        result = hook._get_provider_kwargs(
+            None,
+            None,
+            {
+                "model": "bedrock:us.anthropic.claude-opus-4-5",
+                "api_key": "bearer-token-value",
+                "region_name": "us-east-1",
+            },
+        )
+        assert result["api_key"] == "bearer-token-value"
+        assert result["region_name"] == "us-east-1"
+        assert "aws_access_key_id" not in result
+
+    def test_get_provider_kwargs_base_url(self):
+        """base_url in extra is forwarded to BedrockProvider."""
+        hook = PydanticAIBedrockHook.__new__(PydanticAIBedrockHook)
+        result = hook._get_provider_kwargs(
+            None,
+            None,
+            {
+                "model": "bedrock:us.anthropic.claude-opus-4-5",
+                "base_url": "https://custom-bedrock.example.com";,
+            },
+        )
+        assert result["base_url"] == "https://custom-bedrock.example.com";
+
+    def test_get_provider_kwargs_float_timeouts(self):
+        """Timeout values are coerced to float (JSON delivers them as int)."""
+        hook = PydanticAIBedrockHook.__new__(PydanticAIBedrockHook)
+        result = hook._get_provider_kwargs(
+            None,
+            None,
+            {
+                "model": "bedrock:us.anthropic.claude-opus-4-5",
+                "aws_read_timeout": 60,  # int from JSON
+                "aws_connect_timeout": 10.5,  # float already
+            },
+        )
+        assert result["aws_read_timeout"] == 60.0
+        assert isinstance(result["aws_read_timeout"], float)
+        assert result["aws_connect_timeout"] == 10.5
+        assert isinstance(result["aws_connect_timeout"], float)
+
+
+class TestPydanticAIVertexHook:
+    """Tests for PydanticAIVertexHook."""
+
+    def test_conn_type(self):
+        assert PydanticAIVertexHook.conn_type == "pydanticai-vertex"
+
+    def test_hook_name(self):
+        assert "Vertex" in PydanticAIVertexHook.hook_name
+
+    def test_ui_hides_host_and_password(self):
+        behaviour = PydanticAIVertexHook.get_ui_field_behaviour()
+        assert "host" in behaviour["hidden_fields"]
+        assert "password" in behaviour["hidden_fields"]
+
+    def test_get_provider_kwargs_maps_vertex_fields(self):
+        """project and location are passed directly; api_key absent when not 
in extra."""
+        hook = PydanticAIVertexHook.__new__(PydanticAIVertexHook)
+        result = hook._get_provider_kwargs(
+            None,
+            None,
+            {
+                "model": "google-vertex:gemini-2.0-flash",
+                "project": "my-project",
+                "location": "us-central1",
+            },
+        )
+        assert result["project"] == "my-project"
+        assert result["location"] == "us-central1"
+        assert "model" not in result
+        assert "api_key" not in result
+        assert "project_id" not in result
+
+    def test_get_provider_kwargs_api_key_gla_mode(self):
+        """api_key in extra is forwarded for Generative Language API mode."""
+        hook = PydanticAIVertexHook.__new__(PydanticAIVertexHook)
+        result = hook._get_provider_kwargs(
+            None,
+            None,
+            {"model": "google-gla:gemini-2.0-flash", "api_key": "gla-key"},
+        )
+        assert result["api_key"] == "gla-key"
+
+    def test_get_provider_kwargs_vertexai_flag(self):
+        """vertexai bool is forwarded and coerced to bool."""
+        hook = PydanticAIVertexHook.__new__(PydanticAIVertexHook)
+        result = hook._get_provider_kwargs(
+            None,
+            None,
+            {"model": "google-vertex:gemini-2.0-flash", "api_key": "key", 
"vertexai": True},
+        )
+        assert result["vertexai"] is True
+
+    def test_get_provider_kwargs_service_account_info_loads_credentials(self):
+        """service_account_info dict is loaded into a Credentials object."""
+        mock_sa = MagicMock()
+        mock_creds = MagicMock()
+        mock_sa.Credentials.from_service_account_info.return_value = mock_creds
+
+        mock_google_oauth2 = MagicMock()
+        mock_google_oauth2.service_account = mock_sa
+
+        sa_info_dict = {"type": "service_account", "project_id": "my-project", 
"private_key": "..."}
+        hook = PydanticAIVertexHook.__new__(PydanticAIVertexHook)
+        with patch.dict(
+            sys.modules,
+            {
+                "google": MagicMock(),
+                "google.oauth2": mock_google_oauth2,
+                "google.oauth2.service_account": mock_sa,
+            },
+        ):
+            result = hook._get_provider_kwargs(
+                None,
+                None,
+                {
+                    "model": "google-vertex:gemini-2.0-flash",
+                    "service_account_info": sa_info_dict,
+                },
+            )
+
+        mock_sa.Credentials.from_service_account_info.assert_called_once_with(
+            sa_info_dict,
+            scopes=["https://www.googleapis.com/auth/cloud-platform";],
+        )
+        assert result["credentials"] is mock_creds
+        assert "service_account_info" not in result
+
+    def test_get_provider_kwargs_returns_empty_for_adc(self):
+        """When no keys are in extra, return {} so ADC path is taken."""
+        hook = PydanticAIVertexHook.__new__(PydanticAIVertexHook)
+        result = hook._get_provider_kwargs(None, None, {"model": 
"google-vertex:gemini-2.0-flash"})
+        assert result == {}
+
+    @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model", 
autospec=True)
+    def test_get_conn_falls_back_to_adc(self, mock_infer_model):
+        mock_infer_model.return_value = MagicMock(spec=Model)
+        hook = PydanticAIVertexHook(llm_conn_id="vertex_test")
+        conn = Connection(
+            conn_id="vertex_test",
+            conn_type="pydanticai-vertex",
+            extra=json.dumps({"model": "google-vertex:gemini-2.0-flash"}),
+        )
+        with patch.object(hook, "get_connection", return_value=conn):
+            hook.get_conn()
+
+        
mock_infer_model.assert_called_once_with("google-vertex:gemini-2.0-flash")
+
+    @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model", 
autospec=True)
+    
@patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_provider_class", 
autospec=True)
+    def test_get_conn_uses_explicit_project(self, mock_infer_provider_class, 
mock_infer_model):
+        mock_infer_model.return_value = MagicMock(spec=Model)
+        mock_provider_cls = MagicMock(return_value=MagicMock())
+        mock_infer_provider_class.return_value = mock_provider_cls
+
+        hook = PydanticAIVertexHook(llm_conn_id="vertex_test")
+        conn = Connection(
+            conn_id="vertex_test",
+            conn_type="pydanticai-vertex",
+            extra=json.dumps(
+                {
+                    "model": "google-vertex:gemini-2.0-flash",
+                    "project": "my-project",
+                    "location": "europe-west4",
+                }
+            ),
+        )
+        with patch.object(hook, "get_connection", return_value=conn):
+            hook.get_conn()
+
+        factory = mock_infer_model.call_args[1]["provider_factory"]
+        factory("google-vertex")
+        mock_provider_cls.assert_called_with(project="my-project", 
location="europe-west4")
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py 
b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
index 2125b3301f8..13410b35c55 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
@@ -82,7 +82,7 @@ class TestAgentOperatorExecute:
     @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
     def test_execute_creates_agent_from_hook(self, mock_hook_cls):
         mock_agent = _make_mock_agent("The answer is 42.")
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         op = AgentOperator(
             task_id="test",
@@ -93,8 +93,8 @@ class TestAgentOperatorExecute:
         result = op.execute(context=MagicMock())
 
         assert result == "The answer is 42."
-        mock_hook_cls.assert_called_once_with(llm_conn_id="my_llm", 
model_id=None)
-        mock_hook_cls.return_value.create_agent.assert_called_once_with(
+        mock_hook_cls.get_hook.assert_called_once_with("my_llm", 
hook_params={"model_id": None})
+        
mock_hook_cls.get_hook.return_value.create_agent.assert_called_once_with(
             output_type=str, instructions="You are helpful."
         )
         mock_agent.run_sync.assert_called_once_with("What is the answer?")
@@ -102,7 +102,7 @@ class TestAgentOperatorExecute:
     @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
     def test_execute_passes_toolsets_in_agent_kwargs(self, mock_hook_cls):
         """Toolsets are passed through to the agent constructor."""
-        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("done")
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
_make_mock_agent("done")
 
         mock_toolset = MagicMock()
         op = AgentOperator(
@@ -113,7 +113,7 @@ class TestAgentOperatorExecute:
         )
         op.execute(context=MagicMock())
 
-        create_call = mock_hook_cls.return_value.create_agent.call_args
+        create_call = 
mock_hook_cls.get_hook.return_value.create_agent.call_args
         passed_toolsets = create_call[1]["toolsets"]
         assert len(passed_toolsets) == 1
         assert isinstance(passed_toolsets[0], LoggingToolset)
@@ -122,7 +122,7 @@ class TestAgentOperatorExecute:
     @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
     def test_enable_tool_logging_false_skips_wrapping(self, mock_hook_cls):
         """enable_tool_logging=False passes toolsets through unwrapped."""
-        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("done")
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
_make_mock_agent("done")
 
         mock_toolset = MagicMock()
         op = AgentOperator(
@@ -134,13 +134,13 @@ class TestAgentOperatorExecute:
         )
         op.execute(context=MagicMock())
 
-        create_call = mock_hook_cls.return_value.create_agent.call_args
+        create_call = 
mock_hook_cls.get_hook.return_value.create_agent.call_args
         assert create_call[1]["toolsets"] == [mock_toolset]
 
     @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
     def test_execute_passes_agent_params(self, mock_hook_cls):
         """agent_params are unpacked into create_agent."""
-        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("ok")
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
_make_mock_agent("ok")
 
         op = AgentOperator(
             task_id="test",
@@ -150,7 +150,7 @@ class TestAgentOperatorExecute:
         )
         op.execute(context=MagicMock())
 
-        create_call = mock_hook_cls.return_value.create_agent.call_args
+        create_call = 
mock_hook_cls.get_hook.return_value.create_agent.call_args
         assert create_call[1]["retries"] == 3
         assert create_call[1]["model_settings"] == {"temperature": 0}
 
@@ -162,7 +162,7 @@ class TestAgentOperatorExecute:
             text: str
             score: float
 
-        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent(
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
_make_mock_agent(
             Summary(text="Great", score=0.95)
         )
 
@@ -179,7 +179,7 @@ class TestAgentOperatorExecute:
     @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
     def test_execute_with_model_id(self, mock_hook_cls):
         """model_id is passed to PydanticAIHook."""
-        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("ok")
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
_make_mock_agent("ok")
 
         op = AgentOperator(
             task_id="test",
@@ -189,7 +189,7 @@ class TestAgentOperatorExecute:
         )
         op.execute(context=MagicMock())
 
-        mock_hook_cls.assert_called_once_with(llm_conn_id="my_llm", 
model_id="openai:gpt-5")
+        mock_hook_cls.get_hook.assert_called_once_with("my_llm", 
hook_params={"model_id": "openai:gpt-5"})
 
     @pytest.mark.skipif(
         not AIRFLOW_V_3_1_PLUS, reason="Human in the loop is only compatible 
with Airflow >= 3.1.0"
@@ -203,7 +203,7 @@ class TestAgentOperatorExecute:
         mock_result.all_messages.return_value = msg_history
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = mock_result
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
         mock_run_hitl.return_value = "Approved output"
 
         op = AgentOperator(
@@ -234,7 +234,7 @@ class TestAgentOperatorExecute:
         mock_result = _make_mock_run_result(Summary(text="Approved summary", 
score=0.9))
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = mock_result
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
         # run_hitl_review returns JSON string (as stored in 
session.current_output)
         mock_run_hitl.return_value = '{"text": "Approved summary", "score": 
0.9}'
 
@@ -261,7 +261,7 @@ class TestAgentOperatorExecute:
         mock_result = _make_mock_run_result("Initial output")
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = mock_result
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
         mock_run_hitl.return_value = "Approved output"
 
         op = AgentOperator(
@@ -289,7 +289,7 @@ class TestAgentOperatorExecute:
         mock_result = _make_mock_run_result("Initial output")
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = mock_result
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
         mock_run_hitl.side_effect = HITLMaxIterationsError("Task exceeded max 
iterations.")
 
         op = AgentOperator(
@@ -361,7 +361,7 @@ class TestAgentOperatorRegenerateWithFeedback:
         mock_result.all_messages.return_value = msg_history + [MagicMock()]
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = mock_result
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         op = AgentOperator(
             task_id="test",
@@ -391,7 +391,7 @@ class TestAgentOperatorRegenerateWithFeedback:
         mock_result.all_messages.return_value = []
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = mock_result
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         op = AgentOperator(
             task_id="test",
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py 
b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
index d596bb39b87..3a686d014f0 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
@@ -53,15 +53,17 @@ class TestLLMOperator:
         """Default output_type=str returns the LLM string directly."""
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = _make_mock_run_result("Paris is the 
capital of France.")
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         op = LLMOperator(task_id="test", prompt="What is the capital of 
France?", llm_conn_id="my_llm")
         result = op.execute(context=MagicMock())
 
         assert result == "Paris is the capital of France."
         mock_agent.run_sync.assert_called_once_with("What is the capital of 
France?")
-        
mock_hook_cls.return_value.create_agent.assert_called_once_with(output_type=str,
 instructions="")
-        mock_hook_cls.assert_called_once_with(llm_conn_id="my_llm", 
model_id=None)
+        
mock_hook_cls.get_hook.return_value.create_agent.assert_called_once_with(
+            output_type=str, instructions=""
+        )
+        mock_hook_cls.get_hook.assert_called_once_with("my_llm", 
hook_params={"model_id": None})
 
     @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
     def test_execute_structured_output_with_all_params(self, mock_hook_cls):
@@ -72,7 +74,7 @@ class TestLLMOperator:
 
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = 
_make_mock_run_result(Entities(names=["Alice", "Bob"]))
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         op = LLMOperator(
             task_id="test",
@@ -86,8 +88,8 @@ class TestLLMOperator:
         result = op.execute(context=MagicMock())
 
         assert result == {"names": ["Alice", "Bob"]}
-        mock_hook_cls.assert_called_once_with(llm_conn_id="my_llm", 
model_id="openai:gpt-5")
-        mock_hook_cls.return_value.create_agent.assert_called_once_with(
+        mock_hook_cls.get_hook.assert_called_once_with("my_llm", 
hook_params={"model_id": "openai:gpt-5"})
+        
mock_hook_cls.get_hook.return_value.create_agent.assert_called_once_with(
             output_type=Entities,
             instructions="You are an extractor.",
             retries=3,
@@ -126,7 +128,7 @@ class TestLLMOperatorApproval:
 
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = _make_mock_run_result("LLM 
response")
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         op = LLMOperator(
             task_id="approval_test",
@@ -152,7 +154,7 @@ class TestLLMOperatorApproval:
 
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = _make_mock_run_result("draft 
output")
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         op = LLMOperator(
             task_id="mod_test",
@@ -178,7 +180,7 @@ class TestLLMOperatorApproval:
 
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = _make_mock_run_result("output")
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         timeout = timedelta(hours=1)
         op = LLMOperator(
@@ -207,7 +209,7 @@ class TestLLMOperatorApproval:
 
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = 
_make_mock_run_result(Summary(text="hello"))
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         op = LLMOperator(
             task_id="struct_test",
@@ -228,7 +230,7 @@ class TestLLMOperatorApproval:
         """When require_approval=False, execute() returns output directly."""
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = _make_mock_run_result("plain 
output")
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         op = LLMOperator(task_id="no_approval", prompt="p", 
llm_conn_id="my_llm", require_approval=False)
         result = op.execute(context={})
diff --git 
a/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py 
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py
index ffd475f71b8..78e6c231080 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py
@@ -64,7 +64,7 @@ class TestLLMBranchOperator:
 
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = 
_make_mock_run_result(downstream_enum.task_a)
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
         mock_do_branch.return_value = "task_a"
 
         op = LLMBranchOperator(
@@ -93,7 +93,7 @@ class TestLLMBranchOperator:
         mock_agent.run_sync.return_value = _make_mock_run_result(
             [downstream_enum.task_a, downstream_enum.task_c]
         )
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
         mock_do_branch.return_value = ["task_a", "task_c"]
 
         op = LLMBranchOperator(
@@ -118,7 +118,7 @@ class TestLLMBranchOperator:
 
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = 
_make_mock_run_result(downstream_enum.task_a)
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         op = LLMBranchOperator(
             task_id="test",
@@ -130,7 +130,7 @@ class TestLLMBranchOperator:
 
         op.execute(MagicMock())
 
-        call_kwargs = mock_hook_cls.return_value.create_agent.call_args
+        call_kwargs = 
mock_hook_cls.get_hook.return_value.create_agent.call_args
         assert call_kwargs.kwargs["instructions"] == "Route tickets to the 
right team."
 
     @patch.object(LLMBranchOperator, "do_branch")
@@ -143,7 +143,7 @@ class TestLLMBranchOperator:
 
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = 
_make_mock_run_result(downstream_enum.billing)
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         op = LLMBranchOperator(
             task_id="test",
@@ -154,7 +154,7 @@ class TestLLMBranchOperator:
 
         op.execute(MagicMock())
 
-        output_type = 
mock_hook_cls.return_value.create_agent.call_args.kwargs["output_type"]
+        output_type = 
mock_hook_cls.get_hook.return_value.create_agent.call_args.kwargs["output_type"]
         assert {m.value for m in output_type} == {"billing", "auth", "general"}
 
     def test_execute_raises_on_no_downstream_tasks(self):
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py 
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
index 57bdcb088f0..4ff3158a356 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
@@ -100,7 +100,7 @@ class TestLLMSQLQueryOperator:
     def test_execute_with_schema_context(self, mock_hook_cls):
         """Operator uses schema_context and returns generated SQL."""
         mock_agent = _make_mock_agent("SELECT id, name FROM users WHERE active 
= true")
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         op = LLMSQLQueryOperator(
             task_id="test",
@@ -116,7 +116,7 @@ class TestLLMSQLQueryOperator:
     @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
     def test_execute_validation_blocks_unsafe_sql(self, mock_hook_cls):
         """Validation catches unsafe SQL generated by the LLM."""
-        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("DROP TABLE users")
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
_make_mock_agent("DROP TABLE users")
 
         op = LLMSQLQueryOperator(task_id="test", prompt="Delete everything", 
llm_conn_id="my_llm")
 
@@ -126,7 +126,7 @@ class TestLLMSQLQueryOperator:
     @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
     def test_execute_validation_disabled(self, mock_hook_cls):
         """When validate_sql=False, unsafe SQL is returned without checks."""
-        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("DROP TABLE users")
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
_make_mock_agent("DROP TABLE users")
 
         op = LLMSQLQueryOperator(task_id="test", prompt="Drop it", 
llm_conn_id="my_llm", validate_sql=False)
         result = op.execute(context=MagicMock())
@@ -136,7 +136,7 @@ class TestLLMSQLQueryOperator:
     @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
     def test_execute_passes_agent_params(self, mock_hook_cls):
         """agent_params inherited from LLMOperator are unpacked into 
create_agent."""
-        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("SELECT 1")
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
_make_mock_agent("SELECT 1")
 
         op = LLMSQLQueryOperator(
             task_id="test",
@@ -146,14 +146,14 @@ class TestLLMSQLQueryOperator:
         )
         op.execute(context=MagicMock())
 
-        create_agent_call = mock_hook_cls.return_value.create_agent.call_args
+        create_agent_call = 
mock_hook_cls.get_hook.return_value.create_agent.call_args
         assert create_agent_call[1]["retries"] == 3
         assert create_agent_call[1]["model_settings"] == {"temperature": 0}
 
     @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
     def test_system_prompt_appended_to_sql_instructions(self, mock_hook_cls):
         """User-provided system_prompt is appended to built-in SQL safety 
prompt."""
-        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("SELECT 1")
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
_make_mock_agent("SELECT 1")
 
         op = LLMSQLQueryOperator(
             task_id="test",
@@ -163,7 +163,7 @@ class TestLLMSQLQueryOperator:
         )
         op.execute(context=MagicMock())
 
-        instructions = 
mock_hook_cls.return_value.create_agent.call_args[1]["instructions"]
+        instructions = 
mock_hook_cls.get_hook.return_value.create_agent.call_args[1]["instructions"]
         assert "Always use LEFT JOINs." in instructions
         # Built-in SQL safety prompt should still be present
         assert "Generate only SELECT queries" in instructions
@@ -175,7 +175,7 @@ class TestLLMSQLQueryOperatorSchemaIntrospection:
     def test_introspect_schemas_via_db_hook(self, mock_hook_cls):
         """db_conn_id + table_names triggers schema introspection."""
         mock_agent = _make_mock_agent("SELECT id FROM users")
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         mock_db_hook = MagicMock(spec=["get_table_schema", "dialect_name"])
         mock_db_hook.get_table_schema.return_value = [
@@ -199,7 +199,7 @@ class TestLLMSQLQueryOperatorSchemaIntrospection:
         mock_db_hook.get_table_schema.assert_called_once_with("users")
 
         # Verify the system prompt contains the schema info
-        instructions = 
mock_hook_cls.return_value.create_agent.call_args[1]["instructions"]
+        instructions = 
mock_hook_cls.get_hook.return_value.create_agent.call_args[1]["instructions"]
         assert "users" in instructions
         assert "id INTEGER" in instructions
 
@@ -362,7 +362,7 @@ class TestLLMSQLQueryOperatorSchemaIntrospection:
         mock_engine.get_schema.return_value = "event: TEXT\nts: TIMESTAMP"
 
         mock_agent = _make_mock_agent("SELECT u.id, e.event FROM users u JOIN 
events e ON u.id = e.user_id")
-        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
 
         ds_config = DataSourceConfig(
             conn_id="aws_default",
@@ -390,7 +390,7 @@ class TestLLMSQLQueryOperatorSchemaIntrospection:
             result = op.execute(context=MagicMock())
 
         assert "SELECT" in result
-        instructions = 
mock_hook_cls.return_value.create_agent.call_args[1]["instructions"]
+        instructions = 
mock_hook_cls.get_hook.return_value.create_agent.call_args[1]["instructions"]
         assert "users" in instructions
         assert "events" in instructions
         assert "event: TEXT\nts: TIMESTAMP" in instructions
@@ -464,7 +464,7 @@ class TestLLMSQLQueryOperatorApproval:
     def test_execute_with_approval_defers(self, mock_hook_cls, mock_upsert, 
mock_trigger_cls):
         """When require_approval=True, execute() defers after generating and 
validating SQL."""
 
-        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent(
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
_make_mock_agent(
             "SELECT id FROM users WHERE active"
         )
 
@@ -491,7 +491,7 @@ class TestLLMSQLQueryOperatorApproval:
         self, mock_hook_cls, mock_upsert, mock_trigger_cls
     ):
         """SQL validation runs before defer_for_approval; unsafe SQL is 
blocked."""
-        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("DROP TABLE users")
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
_make_mock_agent("DROP TABLE users")
 
         op = LLMSQLQueryOperator(
             task_id="sql_unsafe",
@@ -512,7 +512,7 @@ class TestLLMSQLQueryOperatorApproval:
     def test_execute_with_approval_and_modifications(self, mock_hook_cls, 
mock_upsert, mock_trigger_cls):
         """allow_modifications=True passes editable params."""
 
-        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("SELECT 1")
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
_make_mock_agent("SELECT 1")
 
         op = LLMSQLQueryOperator(
             task_id="sql_mod",
@@ -535,7 +535,7 @@ class TestLLMSQLQueryOperatorApproval:
     def test_execute_with_approval_and_timeout(self, mock_hook_cls, 
mock_upsert, mock_trigger_cls):
         """approval_timeout is propagated to the trigger."""
 
-        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("SELECT 1")
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
_make_mock_agent("SELECT 1")
         timeout = timedelta(minutes=30)
 
         op = LLMSQLQueryOperator(
@@ -555,7 +555,7 @@ class TestLLMSQLQueryOperatorApproval:
     @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
     def test_execute_without_approval_returns_sql(self, mock_hook_cls):
         """When require_approval=False, execute() returns the SQL directly."""
-        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("SELECT 1")
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
_make_mock_agent("SELECT 1")
 
         op = LLMSQLQueryOperator(
             task_id="no_approval",
@@ -573,7 +573,9 @@ class TestLLMSQLQueryOperatorApproval:
     def test_execute_strips_code_fences_before_deferring(self, mock_hook_cls, 
mock_upsert, mock_trigger_cls):
         """Markdown code fences are stripped from LLM output before 
deferring."""
 
-        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("```sql\nSELECT 1\n```")
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
_make_mock_agent(
+            "```sql\nSELECT 1\n```"
+        )
 
         op = LLMSQLQueryOperator(
             task_id="strip_test",


Reply via email to