kaxil commented on code in PR #62816:
URL: https://github.com/apache/airflow/pull/62816#discussion_r2880383612


##########
providers/common/ai/src/airflow/providers/common/ai/builders/base.py:
##########
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Protocol
+
+if TYPE_CHECKING:
+    from pydantic_ai.models import KnownModelName, Model
+
+
+class ProviderBuilder(Protocol):

Review Comment:
   Do we need this abstraction? The current `get_conn()` is a straightforward 
if/else that delegates to pydantic-ai's own `infer_model()`. Adding a Protocol 
+ 3 builder classes + a dispatch loop for what's essentially a single new code 
path (Azure) feels like premature abstraction for a problem that doesn't exist 
yet. If/when we genuinely need pluggable resolution, we can introduce it then.
   
   Also — `ProviderBuilder` is declared as a `Protocol` (structural typing), 
but the concrete classes inherit from it (nominal typing). These are two 
different patterns — if you want an inheritance hierarchy, use `ABC`; if you 
want duck typing, don't inherit from the `Protocol` in the concrete classes. 
Mixing both is confusing for contributors.



##########
providers/common/ai/src/airflow/providers/common/ai/builders/azure_openai.py:
##########
@@ -0,0 +1,99 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+    from pydantic_ai.models import KnownModelName, Model
+
+from airflow.providers.common.ai.builders.base import ProviderBuilder
+
+
+class AzureOpenAIBuilder(ProviderBuilder):
+    """Builds a pydantic-ai Model backed by an AsyncAzureOpenAI client."""
+
+    def supports(self, extra: dict[str, Any], api_key: str | None, base_url: 
str | None) -> bool:
+        """Return True when the connection extra contains 'api_version' for 
Azure."""
+        return bool(extra.get("api_version"))
+
+    def build(
+        self,
+        model_name: str | KnownModelName,
+        extra: dict[str, Any],
+        api_key: str | None,
+        base_url: str | None,
+    ) -> Model:
+        try:
+            from openai import AsyncAzureOpenAI
+            from pydantic_ai.models.openai import OpenAIChatModel
+            from pydantic_ai.providers.openai import OpenAIProvider
+        except ImportError as exc:
+            raise ImportError(
+                "The 'openai' and 'pydantic-ai[openai]' packages are required 
for Azure OpenAI connections. "
+                "Install them with: pip install 'pydantic-ai[openai]'"
+            ) from exc
+
+        api_version: str | None = extra.get("api_version")
+        if not api_version:
+            raise ValueError(
+                "Connection extra must contain 'api_version' for Azure OpenAI. 
"
+                'Example: {"api_version": "2024-07-01-preview"}'
+            )
+
+        if not base_url:
+            raise ValueError(
+                "Connection 'host' must be set to the Azure endpoint. "
+                "Example: https://<resource>.openai.azure.com"
+            )
+
+        client_kwargs: dict[str, Any] = {
+            "api_version": api_version,
+            "azure_endpoint": base_url,
+        }
+        if api_key:
+            client_kwargs["api_key"] = api_key
+
+        azure_deployment: str | None = extra.get("azure_deployment")
+        if azure_deployment:
+            client_kwargs["azure_deployment"] = azure_deployment
+
+        azure_ad_token: str | None = extra.get("azure_ad_token")
+        if azure_ad_token:
+            client_kwargs["azure_ad_token"] = azure_ad_token
+
+        azure_ad_token_provider_path: str | None = 
extra.get("azure_ad_token_provider")
+        if azure_ad_token_provider_path:
+            client_kwargs["azure_ad_token_provider"] = 
self._import_callable(azure_ad_token_provider_path)
+
+        azure_client = AsyncAzureOpenAI(**client_kwargs)
+
+        # Strip provider prefix if present ("openai:gpt-5.2" → "gpt-5.2")
+        slug = model_name.split(":", 1)[-1] if ":" in model_name else 
model_name
+        return OpenAIChatModel(slug, 
provider=OpenAIProvider(openai_client=azure_client))
+
+    @staticmethod
+    def _import_callable(dotted_path: str) -> Any:

Review Comment:
   This calls `importlib.import_module()` on a user-provided string from 
