This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b73366799d openlineage, gcs: add openlineage methods for 
GcsToGcsOperator (#31350)
b73366799d is described below

commit b73366799d98195a5ccc49a2008932186c4763b5
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Thu Jul 27 07:53:43 2023 +0200

    openlineage, gcs: add openlineage methods for GcsToGcsOperator (#31350)
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 .../providers/google/cloud/transfers/gcs_to_gcs.py | 29 +++++++++
 airflow/providers/openlineage/extractors/base.py   | 18 +++---
 dev/breeze/tests/test_selective_checks.py          | 12 ++--
 generated/provider_dependencies.json               |  1 +
 .../google/cloud/transfers/test_gcs_to_gcs.py      | 73 ++++++++++++++++++++++
 5 files changed, 120 insertions(+), 13 deletions(-)

diff --git a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py 
b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
index 17d1638559..cf4cf3c0ce 100644
--- a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
@@ -233,6 +233,8 @@ class GCSToGCSOperator(BaseOperator):
         self.source_object_required = source_object_required
         self.exact_match = exact_match
         self.match_glob = match_glob
+        self.resolved_source_objects: set[str] = set()
+        self.resolved_target_objects: set[str] = set()
 
     def execute(self, context: Context):
 
@@ -540,7 +542,34 @@ class GCSToGCSOperator(BaseOperator):
             destination_object,
         )
 
+        self.resolved_source_objects.add(source_object)
+        if not destination_object:
+            self.resolved_target_objects.add(source_object)
+        else:
+            self.resolved_target_objects.add(destination_object)
+
         hook.rewrite(self.source_bucket, source_object, 
self.destination_bucket, destination_object)
 
         if self.move_object:
             hook.delete(self.source_bucket, source_object)
+
+    def get_openlineage_events_on_complete(self, task_instance):
+        """
+        Implementing _on_complete because execute method does preprocessing on 
internals.
+        This means we won't have to normalize self.source_object and 
self.source_objects,
+        destination bucket and so on.
+        """
+        from openlineage.client.run import Dataset
+
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        return OperatorLineage(
+            inputs=[
+                Dataset(namespace=f"gs://{self.source_bucket}", name=source)
+                for source in sorted(self.resolved_source_objects)
+            ],
+            outputs=[
+                Dataset(namespace=f"gs://{self.destination_bucket}", 
name=target)
+                for target in sorted(self.resolved_target_objects)
+            ],
+        )
diff --git a/airflow/providers/openlineage/extractors/base.py 
b/airflow/providers/openlineage/extractors/base.py
index 51c9281e56..95d8fa6f28 100644
--- a/airflow/providers/openlineage/extractors/base.py
+++ b/airflow/providers/openlineage/extractors/base.py
@@ -83,6 +83,7 @@ class DefaultExtractor(BaseExtractor):
         return []
 
     def extract(self) -> OperatorLineage | None:
+        # OpenLineage methods are optional - if there's no method, return None
         try:
             return 
self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start)  # 
type: ignore
         except AttributeError:
@@ -100,7 +101,15 @@ class DefaultExtractor(BaseExtractor):
 
     def _get_openlineage_facets(self, get_facets_method, *args) -> 
OperatorLineage | None:
         try:
