This is an automated email from the ASF dual-hosted git repository.
jason810496 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3276dded172 Add prek hook to enforce keyword-only `session` on
`@provide_session` (#67150)
3276dded172 is described below
commit 3276dded17244c83c5b01439441874d890286dfc
Author: Jason(Zhe-You) Liu <[email protected]>
AuthorDate: Thu May 28 09:32:57 2026 +0800
Add prek hook to enforce keyword-only `session` on `@provide_session`
(#67150)
---
.pre-commit-config.yaml | 6 +
scripts/ci/prek/check_provide_session_kwargs.py | 427 ++++++++++++++++++
.../ci/prek/known_provide_session_positional.txt | 89 ++++
.../ci/prek/test_check_provide_session_kwargs.py | 482 +++++++++++++++++++++
4 files changed, 1004 insertions(+)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 5f8351b30e7..42ab2035ec0 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1071,6 +1071,12 @@ repos:
language: python
pass_filenames: true
files: ^(airflow-core|airflow-ctl|task-sdk|providers|shared)/.*\.py$
+ - id: check-no-new-provide-session-positional
+ name: Check that no new @provide_session functions declare `session`
positionally
+ entry: ./scripts/ci/prek/check_provide_session_kwargs.py
+ language: python
+ pass_filenames: true
+ files:
^(airflow-core|providers)/.*\.py$|^scripts/ci/prek/known_provide_session_positional\.txt$|^scripts/ci/prek/check_provide_session_kwargs\.py$
- id: check-no-new-airflow-core-utils-modules
name: Check that no new modules are added under
airflow-core/src/airflow/utils
entry: ./scripts/ci/prek/check_no_new_airflow_core_utils_modules.py
diff --git a/scripts/ci/prek/check_provide_session_kwargs.py
b/scripts/ci/prek/check_provide_session_kwargs.py
new file mode 100755
index 00000000000..29152e13a4d
--- /dev/null
+++ b/scripts/ci/prek/check_provide_session_kwargs.py
@@ -0,0 +1,427 @@
+#!/usr/bin/env python
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "rich>=13.0.0",
+# ]
+# ///
+"""Check that no new ``@provide_session`` functions declare ``session``
positionally.
+
+The project convention is that any function decorated with ``@provide_session``
+must declare ``session`` as keyword-only (after a bare ``*`` in the signature),
+so callers cannot pass it positionally by accident. See
+``contributing-docs/05_pull_requests.rst#database-session-handling``.
+
+All *existing* offenders are recorded in
``known_provide_session_positional.txt``
+next to this script as ``relative/path::N`` entries (one per file), where ``N``
+is the maximum number of ``@provide_session`` functions with a positional
+``session`` argument allowed in that file. A file whose current count exceeds
+the recorded limit is treated as a violation – move the ``session`` argument
+behind a bare ``*`` instead.
+
+Modes
+-----
+Default (files passed by prek/pre-commit):
+ Check only the supplied files; fail if any file's count exceeds the limit.
+ When a file's count has *decreased*, the allowlist entry is tightened
+ automatically and the hook exits with a non-zero code so that pre-commit
+ reports the modified allowlist – just stage
+ ``scripts/ci/prek/known_provide_session_positional.txt`` and re-run.
+
+``--all-files``:
+ Walk every ``.py`` file under the project source roots
+ (``airflow-core``, ``providers``, ``shared``) —
+ the same scope the pre-commit hook applies to.
+
+``--cleanup``:
+ Remove entries for files that no longer exist. Safe to run at any time;
+ does not add new entries or raise limits.
+
+``--generate``:
+ Scan the same project source roots as ``--all-files`` and *rebuild* the
+ allowlist from scratch. Intended for the initial setup or after a
+ large-scale clean-up sprint.
+"""
+
+from __future__ import annotations
+
+import argparse
+import ast
+import subprocess
+import typing
+from pathlib import Path
+
+from rich.console import Console
+from rich.panel import Panel
+
+console = Console(color_system="standard", width=200)
+
+REPO_ROOT = Path(__file__).parents[3]
+
+_PROVIDE_SESSION_DECORATOR = "provide_session"
+
+# Top-level directories scanned by ``--all-files`` / ``--generate``. Keep in
sync with the
+# ``files:`` pattern for this hook in ``.pre-commit-config.yaml``.
+_PROJECT_SOURCE_ROOTS = ("airflow-core", "providers", "shared")
+
+
+def _has_provide_session_decorator(nodes: list[ast.expr]) -> bool:
+ """Whether one of ``nodes`` is a ``@provide_session`` decorator.
+
+ Accepts both bare names (``@provide_session``) and attribute access
+ (``@something.provide_session``).
+ """
+ for node in nodes:
+ if isinstance(node, ast.Name) and node.id ==
_PROVIDE_SESSION_DECORATOR:
+ return True
+ if isinstance(node, ast.Attribute) and node.attr ==
_PROVIDE_SESSION_DECORATOR:
+ return True
+ return False
+
+
+def _session_is_positional(args: ast.arguments) -> ast.arg | None:
+ """Return the ``session`` arg if it is positional (not keyword-only).
+
+ Covers both regular positional args and positional-only args (``def
f(session, /, ...)``).
+ """
+ for argument in (*args.posonlyargs, *args.args):
+ if argument.arg == "session":
+ return argument
+ return None
+
+
+def _iter_positional_session_in_provide_session(
+ path: Path,
+) -> typing.Iterator[tuple[ast.FunctionDef | ast.AsyncFunctionDef, ast.arg]]:
+ """Yield ``@provide_session`` functions in *path* whose ``session`` is
positional."""
+ try:
+ source = path.read_text(encoding="utf-8", errors="replace")
+ except OSError:
+ return
+ try:
+ tree = ast.parse(source, str(path))
+ except SyntaxError:
+ return
+ for node in ast.walk(tree):
+ if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
+ continue
+ if not _has_provide_session_decorator(node.decorator_list):
+ continue
+ argument = _session_is_positional(node.args)
+ if argument is None:
+ continue
+ yield node, argument
+
+
+def _count_violations(path: Path) -> int:
+ return sum(1 for _ in _iter_positional_session_in_provide_session(path))
+
+
+def _is_safe_relative(rel: str) -> bool:
+ """Whether ``rel`` is a plain relative path that stays inside
``REPO_ROOT``.
+
+ Rejects absolute paths and any entry that resolves outside the repo root so
+ callers can ``relative_to(REPO_ROOT)`` without fear of a ``ValueError``.
+ """
+ candidate = Path(rel)
+ if candidate.is_absolute():
+ return False
+ try:
+ (REPO_ROOT / candidate).resolve().relative_to(REPO_ROOT.resolve())
+ except ValueError:
+ return False
+ return True
+
+
+class AllowlistManager:
+ def __init__(self, allowlist_file: Path) -> None:
+ self.allowlist_file = allowlist_file
+
+ @staticmethod
+ def parse(text: str) -> dict[str, int]:
+ """Parse allowlist *text* into a ``{rel_path: count}`` mapping.
+
+ Same validation rules as :meth:`load` so we can reuse parsing for the
+ on-disk allowlist *and* for the git-tracked version fetched from
+ ``HEAD`` when guarding against entry-removal bypasses.
+ """
+ result: dict[str, int] = {}
+ for raw_line in text.splitlines():
+ if not (stripped := raw_line.strip()):
+ continue
+
+ rel_str, _, count_str = stripped.rpartition("::")
+ if not rel_str or not count_str:
+ continue
+
+ try:
+ count = int(count_str)
+ except ValueError:
+ continue
+
+ if not _is_safe_relative(rel_str):
+ console.print(
+ f"[yellow]Ignoring unsafe allowlist entry (escapes repo
root):[/yellow] {rel_str}"
+ )
+ continue
+
+ result[rel_str] = count
+
+ return result
+
+ def load(self) -> dict[str, int]:
+ if not self.allowlist_file.exists():
+ return {}
+ return self.parse(self.allowlist_file.read_text())
+
+ def save(self, counts: dict[str, int]) -> None:
+ lines = [f"{rel}::{count}" for rel, count in sorted(counts.items())]
+ self.allowlist_file.write_text("\n".join(lines) + "\n")
+
+ def generate(self) -> int:
+ roots = ", ".join(_PROJECT_SOURCE_ROOTS)
+ console.print(
+ f"Scanning project source roots ([cyan]{roots}[/cyan]) under
[cyan]{REPO_ROOT}[/cyan] "
+ "for @provide_session functions with positional session …"
+ )
+ counts: dict[str, int] = {}
+ for path in _iter_python_files():
+ n = _count_violations(path)
+ if n > 0:
+ counts[str(path.relative_to(REPO_ROOT))] = n
+
+ self.save(counts)
+ total = sum(counts.values())
+ console.print(
+ f"[green]Generated[/green]
[cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan] "
+ f"with [bold]{len(counts)}[/bold] files / [bold]{total}[/bold]
offenders."
+ )
+ return 0
+
+ def cleanup(self) -> int:
+ allowlist = self.load()
+ if not allowlist:
+ console.print("[yellow]Allowlist is empty - nothing to clean
up.[/yellow]")
+ return 0
+
+ stale: list[str] = [rel for rel in allowlist if not (REPO_ROOT /
rel).exists()]
+ if stale:
+ console.print(
+ f"[yellow]Removing {len(stale)} stale entr{'y' if len(stale)
== 1 else 'ies'}:[/yellow]"
+ )
+ for s in sorted(stale):
+ console.print(f" [dim]-[/dim] {s}")
+ for s in stale:
+ del allowlist[s]
+ self.save(allowlist)
+ console.print(
+ f"\n[green]Updated[/green]
[cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan]"
+ )
+ else:
+ console.print("[green]No stale entries found.[/green]")
+ return 0
+
+
+def _iter_python_files() -> list[Path]:
+ candidates: list[Path] = []
+ for top in _PROJECT_SOURCE_ROOTS:
+ candidates.extend(
+ p.resolve()
+ for p in (REPO_ROOT / top).rglob("*.py")
+ if ".tox" not in p.parts and "__pycache__" not in p.parts
+ )
+ return candidates
+
+
+def _check_provide_session_kwargs(
+ files: list[Path], allowlist: dict[str, int], manager: AllowlistManager
+) -> int:
+ allowlist_file = manager.allowlist_file.resolve()
+ if any(p.resolve() == allowlist_file for p in files) and not
allowlist_file.exists():
+ console.print(
+ Panel.fit(
+ f"Allowlist file [cyan]{allowlist_file}[/cyan] is missing.\n"
+ "It was passed to the hook but cannot be read, so the check
cannot proceed.\n"
+ "Restore it from git or regenerate it with:\n\n"
+ " [cyan]uv run
./scripts/ci/prek/check_provide_session_kwargs.py --generate[/cyan]",
+ title="[red]Check failed[/red]",
+ border_style="red",
+ )
+ )
+ return 1
+
+ violations: list[tuple[Path, int, int]] = []
+ tightened: list[tuple[str, int, int]] = []
+
+ for path in files:
+ if not path.exists() or path.suffix != ".py":
+ continue
+ actual = _count_violations(path)
+ rel = str(path.relative_to(REPO_ROOT))
+ allowed = allowlist.get(rel, 0)
+ if actual > allowed:
+ violations.append((path, actual, allowed))
+ elif actual < allowed:
+ if actual == 0:
+ del allowlist[rel]
+ else:
+ allowlist[rel] = actual
+ tightened.append((rel, allowed, actual))
+
+ if tightened:
+ manager.save(allowlist)
+ console.print(
+ f"[green]Tightened {len(tightened)} entr{'y' if len(tightened) ==
1 else 'ies'} "
+ f"in
[cyan]{manager.allowlist_file.relative_to(REPO_ROOT)}[/cyan][/green] "
+ "(stage the updated file):"
+ )
+ for rel, old, new in tightened:
+ console.print(f" [cyan]{rel}[/cyan] {old} -> {new}")
+
+ if violations:
+ console.print(
+ Panel.fit(
+ "New [bold]@provide_session[/bold] function with positional
``session`` detected.\n"
+ "Move ``session`` after a bare ``*`` in the signature so
callers must pass it by keyword:\n\n"
+ " [cyan]@provide_session\n"
+ " def foo(arg, *, session: Session = NEW_SESSION) -> None:
...[/cyan]\n\n"
+ "If this usage is intentional and pre-existing, run:\n\n"
+ " [cyan]uv run
./scripts/ci/prek/check_provide_session_kwargs.py --generate[/cyan]\n\n"
+ "to regenerate the allowlist, then commit the updated\n"
+
"[cyan]scripts/ci/prek/known_provide_session_positional.txt[/cyan].",
+ title="[red]Check failed[/red]",
+ border_style="red",
+ )
+ )
+ for path, actual, allowed in violations:
+ console.print(f" [cyan]{path.relative_to(REPO_ROOT)}[/cyan]
count={actual} (allowed={allowed})")
+ for func, argument in
_iter_positional_session_in_provide_session(path):
+ console.print(f" [dim]L{argument.lineno}[/dim] def
{func.name}(...)")
+ return 1
+
+ return 1 if tightened else 0
+
+
+def main(argv: list[str] | None = None) -> int:
+ parser = argparse.ArgumentParser(
+ description="Prevent new @provide_session functions from declaring
`session` positionally.",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog=__doc__,
+ )
+ parser.add_argument("files", nargs="*", metavar="FILE", help="Files to
check (provided by prek)")
+ parser.add_argument(
+ "--all-files",
+ action="store_true",
+ help=("Check every Python file under the project source roots
(airflow-core, providers, shared)"),
+ )
+ parser.add_argument(
+ "--cleanup",
+ action="store_true",
+ help="Remove stale entries from the allowlist and exit",
+ )
+ parser.add_argument(
+ "--generate",
+ action="store_true",
+ help="Regenerate the allowlist from the current codebase and exit",
+ )
+ args = parser.parse_args(argv)
+
+ manager = AllowlistManager(Path(__file__).parent /
"known_provide_session_positional.txt")
+
+ if args.generate:
+ return manager.generate()
+
+ if args.cleanup:
+ return manager.cleanup()
+
+ allowlist = manager.load()
+
+ if args.all_files:
+ return _check_provide_session_kwargs(_iter_python_files(), allowlist,
manager)
+
+ if not args.files:
+ console.print(
+ "[yellow]No files provided. Pass filenames or use --all-files to
scan the whole repo.[/yellow]"
+ )
+ return 0
+
+ paths = [Path(f).resolve() for f in args.files]
+ paths = _expand_for_allowlist_edits(paths, manager, allowlist)
+ return _check_provide_session_kwargs(paths, allowlist, manager)
+
+
+def _parse_tracked_allowlist(manager: AllowlistManager) -> dict[str, int]:
+ """Return the allowlist as recorded at ``HEAD`` (the git-tracked version).
+
+ Used by :func:`_expand_for_allowlist_edits` so that *removing* an entry
+ cannot silently drop coverage: the previously-listed file is still
+ re-validated against the new (post-edit) allowlist. Returns an empty
mapping
+ when git is unavailable, the file does not yet exist at ``HEAD``, or the
+ allowlist sits outside ``REPO_ROOT``.
+ """
+ try:
+ rel = manager.allowlist_file.resolve().relative_to(REPO_ROOT.resolve())
+ except ValueError:
+ return {}
+ try:
+ completed = subprocess.run(
+ ["git", "-C", str(REPO_ROOT), "show", f"HEAD:{rel.as_posix()}"],
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+ except (FileNotFoundError, OSError):
+ return {}
+ if completed.returncode != 0:
+ return {}
+ return AllowlistManager.parse(completed.stdout)
+
+
+def _expand_for_allowlist_edits(
+ paths: list[Path], manager: AllowlistManager, allowlist: dict[str, int]
+) -> list[Path]:
+ """Add allowlisted files when the allowlist itself is being changed.
+
+ Without this, a contributor could raise counts in
+ ``known_provide_session_positional.txt`` and the hook would do no
validation
+ (since only the ``.txt`` file is passed), letting the loosened allowlist
+ sail through. We also union the git-tracked allowlist (from ``HEAD``) so
+ that removing an entry cannot silently bypass the check for a file that
+ still has positional ``session`` arguments.
+
+ Both sides of the allowlist-file comparison are resolved so the detection
is
+ robust to symlinks and unresolved inputs (the hook can be invoked with
either).
+ """
+ allowlist_file = manager.allowlist_file.resolve()
+ if not any(p.resolve() == allowlist_file for p in paths):
+ return paths
+
+ expanded = list(paths)
+ seen = {p.resolve() for p in paths if p.suffix == ".py"}
+ tracked = _parse_tracked_allowlist(manager)
+ for rel in {*allowlist, *tracked}:
+ candidate = (REPO_ROOT / rel).resolve()
+ if candidate.exists() and candidate not in seen:
+ seen.add(candidate)
+ expanded.append(candidate)
+ return expanded
+
+
+if __name__ == "__main__":
+ raise SystemExit(main())
diff --git a/scripts/ci/prek/known_provide_session_positional.txt
b/scripts/ci/prek/known_provide_session_positional.txt
new file mode 100644
index 00000000000..d0c84e2f6b4
--- /dev/null
+++ b/scripts/ci/prek/known_provide_session_positional.txt
@@ -0,0 +1,89 @@
+airflow-core/src/airflow/api/common/delete_dag.py::1
+airflow-core/src/airflow/api/common/mark_tasks.py::1
+airflow-core/src/airflow/callbacks/database_callback_sink.py::1
+airflow-core/src/airflow/cli/commands/dag_command.py::8
+airflow-core/src/airflow/cli/commands/jobs_command.py::1
+airflow-core/src/airflow/cli/commands/task_command.py::1
+airflow-core/src/airflow/cli/commands/team_command.py::4
+airflow-core/src/airflow/cli/commands/variable_command.py::1
+airflow-core/src/airflow/dag_processing/dagbag.py::1
+airflow-core/src/airflow/dag_processing/manager.py::4
+airflow-core/src/airflow/jobs/base_job_runner.py::2
+airflow-core/src/airflow/jobs/job.py::7
+airflow-core/src/airflow/jobs/scheduler_job_runner.py::11
+airflow-core/src/airflow/jobs/triggerer_job_runner.py::1
+airflow-core/src/airflow/models/connection.py::2
+airflow-core/src/airflow/models/dag.py::7
+airflow-core/src/airflow/models/dagcode.py::6
+airflow-core/src/airflow/models/dagrun.py::15
+airflow-core/src/airflow/models/dagwarning.py::1
+airflow-core/src/airflow/models/deadline.py::1
+airflow-core/src/airflow/models/deadline_alert.py::1
+airflow-core/src/airflow/models/pool.py::11
+airflow-core/src/airflow/models/renderedtifields.py::4
+airflow-core/src/airflow/models/revoked_token.py::2
+airflow-core/src/airflow/models/serialized_dag.py::6
+airflow-core/src/airflow/models/taskinstance.py::21
+airflow-core/src/airflow/models/taskinstancehistory.py::2
+airflow-core/src/airflow/models/team.py::1
+airflow-core/src/airflow/models/trigger.py::7
+airflow-core/src/airflow/models/variable.py::2
+airflow-core/src/airflow/secrets/metastore.py::2
+airflow-core/src/airflow/serialization/definitions/dag.py::2
+airflow-core/src/airflow/ti_deps/deps/base_ti_dep.py::2
+airflow-core/src/airflow/ti_deps/deps/dag_ti_slots_available_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/dag_unpaused_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/dagrun_exists_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/exec_date_after_start_date_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/not_in_retry_period_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/pool_slots_available_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/ready_to_reschedule.py::1
+airflow-core/src/airflow/ti_deps/deps/runnable_exec_date_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/task_concurrency_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/task_not_running_dep.py::1
+airflow-core/src/airflow/ti_deps/deps/valid_state_dep.py::1
+airflow-core/src/airflow/utils/cli_action_loggers.py::1
+airflow-core/src/airflow/utils/db.py::7
+airflow-core/src/airflow/utils/db_cleanup.py::2
+airflow-core/src/airflow/utils/log/file_task_handler.py::1
+airflow-core/tests/unit/api_fastapi/common/test_exceptions.py::4
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py::19
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py::2
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py::1
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_tags.py::1
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_warning.py::1
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_event_logs.py::1
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_import_error.py::8
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_job.py::1
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_monitor.py::2
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py::2
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py::2
+airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py::3
+airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_calendar.py::2
+airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py::1
+airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_gantt.py::1
+airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py::1
+airflow-core/tests/unit/cli/commands/test_rotate_fernet_key_command.py::2
+airflow-core/tests/unit/jobs/test_scheduler_job.py::1
+airflow-core/tests/unit/listeners/test_listeners.py::7
+airflow-core/tests/unit/models/test_taskinstance.py::4
+airflow-core/tests/unit/models/test_timestamp.py::2
+providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py::1
+providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py::1
+providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/template_rendering.py::1
+providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py::1
+providers/common/ai/tests/unit/common/ai/plugins/test_hitl_review.py::1
+providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py::2
+providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py::3
+providers/edge3/src/airflow/providers/edge3/models/edge_worker.py::10
+providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py::1
+providers/edge3/src/airflow/providers/edge3/worker_api/routes/logs.py::1
+providers/fab/src/airflow/providers/fab/auth_manager/cli_commands/permissions_command.py::1
+providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py::1
+providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py::3
+providers/openlineage/src/airflow/providers/openlineage/utils/utils.py::1
+providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py::1
+providers/standard/src/airflow/providers/standard/sensors/external_task.py::1
+providers/standard/src/airflow/providers/standard/utils/sensor_helper.py::1
+providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py::3
diff --git a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py
b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py
new file mode 100644
index 00000000000..78b85cd270b
--- /dev/null
+++ b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py
@@ -0,0 +1,482 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import ast
+import os
+import subprocess
+import textwrap
+from pathlib import Path
+
+import pytest
+from ci.prek import check_provide_session_kwargs as hook
+from ci.prek.check_provide_session_kwargs import (
+ AllowlistManager,
+ _check_provide_session_kwargs,
+ _count_violations,
+ _expand_for_allowlist_edits,
+ _has_provide_session_decorator,
+ _iter_positional_session_in_provide_session,
+ _parse_tracked_allowlist,
+ _session_is_positional,
+)
+
+
[email protected]
+def find_violations(write_python_file):
+ """Factory fixture: write code to a temp file and return
positional-session violations."""
+
+ def _check(code: str) -> list[tuple[ast.FunctionDef |
ast.AsyncFunctionDef, ast.arg]]:
+ path = write_python_file(code)
+ return list(_iter_positional_session_in_provide_session(path))
+
+ return _check
+
+
[email protected]
+def create_fake_repo(tmp_path, monkeypatch):
+ """Create a fake repo layout and patch REPO_ROOT so paths resolve
correctly."""
+ monkeypatch.setattr(hook, "REPO_ROOT", tmp_path)
+
+ def _write(rel: str, code: str) -> Path:
+ path = tmp_path / rel
+ path.parent.mkdir(parents=True, exist_ok=True)
+ path.write_text(textwrap.dedent(code))
+ return path
+
+ return _write
+
+
[email protected]
+def create_git_repo(create_fake_repo, tmp_path):
+ """Initialise ``tmp_path`` as a git repo so ``git show HEAD:<file>`` works.
+
+ Returns a helper that commits the current working-tree contents under a
given
+ message, so tests can stage a "previous" allowlist at HEAD before mutating
it.
+ """
+ env = {
+ **os.environ,
+ "GIT_AUTHOR_NAME": "t",
+ "GIT_AUTHOR_EMAIL": "t@t",
+ "GIT_COMMITTER_NAME": "t",
+ "GIT_COMMITTER_EMAIL": "t@t",
+ }
+
+ def _run(*args: str) -> None:
+ subprocess.run(["git", "-C", str(tmp_path), *args], check=True,
env=env, capture_output=True)
+
+ _run("init", "-q", "-b", "main")
+ _run("config", "commit.gpgsign", "false")
+
+ def _commit(message: str) -> None:
+ _run("add", "-A")
+ _run("commit", "-q", "--allow-empty", "-m", message)
+
+ return _commit
+
+
+class TestHasProvideSessionDecorator:
+ def test_provide_session_name(self):
+ func = ast.parse("@provide_session\ndef foo(): pass").body[0]
+ assert _has_provide_session_decorator(func.decorator_list) is True
+
+ def test_provide_session_attribute(self):
+ func = ast.parse("@utils.provide_session\ndef foo(): pass").body[0]
+ assert _has_provide_session_decorator(func.decorator_list) is True
+
+ def test_no_decorator(self):
+ func = ast.parse("def foo(): pass").body[0]
+ assert _has_provide_session_decorator(func.decorator_list) is False
+
+ def test_unrelated_decorator(self):
+ func = ast.parse("@staticmethod\ndef foo(): pass").body[0]
+ assert _has_provide_session_decorator(func.decorator_list) is False
+
+ def test_multiple_decorators_including_provide_session(self):
+ func = ast.parse("@staticmethod\n@provide_session\ndef foo():
pass").body[0]
+ assert _has_provide_session_decorator(func.decorator_list) is True
+
+
+class TestSessionIsPositional:
+ def test_no_session_arg(self):
+ func = ast.parse("def foo(x, y): pass").body[0]
+ assert _session_is_positional(func.args) is None
+
+ def test_session_positional(self):
+ func = ast.parse("def foo(session=NEW_SESSION): pass").body[0]
+ argument = _session_is_positional(func.args)
+ assert argument is not None
+ assert argument.arg == "session"
+
+ def test_session_keyword_only(self):
+ func = ast.parse("def foo(*, session=NEW_SESSION): pass").body[0]
+ assert _session_is_positional(func.args) is None
+
+ def test_session_positional_among_other_args(self):
+ func = ast.parse("def foo(x, y, session=NEW_SESSION): pass").body[0]
+ argument = _session_is_positional(func.args)
+ assert argument is not None
+ assert argument.arg == "session"
+
+ def test_session_kwonly_after_other_positional(self):
+ func = ast.parse("def foo(x, y, *, session=NEW_SESSION): pass").body[0]
+ assert _session_is_positional(func.args) is None
+
+ def test_session_positional_only(self):
+ func = ast.parse("def foo(session, /, x): pass").body[0]
+ argument = _session_is_positional(func.args)
+ assert argument is not None
+ assert argument.arg == "session"
+
+
+class TestIterPositionalSessionInProvideSession:
+ def test_keyword_only_session_is_clean(self, find_violations):
+ code = """\
+ @provide_session
+ def foo(*, session=NEW_SESSION):
+ pass
+ """
+ assert find_violations(code) == []
+
+ def test_positional_session_is_flagged(self, find_violations):
+ code = """\
+ @provide_session
+ def foo(session=NEW_SESSION):
+ pass
+ """
+ violations = find_violations(code)
+ assert len(violations) == 1
+ func, argument = violations[0]
+ assert func.name == "foo"
+ assert argument.arg == "session"
+
+ def test_no_provide_session_decorator_is_ignored(self, find_violations):
+ code = """\
+ def foo(session=NEW_SESSION):
+ pass
+ """
+ assert find_violations(code) == []
+
+ def test_async_function_with_positional_session_is_flagged(self,
find_violations):
+ code = """\
+ @provide_session
+ async def foo(session=NEW_SESSION):
+ pass
+ """
+ violations = find_violations(code)
+ assert len(violations) == 1
+
+ def test_method_with_positional_session_is_flagged(self, find_violations):
+ code = """\
+ class C:
+ @provide_session
+ def foo(self, session=NEW_SESSION):
+ pass
+ """
+ violations = find_violations(code)
+ assert len(violations) == 1
+ assert violations[0][0].name == "foo"
+
+ def test_attribute_decorator_is_recognised(self, find_violations):
+ code = """\
+ @airflow.utils.session.provide_session
+ def foo(session=NEW_SESSION):
+ pass
+ """
+ violations = find_violations(code)
+ assert len(violations) == 1
+
+ def test_count_violations_multiple_in_file(self, write_python_file):
+ code = """\
+ @provide_session
+ def a(session=NEW_SESSION):
+ pass
+
+ @provide_session
+ def b(x, session=NEW_SESSION):
+ pass
+
+ @provide_session
+ def c(*, session=NEW_SESSION):
+ pass
+ """
+ path = write_python_file(code)
+ assert _count_violations(path) == 2
+
+ def test_syntax_error_returns_no_violations(self, write_python_file):
+ path = write_python_file("def foo(:\n pass")
+ assert _count_violations(path) == 0
+
+ def test_invalid_utf8_does_not_crash(self, tmp_path):
+ path = tmp_path / "invalid_utf8.py"
+ path.write_bytes(b"# bad byte: \xff\n@provide_session\ndef
foo(session=NEW_SESSION):\n pass\n")
+
+ assert _count_violations(path) == 1
+
+
+class TestAllowlistManager:
+ def test_load_missing_file_returns_empty(self, tmp_path):
+ manager = AllowlistManager(tmp_path / "missing.txt")
+ assert manager.load() == {}
+
+ def test_save_and_load_round_trip(self, tmp_path):
+ manager = AllowlistManager(tmp_path / "allowlist.txt")
+ manager.save({"b/file.py": 2, "a/file.py": 1})
+ # Sorted by key in the file
+ text = (tmp_path / "allowlist.txt").read_text()
+ assert text.splitlines() == ["a/file.py::1", "b/file.py::2"]
+ assert manager.load() == {"a/file.py": 1, "b/file.py": 2}
+
+ def test_load_skips_blank_and_malformed_lines(self, tmp_path):
+ path = tmp_path / "allowlist.txt"
+ path.write_text("\nvalid/file.py::3\nnocount\n::5\nbad::notanumber\n")
+ assert AllowlistManager(path).load() == {"valid/file.py": 3}
+
+ @pytest.mark.usefixtures("create_fake_repo")
+ def test_load_skips_unsafe_entries(self, tmp_path):
+ """Entries that escape REPO_ROOT (absolute paths or `..` segments) are
ignored."""
+ path = tmp_path / "allowlist.txt"
+
path.write_text("airflow-core/src/airflow/safe.py::1\n../escape.py::1\n/etc/passwd::1\n")
+ # `create_fake_repo` patches REPO_ROOT to tmp_path so the safety check
is meaningful.
+ assert AllowlistManager(path).load() ==
{"airflow-core/src/airflow/safe.py": 1}
+
+
+class TestCheckProvideSessionKwargs:
+ def test_no_violations_in_clean_file(self, create_fake_repo, tmp_path):
+ path = create_fake_repo(
+ "airflow-core/src/airflow/clean.py",
+ """\
+ @provide_session
+ def foo(*, session=NEW_SESSION):
+ pass
+ """,
+ )
+ manager = AllowlistManager(tmp_path / "allowlist.txt")
+ assert _check_provide_session_kwargs([path], {}, manager) == 0
+
+ def test_new_violation_fails(self, create_fake_repo, tmp_path):
+ path = create_fake_repo(
+ "airflow-core/src/airflow/bad.py",
+ """\
+ @provide_session
+ def foo(session=NEW_SESSION):
+ pass
+ """,
+ )
+ manager = AllowlistManager(tmp_path / "allowlist.txt")
+ assert _check_provide_session_kwargs([path], {}, manager) == 1
+
+ def test_violation_within_allowlist_passes(self, create_fake_repo,
tmp_path):
+ path = create_fake_repo(
+ "airflow-core/src/airflow/grandfathered.py",
+ """\
+ @provide_session
+ def foo(session=NEW_SESSION):
+ pass
+ """,
+ )
+ manager = AllowlistManager(tmp_path / "allowlist.txt")
+ allowlist = {"airflow-core/src/airflow/grandfathered.py": 1}
+ assert _check_provide_session_kwargs([path], allowlist, manager) == 0
+
+ def test_exceeding_allowlist_fails(self, create_fake_repo, tmp_path):
+ path = create_fake_repo(
+ "airflow-core/src/airflow/grew.py",
+ """\
+ @provide_session
+ def a(session=NEW_SESSION):
+ pass
+
+ @provide_session
+ def b(session=NEW_SESSION):
+ pass
+ """,
+ )
+ manager = AllowlistManager(tmp_path / "allowlist.txt")
+ allowlist = {"airflow-core/src/airflow/grew.py": 1}
+ assert _check_provide_session_kwargs([path], allowlist, manager) == 1
+
+ def test_reducing_violations_tightens_allowlist(self, create_fake_repo,
tmp_path):
+ path = create_fake_repo(
+ "airflow-core/src/airflow/improved.py",
+ """\
+ @provide_session
+ def foo(session=NEW_SESSION):
+ pass
+
+ @provide_session
+ def bar(*, session=NEW_SESSION):
+ pass
+ """,
+ )
+ manager = AllowlistManager(tmp_path / "allowlist.txt")
+ allowlist = {"airflow-core/src/airflow/improved.py": 2}
+ # Exit non-zero so pre-commit reports the modified allowlist
+ assert _check_provide_session_kwargs([path], allowlist, manager) == 1
+ assert manager.load() == {"airflow-core/src/airflow/improved.py": 1}
+
+ def test_fixing_all_violations_removes_entry(self, create_fake_repo,
tmp_path):
+ path = create_fake_repo(
+ "airflow-core/src/airflow/fixed.py",
+ """\
+ @provide_session
+ def foo(*, session=NEW_SESSION):
+ pass
+ """,
+ )
+ manager = AllowlistManager(tmp_path / "allowlist.txt")
+ allowlist = {"airflow-core/src/airflow/fixed.py": 1}
+ assert _check_provide_session_kwargs([path], allowlist, manager) == 1
+ assert manager.load() == {}
+
+ def test_non_python_file_is_skipped(self, create_fake_repo, tmp_path):
+ path = create_fake_repo(
+ "airflow-core/src/airflow/not_python.txt", "@provide_session\ndef
foo(session=N): pass\n"
+ )
+ manager = AllowlistManager(tmp_path / "allowlist.txt")
+ assert _check_provide_session_kwargs([path], {}, manager) == 0
+
+ @pytest.mark.usefixtures("create_fake_repo")
+ def test_missing_allowlist_file_fails_loudly(self, tmp_path):
+ """Passing the allowlist path when the file is missing must fail, not
silently pass."""
+ allowlist_path = tmp_path / "allowlist.txt"
+ manager = AllowlistManager(allowlist_path)
+ assert not allowlist_path.exists()
+ assert _check_provide_session_kwargs([allowlist_path.resolve()], {},
manager) == 1
+
+
+class TestExpandForAllowlistEdits:
+ def test_unchanged_when_allowlist_not_in_paths(self, create_fake_repo,
tmp_path):
+ py = create_fake_repo("airflow-core/src/airflow/x.py", "pass")
+ manager = AllowlistManager(tmp_path / "allowlist.txt")
+ assert _expand_for_allowlist_edits([py], manager,
{"airflow-core/src/airflow/x.py": 1}) == [py]
+
+ def test_appends_allowlisted_files_when_allowlist_edited(self,
create_fake_repo, tmp_path):
+ allowlist_path = tmp_path / "allowlist.txt"
+ manager = AllowlistManager(allowlist_path)
+ listed = create_fake_repo("airflow-core/src/airflow/listed.py", "pass")
+ # Pass a resolved path — matches production behavior (``main()``
resolves argv).
+ result = _expand_for_allowlist_edits(
+ [allowlist_path.resolve()],
+ manager,
+ {"airflow-core/src/airflow/listed.py": 1,
"airflow-core/src/airflow/gone.py": 1},
+ )
+ assert allowlist_path.resolve() in result
+ assert listed in result
+ # File in allowlist that does not exist on disk should be ignored.
+ assert (tmp_path / "airflow-core/src/airflow/gone.py").resolve() not
in result
+
+ def test_detection_robust_to_symlinked_allowlist(self, create_fake_repo,
tmp_path):
+ """A symlink pointing at the allowlist file must still trigger
expansion."""
+ allowlist_path = tmp_path / "allowlist.txt"
+ manager = AllowlistManager(allowlist_path)
+ listed = create_fake_repo("airflow-core/src/airflow/listed.py", "pass")
+ manager.save({"airflow-core/src/airflow/listed.py": 1})
+
+ symlink = tmp_path / "allowlist_link.txt"
+ symlink.symlink_to(allowlist_path)
+
+ # Production resolves argv before calling the helper — a symlinked
path resolves
+ # to the real allowlist file and must be recognised as an allowlist
edit.
+ result = _expand_for_allowlist_edits([symlink.resolve()], manager,
manager.load())
+
+ assert listed in result
+
+ def test_includes_parse_tracked_allowlist_entries_when_removed(
+ self, create_fake_repo, create_git_repo, tmp_path
+ ):
+ """Removing an entry from the allowlist must still re-check the
previously-listed file."""
+ rel = "airflow-core/src/airflow/dropped.py"
+ create_fake_repo(
+ rel,
+ """\
+ @provide_session
+ def foo(session=NEW_SESSION):
+ pass
+ """,
+ )
+ allowlist_path = tmp_path / "allowlist.txt"
+ manager = AllowlistManager(allowlist_path)
+ manager.save({rel: 1})
+ create_git_repo("seed allowlist at HEAD")
+
+ # Working tree: remove the entry, but the offending file still exists.
+ allowlist_path.write_text("")
+ current = manager.load()
+ assert current == {}
+
+ expanded = _expand_for_allowlist_edits([allowlist_path.resolve()],
manager, current)
+ # The previously-listed file must be re-validated.
+ assert (tmp_path / rel).resolve() in expanded
+
+ # And the full check should fail because the file still has positional
sessions.
+ assert _check_provide_session_kwargs(expanded, current, manager) == 1
+
+ @pytest.mark.usefixtures("create_fake_repo")
+ def test_parse_tracked_allowlist_empty_when_no_git_history(self, tmp_path):
+ """Without a git repo the git-tracked allowlist lookup returns empty
and does not crash."""
+ manager = AllowlistManager(tmp_path / "allowlist.txt")
+ assert _parse_tracked_allowlist(manager) == {}
+
+ def test_re_validates_listed_files_so_loosening_cannot_bypass(self,
create_fake_repo, tmp_path, capsys):
+ """Editing only the allowlist must still trigger validation of listed
files."""
+ rel = "airflow-core/src/airflow/loosened.py"
+ create_fake_repo(
+ rel,
+ """\
+ @provide_session
+ def foo(session=NEW_SESSION):
+ pass
+
+ @provide_session
+ def bar(session=NEW_SESSION):
+ pass
+ """,
+ )
+ allowlist_path = tmp_path / "allowlist.txt"
+ manager = AllowlistManager(allowlist_path)
+ # Allowlist loosened to 5 although file only has 2 positional sessions.
+ allowlist = {rel: 5}
+ manager.save(allowlist)
+
+ # Only the allowlist file is "changed"; without re-validation this
would return 0.
+ # Resolve the path to mirror what ``main()`` does in production.
+ paths = _expand_for_allowlist_edits([allowlist_path.resolve()],
manager, allowlist)
+ rc = _check_provide_session_kwargs(paths, allowlist, manager)
+
+ # Tightened from 5 -> 2, so the hook exits non-zero to surface the
modified allowlist.
+ assert rc == 1
+ assert manager.load() == {rel: 2}
+
+
+class TestCleanup:
+ def test_cleanup_removes_stale_entries(self, create_fake_repo, tmp_path):
+ create_fake_repo("airflow-core/src/airflow/keeper.py", "pass")
+ allowlist_path = tmp_path / "allowlist.txt"
+ manager = AllowlistManager(allowlist_path)
+ manager.save(
+ {
+ "airflow-core/src/airflow/keeper.py": 1,
+ "airflow-core/src/airflow/gone.py": 1,
+ }
+ )
+ assert manager.cleanup() == 0
+ assert manager.load() == {"airflow-core/src/airflow/keeper.py": 1}
+
+ def test_cleanup_empty_allowlist(self, tmp_path):
+ manager = AllowlistManager(tmp_path / "allowlist.txt")
+ assert manager.cleanup() == 0