This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new cd2ad3ccc72 AIP-82 Use `hash` instead of `repr` (#44797)
cd2ad3ccc72 is described below

commit cd2ad3ccc728e8ce30b5f17c5363776b5c6a45fc
Author: Vincent <[email protected]>
AuthorDate: Fri Dec 20 10:25:04 2024 -0500

    AIP-82 Use `hash` instead of `repr` (#44797)
---
 airflow/dag_processing/collection.py | 75 ++++++++++++++++++++++--------------
 1 file changed, 47 insertions(+), 28 deletions(-)

diff --git a/airflow/dag_processing/collection.py 
b/airflow/dag_processing/collection.py
index f3e3b8322ca..e289d50d7f1 100644
--- a/airflow/dag_processing/collection.py
+++ b/airflow/dag_processing/collection.py
@@ -27,9 +27,10 @@ This should generally only be called by internal methods 
such as
 
 from __future__ import annotations
 
+import json
 import logging
 import traceback
-from typing import TYPE_CHECKING, NamedTuple
+from typing import TYPE_CHECKING, Any, NamedTuple
 
 from sqlalchemy import and_, delete, exists, func, select, tuple_
 from sqlalchemy.exc import OperationalError
@@ -50,6 +51,7 @@ from airflow.models.dagwarning import DagWarningType
 from airflow.models.errors import ParseImportError
 from airflow.models.trigger import Trigger
 from airflow.sdk.definitions.asset import Asset, AssetAlias
+from airflow.serialization.serialized_objects import BaseSerialization
 from airflow.triggers.base import BaseTrigger
 from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
 from airflow.utils.sqlalchemy import with_row_locks
@@ -64,6 +66,7 @@ if TYPE_CHECKING:
 
     from airflow.models.dagwarning import DagWarning
     from airflow.serialization.serialized_objects import MaybeSerializedDAG
+    from airflow.triggers.base import BaseTrigger
     from airflow.typing_compat import Self
 
 log = logging.getLogger(__name__)
@@ -652,9 +655,9 @@ class AssetModelOperation(NamedTuple):
         self, assets: dict[tuple[str, str], AssetModel], *, session: Session
     ) -> None:
         # Update references from assets being used
-        refs_to_add: dict[tuple[str, str], set[str]] = {}
-        refs_to_remove: dict[tuple[str, str], set[str]] = {}
-        triggers: dict[str, BaseTrigger] = {}
+        refs_to_add: dict[tuple[str, str], set[int]] = {}
+        refs_to_remove: dict[tuple[str, str], set[int]] = {}
+        triggers: dict[int, BaseTrigger] = {}
 
         # Optimization: if no asset collected, skip fetching active assets
         active_assets = _find_active_assets(self.assets.keys(), 
session=session) if self.assets else {}
@@ -662,40 +665,40 @@ class AssetModelOperation(NamedTuple):
         for name_uri, asset in self.assets.items():
             # If the asset belong to a DAG not active or paused, consider 
there is no watcher associated to it
             asset_watchers = asset.watchers if name_uri in active_assets else 
[]
-            trigger_repr_to_trigger_dict: dict[str, BaseTrigger] = {
-                repr(trigger): trigger for trigger in asset_watchers
+            trigger_hash_to_trigger_dict: dict[int, BaseTrigger] = {
+                self._get_base_trigger_hash(trigger): trigger for trigger in 
asset_watchers
             }
-            triggers.update(trigger_repr_to_trigger_dict)
-            trigger_repr_from_asset: set[str] = 
set(trigger_repr_to_trigger_dict.keys())
+            triggers.update(trigger_hash_to_trigger_dict)
+            trigger_hash_from_asset: set[int] = 
set(trigger_hash_to_trigger_dict.keys())
 
             asset_model = assets[name_uri]
-            trigger_repr_from_asset_model: set[str] = {
-                BaseTrigger.repr(trigger.classpath, trigger.kwargs) for 
trigger in asset_model.triggers
+            trigger_hash_from_asset_model: set[int] = {
+                self._get_trigger_hash(trigger.classpath, trigger.kwargs) for 
trigger in asset_model.triggers
             }
 
             # Optimization: no diff between the DB and DAG definitions, no 
update needed
-            if trigger_repr_from_asset == trigger_repr_from_asset_model:
+            if trigger_hash_from_asset == trigger_hash_from_asset_model:
                 continue
 
-            diff_to_add = trigger_repr_from_asset - 
trigger_repr_from_asset_model
-            diff_to_remove = trigger_repr_from_asset_model - 
trigger_repr_from_asset
+            diff_to_add = trigger_hash_from_asset - 
trigger_hash_from_asset_model
+            diff_to_remove = trigger_hash_from_asset_model - 
trigger_hash_from_asset
             if diff_to_add:
                 refs_to_add[name_uri] = diff_to_add
             if diff_to_remove:
                 refs_to_remove[name_uri] = diff_to_remove
 
         if refs_to_add:
-            all_trigger_reprs: set[str] = {
-                trigger_repr for trigger_reprs in refs_to_add.values() for 
trigger_repr in trigger_reprs
+            all_trigger_hashes: set[int] = {
+                trigger_hash for trigger_hashes in refs_to_add.values() for 
trigger_hash in trigger_hashes
             }
 
             all_trigger_keys: set[tuple[str, str]] = {
-                self._encrypt_trigger_kwargs(triggers[trigger_repr])
-                for trigger_reprs in refs_to_add.values()
-                for trigger_repr in trigger_reprs
+                self._encrypt_trigger_kwargs(triggers[trigger_hash])
+                for trigger_hashes in refs_to_add.values()
+                for trigger_hash in trigger_hashes
             }
-            orm_triggers: dict[str, Trigger] = {
-                BaseTrigger.repr(trigger.classpath, trigger.kwargs): trigger
+            orm_triggers: dict[int, Trigger] = {
+                self._get_trigger_hash(trigger.classpath, trigger.kwargs): 
trigger
                 for trigger in session.scalars(
                     select(Trigger).where(
                         tuple_(Trigger.classpath, 
Trigger.encrypted_kwargs).in_(all_trigger_keys)
@@ -707,32 +710,32 @@ class AssetModelOperation(NamedTuple):
             new_trigger_models = [
                 trigger
                 for trigger in [
-                    Trigger.from_object(triggers[trigger_repr])
-                    for trigger_repr in all_trigger_reprs
-                    if trigger_repr not in orm_triggers
+                    Trigger.from_object(triggers[trigger_hash])
+                    for trigger_hash in all_trigger_hashes
+                    if trigger_hash not in orm_triggers
                 ]
             ]
             session.add_all(new_trigger_models)
             orm_triggers.update(
-                (BaseTrigger.repr(trigger.classpath, trigger.kwargs), trigger)
+                (self._get_trigger_hash(trigger.classpath, trigger.kwargs), 
trigger)
                 for trigger in new_trigger_models
             )
 
             # Add new references
-            for name_uri, trigger_reprs in refs_to_add.items():
+            for name_uri, trigger_hashes in refs_to_add.items():
                 asset_model = assets[name_uri]
                 asset_model.triggers.extend(
-                    [orm_triggers.get(trigger_repr) for trigger_repr in 
trigger_reprs]
+                    [orm_triggers.get(trigger_hash) for trigger_hash in 
trigger_hashes]
                 )
 
         if refs_to_remove:
             # Remove old references
-            for name_uri, trigger_reprs in refs_to_remove.items():
+            for name_uri, trigger_hashes in refs_to_remove.items():
                 asset_model = assets[name_uri]
                 asset_model.triggers = [
                     trigger
                     for trigger in asset_model.triggers
-                    if BaseTrigger.repr(trigger.classpath, trigger.kwargs) not 
in trigger_reprs
+                    if self._get_trigger_hash(trigger.classpath, 
trigger.kwargs) not in trigger_hashes
                 ]
 
         # Remove references from assets no longer used
@@ -747,3 +750,19 @@ class AssetModelOperation(NamedTuple):
     def _encrypt_trigger_kwargs(trigger: BaseTrigger) -> tuple[str, str]:
         classpath, kwargs = trigger.serialize()
         return classpath, Trigger.encrypt_kwargs(kwargs)
+
+    @staticmethod
+    def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int:
+        """
+        Return the hash of the trigger classpath and kwargs. This is used to 
uniquely identify a trigger.
+
+        We do not want to move this logic in a `__hash__` method in 
`BaseTrigger` because we do not want to
+        make the triggers hashable. The reason being, when the triggerer 
retrieve the list of triggers, we do
+        not want it dedupe them. When used to defer tasks, 2 triggers can have 
the same classpath and kwargs.
+        This is not true for event driven scheduling.
+        """
+        return hash((classpath, 
json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))
+
+    def _get_base_trigger_hash(self, trigger: BaseTrigger) -> int:
+        classpath, kwargs = trigger.serialize()
+        return self._get_trigger_hash(classpath, kwargs)

Reply via email to