This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 88cef4f5986 Validate dag_id and group_id in sdk (#48613)
88cef4f5986 is described below
commit 88cef4f5986cd7798c463cd4444bfbf4257b1470
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Tue Apr 1 18:56:52 2025 +0800
Validate dag_id and group_id in sdk (#48613)
---
airflow-core/tests/unit/models/test_backfill.py | 2 +-
.../src/airflow/sdk/definitions/_internal/node.py | 18 ++++++--
task-sdk/src/airflow/sdk/definitions/dag.py | 3 +-
task-sdk/src/airflow/sdk/definitions/taskgroup.py | 9 +++-
task-sdk/tests/task_sdk/definitions/test_dag.py | 29 ++++++++++++
.../tests/task_sdk/definitions/test_taskgroup.py | 53 ++++++++++++++++++++++
6 files changed, 107 insertions(+), 7 deletions(-)
diff --git a/airflow-core/tests/unit/models/test_backfill.py
b/airflow-core/tests/unit/models/test_backfill.py
index b7835738c02..a67e3d958d8 100644
--- a/airflow-core/tests/unit/models/test_backfill.py
+++ b/airflow-core/tests/unit/models/test_backfill.py
@@ -177,7 +177,7 @@ def test_reprocess_behavior(reprocess_behavior, num_in_b,
exc_reasons, dag_maker
# introduce runs for a dag different from the test dag
# so that we can verify that queries won't pick up runs from
# other dags with same date
- with dag_maker(schedule="@daily", dag_id="noise dag"):
+ with dag_maker(schedule="@daily", dag_id="noise-dag"):
PythonOperator(task_id="hi", python_callable=print)
date = "2021-01-06"
dr = dag_maker.create_dagrun(
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/node.py
b/task-sdk/src/airflow/sdk/definitions/_internal/node.py
index 05968290d7b..1a87477185d 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/node.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/node.py
@@ -46,12 +46,24 @@ def validate_key(k: str, max_length: int = 250):
"""Validate value used as a key."""
if not isinstance(k, str):
raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
- if len(k) > max_length:
- raise ValueError(f"The key has to be less than {max_length}
characters")
+ if (length := len(k)) > max_length:
+ raise ValueError(f"The key has to be less than {max_length}
characters, not {length}")
if not KEY_REGEX.match(k):
raise ValueError(
f"The key {k!r} has to be made of alphanumeric characters, dashes,
"
- "dots and underscores exclusively"
+ f"dots, and underscores exclusively"
+ )
+
+
+def validate_group_key(k: str, max_length: int = 200):
+ """Validate value used as a group key."""
+ if not isinstance(k, str):
+ raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
+ if (length := len(k)) > max_length:
+ raise ValueError(f"The key has to be less than {max_length}
characters, not {length}")
+ if not GROUP_KEY_REGEX.match(k):
+ raise ValueError(
+ f"The key {k!r} has to be made of alphanumeric characters, dashes,
and underscores exclusively"
)
diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py
b/task-sdk/src/airflow/sdk/definitions/dag.py
index b7249587a23..75ea3055aa1 100644
--- a/task-sdk/src/airflow/sdk/definitions/dag.py
+++ b/task-sdk/src/airflow/sdk/definitions/dag.py
@@ -53,6 +53,7 @@ from airflow.exceptions import (
)
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
+from airflow.sdk.definitions._internal.node import validate_key
from airflow.sdk.definitions._internal.types import NOTSET
from airflow.sdk.definitions.asset import AssetAll, BaseAsset
from airflow.sdk.definitions.context import Context
@@ -375,7 +376,7 @@ class DAG:
# NOTE: When updating arguments here, please also keep arguments in @dag()
# below in sync. (Search for 'def dag(' in this file.)
- dag_id: str = attrs.field(kw_only=False,
validator=attrs.validators.instance_of(str))
+ dag_id: str = attrs.field(kw_only=False, validator=lambda i, a, v:
validate_key(v))
description: str | None = attrs.field(
default=None,
validator=attrs.validators.optional(attrs.validators.instance_of(str)),
diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
index 65ae3f4c2c4..ec06e4bf368 100644
--- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -36,7 +36,7 @@ from airflow.exceptions import (
DuplicateTaskIdFound,
TaskAlreadyInTaskGroup,
)
-from airflow.sdk.definitions._internal.node import DAGNode
+from airflow.sdk.definitions._internal.node import DAGNode, validate_group_key
from airflow.utils.trigger_rule import TriggerRule
if TYPE_CHECKING:
@@ -73,6 +73,11 @@ def _default_dag(instance: TaskGroup):
return DagContext.get_current()
+# Mypy does not like a lambda for some reason. An explicit annotated function
makes it happy.
+def _validate_group_id(instance, attribute, value: str) -> None:
+ validate_group_key(value)
+
+
@attrs.define(repr=False)
class TaskGroup(DAGNode):
"""
@@ -106,7 +111,7 @@ class TaskGroup(DAGNode):
"""
_group_id: str | None = attrs.field(
- validator=attrs.validators.optional(attrs.validators.instance_of(str)),
+ validator=attrs.validators.optional(_validate_group_id),
# This is the default behaviour for attrs, but by specifying this it
makes IDEs happier
alias="group_id",
)
diff --git a/task-sdk/tests/task_sdk/definitions/test_dag.py
b/task-sdk/tests/task_sdk/definitions/test_dag.py
index 6c00f69653b..b97ac29c1c8 100644
--- a/task-sdk/tests/task_sdk/definitions/test_dag.py
+++ b/task-sdk/tests/task_sdk/definitions/test_dag.py
@@ -31,6 +31,35 @@ DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc)
class TestDag:
+ @pytest.mark.parametrize(
+ "dag_id, exc_type, exc_value",
+ [
+ pytest.param(
+ 123,
+ TypeError,
+ "The key has to be a string and is <class 'int'>:123",
+ id="type",
+ ),
+ pytest.param(
+ "a" * 1000,
+ ValueError,
+ "The key has to be less than 250 characters, not 1000",
+ id="long",
+ ),
+ pytest.param(
+ "something*invalid",
+ ValueError,
+ "The key 'something*invalid' has to be made of alphanumeric
characters, dashes, "
+ "dots, and underscores exclusively",
+ id="illegal",
+ ),
+ ],
+ )
+ def test_dag_id_validation(self, dag_id, exc_type, exc_value):
+ with pytest.raises(exc_type) as ctx:
+ DAG(dag_id)
+ assert str(ctx.value) == exc_value
+
def test_dag_topological_sort_dag_without_tasks(self):
dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE,
default_args={"owner": "owner1"})
diff --git a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
new file mode 100644
index 00000000000..4198ac999eb
--- /dev/null
+++ b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pytest
+
+from airflow.sdk.definitions.taskgroup import TaskGroup
+
+
+class TestTaskGroup:
+ @pytest.mark.parametrize(
+ "group_id, exc_type, exc_value",
+ [
+ pytest.param(
+ 123,
+ TypeError,
+ "The key has to be a string and is <class 'int'>:123",
+ id="type",
+ ),
+ pytest.param(
+ "a" * 1000,
+ ValueError,
+ "The key has to be less than 200 characters, not 1000",
+ id="long",
+ ),
+ pytest.param(
+ "something*invalid",
+ ValueError,
+ "The key 'something*invalid' has to be made of alphanumeric
characters, dashes, "
+ "and underscores exclusively",
+ id="illegal",
+ ),
+ ],
+ )
+ def test_dag_id_validation(self, group_id, exc_type, exc_value):
+ with pytest.raises(exc_type) as ctx:
+ TaskGroup(group_id)
+ assert str(ctx.value) == exc_value