amoghrajesh commented on code in PR #67118: URL: https://github.com/apache/airflow/pull/67118#discussion_r3308530233
########## task-sdk/src/airflow/sdk/bases/resumablemixin.py: ########## @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from pydantic import JsonValue + + from airflow.sdk.definitions.context import Context + from airflow.sdk.types import Logger + + +class ResumableJobMixin: + """ + Mixin for operators that submit one long-running job to an external system and poll for completion. + + **Purpose:** This mixin makes the synchronous operator path crash-safe. It is not a replacement + for deferrable operators — deferrable remains the recommended approach for long-running tasks when + a Triggerer is available and the async model fits the team. This mixin is for teams already running + synchronous operators who want worker crashes to reconnect to the existing job rather than + resubmitting a duplicate. + + **How it works:** On the first run, after submitting the job, the external ID (driver ID, YARN + application ID, etc.) is persisted to ``task_state`` before polling starts. On retry, the mixin + reads that ID back and reconnects to the already-running job instead of starting a new one. + + **What it does not do:** It does not free the worker slot during polling (use deferrable for that), + and it does not stream logs from the remote system (the operator controls that separately). + + Usage: call ``execute_resumable(context)`` from the operator's ``execute()`` when reconnection + is supported. + + Subclasses must implement the methods specific to their external system. The mixin owns + only ``execute_resumable()`` and the task_state read/write logic. + + Example:: + + class MyOperator(ResumableJobMixin, BaseOperator): + external_id_key = "my_job_id" + + def execute(self, context): + return self.execute_resumable(context) + + def submit_job(self, context) -> JsonValue: + return self.hook.submit(...) + + def get_job_status(self, external_id: JsonValue) -> str: + return self.hook.get_status(external_id) + + def is_job_active(self, status: str) -> bool: + return status in ("RUNNING", "PENDING") + + def is_job_succeeded(self, status: str) -> bool: + return status == "SUCCEEDED" + + def poll_until_complete(self, external_id: JsonValue, context: Context) -> None: + self.hook.poll(external_id) + + def get_job_result(self, external_id: JsonValue, context: Context) -> Any: + return None + """ + + if TYPE_CHECKING: + # log comes from BaseOperator (via LoggingMixin) at runtime, but mypy cannot see + # that because ResumableJobMixin does not inherit from it directly. + log: Logger + + # Key used to store and retrieve the external job ID from task_state across retries. + # Renaming this on a deployed operator breaks in-flight retries — the old key is already stored. + external_id_key: str = "remote_job_id" + + def execute_resumable(self, context: Context) -> Any: + """ + Core of the resumable execution logic. Call this from execute() when reconnection is supported. + + On initial run: submits the job, persists the external ID to task_state, then polls. + + Behaviour on retry: + - On retry with active job: skips submission, reconnects to the running job. + - On retry with succeeded job: skips submission and polling, returns result immediately. + - On retry with failed job: falls through and resubmits fresh. + """ + task_state = context.get("task_state") + + if task_state is not None: + external_id = task_state.get(self.external_id_key) + if external_id: + status = self.get_job_status(external_id) + if self.is_job_active(status): + self.log.info( + "Reconnecting to existing job identified by: %s (status: %s)", external_id, status + ) + return self.poll_until_complete(external_id, context) + if self.is_job_succeeded(status): + self.log.info( + "Job with identifier: %s already completed successfully, skipping resubmission", + external_id, + ) + return self.get_job_result(external_id, context) + self.log.info( + "Prior job with identifier: %s in terminal state %s, resubmitting fresh", + external_id, + status, + ) + + external_id = self.submit_job(context) + + if task_state is not None: + task_state.set(self.external_id_key, external_id) Review Comment: Handled it in d844159d18 ########## providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py: ########## @@ -198,7 +226,122 @@ def execute(self, context: Context) -> None: self.conf = inject_transport_information_into_spark_properties(self.conf, context) if self._hook is None: self._hook = self._get_hook() - self._hook.submit(self.application) + hook = self._hook + if hook._should_track_driver_status: + if self.reconnect_on_retry: + return self.execute_resumable(context) + # reconnect_on_retry=False: still submit-and-poll, just skip task_state persistence. + driver_id = self.submit_job(context) + self.poll_until_complete(driver_id, context) + return self.get_job_result(driver_id, context) + hook.submit(self.application) + + def submit_job(self, context: Context) -> str: + if self._hook is None: + self._hook = self._get_hook() + driver_id = self._hook.submit(self.application) + if not driver_id: + raise RuntimeError("spark-submit did not return a driver ID") + self.log.info("Spark driver submitted: %s", driver_id) + return driver_id + + def get_job_status(self, external_id: JsonValue) -> str: + # called from submit_job which always returns a str (Spark driver IDs are strings) + external_id = cast("str", external_id) + if self._hook is None: + self._hook = self._get_hook() + # The YARN and K8s branches below (and in is_job_active, is_job_succeeded, poll_until_complete) + # are currently unreachable: execute_resumable is only called when _should_track_driver_status + # is True, which requires spark:// + cluster mode. They are scaffolding for a follow-up PR + # that extends ResumableJobMixin support to YARN and Kubernetes. + if self._hook._is_yarn: + # TODO: call YARN ResourceManager REST API + # GET http://rm:8088/ws/v1/cluster/apps/{external_id} + raise NotImplementedError("YARN job status not yet implemented") + if self._hook._is_kubernetes: + # TODO: call K8s pod status API + raise NotImplementedError("K8s job status not yet implemented") + scheme = self._hook._connection.get("rest_scheme", "http") + rest_port = self._hook._connection.get("rest_port", 6066) + # HA master URLs can look like spark://m1:7077,m2:7077 — try each host in order. + # The master URL port (e.g. 7077) is the RPC port — not the REST API port. + # Use rest-port connection extra to override spark.master.rest.port (default 6066). + master_urls = self._hook._connection["master"].replace("spark://", "").split(",") + last_exc: Exception = RuntimeError("No Spark masters to query") + for m in master_urls: + host = m.strip().split(":")[0] + url = f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}" + try: + status = self._fetch_driver_status(url, external_id) + return status + except RuntimeError: + raise + except Exception as e: + self.log.warning("Could not reach Spark master %s: %s", host, e) + last_exc = e + raise last_exc + + @retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True) + def _fetch_driver_status(self, url: str, external_id: str) -> str: + response = requests.get(url, timeout=30) + response.raise_for_status() + # "success:false" means the master does not recognise the driver ID or is in recovery. + # https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala + data = response.json() + if not data.get("success"): + raise RuntimeError( + f"Spark REST API returned failure for {external_id}: {data.get('message', 'unknown error')}" + ) + status = data["driverState"] + self.log.info("Driver %s status: %s", external_id, status) + return status + + def is_job_active(self, status: str) -> bool: + if self._hook is None: + self._hook = self._get_hook() + status = status.upper() + if self._hook._is_yarn: + # https://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/ResourceManagerRest.html + return status in ("NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED", "RUNNING") + if self._hook._is_kubernetes: + return status in ("PENDING", "RUNNING") + # RELAUNCHING: driver is being restarted after a failure, still alive. + # UNKNOWN: master is in failure recovery, state is temporarily unavailable. + # https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala + return status in ("SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN") + + def is_job_succeeded(self, status: str) -> bool: + if self._hook is None: + self._hook = self._get_hook() + status = status.upper() + if self._hook._is_kubernetes: + return status == "SUCCEEDED" + # standalone and YARN both use FINISHED + return status == "FINISHED" + + def poll_until_complete(self, external_id: JsonValue, context: Context) -> None: + # called from submit_job which always returns a str (Spark driver IDs are strings) + external_id = cast("str", external_id) + if self._hook is None: + self._hook = self._get_hook() + if self._hook._is_yarn: + # TODO: poll YARN ResourceManager until app reaches terminal state + raise NotImplementedError("YARN poll not yet implemented") + if self._hook._is_kubernetes: + # TODO: poll K8s pod phase until terminal + raise NotImplementedError("K8s poll not yet implemented") + self.log.info("Polling driver %s until completion", external_id) + self._hook._driver_id = external_id + self._hook._driver_status = "SUBMITTED" Review Comment: Handled it in: d844159d18 ########## providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py: ########## @@ -184,9 +207,14 @@ def __init__( self._conn_id = conn_id self._use_krb5ccache = use_krb5ccache + self.reconnect_on_retry = reconnect_on_retry self._openlineage_inject_parent_job_info = openlineage_inject_parent_job_info self._openlineage_inject_transport_info = openlineage_inject_transport_info + # Generic key used across all Spark deployment modes (standalone driver ID, + # YARN application ID, K8s driver pod name). + external_id_key = "spark_job_id" Review Comment: Handled it in: d844159d18 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
