This is an automated email from the ASF dual-hosted git repository.

jason810496 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3276dded172 Add prek hook to enforce keyword-only `session` on 
`@provide_session` (#67150)
3276dded172 is described below

commit 3276dded17244c83c5b01439441874d890286dfc
Author: Jason(Zhe-You) Liu <[email protected]>
AuthorDate: Thu May 28 09:32:57 2026 +0800

    Add prek hook to enforce keyword-only `session` on `@provide_session` 
(#67150)
---
 .pre-commit-config.yaml                            |   6 +
 scripts/ci/prek/check_provide_session_kwargs.py    | 427 ++++++++++++++++++
 .../ci/prek/known_provide_session_positional.txt   |  89 ++++
 .../ci/prek/test_check_provide_session_kwargs.py   | 482 +++++++++++++++++++++
 4 files changed, 1004 insertions(+)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 5f8351b30e7..42ab2035ec0 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1071,6 +1071,12 @@ repos:
         language: python
         pass_filenames: true
         files: ^(airflow-core|airflow-ctl|task-sdk|providers|shared)/.*\.py$
+      - id: check-no-new-provide-session-positional
+        name: Check that no new @provide_session functions declare `session` 
positionally
+        entry: ./scripts/ci/prek/check_provide_session_kwargs.py
+        language: python
+        pass_filenames: true
+        files: 
^(airflow-core|providers)/.*\.py$|^scripts/ci/prek/known_provide_session_positional\.txt$|^scripts/ci/prek/check_provide_session_kwargs\.py$
       - id: check-no-new-airflow-core-utils-modules
         name: Check that no new modules are added under 
airflow-core/src/airflow/utils
         entry: ./scripts/ci/prek/check_no_new_airflow_core_utils_modules.py
diff --git a/scripts/ci/prek/check_provide_session_kwargs.py 
b/scripts/ci/prek/check_provide_session_kwargs.py
new file mode 100755
index 00000000000..29152e13a4d
--- /dev/null
+++ b/scripts/ci/prek/check_provide_session_kwargs.py
@@ -0,0 +1,427 @@
+#!/usr/bin/env python
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#   "rich>=13.0.0",
+# ]
+# ///
+"""Check that no new ``@provide_session`` functions declare ``session`` 
positionally.
+
+The project convention is that any function decorated with ``@provide_session``
+must declare ``session`` as keyword-only (after a bare ``*`` in the signature),
+so callers cannot pass it positionally by accident. See
+``contributing-docs/05_pull_requests.rst#database-session-handling``.
+
+All *existing* offenders are recorded in 
``known_provide_session_positional.txt``
+next to this script as ``relative/path::N`` entries (one per file), where ``N``
+is the maximum number of ``@provide_session`` functions with a positional
+``session`` argument allowed in that file. A file whose current count exceeds
+the recorded limit is treated as a violation – move the ``session`` argument
+behind a bare ``*`` instead.
+
+Modes
+-----
+Default (files passed by prek/pre-commit):
+    Check only the supplied files; fail if any file's count exceeds the limit.
+    When a file's count has *decreased*, the allowlist entry is tightened
+    automatically and the hook exits with a non-zero code so that pre-commit
+    reports the modified allowlist – just stage
+    ``scripts/ci/prek/known_provide_session_positional.txt`` and re-run.
+
+``--all-files``:
+    Walk every ``.py`` file under the project source roots
+    (``airflow-core``, ``providers``, ``shared``) —
+    the same scope the pre-commit hook applies to.
+
+``--cleanup``:
+    Remove entries for files that no longer exist. Safe to run at any time;
+    does not add new entries or raise limits.
+
+``--generate``:
+    Scan the same project source roots as ``--all-files`` and *rebuild* the
+    allowlist from scratch. Intended for the initial setup or after a
+    large-scale clean-up sprint.
+"""
+
+from __future__ import annotations
+
+import argparse
+import ast
+import subprocess
+import typing
+from pathlib import Path
+
+from rich.console import Console
+from rich.panel import Panel
+
+console = Console(color_system="standard", width=200)
+
+REPO_ROOT = Path(__file__).parents[3]
+
+_PROVIDE_SESSION_DECORATOR = "provide_session"
+
+# Top-level directories scanned by ``--all-files`` / ``--generate``. Keep in 
sync with the
+# ``files:`` pattern for this hook in ``.pre-commit-config.yaml``.
+_PROJECT_SOURCE_ROOTS = ("airflow-core", "providers", "shared")
+
+
+def _has_provide_session_decorator(nodes: list[ast.expr]) -> bool:
+    """Whether one of ``nodes`` is a ``@provide_session`` decorator.
+
+    Accepts both bare names (``@provide_session``) and attribute access
+    (``@something.provide_session``).
+    """
+    for node in nodes:
+        if isinstance(node, ast.Name) and node.id == 
_PROVIDE_SESSION_DECORATOR:
+            return True
+        if isinstance(node, ast.Attribute) and node.attr == 
_PROVIDE_SESSION_DECORATOR:
+            return True
+    return False
+
+
+def _session_is_positional(args: ast.arguments) -> ast.arg | None:
+    """Return the ``session`` arg if it is positional (not keyword-only).
+
+    Covers both regular positional args and positional-only args (``def 
f(session, /, ...)``).
+    """
+    for argument in (*args.posonlyargs, *args.args):
+        if argument.arg == "session":
+            return argument
+    return None
+
+
+def _iter_positional_session_in_provide_session(
+    path: Path,
+) -> typing.Iterator[tuple[ast.FunctionDef | ast.AsyncFunctionDef, ast.arg]]:
+    """Yield ``@provide_session`` functions in *path* whose ``session`` is 
positional."""
+    try:
+        source = path.read_text(encoding="utf-8", errors="replace")
+    except OSError:
+        return
+    try:
+        tree = ast.parse(source, str(path))
+    except SyntaxError:
+        return
+    for node in ast.walk(tree):
+        if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
+            continue
+        if not _has_provide_session_decorator(node.decorator_list):
+            continue
+        argument = _session_is_positional(node.args)
+        if argument is None:
+            continue
+        yield node, argument
+
+
+def _count_violations(path: Path) -> int:
+    return sum(1 for _ in _iter_positional_session_in_provide_session(path))
+
+
+def _is_safe_relative(rel: str) -> bool:
+    """Whether ``rel`` is a plain relative path that stays inside 
``REPO_ROOT``.
+
+    Rejects absolute paths and any entry that resolves outside the repo root so
+    callers can ``relative_to(REPO_ROOT)`` without fear of a ``ValueError``.
+    """
+    candidate = Path(rel)
+    if candidate.is_absolute():
+        return False
+    try:
+        (REPO_ROOT / candidate).resolve().relative_to(REPO_ROOT.resolve())
+    except ValueError:
+        return False
+    return True
+
+
+class AllowlistManager:
+    def __init__(self, allowlist_file: Path) -> None:
+        self.allowlist_file = allowlist_file
+
+    @staticmethod
+    def parse(text: str) -> dict[str, int]:
+        """Parse allowlist *text* into a ``{rel_path: count}`` mapping.
+
+        Same validation rules as :meth:`load` so we can reuse parsing for the
+        on-disk allowlist *and* for the git-tracked version fetched from
+        ``HEAD`` when guarding against entry-removal bypasses.
+        """
+        result: dict[str, int] = {}
+        for raw_line in text.splitlines():
+            if not (stripped := raw_line.strip()):
+                continue
+
+            rel_str, _, count_str = stripped.rpartition("::")
+            if not rel_str or not count_str:
+                continue
+
+            try:
+                count = int(count_str)
+            except ValueError:
+                continue
+
+            if not _is_safe_relative(rel_str):
+                console.print(
+                    f"[yellow]Ignoring unsafe allowlist entry (escapes repo 
root):[/yellow] {rel_str}"
+                )
+                continue
+
+            result[rel_str] = count
+
+        return result
+
+    def load(self) -> dict[str, int]:
+        if not self.allowlist_file.exists():
+            return {}
+        return self.parse(self.allowlist_file.read_text())
+
+    def save(self, counts: dict[str, int]) -> None:
+        lines = [f"{rel}::{count}" for rel, count in sorted(counts.items())]
+        self.allowlist_file.write_text("\n".join(lines) + "\n")
+
+    def generate(self) -> int:
+        roots = ", ".join(_PROJECT_SOURCE_ROOTS)
+        console.print(
+            f"Scanning project source roots ([cyan]{roots}[/cyan]) under 
[cyan]{REPO_ROOT}[/cyan] "
+            "for @provide_session functions with positional session …"
+        )
+        counts: dict[str, int] = {}
+        for path in _iter_python_files():
+            n = _count_violations(path)
+            if n > 0:
+                counts[str(path.relative_to(REPO_ROOT))] = n
+
+        self.save(counts)
+        total = sum(counts.values())
+        console.print(
+            f"[green]Generated[/green] 
[cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan] "
+            f"with [bold]{len(counts)}[/bold] files / [bold]{total}[/bold] 
offenders."
+        )
+        return 0
+
+    def cleanup(self) -> int:
+        allowlist = self.load()
+        if not allowlist:
+            console.print("[yellow]Allowlist is empty - nothing to clean 
up.[/yellow]")
+            return 0
+
+        stale: list[str] = [rel for rel in allowlist if not (REPO_ROOT / 
rel).exists()]
+        if stale:
+            console.print(
+                f"[yellow]Removing {len(stale)} stale entr{'y' if len(stale) 
== 1 else 'ies'}:[/yellow]"
+            )
+            for s in sorted(stale):
+                console.print(f"  [dim]-[/dim] {s}")
+            for s in stale:
+                del allowlist[s]
+            self.save(allowlist)
+            console.print(
+                f"\n[green]Updated[/green] 
[cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan]"
+            )
+        else:
+            console.print("[green]No stale entries found.[/green]")
+        return 0
+
+
+def _iter_python_files() -> list[Path]:
+    candidates: list[Path] = []
+    for top in _PROJECT_SOURCE_ROOTS:
+        candidates.extend(
+            p.resolve()
+            for p in (REPO_ROOT / top).rglob("*.py")
+            if ".tox" not in p.parts and "__pycache__" not in p.parts
+        )
+    return candidates
+
+
+def _check_provide_session_kwargs(
+    files: list[Path], allowlist: dict[str, int], manager: AllowlistManager
+) -> int:
+    allowlist_file = manager.allowlist_file.resolve()
+    if any(p.resolve() == allowlist_file for p in files) and not 
allowlist_file.exists():
+        console.print(
+            Panel.fit(
+                f"Allowlist file [cyan]{allowlist_file}[/cyan] is missing.\n"
+                "It was passed to the hook but cannot be read, so the check 
cannot proceed.\n"
+                "Restore it from git or regenerate it with:\n\n"
+                "  [cyan]uv run 
./scripts/ci/prek/check_provide_session_kwargs.py --generate[/cyan]",
+                title="[red]Check failed[/red]",
+                border_style="red",
+            )
+        )
+        return 1
+
+    violations: list[tuple[Path, int, int]] = []
+    tightened: list[tuple[str, int, int]] = []
+
+    for path in files:
+        if not path.exists() or path.suffix != ".py":
+            continue
+        actual = _count_violations(path)
+        rel = str(path.relative_to(REPO_ROOT))
+        allowed = allowlist.get(rel, 0)
+        if actual > allowed:
+            violations.append((path, actual, allowed))
+        elif actual < allowed:
+            if actual == 0:
+                del allowlist[rel]
+            else:
+                allowlist[rel] = actual
+            tightened.append((rel, allowed, actual))
+
+    if tightened:
+        manager.save(allowlist)
+        console.print(
+            f"[green]Tightened {len(tightened)} entr{'y' if len(tightened) == 
1 else 'ies'} "
+            f"in 
[cyan]{manager.allowlist_file.relative_to(REPO_ROOT)}[/cyan][/green] "
+            "(stage the updated file):"
+        )
+        for rel, old, new in tightened:
+            console.print(f"  [cyan]{rel}[/cyan]  {old} -> {new}")
+
+    if violations:
+        console.print(
+            Panel.fit(
+                "New [bold]@provide_session[/bold] function with positional 
``session`` detected.\n"
+                "Move ``session`` after a bare ``*`` in the signature so 
callers must pass it by keyword:\n\n"
+                "  [cyan]@provide_session\n"
+                "  def foo(arg, *, session: Session = NEW_SESSION) -> None: 
...[/cyan]\n\n"
+                "If this usage is intentional and pre-existing, run:\n\n"
+                "  [cyan]uv run 
./scripts/ci/prek/check_provide_session_kwargs.py --generate[/cyan]\n\n"
+                "to regenerate the allowlist, then commit the updated\n"
+                
"[cyan]scripts/ci/prek/known_provide_session_positional.txt[/cyan].",
+                title="[red]Check failed[/red]",
+                border_style="red",
+            )
+        )
+        for path, actual, allowed in violations:
+            console.print(f"  [cyan]{path.relative_to(REPO_ROOT)}[/cyan]  
count={actual} (allowed={allowed})")
+            for func, argument in 
_iter_positional_session_in_provide_session(path):
+                console.print(f"      [dim]L{argument.lineno}[/dim] def 
{func.name}(...)")
+        return 1
+
+    return 1 if tightened else 0
+
+
+def main(argv: list[str] | None = None) -> int:
+    parser = argparse.ArgumentParser(
+        description="Prevent new @provide_session functions from declaring 
`session` positionally.",
+        formatter_class=argparse.RawDescriptionHelpFormatter,
+        epilog=__doc__,
+    )
+    parser.add_argument("files", nargs="*", metavar="FILE", help="Files to 
check (provided by prek)")
+    parser.add_argument(
+        "--all-files",
+        action="store_true",
+        help=("Check every Python file under the project source roots 
(airflow-core, providers, shared)"),
+    )
+    parser.add_argument(
+        "--cleanup",
+        action="store_true",
+        help="Remove stale entries from the allowlist and exit",
+    )
+    parser.add_argument(
+        "--generate",
+        action="store_true",
+        help="Regenerate the allowlist from the current codebase and exit",
+    )
+    args = parser.parse_args(argv)
+
+    manager = AllowlistManager(Path(__file__).parent / 
"known_provide_session_positional.txt")
+
+    if args.generate:
+        return manager.generate()
+
+    if args.cleanup:
+        return manager.cleanup()
+
+    allowlist = manager.load()
+
+    if args.all_files:
+        return _check_provide_session_kwargs(_iter_python_files(), allowlist, 
manager)
+
+    if not args.files:
+        console.print(
+            "[yellow]No files provided. Pass filenames or use --all-files to 
scan the whole repo.[/yellow]"
+        )
+        return 0
+
+    paths = [Path(f).resolve() for f in args.files]
+    paths = _expand_for_allowlist_edits(paths, manager, allowlist)
+    return _check_provide_session_kwargs(paths, allowlist, manager)
+
+
+def _parse_tracked_allowlist(manager: AllowlistManager) -> dict[str, int]:
+    """Return the allowlist as recorded at ``HEAD`` (the git-tracked version).
+
+    Used by :func:`_expand_for_allowlist_edits` so that *removing* an entry
+    cannot silently drop coverage: the previously-listed file is still
+    re-validated against the new (post-edit) allowlist. Returns an empty 
mapping
+    when git is unavailable, the file does not yet exist at ``HEAD``, or the
+    allowlist sits outside ``REPO_ROOT``.
+    """
+    try:
+        rel = manager.allowlist_file.resolve().relative_to(REPO_ROOT.resolve())
+    except ValueError:
+        return {}
+    try:
+        completed = subprocess.run(
+            ["git", "-C", str(REPO_ROOT), "show", f"HEAD:{rel.as_posix()}"],
+            capture_output=True,
+            text=True,
+            check=False,
+        )
+    except (FileNotFoundError, OSError):
+        return {}
+    if completed.returncode != 0:
+        return {}
+    return AllowlistManager.parse(completed.stdout)
+
+
+def _expand_for_allowlist_edits(
+    paths: list[Path], manager: AllowlistManager, allowlist: dict[str, int]
+) -> list[Path]:
+    """Add allowlisted files when the allowlist itself is being changed.
+
+    Without this, a contributor could raise counts in
+    ``known_provide_session_positional.txt`` and the hook would do no 
validation
+    (since only the ``.txt`` file is passed), letting the loosened allowlist
+    sail through. We also union the git-tracked allowlist (from ``HEAD``) so
+    that removing an entry cannot silently bypass the check for a file that
+    still has positional ``session`` arguments.
+
+    Both sides of the allowlist-file comparison are resolved so the detection 
is
+    robust to symlinks and unresolved inputs (the hook can be invoked with 
either).
+    """
+    allowlist_file = manager.allowlist_file.resolve()
+    if not any(p.resolve() == allowlist_file for p in paths):
+        return paths
+
+    expanded = list(paths)
+    seen = {p.resolve() for p in paths if p.suffix == ".py"}
+    tracked = _parse_tracked_allowlist(manager)
+    for rel in {*allowlist, *tracked}:
+        candidate = (REPO_ROOT / rel).resolve()
+        if candidate.exists() and candidate not in seen:
+            seen.add(candidate)
+            expanded.append(candidate)
+    return expanded
+
+
+if __name__ == "__main__":
+    raise SystemExit(main())
diff --git a/scripts/ci/prek/known_provide_session_positional.txt 
b/scripts/ci/prek/known_provide_session_positional.txt
new file mode 100644
index 00000000000..d0c84e2f6b4
--- /dev/null
+++ b/scripts/ci/prek/known_provide_session_positional.txt
@@ -0,0 +1,89 @@
+airflow-core/src/airflow/api/common/delete_dag.py::1
+airflow-core/src/airflow/api/common/mark_tasks.py::1
+airflow-core/src/airflow/callbacks/database_callback_sink.py::1
+airflow-core/src/airflow/cli/commands/dag_command.py::8
+airflow-core/src/airflow/cli/commands/jobs_command.py::1
+airflow-core/src/airflow/cli/commands/task_command.py::1
+airflow-core/src/airflow/cli/commands/team_command.py::4
+airflow-core/src/airflow/cli/commands/variable_command.py::1
+airflow-core/src/airflow/dag_processing/dagbag.py::1
+airflow-core/src/airflow/dag_processing/manager.py::4
+airflow-core/src/airflow/jobs/base_job_runner.py::2
+airflow-core/src/airflow/jobs/job.py::7
+airflow-core/src/airflow/jobs/scheduler_job_runner.py::11
+airflow-core/src/airflow/jobs/triggerer_job_runner.py::1
+airflow-core/src/airflow/models/connection.py::2
+airflow-core/src/airflow/models/dag.py::7
+airflow-core/src/airflow/models/dagcode.py::6
+airflow-core/src/airflow/models/dagrun.py::15
+airflow-core/src/airflow/models/dagwarning.py::1
+airflow-core/src/airflow/models/deadline.py::1
+airflow-core/src/airflow/models/deadline_alert.py::1
+airflow-core/src/airflow/models/pool.py::11
+airflow-core/src/airflow/models/renderedtifields.py::4
+airflow-core/src/airflow/models/revoked_token.py::2
+airflow-core/src/airflow/models/serialized_dag.py::6
+airflow-core/src/airflow/models/taskinstance.py::21
+airflow-core/src/airflow/models/taskinstancehistory.py::2
+airflow-core/src/airflow/models/team.py::1
+airflow-core/src/airflow/models/trigger.py::7
+airflow-core/src/airflow/models/variable.py::2
+airflow-core/src/airflow/secrets/metastore.py::2
+airflow-core/src/airflow/serialization/definitions/dag.py::2
+airflow-core/src/airflow/ti_deps/deps/base_ti_dep.py::2
+airflow-core/src/airflow/ti_deps/deps/dag_ti_slots_available_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/dag_unpaused_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/dagrun_exists_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/exec_date_after_start_date_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/not_in_retry_period_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/pool_slots_available_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/ready_to_reschedule.py::1
+airflow-core/src/airflow/ti_deps/deps/runnable_exec_date_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/task_concurrency_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/task_not_running_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/valid_state_dep.py::1
+airflow-core/src/airflow/utils/cli_action_loggers.py::1
+airflow-core/src/airflow/utils/db.py::7
+airflow-core/src/airflow/utils/db_cleanup.py::2
+airflow-core/src/airflow/utils/log/file_task_handler.py::1
+airflow-core/tests/unit/api_fastapi/common/test_exceptions.py::4
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py::19
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py::2
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py::1
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_tags.py::1
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_warning.py::1
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_event_logs.py::1
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_import_error.py::8
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_job.py::1
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_monitor.py::2
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py::2
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py::2
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py::3
+airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_calendar.py::2
+airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py::1
+airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_gantt.py::1
+airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py::1
+airflow-core/tests/unit/cli/commands/test_rotate_fernet_key_command.py::2
+airflow-core/tests/unit/jobs/test_scheduler_job.py::1
+airflow-core/tests/unit/listeners/test_listeners.py::7
+airflow-core/tests/unit/models/test_taskinstance.py::4
+airflow-core/tests/unit/models/test_timestamp.py::2
+providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py::1
+providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py::1
+providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/template_rendering.py::1
+providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py::1
+providers/common/ai/tests/unit/common/ai/plugins/test_hitl_review.py::1
+providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py::2
+providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py::3
+providers/edge3/src/airflow/providers/edge3/models/edge_worker.py::10
+providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py::1
+providers/edge3/src/airflow/providers/edge3/worker_api/routes/logs.py::1
+providers/fab/src/airflow/providers/fab/auth_manager/cli_commands/permissions_command.py::1
+providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py::1
+providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py::3
+providers/openlineage/src/airflow/providers/openlineage/utils/utils.py::1
+providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py::1
+providers/standard/src/airflow/providers/standard/sensors/external_task.py::1
+providers/standard/src/airflow/providers/standard/utils/sensor_helper.py::1
+providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py::3
diff --git a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py 
b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py
new file mode 100644
index 00000000000..78b85cd270b
--- /dev/null
+++ b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py
@@ -0,0 +1,482 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import ast
+import os
+import subprocess
+import textwrap
+from pathlib import Path
+
+import pytest
+from ci.prek import check_provide_session_kwargs as hook
+from ci.prek.check_provide_session_kwargs import (
+    AllowlistManager,
+    _check_provide_session_kwargs,
+    _count_violations,
+    _expand_for_allowlist_edits,
+    _has_provide_session_decorator,
+    _iter_positional_session_in_provide_session,
+    _parse_tracked_allowlist,
+    _session_is_positional,
+)
+
+
[email protected]
+def find_violations(write_python_file):
+    """Factory fixture: write code to a temp file and return 
positional-session violations."""
+
+    def _check(code: str) -> list[tuple[ast.FunctionDef | 
ast.AsyncFunctionDef, ast.arg]]:
+        path = write_python_file(code)
+        return list(_iter_positional_session_in_provide_session(path))
+
+    return _check
+
+
[email protected]
+def create_fake_repo(tmp_path, monkeypatch):
+    """Create a fake repo layout and patch REPO_ROOT so paths resolve 
correctly."""
+    monkeypatch.setattr(hook, "REPO_ROOT", tmp_path)
+
+    def _write(rel: str, code: str) -> Path:
+        path = tmp_path / rel
+        path.parent.mkdir(parents=True, exist_ok=True)
+        path.write_text(textwrap.dedent(code))
+        return path
+
+    return _write
+
+
[email protected]
+def create_git_repo(create_fake_repo, tmp_path):
+    """Initialise ``tmp_path`` as a git repo so ``git show HEAD:<file>`` works.
+
+    Returns a helper that commits the current working-tree contents under a 
given
+    message, so tests can stage a "previous" allowlist at HEAD before mutating 
it.
+    """
+    env = {
+        **os.environ,
+        "GIT_AUTHOR_NAME": "t",
+        "GIT_AUTHOR_EMAIL": "t@t",
+        "GIT_COMMITTER_NAME": "t",
+        "GIT_COMMITTER_EMAIL": "t@t",
+    }
+
+    def _run(*args: str) -> None:
+        subprocess.run(["git", "-C", str(tmp_path), *args], check=True, 
env=env, capture_output=True)
+
+    _run("init", "-q", "-b", "main")
+    _run("config", "commit.gpgsign", "false")
+
+    def _commit(message: str) -> None:
+        _run("add", "-A")
+        _run("commit", "-q", "--allow-empty", "-m", message)
+
+    return _commit
+
+
+class TestHasProvideSessionDecorator:
+    def test_provide_session_name(self):
+        func = ast.parse("@provide_session\ndef foo(): pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is True
+
+    def test_provide_session_attribute(self):
+        func = ast.parse("@utils.provide_session\ndef foo(): pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is True
+
+    def test_no_decorator(self):
+        func = ast.parse("def foo(): pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is False
+
+    def test_unrelated_decorator(self):
+        func = ast.parse("@staticmethod\ndef foo(): pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is False
+
+    def test_multiple_decorators_including_provide_session(self):
+        func = ast.parse("@staticmethod\n@provide_session\ndef foo(): 
pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is True
+
+
+class TestSessionIsPositional:
+    def test_no_session_arg(self):
+        func = ast.parse("def foo(x, y): pass").body[0]
+        assert _session_is_positional(func.args) is None
+
+    def test_session_positional(self):
+        func = ast.parse("def foo(session=NEW_SESSION): pass").body[0]
+        argument = _session_is_positional(func.args)
+        assert argument is not None
+        assert argument.arg == "session"
+
+    def test_session_keyword_only(self):
+        func = ast.parse("def foo(*, session=NEW_SESSION): pass").body[0]
+        assert _session_is_positional(func.args) is None
+
+    def test_session_positional_among_other_args(self):
+        func = ast.parse("def foo(x, y, session=NEW_SESSION): pass").body[0]
+        argument = _session_is_positional(func.args)
+        assert argument is not None
+        assert argument.arg == "session"
+
+    def test_session_kwonly_after_other_positional(self):
+        func = ast.parse("def foo(x, y, *, session=NEW_SESSION): pass").body[0]
+        assert _session_is_positional(func.args) is None
+
+    def test_session_positional_only(self):
+        func = ast.parse("def foo(session, /, x): pass").body[0]
+        argument = _session_is_positional(func.args)
+        assert argument is not None
+        assert argument.arg == "session"
+
+
+class TestIterPositionalSessionInProvideSession:
+    def test_keyword_only_session_is_clean(self, find_violations):
+        code = """\
+        @provide_session
+        def foo(*, session=NEW_SESSION):
+            pass
+        """
+        assert find_violations(code) == []
+
+    def test_positional_session_is_flagged(self, find_violations):
+        code = """\
+        @provide_session
+        def foo(session=NEW_SESSION):
+            pass
+        """
+        violations = find_violations(code)
+        assert len(violations) == 1
+        func, argument = violations[0]
+        assert func.name == "foo"
+        assert argument.arg == "session"
+
+    def test_no_provide_session_decorator_is_ignored(self, find_violations):
+        code = """\
+        def foo(session=NEW_SESSION):
+            pass
+        """
+        assert find_violations(code) == []
+
+    def test_async_function_with_positional_session_is_flagged(self, 
find_violations):
+        code = """\
+        @provide_session
+        async def foo(session=NEW_SESSION):
+            pass
+        """
+        violations = find_violations(code)
+        assert len(violations) == 1
+
+    def test_method_with_positional_session_is_flagged(self, find_violations):
+        code = """\
+        class C:
+            @provide_session
+            def foo(self, session=NEW_SESSION):
+                pass
+        """
+        violations = find_violations(code)
+        assert len(violations) == 1
+        assert violations[0][0].name == "foo"
+
+    def test_attribute_decorator_is_recognised(self, find_violations):
+        code = """\
+        @airflow.utils.session.provide_session
+        def foo(session=NEW_SESSION):
+            pass
+        """
+        violations = find_violations(code)
+        assert len(violations) == 1
+
+    def test_count_violations_multiple_in_file(self, write_python_file):
+        code = """\
+        @provide_session
+        def a(session=NEW_SESSION):
+            pass
+
+        @provide_session
+        def b(x, session=NEW_SESSION):
+            pass
+
+        @provide_session
+        def c(*, session=NEW_SESSION):
+            pass
+        """
+        path = write_python_file(code)
+        assert _count_violations(path) == 2
+
+    def test_syntax_error_returns_no_violations(self, write_python_file):
+        path = write_python_file("def foo(:\n    pass")
+        assert _count_violations(path) == 0
+
+    def test_invalid_utf8_does_not_crash(self, tmp_path):
+        path = tmp_path / "invalid_utf8.py"
+        path.write_bytes(b"# bad byte: \xff\n@provide_session\ndef 
foo(session=NEW_SESSION):\n    pass\n")
+
+        assert _count_violations(path) == 1
+
+
+class TestAllowlistManager:
+    def test_load_missing_file_returns_empty(self, tmp_path):
+        manager = AllowlistManager(tmp_path / "missing.txt")
+        assert manager.load() == {}
+
+    def test_save_and_load_round_trip(self, tmp_path):
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        manager.save({"b/file.py": 2, "a/file.py": 1})
+        # Sorted by key in the file
+        text = (tmp_path / "allowlist.txt").read_text()
+        assert text.splitlines() == ["a/file.py::1", "b/file.py::2"]
+        assert manager.load() == {"a/file.py": 1, "b/file.py": 2}
+
+    def test_load_skips_blank_and_malformed_lines(self, tmp_path):
+        path = tmp_path / "allowlist.txt"
+        path.write_text("\nvalid/file.py::3\nnocount\n::5\nbad::notanumber\n")
+        assert AllowlistManager(path).load() == {"valid/file.py": 3}
+
+    @pytest.mark.usefixtures("create_fake_repo")
+    def test_load_skips_unsafe_entries(self, tmp_path):
+        """Entries that escape REPO_ROOT (absolute paths or `..` segments) are 
ignored."""
+        path = tmp_path / "allowlist.txt"
+        
path.write_text("airflow-core/src/airflow/safe.py::1\n../escape.py::1\n/etc/passwd::1\n")
+        # `create_fake_repo` patches REPO_ROOT to tmp_path so the safety check 
is meaningful.
+        assert AllowlistManager(path).load() == 
{"airflow-core/src/airflow/safe.py": 1}
+
+
+class TestCheckProvideSessionKwargs:
+    def test_no_violations_in_clean_file(self, create_fake_repo, tmp_path):
+        path = create_fake_repo(
+            "airflow-core/src/airflow/clean.py",
+            """\
+            @provide_session
+            def foo(*, session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        assert _check_provide_session_kwargs([path], {}, manager) == 0
+
+    def test_new_violation_fails(self, create_fake_repo, tmp_path):
+        path = create_fake_repo(
+            "airflow-core/src/airflow/bad.py",
+            """\
+            @provide_session
+            def foo(session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        assert _check_provide_session_kwargs([path], {}, manager) == 1
+
+    def test_violation_within_allowlist_passes(self, create_fake_repo, 
tmp_path):
+        path = create_fake_repo(
+            "airflow-core/src/airflow/grandfathered.py",
+            """\
+            @provide_session
+            def foo(session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        allowlist = {"airflow-core/src/airflow/grandfathered.py": 1}
+        assert _check_provide_session_kwargs([path], allowlist, manager) == 0
+
+    def test_exceeding_allowlist_fails(self, create_fake_repo, tmp_path):
+        path = create_fake_repo(
+            "airflow-core/src/airflow/grew.py",
+            """\
+            @provide_session
+            def a(session=NEW_SESSION):
+                pass
+
+            @provide_session
+            def b(session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        allowlist = {"airflow-core/src/airflow/grew.py": 1}
+        assert _check_provide_session_kwargs([path], allowlist, manager) == 1
+
+    def test_reducing_violations_tightens_allowlist(self, create_fake_repo, 
tmp_path):
+        path = create_fake_repo(
+            "airflow-core/src/airflow/improved.py",
+            """\
+            @provide_session
+            def foo(session=NEW_SESSION):
+                pass
+
+            @provide_session
+            def bar(*, session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        allowlist = {"airflow-core/src/airflow/improved.py": 2}
+        # Exit non-zero so pre-commit reports the modified allowlist
+        assert _check_provide_session_kwargs([path], allowlist, manager) == 1
+        assert manager.load() == {"airflow-core/src/airflow/improved.py": 1}
+
+    def test_fixing_all_violations_removes_entry(self, create_fake_repo, 
tmp_path):
+        path = create_fake_repo(
+            "airflow-core/src/airflow/fixed.py",
+            """\
+            @provide_session
+            def foo(*, session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        allowlist = {"airflow-core/src/airflow/fixed.py": 1}
+        assert _check_provide_session_kwargs([path], allowlist, manager) == 1
+        assert manager.load() == {}
+
+    def test_non_python_file_is_skipped(self, create_fake_repo, tmp_path):
+        path = create_fake_repo(
+            "airflow-core/src/airflow/not_python.txt", "@provide_session\ndef 
foo(session=N): pass\n"
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        assert _check_provide_session_kwargs([path], {}, manager) == 0
+
+    @pytest.mark.usefixtures("create_fake_repo")
+    def test_missing_allowlist_file_fails_loudly(self, tmp_path):
+        """Passing the allowlist path when the file is missing must fail, not 
silently pass."""
+        allowlist_path = tmp_path / "allowlist.txt"
+        manager = AllowlistManager(allowlist_path)
+        assert not allowlist_path.exists()
+        assert _check_provide_session_kwargs([allowlist_path.resolve()], {}, 
manager) == 1
+
+
+class TestExpandForAllowlistEdits:
+    def test_unchanged_when_allowlist_not_in_paths(self, create_fake_repo, 
tmp_path):
+        py = create_fake_repo("airflow-core/src/airflow/x.py", "pass")
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        assert _expand_for_allowlist_edits([py], manager, 
{"airflow-core/src/airflow/x.py": 1}) == [py]
+
+    def test_appends_allowlisted_files_when_allowlist_edited(self, 
create_fake_repo, tmp_path):
+        allowlist_path = tmp_path / "allowlist.txt"
+        manager = AllowlistManager(allowlist_path)
+        listed = create_fake_repo("airflow-core/src/airflow/listed.py", "pass")
+        # Pass a resolved path — matches production behavior (``main()`` 
resolves argv).
+        result = _expand_for_allowlist_edits(
+            [allowlist_path.resolve()],
+            manager,
+            {"airflow-core/src/airflow/listed.py": 1, 
"airflow-core/src/airflow/gone.py": 1},
+        )
+        assert allowlist_path.resolve() in result
+        assert listed in result
+        # File in allowlist that does not exist on disk should be ignored.
+        assert (tmp_path / "airflow-core/src/airflow/gone.py").resolve() not 
in result
+
+    def test_detection_robust_to_symlinked_allowlist(self, create_fake_repo, 
tmp_path):
+        """A symlink pointing at the allowlist file must still trigger 
expansion."""
+        allowlist_path = tmp_path / "allowlist.txt"
+        manager = AllowlistManager(allowlist_path)
+        listed = create_fake_repo("airflow-core/src/airflow/listed.py", "pass")
+        manager.save({"airflow-core/src/airflow/listed.py": 1})
+
+        symlink = tmp_path / "allowlist_link.txt"
+        symlink.symlink_to(allowlist_path)
+
+        # Production resolves argv before calling the helper — a symlinked 
path resolves
+        # to the real allowlist file and must be recognised as an allowlist 
edit.
+        result = _expand_for_allowlist_edits([symlink.resolve()], manager, 
manager.load())
+
+        assert listed in result
+
+    def test_includes_parse_tracked_allowlist_entries_when_removed(
+        self, create_fake_repo, create_git_repo, tmp_path
+    ):
+        """Removing an entry from the allowlist must still re-check the 
previously-listed file."""
+        rel = "airflow-core/src/airflow/dropped.py"
+        create_fake_repo(
+            rel,
+            """\
+            @provide_session
+            def foo(session=NEW_SESSION):
+                pass
+            """,
+        )
+        allowlist_path = tmp_path / "allowlist.txt"
+        manager = AllowlistManager(allowlist_path)
+        manager.save({rel: 1})
+        create_git_repo("seed allowlist at HEAD")
+
+        # Working tree: remove the entry, but the offending file still exists.
+        allowlist_path.write_text("")
+        current = manager.load()
+        assert current == {}
+
+        expanded = _expand_for_allowlist_edits([allowlist_path.resolve()], 
manager, current)
+        # The previously-listed file must be re-validated.
+        assert (tmp_path / rel).resolve() in expanded
+
+        # And the full check should fail because the file still has positional 
sessions.
+        assert _check_provide_session_kwargs(expanded, current, manager) == 1
+
+    @pytest.mark.usefixtures("create_fake_repo")
+    def test_parse_tracked_allowlist_empty_when_no_git_history(self, tmp_path):
+        """Without a git repo the git-tracked allowlist lookup returns empty 
and does not crash."""
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        assert _parse_tracked_allowlist(manager) == {}
+
+    def test_re_validates_listed_files_so_loosening_cannot_bypass(self, 
create_fake_repo, tmp_path, capsys):
+        """Editing only the allowlist must still trigger validation of listed 
files."""
+        rel = "airflow-core/src/airflow/loosened.py"
+        create_fake_repo(
+            rel,
+            """\
+            @provide_session
+            def foo(session=NEW_SESSION):
+                pass
+
+            @provide_session
+            def bar(session=NEW_SESSION):
+                pass
+            """,
+        )
+        allowlist_path = tmp_path / "allowlist.txt"
+        manager = AllowlistManager(allowlist_path)
+        # Allowlist loosened to 5 although file only has 2 positional sessions.
+        allowlist = {rel: 5}
+        manager.save(allowlist)
+
+        # Only the allowlist file is "changed"; without re-validation this 
would return 0.
+        # Resolve the path to mirror what ``main()`` does in production.
+        paths = _expand_for_allowlist_edits([allowlist_path.resolve()], 
manager, allowlist)
+        rc = _check_provide_session_kwargs(paths, allowlist, manager)
+
+        # Tightened from 5 -> 2, so the hook exits non-zero to surface the 
modified allowlist.
+        assert rc == 1
+        assert manager.load() == {rel: 2}
+
+
+class TestCleanup:
+    def test_cleanup_removes_stale_entries(self, create_fake_repo, tmp_path):
+        create_fake_repo("airflow-core/src/airflow/keeper.py", "pass")
+        allowlist_path = tmp_path / "allowlist.txt"
+        manager = AllowlistManager(allowlist_path)
+        manager.save(
+            {
+                "airflow-core/src/airflow/keeper.py": 1,
+                "airflow-core/src/airflow/gone.py": 1,
+            }
+        )
+        assert manager.cleanup() == 0
+        assert manager.load() == {"airflow-core/src/airflow/keeper.py": 1}
+
+    def test_cleanup_empty_allowlist(self, tmp_path):
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        assert manager.cleanup() == 0


Reply via email to