This is an automated email from the ASF dual-hosted git repository.
rahulvats pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e2289f9f90a Remove state dependency from airflow core in sdk (#55292)
e2289f9f90a is described below
commit e2289f9f90a54fcb926005bab494324f95ae1691
Author: Ankit Chaurasia <[email protected]>
AuthorDate: Wed Oct 22 11:50:10 2025 +0545
Remove state dependency from airflow core in sdk (#55292)
Remove state dependency from airflow core in sdk
---
task-sdk/src/airflow/sdk/definitions/dag.py | 31 +++++++++++--------
task-sdk/tests/task_sdk/api/test_client.py | 2 +-
task-sdk/tests/task_sdk/bases/test_sensor.py | 45 ++++++++++++++--------------
3 files changed, 42 insertions(+), 36 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py
b/task-sdk/src/airflow/sdk/definitions/dag.py
index 599f6dd81c0..30e37b896dc 100644
--- a/task-sdk/src/airflow/sdk/definitions/dag.py
+++ b/task-sdk/src/airflow/sdk/definitions/dag.py
@@ -43,7 +43,7 @@ from airflow.exceptions import (
RemovedInAirflow4Warning,
TaskNotFound,
)
-from airflow.sdk import TriggerRule
+from airflow.sdk import TaskInstanceState, TriggerRule
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions._internal.node import validate_key
from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
@@ -85,6 +85,15 @@ __all__ = [
"dag",
]
+FINISHED_STATES = frozenset(
+ [
+ TaskInstanceState.SUCCESS,
+ TaskInstanceState.FAILED,
+ TaskInstanceState.SKIPPED,
+ TaskInstanceState.UPSTREAM_FAILED,
+ TaskInstanceState.REMOVED,
+ ]
+)
DagStateChangeCallback = Callable[[Context], None]
ScheduleInterval = None | str | timedelta | relativedelta
@@ -1166,10 +1175,9 @@ class DAG:
from airflow import settings
from airflow.configuration import secrets_backend_list
from airflow.models.dagrun import DagRun, get_or_create_dagrun
- from airflow.sdk import DagRunState, TaskInstanceState, timezone
+ from airflow.sdk import DagRunState, timezone
from airflow.secrets.local_filesystem import LocalFilesystemBackend
from airflow.serialization.serialized_objects import SerializedDAG
- from airflow.utils.state import State
from airflow.utils.types import DagRunTriggeredByType, DagRunType
exit_stack = ExitStack()
@@ -1291,7 +1299,7 @@ class DAG:
# triggerer may mark tasks scheduled so we read from DB
all_tis = set(dr.get_task_instances(session=session))
scheduled_tis = {x for x in all_tis if x.state ==
TaskInstanceState.SCHEDULED}
- ids_unrunnable = {x for x in all_tis if x.state not in
State.finished} - scheduled_tis
+ ids_unrunnable = {x for x in all_tis if x.state not in
FINISHED_STATES} - scheduled_tis
if not scheduled_tis and ids_unrunnable:
log.warning("No tasks to run. unrunnable tasks: %s",
ids_unrunnable)
time.sleep(1)
@@ -1331,7 +1339,7 @@ class DAG:
# Run the task locally
try:
if mark_success:
- ti.set_state(State.SUCCESS)
+ ti.set_state(TaskInstanceState.SUCCESS)
log.info("[DAG TEST] Marking success for %s on
%s", task, ti.logical_date)
else:
_run_task(ti=ti, task=task, run_triggerer=True)
@@ -1363,7 +1371,6 @@ def _run_task(
possible. This function is only meant for the `dag.test` function as a
helper function.
"""
from airflow.sdk.module_loading import import_string
- from airflow.utils.state import State
taskrun_result: TaskRunResult | None
log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id,
ti.map_index)
@@ -1378,7 +1385,7 @@ def _run_task(
# The API Server expects the task instance to be in QUEUED state
before
# it is run.
- ti.set_state(State.QUEUED)
+ ti.set_state(TaskInstanceState.QUEUED)
task_sdk_ti = TaskInstanceSDK(
id=ti.id,
task_id=ti.task_id,
@@ -1394,12 +1401,12 @@ def _run_task(
ti.set_state(taskrun_result.ti.state)
ti.task = create_scheduler_operator(taskrun_result.ti.task)
- if ti.state == State.DEFERRED and isinstance(msg, DeferTask) and
run_triggerer:
+ if ti.state == TaskInstanceState.DEFERRED and isinstance(msg,
DeferTask) and run_triggerer:
from airflow.utils.session import create_session
# API Server expects the task instance to be in QUEUED state
before
# resuming from deferral.
- ti.set_state(State.QUEUED)
+ ti.set_state(TaskInstanceState.QUEUED)
log.info("[DAG TEST] running trigger in line")
trigger = import_string(msg.classpath)(**msg.trigger_kwargs)
@@ -1410,15 +1417,15 @@ def _run_task(
# Set the state to SCHEDULED so that the task can be resumed.
with create_session() as session:
- ti.state = State.SCHEDULED
+ ti.state = TaskInstanceState.SCHEDULED
session.add(ti)
continue
break
except Exception:
log.exception("[DAG TEST] Error running task %s", ti)
- if ti.state not in State.finished:
- ti.set_state(State.FAILED)
+ if ti.state not in FINISHED_STATES:
+ ti.set_state(TaskInstanceState.FAILED)
taskrun_result = None
break
raise
diff --git a/task-sdk/tests/task_sdk/api/test_client.py
b/task-sdk/tests/task_sdk/api/test_client.py
index 32a709663ac..e0f3866cc55 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -39,6 +39,7 @@ from airflow.sdk.api.datamodels._generated import (
DagRunStateResponse,
HITLDetailResponse,
HITLUser,
+ TerminalTIState,
VariableResponse,
XComResponse,
)
@@ -52,7 +53,6 @@ from airflow.sdk.execution_time.comms import (
RescheduleTask,
TaskRescheduleStartDate,
)
-from airflow.utils.state import TerminalTIState
if TYPE_CHECKING:
from time_machine import TimeMachineFixture
diff --git a/task-sdk/tests/task_sdk/bases/test_sensor.py
b/task-sdk/tests/task_sdk/bases/test_sensor.py
index 2a3d61f0b52..214fef860f4 100644
--- a/task-sdk/tests/task_sdk/bases/test_sensor.py
+++ b/task-sdk/tests/task_sdk/bases/test_sensor.py
@@ -34,12 +34,11 @@ from airflow.exceptions import (
)
from airflow.models.trigger import TriggerFailureReason
from airflow.providers.standard.operators.empty import EmptyOperator
-from airflow.sdk import timezone
+from airflow.sdk import TaskInstanceState, timezone
from airflow.sdk.bases.sensor import BaseSensorOperator, PokeReturnValue,
poke_mode_only
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.execution_time.comms import RescheduleTask,
TaskRescheduleStartDate
from airflow.sdk.timezone import datetime
-from airflow.utils.state import State
if TYPE_CHECKING:
from airflow.sdk.definitions.context import Context
@@ -172,7 +171,7 @@ class TestBaseSensor:
state, msg, _ = run_task(task=sensor)
- assert state == State.UP_FOR_RESCHEDULE
+ assert state == TaskInstanceState.UP_FOR_RESCHEDULE
assert msg.reschedule_date == date1 +
timedelta(seconds=sensor.poke_interval)
# second poke returns False and task is re-scheduled
@@ -180,14 +179,14 @@ class TestBaseSensor:
date2 = date1 + timedelta(seconds=sensor.poke_interval)
state, msg, _ = run_task(task=sensor)
- assert state == State.UP_FOR_RESCHEDULE
+ assert state == TaskInstanceState.UP_FOR_RESCHEDULE
assert msg.reschedule_date == date2 +
timedelta(seconds=sensor.poke_interval)
# third poke returns True and task succeeds
time_machine.coordinates.shift(sensor.poke_interval)
state, _, _ = run_task(task=sensor)
- assert state == State.SUCCESS
+ assert state == TaskInstanceState.SUCCESS
def test_fail_with_reschedule(self, run_task, make_sensor, time_machine,
mock_supervisor_comms):
sensor = make_sensor(return_value=False, poke_interval=10, timeout=5,
mode="reschedule")
@@ -198,7 +197,7 @@ class TestBaseSensor:
state, msg, _ = run_task(task=sensor)
- assert state == State.UP_FOR_RESCHEDULE
+ assert state == TaskInstanceState.UP_FOR_RESCHEDULE
assert msg.reschedule_date == date1 +
timedelta(seconds=sensor.poke_interval)
# second poke returns False, timeout occurs
@@ -208,7 +207,7 @@ class TestBaseSensor:
mock_supervisor_comms.send.return_value =
TaskRescheduleStartDate(start_date=date1)
state, msg, error = run_task(task=sensor,
context_update={"task_reschedule_count": 1})
- assert state == State.FAILED
+ assert state == TaskInstanceState.FAILED
assert isinstance(error, AirflowSensorTimeout)
def test_soft_fail_with_reschedule(self, run_task, make_sensor,
time_machine, mock_supervisor_comms):
@@ -221,7 +220,7 @@ class TestBaseSensor:
time_machine.move_to(date1, tick=False)
state, msg, _ = run_task(task=sensor)
- assert state == State.UP_FOR_RESCHEDULE
+ assert state == TaskInstanceState.UP_FOR_RESCHEDULE
# second poke returns False, timeout occurs
time_machine.coordinates.shift(sensor.poke_interval)
@@ -229,7 +228,7 @@ class TestBaseSensor:
# Mocking values from DB/API-server
mock_supervisor_comms.send.return_value =
TaskRescheduleStartDate(start_date=date1)
state, msg, _ = run_task(task=sensor,
context_update={"task_reschedule_count": 1})
- assert state == State.SKIPPED
+ assert state == TaskInstanceState.SKIPPED
def test_ok_with_reschedule_and_exponential_backoff(
self, run_task, make_sensor, time_machine, mock_supervisor_comms
@@ -261,7 +260,7 @@ class TestBaseSensor:
curr_date = curr_date + timedelta(seconds=new_interval)
time_machine.coordinates.shift(new_interval)
state, msg, _ = run_task(sensor,
context_update={"task_reschedule_count": _poke_count})
- assert state == State.UP_FOR_RESCHEDULE
+ assert state == TaskInstanceState.UP_FOR_RESCHEDULE
old_interval = new_interval
new_interval = sensor._get_next_poke_interval(task_start_date,
run_duration, _poke_count)
assert old_interval < new_interval # actual test
@@ -272,7 +271,7 @@ class TestBaseSensor:
time_machine.coordinates.shift(new_interval)
state, msg, _ = run_task(sensor,
context_update={"task_reschedule_count": false_count + 1})
- assert state == State.SUCCESS
+ assert state == TaskInstanceState.SUCCESS
def test_invalid_mode(self):
with pytest.raises(AirflowException):
@@ -291,7 +290,7 @@ class TestBaseSensor:
with time_machine.travel(date1, tick=False):
state, msg, error = run_task(sensor)
- assert state == State.UP_FOR_RESCHEDULE
+ assert state == TaskInstanceState.UP_FOR_RESCHEDULE
assert isinstance(msg, RescheduleTask)
assert msg.reschedule_date == date2
@@ -299,14 +298,14 @@ class TestBaseSensor:
with time_machine.travel(date2, tick=False):
state, msg, error = run_task(sensor)
- assert state == State.UP_FOR_RESCHEDULE
+ assert state == TaskInstanceState.UP_FOR_RESCHEDULE
assert isinstance(msg, RescheduleTask)
assert msg.reschedule_date == date3
# third poke returns True and task succeeds
with time_machine.travel(date3, tick=False):
state, _, _ = run_task(sensor)
- assert state == State.SUCCESS
+ assert state == TaskInstanceState.SUCCESS
def test_sensor_with_invalid_poke_interval(self):
negative_poke_interval = -10
@@ -523,36 +522,36 @@ class TestBaseSensor:
context_update={"task_reschedule_count":
test_state["task_reschedule_count"]},
)
- if state == State.UP_FOR_RESCHEDULE:
+ if state == TaskInstanceState.UP_FOR_RESCHEDULE:
test_state["task_reschedule_count"] += 1
# Only set first_reschedule_date on the first successful
reschedule
if test_state["first_reschedule_date"] is None:
test_state["first_reschedule_date"] =
test_state["current_time"]
- elif state == State.UP_FOR_RETRY:
+ elif state == TaskInstanceState.UP_FOR_RETRY:
test_state["try_number"] += 1
return state, msg, error
# Phase 1: Initial execution until failure
# First poke - should reschedule
state, _, _ = _run_task()
- assert state == State.UP_FOR_RESCHEDULE
+ assert state == TaskInstanceState.UP_FOR_RESCHEDULE
# Second poke - should raise RuntimeError and retry
test_state["current_time"] += timedelta(seconds=sensor.poke_interval)
state, _, error = _run_task()
- assert state == State.UP_FOR_RETRY
+ assert state == TaskInstanceState.UP_FOR_RETRY
assert isinstance(error, RuntimeError)
# Third poke - should reschedule again
test_state["current_time"] += sensor.retry_delay + timedelta(seconds=1)
state, _, _ = _run_task()
- assert state == State.UP_FOR_RESCHEDULE
+ assert state == TaskInstanceState.UP_FOR_RESCHEDULE
# Fourth poke - should timeout
test_state["current_time"] += timedelta(seconds=sensor.poke_interval)
state, _, error = _run_task()
assert isinstance(error, AirflowSensorTimeout)
- assert state == State.FAILED
+ assert state == TaskInstanceState.FAILED
# Phase 2: After clearing the failed sensor
# Reset supervisor comms to return None, simulating a fresh start
after clearing
@@ -564,13 +563,13 @@ class TestBaseSensor:
for _ in range(3):
test_state["current_time"] +=
timedelta(seconds=sensor.poke_interval)
state, _, _ = _run_task()
- assert state == State.UP_FOR_RESCHEDULE
+ assert state == TaskInstanceState.UP_FOR_RESCHEDULE
# Final poke - should timeout
test_state["current_time"] += timedelta(seconds=sensor.poke_interval)
state, _, error = _run_task()
assert isinstance(error, AirflowSensorTimeout)
- assert state == State.FAILED
+ assert state == TaskInstanceState.FAILED
def test_sensor_with_xcom(self, make_sensor):
xcom_value = "TestValue"
@@ -615,7 +614,7 @@ class TestBaseSensor:
state, _, error = run_task(task=task,
dag_id=f"test_sensor_timeout_{mode}_{retries}")
assert isinstance(error, AirflowSensorTimeout)
- assert state == State.FAILED
+ assert state == TaskInstanceState.FAILED
@poke_mode_only