nailo2c commented on code in PR #62331:
URL: https://github.com/apache/airflow/pull/62331#discussion_r2913550470
##########
providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -230,6 +230,155 @@ async def run(self):
raise e
+class DataprocSubmitJobDirectTrigger(DataprocBaseTrigger):
+ """
+ Trigger that submits a Dataproc job and polls for its completion.
+
+ Used for direct-to-triggerer functionality where job submission and polling
+ are handled entirely by the triggerer without requiring a worker.
+
+ :param job: The job resource dict to submit.
+ :param project_id: Google Cloud Project where the job is running.
+ :param region: The Cloud Dataproc region in which to handle the request.
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud
Platform.
+ :param impersonation_chain: Optional service account to impersonate using
short-term credentials.
+ :param polling_interval_seconds: Polling period in seconds to check for
the status.
+ :param cancel_on_kill: Flag indicating whether to cancel the job when
on_kill is called.
+ :param request_id: Optional unique id used to identify the request.
+ """
+
+ def __init__(
+ self,
+ job: dict,
+ request_id: str | None = None,
+ **kwargs,
+ ):
+ self.job = job
+ self.request_id = request_id
+ self.job_id: str | None = None
+ super().__init__(**kwargs)
+
+ def serialize(self):
+ return (
+
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger",
+ {
+ "job": self.job,
+ "request_id": self.request_id,
+ "project_id": self.project_id,
+ "region": self.region,
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "polling_interval_seconds": self.polling_interval_seconds,
+ "cancel_on_kill": self.cancel_on_kill,
+ },
+ )
+
+ if not AIRFLOW_V_3_0_PLUS:
+
+ @provide_session
+ def get_task_instance(self, session: Session) -> TaskInstance:
+ """
+ Get the task instance for the current task.
+
+ :param session: Sqlalchemy session
+ """
+ task_instance = session.scalar(
+ select(TaskInstance).where(
+ TaskInstance.dag_id == self.task_instance.dag_id,
+ TaskInstance.task_id == self.task_instance.task_id,
+ TaskInstance.run_id == self.task_instance.run_id,
+ TaskInstance.map_index == self.task_instance.map_index,
+ )
+ )
+ if task_instance is None:
+ raise AirflowException(
Review Comment:
nit: using python's standard exceptions instead of AirflowException.
https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#dont-raise-airflowexception-directly
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]