cetingokhan commented on code in PR #62816:
URL: https://github.com/apache/airflow/pull/62816#discussion_r2900085520
##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -157,13 +194,266 @@ def test_connection(self) -> tuple[bool, str]:
"""
Test connection by resolving the model.
- Validates that the model string is valid, the provider package is
- installed, and the provider class can be instantiated. Does NOT make an
- LLM API call — that would be expensive, flaky, and fail for reasons
- unrelated to connectivity (quotas, billing, rate limits).
+ Validates that the model string is valid and the provider class can be
+ instantiated with the supplied credentials. Does NOT make an LLM API
+ call — that would be expensive and fail for reasons unrelated to
+ connectivity (quotas, billing, rate limits).
"""
try:
self.get_conn()
return True, "Model resolved successfully."
except Exception as e:
return False, str(e)
+
+ @classmethod
+ def for_connection(cls, conn_id: str, model_id: str | None = None) ->
PydanticAIHook:
+ """
+ Return the correct :class:`PydanticAIHook` subclass for *conn_id*.
+
+ Looks up the connection's ``conn_type`` in the registered hook map and
+ instantiates the matching subclass. Falls back to
+ :class:`PydanticAIHook` for unknown types.
+
+ :param conn_id: Airflow connection ID.
+ :param model_id: Optional model override forwarded to the hook.
+ """
+ conn = cls.get_connection(conn_id)
+ hook_cls = _CONN_TYPE_TO_HOOK.get(conn.conn_type or "", cls)
+ return hook_cls(llm_conn_id=conn_id, model_id=model_id)
+
+
+class PydanticAIAzureHook(PydanticAIHook):
+ """
+ Hook for Azure OpenAI via pydantic-ai.
+
+ Connection fields:
+ - **password**: Azure API key
+ - **host**: Azure endpoint (e.g.
``https://<resource>.openai.azure.com``)
+ - **extra** JSON::
+
+ {"model": "azure:gpt-4o", "api_version": "2024-07-01-preview"}
+
+ :param llm_conn_id: Airflow connection ID.
+ :param model_id: Model identifier, e.g. ``"azure:gpt-4o"``.
+ """
+
+ conn_type = "pydanticai_azure"
+ hook_name = "Pydantic AI (Azure OpenAI)"
+
+ @staticmethod
+ def get_ui_field_behaviour() -> dict[str, Any]:
+ """Return custom field behaviour for the Airflow connection form."""
+ return {
+ "hidden_fields": ["schema", "port", "login"],
+ "relabeling": {"password": "API Key", "host": "Azure Endpoint"},
+ "placeholders": {
+ "host": "https://<resource>.openai.azure.com",
+ "extra": '{"model": "azure:gpt-4o", "api_version":
"2024-07-01-preview"}',
+ },
+ }
+
+ def _get_provider_kwargs(
+ self,
+ api_key: str | None,
+ base_url: str | None,
+ extra: dict[str, Any],
+ ) -> dict[str, Any]:
+ kwargs: dict[str, Any] = {}
+ if api_key:
+ kwargs["api_key"] = api_key
+ if base_url:
+ kwargs["azure_endpoint"] = base_url
+ if extra.get("api_version"):
+ kwargs["api_version"] = extra["api_version"]
+ return kwargs
+
+
+class PydanticAIBedrockHook(PydanticAIHook):
+ """
+ Hook for AWS Bedrock via pydantic-ai.
+
+ Credentials are resolved in order:
+
+ 1. IAM keys from ``extra`` (``aws_access_key_id`` +
``aws_secret_access_key``,
+ optionally ``aws_session_token``).
+ 2. Bearer token in ``extra`` (``api_key``, maps to env
``AWS_BEARER_TOKEN_BEDROCK``).
+ 3. Environment-variable / instance-role chain (``AWS_PROFILE``, IAM role,
…)
+ when no explicit keys are provided.
+
+ Connection fields:
+ - **extra** JSON::
+
+ {
+ "model": "bedrock:us.anthropic.claude-opus-4-5",
+ "region_name": "us-east-1",
+ "aws_access_key_id": "AKIA...",
+ "aws_secret_access_key": "...",
+ "aws_session_token": "...",
+ "profile_name": "my-aws-profile",
+ "api_key": "bearer-token",
+ "base_url": "https://custom-bedrock-endpoint",
+ "aws_read_timeout": 60.0,
+ "aws_connect_timeout": 10.0
+ }
+
+ Leave ``aws_access_key_id`` / ``aws_secret_access_key`` and
``api_key``
+ empty to use the default AWS credential chain.
+
+ :param llm_conn_id: Airflow connection ID.
+ :param model_id: Model identifier, e.g.
``"bedrock:us.anthropic.claude-opus-4-5"``.
+ """
+
+ conn_type = "pydanticai_bedrock"
+ hook_name = "Pydantic AI (AWS Bedrock)"
+
+ @staticmethod
+ def get_ui_field_behaviour() -> dict[str, Any]:
+ """Return custom field behaviour for the Airflow connection form."""
+ return {
+ "hidden_fields": ["schema", "port", "login", "host", "password"],
+ "relabeling": {},
+ "placeholders": {
+ "extra": (
+ '{"model": "bedrock:us.anthropic.claude-opus-4-5", '
+ '"region_name": "us-east-1"}'
+ " — leave aws_access_key_id empty for IAM role / env-var
auth"
+ ),
+ },
+ }
+
+ def _get_provider_kwargs(
Review Comment:
Added comment which you point in option1
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]