This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 44ac115937a Add a standard toggle for resumability to
ResumableJobMixin (#68623)
44ac115937a is described below
commit 44ac115937a9756b8a1c15bafd819b7bd11f22f2
Author: Amogh Desai <[email protected]>
AuthorDate: Tue Jun 23 14:25:42 2026 +0530
Add a standard toggle for resumability to ResumableJobMixin (#68623)
---
docs/spelling_wordlist.txt | 1 +
providers/apache/spark/docs/operators.rst | 6 +--
.../apache/spark/operators/spark_submit.py | 49 +++++++++++-----------
.../apache/spark/operators/test_spark_submit.py | 41 +++++++++++-------
task-sdk/docs/resumable-job-mixin.rst | 23 +++++++++-
.../src/airflow/sdk/bases/resumablejobmixin.py | 15 +++++++
.../tests/task_sdk/bases/test_resumablejobmixin.py | 29 +++++++++++++
7 files changed, 122 insertions(+), 42 deletions(-)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 133c37b413b..6a9616c7c0d 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1370,6 +1370,7 @@ rebase
Rebasing
Recency
recurse
+redeclare
redelivery
Redhat
redis
diff --git a/providers/apache/spark/docs/operators.rst
b/providers/apache/spark/docs/operators.rst
index 4d3a9a526af..d20c1da5cf7 100644
--- a/providers/apache/spark/docs/operators.rst
+++ b/providers/apache/spark/docs/operators.rst
@@ -236,7 +236,7 @@ Python Kubernetes client rather than holding
``spark-submit`` open for the full
conn_id="spark_k8s",
deploy_mode="cluster",
track_driver_via_k8s_api=True,
- reconnect_on_retry=True,
+ durable=True,
)
**Requirements**
@@ -246,9 +246,9 @@ Python Kubernetes client rather than holding
``spark-submit`` open for the full
conflicts with the flag and a ``ValueError`` will be raised at task start.
* The Airflow worker must be able to reach the Kubernetes API server and have
permission to
read and delete pods in the driver's namespace; otherwise pod tracking and
cleanup will fail.
-* Set ``reconnect_on_retry=True`` (the default) to enable crash recovery: the
driver pod name is
+* Set ``durable=True`` (the default) to enable crash recovery: the driver pod
name is
persisted to task state before polling begins, so a worker crash and retry
reconnects to the
- existing pod instead of submitting a fresh one. Set
``reconnect_on_retry=False`` to always
+ existing pod instead of submitting a fresh one. Set ``durable=False`` to
always
submit a fresh driver on retry.
* Pod completion is detected from ``pod.status.phase``. If your driver pods
have sidecar
containers (e.g. Istio injection enabled for the driver namespace), the pod
phase may not
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index 7ceb95b387a..8a2129e9673 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -17,12 +17,14 @@
# under the License.
from __future__ import annotations
+import warnings
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, cast
import requests
from tenacity import retry, stop_after_attempt, wait_fixed
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.apache.spark.hooks.spark_submit import
_K8S_WAIT_APP_COMPLETION_CONF, SparkSubmitHook
from airflow.providers.common.compat.openlineage.utils.spark import (
inject_parent_job_information_into_spark_properties,
@@ -46,6 +48,11 @@ except ImportError:
external_id_key: str = "remote_job_id"
+ def __init__(self, *, durable: bool = True, **kwargs: Any) -> None:
+ # Accept durable so the kwarg doesn't leak to BaseOperator; crash
recovery is a no-op here.
+ super().__init__(**kwargs)
+ self.durable = durable
+
def execute_resumable(self, context):
external_id = self.submit_job(context)
self.poll_until_complete(external_id, context)
@@ -139,6 +146,9 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
omitted, Kerberos-enabled Spark connections with both ``keytab`` and
``principal`` configured use ``requests-kerberos`` automatically.
Defaults to ``None`` (no auth for non-Kerberos connections).
+ :param durable: When ``True`` (the default), the external job ID is
persisted to task state
+ store before polling begins so that a worker crash and retry
reconnects to the existing job
+ instead of submitting a fresh one. Set to ``False`` to always submit a
new job on retry.
"""
# Generic key used across all Spark deployment modes (standalone driver ID,
@@ -203,7 +213,6 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
deploy_mode: str | None = None,
use_krb5ccache: bool = False,
post_submit_commands: list[str] | None = None,
- reconnect_on_retry: bool = True,
track_driver_via_k8s_api: bool = False,
yarn_track_via_rm_api: bool = False,
yarn_rm_auth: AuthBase | None = None,
@@ -213,8 +222,16 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
openlineage_inject_transport_info: bool = conf.getboolean(
"openlineage", "spark_inject_transport_info", fallback=False
),
+ reconnect_on_retry: bool | None = None,
**kwargs: Any,
) -> None:
+ if reconnect_on_retry is not None:
+ warnings.warn(
+ "reconnect_on_retry is renamed to durable.",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ kwargs.setdefault("durable", reconnect_on_retry)
super().__init__(**kwargs)
self.application = application
self.conf = conf
@@ -252,7 +269,6 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
self._yarn_track_via_rm_api = yarn_track_via_rm_api
self._yarn_rm_auth = yarn_rm_auth
- self.reconnect_on_retry = reconnect_on_retry
self._track_driver_via_k8s_api = track_driver_via_k8s_api
self._openlineage_inject_parent_job_info =
openlineage_inject_parent_job_info
self._openlineage_inject_transport_info =
openlineage_inject_transport_info
@@ -272,33 +288,18 @@ class SparkSubmitOperator(ResumableJobMixin,
BaseOperator):
if self._track_driver_via_k8s_api:
hook._validate_track_driver_via_k8s_api_config()
if hook._should_track_driver_status:
- if self.reconnect_on_retry:
- return self.execute_resumable(context)
- # reconnect_on_retry=False: still submit-and-poll, just skip
task_state_store persistence.
- driver_id = self.submit_job(context)
- self.poll_until_complete(driver_id, context)
- return self.get_job_result(driver_id, context)
+ return self.execute_resumable(context)
if hook._should_track_driver_via_k8s_api():
- if self.reconnect_on_retry:
- return self.execute_resumable(context)
- # reconnect_on_retry=False: still submit-and-poll, just skip
task_state persistence.
- driver_id = self.submit_job(context)
- self.poll_until_complete(driver_id, context)
- return self.get_job_result(driver_id, context)
+ return self.execute_resumable(context)
if hook._is_yarn_cluster_mode:
- if self.reconnect_on_retry and not hook._yarn_track_via_rm_api:
+ if self.durable and not hook._yarn_track_via_rm_api:
raise ValueError(
- "YARN cluster mode with reconnect_on_retry=True requires
yarn_track_via_rm_api=True. "
+ "YARN cluster mode with durable=True requires
yarn_track_via_rm_api=True. "
"The RM REST API is needed to check application status on
retry."
)
if hook._yarn_track_via_rm_api:
hook._validate_yarn_track_via_rm_api_config()
- if self.reconnect_on_retry:
- return self.execute_resumable(context)
- # reconnect_on_retry=False: still submit-and-poll, just skip
task_state_store persistence.
- driver_id = self.submit_job(context)
- self.poll_until_complete(driver_id, context)
- return self.get_job_result(driver_id, context)
+ return self.execute_resumable(context)
hook.submit(self.application)
def submit_job(self, context: Context) -> str | None:
@@ -319,7 +320,7 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
raise ValueError(
"spark.yarn.submit.waitAppCompletion=true cannot be set
for cluster mode as it conflicts"
"with the need to exit spark-submit immediately to persist
the application ID for tracking. "
- "Either remove the explicit conf or set
reconnect_on_retry=False."
+ "Either remove the explicit conf or set durable=False."
)
self._hook._conf["spark.yarn.submit.waitAppCompletion"] = "false"
self._hook.submit(self.application)
@@ -445,7 +446,7 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
# Cache only when the pod actually reached Succeeded, the
404/vanished path
# returns None for cases like: pod deleted by on_kill or garbage
collected after failure)
# and must not be cached, otherwise a retry would see "Succeeded"
and skip resubmission.
- if terminal_phase == "Succeeded" and self.reconnect_on_retry:
+ if terminal_phase == "Succeeded" and self.durable:
if (task_state_store := context.get("task_state_store")) is
not None:
task_state_store.set(self._K8S_DRIVER_STATUS_KEY,
"Succeeded")
return
diff --git
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
index 9708aab1a02..daa0fa119ca 100644
---
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
+++
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
@@ -18,12 +18,14 @@
from __future__ import annotations
import logging
+import warnings
from datetime import timedelta
from unittest import mock
from unittest.mock import MagicMock
import pytest
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import DagRun, TaskInstance
from airflow.models.dag import DAG
from airflow.providers.apache.spark.operators.spark_submit import
SparkSubmitOperator
@@ -590,8 +592,17 @@ class TestSparkSubmitOperatorResumable:
operator._hook.submit.assert_called_once_with("test.jar")
assert polled == ["driver-001"]
- def test_reconnect_on_retry_false_submits_fresh_and_polls(self):
- operator = self._make_operator(reconnect_on_retry=False)
+ def test_reconnect_on_retry_deprecated_alias(self):
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ operator = self._make_operator(reconnect_on_retry=False)
+ assert len(w) == 1
+ assert issubclass(w[0].category, AirflowProviderDeprecationWarning)
+ assert "reconnect_on_retry" in str(w[0].message)
+ assert operator.durable is False
+
+ def test_durable_false_submits_fresh_and_polls(self):
+ operator = self._make_operator(durable=False)
operator._hook = self._make_hook(should_track=True)
operator._hook.submit.return_value = "driver-new"
task_store = FakeTaskStateStore({"spark_job_id": "driver-old"})
@@ -599,7 +610,7 @@ class TestSparkSubmitOperatorResumable:
operator.poll_until_complete = lambda external_id, context:
polled.append(external_id)
operator.execute(context={"task_state_store": task_store})
- # reconnect_on_retry=False: ignores prior driver ID, submits fresh,
but still polls
+ # durable=False: ignores prior driver ID, submits fresh, but still
polls
operator._hook.submit.assert_called_once_with("test.jar")
assert polled == ["driver-new"]
@@ -863,8 +874,8 @@ class TestSparkSubmitOperatorResumable:
hook._kill_yarn_application.assert_called_once_with("application_1234_0001")
def test_yarn_cluster_reconnect_without_rm_api_raises(self):
- """reconnect_on_retry=True + yarn_track_via_rm_api=False must raise -
RM API is required for resume."""
- operator = self._make_operator(reconnect_on_retry=True)
+ """durable=True + yarn_track_via_rm_api=False must raise - RM API is
required for resume."""
+ operator = self._make_operator(durable=True)
hook = self._make_hook(is_yarn_cluster=True)
hook._yarn_track_via_rm_api = False
operator._hook = hook
@@ -873,8 +884,8 @@ class TestSparkSubmitOperatorResumable:
operator.execute(context={})
def
test_yarn_cluster_without_rm_api_reconnect_false_falls_through_to_hook_submit(self):
- """reconnect_on_retry=False + yarn_track_via_rm_api=False falls
through to hook.submit() - no RM polling."""
- operator = self._make_operator(reconnect_on_retry=False)
+ """durable=False + yarn_track_via_rm_api=False falls through to
hook.submit() - no RM polling."""
+ operator = self._make_operator(durable=False)
hook = self._make_hook(is_yarn_cluster=True)
hook._yarn_track_via_rm_api = False
operator._hook = hook
@@ -1028,6 +1039,7 @@ class TestSparkSubmitOperatorK8sTracking:
assert hook._kubernetes_driver_pod == "spark-abc-driver"
hook._poll_k8s_driver_via_api.assert_called_once()
+ @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="task_state_store
requires Airflow 3.3+")
def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self):
operator = self._make_operator(track_driver_via_k8s_api=True)
hook = self._make_k8s_hook()
@@ -1039,8 +1051,9 @@ class TestSparkSubmitOperatorK8sTracking:
assert task_store.get("k8s_driver_status") == "Succeeded"
+ @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="task_state_store
requires Airflow 3.3+")
def
test_k8s_polling_does_not_write_task_store_when_reconnect_disabled(self):
- operator = self._make_operator(track_driver_via_k8s_api=True,
reconnect_on_retry=False)
+ operator = self._make_operator(track_driver_via_k8s_api=True,
durable=False)
hook = self._make_k8s_hook()
hook._poll_k8s_driver_via_api.return_value = "Succeeded"
operator._hook = hook
@@ -1072,9 +1085,9 @@ class TestSparkSubmitOperatorK8sTracking:
not AIRFLOW_V_3_3_PLUS,
reason="ResumableJobMixin reconnect requires task_state, available in
Airflow 3.3+",
)
- def test_k8s_execute_persists_pod_id_when_reconnect_on_retry(self):
- """execute() with reconnect_on_retry=True stores the pod ID in
task_store before polling."""
- operator = self._make_operator(track_driver_via_k8s_api=True,
reconnect_on_retry=True)
+ def test_k8s_execute_persists_pod_id_when_durable(self):
+ """execute() with durable=True stores the pod ID in task_store before
polling."""
+ operator = self._make_operator(track_driver_via_k8s_api=True,
durable=True)
hook = self._make_k8s_hook()
hook._kubernetes_driver_pod = "spark-abc-driver"
hook._connection = {"namespace": "mynamespace"}
@@ -1095,9 +1108,9 @@ class TestSparkSubmitOperatorK8sTracking:
not AIRFLOW_V_3_3_PLUS,
reason="ResumableJobMixin reconnect requires task_state, available in
Airflow 3.3+",
)
- def
test_k8s_execute_reconnect_on_retry_false_does_not_persist_pod_id(self):
- """execute() with reconnect_on_retry=False does not write spark_job_id
to task_store."""
- operator = self._make_operator(track_driver_via_k8s_api=True,
reconnect_on_retry=False)
+ def test_k8s_execute_durable_false_does_not_persist_pod_id(self):
+ """execute() with durable=False does not write spark_job_id to
task_store."""
+ operator = self._make_operator(track_driver_via_k8s_api=True,
durable=False)
hook = self._make_k8s_hook()
hook._kubernetes_driver_pod = "spark-abc-driver"
hook._connection = {"namespace": "mynamespace"}
diff --git a/task-sdk/docs/resumable-job-mixin.rst
b/task-sdk/docs/resumable-job-mixin.rst
index e78a2e5ce1a..345fc4c4494 100644
--- a/task-sdk/docs/resumable-job-mixin.rst
+++ b/task-sdk/docs/resumable-job-mixin.rst
@@ -120,7 +120,7 @@ Example
from pydantic import JsonValue
- class MyBatchOperator(BaseOperator, ResumableJobMixin):
+ class MyBatchOperator(ResumableJobMixin, BaseOperator):
external_id_key = "batch_job_id"
@@ -145,6 +145,27 @@ Example
def get_job_result(self, external_id: JsonValue, context):
return None
+.. _sdk-resumable-job-mixin-resume-on-retry:
+
+Disabling crash recovery per task
+----------------------------------
+
+Set ``durable=False`` on a task to opt out of crash recovery for that specific
instance.
+The operator will always submit a fresh job on retry, with no
``task_state_store`` interaction:
+
+.. code-block:: python
+
+ run_spark = MyBatchOperator(
+ task_id="run_spark",
+ durable=False,
+ )
+
+This is useful when the external job is not idempotent and you want Airflow to
always submit a
+clean run rather than reconnect to a prior submission.
+
+The default is ``True``. ``durable`` is owned by the mixin — operators do not
need to
+redeclare it. ``default_args`` injection and ``.partial()`` work automatically.
+
.. _sdk-resumable-job-mixin-external-id-key:
External ID key
diff --git a/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py
b/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py
index 7066d10cced..27533dbe840 100644
--- a/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py
+++ b/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any
from opentelemetry import trace
from airflow.sdk._shared.observability.metrics import stats
+from airflow.sdk.bases.operator import BaseOperatorMeta
if TYPE_CHECKING:
from pydantic import JsonValue
@@ -90,6 +91,14 @@ class ResumableJobMixin:
# Renaming this on a deployed operator breaks in-flight retries — the old
key is already stored.
external_id_key: str = "remote_job_id"
+ # The mixin is not a BaseOperator subclass, but _apply_defaults is only
ever called on concrete
+ # operators that are BaseOperator subclasses. That is a runtime MRO
guarantee not visible in the static
+ # type signature here and hence we need the type ignore.
+ @BaseOperatorMeta._apply_defaults # type: ignore[type-var]
+ def __init__(self, *, durable: bool = True, **kwargs: Any) -> None:
+ super().__init__(**kwargs)
+ self.durable = durable
+
def execute_resumable(self, context: Context) -> Any:
"""
Core of the resumable execution logic. Call this from execute() when
reconnection is supported.
@@ -107,6 +116,11 @@ class ResumableJobMixin:
Closing this window would require atomic "submit + persist", which is
not possible across
an external system boundary.
"""
+ if not self.durable:
+ external_id = self.submit_job(context)
+ self.poll_until_complete(external_id, context)
+ return self.get_job_result(external_id, context)
+
stats_tags = {"operator": type(self).__name__}
# The task is team-scoped in multi-team deployments; surface team_name
on the
# resumable_job metrics via the running task instance's stats tags
(omitted when
@@ -114,6 +128,7 @@ class ResumableJobMixin:
ti = context.get("ti")
if ti is not None and (team_name := ti.stats_tags.get("team_name")):
stats_tags["team_name"] = team_name
+
reconnect_to: Any = None
already_succeeded_id: Any = None
diff --git a/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py
b/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py
index e186cc19a36..8e962583e99 100644
--- a/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py
+++ b/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py
@@ -206,6 +206,35 @@ class TestNoneExternalId:
assert task_state._store == {}
+class TestResumeOnRetryDisabled:
+ def test_submits_and_polls_without_task_store_interaction(self):
+ op = ConcreteResumableOperator(task_id="test_task", durable=False)
+ task_store = FakeTaskState()
+ op.execute_resumable(make_context(task_store))
+
+ assert op.submitted_ids == ["job-001"]
+ assert op.polled_ids == ["job-001"]
+ assert task_store._store == {}, "task_store must not be written when
durable=False"
+
+ def test_does_not_reconnect_when_prior_id_exists(self):
+ op = ConcreteResumableOperator(task_id="test_task", durable=False)
+ op._status_map["job-001"] = "RUNNING"
+ task_store = FakeTaskState({"test_job_id": "job-001"})
+
+ op.execute_resumable(make_context(task_store))
+
+ assert op.submitted_ids == ["job-001"], "should submit fresh even with
a prior ID stored"
+
+ def test_returns_result(self):
+ op = ConcreteResumableOperator(task_id="test_task", durable=False)
+ result = op.execute_resumable(make_context(FakeTaskState()))
+ assert result == "result-of-job-001"
+
+ def test_default_is_true(self):
+ op = ConcreteResumableOperator(task_id="test_task")
+ assert op.durable is True
+
+
class TestExternalIdKey:
def test_custom_key_used_for_storage_and_retrieval(self):
class CustomKeyOp(ConcreteResumableOperator):