This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4bc7ffb86df feat: allow to set task/dag labels for 
`DataprocCreateBatchOperator` (#46781)
4bc7ffb86df is described below

commit 4bc7ffb86df0515bee5703ffe4d2be97412ddb5d
Author: Nikita Trush <[email protected]>
AuthorDate: Sun Mar 9 08:28:59 2025 -0700

    feat: allow to set task/dag labels for `DataprocCreateBatchOperator` 
(#46781)
    
    * Set task/dag labels for Dataproc Batch
    
    * Adding airflow-dag-display-name label support
---
 .../providers/google/cloud/operators/dataproc.py   |  27 ++++++
 .../unit/google/cloud/operators/test_dataproc.py   | 108 ++++++++++++++++++++-
 2 files changed, 134 insertions(+), 1 deletion(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py 
b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
index eff64faff87..4486ad8e762 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
@@ -2530,6 +2530,8 @@ class 
DataprocCreateBatchOperator(GoogleCloudBaseOperator):
             self.log.info("Automatic injection of OpenLineage information into 
Spark properties is enabled.")
             self._inject_openlineage_properties_into_dataproc_batch(context)
 
+        self.__update_batch_labels()
+
         try:
             self.operation = self.hook.create_batch(
                 region=self.region,
@@ -2710,6 +2712,31 @@ class 
DataprocCreateBatchOperator(GoogleCloudBaseOperator):
                 exc_info=e,
             )
 
+    def __update_batch_labels(self):
+        dag_id = re.sub(r"[.\s]", "_", self.dag_id.lower())
+        task_id = re.sub(r"[.\s]", "_", self.task_id.lower())
+
+        labels_regex = re.compile(r"^[a-z][\w-]{0,63}$")
+        if not labels_regex.match(dag_id) or not labels_regex.match(task_id):
+            return
+
+        labels_limit = 32
+        new_labels = {"airflow-dag-id": dag_id, "airflow-task-id": task_id}
+
+        if self._dag:
+            dag_display_name = re.sub(r"[.\s]", "_", 
self._dag.dag_display_name.lower())
+            if labels_regex.match(dag_id):
+                new_labels["airflow-dag-display-name"] = dag_display_name
+
+        if isinstance(self.batch, Batch):
+            if len(self.batch.labels) + len(new_labels) <= labels_limit:
+                self.batch.labels.update(new_labels)
+        elif "labels" not in self.batch:
+            self.batch["labels"] = new_labels
+        elif isinstance(self.batch.get("labels"), dict):
+            if len(self.batch["labels"]) + len(new_labels) <= labels_limit:
+                self.batch["labels"].update(new_labels)
+
 
 class DataprocDeleteBatchOperator(GoogleCloudBaseOperator):
     """
diff --git 
a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py 
b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
index f15de603053..7eacfcf4ca4 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
@@ -18,8 +18,9 @@ from __future__ import annotations
 
 import datetime as dt
 import inspect
+from copy import deepcopy
 from unittest import mock
-from unittest.mock import MagicMock, Mock, call
+from unittest.mock import ANY, MagicMock, Mock, call
 
 import pytest
 from google.api_core.exceptions import AlreadyExists, NotFound
@@ -3775,6 +3776,111 @@ class TestDataprocCreateBatchOperator:
             metadata=METADATA,
         )
 
+    @staticmethod
+    def __assert_batch_create(mock_hook, expected_batch):
+        mock_hook.return_value.create_batch.assert_called_once_with(
+            region=ANY,
+            project_id=ANY,
+            batch=expected_batch,
+            batch_id=ANY,
+            request_id=ANY,
+            retry=ANY,
+            timeout=ANY,
+            metadata=ANY,
+        )
+
+    @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_create_batch_asdict_labels_updated(self, mock_hook, to_dict_mock):
+        expected_labels = {
+            "airflow-dag-id": "test_dag",
+            "airflow-dag-display-name": "test_dag",
+            "airflow-task-id": "test-task",
+        }
+
+        expected_batch = {
+            **BATCH,
+            "labels": expected_labels,
+        }
+
+        DataprocCreateBatchOperator(
+            task_id="test-task",
+            dag=DAG(dag_id="test_dag"),
+            batch=BATCH,
+            region=GCP_REGION,
+        ).execute(context=EXAMPLE_CONTEXT)
+
+        TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook, 
expected_batch)
+
+    @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_create_batch_asdict_labels_uppercase_transformed(self, mock_hook, 
to_dict_mock):
+        expected_labels = {
+            "airflow-dag-id": "test_dag",
+            "airflow-dag-display-name": "test_dag",
+            "airflow-task-id": "test-task",
+        }
+
+        expected_batch = {
+            **BATCH,
+            "labels": expected_labels,
+        }
+
+        DataprocCreateBatchOperator(
+            task_id="test-TASK",
+            dag=DAG(dag_id="Test_dag"),
+            batch=BATCH,
+            region=GCP_REGION,
+        ).execute(context=EXAMPLE_CONTEXT)
+
+        TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook, 
expected_batch)
+
+    @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_create_batch_invalid_taskid_labels_ignored(self, mock_hook, 
to_dict_mock):
+        DataprocCreateBatchOperator(
+            task_id=".task-id",
+            dag=DAG(dag_id="test-dag"),
+            batch=BATCH,
+            region=GCP_REGION,
+        ).execute(context=EXAMPLE_CONTEXT)
+
+        TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook, BATCH)
+
+    @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_create_batch_long_taskid_labels_ignored(self, mock_hook, 
to_dict_mock):
+        DataprocCreateBatchOperator(
+            task_id="a" * 65,
+            dag=DAG(dag_id="test-dag"),
+            batch=BATCH,
+            region=GCP_REGION,
+        ).execute(context=EXAMPLE_CONTEXT)
+
+        TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook, BATCH)
+
+    @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_create_batch_asobj_labels_updated(self, mock_hook, to_dict_mock):
+        batch = Batch(name="test")
+        batch.labels["foo"] = "bar"
+        dag = DAG(dag_id="test_dag")
+
+        expected_labels = {
+            "airflow-dag-id": "test_dag",
+            "airflow-dag-display-name": "test_dag",
+            "airflow-task-id": "test-task",
+        }
+
+        expected_batch = deepcopy(batch)
+        expected_batch.labels.update(expected_labels)
+
+        DataprocCreateBatchOperator(task_id="test-task", batch=batch, 
region=GCP_REGION, dag=dag).execute(
+            context=EXAMPLE_CONTEXT
+        )
+
+        TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook, 
expected_batch)
+
 
 class TestDataprocDeleteBatchOperator:
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))

Reply via email to