This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0212f67192 openlineage: add support for hook lineage for S3Hook 
(#40819)
0212f67192 is described below

commit 0212f671921fc5da15085eefbde8b0a76db40fd9
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Mon Jul 22 11:00:48 2024 +0200

    openlineage: add support for hook lineage for S3Hook (#40819)
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 airflow/lineage/hook.py                            |  4 +-
 .../file.py => amazon/aws/datasets/__init__.py}    |  8 --
 .../datasets/file.py => amazon/aws/datasets/s3.py} |  5 +-
 airflow/providers/amazon/aws/hooks/s3.py           | 27 +++++-
 airflow/providers/amazon/provider.yaml             |  2 +
 airflow/providers/common/compat/lineage/hook.py    |  4 +-
 airflow/providers/common/compat/provider.yaml      |  1 +
 airflow/providers/common/io/datasets/file.py       |  4 +-
 airflow/providers_manager.py                       | 26 +++---
 dev/breeze/tests/test_selective_checks.py          |  8 +-
 generated/provider_dependencies.json               |  2 +
 prod_image_installed_providers.txt                 |  1 +
 tests/conftest.py                                  | 10 +++
 .../providers/amazon/aws/datasets/__init__.py      |  8 --
 .../providers/amazon/aws/datasets/test_s3.py       |  9 +-
 tests/providers/amazon/aws/hooks/test_s3.py        | 99 +++++++++++++++++++++-
 16 files changed, 171 insertions(+), 47 deletions(-)

diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py
index 70893516bd..ee12e1624e 100644
--- a/airflow/lineage/hook.py
+++ b/airflow/lineage/hook.py
@@ -139,10 +139,10 @@ class NoOpCollector(HookLineageCollector):
     It is used when you want to disable lineage collection.
     """
 
-    def add_input_dataset(self, *_):
+    def add_input_dataset(self, *_, **__):
         pass
 
-    def add_output_dataset(self, *_):
+    def add_output_dataset(self, *_, **__):
         pass
 
     @property
diff --git a/airflow/providers/common/io/datasets/file.py 
b/airflow/providers/amazon/aws/datasets/__init__.py
similarity index 78%
copy from airflow/providers/common/io/datasets/file.py
copy to airflow/providers/amazon/aws/datasets/__init__.py
index 46c7499037..13a83393a9 100644
--- a/airflow/providers/common/io/datasets/file.py
+++ b/airflow/providers/amazon/aws/datasets/__init__.py
@@ -14,11 +14,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from __future__ import annotations
-
-from airflow.datasets import Dataset
-
-
-def create_dataset(*, path: str) -> Dataset:
-    # We assume that we get absolute path starting with /
-    return Dataset(uri=f"file://{path}")
diff --git a/airflow/providers/common/io/datasets/file.py 
b/airflow/providers/amazon/aws/datasets/s3.py
similarity index 85%
copy from airflow/providers/common/io/datasets/file.py
copy to airflow/providers/amazon/aws/datasets/s3.py
index 46c7499037..89889efe57 100644
--- a/airflow/providers/common/io/datasets/file.py
+++ b/airflow/providers/amazon/aws/datasets/s3.py
@@ -19,6 +19,5 @@ from __future__ import annotations
 from airflow.datasets import Dataset
 
 
-def create_dataset(*, path: str) -> Dataset:
-    # We assume that we get absolute path starting with /
-    return Dataset(uri=f"file://{path}")
+def create_dataset(*, bucket: str, key: str, extra=None) -> Dataset:
+    return Dataset(uri=f"s3://{bucket}/{key}", extra=extra)
diff --git a/airflow/providers/amazon/aws/hooks/s3.py 
b/airflow/providers/amazon/aws/hooks/s3.py
index 8ca93766e2..5f2c136640 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -41,6 +41,8 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Callable
 from urllib.parse import urlsplit
 from uuid import uuid4
 
+from airflow.providers.common.compat.lineage.hook import 
get_hook_lineage_collector
+
 if TYPE_CHECKING:
     from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as 
S3ResourceObject
 
@@ -1111,6 +1113,12 @@ class S3Hook(AwsBaseHook):
 
         client = self.get_conn()
         client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args, 
Config=self.transfer_config)
+        get_hook_lineage_collector().add_input_dataset(
+            context=self, scheme="file", dataset_kwargs={"path": filename}
+        )
+        get_hook_lineage_collector().add_output_dataset(
+            context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, 
"key": key}
+        )
 
     @unify_bucket_name_and_key
     @provide_bucket_name
@@ -1251,6 +1259,10 @@ class S3Hook(AwsBaseHook):
             ExtraArgs=extra_args,
             Config=self.transfer_config,
         )
+        # No input because file_obj can be anything - handle in calling 
function if possible
+        get_hook_lineage_collector().add_output_dataset(
+            context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, 
"key": key}
+        )
 
     def copy_object(
         self,
@@ -1306,6 +1318,12 @@ class S3Hook(AwsBaseHook):
         response = self.get_conn().copy_object(
             Bucket=dest_bucket_name, Key=dest_bucket_key, 
CopySource=copy_source, **kwargs
         )
+        get_hook_lineage_collector().add_input_dataset(
+            context=self, scheme="s3", dataset_kwargs={"bucket": 
source_bucket_name, "key": source_bucket_key}
+        )
+        get_hook_lineage_collector().add_output_dataset(
+            context=self, scheme="s3", dataset_kwargs={"bucket": 
dest_bucket_name, "key": dest_bucket_key}
+        )
         return response
 
     @provide_bucket_name
@@ -1425,6 +1443,11 @@ class S3Hook(AwsBaseHook):
 
             file_path.parent.mkdir(exist_ok=True, parents=True)
 
+            get_hook_lineage_collector().add_output_dataset(
+                context=self,
+                scheme="file",
+                dataset_kwargs={"path": file_path if file_path.is_absolute() 
else file_path.absolute()},
+            )
             file = open(file_path, "wb")
         else:
             file = NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", 
delete=False)  # type: ignore
@@ -1435,7 +1458,9 @@ class S3Hook(AwsBaseHook):
                 ExtraArgs=self.extra_args,
                 Config=self.transfer_config,
             )
-
+        get_hook_lineage_collector().add_input_dataset(
+            context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, 
"key": key}
+        )
         return file.name
 
     def generate_presigned_url(
diff --git a/airflow/providers/amazon/provider.yaml 
b/airflow/providers/amazon/provider.yaml
index a7b4d4272f..9dd76ac9fa 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -91,6 +91,7 @@ dependencies:
   - apache-airflow>=2.7.0
   - apache-airflow-providers-common-sql>=1.3.1
   - apache-airflow-providers-http
+  - apache-airflow-providers-common-compat>=1.1.0
   # We should update minimum version of boto3 and here regularly to avoid 
`pip` backtracking with the number
   # of candidates to consider. Make sure to configure boto3 version here as 
well as in all the tools below
   # in the `devel-dependencies` section to be the same minimum version.
@@ -561,6 +562,7 @@ sensors:
 dataset-uris:
   - schemes: [s3]
     handler: null
+    factory: airflow.providers.amazon.aws.datasets.s3.create_dataset
 
 filesystems:
   - airflow.providers.amazon.aws.fs.s3
diff --git a/airflow/providers/common/compat/lineage/hook.py 
b/airflow/providers/common/compat/lineage/hook.py
index 2115c992e7..dbdbc5bf86 100644
--- a/airflow/providers/common/compat/lineage/hook.py
+++ b/airflow/providers/common/compat/lineage/hook.py
@@ -32,10 +32,10 @@ def get_hook_lineage_collector():
             It is used when you want to disable lineage collection.
             """
 
-            def add_input_dataset(self, *_):
+            def add_input_dataset(self, *_, **__):
                 pass
 
-            def add_output_dataset(self, *_):
+            def add_output_dataset(self, *_, **__):
                 pass
 
         return NoOpCollector()
