Lee-W commented on code in PR #67150:
URL: https://github.com/apache/airflow/pull/67150#discussion_r3301151744


##########
scripts/ci/prek/check_provide_session_kwargs.py:
##########
@@ -0,0 +1,430 @@
+#!/usr/bin/env python
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#   "rich>=13.0.0",
+# ]
+# ///
+"""Check that no new ``@provide_session`` functions declare ``session`` 
positionally.
+
+The project convention is that any function decorated with ``@provide_session``
+must declare ``session`` as keyword-only (after a bare ``*`` in the signature),
+so callers cannot pass it positionally by accident. See
+``contributing-docs/05_pull_requests.rst#database-session-handling``.
+
+All *existing* offenders are recorded in 
``known_provide_session_positional.txt``
+next to this script as ``relative/path::N`` entries (one per file), where ``N``
+is the maximum number of ``@provide_session`` functions with a positional
+``session`` argument allowed in that file. A file whose current count exceeds
+the recorded limit is treated as a violation – move the ``session`` argument
+behind a bare ``*`` instead.
+
+Modes
+-----
+Default (files passed by prek/pre-commit):
+    Check only the supplied files; fail if any file's count exceeds the limit.
+    When a file's count has *decreased*, the allowlist entry is tightened
+    automatically and the hook exits with a non-zero code so that pre-commit
+    reports the modified allowlist – just stage
+    ``scripts/ci/prek/known_provide_session_positional.txt`` and re-run.
+
+``--all-files``:
+    Walk every ``.py`` file under the project source roots
+    (``airflow-core``, ``airflow-ctl``, ``task-sdk``, ``providers``, 
``shared``) —

Review Comment:
   task-sdk leftover. or should we just check whether there's session usage 
there (probably already have one)? nonetheless, I'm good once we clean up the 
leftover part



##########
scripts/ci/prek/check_provide_session_kwargs.py:
##########
@@ -0,0 +1,430 @@
+#!/usr/bin/env python
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#   "rich>=13.0.0",
+# ]
+# ///
+"""Check that no new ``@provide_session`` functions declare ``session`` 
positionally.
+
+The project convention is that any function decorated with ``@provide_session``
+must declare ``session`` as keyword-only (after a bare ``*`` in the signature),
+so callers cannot pass it positionally by accident. See
+``contributing-docs/05_pull_requests.rst#database-session-handling``.
+
+All *existing* offenders are recorded in 
``known_provide_session_positional.txt``
+next to this script as ``relative/path::N`` entries (one per file), where ``N``
+is the maximum number of ``@provide_session`` functions with a positional
+``session`` argument allowed in that file. A file whose current count exceeds
+the recorded limit is treated as a violation – move the ``session`` argument
+behind a bare ``*`` instead.
+
+Modes
+-----
+Default (files passed by prek/pre-commit):
+    Check only the supplied files; fail if any file's count exceeds the limit.
+    When a file's count has *decreased*, the allowlist entry is tightened
+    automatically and the hook exits with a non-zero code so that pre-commit
+    reports the modified allowlist – just stage
+    ``scripts/ci/prek/known_provide_session_positional.txt`` and re-run.
+
+``--all-files``:
+    Walk every ``.py`` file under the project source roots
+    (``airflow-core``, ``airflow-ctl``, ``task-sdk``, ``providers``, 
``shared``) —

Review Comment:
   ```suggestion
       (``airflow-core``, ``airflow-ctl``, ``providers``, ``shared``) —
   ```



##########
scripts/ci/prek/check_provide_session_kwargs.py:
##########
@@ -0,0 +1,430 @@
+#!/usr/bin/env python
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#   "rich>=13.0.0",
+# ]
+# ///
+"""Check that no new ``@provide_session`` functions declare ``session`` 
positionally.
+
+The project convention is that any function decorated with ``@provide_session``
+must declare ``session`` as keyword-only (after a bare ``*`` in the signature),
+so callers cannot pass it positionally by accident. See
+``contributing-docs/05_pull_requests.rst#database-session-handling``.
+
+All *existing* offenders are recorded in 
``known_provide_session_positional.txt``
+next to this script as ``relative/path::N`` entries (one per file), where ``N``
+is the maximum number of ``@provide_session`` functions with a positional
+``session`` argument allowed in that file. A file whose current count exceeds
+the recorded limit is treated as a violation – move the ``session`` argument
+behind a bare ``*`` instead.
+
+Modes
+-----
+Default (files passed by prek/pre-commit):
+    Check only the supplied files; fail if any file's count exceeds the limit.
+    When a file's count has *decreased*, the allowlist entry is tightened
+    automatically and the hook exits with a non-zero code so that pre-commit
+    reports the modified allowlist – just stage
+    ``scripts/ci/prek/known_provide_session_positional.txt`` and re-run.
+
+``--all-files``:
+    Walk every ``.py`` file under the project source roots
+    (``airflow-core``, ``airflow-ctl``, ``task-sdk``, ``providers``, 
``shared``) —
+    the same scope the pre-commit hook applies to.
+
+``--cleanup``:
+    Remove entries for files that no longer exist. Safe to run at any time;
+    does not add new entries or raise limits.
+
+``--generate``:
+    Scan the same project source roots as ``--all-files`` and *rebuild* the
+    allowlist from scratch. Intended for the initial setup or after a
+    large-scale clean-up sprint.
+"""
+
+from __future__ import annotations
+
+import argparse
+import ast
+import subprocess
+import typing
+from pathlib import Path
+
+from rich.console import Console
+from rich.panel import Panel
+
+console = Console(color_system="standard", width=200)
+
+REPO_ROOT = Path(__file__).parents[3]
+
+_PROVIDE_SESSION_DECORATOR = "provide_session"
+
+# Top-level directories scanned by ``--all-files`` / ``--generate``. Keep in 
sync with the
+# ``files:`` pattern for this hook in ``.pre-commit-config.yaml``.
+_PROJECT_SOURCE_ROOTS = ("airflow-core", "airflow-ctl", "task-sdk", 
"providers", "shared")

Review Comment:
   ```suggestion
   _PROJECT_SOURCE_ROOTS = ("airflow-core", "airflow-ctl", "providers", 
"shared")
   ```



