This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1283cc347f2 AIP-72: Handling task retries in task SDK + execution API 
(#45106)
1283cc347f2 is described below

commit 1283cc347f26678e7504bbf39c36fd8ee89af6d5
Author: Amogh Desai <[email protected]>
AuthorDate: Mon Dec 30 14:03:55 2024 +0530

    AIP-72: Handling task retries in task SDK + execution API (#45106)
---
 .../execution_api/routes/task_instances.py         |  36 +++++--
 airflow/utils/state.py                             |   1 +
 task_sdk/src/airflow/sdk/api/client.py             |   1 -
 .../src/airflow/sdk/api/datamodels/_generated.py   |   1 +
 .../src/airflow/sdk/execution_time/task_runner.py  |   9 +-
 task_sdk/tests/execution_time/test_supervisor.py   |   2 +-
 task_sdk/tests/execution_time/test_task_runner.py  |  80 +++++++++++++++-
 .../execution_api/routes/test_task_instances.py    | 106 ++++++++++++++++++++-
 8 files changed, 217 insertions(+), 19 deletions(-)

diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow/api_fastapi/execution_api/routes/task_instances.py
index 016f5222c79..4956466ca70 100644
--- a/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -44,7 +44,7 @@ from airflow.models.taskinstance import TaskInstance as TI, 
_update_rtif
 from airflow.models.taskreschedule import TaskReschedule
 from airflow.models.trigger import Trigger
 from airflow.utils import timezone
-from airflow.utils.state import State
+from airflow.utils.state import State, TerminalTIState
 
 # TODO: Add dependency on JWT token
 router = AirflowRouter()
@@ -185,9 +185,13 @@ def ti_update_state(
     # We only use UUID above for validation purposes
     ti_id_str = str(task_instance_id)
 
-    old = select(TI.state).where(TI.id == ti_id_str).with_for_update()
+    old = select(TI.state, TI.try_number, TI.max_tries).where(TI.id == 
ti_id_str).with_for_update()
     try:
-        (previous_state,) = session.execute(old).one()
+        (
+            previous_state,
+            try_number,
+            max_tries,
+        ) = session.execute(old).one()
     except NoResultFound:
         log.error("Task Instance %s not found", ti_id_str)
         raise HTTPException(
@@ -205,11 +209,17 @@ def ti_update_state(
 
     if isinstance(ti_patch_payload, TITerminalStatePayload):
         query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
-        query = query.values(state=ti_patch_payload.state)
-        if ti_patch_payload.state == State.FAILED:
-            # clear the next_method and next_kwargs
-            query = query.values(next_method=None, next_kwargs=None)
+        updated_state = ti_patch_payload.state
+        # if we get failed, we should attempt to retry, as it is a more
+        # normal state. Tasks with retries are more frequent than without 
retries.
+        if ti_patch_payload.state == TerminalTIState.FAIL_WITHOUT_RETRY:
             updated_state = State.FAILED
+        elif ti_patch_payload.state == State.FAILED:
+            if _is_eligible_to_retry(previous_state, try_number, max_tries):
+                updated_state = State.UP_FOR_RETRY
+            else:
+                updated_state = State.FAILED
+        query = query.values(state=updated_state)
     elif isinstance(ti_patch_payload, TIDeferredStatePayload):
         # Calculate timeout if it was passed
         timeout = None
@@ -359,3 +369,15 @@ def ti_put_rtif(
     _update_rtif(task_instance, put_rtif_payload, session)
 
     return {"message": "Rendered task instance fields successfully set"}
+
+
+def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool:
+    """Is task instance is eligible for retry."""
+    if state == State.RESTARTING:
+        # If a task is cleared when running, it goes into RESTARTING state and 
is always
+        # eligible for retry
+        return True
+
+    # max_tries is initialised with the retries defined at task level, we do 
not need to explicitly ask for
+    # retries from the task SDK now, we can handle using max_tries
+    return max_tries != 0 and try_number <= max_tries
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index e4e2e9db8a5..dca2c8fc93f 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -39,6 +39,7 @@ class TerminalTIState(str, Enum):
     FAILED = "failed"
     SKIPPED = "skipped"  # A user can raise a AirflowSkipException from a task 
& it will be marked as skipped
     REMOVED = "removed"
+    FAIL_WITHOUT_RETRY = "fail_without_retry"
 
     def __str__(self) -> str:
         return self.value
diff --git a/task_sdk/src/airflow/sdk/api/client.py 
b/task_sdk/src/airflow/sdk/api/client.py
index ee4144c7f54..fd4dd6c7e6c 100644
--- a/task_sdk/src/airflow/sdk/api/client.py
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -132,7 +132,6 @@ class TaskInstanceOperations:
         """Tell the API server that this TI has reached a terminal state."""
         # TODO: handle the naming better. finish sounds wrong as "even" 
deferred is essentially finishing.
         body = TITerminalStatePayload(end_date=when, 
state=TerminalTIState(state))
-
         self.client.patch(f"task-instances/{id}/state", 
content=body.model_dump_json())
 
     def heartbeat(self, id: uuid.UUID, pid: int):
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
index 00187364c86..ff4cc588ff5 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -129,6 +129,7 @@ class TerminalTIState(str, Enum):
     FAILED = "failed"
     SKIPPED = "skipped"
     REMOVED = "removed"
+    FAIL_WITHOUT_RETRY = "fail_without_retry"
 
 
 class ValidationError(BaseModel):
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index e48ebc389e1..5d788cf3d73 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -420,7 +420,7 @@ def run(ti: RuntimeTaskInstance, log: Logger):
         # TODO: Handle fail_stop here: 
https://github.com/apache/airflow/issues/44951
         # TODO: Handle addition to Log table: 
https://github.com/apache/airflow/issues/44952
         msg = TaskState(
-            state=TerminalTIState.FAILED,
+            state=TerminalTIState.FAIL_WITHOUT_RETRY,
             end_date=datetime.now(tz=timezone.utc),
         )
 
@@ -433,16 +433,15 @@ def run(ti: RuntimeTaskInstance, log: Logger):
         # updated already be another UI API. So, these exceptions should 
ideally never be thrown.
         # If these are thrown, we should mark the TI state as failed.
         msg = TaskState(
-            state=TerminalTIState.FAILED,
+            state=TerminalTIState.FAIL_WITHOUT_RETRY,
             end_date=datetime.now(tz=timezone.utc),
         )
         # TODO: Run task failure callbacks here
     except SystemExit:
         ...
     except BaseException:
-        # TODO: Handle TI handle failure
-        raise
-
+        # TODO: Run task failure callbacks here
+        msg = TaskState(state=TerminalTIState.FAILED, 
end_date=datetime.now(tz=timezone.utc))
     if msg:
         SUPERVISOR_COMMS.send_request(msg=msg, log=log)
 
diff --git a/task_sdk/tests/execution_time/test_supervisor.py 
b/task_sdk/tests/execution_time/test_supervisor.py
index 53da57cf178..9cfe456962b 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -903,6 +903,7 @@ class TestHandleRequest:
         client_attr_path,
         method_arg,
         mock_response,
+        time_machine,
     ):
         """
         Test handling of different messages to the subprocess. For any new 
message type, add a
@@ -915,7 +916,6 @@ class TestHandleRequest:
             3. Checks that the buffer is updated with the expected response.
             4. Verifies that the response is correctly decoded.
         """
-
         # Mock the client method. E.g. `client.variables.get` or 
`client.connections.get`
         mock_client_method = 
attrgetter(client_attr_path)(watched_subprocess.client)
         mock_client_method.return_value = mock_response
diff --git a/task_sdk/tests/execution_time/test_task_runner.py 
b/task_sdk/tests/execution_time/test_task_runner.py
index 48ab35709bb..7cbe6e649b6 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -254,6 +254,84 @@ def test_run_basic_skipped(time_machine, mocked_parse, 
make_ti_context, mock_sup
     )
 
 
+def test_run_raises_base_exception(time_machine, mocked_parse, 
make_ti_context):
+    """Test running a basic task that raises a base exception which should 
send fail_with_retry state."""
+    from airflow.providers.standard.operators.python import PythonOperator
+
+    task = PythonOperator(
+        task_id="zero_division_error",
+        python_callable=lambda: 1 / 0,
+    )
+
+    what = StartupDetails(
+        ti=TaskInstance(
+            id=uuid7(),
+            task_id="zero_division_error",
+            dag_id="basic_dag_base_exception",
+            run_id="c",
+            try_number=1,
+        ),
+        file="",
+        requests_fd=0,
+        ti_context=make_ti_context(),
+    )
+
+    ti = mocked_parse(what, "basic_dag_base_exception", task)
+
+    instant = timezone.datetime(2024, 12, 3, 10, 0)
+    time_machine.move_to(instant, tick=False)
+
+    with mock.patch(
+        "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
+    ) as mock_supervisor_comms:
+        run(ti, log=mock.MagicMock())
+
+        mock_supervisor_comms.send_request.assert_called_once_with(
+            msg=TaskState(
+                state=TerminalTIState.FAILED,
+                end_date=instant,
+            ),
+            log=mock.ANY,
+        )
+
+
+def test_startup_basic_templated_dag(mocked_parse, make_ti_context):
+    """Test running a DAG with templated task."""
+    from airflow.providers.standard.operators.bash import BashOperator
+
+    task = BashOperator(
+        task_id="templated_task",
+        bash_command="echo 'Logical date is {{ logical_date }}'",
+    )
+
+    what = StartupDetails(
+        ti=TaskInstance(
+            id=uuid7(), task_id="templated_task", 
dag_id="basic_templated_dag", run_id="c", try_number=1
+        ),
+        file="",
+        requests_fd=0,
+        ti_context=make_ti_context(),
+    )
+    mocked_parse(what, "basic_templated_dag", task)
+
+    with mock.patch(
+        "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
+    ) as mock_supervisor_comms:
+        mock_supervisor_comms.get_message.return_value = what
+        startup()
+
+        mock_supervisor_comms.send_request.assert_called_once_with(
+            msg=SetRenderedFields(
+                rendered_fields={
+                    "bash_command": "echo 'Logical date is {{ logical_date 
}}'",
+                    "cwd": None,
+                    "env": None,
+                }
+            ),
+            log=mock.ANY,
+        )
+
+
 @pytest.mark.parametrize(
     ["task_params", "expected_rendered_fields"],
     [
@@ -376,7 +454,7 @@ def test_run_basic_failed(
     run(ti, log=mock.MagicMock())
 
     mock_supervisor_comms.send_request.assert_called_once_with(
-        msg=TaskState(state=TerminalTIState.FAILED, end_date=instant), 
log=mock.ANY
+        msg=TaskState(state=TerminalTIState.FAIL_WITHOUT_RETRY, 
end_date=instant), log=mock.ANY
     )
 
 
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py 
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index b997a2e0ac8..497c5fbaf3f 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -28,7 +28,7 @@ from sqlalchemy.exc import SQLAlchemyError
 from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger
 from airflow.models.taskinstance import TaskInstance
 from airflow.utils import timezone
-from airflow.utils.state import State, TaskInstanceState
+from airflow.utils.state import State, TaskInstanceState, TerminalTIState
 
 from tests_common.test_utils.db import clear_db_runs, clear_rendered_ti_fields
 
@@ -234,7 +234,7 @@ class TestTIUpdateState:
         with mock.patch(
             "airflow.api_fastapi.common.db.common.Session.execute",
             side_effect=[
-                mock.Mock(one=lambda: ("running",)),  # First call returns 
"queued"
+                mock.Mock(one=lambda: ("running", 1, 0)),  # First call 
returns "queued"
                 SQLAlchemyError("Database error"),  # Second call raises an 
error
             ],
         ):
@@ -340,7 +340,105 @@ class TestTIUpdateState:
         assert trs[0].map_index == -1
         assert trs[0].duration == 129600
 
-    def test_ti_update_state_to_failed_table_check(self, client, session, 
create_task_instance):
+    @pytest.mark.parametrize(
+        ("retries", "expected_state"),
+        [
+            (0, State.FAILED),
+            (None, State.FAILED),
+            (3, State.UP_FOR_RETRY),
+        ],
+    )
+    def test_ti_update_state_to_failed_with_retries(
+        self, client, session, create_task_instance, retries, expected_state
+    ):
+        ti = create_task_instance(
+            task_id="test_ti_update_state_to_retry",
+            state=State.RUNNING,
+        )
+
+        if retries is not None:
+            ti.max_tries = retries
+        session.commit()
+
+        response = client.patch(
+            f"/execution/task-instances/{ti.id}/state",
+            json={
+                "state": TerminalTIState.FAILED,
+                "end_date": DEFAULT_END_DATE.isoformat(),
+            },
+        )
+
+        assert response.status_code == 204
+        assert response.text == ""
+
+        session.expire_all()
+
+        ti = session.get(TaskInstance, ti.id)
+        assert ti.state == expected_state
+        assert ti.next_method is None
+        assert ti.next_kwargs is None
+
+    def test_ti_update_state_when_ti_is_restarting(self, client, session, 
create_task_instance):
+        ti = create_task_instance(
+            task_id="test_ti_update_state_when_ti_is_restarting",
+            state=State.RUNNING,
+        )
+        # update state to restarting
+        ti.state = State.RESTARTING
+        session.commit()
+
+        response = client.patch(
+            f"/execution/task-instances/{ti.id}/state",
+            json={
+                "state": TerminalTIState.FAILED,
+                "end_date": DEFAULT_END_DATE.isoformat(),
+            },
+        )
+
+        assert response.status_code == 204
+        assert response.text == ""
+
+        session.expire_all()
+
+        ti = session.get(TaskInstance, ti.id)
+        # restarting is always retried
+        assert ti.state == State.UP_FOR_RETRY
+        assert ti.next_method is None
+        assert ti.next_kwargs is None
+
+    def test_ti_update_state_when_ti_has_higher_tries_than_retries(
+        self, client, session, create_task_instance
+    ):
+        ti = create_task_instance(
+            
task_id="test_ti_update_state_when_ti_has_higher_tries_than_retries",
+            state=State.RUNNING,
+        )
+        # two maximum tries defined, but third try going on
+        ti.max_tries = 2
+        ti.try_number = 3
+        session.commit()
+
+        response = client.patch(
+            f"/execution/task-instances/{ti.id}/state",
+            json={
+                "state": TerminalTIState.FAILED,
+                "end_date": DEFAULT_END_DATE.isoformat(),
+            },
+        )
+
+        assert response.status_code == 204
+        assert response.text == ""
+
+        session.expire_all()
+
+        ti = session.get(TaskInstance, ti.id)
+        # all retries exhausted, marking as failed
+        assert ti.state == State.FAILED
+        assert ti.next_method is None
+        assert ti.next_kwargs is None
+
+    def test_ti_update_state_to_failed_without_retry_table_check(self, client, 
session, create_task_instance):
+        # we just want to fail in this test, no need to retry
         ti = create_task_instance(
             task_id="test_ti_update_state_to_failed_table_check",
             state=State.RUNNING,
@@ -351,7 +449,7 @@ class TestTIUpdateState:
         response = client.patch(
             f"/execution/task-instances/{ti.id}/state",
             json={
-                "state": State.FAILED,
+                "state": TerminalTIState.FAIL_WITHOUT_RETRY,
                 "end_date": DEFAULT_END_DATE.isoformat(),
             },
         )

Reply via email to