This is an automated email from the ASF dual-hosted git repository.

phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 88cef4f5986 Validate dag_id and group_id in sdk (#48613)
88cef4f5986 is described below

commit 88cef4f5986cd7798c463cd4444bfbf4257b1470
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Tue Apr 1 18:56:52 2025 +0800

    Validate dag_id and group_id in sdk (#48613)
---
 airflow-core/tests/unit/models/test_backfill.py    |  2 +-
 .../src/airflow/sdk/definitions/_internal/node.py  | 18 ++++++--
 task-sdk/src/airflow/sdk/definitions/dag.py        |  3 +-
 task-sdk/src/airflow/sdk/definitions/taskgroup.py  |  9 +++-
 task-sdk/tests/task_sdk/definitions/test_dag.py    | 29 ++++++++++++
 .../tests/task_sdk/definitions/test_taskgroup.py   | 53 ++++++++++++++++++++++
 6 files changed, 107 insertions(+), 7 deletions(-)

diff --git a/airflow-core/tests/unit/models/test_backfill.py 
b/airflow-core/tests/unit/models/test_backfill.py
index b7835738c02..a67e3d958d8 100644
--- a/airflow-core/tests/unit/models/test_backfill.py
+++ b/airflow-core/tests/unit/models/test_backfill.py
@@ -177,7 +177,7 @@ def test_reprocess_behavior(reprocess_behavior, num_in_b, 
exc_reasons, dag_maker
     # introduce runs for a dag different from the test dag
     # so that we can verify that queries won't pick up runs from
     # other dags with same date
-    with dag_maker(schedule="@daily", dag_id="noise dag"):
+    with dag_maker(schedule="@daily", dag_id="noise-dag"):
         PythonOperator(task_id="hi", python_callable=print)
     date = "2021-01-06"
     dr = dag_maker.create_dagrun(
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/node.py 
b/task-sdk/src/airflow/sdk/definitions/_internal/node.py
index 05968290d7b..1a87477185d 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/node.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/node.py
@@ -46,12 +46,24 @@ def validate_key(k: str, max_length: int = 250):
     """Validate value used as a key."""
     if not isinstance(k, str):
         raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
-    if len(k) > max_length:
-        raise ValueError(f"The key has to be less than {max_length} 
characters")
+    if (length := len(k)) > max_length:
+        raise ValueError(f"The key has to be less than {max_length} 
characters, not {length}")
     if not KEY_REGEX.match(k):
         raise ValueError(
             f"The key {k!r} has to be made of alphanumeric characters, dashes, 
"
-            "dots and underscores exclusively"
+            f"dots, and underscores exclusively"
+        )
+
+
+def validate_group_key(k: str, max_length: int = 200):
+    """Validate value used as a group key."""
+    if not isinstance(k, str):
+        raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
+    if (length := len(k)) > max_length:
+        raise ValueError(f"The key has to be less than {max_length} 
characters, not {length}")
+    if not GROUP_KEY_REGEX.match(k):
+        raise ValueError(
+            f"The key {k!r} has to be made of alphanumeric characters, dashes, 
and underscores exclusively"
         )
 
 
diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py 
b/task-sdk/src/airflow/sdk/definitions/dag.py
index b7249587a23..75ea3055aa1 100644
--- a/task-sdk/src/airflow/sdk/definitions/dag.py
+++ b/task-sdk/src/airflow/sdk/definitions/dag.py
@@ -53,6 +53,7 @@ from airflow.exceptions import (
 )
 from airflow.sdk.bases.operator import BaseOperator
 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
+from airflow.sdk.definitions._internal.node import validate_key
 from airflow.sdk.definitions._internal.types import NOTSET
 from airflow.sdk.definitions.asset import AssetAll, BaseAsset
 from airflow.sdk.definitions.context import Context
@@ -375,7 +376,7 @@ class DAG:
 
     # NOTE: When updating arguments here, please also keep arguments in @dag()
     # below in sync. (Search for 'def dag(' in this file.)
-    dag_id: str = attrs.field(kw_only=False, 
validator=attrs.validators.instance_of(str))
+    dag_id: str = attrs.field(kw_only=False, validator=lambda i, a, v: 
validate_key(v))
     description: str | None = attrs.field(
         default=None,
         validator=attrs.validators.optional(attrs.validators.instance_of(str)),
diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py 
b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
index 65ae3f4c2c4..ec06e4bf368 100644
--- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -36,7 +36,7 @@ from airflow.exceptions import (
     DuplicateTaskIdFound,
     TaskAlreadyInTaskGroup,
 )
-from airflow.sdk.definitions._internal.node import DAGNode
+from airflow.sdk.definitions._internal.node import DAGNode, validate_group_key
 from airflow.utils.trigger_rule import TriggerRule
 
 if TYPE_CHECKING:
@@ -73,6 +73,11 @@ def _default_dag(instance: TaskGroup):
     return DagContext.get_current()
 
 
+# Mypy does not like a lambda for some reason. An explicit annotated function 
makes it happy.
+def _validate_group_id(instance, attribute, value: str) -> None:
+    validate_group_key(value)
+
+
 @attrs.define(repr=False)
 class TaskGroup(DAGNode):
     """
@@ -106,7 +111,7 @@ class TaskGroup(DAGNode):
     """
 
     _group_id: str | None = attrs.field(
-        validator=attrs.validators.optional(attrs.validators.instance_of(str)),
+        validator=attrs.validators.optional(_validate_group_id),
         # This is the default behaviour for attrs, but by specifying this it 
makes IDEs happier
         alias="group_id",
     )
diff --git a/task-sdk/tests/task_sdk/definitions/test_dag.py 
b/task-sdk/tests/task_sdk/definitions/test_dag.py
index 6c00f69653b..b97ac29c1c8 100644
--- a/task-sdk/tests/task_sdk/definitions/test_dag.py
+++ b/task-sdk/tests/task_sdk/definitions/test_dag.py
@@ -31,6 +31,35 @@ DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc)
 
 
 class TestDag:
+    @pytest.mark.parametrize(
+        "dag_id, exc_type, exc_value",
+        [
+            pytest.param(
+                123,
+                TypeError,
+                "The key has to be a string and is <class 'int'>:123",
+                id="type",
+            ),
+            pytest.param(
+                "a" * 1000,
+                ValueError,
+                "The key has to be less than 250 characters, not 1000",
+                id="long",
+            ),
+            pytest.param(
+                "something*invalid",
+                ValueError,
+                "The key 'something*invalid' has to be made of alphanumeric 
characters, dashes, "
+                "dots, and underscores exclusively",
+                id="illegal",
+            ),
+        ],
+    )
+    def test_dag_id_validation(self, dag_id, exc_type, exc_value):
+        with pytest.raises(exc_type) as ctx:
+            DAG(dag_id)
+        assert str(ctx.value) == exc_value
+
     def test_dag_topological_sort_dag_without_tasks(self):
         dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE, 
default_args={"owner": "owner1"})
 
diff --git a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py 
b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
new file mode 100644
index 00000000000..4198ac999eb
--- /dev/null
+++ b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pytest
+
+from airflow.sdk.definitions.taskgroup import TaskGroup
+
+
+class TestTaskGroup:
+    @pytest.mark.parametrize(
+        "group_id, exc_type, exc_value",
+        [
+            pytest.param(
+                123,
+                TypeError,
+                "The key has to be a string and is <class 'int'>:123",
+                id="type",
+            ),
+            pytest.param(
+                "a" * 1000,
+                ValueError,
+                "The key has to be less than 200 characters, not 1000",
+                id="long",
+            ),
+            pytest.param(
+                "something*invalid",
+                ValueError,
+                "The key 'something*invalid' has to be made of alphanumeric 
characters, dashes, "
+                "and underscores exclusively",
+                id="illegal",
+            ),
+        ],
+    )
+    def test_dag_id_validation(self, group_id, exc_type, exc_value):
+        with pytest.raises(exc_type) as ctx:
+            TaskGroup(group_id)
+        assert str(ctx.value) == exc_value

Reply via email to