This is an automated email from the ASF dual-hosted git repository. mobuchowski pushed a commit to branch add-openlineage-config-spark-submit in repository https://gitbox.apache.org/repos/asf/airflow.git
commit d9ef1b0d2cfb61850b26c7e9367eefb488c9cd87 Author: Maciej Obuchowski <[email protected]> AuthorDate: Thu Mar 6 13:00:28 2025 +0100 serialize http transports contained in composite transport Signed-off-by: Maciej Obuchowski <[email protected]> --- .../common/compat/openlineage/utils/spark.py | 98 ++++++++++++++++------ .../unit/google/cloud/openlineage/test_utils.py | 10 ++- .../unit/google/cloud/operators/test_dataproc.py | 46 ++++++++-- .../providers/openlineage/plugins/macros.py | 19 ++++- .../airflow/providers/openlineage/utils/spark.py | 90 ++++++++++++++------ .../tests/unit/openlineage/plugins/test_macros.py | 34 +++++++- .../tests/unit/openlineage/utils/test_spark.py | 89 +++++++++++++++++++- 7 files changed, 315 insertions(+), 71 deletions(-) diff --git a/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py b/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py index cbc7997dd44..c6a4313300f 100644 --- a/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py +++ b/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py @@ -90,8 +90,78 @@ else: ) return properties - transport = get_openlineage_listener().adapter.get_or_create_openlineage_client().transport - if transport.kind != "http": + def _get_transport_information_as_spark_properties() -> dict: + """Retrieve transport information as Spark properties.""" + + def _get_transport_information(tp) -> dict: + props = { + "type": tp.kind, + "url": tp.url, + "endpoint": tp.endpoint, + "timeoutInMillis": str( + int(tp.timeout * 1000) + # convert to milliseconds, as required by Spark integration + ), + } + if hasattr(tp, "compression") and tp.compression: + props["compression"] = str(tp.compression) + + if hasattr(tp.config.auth, "api_key") and tp.config.auth.get_bearer(): + props["auth.type"] = "api_key" + props["auth.apiKey"] = tp.config.auth.get_bearer() + + if hasattr(tp.config, "custom_headers") and tp.config.custom_headers: + for key, value in tp.config.custom_headers.items(): + props[f"headers.{key}"] = value + return props + + def _format_transport(props: dict, transport: dict, name: str | None): + for key, value in transport.items(): + if name: + props[f"spark.openlineage.transport.transports.{name}.{key}"] = value + else: + props[f"spark.openlineage.transport.{key}"] = value + return props + + transport = ( + get_openlineage_listener().adapter.get_or_create_openlineage_client().transport + ) + + if transport.kind == "composite": + http_transports = {} + for nested_transport in transport.transports: + if nested_transport.kind == "http": + http_transports[nested_transport.name] = _get_transport_information( + nested_transport + ) + else: + name = ( + nested_transport.name if hasattr(nested_transport, "name") else "no-name" + ) + log.info( + "OpenLineage transport type `%s` with name `%s` is not supported in composite transport.", + nested_transport.kind, + name, + ) + if len(http_transports) == 0: + log.warning( + "OpenLineage transport type `composite` does not contain http transport. Skipping " + "injection of OpenLineage transport information into Spark properties.", + ) + return {} + props = { + "spark.openlineage.transport.type": "composite", + "spark.openlineage.transport.continueOnFailure": str( + transport.config.continue_on_failure + ), + } + for name, http_transport in http_transports.items(): + props = _format_transport(props, http_transport, name) + return props + + elif transport.kind == "http": + return _format_transport({}, _get_transport_information(transport), None) + log.info( "OpenLineage transport type `%s` does not support automatic " "injection of OpenLineage transport information into Spark properties.", @@ -99,29 +169,7 @@ else: ) return {} - transport_properties = { - "spark.openlineage.transport.type": "http", - "spark.openlineage.transport.url": transport.url, - "spark.openlineage.transport.endpoint": transport.endpoint, - # Timeout is converted to milliseconds, as required by Spark integration, - "spark.openlineage.transport.timeoutInMillis": str(int(transport.timeout * 1000)), - } - if transport.compression: - transport_properties["spark.openlineage.transport.compression"] = str( - transport.compression - ) - - if hasattr(transport.config.auth, "api_key") and transport.config.auth.get_bearer(): - transport_properties["spark.openlineage.transport.auth.type"] = "api_key" - transport_properties["spark.openlineage.transport.auth.apiKey"] = ( - transport.config.auth.get_bearer() - ) - - if hasattr(transport.config, "custom_headers") and transport.config.custom_headers: - for key, value in transport.config.custom_headers.items(): - transport_properties[f"spark.openlineage.transport.headers.{key}"] = value - - return {**properties, **transport_properties} + return {**properties, **_get_transport_information_as_spark_properties()} __all__ = [ diff --git a/providers/google/tests/unit/google/cloud/openlineage/test_utils.py b/providers/google/tests/unit/google/cloud/openlineage/test_utils.py index 4be23910335..21317983d62 100644 --- a/providers/google/tests/unit/google/cloud/openlineage/test_utils.py +++ b/providers/google/tests/unit/google/cloud/openlineage/test_utils.py @@ -106,6 +106,7 @@ EXAMPLE_TEMPLATE = { EXAMPLE_CONTEXT = { "ti": MagicMock( dag_id="dag_id", + dag_run=MagicMock(run_after=dt.datetime(2024, 11, 11), logical_date=dt.datetime(2024, 11, 11)), task_id="task_id", try_number=1, map_index=1, @@ -574,7 +575,8 @@ def test_replace_dataproc_job_properties_key_error(): def test_inject_openlineage_properties_into_dataproc_job_provider_not_accessible(mock_is_accessible): mock_is_accessible.return_value = False job = {"sparkJob": {"properties": {"existingProperty": "value"}}} - result = inject_openlineage_properties_into_dataproc_job(job, None, True, True) + + result = inject_openlineage_properties_into_dataproc_job(job, EXAMPLE_CONTEXT, True, True) assert result == job @@ -586,7 +588,7 @@ def test_inject_openlineage_properties_into_dataproc_job_unsupported_job_type( mock_is_accessible.return_value = True mock_extract_job_type.return_value = None job = {"unsupportedJob": {"properties": {"existingProperty": "value"}}} - result = inject_openlineage_properties_into_dataproc_job(job, None, True, True) + result = inject_openlineage_properties_into_dataproc_job(job, EXAMPLE_CONTEXT, True, True) assert result == job @@ -599,7 +601,9 @@ def test_inject_openlineage_properties_into_dataproc_job_no_injection( mock_extract_job_type.return_value = "sparkJob" inject_parent_job_info = False job = {"sparkJob": {"properties": {"existingProperty": "value"}}} - result = inject_openlineage_properties_into_dataproc_job(job, None, inject_parent_job_info, False) + result = inject_openlineage_properties_into_dataproc_job( + job, EXAMPLE_CONTEXT, inject_parent_job_info, False + ) assert result == job diff --git a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py index 73f3e6ce848..aa7712d7d86 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py @@ -1559,9 +1559,13 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase): op.execute(context=self.mock_context) assert not mock_defer.called + @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_execute_openlineage_parent_job_info_injection(self, mock_hook, mock_ol_accessible): + def test_execute_openlineage_parent_job_info_injection( + self, mock_hook, mock_ol_accessible, mock_static_uuid + ): + mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" job_config = { "placement": {"cluster_name": CLUSTER_NAME}, "pyspark_job": { @@ -1620,13 +1624,15 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase): metadata=METADATA, ) + @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_openlineage_http_transport_info_injection( - self, mock_hook, mock_ol_accessible, mock_ol_listener + self, mock_hook, mock_ol_accessible, mock_ol_listener, mock_static_uuid ): mock_ol_accessible.return_value = True + mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) ) @@ -1673,11 +1679,15 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase): metadata=METADATA, ) + @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_execute_openlineage_all_info_injection(self, mock_hook, mock_ol_accessible, mock_ol_listener): + def test_execute_openlineage_all_info_injection( + self, mock_hook, mock_ol_accessible, mock_ol_listener, mock_static_uuid + ): mock_ol_accessible.return_value = True + mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) ) @@ -2705,10 +2715,14 @@ class TestDataprocWorkflowTemplateInstantiateInlineOperator: ) mock_op.return_value.result.assert_not_called() + @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_execute_openlineage_parent_job_info_injection(self, mock_hook, mock_ol_accessible): + def test_execute_openlineage_parent_job_info_injection( + self, mock_hook, mock_ol_accessible, mock_static_uuid + ): mock_ol_accessible.return_value = True + mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" template = { **WORKFLOW_TEMPLATE, "jobs": [ @@ -2891,13 +2905,15 @@ class TestDataprocWorkflowTemplateInstantiateInlineOperator: metadata=METADATA, ) + @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_openlineage_transport_info_injection( - self, mock_hook, mock_ol_accessible, mock_ol_listener + self, mock_hook, mock_ol_accessible, mock_ol_listener, mock_static_uuid ): mock_ol_accessible.return_value = True + mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) ) @@ -2995,11 +3011,15 @@ class TestDataprocWorkflowTemplateInstantiateInlineOperator: metadata=METADATA, ) + @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_execute_openlineage_all_info_injection(self, mock_hook, mock_ol_accessible, mock_ol_listener): + def test_execute_openlineage_all_info_injection( + self, mock_hook, mock_ol_accessible, mock_ol_listener, mock_static_uuid + ): mock_ol_accessible.return_value = True + mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) ) @@ -3467,11 +3487,15 @@ class TestDataprocCreateBatchOperator: metadata=METADATA, ) + @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_execute_openlineage_parent_job_info_injection(self, mock_hook, to_dict_mock, mock_ol_accessible): + def test_execute_openlineage_parent_job_info_injection( + self, mock_hook, to_dict_mock, mock_ol_accessible, mock_static_uuid + ): mock_ol_accessible.return_value = True + mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" expected_batch = { **BATCH, "runtime_config": {"properties": OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES}, @@ -3504,14 +3528,16 @@ class TestDataprocCreateBatchOperator: metadata=METADATA, ) + @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_openlineage_transport_info_injection( - self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener + self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener, mock_static_uuid ): mock_ol_accessible.return_value = True + mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) ) @@ -3547,14 +3573,16 @@ class TestDataprocCreateBatchOperator: metadata=METADATA, ) + @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_openlineage_all_info_injection( - self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener + self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener, mock_static_uuid ): mock_ol_accessible.return_value = True + mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) ) diff --git a/providers/openlineage/src/airflow/providers/openlineage/plugins/macros.py b/providers/openlineage/src/airflow/providers/openlineage/plugins/macros.py index fd5194d3f32..ef5b9c0ad64 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/plugins/macros.py +++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/macros.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING from airflow.providers.openlineage import conf from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter from airflow.providers.openlineage.utils.utils import get_job_name +from airflow.providers.openlineage.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: from airflow.models import TaskInstance @@ -58,15 +59,25 @@ def lineage_run_id(task_instance: TaskInstance): For more information take a look at the guide: :ref:`howto/macros:openlineage` """ - if hasattr(task_instance, "logical_date"): - logical_date = task_instance.logical_date + if AIRFLOW_V_3_0_PLUS: + context = task_instance.get_template_context() + if hasattr(task_instance, "dag_run"): + dag_run = task_instance.dag_run + elif hasattr(context, "dag_run"): + dag_run = context["dag_run"] + if hasattr(dag_run, "logical_date") and dag_run.logical_date: + date = dag_run.logical_date + else: + date = dag_run.run_after + elif hasattr(task_instance, "logical_date"): + date = task_instance.logical_date else: - logical_date = task_instance.execution_date + date = task_instance.execution_date return OpenLineageAdapter.build_task_instance_run_id( dag_id=task_instance.dag_id, task_id=task_instance.task_id, try_number=task_instance.try_number, - logical_date=logical_date, + logical_date=date, map_index=task_instance.map_index, ) diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py index c2991e65180..9f0fef84be0 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py @@ -53,35 +53,73 @@ def _get_parent_job_information_as_spark_properties(context: Context) -> dict: def _get_transport_information_as_spark_properties() -> dict: """Retrieve transport information as Spark properties.""" - transport = get_openlineage_listener().adapter.get_or_create_openlineage_client().transport - if transport.kind != "http": - log.info( - "OpenLineage transport type `%s` does not support automatic " - "injection of OpenLineage transport information into Spark properties.", - transport.kind, - ) - return {} - - properties = { - "spark.openlineage.transport.type": transport.kind, - "spark.openlineage.transport.url": transport.url, - "spark.openlineage.transport.endpoint": transport.endpoint, - "spark.openlineage.transport.timeoutInMillis": str( - int(transport.timeout * 1000) # convert to milliseconds, as required by Spark integration - ), - } - if transport.compression: - properties["spark.openlineage.transport.compression"] = str(transport.compression) - if hasattr(transport.config.auth, "api_key") and transport.config.auth.get_bearer(): - properties["spark.openlineage.transport.auth.type"] = "api_key" - properties["spark.openlineage.transport.auth.apiKey"] = transport.config.auth.get_bearer() + def _get_transport_information(tp) -> dict: + properties = { + "type": tp.kind, + "url": tp.url, + "endpoint": tp.endpoint, + "timeoutInMillis": str( + int(tp.timeout * 1000) # convert to milliseconds, as required by Spark integration + ), + } + if hasattr(tp, "compression") and tp.compression: + properties["compression"] = str(tp.compression) + + if hasattr(tp.config.auth, "api_key") and tp.config.auth.get_bearer(): + properties["auth.type"] = "api_key" + properties["auth.apiKey"] = tp.config.auth.get_bearer() + + if hasattr(tp.config, "custom_headers") and tp.config.custom_headers: + for key, value in tp.config.custom_headers.items(): + properties[f"headers.{key}"] = value + return properties + + def _format_transport(props: dict, transport: dict, name: str | None): + for key, value in transport.items(): + if name: + props[f"spark.openlineage.transport.transports.{name}.{key}"] = value + else: + props[f"spark.openlineage.transport.{key}"] = value + return props - if hasattr(transport.config, "custom_headers") and transport.config.custom_headers: - for key, value in transport.config.custom_headers.items(): - properties[f"spark.openlineage.transport.headers.{key}"] = value + transport = get_openlineage_listener().adapter.get_or_create_openlineage_client().transport - return properties + if transport.kind == "composite": + http_transports = {} + for nested_transport in transport.transports: + if nested_transport.kind == "http": + http_transports[nested_transport.name] = _get_transport_information(nested_transport) + else: + name = nested_transport.name if hasattr(nested_transport, "name") else "no-name" + log.info( + "OpenLineage transport type `%s` with name `%s` is not supported in composite transport.", + nested_transport.kind, + name, + ) + if len(http_transports) == 0: + log.warning( + "OpenLineage transport type `composite` does not contain http transport. Skipping " + "injection of OpenLineage transport information into Spark properties.", + ) + return {} + props = { + "spark.openlineage.transport.type": "composite", + "spark.openlineage.transport.continueOnFailure": str(transport.config.continue_on_failure), + } + for name, http_transport in http_transports.items(): + props = _format_transport(props, http_transport, name) + return props + + elif transport.kind == "http": + return _format_transport({}, _get_transport_information(transport), None) + + log.info( + "OpenLineage transport type `%s` does not support automatic " + "injection of OpenLineage transport information into Spark properties.", + transport.kind, + ) + return {} def _is_parent_job_information_present_in_spark_properties(properties: dict) -> bool: diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_macros.py b/providers/openlineage/tests/unit/openlineage/plugins/test_macros.py index 28eb8c74cd9..dbc4455fb2b 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_macros.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_macros.py @@ -19,6 +19,8 @@ from __future__ import annotations from datetime import datetime, timezone from unittest import mock +import pytest + from airflow import __version__ from airflow.providers.openlineage.conf import namespace from airflow.providers.openlineage.plugins.macros import ( @@ -28,6 +30,8 @@ from airflow.providers.openlineage.plugins.macros import ( lineage_run_id, ) +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + _DAG_NAMESPACE = namespace() if __version__.startswith("2."): @@ -51,18 +55,42 @@ def test_lineage_job_name(): def test_lineage_run_id(): + date = datetime(2020, 1, 1, 1, 1, 1, 0, tzinfo=timezone.utc) + dag_run = mock.MagicMock(run_id="run_id") + dag_run.logical_date = date + task_instance = mock.MagicMock( + dag_id="dag_id", + task_id="task_id", + dag_run=dag_run, + logical_date=date, + try_number=1, + ) + + call_result1 = lineage_run_id(task_instance) + call_result2 = lineage_run_id(task_instance) + + # random part value does not matter, it just has to be the same for the same TaskInstance + assert call_result1 == call_result2 + # execution_date is used as most significant bits of UUID + assert call_result1.startswith("016f5e9e-c4c8-") + + [email protected](not AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow 3.0+") +def test_lineage_run_after_airflow_3(): + dag_run = mock.MagicMock(run_id="run_id") + dag_run.run_after = datetime(2020, 1, 1, 1, 1, 1, 0, tzinfo=timezone.utc) + dag_run.logical_date = None task_instance = mock.MagicMock( dag_id="dag_id", task_id="task_id", - dag_run=mock.MagicMock(run_id="run_id"), - logical_date=datetime(2020, 1, 1, 1, 1, 1, 0, tzinfo=timezone.utc), + dag_run=dag_run, try_number=1, ) call_result1 = lineage_run_id(task_instance) call_result2 = lineage_run_id(task_instance) - # random part value does not matter, it just have to be the same for the same TaskInstance + # random part value does not matter, it just has to be the same for the same TaskInstance assert call_result1 == call_result2 # execution_date is used as most significant bits of UUID assert call_result1.startswith("016f5e9e-c4c8-") diff --git a/providers/openlineage/tests/unit/openlineage/utils/test_spark.py b/providers/openlineage/tests/unit/openlineage/utils/test_spark.py index c4073da4699..dba2f337128 100644 --- a/providers/openlineage/tests/unit/openlineage/utils/test_spark.py +++ b/providers/openlineage/tests/unit/openlineage/utils/test_spark.py @@ -21,7 +21,9 @@ import datetime as dt from unittest.mock import MagicMock, patch import pytest -from openlineage.client.transport import HttpConfig, HttpTransport, KafkaConfig, KafkaTransport +from openlineage.client.transport.composite import CompositeConfig, CompositeTransport +from openlineage.client.transport.http import HttpConfig, HttpTransport +from openlineage.client.transport.kafka import KafkaConfig, KafkaTransport from airflow.providers.openlineage.utils.spark import ( _get_parent_job_information_as_spark_properties, @@ -38,10 +40,12 @@ EXAMPLE_CONTEXT = { task_id="task_id", try_number=1, map_index=1, + dag_run=MagicMock(logical_date=dt.datetime(2024, 11, 11)), logical_date=dt.datetime(2024, 11, 11), ) } EXAMPLE_HTTP_TRANSPORT_CONFIG = { + "type": "http", "url": "https://some-custom.url", "endpoint": "/api/custom", "timeout": 123, @@ -55,6 +59,17 @@ EXAMPLE_HTTP_TRANSPORT_CONFIG = { "apiKey": "secret_123", }, } +EXAMPLE_KAFKA_TRANSPORT_CONFIG = { + "type": "kafka", + "topic": "my_topic", + "config": { + "bootstrap.servers": "test-kafka-hm0fo:10011,another.host-uuj0l:10012", + "acks": "all", + "retries": "3", + }, + "flush": True, + "messageKey": "some", +} EXAMPLE_PARENT_JOB_SPARK_PROPERTIES = { "spark.openlineage.parentJobName": "dag_id.task_id", "spark.openlineage.parentJobNamespace": "default", @@ -72,6 +87,20 @@ EXAMPLE_TRANSPORT_SPARK_PROPERTIES = { "spark.openlineage.transport.timeoutInMillis": "123000", } +EXAMPLE_COMPOSITE_TRANSPORT_SPARK_PROPERTIES = { + "spark.openlineage.transport.type": "composite", + "spark.openlineage.transport.continueOnFailure": "True", + "spark.openlineage.transport.transports.http.type": "http", + "spark.openlineage.transport.transports.http.url": "https://some-custom.url", + "spark.openlineage.transport.transports.http.endpoint": "/api/custom", + "spark.openlineage.transport.transports.http.auth.type": "api_key", + "spark.openlineage.transport.transports.http.auth.apiKey": "Bearer secret_123", + "spark.openlineage.transport.transports.http.compression": "gzip", + "spark.openlineage.transport.transports.http.headers.key1": "val1", + "spark.openlineage.transport.transports.http.headers.key2": "val2", + "spark.openlineage.transport.transports.http.timeoutInMillis": "123000", +} + def test_get_parent_job_information_as_spark_properties(): result = _get_parent_job_information_as_spark_properties(EXAMPLE_CONTEXT) @@ -106,6 +135,17 @@ def test_get_transport_information_as_spark_properties_unsupported_transport_typ assert result == {} +@patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") +def test_get_transport_information_as_spark_properties_composite_transport_type(mock_ol_listener): + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = CompositeTransport( + CompositeConfig.from_dict( + {"transports": {"http": EXAMPLE_HTTP_TRANSPORT_CONFIG, "kafka": EXAMPLE_KAFKA_TRANSPORT_CONFIG}} + ) + ) + result = _get_transport_information_as_spark_properties() + assert result == EXAMPLE_COMPOSITE_TRANSPORT_SPARK_PROPERTIES + + @pytest.mark.parametrize( "properties, expected", [ @@ -260,3 +300,50 @@ def test_inject_transport_information_into_spark_properties(mock_ol_listener, pr result = inject_transport_information_into_spark_properties(properties, EXAMPLE_CONTEXT) expected = {**properties, **EXAMPLE_TRANSPORT_SPARK_PROPERTIES} if should_inject else properties assert result == expected + + [email protected]( + "properties, should_inject", + [ + ( + {"spark.openlineage.transport": "example_namespace"}, + False, + ), + ( + {"spark.openlineage.transport.type": "some_job_name"}, + False, + ), + ( + {"spark.openlineage.transport.url": "some_run_id"}, + False, + ), + ( + {"spark.openlineage.transportWhatever": "some_value", "some.other.property": "value"}, + False, + ), + ( + {"some.other.property": "value"}, + True, + ), + ( + {}, + True, + ), + ], +) +@patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") +def test_inject_composite_transport_information_into_spark_properties( + mock_ol_listener, properties, should_inject +): + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = CompositeTransport( + CompositeConfig( + transports={ + "http": EXAMPLE_HTTP_TRANSPORT_CONFIG, + "console": {"type": "console"}, + }, + continue_on_failure=True, + ) + ) + result = inject_transport_information_into_spark_properties(properties, EXAMPLE_CONTEXT) + expected = {**properties, **EXAMPLE_COMPOSITE_TRANSPORT_SPARK_PROPERTIES} if should_inject else properties + assert result == expected
