This is an automated email from the ASF dual-hosted git repository.
kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c0911ae7571 Add `LLMRetryPolicy` to common-ai provider (#65451)
c0911ae7571 is described below
commit c0911ae7571297e45ea7f36f827a4ad5fbccb2d8
Author: Kaxil Naik <[email protected]>
AuthorDate: Mon May 18 18:27:10 2026 +0100
Add `LLMRetryPolicy` to common-ai provider (#65451)
Uses PydanticAIHook to call any LLM for error classification with
structured output. Timeout via pydantic-ai ModelSettings (default 30s).
Falls back to declarative fallback_rules when LLM call fails.
Gated for Airflow 3.3+. RST docs, example DAG, 12 tests.
Add a Custom Instructions section showing how to override the default
classifier prompt with domain-specific guidance. Uses Snowflake as the
example since its transient errors (Warehouse suspended, JWT expired,
Statement queued) need backend-specific knowledge to classify correctly.
* Fix LLMRetryPolicy CI: mypy, ruff, spellcheck, compat tests
- Type-annotate llm_policy as Optional in the example DAG so the
fallback branch satisfies mypy on Airflow versions without RetryPolicy.
- Wrap pydanticai* and claude-haiku* connection examples in
double-backticks so RST autoapi treats them as code (skips spellcheck).
- Gate test_retry.py with pytest.importorskip so compat-3.2.1 CI on
older Airflow versions skips the module instead of erroring on import.
- Apply ruff and ruff-format auto-fixes.
* Fix LLMRetryPolicy CI: license header, sphinx duplicate llm_policy
- Add Apache license header to the empty policies test __init__.py
(insert-license + end-of-file-fixer were failing on it).
- Move the @dag definition inside the try block in the example DAG so
llm_policy is only assigned in one branch. Sphinx autoapi was
treating the upfront None declaration plus the in-try LLMRetryPolicy
assignment as two separate object descriptions, failing the docs
build with a duplicate-object warning (treated as error). Without
the upfront declaration, mypy is happy because the variable only
exists when the import succeeds.
---
providers/common/ai/docs/index.rst | 1 +
providers/common/ai/docs/retry_policies.rst | 170 ++++++++++++++++++
.../ai/example_dags/example_llm_retry_policy.py | 72 ++++++++
.../providers/common/ai/policies/__init__.py | 16 ++
.../airflow/providers/common/ai/policies/retry.py | 183 +++++++++++++++++++
.../ai/tests/unit/common/ai/policies/__init__.py | 16 ++
.../ai/tests/unit/common/ai/policies/test_retry.py | 197 +++++++++++++++++++++
7 files changed, 655 insertions(+)
diff --git a/providers/common/ai/docs/index.rst
b/providers/common/ai/docs/index.rst
index e96ba4cfd27..a5cd4196f7a 100644
--- a/providers/common/ai/docs/index.rst
+++ b/providers/common/ai/docs/index.rst
@@ -39,6 +39,7 @@
Hooks <hooks/pydantic_ai>
Toolsets <toolsets>
Operators <operators/index>
+ Retry Policies <retry_policies>
HITL Review <hitl_review>
.. toctree::
diff --git a/providers/common/ai/docs/retry_policies.rst
b/providers/common/ai/docs/retry_policies.rst
new file mode 100644
index 00000000000..036bc8e5908
--- /dev/null
+++ b/providers/common/ai/docs/retry_policies.rst
@@ -0,0 +1,170 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+LLM Retry Policies
+===================
+
+.. versionadded:: 3.3.0
+
+The ``LLMRetryPolicy`` uses an LLM to classify task errors and make intelligent
+retry decisions. It works with any LLM provider supported by pydantic-ai
+(OpenAI, Anthropic, Bedrock, Vertex, Ollama, etc.).
+
+For the core retry policy concepts, see
:doc:`apache-airflow:core-concepts/tasks`.
+
+Setup
+-----
+
+1. Install the provider with the LLM backend you need:
+
+ .. code-block:: bash
+
+ pip install 'apache-airflow-providers-common-ai[anthropic]'
+
+2. Create a connection (``Admin > Connections``):
+
+ - **Connection Id**: ``pydanticai_default``
+ - **Connection Type**: ``Pydantic AI``
+ - **Password**: Your API key
+ - **Extra**: ``{"model": "anthropic:claude-haiku-4-5-20251001"}``
+
+Usage
+-----
+
+.. code-block:: python
+
+ from airflow.providers.common.ai.policies.retry import LLMRetryPolicy
+ from airflow.sdk.definitions.retry_policy import RetryAction, RetryRule
+ from datetime import timedelta
+
+ llm_policy = LLMRetryPolicy(
+ llm_conn_id="pydanticai_default",
+ timeout=30.0, # max seconds to wait for LLM response
+ fallback_rules=[ # used when LLM call fails
+ RetryRule(exception=ConnectionError, action=RetryAction.RETRY,
retry_delay=timedelta(seconds=10)),
+ RetryRule(exception=PermissionError, action=RetryAction.FAIL),
+ ],
+ )
+
+
+ @task(retries=5, retry_policy=llm_policy)
+ def call_external_api(): ...
+
+How it works
+------------
+
+When a task fails, ``LLMRetryPolicy``:
+
+1. Sends the exception message to the configured LLM
+2. The LLM classifies the error into a category (``rate_limit``, ``auth``,
+ ``network``, ``data``, ``transient``, ``permanent``)
+3. Based on the classification, returns RETRY (with a suggested delay) or FAIL
+4. The classification reason is logged in the task logs
+
+If the LLM call fails (provider down, timeout, bad credentials), the policy
+falls back to ``fallback_rules`` if configured, or to the task's standard
+retry behaviour.
+
+Custom instructions
+-------------------
+
+The default classifier handles generic categories. For domain-specific
+behaviour, override ``instructions`` to inject your own taxonomy. The LLM still
+returns an
:class:`~airflow.providers.common.ai.policies.retry.ErrorClassification`
+(``category``, ``should_retry``, ``suggested_delay_seconds``, ``reasoning``)
+-- only the prompt changes.
+
+.. code-block:: python
+
+ SNOWFLAKE_INSTRUCTIONS = (
+ "You are an error classifier for Snowflake-backed data pipelines. "
+ "Classify the error into one of: rate_limit, auth, network, data, "
+ "transient, permanent.\n\n"
+ "Snowflake-specific guidance:\n"
+ "- 'Statement queued' or 'concurrency limit' -> rate_limit, retry
after 120s\n"
+ "- 'JWT token expired' -> transient (token rotates), retry after 30s\n"
+ "- 'Authentication token has expired' AFTER multiple retries -> auth,
do NOT retry\n"
+ "- 'Column does not exist' -> data, do NOT retry (schema drift needs
human fix)\n"
+ "- 'Warehouse suspended' -> transient, retry after 30s
(auto-resume)\n\n"
+ "Set suggested_delay_seconds based on the error type. "
+ "Set 0 for errors that should not retry."
+ )
+
+ snowflake_policy = LLMRetryPolicy(
+ llm_conn_id="pydanticai_default",
+ instructions=SNOWFLAKE_INSTRUCTIONS,
+ fallback_rules=[
+ RetryRule(
+ exception=ConnectionError,
+ action=RetryAction.RETRY,
+ retry_delay=timedelta(seconds=30),
+ ),
+ ],
+ )
+
+
+ @task(retries=5, retry_policy=snowflake_policy)
+ def query_snowflake(): ...
+
+When writing custom instructions:
+
+- The LLM must return the same ``ErrorClassification`` schema (``category``,
+ ``should_retry``, ``suggested_delay_seconds``, ``reasoning``). Mention the
+ fields explicitly so the model fills them.
+- Be concrete with examples (``"'Warehouse suspended' -> transient"``) rather
+ than vague rules ("treat warehouse issues as recoverable").
+- ``retry_reason`` is truncated to 500 chars in the audit log -- keep
+ ``reasoning`` outputs concise.
+
+Parameters
+----------
+
+.. list-table::
+ :header-rows: 1
+ :widths: 20 15 65
+
+ * - Parameter
+ - Default
+ - Description
+ * - ``llm_conn_id``
+ - (required)
+ - Airflow connection ID for the LLM provider.
+ * - ``model_id``
+ - None
+ - Override the model from the connection (e.g., ``"openai:gpt-4o-mini"``).
+ * - ``instructions``
+ - (built-in)
+ - Custom system prompt for error classification.
+ * - ``fallback_rules``
+ - None
+ - List of ``RetryRule`` objects used when the LLM call fails.
+ * - ``timeout``
+ - 30.0
+ - Max seconds to wait for the LLM response before falling back.
+
+Local LLM support
+-----------------
+
+For environments where exception data must not leave the infrastructure, point
+to a local model via Ollama or vLLM:
+
+.. code-block:: python
+
+ LLMRetryPolicy(
+ llm_conn_id="ollama_local", # host=http://localhost:11434
+ model_id="ollama:llama3.2",
+ )
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_retry_policy.py
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_retry_policy.py
new file mode 100644
index 00000000000..bdd2528dd28
--- /dev/null
+++
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_retry_policy.py
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example DAG demonstrating LLM-powered retry policies.
+
+Uses an LLM (via PydanticAIHook) to classify errors and decide whether
+to retry, fail immediately, or retry with a custom delay.
+
+Prerequisites:
+ - Connection ``pydanticai_default`` with ``conn_type='pydanticai'``,
+ ``password=<API key>``, ``extra='{"model":
"anthropic:claude-haiku-4-5-20251001"}'``
+ - ``pip install apache-airflow-providers-common-ai[anthropic]``
+"""
+
+from __future__ import annotations
+
+from datetime import timedelta
+
+from airflow.providers.common.compat.sdk import dag, task
+
+try:
+ from airflow.providers.common.ai.policies.retry import LLMRetryPolicy
+ from airflow.sdk.definitions.retry_policy import RetryAction, RetryRule
+
+ llm_policy = LLMRetryPolicy(
+ llm_conn_id="pydanticai_default",
+ timeout=30.0,
+ fallback_rules=[
+ RetryRule(exception=ConnectionError, action=RetryAction.RETRY,
retry_delay=timedelta(seconds=10)),
+ RetryRule(exception=PermissionError, action=RetryAction.FAIL),
+ ],
+ )
+
+ @dag(catchup=False, tags=["example", "retry_policy", "llm"])
+ def example_llm_retry_policy():
+ @task(retries=3, retry_delay=timedelta(minutes=1),
retry_policy=llm_policy)
+ def task_auth_error():
+ """LLM should classify as auth -> FAIL immediately."""
+ raise PermissionError("403 Forbidden: API key expired for service
account [email protected]")
+
+ @task(retries=3, retry_delay=timedelta(minutes=1),
retry_policy=llm_policy)
+ def task_rate_limit():
+ """LLM should classify as rate_limit -> RETRY with ~60s delay."""
+ raise RuntimeError("429 Too Many Requests: Rate limit exceeded.
Retry after 60 seconds.")
+
+ @task(retries=3, retry_delay=timedelta(minutes=1),
retry_policy=llm_policy)
+ def task_data_error():
+ """LLM should classify as data -> FAIL immediately."""
+ raise ValueError("Column 'user_id' expected type INT but got
STRING in row 42.")
+
+ task_auth_error()
+ task_rate_limit()
+ task_data_error()
+
+ example_llm_retry_policy()
+except ImportError:
+ # RetryPolicy requires Airflow 3.3+; example DAG is skipped on older
versions.
+ pass
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/policies/__init__.py
b/providers/common/ai/src/airflow/providers/common/ai/policies/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/policies/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/policies/retry.py
b/providers/common/ai/src/airflow/providers/common/ai/policies/retry.py
new file mode 100644
index 00000000000..f92e4e0d64f
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/policies/retry.py
@@ -0,0 +1,183 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+LLM-powered retry policy using pydantic-ai for error classification.
+
+Requires Airflow 3.3+ (RetryPolicy was added in AIP-105).
+"""
+
+from __future__ import annotations
+
+import logging
+from datetime import timedelta
+from typing import TYPE_CHECKING
+
+from pydantic import BaseModel
+
+try:
+ from airflow.sdk.definitions.retry_policy import (
+ ExceptionRetryPolicy,
+ RetryDecision,
+ RetryPolicy,
+ )
+except ImportError:
+ raise ImportError(
+ "LLMRetryPolicy requires Airflow 3.3+ which includes RetryPolicy
support. "
+ "Please upgrade apache-airflow-core."
+ ) from None
+
+if TYPE_CHECKING:
+ from airflow.sdk.definitions.context import Context
+ from airflow.sdk.definitions.retry_policy import RetryRule
+
+log = logging.getLogger(__name__)
+
+__all__ = ["ErrorClassification", "LLMRetryPolicy"]
+
+DEFAULT_INSTRUCTIONS = (
+ "You are an error classifier for a data pipeline system. "
+ "Given an error message from a failed task, classify it into one of these
categories:\n\n"
+ "- rate_limit: API throttling or quota exceeded. Should retry after a
delay.\n"
+ "- auth: Credentials invalid, expired, or missing permissions. Should NOT
retry.\n"
+ "- network: Transient connectivity issue. Should retry quickly.\n"
+ "- data: Schema validation, type mismatch, or bad input data. Should NOT
retry.\n"
+ "- resource: Resource not found or unavailable (e.g., missing table,
bucket). Should NOT retry.\n"
+ "- transient: Temporary issue likely to resolve on its own. Should
retry.\n"
+ "- permanent: Problem that won't resolve without code or config changes.
Should NOT retry.\n\n"
+ "Set suggested_delay_seconds based on the error type: "
+ "60 for rate limits, 10 for network, 30 for transient. "
+ "Set 0 for errors that should not retry."
+)
+
+
+class ErrorClassification(BaseModel):
+ """Structured LLM output for error classification."""
+
+ category: str
+ """One of: rate_limit, auth, network, data, resource, transient,
permanent."""
+ should_retry: bool
+ """Whether the operation should be retried."""
+ suggested_delay_seconds: int = 0
+ """How long to wait before retrying (0 if should_retry is False)."""
+ reasoning: str
+ """Brief explanation of the classification decision."""
+
+
+class LLMRetryPolicy(RetryPolicy):
+ """
+ Retry policy that uses an LLM to classify errors and decide retry
behaviour.
+
+ Uses :class:`~airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook`
+ to call any configured LLM provider (OpenAI, Anthropic, Bedrock, Vertex,
+ Ollama, etc.) for error classification with structured output.
+
+ When the LLM call itself fails, the policy falls back to ``fallback_rules``
+ (if provided) or returns DEFAULT to use the task's standard retry logic.
+
+ :param llm_conn_id: Airflow connection ID for the LLM provider.
+ :param model_id: Model identifier override (e.g. ``"openai:gpt-4o-mini"``
+ for cost efficiency). If not set, uses the model from the connection.
+ :param instructions: Custom system prompt for classification.
+ Defaults to a general-purpose error classifier.
+ :param fallback_rules: Optional list of
+ :class:`~airflow.sdk.definitions.retry_policy.RetryRule` applied when
the
+ LLM call fails. Provides a deterministic safety net.
+ :param timeout: Maximum seconds to wait for the LLM response before
+ falling back. Defaults to 30s. The LLM provider's own timeout
+ (e.g. 600s for Anthropic) is much longer; this keeps the retry
+ decision path fast even when the provider is degraded.
+ """
+
+ def __init__(
+ self,
+ llm_conn_id: str,
+ model_id: str | None = None,
+ instructions: str | None = None,
+ fallback_rules: list[RetryRule] | None = None,
+ timeout: float = 30.0,
+ ) -> None:
+ self.llm_conn_id = llm_conn_id
+ self.model_id = model_id
+ self.instructions = instructions or DEFAULT_INSTRUCTIONS
+ self.fallback_rules = fallback_rules
+ self.timeout = timeout
+
+ def evaluate(
+ self,
+ exception: BaseException,
+ try_number: int,
+ max_tries: int,
+ context: Context | None = None,
+ ) -> RetryDecision:
+ try:
+ return self._classify(exception, try_number, max_tries)
+ except Exception:
+ log.exception("LLM retry classification failed, using fallback")
+ if self.fallback_rules:
+ return
ExceptionRetryPolicy(rules=self.fallback_rules).evaluate(
+ exception, try_number, max_tries, context
+ )
+ return RetryDecision.default()
+
+ def _classify(
+ self,
+ exception: BaseException,
+ try_number: int,
+ max_tries: int,
+ ) -> RetryDecision:
+ from airflow.providers.common.ai.hooks.pydantic_ai import
PydanticAIHook
+
+ hook = PydanticAIHook(llm_conn_id=self.llm_conn_id,
model_id=self.model_id)
+ agent = hook.create_agent(
+ output_type=ErrorClassification,
+ instructions=self.instructions,
+ )
+
+ prompt = (
+ f"Classify this error from a data pipeline task "
+ f"(attempt {try_number} of {max_tries}):\n\n"
+ f"{type(exception).__name__}: {exception}"
+ )
+
+ from pydantic_ai.settings import ModelSettings
+
+ result = agent.run_sync(
+ prompt,
+ model_settings=ModelSettings(timeout=self.timeout),
+ )
+ classification = result.output
+
+ log.info(
+ "LLM error classification: category=%s, should_retry=%s,
delay=%ds, reasoning=%s",
+ classification.category,
+ classification.should_retry,
+ classification.suggested_delay_seconds,
+ classification.reasoning,
+ )
+
+ if not classification.should_retry:
+ return RetryDecision.fail(reason=f"{classification.category}:
{classification.reasoning}")
+
+ delay = (
+ timedelta(seconds=classification.suggested_delay_seconds)
+ if classification.suggested_delay_seconds > 0
+ else None
+ )
+ return RetryDecision.retry(
+ delay=delay,
+ reason=f"{classification.category}: {classification.reasoning}",
+ )
diff --git a/providers/common/ai/tests/unit/common/ai/policies/__init__.py
b/providers/common/ai/tests/unit/common/ai/policies/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/policies/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/providers/common/ai/tests/unit/common/ai/policies/test_retry.py
b/providers/common/ai/tests/unit/common/ai/policies/test_retry.py
new file mode 100644
index 00000000000..6f9d976d6f1
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/policies/test_retry.py
@@ -0,0 +1,197 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import timedelta
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+# LLMRetryPolicy depends on the RetryPolicy ABC introduced in Airflow 3.3
(AIP-105).
+# Skip the entire test module on older Airflow versions tested in compat CI.
+pytest.importorskip("airflow.sdk.definitions.retry_policy",
reason="RetryPolicy requires Airflow 3.3+")
+
+from airflow.providers.common.ai.policies.retry import (
+ ErrorClassification,
+ LLMRetryPolicy,
+)
+from airflow.sdk.definitions.retry_policy import RetryAction, RetryRule
+
+
+def _make_mock_agent(category, should_retry, delay=0, reasoning="test"):
+ """Create a mock agent that returns a canned ErrorClassification."""
+ mock_result = MagicMock()
+ mock_result.output = ErrorClassification(
+ category=category,
+ should_retry=should_retry,
+ suggested_delay_seconds=delay,
+ reasoning=reasoning,
+ )
+ mock_agent = MagicMock()
+ mock_agent.run_sync.return_value = mock_result
+ return mock_agent
+
+
+class TestLLMClassifyDecisions:
+ """Test that _classify maps LLM classification to correct
RetryDecisions."""
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook",
autospec=True)
+ def test_auth_error_returns_fail(self, mock_hook_cls):
+ mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent(
+ "auth", should_retry=False, reasoning="API key expired"
+ )
+ policy = LLMRetryPolicy(llm_conn_id="test")
+ decision = policy.evaluate(PermissionError("403"), try_number=1,
max_tries=3)
+
+ assert decision.action == RetryAction.FAIL
+ assert "auth" in decision.reason
+ assert "API key expired" in decision.reason
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook",
autospec=True)
+ def test_rate_limit_returns_retry_with_delay(self, mock_hook_cls):
+ mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent(
+ "rate_limit", should_retry=True, delay=60, reasoning="429"
+ )
+ policy = LLMRetryPolicy(llm_conn_id="test")
+ decision = policy.evaluate(RuntimeError("429"), try_number=1,
max_tries=3)
+
+ assert decision.action == RetryAction.RETRY
+ assert decision.retry_delay == timedelta(seconds=60)
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook",
autospec=True)
+ def test_transient_retry_with_zero_delay_uses_default(self, mock_hook_cls):
+ """suggested_delay_seconds=0 means use the task's default delay, not
override."""
+ mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent(
+ "transient", should_retry=True, delay=0
+ )
+ policy = LLMRetryPolicy(llm_conn_id="test")
+ decision = policy.evaluate(RuntimeError("glitch"), try_number=1,
max_tries=3)
+
+ assert decision.action == RetryAction.RETRY
+ assert decision.retry_delay is None # None = use task's default
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook",
autospec=True)
+ def test_negative_delay_treated_as_no_override(self, mock_hook_cls):
+ """Negative delay from LLM should not produce a negative timedelta."""
+ mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent(
+ "transient", should_retry=True, delay=-5
+ )
+ policy = LLMRetryPolicy(llm_conn_id="test")
+ decision = policy.evaluate(RuntimeError("x"), try_number=1,
max_tries=3)
+
+ assert decision.action == RetryAction.RETRY
+ assert decision.retry_delay is None
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook",
autospec=True)
+ def test_prompt_includes_exception_type_and_message(self, mock_hook_cls):
+ mock_agent = _make_mock_agent("data", should_retry=False)
+ mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+ policy = LLMRetryPolicy(llm_conn_id="test")
+ policy.evaluate(ValueError("bad column type"), try_number=2,
max_tries=5)
+
+ prompt = mock_agent.run_sync.call_args[0][0]
+ assert "ValueError: bad column type" in prompt
+ assert "attempt 2 of 5" in prompt
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook",
autospec=True)
+ def test_custom_instructions_forwarded_to_agent(self, mock_hook_cls):
+ mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("x", False)
+
+ policy = LLMRetryPolicy(llm_conn_id="test", instructions="My custom
prompt")
+ policy.evaluate(ValueError("x"), try_number=1, max_tries=3)
+
+ mock_hook_cls.return_value.create_agent.assert_called_once_with(
+ output_type=ErrorClassification,
+ instructions="My custom prompt",
+ )
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook",
autospec=True)
+ def test_timeout_passed_via_model_settings(self, mock_hook_cls):
+ mock_agent = _make_mock_agent("auth", False)
+ mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+ policy = LLMRetryPolicy(llm_conn_id="test", timeout=15.0)
+ policy.evaluate(ValueError("x"), try_number=1, max_tries=3)
+
+ model_settings = mock_agent.run_sync.call_args.kwargs["model_settings"]
+ assert model_settings["timeout"] == 15.0
+
+
+class TestLLMFallbackBehaviour:
+ """Test fallback when the LLM call itself fails."""
+
+ def test_falls_back_to_rules_when_connection_missing(self):
+ policy = LLMRetryPolicy(
+ llm_conn_id="nonexistent",
+ fallback_rules=[
+ RetryRule(
+ exception=ConnectionError, action=RetryAction.RETRY,
retry_delay=timedelta(seconds=10)
+ ),
+ RetryRule(exception=PermissionError, action=RetryAction.FAIL,
reason="auth fallback"),
+ ],
+ )
+ d = policy.evaluate(ConnectionError("refused"), try_number=1,
max_tries=3)
+ assert d.action == RetryAction.RETRY
+ assert d.retry_delay == timedelta(seconds=10)
+
+ d = policy.evaluate(PermissionError("denied"), try_number=1,
max_tries=3)
+ assert d.action == RetryAction.FAIL
+
+ def test_falls_back_to_default_when_no_rules(self):
+ policy = LLMRetryPolicy(llm_conn_id="nonexistent")
+ d = policy.evaluate(ValueError("bad"), try_number=1, max_tries=3)
+ assert d.action == RetryAction.DEFAULT
+
+ def test_fallback_rules_no_match_returns_default(self):
+ """When fallback rules exist but none match, DEFAULT is returned."""
+ policy = LLMRetryPolicy(
+ llm_conn_id="nonexistent",
+ fallback_rules=[
+ RetryRule(exception=PermissionError, action=RetryAction.FAIL),
+ ],
+ )
+ # ValueError doesn't match the PermissionError rule
+ d = policy.evaluate(ValueError("bad"), try_number=1, max_tries=3)
+ assert d.action == RetryAction.DEFAULT
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook",
autospec=True)
+ def test_agent_run_sync_failure_triggers_fallback(self, mock_hook_cls):
+ """Failure during run_sync (not hook creation) still triggers
fallback."""
+ mock_agent = MagicMock()
+ mock_agent.run_sync.side_effect = RuntimeError("network error
mid-call")
+ mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+ policy = LLMRetryPolicy(
+ llm_conn_id="test",
+ fallback_rules=[RetryRule(exception=ValueError,
action=RetryAction.FAIL, reason="fallback")],
+ )
+ d = policy.evaluate(ValueError("x"), try_number=1, max_tries=3)
+ assert d.action == RetryAction.FAIL
+ assert d.reason == "fallback"
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook",
autospec=True)
+ def test_hook_creation_failure_triggers_fallback(self, mock_hook_cls):
+ """Failure during hook.create_agent still triggers fallback."""
+ mock_hook_cls.return_value.create_agent.side_effect =
RuntimeError("unexpected")
+
+ policy = LLMRetryPolicy(
+ llm_conn_id="test",
+ fallback_rules=[RetryRule(exception=ValueError,
action=RetryAction.FAIL, reason="caught")],
+ )
+ d = policy.evaluate(ValueError("x"), try_number=1, max_tries=3)
+ assert d.action == RetryAction.FAIL