This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2d682285350 Feat ol bigquery to postgres operatorfeature: Add
OpenLineage support for BigQueryToPostgresOperator (#55392)
2d682285350 is described below
commit 2d6822853509693977451a933e8d0b0c490e5664
Author: pawelgrochowicz <[email protected]>
AuthorDate: Fri Sep 12 14:02:58 2025 +0200
Feat ol bigquery to postgres operatorfeature: Add OpenLineage support for
BigQueryToPostgresOperator (#55392)
* feature: Add OpenLineage support for BigQueryToPostgresOperator
* feature: Add OpenLineage support for BigQueryToPostgresOperator
* Revert "feature: Add OpenLineage support for BigQueryToPostgresOperator"
This reverts commit 30ec57b68c9b9ecbe2a6a7488b315e482e7e0268.
* feature: Add OpenLineage support for BigQueryToPostgresOperator
* feature: Add OpenLineage support for BigQueryToPostgresOperator
* feature: Add OpenLineage support for BigQueryToPostgresOperator
* feature: Add OpenLineage support for BigQueryToPostgresOperator
* feature: Add OpenLineage support for BigQueryToPostgresOperator
---
.../google/cloud/transfers/bigquery_to_mssql.py | 66 -------
.../google/cloud/transfers/bigquery_to_mysql.py | 59 -------
.../google/cloud/transfers/bigquery_to_postgres.py | 27 ++-
.../google/cloud/transfers/bigquery_to_sql.py | 95 ++++++++++
.../cloud/transfers/test_bigquery_to_mssql.py | 4 +-
.../cloud/transfers/test_bigquery_to_postgres.py | 194 ++++++++++++++++++---
6 files changed, 289 insertions(+), 156 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py
index 0464fd79207..38fcf2f68d2 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py
@@ -25,13 +25,11 @@ from functools import cached_property
from typing import TYPE_CHECKING
from airflow.exceptions import AirflowProviderDeprecationWarning
-from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
from airflow.providers.google.cloud.transfers.bigquery_to_sql import
BigQueryToSqlBaseOperator
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
if TYPE_CHECKING:
- from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.utils.context import Context
@@ -112,67 +110,3 @@ class BigQueryToMsSqlOperator(BigQueryToSqlBaseOperator):
project_id=project_id,
table_id=table_id,
)
-
- def get_openlineage_facets_on_complete(self, task_instance) ->
OperatorLineage | None:
- from airflow.providers.common.compat.openlineage.facet import Dataset
- from airflow.providers.google.cloud.openlineage.utils import (
- BIGQUERY_NAMESPACE,
- get_facets_from_bq_table_for_given_fields,
- get_identity_column_lineage_facet,
- )
- from airflow.providers.openlineage.extractors import OperatorLineage
-
- if not self.bigquery_hook:
- self.bigquery_hook = BigQueryHook(
- gcp_conn_id=self.gcp_conn_id,
- location=self.location,
- impersonation_chain=self.impersonation_chain,
- )
-
- try:
- table_obj =
self.bigquery_hook.get_client().get_table(self.source_project_dataset_table)
- except Exception:
- self.log.debug(
- "OpenLineage: could not fetch BigQuery table %s",
- self.source_project_dataset_table,
- exc_info=True,
- )
- return OperatorLineage()
-
- if self.selected_fields:
- if isinstance(self.selected_fields, str):
- bigquery_field_names = list(self.selected_fields)
- else:
- bigquery_field_names = self.selected_fields
- else:
- bigquery_field_names = [f.name for f in getattr(table_obj,
"schema", [])]
-
- input_dataset = Dataset(
- namespace=BIGQUERY_NAMESPACE,
- name=self.source_project_dataset_table,
- facets=get_facets_from_bq_table_for_given_fields(table_obj,
bigquery_field_names),
- )
-
- db_info =
self.mssql_hook.get_openlineage_database_info(self.mssql_hook.get_conn())
- default_schema = self.mssql_hook.get_openlineage_default_schema()
- namespace = f"{db_info.scheme}://{db_info.authority}"
-
- if self.target_table_name and "." in self.target_table_name:
- schema_name, table_name = self.target_table_name.split(".", 1)
- else:
- schema_name = default_schema or ""
- table_name = self.target_table_name or ""
-
- if self.database:
- output_name = f"{self.database}.{schema_name}.{table_name}"
- else:
- output_name = f"{schema_name}.{table_name}"
-
- column_lineage_facet = get_identity_column_lineage_facet(
- bigquery_field_names, input_datasets=[input_dataset]
- )
-
- output_facets = column_lineage_facet or {}
- output_dataset = Dataset(namespace=namespace, name=output_name,
facets=output_facets)
-
- return OperatorLineage(inputs=[input_dataset],
outputs=[output_dataset])
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mysql.py
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mysql.py
index 544d33f275f..56c3dd8ccc2 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mysql.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mysql.py
@@ -22,16 +22,11 @@ from __future__ import annotations
import warnings
from collections.abc import Sequence
from functools import cached_property
-from typing import TYPE_CHECKING
from airflow.exceptions import AirflowProviderDeprecationWarning
-from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.transfers.bigquery_to_sql import
BigQueryToSqlBaseOperator
from airflow.providers.mysql.hooks.mysql import MySqlHook
-if TYPE_CHECKING:
- from airflow.providers.openlineage.extractors import OperatorLineage
-
class BigQueryToMySqlOperator(BigQueryToSqlBaseOperator):
"""
@@ -94,57 +89,3 @@ class BigQueryToMySqlOperator(BigQueryToSqlBaseOperator):
project_id = self.bigquery_hook.project_id
self.source_project_dataset_table =
f"{project_id}.{self.dataset_id}.{self.table_id}"
return super().execute(context)
-
- def get_openlineage_facets_on_complete(self, task_instance) ->
OperatorLineage | None:
- from airflow.providers.common.compat.openlineage.facet import Dataset
- from airflow.providers.google.cloud.openlineage.utils import (
- BIGQUERY_NAMESPACE,
- get_facets_from_bq_table_for_given_fields,
- get_identity_column_lineage_facet,
- )
- from airflow.providers.openlineage.extractors import OperatorLineage
-
- if not self.bigquery_hook:
- self.bigquery_hook = BigQueryHook(
- gcp_conn_id=self.gcp_conn_id,
- location=self.location,
- impersonation_chain=self.impersonation_chain,
- )
-
- try:
- table_obj =
self.bigquery_hook.get_client().get_table(self.source_project_dataset_table)
- except Exception:
- self.log.debug(
- "OpenLineage: could not fetch BigQuery table %s",
- self.source_project_dataset_table,
- exc_info=True,
- )
- return OperatorLineage()
-
- if self.selected_fields:
- if isinstance(self.selected_fields, str):
- bigquery_field_names = list(self.selected_fields)
- else:
- bigquery_field_names = self.selected_fields
- else:
- bigquery_field_names = [f.name for f in getattr(table_obj,
"schema", [])]
-
- input_dataset = Dataset(
- namespace=BIGQUERY_NAMESPACE,
- name=self.source_project_dataset_table,
- facets=get_facets_from_bq_table_for_given_fields(table_obj,
bigquery_field_names),
- )
-
- db_info =
self.mysql_hook.get_openlineage_database_info(self.mysql_hook.get_conn())
- namespace = f"{db_info.scheme}://{db_info.authority}"
-
- output_name = f"{self.database}.{self.target_table_name}"
-
- column_lineage_facet = get_identity_column_lineage_facet(
- bigquery_field_names, input_datasets=[input_dataset]
- )
-
- output_facets = column_lineage_facet or {}
- output_dataset = Dataset(namespace=namespace, name=output_name,
facets=output_facets)
-
- return OperatorLineage(inputs=[input_dataset],
outputs=[output_dataset])
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_postgres.py
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_postgres.py
index cb0baf3d199..c298d5bfe9b 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_postgres.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_postgres.py
@@ -19,6 +19,7 @@
from __future__ import annotations
+from functools import cached_property
from typing import TYPE_CHECKING
from psycopg2.extensions import register_adapter
@@ -78,28 +79,36 @@ class BigQueryToPostgresOperator(BigQueryToSqlBaseOperator):
self.postgres_conn_id = postgres_conn_id
self.replace_index = replace_index
- def get_sql_hook(self) -> PostgresHook:
+ @cached_property
+ def postgres_hook(self) -> PostgresHook:
register_adapter(list, Json)
register_adapter(dict, Json)
return PostgresHook(database=self.database,
postgres_conn_id=self.postgres_conn_id)
+ def get_sql_hook(self) -> PostgresHook:
+ return self.postgres_hook
+
def execute(self, context: Context) -> None:
- big_query_hook = BigQueryHook(
- gcp_conn_id=self.gcp_conn_id,
- location=self.location,
- impersonation_chain=self.impersonation_chain,
- )
+ if not self.bigquery_hook:
+ self.bigquery_hook = BigQueryHook(
+ gcp_conn_id=self.gcp_conn_id,
+ location=self.location,
+ impersonation_chain=self.impersonation_chain,
+ )
+ # Set source_project_dataset_table here, after hooks are initialized
and project_id is available
+ project_id = self.bigquery_hook.project_id
+ self.source_project_dataset_table =
f"{project_id}.{self.dataset_id}.{self.table_id}"
+
self.persist_links(context)
- sql_hook: PostgresHook = self.get_sql_hook()
for rows in bigquery_get_data(
self.log,
self.dataset_id,
self.table_id,
- big_query_hook,
+ self.bigquery_hook,
self.batch_size,
self.selected_fields,
):
- sql_hook.insert_rows(
+ self.postgres_hook.insert_rows(
table=self.target_table_name,
rows=rows,
target_fields=self.selected_fields,
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py
index 20a9f8edc30..a76f34eae67 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py
@@ -30,6 +30,7 @@ from airflow.providers.google.version_compat import
BaseOperator
if TYPE_CHECKING:
from airflow.providers.common.sql.hooks.sql import DbApiHook
+ from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.utils.context import Context
@@ -140,3 +141,97 @@ class BigQueryToSqlBaseOperator(BaseOperator):
replace=self.replace,
commit_every=self.batch_size,
)
+
+ def get_openlineage_facets_on_complete(self, task_instance) ->
OperatorLineage | None:
+ """
+ Build a generic OpenLineage facet for BigQuery -> SQL transfers.
+
+ This consolidates nearly identical implementations from child
+ operators. Children still provide a concrete SQL hook via
+ ``get_sql_hook()`` and may override behavior if needed.
+ """
+ from airflow.providers.common.compat.openlineage.facet import Dataset
+ from airflow.providers.google.cloud.openlineage.utils import (
+ BIGQUERY_NAMESPACE,
+ get_facets_from_bq_table_for_given_fields,
+ get_identity_column_lineage_facet,
+ )
+ from airflow.providers.openlineage.extractors import OperatorLineage
+
+ if not self.bigquery_hook:
+ self.bigquery_hook = BigQueryHook(
+ gcp_conn_id=self.gcp_conn_id,
+ location=self.location,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ try:
+ if not getattr(self, "source_project_dataset_table", None):
+ project_id = self.bigquery_hook.project_id
+ self.source_project_dataset_table =
f"{project_id}.{self.dataset_id}.{self.table_id}"
+
+ table_obj =
self.bigquery_hook.get_client().get_table(self.source_project_dataset_table)
+ except Exception:
+ self.log.debug(
+ "OpenLineage: could not fetch BigQuery table %s",
+ getattr(self, "source_project_dataset_table", None),
+ exc_info=True,
+ )
+ return OperatorLineage()
+
+ if self.selected_fields:
+ if isinstance(self.selected_fields, str):
+ bigquery_field_names = list(self.selected_fields)
+ else:
+ bigquery_field_names = self.selected_fields
+ else:
+ bigquery_field_names = [f.name for f in getattr(table_obj,
"schema", [])]
+
+ input_dataset = Dataset(
+ namespace=BIGQUERY_NAMESPACE,
+ name=self.source_project_dataset_table,
+ facets=get_facets_from_bq_table_for_given_fields(table_obj,
bigquery_field_names),
+ )
+
+ sql_hook = self.get_sql_hook()
+ db_info = sql_hook.get_openlineage_database_info(sql_hook.get_conn())
+ if db_info is None:
+ self.log.debug("OpenLineage: could not get database info from SQL
hook %s", type(sql_hook))
+ return OperatorLineage()
+ namespace = f"{db_info.scheme}://{db_info.authority}"
+
+ schema_name = None
+ if hasattr(sql_hook, "get_openlineage_default_schema"):
+ try:
+ schema_name = sql_hook.get_openlineage_default_schema()
+ except Exception:
+ schema_name = None
+
+ if self.target_table_name and "." in self.target_table_name:
+ schema_part, table_part = self.target_table_name.split(".", 1)
+ else:
+ schema_part = schema_name or ""
+ table_part = self.target_table_name or ""
+
+ if db_info and db_info.scheme == "mysql":
+ output_name = f"{self.database}.{table_part}" if self.database
else f"{table_part}"
+ else:
+ if self.database:
+ if schema_part:
+ output_name = f"{self.database}.{schema_part}.{table_part}"
+ else:
+ output_name = f"{self.database}.{table_part}"
+ else:
+ if schema_part:
+ output_name = f"{schema_part}.{table_part}"
+ else:
+ output_name = f"{table_part}"
+
+ column_lineage_facet = get_identity_column_lineage_facet(
+ bigquery_field_names, input_datasets=[input_dataset]
+ )
+
+ output_facets = column_lineage_facet or {}
+ output_dataset = Dataset(namespace=namespace, name=output_name,
facets=output_facets)
+
+ return OperatorLineage(inputs=[input_dataset],
outputs=[output_dataset])
diff --git
a/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_mssql.py
b/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_mssql.py
index db26444c1fd..a9c4668b50e 100644
---
a/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_mssql.py
+++
b/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_mssql.py
@@ -106,7 +106,7 @@ class TestBigQueryToMsSqlOperator:
)
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.MsSqlHook")
-
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryHook")
+
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook")
def test_get_openlineage_facets_on_complete_no_selected_fields(self,
mock_bq_hook, mock_mssql_hook):
mock_bq_client = MagicMock()
table_obj = _make_bq_table(["id", "name", "value"])
@@ -152,7 +152,7 @@ class TestBigQueryToMsSqlOperator:
assert set(col_lineage.fields.keys()) == {"id", "name", "value"}
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.MsSqlHook")
-
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryHook")
+
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook")
def test_get_openlineage_facets_on_complete_selected_fields(self,
mock_bq_hook, mock_mssql_hook):
mock_bq_client = MagicMock()
table_obj = _make_bq_table(["id", "name", "value"])
diff --git
a/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_postgres.py
b/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_postgres.py
index 79003313411..41d0549d376 100644
---
a/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_postgres.py
+++
b/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_postgres.py
@@ -18,6 +18,7 @@
from __future__ import annotations
from unittest import mock
+from unittest.mock import MagicMock
import pytest
from psycopg2.extras import Json
@@ -29,11 +30,32 @@ TEST_DATASET = "test-dataset"
TEST_TABLE_ID = "test-table-id"
TEST_DAG_ID = "test-bigquery-operators"
TEST_DESTINATION_TABLE = "table"
+TEST_PROJECT = "test-project"
+
+
+def _make_bq_table(schema_names: list[str]):
+ class TableObj:
+ def __init__(self, schema):
+ self.schema = []
+ for n in schema:
+ field = MagicMock()
+ field.name = n
+ self.schema.append(field)
+ self.description = "table description"
+ self.external_data_configuration = None
+ self.labels = {}
+ self.num_rows = 0
+ self.num_bytes = 0
+ self.table_type = "TABLE"
+
+ return TableObj(schema_names)
class TestBigQueryToPostgresOperator:
-
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.BigQueryHook")
- def test_execute_good_request_to_bq(self, mock_hook):
+
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.bigquery_get_data")
+ @mock.patch.object(BigQueryToPostgresOperator, "bigquery_hook",
new_callable=mock.PropertyMock)
+ @mock.patch.object(BigQueryToPostgresOperator, "postgres_hook",
new_callable=mock.PropertyMock)
+ def test_execute_good_request_to_bq(self, mock_pg_hook, mock_bq_hook,
mock_bigquery_get_data):
operator = BigQueryToPostgresOperator(
task_id=TASK_ID,
dataset_table=f"{TEST_DATASET}.{TEST_TABLE_ID}",
@@ -41,17 +63,34 @@ class TestBigQueryToPostgresOperator:
replace=False,
)
+ mock_bigquery_get_data.return_value = [[("row1", "val1")], [("row2",
"val2")]]
+ mock_pg = mock.MagicMock()
+ mock_pg_hook.return_value = mock_pg
+ mock_bq = mock.MagicMock()
+ mock_bq.project_id = TEST_PROJECT
+ mock_bq_hook.return_value = mock_bq
+
operator.execute(context=mock.MagicMock())
- mock_hook.return_value.list_rows.assert_called_once_with(
- dataset_id=TEST_DATASET,
- table_id=TEST_TABLE_ID,
- max_results=1000,
- selected_fields=None,
- start_index=0,
+
+ mock_bigquery_get_data.assert_called_once_with(
+ operator.log,
+ TEST_DATASET,
+ TEST_TABLE_ID,
+ mock_bq,
+ operator.batch_size,
+ operator.selected_fields,
)
+ assert mock_pg.insert_rows.call_count == 2
-
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.BigQueryHook")
- def test_execute_good_request_to_bq__with_replace(self, mock_hook):
+
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.bigquery_get_data")
+ @mock.patch.object(BigQueryToPostgresOperator, "bigquery_hook",
new_callable=mock.PropertyMock)
+ @mock.patch.object(BigQueryToPostgresOperator, "postgres_hook",
new_callable=mock.PropertyMock)
+ def test_execute_good_request_to_bq_with_replace(
+ self,
+ mock_pg_hook,
+ mock_bq_hook,
+ mock_bigquery_get_data,
+ ):
operator = BigQueryToPostgresOperator(
task_id=TASK_ID,
dataset_table=f"{TEST_DATASET}.{TEST_TABLE_ID}",
@@ -61,13 +100,30 @@ class TestBigQueryToPostgresOperator:
replace_index=["col_1"],
)
+ mock_bigquery_get_data.return_value = [[("only_row", "val")]]
+ mock_pg = mock.MagicMock()
+ mock_pg_hook.return_value = mock_pg
+ mock_bq = mock.MagicMock()
+ mock_bq.project_id = TEST_PROJECT
+ mock_bq_hook.return_value = mock_bq
+
operator.execute(context=mock.MagicMock())
- mock_hook.return_value.list_rows.assert_called_once_with(
- dataset_id=TEST_DATASET,
- table_id=TEST_TABLE_ID,
- max_results=1000,
- selected_fields=["col_1", "col_2"],
- start_index=0,
+
+ mock_bigquery_get_data.assert_called_once_with(
+ operator.log,
+ TEST_DATASET,
+ TEST_TABLE_ID,
+ mock_bq,
+ operator.batch_size,
+ ["col_1", "col_2"],
+ )
+ mock_pg.insert_rows.assert_called_once_with(
+ table=TEST_DESTINATION_TABLE,
+ rows=[("only_row", "val")],
+ target_fields=["col_1", "col_2"],
+ replace=True,
+ commit_every=operator.batch_size,
+ replace_index=["col_1"],
)
@pytest.mark.parametrize(
@@ -87,15 +143,113 @@ class TestBigQueryToPostgresOperator:
replace_index=replace_index,
)
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client")
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id"
+ )
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.register_adapter")
-
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.BigQueryHook")
- def test_adapters_to_json_registered(self, mock_hook,
mock_register_adapter):
- BigQueryToPostgresOperator(
+ def test_adapters_to_json_registered(self, mock_register_adapter,
mock_get_creds, mock_get_client):
+ mock_get_creds.return_value = (None, TEST_PROJECT)
+ client = MagicMock()
+ client.list_rows.return_value = []
+ mock_get_client.return_value = client
+
+ operator = BigQueryToPostgresOperator(
task_id=TASK_ID,
dataset_table=f"{TEST_DATASET}.{TEST_TABLE_ID}",
target_table_name=TEST_DESTINATION_TABLE,
replace=False,
- ).execute(context=mock.MagicMock())
+ )
+ operator.postgres_hook
mock_register_adapter.assert_any_call(list, Json)
mock_register_adapter.assert_any_call(dict, Json)
+
+
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.PostgresHook")
+
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook")
+ def test_get_openlineage_facets_on_complete_no_selected_fields(self,
mock_bq_hook, mock_postgres_hook):
+ mock_bq_client = MagicMock()
+ mock_bq_client.get_table.return_value = _make_bq_table(["id", "name",
"value"])
+ mock_bq_hook.get_client.return_value = mock_bq_client
+ mock_bq_hook.return_value = mock_bq_hook
+
+ db_info = MagicMock(scheme="postgres", authority="localhost:5432",
database="postgresdb")
+ mock_postgres_hook.get_openlineage_database_info.return_value = db_info
+ mock_postgres_hook.get_openlineage_default_schema.return_value =
"postgres-schema"
+ mock_postgres_hook.return_value = mock_postgres_hook
+
+ op = BigQueryToPostgresOperator(
+ task_id=TASK_ID,
+ dataset_table=f"{TEST_DATASET}.{TEST_TABLE_ID}",
+ target_table_name="destination",
+ selected_fields=None,
+ database="postgresdb",
+ )
+ op.bigquery_hook = mock_bq_hook
+ op.bigquery_hook.project_id = TEST_PROJECT
+ op.postgres_hook = mock_postgres_hook
+ context = mock.MagicMock()
+ op.execute(context=context)
+
+ result =
op.get_openlineage_facets_on_complete(task_instance=MagicMock())
+ assert len(result.inputs) == 1
+ assert len(result.outputs) == 1
+
+ input_ds = result.inputs[0]
+ assert input_ds.namespace == "bigquery"
+ assert input_ds.name ==
f"{TEST_PROJECT}.{TEST_DATASET}.{TEST_TABLE_ID}"
+ assert "schema" in input_ds.facets
+ schema_fields = [f.name for f in input_ds.facets["schema"].fields]
+ assert set(schema_fields) == {"id", "name", "value"}
+
+ output_ds = result.outputs[0]
+ assert output_ds.namespace == "postgres://localhost:5432"
+ assert output_ds.name == "postgresdb.postgres-schema.destination"
+
+ assert "columnLineage" in output_ds.facets
+ col_lineage = output_ds.facets["columnLineage"]
+ assert set(col_lineage.fields.keys()) == {"id", "name", "value"}
+
+
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.PostgresHook")
+
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook")
+ def test_get_openlineage_facets_on_complete_selected_fields(self,
mock_bq_hook, mock_postgres_hook):
+ mock_bq_client = MagicMock()
+ mock_bq_client.get_table.return_value = _make_bq_table(["id", "name",
"value"])
+ mock_bq_hook.get_client.return_value = mock_bq_client
+ mock_bq_hook.return_value = mock_bq_hook
+
+ db_info = MagicMock(scheme="postgres", authority="localhost:5432",
database="postgresdb")
+ mock_postgres_hook.get_openlineage_database_info.return_value = db_info
+ mock_postgres_hook.get_openlineage_default_schema.return_value =
"postgres-schema"
+ mock_postgres_hook.return_value = mock_postgres_hook
+
+ op = BigQueryToPostgresOperator(
+ task_id=TASK_ID,
+ dataset_table=f"{TEST_DATASET}.{TEST_TABLE_ID}",
+ target_table_name="destination",
+ selected_fields=["id", "name"],
+ database="postgresdb",
+ )
+ op.bigquery_hook = mock_bq_hook
+ op.bigquery_hook.project_id = TEST_PROJECT
+ op.postgres_hook = mock_postgres_hook
+ context = mock.MagicMock()
+ op.execute(context=context)
+
+ result =
op.get_openlineage_facets_on_complete(task_instance=MagicMock())
+ assert len(result.inputs) == 1
+ assert len(result.outputs) == 1
+
+ input_ds = result.inputs[0]
+ assert input_ds.namespace == "bigquery"
+ assert input_ds.name ==
f"{TEST_PROJECT}.{TEST_DATASET}.{TEST_TABLE_ID}"
+ assert "schema" in input_ds.facets
+ schema_fields = [f.name for f in input_ds.facets["schema"].fields]
+ assert set(schema_fields) == {"id", "name"}
+
+ output_ds = result.outputs[0]
+ assert output_ds.namespace == "postgres://localhost:5432"
+ assert output_ds.name == "postgresdb.postgres-schema.destination"
+ assert "columnLineage" in output_ds.facets
+ col_lineage = output_ds.facets["columnLineage"]
+ assert set(col_lineage.fields.keys()) == {"id", "name"}