amoghrajesh commented on code in PR #67118:
URL: https://github.com/apache/airflow/pull/67118#discussion_r3308529578
##########
providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py:
##########
@@ -198,7 +226,122 @@ def execute(self, context: Context) -> None:
self.conf =
inject_transport_information_into_spark_properties(self.conf, context)
if self._hook is None:
self._hook = self._get_hook()
- self._hook.submit(self.application)
+ hook = self._hook
+ if hook._should_track_driver_status:
+ if self.reconnect_on_retry:
+ return self.execute_resumable(context)
+ # reconnect_on_retry=False: still submit-and-poll, just skip
task_state persistence.
+ driver_id = self.submit_job(context)
+ self.poll_until_complete(driver_id, context)
+ return self.get_job_result(driver_id, context)
+ hook.submit(self.application)
+
+ def submit_job(self, context: Context) -> str:
+ if self._hook is None:
+ self._hook = self._get_hook()
+ driver_id = self._hook.submit(self.application)
+ if not driver_id:
+ raise RuntimeError("spark-submit did not return a driver ID")
+ self.log.info("Spark driver submitted: %s", driver_id)
+ return driver_id
+
+ def get_job_status(self, external_id: JsonValue) -> str:
+ # called from submit_job which always returns a str (Spark driver IDs
are strings)
+ external_id = cast("str", external_id)
+ if self._hook is None:
+ self._hook = self._get_hook()
+ # The YARN and K8s branches below (and in is_job_active,
is_job_succeeded, poll_until_complete)
+ # are currently unreachable: execute_resumable is only called when
_should_track_driver_status
+ # is True, which requires spark:// + cluster mode. They are
scaffolding for a follow-up PR
+ # that extends ResumableJobMixin support to YARN and Kubernetes.
+ if self._hook._is_yarn:
+ # TODO: call YARN ResourceManager REST API
+ # GET http://rm:8088/ws/v1/cluster/apps/{external_id}
+ raise NotImplementedError("YARN job status not yet implemented")
+ if self._hook._is_kubernetes:
+ # TODO: call K8s pod status API
+ raise NotImplementedError("K8s job status not yet implemented")
+ scheme = self._hook._connection.get("rest_scheme", "http")
+ rest_port = self._hook._connection.get("rest_port", 6066)
+ # HA master URLs can look like spark://m1:7077,m2:7077 — try each host
in order.
+ # The master URL port (e.g. 7077) is the RPC port — not the REST API
port.
+ # Use rest-port connection extra to override spark.master.rest.port
(default 6066).
+ master_urls = self._hook._connection["master"].replace("spark://",
"").split(",")
+ last_exc: Exception = RuntimeError("No Spark masters to query")
+ for m in master_urls:
+ host = m.strip().split(":")[0]
+ url =
f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}"
+ try:
+ status = self._fetch_driver_status(url, external_id)
+ return status
+ except RuntimeError:
+ raise
Review Comment:
Handled in d844159d18
##########
providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py:
##########
@@ -198,7 +226,122 @@ def execute(self, context: Context) -> None:
self.conf =
inject_transport_information_into_spark_properties(self.conf, context)
if self._hook is None:
self._hook = self._get_hook()
- self._hook.submit(self.application)
+ hook = self._hook
+ if hook._should_track_driver_status:
+ if self.reconnect_on_retry:
+ return self.execute_resumable(context)
+ # reconnect_on_retry=False: still submit-and-poll, just skip
task_state persistence.
+ driver_id = self.submit_job(context)
+ self.poll_until_complete(driver_id, context)
+ return self.get_job_result(driver_id, context)
+ hook.submit(self.application)
+
+ def submit_job(self, context: Context) -> str:
+ if self._hook is None:
+ self._hook = self._get_hook()
+ driver_id = self._hook.submit(self.application)
+ if not driver_id:
+ raise RuntimeError("spark-submit did not return a driver ID")
+ self.log.info("Spark driver submitted: %s", driver_id)
+ return driver_id
+
+ def get_job_status(self, external_id: JsonValue) -> str:
+ # called from submit_job which always returns a str (Spark driver IDs
are strings)
+ external_id = cast("str", external_id)
+ if self._hook is None:
+ self._hook = self._get_hook()
+ # The YARN and K8s branches below (and in is_job_active,
is_job_succeeded, poll_until_complete)
+ # are currently unreachable: execute_resumable is only called when
_should_track_driver_status
+ # is True, which requires spark:// + cluster mode. They are
scaffolding for a follow-up PR
+ # that extends ResumableJobMixin support to YARN and Kubernetes.
+ if self._hook._is_yarn:
+ # TODO: call YARN ResourceManager REST API
+ # GET http://rm:8088/ws/v1/cluster/apps/{external_id}
+ raise NotImplementedError("YARN job status not yet implemented")
+ if self._hook._is_kubernetes:
+ # TODO: call K8s pod status API
+ raise NotImplementedError("K8s job status not yet implemented")
+ scheme = self._hook._connection.get("rest_scheme", "http")
+ rest_port = self._hook._connection.get("rest_port", 6066)
+ # HA master URLs can look like spark://m1:7077,m2:7077 — try each host
in order.
+ # The master URL port (e.g. 7077) is the RPC port — not the REST API
port.
+ # Use rest-port connection extra to override spark.master.rest.port
(default 6066).
+ master_urls = self._hook._connection["master"].replace("spark://",
"").split(",")
+ last_exc: Exception = RuntimeError("No Spark masters to query")
+ for m in master_urls:
+ host = m.strip().split(":")[0]
+ url =
f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}"
+ try:
+ status = self._fetch_driver_status(url, external_id)
+ return status
+ except RuntimeError:
+ raise
+ except Exception as e:
+ self.log.warning("Could not reach Spark master %s: %s", host,
e)
+ last_exc = e
+ raise last_exc
+
+ @retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True)
+ def _fetch_driver_status(self, url: str, external_id: str) -> str:
+ response = requests.get(url, timeout=30)
+ response.raise_for_status()
+ # "success:false" means the master does not recognise the driver ID or
is in recovery.
+ #
https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
+ data = response.json()
+ if not data.get("success"):
+ raise RuntimeError(
+ f"Spark REST API returned failure for {external_id}:
{data.get('message', 'unknown error')}"
+ )
+ status = data["driverState"]
+ self.log.info("Driver %s status: %s", external_id, status)
+ return status
+
+ def is_job_active(self, status: str) -> bool:
+ if self._hook is None:
+ self._hook = self._get_hook()
+ status = status.upper()
+ if self._hook._is_yarn:
+ #
https://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/ResourceManagerRest.html
+ return status in ("NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED",
"RUNNING")
+ if self._hook._is_kubernetes:
+ return status in ("PENDING", "RUNNING")
+ # RELAUNCHING: driver is being restarted after a failure, still alive.
+ # UNKNOWN: master is in failure recovery, state is temporarily
unavailable.
+ #
https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
+ return status in ("SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN")
+
+ def is_job_succeeded(self, status: str) -> bool:
+ if self._hook is None:
+ self._hook = self._get_hook()
+ status = status.upper()
+ if self._hook._is_kubernetes:
+ return status == "SUCCEEDED"
+ # standalone and YARN both use FINISHED
+ return status == "FINISHED"
+
+ def poll_until_complete(self, external_id: JsonValue, context: Context) ->
None:
+ # called from submit_job which always returns a str (Spark driver IDs
are strings)
+ external_id = cast("str", external_id)
+ if self._hook is None:
+ self._hook = self._get_hook()
+ if self._hook._is_yarn:
+ # TODO: poll YARN ResourceManager until app reaches terminal state
+ raise NotImplementedError("YARN poll not yet implemented")
+ if self._hook._is_kubernetes:
+ # TODO: poll K8s pod phase until terminal
+ raise NotImplementedError("K8s poll not yet implemented")
+ self.log.info("Polling driver %s until completion", external_id)
+ self._hook._driver_id = external_id
+ self._hook._driver_status = "SUBMITTED"
+ self._hook._start_driver_status_tracking()
+ if self._hook._driver_status != "FINISHED":
+ raise RuntimeError(f"Driver {external_id} exited with status
{self._hook._driver_status}")
+ # Run post-submit commands here instead of in the hook so they fire
after the job
+ # finishes, not immediately after spark-submit returns the driver ID.
+ self._hook._run_post_submit_commands()
Review Comment:
Handled in d844159d18
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]