This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 840bcf30247 Add `PydanticAIHook` to `common.ai` provider (#62546)
840bcf30247 is described below
commit 840bcf30247dd63063bbb054e3c8645f7896fb5b
Author: Kaxil Naik <[email protected]>
AuthorDate: Fri Feb 27 19:18:14 2026 +0000
Add `PydanticAIHook` to `common.ai` provider (#62546)
Adds a hook for LLM access via pydantic-ai to the common.ai provider.
The hook manages connection credentials and creates pydantic-ai Model
and Agent objects, supporting any provider (OpenAI, Anthropic, Google,
Bedrock, Ollama, vLLM, etc.).
- get_conn() returns a pydantic-ai Model configured with credentials
from the Airflow connection (api_key, base_url via provider_factory)
- create_agent() creates a pydantic-ai Agent with the hook's model
- test_connection() validates model resolution without an API call
- Connection UI fields: password (API Key), host (base URL), extra (model)
- Google Vertex/GLA providers delegate to default ADC auth
TypeVar on create_agent() lets mypy propagate the output_type
through Agent[None, OutputT] → RunResult[OutputT] → result.output,
so callers like example_pydantic_ai_hook.py don't need type: ignore.
Also fix black-docs blank line in RST code block.
- Move SQLResult inside task function so Sphinx autoapi doesn't
document Pydantic BaseModel internals (fixes RST indentation errors)
- Add Groq, Ollama, vLLM to spelling wordlist
- Change "parseable" to "valid" in test_connection docstring
- Remove separate code-block from RST (class is now in exampleinclude)
- Import BaseHook from common.compat.sdk for Airflow 2.x/3.x compat
- Import dag/task from common.compat.sdk in example DAG
- Replace AirflowException with ValueError for model validation
- Use @overload for create_agent so mypy handles the default correctly
Co-authored-by: GPK <[email protected]>
---
docs/spelling_wordlist.txt | 3 +
.../common/ai/docs/connections/pydantic_ai.rst | 114 ++++++++
providers/common/ai/docs/hooks/pydantic_ai.rst | 68 +++++
providers/common/ai/docs/index.rst | 49 +++-
providers/common/ai/provider.yaml | 26 +-
providers/common/ai/pyproject.toml | 13 +
.../providers/common/ai/example_dags/__init__.py} | 7 -
.../ai/example_dags/example_pydantic_ai_hook.py | 67 +++++
.../providers/common/ai/get_provider_info.py | 29 +-
.../airflow/providers/common/ai/hooks/__init__.py} | 7 -
.../providers/common/ai/hooks/pydantic_ai.py | 166 ++++++++++++
.../common/ai/{test_empty.py => hooks/__init__.py} | 7 -
.../tests/unit/common/ai/hooks/test_pydantic_ai.py | 298 +++++++++++++++++++++
providers/yandex/docs/index.rst | 2 +-
providers/yandex/pyproject.toml | 2 +-
15 files changed, 827 insertions(+), 31 deletions(-)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 59f37d532e2..0cc5d0b6fd0 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -812,6 +812,7 @@ gpus
Grafana
graphviz
greenlet
+Groq
Groupalia
groupId
Groupon
@@ -1282,6 +1283,7 @@ ok
oklch
Okta
okta
+Ollama
onboarded
onboarding
OnFailure
@@ -2047,6 +2049,7 @@ virtualenv
virtualenvs
virtualized
Vite
+vLLM
vm
VolumeKmsKeyId
VolumeMount
diff --git a/providers/common/ai/docs/connections/pydantic_ai.rst
b/providers/common/ai/docs/connections/pydantic_ai.rst
new file mode 100644
index 00000000000..85871549643
--- /dev/null
+++ b/providers/common/ai/docs/connections/pydantic_ai.rst
@@ -0,0 +1,114 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. _howto/connection:pydantic_ai:
+
+Pydantic AI Connection
+======================
+
+The `Pydantic AI <https://ai.pydantic.dev/>`__ connection type configures
access
+to LLM providers via the pydantic-ai framework. A single connection type works
with
+any provider that pydantic-ai supports: OpenAI, Anthropic, Google, Bedrock,
Groq,
+Mistral, Ollama, vLLM, and others.
+
+Default Connection IDs
+----------------------
+
+The ``PydanticAIHook`` uses ``pydantic_ai_default`` by default.
+
+Configuring the Connection
+--------------------------
+
+API Key (Password field)
+ The API key for your LLM provider. Required for API-key-based providers
+ (OpenAI, Anthropic, Groq, Mistral). Leave empty for providers using
+ environment-based auth (Bedrock via ``AWS_PROFILE``, Vertex via
+ ``GOOGLE_APPLICATION_CREDENTIALS``).
+
+Host (optional)
+ Base URL for the provider's API. Only needed for custom endpoints:
+
+ - Ollama: ``http://localhost:11434/v1``
+ - vLLM: ``http://localhost:8000/v1``
+ - Azure OpenAI:
``https://<resource>.openai.azure.com/openai/deployments/<deployment>``
+ - Any OpenAI-compatible API: the base URL of that service
+
+Extra (JSON, optional)
+ A JSON object with additional configuration. The ``model`` key specifies
+ the default model in ``provider:model`` format:
+
+ .. code-block:: json
+
+ {"model": "openai:gpt-5.3"}
+
+ The model can also be overridden at the hook/operator level via the
+ ``model_id`` parameter.
+
+Examples
+--------
+
+**OpenAI**
+
+.. code-block:: json
+
+ {
+ "conn_type": "pydantic_ai",
+ "password": "sk-...",
+ "extra": "{\"model\": \"openai:gpt-5.3\"}"
+ }
+
+**Anthropic**
+
+.. code-block:: json
+
+ {
+ "conn_type": "pydantic_ai",
+ "password": "sk-ant-...",
+ "extra": "{\"model\": \"anthropic:claude-opus-4-6\"}"
+ }
+
+**Ollama (local)**
+
+.. code-block:: json
+
+ {
+ "conn_type": "pydantic_ai",
+ "host": "http://localhost:11434/v1",
+ "extra": "{\"model\": \"openai:llama3\"}"
+ }
+
+**AWS Bedrock**
+
+Leave password empty and configure ``AWS_PROFILE`` or IAM role in the
environment:
+
+.. code-block:: json
+
+ {
+ "conn_type": "pydantic_ai",
+ "extra": "{\"model\": \"bedrock:us.anthropic.claude-opus-4-6-v1:0\"}"
+ }
+
+**Google Vertex AI**
+
+Leave password empty and configure ``GOOGLE_APPLICATION_CREDENTIALS`` in the
environment:
+
+.. code-block:: json
+
+ {
+ "conn_type": "pydantic_ai",
+ "extra": "{\"model\": \"google:gemini-2.0-flash\"}"
+ }
diff --git a/providers/common/ai/docs/hooks/pydantic_ai.rst
b/providers/common/ai/docs/hooks/pydantic_ai.rst
new file mode 100644
index 00000000000..2cea9740925
--- /dev/null
+++ b/providers/common/ai/docs/hooks/pydantic_ai.rst
@@ -0,0 +1,68 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. _howto/hook:pydantic_ai:
+
+PydanticAIHook
+==============
+
+Use :class:`~airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook` to
interact
+with LLM providers via `pydantic-ai <https://ai.pydantic.dev/>`__.
+
+The hook manages API credentials from an Airflow connection and creates
pydantic-ai
+``Model`` and ``Agent`` objects. It supports any provider that pydantic-ai
supports.
+
+.. seealso::
+ :ref:`Connection configuration <howto/connection:pydantic_ai>`
+
+Basic Usage
+-----------
+
+Use the hook in a ``@task`` function to call an LLM:
+
+.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_pydantic_ai_hook.py
+ :language: python
+ :start-after: [START howto_hook_pydantic_ai_basic]
+ :end-before: [END howto_hook_pydantic_ai_basic]
+
+Overriding the Model
+--------------------
+
+The model can be specified at three levels (highest priority first):
+
+1. ``model_id`` parameter on the hook
+2. ``model`` key in the connection's extra JSON
+3. (No default — raises an error if neither is set)
+
+.. code-block:: python
+
+ # Use model from the connection's extra JSON
+ hook = PydanticAIHook(llm_conn_id="my_llm")
+
+ # Override with a specific model
+ hook = PydanticAIHook(llm_conn_id="my_llm",
model_id="anthropic:claude-opus-4-6")
+
+Structured Output
+-----------------
+
+Pydantic-ai's structured output works naturally through the hook.
+Define a Pydantic model for the expected output shape, then pass it as
``output_type``:
+
+.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_pydantic_ai_hook.py
+ :language: python
+ :start-after: [START howto_hook_pydantic_ai_structured_output]
+ :end-before: [END howto_hook_pydantic_ai_structured_output]
diff --git a/providers/common/ai/docs/index.rst
b/providers/common/ai/docs/index.rst
index 98068f43132..19077bfe183 100644
--- a/providers/common/ai/docs/index.rst
+++ b/providers/common/ai/docs/index.rst
@@ -29,6 +29,14 @@
Changelog <changelog>
Security <security>
+.. toctree::
+ :hidden:
+ :maxdepth: 1
+ :caption: Guides
+
+ Connection types <connections/pydantic_ai>
+ Hooks <hooks/pydantic_ai>
+
.. toctree::
:hidden:
:maxdepth: 1
@@ -65,7 +73,7 @@
apache-airflow-providers-common-ai package
------------------------------------------------------
-``Common AI Provider``
+AI/LLM hooks and operators for Airflow pipelines using `pydantic-ai
<https://ai.pydantic.dev/>`__.
Release: 0.0.1
@@ -88,8 +96,37 @@ Requirements
The minimum Apache Airflow version supported by this provider distribution is
``3.0.0``.
-================== ==================
-PIP package Version required
-================== ==================
-``apache-airflow`` ``>=3.0.0``
-================== ==================
+==================== ==================
+PIP package Version required
+==================== ==================
+``apache-airflow`` ``>=3.0.0``
+``pydantic-ai-slim`` ``>=1.14.0``
+==================== ==================
+
+Cross provider package dependencies
+-----------------------------------
+
+Those are dependencies that might be needed in order to use all the features
of the package.
+You need to install the specified provider distributions in order to use them.
+
+You can install such cross-provider dependencies when installing from PyPI.
For example:
+
+.. code-block:: bash
+
+ pip install apache-airflow-providers-common-ai[common.compat]
+
+
+==================================================================================================================
=================
+Dependent package
Extra
+==================================================================================================================
=================
+`apache-airflow-providers-common-compat
<https://airflow.apache.org/docs/apache-airflow-providers-common-compat>`_
``common.compat``
+==================================================================================================================
=================
+
+Downloading official packages
+-----------------------------
+
+You can download officially released packages and verify their checksums and
signatures from the
+`Official Apache Download site
<https://downloads.apache.org/airflow/providers/>`_
+
+* `The apache-airflow-providers-common-ai 0.0.1 sdist package
<https://downloads.apache.org/airflow/providers/apache_airflow_providers_common_ai-0.0.1.tar.gz>`_
(`asc
<https://downloads.apache.org/airflow/providers/apache_airflow_providers_common_ai-0.0.1.tar.gz.asc>`__,
`sha512
<https://downloads.apache.org/airflow/providers/apache_airflow_providers_common_ai-0.0.1.tar.gz.sha512>`__)
+* `The apache-airflow-providers-common-ai 0.0.1 wheel package
<https://downloads.apache.org/airflow/providers/apache_airflow_providers_common_ai-0.0.1-py3-none-any.whl>`_
(`asc
<https://downloads.apache.org/airflow/providers/apache_airflow_providers_common_ai-0.0.1-py3-none-any.whl.asc>`__,
`sha512
<https://downloads.apache.org/airflow/providers/apache_airflow_providers_common_ai-0.0.1-py3-none-any.whl.sha512>`__)
diff --git a/providers/common/ai/provider.yaml
b/providers/common/ai/provider.yaml
index 6c8356dafa3..247371f289b 100644
--- a/providers/common/ai/provider.yaml
+++ b/providers/common/ai/provider.yaml
@@ -19,7 +19,7 @@
package-name: apache-airflow-providers-common-ai
name: Common AI
description: |
- ``Common AI Provider``
+ AI/LLM hooks and operators for Airflow pipelines using `pydantic-ai
<https://ai.pydantic.dev/>`__.
state: not-ready
lifecycle: incubation
@@ -27,3 +27,27 @@ source-date-epoch: 1770463465
# note that those versions are maintained by release manager - do not update
them manually
versions:
- 0.0.1
+
+integrations:
+ - integration-name: Pydantic AI
+ external-doc-url: https://ai.pydantic.dev/
+ tags: [software]
+
+hooks:
+ - integration-name: Pydantic AI
+ python-modules:
+ - airflow.providers.common.ai.hooks.pydantic_ai
+
+connection-types:
+ - hook-class-name:
airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook
+ connection-type: pydantic_ai
+ ui-field-behaviour:
+ hidden-fields:
+ - schema
+ - port
+ - login
+ relabeling:
+ password: API Key
+ placeholders:
+ host: "https://api.openai.com/v1 (optional, for custom endpoints)"
+ extra: '{"model": "openai:gpt-5"}'
diff --git a/providers/common/ai/pyproject.toml
b/providers/common/ai/pyproject.toml
index f066d693728..ae88717446a 100644
--- a/providers/common/ai/pyproject.toml
+++ b/providers/common/ai/pyproject.toml
@@ -59,6 +59,18 @@ requires-python = ">=3.10"
# After you modify the dependencies, and rebuild your Breeze CI image with
``breeze ci-image build``
dependencies = [
"apache-airflow>=3.0.0",
+ "pydantic-ai-slim>=1.14.0",
+]
+
+# The optional dependencies should be modified in place in the generated file
+# Any change in the dependencies is preserved when the file is regenerated
+[project.optional-dependencies]
+"anthropic" = ["pydantic-ai-slim[anthropic]"]
+"bedrock" = ["pydantic-ai-slim[bedrock]"]
+"google" = ["pydantic-ai-slim[google]"]
+"openai" = ["pydantic-ai-slim[openai]"]
+"common.compat" = [
+ "apache-airflow-providers-common-compat"
]
[dependency-groups]
@@ -66,6 +78,7 @@ dev = [
"apache-airflow",
"apache-airflow-task-sdk",
"apache-airflow-devel-common",
+ "apache-airflow-providers-common-compat",
# Additional devel dependencies (do not remove this line and add extra
development dependencies)
]
diff --git a/providers/common/ai/tests/unit/common/ai/test_empty.py
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/__init__.py
similarity index 85%
copy from providers/common/ai/tests/unit/common/ai/test_empty.py
copy to
providers/common/ai/src/airflow/providers/common/ai/example_dags/__init__.py
index b92357a7e5c..13a83393a91 100644
--- a/providers/common/ai/tests/unit/common/ai/test_empty.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/__init__.py
@@ -14,10 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-# To remove later, we need at least one test, otherwise CI fails
-from __future__ import annotations
-
-
-def test_empty():
- assert True
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_pydantic_ai_hook.py
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_pydantic_ai_hook.py
new file mode 100644
index 00000000000..887603825ef
--- /dev/null
+++
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_pydantic_ai_hook.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Example DAG demonstrating PydanticAIHook usage."""
+
+from __future__ import annotations
+
+from pydantic import BaseModel
+
+from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
+from airflow.providers.common.compat.sdk import dag, task
+
+
+# [START howto_hook_pydantic_ai_basic]
+@dag(schedule=None)
+def example_pydantic_ai_hook():
+ @task
+ def generate_summary(text: str) -> str:
+ hook = PydanticAIHook(llm_conn_id="pydantic_ai_default")
+ agent = hook.create_agent(output_type=str, instructions="Summarize
concisely.")
+ result = agent.run_sync(text)
+ return result.output
+
+ generate_summary("Apache Airflow is a platform for programmatically
authoring...")
+
+
+# [END howto_hook_pydantic_ai_basic]
+
+example_pydantic_ai_hook()
+
+
+# [START howto_hook_pydantic_ai_structured_output]
+@dag(schedule=None)
+def example_pydantic_ai_structured_output():
+ @task
+ def generate_sql(prompt: str) -> dict:
+ class SQLResult(BaseModel):
+ query: str
+ explanation: str
+
+ hook = PydanticAIHook(llm_conn_id="pydantic_ai_default")
+ agent = hook.create_agent(
+ output_type=SQLResult,
+ instructions="Generate a SQL query and explain it.",
+ )
+ result = agent.run_sync(prompt)
+ return result.output.model_dump()
+
+ generate_sql("Find the top 10 customers by revenue")
+
+
+# [END howto_hook_pydantic_ai_structured_output]
+
+example_pydantic_ai_structured_output()
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
index 29cbc291b1c..ee2a7c03a7f 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
@@ -25,5 +25,32 @@ def get_provider_info():
return {
"package-name": "apache-airflow-providers-common-ai",
"name": "Common AI",
- "description": "``Common AI Provider``\n",
+ "description": "AI/LLM hooks and operators for Airflow pipelines using
`pydantic-ai <https://ai.pydantic.dev/>`__.\n",
+ "integrations": [
+ {
+ "integration-name": "Pydantic AI",
+ "external-doc-url": "https://ai.pydantic.dev/",
+ "tags": ["software"],
+ }
+ ],
+ "hooks": [
+ {
+ "integration-name": "Pydantic AI",
+ "python-modules":
["airflow.providers.common.ai.hooks.pydantic_ai"],
+ }
+ ],
+ "connection-types": [
+ {
+ "hook-class-name":
"airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook",
+ "connection-type": "pydantic_ai",
+ "ui-field-behaviour": {
+ "hidden-fields": ["schema", "port", "login"],
+ "relabeling": {"password": "API Key"},
+ "placeholders": {
+ "host": "https://api.openai.com/v1 (optional, for
custom endpoints)",
+ "extra": '{"model": "openai:gpt-5"}',
+ },
+ },
+ }
+ ],
}
diff --git a/providers/common/ai/tests/unit/common/ai/test_empty.py
b/providers/common/ai/src/airflow/providers/common/ai/hooks/__init__.py
similarity index 85%
copy from providers/common/ai/tests/unit/common/ai/test_empty.py
copy to providers/common/ai/src/airflow/providers/common/ai/hooks/__init__.py
index b92357a7e5c..13a83393a91 100644
--- a/providers/common/ai/tests/unit/common/ai/test_empty.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/hooks/__init__.py
@@ -14,10 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-# To remove later, we need at least one test, otherwise CI fails
-from __future__ import annotations
-
-
-def test_empty():
- assert True
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py
b/providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py
new file mode 100644
index 00000000000..f94245e40f3
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py
@@ -0,0 +1,166 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, TypeVar, overload
+
+from pydantic_ai import Agent
+from pydantic_ai.models import Model, infer_model
+from pydantic_ai.providers import Provider, infer_provider,
infer_provider_class
+
+from airflow.providers.common.compat.sdk import BaseHook
+
+OutputT = TypeVar("OutputT")
+
+if TYPE_CHECKING:
+ from pydantic_ai.models import KnownModelName
+
+
+class PydanticAIHook(BaseHook):
+ """
+ Hook for LLM access via pydantic-ai.
+
+ Manages connection credentials and model creation. Uses pydantic-ai's
+ model inference to support any provider (OpenAI, Anthropic, Google,
+ Bedrock, Ollama, vLLM, etc.).
+
+ Connection fields:
+ - **password**: API key (OpenAI, Anthropic, Groq, Mistral, etc.)
+ - **host**: Base URL (optional — for custom endpoints like Ollama,
vLLM, Azure)
+ - **extra** JSON: ``{"model": "openai:gpt-5.3"}``
+
+ Cloud providers (Bedrock, Vertex) that use native auth chains should leave
+ password empty and configure environment-based auth (``AWS_PROFILE``,
+ ``GOOGLE_APPLICATION_CREDENTIALS``).
+
+ :param llm_conn_id: Airflow connection ID for the LLM provider.
+ :param model_id: Model identifier in ``provider:model`` format (e.g.
``"openai:gpt-5.3"``).
+ Overrides the model stored in the connection's extra field.
+ """
+
+ conn_name_attr = "llm_conn_id"
+ default_conn_name = "pydantic_ai_default"
+ conn_type = "pydantic_ai"
+ hook_name = "Pydantic AI"
+
+ def __init__(
+ self,
+ llm_conn_id: str = default_conn_name,
+ model_id: str | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.llm_conn_id = llm_conn_id
+ self.model_id = model_id
+ self._model: Model | None = None
+
+ @staticmethod
+ def get_ui_field_behaviour() -> dict[str, Any]:
+ """Return custom field behaviour for the Airflow connection form."""
+ return {
+ "hidden_fields": ["schema", "port", "login"],
+ "relabeling": {"password": "API Key"},
+ "placeholders": {
+ "host": "https://api.openai.com/v1 (optional, for custom
endpoints)",
+ "extra": '{"model": "openai:gpt-5.3"}',
+ },
+ }
+
+ def get_conn(self) -> Model:
+ """
+ Return a configured pydantic-ai Model.
+
+ Reads API key from connection password, model from connection extra
+ or ``model_id`` parameter, and base_url from connection host.
+ The result is cached for the lifetime of this hook instance.
+ """
+ if self._model is not None:
+ return self._model
+
+ conn = self.get_connection(self.llm_conn_id)
+ model_name: str | KnownModelName = self.model_id or
conn.extra_dejson.get("model", "")
+ if not model_name:
+ raise ValueError(
+ "No model specified. Set model_id on the hook or 'model' in
the connection's extra JSON."
+ )
+ api_key = conn.password
+ base_url = conn.host or None
+
+ if not api_key and not base_url:
+ # No credentials to inject — use default provider resolution
+ # (picks up env vars like OPENAI_API_KEY, AWS_PROFILE, etc.)
+ self._model = infer_model(model_name)
+ return self._model
+
+ def _provider_factory(provider_name: str) -> Provider[Any]:
+ """
+ Create a provider with credentials from the Airflow connection.
+
+ Falls back to default provider resolution if the provider's
constructor
+ doesn't accept api_key/base_url (e.g. Google Vertex, Bedrock).
+ """
+ provider_cls = infer_provider_class(provider_name)
+ kwargs: dict[str, Any] = {}
+ if api_key:
+ kwargs["api_key"] = api_key
+ if base_url:
+ kwargs["base_url"] = base_url
+ try:
+ return provider_cls(**kwargs)
+ except TypeError:
+ # Provider doesn't accept these kwargs (e.g. Google Vertex/GLA
+ # use ADC, Bedrock uses boto session). Fall back to default
+ # provider resolution which reads credentials from the
environment.
+ return infer_provider(provider_name)
+
+ self._model = infer_model(model_name,
provider_factory=_provider_factory)
+ return self._model
+
+ @overload
+ def create_agent(
+ self, output_type: type[OutputT], *, instructions: str, **agent_kwargs
+ ) -> Agent[None, OutputT]: ...
+
+ @overload
+ def create_agent(self, *, instructions: str, **agent_kwargs) ->
Agent[None, str]: ...
+
+ def create_agent(
+ self, output_type: type[Any] = str, *, instructions: str,
**agent_kwargs
+ ) -> Agent[None, Any]:
+ """
+ Create a pydantic-ai Agent configured with this hook's model.
+
+ :param output_type: The expected output type from the agent (default:
``str``).
+ :param instructions: System-level instructions for the agent.
+ :param agent_kwargs: Additional keyword arguments passed to the Agent
constructor.
+ """
+ return Agent(self.get_conn(), output_type=output_type,
instructions=instructions, **agent_kwargs)
+
+ def test_connection(self) -> tuple[bool, str]:
+ """
+ Test connection by resolving the model.
+
+ Validates that the model string is valid, the provider package is
+ installed, and the provider class can be instantiated. Does NOT make an
+ LLM API call — that would be expensive, flaky, and fail for reasons
+ unrelated to connectivity (quotas, billing, rate limits).
+ """
+ try:
+ self.get_conn()
+ return True, "Model resolved successfully."
+ except Exception as e:
+ return False, str(e)
diff --git a/providers/common/ai/tests/unit/common/ai/test_empty.py
b/providers/common/ai/tests/unit/common/ai/hooks/__init__.py
similarity index 85%
rename from providers/common/ai/tests/unit/common/ai/test_empty.py
rename to providers/common/ai/tests/unit/common/ai/hooks/__init__.py
index b92357a7e5c..13a83393a91 100644
--- a/providers/common/ai/tests/unit/common/ai/test_empty.py
+++ b/providers/common/ai/tests/unit/common/ai/hooks/__init__.py
@@ -14,10 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-# To remove later, we need at least one test, otherwise CI fails
-from __future__ import annotations
-
-
-def test_empty():
- assert True
diff --git a/providers/common/ai/tests/unit/common/ai/hooks/test_pydantic_ai.py
b/providers/common/ai/tests/unit/common/ai/hooks/test_pydantic_ai.py
new file mode 100644
index 00000000000..8f8dfbcad06
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/hooks/test_pydantic_ai.py
@@ -0,0 +1,298 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+from pydantic_ai.models import Model
+
+from airflow.models.connection import Connection
+from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
+
+
+class TestPydanticAIHookInit:
+ def test_default_conn_id(self):
+ hook = PydanticAIHook()
+ assert hook.llm_conn_id == "pydantic_ai_default"
+ assert hook.model_id is None
+
+ def test_custom_conn_id(self):
+ hook = PydanticAIHook(llm_conn_id="my_llm", model_id="openai:gpt-5.3")
+ assert hook.llm_conn_id == "my_llm"
+ assert hook.model_id == "openai:gpt-5.3"
+
+
+class TestPydanticAIHookGetConn:
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+
@patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_provider_class",
autospec=True)
+ def test_get_conn_with_api_key_and_base_url(self,
mock_infer_provider_class, mock_infer_model):
+ """Credentials are injected via provider_factory, not as direct
kwargs."""
+ mock_model = MagicMock(spec=Model)
+ mock_infer_model.return_value = mock_model
+ mock_provider = MagicMock()
+ mock_infer_provider_class.return_value =
MagicMock(return_value=mock_provider)
+
+ hook = PydanticAIHook(llm_conn_id="test_conn",
model_id="openai:gpt-5.3")
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="pydantic_ai",
+ password="sk-test-key",
+ host="https://api.openai.com/v1",
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ result = hook.get_conn()
+
+ assert result is mock_model
+ mock_infer_model.assert_called_once()
+ call_args = mock_infer_model.call_args
+ assert call_args[0][0] == "openai:gpt-5.3"
+ # provider_factory should be passed as keyword arg
+ assert "provider_factory" in call_args[1]
+
+ # Call the factory to verify it creates the provider with credentials
+ factory = call_args[1]["provider_factory"]
+ factory("openai")
+ mock_infer_provider_class.assert_called_with("openai")
+ mock_infer_provider_class.return_value.assert_called_with(
+ api_key="sk-test-key", base_url="https://api.openai.com/v1"
+ )
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+
@patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_provider_class",
autospec=True)
+ def test_get_conn_with_model_from_extra(self, mock_infer_provider_class,
mock_infer_model):
+ mock_model = MagicMock(spec=Model)
+ mock_infer_model.return_value = mock_model
+ mock_infer_provider_class.return_value =
MagicMock(return_value=MagicMock())
+
+ hook = PydanticAIHook(llm_conn_id="test_conn")
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="pydantic_ai",
+ password="sk-test-key",
+ extra='{"model": "anthropic:claude-opus-4-6"}',
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ result = hook.get_conn()
+
+ assert result is mock_model
+ assert mock_infer_model.call_args[0][0] == "anthropic:claude-opus-4-6"
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+
@patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_provider_class",
autospec=True)
+ def test_model_id_param_overrides_extra(self, mock_infer_provider_class,
mock_infer_model):
+ mock_infer_model.return_value = MagicMock(spec=Model)
+ mock_infer_provider_class.return_value =
MagicMock(return_value=MagicMock())
+
+ hook = PydanticAIHook(llm_conn_id="test_conn",
model_id="openai:gpt-5.3")
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="pydantic_ai",
+ password="sk-test-key",
+ extra='{"model": "anthropic:claude-opus-4-6"}',
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.get_conn()
+
+ # model_id param takes priority over extra
+ assert mock_infer_model.call_args[0][0] == "openai:gpt-5.3"
+
+ def test_get_conn_raises_when_no_model(self):
+ hook = PydanticAIHook(llm_conn_id="test_conn")
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="pydantic_ai",
+ password="sk-test-key",
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ with pytest.raises(ValueError, match="No model specified"):
+ hook.get_conn()
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+ def test_get_conn_without_credentials_uses_default_provider(self,
mock_infer_model):
+ """No api_key or base_url means env-based auth (Bedrock, Vertex,
etc.)."""
+ mock_model = MagicMock(spec=Model)
+ mock_infer_model.return_value = mock_model
+
+ hook = PydanticAIHook(llm_conn_id="test_conn",
model_id="bedrock:us.anthropic.claude-v2")
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="pydantic_ai",
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.get_conn()
+
+ # No provider_factory — uses default infer_provider which reads env
vars
+
mock_infer_model.assert_called_once_with("bedrock:us.anthropic.claude-v2")
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+
@patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_provider_class",
autospec=True)
+ def test_get_conn_with_base_url_only(self, mock_infer_provider_class,
mock_infer_model):
+ """Ollama / vLLM: base_url but no API key."""
+ mock_infer_model.return_value = MagicMock(spec=Model)
+ mock_infer_provider_class.return_value =
MagicMock(return_value=MagicMock())
+
+ hook = PydanticAIHook(llm_conn_id="test_conn",
model_id="openai:llama3")
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="pydantic_ai",
+ host="http://localhost:11434/v1",
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.get_conn()
+
+ # provider_factory should be used since base_url is set
+ factory = mock_infer_model.call_args[1]["provider_factory"]
+ factory("openai")
+
mock_infer_provider_class.return_value.assert_called_with(base_url="http://localhost:11434/v1")
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+ def test_get_conn_caches_model(self, mock_infer_model):
+ """get_conn() should resolve the model once and cache it."""
+ mock_model = MagicMock(spec=Model)
+ mock_infer_model.return_value = mock_model
+
+ hook = PydanticAIHook(llm_conn_id="test_conn",
model_id="openai:gpt-5.3")
+ conn = Connection(conn_id="test_conn", conn_type="pydantic_ai")
+ with patch.object(hook, "get_connection", return_value=conn):
+ first = hook.get_conn()
+ second = hook.get_conn()
+
+ assert first is second
+ mock_infer_model.assert_called_once()
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_provider",
autospec=True)
+
@patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_provider_class",
autospec=True)
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+ def test_provider_factory_falls_back_on_unsupported_kwargs(
+ self, mock_infer_model, mock_infer_provider_class, mock_infer_provider
+ ):
+ """If a provider rejects api_key/base_url, fall back to default
resolution."""
+ mock_infer_model.return_value = MagicMock(spec=Model)
+ mock_fallback_provider = MagicMock()
+ mock_infer_provider.return_value = mock_fallback_provider
+ # Simulate a provider that doesn't accept api_key/base_url
+ mock_infer_provider_class.return_value =
MagicMock(side_effect=TypeError("unexpected keyword"))
+
+ hook = PydanticAIHook(llm_conn_id="test_conn",
model_id="google:gemini-2.0-flash")
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="pydantic_ai",
+ password="some-key",
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.get_conn()
+
+ factory = mock_infer_model.call_args[1]["provider_factory"]
+ result = factory("google-gla")
+
+ # Should have tried provider_cls first, then fallen back to
infer_provider
+
mock_infer_provider_class.return_value.assert_called_once_with(api_key="some-key")
+ mock_infer_provider.assert_called_with("google-gla")
+ assert result is mock_fallback_provider
+
+
+class TestPydanticAIHookCreateAgent:
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.Agent",
autospec=True)
+ def test_create_agent_defaults(self, mock_agent_cls, mock_infer_model):
+ mock_model = MagicMock(spec=Model)
+ mock_infer_model.return_value = mock_model
+
+ hook = PydanticAIHook(llm_conn_id="test_conn",
model_id="openai:gpt-5.3")
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="pydantic_ai",
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.create_agent(instructions="You are a helpful assistant.")
+
+ mock_agent_cls.assert_called_once_with(
+ mock_model,
+ output_type=str,
+ instructions="You are a helpful assistant.",
+ )
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.Agent",
autospec=True)
+ def test_create_agent_with_params(self, mock_agent_cls, mock_infer_model):
+ mock_model = MagicMock(spec=Model)
+ mock_infer_model.return_value = mock_model
+
+ hook = PydanticAIHook(llm_conn_id="test_conn",
model_id="openai:gpt-5.3")
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="pydantic_ai",
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.create_agent(
+ output_type=dict,
+ instructions="Be helpful.",
+ retries=3,
+ )
+
+ mock_agent_cls.assert_called_once_with(
+ mock_model,
+ output_type=dict,
+ instructions="Be helpful.",
+ retries=3,
+ )
+
+
+class TestPydanticAIHookTestConnection:
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+ def test_successful_connection(self, mock_infer_model):
+ mock_model = MagicMock(spec=Model)
+ mock_infer_model.return_value = mock_model
+
+ hook = PydanticAIHook(llm_conn_id="test_conn",
model_id="openai:gpt-5.3")
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="pydantic_ai",
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ success, message = hook.test_connection()
+
+ assert success is True
+ assert message == "Model resolved successfully."
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+ def test_failed_connection(self, mock_infer_model):
+ mock_infer_model.side_effect = ValueError("Unknown provider
'badprovider'")
+
+ hook = PydanticAIHook(llm_conn_id="test_conn",
model_id="badprovider:model")
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="pydantic_ai",
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ success, message = hook.test_connection()
+
+ assert success is False
+ assert "Unknown provider" in message
+
+ def test_failed_connection_no_model(self):
+ hook = PydanticAIHook(llm_conn_id="test_conn")
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="pydantic_ai",
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ success, message = hook.test_connection()
+
+ assert success is False
+ assert "No model specified" in message
diff --git a/providers/yandex/docs/index.rst b/providers/yandex/docs/index.rst
index 46f256e0d45..b5f40a17ffd 100644
--- a/providers/yandex/docs/index.rst
+++ b/providers/yandex/docs/index.rst
@@ -109,7 +109,7 @@ PIP package Version required
``yandexcloud`` ``>=0.328.0; python_version <
"3.13"``
``yandex-query-client`` ``>=0.1.4``
``apache-airflow-providers-common-compat`` ``>=1.13.0``
-``grpcio`` ``>=1.70.0; python_version >=
"3.13"``
+``grpcio`` ``>=1.70.0``
==========================================
=======================================
Cross provider package dependencies
diff --git a/providers/yandex/pyproject.toml b/providers/yandex/pyproject.toml
index 96962ca0da5..fc54c98c195 100644
--- a/providers/yandex/pyproject.toml
+++ b/providers/yandex/pyproject.toml
@@ -67,7 +67,7 @@ dependencies = [
# ERROR providers/yandex/tests/unit/yandex/operators/test_dataproc.py -
RuntimeError: The grpc package installed is at version 1.68.1, but the
generated code in yandex/cloud/endpoint/api_endpoint_service_pb2_grpc.py
depends on grpcio>=1.70.0. Please upgrade your grpc module to grpcio>=1.70.0 or
downgrade your generated code using grpcio-tools<=1.68.1.
# this dependency can be removed when yandexcloud bump min version of
grpcio to 1.70
#
https://github.com/yandex-cloud/python-sdk/blob/82493e32bbf1d678afbb8376632b3f5b5923fd10/pyproject.toml#L23
- 'grpcio>=1.70.0; python_version >= "3.13"',
+ "grpcio>=1.70.0",
]
[dependency-groups]