This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch add-openlineage-config-spark-submit
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit d9ef1b0d2cfb61850b26c7e9367eefb488c9cd87
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Thu Mar 6 13:00:28 2025 +0100

    serialize http transports contained in composite transport
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 .../common/compat/openlineage/utils/spark.py       | 98 ++++++++++++++++------
 .../unit/google/cloud/openlineage/test_utils.py    | 10 ++-
 .../unit/google/cloud/operators/test_dataproc.py   | 46 ++++++++--
 .../providers/openlineage/plugins/macros.py        | 19 ++++-
 .../airflow/providers/openlineage/utils/spark.py   | 90 ++++++++++++++------
 .../tests/unit/openlineage/plugins/test_macros.py  | 34 +++++++-
 .../tests/unit/openlineage/utils/test_spark.py     | 89 +++++++++++++++++++-
 7 files changed, 315 insertions(+), 71 deletions(-)

diff --git 
a/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py
 
b/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py
index cbc7997dd44..c6a4313300f 100644
--- 
a/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py
+++ 
b/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py
@@ -90,8 +90,78 @@ else:
                     )
                     return properties
 
-                transport = 
get_openlineage_listener().adapter.get_or_create_openlineage_client().transport
-                if transport.kind != "http":
+                def _get_transport_information_as_spark_properties() -> dict:
+                    """Retrieve transport information as Spark properties."""
+
+                    def _get_transport_information(tp) -> dict:
+                        props = {
+                            "type": tp.kind,
+                            "url": tp.url,
+                            "endpoint": tp.endpoint,
+                            "timeoutInMillis": str(
+                                int(tp.timeout * 1000)
+                                # convert to milliseconds, as required by 
Spark integration
+                            ),
+                        }
+                        if hasattr(tp, "compression") and tp.compression:
+                            props["compression"] = str(tp.compression)
+
+                        if hasattr(tp.config.auth, "api_key") and 
tp.config.auth.get_bearer():
+                            props["auth.type"] = "api_key"
+                            props["auth.apiKey"] = tp.config.auth.get_bearer()
+
+                        if hasattr(tp.config, "custom_headers") and 
tp.config.custom_headers:
+                            for key, value in tp.config.custom_headers.items():
+                                props[f"headers.{key}"] = value
+                        return props
+
+                    def _format_transport(props: dict, transport: dict, name: 
str | None):
+                        for key, value in transport.items():
+                            if name:
+                                
props[f"spark.openlineage.transport.transports.{name}.{key}"] = value
+                            else:
+                                props[f"spark.openlineage.transport.{key}"] = 
value
+                        return props
+
+                    transport = (
+                        
get_openlineage_listener().adapter.get_or_create_openlineage_client().transport
+                    )
+
+                    if transport.kind == "composite":
+                        http_transports = {}
+                        for nested_transport in transport.transports:
+                            if nested_transport.kind == "http":
+                                http_transports[nested_transport.name] = 
_get_transport_information(
+                                    nested_transport
+                                )
+                            else:
+                                name = (
+                                    nested_transport.name if 
hasattr(nested_transport, "name") else "no-name"
+                                )
+                                log.info(
+                                    "OpenLineage transport type `%s` with name 
`%s` is not supported in composite transport.",
+                                    nested_transport.kind,
+                                    name,
+                                )
+                        if len(http_transports) == 0:
+                            log.warning(
+                                "OpenLineage transport type `composite` does 
not contain http transport. Skipping "
+                                "injection of OpenLineage transport 
information into Spark properties.",
+                            )
+                            return {}
+                        props = {
+                            "spark.openlineage.transport.type": "composite",
+                            "spark.openlineage.transport.continueOnFailure": 
str(
+                                transport.config.continue_on_failure
+                            ),
+                        }
+                        for name, http_transport in http_transports.items():
+                            props = _format_transport(props, http_transport, 
name)
+                        return props
+
+                    elif transport.kind == "http":
+                        return _format_transport({}, 
_get_transport_information(transport), None)
+
                     log.info(
                         "OpenLineage transport type `%s` does not support 
automatic "
                         "injection of OpenLineage transport information into 
Spark properties.",
@@ -99,29 +169,7 @@ else:
                     )
                     return {}
 
