This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 98a999034c Add OpenLineage support for MySQL. (#31609)
98a999034c is described below

commit 98a999034cab576b001340c9274ed293dcfce2cd
Author: JDarDagran <[email protected]>
AuthorDate: Fri Jul 21 18:35:22 2023 +0200

    Add OpenLineage support for MySQL. (#31609)
    
    Signed-off-by: Jakub Dardzinski <[email protected]>
---
 airflow/providers/mysql/hooks/mysql.py        | 25 +++++++++
 generated/provider_dependencies.json          |  1 +
 tests/providers/mysql/operators/test_mysql.py | 77 +++++++++++++++++++++++++++
 3 files changed, 103 insertions(+)

diff --git a/airflow/providers/mysql/hooks/mysql.py 
b/airflow/providers/mysql/hooks/mysql.py
index e105d8f96a..e5023aac1a 100644
--- a/airflow/providers/mysql/hooks/mysql.py
+++ b/airflow/providers/mysql/hooks/mysql.py
@@ -298,3 +298,28 @@ class MySqlHook(DbApiHook):
         cursor.close()
         conn.commit()
         conn.close()  # type: ignore[misc]
+
+    def get_openlineage_database_info(self, connection):
+        """Returns MySQL specific information for OpenLineage."""
+        from airflow.providers.openlineage.sqlparser import DatabaseInfo
+
+        return DatabaseInfo(
+            scheme=self.get_openlineage_database_dialect(connection),
+            authority=DbApiHook.get_openlineage_authority_part(connection),
+            information_schema_columns=[
+                "table_schema",
+                "table_name",
+                "column_name",
+                "ordinal_position",
+                "column_type",
+            ],
+            normalize_name_method=lambda name: name.upper(),
+        )
+
+    def get_openlineage_database_dialect(self, _):
+        """Returns database dialect."""
+        return "mysql"
+
+    def get_openlineage_default_schema(self):
+        """MySQL has no concept of schema."""
+        return None
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index abf1088cd4..7a160c96dc 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -602,6 +602,7 @@
     "cross-providers-deps": [
       "amazon",
       "common.sql",
+      "openlineage",
       "presto",
       "trino",
       "vertica"
diff --git a/tests/providers/mysql/operators/test_mysql.py 
b/tests/providers/mysql/operators/test_mysql.py
index 379459985a..9202767f9d 100644
--- a/tests/providers/mysql/operators/test_mysql.py
+++ b/tests/providers/mysql/operators/test_mysql.py
@@ -20,9 +20,13 @@ from __future__ import annotations
 import os
 from contextlib import closing
 from tempfile import NamedTemporaryFile
+from unittest.mock import MagicMock
 
 import pytest
+from openlineage.client.facet import SchemaDatasetFacet, SchemaField, 
SqlJobFacet
+from openlineage.client.run import Dataset
 
+from airflow.models.connection import Connection
 from airflow.models.dag import DAG
 from airflow.providers.mysql.hooks.mysql import MySqlHook
 from airflow.providers.mysql.operators.mysql import MySqlOperator
@@ -111,3 +115,76 @@ class TestMySql:
 
         assert isinstance(task.parameters, dict)
         assert task.parameters["foo"] == "{{ ds }}"
+
+    @pytest.mark.parametrize("client", ["mysqlclient", 
"mysql-connector-python"])
+    def test_mysql_operator_openlineage(self, client):
+        with MySqlContext(client):
+            sql = """
+            CREATE TABLE IF NOT EXISTS test_airflow (
+                dummy VARCHAR(50)
+            );
+            """
+            op = MySqlOperator(task_id="basic_mysql", sql=sql, dag=self.dag)
+
+            lineage = op.get_openlineage_facets_on_start()
+            assert len(lineage.inputs) == 0
+            assert len(lineage.outputs) == 0
+            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
+
+            # OpenLineage provider runs same method on complete by default
+            lineage_on_complete = op.get_openlineage_facets_on_start()
+            assert len(lineage_on_complete.inputs) == 0
+            assert len(lineage_on_complete.outputs) == 1
+
+
+def test_execute_openlineage_events():
+    class MySqlHookForTests(MySqlHook):
+        conn_name_attr = "sql_default"
+        get_conn = MagicMock(name="conn")
+        get_connection = MagicMock()
+
+    dbapi_hook = MySqlHookForTests()
+
+    class SQLExecuteQueryOperatorForTest(MySqlOperator):
+        def get_db_hook(self):
+            return dbapi_hook
+
+    sql = """CREATE TABLE IF NOT EXISTS popular_orders_day_of_week (
+        order_day_of_week VARCHAR(64) NOT NULL,
+        order_placed_on   TIMESTAMP NOT NULL,
+        orders_placed     INTEGER NOT NULL
+    );
+FORGOT TO COMMENT"""
+    op = SQLExecuteQueryOperatorForTest(task_id="mysql-operator", sql=sql)
+    DB_SCHEMA_NAME = "PUBLIC"
+    rows = [
+        (DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_day_of_week", 1, 
"varchar"),
+        (DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_placed_on", 2, 
"timestamp"),
+        (DB_SCHEMA_NAME, "popular_orders_day_of_week", "orders_placed", 3, 
"int4"),
+    ]
+    dbapi_hook.get_connection.return_value = Connection(
+        conn_id="mysql_default", conn_type="mysql", host="host", port=1234
+    )
+    dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect 
= [rows, []]
+
+    lineage = op.get_openlineage_facets_on_start()
+    assert len(lineage.inputs) == 0
+    assert lineage.outputs == [
+        Dataset(
+            namespace="mysql://host:1234",
+            name="PUBLIC.popular_orders_day_of_week",
+            facets={
+                "schema": SchemaDatasetFacet(
+                    fields=[
+                        SchemaField(name="order_day_of_week", type="varchar"),
+                        SchemaField(name="order_placed_on", type="timestamp"),
+                        SchemaField(name="orders_placed", type="int4"),
+                    ]
+                )
+            },
+        )
+    ]
+
+    assert lineage.job_facets == {"sql": SqlJobFacet(query=sql)}
+
+    assert lineage.run_facets["extractionError"].failedTasks == 1

Reply via email to