This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 44ac115937a Add a standard toggle for resumability to 
ResumableJobMixin (#68623)
44ac115937a is described below

commit 44ac115937a9756b8a1c15bafd819b7bd11f22f2
Author: Amogh Desai <[email protected]>
AuthorDate: Tue Jun 23 14:25:42 2026 +0530

    Add a standard toggle for resumability to ResumableJobMixin (#68623)
---
 docs/spelling_wordlist.txt                         |  1 +
 providers/apache/spark/docs/operators.rst          |  6 +--
 .../apache/spark/operators/spark_submit.py         | 49 +++++++++++-----------
 .../apache/spark/operators/test_spark_submit.py    | 41 +++++++++++-------
 task-sdk/docs/resumable-job-mixin.rst              | 23 +++++++++-
 .../src/airflow/sdk/bases/resumablejobmixin.py     | 15 +++++++
 .../tests/task_sdk/bases/test_resumablejobmixin.py | 29 +++++++++++++
 7 files changed, 122 insertions(+), 42 deletions(-)

diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 133c37b413b..6a9616c7c0d 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1370,6 +1370,7 @@ rebase
 Rebasing
 Recency
 recurse
+redeclare
 redelivery
 Redhat
 redis
diff --git a/providers/apache/spark/docs/operators.rst 
b/providers/apache/spark/docs/operators.rst
index 4d3a9a526af..d20c1da5cf7 100644
--- a/providers/apache/spark/docs/operators.rst
+++ b/providers/apache/spark/docs/operators.rst
@@ -236,7 +236,7 @@ Python Kubernetes client rather than holding 
``spark-submit`` open for the full
        conn_id="spark_k8s",
        deploy_mode="cluster",
        track_driver_via_k8s_api=True,
-       reconnect_on_retry=True,
+       durable=True,
    )
 
 **Requirements**
@@ -246,9 +246,9 @@ Python Kubernetes client rather than holding 
``spark-submit`` open for the full
   conflicts with the flag and a ``ValueError`` will be raised at task start.
 * The Airflow worker must be able to reach the Kubernetes API server and have 
permission to
   read and delete pods in the driver's namespace; otherwise pod tracking and 
cleanup will fail.
-* Set ``reconnect_on_retry=True`` (the default) to enable crash recovery: the 
driver pod name is
+* Set ``durable=True`` (the default) to enable crash recovery: the driver pod 
name is
   persisted to task state before polling begins, so a worker crash and retry 
reconnects to the
-  existing pod instead of submitting a fresh one. Set 
``reconnect_on_retry=False`` to always
+  existing pod instead of submitting a fresh one. Set ``durable=False`` to 
always
   submit a fresh driver on retry.
 * Pod completion is detected from ``pod.status.phase``. If your driver pods 
have sidecar
   containers (e.g. Istio injection enabled for the driver namespace), the pod 
phase may not
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index 7ceb95b387a..8a2129e9673 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -17,12 +17,14 @@
 # under the License.
 from __future__ import annotations
 
+import warnings
 from collections.abc import Sequence
 from typing import TYPE_CHECKING, Any, cast
 
 import requests
 from tenacity import retry, stop_after_attempt, wait_fixed
 
+from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.providers.apache.spark.hooks.spark_submit import 
_K8S_WAIT_APP_COMPLETION_CONF, SparkSubmitHook
 from airflow.providers.common.compat.openlineage.utils.spark import (
     inject_parent_job_information_into_spark_properties,
@@ -46,6 +48,11 @@ except ImportError:
 
         external_id_key: str = "remote_job_id"
 
+        def __init__(self, *, durable: bool = True, **kwargs: Any) -> None:
+            # Accept durable so the kwarg doesn't leak to BaseOperator; crash 
recovery is a no-op here.
+            super().__init__(**kwargs)
+            self.durable = durable
+
         def execute_resumable(self, context):
             external_id = self.submit_job(context)
             self.poll_until_complete(external_id, context)
@@ -139,6 +146,9 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
         omitted, Kerberos-enabled Spark connections with both ``keytab`` and
         ``principal`` configured use ``requests-kerberos`` automatically.
         Defaults to ``None`` (no auth for non-Kerberos connections).
+    :param durable: When ``True`` (the default), the external job ID is 
persisted to task state
+        store before polling begins so that a worker crash and retry 
reconnects to the existing job
+        instead of submitting a fresh one. Set to ``False`` to always submit a 
new job on retry.
     """
 
     # Generic key used across all Spark deployment modes (standalone driver ID,
@@ -203,7 +213,6 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
         deploy_mode: str | None = None,
         use_krb5ccache: bool = False,
         post_submit_commands: list[str] | None = None,
-        reconnect_on_retry: bool = True,
         track_driver_via_k8s_api: bool = False,
         yarn_track_via_rm_api: bool = False,
         yarn_rm_auth: AuthBase | None = None,
@@ -213,8 +222,16 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
         openlineage_inject_transport_info: bool = conf.getboolean(
             "openlineage", "spark_inject_transport_info", fallback=False
         ),
+        reconnect_on_retry: bool | None = None,
         **kwargs: Any,
     ) -> None:
+        if reconnect_on_retry is not None:
+            warnings.warn(
+                "reconnect_on_retry is renamed to durable.",
+                AirflowProviderDeprecationWarning,
+                stacklevel=2,
+            )
+            kwargs.setdefault("durable", reconnect_on_retry)
         super().__init__(**kwargs)
         self.application = application
         self.conf = conf
@@ -252,7 +269,6 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
         self._yarn_track_via_rm_api = yarn_track_via_rm_api
         self._yarn_rm_auth = yarn_rm_auth
 
-        self.reconnect_on_retry = reconnect_on_retry
         self._track_driver_via_k8s_api = track_driver_via_k8s_api
         self._openlineage_inject_parent_job_info = 
openlineage_inject_parent_job_info
         self._openlineage_inject_transport_info = 
openlineage_inject_transport_info
@@ -272,33 +288,18 @@ class SparkSubmitOperator(ResumableJobMixin, 
BaseOperator):
         if self._track_driver_via_k8s_api:
             hook._validate_track_driver_via_k8s_api_config()
         if hook._should_track_driver_status:
-            if self.reconnect_on_retry:
-                return self.execute_resumable(context)
-            # reconnect_on_retry=False: still submit-and-poll, just skip 
task_state_store persistence.
-            driver_id = self.submit_job(context)
-            self.poll_until_complete(driver_id, context)
-            return self.get_job_result(driver_id, context)
+            return self.execute_resumable(context)
         if hook._should_track_driver_via_k8s_api():
-            if self.reconnect_on_retry:
-                return self.execute_resumable(context)
-            # reconnect_on_retry=False: still submit-and-poll, just skip 
task_state persistence.
-            driver_id = self.submit_job(context)
-            self.poll_until_complete(driver_id, context)
-            return self.get_job_result(driver_id, context)
+            return self.execute_resumable(context)
         if hook._is_yarn_cluster_mode:
-            if self.reconnect_on_retry and not hook._yarn_track_via_rm_api:
+            if self.durable and not hook._yarn_track_via_rm_api:
                 raise ValueError(
-                    "YARN cluster mode with reconnect_on_retry=True requires 
yarn_track_via_rm_api=True. "
+                    "YARN cluster mode with durable=True requires 
yarn_track_via_rm_api=True. "
                     "The RM REST API is needed to check application status on 
retry."
                 )
             if hook._yarn_track_via_rm_api:
                 hook._validate_yarn_track_via_rm_api_config()
-                if self.reconnect_on_retry:
-                    return self.execute_resumable(context)
-                # reconnect_on_retry=False: still submit-and-poll, just skip 
task_state_store persistence.
-                driver_id = self.submit_job(context)
-                self.poll_until_complete(driver_id, context)
-                return self.get_job_result(driver_id, context)
+                return self.execute_resumable(context)
         hook.submit(self.application)
 
     def submit_job(self, context: Context) -> str | None:
@@ -319,7 +320,7 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
                 raise ValueError(
                     "spark.yarn.submit.waitAppCompletion=true cannot be set 
for cluster mode as it conflicts"
                     "with the need to exit spark-submit immediately to persist 
the application ID for tracking. "
-                    "Either remove the explicit conf or set 
reconnect_on_retry=False."
+                    "Either remove the explicit conf or set durable=False."
                 )
             self._hook._conf["spark.yarn.submit.waitAppCompletion"] = "false"
             self._hook.submit(self.application)
@@ -445,7 +446,7 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
             # Cache only when the pod actually reached Succeeded, the 
404/vanished path
             # returns None for cases like: pod deleted by on_kill or garbage 
collected after failure)
             # and must not be cached, otherwise a retry would see "Succeeded" 
and skip resubmission.
-            if terminal_phase == "Succeeded" and self.reconnect_on_retry:
+            if terminal_phase == "Succeeded" and self.durable:
                 if (task_state_store := context.get("task_state_store")) is 
not None:
                     task_state_store.set(self._K8S_DRIVER_STATUS_KEY, 
"Succeeded")
             return
diff --git 
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py 
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
index 9708aab1a02..daa0fa119ca 100644
--- 
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
+++ 
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
@@ -18,12 +18,14 @@
 from __future__ import annotations
 
 import logging
+import warnings
 from datetime import timedelta
 from unittest import mock
 from unittest.mock import MagicMock
 
 import pytest
 
+from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.models import DagRun, TaskInstance
 from airflow.models.dag import DAG
 from airflow.providers.apache.spark.operators.spark_submit import 
SparkSubmitOperator
@@ -590,8 +592,17 @@ class TestSparkSubmitOperatorResumable:
         operator._hook.submit.assert_called_once_with("test.jar")
         assert polled == ["driver-001"]
 
-    def test_reconnect_on_retry_false_submits_fresh_and_polls(self):
-        operator = self._make_operator(reconnect_on_retry=False)
+    def test_reconnect_on_retry_deprecated_alias(self):
+        with warnings.catch_warnings(record=True) as w:
+            warnings.simplefilter("always")
+            operator = self._make_operator(reconnect_on_retry=False)
+        assert len(w) == 1
+        assert issubclass(w[0].category, AirflowProviderDeprecationWarning)
+        assert "reconnect_on_retry" in str(w[0].message)
+        assert operator.durable is False
+
+    def test_durable_false_submits_fresh_and_polls(self):
+        operator = self._make_operator(durable=False)
         operator._hook = self._make_hook(should_track=True)
         operator._hook.submit.return_value = "driver-new"
         task_store = FakeTaskStateStore({"spark_job_id": "driver-old"})
@@ -599,7 +610,7 @@ class TestSparkSubmitOperatorResumable:
         operator.poll_until_complete = lambda external_id, context: 
polled.append(external_id)
 
         operator.execute(context={"task_state_store": task_store})
-        # reconnect_on_retry=False: ignores prior driver ID, submits fresh, 
but still polls
+        # durable=False: ignores prior driver ID, submits fresh, but still 
polls
         operator._hook.submit.assert_called_once_with("test.jar")
         assert polled == ["driver-new"]
 
@@ -863,8 +874,8 @@ class TestSparkSubmitOperatorResumable:
         
hook._kill_yarn_application.assert_called_once_with("application_1234_0001")
 
     def test_yarn_cluster_reconnect_without_rm_api_raises(self):
-        """reconnect_on_retry=True + yarn_track_via_rm_api=False must raise - 
RM API is required for resume."""
-        operator = self._make_operator(reconnect_on_retry=True)
+        """durable=True + yarn_track_via_rm_api=False must raise - RM API is 
required for resume."""
+        operator = self._make_operator(durable=True)
         hook = self._make_hook(is_yarn_cluster=True)
         hook._yarn_track_via_rm_api = False
         operator._hook = hook
@@ -873,8 +884,8 @@ class TestSparkSubmitOperatorResumable:
             operator.execute(context={})
 
     def 
test_yarn_cluster_without_rm_api_reconnect_false_falls_through_to_hook_submit(self):
-        """reconnect_on_retry=False + yarn_track_via_rm_api=False falls 
through to hook.submit() - no RM polling."""
-        operator = self._make_operator(reconnect_on_retry=False)
+        """durable=False + yarn_track_via_rm_api=False falls through to 
hook.submit() - no RM polling."""
+        operator = self._make_operator(durable=False)
         hook = self._make_hook(is_yarn_cluster=True)
         hook._yarn_track_via_rm_api = False
         operator._hook = hook
@@ -1028,6 +1039,7 @@ class TestSparkSubmitOperatorK8sTracking:
         assert hook._kubernetes_driver_pod == "spark-abc-driver"
         hook._poll_k8s_driver_via_api.assert_called_once()
 
+    @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="task_state_store 
requires Airflow 3.3+")
     def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self):
         operator = self._make_operator(track_driver_via_k8s_api=True)
         hook = self._make_k8s_hook()
@@ -1039,8 +1051,9 @@ class TestSparkSubmitOperatorK8sTracking:
 
         assert task_store.get("k8s_driver_status") == "Succeeded"
 
+    @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="task_state_store 
requires Airflow 3.3+")
     def 
test_k8s_polling_does_not_write_task_store_when_reconnect_disabled(self):
-        operator = self._make_operator(track_driver_via_k8s_api=True, 
reconnect_on_retry=False)
+        operator = self._make_operator(track_driver_via_k8s_api=True, 
durable=False)
         hook = self._make_k8s_hook()
         hook._poll_k8s_driver_via_api.return_value = "Succeeded"
         operator._hook = hook
@@ -1072,9 +1085,9 @@ class TestSparkSubmitOperatorK8sTracking:
         not AIRFLOW_V_3_3_PLUS,
         reason="ResumableJobMixin reconnect requires task_state, available in 
Airflow 3.3+",
     )
-    def test_k8s_execute_persists_pod_id_when_reconnect_on_retry(self):
-        """execute() with reconnect_on_retry=True stores the pod ID in 
task_store before polling."""
-        operator = self._make_operator(track_driver_via_k8s_api=True, 
reconnect_on_retry=True)
+    def test_k8s_execute_persists_pod_id_when_durable(self):
+        """execute() with durable=True stores the pod ID in task_store before 
polling."""
+        operator = self._make_operator(track_driver_via_k8s_api=True, 
durable=True)
         hook = self._make_k8s_hook()
         hook._kubernetes_driver_pod = "spark-abc-driver"
         hook._connection = {"namespace": "mynamespace"}
@@ -1095,9 +1108,9 @@ class TestSparkSubmitOperatorK8sTracking:
         not AIRFLOW_V_3_3_PLUS,
         reason="ResumableJobMixin reconnect requires task_state, available in 
Airflow 3.3+",
     )
-    def 
test_k8s_execute_reconnect_on_retry_false_does_not_persist_pod_id(self):
-        """execute() with reconnect_on_retry=False does not write spark_job_id 
to task_store."""
-        operator = self._make_operator(track_driver_via_k8s_api=True, 
reconnect_on_retry=False)
+    def test_k8s_execute_durable_false_does_not_persist_pod_id(self):
+        """execute() with durable=False does not write spark_job_id to 
task_store."""
+        operator = self._make_operator(track_driver_via_k8s_api=True, 
durable=False)
         hook = self._make_k8s_hook()
         hook._kubernetes_driver_pod = "spark-abc-driver"
         hook._connection = {"namespace": "mynamespace"}
diff --git a/task-sdk/docs/resumable-job-mixin.rst 
b/task-sdk/docs/resumable-job-mixin.rst
index e78a2e5ce1a..345fc4c4494 100644
--- a/task-sdk/docs/resumable-job-mixin.rst
+++ b/task-sdk/docs/resumable-job-mixin.rst
@@ -120,7 +120,7 @@ Example
     from pydantic import JsonValue
 
 
-    class MyBatchOperator(BaseOperator, ResumableJobMixin):
+    class MyBatchOperator(ResumableJobMixin, BaseOperator):
 
         external_id_key = "batch_job_id"
 
@@ -145,6 +145,27 @@ Example
         def get_job_result(self, external_id: JsonValue, context):
             return None
 
+.. _sdk-resumable-job-mixin-resume-on-retry:
+
+Disabling crash recovery per task
+----------------------------------
+
+Set ``durable=False`` on a task to opt out of crash recovery for that specific 
instance.
+The operator will always submit a fresh job on retry, with no 
``task_state_store`` interaction:
+
+.. code-block:: python
+
+    run_spark = MyBatchOperator(
+        task_id="run_spark",
+        durable=False,
+    )
+
+This is useful when the external job is not idempotent and you want Airflow to 
always submit a
+clean run rather than reconnect to a prior submission.
+
+The default is ``True``. ``durable`` is owned by the mixin — operators do not 
need to
+redeclare it. ``default_args`` injection and ``.partial()`` work automatically.
+
 .. _sdk-resumable-job-mixin-external-id-key:
 
 External ID key
diff --git a/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py 
b/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py
index 7066d10cced..27533dbe840 100644
--- a/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py
+++ b/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any
 from opentelemetry import trace
 
 from airflow.sdk._shared.observability.metrics import stats
+from airflow.sdk.bases.operator import BaseOperatorMeta
 
 if TYPE_CHECKING:
     from pydantic import JsonValue
@@ -90,6 +91,14 @@ class ResumableJobMixin:
     # Renaming this on a deployed operator breaks in-flight retries — the old 
key is already stored.
     external_id_key: str = "remote_job_id"
 
+    # The mixin is not a BaseOperator subclass, but _apply_defaults is only 
ever called on concrete
+    # operators that are BaseOperator subclasses. That is a runtime MRO 
guarantee not visible in the static
+    # type signature here and hence we need the type ignore.
+    @BaseOperatorMeta._apply_defaults  # type: ignore[type-var]
+    def __init__(self, *, durable: bool = True, **kwargs: Any) -> None:
+        super().__init__(**kwargs)
+        self.durable = durable
+
     def execute_resumable(self, context: Context) -> Any:
         """
         Core of the resumable execution logic. Call this from execute() when 
reconnection is supported.
@@ -107,6 +116,11 @@ class ResumableJobMixin:
         Closing this window would require atomic "submit + persist", which is 
not possible across
         an external system boundary.
         """
+        if not self.durable:
+            external_id = self.submit_job(context)
+            self.poll_until_complete(external_id, context)
+            return self.get_job_result(external_id, context)
+
         stats_tags = {"operator": type(self).__name__}
         # The task is team-scoped in multi-team deployments; surface team_name 
on the
         # resumable_job metrics via the running task instance's stats tags 
(omitted when
@@ -114,6 +128,7 @@ class ResumableJobMixin:
         ti = context.get("ti")
         if ti is not None and (team_name := ti.stats_tags.get("team_name")):
             stats_tags["team_name"] = team_name
+
         reconnect_to: Any = None
         already_succeeded_id: Any = None
 
diff --git a/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py 
b/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py
index e186cc19a36..8e962583e99 100644
--- a/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py
+++ b/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py
@@ -206,6 +206,35 @@ class TestNoneExternalId:
         assert task_state._store == {}
 
 
+class TestResumeOnRetryDisabled:
+    def test_submits_and_polls_without_task_store_interaction(self):
+        op = ConcreteResumableOperator(task_id="test_task", durable=False)
+        task_store = FakeTaskState()
+        op.execute_resumable(make_context(task_store))
+
+        assert op.submitted_ids == ["job-001"]
+        assert op.polled_ids == ["job-001"]
+        assert task_store._store == {}, "task_store must not be written when 
durable=False"
+
+    def test_does_not_reconnect_when_prior_id_exists(self):
+        op = ConcreteResumableOperator(task_id="test_task", durable=False)
+        op._status_map["job-001"] = "RUNNING"
+        task_store = FakeTaskState({"test_job_id": "job-001"})
+
+        op.execute_resumable(make_context(task_store))
+
+        assert op.submitted_ids == ["job-001"], "should submit fresh even with 
a prior ID stored"
+
+    def test_returns_result(self):
+        op = ConcreteResumableOperator(task_id="test_task", durable=False)
+        result = op.execute_resumable(make_context(FakeTaskState()))
+        assert result == "result-of-job-001"
+
+    def test_default_is_true(self):
+        op = ConcreteResumableOperator(task_id="test_task")
+        assert op.durable is True
+
+
 class TestExternalIdKey:
     def test_custom_key_used_for_storage_and_retrieval(self):
         class CustomKeyOp(ConcreteResumableOperator):

Reply via email to