-                transport_properties = {
-                    "spark.openlineage.transport.type": "http",
-                    "spark.openlineage.transport.url": transport.url,
-                    "spark.openlineage.transport.endpoint": transport.endpoint,
-                    # Timeout is converted to milliseconds, as required by 
Spark integration,
-                    "spark.openlineage.transport.timeoutInMillis": 
str(int(transport.timeout * 1000)),
-                }
-                if transport.compression:
-                    
transport_properties["spark.openlineage.transport.compression"] = str(
-                        transport.compression
-                    )
-
-                if hasattr(transport.config.auth, "api_key") and 
transport.config.auth.get_bearer():
-                    
transport_properties["spark.openlineage.transport.auth.type"] = "api_key"
-                    
transport_properties["spark.openlineage.transport.auth.apiKey"] = (
-                        transport.config.auth.get_bearer()
-                    )
-
-                if hasattr(transport.config, "custom_headers") and 
transport.config.custom_headers:
-                    for key, value in transport.config.custom_headers.items():
-                        
transport_properties[f"spark.openlineage.transport.headers.{key}"] = value
-
-                return {**properties, **transport_properties}
+                return {**properties, 
**_get_transport_information_as_spark_properties()}
 
 
 __all__ = [
diff --git a/providers/google/tests/unit/google/cloud/openlineage/test_utils.py 
b/providers/google/tests/unit/google/cloud/openlineage/test_utils.py
index 4be23910335..21317983d62 100644
--- a/providers/google/tests/unit/google/cloud/openlineage/test_utils.py
+++ b/providers/google/tests/unit/google/cloud/openlineage/test_utils.py
@@ -106,6 +106,7 @@ EXAMPLE_TEMPLATE = {
 EXAMPLE_CONTEXT = {
     "ti": MagicMock(
         dag_id="dag_id",
+        dag_run=MagicMock(run_after=dt.datetime(2024, 11, 11), 
logical_date=dt.datetime(2024, 11, 11)),
         task_id="task_id",
         try_number=1,
         map_index=1,
@@ -574,7 +575,8 @@ def test_replace_dataproc_job_properties_key_error():
 def 
test_inject_openlineage_properties_into_dataproc_job_provider_not_accessible(mock_is_accessible):
     mock_is_accessible.return_value = False
     job = {"sparkJob": {"properties": {"existingProperty": "value"}}}
-    result = inject_openlineage_properties_into_dataproc_job(job, None, True, 
True)
+
+    result = inject_openlineage_properties_into_dataproc_job(job, 
EXAMPLE_CONTEXT, True, True)
     assert result == job
 
 
@@ -586,7 +588,7 @@ def 
test_inject_openlineage_properties_into_dataproc_job_unsupported_job_type(
     mock_is_accessible.return_value = True
     mock_extract_job_type.return_value = None
     job = {"unsupportedJob": {"properties": {"existingProperty": "value"}}}
-    result = inject_openlineage_properties_into_dataproc_job(job, None, True, 
True)
+    result = inject_openlineage_properties_into_dataproc_job(job, 
EXAMPLE_CONTEXT, True, True)
     assert result == job
 
 
@@ -599,7 +601,9 @@ def 
test_inject_openlineage_properties_into_dataproc_job_no_injection(
     mock_extract_job_type.return_value = "sparkJob"
     inject_parent_job_info = False
     job = {"sparkJob": {"properties": {"existingProperty": "value"}}}
-    result = inject_openlineage_properties_into_dataproc_job(job, None, 
inject_parent_job_info, False)
+    result = inject_openlineage_properties_into_dataproc_job(
+        job, EXAMPLE_CONTEXT, inject_parent_job_info, False
+    )
     assert result == job
 
 
diff --git 
a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py 
b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
index 73f3e6ce848..aa7712d7d86 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
@@ -1559,9 +1559,13 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
         op.execute(context=self.mock_context)
         assert not mock_defer.called
 
+    
@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
     
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
-    def test_execute_openlineage_parent_job_info_injection(self, mock_hook, 
mock_ol_accessible):
+    def test_execute_openlineage_parent_job_info_injection(
+        self, mock_hook, mock_ol_accessible, mock_static_uuid
+    ):
+        mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
         job_config = {
             "placement": {"cluster_name": CLUSTER_NAME},
             "pyspark_job": {
@@ -1620,13 +1624,15 @@ class 
TestDataprocSubmitJobOperator(DataprocJobTestBase):
             metadata=METADATA,
         )
 
+    
@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
     
@mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
     
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_execute_openlineage_http_transport_info_injection(
-        self, mock_hook, mock_ol_accessible, mock_ol_listener
+        self, mock_hook, mock_ol_accessible, mock_ol_listener, mock_static_uuid
     ):
         mock_ol_accessible.return_value = True
+        mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
         
mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport
 = HttpTransport(
             HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG)
         )
@@ -1673,11 +1679,15 @@ class 
TestDataprocSubmitJobOperator(DataprocJobTestBase):
             metadata=METADATA,
         )
 
+    
@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
     
@mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
     
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
-    def test_execute_openlineage_all_info_injection(self, mock_hook, 
mock_ol_accessible, mock_ol_listener):
+    def test_execute_openlineage_all_info_injection(
+        self, mock_hook, mock_ol_accessible, mock_ol_listener, mock_static_uuid
+    ):
         mock_ol_accessible.return_value = True
+        mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
         
mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport
 = HttpTransport(
             HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG)
         )
@@ -2705,10 +2715,14 @@ class 
TestDataprocWorkflowTemplateInstantiateInlineOperator:
         )
         mock_op.return_value.result.assert_not_called()
 
+    
@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
     
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
-    def test_execute_openlineage_parent_job_info_injection(self, mock_hook, 
mock_ol_accessible):
+    def test_execute_openlineage_parent_job_info_injection(
+        self, mock_hook, mock_ol_accessible, mock_static_uuid
+    ):
         mock_ol_accessible.return_value = True
+        mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
         template = {
             **WORKFLOW_TEMPLATE,
             "jobs": [
@@ -2891,13 +2905,15 @@ class 
TestDataprocWorkflowTemplateInstantiateInlineOperator:
             metadata=METADATA,
         )
 
+    
@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
     
@mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
     
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_execute_openlineage_transport_info_injection(
-        self, mock_hook, mock_ol_accessible, mock_ol_listener
+        self, mock_hook, mock_ol_accessible, mock_ol_listener, mock_static_uuid
     ):
         mock_ol_accessible.return_value = True
+        mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
         
mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport
 = HttpTransport(
             HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG)
         )
@@ -2995,11 +3011,15 @@ class 
TestDataprocWorkflowTemplateInstantiateInlineOperator:
             metadata=METADATA,
         )
 
+    
@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
     
@mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
     
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
-    def test_execute_openlineage_all_info_injection(self, mock_hook, 
mock_ol_accessible, mock_ol_listener):
+    def test_execute_openlineage_all_info_injection(
+        self, mock_hook, mock_ol_accessible, mock_ol_listener, mock_static_uuid
+    ):
         mock_ol_accessible.return_value = True
+        mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
         
mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport
 = HttpTransport(
             HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG)
         )
@@ -3467,11 +3487,15 @@ class TestDataprocCreateBatchOperator:
             metadata=METADATA,
         )
 
