This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 48a5a0aae97 feat: automatically inject OL info into spark job in 
DataprocInstantiateInlineWorkflowTemplateOperator (#44697)
48a5a0aae97 is described below

commit 48a5a0aae9782d2df66f91cdc5f8cf46985026c8
Author: Kacper Muda <[email protected]>
AuthorDate: Thu Jan 2 09:14:19 2025 +0100

    feat: automatically inject OL info into spark job in 
DataprocInstantiateInlineWorkflowTemplateOperator (#44697)
    
    Signed-off-by: Kacper Muda <[email protected]>
---
 docs/exts/templates/openlineage.rst.jinja2         |   2 +
 .../providers/google/cloud/openlineage/utils.py    |  64 ++++++
 .../providers/google/cloud/operators/dataproc.py   |  13 ++
 .../tests/google/cloud/openlineage/test_utils.py   | 123 +++++++++++
 .../tests/google/cloud/operators/test_dataproc.py  | 236 +++++++++++++++++++++
 5 files changed, 438 insertions(+)

diff --git a/docs/exts/templates/openlineage.rst.jinja2 
b/docs/exts/templates/openlineage.rst.jinja2
index 217e634457c..af5798d5d51 100644
--- a/docs/exts/templates/openlineage.rst.jinja2
+++ b/docs/exts/templates/openlineage.rst.jinja2
@@ -38,6 +38,8 @@ apache-airflow-providers-google
     - Parent Job Information
 - 
:class:`~airflow.providers.google.cloud.operators.dataproc.DataprocCreateBatchOperator`
     - Parent Job Information
+- 
:class:`~airflow.providers.google.cloud.operators.dataproc.DataprocInstantiateInlineWorkflowTemplateOperator`
+    - Parent Job Information
 
 
 :class:`~airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`
diff --git a/providers/src/airflow/providers/google/cloud/openlineage/utils.py 
b/providers/src/airflow/providers/google/cloud/openlineage/utils.py
index 0f3dcb5d4be..c6fbadba953 100644
--- a/providers/src/airflow/providers/google/cloud/openlineage/utils.py
+++ b/providers/src/airflow/providers/google/cloud/openlineage/utils.py
@@ -622,3 +622,67 @@ def inject_openlineage_properties_into_dataproc_batch(
 
     batch_with_ol_config = _replace_dataproc_batch_properties(batch=batch, 
new_properties=properties)
     return batch_with_ol_config
+
+
+def inject_openlineage_properties_into_dataproc_workflow_template(
+    template: dict, context: Context, inject_parent_job_info: bool
+) -> dict:
+    """
+    Inject OpenLineage properties into Spark jobs in Workflow Template.
+
+    Function is not removing any configuration or modifying the jobs in any 
other way,
+    apart from adding desired OpenLineage properties to Dataproc job 
definition if not already present.
+
+    Note:
+        Any modification to job will be skipped if:
+            - OpenLineage provider is not accessible.
+            - The job type is not supported.
+            - Automatic parent job information injection is disabled.
+            - Any OpenLineage properties with parent job information are 
already present
+              in the Spark job definition.
+
+    Args:
+        template: The original Dataproc Workflow Template definition.
+        context: The Airflow context in which the job is running.
+        inject_parent_job_info: Flag indicating whether to inject parent job 
information.
+
+    Returns:
+        The modified Workflow Template definition with OpenLineage properties 
injected, if applicable.
+    """
+    if not inject_parent_job_info:
+        log.debug("Automatic injection of OpenLineage information is 
disabled.")
+        return template
+
+    if not _is_openlineage_provider_accessible():
+        log.warning(
+            "Could not access OpenLineage provider for automatic OpenLineage "
+            "properties injection. No action will be performed."
+        )
+        return template
+
+    final_jobs = []
+    for single_job_definition in template["jobs"]:
+        step_id = single_job_definition["step_id"]
+        log.debug("Injecting OpenLineage properties into Workflow step: `%s`", 
step_id)
+
+        if (job_type := 
_extract_supported_job_type_from_dataproc_job(single_job_definition)) is None:
+            log.debug(
+                "Could not find a supported Dataproc job type for automatic 
OpenLineage "
+                "properties injection. No action will be performed.",
+            )
+            final_jobs.append(single_job_definition)
+            continue
+
+        properties = single_job_definition[job_type].get("properties", {})
+
+        properties = inject_parent_job_information_into_spark_properties(
+            properties=properties, context=context
+        )
+
+        job_with_ol_config = _replace_dataproc_job_properties(
+            job=single_job_definition, job_type=job_type, 
new_properties=properties
+        )
+        final_jobs.append(job_with_ol_config)
+
+    template["jobs"] = final_jobs
+    return template
diff --git a/providers/src/airflow/providers/google/cloud/operators/dataproc.py 
b/providers/src/airflow/providers/google/cloud/operators/dataproc.py
index 5e64f7d9207..1d5ced10283 100644
--- a/providers/src/airflow/providers/google/cloud/operators/dataproc.py
+++ b/providers/src/airflow/providers/google/cloud/operators/dataproc.py
@@ -57,6 +57,7 @@ from airflow.providers.google.cloud.links.dataproc import (
 from airflow.providers.google.cloud.openlineage.utils import (
     inject_openlineage_properties_into_dataproc_batch,
     inject_openlineage_properties_into_dataproc_job,
+    inject_openlineage_properties_into_dataproc_workflow_template,
 )
 from airflow.providers.google.cloud.operators.cloud_base import 
GoogleCloudBaseOperator
 from airflow.providers.google.cloud.triggers.dataproc import (
@@ -1825,6 +1826,9 @@ class 
DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         polling_interval_seconds: int = 10,
         cancel_on_kill: bool = True,
+        openlineage_inject_parent_job_info: bool = conf.getboolean(
+            "openlineage", "spark_inject_parent_job_info", fallback=False
+        ),
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -1844,11 +1848,20 @@ class 
DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
         self.polling_interval_seconds = polling_interval_seconds
         self.cancel_on_kill = cancel_on_kill
         self.operation_name: str | None = None
+        self.openlineage_inject_parent_job_info = 
openlineage_inject_parent_job_info
 
     def execute(self, context: Context):
         self.log.info("Instantiating Inline Template")
         hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, 
impersonation_chain=self.impersonation_chain)
         project_id = self.project_id or hook.project_id
+        if self.openlineage_inject_parent_job_info:
+            self.log.info("Automatic injection of OpenLineage information into 
Spark properties is enabled.")
+            self.template = 
inject_openlineage_properties_into_dataproc_workflow_template(
+                template=self.template,
+                context=context,
+                inject_parent_job_info=self.openlineage_inject_parent_job_info,
+            )
+
         operation = hook.instantiate_inline_workflow_template(
             template=self.template,
             project_id=project_id,
diff --git a/providers/tests/google/cloud/openlineage/test_utils.py 
b/providers/tests/google/cloud/openlineage/test_utils.py
index 58949125f84..b5e451debe5 100644
--- a/providers/tests/google/cloud/openlineage/test_utils.py
+++ b/providers/tests/google/cloud/openlineage/test_utils.py
@@ -48,6 +48,7 @@ from airflow.providers.google.cloud.openlineage.utils import (
     get_identity_column_lineage_facet,
     inject_openlineage_properties_into_dataproc_batch,
     inject_openlineage_properties_into_dataproc_job,
+    inject_openlineage_properties_into_dataproc_workflow_template,
     merge_column_lineage_facets,
 )
 
@@ -829,3 +830,125 @@ def 
test_inject_openlineage_properties_into_dataproc_batch(mock_is_ol_accessible
     }
     result = inject_openlineage_properties_into_dataproc_batch(batch, context, 
True)
     assert result == expected_batch
+
+
+@patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
+def 
test_inject_openlineage_properties_into_dataproc_workflow_template_provider_not_accessible(
+    mock_is_accessible,
+):
+    mock_is_accessible.return_value = False
+    template = {"workflow": "template"}  # It does not matter what the dict 
is, we should return it unmodified
+    result = 
inject_openlineage_properties_into_dataproc_workflow_template(template, None, 
True)
+    assert result == template
+
+
+@patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
+@patch("airflow.providers.google.cloud.openlineage.utils._extract_supported_job_type_from_dataproc_job")
+def 
test_inject_openlineage_properties_into_dataproc_workflow_template_no_inject_parent_job_info(
+    mock_extract_job_type, mock_is_accessible
+):
+    mock_is_accessible.return_value = True
+    mock_extract_job_type.return_value = "sparkJob"
+    inject_parent_job_info = False
+    template = {"workflow": "template"}  # It does not matter what the dict 
is, we should return it unmodified
+    result = inject_openlineage_properties_into_dataproc_workflow_template(
+        template, None, inject_parent_job_info
+    )
+    assert result == template
+
+
+@patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
+def 
test_inject_openlineage_properties_into_dataproc_workflow_template(mock_is_ol_accessible):
+    mock_is_ol_accessible.return_value = True
+    context = {
+        "ti": MagicMock(
+            dag_id="dag_id",
+            task_id="task_id",
+            try_number=1,
+            map_index=1,
+            logical_date=dt.datetime(2024, 11, 11),
+        )
+    }
+    template = {
+        "id": "test-workflow",
+        "placement": {
+            "cluster_selector": {
+                "zone": "europe-central2-c",
+                "cluster_labels": {"key": "value"},
+            }
+        },
+        "jobs": [
+            {
+                "step_id": "job_1",
+                "pyspark_job": {
+                    "main_python_file_uri": "gs://bucket1/spark_job.py",
+                    "properties": {
+                        "spark.sql.shuffle.partitions": "1",
+                    },
+                },
+            },
+            {
+                "step_id": "job_2",
+                "pyspark_job": {
+                    "main_python_file_uri": "gs://bucket2/spark_job.py",
+                    "properties": {
+                        "spark.sql.shuffle.partitions": "1",
+                        "spark.openlineage.parentJobNamespace": "test",
+                    },
+                },
+            },
+            {
+                "step_id": "job_3",
+                "hive_job": {
+                    "main_python_file_uri": "gs://bucket3/hive_job.py",
+                    "properties": {
+                        "spark.sql.shuffle.partitions": "1",
+                    },
+                },
+            },
+        ],
+    }
+    expected_template = {
+        "id": "test-workflow",
+        "placement": {
+            "cluster_selector": {
+                "zone": "europe-central2-c",
+                "cluster_labels": {"key": "value"},
+            }
+        },
+        "jobs": [
+            {
+                "step_id": "job_1",
+                "pyspark_job": {
+                    "main_python_file_uri": "gs://bucket1/spark_job.py",
+                    "properties": {  # Injected properties
+                        "spark.sql.shuffle.partitions": "1",
+                        "spark.openlineage.parentJobName": "dag_id.task_id",
+                        "spark.openlineage.parentJobNamespace": "default",
+                        "spark.openlineage.parentRunId": 
"01931885-2800-7be7-aa8d-aaa15c337267",
+                    },
+                },
+            },
+            {
+                "step_id": "job_2",
+                "pyspark_job": {  # Not modified because it's already present
+                    "main_python_file_uri": "gs://bucket2/spark_job.py",
+                    "properties": {
+                        "spark.sql.shuffle.partitions": "1",
+                        "spark.openlineage.parentJobNamespace": "test",
+                    },
+                },
+            },
+            {
+                "step_id": "job_3",
+                "hive_job": {  # Not modified because it's unsupported job type
+                    "main_python_file_uri": "gs://bucket3/hive_job.py",
+                    "properties": {
+                        "spark.sql.shuffle.partitions": "1",
+                    },
+                },
+            },
+        ],
+    }
+    result = 
inject_openlineage_properties_into_dataproc_workflow_template(template, 
context, True)
+    assert result == expected_template
diff --git a/providers/tests/google/cloud/operators/test_dataproc.py 
b/providers/tests/google/cloud/operators/test_dataproc.py
index 5d4a9b0d79c..f79a0bdba0c 100644
--- a/providers/tests/google/cloud/operators/test_dataproc.py
+++ b/providers/tests/google/cloud/operators/test_dataproc.py
@@ -2356,6 +2356,242 @@ class 
TestDataprocWorkflowTemplateInstantiateInlineOperator:
         )
         mock_op.return_value.result.assert_not_called()
 
+    
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_execute_openlineage_parent_job_info_injection(self, mock_hook, 
mock_ol_accessible):
+        mock_ol_accessible.return_value = True
+        context = {
+            "ti": MagicMock(
+                dag_id="dag_id",
+                task_id="task_id",
+                try_number=1,
+                map_index=1,
+                logical_date=dt.datetime(2024, 11, 11),
+            )
+        }
+        template = {
+            "id": "test-workflow",
+            "placement": {
+                "cluster_selector": {
+                    "zone": "europe-central2-c",
+                    "cluster_labels": {"key": "value"},
+                }
+            },
+            "jobs": [
+                {
+                    "step_id": "job_1",
+                    "pyspark_job": {
+                        "main_python_file_uri": "gs://bucket1/spark_job.py",
+                        "properties": {
+                            "spark.sql.shuffle.partitions": "1",
+                        },
+                    },
+                },
+                {
+                    "step_id": "job_2",
+                    "pyspark_job": {
+                        "main_python_file_uri": "gs://bucket2/spark_job.py",
+                        "properties": {
+                            "spark.sql.shuffle.partitions": "1",
+                            "spark.openlineage.parentJobNamespace": "test",
+                        },
+                    },
+                },
+                {
+                    "step_id": "job_3",
+                    "hive_job": {
+                        "main_python_file_uri": "gs://bucket3/hive_job.py",
+                        "properties": {
+                            "spark.sql.shuffle.partitions": "1",
+                        },
+                    },
+                },
+            ],
+            "parameters": [
+                {
+                    "name": "ZONE",
+                    "fields": [
+                        "placement.clusterSelector.zone",
+                    ],
+                }
+            ],
+        }
+        expected_template = {
+            "id": "test-workflow",
+            "placement": {
+                "cluster_selector": {
+                    "zone": "europe-central2-c",
+                    "cluster_labels": {"key": "value"},
+                }
+            },
+            "jobs": [
+                {
+                    "step_id": "job_1",
+                    "pyspark_job": {
+                        "main_python_file_uri": "gs://bucket1/spark_job.py",
+                        "properties": {  # Injected properties
+                            "spark.sql.shuffle.partitions": "1",
+                            "spark.openlineage.parentJobName": 
"dag_id.task_id",
+                            "spark.openlineage.parentJobNamespace": "default",
+                            "spark.openlineage.parentRunId": 
"01931885-2800-7be7-aa8d-aaa15c337267",
+                        },
+                    },
+                },
+                {
+                    "step_id": "job_2",
+                    "pyspark_job": {  # Not modified because it's already 
present
+                        "main_python_file_uri": "gs://bucket2/spark_job.py",
+                        "properties": {
+                            "spark.sql.shuffle.partitions": "1",
+                            "spark.openlineage.parentJobNamespace": "test",
+                        },
+                    },
+                },
+                {
+                    "step_id": "job_3",
+                    "hive_job": {  # Not modified because it's unsupported job 
type
+                        "main_python_file_uri": "gs://bucket3/hive_job.py",
+                        "properties": {
+                            "spark.sql.shuffle.partitions": "1",
+                        },
+                    },
+                },
+            ],
+            "parameters": [
+                {
+                    "name": "ZONE",
+                    "fields": [
+                        "placement.clusterSelector.zone",
+                    ],
+                }
+            ],
+        }
+
+        op = DataprocInstantiateInlineWorkflowTemplateOperator(
+            task_id=TASK_ID,
+            template=template,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            request_id=REQUEST_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            openlineage_inject_parent_job_info=True,
+        )
+        op.execute(context=context)
+        mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, 
impersonation_chain=IMPERSONATION_CHAIN)
+        
mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once_with(
+            template=expected_template,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            request_id=REQUEST_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+    
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def 
test_execute_openlineage_parent_job_info_injection_skipped_by_default_unless_enabled(
+        self, mock_hook, mock_ol_accessible
+    ):
+        mock_ol_accessible.return_value = True
+
+        template = {
+            "id": "test-workflow",
+            "placement": {
+                "cluster_selector": {
+                    "zone": "europe-central2-c",
+                    "cluster_labels": {"key": "value"},
+                }
+            },
+            "jobs": [
+                {
+                    "step_id": "job_1",
+                    "pyspark_job": {
+                        "main_python_file_uri": "gs://bucket1/spark_job.py",
+                    },
+                }
+            ],
+        }
+
+        op = DataprocInstantiateInlineWorkflowTemplateOperator(
+            task_id=TASK_ID,
+            template=template,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            request_id=REQUEST_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            # not passing openlineage_inject_parent_job_info, should be False 
by default
+        )
+        op.execute(context=MagicMock())
+        mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, 
impersonation_chain=IMPERSONATION_CHAIN)
+        
mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once_with(
+            template=template,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            request_id=REQUEST_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+    
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def 
test_execute_openlineage_parent_job_info_injection_skipped_when_ol_not_accessible(
+        self, mock_hook, mock_ol_accessible
+    ):
+        mock_ol_accessible.return_value = False
+
+        template = {
+            "id": "test-workflow",
+            "placement": {
+                "cluster_selector": {
+                    "zone": "europe-central2-c",
+                    "cluster_labels": {"key": "value"},
+                }
+            },
+            "jobs": [
+                {
+                    "step_id": "job_1",
+                    "pyspark_job": {
+                        "main_python_file_uri": "gs://bucket1/spark_job.py",
+                    },
+                }
+            ],
+        }
+
+        op = DataprocInstantiateInlineWorkflowTemplateOperator(
+            task_id=TASK_ID,
+            template=template,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            request_id=REQUEST_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            openlineage_inject_parent_job_info=True,
+        )
+        op.execute(context=MagicMock())
+        mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, 
impersonation_chain=IMPERSONATION_CHAIN)
+        
mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once_with(
+            template=template,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            request_id=REQUEST_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
 
 @pytest.mark.db_test
 @pytest.mark.need_serialized_dag

Reply via email to