This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch task-sdk-first-code
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit a4bd6b47e27699440df04871d200b2df600646dc
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Thu Oct 24 12:31:43 2024 +0100

    Update pre-commit scripts.
    
    I hve removed some of the checks to reflect the new code structure. Now that
    the `def dag` decorator is defined in runtime as `def dag(**kwargs)` it 
means
    the default check is not useful/relevant anymore.
---
 scripts/ci/pre_commit/sync_init_decorator.py      | 195 ++++++++++++++--------
 task_sdk/src/airflow/sdk/definitions/dag.py       |  15 +-
 task_sdk/src/airflow/sdk/definitions/taskgroup.py |   6 +
 3 files changed, 136 insertions(+), 80 deletions(-)

diff --git a/scripts/ci/pre_commit/sync_init_decorator.py 
b/scripts/ci/pre_commit/sync_init_decorator.py
index fc99fc48d7b..13e80d62c6c 100755
--- a/scripts/ci/pre_commit/sync_init_decorator.py
+++ b/scripts/ci/pre_commit/sync_init_decorator.py
@@ -26,38 +26,71 @@ import pathlib
 import sys
 from typing import TYPE_CHECKING
 
-PACKAGE_ROOT = pathlib.Path(__file__).resolve().parents[3].joinpath("airflow")
-DAG_PY = PACKAGE_ROOT.joinpath("models", "dag.py")
-UTILS_TG_PY = PACKAGE_ROOT.joinpath("utils", "task_group.py")
+ROOT = pathlib.Path(__file__).resolve().parents[3]
+PACKAGE_ROOT = ROOT.joinpath("airflow")
+SDK_DEFINITIONS_PKG = ROOT.joinpath("task_sdk", "src", "airflow", "sdk", 
"definitions")
+DAG_PY = SDK_DEFINITIONS_PKG.joinpath("dag.py")
+TG_PY = SDK_DEFINITIONS_PKG.joinpath("taskgroup.py")
 DECOS_TG_PY = PACKAGE_ROOT.joinpath("decorators", "task_group.py")
 
 
