andrewgodwin commented on a change in pull request #15389: URL: https://github.com/apache/airflow/pull/15389#discussion_r638965384
########## File path: airflow/jobs/triggerer_job.py ########## @@ -0,0 +1,418 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +import importlib +import os +import signal +import sys +import threading +import time +from collections import deque +from typing import Deque, Dict, List, Optional, Set, Tuple, Type + +from airflow.jobs.base_job import BaseJob +from airflow.models.trigger import Trigger +from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.typing_compat import TypedDict +from airflow.utils.asyncio import create_task +from airflow.utils.log.logging_mixin import LoggingMixin + + +class TriggererJob(BaseJob, LoggingMixin): + """ + TriggererJob continuously runs active triggers in asyncio, watching + for them to fire off their events and then dispatching that information + to their dependent tasks/DAGs. + + It runs as two threads: + - The main thread does DB calls/checkins + - A subthread runs all the async code + """ + + __mapper_args__ = {'polymorphic_identity': 'TriggererJob'} + + partition_ids: Optional[List[int]] = None + partition_total: Optional[int] = None + + def __init__(self, partition=None, *args, **kwargs): + # Make sure we can actually run + if not hasattr(asyncio, "create_task"): + raise RuntimeError("The triggerer/deferred operators only work on Python 3.7 and above.") + # Call superclass + super().__init__(*args, **kwargs) + # Decode partition information + self.partition_ids, self.partition_total = None, None + if partition: + self.partition_ids, self.partition_total = self.decode_partition(partition) + # Set up runner async thread + self.runner = TriggerRunner() + + def decode_partition(self, partition: str) -> Tuple[List[int], int]: + """ + Given a string-format partition specification, returns the list of + partition IDs it represents and the partition total. + """ + try: + # The partition format is "1,2,3/10" where the numbers before + # the slash are the partitions we represent, and the number + # after is the total number. Most users will just have a single + # partition number, e.g. "2/10". + ids_str, total_str = partition.split("/", 1) + partition_total = int(total_str) + partition_ids = [] + for id_str in ids_str.split(","): + id_number = int(id_str) + # Bounds checking (they're 1-indexed, which might catch people out) + if id_number <= 0 or id_number > self.partition_total: + raise ValueError(f"Partition number {id_number} is impossible") + self.partition_ids.append(id_number) + except (ValueError, TypeError): + raise ValueError(f"Invalid partition specification: {partition}") + return partition_ids, partition_total + + def register_signals(self) -> None: + """Register signals that stop child processes""" + signal.signal(signal.SIGINT, self._exit_gracefully) + signal.signal(signal.SIGTERM, self._exit_gracefully) + + def _exit_gracefully(self, signum, frame) -> None: # pylint: disable=unused-argument + """Helper method to clean up processor_agent to avoid leaving orphan processes.""" + # The first time, try to exit nicely + if not self.runner.stop: + self.log.info("Exiting gracefully upon receiving signal %s", signum) + self.runner.stop = True + else: + self.log.warning("Forcing exit due to second exit signal %s", signum) + sys.exit(os.EX_SOFTWARE) + + def _execute(self) -> None: + # Display custom startup ack depending on plurality of partitions + if self.partition_ids is None: + self.log.info("Starting the triggerer") + elif len(self.partition_ids) == 1: + self.log.info( + "Starting the triggerer (partition %s of %s)", self.partition_ids[0], self.partition_total + ) + else: + self.log.info( + "Starting the triggerer (partitions %s of %s)", self.partition_ids, self.partition_total + ) + + try: + # Kick off runner thread + self.runner.start() + # Start our own DB loop in the main thread + self._run_trigger_loop() + except Exception: # pylint: disable=broad-except + self.log.exception("Exception when executing TriggererJob._run_trigger_loop") + raise + finally: + self.log.info("Waiting for triggers to clean up") + # Tell the subthread to stop and then wait for it. + # If the user interrupts/terms again, _graceful_exit will allow them + # to force-kill here. + self.runner.stop = True + self.runner.join() Review comment: I'll put a long one on as a fallback in case SIGINT is sent non-interactively. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