connection extras — module imports can run arbitrary code at import time. 
Connection extras are editable by any user with connection-edit permissions, 
which is a lower-privilege surface than DAG deployment.
   
   If we end up needing a custom token provider path, this should at least be 
documented as a security-sensitive field. But since pydantic-ai's 
`AzureProvider` accepts `api_key` directly, we may not need this at all for the 
initial implementation.



##########
providers/common/ai/src/airflow/providers/common/ai/builders/azure_openai.py:
##########
@@ -0,0 +1,99 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+    from pydantic_ai.models import KnownModelName, Model
+
+from airflow.providers.common.ai.builders.base import ProviderBuilder
+
+
+class AzureOpenAIBuilder(ProviderBuilder):
+    """Builds a pydantic-ai Model backed by an AsyncAzureOpenAI client."""
+
+    def supports(self, extra: dict[str, Any], api_key: str | None, base_url: 
str | None) -> bool:
+        """Return True when the connection extra contains 'api_version' for 
Azure."""
+        return bool(extra.get("api_version"))
+
+    def build(
+        self,
+        model_name: str | KnownModelName,
+        extra: dict[str, Any],
+        api_key: str | None,
+        base_url: str | None,
+    ) -> Model:
+        try:
+            from openai import AsyncAzureOpenAI

Review Comment:
   pydantic-ai already has native Azure support via `AzureProvider` — no need 
to drop down to the raw `openai` SDK:
   
   ```python
   from pydantic_ai.providers.azure import AzureProvider
   from pydantic_ai.models.openai import OpenAIChatModel
   
   provider = AzureProvider(
       azure_endpoint=base_url,
       api_version=api_version,
       api_key=api_key,
   )
   model = OpenAIChatModel(model_name, provider=provider)
   ```
   
   See https://ai.pydantic.dev/models/openai/#azure
   
   By constructing `AsyncAzureOpenAI` directly we're bypassing pydantic-ai's 
own provider abstraction (which may handle retries, error mapping, etc.) and 
coupling ourselves to `openai` SDK internals.
   
   Using `AzureProvider` would also make a separate builder class unnecessary — 
it could just be a few lines in `get_conn()`.



##########
providers/common/ai/src/airflow/providers/common/ai/builders/custom_endpoint.py:
##########
@@ -0,0 +1,92 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+from pydantic_ai.models import infer_model
+from pydantic_ai.providers import Provider, infer_provider, 
infer_provider_class
+
+if TYPE_CHECKING:
+    from pydantic_ai.models import KnownModelName, Model
+
+from airflow.providers.common.ai.builders.base import ProviderBuilder
+
+
+class CustomEndpointBuilder(ProviderBuilder):
+    """
+    Builds a pydantic-ai Model with explicitly overridden credentials or 
endpoints.
+
+    Used for general connections (OpenAI, Gemini, Groq, Ollama) where password 
or host is set.
+    """
+
+    def supports(self, extra: dict[str, Any], api_key: str | None, base_url: 
str | None) -> bool:
+        """Return True if api_key or base_url is explicitly provided."""
+        return bool(api_key or base_url)
+
+    def build(
+        self,
+        model_name: str | KnownModelName,
+        extra: dict[str, Any],
+        api_key: str | None,
+        base_url: str | None,
+    ) -> Model:
+        return infer_model(model_name, 
provider_factory=self._make_provider_factory(api_key, base_url))
+
+    @staticmethod
+    def _make_provider_factory(api_key: str | None, base_url: str | None):
+        """
+        Return a provider factory closure for non-Azure custom endpoints.
+
+        Falls back to default provider resolution when the provider's
+        constructor does not accept `api_key` / `base_url`
+        """
+
+        def _factory(provider_name: str) -> Provider[Any]:
+            provider_cls = infer_provider_class(provider_name)
+            kwargs: dict[str, Any] = {}
+            if api_key:
+                kwargs["api_key"] = api_key
+            if base_url:
+                kwargs["base_url"] = base_url
+            try:
+                return provider_cls(**kwargs)
+            except TypeError:
+                return infer_provider(provider_name)
+
+        return _factory
+
+
+class DefaultBuilder(ProviderBuilder):

Review Comment:
   `DefaultBuilder.build()` is `return infer_model(model_name)` — a single 
