This is an automated email from the ASF dual-hosted git repository.

rahulvats pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e2289f9f90a Remove state dependency from airflow core in sdk (#55292)
e2289f9f90a is described below

commit e2289f9f90a54fcb926005bab494324f95ae1691
Author: Ankit Chaurasia <[email protected]>
AuthorDate: Wed Oct 22 11:50:10 2025 +0545

    Remove state dependency from airflow core in sdk (#55292)
    
    Remove state dependency from airflow core in sdk
---
 task-sdk/src/airflow/sdk/definitions/dag.py  | 31 +++++++++++--------
 task-sdk/tests/task_sdk/api/test_client.py   |  2 +-
 task-sdk/tests/task_sdk/bases/test_sensor.py | 45 ++++++++++++++--------------
 3 files changed, 42 insertions(+), 36 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py 
b/task-sdk/src/airflow/sdk/definitions/dag.py
index 599f6dd81c0..30e37b896dc 100644
--- a/task-sdk/src/airflow/sdk/definitions/dag.py
+++ b/task-sdk/src/airflow/sdk/definitions/dag.py
@@ -43,7 +43,7 @@ from airflow.exceptions import (
     RemovedInAirflow4Warning,
     TaskNotFound,
 )
-from airflow.sdk import TriggerRule
+from airflow.sdk import TaskInstanceState, TriggerRule
 from airflow.sdk.bases.operator import BaseOperator
 from airflow.sdk.definitions._internal.node import validate_key
 from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
@@ -85,6 +85,15 @@ __all__ = [
     "dag",
 ]
 
+FINISHED_STATES = frozenset(
+    [
+        TaskInstanceState.SUCCESS,
+        TaskInstanceState.FAILED,
+        TaskInstanceState.SKIPPED,
+        TaskInstanceState.UPSTREAM_FAILED,
+        TaskInstanceState.REMOVED,
+    ]
+)
 
 DagStateChangeCallback = Callable[[Context], None]
 ScheduleInterval = None | str | timedelta | relativedelta
@@ -1166,10 +1175,9 @@ class DAG:
         from airflow import settings
         from airflow.configuration import secrets_backend_list
         from airflow.models.dagrun import DagRun, get_or_create_dagrun
-        from airflow.sdk import DagRunState, TaskInstanceState, timezone
+        from airflow.sdk import DagRunState, timezone
         from airflow.secrets.local_filesystem import LocalFilesystemBackend
         from airflow.serialization.serialized_objects import SerializedDAG
-        from airflow.utils.state import State
         from airflow.utils.types import DagRunTriggeredByType, DagRunType
 
         exit_stack = ExitStack()
@@ -1291,7 +1299,7 @@ class DAG:
                 # triggerer may mark tasks scheduled so we read from DB
                 all_tis = set(dr.get_task_instances(session=session))
                 scheduled_tis = {x for x in all_tis if x.state == 
TaskInstanceState.SCHEDULED}
-                ids_unrunnable = {x for x in all_tis if x.state not in 
State.finished} - scheduled_tis
+                ids_unrunnable = {x for x in all_tis if x.state not in 
FINISHED_STATES} - scheduled_tis
                 if not scheduled_tis and ids_unrunnable:
                     log.warning("No tasks to run. unrunnable tasks: %s", 
ids_unrunnable)
                     time.sleep(1)
@@ -1331,7 +1339,7 @@ class DAG:
                         # Run the task locally
                         try:
                             if mark_success:
-                                ti.set_state(State.SUCCESS)
+                                ti.set_state(TaskInstanceState.SUCCESS)
                                 log.info("[DAG TEST] Marking success for %s on 
%s", task, ti.logical_date)
                             else:
                                 _run_task(ti=ti, task=task, run_triggerer=True)
@@ -1363,7 +1371,6 @@ def _run_task(
     possible.  This function is only meant for the `dag.test` function as a 
helper function.
     """
     from airflow.sdk.module_loading import import_string
-    from airflow.utils.state import State
 
     taskrun_result: TaskRunResult | None
     log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, 
ti.map_index)
@@ -1378,7 +1385,7 @@ def _run_task(
 
             # The API Server expects the task instance to be in QUEUED state 
before
             # it is run.
-            ti.set_state(State.QUEUED)
+            ti.set_state(TaskInstanceState.QUEUED)
             task_sdk_ti = TaskInstanceSDK(
                 id=ti.id,
                 task_id=ti.task_id,
@@ -1394,12 +1401,12 @@ def _run_task(
             ti.set_state(taskrun_result.ti.state)
             ti.task = create_scheduler_operator(taskrun_result.ti.task)
 
-            if ti.state == State.DEFERRED and isinstance(msg, DeferTask) and 
run_triggerer:
+            if ti.state == TaskInstanceState.DEFERRED and isinstance(msg, 
DeferTask) and run_triggerer:
                 from airflow.utils.session import create_session
 
                 # API Server expects the task instance to be in QUEUED state 
before
                 # resuming from deferral.
-                ti.set_state(State.QUEUED)
+                ti.set_state(TaskInstanceState.QUEUED)
 
                 log.info("[DAG TEST] running trigger in line")
                 trigger = import_string(msg.classpath)(**msg.trigger_kwargs)
@@ -1410,15 +1417,15 @@ def _run_task(
 
                 # Set the state to SCHEDULED so that the task can be resumed.
                 with create_session() as session:
-                    ti.state = State.SCHEDULED
+                    ti.state = TaskInstanceState.SCHEDULED
                     session.add(ti)
                 continue
 
             break
         except Exception:
             log.exception("[DAG TEST] Error running task %s", ti)
-            if ti.state not in State.finished:
-                ti.set_state(State.FAILED)
+            if ti.state not in FINISHED_STATES:
+                ti.set_state(TaskInstanceState.FAILED)
                 taskrun_result = None
                 break
             raise
diff --git a/task-sdk/tests/task_sdk/api/test_client.py 
b/task-sdk/tests/task_sdk/api/test_client.py
index 32a709663ac..e0f3866cc55 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -39,6 +39,7 @@ from airflow.sdk.api.datamodels._generated import (
     DagRunStateResponse,
     HITLDetailResponse,
     HITLUser,
+    TerminalTIState,
     VariableResponse,
     XComResponse,
 )
@@ -52,7 +53,6 @@ from airflow.sdk.execution_time.comms import (
     RescheduleTask,
     TaskRescheduleStartDate,
 )
-from airflow.utils.state import TerminalTIState
 
 if TYPE_CHECKING:
     from time_machine import TimeMachineFixture
diff --git a/task-sdk/tests/task_sdk/bases/test_sensor.py 
b/task-sdk/tests/task_sdk/bases/test_sensor.py
index 2a3d61f0b52..214fef860f4 100644
--- a/task-sdk/tests/task_sdk/bases/test_sensor.py
+++ b/task-sdk/tests/task_sdk/bases/test_sensor.py
@@ -34,12 +34,11 @@ from airflow.exceptions import (
 )
 from airflow.models.trigger import TriggerFailureReason
 from airflow.providers.standard.operators.empty import EmptyOperator
-from airflow.sdk import timezone
+from airflow.sdk import TaskInstanceState, timezone
 from airflow.sdk.bases.sensor import BaseSensorOperator, PokeReturnValue, 
poke_mode_only
 from airflow.sdk.definitions.dag import DAG
 from airflow.sdk.execution_time.comms import RescheduleTask, 
TaskRescheduleStartDate
 from airflow.sdk.timezone import datetime
-from airflow.utils.state import State
 
 if TYPE_CHECKING:
     from airflow.sdk.definitions.context import Context
@@ -172,7 +171,7 @@ class TestBaseSensor:
 
         state, msg, _ = run_task(task=sensor)
 
-        assert state == State.UP_FOR_RESCHEDULE
+        assert state == TaskInstanceState.UP_FOR_RESCHEDULE
         assert msg.reschedule_date == date1 + 
timedelta(seconds=sensor.poke_interval)
 
         # second poke returns False and task is re-scheduled
@@ -180,14 +179,14 @@ class TestBaseSensor:
         date2 = date1 + timedelta(seconds=sensor.poke_interval)
         state, msg, _ = run_task(task=sensor)
 
-        assert state == State.UP_FOR_RESCHEDULE
+        assert state == TaskInstanceState.UP_FOR_RESCHEDULE
         assert msg.reschedule_date == date2 + 
timedelta(seconds=sensor.poke_interval)
 
         # third poke returns True and task succeeds
         time_machine.coordinates.shift(sensor.poke_interval)
         state, _, _ = run_task(task=sensor)
 
-        assert state == State.SUCCESS
+        assert state == TaskInstanceState.SUCCESS
 
     def test_fail_with_reschedule(self, run_task, make_sensor, time_machine, 
mock_supervisor_comms):
         sensor = make_sensor(return_value=False, poke_interval=10, timeout=5, 
mode="reschedule")
@@ -198,7 +197,7 @@ class TestBaseSensor:
 
         state, msg, _ = run_task(task=sensor)
 
-        assert state == State.UP_FOR_RESCHEDULE
+        assert state == TaskInstanceState.UP_FOR_RESCHEDULE
         assert msg.reschedule_date == date1 + 
timedelta(seconds=sensor.poke_interval)
 
         # second poke returns False, timeout occurs
@@ -208,7 +207,7 @@ class TestBaseSensor:
         mock_supervisor_comms.send.return_value = 
TaskRescheduleStartDate(start_date=date1)
         state, msg, error = run_task(task=sensor, 
context_update={"task_reschedule_count": 1})
 
-        assert state == State.FAILED
+        assert state == TaskInstanceState.FAILED
         assert isinstance(error, AirflowSensorTimeout)
 
     def test_soft_fail_with_reschedule(self, run_task, make_sensor, 
time_machine, mock_supervisor_comms):
@@ -221,7 +220,7 @@ class TestBaseSensor:
         time_machine.move_to(date1, tick=False)
 
         state, msg, _ = run_task(task=sensor)
-        assert state == State.UP_FOR_RESCHEDULE
+        assert state == TaskInstanceState.UP_FOR_RESCHEDULE
 
         # second poke returns False, timeout occurs
         time_machine.coordinates.shift(sensor.poke_interval)
@@ -229,7 +228,7 @@ class TestBaseSensor:
         # Mocking values from DB/API-server
         mock_supervisor_comms.send.return_value = 
TaskRescheduleStartDate(start_date=date1)
         state, msg, _ = run_task(task=sensor, 
context_update={"task_reschedule_count": 1})
-        assert state == State.SKIPPED
+        assert state == TaskInstanceState.SKIPPED
 
     def test_ok_with_reschedule_and_exponential_backoff(
         self, run_task, make_sensor, time_machine, mock_supervisor_comms
@@ -261,7 +260,7 @@ class TestBaseSensor:
             curr_date = curr_date + timedelta(seconds=new_interval)
             time_machine.coordinates.shift(new_interval)
             state, msg, _ = run_task(sensor, 
context_update={"task_reschedule_count": _poke_count})
-            assert state == State.UP_FOR_RESCHEDULE
+            assert state == TaskInstanceState.UP_FOR_RESCHEDULE
             old_interval = new_interval
             new_interval = sensor._get_next_poke_interval(task_start_date, 
run_duration, _poke_count)
             assert old_interval < new_interval  # actual test
@@ -272,7 +271,7 @@ class TestBaseSensor:
         time_machine.coordinates.shift(new_interval)
 
         state, msg, _ = run_task(sensor, 
context_update={"task_reschedule_count": false_count + 1})
-        assert state == State.SUCCESS
+        assert state == TaskInstanceState.SUCCESS
 
     def test_invalid_mode(self):
         with pytest.raises(AirflowException):
@@ -291,7 +290,7 @@ class TestBaseSensor:
         with time_machine.travel(date1, tick=False):
             state, msg, error = run_task(sensor)
 
-        assert state == State.UP_FOR_RESCHEDULE
+        assert state == TaskInstanceState.UP_FOR_RESCHEDULE
         assert isinstance(msg, RescheduleTask)
         assert msg.reschedule_date == date2
 
@@ -299,14 +298,14 @@ class TestBaseSensor:
         with time_machine.travel(date2, tick=False):
             state, msg, error = run_task(sensor)
 
-        assert state == State.UP_FOR_RESCHEDULE
+        assert state == TaskInstanceState.UP_FOR_RESCHEDULE
         assert isinstance(msg, RescheduleTask)
         assert msg.reschedule_date == date3
 
         # third poke returns True and task succeeds
         with time_machine.travel(date3, tick=False):
             state, _, _ = run_task(sensor)
-        assert state == State.SUCCESS
+        assert state == TaskInstanceState.SUCCESS
 
     def test_sensor_with_invalid_poke_interval(self):
         negative_poke_interval = -10
@@ -523,36 +522,36 @@ class TestBaseSensor:
                 context_update={"task_reschedule_count": 
test_state["task_reschedule_count"]},
             )
 
-            if state == State.UP_FOR_RESCHEDULE:
+            if state == TaskInstanceState.UP_FOR_RESCHEDULE:
                 test_state["task_reschedule_count"] += 1
                 # Only set first_reschedule_date on the first successful 
reschedule
                 if test_state["first_reschedule_date"] is None:
                     test_state["first_reschedule_date"] = 
test_state["current_time"]
-            elif state == State.UP_FOR_RETRY:
+            elif state == TaskInstanceState.UP_FOR_RETRY:
                 test_state["try_number"] += 1
             return state, msg, error
 
         # Phase 1: Initial execution until failure
         # First poke - should reschedule
         state, _, _ = _run_task()
-        assert state == State.UP_FOR_RESCHEDULE
+        assert state == TaskInstanceState.UP_FOR_RESCHEDULE
 
         # Second poke - should raise RuntimeError and retry
         test_state["current_time"] += timedelta(seconds=sensor.poke_interval)
         state, _, error = _run_task()
-        assert state == State.UP_FOR_RETRY
+        assert state == TaskInstanceState.UP_FOR_RETRY
         assert isinstance(error, RuntimeError)
 
         # Third poke - should reschedule again
         test_state["current_time"] += sensor.retry_delay + timedelta(seconds=1)
         state, _, _ = _run_task()
-        assert state == State.UP_FOR_RESCHEDULE
+        assert state == TaskInstanceState.UP_FOR_RESCHEDULE
 
         # Fourth poke - should timeout
         test_state["current_time"] += timedelta(seconds=sensor.poke_interval)
         state, _, error = _run_task()
         assert isinstance(error, AirflowSensorTimeout)
-        assert state == State.FAILED
+        assert state == TaskInstanceState.FAILED
 
         # Phase 2: After clearing the failed sensor
         # Reset supervisor comms to return None, simulating a fresh start 
after clearing
@@ -564,13 +563,13 @@ class TestBaseSensor:
         for _ in range(3):
             test_state["current_time"] += 
timedelta(seconds=sensor.poke_interval)
             state, _, _ = _run_task()
-            assert state == State.UP_FOR_RESCHEDULE
+            assert state == TaskInstanceState.UP_FOR_RESCHEDULE
 
         # Final poke - should timeout
         test_state["current_time"] += timedelta(seconds=sensor.poke_interval)
         state, _, error = _run_task()
         assert isinstance(error, AirflowSensorTimeout)
-        assert state == State.FAILED
+        assert state == TaskInstanceState.FAILED
 
     def test_sensor_with_xcom(self, make_sensor):
         xcom_value = "TestValue"
@@ -615,7 +614,7 @@ class TestBaseSensor:
         state, _, error = run_task(task=task, 
dag_id=f"test_sensor_timeout_{mode}_{retries}")
 
         assert isinstance(error, AirflowSensorTimeout)
-        assert state == State.FAILED
+        assert state == TaskInstanceState.FAILED
 
 
 @poke_mode_only

Reply via email to