-def _find_dag_init(mod: ast.Module) -> ast.FunctionDef:
-    """Find definition of the ``DAG`` class's ``__init__``."""
-    dag_class = next(n for n in ast.iter_child_nodes(mod) if isinstance(n, 
ast.ClassDef) and n.name == "DAG")
-    return next(
-        node
-        for node in ast.iter_child_nodes(dag_class)
-        if isinstance(node, ast.FunctionDef) and node.name == "__init__"
+def _name(node: ast.expr) -> str:
+    if not isinstance(node, ast.Name):
+        raise TypeError("node was not an ast.Name node")
+    return node.id
+
+
+def _find_cls_attrs(
+    mod: ast.Module, class_name: str, ignore: list[str] | None = None
+) -> collections.abc.Iterable[ast.AnnAssign]:
+    """Find the type-annotated/attrs properties in the body of the specified 
class."""
+    dag_class = next(
+        n for n in ast.iter_child_nodes(mod) if isinstance(n, ast.ClassDef) 
and n.name == class_name
     )
 
+    ignore = ignore or []
+
+    for node in ast.iter_child_nodes(dag_class):
+        if not isinstance(node, ast.AnnAssign) or not node.annotation:
+            continue
+
+        # ClassVar[Any]
+        if isinstance(node.annotation, ast.Subscript) and 
_name(node.annotation.value) == "ClassVar":
+            continue
+
+        # Skip private attrs fields, ones with `attrs.field(init=False)` kwargs
+        if isinstance(node.value, ast.Call):
+            # Lazy coding: since init=True is the default, we're just looking 
for the presence of the init
+            # arg name
+            if TYPE_CHECKING:
+                assert isinstance(node.value.func, ast.Attribute)
+            if (
+                node.value.func.attr == "field"
+                and _name(node.value.func.value) == "attrs"
+                and any(arg.arg == "init" for arg in node.value.keywords)
+            ):
+                continue
+        if _name(node.target) in ignore:
+            continue
+
+        # Attrs treats `_group_id: str` as `group_id` arg to __init__
+        if _name(node.target).startswith("_"):
+            node.target.id = node.target.id[1:]  # type: ignore[union-attr]
+        yield node
+
 
 def _find_dag_deco(mod: ast.Module) -> ast.FunctionDef:
     """Find definition of the ``@dag`` decorator."""
-    return next(n for n in ast.iter_child_nodes(mod) if isinstance(n, 
ast.FunctionDef) and n.name == "dag")
-
-
-def _find_tg_init(mod: ast.Module) -> ast.FunctionDef:
-    """Find definition of the ``TaskGroup`` class's ``__init__``."""
-    task_group_class = next(
+    # We now define the signature in a type checking block, the runtime impl 
uses **kwargs
+    type_checking_blocks = (
         node
         for node in ast.iter_child_nodes(mod)
-        if isinstance(node, ast.ClassDef) and node.name == "TaskGroup"
+        if isinstance(node, ast.If) and node.test.id == "TYPE_CHECKING"  # 
type: ignore[attr-defined]
     )
     return next(
-        node
-        for node in ast.iter_child_nodes(task_group_class)
-        if isinstance(node, ast.FunctionDef) and node.name == "__init__"
+        n
+        for n in itertools.chain.from_iterable(map(ast.iter_child_nodes, 
type_checking_blocks))
+        if isinstance(n, ast.FunctionDef) and n.name == "dag"
     )
 
 
@@ -74,43 +107,72 @@ def _find_tg_deco(mod: ast.Module) -> ast.FunctionDef:
     )
 
 
+# Hard-code some specific examples of allowable decorate type annotation -> 
class type annotation mappings
+# where they don't match exactly
+
+
+def _expr_to_ast_dump(expr: str) -> str:
+    return ast.dump(ast.parse(expr).body[0].value)  # type: 
ignore[attr-defined]
+
+
+ALLOWABLE_TYPE_ANNOTATIONS = {
+    _expr_to_ast_dump("Collection[str] | None"): 
_expr_to_ast_dump("MutableSet[str]")
+}
+
+
 def _match_arguments(
-    init_def: tuple[str, list[ast.arg]],
+    init_def: tuple[str, list[ast.AnnAssign]],
     deco_def: tuple[str, list[ast.arg]],
 ) -> collections.abc.Iterator[str]:
     init_name, init_args = init_def
     deco_name, deco_args = deco_def
+    init_args.sort(key=lambda a: _name(a.target))
+    deco_args.sort(key=lambda a: a.arg)
     for i, (ini, dec) in enumerate(itertools.zip_longest(init_args, deco_args, 
fillvalue=None)):
         if ini is None and dec is not None:
             yield f"Argument present in @{deco_name} but missing from 
{init_name}: {dec.arg}"
             return
         if dec is None and ini is not None:
-            yield f"Argument present in {init_name} but missing from 
@{deco_name}: {ini.arg}"
+            yield f"Argument present in {init_name} but missing from 
@{deco_name}: {_name(ini.target)}"
             return
+        if TYPE_CHECKING:
+            # Mypy can't work out that zip_longest means one of ini or dec 
must be non None
+            assert ini is not None
+
+        if not isinstance(ini.target, ast.Name):
+            raise RuntimeError(f"Don't know how to examine 
{ast.unparse(ini)!r}")
+        attr_name = _name(ini.target)
 
         if TYPE_CHECKING:
             assert ini is not None and dec is not None  # Because None is only 
possible as fillvalue.
 
-        if ini.arg != dec.arg:
-            yield f"Argument {i + 1} mismatch: {init_name} has {ini.arg} but 
@{deco_name} has {dec.arg}"
+        if attr_name != dec.arg:
+            yield f"Argument {i + 1} mismatch: {init_name} has {attr_name} but 
@{deco_name} has {dec.arg}"
             return
 
         if getattr(ini, "type_comment", None):  # 3.8+
-            yield f"Do not use type comments on {init_name} argument: 
{ini.arg}"
+            yield f"Do not use type comments on {init_name} argument: {ini}"
         if getattr(dec, "type_comment", None):  # 3.8+
             yield f"Do not use type comments on @{deco_name} argument: 
{dec.arg}"
 
         # Poorly implemented node equality check.
-        if ini.annotation and dec.annotation and ast.dump(ini.annotation) != 
ast.dump(dec.annotation):
-            yield (
-                f"Type annotations differ on argument {ini.arg} between 
{init_name} and @{deco_name}: "
-                f"{ast.unparse(ini.annotation)} != 
{ast.unparse(dec.annotation)}"
-            )
-        else:
-            if not ini.annotation:
-                yield f"Type annotation missing on {init_name} argument: 
{ini.arg}"
-            if not dec.annotation:
-                yield f"Type annotation missing on @{deco_name} argument: 
{ini.arg}"
+        if ini.annotation and dec.annotation:
+            ini_anno = ast.dump(ini.annotation)
+            dec_anno = ast.dump(dec.annotation)
+            if (
+                ini_anno != dec_anno
+                # The decorator can have `| None` type in addaition to the 
base attribute
+                and dec_anno != f"BinOp(left={ini_anno}, op=BitOr(), 
right=Constant(value=None))"
+                and ALLOWABLE_TYPE_ANNOTATIONS.get(dec_anno) != ini_anno
+            ):
+                yield (
+                    f"Type annotations differ on argument {attr_name!r} 
between {init_name} and @{deco_name}: "
+                    f"{ast.unparse(ini.annotation)} != 
{ast.unparse(dec.annotation)}"
+                )
+        elif not ini.annotation:
+            yield f"Type annotation missing on {init_name} argument: 
{attr_name}"
+        elif not dec.annotation:
+            yield f"Type annotation missing on @{deco_name} argument: 
{attr_name}"
 
 
 def _match_defaults(
@@ -130,47 +192,38 @@ def _match_defaults(
 
 def check_dag_init_decorator_arguments() -> int:
     dag_mod = ast.parse(DAG_PY.read_text("utf-8"), str(DAG_PY))
-
-    utils_tg = ast.parse(UTILS_TG_PY.read_text("utf-8"), str(UTILS_TG_PY))
+    tg_mod = ast.parse(TG_PY.read_text("utf-8"), str(TG_PY))
     decos_tg = ast.parse(DECOS_TG_PY.read_text("utf-8"), str(DECOS_TG_PY))
 
     items_to_check = [
-        ("DAG", _find_dag_init(dag_mod), "dag", _find_dag_deco(dag_mod), 
"dag_id", ""),
-        ("TaskGroup", _find_tg_init(utils_tg), "task_group", 
_find_tg_deco(decos_tg), "group_id", None),
+        (
+            "DAG",
+            list(_find_cls_attrs(dag_mod, "DAG", ignore=["full_filepath", 
"task_group"])),
+            "dag",
+            _find_dag_deco(dag_mod),
+            "dag_id",
+            "",
+        ),
+        (
+            "TaskGroup",
+            list(_find_cls_attrs(tg_mod, "TaskGroup")),
+            "_task_group",
+            _find_tg_deco(decos_tg),
+            "group_id",
+            None,
+        ),
     ]
 
-    for init_name, init, deco_name, deco, id_arg, id_default in items_to_check:
-        if getattr(init.args, "posonlyargs", None) or getattr(deco.args, 
"posonlyargs", None):
-            print(f"{init_name} and @{deco_name} should not declare 
positional-only arguments")
-            return -1
-        if init.args.vararg or init.args.kwarg or deco.args.vararg or 
deco.args.kwarg:
-            print(f"{init_name} and @{deco_name} should not declare *args and 
**kwargs")
-            return -1
-
-        # Feel free to change this and make some of the arguments keyword-only!
-        if init.args.kwonlyargs or deco.args.kwonlyargs:
-            print(f"{init_name}() and @{deco_name}() should not declare 
keyword-only arguments")
-            return -2
-        if init.args.kw_defaults or deco.args.kw_defaults:
-            print(f"{init_name}() and @{deco_name}() should not declare 
keyword-only arguments")
-            return -2
-
-        init_arg_names = [a.arg for a in init.args.args]
+    for init_name, cls_attrs, deco_name, deco, id_arg, id_default in 
items_to_check:
         deco_arg_names = [a.arg for a in deco.args.args]
 
-        if init_arg_names[0] != "self":
-            print(f"First argument in {init_name} must be 'self'")
-            return -3
-        if init_arg_names[1] != id_arg:
-            print(f"Second argument in {init_name} must be {id_arg!r}")
+        if _name(cls_attrs[0].target) != id_arg:
+            print(f"First attribute in {init_name} must be {id_arg!r} (got 
{cls_attrs[0]!r})")
             return -3
         if deco_arg_names[0] != id_arg:
-            print(f"First argument in @{deco_name} must be {id_arg!r}")
+            print(f"First argument in @{deco_name} must be {id_arg!r} (got 
{deco_arg_names[0]!r})")
             return -3
 
-        if len(init.args.defaults) != len(init_arg_names) - 2:
-            print(f"All arguments on {init_name} except self and {id_arg} must 
have defaults")
-            return -4
         if len(deco.args.defaults) != len(deco_arg_names):
             print(f"All arguments on @{deco_name} must have defaults")
             return -4
@@ -178,13 +231,11 @@ def check_dag_init_decorator_arguments() -> int:
             print(f"Default {id_arg} on @{deco_name} must be {id_default!r}")
             return -4
 
-    for init_name, init, deco_name, deco, _, _ in items_to_check:
-        errors = list(_match_arguments((init_name, init.args.args[1:]), 
(deco_name, deco.args.args)))
-        if errors:
-            break
-        init_defaults_def = (init_name, init.args.defaults)
-        deco_defaults_def = (deco_name, deco.args.defaults[1:])
-        errors = list(_match_defaults(deco_arg_names, init_defaults_def, 
deco_defaults_def))
+    errors = []
+    for init_name, cls_attrs, deco_name, deco, _, _ in items_to_check:
+        errors = list(
+            _match_arguments((init_name, cls_attrs), (deco_name, 
deco.args.args + deco.args.kwonlyargs))
+        )
         if errors:
             break
 
diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py 
b/task_sdk/src/airflow/sdk/definitions/dag.py
index 71dbdc6f9eb..e784639f8c9 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -340,7 +340,6 @@ class DAG:
     template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined
     user_defined_macros: dict | None = None
     user_defined_filters: dict | None = None
-    concurrency: int | None = None
     max_active_tasks: int = attrs.field(default=16, 
validator=attrs.validators.instance_of(int))
     max_active_runs: int = attrs.field(default=16, 
validator=attrs.validators.instance_of(int))
     max_consecutive_failed_dag_runs: int = attrs.field(
@@ -360,7 +359,7 @@ class DAG:
         default=None,
         converter=attrs.Converter(_convert_params, takes_self=True),  # type: 
ignore[misc, call-overload]
     )
-    access_control: dict | None = attrs.field(
+    access_control: dict[str, dict[str, Collection[str]]] | dict[str, 
Collection[str]] | None = attrs.field(
         default=None, converter=attrs.Converter(_convert_access_control, 
takes_self=True)
     )
     is_paused_upon_creation: bool | None = None
@@ -935,7 +934,7 @@ class DAG:
                 "_log",
                 "task_dict",
                 "template_searchpath",
-                "sla_miss_callback",
+                # "sla_miss_callback",
                 "on_success_callback",
                 "on_failure_callback",
                 "template_undefined",
@@ -997,17 +996,17 @@ if TYPE_CHECKING:
         template_undefined: type[jinja2.StrictUndefined] = 
jinja2.StrictUndefined,
         user_defined_macros: dict | None = None,
         user_defined_filters: dict | None = None,
-        default_args: dict | None = None,
+        default_args: dict[str, Any] | None = None,
         max_active_tasks: int = ...,
         max_active_runs: int = ...,
         max_consecutive_failed_dag_runs: int = ...,
         dagrun_timeout: timedelta | None = None,
-        sla_miss_callback: Any = None,
+        # sla_miss_callback: Any = None,
         catchup: bool = ...,
-        on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
-        on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
+        # on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
+        # on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
         doc_md: str | None = None,
-        params: abc.MutableMapping | None = None,
+        params: ParamsDict | None = None,
         access_control: dict[str, dict[str, Collection[str]]] | dict[str, 
Collection[str]] | None = None,
         is_paused_upon_creation: bool | None = None,
         jinja_environment_kwargs: dict | None = None,
diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py 
b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
index 72cd9ed3bc7..e417eab2760 100644
--- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -112,6 +112,8 @@ class TaskGroup(DAGNode):
     ui_color: str = "CornflowerBlue"
     ui_fgcolor: str = "#000"
 
+    add_suffix_on_collision: bool = False
+
     @parent_group.default
     def _default_parent_group(self):
         from airflow.sdk.definitions.contextmanager import TaskGroupContext
@@ -132,6 +134,10 @@ class TaskGroup(DAGNode):
             raise RuntimeError("TaskGroup can only be used inside a dag")
 
     def __attrs_post_init__(self):
+        # TODO: If attrs supported init only args we could use that here
+        # https://github.com/python-attrs/attrs/issues/342
+        self._check_for_group_id_collisions(self.add_suffix_on_collision)
+
         if self.parent_group:
             object.__setattr__(self, "used_group_ids", 
self.parent_group.used_group_ids)
             self.parent_group.add(self)

Reply via email to