+    
@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
     
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
     @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
-    def test_execute_openlineage_parent_job_info_injection(self, mock_hook, 
to_dict_mock, mock_ol_accessible):
+    def test_execute_openlineage_parent_job_info_injection(
+        self, mock_hook, to_dict_mock, mock_ol_accessible, mock_static_uuid
+    ):
         mock_ol_accessible.return_value = True
+        mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
         expected_batch = {
             **BATCH,
             "runtime_config": {"properties": 
OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES},
@@ -3504,14 +3528,16 @@ class TestDataprocCreateBatchOperator:
             metadata=METADATA,
         )
 
+    
@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
     
@mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
     
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
     @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_execute_openlineage_transport_info_injection(
-        self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener
+        self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener, 
mock_static_uuid
     ):
         mock_ol_accessible.return_value = True
+        mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
         
mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport
 = HttpTransport(
             HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG)
         )
@@ -3547,14 +3573,16 @@ class TestDataprocCreateBatchOperator:
             metadata=METADATA,
         )
 
+    
@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
     
@mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
     
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
     @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_execute_openlineage_all_info_injection(
-        self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener
+        self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener, 
mock_static_uuid
     ):
         mock_ol_accessible.return_value = True
+        mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
         
mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport
 = HttpTransport(
             HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG)
         )
