This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8a912f9fa0 [AIP-62] Translate AIP-60 URI to OpenLineage (#40173)
8a912f9fa0 is described below
commit 8a912f9fa00bf25763e70323f15eef5f94826495
Author: Kacper Muda <[email protected]>
AuthorDate: Tue Jul 23 11:04:14 2024 +0200
[AIP-62] Translate AIP-60 URI to OpenLineage (#40173)
* aip-62: implement translation mechanism from aip-60 to OpenLineage
Signed-off-by: Kacper Muda <[email protected]>
* aip-62: implement translation examples from aip-60 to OpenLineage
Signed-off-by: Kacper Muda <[email protected]>
---------
Signed-off-by: Kacper Muda <[email protected]>
---
airflow/datasets/__init__.py | 30 ++++++++-
airflow/provider.yaml.schema.json | 4 ++
airflow/providers/amazon/aws/datasets/s3.py | 22 +++++++
airflow/providers/amazon/provider.yaml | 3 +-
airflow/providers/common/io/datasets/file.py | 26 ++++++++
airflow/providers/common/io/provider.yaml | 3 +-
airflow/providers/openlineage/utils/utils.py | 30 +++++++++
airflow/providers_manager.py | 82 +++++++++++++++++--------
tests/datasets/test_dataset.py | 48 +++++++++++++--
tests/providers/amazon/aws/datasets/test_s3.py | 54 +++++++++++++++-
tests/providers/common/io/datasets/test_file.py | 45 +++++++++++++-
11 files changed, 312 insertions(+), 35 deletions(-)
diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py
index 90cb52b6ad..dd27338189 100644
--- a/airflow/datasets/__init__.py
+++ b/airflow/datasets/__init__.py
@@ -56,6 +56,11 @@ def _get_uri_normalizer(scheme: str) ->
Callable[[SplitResult], SplitResult] | N
return ProvidersManager().dataset_uri_handlers.get(scheme)
+def _get_normalized_scheme(uri: str) -> str:
+ parsed = urllib.parse.urlsplit(uri)
+ return parsed.scheme.lower()
+
+
def _sanitize_uri(uri: str) -> str:
"""
Sanitize a dataset URI.
@@ -72,7 +77,8 @@ def _sanitize_uri(uri: str) -> str:
parsed = urllib.parse.urlsplit(uri)
if not parsed.scheme and not parsed.netloc: # Does not look like a URI.
return uri
- normalized_scheme = parsed.scheme.lower()
+ if not (normalized_scheme := _get_normalized_scheme(uri)):
+ return uri
if normalized_scheme.startswith("x-"):
return uri
if normalized_scheme == "airflow":
@@ -231,6 +237,28 @@ class Dataset(os.PathLike, BaseDataset):
def __hash__(self) -> int:
return hash(self.uri)
+ @property
+ def normalized_uri(self) -> str | None:
+ """
+ Returns the normalized and AIP-60 compliant URI whenever possible.
+
+ If we can't retrieve the scheme from URI or no normalizer is provided
or if parsing fails,
+ it returns None.
+
+ If a normalizer for the scheme exists and parsing is successful we
return the normalizer result.
+ """
+ if not (normalized_scheme := _get_normalized_scheme(self.uri)):
+ return None
+
+ if (normalizer := _get_uri_normalizer(normalized_scheme)) is None:
+ return None
+ parsed = urllib.parse.urlsplit(self.uri)
+ try:
+ normalized_uri = normalizer(parsed)
+ return urllib.parse.urlunsplit(normalized_uri)
+ except ValueError:
+ return None
+
def as_expression(self) -> Any:
"""
Serialize the dataset into its scheduling expression.
diff --git a/airflow/provider.yaml.schema.json
b/airflow/provider.yaml.schema.json
index adbca7846d..8f11833ee1 100644
--- a/airflow/provider.yaml.schema.json
+++ b/airflow/provider.yaml.schema.json
@@ -216,6 +216,10 @@
"factory": {
"type": ["string", "null"],
"description": "Dataset factory for specified URI.
Creates AIP-60 compliant Dataset."
+ },
+ "to_openlineage_converter": {
+ "type": ["string", "null"],
+ "description": "OpenLineage converter function for
specified URI schemes. Import path to a callable accepting a Dataset and
LineageContext and returning OpenLineage dataset."
}
}
}
diff --git a/airflow/providers/amazon/aws/datasets/s3.py
b/airflow/providers/amazon/aws/datasets/s3.py
index 89889efe57..e6bed6dbe3 100644
--- a/airflow/providers/amazon/aws/datasets/s3.py
+++ b/airflow/providers/amazon/aws/datasets/s3.py
@@ -16,8 +16,30 @@
# under the License.
from __future__ import annotations
+from typing import TYPE_CHECKING
+
from airflow.datasets import Dataset
+from airflow.providers.amazon.aws.hooks.s3 import S3Hook
+
+if TYPE_CHECKING:
+ from urllib.parse import SplitResult
+
+ from openlineage.client.run import Dataset as OpenLineageDataset
def create_dataset(*, bucket: str, key: str, extra=None) -> Dataset:
return Dataset(uri=f"s3://{bucket}/{key}", extra=extra)
+
+
+def sanitize_uri(uri: SplitResult) -> SplitResult:
+ if not uri.netloc:
+ raise ValueError("URI format s3:// must contain a bucket name")
+ return uri
+
+
+def convert_dataset_to_openlineage(dataset: Dataset, lineage_context) ->
OpenLineageDataset:
+ """Translate Dataset with valid AIP-60 uri to OpenLineage with assistance
from the hook."""
+ from openlineage.client.run import Dataset as OpenLineageDataset
+
+ bucket, key = S3Hook.parse_s3_url(dataset.uri)
+ return OpenLineageDataset(namespace=f"s3://{bucket}", name=key if key else
"/")
diff --git a/airflow/providers/amazon/provider.yaml
b/airflow/providers/amazon/provider.yaml
index 9dd76ac9fa..309abcc23a 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -561,7 +561,8 @@ sensors:
dataset-uris:
- schemes: [s3]
- handler: null
+ handler: airflow.providers.amazon.aws.datasets.s3.sanitize_uri
+ to_openlineage_converter:
airflow.providers.amazon.aws.datasets.s3.convert_dataset_to_openlineage
factory: airflow.providers.amazon.aws.datasets.s3.create_dataset
filesystems:
diff --git a/airflow/providers/common/io/datasets/file.py
b/airflow/providers/common/io/datasets/file.py
index 1bc4969762..aa7e8d98be 100644
--- a/airflow/providers/common/io/datasets/file.py
+++ b/airflow/providers/common/io/datasets/file.py
@@ -16,9 +16,35 @@
# under the License.
from __future__ import annotations
+import urllib.parse
+from typing import TYPE_CHECKING
+
from airflow.datasets import Dataset
+if TYPE_CHECKING:
+ from urllib.parse import SplitResult
+
+ from openlineage.client.run import Dataset as OpenLineageDataset
+
def create_dataset(*, path: str, extra=None) -> Dataset:
# We assume that we get absolute path starting with /
return Dataset(uri=f"file://{path}", extra=extra)
+
+
+def sanitize_uri(uri: SplitResult) -> SplitResult:
+ if not uri.path:
+ raise ValueError("URI format file:// must contain a non-empty path.")
+ return uri
+
+
+def convert_dataset_to_openlineage(dataset: Dataset, lineage_context) ->
OpenLineageDataset:
+ """
+ Translate Dataset with valid AIP-60 uri to OpenLineage with assistance
from the context.
+
+ Windows paths are not standardized and can produce unexpected behaviour.
+ """
+ from openlineage.client.run import Dataset as OpenLineageDataset
+
+ parsed = urllib.parse.urlsplit(dataset.uri)
+ return OpenLineageDataset(namespace=f"file://{parsed.netloc}",
name=parsed.path)
diff --git a/airflow/providers/common/io/provider.yaml
b/airflow/providers/common/io/provider.yaml
index a45d3d7dfe..e644b3f070 100644
--- a/airflow/providers/common/io/provider.yaml
+++ b/airflow/providers/common/io/provider.yaml
@@ -53,7 +53,8 @@ xcom:
dataset-uris:
- schemes: [file]
- handler: null
+ handler: airflow.providers.common.io.datasets.file.sanitize_uri
+ to_openlineage_converter:
airflow.providers.common.io.datasets.file.convert_dataset_to_openlineage
factory: airflow.providers.common.io.datasets.file.create_dataset
config:
diff --git a/airflow/providers/openlineage/utils/utils.py
b/airflow/providers/openlineage/utils/utils.py
index 0689ea3977..a36f44b3d5 100644
--- a/airflow/providers/openlineage/utils/utils.py
+++ b/airflow/providers/openlineage/utils/utils.py
@@ -55,6 +55,8 @@ from airflow.utils.log.secrets_masker import Redactable,
Redacted, SecretsMasker
from airflow.utils.module_loading import import_string
if TYPE_CHECKING:
+ from openlineage.client.run import Dataset as OpenLineageDataset
+
from airflow.models import DagRun, TaskInstance
@@ -635,3 +637,31 @@ def should_use_external_connection(hook) -> bool:
if not _IS_AIRFLOW_2_10_OR_HIGHER:
return hook.__class__.__name__ not in ["SnowflakeHook",
"SnowflakeSqlApiHook", "RedshiftSQLHook"]
return True
+
+
+def translate_airflow_dataset(dataset: Dataset, lineage_context) ->
OpenLineageDataset | None:
+ """
+ Convert a Dataset with an AIP-60 compliant URI to an OpenLineageDataset.
+
+ This function returns None if no URI normalizer is defined, no dataset
converter is found or
+ some core Airflow changes are missing and ImportError is raised.
+ """
+ try:
+ from airflow.datasets import _get_normalized_scheme
+ from airflow.providers_manager import ProvidersManager
+
+ ol_converters = ProvidersManager().dataset_to_openlineage_converters
+ normalized_uri = dataset.normalized_uri
+ except (ImportError, AttributeError):
+ return None
+
+ if normalized_uri is None:
+ return None
+
+ if not (normalized_scheme := _get_normalized_scheme(normalized_uri)):
+ return None
+
+ if (airflow_to_ol_converter := ol_converters.get(normalized_scheme)) is
None:
+ return None
+
+ return airflow_to_ol_converter(Dataset(uri=normalized_uri,
extra=dataset.extra), lineage_context)
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index f6d29a51d1..dd3e841fa1 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -428,6 +428,7 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
self._fs_set: set[str] = set()
self._dataset_uri_handlers: dict[str, Callable[[SplitResult],
SplitResult]] = {}
self._dataset_factories: dict[str, Callable[..., Dataset]] = {}
+ self._dataset_to_openlineage_converters: dict[str, Callable] = {}
self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache()
# type: ignore[assignment]
# keeps mapping between connection_types and hook class, package they
come from
self._hook_provider_dict: dict[str, HookClassProvider] = {}
@@ -525,10 +526,10 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
self._discover_filesystems()
@provider_info_cache("dataset_uris")
- def initialize_providers_dataset_uri_handlers_and_factories(self):
- """Lazy initialization of provider dataset URI handlers."""
+ def initialize_providers_dataset_uri_resources(self):
+ """Lazy initialization of provider dataset URI handlers, factories,
converters etc."""
self.initialize_providers_list()
- self._discover_dataset_uri_handlers_and_factories()
+ self._discover_dataset_uri_resources()
@provider_info_cache("hook_lineage_writers")
@provider_info_cache("taskflow_decorators")
@@ -881,28 +882,52 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
self._fs_set.add(fs_module_name)
self._fs_set = set(sorted(self._fs_set))
- def _discover_dataset_uri_handlers_and_factories(self) -> None:
+ def _discover_dataset_uri_resources(self) -> None:
+ """Discovers and registers dataset URI handlers, factories, and
converters for all providers."""
from airflow.datasets import normalize_noop
- for provider_package, provider in self._provider_dict.items():
- for handler_info in provider.data.get("dataset-uris", []):
- schemes = handler_info.get("schemes")
- handler_path = handler_info.get("handler")
- factory_path = handler_info.get("factory")
- if schemes is None:
- continue
-
- if handler_path is not None and (
- handler := _correctness_check(provider_package,
handler_path, provider)
- ):
- pass
- else:
- handler = normalize_noop
- self._dataset_uri_handlers.update((scheme, handler) for scheme
in schemes)
- if factory_path is not None and (
- factory := _correctness_check(provider_package,
factory_path, provider)
- ):
- self._dataset_factories.update((scheme, factory) for
scheme in schemes)
+ def _safe_register_resource(
+ provider_package_name: str,
+ schemes_list: list[str],
+ resource_path: str | None,
+ resource_registry: dict,
+ default_resource: Any = None,
+ ):
+ """
+ Register a specific resource (handler, factory, or converter) for
the given schemes.
+
+ If the resolved resource (either from the path or the default) is
valid, it updates
+ the resource registry with the appropriate resource for each
scheme.
+ """
+ resource = (
+ _correctness_check(provider_package_name, resource_path,
provider)
+ if resource_path is not None
+ else default_resource
+ )
+ if resource:
+ resource_registry.update((scheme, resource) for scheme in
schemes_list)
+
+ for provider_name, provider in self._provider_dict.items():
+ for uri_info in provider.data.get("dataset-uris", []):
+ if "schemes" not in uri_info or "handler" not in uri_info:
+ continue # Both schemas and handler must be explicitly
set, handler can be set to null
+ common_args = {"schemes_list": uri_info["schemes"],
"provider_package_name": provider_name}
+ _safe_register_resource(
+ resource_path=uri_info["handler"],
+ resource_registry=self._dataset_uri_handlers,
+ default_resource=normalize_noop,
+ **common_args,
+ )
+ _safe_register_resource(
+ resource_path=uri_info.get("factory"),
+ resource_registry=self._dataset_factories,
+ **common_args,
+ )
+ _safe_register_resource(
+ resource_path=uri_info.get("to_openlineage_converter"),
+ resource_registry=self._dataset_to_openlineage_converters,
+ **common_args,
+ )
def _discover_taskflow_decorators(self) -> None:
for name, info in self._provider_dict.items():
@@ -1301,14 +1326,21 @@ class ProvidersManager(LoggingMixin,
metaclass=Singleton):
@property
def dataset_factories(self) -> dict[str, Callable[..., Dataset]]:
- self.initialize_providers_dataset_uri_handlers_and_factories()
+ self.initialize_providers_dataset_uri_resources()
return self._dataset_factories
@property
def dataset_uri_handlers(self) -> dict[str, Callable[[SplitResult],
SplitResult]]:
- self.initialize_providers_dataset_uri_handlers_and_factories()
+ self.initialize_providers_dataset_uri_resources()
return self._dataset_uri_handlers
+ @property
+ def dataset_to_openlineage_converters(
+ self,
+ ) -> dict[str, Callable]:
+ self.initialize_providers_dataset_uri_resources()
+ return self._dataset_to_openlineage_converters
+
@property
def provider_configs(self) -> list[tuple[str, dict[str, Any]]]:
self.initialize_providers_configuration()
diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py
index 19dbcf4b53..1b2d3d6d4c 100644
--- a/tests/datasets/test_dataset.py
+++ b/tests/datasets/test_dataset.py
@@ -31,6 +31,7 @@ from airflow.datasets import (
DatasetAll,
DatasetAny,
_DatasetAliasCondition,
+ _get_normalized_scheme,
_sanitize_uri,
)
from airflow.models.dataset import DatasetAliasModel, DatasetDagRunQueue,
DatasetModel
@@ -454,31 +455,68 @@ def test_datasets_expression_error(expression:
Callable[[], None], error: str) -
assert str(info.value) == error
-def mock_get_uri_normalizer(normalized_scheme):
+def test_get_normalized_scheme():
+ assert _get_normalized_scheme("http://example.com") == "http"
+ assert _get_normalized_scheme("HTTPS://example.com") == "https"
+ assert _get_normalized_scheme("ftp://example.com") == "ftp"
+ assert _get_normalized_scheme("file://") == "file"
+
+ assert _get_normalized_scheme("example.com") == ""
+ assert _get_normalized_scheme("") == ""
+ assert _get_normalized_scheme(" ") == ""
+
+
+def _mock_get_uri_normalizer_raising_error(normalized_scheme):
def normalizer(uri):
raise ValueError("Incorrect URI format")
return normalizer
-@patch("airflow.datasets._get_uri_normalizer", mock_get_uri_normalizer)
+def _mock_get_uri_normalizer_noop(normalized_scheme):
+ def normalizer(uri):
+ return uri
+
+ return normalizer
+
+
+@patch("airflow.datasets._get_uri_normalizer",
_mock_get_uri_normalizer_raising_error)
@patch("airflow.datasets.warnings.warn")
-def test__sanitize_uri_raises_warning(mock_warn):
+def test_sanitize_uri_raises_warning(mock_warn):
_sanitize_uri("postgres://localhost:5432/database.schema.table")
msg = mock_warn.call_args.args[0]
assert "The dataset URI postgres://localhost:5432/database.schema.table is
not AIP-60 compliant" in msg
assert "In Airflow 3, this will raise an exception." in msg
-@patch("airflow.datasets._get_uri_normalizer", mock_get_uri_normalizer)
+@patch("airflow.datasets._get_uri_normalizer",
_mock_get_uri_normalizer_raising_error)
@conf_vars({("core", "strict_dataset_uri_validation"): "True"})
-def test__sanitize_uri_raises_exception():
+def test_sanitize_uri_raises_exception():
with pytest.raises(ValueError) as e_info:
_sanitize_uri("postgres://localhost:5432/database.schema.table")
assert isinstance(e_info.value, ValueError)
assert str(e_info.value) == "Incorrect URI format"
+@patch("airflow.datasets._get_uri_normalizer", lambda x: None)
+def test_normalize_uri_no_normalizer_found():
+ dataset = Dataset(uri="any_uri_without_normalizer_defined")
+ assert dataset.normalized_uri is None
+
+
+@patch("airflow.datasets._get_uri_normalizer",
_mock_get_uri_normalizer_raising_error)
+def test_normalize_uri_invalid_uri():
+ dataset = Dataset(uri="any_uri_not_aip60_compliant")
+ assert dataset.normalized_uri is None
+
+
+@patch("airflow.datasets._get_uri_normalizer", _mock_get_uri_normalizer_noop)
+@patch("airflow.datasets._get_normalized_scheme", lambda x: "valid_scheme")
+def test_normalize_uri_valid_uri():
+ dataset = Dataset(uri="valid_aip60_uri")
+ assert dataset.normalized_uri == "valid_aip60_uri"
+
+
@pytest.mark.db_test
@pytest.mark.usefixtures("clear_datasets")
class Test_DatasetAliasCondition:
diff --git a/tests/providers/amazon/aws/datasets/test_s3.py
b/tests/providers/amazon/aws/datasets/test_s3.py
index c7ffe25240..893d6acf67 100644
--- a/tests/providers/amazon/aws/datasets/test_s3.py
+++ b/tests/providers/amazon/aws/datasets/test_s3.py
@@ -16,8 +16,38 @@
# under the License.
from __future__ import annotations
+import urllib.parse
+
+import pytest
+
from airflow.datasets import Dataset
-from airflow.providers.amazon.aws.datasets.s3 import create_dataset
+from airflow.providers.amazon.aws.datasets.s3 import (
+ convert_dataset_to_openlineage,
+ create_dataset,
+ sanitize_uri,
+)
+from airflow.providers.amazon.aws.hooks.s3 import S3Hook
+
+
+def test_sanitize_uri():
+ uri = sanitize_uri(urllib.parse.urlsplit("s3://bucket/dir/file.txt"))
+ result = sanitize_uri(uri)
+ assert result.scheme == "s3"
+ assert result.netloc == "bucket"
+ assert result.path == "/dir/file.txt"
+
+
+def test_sanitize_uri_no_netloc():
+ with pytest.raises(ValueError):
+ sanitize_uri(urllib.parse.urlsplit("s3://"))
+
+
+def test_sanitize_uri_no_path():
+ uri = sanitize_uri(urllib.parse.urlsplit("s3://bucket"))
+ result = sanitize_uri(uri)
+ assert result.scheme == "s3"
+ assert result.netloc == "bucket"
+ assert result.path == ""
def test_create_dataset():
@@ -25,3 +55,25 @@ def test_create_dataset():
assert create_dataset(bucket="test-bucket", key="test-dir/test-path") ==
Dataset(
uri="s3://test-bucket/test-dir/test-path"
)
+
+
+def test_sanitize_uri_trailing_slash():
+ uri = sanitize_uri(urllib.parse.urlsplit("s3://bucket/"))
+ result = sanitize_uri(uri)
+ assert result.scheme == "s3"
+ assert result.netloc == "bucket"
+ assert result.path == "/"
+
+
+def test_convert_dataset_to_openlineage_valid():
+ uri = "s3://bucket/dir/file.txt"
+ ol_dataset = convert_dataset_to_openlineage(dataset=Dataset(uri=uri),
lineage_context=S3Hook())
+ assert ol_dataset.namespace == "s3://bucket"
+ assert ol_dataset.name == "dir/file.txt"
+
+
[email protected]("uri", ("s3://bucket", "s3://bucket/"))
+def test_convert_dataset_to_openlineage_no_path(uri):
+ ol_dataset = convert_dataset_to_openlineage(dataset=Dataset(uri=uri),
lineage_context=S3Hook())
+ assert ol_dataset.namespace == "s3://bucket"
+ assert ol_dataset.name == "/"
diff --git a/tests/providers/common/io/datasets/test_file.py
b/tests/providers/common/io/datasets/test_file.py
index 43d63cb205..b2e4fddf98 100644
--- a/tests/providers/common/io/datasets/test_file.py
+++ b/tests/providers/common/io/datasets/test_file.py
@@ -16,9 +16,52 @@
# under the License.
from __future__ import annotations
+from urllib.parse import urlsplit, urlunsplit
+
+import pytest
+from openlineage.client.run import Dataset as OpenLineageDataset
+
from airflow.datasets import Dataset
-from airflow.providers.common.io.datasets.file import create_dataset
+from airflow.providers.common.io.datasets.file import (
+ convert_dataset_to_openlineage,
+ create_dataset,
+ sanitize_uri,
+)
+
+
[email protected](
+ ("uri", "expected"),
+ (
+ ("file:///valid/path/", "file:///valid/path/"),
+ ("file://C://dir/file", "file://C://dir/file"),
+ ),
+)
+def test_sanitize_uri_valid(uri, expected):
+ result = sanitize_uri(urlsplit(uri))
+ assert urlunsplit(result) == expected
+
+
[email protected]("uri", ("file://",))
+def test_sanitize_uri_invalid(uri):
+ with pytest.raises(ValueError):
+ sanitize_uri(urlsplit(uri))
def test_file_dataset():
assert create_dataset(path="/asdf/fdsa") ==
Dataset(uri="file:///asdf/fdsa")
+
+
[email protected](
+ ("uri", "ol_dataset"),
+ (
+ ("file:///valid/path", OpenLineageDataset(namespace="file://",
name="/valid/path")),
+ (
+ "file://127.0.0.1:8080/dir/file.csv",
+ OpenLineageDataset(namespace="file://127.0.0.1:8080",
name="/dir/file.csv"),
+ ),
+ ("file:///C://dir/file", OpenLineageDataset(namespace="file://",
name="/C://dir/file")),
+ ),
+)
+def test_convert_dataset_to_openlineage(uri, ol_dataset):
+ result = convert_dataset_to_openlineage(Dataset(uri=uri), None)
+ assert result == ol_dataset