This is an automated email from the ASF dual-hosted git repository.

utkarsharma pushed a commit to branch v2-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v2-10-test by this push:
     new 957856fe77 allow dataset alias to add more than one dataset events 
(#42189) (#42247)
957856fe77 is described below

commit 957856fe77ccd3b359bff49b4fbcfa4c6fbcbbc4
Author: Utkarsh Sharma <[email protected]>
AuthorDate: Mon Sep 16 18:11:55 2024 +0530

    allow dataset alias to add more than one dataset events (#42189) (#42247)
    
    (cherry picked from commit a5d0a63d8784d7f4100a4770748c783261968e3c)
    
    Co-authored-by: Wei Lee <[email protected]>
---
 airflow/datasets/__init__.py                   |  1 +
 airflow/models/taskinstance.py                 |  6 +++---
 airflow/serialization/serialized_objects.py    | 17 +++++++++++----
 airflow/utils/context.py                       | 10 ++++-----
 airflow/utils/context.pyi                      |  4 ++--
 tests/serialization/test_serialized_objects.py | 14 ++++++-------
 tests/utils/test_context.py                    | 29 ++++++++++++++------------
 7 files changed, 45 insertions(+), 36 deletions(-)

diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py
index 55d947544c..80ed083c4b 100644
--- a/airflow/datasets/__init__.py
+++ b/airflow/datasets/__init__.py
@@ -239,6 +239,7 @@ class DatasetAliasEvent(TypedDict):
 
     source_alias_name: str
     dest_dataset_uri: str
+    extra: dict[str, Any]
 
 
 @attr.define()
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 0c34d35024..b34c71bc9f 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -3030,11 +3030,11 @@ class TaskInstance(Base, LoggingMixin):
                     session=session,
                 )
             elif isinstance(obj, DatasetAlias):
-                if dataset_alias_event := events[obj].dataset_alias_event:
+                for dataset_alias_event in events[obj].dataset_alias_events:
+                    dataset_alias_name = 
dataset_alias_event["source_alias_name"]
                     dataset_uri = dataset_alias_event["dest_dataset_uri"]
-                    extra = events[obj].extra
+                    extra = dataset_alias_event["extra"]
                     frozen_extra = frozenset(extra.items())
-                    dataset_alias_name = 
dataset_alias_event["source_alias_name"]
 
                     dataset_tuple_to_alias_names_mapping[(dataset_uri, 
frozen_extra)].add(dataset_alias_name)
 
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 84ad567918..9eb3332a6b 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -286,15 +286,24 @@ def encode_outlet_event_accessor(var: 
OutletEventAccessor) -> dict[str, Any]:
     raw_key = var.raw_key
     return {
         "extra": var.extra,
-        "dataset_alias_event": var.dataset_alias_event,
+        "dataset_alias_events": var.dataset_alias_events,
         "raw_key": BaseSerialization.serialize(raw_key),
     }
 
 
 def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor:
-    raw_key = BaseSerialization.deserialize(var["raw_key"])
-    outlet_event_accessor = OutletEventAccessor(extra=var["extra"], 
raw_key=raw_key)
-    outlet_event_accessor.dataset_alias_event = var["dataset_alias_event"]
+    # This is added for compatibility. The attribute used to be 
dataset_alias_event and
+    # is now dataset_alias_events.
+    if dataset_alias_event := var.get("dataset_alias_event", None):
+        dataset_alias_events = [dataset_alias_event]
+    else:
+        dataset_alias_events = var.get("dataset_alias_events", [])
+
+    outlet_event_accessor = OutletEventAccessor(
+        extra=var["extra"],
+        raw_key=BaseSerialization.deserialize(var["raw_key"]),
+        dataset_alias_events=dataset_alias_events,
+    )
     return outlet_event_accessor
 
 
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index c2a0ad7052..a72885401f 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -172,7 +172,7 @@ class OutletEventAccessor:
 
     raw_key: str | Dataset | DatasetAlias
     extra: dict[str, Any] = attrs.Factory(dict)
-    dataset_alias_event: DatasetAliasEvent | None = None
+    dataset_alias_events: list[DatasetAliasEvent] = attrs.field(factory=list)
 
     def add(self, dataset: Dataset | str, extra: dict[str, Any] | None = None) 
-> None:
         """Add a DatasetEvent to an existing Dataset."""
@@ -190,12 +190,10 @@ class OutletEventAccessor:
         else:
             return
 
-        if extra:
-            self.extra = extra
-
-        self.dataset_alias_event = DatasetAliasEvent(
-            source_alias_name=dataset_alias_name, dest_dataset_uri=dataset_uri
+        event = DatasetAliasEvent(
+            source_alias_name=dataset_alias_name, 
dest_dataset_uri=dataset_uri, extra=extra or {}
         )
+        self.dataset_alias_events.append(event)
 
 
 class OutletEventAccessors(Mapping[str, OutletEventAccessor]):
diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi
index d3546286cf..658aac5839 100644
--- a/airflow/utils/context.pyi
+++ b/airflow/utils/context.pyi
@@ -63,12 +63,12 @@ class OutletEventAccessor:
         *,
         extra: dict[str, Any],
         raw_key: str | Dataset | DatasetAlias,
-        dataset_alias_event: DatasetAliasEvent | None = None,
+        dataset_alias_events: list[DatasetAliasEvent],
     ) -> None: ...
     def add(self, dataset: Dataset | str, extra: dict[str, Any] | None = None) 
-> None: ...
     extra: dict[str, Any]
     raw_key: str | Dataset | DatasetAlias
-    dataset_alias_event: DatasetAliasEvent | None
+    dataset_alias_events: list[DatasetAliasEvent]
 
 class OutletEventAccessors(Mapping[str, OutletEventAccessor]):
     def __iter__(self) -> Iterator[str]: ...
diff --git a/tests/serialization/test_serialized_objects.py 
b/tests/serialization/test_serialized_objects.py
index 661ecbf5dc..82d8c16f3f 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -163,7 +163,7 @@ def equal_exception(a: AirflowException, b: 
AirflowException) -> bool:
 
 
 def equal_outlet_event_accessor(a: OutletEventAccessor, b: 
OutletEventAccessor) -> bool:
-    return a.raw_key == b.raw_key and a.extra == b.extra and 
a.dataset_alias_event == b.dataset_alias_event
+    return a.raw_key == b.raw_key and a.extra == b.extra and 
a.dataset_alias_events == b.dataset_alias_events
 
 
 class MockLazySelectSequence(LazySelectSequence):
@@ -240,9 +240,7 @@ class MockLazySelectSequence(LazySelectSequence):
             lambda a, b: a.get_uri() == b.get_uri(),
         ),
         (
-            OutletEventAccessor(
-                raw_key=Dataset(uri="test"), extra={"key": "value"}, 
dataset_alias_event=None
-            ),
+            OutletEventAccessor(raw_key=Dataset(uri="test"), extra={"key": 
"value"}, dataset_alias_events=[]),
             DAT.DATASET_EVENT_ACCESSOR,
             equal_outlet_event_accessor,
         ),
