This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 61ecbedfa22 feat: Add OpenLineage support for MsSqlHook and
MSSQLToGCSOperator (#45637)
61ecbedfa22 is described below
commit 61ecbedfa22a075de42fafffd3f579145a96a84b
Author: Kacper Muda <[email protected]>
AuthorDate: Thu Jan 16 14:55:47 2025 +0100
feat: Add OpenLineage support for MsSqlHook and MSSQLToGCSOperator (#45637)
Signed-off-by: Kacper Muda <[email protected]>
---
dev/breeze/tests/test_selective_checks.py | 6 +-
generated/provider_dependencies.json | 3 +-
.../google/cloud/transfers/mssql_to_gcs.py | 29 +++++++++-
.../providers/microsoft/mssql/hooks/mssql.py | 28 ++++++++++
.../google/cloud/transfers/test_mssql_to_gcs.py | 65 ++++++++++++++++++++++
5 files changed, 125 insertions(+), 6 deletions(-)
diff --git a/dev/breeze/tests/test_selective_checks.py
b/dev/breeze/tests/test_selective_checks.py
index db5d674e4e0..ae66b7d6d68 100644
--- a/dev/breeze/tests/test_selective_checks.py
+++ b/dev/breeze/tests/test_selective_checks.py
@@ -1752,7 +1752,7 @@ def test_expected_output_push(
"airflow/datasets/",
),
{
- "selected-providers-list-as-string": "amazon common.compat
common.io common.sql dbt.cloud ftp google mysql openlineage postgres sftp
snowflake trino",
+ "selected-providers-list-as-string": "amazon common.compat
common.io common.sql dbt.cloud ftp google microsoft.mssql mysql openlineage
postgres sftp snowflake trino",
"all-python-versions": "['3.9']",
"all-python-versions-list-as-string": "3.9",
"ci-image-build": "true",
@@ -1762,13 +1762,13 @@ def test_expected_output_push(
"skip-providers-tests": "false",
"test-groups": "['core', 'providers']",
"docs-build": "true",
- "docs-list-as-string": "apache-airflow amazon common.compat
common.io common.sql dbt.cloud ftp google mysql openlineage postgres sftp
snowflake trino",
+ "docs-list-as-string": "apache-airflow amazon common.compat
common.io common.sql dbt.cloud ftp google microsoft.mssql mysql openlineage
postgres sftp snowflake trino",
"skip-pre-commits":
"check-provider-yaml-valid,flynt,identity,lint-helm-chart,mypy-airflow,mypy-dev,mypy-docs,mypy-providers,mypy-task-sdk,"
"ts-compile-format-lint-ui,ts-compile-format-lint-www",
"run-kubernetes-tests": "false",
"upgrade-to-newer-dependencies": "false",
"core-test-types-list-as-string": "API Always CLI Core
Operators Other Serialization WWW",
- "providers-test-types-list-as-string": "Providers[amazon]
Providers[common.compat,common.io,common.sql,dbt.cloud,ftp,mysql,openlineage,postgres,sftp,snowflake,trino]
Providers[google]",
+ "providers-test-types-list-as-string": "Providers[amazon]
Providers[common.compat,common.io,common.sql,dbt.cloud,ftp,microsoft.mssql,mysql,openlineage,postgres,sftp,snowflake,trino]
Providers[google]",
"needs-mypy": "false",
"mypy-checks": "[]",
},
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 9cc7784ca8d..2724c6a73d4 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -860,7 +860,8 @@
"devel-deps": [],
"plugins": [],
"cross-providers-deps": [
- "common.sql"
+ "common.sql",
+ "openlineage"
],
"excluded-python-versions": [],
"state": "ready"
diff --git
a/providers/src/airflow/providers/google/cloud/transfers/mssql_to_gcs.py
b/providers/src/airflow/providers/google/cloud/transfers/mssql_to_gcs.py
index 6bcd9f38ff0..8f861aaee69 100644
--- a/providers/src/airflow/providers/google/cloud/transfers/mssql_to_gcs.py
+++ b/providers/src/airflow/providers/google/cloud/transfers/mssql_to_gcs.py
@@ -22,10 +22,15 @@ from __future__ import annotations
import datetime
import decimal
from collections.abc import Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING
from airflow.providers.google.cloud.transfers.sql_to_gcs import
BaseSQLToGCSOperator
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
+if TYPE_CHECKING:
+ from airflow.providers.openlineage.extractors import OperatorLineage
+
class MSSQLToGCSOperator(BaseSQLToGCSOperator):
"""
@@ -75,14 +80,17 @@ class MSSQLToGCSOperator(BaseSQLToGCSOperator):
self.mssql_conn_id = mssql_conn_id
self.bit_fields = bit_fields or []
+ @cached_property
+ def db_hook(self) -> MsSqlHook:
+ return MsSqlHook(mssql_conn_id=self.mssql_conn_id)
+
def query(self):
"""
Query MSSQL and returns a cursor of results.
:return: mssql cursor
"""
- mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id)
- conn = mssql.get_conn()
+ conn = self.db_hook.get_conn()
cursor = conn.cursor()
cursor.execute(self.sql)
return cursor
@@ -109,3 +117,20 @@ class MSSQLToGCSOperator(BaseSQLToGCSOperator):
if isinstance(value, (datetime.date, datetime.time)):
return value.isoformat()
return value
+
+ def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
+ from airflow.providers.common.compat.openlineage.facet import
SQLJobFacet
+ from airflow.providers.common.compat.openlineage.utils.sql import
get_openlineage_facets_with_sql
+ from airflow.providers.openlineage.extractors import OperatorLineage
+
+ sql_parsing_result = get_openlineage_facets_with_sql(
+ hook=self.db_hook,
+ sql=self.sql,
+ conn_id=self.mssql_conn_id,
+ database=None,
+ )
+ gcs_output_datasets = self._get_openlineage_output_datasets()
+ if sql_parsing_result:
+ sql_parsing_result.outputs = gcs_output_datasets
+ return sql_parsing_result
+ return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql":
SQLJobFacet(self.sql)})
diff --git a/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py
b/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py
index 089f1ccfb7d..a29018ec0ff 100644
--- a/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py
+++ b/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py
@@ -29,6 +29,7 @@ from airflow.providers.microsoft.mssql.dialects.mssql import
MsSqlDialect
if TYPE_CHECKING:
from airflow.providers.common.sql.dialects.dialect import Dialect
+ from airflow.providers.openlineage.sqlparser import DatabaseInfo
class MsSqlHook(DbApiHook):
@@ -117,3 +118,30 @@ class MsSqlHook(DbApiHook):
def get_autocommit(self, conn: PymssqlConnection):
return conn.autocommit_state
+
+ def get_openlineage_database_info(self, connection) -> DatabaseInfo:
+ """Return MSSQL specific information for OpenLineage."""
+ from airflow.providers.openlineage.sqlparser import DatabaseInfo
+
+ return DatabaseInfo(
+ scheme=self.get_openlineage_database_dialect(connection),
+ authority=DbApiHook.get_openlineage_authority_part(connection,
default_port=1433),
+ information_schema_columns=[
+ "table_schema",
+ "table_name",
+ "column_name",
+ "ordinal_position",
+ "data_type",
+ "table_catalog",
+ ],
+ database=self.schema or self.connection.schema,
+ is_information_schema_cross_db=True,
+ )
+
+ def get_openlineage_database_dialect(self, connection) -> str:
+ """Return database dialect."""
+ return "mssql"
+
+ def get_openlineage_default_schema(self) -> str | None:
+ """Return current schema."""
+ return self.get_first("SELECT SCHEMA_NAME();")[0]
diff --git a/providers/tests/google/cloud/transfers/test_mssql_to_gcs.py
b/providers/tests/google/cloud/transfers/test_mssql_to_gcs.py
index 18ccb502127..c04cf207970 100644
--- a/providers/tests/google/cloud/transfers/test_mssql_to_gcs.py
+++ b/providers/tests/google/cloud/transfers/test_mssql_to_gcs.py
@@ -22,6 +22,12 @@ from unittest import mock
import pytest
+from airflow.models import Connection
+from airflow.providers.common.compat.openlineage.facet import (
+ OutputDataset,
+ SchemaDatasetFacetFields,
+)
+from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.google.cloud.transfers.mssql_to_gcs import
MSSQLToGCSOperator
TASK_ID = "test-mssql-to-gcs"
@@ -188,3 +194,62 @@ class TestMsSqlToGoogleCloudStorageOperator:
# once for the file and once for the schema
assert gcs_hook_mock.upload.call_count == 2
+
+ @pytest.mark.parametrize(
+ "connection_port, default_port, expected_port",
+ [(None, 4321, 4321), (1234, None, 1234), (1234, 4321, 1234)],
+ )
+ def test_execute_openlineage_events(self, connection_port, default_port,
expected_port):
+ class DBApiHookForTests(DbApiHook):
+ conn_name_attr = "sql_default"
+ get_conn = mock.MagicMock(name="conn")
+ get_connection = mock.MagicMock()
+
+ def get_openlineage_database_info(self, connection):
+ from airflow.providers.openlineage.sqlparser import
DatabaseInfo
+
+ return DatabaseInfo(
+ scheme="sqlscheme",
+
authority=DbApiHook.get_openlineage_authority_part(connection,
default_port=default_port),
+ )
+
+ dbapi_hook = DBApiHookForTests()
+
+ class MSSQLToGCSOperatorForTest(MSSQLToGCSOperator):
+ @property
+ def db_hook(self):
+ return dbapi_hook
+
+ sql = """SELECT a,b,c from my_db.my_table"""
+ op = MSSQLToGCSOperatorForTest(task_id=TASK_ID, sql=sql,
bucket="bucket", filename="dir/file{}.csv")
+ DB_SCHEMA_NAME = "PUBLIC"
+ rows = [
+ (DB_SCHEMA_NAME, "popular_orders_day_of_week",
"order_day_of_week", 1, "varchar"),
+ (DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_placed_on",
2, "timestamp"),
+ (DB_SCHEMA_NAME, "popular_orders_day_of_week", "orders_placed", 3,
"int4"),
+ ]
+ dbapi_hook.get_connection.return_value = Connection(
+ conn_id="sql_default", conn_type="mssql", host="host",
port=connection_port
+ )
+
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect =
[rows, []]
+
+ lineage = op.get_openlineage_facets_on_start()
+ assert len(lineage.inputs) == 1
+ assert lineage.inputs[0].namespace ==
f"sqlscheme://host:{expected_port}"
+ assert lineage.inputs[0].name == "PUBLIC.popular_orders_day_of_week"
+ assert len(lineage.inputs[0].facets) == 1
+ assert lineage.inputs[0].facets["schema"].fields == [
+ SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"),
+ SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"),
+ SchemaDatasetFacetFields(name="orders_placed", type="int4"),
+ ]
+ assert lineage.outputs == [
+ OutputDataset(
+ namespace="gs://bucket",
+ name="dir",
+ )
+ ]
+
+ assert len(lineage.job_facets) == 1
+ assert lineage.job_facets["sql"].query == sql
+ assert lineage.run_facets == {}