diff --git a/airflow/providers/common/compat/provider.yaml 
b/airflow/providers/common/compat/provider.yaml
index 27e610e25f..53527f9204 100644
--- a/airflow/providers/common/compat/provider.yaml
+++ b/airflow/providers/common/compat/provider.yaml
@@ -25,6 +25,7 @@ state: ready
 source-date-epoch: 1716287191
 # note that those versions are maintained by release manager - do not update 
them manually
 versions:
+  - 1.1.0
   - 1.0.0
 
 dependencies:
diff --git a/airflow/providers/common/io/datasets/file.py 
b/airflow/providers/common/io/datasets/file.py
index 46c7499037..1bc4969762 100644
--- a/airflow/providers/common/io/datasets/file.py
+++ b/airflow/providers/common/io/datasets/file.py
@@ -19,6 +19,6 @@ from __future__ import annotations
 from airflow.datasets import Dataset
 
 
-def create_dataset(*, path: str) -> Dataset:
+def create_dataset(*, path: str, extra=None) -> Dataset:
     # We assume that we get absolute path starting with /
-    return Dataset(uri=f"file://{path}")
+    return Dataset(uri=f"file://{path}", extra=extra)
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index 9e9dd4d573..f6d29a51d1 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -886,23 +886,23 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
 
         for provider_package, provider in self._provider_dict.items():
             for handler_info in provider.data.get("dataset-uris", []):
-                try:
-                    schemes = handler_info["schemes"]
-                    handler_path = handler_info["handler"]
-                except KeyError:
+                schemes = handler_info.get("schemes")
+                handler_path = handler_info.get("handler")
+                factory_path = handler_info.get("factory")
+                if schemes is None:
                     continue
-                if handler_path is None:
+
+                if handler_path is not None and (
+                    handler := _correctness_check(provider_package, 
handler_path, provider)
+                ):
+                    pass
+                else:
                     handler = normalize_noop
-                elif not (handler := _correctness_check(provider_package, 
handler_path, provider)):
-                    continue
                 self._dataset_uri_handlers.update((scheme, handler) for scheme 
in schemes)
-                factory_path = handler_info.get("factory")
-                if not (
-                    factory_path is not None
-                    and (factory := _correctness_check(provider_package, 
factory_path, provider))
+                if factory_path is not None and (
+                    factory := _correctness_check(provider_package, 
factory_path, provider)
                 ):
-                    continue
-                self._dataset_factories.update((scheme, factory) for scheme in 
schemes)
+                    self._dataset_factories.update((scheme, factory) for 
scheme in schemes)
 
     def _discover_taskflow_decorators(self) -> None:
         for name, info in self._provider_dict.items():
diff --git a/dev/breeze/tests/test_selective_checks.py 
b/dev/breeze/tests/test_selective_checks.py
index 4c215ca3b0..c0c40b9be9 100644
--- a/dev/breeze/tests/test_selective_checks.py
+++ b/dev/breeze/tests/test_selective_checks.py
@@ -569,7 +569,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, 
str], stderr: str):
             ("airflow/providers/amazon/__init__.py",),
             {
                 "affected-providers-list-as-string": "amazon apache.hive 
cncf.kubernetes "
-                "common.sql exasol ftp google http imap microsoft.azure "
+                "common.compat common.sql exasol ftp google http imap 
microsoft.azure "
                 "mongo mysql openlineage postgres salesforce ssh teradata",
                 "all-python-versions": "['3.8']",
                 "all-python-versions-list-as-string": "3.8",
@@ -585,7 +585,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, 
str], stderr: str):
                 "upgrade-to-newer-dependencies": "false",
                 "run-amazon-tests": "true",
                 "parallel-test-types-list-as-string": "Always 
Providers[amazon] "
-                
"Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,http,"
+                
"Providers[apache.hive,cncf.kubernetes,common.compat,common.sql,exasol,ftp,http,"
                 
"imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh,teradata] 
Providers[google]",
                 "needs-mypy": "true",
                 "mypy-folders": "['providers']",
@@ -619,7 +619,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, 
str], stderr: str):
             ("airflow/providers/amazon/file.py",),
             {
                 "affected-providers-list-as-string": "amazon apache.hive 
cncf.kubernetes "
-                "common.sql exasol ftp google http imap microsoft.azure "
+                "common.compat common.sql exasol ftp google http imap 
microsoft.azure "
                 "mongo mysql openlineage postgres salesforce ssh teradata",
                 "all-python-versions": "['3.8']",
                 "all-python-versions-list-as-string": "3.8",
@@ -635,7 +635,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, 
str], stderr: str):
                 "run-kubernetes-tests": "false",
                 "upgrade-to-newer-dependencies": "false",
                 "parallel-test-types-list-as-string": "Always 
Providers[amazon] "
-                
"Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,http,"
+                
"Providers[apache.hive,cncf.kubernetes,common.compat,common.sql,exasol,ftp,http,"
                 
"imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh,teradata] 
Providers[google]",
                 "needs-mypy": "true",
                 "mypy-folders": "['providers']",
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 4eabdd22be..86b9b2e15b 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -28,6 +28,7 @@
   "amazon": {
     "deps": [
       "PyAthena>=3.0.10",
+      "apache-airflow-providers-common-compat>=1.1.0",
       "apache-airflow-providers-common-sql>=1.3.1",
       "apache-airflow-providers-http",
       "apache-airflow>=2.7.0",
@@ -57,6 +58,7 @@
     "cross-providers-deps": [
       "apache.hive",
       "cncf.kubernetes",
+      "common.compat",
       "common.sql",
       "exasol",
       "ftp",
diff --git a/prod_image_installed_providers.txt 
b/prod_image_installed_providers.txt
index c292b7b83d..7340928738 100644
--- a/prod_image_installed_providers.txt
+++ b/prod_image_installed_providers.txt
@@ -2,6 +2,7 @@
 amazon
 celery
 cncf.kubernetes
+common.compat
 common.io
 common.sql
 docker
diff --git a/tests/conftest.py b/tests/conftest.py
index 9027391575..6cb74446dc 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1326,6 +1326,16 @@ def airflow_root_path() -> Path:
     return Path(airflow.__path__[0]).parent
 
 
[email protected]
+def hook_lineage_collector():
+    from airflow.lineage import hook
+
+    hook._hook_lineage_collector = None
+    hook._hook_lineage_collector = hook.HookLineageCollector()
+    yield hook.get_hook_lineage_collector()
+    hook._hook_lineage_collector = None
+
+
 # This constant is set to True if tests are run with Airflow installed from 
Packages rather than running
 # the tests within Airflow sources. While most tests in CI are run using 
Airflow sources, there are
 # also compatibility tests that only use `tests` package and run against 
installed packages of Airflow in
diff --git a/airflow/providers/common/io/datasets/file.py 
b/tests/providers/amazon/aws/datasets/__init__.py
similarity index 78%
copy from airflow/providers/common/io/datasets/file.py
copy to tests/providers/amazon/aws/datasets/__init__.py
index 46c7499037..13a83393a9 100644
--- a/airflow/providers/common/io/datasets/file.py
+++ b/tests/providers/amazon/aws/datasets/__init__.py
@@ -14,11 +14,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from __future__ import annotations
-
-from airflow.datasets import Dataset
-
-
-def create_dataset(*, path: str) -> Dataset:
-    # We assume that we get absolute path starting with /
-    return Dataset(uri=f"file://{path}")
diff --git a/airflow/providers/common/io/datasets/file.py 
b/tests/providers/amazon/aws/datasets/test_s3.py
similarity index 71%
copy from airflow/providers/common/io/datasets/file.py
copy to tests/providers/amazon/aws/datasets/test_s3.py
index 46c7499037..c7ffe25240 100644
--- a/airflow/providers/common/io/datasets/file.py
+++ b/tests/providers/amazon/aws/datasets/test_s3.py
@@ -17,8 +17,11 @@
 from __future__ import annotations
 
 from airflow.datasets import Dataset
+from airflow.providers.amazon.aws.datasets.s3 import create_dataset
 
 
-def create_dataset(*, path: str) -> Dataset:
-    # We assume that we get absolute path starting with /
-    return Dataset(uri=f"file://{path}")
+def test_create_dataset():
+    assert create_dataset(bucket="test-bucket", key="test-path") == 
Dataset(uri="s3://test-bucket/test-path")
+    assert create_dataset(bucket="test-bucket", key="test-dir/test-path") == 
Dataset(
+        uri="s3://test-bucket/test-dir/test-path"
+    )
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py 
b/tests/providers/amazon/aws/hooks/test_s3.py
index 6b10173d3c..acedf3d011 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -31,6 +31,7 @@ import pytest
 from botocore.exceptions import ClientError
 from moto import mock_aws
 
+from airflow.datasets import Dataset
 from airflow.exceptions import AirflowException
 from airflow.models import Connection
 from airflow.providers.amazon.aws.exceptions import S3HookUriParseFailure
@@ -41,6 +42,7 @@ from airflow.providers.amazon.aws.hooks.s3 import (
     unify_bucket_name_and_key,
 )
 from airflow.utils.timezone import datetime
+from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS
 
 
 @pytest.fixture
@@ -388,6 +390,15 @@ class TestAwsS3Hook:
         resource = boto3.resource("s3").Object(s3_bucket, "my_key")
         assert resource.get()["Body"].read() == b"Cont\xc3\xa9nt"
 
+    @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in 
Airflow >= 2.10.0")
+    def test_load_string_exposes_lineage(self, s3_bucket, 
hook_lineage_collector):
+        hook = S3Hook()
+        hook.load_string("Contént", "my_key", s3_bucket)
+        assert len(hook_lineage_collector.collected_datasets.outputs) == 1
+        assert hook_lineage_collector.collected_datasets.outputs[0][0] == 
Dataset(
+            uri=f"s3://{s3_bucket}/my_key"
+        )
+
     def test_load_string_compress(self, s3_bucket):
         hook = S3Hook()
         hook.load_string("Contént", "my_key", s3_bucket, compression="gzip")
@@ -970,6 +981,17 @@ class TestAwsS3Hook:
         resource = boto3.resource("s3").Object(s3_bucket, "my_key")
         assert gz.decompress(resource.get()["Body"].read()) == b"Content"
 
+    @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in 
Airflow >= 2.10.0")
+    def test_load_file_exposes_lineage(self, s3_bucket, tmp_path, 
hook_lineage_collector):
+        hook = S3Hook()
+        path = tmp_path / "testfile"
+        path.write_text("Content")
+        hook.load_file(path, "my_key", s3_bucket)
+        assert len(hook_lineage_collector.collected_datasets.outputs) == 1
+        assert hook_lineage_collector.collected_datasets.outputs[0][0] == 
Dataset(
+            uri=f"s3://{s3_bucket}/my_key"
+        )
+
     def test_load_file_acl(self, s3_bucket, tmp_path):
         hook = S3Hook()
         path = tmp_path / "testfile"
@@ -1027,6 +1049,26 @@ class TestAwsS3Hook:
                 ACL="private",
             )
 
+    @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in 
Airflow >= 2.10.0")
+    @mock_aws
+    def test_copy_object_ol_instrumentation(self, s3_bucket, 
hook_lineage_collector):
+        mock_hook = S3Hook()
+
+        with mock.patch.object(
+            S3Hook,
+            "get_conn",
+        ):
+            mock_hook.copy_object("my_key", "my_key3", s3_bucket, s3_bucket)
+            assert len(hook_lineage_collector.collected_datasets.inputs) == 1
+            assert hook_lineage_collector.collected_datasets.inputs[0][0] == 
Dataset(
+                uri=f"s3://{s3_bucket}/my_key"
+            )
+
+            assert len(hook_lineage_collector.collected_datasets.outputs) == 1
+            assert hook_lineage_collector.collected_datasets.outputs[0][0] == 
Dataset(
+                uri=f"s3://{s3_bucket}/my_key3"
+            )
+
     @mock_aws
     def test_delete_bucket_if_bucket_exist(self, s3_bucket):
         # assert if the bucket is created
@@ -1140,6 +1182,26 @@ class TestAwsS3Hook:
 
         assert path.name == output_file
 
+    @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in 
Airflow >= 2.10.0")
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.NamedTemporaryFile")
+    def test_download_file_exposes_lineage(self, mock_temp_file, tmp_path, 
hook_lineage_collector):
+        path = tmp_path / "airflow_tmp_test_s3_hook"
+        mock_temp_file.return_value = path
+        s3_hook = S3Hook(aws_conn_id="s3_test")
+        s3_hook.check_for_key = Mock(return_value=True)
+        s3_obj = Mock()
+        s3_obj.download_fileobj = Mock(return_value=None)
+        s3_hook.get_key = Mock(return_value=s3_obj)
+        key = "test_key"
+        bucket = "test_bucket"
+
+        s3_hook.download_file(key=key, bucket_name=bucket)
+
+        assert len(hook_lineage_collector.collected_datasets.inputs) == 1
+        assert hook_lineage_collector.collected_datasets.inputs[0][0] == 
Dataset(
+            uri="s3://test_bucket/test_key"
+        )
+
     @mock.patch("airflow.providers.amazon.aws.hooks.s3.open")
     def test_download_file_with_preserve_name(self, mock_open, tmp_path):
         path = tmp_path / "test.log"
@@ -1152,16 +1214,51 @@ class TestAwsS3Hook:
         s3_obj.key = f"s3://{bucket}/{key}"
         s3_obj.download_fileobj = Mock(return_value=None)
         s3_hook.get_key = Mock(return_value=s3_obj)
+        local_path = os.fspath(path.parent)
         s3_hook.download_file(
             key=key,
             bucket_name=bucket,
-            local_path=os.fspath(path.parent),
+            local_path=local_path,
             preserve_file_name=True,
             use_autogenerated_subdir=False,
         )
 
         mock_open.assert_called_once_with(path, "wb")
 
+    @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in 
Airflow >= 2.10.0")
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.open")
+    def test_download_file_with_preserve_name_exposes_lineage(
+        self, mock_open, tmp_path, hook_lineage_collector
+    ):
+        path = tmp_path / "test.log"
+        bucket = "test_bucket"
+        key = f"test_key/{path.name}"
+
+        s3_hook = S3Hook(aws_conn_id="s3_test")
+        s3_hook.check_for_key = Mock(return_value=True)
+        s3_obj = Mock()
+        s3_obj.key = f"s3://{bucket}/{key}"
+        s3_obj.download_fileobj = Mock(return_value=None)
+        s3_hook.get_key = Mock(return_value=s3_obj)
+        local_path = os.fspath(path.parent)
+        s3_hook.download_file(
+            key=key,
+            bucket_name=bucket,
+            local_path=local_path,
+            preserve_file_name=True,
+            use_autogenerated_subdir=False,
+        )
+
+        assert len(hook_lineage_collector.collected_datasets.inputs) == 1
+        assert hook_lineage_collector.collected_datasets.inputs[0][0] == 
Dataset(
+            uri="s3://test_bucket/test_key/test.log"
+        )
+
+        assert len(hook_lineage_collector.collected_datasets.outputs) == 1
+        assert hook_lineage_collector.collected_datasets.outputs[0][0] == 
Dataset(
+            uri=f"file://{local_path}/test.log",
+        )
+
     @mock.patch("airflow.providers.amazon.aws.hooks.s3.open")
     def test_download_file_with_preserve_name_with_autogenerated_subdir(self, 
mock_open, tmp_path):
         path = tmp_path / "test.log"

Reply via email to