This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6163b396911 Add Dataset, Model asset subclasses (#43142)
6163b396911 is described below

commit 6163b396911bd7046c26396b4797263dc5d376d0
Author: Wei Lee <[email protected]>
AuthorDate: Wed Oct 23 14:10:18 2024 +0800

    Add Dataset, Model asset subclasses (#43142)
    
    * feat(assets): add dataset subclass
    * feat(assets): add model subclass
    * feat(assets): make group a default instead of overwriting user input
    * feat(assets): allow "airflow.Dataset" and "airflow.datasets.Dataset", 
"airflow.datasets.DatasetAlias" import for backward compat
---
 airflow/__init__.py          |  5 +++-
 airflow/assets/__init__.py   | 19 +++++++++++--
 airflow/datasets/__init__.py | 45 ++++++++++++++++++++++++++++++
 tests/assets/test_asset.py   | 65 ++++++++++++++++++++++++++++++++++++++++++++
 4 files changed, 130 insertions(+), 4 deletions(-)

diff --git a/airflow/__init__.py b/airflow/__init__.py
index 287aa499faa..33005ff5e91 100644
--- a/airflow/__init__.py
+++ b/airflow/__init__.py
@@ -65,6 +65,8 @@ __all__ = [
     "DAG",
     "Asset",
     "XComArg",
+    # TODO: Remove this module in Airflow 3.2
+    "Dataset",
 ]
 
 # Perform side-effects unless someone has explicitly opted out before import
@@ -83,12 +85,13 @@ __lazy_imports: dict[str, tuple[str, str, bool]] = {
     "version": (".version", "", False),
     # Deprecated lazy imports
     "AirflowException": (".exceptions", "AirflowException", True),
+    "Dataset": (".assets", "Dataset", True),
 }
 if TYPE_CHECKING:
     # These objects are imported by PEP-562, however, static analyzers and 
IDE's
     # have no idea about typing of these objects.
     # Add it under TYPE_CHECKING block should help with it.
-    from airflow.models.asset import Asset
+    from airflow.assets import Asset, Dataset
     from airflow.models.dag import DAG
     from airflow.models.xcom_arg import XComArg
 
diff --git a/airflow/assets/__init__.py b/airflow/assets/__init__.py
index dcc5484656e..58256929948 100644
--- a/airflow/assets/__init__.py
+++ b/airflow/assets/__init__.py
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
 
 from airflow.configuration import conf
 
-__all__ = ["Asset", "AssetAll", "AssetAny"]
+__all__ = ["Asset", "AssetAll", "AssetAny", "Dataset"]
 
 
 log = logging.getLogger(__name__)
@@ -275,13 +275,14 @@ def _set_extra_default(extra: dict | None) -> dict:
 
 @attr.define(init=False, unsafe_hash=False)
 class Asset(os.PathLike, BaseAsset):
-    """A representation of data dependencies between workflows."""
+    """A representation of data asset dependencies between workflows."""
 
     name: str
     uri: str
     group: str
     extra: dict[str, Any]
 
+    asset_type: ClassVar[str] = ""
     __version__: ClassVar[int] = 1
 
     @overload
@@ -313,7 +314,7 @@ class Asset(os.PathLike, BaseAsset):
         fields = attr.fields_dict(Asset)
         self.name = _validate_non_empty_identifier(self, fields["name"], name)
         self.uri = _sanitize_uri(_validate_non_empty_identifier(self, 
fields["uri"], uri))
-        self.group = _validate_identifier(self, fields["group"], group)
+        self.group = _validate_identifier(self, fields["group"], group) if 
group else self.asset_type
         self.extra = _set_extra_default(extra)
 
     def __fspath__(self) -> str:
@@ -372,6 +373,18 @@ class Asset(os.PathLike, BaseAsset):
         )
 
 
+class Dataset(Asset):
+    """A representation of dataset dependencies between workflows."""
+
+    asset_type: ClassVar[str] = "dataset"
+
+
+class Model(Asset):
+    """A representation of model dependencies between workflows."""
+
+    asset_type: ClassVar[str] = "model"
+
+
 class _AssetBooleanCondition(BaseAsset):
     """Base class for asset boolean logic."""
 
diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py
new file mode 100644
index 00000000000..34729e43780
--- /dev/null
+++ b/airflow/datasets/__init__.py
@@ -0,0 +1,45 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# We do not use "from __future__ import annotations" here because it is not 
supported
+# by Pycharm when we want to make sure all imports in airflow work from 
namespace packages
+# Adding it automatically is excluded in pyproject.toml via I002 ruff rule 
exclusion
+
+# Make `airflow` a namespace package, supporting installing
+# airflow.providers.* in different locations (i.e. one in site, and one in user
+# lib.)  This is required by some IDEs to resolve the import paths.
+from __future__ import annotations
+
+import warnings
+
+from airflow.assets import AssetAlias as DatasetAlias, Dataset
+
+# TODO: Remove this module in Airflow 3.2
+
+warnings.warn(
+    "Import from the airflow.dataset module is deprecated and "
+    "will be removed in the Airflow 3.2. Please import it from 
'airflow.assets'.",
+    DeprecationWarning,
+    stacklevel=2,
+)
+
+
+__all__ = [
+    "Dataset",
+    "DatasetAlias",
+]
diff --git a/tests/assets/test_asset.py b/tests/assets/test_asset.py
index 9243fd59857..f1a0ed13bfa 100644
--- a/tests/assets/test_asset.py
+++ b/tests/assets/test_asset.py
@@ -31,6 +31,8 @@ from airflow.assets import (
     AssetAll,
     AssetAny,
     BaseAsset,
+    Dataset,
+    Model,
     _AssetAliasCondition,
     _get_normalized_scheme,
     _sanitize_uri,
@@ -612,3 +614,66 @@ class Test_AssetAliasCondition:
 
         cond = _AssetAliasCondition(resolved_asset_alias_2.name)
         assert cond.evaluate({asset_1.uri: True}) is True
+
+
+class TestAssetSubclasses:
+    @pytest.mark.parametrize("subcls, group", ((Model, "model"), (Dataset, 
"dataset")))
+    def test_only_name(self, subcls, group):
+        obj = subcls(name="foobar")
+        assert obj.name == "foobar"
+        assert obj.uri == "foobar"
+        assert obj.group == group
+
+    @pytest.mark.parametrize("subcls, group", ((Model, "model"), (Dataset, 
"dataset")))
+    def test_only_uri(self, subcls, group):
+        obj = subcls(uri="s3://bucket/key/path")
+        assert obj.name == "s3://bucket/key/path"
+        assert obj.uri == "s3://bucket/key/path"
+        assert obj.group == group
+
+    @pytest.mark.parametrize("subcls, group", ((Model, "model"), (Dataset, 
"dataset")))
+    def test_both_name_and_uri(self, subcls, group):
+        obj = subcls("foobar", "s3://bucket/key/path")
+        assert obj.name == "foobar"
+        assert obj.uri == "s3://bucket/key/path"
+        assert obj.group == group
+
+    @pytest.mark.parametrize("arg", ["foobar", "s3://bucket/key/path"])
+    @pytest.mark.parametrize("subcls, group", ((Model, "model"), (Dataset, 
"dataset")))
+    def test_only_posarg(self, subcls, group, arg):
+        obj = subcls(arg)
+        assert obj.name == arg
+        assert obj.uri == arg
+        assert obj.group == group
+
+
[email protected](
+    "module_path, attr_name, warning_message",
+    (
+        (
+            "airflow",
+            "Dataset",
+            (
+                "Import 'Dataset' directly from the airflow module is 
deprecated and will be removed in the future. "
+                "Please import it from 'airflow.assets.Dataset'."
+            ),
+        ),
+        (
+            "airflow.datasets",
+            "Dataset",
+            (
+                "Import from the airflow.dataset module is deprecated and "
+                "will be removed in the Airflow 3.2. Please import it from 
'airflow.assets'."
+            ),
+        ),
+    ),
+)
+def test_backward_compat_import_before_airflow_3_2(module_path, attr_name, 
warning_message):
+    with pytest.warns() as record:
+        import importlib
+
+        mod = importlib.import_module(module_path, __name__)
+        getattr(mod, attr_name)
+
+    assert record[0].category is DeprecationWarning
+    assert str(record[0].message) == warning_message

Reply via email to