This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2d682285350 Feat ol bigquery to postgres operatorfeature: Add 
OpenLineage support for BigQueryToPostgresOperator (#55392)
2d682285350 is described below

commit 2d6822853509693977451a933e8d0b0c490e5664
Author: pawelgrochowicz <[email protected]>
AuthorDate: Fri Sep 12 14:02:58 2025 +0200

    Feat ol bigquery to postgres operatorfeature: Add OpenLineage support for 
BigQueryToPostgresOperator (#55392)
    
    * feature: Add OpenLineage support for BigQueryToPostgresOperator
    
    * feature: Add OpenLineage support for BigQueryToPostgresOperator
    
    * Revert "feature: Add OpenLineage support for BigQueryToPostgresOperator"
    
    This reverts commit 30ec57b68c9b9ecbe2a6a7488b315e482e7e0268.
    
    * feature: Add OpenLineage support for BigQueryToPostgresOperator
    
    * feature: Add OpenLineage support for BigQueryToPostgresOperator
    
    * feature: Add OpenLineage support for BigQueryToPostgresOperator
    
    * feature: Add OpenLineage support for BigQueryToPostgresOperator
    
    * feature: Add OpenLineage support for BigQueryToPostgresOperator
---
 .../google/cloud/transfers/bigquery_to_mssql.py    |  66 -------
 .../google/cloud/transfers/bigquery_to_mysql.py    |  59 -------
 .../google/cloud/transfers/bigquery_to_postgres.py |  27 ++-
 .../google/cloud/transfers/bigquery_to_sql.py      |  95 ++++++++++
 .../cloud/transfers/test_bigquery_to_mssql.py      |   4 +-
 .../cloud/transfers/test_bigquery_to_postgres.py   | 194 ++++++++++++++++++---
 6 files changed, 289 insertions(+), 156 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py
 
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py
index 0464fd79207..38fcf2f68d2 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py
@@ -25,13 +25,11 @@ from functools import cached_property
 from typing import TYPE_CHECKING
 
 from airflow.exceptions import AirflowProviderDeprecationWarning
-from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
 from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
 from airflow.providers.google.cloud.transfers.bigquery_to_sql import 
BigQueryToSqlBaseOperator
 from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
 
 if TYPE_CHECKING:
-    from airflow.providers.openlineage.extractors import OperatorLineage
     from airflow.utils.context import Context
 
 
@@ -112,67 +110,3 @@ class BigQueryToMsSqlOperator(BigQueryToSqlBaseOperator):
             project_id=project_id,
             table_id=table_id,
         )
-
-    def get_openlineage_facets_on_complete(self, task_instance) -> 
OperatorLineage | None:
-        from airflow.providers.common.compat.openlineage.facet import Dataset
-        from airflow.providers.google.cloud.openlineage.utils import (
-            BIGQUERY_NAMESPACE,
-            get_facets_from_bq_table_for_given_fields,
-            get_identity_column_lineage_facet,
-        )
-        from airflow.providers.openlineage.extractors import OperatorLineage
-
-        if not self.bigquery_hook:
-            self.bigquery_hook = BigQueryHook(
-                gcp_conn_id=self.gcp_conn_id,
-                location=self.location,
-                impersonation_chain=self.impersonation_chain,
-            )
-
-        try:
-            table_obj = 
self.bigquery_hook.get_client().get_table(self.source_project_dataset_table)
-        except Exception:
-            self.log.debug(
-                "OpenLineage: could not fetch BigQuery table %s",
-                self.source_project_dataset_table,
-                exc_info=True,
-            )
-            return OperatorLineage()
-
-        if self.selected_fields:
-            if isinstance(self.selected_fields, str):
-                bigquery_field_names = list(self.selected_fields)
-            else:
-                bigquery_field_names = self.selected_fields
-        else:
-            bigquery_field_names = [f.name for f in getattr(table_obj, 
"schema", [])]
-
-        input_dataset = Dataset(
-            namespace=BIGQUERY_NAMESPACE,
-            name=self.source_project_dataset_table,
-            facets=get_facets_from_bq_table_for_given_fields(table_obj, 
bigquery_field_names),
-        )
-
-        db_info = 
self.mssql_hook.get_openlineage_database_info(self.mssql_hook.get_conn())
-        default_schema = self.mssql_hook.get_openlineage_default_schema()
-        namespace = f"{db_info.scheme}://{db_info.authority}"
-
-        if self.target_table_name and "." in self.target_table_name:
-            schema_name, table_name = self.target_table_name.split(".", 1)
-        else:
-            schema_name = default_schema or ""
-            table_name = self.target_table_name or ""
-
-        if self.database:
-            output_name = f"{self.database}.{schema_name}.{table_name}"
-        else:
-            output_name = f"{schema_name}.{table_name}"
-
-        column_lineage_facet = get_identity_column_lineage_facet(
-            bigquery_field_names, input_datasets=[input_dataset]
-        )
-
-        output_facets = column_lineage_facet or {}
-        output_dataset = Dataset(namespace=namespace, name=output_name, 
facets=output_facets)
-
-        return OperatorLineage(inputs=[input_dataset], 
outputs=[output_dataset])
diff --git 
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mysql.py
 
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mysql.py
index 544d33f275f..56c3dd8ccc2 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mysql.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mysql.py
@@ -22,16 +22,11 @@ from __future__ import annotations
 import warnings
 from collections.abc import Sequence
 from functools import cached_property
-from typing import TYPE_CHECKING
 
 from airflow.exceptions import AirflowProviderDeprecationWarning
-from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
 from airflow.providers.google.cloud.transfers.bigquery_to_sql import 
BigQueryToSqlBaseOperator
 from airflow.providers.mysql.hooks.mysql import MySqlHook
 
-if TYPE_CHECKING:
-    from airflow.providers.openlineage.extractors import OperatorLineage
-
 
 class BigQueryToMySqlOperator(BigQueryToSqlBaseOperator):
     """
@@ -94,57 +89,3 @@ class BigQueryToMySqlOperator(BigQueryToSqlBaseOperator):
         project_id = self.bigquery_hook.project_id
         self.source_project_dataset_table = 
f"{project_id}.{self.dataset_id}.{self.table_id}"
         return super().execute(context)
-
-    def get_openlineage_facets_on_complete(self, task_instance) -> 
OperatorLineage | None:
-        from airflow.providers.common.compat.openlineage.facet import Dataset
-        from airflow.providers.google.cloud.openlineage.utils import (
-            BIGQUERY_NAMESPACE,
-            get_facets_from_bq_table_for_given_fields,
-            get_identity_column_lineage_facet,
-        )
-        from airflow.providers.openlineage.extractors import OperatorLineage
-
-        if not self.bigquery_hook:
-            self.bigquery_hook = BigQueryHook(
-                gcp_conn_id=self.gcp_conn_id,
-                location=self.location,
-                impersonation_chain=self.impersonation_chain,
-            )
-
-        try:
-            table_obj = 
self.bigquery_hook.get_client().get_table(self.source_project_dataset_table)
-        except Exception:
-            self.log.debug(
-                "OpenLineage: could not fetch BigQuery table %s",
-                self.source_project_dataset_table,
-                exc_info=True,
-            )
-            return OperatorLineage()
-
-        if self.selected_fields:
-            if isinstance(self.selected_fields, str):
-                bigquery_field_names = list(self.selected_fields)
-            else:
-                bigquery_field_names = self.selected_fields
-        else:
-            bigquery_field_names = [f.name for f in getattr(table_obj, 
"schema", [])]
-
-        input_dataset = Dataset(
-            namespace=BIGQUERY_NAMESPACE,
-            name=self.source_project_dataset_table,
-            facets=get_facets_from_bq_table_for_given_fields(table_obj, 
bigquery_field_names),
-        )
-
-        db_info = 
self.mysql_hook.get_openlineage_database_info(self.mysql_hook.get_conn())
-        namespace = f"{db_info.scheme}://{db_info.authority}"
-
-        output_name = f"{self.database}.{self.target_table_name}"
-
-        column_lineage_facet = get_identity_column_lineage_facet(
-            bigquery_field_names, input_datasets=[input_dataset]
-        )
-
-        output_facets = column_lineage_facet or {}
-        output_dataset = Dataset(namespace=namespace, name=output_name, 
facets=output_facets)
-
-        return OperatorLineage(inputs=[input_dataset], 
outputs=[output_dataset])
diff --git 
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_postgres.py
 
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_postgres.py
index cb0baf3d199..c298d5bfe9b 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_postgres.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_postgres.py
@@ -19,6 +19,7 @@
 
 from __future__ import annotations
 
+from functools import cached_property
 from typing import TYPE_CHECKING
 
 from psycopg2.extensions import register_adapter
@@ -78,28 +79,36 @@ class BigQueryToPostgresOperator(BigQueryToSqlBaseOperator):
         self.postgres_conn_id = postgres_conn_id
         self.replace_index = replace_index
 
-    def get_sql_hook(self) -> PostgresHook:
+    @cached_property
+    def postgres_hook(self) -> PostgresHook:
         register_adapter(list, Json)
         register_adapter(dict, Json)
         return PostgresHook(database=self.database, 
postgres_conn_id=self.postgres_conn_id)
 
+    def get_sql_hook(self) -> PostgresHook:
+        return self.postgres_hook
+
     def execute(self, context: Context) -> None:
-        big_query_hook = BigQueryHook(
-            gcp_conn_id=self.gcp_conn_id,
-            location=self.location,
-            impersonation_chain=self.impersonation_chain,
-        )
+        if not self.bigquery_hook:
+            self.bigquery_hook = BigQueryHook(
+                gcp_conn_id=self.gcp_conn_id,
+                location=self.location,
+                impersonation_chain=self.impersonation_chain,
+            )
+        # Set source_project_dataset_table here, after hooks are initialized 
and project_id is available
+        project_id = self.bigquery_hook.project_id
+        self.source_project_dataset_table = 
f"{project_id}.{self.dataset_id}.{self.table_id}"
+
         self.persist_links(context)
-        sql_hook: PostgresHook = self.get_sql_hook()
         for rows in bigquery_get_data(
             self.log,
             self.dataset_id,
             self.table_id,
-            big_query_hook,
+            self.bigquery_hook,
             self.batch_size,
             self.selected_fields,
         ):
-            sql_hook.insert_rows(
+            self.postgres_hook.insert_rows(
                 table=self.target_table_name,
                 rows=rows,
                 target_fields=self.selected_fields,
diff --git 
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py
 
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py
index 20a9f8edc30..a76f34eae67 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py
@@ -30,6 +30,7 @@ from airflow.providers.google.version_compat import 
BaseOperator
 
 if TYPE_CHECKING:
     from airflow.providers.common.sql.hooks.sql import DbApiHook
+    from airflow.providers.openlineage.extractors import OperatorLineage
     from airflow.utils.context import Context
 
 
@@ -140,3 +141,97 @@ class BigQueryToSqlBaseOperator(BaseOperator):
                 replace=self.replace,
                 commit_every=self.batch_size,
             )
+
+    def get_openlineage_facets_on_complete(self, task_instance) -> 
OperatorLineage | None:
+        """
+        Build a generic OpenLineage facet for BigQuery -> SQL transfers.
+
+        This consolidates nearly identical implementations from child
+        operators. Children still provide a concrete SQL hook via
+        ``get_sql_hook()`` and may override behavior if needed.
+        """
+        from airflow.providers.common.compat.openlineage.facet import Dataset
+        from airflow.providers.google.cloud.openlineage.utils import (
+            BIGQUERY_NAMESPACE,
+            get_facets_from_bq_table_for_given_fields,
+            get_identity_column_lineage_facet,
+        )
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        if not self.bigquery_hook:
+            self.bigquery_hook = BigQueryHook(
+                gcp_conn_id=self.gcp_conn_id,
+                location=self.location,
+                impersonation_chain=self.impersonation_chain,
+            )
+
+        try:
+            if not getattr(self, "source_project_dataset_table", None):
+                project_id = self.bigquery_hook.project_id
+                self.source_project_dataset_table = 
f"{project_id}.{self.dataset_id}.{self.table_id}"
+
+            table_obj = 
self.bigquery_hook.get_client().get_table(self.source_project_dataset_table)
+        except Exception:
+            self.log.debug(
+                "OpenLineage: could not fetch BigQuery table %s",
+                getattr(self, "source_project_dataset_table", None),
+                exc_info=True,
+            )
+            return OperatorLineage()
+
+        if self.selected_fields:
+            if isinstance(self.selected_fields, str):
+                bigquery_field_names = list(self.selected_fields)
+            else:
+                bigquery_field_names = self.selected_fields
+        else:
+            bigquery_field_names = [f.name for f in getattr(table_obj, 
"schema", [])]
+
+        input_dataset = Dataset(
+            namespace=BIGQUERY_NAMESPACE,
+            name=self.source_project_dataset_table,
+            facets=get_facets_from_bq_table_for_given_fields(table_obj, 
bigquery_field_names),
+        )
+
+        sql_hook = self.get_sql_hook()
+        db_info = sql_hook.get_openlineage_database_info(sql_hook.get_conn())
+        if db_info is None:
+            self.log.debug("OpenLineage: could not get database info from SQL 
hook %s", type(sql_hook))
+            return OperatorLineage()
+        namespace = f"{db_info.scheme}://{db_info.authority}"
+
+        schema_name = None
+        if hasattr(sql_hook, "get_openlineage_default_schema"):
+            try:
+                schema_name = sql_hook.get_openlineage_default_schema()
+            except Exception:
+                schema_name = None
+
+        if self.target_table_name and "." in self.target_table_name:
+            schema_part, table_part = self.target_table_name.split(".", 1)
+        else:
+            schema_part = schema_name or ""
+            table_part = self.target_table_name or ""
+
+        if db_info and db_info.scheme == "mysql":
+            output_name = f"{self.database}.{table_part}" if self.database 
else f"{table_part}"
+        else:
+            if self.database:
+                if schema_part:
+                    output_name = f"{self.database}.{schema_part}.{table_part}"
+                else:
+                    output_name = f"{self.database}.{table_part}"
+            else:
+                if schema_part:
+                    output_name = f"{schema_part}.{table_part}"
+                else:
+                    output_name = f"{table_part}"
+
+        column_lineage_facet = get_identity_column_lineage_facet(
+            bigquery_field_names, input_datasets=[input_dataset]
+        )
+
+        output_facets = column_lineage_facet or {}
+        output_dataset = Dataset(namespace=namespace, name=output_name, 
facets=output_facets)
+
+        return OperatorLineage(inputs=[input_dataset], 
outputs=[output_dataset])
diff --git 
a/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_mssql.py 
b/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_mssql.py
index db26444c1fd..a9c4668b50e 100644
--- 
a/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_mssql.py
+++ 
b/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_mssql.py
@@ -106,7 +106,7 @@ class TestBigQueryToMsSqlOperator:
         )
 
     
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.MsSqlHook")
-    
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryHook")
+    
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook")
     def test_get_openlineage_facets_on_complete_no_selected_fields(self, 
mock_bq_hook, mock_mssql_hook):
         mock_bq_client = MagicMock()
         table_obj = _make_bq_table(["id", "name", "value"])
@@ -152,7 +152,7 @@ class TestBigQueryToMsSqlOperator:
         assert set(col_lineage.fields.keys()) == {"id", "name", "value"}
 
     
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.MsSqlHook")
-    
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryHook")
+    
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook")
     def test_get_openlineage_facets_on_complete_selected_fields(self, 
mock_bq_hook, mock_mssql_hook):
         mock_bq_client = MagicMock()
         table_obj = _make_bq_table(["id", "name", "value"])
diff --git 
a/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_postgres.py
 
b/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_postgres.py
index 79003313411..41d0549d376 100644
--- 
a/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_postgres.py
+++ 
b/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_postgres.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 from unittest import mock
+from unittest.mock import MagicMock
 
 import pytest
 from psycopg2.extras import Json
@@ -29,11 +30,32 @@ TEST_DATASET = "test-dataset"
 TEST_TABLE_ID = "test-table-id"
 TEST_DAG_ID = "test-bigquery-operators"
 TEST_DESTINATION_TABLE = "table"
+TEST_PROJECT = "test-project"
+
+
+def _make_bq_table(schema_names: list[str]):
+    class TableObj:
+        def __init__(self, schema):
+            self.schema = []
+            for n in schema:
+                field = MagicMock()
+                field.name = n
+                self.schema.append(field)
+            self.description = "table description"
+            self.external_data_configuration = None
+            self.labels = {}
+            self.num_rows = 0
+            self.num_bytes = 0
+            self.table_type = "TABLE"
+
+    return TableObj(schema_names)
 
 
 class TestBigQueryToPostgresOperator:
-    
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.BigQueryHook")
-    def test_execute_good_request_to_bq(self, mock_hook):
+    
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.bigquery_get_data")
+    @mock.patch.object(BigQueryToPostgresOperator, "bigquery_hook", 
new_callable=mock.PropertyMock)
+    @mock.patch.object(BigQueryToPostgresOperator, "postgres_hook", 
new_callable=mock.PropertyMock)
+    def test_execute_good_request_to_bq(self, mock_pg_hook, mock_bq_hook, 
mock_bigquery_get_data):
         operator = BigQueryToPostgresOperator(
             task_id=TASK_ID,
             dataset_table=f"{TEST_DATASET}.{TEST_TABLE_ID}",
@@ -41,17 +63,34 @@ class TestBigQueryToPostgresOperator:
             replace=False,
         )
 
+        mock_bigquery_get_data.return_value = [[("row1", "val1")], [("row2", 
"val2")]]
+        mock_pg = mock.MagicMock()
+        mock_pg_hook.return_value = mock_pg
+        mock_bq = mock.MagicMock()
+        mock_bq.project_id = TEST_PROJECT
+        mock_bq_hook.return_value = mock_bq
+
         operator.execute(context=mock.MagicMock())
-        mock_hook.return_value.list_rows.assert_called_once_with(
-            dataset_id=TEST_DATASET,
-            table_id=TEST_TABLE_ID,
-            max_results=1000,
-            selected_fields=None,
-            start_index=0,
+
+        mock_bigquery_get_data.assert_called_once_with(
+            operator.log,
+            TEST_DATASET,
+            TEST_TABLE_ID,
+            mock_bq,
+            operator.batch_size,
+            operator.selected_fields,
         )
+        assert mock_pg.insert_rows.call_count == 2
 
-    
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.BigQueryHook")
-    def test_execute_good_request_to_bq__with_replace(self, mock_hook):
+    
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.bigquery_get_data")
+    @mock.patch.object(BigQueryToPostgresOperator, "bigquery_hook", 
new_callable=mock.PropertyMock)
+    @mock.patch.object(BigQueryToPostgresOperator, "postgres_hook", 
new_callable=mock.PropertyMock)
+    def test_execute_good_request_to_bq_with_replace(
+        self,
+        mock_pg_hook,
+        mock_bq_hook,
+        mock_bigquery_get_data,
+    ):
         operator = BigQueryToPostgresOperator(
             task_id=TASK_ID,
             dataset_table=f"{TEST_DATASET}.{TEST_TABLE_ID}",
@@ -61,13 +100,30 @@ class TestBigQueryToPostgresOperator:
             replace_index=["col_1"],
         )
 
+        mock_bigquery_get_data.return_value = [[("only_row", "val")]]
+        mock_pg = mock.MagicMock()
+        mock_pg_hook.return_value = mock_pg
+        mock_bq = mock.MagicMock()
+        mock_bq.project_id = TEST_PROJECT
+        mock_bq_hook.return_value = mock_bq
+
         operator.execute(context=mock.MagicMock())
-        mock_hook.return_value.list_rows.assert_called_once_with(
-            dataset_id=TEST_DATASET,
-            table_id=TEST_TABLE_ID,
-            max_results=1000,
-            selected_fields=["col_1", "col_2"],
-            start_index=0,
+
+        mock_bigquery_get_data.assert_called_once_with(
+            operator.log,
+            TEST_DATASET,
+            TEST_TABLE_ID,
+            mock_bq,
+            operator.batch_size,
+            ["col_1", "col_2"],
+        )
+        mock_pg.insert_rows.assert_called_once_with(
+            table=TEST_DESTINATION_TABLE,
+            rows=[("only_row", "val")],
+            target_fields=["col_1", "col_2"],
+            replace=True,
+            commit_every=operator.batch_size,
+            replace_index=["col_1"],
         )
 
     @pytest.mark.parametrize(
@@ -87,15 +143,113 @@ class TestBigQueryToPostgresOperator:
                 replace_index=replace_index,
             )
 
+    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client")
+    @mock.patch(
+        
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id"
+    )
     
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.register_adapter")
-    
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.BigQueryHook")
-    def test_adapters_to_json_registered(self, mock_hook, 
mock_register_adapter):
-        BigQueryToPostgresOperator(
+    def test_adapters_to_json_registered(self, mock_register_adapter, 
mock_get_creds, mock_get_client):
+        mock_get_creds.return_value = (None, TEST_PROJECT)
+        client = MagicMock()
+        client.list_rows.return_value = []
+        mock_get_client.return_value = client
+
+        operator = BigQueryToPostgresOperator(
             task_id=TASK_ID,
             dataset_table=f"{TEST_DATASET}.{TEST_TABLE_ID}",
             target_table_name=TEST_DESTINATION_TABLE,
             replace=False,
-        ).execute(context=mock.MagicMock())
+        )
+        operator.postgres_hook
 
         mock_register_adapter.assert_any_call(list, Json)
         mock_register_adapter.assert_any_call(dict, Json)
+
+    
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.PostgresHook")
+    
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook")
+    def test_get_openlineage_facets_on_complete_no_selected_fields(self, 
mock_bq_hook, mock_postgres_hook):
+        mock_bq_client = MagicMock()
+        mock_bq_client.get_table.return_value = _make_bq_table(["id", "name", 
"value"])
+        mock_bq_hook.get_client.return_value = mock_bq_client
+        mock_bq_hook.return_value = mock_bq_hook
+
+        db_info = MagicMock(scheme="postgres", authority="localhost:5432", 
database="postgresdb")
+        mock_postgres_hook.get_openlineage_database_info.return_value = db_info
+        mock_postgres_hook.get_openlineage_default_schema.return_value = 
"postgres-schema"
+        mock_postgres_hook.return_value = mock_postgres_hook
+
+        op = BigQueryToPostgresOperator(
+            task_id=TASK_ID,
+            dataset_table=f"{TEST_DATASET}.{TEST_TABLE_ID}",
+            target_table_name="destination",
+            selected_fields=None,
+            database="postgresdb",
+        )
+        op.bigquery_hook = mock_bq_hook
+        op.bigquery_hook.project_id = TEST_PROJECT
+        op.postgres_hook = mock_postgres_hook
+        context = mock.MagicMock()
+        op.execute(context=context)
+
+        result = 
op.get_openlineage_facets_on_complete(task_instance=MagicMock())
+        assert len(result.inputs) == 1
+        assert len(result.outputs) == 1
+
+        input_ds = result.inputs[0]
+        assert input_ds.namespace == "bigquery"
+        assert input_ds.name == 
f"{TEST_PROJECT}.{TEST_DATASET}.{TEST_TABLE_ID}"
+        assert "schema" in input_ds.facets
+        schema_fields = [f.name for f in input_ds.facets["schema"].fields]
+        assert set(schema_fields) == {"id", "name", "value"}
+
+        output_ds = result.outputs[0]
+        assert output_ds.namespace == "postgres://localhost:5432"
+        assert output_ds.name == "postgresdb.postgres-schema.destination"
+
+        assert "columnLineage" in output_ds.facets
+        col_lineage = output_ds.facets["columnLineage"]
+        assert set(col_lineage.fields.keys()) == {"id", "name", "value"}
+
+    
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.PostgresHook")
+    
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook")
+    def test_get_openlineage_facets_on_complete_selected_fields(self, 
mock_bq_hook, mock_postgres_hook):
+        mock_bq_client = MagicMock()
+        mock_bq_client.get_table.return_value = _make_bq_table(["id", "name", 
"value"])
+        mock_bq_hook.get_client.return_value = mock_bq_client
+        mock_bq_hook.return_value = mock_bq_hook
+
+        db_info = MagicMock(scheme="postgres", authority="localhost:5432", 
database="postgresdb")
+        mock_postgres_hook.get_openlineage_database_info.return_value = db_info
+        mock_postgres_hook.get_openlineage_default_schema.return_value = 
"postgres-schema"
+        mock_postgres_hook.return_value = mock_postgres_hook
+
+        op = BigQueryToPostgresOperator(
+            task_id=TASK_ID,
+            dataset_table=f"{TEST_DATASET}.{TEST_TABLE_ID}",
+            target_table_name="destination",
+            selected_fields=["id", "name"],
+            database="postgresdb",
+        )
+        op.bigquery_hook = mock_bq_hook
+        op.bigquery_hook.project_id = TEST_PROJECT
+        op.postgres_hook = mock_postgres_hook
+        context = mock.MagicMock()
+        op.execute(context=context)
+
+        result = 
op.get_openlineage_facets_on_complete(task_instance=MagicMock())
+        assert len(result.inputs) == 1
+        assert len(result.outputs) == 1
+
+        input_ds = result.inputs[0]
+        assert input_ds.namespace == "bigquery"
+        assert input_ds.name == 
f"{TEST_PROJECT}.{TEST_DATASET}.{TEST_TABLE_ID}"
+        assert "schema" in input_ds.facets
+        schema_fields = [f.name for f in input_ds.facets["schema"].fields]
+        assert set(schema_fields) == {"id", "name"}
+
+        output_ds = result.outputs[0]
+        assert output_ds.namespace == "postgres://localhost:5432"
+        assert output_ds.name == "postgresdb.postgres-schema.destination"
+        assert "columnLineage" in output_ds.facets
+        col_lineage = output_ds.facets["columnLineage"]
+        assert set(col_lineage.fields.keys()) == {"id", "name"}

Reply via email to