This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4bc7ffb86df feat: allow to set task/dag labels for
`DataprocCreateBatchOperator` (#46781)
4bc7ffb86df is described below
commit 4bc7ffb86df0515bee5703ffe4d2be97412ddb5d
Author: Nikita Trush <[email protected]>
AuthorDate: Sun Mar 9 08:28:59 2025 -0700
feat: allow to set task/dag labels for `DataprocCreateBatchOperator`
(#46781)
* Set task/dag labels for Dataproc Batch
* Adding airflow-dag-display-name label support
---
.../providers/google/cloud/operators/dataproc.py | 27 ++++++
.../unit/google/cloud/operators/test_dataproc.py | 108 ++++++++++++++++++++-
2 files changed, 134 insertions(+), 1 deletion(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
index eff64faff87..4486ad8e762 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
@@ -2530,6 +2530,8 @@ class
DataprocCreateBatchOperator(GoogleCloudBaseOperator):
self.log.info("Automatic injection of OpenLineage information into
Spark properties is enabled.")
self._inject_openlineage_properties_into_dataproc_batch(context)
+ self.__update_batch_labels()
+
try:
self.operation = self.hook.create_batch(
region=self.region,
@@ -2710,6 +2712,31 @@ class
DataprocCreateBatchOperator(GoogleCloudBaseOperator):
exc_info=e,
)
+ def __update_batch_labels(self):
+ dag_id = re.sub(r"[.\s]", "_", self.dag_id.lower())
+ task_id = re.sub(r"[.\s]", "_", self.task_id.lower())
+
+ labels_regex = re.compile(r"^[a-z][\w-]{0,63}$")
+ if not labels_regex.match(dag_id) or not labels_regex.match(task_id):
+ return
+
+ labels_limit = 32
+ new_labels = {"airflow-dag-id": dag_id, "airflow-task-id": task_id}
+
+ if self._dag:
+ dag_display_name = re.sub(r"[.\s]", "_",
self._dag.dag_display_name.lower())
+ if labels_regex.match(dag_id):
+ new_labels["airflow-dag-display-name"] = dag_display_name
+
+ if isinstance(self.batch, Batch):
+ if len(self.batch.labels) + len(new_labels) <= labels_limit:
+ self.batch.labels.update(new_labels)
+ elif "labels" not in self.batch:
+ self.batch["labels"] = new_labels
+ elif isinstance(self.batch.get("labels"), dict):
+ if len(self.batch["labels"]) + len(new_labels) <= labels_limit:
+ self.batch["labels"].update(new_labels)
+
class DataprocDeleteBatchOperator(GoogleCloudBaseOperator):
"""
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
index f15de603053..7eacfcf4ca4 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
@@ -18,8 +18,9 @@ from __future__ import annotations
import datetime as dt
import inspect
+from copy import deepcopy
from unittest import mock
-from unittest.mock import MagicMock, Mock, call
+from unittest.mock import ANY, MagicMock, Mock, call
import pytest
from google.api_core.exceptions import AlreadyExists, NotFound
@@ -3775,6 +3776,111 @@ class TestDataprocCreateBatchOperator:
metadata=METADATA,
)
+ @staticmethod
+ def __assert_batch_create(mock_hook, expected_batch):
+ mock_hook.return_value.create_batch.assert_called_once_with(
+ region=ANY,
+ project_id=ANY,
+ batch=expected_batch,
+ batch_id=ANY,
+ request_id=ANY,
+ retry=ANY,
+ timeout=ANY,
+ metadata=ANY,
+ )
+
+ @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_create_batch_asdict_labels_updated(self, mock_hook, to_dict_mock):
+ expected_labels = {
+ "airflow-dag-id": "test_dag",
+ "airflow-dag-display-name": "test_dag",
+ "airflow-task-id": "test-task",
+ }
+
+ expected_batch = {
+ **BATCH,
+ "labels": expected_labels,
+ }
+
+ DataprocCreateBatchOperator(
+ task_id="test-task",
+ dag=DAG(dag_id="test_dag"),
+ batch=BATCH,
+ region=GCP_REGION,
+ ).execute(context=EXAMPLE_CONTEXT)
+
+ TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook,
expected_batch)
+
+ @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_create_batch_asdict_labels_uppercase_transformed(self, mock_hook,
to_dict_mock):
+ expected_labels = {
+ "airflow-dag-id": "test_dag",
+ "airflow-dag-display-name": "test_dag",
+ "airflow-task-id": "test-task",
+ }
+
+ expected_batch = {
+ **BATCH,
+ "labels": expected_labels,
+ }
+
+ DataprocCreateBatchOperator(
+ task_id="test-TASK",
+ dag=DAG(dag_id="Test_dag"),
+ batch=BATCH,
+ region=GCP_REGION,
+ ).execute(context=EXAMPLE_CONTEXT)
+
+ TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook,
expected_batch)
+
+ @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_create_batch_invalid_taskid_labels_ignored(self, mock_hook,
to_dict_mock):
+ DataprocCreateBatchOperator(
+ task_id=".task-id",
+ dag=DAG(dag_id="test-dag"),
+ batch=BATCH,
+ region=GCP_REGION,
+ ).execute(context=EXAMPLE_CONTEXT)
+
+ TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook, BATCH)
+
+ @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_create_batch_long_taskid_labels_ignored(self, mock_hook,
to_dict_mock):
+ DataprocCreateBatchOperator(
+ task_id="a" * 65,
+ dag=DAG(dag_id="test-dag"),
+ batch=BATCH,
+ region=GCP_REGION,
+ ).execute(context=EXAMPLE_CONTEXT)
+
+ TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook, BATCH)
+
+ @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_create_batch_asobj_labels_updated(self, mock_hook, to_dict_mock):
+ batch = Batch(name="test")
+ batch.labels["foo"] = "bar"
+ dag = DAG(dag_id="test_dag")
+
+ expected_labels = {
+ "airflow-dag-id": "test_dag",
+ "airflow-dag-display-name": "test_dag",
+ "airflow-task-id": "test-task",
+ }
+
+ expected_batch = deepcopy(batch)
+ expected_batch.labels.update(expected_labels)
+
+ DataprocCreateBatchOperator(task_id="test-task", batch=batch,
region=GCP_REGION, dag=dag).execute(
+ context=EXAMPLE_CONTEXT
+ )
+
+ TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook,
expected_batch)
+
class TestDataprocDeleteBatchOperator:
@mock.patch(DATAPROC_PATH.format("DataprocHook"))