cetingokhan commented on code in PR #62816:
URL: https://github.com/apache/airflow/pull/62816#discussion_r2901762480


##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -75,62 +78,89 @@ def get_ui_field_behaviour() -> dict[str, Any]:
             "hidden_fields": ["schema", "port", "login"],
             "relabeling": {"password": "API Key"},
             "placeholders": {
-                "host": "https://api.openai.com/v1 (optional, for custom 
endpoints)",
+                "host": "https://api.openai.com/v1  (optional, for custom 
endpoints / Ollama)",
+                "extra": '{"model": "openai:gpt-5.3"}',
             },
         }
 
+    # ------------------------------------------------------------------
+    # Core connection / agent API
+    # ------------------------------------------------------------------
+
+    def _get_provider_kwargs(
+        self,
+        api_key: str | None,
+        base_url: str | None,
+        extra: dict[str, Any],
+    ) -> dict[str, Any]:
+        """
+        Return the kwargs to pass to the provider constructor.
+
+        Subclasses override this method to map their connection fields to the
+        parameters expected by their specific provider class.  The base
+        implementation handles the common ``api_key`` / ``base_url`` pattern
+        used by OpenAI, Anthropic, Groq, Mistral, Ollama, and most other
+        providers.
+
+        :param api_key: Value of ``conn.password``.
+        :param base_url: Value of ``conn.host``.
+        :param extra: Deserialized ``conn.extra`` JSON.
+        :return: Kwargs forwarded to ``provider_cls(**kwargs)``.  Empty dict
+            signals that no explicit credentials are available and the hook
+            should fall back to environment-variable–based auth.
+        """
+        kwargs: dict[str, Any] = {}
+        if api_key:
+            kwargs["api_key"] = api_key
+        if base_url:
+            kwargs["base_url"] = base_url
+        return kwargs
+
     def get_conn(self) -> Model:
         """
-        Return a configured pydantic-ai Model.
+        Return a configured pydantic-ai ``Model``.
 
-        Reads API key from connection password, base_url from connection host,
-        and model from (in priority order):
+        Resolution order:
 
-        1. ``model_id`` parameter on the hook
-        2. ``extra["model"]`` on the connection (set by the "Model" conn-field 
in the UI)
+        1. **Explicit credentials** — when :meth:`_get_provider_kwargs` returns
+           a non-empty dict the provider class is instantiated with those 
kwargs
+           and wrapped in a ``provider_factory``.
+        2. **Default resolution** — delegates to pydantic-ai ``infer_model``
+           which reads standard env vars (``OPENAI_API_KEY``, ``AWS_PROFILE``, 
…).
 
-        The result is cached for the lifetime of this hook instance.
+        The resolved model is cached for the lifetime of this hook instance.
         """
         if self._model is not None:
             return self._model
 
         conn = self.get_connection(self.llm_conn_id)
-        model_name: str | KnownModelName = self.model_id or 
conn.extra_dejson.get("model", "")
+
+        extra: dict[str, Any] = conn.extra_dejson
+        model_name: str | KnownModelName = self.model_id or extra.get("model", 
"")
         if not model_name:
             raise ValueError(
                 "No model specified. Set model_id on the hook or the Model 
field on the connection."
             )
-        api_key = conn.password
-        base_url = conn.host or None
 
-        if not api_key and not base_url:
-            # No credentials to inject — use default provider resolution
-            # (picks up env vars like OPENAI_API_KEY, AWS_PROFILE, etc.)
-            self._model = infer_model(model_name)
+        api_key: str | None = conn.password or None
+        base_url: str | None = conn.host or None
+
+        provider_kwargs = self._get_provider_kwargs(api_key, base_url, extra)
+        if provider_kwargs:
+            _kwargs = provider_kwargs  # capture for closure
+            self.log.info(
+                "Using explicit credentials for provider with model '%s': %s",
+                model_name,
+                list(provider_kwargs),
+            )
+
+            def _provider_factory(pname: str) -> Any:
+                return infer_provider_class(pname)(**_kwargs)
+
+            self._model = infer_model(model_name, 
provider_factory=_provider_factory)
             return self._model

Review Comment:
   
   ```
   
              def _provider_factory(pname: str) -> Any:
                   try:
                       return infer_provider_class(pname)(**_kwargs)
                   except TypeError:
                       return infer_provider(pname)
   ```
   I've added the try/except TypeError fallback and the infer_provider import 
back as suggested.
   
   Regarding the warning log: I haven't added it yet because we previously 
discussed keeping the logs clean and removed it in the earlier rounds. However, 
I’m a bit hesitant about a "silent fallback."
   
   Don't you think it would be more helpful for the user to know that their 
connection credentials were rejected and the hook is falling back to 
environment variables? If it’s silent, they might struggle to understand why 
their Airflow connection settings aren't being applied.
   
   What do you think—should we keep it silent or add a simple warning log?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to