-            facets = get_facets_method(*args)
+            facets: OperatorLineage = get_facets_method(*args)
+            # "rewrite" OperatorLineage to safeguard against different version 
of the same class
+            # that was existing in openlineage-airflow package outside of 
Airflow repo
+            return OperatorLineage(
+                inputs=facets.inputs,
+                outputs=facets.outputs,
+                run_facets=facets.run_facets,
+                job_facets=facets.job_facets,
+            )
         except ImportError:
             self.log.exception(
                 "OpenLineage provider method failed to import OpenLineage 
integration. "
@@ -108,11 +117,4 @@ class DefaultExtractor(BaseExtractor):
             )
         except Exception:
             self.log.exception("OpenLineage provider method failed to extract 
data from provider. ")
-        else:
-            return OperatorLineage(
-                inputs=facets.inputs,
-                outputs=facets.outputs,
-                run_facets=facets.run_facets,
-                job_facets=facets.job_facets,
-            )
         return None
diff --git a/dev/breeze/tests/test_selective_checks.py 
b/dev/breeze/tests/test_selective_checks.py
index ff051b57ad..60ad9015c2 100644
--- a/dev/breeze/tests/test_selective_checks.py
+++ b/dev/breeze/tests/test_selective_checks.py
@@ -539,7 +539,7 @@ def test_expected_output_full_tests_needed(
             {
                 "affected-providers-list-as-string": "amazon apache.beam 
apache.cassandra cncf.kubernetes "
                 "common.sql facebook google hashicorp microsoft.azure 
microsoft.mssql "
-                "mysql oracle postgres presto salesforce sftp ssh trino",
+                "mysql openlineage oracle postgres presto salesforce sftp ssh 
trino",
                 "all-python-versions": "['3.8']",
                 "all-python-versions-list-as-string": "3.8",
                 "needs-helm-tests": "false",
@@ -564,8 +564,8 @@ def test_expected_output_full_tests_needed(
             {
                 "affected-providers-list-as-string": "amazon apache.beam 
apache.cassandra "
                 "cncf.kubernetes common.sql facebook google "
-                "hashicorp microsoft.azure microsoft.mssql mysql oracle 
postgres presto "
-                "salesforce sftp ssh trino",
+                "hashicorp microsoft.azure microsoft.mssql mysql openlineage 
oracle postgres "
+                "presto salesforce sftp ssh trino",
                 "all-python-versions": "['3.8']",
                 "all-python-versions-list-as-string": "3.8",
                 "image-build": "true",
@@ -666,7 +666,7 @@ def test_expected_output_pull_request_v2_3(
                 "affected-providers-list-as-string": "amazon apache.beam 
apache.cassandra "
                 "cncf.kubernetes common.sql "
                 "facebook google hashicorp microsoft.azure microsoft.mssql 
mysql "
-                "oracle postgres presto salesforce sftp ssh trino",
+                "openlineage oracle postgres presto salesforce sftp ssh trino",
                 "all-python-versions": "['3.8']",
                 "all-python-versions-list-as-string": "3.8",
                 "image-build": "true",
@@ -685,6 +685,7 @@ def test_expected_output_pull_request_v2_3(
                 "--package-filter apache-airflow-providers-microsoft-azure "
                 "--package-filter apache-airflow-providers-microsoft-mssql "
                 "--package-filter apache-airflow-providers-mysql "
+                "--package-filter apache-airflow-providers-openlineage "
                 "--package-filter apache-airflow-providers-oracle "
                 "--package-filter apache-airflow-providers-postgres "
                 "--package-filter apache-airflow-providers-presto "
@@ -697,7 +698,7 @@ def test_expected_output_pull_request_v2_3(
                 "skip-provider-tests": "false",
                 "parallel-test-types-list-as-string": "Providers[amazon] 
Always CLI "
                 
"Providers[apache.beam,apache.cassandra,cncf.kubernetes,common.sql,facebook,"
-                
"hashicorp,microsoft.azure,microsoft.mssql,mysql,oracle,postgres,presto,"
+                
"hashicorp,microsoft.azure,microsoft.mssql,mysql,openlineage,oracle,postgres,presto,"
                 "salesforce,sftp,ssh,trino] Providers[google]",
             },
             id="CLI tests and Google-related provider tests should run if 
cli/chart files changed",
@@ -965,6 +966,7 @@ def test_upgrade_to_newer_dependencies(files: tuple[str, 
...], expected_outputs:
                 "--package-filter apache-airflow-providers-microsoft-azure "
                 "--package-filter apache-airflow-providers-microsoft-mssql "
                 "--package-filter apache-airflow-providers-mysql "
+                "--package-filter apache-airflow-providers-openlineage "
                 "--package-filter apache-airflow-providers-oracle "
                 "--package-filter apache-airflow-providers-postgres "
                 "--package-filter apache-airflow-providers-presto "
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 94fc01793a..f9b7722f9b 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -467,6 +467,7 @@
       "microsoft.azure",
       "microsoft.mssql",
       "mysql",
+      "openlineage",
       "oracle",
       "postgres",
       "presto",
diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py 
b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py
index d29a505ba3..cf525235b1 100644
--- a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py
@@ -21,6 +21,7 @@ from datetime import datetime
 from unittest import mock
 
 import pytest
+from openlineage.client.run import Dataset
 
 from airflow.exceptions import AirflowException
 from airflow.providers.google.cloud.transfers.gcs_to_gcs import WILDCARD, 
GCSToGCSOperator
@@ -827,3 +828,75 @@ class TestGoogleCloudStorageToCloudStorageOperator:
             for src, dst in zip(expected_source_objects, 
expected_destination_objects)
         ]
         mock_hook.return_value.rewrite.assert_has_calls(mock_calls)
+
+    @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook")
+    def test_execute_simple_reports_openlineage(self, mock_hook):
+        operator = GCSToGCSOperator(
+            task_id=TASK_ID,
+            source_bucket=TEST_BUCKET,
+            source_object=SOURCE_OBJECTS_SINGLE_FILE[0],
+            destination_bucket=DESTINATION_BUCKET,
+        )
+
+        operator.execute(None)
+
+        lineage = operator.get_openlineage_events_on_complete(None)
+        assert len(lineage.inputs) == 1
+        assert len(lineage.outputs) == 1
+        assert lineage.inputs[0] == Dataset(
+            namespace=f"gs://{TEST_BUCKET}", name=SOURCE_OBJECTS_SINGLE_FILE[0]
+        )
+        assert lineage.outputs[0] == Dataset(
+            namespace=f"gs://{DESTINATION_BUCKET}", 
name=SOURCE_OBJECTS_SINGLE_FILE[0]
+        )
+
+    @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook")
+    def test_execute_multiple_reports_openlineage(self, mock_hook):
+        operator = GCSToGCSOperator(
+            task_id=TASK_ID,
+            source_bucket=TEST_BUCKET,
+            source_objects=SOURCE_OBJECTS_LIST,
+            destination_bucket=DESTINATION_BUCKET,
+            destination_object=DESTINATION_OBJECT,
+        )
+
+        operator.execute(None)
+
+        lineage = operator.get_openlineage_events_on_complete(None)
+        assert len(lineage.inputs) == 3
+        assert len(lineage.outputs) == 1
+        assert lineage.inputs == [
+            Dataset(namespace=f"gs://{TEST_BUCKET}", 
name=SOURCE_OBJECTS_LIST[0]),
+            Dataset(namespace=f"gs://{TEST_BUCKET}", 
name=SOURCE_OBJECTS_LIST[1]),
+            Dataset(namespace=f"gs://{TEST_BUCKET}", 
name=SOURCE_OBJECTS_LIST[2]),
+        ]
+        assert lineage.outputs[0] == 
Dataset(namespace=f"gs://{DESTINATION_BUCKET}", name=DESTINATION_OBJECT)
+
+    @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook")
+    def test_execute_wildcard_reports_openlineage(self, mock_hook):
+        mock_hook.return_value.list.return_value = [
+            "test_object1.txt",
+            "test_object2.txt",
+        ]
+
+        operator = GCSToGCSOperator(
+            task_id=TASK_ID,
+            source_bucket=TEST_BUCKET,
+            source_object=SOURCE_OBJECT_WILDCARD_SUFFIX,
+            destination_bucket=DESTINATION_BUCKET,
+            destination_object=DESTINATION_OBJECT,
+        )
+
+        operator.execute(None)
+
+        lineage = operator.get_openlineage_events_on_complete(None)
+        assert len(lineage.inputs) == 2
+        assert len(lineage.outputs) == 2
+        assert lineage.inputs == [
+            Dataset(namespace=f"gs://{TEST_BUCKET}", name="test_object1.txt"),
+            Dataset(namespace=f"gs://{TEST_BUCKET}", name="test_object2.txt"),
+        ]
+        assert lineage.outputs == [
+            Dataset(namespace=f"gs://{DESTINATION_BUCKET}", 
name="foo/bar/1.txt"),
+            Dataset(namespace=f"gs://{DESTINATION_BUCKET}", 
name="foo/bar/2.txt"),
+        ]

Reply via email to