kaxil commented on code in PR #62816:
URL: https://github.com/apache/airflow/pull/62816#discussion_r2892283022
##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -75,59 +74,104 @@ def get_ui_field_behaviour() -> dict[str, Any]:
"hidden_fields": ["schema", "port", "login"],
"relabeling": {"password": "API Key"},
"placeholders": {
- "host": "https://api.openai.com/v1 (optional, for custom
endpoints)",
+ "host": "https://api.openai.com/v1 (optional, for custom
endpoints / Ollama)",
"extra": '{"model": "openai:gpt-5.3"}',
},
}
+ # ------------------------------------------------------------------
+ # Core connection / agent API
+ # ------------------------------------------------------------------
+
+ def _get_provider_kwargs(
+ self,
+ api_key: str | None,
+ base_url: str | None,
+ extra: dict[str, Any],
+ ) -> dict[str, Any]:
+ """
+ Return the kwargs to pass to the provider constructor.
+
+ Subclasses override this method to map their connection fields to the
+ parameters expected by their specific provider class. The base
+ implementation handles the common ``api_key`` / ``base_url`` pattern
+ used by OpenAI, Anthropic, Groq, Mistral, Ollama, and most other
+ providers.
+
+ :param api_key: Value of ``conn.password``.
+ :param base_url: Value of ``conn.host``.
+ :param extra: Deserialized ``conn.extra`` JSON.
+ :return: Kwargs forwarded to ``provider_cls(**kwargs)``. Empty dict
+ signals that no explicit credentials are available and the hook
+ should fall back to environment-variable–based auth.
+ """
+ kwargs: dict[str, Any] = {}
+ if api_key:
+ kwargs["api_key"] = api_key
+ if base_url:
+ kwargs["base_url"] = base_url
+ return kwargs
+
def get_conn(self) -> Model:
"""
- Return a configured pydantic-ai Model.
+ Return a configured pydantic-ai ``Model``.
- Reads API key from connection password, model from connection extra
- or ``model_id`` parameter, and base_url from connection host.
- The result is cached for the lifetime of this hook instance.
+ Resolution order:
+
+ 1. **Explicit credentials** — when :meth:`_get_provider_kwargs` returns
+ a non-empty dict the provider class is instantiated with those
kwargs
+ and wrapped in a ``provider_factory``.
+ 2. **Default resolution** — delegates to pydantic-ai ``infer_model``
+ which reads standard env vars (``OPENAI_API_KEY``, ``AWS_PROFILE``,
…).
+
+ The resolved model is cached for the lifetime of this hook instance.
"""
if self._model is not None:
return self._model
conn = self.get_connection(self.llm_conn_id)
- model_name: str | KnownModelName = self.model_id or
conn.extra_dejson.get("model", "")
+
+ extra: dict[str, Any] = conn.extra_dejson
+ model_name: str | KnownModelName = self.model_id or extra.get("model",
"")
if not model_name:
raise ValueError(
"No model specified. Set model_id on the hook or 'model' in
the connection's extra JSON."
)
- api_key = conn.password
- base_url = conn.host or None
- if not api_key and not base_url:
- # No credentials to inject — use default provider resolution
- # (picks up env vars like OPENAI_API_KEY, AWS_PROFILE, etc.)
- self._model = infer_model(model_name)
+ api_key: str | None = conn.password or None
+ base_url: str | None = conn.host or None
+
+ # Auto-dispatch: if using base hook with a subclass connection type,
borrow
+ # that subclass's _get_provider_kwargs so the correct field mapping is
used.
+ if type(self) is PydanticAIHook:
+ hook_cls = _CONN_TYPE_TO_HOOK.get(conn.conn_type or "",
PydanticAIHook)
+ provider_kwargs = hook_cls._get_provider_kwargs(self, api_key,
base_url, extra)
+ else:
+ provider_kwargs = self._get_provider_kwargs(api_key, base_url,
extra)
+ if provider_kwargs:
+ _kwargs = provider_kwargs # capture for closure
+ self.log.info(
+ "Using explicit credentials for provider with model '%s': %s",
+ model_name,
+ list(provider_kwargs),
+ )
+
+ def _provider_factory(pname: str) -> Any:
+ try:
+ return infer_provider_class(pname)(**_kwargs)
+ except TypeError:
+ self.log.warning(
+ "Provider '%s' does not accept the supplied kwargs %s;
"
Review Comment:
`TypeError` here catches both "provider doesn't accept these kwargs" and
genuine bugs (wrong types, missing required args). The warning only logs kwarg
*names*, not the actual error. This makes debugging auth failures harder — a
user who passes the wrong kwarg name gets silently downgraded to env-var auth
with a vague warning.
Suggestion: log the exception too:
```python
except TypeError as exc:
self.log.warning(
"Provider '%s' does not accept kwargs %s (%s); "
"falling back to env-var auth.",
pname, list(_kwargs), exc,
)
```
##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -154,13 +198,213 @@ def test_connection(self) -> tuple[bool, str]:
"""
Test connection by resolving the model.
- Validates that the model string is valid, the provider package is
- installed, and the provider class can be instantiated. Does NOT make an
- LLM API call — that would be expensive, flaky, and fail for reasons
- unrelated to connectivity (quotas, billing, rate limits).
+ Validates that the model string is valid and the provider class can be
+ instantiated with the supplied credentials. Does NOT make an LLM API
+ call — that would be expensive and fail for reasons unrelated to
+ connectivity (quotas, billing, rate limits).
"""
try:
self.get_conn()
return True, "Model resolved successfully."
except Exception as e:
return False, str(e)
+
+ @classmethod
+ def for_connection(cls, conn_id: str, model_id: str | None = None) ->
PydanticAIHook:
+ """
+ Return the correct :class:`PydanticAIHook` subclass for *conn_id*.
+
+ Looks up the connection's ``conn_type`` in the registered hook map and
+ instantiates the matching subclass. Falls back to
+ :class:`PydanticAIHook` for unknown types.
+
+ :param conn_id: Airflow connection ID.
+ :param model_id: Optional model override forwarded to the hook.
+ """
+ conn = cls.get_connection(conn_id)
+ hook_cls = _CONN_TYPE_TO_HOOK.get(conn.conn_type or "", cls)
+ return hook_cls(llm_conn_id=conn_id, model_id=model_id)
+
+
+class PydanticAIAzureHook(PydanticAIHook):
+ """
+ Hook for Azure OpenAI via pydantic-ai.
+
+ Connection fields:
+ - **password**: Azure API key
+ - **host**: Azure endpoint (e.g.
``https://<resource>.openai.azure.com``)
+ - **extra** JSON::
+
+ {"model": "azure:gpt-4o", "api_version": "2024-07-01-preview"}
+
+ :param llm_conn_id: Airflow connection ID.
+ :param model_id: Model identifier, e.g. ``"azure:gpt-4o"``.
+ """
+
+ conn_name_attr = "llm_conn_id"
+ default_conn_name = "pydantic_ai_azure_default"
+ conn_type = "pydanticai_azure"
+ hook_name = "Pydantic AI (Azure OpenAI)"
+
+ @staticmethod
+ def get_ui_field_behaviour() -> dict[str, Any]:
+ """Return custom field behaviour for the Airflow connection form."""
+ return {
+ "hidden_fields": ["schema", "port", "login"],
+ "relabeling": {"password": "API Key", "host": "Azure Endpoint"},
+ "placeholders": {
+ "host": "https://<resource>.openai.azure.com",
+ "extra": '{"model": "azure:gpt-4o", "api_version":
"2024-07-01-preview"}',
+ },
+ }
+
+ def _get_provider_kwargs(
+ self,
+ api_key: str | None,
+ base_url: str | None,
+ extra: dict[str, Any],
+ ) -> dict[str, Any]:
+ kwargs: dict[str, Any] = {}
+ if api_key:
+ kwargs["api_key"] = api_key
+ if base_url:
+ kwargs["azure_endpoint"] = base_url
+ if extra.get("api_version"):
+ kwargs["api_version"] = extra["api_version"]
+
+ self.log.info(
+ "Using explicit credentials for Azure provider with model '%s':
%s",
+ extra.get("model", ""),
Review Comment:
This `self.log.info(...)` duplicates the log in `get_conn()` which also logs
when `provider_kwargs` is non-empty. Azure connections will produce two "Using
explicit credentials..." log lines; Bedrock and Vertex produce one. Remove this
— let the base class handle logging uniformly.
##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -154,13 +198,213 @@ def test_connection(self) -> tuple[bool, str]:
"""
Test connection by resolving the model.
- Validates that the model string is valid, the provider package is
- installed, and the provider class can be instantiated. Does NOT make an
- LLM API call — that would be expensive, flaky, and fail for reasons
- unrelated to connectivity (quotas, billing, rate limits).
+ Validates that the model string is valid and the provider class can be
+ instantiated with the supplied credentials. Does NOT make an LLM API
+ call — that would be expensive and fail for reasons unrelated to
+ connectivity (quotas, billing, rate limits).
"""
try:
self.get_conn()
return True, "Model resolved successfully."
except Exception as e:
return False, str(e)
+
+ @classmethod
+ def for_connection(cls, conn_id: str, model_id: str | None = None) ->
PydanticAIHook:
+ """
+ Return the correct :class:`PydanticAIHook` subclass for *conn_id*.
+
+ Looks up the connection's ``conn_type`` in the registered hook map and
+ instantiates the matching subclass. Falls back to
+ :class:`PydanticAIHook` for unknown types.
+
+ :param conn_id: Airflow connection ID.
+ :param model_id: Optional model override forwarded to the hook.
+ """
+ conn = cls.get_connection(conn_id)
+ hook_cls = _CONN_TYPE_TO_HOOK.get(conn.conn_type or "", cls)
+ return hook_cls(llm_conn_id=conn_id, model_id=model_id)
+
+
+class PydanticAIAzureHook(PydanticAIHook):
+ """
+ Hook for Azure OpenAI via pydantic-ai.
+
+ Connection fields:
+ - **password**: Azure API key
+ - **host**: Azure endpoint (e.g.
``https://<resource>.openai.azure.com``)
+ - **extra** JSON::
+
+ {"model": "azure:gpt-4o", "api_version": "2024-07-01-preview"}
+
+ :param llm_conn_id: Airflow connection ID.
+ :param model_id: Model identifier, e.g. ``"azure:gpt-4o"``.
+ """
+
+ conn_name_attr = "llm_conn_id"
+ default_conn_name = "pydantic_ai_azure_default"
+ conn_type = "pydanticai_azure"
+ hook_name = "Pydantic AI (Azure OpenAI)"
Review Comment:
Python evaluates default argument values at function/method definition time.
`PydanticAIHook.__init__` binds `llm_conn_id`'s default to
`"pydanticai_default"` when the class is created. Subclasses inherit this
`__init__` with the already-resolved default, so:
```python
hook = PydanticAIAzureHook() # llm_conn_id="pydanticai_default", not
"pydantic_ai_azure_default"
```
The `default_conn_name` attribute on the subclass is never consulted during
`__init__`. Same issue for Bedrock and Vertex subclasses.
##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -75,59 +74,104 @@ def get_ui_field_behaviour() -> dict[str, Any]:
"hidden_fields": ["schema", "port", "login"],
"relabeling": {"password": "API Key"},
"placeholders": {
- "host": "https://api.openai.com/v1 (optional, for custom
endpoints)",
+ "host": "https://api.openai.com/v1 (optional, for custom
endpoints / Ollama)",
"extra": '{"model": "openai:gpt-5.3"}',
},
}
+ # ------------------------------------------------------------------
+ # Core connection / agent API
+ # ------------------------------------------------------------------
+
+ def _get_provider_kwargs(
+ self,
+ api_key: str | None,
+ base_url: str | None,
+ extra: dict[str, Any],
+ ) -> dict[str, Any]:
+ """
+ Return the kwargs to pass to the provider constructor.
+
+ Subclasses override this method to map their connection fields to the
+ parameters expected by their specific provider class. The base
+ implementation handles the common ``api_key`` / ``base_url`` pattern
+ used by OpenAI, Anthropic, Groq, Mistral, Ollama, and most other
+ providers.
+
+ :param api_key: Value of ``conn.password``.
+ :param base_url: Value of ``conn.host``.
+ :param extra: Deserialized ``conn.extra`` JSON.
+ :return: Kwargs forwarded to ``provider_cls(**kwargs)``. Empty dict
+ signals that no explicit credentials are available and the hook
+ should fall back to environment-variable–based auth.
+ """
+ kwargs: dict[str, Any] = {}
+ if api_key:
+ kwargs["api_key"] = api_key
+ if base_url:
+ kwargs["base_url"] = base_url
+ return kwargs
+
def get_conn(self) -> Model:
"""
- Return a configured pydantic-ai Model.
+ Return a configured pydantic-ai ``Model``.
- Reads API key from connection password, model from connection extra
- or ``model_id`` parameter, and base_url from connection host.
- The result is cached for the lifetime of this hook instance.
+ Resolution order:
+
+ 1. **Explicit credentials** — when :meth:`_get_provider_kwargs` returns
+ a non-empty dict the provider class is instantiated with those
kwargs
+ and wrapped in a ``provider_factory``.
+ 2. **Default resolution** — delegates to pydantic-ai ``infer_model``
+ which reads standard env vars (``OPENAI_API_KEY``, ``AWS_PROFILE``,
…).
+
+ The resolved model is cached for the lifetime of this hook instance.
"""
if self._model is not None:
return self._model
conn = self.get_connection(self.llm_conn_id)
- model_name: str | KnownModelName = self.model_id or
conn.extra_dejson.get("model", "")
+
+ extra: dict[str, Any] = conn.extra_dejson
+ model_name: str | KnownModelName = self.model_id or extra.get("model",
"")
if not model_name:
raise ValueError(
"No model specified. Set model_id on the hook or 'model' in
the connection's extra JSON."
)
- api_key = conn.password
- base_url = conn.host or None
- if not api_key and not base_url:
- # No credentials to inject — use default provider resolution
- # (picks up env vars like OPENAI_API_KEY, AWS_PROFILE, etc.)
- self._model = infer_model(model_name)
+ api_key: str | None = conn.password or None
+ base_url: str | None = conn.host or None
+
+ # Auto-dispatch: if using base hook with a subclass connection type,
borrow
+ # that subclass's _get_provider_kwargs so the correct field mapping is
used.
+ if type(self) is PydanticAIHook:
+ hook_cls = _CONN_TYPE_TO_HOOK.get(conn.conn_type or "",
PydanticAIHook)
+ provider_kwargs = hook_cls._get_provider_kwargs(self, api_key,
base_url, extra)
+ else:
+ provider_kwargs = self._get_provider_kwargs(api_key, base_url,
extra)
Review Comment:
This calls a subclass's unbound method with a `PydanticAIHook` instance as
`self`:
```python
hook_cls._get_provider_kwargs(self, api_key, base_url, extra)
```
It works today because `_get_provider_kwargs` only touches `self.log`
(inherited from BaseHook). But if any subclass later accesses subclass-specific
attributes, this breaks silently.
A cleaner alternative: have the operator use `for_connection()` instead of
constructing `PydanticAIHook` directly. That returns the right subclass, and no
auto-dispatch is needed:
```python
@cached_property
def llm_hook(self) -> PydanticAIHook:
return PydanticAIHook.for_connection(self.llm_conn_id,
model_id=self.model_id)
```
This eliminates the `type(self) is PydanticAIHook` check, the
`_CONN_TYPE_TO_HOOK` registry, and the unbound method call.
##########
providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py:
##########
@@ -60,7 +60,41 @@ def get_provider_info():
"extra": '{"model": "openai:gpt-5"}',
},
},
- }
+ },
Review Comment:
This file is auto-generated from `provider.yaml` — manual edits here will be
overwritten by `prek run update-providers-build-files`. The new connection
types need to be added to `provider.yaml` instead.
##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -154,13 +198,213 @@ def test_connection(self) -> tuple[bool, str]:
"""
Test connection by resolving the model.
- Validates that the model string is valid, the provider package is
- installed, and the provider class can be instantiated. Does NOT make an
- LLM API call — that would be expensive, flaky, and fail for reasons
- unrelated to connectivity (quotas, billing, rate limits).
+ Validates that the model string is valid and the provider class can be
+ instantiated with the supplied credentials. Does NOT make an LLM API
+ call — that would be expensive and fail for reasons unrelated to
+ connectivity (quotas, billing, rate limits).
"""
try:
self.get_conn()
return True, "Model resolved successfully."
except Exception as e:
return False, str(e)
+
+ @classmethod
+ def for_connection(cls, conn_id: str, model_id: str | None = None) ->
PydanticAIHook:
+ """
+ Return the correct :class:`PydanticAIHook` subclass for *conn_id*.
+
+ Looks up the connection's ``conn_type`` in the registered hook map and
+ instantiates the matching subclass. Falls back to
+ :class:`PydanticAIHook` for unknown types.
+
+ :param conn_id: Airflow connection ID.
+ :param model_id: Optional model override forwarded to the hook.
+ """
+ conn = cls.get_connection(conn_id)
+ hook_cls = _CONN_TYPE_TO_HOOK.get(conn.conn_type or "", cls)
+ return hook_cls(llm_conn_id=conn_id, model_id=model_id)
+
+
+class PydanticAIAzureHook(PydanticAIHook):
+ """
+ Hook for Azure OpenAI via pydantic-ai.
+
+ Connection fields:
+ - **password**: Azure API key
+ - **host**: Azure endpoint (e.g.
``https://<resource>.openai.azure.com``)
+ - **extra** JSON::
+
+ {"model": "azure:gpt-4o", "api_version": "2024-07-01-preview"}
+
+ :param llm_conn_id: Airflow connection ID.
+ :param model_id: Model identifier, e.g. ``"azure:gpt-4o"``.
+ """
+
+ conn_name_attr = "llm_conn_id"
+ default_conn_name = "pydantic_ai_azure_default"
+ conn_type = "pydanticai_azure"
+ hook_name = "Pydantic AI (Azure OpenAI)"
+
+ @staticmethod
+ def get_ui_field_behaviour() -> dict[str, Any]:
+ """Return custom field behaviour for the Airflow connection form."""
+ return {
+ "hidden_fields": ["schema", "port", "login"],
+ "relabeling": {"password": "API Key", "host": "Azure Endpoint"},
+ "placeholders": {
+ "host": "https://<resource>.openai.azure.com",
+ "extra": '{"model": "azure:gpt-4o", "api_version":
"2024-07-01-preview"}',
+ },
+ }
+
+ def _get_provider_kwargs(
+ self,
+ api_key: str | None,
+ base_url: str | None,
+ extra: dict[str, Any],
+ ) -> dict[str, Any]:
+ kwargs: dict[str, Any] = {}
+ if api_key:
+ kwargs["api_key"] = api_key
+ if base_url:
+ kwargs["azure_endpoint"] = base_url
+ if extra.get("api_version"):
+ kwargs["api_version"] = extra["api_version"]
+
+ self.log.info(
+ "Using explicit credentials for Azure provider with model '%s':
%s",
+ extra.get("model", ""),
+ list(kwargs),
+ )
+ return kwargs
+
+
+class PydanticAIBedrockHook(PydanticAIHook):
+ """
+ Hook for AWS Bedrock via pydantic-ai.
+
+ Credentials are resolved in order:
+
+ 1. Explicit keys in ``extra`` (``aws_access_key_id``,
+ ``aws_secret_access_key``, ``aws_session_token``).
+ 2. Environment-variable / instance-role chain (``AWS_PROFILE``,
+ IAM role, …) when no explicit keys are provided.
+
+ Connection fields:
+ - **extra** JSON::
+
+ {
+ "model": "bedrock:us.anthropic.claude-opus-4-5",
+ "region_name": "us-east-1",
+ "aws_access_key_id": "AKIA...",
+ "aws_secret_access_key": "..."
+ }
+
+ Leave ``aws_access_key_id`` / ``aws_secret_access_key`` empty to use
+ the default AWS credential chain.
+
+ :param llm_conn_id: Airflow connection ID.
+ :param model_id: Model identifier, e.g.
``"bedrock:us.anthropic.claude-opus-4-5"``.
+ """
+
+ conn_name_attr = "llm_conn_id"
+ default_conn_name = "pydantic_ai_bedrock_default"
+ conn_type = "pydanticai_bedrock"
+ hook_name = "Pydantic AI (AWS Bedrock)"
+
+ @staticmethod
+ def get_ui_field_behaviour() -> dict[str, Any]:
+ """Return custom field behaviour for the Airflow connection form."""
+ return {
+ "hidden_fields": ["schema", "port", "login", "host", "password"],
+ "relabeling": {},
+ "placeholders": {
+ "extra": (
+ '{"model": "bedrock:us.anthropic.claude-opus-4-5", '
+ '"region_name": "us-east-1"}'
+ " — leave aws_access_key_id empty for IAM role / env-var
auth"
+ ),
+ },
+ }
+
+ def _get_provider_kwargs(
+ self,
+ api_key: str | None,
+ base_url: str | None,
+ extra: dict[str, Any],
+ ) -> dict[str, Any]:
+ _bedrock_keys = (
+ "aws_access_key_id",
+ "aws_secret_access_key",
+ "aws_session_token",
+ "region_name",
+ "profile_name",
+ )
+ return {k: extra[k] for k in _bedrock_keys if extra.get(k) is not None}
+
+
+class PydanticAIVertexHook(PydanticAIHook):
+ """
+ Hook for Google Vertex AI via pydantic-ai.
+
+ Credentials are resolved in order:
+
+ 1. ``service_account_file`` / ``service_account_info`` in ``extra``.
+ 2. Application Default Credentials (``GOOGLE_APPLICATION_CREDENTIALS``,
+ ``gcloud auth application-default login``, …) when no explicit keys
+ are provided.
+
+ Connection fields:
+ - **extra** JSON::
+
+ {
+ "model": "google-vertex:gemini-2.0-flash",
+ "project_id": "my-gcp-project",
+ "location": "us-central1",
+ "service_account_file": "/path/to/sa.json",
+ }
+
+ :param llm_conn_id: Airflow connection ID.
+ :param model_id: Model identifier, e.g.
``"google-vertex:gemini-2.0-flash"``.
+ """
+
+ conn_name_attr = "llm_conn_id"
+ default_conn_name = "pydantic_ai_vertex_default"
+ conn_type = "pydanticai_vertex"
+ hook_name = "Pydantic AI (Google Vertex AI)"
+
+ @staticmethod
+ def get_ui_field_behaviour() -> dict[str, Any]:
+ """Return custom field behaviour for the Airflow connection form."""
+ return {
+ "hidden_fields": ["schema", "port", "login", "host", "password"],
+ "relabeling": {},
+ "placeholders": {
+ "extra": (
+ '{"model": "google-vertex:gemini-2.0-flash", '
+ '"project_id": "my-project", "location": "us-central1"}'
+ " — leave service_account_file empty for ADC auth"
+ ),
+ },
+ }
+
+ def _get_provider_kwargs(
+ self,
+ api_key: str | None,
+ base_url: str | None,
+ extra: dict[str, Any],
+ ) -> dict[str, Any]:
+ _vertex_keys = ("project_id", "location", "service_account_file",
"service_account_info")
+ return {k: extra[k] for k in _vertex_keys if extra.get(k) is not None}
+
Review Comment:
These kwarg names don't match what pydantic-ai's `GoogleProvider` accepts
(see [docs](https://ai.pydantic.dev/models/google/)):
| Connection extra | pydantic-ai param | Match? |
|---|---|---|
| `project_id` | `project` | **No** |
| `location` | `location` | Yes |
| `service_account_file` | `credentials` (object) | **No** |
| `service_account_info` | `credentials` (object) | **No** |
`project_id` needs to be mapped to `project`.
`service_account_file`/`service_account_info` need to be loaded into a
`google.auth.credentials.Credentials` object before passing — the provider
doesn't accept file paths directly.
Because of the `TypeError` catch in `_provider_factory`, these mismatches
fail silently — the credentials get ignored and it falls back to ADC. Users who
explicitly configure credentials in the connection would have no idea they're
not being used.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]