@@ -250,15 +248,15 @@ class MockLazySelectSequence(LazySelectSequence):
             OutletEventAccessor(
                 raw_key=DatasetAlias(name="test_alias"),
                 extra={"key": "value"},
-                dataset_alias_event=DatasetAliasEvent(
-                    source_alias_name="test_alias", dest_dataset_uri="test_uri"
-                ),
+                dataset_alias_events=[
+                    DatasetAliasEvent(source_alias_name="test_alias", 
dest_dataset_uri="test_uri", extra={})
+                ],
             ),
             DAT.DATASET_EVENT_ACCESSOR,
             equal_outlet_event_accessor,
         ),
         (
-            OutletEventAccessor(raw_key="test", extra={"key": "value"}),
+            OutletEventAccessor(raw_key="test", extra={"key": "value"}, 
dataset_alias_events=[]),
             DAT.DATASET_EVENT_ACCESSOR,
             equal_outlet_event_accessor,
         ),
diff --git a/tests/utils/test_context.py b/tests/utils/test_context.py
index 1237be2f8d..0f4f80f365 100644
--- a/tests/utils/test_context.py
+++ b/tests/utils/test_context.py
@@ -27,41 +27,44 @@ from airflow.utils.context import OutletEventAccessor, 
OutletEventAccessors
 
 class TestOutletEventAccessor:
     @pytest.mark.parametrize(
-        "raw_key, dataset_alias_event",
+        "raw_key, dataset_alias_events",
         (
             (
                 DatasetAlias("test_alias"),
-                DatasetAliasEvent(source_alias_name="test_alias", 
dest_dataset_uri="test_uri"),
+                [DatasetAliasEvent(source_alias_name="test_alias", 
dest_dataset_uri="test_uri", extra={})],
             ),
-            (Dataset("test_uri"), None),
+            (Dataset("test_uri"), []),
         ),
     )
-    def test_add(self, raw_key, dataset_alias_event):
+    def test_add(self, raw_key, dataset_alias_events):
         outlet_event_accessor = OutletEventAccessor(raw_key=raw_key, extra={})
         outlet_event_accessor.add(Dataset("test_uri"))
-        assert outlet_event_accessor.dataset_alias_event == dataset_alias_event
+        assert outlet_event_accessor.dataset_alias_events == 
dataset_alias_events
 
     @pytest.mark.db_test
     @pytest.mark.parametrize(
-        "raw_key, dataset_alias_event",
+        "raw_key, dataset_alias_events",
         (
             (
                 DatasetAlias("test_alias"),
-                DatasetAliasEvent(source_alias_name="test_alias", 
dest_dataset_uri="test_uri"),
+                [DatasetAliasEvent(source_alias_name="test_alias", 
dest_dataset_uri="test_uri", extra={})],
             ),
-            ("test_alias", DatasetAliasEvent(source_alias_name="test_alias", 
dest_dataset_uri="test_uri")),
-            (Dataset("test_uri"), None),
+            (
+                "test_alias",
+                [DatasetAliasEvent(source_alias_name="test_alias", 
dest_dataset_uri="test_uri", extra={})],
+            ),
+            (Dataset("test_uri"), []),
         ),
     )
-    def test_add_with_db(self, raw_key, dataset_alias_event, session):
+    def test_add_with_db(self, raw_key, dataset_alias_events, session):
         dsm = DatasetModel(uri="test_uri")
         dsam = DatasetAliasModel(name="test_alias")
         session.add_all([dsm, dsam])
         session.flush()
 
-        outlet_event_accessor = OutletEventAccessor(raw_key=raw_key, extra={})
-        outlet_event_accessor.add("test_uri")
-        assert outlet_event_accessor.dataset_alias_event == dataset_alias_event
+        outlet_event_accessor = OutletEventAccessor(raw_key=raw_key, 
extra={"not": ""})
+        outlet_event_accessor.add("test_uri", extra={})
+        assert outlet_event_accessor.dataset_alias_events == 
dataset_alias_events
 
 
 class TestOutletEventAccessors:

Reply via email to