ashb commented on a change in pull request #15389: URL: https://github.com/apache/airflow/pull/15389#discussion_r638663257
########## File path: airflow/jobs/triggerer_job.py ########## @@ -0,0 +1,418 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +import importlib +import os +import signal +import sys +import threading +import time +from collections import deque +from typing import Deque, Dict, List, Optional, Set, Tuple, Type + +from airflow.jobs.base_job import BaseJob +from airflow.models.trigger import Trigger +from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.typing_compat import TypedDict +from airflow.utils.asyncio import create_task +from airflow.utils.log.logging_mixin import LoggingMixin + + +class TriggererJob(BaseJob, LoggingMixin): + """ + TriggererJob continuously runs active triggers in asyncio, watching + for them to fire off their events and then dispatching that information + to their dependent tasks/DAGs. + + It runs as two threads: + - The main thread does DB calls/checkins + - A subthread runs all the async code + """ + + __mapper_args__ = {'polymorphic_identity': 'TriggererJob'} + + partition_ids: Optional[List[int]] = None + partition_total: Optional[int] = None + + def __init__(self, partition=None, *args, **kwargs): + # Make sure we can actually run + if not hasattr(asyncio, "create_task"): + raise RuntimeError("The triggerer/deferred operators only work on Python 3.7 and above.") + # Call superclass + super().__init__(*args, **kwargs) + # Decode partition information + self.partition_ids, self.partition_total = None, None + if partition: + self.partition_ids, self.partition_total = self.decode_partition(partition) + # Set up runner async thread + self.runner = TriggerRunner() + + def decode_partition(self, partition: str) -> Tuple[List[int], int]: + """ + Given a string-format partition specification, returns the list of + partition IDs it represents and the partition total. + """ + try: + # The partition format is "1,2,3/10" where the numbers before + # the slash are the partitions we represent, and the number + # after is the total number. Most users will just have a single + # partition number, e.g. "2/10". + ids_str, total_str = partition.split("/", 1) + partition_total = int(total_str) + partition_ids = [] + for id_str in ids_str.split(","): + id_number = int(id_str) + # Bounds checking (they're 1-indexed, which might catch people out) + if id_number <= 0 or id_number > self.partition_total: + raise ValueError(f"Partition number {id_number} is impossible") + self.partition_ids.append(id_number) + except (ValueError, TypeError): + raise ValueError(f"Invalid partition specification: {partition}") + return partition_ids, partition_total + + def register_signals(self) -> None: + """Register signals that stop child processes""" + signal.signal(signal.SIGINT, self._exit_gracefully) + signal.signal(signal.SIGTERM, self._exit_gracefully) + + def _exit_gracefully(self, signum, frame) -> None: # pylint: disable=unused-argument + """Helper method to clean up processor_agent to avoid leaving orphan processes.""" + # The first time, try to exit nicely + if not self.runner.stop: + self.log.info("Exiting gracefully upon receiving signal %s", signum) + self.runner.stop = True + else: + self.log.warning("Forcing exit due to second exit signal %s", signum) + sys.exit(os.EX_SOFTWARE) + + def _execute(self) -> None: + # Display custom startup ack depending on plurality of partitions + if self.partition_ids is None: + self.log.info("Starting the triggerer") + elif len(self.partition_ids) == 1: + self.log.info( + "Starting the triggerer (partition %s of %s)", self.partition_ids[0], self.partition_total + ) + else: + self.log.info( + "Starting the triggerer (partitions %s of %s)", self.partition_ids, self.partition_total + ) + + try: + # Kick off runner thread + self.runner.start() + # Start our own DB loop in the main thread + self._run_trigger_loop() + except Exception: # pylint: disable=broad-except + self.log.exception("Exception when executing TriggererJob._run_trigger_loop") + raise + finally: + self.log.info("Waiting for triggers to clean up") + # Tell the subthread to stop and then wait for it. + # If the user interrupts/terms again, _graceful_exit will allow them + # to force-kill here. + self.runner.stop = True + self.runner.join() Review comment: Worth having a timeout on join? ########## File path: airflow/models/baseoperator.py ########## @@ -1534,6 +1535,23 @@ def inherits_from_dummy_operator(self): # of its sub-classes (which don't inherit from anything but BaseOperator). return getattr(self, '_is_dummy', False) + def defer( + self, + *, + trigger: BaseTrigger, + method_name: str, + kwargs: Optional[Dict[str, Any]] = None, + timeout: Optional[timedelta] = None, + ): + """ + Marks this Operator as being "deferred" - that is, suspending its + execution until the provided trigger fires an event. + + This is achieved by raising a special exception (OperatorDeferred) + which is caught in the main _execute_task wrapper. + """ + raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) Review comment: Is it worth doing a check for Python <3.7 here and _failing_ the task instead? (Rather than going to deferred but not having the trigger ever be able to run?) ########## File path: airflow/models/baseoperator.py ########## @@ -1534,6 +1535,23 @@ def inherits_from_dummy_operator(self): # of its sub-classes (which don't inherit from anything but BaseOperator). return getattr(self, '_is_dummy', False) + def defer( + self, + *, + trigger: BaseTrigger, + method_name: str, + kwargs: Optional[Dict[str, Any]] = None, + timeout: Optional[timedelta] = None, + ): + """ + Marks this Operator as being "deferred" - that is, suspending its + execution until the provided trigger fires an event. + + This is achieved by raising a special exception (OperatorDeferred) Review comment: ```suggestion This is achieved by raising a special exception (TaskDeferred) ``` ########## File path: airflow/jobs/triggerer_job.py ########## @@ -0,0 +1,418 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +import importlib +import os +import signal +import sys +import threading +import time +from collections import deque +from typing import Deque, Dict, List, Optional, Set, Tuple, Type + +from airflow.jobs.base_job import BaseJob +from airflow.models.trigger import Trigger +from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.typing_compat import TypedDict +from airflow.utils.asyncio import create_task +from airflow.utils.log.logging_mixin import LoggingMixin + + +class TriggererJob(BaseJob, LoggingMixin): Review comment: ```suggestion class TriggererJob(BaseJob): ``` BaseJob already has LoggingMixin ########## File path: airflow/models/dag.py ########## @@ -1177,8 +1179,10 @@ def clear( :type dry_run: bool :param session: The sqlalchemy session to use :type session: sqlalchemy.orm.session.Session - :param get_tis: Return the sqlalchemy query for finding the TaskInstance without clearing the tasks - :type get_tis: bool + :param get_ti_keys: Return the sqlalchemy query for TaskInstance PKs without clearing the tasks + :type get_ti_keys: bool + :param get_ti_instances: Return the sqlalchemy query for TaskInstances without clearing the tasks + :type get_ti_instances: bool Review comment: These almost feel like it should be a different method, and clear should be made to call that instead. ########## File path: airflow/jobs/triggerer_job.py ########## @@ -0,0 +1,418 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +import importlib +import os +import signal +import sys +import threading +import time +from collections import deque +from typing import Deque, Dict, List, Optional, Set, Tuple, Type + +from airflow.jobs.base_job import BaseJob +from airflow.models.trigger import Trigger +from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.typing_compat import TypedDict +from airflow.utils.asyncio import create_task +from airflow.utils.log.logging_mixin import LoggingMixin + + +class TriggererJob(BaseJob, LoggingMixin): + """ + TriggererJob continuously runs active triggers in asyncio, watching + for them to fire off their events and then dispatching that information + to their dependent tasks/DAGs. + + It runs as two threads: + - The main thread does DB calls/checkins + - A subthread runs all the async code + """ + + __mapper_args__ = {'polymorphic_identity': 'TriggererJob'} + + partition_ids: Optional[List[int]] = None + partition_total: Optional[int] = None + + def __init__(self, partition=None, *args, **kwargs): + # Make sure we can actually run + if not hasattr(asyncio, "create_task"): + raise RuntimeError("The triggerer/deferred operators only work on Python 3.7 and above.") + # Call superclass + super().__init__(*args, **kwargs) + # Decode partition information + self.partition_ids, self.partition_total = None, None + if partition: + self.partition_ids, self.partition_total = self.decode_partition(partition) + # Set up runner async thread + self.runner = TriggerRunner() + + def decode_partition(self, partition: str) -> Tuple[List[int], int]: + """ + Given a string-format partition specification, returns the list of + partition IDs it represents and the partition total. + """ + try: + # The partition format is "1,2,3/10" where the numbers before + # the slash are the partitions we represent, and the number + # after is the total number. Most users will just have a single + # partition number, e.g. "2/10". + ids_str, total_str = partition.split("/", 1) + partition_total = int(total_str) + partition_ids = [] + for id_str in ids_str.split(","): + id_number = int(id_str) + # Bounds checking (they're 1-indexed, which might catch people out) + if id_number <= 0 or id_number > self.partition_total: + raise ValueError(f"Partition number {id_number} is impossible") + self.partition_ids.append(id_number) + except (ValueError, TypeError): + raise ValueError(f"Invalid partition specification: {partition}") + return partition_ids, partition_total + + def register_signals(self) -> None: + """Register signals that stop child processes""" + signal.signal(signal.SIGINT, self._exit_gracefully) + signal.signal(signal.SIGTERM, self._exit_gracefully) + + def _exit_gracefully(self, signum, frame) -> None: # pylint: disable=unused-argument + """Helper method to clean up processor_agent to avoid leaving orphan processes.""" + # The first time, try to exit nicely + if not self.runner.stop: + self.log.info("Exiting gracefully upon receiving signal %s", signum) + self.runner.stop = True + else: + self.log.warning("Forcing exit due to second exit signal %s", signum) + sys.exit(os.EX_SOFTWARE) + + def _execute(self) -> None: + # Display custom startup ack depending on plurality of partitions + if self.partition_ids is None: + self.log.info("Starting the triggerer") + elif len(self.partition_ids) == 1: + self.log.info( + "Starting the triggerer (partition %s of %s)", self.partition_ids[0], self.partition_total + ) + else: + self.log.info( + "Starting the triggerer (partitions %s of %s)", self.partition_ids, self.partition_total + ) + + try: + # Kick off runner thread + self.runner.start() + # Start our own DB loop in the main thread + self._run_trigger_loop() + except Exception: # pylint: disable=broad-except + self.log.exception("Exception when executing TriggererJob._run_trigger_loop") + raise + finally: + self.log.info("Waiting for triggers to clean up") + # Tell the subthread to stop and then wait for it. + # If the user interrupts/terms again, _graceful_exit will allow them + # to force-kill here. + self.runner.stop = True + self.runner.join() + self.log.info("Exited trigger loop") + + def _run_trigger_loop(self) -> None: + """ + The main-thread trigger loop. + + This runs synchronously and handles all database reads/writes. + """ + while not self.runner.stop: + # Clean out unused triggers + Trigger.clean_unused() + # Load/delete triggers + self.load_triggers() + # Handle events + self.handle_events() + # Handle failed triggers + self.handle_failed_triggers() + # Idle sleep + time.sleep(1) + + def load_triggers(self): + """ + Queries the database to get the triggers we're supposed to be running, + adds them to our runner, and then removes ones from it we no longer + need. + """ + requested_trigger_ids = Trigger.runnable_ids( + partition_ids=self.partition_ids, partition_total=self.partition_total + ) + self.runner.update_triggers(set(requested_trigger_ids)) + + def handle_events(self): + """ + Handles outbound events from triggers - dispatching them into the Trigger + model where they are then pushed into the relevant task instances. + """ + while self.runner.events: + # Get the event and its trigger ID + trigger_id, event = self.runner.events.popleft() + # Tell the model to wake up its tasks + Trigger.submit_event(trigger_id=trigger_id, event=event) + + def handle_failed_triggers(self): + """ + Handles "failed" triggers - ones that errored or exited before they + sent an event. Task Instances that depend on them need failing. + """ + while self.runner.failed_triggers: + # Tell the model to fail this trigger's deps + trigger_id = self.runner.failed_triggers.popleft() + Trigger.submit_failure(trigger_id=trigger_id) + + +class TriggerDetails(TypedDict): + """Type class for the trigger details dictionary""" + + task: asyncio.Task + name: str + events: int + + +class TriggerRunner(threading.Thread, LoggingMixin): + """ + Runtime environment for all triggers. + + Mainly runs inside its own thread, where it hands control off to an asyncio + event loop, but is also sometimes interacted with from the main thread + (where all the DB queries are done). All communication between threads is + done via Deques. + """ + + # Maps trigger IDs to their running tasks and other info + triggers: Dict[int, TriggerDetails] + + # Cache for looking up triggers by classpath + trigger_cache: Dict[str, Type[BaseTrigger]] + + # Inbound queue of new triggers + to_create: Deque[Tuple[int, BaseTrigger]] + + # Inbound queue of deleted triggers + to_delete: Deque[int] + + # Outbound queue of events + events: Deque[Tuple[int, TriggerEvent]] + + # Outbound queue of failed triggers + failed_triggers: Deque[int] + + # Should-we-stop flag + stop: bool = False + + def __init__(self): + super().__init__() + self.triggers = {} + self.trigger_cache = {} + self.to_create = deque() + self.to_delete = deque() + self.events = deque() + self.failed_triggers = deque() + + def run(self): + """Sync entrypoint - just runs arun in an async loop.""" + # Pylint complains about this with a 3.6 base, can remove with 3.7+ + asyncio.run(self.arun()) # pylint: disable=no-member + + async def arun(self): + """ + Main (asynchronous) logic loop. + + The loop in here runs trigger addition/deletion/cleanup. Actual + triggers run in their own separate coroutines. + """ + watchdog = create_task(self.block_watchdog()) + last_status = time.time() + while not self.stop: + # Run core logic + await self.create_triggers() + await self.delete_triggers() + await self.cleanup_finished_triggers() + # Sleep for a bit + await asyncio.sleep(1) + # Every minute, log status + if time.time() - last_status >= 60: + self.log.info("%i triggers currently running", len(self.triggers)) + last_status = time.time() + # Wait for watchdog to complete + await watchdog + + async def create_triggers(self): + """ + Drain the to_create queue and create all triggers that have been + requested in the DB that we don't yet have. + """ + while self.to_create: + trigger_id, trigger_instance = self.to_create.popleft() + if trigger_id not in self.triggers: + self.triggers[trigger_id] = { + "task": create_task(self.run_trigger(trigger_id, trigger_instance)), + "name": f"{trigger_instance!r} (ID {trigger_id})", + "events": 0, + } + else: + self.log.warning("Trigger %s had insertion attempted twice", trigger_id) + + async def delete_triggers(self): + """ + Drain the to_delete queue and ensure all triggers that are not in the + DB are cancelled, so the cleanup job deletes them. + """ + while self.to_delete: + trigger_id = self.to_delete.popleft() + if trigger_id in self.triggers: + # We only delete if it did not exit already + self.triggers[trigger_id]["task"].cancel() + + async def cleanup_finished_triggers(self): + """ + Go through all trigger tasks (coroutines) and clean up entries for + ones that have exited, optionally warning users if the exit was + not normal. + """ + for trigger_id, details in list(self.triggers.items()): # pylint: disable=too-many-nested-blocks + if details["task"].done(): + # Check to see if it exited for good reasons + try: + result = details["task"].result() + except (asyncio.CancelledError, SystemExit, KeyboardInterrupt): + # These are "expected" exceptions and we stop processing here + # If we don't, then the system requesting a trigger be removed - + # which turns into CancelledError - results in a failure. + del self.triggers[trigger_id] + continue + except BaseException as e: + # This is potentially bad, so log it. + self.log.error("Trigger %s exited with error %s", details["name"], e) + else: + # See if they foolishly returned a TriggerEvent + if isinstance(result, TriggerEvent): + self.log.error( + "Trigger %s returned a TriggerEvent rather than yielding it", details["name"] + ) + # See if this exited without sending an event, in which case + # any task instances depending on it need to be failed + if details["events"] == 0: + self.log.error( + "Trigger %s exited without sending an event. Dependent tasks will be failed.", + details["name"], + ) + self.failed_triggers.append(trigger_id) + del self.triggers[trigger_id] + + async def block_watchdog(self): + """ + Watchdog loop that detects blocking (badly-written) triggers. + + Triggers should be well-behaved async coroutines and await whenever + they need to wait; this loop tries to run every 100ms to see if + there are badly-written triggers taking longer than that and blocking + the event loop. + + Unfortunately, we can't tell what trigger is blocking things, but + we can at least detect the top-level problem. + """ + while not self.stop: + last_run = time.monotonic() + await asyncio.sleep(0.1) + # We allow a generous amount of buffer room for now, since it might + # be a busy event loop. + time_elapsed = time.monotonic() - last_run + if time_elapsed > 0.2: + self.log.error( + "Triggerer's async thread was blocked for %.2f seconds, " + "likely by a badly-written trigger. Set PYTHONASYNCIODEBUG=1 " + "to get more information on overrunning coroutines.", + time_elapsed, + ) + + # Async trigger logic + + async def run_trigger(self, trigger_id, trigger): + """ + Wrapper which runs an actual trigger (they are async generators) + and pushes their events into our outbound event deque. + """ + self.log.info("Trigger %s starting", self.triggers[trigger_id]['name']) + try: + async for event in trigger.run(): + self.log.info("Trigger %s fired: %s", self.triggers[trigger_id]['name'], event) + self.triggers[trigger_id]["events"] += 1 + self.events.append((trigger_id, event)) + finally: + # CancelledError will get injected when we're stopped - which is + # fine, the cleanup process will understand that, but we want to + # allow triggers a chance to cleanup, either in that case or if + # they exit cleanly. + trigger.cleanup() + + # Main-thread sync API + + def update_triggers(self, requested_trigger_ids: Set[int]): + """ + Called from the main thread to request that we update what + triggers we're running. + + Works out the differences - ones to add, and ones to remove - then + adds them to the deques so the subthread can actually mutate the running + trigger set. + """ + current_trigger_ids = set(self.triggers.keys()) Review comment: This looks like it's prone to a race condition: This code is run from the main thread, and the aio thread could also be mutating this via `cleanup_finished_triggers`. Now I'm not familiar with Python's threading access model, but I'd _guess_ this needs a lock to protect access to this variable? ########## File path: airflow/jobs/triggerer_job.py ########## @@ -0,0 +1,418 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +import importlib +import os +import signal +import sys +import threading +import time +from collections import deque +from typing import Deque, Dict, List, Optional, Set, Tuple, Type + +from airflow.jobs.base_job import BaseJob +from airflow.models.trigger import Trigger +from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.typing_compat import TypedDict +from airflow.utils.asyncio import create_task +from airflow.utils.log.logging_mixin import LoggingMixin + + +class TriggererJob(BaseJob, LoggingMixin): + """ + TriggererJob continuously runs active triggers in asyncio, watching + for them to fire off their events and then dispatching that information + to their dependent tasks/DAGs. + + It runs as two threads: + - The main thread does DB calls/checkins + - A subthread runs all the async code + """ + + __mapper_args__ = {'polymorphic_identity': 'TriggererJob'} + + partition_ids: Optional[List[int]] = None + partition_total: Optional[int] = None + + def __init__(self, partition=None, *args, **kwargs): + # Make sure we can actually run + if not hasattr(asyncio, "create_task"): + raise RuntimeError("The triggerer/deferred operators only work on Python 3.7 and above.") + # Call superclass + super().__init__(*args, **kwargs) + # Decode partition information + self.partition_ids, self.partition_total = None, None + if partition: + self.partition_ids, self.partition_total = self.decode_partition(partition) + # Set up runner async thread + self.runner = TriggerRunner() + + def decode_partition(self, partition: str) -> Tuple[List[int], int]: + """ + Given a string-format partition specification, returns the list of + partition IDs it represents and the partition total. + """ + try: + # The partition format is "1,2,3/10" where the numbers before + # the slash are the partitions we represent, and the number + # after is the total number. Most users will just have a single + # partition number, e.g. "2/10". + ids_str, total_str = partition.split("/", 1) + partition_total = int(total_str) + partition_ids = [] + for id_str in ids_str.split(","): + id_number = int(id_str) + # Bounds checking (they're 1-indexed, which might catch people out) + if id_number <= 0 or id_number > self.partition_total: + raise ValueError(f"Partition number {id_number} is impossible") + self.partition_ids.append(id_number) + except (ValueError, TypeError): + raise ValueError(f"Invalid partition specification: {partition}") + return partition_ids, partition_total + + def register_signals(self) -> None: + """Register signals that stop child processes""" + signal.signal(signal.SIGINT, self._exit_gracefully) + signal.signal(signal.SIGTERM, self._exit_gracefully) + + def _exit_gracefully(self, signum, frame) -> None: # pylint: disable=unused-argument + """Helper method to clean up processor_agent to avoid leaving orphan processes.""" + # The first time, try to exit nicely + if not self.runner.stop: + self.log.info("Exiting gracefully upon receiving signal %s", signum) + self.runner.stop = True + else: + self.log.warning("Forcing exit due to second exit signal %s", signum) + sys.exit(os.EX_SOFTWARE) + + def _execute(self) -> None: + # Display custom startup ack depending on plurality of partitions + if self.partition_ids is None: + self.log.info("Starting the triggerer") + elif len(self.partition_ids) == 1: + self.log.info( + "Starting the triggerer (partition %s of %s)", self.partition_ids[0], self.partition_total + ) + else: + self.log.info( + "Starting the triggerer (partitions %s of %s)", self.partition_ids, self.partition_total + ) + + try: + # Kick off runner thread + self.runner.start() + # Start our own DB loop in the main thread + self._run_trigger_loop() + except Exception: # pylint: disable=broad-except + self.log.exception("Exception when executing TriggererJob._run_trigger_loop") + raise + finally: + self.log.info("Waiting for triggers to clean up") + # Tell the subthread to stop and then wait for it. + # If the user interrupts/terms again, _graceful_exit will allow them + # to force-kill here. + self.runner.stop = True + self.runner.join() + self.log.info("Exited trigger loop") + + def _run_trigger_loop(self) -> None: + """ + The main-thread trigger loop. + + This runs synchronously and handles all database reads/writes. + """ + while not self.runner.stop: + # Clean out unused triggers + Trigger.clean_unused() + # Load/delete triggers + self.load_triggers() + # Handle events + self.handle_events() + # Handle failed triggers + self.handle_failed_triggers() + # Idle sleep + time.sleep(1) + + def load_triggers(self): + """ + Queries the database to get the triggers we're supposed to be running, + adds them to our runner, and then removes ones from it we no longer + need. + """ + requested_trigger_ids = Trigger.runnable_ids( + partition_ids=self.partition_ids, partition_total=self.partition_total + ) + self.runner.update_triggers(set(requested_trigger_ids)) + + def handle_events(self): + """ + Handles outbound events from triggers - dispatching them into the Trigger + model where they are then pushed into the relevant task instances. + """ + while self.runner.events: + # Get the event and its trigger ID + trigger_id, event = self.runner.events.popleft() + # Tell the model to wake up its tasks + Trigger.submit_event(trigger_id=trigger_id, event=event) + + def handle_failed_triggers(self): + """ + Handles "failed" triggers - ones that errored or exited before they + sent an event. Task Instances that depend on them need failing. + """ + while self.runner.failed_triggers: + # Tell the model to fail this trigger's deps + trigger_id = self.runner.failed_triggers.popleft() + Trigger.submit_failure(trigger_id=trigger_id) + + +class TriggerDetails(TypedDict): + """Type class for the trigger details dictionary""" + + task: asyncio.Task + name: str + events: int + + +class TriggerRunner(threading.Thread, LoggingMixin): + """ + Runtime environment for all triggers. + + Mainly runs inside its own thread, where it hands control off to an asyncio + event loop, but is also sometimes interacted with from the main thread + (where all the DB queries are done). All communication between threads is + done via Deques. + """ + + # Maps trigger IDs to their running tasks and other info + triggers: Dict[int, TriggerDetails] + + # Cache for looking up triggers by classpath + trigger_cache: Dict[str, Type[BaseTrigger]] + + # Inbound queue of new triggers + to_create: Deque[Tuple[int, BaseTrigger]] + + # Inbound queue of deleted triggers + to_delete: Deque[int] + + # Outbound queue of events + events: Deque[Tuple[int, TriggerEvent]] + + # Outbound queue of failed triggers + failed_triggers: Deque[int] + + # Should-we-stop flag + stop: bool = False + + def __init__(self): + super().__init__() + self.triggers = {} + self.trigger_cache = {} + self.to_create = deque() + self.to_delete = deque() + self.events = deque() + self.failed_triggers = deque() + + def run(self): + """Sync entrypoint - just runs arun in an async loop.""" + # Pylint complains about this with a 3.6 base, can remove with 3.7+ + asyncio.run(self.arun()) # pylint: disable=no-member + + async def arun(self): + """ + Main (asynchronous) logic loop. + + The loop in here runs trigger addition/deletion/cleanup. Actual + triggers run in their own separate coroutines. + """ + watchdog = create_task(self.block_watchdog()) + last_status = time.time() + while not self.stop: + # Run core logic + await self.create_triggers() + await self.delete_triggers() + await self.cleanup_finished_triggers() + # Sleep for a bit + await asyncio.sleep(1) + # Every minute, log status + if time.time() - last_status >= 60: + self.log.info("%i triggers currently running", len(self.triggers)) + last_status = time.time() + # Wait for watchdog to complete + await watchdog + + async def create_triggers(self): + """ + Drain the to_create queue and create all triggers that have been + requested in the DB that we don't yet have. + """ + while self.to_create: + trigger_id, trigger_instance = self.to_create.popleft() + if trigger_id not in self.triggers: + self.triggers[trigger_id] = { + "task": create_task(self.run_trigger(trigger_id, trigger_instance)), + "name": f"{trigger_instance!r} (ID {trigger_id})", + "events": 0, + } + else: + self.log.warning("Trigger %s had insertion attempted twice", trigger_id) + + async def delete_triggers(self): + """ + Drain the to_delete queue and ensure all triggers that are not in the + DB are cancelled, so the cleanup job deletes them. + """ + while self.to_delete: + trigger_id = self.to_delete.popleft() + if trigger_id in self.triggers: + # We only delete if it did not exit already + self.triggers[trigger_id]["task"].cancel() + + async def cleanup_finished_triggers(self): + """ + Go through all trigger tasks (coroutines) and clean up entries for + ones that have exited, optionally warning users if the exit was + not normal. + """ + for trigger_id, details in list(self.triggers.items()): # pylint: disable=too-many-nested-blocks + if details["task"].done(): + # Check to see if it exited for good reasons + try: + result = details["task"].result() + except (asyncio.CancelledError, SystemExit, KeyboardInterrupt): + # These are "expected" exceptions and we stop processing here + # If we don't, then the system requesting a trigger be removed - + # which turns into CancelledError - results in a failure. + del self.triggers[trigger_id] + continue + except BaseException as e: + # This is potentially bad, so log it. + self.log.error("Trigger %s exited with error %s", details["name"], e) + else: + # See if they foolishly returned a TriggerEvent + if isinstance(result, TriggerEvent): + self.log.error( + "Trigger %s returned a TriggerEvent rather than yielding it", details["name"] + ) + # See if this exited without sending an event, in which case + # any task instances depending on it need to be failed + if details["events"] == 0: + self.log.error( + "Trigger %s exited without sending an event. Dependent tasks will be failed.", + details["name"], + ) + self.failed_triggers.append(trigger_id) + del self.triggers[trigger_id] + + async def block_watchdog(self): + """ + Watchdog loop that detects blocking (badly-written) triggers. + + Triggers should be well-behaved async coroutines and await whenever + they need to wait; this loop tries to run every 100ms to see if + there are badly-written triggers taking longer than that and blocking + the event loop. + + Unfortunately, we can't tell what trigger is blocking things, but + we can at least detect the top-level problem. + """ + while not self.stop: + last_run = time.monotonic() + await asyncio.sleep(0.1) + # We allow a generous amount of buffer room for now, since it might + # be a busy event loop. + time_elapsed = time.monotonic() - last_run + if time_elapsed > 0.2: + self.log.error( + "Triggerer's async thread was blocked for %.2f seconds, " + "likely by a badly-written trigger. Set PYTHONASYNCIODEBUG=1 " + "to get more information on overrunning coroutines.", + time_elapsed, + ) + + # Async trigger logic + + async def run_trigger(self, trigger_id, trigger): + """ + Wrapper which runs an actual trigger (they are async generators) + and pushes their events into our outbound event deque. + """ + self.log.info("Trigger %s starting", self.triggers[trigger_id]['name']) + try: + async for event in trigger.run(): + self.log.info("Trigger %s fired: %s", self.triggers[trigger_id]['name'], event) + self.triggers[trigger_id]["events"] += 1 + self.events.append((trigger_id, event)) + finally: + # CancelledError will get injected when we're stopped - which is + # fine, the cleanup process will understand that, but we want to + # allow triggers a chance to cleanup, either in that case or if + # they exit cleanly. + trigger.cleanup() + + # Main-thread sync API + + def update_triggers(self, requested_trigger_ids: Set[int]): + """ + Called from the main thread to request that we update what + triggers we're running. + + Works out the differences - ones to add, and ones to remove - then + adds them to the deques so the subthread can actually mutate the running + trigger set. + """ + current_trigger_ids = set(self.triggers.keys()) + # Work out the two difference sets + new_trigger_ids = requested_trigger_ids.difference(current_trigger_ids) + old_trigger_ids = current_trigger_ids.difference(requested_trigger_ids) + # Bulk-fetch new trigger records + new_triggers = Trigger.bulk_fetch(new_trigger_ids) + # Add in new triggers + for new_id in new_trigger_ids: + # Check it didn't vanish in the meantime + if new_id not in new_triggers: + self.log.warning("Trigger ID %s disappeared before we could start it", new_id) + continue + # Resolve trigger record into an actual class instance + trigger_class = self.get_trigger_by_classpath(new_triggers[new_id].classpath) + self.to_create.append((new_id, trigger_class(**new_triggers[new_id].kwargs))) + # Remove old triggers + for old_id in old_trigger_ids: + self.to_delete.append(old_id) + + def get_trigger_by_classpath(self, classpath: str) -> Type[BaseTrigger]: + """ + Gets a trigger class by its classpath ("path.to.module.classname") + + Uses a cache dictionary to speed up lookups after the first time. + """ + if classpath not in self.trigger_cache: + module_name, class_name = classpath.rsplit(".", 1) + try: + module = importlib.import_module(module_name) + except ImportError: + raise ImportError( + f"Cannot import trigger module {module_name} (from trigger classpath {classpath})" + ) + try: + trigger_class = getattr(module, class_name) + except AttributeError: + raise ImportError(f"Cannot import trigger {class_name} from module {module_name}") + self.trigger_cache[classpath] = trigger_class Review comment: `from airflow.utils.module_loading import import_string` and then ```suggestion self.trigger_cache[classpath] = import_string(classpath) ``` (The exceptions would be slightly different) ########## File path: airflow/models/dag.py ########## @@ -1313,31 +1324,36 @@ def clear( ) visited_external_tis.add(ti_key) - if get_tis: - return tis + if get_ti_keys: + return result + + result_instances = session.query(TI).filter( + tuple_(TI.dag_id, TI.task_id, TI.execution_date).in_(result) + ) - tis = tis.all() + if get_ti_instances: + return result_instances if dry_run: session.expunge_all() - return tis + return result_instances # Do not use count() here, it's actually much slower than just retrieving all the rows when # tis has multiple UNION statements. Review comment: You've removed the UNIONs now haven't you? If so this comment is now out of date ########## File path: airflow/models/taskinstance.py ########## @@ -1328,22 +1389,58 @@ def _update_ti_state_for_sensing(self, session=None): def _execute_task(self, context, task_copy): """Executes Task (optionally with a Timeout) and pushes Xcom results""" + # If the task has been deferred and is being executed due to a trigger, + # then we need to pick the right method to come back to, otherwise + # we go for the default execute + execute_callable = task_copy.execute + if self.next_method: + execute_callable = getattr(task_copy, self.next_method) + if self.next_kwargs: + execute_callable = partial(execute_callable, **self.next_kwargs) # If a timeout is specified for the task, make it fail # if it goes beyond if task_copy.execution_timeout: try: with timeout(task_copy.execution_timeout.total_seconds()): - result = task_copy.execute(context=context) + result = execute_callable(context=context) except AirflowTaskTimeout: task_copy.on_kill() raise else: - result = task_copy.execute(context=context) + result = execute_callable(context=context) # If the task returns a result, push an XCom containing it if task_copy.do_xcom_push and result is not None: self.xcom_push(key=XCOM_RETURN_KEY, value=result) return result + @provide_session + def _defer_task(self, session, defer: TaskDeferred): + """ + Marks the task as deferred and sets up the trigger that is needed + to resume it. + """ + from airflow.models.trigger import Trigger + + # First, make the trigger entry + trigger_row = Trigger.from_object(defer.trigger) + session.add(trigger_row) + session.commit() Review comment: https://github.com/apache/airflow/blob/master/CONTRIBUTING.rst#database-session-handling > If a function accepts a session parameter it should not commit the transaction itself. Session management is up to the caller. (And provide session will commit for us automatically if it creates the session) ```suggestion session.flush() ``` ########## File path: airflow/models/trigger.py ########## @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import datetime +from typing import Any, Dict, List, Optional + +import jump +from sqlalchemy import BigInteger, Column, String, func + +from airflow.models.base import Base +from airflow.models.taskinstance import TaskInstance +from airflow.triggers.base import BaseTrigger +from airflow.utils import timezone +from airflow.utils.session import provide_session +from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime +from airflow.utils.state import State + + +class Trigger(Base): + """ + Triggers are a workload that run in an asynchronous event loop shared with + other Triggers, and fire off events that will unpause deferred Tasks, + start linked DAGs, etc. + + They are persisted into the database and then re-hydrated into a single + "triggerer" process, where they're all run at once. We model it so that + there is a many-to-one relationship between Task and Trigger, for future + deduplication logic to use. + + Rows will be evicted from the database when the triggerer detects no + active Tasks/DAGs using them. Events are not stored in the database; + when an Event is fired, the triggerer will directly push its data to the + appropriate Task/DAG. + """ + + __tablename__ = "trigger" + + id = Column(BigInteger, primary_key=True) + classpath = Column(String(1000), nullable=False) + kwargs = Column(ExtendedJSON, nullable=False) + created_date = Column(UtcDateTime, nullable=False) + + def __init__( + self, classpath: str, kwargs: Dict[str, Any], created_date: Optional[datetime.datetime] = None + ): + super().__init__() + self.classpath = classpath + self.kwargs = kwargs + self.created_date = created_date or timezone.utcnow() + + @classmethod + @provide_session + def runnable_ids( + cls, session=None, partition_ids: Optional[List[int]] = None, partition_total: Optional[int] = None + ): # pylint: disable=unused-argument + """ + Returns all "runnable" triggers IDs, optionally filtering down by partition. + + This is a pretty basic partition algorithm for now, but it does the job. + """ + # NOTE: It's possible in future that we could try and pre-calculate a + # partition entry in a large virtual ring (e.g. 4096 buckets) and store + # that in the DB for more direct querying, but for now Jump is fast + # enough of a hash to do this all locally - about 0.1s per million hashes + + # Retrieve all IDs first + trigger_ids = [row[0] for row in session.query(cls.id).all()] + + # Short-circuit for "no partitioning" + if partition_ids is None or partition_total is None: + return trigger_ids + + # Go through and map each trigger ID to a partition number, + # using a quick, consistent hash (Jump), keeping only the ones that + # match one of our partition IDs + return [x for x in trigger_ids if jump.hash(x, partition_total) + 1 in partition_ids] + + @classmethod + def from_object(cls, trigger: BaseTrigger): + """ + Alternative constructor that creates a trigger row based directly + off of a Trigger object. + """ + classpath, kwargs = trigger.serialize() + return cls(classpath=classpath, kwargs=kwargs) + + @classmethod + @provide_session + def bulk_fetch(cls, ids: List[int], session=None) -> Dict[int, "Trigger"]: + """ + Fetches all of the Triggers by ID and returns a dict mapping + ID -> Trigger instance + """ + return {obj.id: obj for obj in session.query(cls).filter(cls.id.in_(ids)).all()} + + @classmethod + @provide_session + def clean_unused(cls, session=None): + """ + Deletes all triggers that have no tasks/DAGs dependent on them + (triggers have a one-to-many relationship to both) + """ + # Update all task instances with trigger IDs that are not DEFERRED to remove them + session.query(TaskInstance).filter( + TaskInstance.state != State.DEFERRED, TaskInstance.trigger_id.isnot(None) + ).update({TaskInstance.trigger_id: None}) + # Get all triggers that have no task instances depending on them... + ids = [ + x[0] + for x in ( Review comment: ```suggestion trigger_id for (trigger_id,) in ( ``` I think is slightly easier to read. ########## File path: airflow/models/trigger.py ########## @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import datetime +from typing import Any, Dict, List, Optional + +import jump +from sqlalchemy import BigInteger, Column, String, func + +from airflow.models.base import Base +from airflow.models.taskinstance import TaskInstance +from airflow.triggers.base import BaseTrigger +from airflow.utils import timezone +from airflow.utils.session import provide_session +from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime +from airflow.utils.state import State + + +class Trigger(Base): + """ + Triggers are a workload that run in an asynchronous event loop shared with + other Triggers, and fire off events that will unpause deferred Tasks, + start linked DAGs, etc. + + They are persisted into the database and then re-hydrated into a single + "triggerer" process, where they're all run at once. We model it so that + there is a many-to-one relationship between Task and Trigger, for future + deduplication logic to use. + + Rows will be evicted from the database when the triggerer detects no + active Tasks/DAGs using them. Events are not stored in the database; + when an Event is fired, the triggerer will directly push its data to the + appropriate Task/DAG. + """ + + __tablename__ = "trigger" + + id = Column(BigInteger, primary_key=True) + classpath = Column(String(1000), nullable=False) + kwargs = Column(ExtendedJSON, nullable=False) + created_date = Column(UtcDateTime, nullable=False) + + def __init__( + self, classpath: str, kwargs: Dict[str, Any], created_date: Optional[datetime.datetime] = None + ): + super().__init__() + self.classpath = classpath + self.kwargs = kwargs + self.created_date = created_date or timezone.utcnow() + + @classmethod + @provide_session + def runnable_ids( + cls, session=None, partition_ids: Optional[List[int]] = None, partition_total: Optional[int] = None + ): # pylint: disable=unused-argument + """ + Returns all "runnable" triggers IDs, optionally filtering down by partition. + + This is a pretty basic partition algorithm for now, but it does the job. + """ + # NOTE: It's possible in future that we could try and pre-calculate a + # partition entry in a large virtual ring (e.g. 4096 buckets) and store + # that in the DB for more direct querying, but for now Jump is fast + # enough of a hash to do this all locally - about 0.1s per million hashes + + # Retrieve all IDs first + trigger_ids = [row[0] for row in session.query(cls.id).all()] + + # Short-circuit for "no partitioning" + if partition_ids is None or partition_total is None: + return trigger_ids + + # Go through and map each trigger ID to a partition number, + # using a quick, consistent hash (Jump), keeping only the ones that + # match one of our partition IDs + return [x for x in trigger_ids if jump.hash(x, partition_total) + 1 in partition_ids] + + @classmethod + def from_object(cls, trigger: BaseTrigger): + """ + Alternative constructor that creates a trigger row based directly + off of a Trigger object. + """ + classpath, kwargs = trigger.serialize() + return cls(classpath=classpath, kwargs=kwargs) + + @classmethod + @provide_session + def bulk_fetch(cls, ids: List[int], session=None) -> Dict[int, "Trigger"]: + """ + Fetches all of the Triggers by ID and returns a dict mapping + ID -> Trigger instance + """ + return {obj.id: obj for obj in session.query(cls).filter(cls.id.in_(ids)).all()} + + @classmethod + @provide_session + def clean_unused(cls, session=None): + """ + Deletes all triggers that have no tasks/DAGs dependent on them + (triggers have a one-to-many relationship to both) + """ + # Update all task instances with trigger IDs that are not DEFERRED to remove them + session.query(TaskInstance).filter( + TaskInstance.state != State.DEFERRED, TaskInstance.trigger_id.isnot(None) + ).update({TaskInstance.trigger_id: None}) + # Get all triggers that have no task instances depending on them... + ids = [ + x[0] + for x in ( + session.query(cls.id) + .join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=True) + .group_by(cls.id) + .having(func.count(TaskInstance.trigger_id) == 0) + ) + ] + # ...and delete them (we can't do this in one query due to MySQL) Review comment: In other places we've special cased this sort of path for MySQL https://github.com/apache/airflow/blob/9f7c67feb5f2f8d3eeb81cb5f2bf158fb76f5b9e/airflow/jobs/scheduler_job.py#L785-L859 -- is it worth doing that here too? ########## File path: airflow/models/trigger.py ########## @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import datetime +from typing import Any, Dict, List, Optional + +import jump +from sqlalchemy import BigInteger, Column, String, func + +from airflow.models.base import Base +from airflow.models.taskinstance import TaskInstance +from airflow.triggers.base import BaseTrigger +from airflow.utils import timezone +from airflow.utils.session import provide_session +from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime +from airflow.utils.state import State + + +class Trigger(Base): + """ + Triggers are a workload that run in an asynchronous event loop shared with + other Triggers, and fire off events that will unpause deferred Tasks, + start linked DAGs, etc. + + They are persisted into the database and then re-hydrated into a single + "triggerer" process, where they're all run at once. We model it so that + there is a many-to-one relationship between Task and Trigger, for future + deduplication logic to use. + + Rows will be evicted from the database when the triggerer detects no + active Tasks/DAGs using them. Events are not stored in the database; + when an Event is fired, the triggerer will directly push its data to the + appropriate Task/DAG. + """ + + __tablename__ = "trigger" + + id = Column(BigInteger, primary_key=True) + classpath = Column(String(1000), nullable=False) + kwargs = Column(ExtendedJSON, nullable=False) + created_date = Column(UtcDateTime, nullable=False) + + def __init__( + self, classpath: str, kwargs: Dict[str, Any], created_date: Optional[datetime.datetime] = None + ): + super().__init__() + self.classpath = classpath + self.kwargs = kwargs + self.created_date = created_date or timezone.utcnow() + + @classmethod + @provide_session + def runnable_ids( + cls, session=None, partition_ids: Optional[List[int]] = None, partition_total: Optional[int] = None + ): # pylint: disable=unused-argument + """ + Returns all "runnable" triggers IDs, optionally filtering down by partition. + + This is a pretty basic partition algorithm for now, but it does the job. + """ + # NOTE: It's possible in future that we could try and pre-calculate a + # partition entry in a large virtual ring (e.g. 4096 buckets) and store + # that in the DB for more direct querying, but for now Jump is fast + # enough of a hash to do this all locally - about 0.1s per million hashes + + # Retrieve all IDs first + trigger_ids = [row[0] for row in session.query(cls.id).all()] + + # Short-circuit for "no partitioning" + if partition_ids is None or partition_total is None: + return trigger_ids + + # Go through and map each trigger ID to a partition number, + # using a quick, consistent hash (Jump), keeping only the ones that + # match one of our partition IDs + return [x for x in trigger_ids if jump.hash(x, partition_total) + 1 in partition_ids] + + @classmethod + def from_object(cls, trigger: BaseTrigger): + """ + Alternative constructor that creates a trigger row based directly + off of a Trigger object. + """ + classpath, kwargs = trigger.serialize() + return cls(classpath=classpath, kwargs=kwargs) + + @classmethod + @provide_session + def bulk_fetch(cls, ids: List[int], session=None) -> Dict[int, "Trigger"]: + """ + Fetches all of the Triggers by ID and returns a dict mapping + ID -> Trigger instance + """ + return {obj.id: obj for obj in session.query(cls).filter(cls.id.in_(ids)).all()} + + @classmethod + @provide_session + def clean_unused(cls, session=None): + """ + Deletes all triggers that have no tasks/DAGs dependent on them + (triggers have a one-to-many relationship to both) + """ + # Update all task instances with trigger IDs that are not DEFERRED to remove them + session.query(TaskInstance).filter( + TaskInstance.state != State.DEFERRED, TaskInstance.trigger_id.isnot(None) + ).update({TaskInstance.trigger_id: None}) + # Get all triggers that have no task instances depending on them... + ids = [ + x[0] + for x in ( + session.query(cls.id) + .join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=True) + .group_by(cls.id) + .having(func.count(TaskInstance.trigger_id) == 0) + ) + ] + # ...and delete them (we can't do this in one query due to MySQL) + session.query(Trigger).filter(Trigger.id.in_(ids)).delete(synchronize_session=False) + + @classmethod + @provide_session + def submit_event(cls, trigger_id, event, session=None): + """ + Takes an event from an instance of itself, and triggers all dependent + tasks to resume. + """ + for task_instance in session.query(TaskInstance).filter( + TaskInstance.trigger_id == trigger_id, TaskInstance.state == State.DEFERRED + ): + # Add the event's payload into the kwargs for the task + next_kwargs = task_instance.next_kwargs or {} + next_kwargs["event"] = event.payload + task_instance.next_kwargs = next_kwargs + # Remove ourselves as its trigger + task_instance.trigger_id = None + # Finally, mark it as scheduled so it gets re-queued + task_instance.state = State.SCHEDULED Review comment: Worth adding a `Log` row for when it gets resumed, like we had for when it gets deferred? ########## File path: airflow/triggers/base.py ########## @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, AsyncIterator, Dict, Tuple + + +class BaseTrigger: + """ + Base class for all triggers. + + A trigger has two contexts it can exist in: + + - As part of a DAG declaration, where it's declared. Review comment: Is this true? What I've seen of the code so far Triggers only exist when they are created in side a task's execute function ########## File path: airflow/triggers/base.py ########## @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, AsyncIterator, Dict, Tuple + + +class BaseTrigger: + """ + Base class for all triggers. + + A trigger has two contexts it can exist in: + + - As part of a DAG declaration, where it's declared. + - Actively running in a trigger worker + + We use the same class for both situations, and rely on all Trigger classes + to be able to return the (Airflow-JSON-encodable) arguments that will + let them be reinsantiated elsewhere. + """ + + def __init__(self): + pass + + def serialize(self) -> Tuple[str, Dict[str, Any]]: + """ + Returns the information needed to reconstruct this Trigger. + + The first element of the returned tuple is the class path, the second + is the keyword arguments needed to re-instantiate it. Review comment: ```suggestion :return: Tuple of (class path, keyword arguments needed to re-instantiate). ``` ########## File path: airflow/models/trigger.py ########## @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import datetime +from typing import Any, Dict, List, Optional + +import jump +from sqlalchemy import BigInteger, Column, String, func + +from airflow.models.base import Base +from airflow.models.taskinstance import TaskInstance +from airflow.triggers.base import BaseTrigger +from airflow.utils import timezone +from airflow.utils.session import provide_session +from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime +from airflow.utils.state import State + + +class Trigger(Base): + """ + Triggers are a workload that run in an asynchronous event loop shared with + other Triggers, and fire off events that will unpause deferred Tasks, + start linked DAGs, etc. + + They are persisted into the database and then re-hydrated into a single + "triggerer" process, where they're all run at once. We model it so that + there is a many-to-one relationship between Task and Trigger, for future + deduplication logic to use. + + Rows will be evicted from the database when the triggerer detects no + active Tasks/DAGs using them. Events are not stored in the database; + when an Event is fired, the triggerer will directly push its data to the + appropriate Task/DAG. + """ + + __tablename__ = "trigger" + + id = Column(BigInteger, primary_key=True) + classpath = Column(String(1000), nullable=False) + kwargs = Column(ExtendedJSON, nullable=False) + created_date = Column(UtcDateTime, nullable=False) + + def __init__( + self, classpath: str, kwargs: Dict[str, Any], created_date: Optional[datetime.datetime] = None + ): + super().__init__() + self.classpath = classpath + self.kwargs = kwargs + self.created_date = created_date or timezone.utcnow() + + @classmethod + @provide_session + def runnable_ids( + cls, session=None, partition_ids: Optional[List[int]] = None, partition_total: Optional[int] = None + ): # pylint: disable=unused-argument + """ + Returns all "runnable" triggers IDs, optionally filtering down by partition. + + This is a pretty basic partition algorithm for now, but it does the job. + """ + # NOTE: It's possible in future that we could try and pre-calculate a + # partition entry in a large virtual ring (e.g. 4096 buckets) and store + # that in the DB for more direct querying, but for now Jump is fast + # enough of a hash to do this all locally - about 0.1s per million hashes + + # Retrieve all IDs first + trigger_ids = [row[0] for row in session.query(cls.id).all()] + + # Short-circuit for "no partitioning" + if partition_ids is None or partition_total is None: + return trigger_ids + + # Go through and map each trigger ID to a partition number, + # using a quick, consistent hash (Jump), keeping only the ones that + # match one of our partition IDs + return [x for x in trigger_ids if jump.hash(x, partition_total) + 1 in partition_ids] + + @classmethod + def from_object(cls, trigger: BaseTrigger): + """ + Alternative constructor that creates a trigger row based directly + off of a Trigger object. + """ + classpath, kwargs = trigger.serialize() + return cls(classpath=classpath, kwargs=kwargs) + + @classmethod + @provide_session + def bulk_fetch(cls, ids: List[int], session=None) -> Dict[int, "Trigger"]: + """ + Fetches all of the Triggers by ID and returns a dict mapping + ID -> Trigger instance + """ + return {obj.id: obj for obj in session.query(cls).filter(cls.id.in_(ids)).all()} + + @classmethod + @provide_session + def clean_unused(cls, session=None): + """ + Deletes all triggers that have no tasks/DAGs dependent on them + (triggers have a one-to-many relationship to both) + """ + # Update all task instances with trigger IDs that are not DEFERRED to remove them + session.query(TaskInstance).filter( + TaskInstance.state != State.DEFERRED, TaskInstance.trigger_id.isnot(None) + ).update({TaskInstance.trigger_id: None}) + # Get all triggers that have no task instances depending on them... + ids = [ + x[0] + for x in ( + session.query(cls.id) + .join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=True) + .group_by(cls.id) + .having(func.count(TaskInstance.trigger_id) == 0) + ) + ] + # ...and delete them (we can't do this in one query due to MySQL) + session.query(Trigger).filter(Trigger.id.in_(ids)).delete(synchronize_session=False) + + @classmethod + @provide_session + def submit_event(cls, trigger_id, event, session=None): + """ + Takes an event from an instance of itself, and triggers all dependent + tasks to resume. + """ + for task_instance in session.query(TaskInstance).filter( + TaskInstance.trigger_id == trigger_id, TaskInstance.state == State.DEFERRED + ): + # Add the event's payload into the kwargs for the task + next_kwargs = task_instance.next_kwargs or {} + next_kwargs["event"] = event.payload + task_instance.next_kwargs = next_kwargs + # Remove ourselves as its trigger + task_instance.trigger_id = None + # Finally, mark it as scheduled so it gets re-queued + task_instance.state = State.SCHEDULED + + @classmethod + @provide_session + def submit_failure(cls, trigger_id, session=None): + """ + Called when a trigger has failed unexpectedly, and we need to mark + everything that depended on it as failed. + """ + session.query(TaskInstance).filter( + TaskInstance.trigger_id == trigger_id, TaskInstance.state == State.DEFERRED + ).update({TaskInstance.state: State.FAILED}) Review comment: This won't call on_failure_callbacks, nor will it cascade to downstream tasks (which should end up in state upstream_failed) so this will need to be a bit more complex here. ########## File path: airflow/triggers/temporal.py ########## @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +import datetime +from typing import Any, Dict, Tuple + +import pytz + +from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.utils import timezone + + +class DateTimeTrigger(BaseTrigger): + """ + A trigger that fires exactly once, at the given datetime, give or take + a few seconds. + + The provided datetime MUST be in UTC. + """ + + def __init__(self, moment: datetime.datetime): + super().__init__() + # Make sure it's in UTC + if moment.tzinfo is None: + self.moment = pytz.utc.localize(moment) + elif moment.tzinfo == pytz.utc or getattr(moment.tzinfo, "name", None) == "UTC": + self.moment = moment + else: + raise ValueError(f"The passed datetime must be in UTC, not {moment.tzinfo!r}") + + def serialize(self) -> Tuple[str, Dict[str, Any]]: + return ("airflow.triggers.temporal.DateTimeTrigger", {"moment": self.moment}) + + async def run(self): + """ + Simple time delay loop until the relevant time is met. + + We do have a two-phase delay to save some cycles, but sleeping is so + cheap anyway that it's pretty loose. + """ + # Sleep an hour at a time while it's more than 2 hours away + while timezone.utcnow() - self.moment > datetime.timedelta(hours=2): + await (asyncio.sleep(3600)) + # Sleep a second at a time otherwise + while self.moment > timezone.utcnow(): + await asyncio.sleep(1) Review comment: Is there a reason we don't do this ```python await asyncio.sleep((timezone.utcnow() - self.moment).total_seconds()) ``` ########## File path: airflow/triggers/temporal.py ########## @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +import datetime +from typing import Any, Dict, Tuple + +import pytz + +from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.utils import timezone + + +class DateTimeTrigger(BaseTrigger): + """ + A trigger that fires exactly once, at the given datetime, give or take + a few seconds. + + The provided datetime MUST be in UTC. + """ + + def __init__(self, moment: datetime.datetime): + super().__init__() + # Make sure it's in UTC + if moment.tzinfo is None: + self.moment = pytz.utc.localize(moment) + elif moment.tzinfo == pytz.utc or getattr(moment.tzinfo, "name", None) == "UTC": + self.moment = moment + else: + raise ValueError(f"The passed datetime must be in UTC, not {moment.tzinfo!r}") + + def serialize(self) -> Tuple[str, Dict[str, Any]]: + return ("airflow.triggers.temporal.DateTimeTrigger", {"moment": self.moment}) + + async def run(self): + """ + Simple time delay loop until the relevant time is met. + + We do have a two-phase delay to save some cycles, but sleeping is so + cheap anyway that it's pretty loose. + """ + # Sleep an hour at a time while it's more than 2 hours away + while timezone.utcnow() - self.moment > datetime.timedelta(hours=2): + await (asyncio.sleep(3600)) + # Sleep a second at a time otherwise + while self.moment > timezone.utcnow(): + await asyncio.sleep(1) + # Send our single event and then we're done + yield TriggerEvent(self.moment) + + +class TimeDeltaTrigger(DateTimeTrigger): + """ + Subclass to create DateTimeTriggers based on time delays rather + than exact moments. + + While this is its own distinct class here, it will serialise to a + DateTimeTrigger class, since they're operationally the same. + """ + + def __init__(self, delta: datetime.timedelta): + DateTimeTrigger.__init__(self, moment=timezone.utcnow() + delta) Review comment: ```suggestion super().__init__(self, moment=timezone.utcnow() + delta) ``` ########## File path: docs/apache-airflow/concepts/deferring.rst ########## @@ -0,0 +1,172 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Deferrable Operators & Triggers +=============================== + +Standard :doc:`Operators <operators>` and :doc:`Sensors <sensors>` take up a full *worker slot* for the entire time they are running, even if they are idle; for example, if you only have 100 worker slots available to run Tasks, and you have 100 DAGs waiting on a Sensor that's currently running but idle, then you *cannot run anything else* - even though your entire Airflow cluster is essentially idle. Review comment: This isn't strictly true for sensors in "reschedule" mode -- they will run, then stop and wait for the scheduler to send them back to a worker. ########## File path: docs/apache-airflow/concepts/deferring.rst ########## @@ -0,0 +1,172 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Deferrable Operators & Triggers +=============================== + +Standard :doc:`Operators <operators>` and :doc:`Sensors <sensors>` take up a full *worker slot* for the entire time they are running, even if they are idle; for example, if you only have 100 worker slots available to run Tasks, and you have 100 DAGs waiting on a Sensor that's currently running but idle, then you *cannot run anything else* - even though your entire Airflow cluster is essentially idle. + +This is where *Deferrable Operators* come in. A deferrable operator is one that is written with the ability to suspend itself and remove itself from the worker when it knows that it will have to wait, and hand off the job of resuming it to something called a *Trigger*. As a result, while it is suspended (deferred), it is not taking up a worker slot and your cluster will have a lot less resources wasted on idle Operators or Sensors. + +*Triggers* are small, asynchronous pieces of Python code designed to be run all together in a single Python process; because they are asynchronous, they are able to all co-exist efficiently. As an overview of how this process works: + +* A task instance (running operator) gets to a point where it has to wait, and defers itself with a trigger tied to the event that should resume it. It then removes itself from its current worker and frees up space. +* The new Trigger instance is registered inside Airflow, and picked up by one or more *triggerer* processes +* The trigger is run until it fires, at which point its source task is re-scheduled +* The task instance resumes Review comment: ```suggestion * The scheduler queues the Task Instance to resume on a worker node. ``` ########## File path: docs/apache-airflow/concepts/deferring.rst ########## @@ -0,0 +1,172 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Deferrable Operators & Triggers +=============================== + +Standard :doc:`Operators <operators>` and :doc:`Sensors <sensors>` take up a full *worker slot* for the entire time they are running, even if they are idle; for example, if you only have 100 worker slots available to run Tasks, and you have 100 DAGs waiting on a Sensor that's currently running but idle, then you *cannot run anything else* - even though your entire Airflow cluster is essentially idle. + +This is where *Deferrable Operators* come in. A deferrable operator is one that is written with the ability to suspend itself and remove itself from the worker when it knows that it will have to wait, and hand off the job of resuming it to something called a *Trigger*. As a result, while it is suspended (deferred), it is not taking up a worker slot and your cluster will have a lot less resources wasted on idle Operators or Sensors. + +*Triggers* are small, asynchronous pieces of Python code designed to be run all together in a single Python process; because they are asynchronous, they are able to all co-exist efficiently. As an overview of how this process works: + +* A task instance (running operator) gets to a point where it has to wait, and defers itself with a trigger tied to the event that should resume it. It then removes itself from its current worker and frees up space. +* The new Trigger instance is registered inside Airflow, and picked up by one or more *triggerer* processes +* The trigger is run until it fires, at which point its source task is re-scheduled +* The task instance resumes + +Using deferrable operators as a DAG author is almost transparent; writing them, however, takes a bit more work. + +.. note:: + + Deferrable Operators & Triggers rely on more recent ``asyncio`` features, and as a result only work + on Python 3.7 or higher. + + +Using Deferrable Operators +-------------------------- + +If all you wish to do is use pre-written Deferrable Operators (such as ``TimeSensorAsync``, which comes with Airflow), then there are only two steps you need: + +* Ensure your Airflow installation is running at least one *triggerer* process, as well as the normal *scheduler* Review comment: ```suggestion * Ensure your Airflow installation is running at least one ``triggerer`` process, as well as the normal ``scheduler`` ``` I think ########## File path: docs/apache-airflow/concepts/deferring.rst ########## @@ -0,0 +1,172 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Deferrable Operators & Triggers +=============================== + +Standard :doc:`Operators <operators>` and :doc:`Sensors <sensors>` take up a full *worker slot* for the entire time they are running, even if they are idle; for example, if you only have 100 worker slots available to run Tasks, and you have 100 DAGs waiting on a Sensor that's currently running but idle, then you *cannot run anything else* - even though your entire Airflow cluster is essentially idle. + +This is where *Deferrable Operators* come in. A deferrable operator is one that is written with the ability to suspend itself and remove itself from the worker when it knows that it will have to wait, and hand off the job of resuming it to something called a *Trigger*. As a result, while it is suspended (deferred), it is not taking up a worker slot and your cluster will have a lot less resources wasted on idle Operators or Sensors. + +*Triggers* are small, asynchronous pieces of Python code designed to be run all together in a single Python process; because they are asynchronous, they are able to all co-exist efficiently. As an overview of how this process works: + +* A task instance (running operator) gets to a point where it has to wait, and defers itself with a trigger tied to the event that should resume it. It then removes itself from its current worker and frees up space. +* The new Trigger instance is registered inside Airflow, and picked up by one or more *triggerer* processes +* The trigger is run until it fires, at which point its source task is re-scheduled +* The task instance resumes + +Using deferrable operators as a DAG author is almost transparent; writing them, however, takes a bit more work. + +.. note:: + + Deferrable Operators & Triggers rely on more recent ``asyncio`` features, and as a result only work + on Python 3.7 or higher. + + +Using Deferrable Operators +-------------------------- + +If all you wish to do is use pre-written Deferrable Operators (such as ``TimeSensorAsync``, which comes with Airflow), then there are only two steps you need: + +* Ensure your Airflow installation is running at least one *triggerer* process, as well as the normal *scheduler* +* Use deferrable operators/sensors in your DAGs + +That's it; everything else will be automatically handled for you. If you're upgrading existing DAGs, we even provide some API-compatible sensor variants (e.g. ``TimeSensorAsync`` for ``TimeSensor``) that you can swap into your DAG with no other changes required. + +Note that you cannot yet use the deferral ability from inside custom PythonOperator/TaskFlow code; it is only available to pre-built Operators at the moment. + + +Writing Deferrable Operators +---------------------------- + +Writing a deferrable operator takes a bit more work. There are some main points to consider: + +* Your Operator must defer itself based on a Trigger. If there is a Trigger in core Airflow you can use, great; otherwise, you will have to write one. +* Your Operator will be deleted and removed from its worker while deferred, and no state will persist automatically. You can persist state by asking Airflow to resume you at a certain method or pass certain kwargs, but that's it. Review comment: ```suggestion * Your Operator will be stopped and removed from its worker while deferred, and no state will persist automatically. You can persist state by asking Airflow to resume you at a certain method or pass certain kwargs, but that's it. ``` ########## File path: docs/apache-airflow/concepts/deferring.rst ########## @@ -0,0 +1,172 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Deferrable Operators & Triggers +=============================== + +Standard :doc:`Operators <operators>` and :doc:`Sensors <sensors>` take up a full *worker slot* for the entire time they are running, even if they are idle; for example, if you only have 100 worker slots available to run Tasks, and you have 100 DAGs waiting on a Sensor that's currently running but idle, then you *cannot run anything else* - even though your entire Airflow cluster is essentially idle. + +This is where *Deferrable Operators* come in. A deferrable operator is one that is written with the ability to suspend itself and remove itself from the worker when it knows that it will have to wait, and hand off the job of resuming it to something called a *Trigger*. As a result, while it is suspended (deferred), it is not taking up a worker slot and your cluster will have a lot less resources wasted on idle Operators or Sensors. + +*Triggers* are small, asynchronous pieces of Python code designed to be run all together in a single Python process; because they are asynchronous, they are able to all co-exist efficiently. As an overview of how this process works: + +* A task instance (running operator) gets to a point where it has to wait, and defers itself with a trigger tied to the event that should resume it. It then removes itself from its current worker and frees up space. +* The new Trigger instance is registered inside Airflow, and picked up by one or more *triggerer* processes +* The trigger is run until it fires, at which point its source task is re-scheduled +* The task instance resumes + +Using deferrable operators as a DAG author is almost transparent; writing them, however, takes a bit more work. + +.. note:: + + Deferrable Operators & Triggers rely on more recent ``asyncio`` features, and as a result only work + on Python 3.7 or higher. + + +Using Deferrable Operators +-------------------------- + +If all you wish to do is use pre-written Deferrable Operators (such as ``TimeSensorAsync``, which comes with Airflow), then there are only two steps you need: + +* Ensure your Airflow installation is running at least one *triggerer* process, as well as the normal *scheduler* +* Use deferrable operators/sensors in your DAGs + +That's it; everything else will be automatically handled for you. If you're upgrading existing DAGs, we even provide some API-compatible sensor variants (e.g. ``TimeSensorAsync`` for ``TimeSensor``) that you can swap into your DAG with no other changes required. + +Note that you cannot yet use the deferral ability from inside custom PythonOperator/TaskFlow code; it is only available to pre-built Operators at the moment. Review comment: ```suggestion Note that you cannot yet use the deferral ability from inside custom PythonOperator/TaskFlow python functions; it is only available to pre-built Operators at the moment. ``` (Cos you can use deferabble operators inside a DAG that uses task flow, you just can't defer a `@task` function, right?) ########## File path: docs/apache-airflow/concepts/deferring.rst ########## @@ -0,0 +1,172 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Deferrable Operators & Triggers +=============================== + +Standard :doc:`Operators <operators>` and :doc:`Sensors <sensors>` take up a full *worker slot* for the entire time they are running, even if they are idle; for example, if you only have 100 worker slots available to run Tasks, and you have 100 DAGs waiting on a Sensor that's currently running but idle, then you *cannot run anything else* - even though your entire Airflow cluster is essentially idle. + +This is where *Deferrable Operators* come in. A deferrable operator is one that is written with the ability to suspend itself and remove itself from the worker when it knows that it will have to wait, and hand off the job of resuming it to something called a *Trigger*. As a result, while it is suspended (deferred), it is not taking up a worker slot and your cluster will have a lot less resources wasted on idle Operators or Sensors. + +*Triggers* are small, asynchronous pieces of Python code designed to be run all together in a single Python process; because they are asynchronous, they are able to all co-exist efficiently. As an overview of how this process works: + +* A task instance (running operator) gets to a point where it has to wait, and defers itself with a trigger tied to the event that should resume it. It then removes itself from its current worker and frees up space. +* The new Trigger instance is registered inside Airflow, and picked up by one or more *triggerer* processes +* The trigger is run until it fires, at which point its source task is re-scheduled +* The task instance resumes + +Using deferrable operators as a DAG author is almost transparent; writing them, however, takes a bit more work. + +.. note:: + + Deferrable Operators & Triggers rely on more recent ``asyncio`` features, and as a result only work + on Python 3.7 or higher. + + +Using Deferrable Operators +-------------------------- + +If all you wish to do is use pre-written Deferrable Operators (such as ``TimeSensorAsync``, which comes with Airflow), then there are only two steps you need: + +* Ensure your Airflow installation is running at least one *triggerer* process, as well as the normal *scheduler* +* Use deferrable operators/sensors in your DAGs + +That's it; everything else will be automatically handled for you. If you're upgrading existing DAGs, we even provide some API-compatible sensor variants (e.g. ``TimeSensorAsync`` for ``TimeSensor``) that you can swap into your DAG with no other changes required. + +Note that you cannot yet use the deferral ability from inside custom PythonOperator/TaskFlow code; it is only available to pre-built Operators at the moment. + + +Writing Deferrable Operators +---------------------------- + +Writing a deferrable operator takes a bit more work. There are some main points to consider: + +* Your Operator must defer itself based on a Trigger. If there is a Trigger in core Airflow you can use, great; otherwise, you will have to write one. +* Your Operator will be deleted and removed from its worker while deferred, and no state will persist automatically. You can persist state by asking Airflow to resume you at a certain method or pass certain kwargs, but that's it. +* You can defer multiple times, and you can defer before/after your Operator does significant work, or only defer if certain conditions are met (e.g. a system does not have an immediate answer). Deferral is entirely under your control. +* Any Operator can defer; no special marking on its class is needed, and it's not limited to Sensors. + + +Triggering Deferral +~~~~~~~~~~~~~~~~~~~ + +If you want to trigger deferral, at any place in your Operator you can call ``self.defer(trigger, method_name, kwargs, timeout)``, which will raise a special exception that Airflow will catch. The arguments are: + +* ``trigger``: An instance of a Trigger that you wish to defer on. It will be serialized into the database. +* ``method_name``: The method name on your Operator you want Airflow to call when it resumes, other than ``execute``. +* ``kwargs``: Additional keyword arguments to pass to the method when it is called. Optional, defaults to ``{}``. +* ``timeout``: A timedelta that specifies a timeout after which this deferral will fail, and fail the task instance. Optional, defaults to ``None``, meaning no timeout. + +When you opt to defer, your Operator will *stop executing at that point and be deleted from a worker*. No state will persist, and when your Operator is resumed it will be a *brand new instance* of it. The only way you can pass state from the old instance of the Operator to the new one is via ``method_name`` and ``kwargs``. Review comment: We shouldn't use "deleted" as it could be confused with deleting the DB row. ########## File path: docs/apache-airflow/concepts/deferring.rst ########## @@ -0,0 +1,172 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Deferrable Operators & Triggers +=============================== + +Standard :doc:`Operators <operators>` and :doc:`Sensors <sensors>` take up a full *worker slot* for the entire time they are running, even if they are idle; for example, if you only have 100 worker slots available to run Tasks, and you have 100 DAGs waiting on a Sensor that's currently running but idle, then you *cannot run anything else* - even though your entire Airflow cluster is essentially idle. + +This is where *Deferrable Operators* come in. A deferrable operator is one that is written with the ability to suspend itself and remove itself from the worker when it knows that it will have to wait, and hand off the job of resuming it to something called a *Trigger*. As a result, while it is suspended (deferred), it is not taking up a worker slot and your cluster will have a lot less resources wasted on idle Operators or Sensors. + +*Triggers* are small, asynchronous pieces of Python code designed to be run all together in a single Python process; because they are asynchronous, they are able to all co-exist efficiently. As an overview of how this process works: + +* A task instance (running operator) gets to a point where it has to wait, and defers itself with a trigger tied to the event that should resume it. It then removes itself from its current worker and frees up space. Review comment: > It then removes itself from its current worker and frees up space. This sounds like something the each operator has to perform, but that isn't the case ```suggestion * A task instance (running operator) gets to a point where it has to wait, and defers itself with a trigger tied to the event that should resume it. The worker is then free to execute another task in that slot. ``` ########## File path: airflow/models/dag.py ########## @@ -1232,23 +1253,13 @@ def clear( ) ) - if start_date: - tis = tis.filter(TI.execution_date >= start_date) - if end_date: - tis = tis.filter(TI.execution_date <= end_date) - if only_failed: - tis = tis.filter(or_(TI.state == State.FAILED, TI.state == State.UPSTREAM_FAILED)) - if only_running: - tis = tis.filter(TI.state == State.RUNNING) - if task_ids: - tis = tis.filter(TI.task_id.in_(task_ids)) - if include_subdags: from airflow.sensors.external_task import ExternalTaskMarker # Recursively find external tasks indicated by ExternalTaskMarker - instances = tis.all() - for ti in instances: + for ti in ( + session.query(TI).filter(tuple_(TI.dag_id, TI.task_id, TI.execution_date).in_(tis)).all() Review comment: mssql doesn't like this "in tuple" approach, and we are working on adding support for mssql in https://github.com/apache/airflow/pull/9973 so this will need to change/have db-specific paths. ########## File path: airflow/models/taskinstance.py ########## @@ -282,6 +296,18 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904 executor_config = Column(PickleType(pickler=dill)) external_executor_id = Column(String(ID_LEN, **COLLATION_ARGS)) + + # The trigger to resume on if we are in state DEFERRED + trigger_id = Column(BigInteger) + + # Optional timeout datetime for the trigger (past this, we'll fail) + trigger_timeout = Column(UtcDateTime) + + # The method to call next, and any extra arguments to pass to it. + # Usually used when resuming from DEFERRED. + next_method = Column(String(1000)) Review comment: ```suggestion next_method = Column(String(1000, **COLLATION_ARGS)) ``` (And similarly in the migration too I think?) ########## File path: airflow/models/taskinstance.py ########## @@ -1137,6 +1180,24 @@ def _run_raw_task( self._prepare_and_execute_task_with_callbacks(context, task) self.refresh_from_db(lock_for_update=True) self.state = State.SUCCESS + except TaskDeferred as defer: + # The task has signalled it wants to defer execution based on + # a trigger. + self._defer_task(defer=defer) + self.log.info(self.state) + self.log.info(self.next_method) Review comment: Left over debug logging? ```suggestion ``` ########## File path: airflow/triggers/temporal.py ########## @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +import datetime +from typing import Any, Dict, Tuple + +import pytz Review comment: We've generally used pendulum for all datetime/timezone shenanigans -- could you see if it does what you need please? ########## File path: airflow/triggers/base.py ########## @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, AsyncIterator, Dict, Tuple + + +class BaseTrigger: Review comment: ```suggestion class BaseTrigger(abc.ABC): ``` And then using `@abstractmethod` where appropriate? ########## File path: airflow/sensors/time_sensor.py ########## @@ -35,3 +36,25 @@ def __init__(self, *, target_time, **kwargs): def poke(self, context): self.log.info('Checking if the time (%s) has come', self.target_time) return timezone.make_naive(timezone.utcnow(), self.dag.timezone).time() > self.target_time + + +class TimeSensorAsync(BaseSensorOperator): + """ + Waits until the specified time of the day, freeing up a worker slot while + it is waiting. + + :param target_time: time after which the job succeeds + :type target_time: datetime.time + """ + + def __init__(self, *, target_time, **kwargs): + super().__init__(**kwargs) + self.target_time = target_time + + def execute(self, context): + self.log.info("Kicking off trigger deferral") Review comment: Debugging log? ```suggestion ``` ########## File path: docs/apache-airflow/concepts/deferring.rst ########## @@ -0,0 +1,172 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Deferrable Operators & Triggers +=============================== + +Standard :doc:`Operators <operators>` and :doc:`Sensors <sensors>` take up a full *worker slot* for the entire time they are running, even if they are idle; for example, if you only have 100 worker slots available to run Tasks, and you have 100 DAGs waiting on a Sensor that's currently running but idle, then you *cannot run anything else* - even though your entire Airflow cluster is essentially idle. + +This is where *Deferrable Operators* come in. A deferrable operator is one that is written with the ability to suspend itself and remove itself from the worker when it knows that it will have to wait, and hand off the job of resuming it to something called a *Trigger*. As a result, while it is suspended (deferred), it is not taking up a worker slot and your cluster will have a lot less resources wasted on idle Operators or Sensors. + +*Triggers* are small, asynchronous pieces of Python code designed to be run all together in a single Python process; because they are asynchronous, they are able to all co-exist efficiently. As an overview of how this process works: + +* A task instance (running operator) gets to a point where it has to wait, and defers itself with a trigger tied to the event that should resume it. It then removes itself from its current worker and frees up space. +* The new Trigger instance is registered inside Airflow, and picked up by one or more *triggerer* processes +* The trigger is run until it fires, at which point its source task is re-scheduled +* The task instance resumes + +Using deferrable operators as a DAG author is almost transparent; writing them, however, takes a bit more work. + +.. note:: + + Deferrable Operators & Triggers rely on more recent ``asyncio`` features, and as a result only work + on Python 3.7 or higher. + + +Using Deferrable Operators +-------------------------- + +If all you wish to do is use pre-written Deferrable Operators (such as ``TimeSensorAsync``, which comes with Airflow), then there are only two steps you need: + +* Ensure your Airflow installation is running at least one *triggerer* process, as well as the normal *scheduler* +* Use deferrable operators/sensors in your DAGs + +That's it; everything else will be automatically handled for you. If you're upgrading existing DAGs, we even provide some API-compatible sensor variants (e.g. ``TimeSensorAsync`` for ``TimeSensor``) that you can swap into your DAG with no other changes required. + +Note that you cannot yet use the deferral ability from inside custom PythonOperator/TaskFlow code; it is only available to pre-built Operators at the moment. + + +Writing Deferrable Operators +---------------------------- + +Writing a deferrable operator takes a bit more work. There are some main points to consider: + +* Your Operator must defer itself based on a Trigger. If there is a Trigger in core Airflow you can use, great; otherwise, you will have to write one. +* Your Operator will be deleted and removed from its worker while deferred, and no state will persist automatically. You can persist state by asking Airflow to resume you at a certain method or pass certain kwargs, but that's it. +* You can defer multiple times, and you can defer before/after your Operator does significant work, or only defer if certain conditions are met (e.g. a system does not have an immediate answer). Deferral is entirely under your control. +* Any Operator can defer; no special marking on its class is needed, and it's not limited to Sensors. + + +Triggering Deferral +~~~~~~~~~~~~~~~~~~~ + +If you want to trigger deferral, at any place in your Operator you can call ``self.defer(trigger, method_name, kwargs, timeout)``, which will raise a special exception that Airflow will catch. The arguments are: + +* ``trigger``: An instance of a Trigger that you wish to defer on. It will be serialized into the database. +* ``method_name``: The method name on your Operator you want Airflow to call when it resumes, other than ``execute``. +* ``kwargs``: Additional keyword arguments to pass to the method when it is called. Optional, defaults to ``{}``. +* ``timeout``: A timedelta that specifies a timeout after which this deferral will fail, and fail the task instance. Optional, defaults to ``None``, meaning no timeout. Review comment: We need to document how this timeout and the existing execution_timeout interact. (I _think_ that `execution_timeout` should be a total ceiling of runtime, i.e. time since task first started, right?) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