call. The existing hook does this in two lines: `if not api_key and not 
base_url: return infer_model(model_name)`. Does this one-liner need its own 
class with Protocol conformance and a `supports()` method?



##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -75,60 +89,62 @@ def get_ui_field_behaviour() -> dict[str, Any]:
             "hidden_fields": ["schema", "port", "login"],
             "relabeling": {"password": "API Key"},
             "placeholders": {
-                "host": "https://api.openai.com/v1 (optional, for custom 
endpoints)",
-                "extra": '{"model": "openai:gpt-5.3"}',
+                "host": (
+                    "https://api.openai.com/v1  — or Azure endpoint: 
https://<resource>.openai.azure.com"
+                ),
+                "extra": ('{"model": "openai:gpt-5.3"}  — Azure: also add 
"api_version", "azure_deployment"'),
             },
         }
 
+    # ------------------------------------------------------------------
+    # Core connection / agent API
+    # ------------------------------------------------------------------
+
     def get_conn(self) -> Model:
         """
         Return a configured pydantic-ai Model.
 
-        Reads API key from connection password, model from connection extra
-        or ``model_id`` parameter, and base_url from connection host.
-        The result is cached for the lifetime of this hook instance.
+        Resolution is delegated to builders based on connection 
characteristics:
+
+        1. **Azure OpenAI** — when ``api_version`` is present in connection 
extra.
+        2. **Custom endpoint** — when ``password`` or ``host`` are set.
+        3. **Default resolution** — delegates to pydantic-ai ``infer_model``.
+
+        The resolved model is cached for the lifetime of this hook instance.
         """
         if self._model is not None:
             return self._model
 
+        from airflow.providers.common.ai.builders import (
+            AzureOpenAIBuilder,
+            CustomEndpointBuilder,
+            DefaultBuilder,
+            ProviderBuilder,
+        )
+
         conn = self.get_connection(self.llm_conn_id)
-        model_name: str | KnownModelName = self.model_id or 
conn.extra_dejson.get("model", "")
+        extra: dict[str, Any] = conn.extra_dejson
+        model_name: str | KnownModelName = self.model_id or extra.get("model", 
"")
         if not model_name:
             raise ValueError(
                 "No model specified. Set model_id on the hook or 'model' in 
the connection's extra JSON."
             )
-        api_key = conn.password
-        base_url = conn.host or None
 
-        if not api_key and not base_url:
-            # No credentials to inject — use default provider resolution
-            # (picks up env vars like OPENAI_API_KEY, AWS_PROFILE, etc.)
-            self._model = infer_model(model_name)
-            return self._model
+        api_key: str | None = conn.password or None
+        base_url: str | None = conn.host or None
+
+        builders: list[ProviderBuilder] = [
+            AzureOpenAIBuilder(),
+            CustomEndpointBuilder(),
+            DefaultBuilder(),
+        ]
+
+        for builder in builders:
+            if builder.supports(extra, api_key, base_url):
+                self._model = builder.build(model_name, extra, api_key, 
base_url)
+                return self._model
 
-        def _provider_factory(provider_name: str) -> Provider[Any]:
-            """
-            Create a provider with credentials from the Airflow connection.
-
-            Falls back to default provider resolution if the provider's 
constructor
-            doesn't accept api_key/base_url (e.g. Google Vertex, Bedrock).
-            """
-            provider_cls = infer_provider_class(provider_name)
-            kwargs: dict[str, Any] = {}
-            if api_key:
-                kwargs["api_key"] = api_key
-            if base_url:
-                kwargs["base_url"] = base_url
-            try:
-                return provider_cls(**kwargs)
-            except TypeError:
-                # Provider doesn't accept these kwargs (e.g. Google Vertex/GLA
-                # use ADC, Bedrock uses boto session). Fall back to default
-                # provider resolution which reads credentials from the 
environment.
-                return infer_provider(provider_name)
-
-        self._model = infer_model(model_name, 
provider_factory=_provider_factory)
-        return self._model
+        raise RuntimeError("No suitable ProviderBuilder found to construct the 
model.")

Review Comment:
   This line is unreachable — `DefaultBuilder.supports()` always returns 
`True`, so the loop will always match on the third iteration. Dead code that 
implies the loop might not match, which could mislead future readers.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to