This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ce2841bf6a Add default port to Openlineage authority method. (#32828)
ce2841bf6a is described below

commit ce2841bf6ab609f31cb04aea9a39473de281bf24
Author: JDarDagran <[email protected]>
AuthorDate: Tue Jul 25 12:44:22 2023 +0200

    Add default port to Openlineage authority method. (#32828)
    
    Signed-off-by: Jakub Dardzinski <[email protected]>
---
 airflow/providers/common/sql/hooks/sql.py                | 10 ++++++++--
 airflow/providers/mysql/hooks/mysql.py                   |  2 +-
 tests/providers/common/sql/operators/test_sql_execute.py | 13 +++++++++----
 tests/providers/mysql/operators/test_mysql.py            |  7 ++++---
 4 files changed, 22 insertions(+), 10 deletions(-)

diff --git a/airflow/providers/common/sql/hooks/sql.py 
b/airflow/providers/common/sql/hooks/sql.py
index d444a95287..ade3c6a95d 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -598,13 +598,19 @@ class DbApiHook(BaseForDbApiHook):
         """
 
     @staticmethod
-    def get_openlineage_authority_part(connection) -> str:
+    def get_openlineage_authority_part(connection, default_port: int | None = 
None) -> str:
         """
         This method serves as common method for several hooks to get authority 
part from Airflow Connection.
 
         The authority represents the hostname and port of the connection
         and conforms OpenLineage naming convention for a number of databases 
(e.g. MySQL, Postgres, Trino).
+
+        :param default_port: (optional) used if no port parsed from connection 
URI
         """
         parsed = urlparse(connection.get_uri())
-        authority = f"{parsed.hostname}:{parsed.port}"
+        port = parsed.port or default_port
+        if port:
+            authority = f"{parsed.hostname}:{port}"
+        else:
+            authority = parsed.hostname
         return authority
diff --git a/airflow/providers/mysql/hooks/mysql.py 
b/airflow/providers/mysql/hooks/mysql.py
index e5023aac1a..fa011ed35b 100644
--- a/airflow/providers/mysql/hooks/mysql.py
+++ b/airflow/providers/mysql/hooks/mysql.py
@@ -305,7 +305,7 @@ class MySqlHook(DbApiHook):
 
         return DatabaseInfo(
             scheme=self.get_openlineage_database_dialect(connection),
-            authority=DbApiHook.get_openlineage_authority_part(connection),
+            authority=DbApiHook.get_openlineage_authority_part(connection, 
default_port=3306),
             information_schema_columns=[
                 "table_schema",
                 "table_name",
diff --git a/tests/providers/common/sql/operators/test_sql_execute.py 
b/tests/providers/common/sql/operators/test_sql_execute.py
index ddd1372485..5dc2e6e30c 100644
--- a/tests/providers/common/sql/operators/test_sql_execute.py
+++ b/tests/providers/common/sql/operators/test_sql_execute.py
@@ -281,7 +281,11 @@ def test_exec_success_with_process_output(
     )
 
 
-def test_execute_openlineage_events():
[email protected](
+    "connection_port, default_port, expected_port",
+    [(None, 4321, 4321), (1234, None, 1234), (1234, 4321, 1234)],
+)
+def test_execute_openlineage_events(connection_port, default_port, 
expected_port):
     class DBApiHookForTests(DbApiHook):
         conn_name_attr = "sql_default"
         get_conn = MagicMock(name="conn")
@@ -291,7 +295,8 @@ def test_execute_openlineage_events():
             from airflow.providers.openlineage.sqlparser import DatabaseInfo
 
             return DatabaseInfo(
-                scheme="sqlscheme", 
authority=DbApiHook.get_openlineage_authority_part(connection)
+                scheme="sqlscheme",
+                authority=DbApiHook.get_openlineage_authority_part(connection, 
default_port=default_port),
             )
 
         def get_openlineage_database_specific_lineage(self, task_instance):
@@ -317,7 +322,7 @@ FORGOT TO COMMENT"""
         (DB_SCHEMA_NAME, "popular_orders_day_of_week", "orders_placed", 3, 
"int4"),
     ]
     dbapi_hook.get_connection.return_value = Connection(
-        conn_id="sql_default", conn_type="postgresql", host="host", port=1234
+        conn_id="sql_default", conn_type="postgresql", host="host", 
port=connection_port
     )
     dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect 
= [rows, []]
 
@@ -325,7 +330,7 @@ FORGOT TO COMMENT"""
     assert len(lineage.inputs) == 0
     assert lineage.outputs == [
         Dataset(
-            namespace="sqlscheme://host:1234",
+            namespace=f"sqlscheme://host:{expected_port}",
             name="PUBLIC.popular_orders_day_of_week",
             facets={
                 "schema": SchemaDatasetFacet(
diff --git a/tests/providers/mysql/operators/test_mysql.py 
b/tests/providers/mysql/operators/test_mysql.py
index 9202767f9d..44aed62917 100644
--- a/tests/providers/mysql/operators/test_mysql.py
+++ b/tests/providers/mysql/operators/test_mysql.py
@@ -137,7 +137,8 @@ class TestMySql:
             assert len(lineage_on_complete.outputs) == 1
 
 
-def test_execute_openlineage_events():
[email protected]("connection_port", [None, 1234])
+def test_execute_openlineage_events(connection_port):
     class MySqlHookForTests(MySqlHook):
         conn_name_attr = "sql_default"
         get_conn = MagicMock(name="conn")
@@ -163,7 +164,7 @@ FORGOT TO COMMENT"""
         (DB_SCHEMA_NAME, "popular_orders_day_of_week", "orders_placed", 3, 
"int4"),
     ]
     dbapi_hook.get_connection.return_value = Connection(
-        conn_id="mysql_default", conn_type="mysql", host="host", port=1234
+        conn_id="mysql_default", conn_type="mysql", host="host", 
port=connection_port
     )
     dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect 
= [rows, []]
 
@@ -171,7 +172,7 @@ FORGOT TO COMMENT"""
     assert len(lineage.inputs) == 0
     assert lineage.outputs == [
         Dataset(
-            namespace="mysql://host:1234",
+            namespace=f"mysql://host:{connection_port or 3306}",
             name="PUBLIC.popular_orders_day_of_week",
             facets={
                 "schema": SchemaDatasetFacet(

Reply via email to