This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fd9241cdf0b Add `ResumableJobMixin` with `SparkSubmitOperator` as a 
case study for surviving worker failures (standalone) (#67118)
fd9241cdf0b is described below

commit fd9241cdf0bb64d5b3c4619be83619db62671824
Author: Amogh Desai <[email protected]>
AuthorDate: Thu May 28 10:05:26 2026 +0530

    Add `ResumableJobMixin` with `SparkSubmitOperator` as a case study for 
surviving worker failures (standalone) (#67118)
---
 providers/apache/spark/docs/index.rst              |   2 +
 providers/apache/spark/docs/operators.rst          |  21 ++
 providers/apache/spark/provider.yaml               |  16 ++
 providers/apache/spark/pyproject.toml              |   2 +
 .../providers/apache/spark/get_provider_info.py    |  10 +
 .../providers/apache/spark/hooks/spark_submit.py   |  52 ++--
 .../apache/spark/operators/spark_submit.py         | 147 +++++++++++-
 .../unit/apache/spark/hooks/test_spark_submit.py   |  30 +++
 .../apache/spark/operators/test_spark_submit.py    | 262 ++++++++++++++++++++-
 scripts/ci/prek/known_airflow_exceptions.txt       |   2 +-
 task-sdk/docs/api.rst                              |   2 +
 task-sdk/src/airflow/sdk/__init__.py               |   3 +
 task-sdk/src/airflow/sdk/bases/resumablemixin.py   | 167 +++++++++++++
 .../tests/task_sdk/bases/test_resumablemixin.py    | 177 ++++++++++++++
 uv.lock                                            |   4 +
 15 files changed, 870 insertions(+), 27 deletions(-)

diff --git a/providers/apache/spark/docs/index.rst 
b/providers/apache/spark/docs/index.rst
index eb3cadfb81b..5138bf8952d 100644
--- a/providers/apache/spark/docs/index.rst
+++ b/providers/apache/spark/docs/index.rst
@@ -104,6 +104,8 @@ PIP package                                 Version required
 ``apache-airflow-providers-common-compat``  ``>=1.12.0``
 ``pyspark-client``                          ``>=4.0.0``
 ``grpcio-status``                           ``>=1.67.0``
+``requests``                                ``>=2.32.0``
+``tenacity``                                ``>=8.3.0``
 ==========================================  ==================
 
 Cross provider package dependencies
diff --git a/providers/apache/spark/docs/operators.rst 
b/providers/apache/spark/docs/operators.rst
index 125039ebdf3..0a645542323 100644
--- a/providers/apache/spark/docs/operators.rst
+++ b/providers/apache/spark/docs/operators.rst
@@ -181,3 +181,24 @@ Reference
 """""""""
 
 For further information, look at `Apache Spark submitting applications 
<https://spark.apache.org/docs/latest/submitting-applications.html>`_.
+
+Cluster mode crash recovery (Spark standalone)
+"""""""""""""""""""""""""""""""""""""""""""""""
+
+When running in Spark standalone cluster mode (``--deploy-mode cluster``), the 
Spark driver runs
+independently on the cluster. If the Airflow worker dies while the Spark job 
is running, the driver keeps running but
+Airflow loses track of it and the behaviour to submit a brand new job would be 
wasting
+the compute already done or even cause conflicts if the Spark job itself is 
not designed to be idempotent.
+
+Now, the ``SparkSubmitOperator`` solves this by persisting the driver ID to 
``task_state`` immediately after
+submission. On retry, it reads the ID back and reconnects to the 
already-running driver instead of
+resubmitting.
+
+This is the **synchronous path** — the worker holds a slot for the duration of 
polling. This is
+a crash-safety net for teams running sync operators for log observability, org 
constraints, or
+because a Triggerer is not available. Teams with a Triggerer available may 
also consider
+deferrable operators, which free the worker slot but may come with added 
complexity.
+
+.. note::
+    Crash recovery in cluster mode requires Airflow 3.3+ (``task_state`` 
support). On earlier
+    versions the operator falls back to the previous behavior of always 
submitting fresh.
diff --git a/providers/apache/spark/provider.yaml 
b/providers/apache/spark/provider.yaml
index 2fd094e6d75..bb91a8e3127 100644
--- a/providers/apache/spark/provider.yaml
+++ b/providers/apache/spark/provider.yaml
@@ -210,6 +210,22 @@ connection-types:
             - string
             - 'null'
           format: password
+      rest-scheme:
+        label: REST scheme
+        description: Scheme for the Spark standalone REST API (http or https). 
Default is http.
+        schema:
+          type:
+            - string
+            - 'null'
+          default: http
+      rest-port:
+        label: REST port
+        description: Port for the Spark standalone REST API 
(spark.master.rest.port). Default is 6066.
+        schema:
+          type:
+            - string
+            - 'null'
+          default: '6066'
 
 task-decorators:
   - class-name: airflow.providers.apache.spark.decorators.pyspark.pyspark_task
diff --git a/providers/apache/spark/pyproject.toml 
b/providers/apache/spark/pyproject.toml
index 216c5c003da..e7f4fa480d3 100644
--- a/providers/apache/spark/pyproject.toml
+++ b/providers/apache/spark/pyproject.toml
@@ -63,6 +63,8 @@ dependencies = [
     "apache-airflow-providers-common-compat>=1.12.0",
     "pyspark-client>=4.0.0",
     "grpcio-status>=1.67.0",
+    "requests>=2.32.0",
+    "tenacity>=8.3.0",
 ]
 
 # The optional dependencies should be modified in place in the generated file
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
index b9871156257..ef09d0a6ae9 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
@@ -126,6 +126,16 @@ def get_provider_info():
                         "description": "Run the command `base64 
<your-keytab-path>` and use its output.",
                         "schema": {"type": ["string", "null"], "format": 
"password"},
                     },
+                    "rest-scheme": {
+                        "label": "REST scheme",
+                        "description": "Scheme for the Spark standalone REST 
API (http or https). Default is http.",
+                        "schema": {"type": ["string", "null"], "default": 
"http"},
+                    },
+                    "rest-port": {
+                        "label": "REST port",
+                        "description": "Port for the Spark standalone REST API 
(spark.master.rest.port). Default is 6066.",
+                        "schema": {"type": ["string", "null"], "default": 
"6066"},
+                    },
                 },
             },
         ],
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
index 62d18aac049..9aa3ddc885e 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -160,6 +160,18 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                 description="Run the command `base64 <your-keytab-path>` and 
use its output.",
                 validators=[Optional()],
             ),
+            "rest-scheme": StringField(
+                lazy_gettext("REST scheme"),
+                widget=BS3TextFieldWidget(),
+                description="Scheme for the Spark standalone REST API (http or 
https). Default: http.",
+                validators=[Optional()],
+            ),
+            "rest-port": StringField(
+                lazy_gettext("REST port"),
+                widget=BS3TextFieldWidget(),
+                description="Port for the Spark standalone REST API 
(spark.master.rest.port). Default: 6066.",
+                validators=[Optional()],
+            ),
         }
 
     def __init__(
@@ -258,7 +270,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
 
     def _resolve_connection(self) -> dict[str, Any]:
         # Build from connection master or default to yarn if not available
-        conn_data = {
+        conn_data: dict[str, Any] = {
             "master": "yarn",
             "queue": None,  # yarn queue
             "deploy_mode": None,
@@ -266,6 +278,9 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
             "namespace": None,
             "principal": self._principal,
             "keytab": self._keytab,
+            # fallback if connection lookup fails; overridden by 
rest-scheme/rest-port extras below
+            "rest_scheme": "http",
+            "rest_port": 6066,
         }
 
         try:
@@ -308,6 +323,8 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                 )
             conn_data["spark_binary"] = self.spark_binary
             conn_data["namespace"] = extra.get("namespace")
+            conn_data["rest_scheme"] = extra.get("rest-scheme", "http")
+            conn_data["rest_port"] = int(extra.get("rest-port", 6066))
             if conn_data["principal"] is None:
                 conn_data["principal"] = extra.get("principal")
             if conn_data["keytab"] is None:
@@ -587,7 +604,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
             except Exception as exc:
                 self.log.warning("Post-submit command raised an exception: %s. 
Error: %s", cmd, exc)
 
-    def submit(self, application: str = "", **kwargs: Any) -> None:
+    def submit(self, application: str = "", **kwargs: Any) -> str | None:
         """
         Remote Popen to execute the spark-submit job.
 
@@ -626,27 +643,18 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                     f"Cannot execute: {self._mask_cmd(spark_submit_cmd)}. 
Error code is: {returncode}."
                 )
 
-            self.log.debug("Should track driver: %s", 
self._should_track_driver_status)
-
-            # We want the Airflow job to wait until the Spark driver is 
finished
-            if self._should_track_driver_status:
-                if self._driver_id is None:
-                    raise AirflowException(
-                        "No driver id is known: something went wrong when 
executing the spark submit command"
-                    )
-
-                # We start with the SUBMITTED status as initial status
-                self._driver_status = "SUBMITTED"
-
-                # Start tracking the driver status (blocking function)
-                self._start_driver_status_tracking()
-
-                if self._driver_status != "FINISHED":
-                    raise AirflowException(
-                        f"ERROR : Driver {self._driver_id} badly exited with 
status {self._driver_status}"
-                    )
+            if self._should_track_driver_status and self._driver_id is None:
+                raise AirflowException(
+                    "No driver id is known: something went wrong when 
executing the spark submit command"
+                )
         finally:
-            self._run_post_submit_commands()
+            # In cluster mode with driver tracking, the operator calls 
poll_until_complete
+            # after submit() returns, so post_submit_commands are deferred 
there to preserve
+            # the "runs after job finishes" contract. In all other modes, run 
them here.
+            if not self._should_track_driver_status:
+                self._run_post_submit_commands()
+
+        return self._driver_id
 
     def _process_spark_submit_log(self, itr: Iterator[Any]) -> None:
         """
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index 0e67fa5b50d..76b010107da 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -18,7 +18,10 @@
 from __future__ import annotations
 
 from collections.abc import Sequence
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, cast
+
+import requests
+from tenacity import retry, stop_after_attempt, wait_fixed
 
 from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
 from airflow.providers.common.compat.openlineage.utils.spark import (
@@ -27,11 +30,30 @@ from 
airflow.providers.common.compat.openlineage.utils.spark import (
 )
 from airflow.providers.common.compat.sdk import BaseOperator, conf
 
+try:
+    from airflow.sdk.bases.resumablemixin import ResumableJobMixin
+except ImportError:
+    # Airflow 2 compat.
+    # ResumableJobMixin does not exist in Airflow 2, so we need to add a stub 
to make it
+    # behave as before
+    class ResumableJobMixin:  # type: ignore[no-redef]
+        """Airflow 2 stub — no task_state, always submits fresh."""
+
+        external_id_key: str = "remote_job_id"
+
+        def execute_resumable(self, context):
+            external_id = self.submit_job(context)
+            self.poll_until_complete(external_id, context)
+            return self.get_job_result(external_id, context)
+
+
 if TYPE_CHECKING:
+    from pydantic import JsonValue
+
     from airflow.providers.common.compat.sdk import Context
 
 
-class SparkSubmitOperator(BaseOperator):
+class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
     """
     Wrap the spark-submit binary to kick off a spark-submit job; requires 
"spark-submit" binary in the PATH.
 
@@ -88,6 +110,10 @@ class SparkSubmitOperator(BaseOperator):
         Useful for cleaning up sidecars such as Istio. Failures produce a 
warning but do not fail the task.
     """
 
+    # Generic key used across all Spark deployment modes (standalone driver ID,
+    # YARN application ID, K8s driver pod name).
+    external_id_key = "spark_job_id"
+
     template_fields: Sequence[str] = (
         "application",
         "conf",
@@ -141,6 +167,7 @@ class SparkSubmitOperator(BaseOperator):
         deploy_mode: str | None = None,
         use_krb5ccache: bool = False,
         post_submit_commands: list[str] | None = None,
+        reconnect_on_retry: bool = True,
         openlineage_inject_parent_job_info: bool = conf.getboolean(
             "openlineage", "spark_inject_parent_job_info", fallback=False
         ),
@@ -184,6 +211,7 @@ class SparkSubmitOperator(BaseOperator):
         self._conn_id = conn_id
         self._use_krb5ccache = use_krb5ccache
 
+        self.reconnect_on_retry = reconnect_on_retry
         self._openlineage_inject_parent_job_info = 
openlineage_inject_parent_job_info
         self._openlineage_inject_transport_info = 
openlineage_inject_transport_info
 
@@ -198,7 +226,120 @@ class SparkSubmitOperator(BaseOperator):
             self.conf = 
inject_transport_information_into_spark_properties(self.conf, context)
         if self._hook is None:
             self._hook = self._get_hook()
-        self._hook.submit(self.application)
+        hook = self._hook
+        if hook._should_track_driver_status:
+            if self.reconnect_on_retry:
+                return self.execute_resumable(context)
+            # reconnect_on_retry=False: still submit-and-poll, just skip 
task_state persistence.
+            driver_id = self.submit_job(context)
+            self.poll_until_complete(driver_id, context)
+            return self.get_job_result(driver_id, context)
+        hook.submit(self.application)
+
+    def submit_job(self, context: Context) -> str:
+        if self._hook is None:
+            self._hook = self._get_hook()
+        driver_id = self._hook.submit(self.application)
+        if not driver_id:
+            raise RuntimeError("spark-submit did not return a driver ID")
+        self.log.info("Spark driver submitted: %s", driver_id)
+        return driver_id
+
+    def get_job_status(self, external_id: JsonValue) -> str:
+        # called from submit_job which always returns a str (Spark driver IDs 
are strings)
+        external_id = cast("str", external_id)
+        if self._hook is None:
+            self._hook = self._get_hook()
+        # The YARN and K8s branches below (and in is_job_active, 
is_job_succeeded, poll_until_complete)
+        # are currently unreachable: execute_resumable is only called when 
_should_track_driver_status
+        # is True, which requires spark:// + cluster mode. They are 
scaffolding for a follow-up PR
+        # that extends ResumableJobMixin support to YARN and Kubernetes.
+        if self._hook._is_yarn:
+            # TODO: call YARN ResourceManager REST API
+            # GET http://rm:8088/ws/v1/cluster/apps/{external_id}
+            raise NotImplementedError("YARN job status not yet implemented")
+        if self._hook._is_kubernetes:
+            # TODO: call K8s pod status API
+            raise NotImplementedError("K8s job status not yet implemented")
+        scheme = self._hook._connection.get("rest_scheme", "http")
+        rest_port = self._hook._connection.get("rest_port", 6066)
+        # HA master URLs can look like spark://m1:7077,m2:7077 — try each host 
in order.
+        # The master URL port (e.g. 7077) is the RPC port — not the REST API 
port.
+        # Use rest-port connection extra to override spark.master.rest.port 
(default 6066).
+        master_urls = self._hook._connection["master"].replace("spark://", 
"").split(",")
+        last_exc: Exception = RuntimeError("No Spark masters to query")
+        for m in master_urls:
+            host = m.strip().split(":")[0]
+            url = 
f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}"
+            try:
+                status = self._fetch_driver_status(url, external_id)
+                return status
+            except Exception as e:
+                self.log.warning("Could not reach Spark master %s: %s", host, 
e)
+                last_exc = e
+        raise last_exc
+
+    @retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True)
+    def _fetch_driver_status(self, url: str, external_id: str) -> str:
+        response = requests.get(url, timeout=30)
+        response.raise_for_status()
+        # "success:false" means the master does not recognise the driver ID or 
is in recovery.
+        # 
https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
+        data = response.json()
+        if not data.get("success"):
+            raise RuntimeError(
+                f"Spark REST API returned failure for {external_id}: 
{data.get('message', 'unknown error')}"
+            )
+        status = data["driverState"]
+        self.log.info("Driver %s status: %s", external_id, status)
+        return status
+
+    def is_job_active(self, status: str) -> bool:
+        if self._hook is None:
+            self._hook = self._get_hook()
+        status = status.upper()
+        if self._hook._is_yarn:
+            # 
https://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/ResourceManagerRest.html
+            return status in ("NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED", 
"RUNNING")
+        if self._hook._is_kubernetes:
+            return status in ("PENDING", "RUNNING")
+        # RELAUNCHING: driver is being restarted after a failure, still alive.
+        # UNKNOWN: master is in failure recovery, state is temporarily 
unavailable.
+        # 
https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
+        return status in ("SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN")
+
+    def is_job_succeeded(self, status: str) -> bool:
+        if self._hook is None:
+            self._hook = self._get_hook()
+        status = status.upper()
+        if self._hook._is_kubernetes:
+            return status == "SUCCEEDED"
+        # standalone and YARN both use FINISHED
+        return status == "FINISHED"
+
+    def poll_until_complete(self, external_id: JsonValue, context: Context) -> 
None:
+        # called from submit_job which always returns a str (Spark driver IDs 
are strings)
+        external_id = cast("str", external_id)
+        if self._hook is None:
+            self._hook = self._get_hook()
+        if self._hook._is_yarn:
+            # TODO: poll YARN ResourceManager until app reaches terminal state
+            raise NotImplementedError("YARN poll not yet implemented")
+        if self._hook._is_kubernetes:
+            # TODO: poll K8s pod phase until terminal
+            raise NotImplementedError("K8s poll not yet implemented")
+        self.log.info("Polling driver %s until completion", external_id)
+        self._hook._driver_id = external_id
+        try:
+            self._hook._start_driver_status_tracking()
+            if self._hook._driver_status != "FINISHED":
+                raise RuntimeError(f"Driver {external_id} exited with status 
{self._hook._driver_status}")
+        finally:
+            # post-submit commands must fire whether the job succeeded or 
failed.
+            self._hook._run_post_submit_commands()
+
+    def get_job_result(self, external_id: JsonValue, context: Context) -> None:
+        return None
 
     def on_kill(self) -> None:
         if self._hook is None:
diff --git 
a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py 
b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
index 1b2feaa33e9..c909e9f12ab 100644
--- a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
+++ b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
@@ -397,6 +397,8 @@ class TestSparkSubmitHook:
             "namespace": None,
             "principal": None,
             "keytab": None,
+            "rest_scheme": "http",
+            "rest_port": 6066,
         }
         assert connection == expected_spark_connection
         assert dict_cmd["--master"] == "yarn"
@@ -420,6 +422,8 @@ class TestSparkSubmitHook:
             "namespace": None,
             "principal": None,
             "keytab": None,
+            "rest_scheme": "http",
+            "rest_port": 6066,
         }
         assert connection == expected_spark_connection
         assert dict_cmd["--master"] == "yarn"
@@ -443,6 +447,8 @@ class TestSparkSubmitHook:
             "namespace": None,
             "principal": None,
             "keytab": None,
+            "rest_scheme": "http",
+            "rest_port": 6066,
         }
         assert connection == expected_spark_connection
         assert dict_cmd["--master"] == "mesos://host:5050"
@@ -465,6 +471,8 @@ class TestSparkSubmitHook:
             "namespace": None,
             "principal": None,
             "keytab": None,
+            "rest_scheme": "http",
+            "rest_port": 6066,
         }
         assert connection == expected_spark_connection
         assert dict_cmd["--master"] == "yarn://yarn-master"
@@ -489,6 +497,8 @@ class TestSparkSubmitHook:
             "namespace": "mynamespace",
             "principal": None,
             "keytab": None,
+            "rest_scheme": "http",
+            "rest_port": 6066,
         }
         assert connection == expected_spark_connection
         assert dict_cmd["--master"] == "k8s://https://k8s-master";
@@ -515,6 +525,8 @@ class TestSparkSubmitHook:
             "namespace": "airflow",
             "principal": None,
             "keytab": None,
+            "rest_scheme": "http",
+            "rest_port": 6066,
         }
         assert connection == expected_spark_connection
         assert dict_cmd["--master"] == "k8s://https://k8s-master";
@@ -538,6 +550,8 @@ class TestSparkSubmitHook:
             "namespace": None,
             "principal": None,
             "keytab": None,
+            "rest_scheme": "http",
+            "rest_port": 6066,
         }
         assert connection == expected_spark_connection
         assert cmd[0] == "spark2-submit"
@@ -559,6 +573,8 @@ class TestSparkSubmitHook:
             "namespace": None,
             "principal": None,
             "keytab": None,
+            "rest_scheme": "http",
+            "rest_port": 6066,
         }
         assert connection == expected_spark_connection
         assert cmd[0] == "spark3-submit"
@@ -619,6 +635,8 @@ class TestSparkSubmitHook:
             "namespace": None,
             "principal": None,
             "keytab": None,
+            "rest_scheme": "http",
+            "rest_port": 6066,
         }
         assert connection == expected_spark_connection
         assert cmd[0] == "spark3-submit"
@@ -641,6 +659,8 @@ class TestSparkSubmitHook:
             "namespace": None,
             "principal": None,
             "keytab": None,
+            "rest_scheme": "http",
+            "rest_port": 6066,
         }
         assert connection == expected_spark_connection
         assert cmd[0] == "spark-submit"
@@ -662,6 +682,8 @@ class TestSparkSubmitHook:
             "namespace": None,
             "principal": None,
             "keytab": None,
+            "rest_scheme": "http",
+            "rest_port": 6066,
         }
         assert connection == expected_spark_connection
         assert cmd[0] == "spark-submit"
@@ -684,6 +706,8 @@ class TestSparkSubmitHook:
             "namespace": None,
             "principal": "user/[email protected]",
             "keytab": None,
+            "rest_scheme": "http",
+            "rest_port": 6066,
         }
         assert connection == expected_spark_connection
         assert dict_cmd["--principal"] == "user/[email protected]"
@@ -706,6 +730,8 @@ class TestSparkSubmitHook:
             "namespace": None,
             "principal": "will-override",
             "keytab": None,
+            "rest_scheme": "http",
+            "rest_port": 6066,
         }
         assert connection == expected_spark_connection
         assert dict_cmd["--principal"] == "will-override"
@@ -732,6 +758,8 @@ class TestSparkSubmitHook:
             "namespace": None,
             "principal": None,
             "keytab": "privileged_user.keytab",
+            "rest_scheme": "http",
+            "rest_port": 6066,
         }
         assert connection == expected_spark_connection
         assert dict_cmd["--keytab"] == "privileged_user.keytab"
@@ -757,6 +785,8 @@ class TestSparkSubmitHook:
             "namespace": None,
             "principal": None,
             "keytab": "will-override",
+            "rest_scheme": "http",
+            "rest_port": 6066,
         }
         assert connection == expected_spark_connection
         assert dict_cmd["--keytab"] == "will-override"
diff --git 
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py 
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
index 6a85ce1b920..65af1116861 100644
--- 
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
+++ 
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
@@ -32,7 +32,7 @@ from airflow.utils.types import DagRunType
 
 from tests_common.test_utils.dag import sync_dag_to_db
 from tests_common.test_utils.taskinstance import create_task_instance, 
render_template_fields
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, 
AIRFLOW_V_3_3_PLUS
 
 DEFAULT_DATE = timezone.datetime(2017, 1, 1)
 
@@ -321,6 +321,7 @@ class TestSparkSubmitOperator:
             openlineage_inject_transport_info=True,
             **self._config,
         )
+        mock_get_hook.return_value._should_track_driver_status = False
         operator.execute(MagicMock())
 
         assert operator.conf == {
@@ -387,6 +388,7 @@ class TestSparkSubmitOperator:
             openlineage_inject_transport_info=True,
             **self._config,
         )
+        mock_get_hook.return_value._should_track_driver_status = False
         operator.execute({"ti": mock_ti})
 
         assert operator.conf == {
@@ -425,6 +427,7 @@ class TestSparkSubmitOperator:
             CompositeConfig.from_dict({"transports": {"test1": {"type": 
"console"}}})
         )
 
+        mock_get_hook.return_value._should_track_driver_status = False
         with caplog.at_level(logging.INFO):
             operator = SparkSubmitOperator(
                 task_id="spark_submit_job",
@@ -456,6 +459,7 @@ class TestSparkSubmitOperator:
             config=ConsoleConfig()
         )
 
+        mock_get_hook.return_value._should_track_driver_status = False
         with caplog.at_level(logging.INFO):
             operator = SparkSubmitOperator(
                 task_id="spark_submit_job",
@@ -474,3 +478,259 @@ class TestSparkSubmitOperator:
         assert operator.conf == {
             "parquet.compression": "SNAPPY",
         }
+
+
+class FakeTaskState:
+    """In-memory task state for tests."""
+
+    def __init__(self, stored: dict[str, str] | None = None):
+        self._store: dict[str, str] = dict(stored or {})
+
+    def get(self, key: str) -> str | None:
+        return self._store.get(key)
+
+    def set(self, key: str, value: str) -> None:
+        self._store[key] = value
+
+
[email protected](
+    not AIRFLOW_V_3_3_PLUS,
+    reason="ResumableJobMixin reconnect requires task_state, available in 
Airflow 3.3+",
+)
+class TestSparkSubmitOperatorResumable:
+    def setup_method(self):
+        args = {"owner": "airflow", "start_date": DEFAULT_DATE}
+        self.dag = DAG("test_resumable_dag", schedule=None, default_args=args)
+
+    def _make_operator(self, **kwargs):
+        return SparkSubmitOperator(task_id="test", dag=self.dag, 
application="test.jar", **kwargs)
+
+    def _make_hook(self, should_track=False, is_yarn=False, 
is_kubernetes=False):
+        hook = MagicMock()
+        hook._should_track_driver_status = should_track
+        hook._is_yarn = is_yarn
+        hook._is_kubernetes = is_kubernetes
+        hook._connection = {"master": "spark://localhost:7077"}
+        return hook
+
+    def test_non_cluster_mode_calls_hook_submit_directly(self):
+        operator = self._make_operator()
+        operator._hook = self._make_hook(should_track=False)
+
+        operator.execute(context={})
+
+        operator._hook.submit.assert_called_once_with("test.jar")
+
+    def test_cluster_mode_first_run_persists_id_before_polling(self):
+        operator = self._make_operator()
+        operator._hook = self._make_hook(should_track=True)
+        operator._hook.submit.return_value = "driver-001"
+
+        task_state = FakeTaskState()
+        persisted_before_poll = []
+
+        def track_poll(external_id, context):
+            persisted_before_poll.append(task_state.get("spark_job_id"))
+
+        operator.poll_until_complete = track_poll
+
+        operator.execute(context={"task_state": task_state})
+
+        operator._hook.submit.assert_called_once_with("test.jar")
+        assert persisted_before_poll == ["driver-001"]
+
+    @pytest.mark.parametrize(
+        ("prior_status", "expect_submit", "expect_poll_id"),
+        [
+            ("RUNNING", False, "driver-001"),
+            ("SUBMITTED", False, "driver-001"),
+            ("FINISHED", False, None),
+            ("FAILED", True, "driver-new"),
+            ("KILLED", True, "driver-new"),
+        ],
+    )
+    def test_retry_behaviour_based_on_prior_driver_status(self, prior_status, 
expect_submit, expect_poll_id):
+        operator = self._make_operator()
+        operator._hook = self._make_hook(should_track=True)
+        operator._hook.submit.return_value = "driver-new"
+        task_state = FakeTaskState({"spark_job_id": "driver-001"})
+
+        operator.get_job_status = lambda external_id: prior_status
+        polled = []
+        operator.poll_until_complete = lambda external_id, context: 
polled.append(external_id)
+
+        operator.execute(context={"task_state": task_state})
+
+        if expect_submit:
+            operator._hook.submit.assert_called_once_with("test.jar")
+        else:
+            operator._hook.submit.assert_not_called()
+
+        if expect_poll_id:
+            assert polled == [expect_poll_id]
+        else:
+            assert polled == []
+
+    def test_submits_fresh_when_task_state_unavailable(self):
+        operator = self._make_operator()
+        operator._hook = self._make_hook(should_track=True)
+        operator._hook.submit.return_value = "driver-001"
+        polled = []
+        operator.poll_until_complete = lambda external_id, context: 
polled.append(external_id)
+
+        # no task_state key in context
+        operator.execute(context={})
+
+        operator._hook.submit.assert_called_once_with("test.jar")
+        assert polled == ["driver-001"]
+
+    def test_reconnect_on_retry_false_submits_fresh_and_polls(self):
+        operator = self._make_operator(reconnect_on_retry=False)
+        operator._hook = self._make_hook(should_track=True)
+        operator._hook.submit.return_value = "driver-new"
+        task_state = FakeTaskState({"spark_job_id": "driver-old"})
+        polled = []
+        operator.poll_until_complete = lambda external_id, context: 
polled.append(external_id)
+
+        operator.execute(context={"task_state": task_state})
+        # reconnect_on_retry=False: ignores prior driver ID, submits fresh, 
but still polls
+        operator._hook.submit.assert_called_once_with("test.jar")
+        assert polled == ["driver-new"]
+
+    @pytest.mark.parametrize(
+        ("is_yarn", "is_kubernetes", "status", "expected_active", 
"expected_succeeded"),
+        [
+            (False, False, "RUNNING", True, False),
+            (False, False, "SUBMITTED", True, False),
+            (False, False, "FINISHED", False, True),
+            (False, False, "FAILED", False, False),
+            (True, False, "RUNNING", True, False),
+            (True, False, "ACCEPTED", True, False),
+            (True, False, "NEW", True, False),
+            (True, False, "FINISHED", False, True),
+            (True, False, "FAILED", False, False),
+            (False, True, "Running", True, False),
+            (False, True, "Pending", True, False),
+            (False, True, "Succeeded", False, True),
+            (False, True, "Failed", False, False),
+        ],
+    )
+    def test_job_status_mappings(self, is_yarn, is_kubernetes, status, 
expected_active, expected_succeeded):
+        operator = self._make_operator()
+        operator._hook = self._make_hook(is_yarn=is_yarn, 
is_kubernetes=is_kubernetes)
+
+        assert operator.is_job_active(status) == expected_active
+        assert operator.is_job_succeeded(status) == expected_succeeded
+
+    @pytest.mark.parametrize(
+        ("response_json", "expected_status", "expected_error"),
+        [
+            ({"success": True, "driverState": "RUNNING"}, "RUNNING", None),
+            ({"success": False, "message": "driver not found"}, None, "driver 
not found"),
+            ({"driverState": "RUNNING"}, None, "unknown error"),
+        ],
+    )
+    def test_get_job_status(self, response_json, expected_status, 
expected_error):
+        operator = self._make_operator()
+        operator._hook = self._make_hook(should_track=True)
+        mock_response = MagicMock()
+        mock_response.json.return_value = response_json
+
+        with mock.patch("requests.get", return_value=mock_response):
+            if expected_error:
+                with pytest.raises(RuntimeError, match=expected_error):
+                    operator.get_job_status("driver-001")
+            else:
+                assert operator.get_job_status("driver-001") == expected_status
+
+    def test_get_job_status_ha_tries_next_master(self):
+        operator = self._make_operator()
+        hook = self._make_hook(should_track=True)
+        # Master URL port (7077) is RPC — REST API must use 6066, not 7077
+        hook._connection = {"master": "spark://m1:7077,m2:7077"}
+        operator._hook = hook
+
+        good_response = MagicMock()
+        good_response.json.return_value = {"success": True, "driverState": 
"RUNNING"}
+        captured_urls = []
+
+        def side_effect(url, timeout):
+            captured_urls.append(url)
+            if "m1" in url:
+                raise ConnectionError("m1 unreachable")
+            return good_response
+
+        with mock.patch("requests.get", side_effect=side_effect):
+            assert operator.get_job_status("driver-001") == "RUNNING"
+
+        assert all(":6066/" in url for url in captured_urls), "REST API must 
use port 6066, not the RPC port"
+
+    def test_get_job_status_ha_tries_next_master_on_success_false(self):
+        """success:false from m1 (e.g. HA recovery in progress) should fall 
through to m2."""
+        operator = self._make_operator()
+        hook = self._make_hook(should_track=True)
+        hook._connection = {"master": "spark://m1:7077,m2:7077"}
+        operator._hook = hook
+
+        bad_response = MagicMock()
+        bad_response.json.return_value = {"success": False, "message": "Driver 
not found"}
+        good_response = MagicMock()
+        good_response.json.return_value = {"success": True, "driverState": 
"RUNNING"}
+
+        def side_effect(url, timeout):
+            if "m1" in url:
+                return bad_response
+            return good_response
+
+        with mock.patch("requests.get", side_effect=side_effect):
+            assert operator.get_job_status("driver-001") == "RUNNING"
+
+    def test_get_job_status_ha_raises_when_all_masters_unreachable(self):
+        operator = self._make_operator()
+        hook = self._make_hook(should_track=True)
+        hook._connection = {"master": "spark://m1:7077,m2:7077"}
+        operator._hook = hook
+
+        with mock.patch("requests.get", 
side_effect=ConnectionError("unreachable")):
+            with pytest.raises(ConnectionError):
+                operator.get_job_status("driver-001")
+
+    def test_get_job_status_uses_rest_scheme_from_connection(self):
+        operator = self._make_operator()
+        hook = self._make_hook(should_track=True)
+        hook._connection = {"master": "spark://myhost:6066", "rest_scheme": 
"https"}
+        operator._hook = hook
+
+        mock_response = MagicMock()
+        mock_response.json.return_value = {"success": True, "driverState": 
"RUNNING"}
+        captured_urls = []
+
+        def capture(url, timeout):
+            captured_urls.append(url)
+            return mock_response
+
+        with mock.patch("requests.get", side_effect=capture):
+            operator.get_job_status("driver-001")
+
+        assert len(captured_urls) == 1
+        assert captured_urls[0].startswith("https://";)
+
+    def test_poll_until_complete_runs_post_submit_on_failure(self):
+        """post_submit_commands must run even when the driver exits with a 
failure status."""
+        operator = self._make_operator()
+        hook = self._make_hook(should_track=True)
+        hook._connection = {"master": "spark://myhost:7077"}
+        hook._driver_status = "FAILED"
+
+        def simulate_failed_tracking():
+            hook._driver_status = "FAILED"
+
+        hook._start_driver_status_tracking = 
mock.MagicMock(side_effect=simulate_failed_tracking)
+        post_submit_called = []
+        hook._run_post_submit_commands = mock.MagicMock(side_effect=lambda: 
post_submit_called.append(True))
+        operator._hook = hook
+
+        with pytest.raises(RuntimeError, match="FAILED"):
+            operator.poll_until_complete("driver-001", {})
+
+        assert post_submit_called, "_run_post_submit_commands must be called 
even on driver failure"
diff --git a/scripts/ci/prek/known_airflow_exceptions.txt 
b/scripts/ci/prek/known_airflow_exceptions.txt
index f1ddfbd1efc..bd4570fc55f 100644
--- a/scripts/ci/prek/known_airflow_exceptions.txt
+++ b/scripts/ci/prek/known_airflow_exceptions.txt
@@ -145,7 +145,7 @@ 
providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py::3
 providers/apache/pig/src/airflow/providers/apache/pig/hooks/pig.py::1
 providers/apache/pinot/src/airflow/providers/apache/pinot/hooks/pinot.py::1
 providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_sql.py::2
-providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py::11
+providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py::10
 providers/arangodb/src/airflow/providers/arangodb/hooks/arangodb.py::9
 providers/arangodb/src/airflow/providers/arangodb/operators/arangodb.py::1
 providers/atlassian/jira/src/airflow/providers/atlassian/jira/hooks/jira.py::1
diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst
index 8f3f3de9539..c2bb0a19fbc 100644
--- a/task-sdk/docs/api.rst
+++ b/task-sdk/docs/api.rst
@@ -99,6 +99,8 @@ Bases
 
 .. autoapiclass:: airflow.sdk.SkipMixin
 
+.. autoclass:: airflow.sdk.ResumableJobMixin
+
 .. autoapiclass:: airflow.sdk.BaseHook
 
 Callbacks
diff --git a/task-sdk/src/airflow/sdk/__init__.py 
b/task-sdk/src/airflow/sdk/__init__.py
index eeae86f1eb3..ab834658c0f 100644
--- a/task-sdk/src/airflow/sdk/__init__.py
+++ b/task-sdk/src/airflow/sdk/__init__.py
@@ -66,6 +66,7 @@ __all__ = [
     "PartitionMapper",
     "PokeReturnValue",
     "ProductMapper",
+    "ResumableJobMixin",
     "RetryAction",
     "RetryDecision",
     "RetryPolicy",
@@ -117,6 +118,7 @@ if TYPE_CHECKING:
         cross_downstream,
     )
     from airflow.sdk.bases.operatorlink import BaseOperatorLink
+    from airflow.sdk.bases.resumablemixin import ResumableJobMixin
     from airflow.sdk.bases.sensor import BaseSensorOperator, PokeReturnValue
     from airflow.sdk.bases.skipmixin import SkipMixin
     from airflow.sdk.bases.xcom import BaseXCom
@@ -233,6 +235,7 @@ __lazy_imports: dict[str, str] = {
     "PartitionMapper": ".definitions.partition_mappers.base",
     "PokeReturnValue": ".bases.sensor",
     "ProductMapper": ".definitions.partition_mappers.product",
+    "ResumableJobMixin": ".bases.resumablemixin",
     "RetryAction": ".definitions.retry_policy",
     "RetryDecision": ".definitions.retry_policy",
     "RetryPolicy": ".definitions.retry_policy",
diff --git a/task-sdk/src/airflow/sdk/bases/resumablemixin.py 
b/task-sdk/src/airflow/sdk/bases/resumablemixin.py
new file mode 100644
index 00000000000..4e252d54743
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/bases/resumablemixin.py
@@ -0,0 +1,167 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+    from pydantic import JsonValue
+
+    from airflow.sdk.definitions.context import Context
+    from airflow.sdk.types import Logger
+
+
+class ResumableJobMixin:
+    """
+    Mixin for operators that submit one long-running job to an external system 
and poll for completion.
+
+    **Purpose:** This mixin makes the synchronous operator path crash-safe. It 
is not a replacement
+    for deferrable operators — deferrable remains the recommended approach for 
long-running tasks when
+    a Triggerer is available and the async model fits the team. This mixin is 
for teams already running
+    synchronous operators who want worker crashes to reconnect to the existing 
job rather than
+    resubmitting a duplicate.
+
+    **How it works:** On the first run, after submitting the job, the external 
ID (driver ID, YARN
+    application ID, etc.) is persisted to ``task_state`` before polling 
starts. On retry, the mixin
+    reads that ID back and reconnects to the already-running job instead of 
starting a new one.
+
+    **What it does not do:** It does not free the worker slot during polling 
(use deferrable for that),
+    and it does not stream logs from the remote system (the operator controls 
that separately).
+
+    Usage: call ``execute_resumable(context)`` from the operator's 
``execute()`` when reconnection
+    is supported.
+
+    Subclasses must implement the methods specific to their external system. 
The mixin owns
+    only ``execute_resumable()`` and the task_state read/write logic.
+
+    Example::
+
+        class MyOperator(ResumableJobMixin, BaseOperator):
+            external_id_key = "my_job_id"
+
+            def execute(self, context):
+                return self.execute_resumable(context)
+
+            def submit_job(self, context) -> JsonValue:
+                return self.hook.submit(...)
+
+            def get_job_status(self, external_id: JsonValue) -> str:
+                return self.hook.get_status(external_id)
+
+            def is_job_active(self, status: str) -> bool:
+                return status in ("RUNNING", "PENDING")
+
+            def is_job_succeeded(self, status: str) -> bool:
+                return status == "SUCCEEDED"
+
+            def poll_until_complete(self, external_id: JsonValue, context: 
Context) -> None:
+                self.hook.poll(external_id)
+
+            def get_job_result(self, external_id: JsonValue, context: Context) 
-> Any:
+                return None
+    """
+
+    if TYPE_CHECKING:
+        # log comes from BaseOperator (via LoggingMixin) at runtime, but mypy 
cannot see
+        # that because ResumableJobMixin does not inherit from it directly.
+        log: Logger
+
+    # Key used to store and retrieve the external job ID from task_state 
across retries.
+    # Renaming this on a deployed operator breaks in-flight retries — the old 
key is already stored.
+    external_id_key: str = "remote_job_id"
+
+    def execute_resumable(self, context: Context) -> Any:
+        """
+        Core of the resumable execution logic. Call this from execute() when 
reconnection is supported.
+
+        On initial run: submits the job, persists the external ID to 
task_state, then polls.
+
+        Behaviour on retry:
+        - On retry with active job: skips submission, reconnects to the 
running job.
+        - On retry with succeeded job: skips submission and polling, returns 
result immediately.
+        - On retry with failed job: falls through and resubmits fresh.
+
+        Known limitation: there is a small window between ``submit_job`` 
returning and
+        ``task_state.set`` completing. If the worker dies in that gap, the 
next retry still
+        holds the previous (terminal) ID and will resubmit a fresh job rather 
than reconnecting.
+        Closing this window would require atomic "submit + persist", which is 
not possible across
+        an external system boundary.
+        """
+        task_state = context.get("task_state")
+
+        if task_state is not None:
+            external_id = task_state.get(self.external_id_key)
+            if external_id:
+                status = self.get_job_status(external_id)
+                if self.is_job_active(status):
+                    self.log.info(
+                        "Reconnecting to existing job identified by: %s 
(status: %s)", external_id, status
+                    )
+                    return self.poll_until_complete(external_id, context)
+                if self.is_job_succeeded(status):
+                    self.log.info(
+                        "Job with identifier: %s already completed 
successfully, skipping resubmission",
+                        external_id,
+                    )
+                    return self.get_job_result(external_id, context)
+                self.log.info(
+                    "Prior job with identifier: %s in terminal state %s, 
resubmitting fresh",
+                    external_id,
+                    status,
+                )
+
+        external_id = self.submit_job(context)
+
+        if task_state is not None:
+            task_state.set(self.external_id_key, external_id)
+
+        self.poll_until_complete(external_id, context)
+        return self.get_job_result(external_id, context)
+
+    def submit_job(self, context: Context) -> JsonValue:
+        """Submit the job to the external system. Return its external ID."""
+        raise NotImplementedError
+
+    def get_job_status(self, external_id: JsonValue) -> str:
+        """Query the external system for the current job status."""
+        raise NotImplementedError
+
+    def is_job_active(self, status: str) -> bool:
+        """
+        Return True if the job is still running and can be reconnected to.
+
+        ``status`` is a raw string returned by the external system — not an 
Airflow enum.
+        Its values are backend-specific (e.g. ``"RUNNING"``, ``"Pending"``, 
``"ContainerCreating"``).
+        """
+        raise NotImplementedError
+
+    def is_job_succeeded(self, status: str) -> bool:
+        """
+        Return True if the job completed successfully.
+
+        ``status`` is a raw string returned by the external system — not an 
Airflow enum.
+        Its values are backend-specific (e.g. ``"FINISHED"``, ``"Succeeded"``).
+        """
+        raise NotImplementedError
+
+    def poll_until_complete(self, external_id: JsonValue, context: Context) -> 
None:
+        """Block until the job reaches a terminal state. Raise on failure."""
+        raise NotImplementedError
+
+    def get_job_result(self, external_id: JsonValue, context: Context) -> Any:
+        """Return the job result after completion. Return None if not 
applicable."""
+        raise NotImplementedError
diff --git a/task-sdk/tests/task_sdk/bases/test_resumablemixin.py 
b/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
new file mode 100644
index 00000000000..8e95e132f3a
--- /dev/null
+++ b/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import pytest
+
+from airflow.sdk.bases.operator import BaseOperator
+from airflow.sdk.bases.resumablemixin import ResumableJobMixin
+
+if TYPE_CHECKING:
+    from pydantic import JsonValue
+
+
+class ConcreteResumableOperator(ResumableJobMixin, BaseOperator):
+    """Minimal concrete implementation for testing the mixin."""
+
+    external_id_key = "test_job_id"
+
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+        self.submitted_ids: list[str] = []
+        self.polled_ids: list[str] = []
+        self._next_id = "job-001"
+        self._status_map: dict[str, str] = {}
+        self._active_statuses = {"RUNNING", "PENDING"}
+        self._succeeded_statuses = {"SUCCEEDED"}
+
+    def submit_job(self, context) -> JsonValue:
+        self.submitted_ids.append(self._next_id)
+        return self._next_id
+
+    def get_job_status(self, external_id: JsonValue) -> str:
+        return self._status_map.get(str(external_id), "UNKNOWN")
+
+    def is_job_active(self, status: str) -> bool:
+        return status in self._active_statuses
+
+    def is_job_succeeded(self, status: str) -> bool:
+        return status in self._succeeded_statuses
+
+    def poll_until_complete(self, external_id: JsonValue, context) -> None:
+        self.polled_ids.append(str(external_id))
+
+    def get_job_result(self, external_id: JsonValue, context) -> str:
+        return f"result-of-{external_id}"
+
+
+class FakeTaskState:
+    def __init__(self, stored: dict[str, str] | None = None):
+        self._store: dict[str, str] = stored or {}
+
+    def get(self, key: str) -> str | None:
+        return self._store.get(key)
+
+    def set(self, key: str, value: str) -> None:
+        self._store[key] = value
+
+
+def make_context(task_state: FakeTaskState | None = None) -> dict:
+    ctx: dict = {}
+    if task_state is not None:
+        ctx["task_state"] = task_state
+    return ctx
+
+
+class TestFirstSubmission:
+    def test_submits_and_polls_when_no_prior_state(self):
+        op = ConcreteResumableOperator(task_id="test_task")
+        task_state = FakeTaskState()
+        ctx = make_context(task_state)
+
+        op.execute_resumable(ctx)
+
+        assert op.submitted_ids == ["job-001"]
+        assert op.polled_ids == ["job-001"]
+
+    def test_persists_external_id_before_polling(self):
+        """The ID must be in task_state before poll_until_complete is 
called."""
+        op = ConcreteResumableOperator(task_id="test_task")
+        task_state = FakeTaskState()
+        persisted_at_poll: list[str | None] = []
+
+        original_set = task_state.set
+
+        def set_and_track(key, value):
+            original_set(key, value)
+
+        def poll_side_effect(external_id, context):
+            persisted_at_poll.append(task_state.get("test_job_id"))
+
+        task_state.set = set_and_track
+        op.poll_until_complete = poll_side_effect
+
+        op.execute_resumable(make_context(task_state))
+
+        assert persisted_at_poll == ["job-001"], "ID must be persisted before 
polling starts"
+
+    def test_returns_job_result(self):
+        op = ConcreteResumableOperator(task_id="test_task")
+        result = op.execute_resumable(make_context(FakeTaskState()))
+
+        assert result == "result-of-job-001"
+
+
+class TestRetryWithDifferentJobStatuses:
+    def test_skips_submission_when_job_active(self):
+        op = ConcreteResumableOperator(task_id="test_task")
+        op._status_map["job-001"] = "RUNNING"
+        task_state = FakeTaskState({"test_job_id": "job-001"})
+        ctx = make_context(task_state)
+
+        op.execute_resumable(ctx)
+
+        assert op.submitted_ids == [], "should not resubmit when job is active"
+        assert op.polled_ids == ["job-001"]
+
+    def test_pending_status_also_skips_submission(self):
+        op = ConcreteResumableOperator(task_id="test_task")
+        op._status_map["job-001"] = "PENDING"
+        task_state = FakeTaskState({"test_job_id": "job-001"})
+
+        op.execute_resumable(make_context(task_state))
+
+        assert op.submitted_ids == []
+        assert op.polled_ids == ["job-001"]
+
+    def test_returns_result_immediately_without_polling(self):
+        op = ConcreteResumableOperator(task_id="test_task")
+        op._status_map["job-001"] = "SUCCEEDED"
+        task_state = FakeTaskState({"test_job_id": "job-001"})
+
+        result = op.execute_resumable(make_context(task_state))
+
+        assert op.submitted_ids == [], "should not resubmit"
+        assert op.polled_ids == [], "should not poll again"
+        assert result == "result-of-job-001"
+
+    @pytest.mark.parametrize("status", ["FAILED", "KILLED", "ERROR", 
"UNKNOWN"])
+    def test_resubmits_when_prior_job_in_terminal_failure(self, status):
+        op = ConcreteResumableOperator(task_id="test_task")
+        op._status_map["job-001"] = status
+        op._next_id = "job-002"
+        task_state = FakeTaskState({"test_job_id": "job-001"})
+
+        op.execute_resumable(make_context(task_state))
+
+        assert op.submitted_ids == ["job-002"], "should resubmit fresh"
+        assert op.polled_ids == ["job-002"]
+
+
+class TestExternalIdKey:
+    def test_custom_key_used_for_storage_and_retrieval(self):
+        class CustomKeyOp(ConcreteResumableOperator):
+            external_id_key = "my_custom_key"
+
+        op = CustomKeyOp(task_id="test_task")
+        task_state = FakeTaskState()
+
+        op.execute_resumable(make_context(task_state))
+
+        assert task_state.get("my_custom_key") == "job-001"
diff --git a/uv.lock b/uv.lock
index 83759330bb2..a02c1fca513 100644
--- a/uv.lock
+++ b/uv.lock
@@ -3813,6 +3813,8 @@ dependencies = [
     { name = "apache-airflow-providers-common-compat" },
     { name = "grpcio-status" },
     { name = "pyspark-client" },
+    { name = "requests" },
+    { name = "tenacity" },
 ]
 
 [package.optional-dependencies]
@@ -3849,6 +3851,8 @@ requires-dist = [
     { name = "grpcio-status", specifier = ">=1.67.0" },
     { name = "pyspark", marker = "extra == 'pyspark'", specifier = ">=4.0.0" },
     { name = "pyspark-client", specifier = ">=4.0.0" },
+    { name = "requests", specifier = ">=2.32.0" },
+    { name = "tenacity", specifier = ">=8.3.0" },
 ]
 provides-extras = ["cncf-kubernetes", "openlineage", "pyspark"]
 

Reply via email to