##########
scripts/ci/prek/check_provide_session_kwargs.py:
##########
@@ -0,0 +1,430 @@
+#!/usr/bin/env python
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#   "rich>=13.0.0",
+# ]
+# ///
+"""Check that no new ``@provide_session`` functions declare ``session`` 
positionally.
+
+The project convention is that any function decorated with ``@provide_session``
+must declare ``session`` as keyword-only (after a bare ``*`` in the signature),
+so callers cannot pass it positionally by accident. See
+``contributing-docs/05_pull_requests.rst#database-session-handling``.
+
+All *existing* offenders are recorded in 
``known_provide_session_positional.txt``
+next to this script as ``relative/path::N`` entries (one per file), where ``N``
+is the maximum number of ``@provide_session`` functions with a positional
+``session`` argument allowed in that file. A file whose current count exceeds
+the recorded limit is treated as a violation – move the ``session`` argument
+behind a bare ``*`` instead.
+
+Modes
+-----
+Default (files passed by prek/pre-commit):
+    Check only the supplied files; fail if any file's count exceeds the limit.
+    When a file's count has *decreased*, the allowlist entry is tightened
+    automatically and the hook exits with a non-zero code so that pre-commit
+    reports the modified allowlist – just stage
+    ``scripts/ci/prek/known_provide_session_positional.txt`` and re-run.
+
+``--all-files``:
+    Walk every ``.py`` file under the project source roots
+    (``airflow-core``, ``airflow-ctl``, ``task-sdk``, ``providers``, 
``shared``) —
+    the same scope the pre-commit hook applies to.
+
+``--cleanup``:
+    Remove entries for files that no longer exist. Safe to run at any time;
+    does not add new entries or raise limits.
+
+``--generate``:
+    Scan the same project source roots as ``--all-files`` and *rebuild* the
+    allowlist from scratch. Intended for the initial setup or after a
+    large-scale clean-up sprint.
+"""
+
+from __future__ import annotations
+
+import argparse
+import ast
+import subprocess
+import typing
+from pathlib import Path
+
+from rich.console import Console
+from rich.panel import Panel
+
+console = Console(color_system="standard", width=200)
+
+REPO_ROOT = Path(__file__).parents[3]
+
+_PROVIDE_SESSION_DECORATOR = "provide_session"
+
+# Top-level directories scanned by ``--all-files`` / ``--generate``. Keep in 
sync with the
+# ``files:`` pattern for this hook in ``.pre-commit-config.yaml``.
+_PROJECT_SOURCE_ROOTS = ("airflow-core", "airflow-ctl", "task-sdk", 
"providers", "shared")
+
+
+def _has_provide_session_decorator(nodes: list[ast.expr]) -> bool:
+    """Whether one of ``nodes`` is a ``@provide_session`` decorator.
+
+    Accepts both bare names (``@provide_session``) and attribute access
+    (``@something.provide_session``).
+    """
+    for node in nodes:
+        if isinstance(node, ast.Name) and node.id == 
_PROVIDE_SESSION_DECORATOR:
+            return True
+        if isinstance(node, ast.Attribute) and node.attr == 
_PROVIDE_SESSION_DECORATOR:
+            return True
+    return False
+
+
+def _session_is_positional(args: ast.arguments) -> ast.arg | None:
+    """Return the ``session`` arg if it is positional (not keyword-only).
+
+    Covers both regular positional args and positional-only args (``def 
f(session, /, ...)``).
+    """
+    for argument in (*args.posonlyargs, *args.args):
+        if argument.arg == "session":
+            return argument
+    return None
+
+
+def _iter_positional_session_in_provide_session(
+    path: Path,
+) -> typing.Iterator[tuple[ast.FunctionDef | ast.AsyncFunctionDef, ast.arg]]:
+    """Yield ``@provide_session`` functions in *path* whose ``session`` is 
positional."""
+    try:
+        source = path.read_text(encoding="utf-8", errors="replace")
+    except OSError:
+        return
+    try:
+        tree = ast.parse(source, str(path))
+    except SyntaxError:
+        return
+    for node in ast.walk(tree):
+        if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
+            continue
+        if not _has_provide_session_decorator(node.decorator_list):
+            continue
+        argument = _session_is_positional(node.args)
+        if argument is None:
+            continue
+        yield node, argument
+
+
+def _count_violations(path: Path) -> int:
+    return sum(1 for _ in _iter_positional_session_in_provide_session(path))
+
+
+def _is_safe_relative(rel: str) -> bool:
+    """Whether ``rel`` is a plain relative path that stays inside 
``REPO_ROOT``.
+
+    Rejects absolute paths and any entry that resolves outside the repo root so
+    callers can ``relative_to(REPO_ROOT)`` without fear of a ``ValueError``.
+    """
+    candidate = Path(rel)
+    if candidate.is_absolute():
+        return False
+    try:
+        (REPO_ROOT / candidate).resolve().relative_to(REPO_ROOT.resolve())
+    except ValueError:
+        return False
+    return True
+
+
+class AllowlistManager:
+    def __init__(self, allowlist_file: Path) -> None:
+        self.allowlist_file = allowlist_file
+
+    @staticmethod
+    def parse(text: str) -> dict[str, int]:
+        """Parse allowlist *text* into a ``{rel_path: count}`` mapping.
+
+        Same validation rules as :meth:`load` so we can reuse parsing for the
+        on-disk allowlist *and* for the previous version fetched from git when
+        guarding against entry-removal bypasses.
+        """
+        result: dict[str, int] = {}
+        for raw_line in text.splitlines():
+            if not (stripped := raw_line.strip()):
+                continue
+
+            rel_str, _, count_str = stripped.rpartition("::")
+            if not rel_str or not count_str:
+                continue
+
+            try:
+                count = int(count_str)
+            except ValueError:
+                continue
+
+            if not _is_safe_relative(rel_str):
+                console.print(
+                    f"[yellow]Ignoring unsafe allowlist entry (escapes repo 
root):[/yellow] {rel_str}"
+                )
+                continue
+
+            result[rel_str] = count
+
+        return result
+
+    def load(self) -> dict[str, int]:
+        if not self.allowlist_file.exists():
+            return {}
+        return self.parse(self.allowlist_file.read_text())
+
+    def save(self, counts: dict[str, int]) -> None:
+        lines = [f"{rel}::{count}" for rel, count in sorted(counts.items())]
+        self.allowlist_file.write_text("\n".join(lines) + "\n")
+
+    def generate(self) -> int:
+        roots = ", ".join(_PROJECT_SOURCE_ROOTS)
+        console.print(
+            f"Scanning project source roots ([cyan]{roots}[/cyan]) under 
[cyan]{REPO_ROOT}[/cyan] "
+            "for @provide_session functions with positional session …"
+        )
+        counts: dict[str, int] = {}
+        for path in _iter_python_files():
+            n = _count_violations(path)
+            if n > 0:
+                counts[str(path.relative_to(REPO_ROOT))] = n
+
+        self.save(counts)
+        total = sum(counts.values())
+        console.print(
+            f"[green]Generated[/green] 
[cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan] "
+            f"with [bold]{len(counts)}[/bold] files / [bold]{total}[/bold] 
offenders."
+        )
+        return 0
+
+    def cleanup(self) -> int:
+        allowlist = self.load()
+        if not allowlist:
+            console.print("[yellow]Allowlist is empty - nothing to clean 
up.[/yellow]")
+            return 0
+
+        stale: list[str] = [rel for rel in allowlist if not (REPO_ROOT / 
rel).exists()]
+        if stale:
+            console.print(
+                f"[yellow]Removing {len(stale)} stale entr{'y' if len(stale) 
== 1 else 'ies'}:[/yellow]"
+            )
+            for s in sorted(stale):
+                console.print(f"  [dim]-[/dim] {s}")
+            for s in stale:
+                del allowlist[s]
+            self.save(allowlist)
+            console.print(
+                f"\n[green]Updated[/green] 
[cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan]"
+            )
+        else:
+            console.print("[green]No stale entries found.[/green]")
+        return 0
+
+
+def _iter_python_files() -> list[Path]:
+    candidates: list[Path] = []
+    for top in _PROJECT_SOURCE_ROOTS:
+        candidates.extend(
+            p.resolve()
+            for p in (REPO_ROOT / top).rglob("*.py")
+            if ".tox" not in p.parts and "__pycache__" not in p.parts
+        )
+    return candidates
+
+
+def _check_provide_session_kwargs(
+    files: list[Path], allowlist: dict[str, int], manager: AllowlistManager
+) -> int:
+    allowlist_file = manager.allowlist_file.resolve()
+    if any(p.resolve() == allowlist_file for p in files) and not 
allowlist_file.exists():
+        console.print(
+            Panel.fit(
+                f"Allowlist file [cyan]{allowlist_file}[/cyan] is missing.\n"
+                "It was passed to the hook but cannot be read, so the check 
cannot proceed.\n"
+                "Restore it from git or regenerate it with:\n\n"
+                "  [cyan]uv run 
./scripts/ci/prek/check_provide_session_kwargs.py --generate[/cyan]",
+                title="[red]Check failed[/red]",
+                border_style="red",
+            )
+        )
+        return 1

