This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new cd2ad3ccc72 AIP-82 Use `hash` instead of `repr` (#44797)
cd2ad3ccc72 is described below
commit cd2ad3ccc728e8ce30b5f17c5363776b5c6a45fc
Author: Vincent <[email protected]>
AuthorDate: Fri Dec 20 10:25:04 2024 -0500
AIP-82 Use `hash` instead of `repr` (#44797)
---
airflow/dag_processing/collection.py | 75 ++++++++++++++++++++++--------------
1 file changed, 47 insertions(+), 28 deletions(-)
diff --git a/airflow/dag_processing/collection.py
b/airflow/dag_processing/collection.py
index f3e3b8322ca..e289d50d7f1 100644
--- a/airflow/dag_processing/collection.py
+++ b/airflow/dag_processing/collection.py
@@ -27,9 +27,10 @@ This should generally only be called by internal methods
such as
from __future__ import annotations
+import json
import logging
import traceback
-from typing import TYPE_CHECKING, NamedTuple
+from typing import TYPE_CHECKING, Any, NamedTuple
from sqlalchemy import and_, delete, exists, func, select, tuple_
from sqlalchemy.exc import OperationalError
@@ -50,6 +51,7 @@ from airflow.models.dagwarning import DagWarningType
from airflow.models.errors import ParseImportError
from airflow.models.trigger import Trigger
from airflow.sdk.definitions.asset import Asset, AssetAlias
+from airflow.serialization.serialized_objects import BaseSerialization
from airflow.triggers.base import BaseTrigger
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
from airflow.utils.sqlalchemy import with_row_locks
@@ -64,6 +66,7 @@ if TYPE_CHECKING:
from airflow.models.dagwarning import DagWarning
from airflow.serialization.serialized_objects import MaybeSerializedDAG
+ from airflow.triggers.base import BaseTrigger
from airflow.typing_compat import Self
log = logging.getLogger(__name__)
@@ -652,9 +655,9 @@ class AssetModelOperation(NamedTuple):
self, assets: dict[tuple[str, str], AssetModel], *, session: Session
) -> None:
# Update references from assets being used
- refs_to_add: dict[tuple[str, str], set[str]] = {}
- refs_to_remove: dict[tuple[str, str], set[str]] = {}
- triggers: dict[str, BaseTrigger] = {}
+ refs_to_add: dict[tuple[str, str], set[int]] = {}
+ refs_to_remove: dict[tuple[str, str], set[int]] = {}
+ triggers: dict[int, BaseTrigger] = {}
# Optimization: if no asset collected, skip fetching active assets
active_assets = _find_active_assets(self.assets.keys(),
session=session) if self.assets else {}
@@ -662,40 +665,40 @@ class AssetModelOperation(NamedTuple):
for name_uri, asset in self.assets.items():
# If the asset belong to a DAG not active or paused, consider
there is no watcher associated to it
asset_watchers = asset.watchers if name_uri in active_assets else
[]
- trigger_repr_to_trigger_dict: dict[str, BaseTrigger] = {
- repr(trigger): trigger for trigger in asset_watchers
+ trigger_hash_to_trigger_dict: dict[int, BaseTrigger] = {
+ self._get_base_trigger_hash(trigger): trigger for trigger in
asset_watchers
}
- triggers.update(trigger_repr_to_trigger_dict)
- trigger_repr_from_asset: set[str] =
set(trigger_repr_to_trigger_dict.keys())
+ triggers.update(trigger_hash_to_trigger_dict)
+ trigger_hash_from_asset: set[int] =
set(trigger_hash_to_trigger_dict.keys())
asset_model = assets[name_uri]
- trigger_repr_from_asset_model: set[str] = {
- BaseTrigger.repr(trigger.classpath, trigger.kwargs) for
trigger in asset_model.triggers
+ trigger_hash_from_asset_model: set[int] = {
+ self._get_trigger_hash(trigger.classpath, trigger.kwargs) for
trigger in asset_model.triggers
}
# Optimization: no diff between the DB and DAG definitions, no
update needed
- if trigger_repr_from_asset == trigger_repr_from_asset_model:
+ if trigger_hash_from_asset == trigger_hash_from_asset_model:
continue
- diff_to_add = trigger_repr_from_asset -
trigger_repr_from_asset_model
- diff_to_remove = trigger_repr_from_asset_model -
trigger_repr_from_asset
+ diff_to_add = trigger_hash_from_asset -
trigger_hash_from_asset_model
+ diff_to_remove = trigger_hash_from_asset_model -
trigger_hash_from_asset
if diff_to_add:
refs_to_add[name_uri] = diff_to_add
if diff_to_remove:
refs_to_remove[name_uri] = diff_to_remove
if refs_to_add:
- all_trigger_reprs: set[str] = {
- trigger_repr for trigger_reprs in refs_to_add.values() for
trigger_repr in trigger_reprs
+ all_trigger_hashes: set[int] = {
+ trigger_hash for trigger_hashes in refs_to_add.values() for
trigger_hash in trigger_hashes
}
all_trigger_keys: set[tuple[str, str]] = {
- self._encrypt_trigger_kwargs(triggers[trigger_repr])
- for trigger_reprs in refs_to_add.values()
- for trigger_repr in trigger_reprs
+ self._encrypt_trigger_kwargs(triggers[trigger_hash])
+ for trigger_hashes in refs_to_add.values()
+ for trigger_hash in trigger_hashes
}
- orm_triggers: dict[str, Trigger] = {
- BaseTrigger.repr(trigger.classpath, trigger.kwargs): trigger
+ orm_triggers: dict[int, Trigger] = {
+ self._get_trigger_hash(trigger.classpath, trigger.kwargs):
trigger
for trigger in session.scalars(
select(Trigger).where(
tuple_(Trigger.classpath,
Trigger.encrypted_kwargs).in_(all_trigger_keys)
@@ -707,32 +710,32 @@ class AssetModelOperation(NamedTuple):
new_trigger_models = [
trigger
for trigger in [
- Trigger.from_object(triggers[trigger_repr])
- for trigger_repr in all_trigger_reprs
- if trigger_repr not in orm_triggers
+ Trigger.from_object(triggers[trigger_hash])
+ for trigger_hash in all_trigger_hashes
+ if trigger_hash not in orm_triggers
]
]
session.add_all(new_trigger_models)
orm_triggers.update(
- (BaseTrigger.repr(trigger.classpath, trigger.kwargs), trigger)
+ (self._get_trigger_hash(trigger.classpath, trigger.kwargs),
trigger)
for trigger in new_trigger_models
)
# Add new references
- for name_uri, trigger_reprs in refs_to_add.items():
+ for name_uri, trigger_hashes in refs_to_add.items():
asset_model = assets[name_uri]
asset_model.triggers.extend(
- [orm_triggers.get(trigger_repr) for trigger_repr in
trigger_reprs]
+ [orm_triggers.get(trigger_hash) for trigger_hash in
trigger_hashes]
)
if refs_to_remove:
# Remove old references
- for name_uri, trigger_reprs in refs_to_remove.items():
+ for name_uri, trigger_hashes in refs_to_remove.items():
asset_model = assets[name_uri]
asset_model.triggers = [
trigger
for trigger in asset_model.triggers
- if BaseTrigger.repr(trigger.classpath, trigger.kwargs) not
in trigger_reprs
+ if self._get_trigger_hash(trigger.classpath,
trigger.kwargs) not in trigger_hashes
]
# Remove references from assets no longer used
@@ -747,3 +750,19 @@ class AssetModelOperation(NamedTuple):
def _encrypt_trigger_kwargs(trigger: BaseTrigger) -> tuple[str, str]:
classpath, kwargs = trigger.serialize()
return classpath, Trigger.encrypt_kwargs(kwargs)
+
+ @staticmethod
+ def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int:
+ """
+ Return the hash of the trigger classpath and kwargs. This is used to
uniquely identify a trigger.
+
+ We do not want to move this logic in a `__hash__` method in
`BaseTrigger` because we do not want to
+ make the triggers hashable. The reason being, when the triggerer
retrieve the list of triggers, we do
+ not want it dedupe them. When used to defer tasks, 2 triggers can have
the same classpath and kwargs.
+ This is not true for event driven scheduling.
+ """
+ return hash((classpath,
json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))
+
+ def _get_base_trigger_hash(self, trigger: BaseTrigger) -> int:
+ classpath, kwargs = trigger.serialize()
+ return self._get_trigger_hash(classpath, kwargs)