kaxil commented on code in PR #67118:
URL: https://github.com/apache/airflow/pull/67118#discussion_r3307489711
##########
providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py:
##########
@@ -198,7 +226,122 @@ def execute(self, context: Context) -> None:
self.conf =
inject_transport_information_into_spark_properties(self.conf, context)
if self._hook is None:
self._hook = self._get_hook()
- self._hook.submit(self.application)
+ hook = self._hook
+ if hook._should_track_driver_status:
+ if self.reconnect_on_retry:
+ return self.execute_resumable(context)
+ # reconnect_on_retry=False: still submit-and-poll, just skip
task_state persistence.
+ driver_id = self.submit_job(context)
+ self.poll_until_complete(driver_id, context)
+ return self.get_job_result(driver_id, context)
+ hook.submit(self.application)
+
+ def submit_job(self, context: Context) -> str:
+ if self._hook is None:
+ self._hook = self._get_hook()
+ driver_id = self._hook.submit(self.application)
+ if not driver_id:
+ raise RuntimeError("spark-submit did not return a driver ID")
+ self.log.info("Spark driver submitted: %s", driver_id)
+ return driver_id
+
+ def get_job_status(self, external_id: JsonValue) -> str:
+ # called from submit_job which always returns a str (Spark driver IDs
are strings)
+ external_id = cast("str", external_id)
+ if self._hook is None:
+ self._hook = self._get_hook()
+ # The YARN and K8s branches below (and in is_job_active,
is_job_succeeded, poll_until_complete)
+ # are currently unreachable: execute_resumable is only called when
_should_track_driver_status
+ # is True, which requires spark:// + cluster mode. They are
scaffolding for a follow-up PR
+ # that extends ResumableJobMixin support to YARN and Kubernetes.
+ if self._hook._is_yarn:
+ # TODO: call YARN ResourceManager REST API
+ # GET http://rm:8088/ws/v1/cluster/apps/{external_id}
+ raise NotImplementedError("YARN job status not yet implemented")
+ if self._hook._is_kubernetes:
+ # TODO: call K8s pod status API
+ raise NotImplementedError("K8s job status not yet implemented")
+ scheme = self._hook._connection.get("rest_scheme", "http")
+ rest_port = self._hook._connection.get("rest_port", 6066)
+ # HA master URLs can look like spark://m1:7077,m2:7077 — try each host
in order.
+ # The master URL port (e.g. 7077) is the RPC port — not the REST API
port.
+ # Use rest-port connection extra to override spark.master.rest.port
(default 6066).
+ master_urls = self._hook._connection["master"].replace("spark://",
"").split(",")
+ last_exc: Exception = RuntimeError("No Spark masters to query")
+ for m in master_urls:
+ host = m.strip().split(":")[0]
+ url =
f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}"
+ try:
+ status = self._fetch_driver_status(url, external_id)
+ return status
+ except RuntimeError:
+ raise
+ except Exception as e:
+ self.log.warning("Could not reach Spark master %s: %s", host,
e)
+ last_exc = e
+ raise last_exc
+
+ @retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True)
+ def _fetch_driver_status(self, url: str, external_id: str) -> str:
+ response = requests.get(url, timeout=30)
+ response.raise_for_status()
+ # "success:false" means the master does not recognise the driver ID or
is in recovery.
+ #
https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
+ data = response.json()
+ if not data.get("success"):
+ raise RuntimeError(
+ f"Spark REST API returned failure for {external_id}:
{data.get('message', 'unknown error')}"
+ )
+ status = data["driverState"]
+ self.log.info("Driver %s status: %s", external_id, status)
+ return status
+
+ def is_job_active(self, status: str) -> bool:
+ if self._hook is None:
+ self._hook = self._get_hook()
+ status = status.upper()
+ if self._hook._is_yarn:
+ #
https://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/ResourceManagerRest.html
+ return status in ("NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED",
"RUNNING")
+ if self._hook._is_kubernetes:
+ return status in ("PENDING", "RUNNING")
+ # RELAUNCHING: driver is being restarted after a failure, still alive.
+ # UNKNOWN: master is in failure recovery, state is temporarily
unavailable.
+ #
https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
+ return status in ("SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN")
+
+ def is_job_succeeded(self, status: str) -> bool:
+ if self._hook is None:
+ self._hook = self._get_hook()
+ status = status.upper()
+ if self._hook._is_kubernetes:
+ return status == "SUCCEEDED"
+ # standalone and YARN both use FINISHED
+ return status == "FINISHED"
+
+ def poll_until_complete(self, external_id: JsonValue, context: Context) ->
None:
+ # called from submit_job which always returns a str (Spark driver IDs
are strings)
+ external_id = cast("str", external_id)
+ if self._hook is None:
+ self._hook = self._get_hook()
+ if self._hook._is_yarn:
+ # TODO: poll YARN ResourceManager until app reaches terminal state
+ raise NotImplementedError("YARN poll not yet implemented")
+ if self._hook._is_kubernetes:
+ # TODO: poll K8s pod phase until terminal
+ raise NotImplementedError("K8s poll not yet implemented")
+ self.log.info("Polling driver %s until completion", external_id)
+ self._hook._driver_id = external_id
+ self._hook._driver_status = "SUBMITTED"
Review Comment:
Forcing `_driver_status = "SUBMITTED"` on the reconnect path loses
information. The mixin only calls `poll_until_complete` after `is_job_active`
has already determined the driver is `RUNNING` (or `RELAUNCHING`), so writing
`SUBMITTED` here makes `_start_driver_status_tracking` waste one poll cycle
re-discovering what we already knew.
Minor, but the line can just be dropped -- the tracking loop tolerates the
existing value or `None` until the next poll lands.
##########
providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py:
##########
@@ -198,7 +226,122 @@ def execute(self, context: Context) -> None:
self.conf =
inject_transport_information_into_spark_properties(self.conf, context)
if self._hook is None:
self._hook = self._get_hook()
- self._hook.submit(self.application)
+ hook = self._hook
+ if hook._should_track_driver_status:
+ if self.reconnect_on_retry:
+ return self.execute_resumable(context)
+ # reconnect_on_retry=False: still submit-and-poll, just skip
task_state persistence.
+ driver_id = self.submit_job(context)
+ self.poll_until_complete(driver_id, context)
+ return self.get_job_result(driver_id, context)
+ hook.submit(self.application)
+
+ def submit_job(self, context: Context) -> str:
+ if self._hook is None:
+ self._hook = self._get_hook()
+ driver_id = self._hook.submit(self.application)
+ if not driver_id:
+ raise RuntimeError("spark-submit did not return a driver ID")
+ self.log.info("Spark driver submitted: %s", driver_id)
+ return driver_id
+
+ def get_job_status(self, external_id: JsonValue) -> str:
+ # called from submit_job which always returns a str (Spark driver IDs
are strings)
+ external_id = cast("str", external_id)
+ if self._hook is None:
+ self._hook = self._get_hook()
+ # The YARN and K8s branches below (and in is_job_active,
is_job_succeeded, poll_until_complete)
+ # are currently unreachable: execute_resumable is only called when
_should_track_driver_status
+ # is True, which requires spark:// + cluster mode. They are
scaffolding for a follow-up PR
+ # that extends ResumableJobMixin support to YARN and Kubernetes.
+ if self._hook._is_yarn:
+ # TODO: call YARN ResourceManager REST API
+ # GET http://rm:8088/ws/v1/cluster/apps/{external_id}
+ raise NotImplementedError("YARN job status not yet implemented")
+ if self._hook._is_kubernetes:
+ # TODO: call K8s pod status API
+ raise NotImplementedError("K8s job status not yet implemented")
+ scheme = self._hook._connection.get("rest_scheme", "http")
+ rest_port = self._hook._connection.get("rest_port", 6066)
+ # HA master URLs can look like spark://m1:7077,m2:7077 — try each host
in order.
+ # The master URL port (e.g. 7077) is the RPC port — not the REST API
port.
+ # Use rest-port connection extra to override spark.master.rest.port
(default 6066).
+ master_urls = self._hook._connection["master"].replace("spark://",
"").split(",")
+ last_exc: Exception = RuntimeError("No Spark masters to query")
+ for m in master_urls:
+ host = m.strip().split(":")[0]
+ url =
f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}"
+ try:
+ status = self._fetch_driver_status(url, external_id)
+ return status
+ except RuntimeError:
+ raise
+ except Exception as e:
+ self.log.warning("Could not reach Spark master %s: %s", host,
e)
+ last_exc = e
+ raise last_exc
+
+ @retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True)
+ def _fetch_driver_status(self, url: str, external_id: str) -> str:
+ response = requests.get(url, timeout=30)
+ response.raise_for_status()
+ # "success:false" means the master does not recognise the driver ID or
is in recovery.
+ #
https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
+ data = response.json()
+ if not data.get("success"):
+ raise RuntimeError(
+ f"Spark REST API returned failure for {external_id}:
{data.get('message', 'unknown error')}"
+ )
+ status = data["driverState"]
+ self.log.info("Driver %s status: %s", external_id, status)
+ return status
+
+ def is_job_active(self, status: str) -> bool:
+ if self._hook is None:
+ self._hook = self._get_hook()
+ status = status.upper()
+ if self._hook._is_yarn:
+ #
https://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/ResourceManagerRest.html
+ return status in ("NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED",
"RUNNING")
+ if self._hook._is_kubernetes:
+ return status in ("PENDING", "RUNNING")
+ # RELAUNCHING: driver is being restarted after a failure, still alive.
+ # UNKNOWN: master is in failure recovery, state is temporarily
unavailable.
+ #
https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
+ return status in ("SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN")
+
+ def is_job_succeeded(self, status: str) -> bool:
+ if self._hook is None:
+ self._hook = self._get_hook()
+ status = status.upper()
+ if self._hook._is_kubernetes:
+ return status == "SUCCEEDED"
+ # standalone and YARN both use FINISHED
+ return status == "FINISHED"
+
+ def poll_until_complete(self, external_id: JsonValue, context: Context) ->
None:
+ # called from submit_job which always returns a str (Spark driver IDs
are strings)
+ external_id = cast("str", external_id)
+ if self._hook is None:
+ self._hook = self._get_hook()
+ if self._hook._is_yarn:
+ # TODO: poll YARN ResourceManager until app reaches terminal state
+ raise NotImplementedError("YARN poll not yet implemented")
+ if self._hook._is_kubernetes:
+ # TODO: poll K8s pod phase until terminal
+ raise NotImplementedError("K8s poll not yet implemented")
+ self.log.info("Polling driver %s until completion", external_id)
+ self._hook._driver_id = external_id
+ self._hook._driver_status = "SUBMITTED"
+ self._hook._start_driver_status_tracking()
+ if self._hook._driver_status != "FINISHED":
+ raise RuntimeError(f"Driver {external_id} exited with status
{self._hook._driver_status}")
+ # Run post-submit commands here instead of in the hook so they fire
after the job
+ # finishes, not immediately after spark-submit returns the driver ID.
+ self._hook._run_post_submit_commands()
Review Comment:
Following up on the post_submit_commands timing thread: the relocation here
only fires on the `FINISHED` path. If the driver exits
`FAILED`/`ERROR`/`KILLED`, the `raise RuntimeError` on line 338 skips this line
entirely.
The previous behaviour was unconditional via `submit()`'s `finally:` block,
which is what the documented "Useful for cleaning up sidecars such as Istio"
contract relies on -- Istio still needs to be told to quit when the Spark job
fails, not just when it succeeds. Wrapping lines 333-338 in `try/finally` would
restore that.
##########
providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py:
##########
@@ -184,9 +207,14 @@ def __init__(
self._conn_id = conn_id
self._use_krb5ccache = use_krb5ccache
+ self.reconnect_on_retry = reconnect_on_retry
self._openlineage_inject_parent_job_info =
openlineage_inject_parent_job_info
self._openlineage_inject_transport_info =
openlineage_inject_transport_info
+ # Generic key used across all Spark deployment modes (standalone driver ID,
+ # YARN application ID, K8s driver pod name).
+ external_id_key = "spark_job_id"
Review Comment:
Class-level attributes are conventionally placed at the top of the class
body alongside `template_fields`, not between `__init__` and `execute`. Tucked
here it's hard to find when a subclass author goes looking for what they need
to override.
Suggest moving `external_id_key` (and the comment above it) to just before
`template_fields` near line 113.
##########
task-sdk/src/airflow/sdk/bases/resumablemixin.py:
##########
@@ -0,0 +1,161 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+ from pydantic import JsonValue
+
+ from airflow.sdk.definitions.context import Context
+ from airflow.sdk.types import Logger
+
+
+class ResumableJobMixin:
+ """
+ Mixin for operators that submit one long-running job to an external system
and poll for completion.
+
+ **Purpose:** This mixin makes the synchronous operator path crash-safe. It
is not a replacement
+ for deferrable operators — deferrable remains the recommended approach for
long-running tasks when
+ a Triggerer is available and the async model fits the team. This mixin is
for teams already running
+ synchronous operators who want worker crashes to reconnect to the existing
job rather than
+ resubmitting a duplicate.
+
+ **How it works:** On the first run, after submitting the job, the external
ID (driver ID, YARN
+ application ID, etc.) is persisted to ``task_state`` before polling
starts. On retry, the mixin
+ reads that ID back and reconnects to the already-running job instead of
starting a new one.
+
+ **What it does not do:** It does not free the worker slot during polling
(use deferrable for that),
+ and it does not stream logs from the remote system (the operator controls
that separately).
+
+ Usage: call ``execute_resumable(context)`` from the operator's
``execute()`` when reconnection
+ is supported.
+
+ Subclasses must implement the methods specific to their external system.
The mixin owns
+ only ``execute_resumable()`` and the task_state read/write logic.
+
+ Example::
+
+ class MyOperator(ResumableJobMixin, BaseOperator):
+ external_id_key = "my_job_id"
+
+ def execute(self, context):
+ return self.execute_resumable(context)
+
+ def submit_job(self, context) -> JsonValue:
+ return self.hook.submit(...)
+
+ def get_job_status(self, external_id: JsonValue) -> str:
+ return self.hook.get_status(external_id)
+
+ def is_job_active(self, status: str) -> bool:
+ return status in ("RUNNING", "PENDING")
+
+ def is_job_succeeded(self, status: str) -> bool:
+ return status == "SUCCEEDED"
+
+ def poll_until_complete(self, external_id: JsonValue, context:
Context) -> None:
+ self.hook.poll(external_id)
+
+ def get_job_result(self, external_id: JsonValue, context: Context)
-> Any:
+ return None
+ """
+
+ if TYPE_CHECKING:
+ # log comes from BaseOperator (via LoggingMixin) at runtime, but mypy
cannot see
+ # that because ResumableJobMixin does not inherit from it directly.
+ log: Logger
+
+ # Key used to store and retrieve the external job ID from task_state
across retries.
+ # Renaming this on a deployed operator breaks in-flight retries — the old
key is already stored.
+ external_id_key: str = "remote_job_id"
+
+ def execute_resumable(self, context: Context) -> Any:
+ """
+ Core of the resumable execution logic. Call this from execute() when
reconnection is supported.
+
+ On initial run: submits the job, persists the external ID to
task_state, then polls.
+
+ Behaviour on retry:
+ - On retry with active job: skips submission, reconnects to the
running job.
+ - On retry with succeeded job: skips submission and polling, returns
result immediately.
+ - On retry with failed job: falls through and resubmits fresh.
+ """
+ task_state = context.get("task_state")
+
+ if task_state is not None:
+ external_id = task_state.get(self.external_id_key)
+ if external_id:
+ status = self.get_job_status(external_id)
+ if self.is_job_active(status):
+ self.log.info(
+ "Reconnecting to existing job identified by: %s
(status: %s)", external_id, status
+ )
+ return self.poll_until_complete(external_id, context)
+ if self.is_job_succeeded(status):
+ self.log.info(
+ "Job with identifier: %s already completed
successfully, skipping resubmission",
+ external_id,
+ )
+ return self.get_job_result(external_id, context)
+ self.log.info(
+ "Prior job with identifier: %s in terminal state %s,
resubmitting fresh",
+ external_id,
+ status,
+ )
+
+ external_id = self.submit_job(context)
+
+ if task_state is not None:
+ task_state.set(self.external_id_key, external_id)
Review Comment:
Two-step submit-then-persist has a window where the new driver is running on
the cluster but `task_state` still holds the old (now-stale) ID: if
`task_state.set` raises, or the worker dies between line 121 and line 124.
The next retry then reads the old terminal ID, falls through the `Prior job
... terminal state` branch above, and submits a third driver -- exactly the
duplicate-submission failure mode this mixin is meant to prevent, just shifted
one retry to the right.
Applies to both the first-submission and post-terminal-resubmit paths. Two
options:
1. Document the window explicitly in the class docstring as a known
limitation (and rely on `task_state.set` being reliable in practice).
2. Write a "pending submit" marker before `submit_job` runs, then overwrite
it with the real ID afterwards; on retry, a pending marker means "check
externally before assuming clean submit".
##########
providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py:
##########
@@ -198,7 +226,122 @@ def execute(self, context: Context) -> None:
self.conf =
inject_transport_information_into_spark_properties(self.conf, context)
if self._hook is None:
self._hook = self._get_hook()
- self._hook.submit(self.application)
+ hook = self._hook
+ if hook._should_track_driver_status:
+ if self.reconnect_on_retry:
+ return self.execute_resumable(context)
+ # reconnect_on_retry=False: still submit-and-poll, just skip
task_state persistence.
+ driver_id = self.submit_job(context)
+ self.poll_until_complete(driver_id, context)
+ return self.get_job_result(driver_id, context)
+ hook.submit(self.application)
+
+ def submit_job(self, context: Context) -> str:
+ if self._hook is None:
+ self._hook = self._get_hook()
+ driver_id = self._hook.submit(self.application)
+ if not driver_id:
+ raise RuntimeError("spark-submit did not return a driver ID")
+ self.log.info("Spark driver submitted: %s", driver_id)
+ return driver_id
+
+ def get_job_status(self, external_id: JsonValue) -> str:
+ # called from submit_job which always returns a str (Spark driver IDs
are strings)
+ external_id = cast("str", external_id)
+ if self._hook is None:
+ self._hook = self._get_hook()
+ # The YARN and K8s branches below (and in is_job_active,
is_job_succeeded, poll_until_complete)
+ # are currently unreachable: execute_resumable is only called when
_should_track_driver_status
+ # is True, which requires spark:// + cluster mode. They are
scaffolding for a follow-up PR
+ # that extends ResumableJobMixin support to YARN and Kubernetes.
+ if self._hook._is_yarn:
+ # TODO: call YARN ResourceManager REST API
+ # GET http://rm:8088/ws/v1/cluster/apps/{external_id}
+ raise NotImplementedError("YARN job status not yet implemented")
+ if self._hook._is_kubernetes:
+ # TODO: call K8s pod status API
+ raise NotImplementedError("K8s job status not yet implemented")
+ scheme = self._hook._connection.get("rest_scheme", "http")
+ rest_port = self._hook._connection.get("rest_port", 6066)
+ # HA master URLs can look like spark://m1:7077,m2:7077 — try each host
in order.
+ # The master URL port (e.g. 7077) is the RPC port — not the REST API
port.
+ # Use rest-port connection extra to override spark.master.rest.port
(default 6066).
+ master_urls = self._hook._connection["master"].replace("spark://",
"").split(",")
+ last_exc: Exception = RuntimeError("No Spark masters to query")
+ for m in master_urls:
+ host = m.strip().split(":")[0]
+ url =
f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}"
+ try:
+ status = self._fetch_driver_status(url, external_id)
+ return status
+ except RuntimeError:
+ raise
Review Comment:
The `except RuntimeError: raise` here re-raises immediately when
`_fetch_driver_status` raises for `success: false`. But that's exactly the
response a stale master gives during HA recovery for a driver it doesn't yet
know about -- the case this HA loop is supposed to handle. With masters
configured as `m1,m2,m3`, if `m1` answers `success: false`, the loop never
tries `m2` or `m3` and the retry burns a slot.
The existing `test_get_job_status_ha_tries_next_master` only exercises
`ConnectionError`, so it doesn't catch this. Suggest collapsing the
`RuntimeError` branch into the same `last_exc = e; continue` path as
`Exception`, and only re-raise after all masters are tried.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]