This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1c4a00bb4e Try to make dataset objects totally unhashable (#42054)
1c4a00bb4e is described below

commit 1c4a00bb4e15de083aea0c2a0ffe14ea97955c70
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Fri Sep 6 15:16:53 2024 +0800

    Try to make dataset objects totally unhashable (#42054)
---
 airflow/datasets/__init__.py                | 22 +++-------------------
 airflow/lineage/hook.py                     |  2 +-
 airflow/models/dag.py                       |  4 ++--
 newsfragments/42054.significant.rst         |  4 ++++
 tests/datasets/test_dataset.py              | 14 ++++----------
 tests/lineage/test_hook.py                  |  4 ++--
 tests/timetables/test_datasets_timetable.py |  2 +-
 7 files changed, 17 insertions(+), 35 deletions(-)

diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py
index 55d947544c..d4305eeb04 100644
--- a/airflow/datasets/__init__.py
+++ b/airflow/datasets/__init__.py
@@ -206,20 +206,12 @@ class BaseDataset:
         raise NotImplementedError
 
 
[email protected]()
[email protected](unsafe_hash=False)
 class DatasetAlias(BaseDataset):
     """A represeation of dataset alias which is used to create dataset during 
the runtime."""
 
     name: str
 
-    def __eq__(self, other: Any) -> bool:
-        if isinstance(other, DatasetAlias):
-            return self.name == other.name
-        return NotImplemented
-
-    def __hash__(self) -> int:
-        return hash(self.name)
-
     def iter_dag_dependencies(self, *, source: str, target: str) -> 
Iterator[DagDependency]:
         """
         Iterate a dataset alias as dag dependency.
@@ -241,7 +233,7 @@ class DatasetAliasEvent(TypedDict):
     dest_dataset_uri: str
 
 
[email protected]()
[email protected](unsafe_hash=False)
 class Dataset(os.PathLike, BaseDataset):
     """A representation of data dependencies between workflows."""
 
@@ -249,21 +241,13 @@ class Dataset(os.PathLike, BaseDataset):
         converter=_sanitize_uri,
         validator=[attr.validators.min_len(1), attr.validators.max_len(3000)],
     )
-    extra: dict[str, Any] | None = None
+    extra: dict[str, Any] = attr.field(factory=dict)
 
     __version__: ClassVar[int] = 1
 
     def __fspath__(self) -> str:
         return self.uri
 
-    def __eq__(self, other: Any) -> bool:
-        if isinstance(other, self.__class__):
-            return self.uri == other.uri
-        return NotImplemented
-
-    def __hash__(self) -> int:
-        return hash(self.uri)
-
     @property
     def normalized_uri(self) -> str | None:
         """
diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py
index 4ff35e4d9c..45227c5248 100644
--- a/airflow/lineage/hook.py
+++ b/airflow/lineage/hook.py
@@ -112,7 +112,7 @@ class HookLineageCollector(LoggingMixin):
         """
         if uri:
             # Fallback to default factory using the provided URI
-            return Dataset(uri=uri, extra=dataset_extra)
+            return Dataset(uri=uri, extra=dataset_extra or {})
 
         if not scheme:
             self.log.debug(
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 54ebce7392..56f7dc89d2 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2786,12 +2786,12 @@ class DAG(LoggingMixin):
             curr_outlet_references = curr_orm_dag and 
curr_orm_dag.task_outlet_dataset_references
             for task in dag.tasks:
                 dataset_outlets: list[Dataset] = []
-                dataset_alias_outlets: set[DatasetAlias] = set()
+                dataset_alias_outlets: list[DatasetAlias] = []
                 for outlet in task.outlets:
                     if isinstance(outlet, Dataset):
                         dataset_outlets.append(outlet)
                     elif isinstance(outlet, DatasetAlias):
-                        dataset_alias_outlets.add(outlet)
+                        dataset_alias_outlets.append(outlet)
 
                 if not dataset_outlets:
                     if curr_outlet_references:
diff --git a/newsfragments/42054.significant.rst 
b/newsfragments/42054.significant.rst
new file mode 100644
index 0000000000..aebf70757f
--- /dev/null
+++ b/newsfragments/42054.significant.rst
@@ -0,0 +1,4 @@
+Dataset and DatasetAlias are no longer hashable
+
+This means they can no longer be used as dict keys or put into a set. Dataset's
+equality logic is also tweaked slightly to consider the extra dict.
diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py
index 017c874763..940e445669 100644
--- a/tests/datasets/test_dataset.py
+++ b/tests/datasets/test_dataset.py
@@ -120,12 +120,6 @@ def test_not_equal_when_different_uri():
     assert dataset1 != dataset2
 
 
-def test_hash():
-    uri = "s3://example/dataset"
-    dataset = Dataset(uri=uri)
-    hash(dataset)
-
-
 def test_dataset_logic_operations():
     result_or = dataset1 | dataset2
     assert isinstance(result_or, DatasetAny)
@@ -187,10 +181,10 @@ def test_datasetbooleancondition_evaluate_iter():
     assert all_condition.evaluate({"s3://bucket1/data1": True, 
"s3://bucket2/data2": False}) is False
 
     # Testing iter_datasets indirectly through the subclasses
-    datasets_any = set(any_condition.iter_datasets())
-    datasets_all = set(all_condition.iter_datasets())
-    assert datasets_any == {("s3://bucket1/data1", dataset1), 
("s3://bucket2/data2", dataset2)}
-    assert datasets_all == {("s3://bucket1/data1", dataset1), 
("s3://bucket2/data2", dataset2)}
+    datasets_any = dict(any_condition.iter_datasets())
+    datasets_all = dict(all_condition.iter_datasets())
+    assert datasets_any == {"s3://bucket1/data1": dataset1, 
"s3://bucket2/data2": dataset2}
+    assert datasets_all == {"s3://bucket1/data1": dataset1, 
"s3://bucket2/data2": dataset2}
 
 
 @pytest.mark.parametrize(
diff --git a/tests/lineage/test_hook.py b/tests/lineage/test_hook.py
index 97e160e1f9..16f386b684 100644
--- a/tests/lineage/test_hook.py
+++ b/tests/lineage/test_hook.py
@@ -69,7 +69,7 @@ class TestHookLineageCollector:
         self.collector.add_input_dataset(hook, uri="test_uri")
 
         assert next(iter(self.collector._inputs.values())) == (dataset, hook)
-        mock_dataset.assert_called_once_with(uri="test_uri", extra=None)
+        mock_dataset.assert_called_once_with(uri="test_uri", extra={})
 
     def test_grouping_datasets(self):
         hook_1 = MagicMock()
@@ -96,7 +96,7 @@ class TestHookLineageCollector:
     @patch("airflow.lineage.hook.ProvidersManager")
     def test_create_dataset(self, mock_providers_manager):
         def create_dataset(arg1, arg2="default", extra=None):
-            return Dataset(uri=f"myscheme://{arg1}/{arg2}", extra=extra)
+            return Dataset(uri=f"myscheme://{arg1}/{arg2}", extra=extra or {})
 
         mock_providers_manager.return_value.dataset_factories = {"myscheme": 
create_dataset}
         assert self.collector.create_dataset(
diff --git a/tests/timetables/test_datasets_timetable.py 
b/tests/timetables/test_datasets_timetable.py
index b055f0d34d..b456b9bf5d 100644
--- a/tests/timetables/test_datasets_timetable.py
+++ b/tests/timetables/test_datasets_timetable.py
@@ -134,7 +134,7 @@ def test_serialization(dataset_timetable: 
DatasetOrTimeSchedule, monkeypatch: An
         "timetable": "mock_serialized_timetable",
         "dataset_condition": {
             "__type": "dataset_all",
-            "objects": [{"__type": "dataset", "uri": "test_dataset", "extra": 
None}],
+            "objects": [{"__type": "dataset", "uri": "test_dataset", "extra": 
{}}],
         },
     }
 

Reply via email to