This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 61ecbedfa22 feat: Add OpenLineage support for MsSqlHook and 
MSSQLToGCSOperator (#45637)
61ecbedfa22 is described below

commit 61ecbedfa22a075de42fafffd3f579145a96a84b
Author: Kacper Muda <[email protected]>
AuthorDate: Thu Jan 16 14:55:47 2025 +0100

    feat: Add OpenLineage support for MsSqlHook and MSSQLToGCSOperator (#45637)
    
    Signed-off-by: Kacper Muda <[email protected]>
---
 dev/breeze/tests/test_selective_checks.py          |  6 +-
 generated/provider_dependencies.json               |  3 +-
 .../google/cloud/transfers/mssql_to_gcs.py         | 29 +++++++++-
 .../providers/microsoft/mssql/hooks/mssql.py       | 28 ++++++++++
 .../google/cloud/transfers/test_mssql_to_gcs.py    | 65 ++++++++++++++++++++++
 5 files changed, 125 insertions(+), 6 deletions(-)

diff --git a/dev/breeze/tests/test_selective_checks.py 
b/dev/breeze/tests/test_selective_checks.py
index db5d674e4e0..ae66b7d6d68 100644
--- a/dev/breeze/tests/test_selective_checks.py
+++ b/dev/breeze/tests/test_selective_checks.py
@@ -1752,7 +1752,7 @@ def test_expected_output_push(
                 "airflow/datasets/",
             ),
             {
-                "selected-providers-list-as-string": "amazon common.compat 
common.io common.sql dbt.cloud ftp google mysql openlineage postgres sftp 
snowflake trino",
+                "selected-providers-list-as-string": "amazon common.compat 
common.io common.sql dbt.cloud ftp google microsoft.mssql mysql openlineage 
postgres sftp snowflake trino",
                 "all-python-versions": "['3.9']",
                 "all-python-versions-list-as-string": "3.9",
                 "ci-image-build": "true",
@@ -1762,13 +1762,13 @@ def test_expected_output_push(
                 "skip-providers-tests": "false",
                 "test-groups": "['core', 'providers']",
                 "docs-build": "true",
-                "docs-list-as-string": "apache-airflow amazon common.compat 
common.io common.sql dbt.cloud ftp google mysql openlineage postgres sftp 
snowflake trino",
+                "docs-list-as-string": "apache-airflow amazon common.compat 
common.io common.sql dbt.cloud ftp google microsoft.mssql mysql openlineage 
postgres sftp snowflake trino",
                 "skip-pre-commits": 
"check-provider-yaml-valid,flynt,identity,lint-helm-chart,mypy-airflow,mypy-dev,mypy-docs,mypy-providers,mypy-task-sdk,"
                 "ts-compile-format-lint-ui,ts-compile-format-lint-www",
                 "run-kubernetes-tests": "false",
                 "upgrade-to-newer-dependencies": "false",
                 "core-test-types-list-as-string": "API Always CLI Core 
Operators Other Serialization WWW",
-                "providers-test-types-list-as-string": "Providers[amazon] 
Providers[common.compat,common.io,common.sql,dbt.cloud,ftp,mysql,openlineage,postgres,sftp,snowflake,trino]
 Providers[google]",
+                "providers-test-types-list-as-string": "Providers[amazon] 
Providers[common.compat,common.io,common.sql,dbt.cloud,ftp,microsoft.mssql,mysql,openlineage,postgres,sftp,snowflake,trino]
 Providers[google]",
                 "needs-mypy": "false",
                 "mypy-checks": "[]",
             },
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 9cc7784ca8d..2724c6a73d4 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -860,7 +860,8 @@
     "devel-deps": [],
     "plugins": [],
     "cross-providers-deps": [
-      "common.sql"
+      "common.sql",
+      "openlineage"
     ],
     "excluded-python-versions": [],
     "state": "ready"
diff --git 
a/providers/src/airflow/providers/google/cloud/transfers/mssql_to_gcs.py 
b/providers/src/airflow/providers/google/cloud/transfers/mssql_to_gcs.py
index 6bcd9f38ff0..8f861aaee69 100644
--- a/providers/src/airflow/providers/google/cloud/transfers/mssql_to_gcs.py
+++ b/providers/src/airflow/providers/google/cloud/transfers/mssql_to_gcs.py
@@ -22,10 +22,15 @@ from __future__ import annotations
 import datetime
 import decimal
 from collections.abc import Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING
 
 from airflow.providers.google.cloud.transfers.sql_to_gcs import 
BaseSQLToGCSOperator
 from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
 
+if TYPE_CHECKING:
+    from airflow.providers.openlineage.extractors import OperatorLineage
+
 
 class MSSQLToGCSOperator(BaseSQLToGCSOperator):
     """
@@ -75,14 +80,17 @@ class MSSQLToGCSOperator(BaseSQLToGCSOperator):
         self.mssql_conn_id = mssql_conn_id
         self.bit_fields = bit_fields or []
 
+    @cached_property
+    def db_hook(self) -> MsSqlHook:
+        return MsSqlHook(mssql_conn_id=self.mssql_conn_id)
+
     def query(self):
         """
         Query MSSQL and returns a cursor of results.
 
         :return: mssql cursor
         """
-        mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id)
-        conn = mssql.get_conn()
+        conn = self.db_hook.get_conn()
         cursor = conn.cursor()
         cursor.execute(self.sql)
         return cursor
@@ -109,3 +117,20 @@ class MSSQLToGCSOperator(BaseSQLToGCSOperator):
         if isinstance(value, (datetime.date, datetime.time)):
             return value.isoformat()
         return value
+
+    def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
+        from airflow.providers.common.compat.openlineage.facet import 
SQLJobFacet
+        from airflow.providers.common.compat.openlineage.utils.sql import 
get_openlineage_facets_with_sql
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        sql_parsing_result = get_openlineage_facets_with_sql(
+            hook=self.db_hook,
+            sql=self.sql,
+            conn_id=self.mssql_conn_id,
+            database=None,
+        )
+        gcs_output_datasets = self._get_openlineage_output_datasets()
+        if sql_parsing_result:
+            sql_parsing_result.outputs = gcs_output_datasets
+            return sql_parsing_result
+        return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql": 
SQLJobFacet(self.sql)})
diff --git a/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py 
b/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py
index 089f1ccfb7d..a29018ec0ff 100644
--- a/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py
+++ b/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py
@@ -29,6 +29,7 @@ from airflow.providers.microsoft.mssql.dialects.mssql import 
MsSqlDialect
 
 if TYPE_CHECKING:
     from airflow.providers.common.sql.dialects.dialect import Dialect
+    from airflow.providers.openlineage.sqlparser import DatabaseInfo
 
 
 class MsSqlHook(DbApiHook):
@@ -117,3 +118,30 @@ class MsSqlHook(DbApiHook):
 
     def get_autocommit(self, conn: PymssqlConnection):
         return conn.autocommit_state
+
+    def get_openlineage_database_info(self, connection) -> DatabaseInfo:
+        """Return MSSQL specific information for OpenLineage."""
+        from airflow.providers.openlineage.sqlparser import DatabaseInfo
+
+        return DatabaseInfo(
+            scheme=self.get_openlineage_database_dialect(connection),
+            authority=DbApiHook.get_openlineage_authority_part(connection, 
default_port=1433),
+            information_schema_columns=[
+                "table_schema",
+                "table_name",
+                "column_name",
+                "ordinal_position",
+                "data_type",
+                "table_catalog",
+            ],
+            database=self.schema or self.connection.schema,
+            is_information_schema_cross_db=True,
+        )
+
+    def get_openlineage_database_dialect(self, connection) -> str:
+        """Return database dialect."""
+        return "mssql"
+
+    def get_openlineage_default_schema(self) -> str | None:
+        """Return current schema."""
+        return self.get_first("SELECT SCHEMA_NAME();")[0]
diff --git a/providers/tests/google/cloud/transfers/test_mssql_to_gcs.py 
b/providers/tests/google/cloud/transfers/test_mssql_to_gcs.py
index 18ccb502127..c04cf207970 100644
--- a/providers/tests/google/cloud/transfers/test_mssql_to_gcs.py
+++ b/providers/tests/google/cloud/transfers/test_mssql_to_gcs.py
@@ -22,6 +22,12 @@ from unittest import mock
 
 import pytest
 
+from airflow.models import Connection
+from airflow.providers.common.compat.openlineage.facet import (
+    OutputDataset,
+    SchemaDatasetFacetFields,
+)
+from airflow.providers.common.sql.hooks.sql import DbApiHook
 from airflow.providers.google.cloud.transfers.mssql_to_gcs import 
MSSQLToGCSOperator
 
 TASK_ID = "test-mssql-to-gcs"
@@ -188,3 +194,62 @@ class TestMsSqlToGoogleCloudStorageOperator:
 
         # once for the file and once for the schema
         assert gcs_hook_mock.upload.call_count == 2
+
+    @pytest.mark.parametrize(
+        "connection_port, default_port, expected_port",
+        [(None, 4321, 4321), (1234, None, 1234), (1234, 4321, 1234)],
+    )
+    def test_execute_openlineage_events(self, connection_port, default_port, 
expected_port):
+        class DBApiHookForTests(DbApiHook):
+            conn_name_attr = "sql_default"
+            get_conn = mock.MagicMock(name="conn")
+            get_connection = mock.MagicMock()
+
+            def get_openlineage_database_info(self, connection):
+                from airflow.providers.openlineage.sqlparser import 
DatabaseInfo
+
+                return DatabaseInfo(
+                    scheme="sqlscheme",
+                    
authority=DbApiHook.get_openlineage_authority_part(connection, 
default_port=default_port),
+                )
+
+        dbapi_hook = DBApiHookForTests()
+
+        class MSSQLToGCSOperatorForTest(MSSQLToGCSOperator):
+            @property
+            def db_hook(self):
+                return dbapi_hook
+
+        sql = """SELECT a,b,c from my_db.my_table"""
+        op = MSSQLToGCSOperatorForTest(task_id=TASK_ID, sql=sql, 
bucket="bucket", filename="dir/file{}.csv")
+        DB_SCHEMA_NAME = "PUBLIC"
+        rows = [
+            (DB_SCHEMA_NAME, "popular_orders_day_of_week", 
"order_day_of_week", 1, "varchar"),
+            (DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_placed_on", 
2, "timestamp"),
+            (DB_SCHEMA_NAME, "popular_orders_day_of_week", "orders_placed", 3, 
"int4"),
+        ]
+        dbapi_hook.get_connection.return_value = Connection(
+            conn_id="sql_default", conn_type="mssql", host="host", 
port=connection_port
+        )
+        
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = 
[rows, []]
+
+        lineage = op.get_openlineage_facets_on_start()
+        assert len(lineage.inputs) == 1
+        assert lineage.inputs[0].namespace == 
f"sqlscheme://host:{expected_port}"
+        assert lineage.inputs[0].name == "PUBLIC.popular_orders_day_of_week"
+        assert len(lineage.inputs[0].facets) == 1
+        assert lineage.inputs[0].facets["schema"].fields == [
+            SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"),
+            SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"),
+            SchemaDatasetFacetFields(name="orders_placed", type="int4"),
+        ]
+        assert lineage.outputs == [
+            OutputDataset(
+                namespace="gs://bucket",
+                name="dir",
+            )
+        ]
+
+        assert len(lineage.job_facets) == 1
+        assert lineage.job_facets["sql"].query == sql
+        assert lineage.run_facets == {}

Reply via email to