This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch task-sdk-first-code in repository https://gitbox.apache.org/repos/asf/airflow.git
commit a4bd6b47e27699440df04871d200b2df600646dc Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Thu Oct 24 12:31:43 2024 +0100 Update pre-commit scripts. I hve removed some of the checks to reflect the new code structure. Now that the `def dag` decorator is defined in runtime as `def dag(**kwargs)` it means the default check is not useful/relevant anymore. --- scripts/ci/pre_commit/sync_init_decorator.py | 195 ++++++++++++++-------- task_sdk/src/airflow/sdk/definitions/dag.py | 15 +- task_sdk/src/airflow/sdk/definitions/taskgroup.py | 6 + 3 files changed, 136 insertions(+), 80 deletions(-) diff --git a/scripts/ci/pre_commit/sync_init_decorator.py b/scripts/ci/pre_commit/sync_init_decorator.py index fc99fc48d7b..13e80d62c6c 100755 --- a/scripts/ci/pre_commit/sync_init_decorator.py +++ b/scripts/ci/pre_commit/sync_init_decorator.py @@ -26,38 +26,71 @@ import pathlib import sys from typing import TYPE_CHECKING -PACKAGE_ROOT = pathlib.Path(__file__).resolve().parents[3].joinpath("airflow") -DAG_PY = PACKAGE_ROOT.joinpath("models", "dag.py") -UTILS_TG_PY = PACKAGE_ROOT.joinpath("utils", "task_group.py") +ROOT = pathlib.Path(__file__).resolve().parents[3] +PACKAGE_ROOT = ROOT.joinpath("airflow") +SDK_DEFINITIONS_PKG = ROOT.joinpath("task_sdk", "src", "airflow", "sdk", "definitions") +DAG_PY = SDK_DEFINITIONS_PKG.joinpath("dag.py") +TG_PY = SDK_DEFINITIONS_PKG.joinpath("taskgroup.py") DECOS_TG_PY = PACKAGE_ROOT.joinpath("decorators", "task_group.py") -def _find_dag_init(mod: ast.Module) -> ast.FunctionDef: - """Find definition of the ``DAG`` class's ``__init__``.""" - dag_class = next(n for n in ast.iter_child_nodes(mod) if isinstance(n, ast.ClassDef) and n.name == "DAG") - return next( - node - for node in ast.iter_child_nodes(dag_class) - if isinstance(node, ast.FunctionDef) and node.name == "__init__" +def _name(node: ast.expr) -> str: + if not isinstance(node, ast.Name): + raise TypeError("node was not an ast.Name node") + return node.id + + +def _find_cls_attrs( + mod: ast.Module, class_name: str, ignore: list[str] | None = None +) -> collections.abc.Iterable[ast.AnnAssign]: + """Find the type-annotated/attrs properties in the body of the specified class.""" + dag_class = next( + n for n in ast.iter_child_nodes(mod) if isinstance(n, ast.ClassDef) and n.name == class_name ) + ignore = ignore or [] + + for node in ast.iter_child_nodes(dag_class): + if not isinstance(node, ast.AnnAssign) or not node.annotation: + continue + + # ClassVar[Any] + if isinstance(node.annotation, ast.Subscript) and _name(node.annotation.value) == "ClassVar": + continue + + # Skip private attrs fields, ones with `attrs.field(init=False)` kwargs + if isinstance(node.value, ast.Call): + # Lazy coding: since init=True is the default, we're just looking for the presence of the init + # arg name + if TYPE_CHECKING: + assert isinstance(node.value.func, ast.Attribute) + if ( + node.value.func.attr == "field" + and _name(node.value.func.value) == "attrs" + and any(arg.arg == "init" for arg in node.value.keywords) + ): + continue + if _name(node.target) in ignore: + continue + + # Attrs treats `_group_id: str` as `group_id` arg to __init__ + if _name(node.target).startswith("_"): + node.target.id = node.target.id[1:] # type: ignore[union-attr] + yield node + def _find_dag_deco(mod: ast.Module) -> ast.FunctionDef: """Find definition of the ``@dag`` decorator.""" - return next(n for n in ast.iter_child_nodes(mod) if isinstance(n, ast.FunctionDef) and n.name == "dag") - - -def _find_tg_init(mod: ast.Module) -> ast.FunctionDef: - """Find definition of the ``TaskGroup`` class's ``__init__``.""" - task_group_class = next( + # We now define the signature in a type checking block, the runtime impl uses **kwargs + type_checking_blocks = ( node for node in ast.iter_child_nodes(mod) - if isinstance(node, ast.ClassDef) and node.name == "TaskGroup" + if isinstance(node, ast.If) and node.test.id == "TYPE_CHECKING" # type: ignore[attr-defined] ) return next( - node - for node in ast.iter_child_nodes(task_group_class) - if isinstance(node, ast.FunctionDef) and node.name == "__init__" + n + for n in itertools.chain.from_iterable(map(ast.iter_child_nodes, type_checking_blocks)) + if isinstance(n, ast.FunctionDef) and n.name == "dag" ) @@ -74,43 +107,72 @@ def _find_tg_deco(mod: ast.Module) -> ast.FunctionDef: ) +# Hard-code some specific examples of allowable decorate type annotation -> class type annotation mappings +# where they don't match exactly + + +def _expr_to_ast_dump(expr: str) -> str: + return ast.dump(ast.parse(expr).body[0].value) # type: ignore[attr-defined] + + +ALLOWABLE_TYPE_ANNOTATIONS = { + _expr_to_ast_dump("Collection[str] | None"): _expr_to_ast_dump("MutableSet[str]") +} + + def _match_arguments( - init_def: tuple[str, list[ast.arg]], + init_def: tuple[str, list[ast.AnnAssign]], deco_def: tuple[str, list[ast.arg]], ) -> collections.abc.Iterator[str]: init_name, init_args = init_def deco_name, deco_args = deco_def + init_args.sort(key=lambda a: _name(a.target)) + deco_args.sort(key=lambda a: a.arg) for i, (ini, dec) in enumerate(itertools.zip_longest(init_args, deco_args, fillvalue=None)): if ini is None and dec is not None: yield f"Argument present in @{deco_name} but missing from {init_name}: {dec.arg}" return if dec is None and ini is not None: - yield f"Argument present in {init_name} but missing from @{deco_name}: {ini.arg}" + yield f"Argument present in {init_name} but missing from @{deco_name}: {_name(ini.target)}" return + if TYPE_CHECKING: + # Mypy can't work out that zip_longest means one of ini or dec must be non None + assert ini is not None + + if not isinstance(ini.target, ast.Name): + raise RuntimeError(f"Don't know how to examine {ast.unparse(ini)!r}") + attr_name = _name(ini.target) if TYPE_CHECKING: assert ini is not None and dec is not None # Because None is only possible as fillvalue. - if ini.arg != dec.arg: - yield f"Argument {i + 1} mismatch: {init_name} has {ini.arg} but @{deco_name} has {dec.arg}" + if attr_name != dec.arg: + yield f"Argument {i + 1} mismatch: {init_name} has {attr_name} but @{deco_name} has {dec.arg}" return if getattr(ini, "type_comment", None): # 3.8+ - yield f"Do not use type comments on {init_name} argument: {ini.arg}" + yield f"Do not use type comments on {init_name} argument: {ini}" if getattr(dec, "type_comment", None): # 3.8+ yield f"Do not use type comments on @{deco_name} argument: {dec.arg}" # Poorly implemented node equality check. - if ini.annotation and dec.annotation and ast.dump(ini.annotation) != ast.dump(dec.annotation): - yield ( - f"Type annotations differ on argument {ini.arg} between {init_name} and @{deco_name}: " - f"{ast.unparse(ini.annotation)} != {ast.unparse(dec.annotation)}" - ) - else: - if not ini.annotation: - yield f"Type annotation missing on {init_name} argument: {ini.arg}" - if not dec.annotation: - yield f"Type annotation missing on @{deco_name} argument: {ini.arg}" + if ini.annotation and dec.annotation: + ini_anno = ast.dump(ini.annotation) + dec_anno = ast.dump(dec.annotation) + if ( + ini_anno != dec_anno + # The decorator can have `| None` type in addaition to the base attribute + and dec_anno != f"BinOp(left={ini_anno}, op=BitOr(), right=Constant(value=None))" + and ALLOWABLE_TYPE_ANNOTATIONS.get(dec_anno) != ini_anno + ): + yield ( + f"Type annotations differ on argument {attr_name!r} between {init_name} and @{deco_name}: " + f"{ast.unparse(ini.annotation)} != {ast.unparse(dec.annotation)}" + ) + elif not ini.annotation: + yield f"Type annotation missing on {init_name} argument: {attr_name}" + elif not dec.annotation: + yield f"Type annotation missing on @{deco_name} argument: {attr_name}" def _match_defaults( @@ -130,47 +192,38 @@ def _match_defaults( def check_dag_init_decorator_arguments() -> int: dag_mod = ast.parse(DAG_PY.read_text("utf-8"), str(DAG_PY)) - - utils_tg = ast.parse(UTILS_TG_PY.read_text("utf-8"), str(UTILS_TG_PY)) + tg_mod = ast.parse(TG_PY.read_text("utf-8"), str(TG_PY)) decos_tg = ast.parse(DECOS_TG_PY.read_text("utf-8"), str(DECOS_TG_PY)) items_to_check = [ - ("DAG", _find_dag_init(dag_mod), "dag", _find_dag_deco(dag_mod), "dag_id", ""), - ("TaskGroup", _find_tg_init(utils_tg), "task_group", _find_tg_deco(decos_tg), "group_id", None), + ( + "DAG", + list(_find_cls_attrs(dag_mod, "DAG", ignore=["full_filepath", "task_group"])), + "dag", + _find_dag_deco(dag_mod), + "dag_id", + "", + ), + ( + "TaskGroup", + list(_find_cls_attrs(tg_mod, "TaskGroup")), + "_task_group", + _find_tg_deco(decos_tg), + "group_id", + None, + ), ] - for init_name, init, deco_name, deco, id_arg, id_default in items_to_check: - if getattr(init.args, "posonlyargs", None) or getattr(deco.args, "posonlyargs", None): - print(f"{init_name} and @{deco_name} should not declare positional-only arguments") - return -1 - if init.args.vararg or init.args.kwarg or deco.args.vararg or deco.args.kwarg: - print(f"{init_name} and @{deco_name} should not declare *args and **kwargs") - return -1 - - # Feel free to change this and make some of the arguments keyword-only! - if init.args.kwonlyargs or deco.args.kwonlyargs: - print(f"{init_name}() and @{deco_name}() should not declare keyword-only arguments") - return -2 - if init.args.kw_defaults or deco.args.kw_defaults: - print(f"{init_name}() and @{deco_name}() should not declare keyword-only arguments") - return -2 - - init_arg_names = [a.arg for a in init.args.args] + for init_name, cls_attrs, deco_name, deco, id_arg, id_default in items_to_check: deco_arg_names = [a.arg for a in deco.args.args] - if init_arg_names[0] != "self": - print(f"First argument in {init_name} must be 'self'") - return -3 - if init_arg_names[1] != id_arg: - print(f"Second argument in {init_name} must be {id_arg!r}") + if _name(cls_attrs[0].target) != id_arg: + print(f"First attribute in {init_name} must be {id_arg!r} (got {cls_attrs[0]!r})") return -3 if deco_arg_names[0] != id_arg: - print(f"First argument in @{deco_name} must be {id_arg!r}") + print(f"First argument in @{deco_name} must be {id_arg!r} (got {deco_arg_names[0]!r})") return -3 - if len(init.args.defaults) != len(init_arg_names) - 2: - print(f"All arguments on {init_name} except self and {id_arg} must have defaults") - return -4 if len(deco.args.defaults) != len(deco_arg_names): print(f"All arguments on @{deco_name} must have defaults") return -4 @@ -178,13 +231,11 @@ def check_dag_init_decorator_arguments() -> int: print(f"Default {id_arg} on @{deco_name} must be {id_default!r}") return -4 - for init_name, init, deco_name, deco, _, _ in items_to_check: - errors = list(_match_arguments((init_name, init.args.args[1:]), (deco_name, deco.args.args))) - if errors: - break - init_defaults_def = (init_name, init.args.defaults) - deco_defaults_def = (deco_name, deco.args.defaults[1:]) - errors = list(_match_defaults(deco_arg_names, init_defaults_def, deco_defaults_def)) + errors = [] + for init_name, cls_attrs, deco_name, deco, _, _ in items_to_check: + errors = list( + _match_arguments((init_name, cls_attrs), (deco_name, deco.args.args + deco.args.kwonlyargs)) + ) if errors: break diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index 71dbdc6f9eb..e784639f8c9 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -340,7 +340,6 @@ class DAG: template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined user_defined_macros: dict | None = None user_defined_filters: dict | None = None - concurrency: int | None = None max_active_tasks: int = attrs.field(default=16, validator=attrs.validators.instance_of(int)) max_active_runs: int = attrs.field(default=16, validator=attrs.validators.instance_of(int)) max_consecutive_failed_dag_runs: int = attrs.field( @@ -360,7 +359,7 @@ class DAG: default=None, converter=attrs.Converter(_convert_params, takes_self=True), # type: ignore[misc, call-overload] ) - access_control: dict | None = attrs.field( + access_control: dict[str, dict[str, Collection[str]]] | dict[str, Collection[str]] | None = attrs.field( default=None, converter=attrs.Converter(_convert_access_control, takes_self=True) ) is_paused_upon_creation: bool | None = None @@ -935,7 +934,7 @@ class DAG: "_log", "task_dict", "template_searchpath", - "sla_miss_callback", + # "sla_miss_callback", "on_success_callback", "on_failure_callback", "template_undefined", @@ -997,17 +996,17 @@ if TYPE_CHECKING: template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, user_defined_macros: dict | None = None, user_defined_filters: dict | None = None, - default_args: dict | None = None, + default_args: dict[str, Any] | None = None, max_active_tasks: int = ..., max_active_runs: int = ..., max_consecutive_failed_dag_runs: int = ..., dagrun_timeout: timedelta | None = None, - sla_miss_callback: Any = None, + # sla_miss_callback: Any = None, catchup: bool = ..., - on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, - on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, + # on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, + # on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, doc_md: str | None = None, - params: abc.MutableMapping | None = None, + params: ParamsDict | None = None, access_control: dict[str, dict[str, Collection[str]]] | dict[str, Collection[str]] | None = None, is_paused_upon_creation: bool | None = None, jinja_environment_kwargs: dict | None = None, diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py b/task_sdk/src/airflow/sdk/definitions/taskgroup.py index 72cd9ed3bc7..e417eab2760 100644 --- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py @@ -112,6 +112,8 @@ class TaskGroup(DAGNode): ui_color: str = "CornflowerBlue" ui_fgcolor: str = "#000" + add_suffix_on_collision: bool = False + @parent_group.default def _default_parent_group(self): from airflow.sdk.definitions.contextmanager import TaskGroupContext @@ -132,6 +134,10 @@ class TaskGroup(DAGNode): raise RuntimeError("TaskGroup can only be used inside a dag") def __attrs_post_init__(self): + # TODO: If attrs supported init only args we could use that here + # https://github.com/python-attrs/attrs/issues/342 + self._check_for_group_id_collisions(self.add_suffix_on_collision) + if self.parent_group: object.__setattr__(self, "used_group_ids", self.parent_group.used_group_ids) self.parent_group.add(self)