diff --git 
a/providers/openlineage/src/airflow/providers/openlineage/plugins/macros.py 
b/providers/openlineage/src/airflow/providers/openlineage/plugins/macros.py
index fd5194d3f32..ef5b9c0ad64 100644
--- a/providers/openlineage/src/airflow/providers/openlineage/plugins/macros.py
+++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/macros.py
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
 from airflow.providers.openlineage import conf
 from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter
 from airflow.providers.openlineage.utils.utils import get_job_name
+from airflow.providers.openlineage.version_compat import AIRFLOW_V_3_0_PLUS
 
 if TYPE_CHECKING:
     from airflow.models import TaskInstance
@@ -58,15 +59,25 @@ def lineage_run_id(task_instance: TaskInstance):
         For more information take a look at the guide:
         :ref:`howto/macros:openlineage`
     """
-    if hasattr(task_instance, "logical_date"):
-        logical_date = task_instance.logical_date
+    if AIRFLOW_V_3_0_PLUS:
+        context = task_instance.get_template_context()
+        if hasattr(task_instance, "dag_run"):
+            dag_run = task_instance.dag_run
+        elif hasattr(context, "dag_run"):
+            dag_run = context["dag_run"]
+        if hasattr(dag_run, "logical_date") and dag_run.logical_date:
+            date = dag_run.logical_date
+        else:
+            date = dag_run.run_after
+    elif hasattr(task_instance, "logical_date"):
+        date = task_instance.logical_date
     else:
-        logical_date = task_instance.execution_date
+        date = task_instance.execution_date
     return OpenLineageAdapter.build_task_instance_run_id(
         dag_id=task_instance.dag_id,
         task_id=task_instance.task_id,
         try_number=task_instance.try_number,
-        logical_date=logical_date,
+        logical_date=date,
         map_index=task_instance.map_index,
     )
 
diff --git 
a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py 
b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py
index c2991e65180..9f0fef84be0 100644
--- a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py
+++ b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py
@@ -53,35 +53,73 @@ def 
_get_parent_job_information_as_spark_properties(context: Context) -> dict:
 
 def _get_transport_information_as_spark_properties() -> dict:
     """Retrieve transport information as Spark properties."""
-    transport = 
get_openlineage_listener().adapter.get_or_create_openlineage_client().transport
-    if transport.kind != "http":
-        log.info(
-            "OpenLineage transport type `%s` does not support automatic "
-            "injection of OpenLineage transport information into Spark 
properties.",
-            transport.kind,
-        )
-        return {}
-
-    properties = {
-        "spark.openlineage.transport.type": transport.kind,
-        "spark.openlineage.transport.url": transport.url,
-        "spark.openlineage.transport.endpoint": transport.endpoint,
-        "spark.openlineage.transport.timeoutInMillis": str(
-            int(transport.timeout * 1000)  # convert to milliseconds, as 
required by Spark integration
-        ),
-    }
-    if transport.compression:
-        properties["spark.openlineage.transport.compression"] = 
str(transport.compression)
 
-    if hasattr(transport.config.auth, "api_key") and 
transport.config.auth.get_bearer():
-        properties["spark.openlineage.transport.auth.type"] = "api_key"
-        properties["spark.openlineage.transport.auth.apiKey"] = 
transport.config.auth.get_bearer()
+    def _get_transport_information(tp) -> dict:
+        properties = {
+            "type": tp.kind,
+            "url": tp.url,
+            "endpoint": tp.endpoint,
+            "timeoutInMillis": str(
+                int(tp.timeout * 1000)  # convert to milliseconds, as required 
by Spark integration
+            ),
+        }
+        if hasattr(tp, "compression") and tp.compression:
+            properties["compression"] = str(tp.compression)
+
+        if hasattr(tp.config.auth, "api_key") and tp.config.auth.get_bearer():
+            properties["auth.type"] = "api_key"
+            properties["auth.apiKey"] = tp.config.auth.get_bearer()
+
+        if hasattr(tp.config, "custom_headers") and tp.config.custom_headers:
+            for key, value in tp.config.custom_headers.items():
+                properties[f"headers.{key}"] = value
+        return properties
+
+    def _format_transport(props: dict, transport: dict, name: str | None):
+        for key, value in transport.items():
+            if name:
+                props[f"spark.openlineage.transport.transports.{name}.{key}"] 
= value
+            else:
+                props[f"spark.openlineage.transport.{key}"] = value
+        return props
 
-    if hasattr(transport.config, "custom_headers") and 
transport.config.custom_headers:
-        for key, value in transport.config.custom_headers.items():
-            properties[f"spark.openlineage.transport.headers.{key}"] = value
+    transport = 
get_openlineage_listener().adapter.get_or_create_openlineage_client().transport
 
-    return properties
+    if transport.kind == "composite":
+        http_transports = {}
+        for nested_transport in transport.transports:
+            if nested_transport.kind == "http":
+                http_transports[nested_transport.name] = 
_get_transport_information(nested_transport)
+            else:
+                name = nested_transport.name if hasattr(nested_transport, 
"name") else "no-name"
+                log.info(
+                    "OpenLineage transport type `%s` with name `%s` is not 
supported in composite transport.",
+                    nested_transport.kind,
+                    name,
+                )
+        if len(http_transports) == 0:
+            log.warning(
+                "OpenLineage transport type `composite` does not contain http 
transport. Skipping "
+                "injection of OpenLineage transport information into Spark 
properties.",
+            )
+            return {}
+        props = {
+            "spark.openlineage.transport.type": "composite",
+            "spark.openlineage.transport.continueOnFailure": 
str(transport.config.continue_on_failure),
+        }
+        for name, http_transport in http_transports.items():
+            props = _format_transport(props, http_transport, name)
+        return props
+
+    elif transport.kind == "http":
+        return _format_transport({}, _get_transport_information(transport), 
None)
+
+    log.info(
+        "OpenLineage transport type `%s` does not support automatic "
+        "injection of OpenLineage transport information into Spark 
properties.",
+        transport.kind,
+    )
+    return {}
 
 
 def _is_parent_job_information_present_in_spark_properties(properties: dict) 
-> bool:
diff --git 
a/providers/openlineage/tests/unit/openlineage/plugins/test_macros.py 
b/providers/openlineage/tests/unit/openlineage/plugins/test_macros.py
index 28eb8c74cd9..dbc4455fb2b 100644
--- a/providers/openlineage/tests/unit/openlineage/plugins/test_macros.py
+++ b/providers/openlineage/tests/unit/openlineage/plugins/test_macros.py
@@ -19,6 +19,8 @@ from __future__ import annotations
 from datetime import datetime, timezone
 from unittest import mock
 
+import pytest
+
 from airflow import __version__
 from airflow.providers.openlineage.conf import namespace
 from airflow.providers.openlineage.plugins.macros import (
@@ -28,6 +30,8 @@ from airflow.providers.openlineage.plugins.macros import (
     lineage_run_id,
 )
 
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+
 _DAG_NAMESPACE = namespace()
 
 if __version__.startswith("2."):
@@ -51,18 +55,42 @@ def test_lineage_job_name():
 
 
 def test_lineage_run_id():
+    date = datetime(2020, 1, 1, 1, 1, 1, 0, tzinfo=timezone.utc)
+    dag_run = mock.MagicMock(run_id="run_id")
+    dag_run.logical_date = date
+    task_instance = mock.MagicMock(
+        dag_id="dag_id",
+        task_id="task_id",
+        dag_run=dag_run,
+        logical_date=date,
+        try_number=1,
+    )
+
+    call_result1 = lineage_run_id(task_instance)
+    call_result2 = lineage_run_id(task_instance)
+
+    # random part value does not matter, it just has to be the same for the 
same TaskInstance
+    assert call_result1 == call_result2
+    # execution_date is used as most significant bits of UUID
+    assert call_result1.startswith("016f5e9e-c4c8-")
+
+
[email protected](not AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow 
3.0+")
+def test_lineage_run_after_airflow_3():
+    dag_run = mock.MagicMock(run_id="run_id")
+    dag_run.run_after = datetime(2020, 1, 1, 1, 1, 1, 0, tzinfo=timezone.utc)
+    dag_run.logical_date = None
     task_instance = mock.MagicMock(
         dag_id="dag_id",
         task_id="task_id",
-        dag_run=mock.MagicMock(run_id="run_id"),
-        logical_date=datetime(2020, 1, 1, 1, 1, 1, 0, tzinfo=timezone.utc),
+        dag_run=dag_run,
         try_number=1,
     )
 
     call_result1 = lineage_run_id(task_instance)
     call_result2 = lineage_run_id(task_instance)
 
-    # random part value does not matter, it just have to be the same for the 
same TaskInstance
+    # random part value does not matter, it just has to be the same for the 
same TaskInstance
     assert call_result1 == call_result2
     # execution_date is used as most significant bits of UUID
     assert call_result1.startswith("016f5e9e-c4c8-")
diff --git a/providers/openlineage/tests/unit/openlineage/utils/test_spark.py 
b/providers/openlineage/tests/unit/openlineage/utils/test_spark.py
index c4073da4699..dba2f337128 100644
--- a/providers/openlineage/tests/unit/openlineage/utils/test_spark.py
+++ b/providers/openlineage/tests/unit/openlineage/utils/test_spark.py
@@ -21,7 +21,9 @@ import datetime as dt
 from unittest.mock import MagicMock, patch
 
 import pytest
-from openlineage.client.transport import HttpConfig, HttpTransport, 
KafkaConfig, KafkaTransport
+from openlineage.client.transport.composite import CompositeConfig, 
CompositeTransport
+from openlineage.client.transport.http import HttpConfig, HttpTransport
+from openlineage.client.transport.kafka import KafkaConfig, KafkaTransport
 
 from airflow.providers.openlineage.utils.spark import (
     _get_parent_job_information_as_spark_properties,
@@ -38,10 +40,12 @@ EXAMPLE_CONTEXT = {
         task_id="task_id",
         try_number=1,
         map_index=1,
+        dag_run=MagicMock(logical_date=dt.datetime(2024, 11, 11)),
         logical_date=dt.datetime(2024, 11, 11),
     )
 }
 EXAMPLE_HTTP_TRANSPORT_CONFIG = {
+    "type": "http",
     "url": "https://some-custom.url";,
     "endpoint": "/api/custom",
     "timeout": 123,
@@ -55,6 +59,17 @@ EXAMPLE_HTTP_TRANSPORT_CONFIG = {
         "apiKey": "secret_123",
     },
 }
+EXAMPLE_KAFKA_TRANSPORT_CONFIG = {
+    "type": "kafka",
+    "topic": "my_topic",
+    "config": {
+        "bootstrap.servers": "test-kafka-hm0fo:10011,another.host-uuj0l:10012",
+        "acks": "all",
+        "retries": "3",
+    },
+    "flush": True,
+    "messageKey": "some",
+}
 EXAMPLE_PARENT_JOB_SPARK_PROPERTIES = {
     "spark.openlineage.parentJobName": "dag_id.task_id",
     "spark.openlineage.parentJobNamespace": "default",
@@ -72,6 +87,20 @@ EXAMPLE_TRANSPORT_SPARK_PROPERTIES = {
     "spark.openlineage.transport.timeoutInMillis": "123000",
 }
 
+EXAMPLE_COMPOSITE_TRANSPORT_SPARK_PROPERTIES = {
+    "spark.openlineage.transport.type": "composite",
+    "spark.openlineage.transport.continueOnFailure": "True",
+    "spark.openlineage.transport.transports.http.type": "http",
+    "spark.openlineage.transport.transports.http.url": 
"https://some-custom.url";,
+    "spark.openlineage.transport.transports.http.endpoint": "/api/custom",
+    "spark.openlineage.transport.transports.http.auth.type": "api_key",
+    "spark.openlineage.transport.transports.http.auth.apiKey": "Bearer 
secret_123",
+    "spark.openlineage.transport.transports.http.compression": "gzip",
+    "spark.openlineage.transport.transports.http.headers.key1": "val1",
+    "spark.openlineage.transport.transports.http.headers.key2": "val2",
+    "spark.openlineage.transport.transports.http.timeoutInMillis": "123000",
+}
+
 
 def test_get_parent_job_information_as_spark_properties():
     result = _get_parent_job_information_as_spark_properties(EXAMPLE_CONTEXT)
@@ -106,6 +135,17 @@ def 
test_get_transport_information_as_spark_properties_unsupported_transport_typ
     assert result == {}
 
 
+@patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
+def 
test_get_transport_information_as_spark_properties_composite_transport_type(mock_ol_listener):
+    
mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport
 = CompositeTransport(
+        CompositeConfig.from_dict(
+            {"transports": {"http": EXAMPLE_HTTP_TRANSPORT_CONFIG, "kafka": 
EXAMPLE_KAFKA_TRANSPORT_CONFIG}}
+        )
+    )
+    result = _get_transport_information_as_spark_properties()
+    assert result == EXAMPLE_COMPOSITE_TRANSPORT_SPARK_PROPERTIES
+
+
 @pytest.mark.parametrize(
     "properties, expected",
     [
@@ -260,3 +300,50 @@ def 
test_inject_transport_information_into_spark_properties(mock_ol_listener, pr
     result = inject_transport_information_into_spark_properties(properties, 
EXAMPLE_CONTEXT)
     expected = {**properties, **EXAMPLE_TRANSPORT_SPARK_PROPERTIES} if 
should_inject else properties
     assert result == expected
+
+
[email protected](
+    "properties, should_inject",
+    [
+        (
+            {"spark.openlineage.transport": "example_namespace"},
+            False,
+        ),
+        (
+            {"spark.openlineage.transport.type": "some_job_name"},
+            False,
+        ),
+        (
+            {"spark.openlineage.transport.url": "some_run_id"},
+            False,
+        ),
+        (
+            {"spark.openlineage.transportWhatever": "some_value", 
"some.other.property": "value"},
+            False,
+        ),
+        (
+            {"some.other.property": "value"},
+            True,
+        ),
+        (
+            {},
+            True,
+        ),
+    ],
+)
+@patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
+def test_inject_composite_transport_information_into_spark_properties(
+    mock_ol_listener, properties, should_inject
+):
+    
mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport
 = CompositeTransport(
+        CompositeConfig(
+            transports={
+                "http": EXAMPLE_HTTP_TRANSPORT_CONFIG,
+                "console": {"type": "console"},
+            },
+            continue_on_failure=True,
+        )
+    )
+    result = inject_transport_information_into_spark_properties(properties, 
EXAMPLE_CONTEXT)
+    expected = {**properties, **EXAMPLE_COMPOSITE_TRANSPORT_SPARK_PROPERTIES} 
if should_inject else properties
+    assert result == expected

Reply via email to