This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1283cc347f2 AIP-72: Handling task retries in task SDK + execution API
(#45106)
1283cc347f2 is described below
commit 1283cc347f26678e7504bbf39c36fd8ee89af6d5
Author: Amogh Desai <[email protected]>
AuthorDate: Mon Dec 30 14:03:55 2024 +0530
AIP-72: Handling task retries in task SDK + execution API (#45106)
---
.../execution_api/routes/task_instances.py | 36 +++++--
airflow/utils/state.py | 1 +
task_sdk/src/airflow/sdk/api/client.py | 1 -
.../src/airflow/sdk/api/datamodels/_generated.py | 1 +
.../src/airflow/sdk/execution_time/task_runner.py | 9 +-
task_sdk/tests/execution_time/test_supervisor.py | 2 +-
task_sdk/tests/execution_time/test_task_runner.py | 80 +++++++++++++++-
.../execution_api/routes/test_task_instances.py | 106 ++++++++++++++++++++-
8 files changed, 217 insertions(+), 19 deletions(-)
diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow/api_fastapi/execution_api/routes/task_instances.py
index 016f5222c79..4956466ca70 100644
--- a/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -44,7 +44,7 @@ from airflow.models.taskinstance import TaskInstance as TI,
_update_rtif
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
from airflow.utils import timezone
-from airflow.utils.state import State
+from airflow.utils.state import State, TerminalTIState
# TODO: Add dependency on JWT token
router = AirflowRouter()
@@ -185,9 +185,13 @@ def ti_update_state(
# We only use UUID above for validation purposes
ti_id_str = str(task_instance_id)
- old = select(TI.state).where(TI.id == ti_id_str).with_for_update()
+ old = select(TI.state, TI.try_number, TI.max_tries).where(TI.id ==
ti_id_str).with_for_update()
try:
- (previous_state,) = session.execute(old).one()
+ (
+ previous_state,
+ try_number,
+ max_tries,
+ ) = session.execute(old).one()
except NoResultFound:
log.error("Task Instance %s not found", ti_id_str)
raise HTTPException(
@@ -205,11 +209,17 @@ def ti_update_state(
if isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date,
query, session.bind)
- query = query.values(state=ti_patch_payload.state)
- if ti_patch_payload.state == State.FAILED:
- # clear the next_method and next_kwargs
- query = query.values(next_method=None, next_kwargs=None)
+ updated_state = ti_patch_payload.state
+ # if we get failed, we should attempt to retry, as it is a more
+ # normal state. Tasks with retries are more frequent than without
retries.
+ if ti_patch_payload.state == TerminalTIState.FAIL_WITHOUT_RETRY:
updated_state = State.FAILED
+ elif ti_patch_payload.state == State.FAILED:
+ if _is_eligible_to_retry(previous_state, try_number, max_tries):
+ updated_state = State.UP_FOR_RETRY
+ else:
+ updated_state = State.FAILED
+ query = query.values(state=updated_state)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
@@ -359,3 +369,15 @@ def ti_put_rtif(
_update_rtif(task_instance, put_rtif_payload, session)
return {"message": "Rendered task instance fields successfully set"}
+
+
+def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool:
+ """Is task instance is eligible for retry."""
+ if state == State.RESTARTING:
+ # If a task is cleared when running, it goes into RESTARTING state and
is always
+ # eligible for retry
+ return True
+
+ # max_tries is initialised with the retries defined at task level, we do
not need to explicitly ask for
+ # retries from the task SDK now, we can handle using max_tries
+ return max_tries != 0 and try_number <= max_tries
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index e4e2e9db8a5..dca2c8fc93f 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -39,6 +39,7 @@ class TerminalTIState(str, Enum):
FAILED = "failed"
SKIPPED = "skipped" # A user can raise a AirflowSkipException from a task
& it will be marked as skipped
REMOVED = "removed"
+ FAIL_WITHOUT_RETRY = "fail_without_retry"
def __str__(self) -> str:
return self.value
diff --git a/task_sdk/src/airflow/sdk/api/client.py
b/task_sdk/src/airflow/sdk/api/client.py
index ee4144c7f54..fd4dd6c7e6c 100644
--- a/task_sdk/src/airflow/sdk/api/client.py
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -132,7 +132,6 @@ class TaskInstanceOperations:
"""Tell the API server that this TI has reached a terminal state."""
# TODO: handle the naming better. finish sounds wrong as "even"
deferred is essentially finishing.
body = TITerminalStatePayload(end_date=when,
state=TerminalTIState(state))
-
self.client.patch(f"task-instances/{id}/state",
content=body.model_dump_json())
def heartbeat(self, id: uuid.UUID, pid: int):
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
index 00187364c86..ff4cc588ff5 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -129,6 +129,7 @@ class TerminalTIState(str, Enum):
FAILED = "failed"
SKIPPED = "skipped"
REMOVED = "removed"
+ FAIL_WITHOUT_RETRY = "fail_without_retry"
class ValidationError(BaseModel):
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index e48ebc389e1..5d788cf3d73 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -420,7 +420,7 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO: Handle fail_stop here:
https://github.com/apache/airflow/issues/44951
# TODO: Handle addition to Log table:
https://github.com/apache/airflow/issues/44952
msg = TaskState(
- state=TerminalTIState.FAILED,
+ state=TerminalTIState.FAIL_WITHOUT_RETRY,
end_date=datetime.now(tz=timezone.utc),
)
@@ -433,16 +433,15 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# updated already be another UI API. So, these exceptions should
ideally never be thrown.
# If these are thrown, we should mark the TI state as failed.
msg = TaskState(
- state=TerminalTIState.FAILED,
+ state=TerminalTIState.FAIL_WITHOUT_RETRY,
end_date=datetime.now(tz=timezone.utc),
)
# TODO: Run task failure callbacks here
except SystemExit:
...
except BaseException:
- # TODO: Handle TI handle failure
- raise
-
+ # TODO: Run task failure callbacks here
+ msg = TaskState(state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc))
if msg:
SUPERVISOR_COMMS.send_request(msg=msg, log=log)
diff --git a/task_sdk/tests/execution_time/test_supervisor.py
b/task_sdk/tests/execution_time/test_supervisor.py
index 53da57cf178..9cfe456962b 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -903,6 +903,7 @@ class TestHandleRequest:
client_attr_path,
method_arg,
mock_response,
+ time_machine,
):
"""
Test handling of different messages to the subprocess. For any new
message type, add a
@@ -915,7 +916,6 @@ class TestHandleRequest:
3. Checks that the buffer is updated with the expected response.
4. Verifies that the response is correctly decoded.
"""
-
# Mock the client method. E.g. `client.variables.get` or
`client.connections.get`
mock_client_method =
attrgetter(client_attr_path)(watched_subprocess.client)
mock_client_method.return_value = mock_response
diff --git a/task_sdk/tests/execution_time/test_task_runner.py
b/task_sdk/tests/execution_time/test_task_runner.py
index 48ab35709bb..7cbe6e649b6 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -254,6 +254,84 @@ def test_run_basic_skipped(time_machine, mocked_parse,
make_ti_context, mock_sup
)
+def test_run_raises_base_exception(time_machine, mocked_parse,
make_ti_context):
+ """Test running a basic task that raises a base exception which should
send fail_with_retry state."""
+ from airflow.providers.standard.operators.python import PythonOperator
+
+ task = PythonOperator(
+ task_id="zero_division_error",
+ python_callable=lambda: 1 / 0,
+ )
+
+ what = StartupDetails(
+ ti=TaskInstance(
+ id=uuid7(),
+ task_id="zero_division_error",
+ dag_id="basic_dag_base_exception",
+ run_id="c",
+ try_number=1,
+ ),
+ file="",
+ requests_fd=0,
+ ti_context=make_ti_context(),
+ )
+
+ ti = mocked_parse(what, "basic_dag_base_exception", task)
+
+ instant = timezone.datetime(2024, 12, 3, 10, 0)
+ time_machine.move_to(instant, tick=False)
+
+ with mock.patch(
+ "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
+ ) as mock_supervisor_comms:
+ run(ti, log=mock.MagicMock())
+
+ mock_supervisor_comms.send_request.assert_called_once_with(
+ msg=TaskState(
+ state=TerminalTIState.FAILED,
+ end_date=instant,
+ ),
+ log=mock.ANY,
+ )
+
+
+def test_startup_basic_templated_dag(mocked_parse, make_ti_context):
+ """Test running a DAG with templated task."""
+ from airflow.providers.standard.operators.bash import BashOperator
+
+ task = BashOperator(
+ task_id="templated_task",
+ bash_command="echo 'Logical date is {{ logical_date }}'",
+ )
+
+ what = StartupDetails(
+ ti=TaskInstance(
+ id=uuid7(), task_id="templated_task",
dag_id="basic_templated_dag", run_id="c", try_number=1
+ ),
+ file="",
+ requests_fd=0,
+ ti_context=make_ti_context(),
+ )
+ mocked_parse(what, "basic_templated_dag", task)
+
+ with mock.patch(
+ "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
+ ) as mock_supervisor_comms:
+ mock_supervisor_comms.get_message.return_value = what
+ startup()
+
+ mock_supervisor_comms.send_request.assert_called_once_with(
+ msg=SetRenderedFields(
+ rendered_fields={
+ "bash_command": "echo 'Logical date is {{ logical_date
}}'",
+ "cwd": None,
+ "env": None,
+ }
+ ),
+ log=mock.ANY,
+ )
+
+
@pytest.mark.parametrize(
["task_params", "expected_rendered_fields"],
[
@@ -376,7 +454,7 @@ def test_run_basic_failed(
run(ti, log=mock.MagicMock())
mock_supervisor_comms.send_request.assert_called_once_with(
- msg=TaskState(state=TerminalTIState.FAILED, end_date=instant),
log=mock.ANY
+ msg=TaskState(state=TerminalTIState.FAIL_WITHOUT_RETRY,
end_date=instant), log=mock.ANY
)
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index b997a2e0ac8..497c5fbaf3f 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -28,7 +28,7 @@ from sqlalchemy.exc import SQLAlchemyError
from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
-from airflow.utils.state import State, TaskInstanceState
+from airflow.utils.state import State, TaskInstanceState, TerminalTIState
from tests_common.test_utils.db import clear_db_runs, clear_rendered_ti_fields
@@ -234,7 +234,7 @@ class TestTIUpdateState:
with mock.patch(
"airflow.api_fastapi.common.db.common.Session.execute",
side_effect=[
- mock.Mock(one=lambda: ("running",)), # First call returns
"queued"
+ mock.Mock(one=lambda: ("running", 1, 0)), # First call
returns "queued"
SQLAlchemyError("Database error"), # Second call raises an
error
],
):
@@ -340,7 +340,105 @@ class TestTIUpdateState:
assert trs[0].map_index == -1
assert trs[0].duration == 129600
- def test_ti_update_state_to_failed_table_check(self, client, session,
create_task_instance):
+ @pytest.mark.parametrize(
+ ("retries", "expected_state"),
+ [
+ (0, State.FAILED),
+ (None, State.FAILED),
+ (3, State.UP_FOR_RETRY),
+ ],
+ )
+ def test_ti_update_state_to_failed_with_retries(
+ self, client, session, create_task_instance, retries, expected_state
+ ):
+ ti = create_task_instance(
+ task_id="test_ti_update_state_to_retry",
+ state=State.RUNNING,
+ )
+
+ if retries is not None:
+ ti.max_tries = retries
+ session.commit()
+
+ response = client.patch(
+ f"/execution/task-instances/{ti.id}/state",
+ json={
+ "state": TerminalTIState.FAILED,
+ "end_date": DEFAULT_END_DATE.isoformat(),
+ },
+ )
+
+ assert response.status_code == 204
+ assert response.text == ""
+
+ session.expire_all()
+
+ ti = session.get(TaskInstance, ti.id)
+ assert ti.state == expected_state
+ assert ti.next_method is None
+ assert ti.next_kwargs is None
+
+ def test_ti_update_state_when_ti_is_restarting(self, client, session,
create_task_instance):
+ ti = create_task_instance(
+ task_id="test_ti_update_state_when_ti_is_restarting",
+ state=State.RUNNING,
+ )
+ # update state to restarting
+ ti.state = State.RESTARTING
+ session.commit()
+
+ response = client.patch(
+ f"/execution/task-instances/{ti.id}/state",
+ json={
+ "state": TerminalTIState.FAILED,
+ "end_date": DEFAULT_END_DATE.isoformat(),
+ },
+ )
+
+ assert response.status_code == 204
+ assert response.text == ""
+
+ session.expire_all()
+
+ ti = session.get(TaskInstance, ti.id)
+ # restarting is always retried
+ assert ti.state == State.UP_FOR_RETRY
+ assert ti.next_method is None
+ assert ti.next_kwargs is None
+
+ def test_ti_update_state_when_ti_has_higher_tries_than_retries(
+ self, client, session, create_task_instance
+ ):
+ ti = create_task_instance(
+
task_id="test_ti_update_state_when_ti_has_higher_tries_than_retries",
+ state=State.RUNNING,
+ )
+ # two maximum tries defined, but third try going on
+ ti.max_tries = 2
+ ti.try_number = 3
+ session.commit()
+
+ response = client.patch(
+ f"/execution/task-instances/{ti.id}/state",
+ json={
+ "state": TerminalTIState.FAILED,
+ "end_date": DEFAULT_END_DATE.isoformat(),
+ },
+ )
+
+ assert response.status_code == 204
+ assert response.text == ""
+
+ session.expire_all()
+
+ ti = session.get(TaskInstance, ti.id)
+ # all retries exhausted, marking as failed
+ assert ti.state == State.FAILED
+ assert ti.next_method is None
+ assert ti.next_kwargs is None
+
+ def test_ti_update_state_to_failed_without_retry_table_check(self, client,
session, create_task_instance):
+ # we just want to fail in this test, no need to retry
ti = create_task_instance(
task_id="test_ti_update_state_to_failed_table_check",
state=State.RUNNING,
@@ -351,7 +449,7 @@ class TestTIUpdateState:
response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
- "state": State.FAILED,
+ "state": TerminalTIState.FAIL_WITHOUT_RETRY,
"end_date": DEFAULT_END_DATE.isoformat(),
},
)