Review Comment:
   It's not that ideal to return 1 or 0 for such functions 🤔 a bit confusing. 
it works "ok" for a command line so I won't block it.



##########
scripts/tests/ci/prek/test_check_provide_session_kwargs.py:
##########
@@ -0,0 +1,477 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import ast
+import os
+import subprocess
+import textwrap
+from pathlib import Path
+
+import pytest
+from ci.prek import check_provide_session_kwargs as hook
+from ci.prek.check_provide_session_kwargs import (
+    AllowlistManager,
+    _check_provide_session_kwargs,
+    _count_violations,
+    _expand_for_allowlist_edits,
+    _has_provide_session_decorator,
+    _iter_positional_session_in_provide_session,
+    _previous_allowlist,
+    _session_is_positional,
+)
+
+
[email protected]
+def find_violations(write_python_file):
+    """Factory fixture: write code to a temp file and return 
positional-session violations."""
+
+    def _check(code: str) -> list[tuple[ast.FunctionDef | 
ast.AsyncFunctionDef, ast.arg]]:
+        path = write_python_file(code)
+        return list(_iter_positional_session_in_provide_session(path))
+
+    return _check
+
+
[email protected]
+def fake_repo(tmp_path, monkeypatch):
+    """Create a fake repo layout and patch REPO_ROOT so paths resolve 
correctly."""
+    monkeypatch.setattr(hook, "REPO_ROOT", tmp_path)
+
+    def _write(rel: str, code: str) -> Path:
+        path = tmp_path / rel
+        path.parent.mkdir(parents=True, exist_ok=True)
+        path.write_text(textwrap.dedent(code))
+        return path
+
+    return _write
+
+
[email protected]
+def git_repo(fake_repo, tmp_path):
+    """Initialise ``tmp_path`` as a git repo so ``git show HEAD:<file>`` works.
+
+    Returns a helper that commits the current working-tree contents under a 
given
+    message, so tests can stage a "previous" allowlist at HEAD before mutating 
it.
+    """
+    env = {
+        **os.environ,
+        "GIT_AUTHOR_NAME": "t",
+        "GIT_AUTHOR_EMAIL": "t@t",
+        "GIT_COMMITTER_NAME": "t",
+        "GIT_COMMITTER_EMAIL": "t@t",
+    }
+
+    def _run(*args: str) -> None:
+        subprocess.run(["git", "-C", str(tmp_path), *args], check=True, 
env=env, capture_output=True)
+
+    _run("init", "-q", "-b", "main")
+    _run("config", "commit.gpgsign", "false")
+
+    def _commit(message: str) -> None:
+        _run("add", "-A")
+        _run("commit", "-q", "--allow-empty", "-m", message)
+
+    return _commit
+
+
+class TestHasProvideSessionDecorator:
+    def test_provide_session_name(self):
+        func = ast.parse("@provide_session\ndef foo(): pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is True
+
+    def test_provide_session_attribute(self):
+        func = ast.parse("@utils.provide_session\ndef foo(): pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is True
+
+    def test_no_decorator(self):
+        func = ast.parse("def foo(): pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is False
+
+    def test_unrelated_decorator(self):
+        func = ast.parse("@staticmethod\ndef foo(): pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is False
+
+    def test_multiple_decorators_including_provide_session(self):
+        func = ast.parse("@staticmethod\n@provide_session\ndef foo(): 
pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is True
+
+
+class TestSessionIsPositional:
+    def test_no_session_arg(self):
+        func = ast.parse("def foo(x, y): pass").body[0]
+        assert _session_is_positional(func.args) is None
+
+    def test_session_positional(self):
+        func = ast.parse("def foo(session=NEW_SESSION): pass").body[0]
+        argument = _session_is_positional(func.args)
+        assert argument is not None
+        assert argument.arg == "session"
+
+    def test_session_keyword_only(self):
+        func = ast.parse("def foo(*, session=NEW_SESSION): pass").body[0]
+        assert _session_is_positional(func.args) is None
+
+    def test_session_positional_among_other_args(self):
+        func = ast.parse("def foo(x, y, session=NEW_SESSION): pass").body[0]
+        argument = _session_is_positional(func.args)
+        assert argument is not None
+        assert argument.arg == "session"
+
+    def test_session_kwonly_after_other_positional(self):
+        func = ast.parse("def foo(x, y, *, session=NEW_SESSION): pass").body[0]
+        assert _session_is_positional(func.args) is None
+
+    def test_session_positional_only(self):
+        func = ast.parse("def foo(session, /, x): pass").body[0]
+        argument = _session_is_positional(func.args)
+        assert argument is not None
+        assert argument.arg == "session"
+
+
+class TestIterPositionalSessionInProvideSession:
+    def test_keyword_only_session_is_clean(self, find_violations):
+        code = """\
+        @provide_session
+        def foo(*, session=NEW_SESSION):
+            pass
+        """
+        assert find_violations(code) == []
+
+    def test_positional_session_is_flagged(self, find_violations):
+        code = """\
+        @provide_session
+        def foo(session=NEW_SESSION):
+            pass
+        """
+        violations = find_violations(code)
+        assert len(violations) == 1
+        func, argument = violations[0]
+        assert func.name == "foo"
+        assert argument.arg == "session"
+
+    def test_no_provide_session_decorator_is_ignored(self, find_violations):
+        code = """\
+        def foo(session=NEW_SESSION):
+            pass
+        """
+        assert find_violations(code) == []
+
+    def test_async_function_with_positional_session_is_flagged(self, 
find_violations):
+        code = """\
+        @provide_session
+        async def foo(session=NEW_SESSION):
+            pass
+        """
+        violations = find_violations(code)
+        assert len(violations) == 1
+
+    def test_method_with_positional_session_is_flagged(self, find_violations):
+        code = """\
+        class C:
+            @provide_session
+            def foo(self, session=NEW_SESSION):
+                pass
+        """
+        violations = find_violations(code)
+        assert len(violations) == 1
+        assert violations[0][0].name == "foo"
+
+    def test_attribute_decorator_is_recognised(self, find_violations):
+        code = """\
+        @airflow.utils.session.provide_session
+        def foo(session=NEW_SESSION):
+            pass
+        """
+        violations = find_violations(code)
+        assert len(violations) == 1
+
+    def test_count_violations_multiple_in_file(self, write_python_file):
+        code = """\
+        @provide_session
+        def a(session=NEW_SESSION):
+            pass
+
+        @provide_session
+        def b(x, session=NEW_SESSION):
+            pass
+
+        @provide_session
+        def c(*, session=NEW_SESSION):
+            pass
+        """
+        path = write_python_file(code)
+        assert _count_violations(path) == 2
+
+    def test_syntax_error_returns_no_violations(self, write_python_file):
+        path = write_python_file("def foo(:\n    pass")
+        assert _count_violations(path) == 0
+
+    def test_invalid_utf8_does_not_crash(self, tmp_path):
+        path = tmp_path / "invalid_utf8.py"
+        path.write_bytes(b"# bad byte: \xff\n@provide_session\ndef 
foo(session=NEW_SESSION):\n    pass\n")
+
+        assert _count_violations(path) == 1
+
+
+class TestAllowlistManager:
+    def test_load_missing_file_returns_empty(self, tmp_path):
+        manager = AllowlistManager(tmp_path / "missing.txt")
+        assert manager.load() == {}
+
+    def test_save_and_load_round_trip(self, tmp_path):
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        manager.save({"b/file.py": 2, "a/file.py": 1})
+        # Sorted by key in the file
+        text = (tmp_path / "allowlist.txt").read_text()
+        assert text.splitlines() == ["a/file.py::1", "b/file.py::2"]
+        assert manager.load() == {"a/file.py": 1, "b/file.py": 2}
+
+    def test_load_skips_blank_and_malformed_lines(self, tmp_path):
+        path = tmp_path / "allowlist.txt"
+        path.write_text("\nvalid/file.py::3\nnocount\n::5\nbad::notanumber\n")
+        assert AllowlistManager(path).load() == {"valid/file.py": 3}
+
+    def test_load_skips_unsafe_entries(self, fake_repo, tmp_path):

Review Comment:
   ```suggestion
       @pytest.mark.usefixtures("fake_repo")
       def test_load_skips_unsafe_entries(self, tmp_path):
   ```



##########
scripts/ci/prek/check_provide_session_kwargs.py:
##########
@@ -0,0 +1,430 @@
+#!/usr/bin/env python
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#   "rich>=13.0.0",
+# ]
+# ///
+"""Check that no new ``@provide_session`` functions declare ``session`` 
positionally.
+
+The project convention is that any function decorated with ``@provide_session``
+must declare ``session`` as keyword-only (after a bare ``*`` in the signature),
+so callers cannot pass it positionally by accident. See
+``contributing-docs/05_pull_requests.rst#database-session-handling``.
+
+All *existing* offenders are recorded in 
``known_provide_session_positional.txt``
+next to this script as ``relative/path::N`` entries (one per file), where ``N``
+is the maximum number of ``@provide_session`` functions with a positional
+``session`` argument allowed in that file. A file whose current count exceeds
+the recorded limit is treated as a violation – move the ``session`` argument
+behind a bare ``*`` instead.
+
+Modes
+-----
+Default (files passed by prek/pre-commit):
+    Check only the supplied files; fail if any file's count exceeds the limit.
+    When a file's count has *decreased*, the allowlist entry is tightened
+    automatically and the hook exits with a non-zero code so that pre-commit
+    reports the modified allowlist – just stage
+    ``scripts/ci/prek/known_provide_session_positional.txt`` and re-run.
+
+``--all-files``:
+    Walk every ``.py`` file under the project source roots
+    (``airflow-core``, ``airflow-ctl``, ``task-sdk``, ``providers``, 
``shared``) —
+    the same scope the pre-commit hook applies to.
+
+``--cleanup``:
+    Remove entries for files that no longer exist. Safe to run at any time;
+    does not add new entries or raise limits.
+
+``--generate``:
+    Scan the same project source roots as ``--all-files`` and *rebuild* the
+    allowlist from scratch. Intended for the initial setup or after a
+    large-scale clean-up sprint.
+"""
+
+from __future__ import annotations
+
+import argparse
+import ast
+import subprocess
+import typing
+from pathlib import Path
+
+from rich.console import Console
+from rich.panel import Panel
+
+console = Console(color_system="standard", width=200)
+
+REPO_ROOT = Path(__file__).parents[3]
+
+_PROVIDE_SESSION_DECORATOR = "provide_session"
+
+# Top-level directories scanned by ``--all-files`` / ``--generate``. Keep in 
sync with the
+# ``files:`` pattern for this hook in ``.pre-commit-config.yaml``.
+_PROJECT_SOURCE_ROOTS = ("airflow-core", "airflow-ctl", "task-sdk", 
"providers", "shared")
+
+
+def _has_provide_session_decorator(nodes: list[ast.expr]) -> bool:
+    """Whether one of ``nodes`` is a ``@provide_session`` decorator.
+
+    Accepts both bare names (``@provide_session``) and attribute access
+    (``@something.provide_session``).
+    """
+    for node in nodes:
+        if isinstance(node, ast.Name) and node.id == 
_PROVIDE_SESSION_DECORATOR:
+            return True
+        if isinstance(node, ast.Attribute) and node.attr == 
_PROVIDE_SESSION_DECORATOR:
+            return True
+    return False
+
+
+def _session_is_positional(args: ast.arguments) -> ast.arg | None:
+    """Return the ``session`` arg if it is positional (not keyword-only).
+
+    Covers both regular positional args and positional-only args (``def 
f(session, /, ...)``).
+    """
+    for argument in (*args.posonlyargs, *args.args):
+        if argument.arg == "session":
+            return argument
+    return None
+
+
+def _iter_positional_session_in_provide_session(
+    path: Path,
+) -> typing.Iterator[tuple[ast.FunctionDef | ast.AsyncFunctionDef, ast.arg]]:
+    """Yield ``@provide_session`` functions in *path* whose ``session`` is 
positional."""
+    try:
+        source = path.read_text(encoding="utf-8", errors="replace")
+    except OSError:
+        return
+    try:
+        tree = ast.parse(source, str(path))
+    except SyntaxError:
+        return
+    for node in ast.walk(tree):
+        if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
+            continue
+        if not _has_provide_session_decorator(node.decorator_list):
+            continue
+        argument = _session_is_positional(node.args)
+        if argument is None:
+            continue
+        yield node, argument
+
+
+def _count_violations(path: Path) -> int:
+    return sum(1 for _ in _iter_positional_session_in_provide_session(path))
+
+
+def _is_safe_relative(rel: str) -> bool:
+    """Whether ``rel`` is a plain relative path that stays inside 
``REPO_ROOT``.
+
+    Rejects absolute paths and any entry that resolves outside the repo root so
+    callers can ``relative_to(REPO_ROOT)`` without fear of a ``ValueError``.
+    """
+    candidate = Path(rel)
+    if candidate.is_absolute():
+        return False
+    try:
+        (REPO_ROOT / candidate).resolve().relative_to(REPO_ROOT.resolve())
+    except ValueError:
+        return False
+    return True
+
+
+class AllowlistManager:
+    def __init__(self, allowlist_file: Path) -> None:
+        self.allowlist_file = allowlist_file
+
+    @staticmethod
+    def parse(text: str) -> dict[str, int]:
+        """Parse allowlist *text* into a ``{rel_path: count}`` mapping.
+
+        Same validation rules as :meth:`load` so we can reuse parsing for the
+        on-disk allowlist *and* for the previous version fetched from git when
+        guarding against entry-removal bypasses.
+        """
+        result: dict[str, int] = {}
+        for raw_line in text.splitlines():
+            if not (stripped := raw_line.strip()):
+                continue
+
+            rel_str, _, count_str = stripped.rpartition("::")
+            if not rel_str or not count_str:
+                continue
+
+            try:
+                count = int(count_str)
+            except ValueError:
+                continue
+
+            if not _is_safe_relative(rel_str):
+                console.print(
+                    f"[yellow]Ignoring unsafe allowlist entry (escapes repo 
root):[/yellow] {rel_str}"
+                )
+                continue
+
+            result[rel_str] = count
+
+        return result
+
+    def load(self) -> dict[str, int]:
+        if not self.allowlist_file.exists():
+            return {}
+        return self.parse(self.allowlist_file.read_text())
+
+    def save(self, counts: dict[str, int]) -> None:
+        lines = [f"{rel}::{count}" for rel, count in sorted(counts.items())]
+        self.allowlist_file.write_text("\n".join(lines) + "\n")
+
+    def generate(self) -> int:
+        roots = ", ".join(_PROJECT_SOURCE_ROOTS)
+        console.print(
+            f"Scanning project source roots ([cyan]{roots}[/cyan]) under 
[cyan]{REPO_ROOT}[/cyan] "
+            "for @provide_session functions with positional session …"
+        )
+        counts: dict[str, int] = {}
+        for path in _iter_python_files():
+            n = _count_violations(path)
+            if n > 0:
+                counts[str(path.relative_to(REPO_ROOT))] = n
+
+        self.save(counts)
+        total = sum(counts.values())
+        console.print(
+            f"[green]Generated[/green] 
[cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan] "
+            f"with [bold]{len(counts)}[/bold] files / [bold]{total}[/bold] 
offenders."
+        )
+        return 0
+
+    def cleanup(self) -> int:
+        allowlist = self.load()
+        if not allowlist:
+            console.print("[yellow]Allowlist is empty - nothing to clean 
up.[/yellow]")
+            return 0
+
+        stale: list[str] = [rel for rel in allowlist if not (REPO_ROOT / 
rel).exists()]
+        if stale:
+            console.print(
+                f"[yellow]Removing {len(stale)} stale entr{'y' if len(stale) 
== 1 else 'ies'}:[/yellow]"
+            )
+            for s in sorted(stale):
+                console.print(f"  [dim]-[/dim] {s}")
+            for s in stale:
+                del allowlist[s]
+            self.save(allowlist)
+            console.print(
+                f"\n[green]Updated[/green] 
[cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan]"
+            )
+        else:
+            console.print("[green]No stale entries found.[/green]")
+        return 0
+
+
+def _iter_python_files() -> list[Path]:
+    candidates: list[Path] = []
+    for top in _PROJECT_SOURCE_ROOTS:
+        candidates.extend(
+            p.resolve()
+            for p in (REPO_ROOT / top).rglob("*.py")
+            if ".tox" not in p.parts and "__pycache__" not in p.parts
+        )
+    return candidates
+
+
+def _check_provide_session_kwargs(
+    files: list[Path], allowlist: dict[str, int], manager: AllowlistManager
+) -> int:
+    allowlist_file = manager.allowlist_file.resolve()
+    if any(p.resolve() == allowlist_file for p in files) and not 
allowlist_file.exists():
+        console.print(
+            Panel.fit(
+                f"Allowlist file [cyan]{allowlist_file}[/cyan] is missing.\n"
+                "It was passed to the hook but cannot be read, so the check 
cannot proceed.\n"
+                "Restore it from git or regenerate it with:\n\n"
+                "  [cyan]uv run 
./scripts/ci/prek/check_provide_session_kwargs.py --generate[/cyan]",
+                title="[red]Check failed[/red]",
+                border_style="red",
+            )
+        )
+        return 1
+
+    violations: list[tuple[Path, int, int]] = []
+    tightened: list[tuple[str, int, int]] = []
+
+    for path in files:
+        if not path.exists() or path.suffix != ".py":
+            continue
+        actual = _count_violations(path)
+        rel = str(path.relative_to(REPO_ROOT))
+        allowed = allowlist.get(rel, 0)
+        if actual > allowed:
+            violations.append((path, actual, allowed))
+        elif actual < allowed:
+            if actual == 0:
+                del allowlist[rel]
+            else:
+                allowlist[rel] = actual
+            tightened.append((rel, allowed, actual))
+
+    if tightened:
+        manager.save(allowlist)
+        console.print(
+            f"[green]Tightened {len(tightened)} entr{'y' if len(tightened) == 
1 else 'ies'} "
+            f"in 
[cyan]{manager.allowlist_file.relative_to(REPO_ROOT)}[/cyan][/green] "
+            "(stage the updated file):"
+        )
+        for rel, old, new in tightened:
+            console.print(f"  [cyan]{rel}[/cyan]  {old} -> {new}")
+
+    if violations:
+        console.print(
+            Panel.fit(
+                "New [bold]@provide_session[/bold] function with positional 
``session`` detected.\n"
+                "Move ``session`` after a bare ``*`` in the signature so 
callers must pass it by keyword:\n\n"
+                "  [cyan]@provide_session\n"
+                "  def foo(arg, *, session: Session = NEW_SESSION) -> None: 
...[/cyan]\n\n"
+                "If this usage is intentional and pre-existing, run:\n\n"
+                "  [cyan]uv run 
./scripts/ci/prek/check_provide_session_kwargs.py --generate[/cyan]\n\n"
+                "to regenerate the allowlist, then commit the updated\n"
+                
"[cyan]scripts/ci/prek/known_provide_session_positional.txt[/cyan].",
+                title="[red]Check failed[/red]",
+                border_style="red",
+            )
+        )
+        for path, actual, allowed in violations:
+            console.print(f"  [cyan]{path.relative_to(REPO_ROOT)}[/cyan]  
count={actual} (allowed={allowed})")
+            for func, argument in 
_iter_positional_session_in_provide_session(path):
+                console.print(f"      [dim]L{argument.lineno}[/dim] def 
{func.name}(...)")
+        return 1
+
+    return 1 if tightened else 0
+
+
+def main(argv: list[str] | None = None) -> int:
+    parser = argparse.ArgumentParser(
+        description="Prevent new @provide_session functions from declaring 
`session` positionally.",
+        formatter_class=argparse.RawDescriptionHelpFormatter,
+        epilog=__doc__,
+    )
+    parser.add_argument("files", nargs="*", metavar="FILE", help="Files to 
check (provided by prek)")
+    parser.add_argument(
+        "--all-files",
+        action="store_true",
+        help=(
+            "Check every Python file under the project source roots "
+            "(airflow-core, airflow-ctl, task-sdk, providers, shared)"
+        ),
+    )
+    parser.add_argument(
+        "--cleanup",
+        action="store_true",
+        help="Remove stale entries from the allowlist and exit",
+    )
+    parser.add_argument(
+        "--generate",
+        action="store_true",
+        help="Regenerate the allowlist from the current codebase and exit",
+    )
+    args = parser.parse_args(argv)
+
+    manager = AllowlistManager(Path(__file__).parent / 
"known_provide_session_positional.txt")
+
+    if args.generate:
+        return manager.generate()
+
+    if args.cleanup:
+        return manager.cleanup()
+
+    allowlist = manager.load()
+
+    if args.all_files:
+        return _check_provide_session_kwargs(_iter_python_files(), allowlist, 
manager)
+
+    if not args.files:
+        console.print(
+            "[yellow]No files provided. Pass filenames or use --all-files to 
scan the whole repo.[/yellow]"
+        )
+        return 0
+
+    paths = [Path(f).resolve() for f in args.files]
+    paths = _expand_for_allowlist_edits(paths, manager, allowlist)
+    return _check_provide_session_kwargs(paths, allowlist, manager)
+
+
+def _previous_allowlist(manager: AllowlistManager) -> dict[str, int]:

Review Comment:
   ```suggestion
   def _parse_tracked_allowlist(manager: AllowlistManager) -> dict[str, int]:
   ```
   
   let's make it a verb. instead of previous, I think we can use the stage 
"tracked" in git as part of the function name



##########
scripts/tests/ci/prek/test_check_provide_session_kwargs.py:
##########
@@ -0,0 +1,477 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import ast
+import os
+import subprocess
+import textwrap
+from pathlib import Path
+
+import pytest
+from ci.prek import check_provide_session_kwargs as hook
+from ci.prek.check_provide_session_kwargs import (
+    AllowlistManager,
+    _check_provide_session_kwargs,
+    _count_violations,
+    _expand_for_allowlist_edits,
+    _has_provide_session_decorator,
+    _iter_positional_session_in_provide_session,
+    _previous_allowlist,
+    _session_is_positional,
+)
+
+
[email protected]
+def find_violations(write_python_file):
+    """Factory fixture: write code to a temp file and return 
positional-session violations."""
+
+    def _check(code: str) -> list[tuple[ast.FunctionDef | 
ast.AsyncFunctionDef, ast.arg]]:
+        path = write_python_file(code)
+        return list(_iter_positional_session_in_provide_session(path))
+
+    return _check
+
+
[email protected]
+def fake_repo(tmp_path, monkeypatch):

Review Comment:
   ```suggestion
   def create_fake_repo(tmp_path, monkeypatch):
   ```



##########
scripts/tests/ci/prek/test_check_provide_session_kwargs.py:
##########
@@ -0,0 +1,477 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import ast
+import os
+import subprocess
+import textwrap
+from pathlib import Path
+
+import pytest
+from ci.prek import check_provide_session_kwargs as hook
+from ci.prek.check_provide_session_kwargs import (
+    AllowlistManager,
+    _check_provide_session_kwargs,
+    _count_violations,
+    _expand_for_allowlist_edits,
+    _has_provide_session_decorator,
+    _iter_positional_session_in_provide_session,
+    _previous_allowlist,
+    _session_is_positional,
+)
+
+
[email protected]
+def find_violations(write_python_file):
+    """Factory fixture: write code to a temp file and return 
positional-session violations."""
+
+    def _check(code: str) -> list[tuple[ast.FunctionDef | 
ast.AsyncFunctionDef, ast.arg]]:
+        path = write_python_file(code)
+        return list(_iter_positional_session_in_provide_session(path))
+
+    return _check
+
+
[email protected]
+def fake_repo(tmp_path, monkeypatch):
+    """Create a fake repo layout and patch REPO_ROOT so paths resolve 
correctly."""
+    monkeypatch.setattr(hook, "REPO_ROOT", tmp_path)
+
+    def _write(rel: str, code: str) -> Path:
+        path = tmp_path / rel
+        path.parent.mkdir(parents=True, exist_ok=True)
+        path.write_text(textwrap.dedent(code))
+        return path
+
+    return _write
+
+
[email protected]
+def git_repo(fake_repo, tmp_path):
+    """Initialise ``tmp_path`` as a git repo so ``git show HEAD:<file>`` works.
+
+    Returns a helper that commits the current working-tree contents under a 
given
+    message, so tests can stage a "previous" allowlist at HEAD before mutating 
it.
+    """
+    env = {
+        **os.environ,
+        "GIT_AUTHOR_NAME": "t",
+        "GIT_AUTHOR_EMAIL": "t@t",
+        "GIT_COMMITTER_NAME": "t",
+        "GIT_COMMITTER_EMAIL": "t@t",
+    }
+
+    def _run(*args: str) -> None:
+        subprocess.run(["git", "-C", str(tmp_path), *args], check=True, 
env=env, capture_output=True)
+
+    _run("init", "-q", "-b", "main")
+    _run("config", "commit.gpgsign", "false")
+
+    def _commit(message: str) -> None:
+        _run("add", "-A")
+        _run("commit", "-q", "--allow-empty", "-m", message)
+
+    return _commit
+
+
+class TestHasProvideSessionDecorator:
+    def test_provide_session_name(self):
+        func = ast.parse("@provide_session\ndef foo(): pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is True
+
+    def test_provide_session_attribute(self):
+        func = ast.parse("@utils.provide_session\ndef foo(): pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is True
+
+    def test_no_decorator(self):
+        func = ast.parse("def foo(): pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is False
+
+    def test_unrelated_decorator(self):
+        func = ast.parse("@staticmethod\ndef foo(): pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is False
+
+    def test_multiple_decorators_including_provide_session(self):
+        func = ast.parse("@staticmethod\n@provide_session\ndef foo(): 
pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is True
+
+
+class TestSessionIsPositional:
+    def test_no_session_arg(self):
+        func = ast.parse("def foo(x, y): pass").body[0]
+        assert _session_is_positional(func.args) is None
+
+    def test_session_positional(self):
+        func = ast.parse("def foo(session=NEW_SESSION): pass").body[0]
+        argument = _session_is_positional(func.args)
+        assert argument is not None
+        assert argument.arg == "session"
+
+    def test_session_keyword_only(self):
+        func = ast.parse("def foo(*, session=NEW_SESSION): pass").body[0]
+        assert _session_is_positional(func.args) is None
+
+    def test_session_positional_among_other_args(self):
+        func = ast.parse("def foo(x, y, session=NEW_SESSION): pass").body[0]
+        argument = _session_is_positional(func.args)
+        assert argument is not None
+        assert argument.arg == "session"
+
+    def test_session_kwonly_after_other_positional(self):
+        func = ast.parse("def foo(x, y, *, session=NEW_SESSION): pass").body[0]
+        assert _session_is_positional(func.args) is None
+
+    def test_session_positional_only(self):
+        func = ast.parse("def foo(session, /, x): pass").body[0]
+        argument = _session_is_positional(func.args)
+        assert argument is not None
+        assert argument.arg == "session"
+
+
+class TestIterPositionalSessionInProvideSession:
+    def test_keyword_only_session_is_clean(self, find_violations):
+        code = """\
+        @provide_session
+        def foo(*, session=NEW_SESSION):
+            pass
+        """
+        assert find_violations(code) == []
+
+    def test_positional_session_is_flagged(self, find_violations):
+        code = """\
+        @provide_session
+        def foo(session=NEW_SESSION):
+            pass
+        """
+        violations = find_violations(code)
+        assert len(violations) == 1
+        func, argument = violations[0]
+        assert func.name == "foo"
+        assert argument.arg == "session"
+
+    def test_no_provide_session_decorator_is_ignored(self, find_violations):
+        code = """\
+        def foo(session=NEW_SESSION):
+            pass
+        """
+        assert find_violations(code) == []
+
+    def test_async_function_with_positional_session_is_flagged(self, 
find_violations):
+        code = """\
+        @provide_session
+        async def foo(session=NEW_SESSION):
+            pass
+        """
+        violations = find_violations(code)
+        assert len(violations) == 1
+
+    def test_method_with_positional_session_is_flagged(self, find_violations):
+        code = """\
+        class C:
+            @provide_session
+            def foo(self, session=NEW_SESSION):
+                pass
+        """
+        violations = find_violations(code)
+        assert len(violations) == 1
+        assert violations[0][0].name == "foo"
+
+    def test_attribute_decorator_is_recognised(self, find_violations):
+        code = """\
+        @airflow.utils.session.provide_session
+        def foo(session=NEW_SESSION):
+            pass
+        """
+        violations = find_violations(code)
+        assert len(violations) == 1
+
+    def test_count_violations_multiple_in_file(self, write_python_file):
+        code = """\
+        @provide_session
+        def a(session=NEW_SESSION):
+            pass
+
+        @provide_session
+        def b(x, session=NEW_SESSION):
+            pass
+
+        @provide_session
+        def c(*, session=NEW_SESSION):
+            pass
+        """
+        path = write_python_file(code)
+        assert _count_violations(path) == 2
+
+    def test_syntax_error_returns_no_violations(self, write_python_file):
+        path = write_python_file("def foo(:\n    pass")
+        assert _count_violations(path) == 0
+
+    def test_invalid_utf8_does_not_crash(self, tmp_path):
+        path = tmp_path / "invalid_utf8.py"
+        path.write_bytes(b"# bad byte: \xff\n@provide_session\ndef 
foo(session=NEW_SESSION):\n    pass\n")
+
+        assert _count_violations(path) == 1
+
+
+class TestAllowlistManager:
+    def test_load_missing_file_returns_empty(self, tmp_path):
+        manager = AllowlistManager(tmp_path / "missing.txt")
+        assert manager.load() == {}
+
+    def test_save_and_load_round_trip(self, tmp_path):
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        manager.save({"b/file.py": 2, "a/file.py": 1})
+        # Sorted by key in the file
+        text = (tmp_path / "allowlist.txt").read_text()
+        assert text.splitlines() == ["a/file.py::1", "b/file.py::2"]
+        assert manager.load() == {"a/file.py": 1, "b/file.py": 2}
+
+    def test_load_skips_blank_and_malformed_lines(self, tmp_path):
+        path = tmp_path / "allowlist.txt"
+        path.write_text("\nvalid/file.py::3\nnocount\n::5\nbad::notanumber\n")
+        assert AllowlistManager(path).load() == {"valid/file.py": 3}
+
+    def test_load_skips_unsafe_entries(self, fake_repo, tmp_path):
+        """Entries that escape REPO_ROOT (absolute paths or `..` segments) are 
ignored."""
+        path = tmp_path / "allowlist.txt"
+        
path.write_text("airflow-core/src/airflow/safe.py::1\n../escape.py::1\n/etc/passwd::1\n")
+        # `fake_repo` patches REPO_ROOT to tmp_path so the safety check is 
meaningful.
+        assert AllowlistManager(path).load() == 
{"airflow-core/src/airflow/safe.py": 1}
+
+
+class TestCheckProvideSessionKwargs:
+    def test_no_violations_in_clean_file(self, fake_repo, tmp_path):
+        path = fake_repo(
+            "airflow-core/src/airflow/clean.py",
+            """\
+            @provide_session
+            def foo(*, session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        assert _check_provide_session_kwargs([path], {}, manager) == 0
+
+    def test_new_violation_fails(self, fake_repo, tmp_path):
+        path = fake_repo(
+            "airflow-core/src/airflow/bad.py",
+            """\
+            @provide_session
+            def foo(session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        assert _check_provide_session_kwargs([path], {}, manager) == 1
+
+    def test_violation_within_allowlist_passes(self, fake_repo, tmp_path):
+        path = fake_repo(
+            "airflow-core/src/airflow/grandfathered.py",
+            """\
+            @provide_session
+            def foo(session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        allowlist = {"airflow-core/src/airflow/grandfathered.py": 1}
+        assert _check_provide_session_kwargs([path], allowlist, manager) == 0
+
+    def test_exceeding_allowlist_fails(self, fake_repo, tmp_path):
+        path = fake_repo(
+            "airflow-core/src/airflow/grew.py",
+            """\
+            @provide_session
+            def a(session=NEW_SESSION):
+                pass
+
+            @provide_session
+            def b(session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        allowlist = {"airflow-core/src/airflow/grew.py": 1}
+        assert _check_provide_session_kwargs([path], allowlist, manager) == 1
+
+    def test_reducing_violations_tightens_allowlist(self, fake_repo, tmp_path):
+        path = fake_repo(
+            "airflow-core/src/airflow/improved.py",
+            """\
+            @provide_session
+            def foo(session=NEW_SESSION):
+                pass
+
+            @provide_session
+            def bar(*, session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        allowlist = {"airflow-core/src/airflow/improved.py": 2}
+        # Exit non-zero so pre-commit reports the modified allowlist
+        assert _check_provide_session_kwargs([path], allowlist, manager) == 1
+        assert manager.load() == {"airflow-core/src/airflow/improved.py": 1}
+
+    def test_fixing_all_violations_removes_entry(self, fake_repo, tmp_path):
+        path = fake_repo(
+            "airflow-core/src/airflow/fixed.py",
+            """\
+            @provide_session
+            def foo(*, session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        allowlist = {"airflow-core/src/airflow/fixed.py": 1}
+        assert _check_provide_session_kwargs([path], allowlist, manager) == 1
+        assert manager.load() == {}
+
+    def test_non_python_file_is_skipped(self, fake_repo, tmp_path):
+        path = fake_repo(
+            "airflow-core/src/airflow/not_python.txt", "@provide_session\ndef 
foo(session=N): pass\n"
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        assert _check_provide_session_kwargs([path], {}, manager) == 0
+
+    def test_missing_allowlist_file_fails_loudly(self, fake_repo, tmp_path):

Review Comment:
   ```suggestion
       @pytest.mark.usefixtures("fake_repo")
       def test_missing_allowlist_file_fails_loudly(self, tmp_path):
   ```



##########
scripts/tests/ci/prek/test_check_provide_session_kwargs.py:
##########
@@ -0,0 +1,477 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import ast
+import os
+import subprocess
+import textwrap
+from pathlib import Path
+
+import pytest
+from ci.prek import check_provide_session_kwargs as hook
+from ci.prek.check_provide_session_kwargs import (
+    AllowlistManager,
+    _check_provide_session_kwargs,
+    _count_violations,
+    _expand_for_allowlist_edits,
+    _has_provide_session_decorator,
+    _iter_positional_session_in_provide_session,
+    _previous_allowlist,
+    _session_is_positional,
+)
+
+
[email protected]
+def find_violations(write_python_file):
+    """Factory fixture: write code to a temp file and return 
positional-session violations."""
+
+    def _check(code: str) -> list[tuple[ast.FunctionDef | 
ast.AsyncFunctionDef, ast.arg]]:
+        path = write_python_file(code)
+        return list(_iter_positional_session_in_provide_session(path))
+
+    return _check
+
+
[email protected]
+def fake_repo(tmp_path, monkeypatch):
+    """Create a fake repo layout and patch REPO_ROOT so paths resolve 
correctly."""
+    monkeypatch.setattr(hook, "REPO_ROOT", tmp_path)
+
+    def _write(rel: str, code: str) -> Path:
+        path = tmp_path / rel
+        path.parent.mkdir(parents=True, exist_ok=True)
+        path.write_text(textwrap.dedent(code))
+        return path
+
+    return _write
+
+
[email protected]
+def git_repo(fake_repo, tmp_path):

Review Comment:
   ```suggestion
   def create_git_repo(fake_repo, tmp_path):
   ```
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to