This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new fd9241cdf0b Add `ResumableJobMixin` with `SparkSubmitOperator` as a
case study for surviving worker failures (standalone) (#67118)
fd9241cdf0b is described below
commit fd9241cdf0bb64d5b3c4619be83619db62671824
Author: Amogh Desai <[email protected]>
AuthorDate: Thu May 28 10:05:26 2026 +0530
Add `ResumableJobMixin` with `SparkSubmitOperator` as a case study for
surviving worker failures (standalone) (#67118)
---
providers/apache/spark/docs/index.rst | 2 +
providers/apache/spark/docs/operators.rst | 21 ++
providers/apache/spark/provider.yaml | 16 ++
providers/apache/spark/pyproject.toml | 2 +
.../providers/apache/spark/get_provider_info.py | 10 +
.../providers/apache/spark/hooks/spark_submit.py | 52 ++--
.../apache/spark/operators/spark_submit.py | 147 +++++++++++-
.../unit/apache/spark/hooks/test_spark_submit.py | 30 +++
.../apache/spark/operators/test_spark_submit.py | 262 ++++++++++++++++++++-
scripts/ci/prek/known_airflow_exceptions.txt | 2 +-
task-sdk/docs/api.rst | 2 +
task-sdk/src/airflow/sdk/__init__.py | 3 +
task-sdk/src/airflow/sdk/bases/resumablemixin.py | 167 +++++++++++++
.../tests/task_sdk/bases/test_resumablemixin.py | 177 ++++++++++++++
uv.lock | 4 +
15 files changed, 870 insertions(+), 27 deletions(-)
diff --git a/providers/apache/spark/docs/index.rst
b/providers/apache/spark/docs/index.rst
index eb3cadfb81b..5138bf8952d 100644
--- a/providers/apache/spark/docs/index.rst
+++ b/providers/apache/spark/docs/index.rst
@@ -104,6 +104,8 @@ PIP package Version required
``apache-airflow-providers-common-compat`` ``>=1.12.0``
``pyspark-client`` ``>=4.0.0``
``grpcio-status`` ``>=1.67.0``
+``requests`` ``>=2.32.0``
+``tenacity`` ``>=8.3.0``
========================================== ==================
Cross provider package dependencies
diff --git a/providers/apache/spark/docs/operators.rst
b/providers/apache/spark/docs/operators.rst
index 125039ebdf3..0a645542323 100644
--- a/providers/apache/spark/docs/operators.rst
+++ b/providers/apache/spark/docs/operators.rst
@@ -181,3 +181,24 @@ Reference
"""""""""
For further information, look at `Apache Spark submitting applications
<https://spark.apache.org/docs/latest/submitting-applications.html>`_.
+
+Cluster mode crash recovery (Spark standalone)
+"""""""""""""""""""""""""""""""""""""""""""""""
+
+When running in Spark standalone cluster mode (``--deploy-mode cluster``), the
Spark driver runs
+independently on the cluster. If the Airflow worker dies while the Spark job
is running, the driver keeps running but
+Airflow loses track of it and the behaviour to submit a brand new job would be
wasting
+the compute already done or even cause conflicts if the Spark job itself is
not designed to be idempotent.
+
+Now, the ``SparkSubmitOperator`` solves this by persisting the driver ID to
``task_state`` immediately after
+submission. On retry, it reads the ID back and reconnects to the
already-running driver instead of
+resubmitting.
+
+This is the **synchronous path** — the worker holds a slot for the duration of
polling. This is
+a crash-safety net for teams running sync operators for log observability, org
constraints, or
+because a Triggerer is not available. Teams with a Triggerer available may
also consider
+deferrable operators, which free the worker slot but may come with added
complexity.
+
+.. note::
+ Crash recovery in cluster mode requires Airflow 3.3+ (``task_state``
support). On earlier
+ versions the operator falls back to the previous behavior of always
submitting fresh.
diff --git a/providers/apache/spark/provider.yaml
b/providers/apache/spark/provider.yaml
index 2fd094e6d75..bb91a8e3127 100644
--- a/providers/apache/spark/provider.yaml
+++ b/providers/apache/spark/provider.yaml
@@ -210,6 +210,22 @@ connection-types:
- string
- 'null'
format: password
+ rest-scheme:
+ label: REST scheme
+ description: Scheme for the Spark standalone REST API (http or https).
Default is http.
+ schema:
+ type:
+ - string
+ - 'null'
+ default: http
+ rest-port:
+ label: REST port
+ description: Port for the Spark standalone REST API
(spark.master.rest.port). Default is 6066.
+ schema:
+ type:
+ - string
+ - 'null'
+ default: '6066'
task-decorators:
- class-name: airflow.providers.apache.spark.decorators.pyspark.pyspark_task
diff --git a/providers/apache/spark/pyproject.toml
b/providers/apache/spark/pyproject.toml
index 216c5c003da..e7f4fa480d3 100644
--- a/providers/apache/spark/pyproject.toml
+++ b/providers/apache/spark/pyproject.toml
@@ -63,6 +63,8 @@ dependencies = [
"apache-airflow-providers-common-compat>=1.12.0",
"pyspark-client>=4.0.0",
"grpcio-status>=1.67.0",
+ "requests>=2.32.0",
+ "tenacity>=8.3.0",
]
# The optional dependencies should be modified in place in the generated file
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
index b9871156257..ef09d0a6ae9 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
@@ -126,6 +126,16 @@ def get_provider_info():
"description": "Run the command `base64
<your-keytab-path>` and use its output.",
"schema": {"type": ["string", "null"], "format":
"password"},
},
+ "rest-scheme": {
+ "label": "REST scheme",
+ "description": "Scheme for the Spark standalone REST
API (http or https). Default is http.",
+ "schema": {"type": ["string", "null"], "default":
"http"},
+ },
+ "rest-port": {
+ "label": "REST port",
+ "description": "Port for the Spark standalone REST API
(spark.master.rest.port). Default is 6066.",
+ "schema": {"type": ["string", "null"], "default":
"6066"},
+ },
},
},
],
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
index 62d18aac049..9aa3ddc885e 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -160,6 +160,18 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
description="Run the command `base64 <your-keytab-path>` and
use its output.",
validators=[Optional()],
),
+ "rest-scheme": StringField(
+ lazy_gettext("REST scheme"),
+ widget=BS3TextFieldWidget(),
+ description="Scheme for the Spark standalone REST API (http or
https). Default: http.",
+ validators=[Optional()],
+ ),
+ "rest-port": StringField(
+ lazy_gettext("REST port"),
+ widget=BS3TextFieldWidget(),
+ description="Port for the Spark standalone REST API
(spark.master.rest.port). Default: 6066.",
+ validators=[Optional()],
+ ),
}
def __init__(
@@ -258,7 +270,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
def _resolve_connection(self) -> dict[str, Any]:
# Build from connection master or default to yarn if not available
- conn_data = {
+ conn_data: dict[str, Any] = {
"master": "yarn",
"queue": None, # yarn queue
"deploy_mode": None,
@@ -266,6 +278,9 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
"namespace": None,
"principal": self._principal,
"keytab": self._keytab,
+ # fallback if connection lookup fails; overridden by
rest-scheme/rest-port extras below
+ "rest_scheme": "http",
+ "rest_port": 6066,
}
try:
@@ -308,6 +323,8 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
)
conn_data["spark_binary"] = self.spark_binary
conn_data["namespace"] = extra.get("namespace")
+ conn_data["rest_scheme"] = extra.get("rest-scheme", "http")
+ conn_data["rest_port"] = int(extra.get("rest-port", 6066))
if conn_data["principal"] is None:
conn_data["principal"] = extra.get("principal")
if conn_data["keytab"] is None:
@@ -587,7 +604,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
except Exception as exc:
self.log.warning("Post-submit command raised an exception: %s.
Error: %s", cmd, exc)
- def submit(self, application: str = "", **kwargs: Any) -> None:
+ def submit(self, application: str = "", **kwargs: Any) -> str | None:
"""
Remote Popen to execute the spark-submit job.
@@ -626,27 +643,18 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
f"Cannot execute: {self._mask_cmd(spark_submit_cmd)}.
Error code is: {returncode}."
)
- self.log.debug("Should track driver: %s",
self._should_track_driver_status)
-
- # We want the Airflow job to wait until the Spark driver is
finished
- if self._should_track_driver_status:
- if self._driver_id is None:
- raise AirflowException(
- "No driver id is known: something went wrong when
executing the spark submit command"
- )
-
- # We start with the SUBMITTED status as initial status
- self._driver_status = "SUBMITTED"
-
- # Start tracking the driver status (blocking function)
- self._start_driver_status_tracking()
-
- if self._driver_status != "FINISHED":
- raise AirflowException(
- f"ERROR : Driver {self._driver_id} badly exited with
status {self._driver_status}"
- )
+ if self._should_track_driver_status and self._driver_id is None:
+ raise AirflowException(
+ "No driver id is known: something went wrong when
executing the spark submit command"
+ )
finally:
- self._run_post_submit_commands()
+ # In cluster mode with driver tracking, the operator calls
poll_until_complete
+ # after submit() returns, so post_submit_commands are deferred
there to preserve
+ # the "runs after job finishes" contract. In all other modes, run
them here.
+ if not self._should_track_driver_status:
+ self._run_post_submit_commands()
+
+ return self._driver_id
def _process_spark_submit_log(self, itr: Iterator[Any]) -> None:
"""
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index 0e67fa5b50d..76b010107da 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -18,7 +18,10 @@
from __future__ import annotations
from collections.abc import Sequence
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, cast
+
+import requests
+from tenacity import retry, stop_after_attempt, wait_fixed
from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
from airflow.providers.common.compat.openlineage.utils.spark import (
@@ -27,11 +30,30 @@ from
airflow.providers.common.compat.openlineage.utils.spark import (
)
from airflow.providers.common.compat.sdk import BaseOperator, conf
+try:
+ from airflow.sdk.bases.resumablemixin import ResumableJobMixin
+except ImportError:
+ # Airflow 2 compat.
+ # ResumableJobMixin does not exist in Airflow 2, so we need to add a stub
to make it
+ # behave as before
+ class ResumableJobMixin: # type: ignore[no-redef]
+ """Airflow 2 stub — no task_state, always submits fresh."""
+
+ external_id_key: str = "remote_job_id"
+
+ def execute_resumable(self, context):
+ external_id = self.submit_job(context)
+ self.poll_until_complete(external_id, context)
+ return self.get_job_result(external_id, context)
+
+
if TYPE_CHECKING:
+ from pydantic import JsonValue
+
from airflow.providers.common.compat.sdk import Context
-class SparkSubmitOperator(BaseOperator):
+class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
"""
Wrap the spark-submit binary to kick off a spark-submit job; requires
"spark-submit" binary in the PATH.
@@ -88,6 +110,10 @@ class SparkSubmitOperator(BaseOperator):
Useful for cleaning up sidecars such as Istio. Failures produce a
warning but do not fail the task.
"""
+ # Generic key used across all Spark deployment modes (standalone driver ID,
+ # YARN application ID, K8s driver pod name).
+ external_id_key = "spark_job_id"
+
template_fields: Sequence[str] = (
"application",
"conf",
@@ -141,6 +167,7 @@ class SparkSubmitOperator(BaseOperator):
deploy_mode: str | None = None,
use_krb5ccache: bool = False,
post_submit_commands: list[str] | None = None,
+ reconnect_on_retry: bool = True,
openlineage_inject_parent_job_info: bool = conf.getboolean(
"openlineage", "spark_inject_parent_job_info", fallback=False
),
@@ -184,6 +211,7 @@ class SparkSubmitOperator(BaseOperator):
self._conn_id = conn_id
self._use_krb5ccache = use_krb5ccache
+ self.reconnect_on_retry = reconnect_on_retry
self._openlineage_inject_parent_job_info =
openlineage_inject_parent_job_info
self._openlineage_inject_transport_info =
openlineage_inject_transport_info
@@ -198,7 +226,120 @@ class SparkSubmitOperator(BaseOperator):
self.conf =
inject_transport_information_into_spark_properties(self.conf, context)
if self._hook is None:
self._hook = self._get_hook()
- self._hook.submit(self.application)
+ hook = self._hook
+ if hook._should_track_driver_status:
+ if self.reconnect_on_retry:
+ return self.execute_resumable(context)
+ # reconnect_on_retry=False: still submit-and-poll, just skip
task_state persistence.
+ driver_id = self.submit_job(context)
+ self.poll_until_complete(driver_id, context)
+ return self.get_job_result(driver_id, context)
+ hook.submit(self.application)
+
+ def submit_job(self, context: Context) -> str:
+ if self._hook is None:
+ self._hook = self._get_hook()
+ driver_id = self._hook.submit(self.application)
+ if not driver_id:
+ raise RuntimeError("spark-submit did not return a driver ID")
+ self.log.info("Spark driver submitted: %s", driver_id)
+ return driver_id
+
+ def get_job_status(self, external_id: JsonValue) -> str:
+ # called from submit_job which always returns a str (Spark driver IDs
are strings)
+ external_id = cast("str", external_id)
+ if self._hook is None:
+ self._hook = self._get_hook()
+ # The YARN and K8s branches below (and in is_job_active,
is_job_succeeded, poll_until_complete)
+ # are currently unreachable: execute_resumable is only called when
_should_track_driver_status
+ # is True, which requires spark:// + cluster mode. They are
scaffolding for a follow-up PR
+ # that extends ResumableJobMixin support to YARN and Kubernetes.
+ if self._hook._is_yarn:
+ # TODO: call YARN ResourceManager REST API
+ # GET http://rm:8088/ws/v1/cluster/apps/{external_id}
+ raise NotImplementedError("YARN job status not yet implemented")
+ if self._hook._is_kubernetes:
+ # TODO: call K8s pod status API
+ raise NotImplementedError("K8s job status not yet implemented")
+ scheme = self._hook._connection.get("rest_scheme", "http")
+ rest_port = self._hook._connection.get("rest_port", 6066)
+ # HA master URLs can look like spark://m1:7077,m2:7077 — try each host
in order.
+ # The master URL port (e.g. 7077) is the RPC port — not the REST API
port.
+ # Use rest-port connection extra to override spark.master.rest.port
(default 6066).
+ master_urls = self._hook._connection["master"].replace("spark://",
"").split(",")
+ last_exc: Exception = RuntimeError("No Spark masters to query")
+ for m in master_urls:
+ host = m.strip().split(":")[0]
+ url =
f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}"
+ try:
+ status = self._fetch_driver_status(url, external_id)
+ return status
+ except Exception as e:
+ self.log.warning("Could not reach Spark master %s: %s", host,
e)
+ last_exc = e
+ raise last_exc
+
+ @retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True)
+ def _fetch_driver_status(self, url: str, external_id: str) -> str:
+ response = requests.get(url, timeout=30)
+ response.raise_for_status()
+ # "success:false" means the master does not recognise the driver ID or
is in recovery.
+ #
https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
+ data = response.json()
+ if not data.get("success"):
+ raise RuntimeError(
+ f"Spark REST API returned failure for {external_id}:
{data.get('message', 'unknown error')}"
+ )
+ status = data["driverState"]
+ self.log.info("Driver %s status: %s", external_id, status)
+ return status
+
+ def is_job_active(self, status: str) -> bool:
+ if self._hook is None:
+ self._hook = self._get_hook()
+ status = status.upper()
+ if self._hook._is_yarn:
+ #
https://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/ResourceManagerRest.html
+ return status in ("NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED",
"RUNNING")
+ if self._hook._is_kubernetes:
+ return status in ("PENDING", "RUNNING")
+ # RELAUNCHING: driver is being restarted after a failure, still alive.
+ # UNKNOWN: master is in failure recovery, state is temporarily
unavailable.
+ #
https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
+ return status in ("SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN")
+
+ def is_job_succeeded(self, status: str) -> bool:
+ if self._hook is None:
+ self._hook = self._get_hook()
+ status = status.upper()
+ if self._hook._is_kubernetes:
+ return status == "SUCCEEDED"
+ # standalone and YARN both use FINISHED
+ return status == "FINISHED"
+
+ def poll_until_complete(self, external_id: JsonValue, context: Context) ->
None:
+ # called from submit_job which always returns a str (Spark driver IDs
are strings)
+ external_id = cast("str", external_id)
+ if self._hook is None:
+ self._hook = self._get_hook()
+ if self._hook._is_yarn:
+ # TODO: poll YARN ResourceManager until app reaches terminal state
+ raise NotImplementedError("YARN poll not yet implemented")
+ if self._hook._is_kubernetes:
+ # TODO: poll K8s pod phase until terminal
+ raise NotImplementedError("K8s poll not yet implemented")
+ self.log.info("Polling driver %s until completion", external_id)
+ self._hook._driver_id = external_id
+ try:
+ self._hook._start_driver_status_tracking()
+ if self._hook._driver_status != "FINISHED":
+ raise RuntimeError(f"Driver {external_id} exited with status
{self._hook._driver_status}")
+ finally:
+ # post-submit commands must fire whether the job succeeded or
failed.
+ self._hook._run_post_submit_commands()
+
+ def get_job_result(self, external_id: JsonValue, context: Context) -> None:
+ return None
def on_kill(self) -> None:
if self._hook is None:
diff --git
a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
index 1b2feaa33e9..c909e9f12ab 100644
--- a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
+++ b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
@@ -397,6 +397,8 @@ class TestSparkSubmitHook:
"namespace": None,
"principal": None,
"keytab": None,
+ "rest_scheme": "http",
+ "rest_port": 6066,
}
assert connection == expected_spark_connection
assert dict_cmd["--master"] == "yarn"
@@ -420,6 +422,8 @@ class TestSparkSubmitHook:
"namespace": None,
"principal": None,
"keytab": None,
+ "rest_scheme": "http",
+ "rest_port": 6066,
}
assert connection == expected_spark_connection
assert dict_cmd["--master"] == "yarn"
@@ -443,6 +447,8 @@ class TestSparkSubmitHook:
"namespace": None,
"principal": None,
"keytab": None,
+ "rest_scheme": "http",
+ "rest_port": 6066,
}
assert connection == expected_spark_connection
assert dict_cmd["--master"] == "mesos://host:5050"
@@ -465,6 +471,8 @@ class TestSparkSubmitHook:
"namespace": None,
"principal": None,
"keytab": None,
+ "rest_scheme": "http",
+ "rest_port": 6066,
}
assert connection == expected_spark_connection
assert dict_cmd["--master"] == "yarn://yarn-master"
@@ -489,6 +497,8 @@ class TestSparkSubmitHook:
"namespace": "mynamespace",
"principal": None,
"keytab": None,
+ "rest_scheme": "http",
+ "rest_port": 6066,
}
assert connection == expected_spark_connection
assert dict_cmd["--master"] == "k8s://https://k8s-master"
@@ -515,6 +525,8 @@ class TestSparkSubmitHook:
"namespace": "airflow",
"principal": None,
"keytab": None,
+ "rest_scheme": "http",
+ "rest_port": 6066,
}
assert connection == expected_spark_connection
assert dict_cmd["--master"] == "k8s://https://k8s-master"
@@ -538,6 +550,8 @@ class TestSparkSubmitHook:
"namespace": None,
"principal": None,
"keytab": None,
+ "rest_scheme": "http",
+ "rest_port": 6066,
}
assert connection == expected_spark_connection
assert cmd[0] == "spark2-submit"
@@ -559,6 +573,8 @@ class TestSparkSubmitHook:
"namespace": None,
"principal": None,
"keytab": None,
+ "rest_scheme": "http",
+ "rest_port": 6066,
}
assert connection == expected_spark_connection
assert cmd[0] == "spark3-submit"
@@ -619,6 +635,8 @@ class TestSparkSubmitHook:
"namespace": None,
"principal": None,
"keytab": None,
+ "rest_scheme": "http",
+ "rest_port": 6066,
}
assert connection == expected_spark_connection
assert cmd[0] == "spark3-submit"
@@ -641,6 +659,8 @@ class TestSparkSubmitHook:
"namespace": None,
"principal": None,
"keytab": None,
+ "rest_scheme": "http",
+ "rest_port": 6066,
}
assert connection == expected_spark_connection
assert cmd[0] == "spark-submit"
@@ -662,6 +682,8 @@ class TestSparkSubmitHook:
"namespace": None,
"principal": None,
"keytab": None,
+ "rest_scheme": "http",
+ "rest_port": 6066,
}
assert connection == expected_spark_connection
assert cmd[0] == "spark-submit"
@@ -684,6 +706,8 @@ class TestSparkSubmitHook:
"namespace": None,
"principal": "user/[email protected]",
"keytab": None,
+ "rest_scheme": "http",
+ "rest_port": 6066,
}
assert connection == expected_spark_connection
assert dict_cmd["--principal"] == "user/[email protected]"
@@ -706,6 +730,8 @@ class TestSparkSubmitHook:
"namespace": None,
"principal": "will-override",
"keytab": None,
+ "rest_scheme": "http",
+ "rest_port": 6066,
}
assert connection == expected_spark_connection
assert dict_cmd["--principal"] == "will-override"
@@ -732,6 +758,8 @@ class TestSparkSubmitHook:
"namespace": None,
"principal": None,
"keytab": "privileged_user.keytab",
+ "rest_scheme": "http",
+ "rest_port": 6066,
}
assert connection == expected_spark_connection
assert dict_cmd["--keytab"] == "privileged_user.keytab"
@@ -757,6 +785,8 @@ class TestSparkSubmitHook:
"namespace": None,
"principal": None,
"keytab": "will-override",
+ "rest_scheme": "http",
+ "rest_port": 6066,
}
assert connection == expected_spark_connection
assert dict_cmd["--keytab"] == "will-override"
diff --git
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
index 6a85ce1b920..65af1116861 100644
---
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
+++
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
@@ -32,7 +32,7 @@ from airflow.utils.types import DagRunType
from tests_common.test_utils.dag import sync_dag_to_db
from tests_common.test_utils.taskinstance import create_task_instance,
render_template_fields
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_3_PLUS
DEFAULT_DATE = timezone.datetime(2017, 1, 1)
@@ -321,6 +321,7 @@ class TestSparkSubmitOperator:
openlineage_inject_transport_info=True,
**self._config,
)
+ mock_get_hook.return_value._should_track_driver_status = False
operator.execute(MagicMock())
assert operator.conf == {
@@ -387,6 +388,7 @@ class TestSparkSubmitOperator:
openlineage_inject_transport_info=True,
**self._config,
)
+ mock_get_hook.return_value._should_track_driver_status = False
operator.execute({"ti": mock_ti})
assert operator.conf == {
@@ -425,6 +427,7 @@ class TestSparkSubmitOperator:
CompositeConfig.from_dict({"transports": {"test1": {"type":
"console"}}})
)
+ mock_get_hook.return_value._should_track_driver_status = False
with caplog.at_level(logging.INFO):
operator = SparkSubmitOperator(
task_id="spark_submit_job",
@@ -456,6 +459,7 @@ class TestSparkSubmitOperator:
config=ConsoleConfig()
)
+ mock_get_hook.return_value._should_track_driver_status = False
with caplog.at_level(logging.INFO):
operator = SparkSubmitOperator(
task_id="spark_submit_job",
@@ -474,3 +478,259 @@ class TestSparkSubmitOperator:
assert operator.conf == {
"parquet.compression": "SNAPPY",
}
+
+
+class FakeTaskState:
+ """In-memory task state for tests."""
+
+ def __init__(self, stored: dict[str, str] | None = None):
+ self._store: dict[str, str] = dict(stored or {})
+
+ def get(self, key: str) -> str | None:
+ return self._store.get(key)
+
+ def set(self, key: str, value: str) -> None:
+ self._store[key] = value
+
+
[email protected](
+ not AIRFLOW_V_3_3_PLUS,
+ reason="ResumableJobMixin reconnect requires task_state, available in
Airflow 3.3+",
+)
+class TestSparkSubmitOperatorResumable:
+ def setup_method(self):
+ args = {"owner": "airflow", "start_date": DEFAULT_DATE}
+ self.dag = DAG("test_resumable_dag", schedule=None, default_args=args)
+
+ def _make_operator(self, **kwargs):
+ return SparkSubmitOperator(task_id="test", dag=self.dag,
application="test.jar", **kwargs)
+
+ def _make_hook(self, should_track=False, is_yarn=False,
is_kubernetes=False):
+ hook = MagicMock()
+ hook._should_track_driver_status = should_track
+ hook._is_yarn = is_yarn
+ hook._is_kubernetes = is_kubernetes
+ hook._connection = {"master": "spark://localhost:7077"}
+ return hook
+
+ def test_non_cluster_mode_calls_hook_submit_directly(self):
+ operator = self._make_operator()
+ operator._hook = self._make_hook(should_track=False)
+
+ operator.execute(context={})
+
+ operator._hook.submit.assert_called_once_with("test.jar")
+
+ def test_cluster_mode_first_run_persists_id_before_polling(self):
+ operator = self._make_operator()
+ operator._hook = self._make_hook(should_track=True)
+ operator._hook.submit.return_value = "driver-001"
+
+ task_state = FakeTaskState()
+ persisted_before_poll = []
+
+ def track_poll(external_id, context):
+ persisted_before_poll.append(task_state.get("spark_job_id"))
+
+ operator.poll_until_complete = track_poll
+
+ operator.execute(context={"task_state": task_state})
+
+ operator._hook.submit.assert_called_once_with("test.jar")
+ assert persisted_before_poll == ["driver-001"]
+
+ @pytest.mark.parametrize(
+ ("prior_status", "expect_submit", "expect_poll_id"),
+ [
+ ("RUNNING", False, "driver-001"),
+ ("SUBMITTED", False, "driver-001"),
+ ("FINISHED", False, None),
+ ("FAILED", True, "driver-new"),
+ ("KILLED", True, "driver-new"),
+ ],
+ )
+ def test_retry_behaviour_based_on_prior_driver_status(self, prior_status,
expect_submit, expect_poll_id):
+ operator = self._make_operator()
+ operator._hook = self._make_hook(should_track=True)
+ operator._hook.submit.return_value = "driver-new"
+ task_state = FakeTaskState({"spark_job_id": "driver-001"})
+
+ operator.get_job_status = lambda external_id: prior_status
+ polled = []
+ operator.poll_until_complete = lambda external_id, context:
polled.append(external_id)
+
+ operator.execute(context={"task_state": task_state})
+
+ if expect_submit:
+ operator._hook.submit.assert_called_once_with("test.jar")
+ else:
+ operator._hook.submit.assert_not_called()
+
+ if expect_poll_id:
+ assert polled == [expect_poll_id]
+ else:
+ assert polled == []
+
+ def test_submits_fresh_when_task_state_unavailable(self):
+ operator = self._make_operator()
+ operator._hook = self._make_hook(should_track=True)
+ operator._hook.submit.return_value = "driver-001"
+ polled = []
+ operator.poll_until_complete = lambda external_id, context:
polled.append(external_id)
+
+ # no task_state key in context
+ operator.execute(context={})
+
+ operator._hook.submit.assert_called_once_with("test.jar")
+ assert polled == ["driver-001"]
+
+ def test_reconnect_on_retry_false_submits_fresh_and_polls(self):
+ operator = self._make_operator(reconnect_on_retry=False)
+ operator._hook = self._make_hook(should_track=True)
+ operator._hook.submit.return_value = "driver-new"
+ task_state = FakeTaskState({"spark_job_id": "driver-old"})
+ polled = []
+ operator.poll_until_complete = lambda external_id, context:
polled.append(external_id)
+
+ operator.execute(context={"task_state": task_state})
+ # reconnect_on_retry=False: ignores prior driver ID, submits fresh,
but still polls
+ operator._hook.submit.assert_called_once_with("test.jar")
+ assert polled == ["driver-new"]
+
+ @pytest.mark.parametrize(
+ ("is_yarn", "is_kubernetes", "status", "expected_active",
"expected_succeeded"),
+ [
+ (False, False, "RUNNING", True, False),
+ (False, False, "SUBMITTED", True, False),
+ (False, False, "FINISHED", False, True),
+ (False, False, "FAILED", False, False),
+ (True, False, "RUNNING", True, False),
+ (True, False, "ACCEPTED", True, False),
+ (True, False, "NEW", True, False),
+ (True, False, "FINISHED", False, True),
+ (True, False, "FAILED", False, False),
+ (False, True, "Running", True, False),
+ (False, True, "Pending", True, False),
+ (False, True, "Succeeded", False, True),
+ (False, True, "Failed", False, False),
+ ],
+ )
+ def test_job_status_mappings(self, is_yarn, is_kubernetes, status,
expected_active, expected_succeeded):
+ operator = self._make_operator()
+ operator._hook = self._make_hook(is_yarn=is_yarn,
is_kubernetes=is_kubernetes)
+
+ assert operator.is_job_active(status) == expected_active
+ assert operator.is_job_succeeded(status) == expected_succeeded
+
+ @pytest.mark.parametrize(
+ ("response_json", "expected_status", "expected_error"),
+ [
+ ({"success": True, "driverState": "RUNNING"}, "RUNNING", None),
+ ({"success": False, "message": "driver not found"}, None, "driver
not found"),
+ ({"driverState": "RUNNING"}, None, "unknown error"),
+ ],
+ )
+ def test_get_job_status(self, response_json, expected_status,
expected_error):
+ operator = self._make_operator()
+ operator._hook = self._make_hook(should_track=True)
+ mock_response = MagicMock()
+ mock_response.json.return_value = response_json
+
+ with mock.patch("requests.get", return_value=mock_response):
+ if expected_error:
+ with pytest.raises(RuntimeError, match=expected_error):
+ operator.get_job_status("driver-001")
+ else:
+ assert operator.get_job_status("driver-001") == expected_status
+
+ def test_get_job_status_ha_tries_next_master(self):
+ operator = self._make_operator()
+ hook = self._make_hook(should_track=True)
+ # Master URL port (7077) is RPC — REST API must use 6066, not 7077
+ hook._connection = {"master": "spark://m1:7077,m2:7077"}
+ operator._hook = hook
+
+ good_response = MagicMock()
+ good_response.json.return_value = {"success": True, "driverState":
"RUNNING"}
+ captured_urls = []
+
+ def side_effect(url, timeout):
+ captured_urls.append(url)
+ if "m1" in url:
+ raise ConnectionError("m1 unreachable")
+ return good_response
+
+ with mock.patch("requests.get", side_effect=side_effect):
+ assert operator.get_job_status("driver-001") == "RUNNING"
+
+ assert all(":6066/" in url for url in captured_urls), "REST API must
use port 6066, not the RPC port"
+
+ def test_get_job_status_ha_tries_next_master_on_success_false(self):
+ """success:false from m1 (e.g. HA recovery in progress) should fall
through to m2."""
+ operator = self._make_operator()
+ hook = self._make_hook(should_track=True)
+ hook._connection = {"master": "spark://m1:7077,m2:7077"}
+ operator._hook = hook
+
+ bad_response = MagicMock()
+ bad_response.json.return_value = {"success": False, "message": "Driver
not found"}
+ good_response = MagicMock()
+ good_response.json.return_value = {"success": True, "driverState":
"RUNNING"}
+
+ def side_effect(url, timeout):
+ if "m1" in url:
+ return bad_response
+ return good_response
+
+ with mock.patch("requests.get", side_effect=side_effect):
+ assert operator.get_job_status("driver-001") == "RUNNING"
+
+ def test_get_job_status_ha_raises_when_all_masters_unreachable(self):
+ operator = self._make_operator()
+ hook = self._make_hook(should_track=True)
+ hook._connection = {"master": "spark://m1:7077,m2:7077"}
+ operator._hook = hook
+
+ with mock.patch("requests.get",
side_effect=ConnectionError("unreachable")):
+ with pytest.raises(ConnectionError):
+ operator.get_job_status("driver-001")
+
+ def test_get_job_status_uses_rest_scheme_from_connection(self):
+ operator = self._make_operator()
+ hook = self._make_hook(should_track=True)
+ hook._connection = {"master": "spark://myhost:6066", "rest_scheme":
"https"}
+ operator._hook = hook
+
+ mock_response = MagicMock()
+ mock_response.json.return_value = {"success": True, "driverState":
"RUNNING"}
+ captured_urls = []
+
+ def capture(url, timeout):
+ captured_urls.append(url)
+ return mock_response
+
+ with mock.patch("requests.get", side_effect=capture):
+ operator.get_job_status("driver-001")
+
+ assert len(captured_urls) == 1
+ assert captured_urls[0].startswith("https://")
+
+ def test_poll_until_complete_runs_post_submit_on_failure(self):
+ """post_submit_commands must run even when the driver exits with a
failure status."""
+ operator = self._make_operator()
+ hook = self._make_hook(should_track=True)
+ hook._connection = {"master": "spark://myhost:7077"}
+ hook._driver_status = "FAILED"
+
+ def simulate_failed_tracking():
+ hook._driver_status = "FAILED"
+
+ hook._start_driver_status_tracking =
mock.MagicMock(side_effect=simulate_failed_tracking)
+ post_submit_called = []
+ hook._run_post_submit_commands = mock.MagicMock(side_effect=lambda:
post_submit_called.append(True))
+ operator._hook = hook
+
+ with pytest.raises(RuntimeError, match="FAILED"):
+ operator.poll_until_complete("driver-001", {})
+
+ assert post_submit_called, "_run_post_submit_commands must be called
even on driver failure"
diff --git a/scripts/ci/prek/known_airflow_exceptions.txt
b/scripts/ci/prek/known_airflow_exceptions.txt
index f1ddfbd1efc..bd4570fc55f 100644
--- a/scripts/ci/prek/known_airflow_exceptions.txt
+++ b/scripts/ci/prek/known_airflow_exceptions.txt
@@ -145,7 +145,7 @@
providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py::3
providers/apache/pig/src/airflow/providers/apache/pig/hooks/pig.py::1
providers/apache/pinot/src/airflow/providers/apache/pinot/hooks/pinot.py::1
providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_sql.py::2
-providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py::11
+providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py::10
providers/arangodb/src/airflow/providers/arangodb/hooks/arangodb.py::9
providers/arangodb/src/airflow/providers/arangodb/operators/arangodb.py::1
providers/atlassian/jira/src/airflow/providers/atlassian/jira/hooks/jira.py::1
diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst
index 8f3f3de9539..c2bb0a19fbc 100644
--- a/task-sdk/docs/api.rst
+++ b/task-sdk/docs/api.rst
@@ -99,6 +99,8 @@ Bases
.. autoapiclass:: airflow.sdk.SkipMixin
+.. autoclass:: airflow.sdk.ResumableJobMixin
+
.. autoapiclass:: airflow.sdk.BaseHook
Callbacks
diff --git a/task-sdk/src/airflow/sdk/__init__.py
b/task-sdk/src/airflow/sdk/__init__.py
index eeae86f1eb3..ab834658c0f 100644
--- a/task-sdk/src/airflow/sdk/__init__.py
+++ b/task-sdk/src/airflow/sdk/__init__.py
@@ -66,6 +66,7 @@ __all__ = [
"PartitionMapper",
"PokeReturnValue",
"ProductMapper",
+ "ResumableJobMixin",
"RetryAction",
"RetryDecision",
"RetryPolicy",
@@ -117,6 +118,7 @@ if TYPE_CHECKING:
cross_downstream,
)
from airflow.sdk.bases.operatorlink import BaseOperatorLink
+ from airflow.sdk.bases.resumablemixin import ResumableJobMixin
from airflow.sdk.bases.sensor import BaseSensorOperator, PokeReturnValue
from airflow.sdk.bases.skipmixin import SkipMixin
from airflow.sdk.bases.xcom import BaseXCom
@@ -233,6 +235,7 @@ __lazy_imports: dict[str, str] = {
"PartitionMapper": ".definitions.partition_mappers.base",
"PokeReturnValue": ".bases.sensor",
"ProductMapper": ".definitions.partition_mappers.product",
+ "ResumableJobMixin": ".bases.resumablemixin",
"RetryAction": ".definitions.retry_policy",
"RetryDecision": ".definitions.retry_policy",
"RetryPolicy": ".definitions.retry_policy",
diff --git a/task-sdk/src/airflow/sdk/bases/resumablemixin.py
b/task-sdk/src/airflow/sdk/bases/resumablemixin.py
new file mode 100644
index 00000000000..4e252d54743
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/bases/resumablemixin.py
@@ -0,0 +1,167 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+ from pydantic import JsonValue
+
+ from airflow.sdk.definitions.context import Context
+ from airflow.sdk.types import Logger
+
+
+class ResumableJobMixin:
+ """
+ Mixin for operators that submit one long-running job to an external system
and poll for completion.
+
+ **Purpose:** This mixin makes the synchronous operator path crash-safe. It
is not a replacement
+ for deferrable operators — deferrable remains the recommended approach for
long-running tasks when
+ a Triggerer is available and the async model fits the team. This mixin is
for teams already running
+ synchronous operators who want worker crashes to reconnect to the existing
job rather than
+ resubmitting a duplicate.
+
+ **How it works:** On the first run, after submitting the job, the external
ID (driver ID, YARN
+ application ID, etc.) is persisted to ``task_state`` before polling
starts. On retry, the mixin
+ reads that ID back and reconnects to the already-running job instead of
starting a new one.
+
+ **What it does not do:** It does not free the worker slot during polling
(use deferrable for that),
+ and it does not stream logs from the remote system (the operator controls
that separately).
+
+ Usage: call ``execute_resumable(context)`` from the operator's
``execute()`` when reconnection
+ is supported.
+
+ Subclasses must implement the methods specific to their external system.
The mixin owns
+ only ``execute_resumable()`` and the task_state read/write logic.
+
+ Example::
+
+ class MyOperator(ResumableJobMixin, BaseOperator):
+ external_id_key = "my_job_id"
+
+ def execute(self, context):
+ return self.execute_resumable(context)
+
+ def submit_job(self, context) -> JsonValue:
+ return self.hook.submit(...)
+
+ def get_job_status(self, external_id: JsonValue) -> str:
+ return self.hook.get_status(external_id)
+
+ def is_job_active(self, status: str) -> bool:
+ return status in ("RUNNING", "PENDING")
+
+ def is_job_succeeded(self, status: str) -> bool:
+ return status == "SUCCEEDED"
+
+ def poll_until_complete(self, external_id: JsonValue, context:
Context) -> None:
+ self.hook.poll(external_id)
+
+ def get_job_result(self, external_id: JsonValue, context: Context)
-> Any:
+ return None
+ """
+
+ if TYPE_CHECKING:
+ # log comes from BaseOperator (via LoggingMixin) at runtime, but mypy
cannot see
+ # that because ResumableJobMixin does not inherit from it directly.
+ log: Logger
+
+ # Key used to store and retrieve the external job ID from task_state
across retries.
+ # Renaming this on a deployed operator breaks in-flight retries — the old
key is already stored.
+ external_id_key: str = "remote_job_id"
+
+ def execute_resumable(self, context: Context) -> Any:
+ """
+ Core of the resumable execution logic. Call this from execute() when
reconnection is supported.
+
+ On initial run: submits the job, persists the external ID to
task_state, then polls.
+
+ Behaviour on retry:
+ - On retry with active job: skips submission, reconnects to the
running job.
+ - On retry with succeeded job: skips submission and polling, returns
result immediately.
+ - On retry with failed job: falls through and resubmits fresh.
+
+ Known limitation: there is a small window between ``submit_job``
returning and
+ ``task_state.set`` completing. If the worker dies in that gap, the
next retry still
+ holds the previous (terminal) ID and will resubmit a fresh job rather
than reconnecting.
+ Closing this window would require atomic "submit + persist", which is
not possible across
+ an external system boundary.
+ """
+ task_state = context.get("task_state")
+
+ if task_state is not None:
+ external_id = task_state.get(self.external_id_key)
+ if external_id:
+ status = self.get_job_status(external_id)
+ if self.is_job_active(status):
+ self.log.info(
+ "Reconnecting to existing job identified by: %s
(status: %s)", external_id, status
+ )
+ return self.poll_until_complete(external_id, context)
+ if self.is_job_succeeded(status):
+ self.log.info(
+ "Job with identifier: %s already completed
successfully, skipping resubmission",
+ external_id,
+ )
+ return self.get_job_result(external_id, context)
+ self.log.info(
+ "Prior job with identifier: %s in terminal state %s,
resubmitting fresh",
+ external_id,
+ status,
+ )
+
+ external_id = self.submit_job(context)
+
+ if task_state is not None:
+ task_state.set(self.external_id_key, external_id)
+
+ self.poll_until_complete(external_id, context)
+ return self.get_job_result(external_id, context)
+
+ def submit_job(self, context: Context) -> JsonValue:
+ """Submit the job to the external system. Return its external ID."""
+ raise NotImplementedError
+
+ def get_job_status(self, external_id: JsonValue) -> str:
+ """Query the external system for the current job status."""
+ raise NotImplementedError
+
+ def is_job_active(self, status: str) -> bool:
+ """
+ Return True if the job is still running and can be reconnected to.
+
+ ``status`` is a raw string returned by the external system — not an
Airflow enum.
+ Its values are backend-specific (e.g. ``"RUNNING"``, ``"Pending"``,
``"ContainerCreating"``).
+ """
+ raise NotImplementedError
+
+ def is_job_succeeded(self, status: str) -> bool:
+ """
+ Return True if the job completed successfully.
+
+ ``status`` is a raw string returned by the external system — not an
Airflow enum.
+ Its values are backend-specific (e.g. ``"FINISHED"``, ``"Succeeded"``).
+ """
+ raise NotImplementedError
+
+ def poll_until_complete(self, external_id: JsonValue, context: Context) ->
None:
+ """Block until the job reaches a terminal state. Raise on failure."""
+ raise NotImplementedError
+
+ def get_job_result(self, external_id: JsonValue, context: Context) -> Any:
+ """Return the job result after completion. Return None if not
applicable."""
+ raise NotImplementedError
diff --git a/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
b/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
new file mode 100644
index 00000000000..8e95e132f3a
--- /dev/null
+++ b/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import pytest
+
+from airflow.sdk.bases.operator import BaseOperator
+from airflow.sdk.bases.resumablemixin import ResumableJobMixin
+
+if TYPE_CHECKING:
+ from pydantic import JsonValue
+
+
+class ConcreteResumableOperator(ResumableJobMixin, BaseOperator):
+ """Minimal concrete implementation for testing the mixin."""
+
+ external_id_key = "test_job_id"
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.submitted_ids: list[str] = []
+ self.polled_ids: list[str] = []
+ self._next_id = "job-001"
+ self._status_map: dict[str, str] = {}
+ self._active_statuses = {"RUNNING", "PENDING"}
+ self._succeeded_statuses = {"SUCCEEDED"}
+
+ def submit_job(self, context) -> JsonValue:
+ self.submitted_ids.append(self._next_id)
+ return self._next_id
+
+ def get_job_status(self, external_id: JsonValue) -> str:
+ return self._status_map.get(str(external_id), "UNKNOWN")
+
+ def is_job_active(self, status: str) -> bool:
+ return status in self._active_statuses
+
+ def is_job_succeeded(self, status: str) -> bool:
+ return status in self._succeeded_statuses
+
+ def poll_until_complete(self, external_id: JsonValue, context) -> None:
+ self.polled_ids.append(str(external_id))
+
+ def get_job_result(self, external_id: JsonValue, context) -> str:
+ return f"result-of-{external_id}"
+
+
+class FakeTaskState:
+ def __init__(self, stored: dict[str, str] | None = None):
+ self._store: dict[str, str] = stored or {}
+
+ def get(self, key: str) -> str | None:
+ return self._store.get(key)
+
+ def set(self, key: str, value: str) -> None:
+ self._store[key] = value
+
+
+def make_context(task_state: FakeTaskState | None = None) -> dict:
+ ctx: dict = {}
+ if task_state is not None:
+ ctx["task_state"] = task_state
+ return ctx
+
+
+class TestFirstSubmission:
+ def test_submits_and_polls_when_no_prior_state(self):
+ op = ConcreteResumableOperator(task_id="test_task")
+ task_state = FakeTaskState()
+ ctx = make_context(task_state)
+
+ op.execute_resumable(ctx)
+
+ assert op.submitted_ids == ["job-001"]
+ assert op.polled_ids == ["job-001"]
+
+ def test_persists_external_id_before_polling(self):
+ """The ID must be in task_state before poll_until_complete is
called."""
+ op = ConcreteResumableOperator(task_id="test_task")
+ task_state = FakeTaskState()
+ persisted_at_poll: list[str | None] = []
+
+ original_set = task_state.set
+
+ def set_and_track(key, value):
+ original_set(key, value)
+
+ def poll_side_effect(external_id, context):
+ persisted_at_poll.append(task_state.get("test_job_id"))
+
+ task_state.set = set_and_track
+ op.poll_until_complete = poll_side_effect
+
+ op.execute_resumable(make_context(task_state))
+
+ assert persisted_at_poll == ["job-001"], "ID must be persisted before
polling starts"
+
+ def test_returns_job_result(self):
+ op = ConcreteResumableOperator(task_id="test_task")
+ result = op.execute_resumable(make_context(FakeTaskState()))
+
+ assert result == "result-of-job-001"
+
+
+class TestRetryWithDifferentJobStatuses:
+ def test_skips_submission_when_job_active(self):
+ op = ConcreteResumableOperator(task_id="test_task")
+ op._status_map["job-001"] = "RUNNING"
+ task_state = FakeTaskState({"test_job_id": "job-001"})
+ ctx = make_context(task_state)
+
+ op.execute_resumable(ctx)
+
+ assert op.submitted_ids == [], "should not resubmit when job is active"
+ assert op.polled_ids == ["job-001"]
+
+ def test_pending_status_also_skips_submission(self):
+ op = ConcreteResumableOperator(task_id="test_task")
+ op._status_map["job-001"] = "PENDING"
+ task_state = FakeTaskState({"test_job_id": "job-001"})
+
+ op.execute_resumable(make_context(task_state))
+
+ assert op.submitted_ids == []
+ assert op.polled_ids == ["job-001"]
+
+ def test_returns_result_immediately_without_polling(self):
+ op = ConcreteResumableOperator(task_id="test_task")
+ op._status_map["job-001"] = "SUCCEEDED"
+ task_state = FakeTaskState({"test_job_id": "job-001"})
+
+ result = op.execute_resumable(make_context(task_state))
+
+ assert op.submitted_ids == [], "should not resubmit"
+ assert op.polled_ids == [], "should not poll again"
+ assert result == "result-of-job-001"
+
+ @pytest.mark.parametrize("status", ["FAILED", "KILLED", "ERROR",
"UNKNOWN"])
+ def test_resubmits_when_prior_job_in_terminal_failure(self, status):
+ op = ConcreteResumableOperator(task_id="test_task")
+ op._status_map["job-001"] = status
+ op._next_id = "job-002"
+ task_state = FakeTaskState({"test_job_id": "job-001"})
+
+ op.execute_resumable(make_context(task_state))
+
+ assert op.submitted_ids == ["job-002"], "should resubmit fresh"
+ assert op.polled_ids == ["job-002"]
+
+
+class TestExternalIdKey:
+ def test_custom_key_used_for_storage_and_retrieval(self):
+ class CustomKeyOp(ConcreteResumableOperator):
+ external_id_key = "my_custom_key"
+
+ op = CustomKeyOp(task_id="test_task")
+ task_state = FakeTaskState()
+
+ op.execute_resumable(make_context(task_state))
+
+ assert task_state.get("my_custom_key") == "job-001"
diff --git a/uv.lock b/uv.lock
index 83759330bb2..a02c1fca513 100644
--- a/uv.lock
+++ b/uv.lock
@@ -3813,6 +3813,8 @@ dependencies = [
{ name = "apache-airflow-providers-common-compat" },
{ name = "grpcio-status" },
{ name = "pyspark-client" },
+ { name = "requests" },
+ { name = "tenacity" },
]
[package.optional-dependencies]
@@ -3849,6 +3851,8 @@ requires-dist = [
{ name = "grpcio-status", specifier = ">=1.67.0" },
{ name = "pyspark", marker = "extra == 'pyspark'", specifier = ">=4.0.0" },
{ name = "pyspark-client", specifier = ">=4.0.0" },
+ { name = "requests", specifier = ">=2.32.0" },
+ { name = "tenacity", specifier = ">=8.3.0" },
]
provides-extras = ["cncf-kubernetes", "openlineage", "pyspark"]