This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5b0b830eed Add better typing in the db clean utils (#42341)
5b0b830eed is described below

commit 5b0b830eed5e15aef7541ff06d0f8160f978b30a
Author: Jed Cunningham <[email protected]>
AuthorDate: Thu Sep 19 08:32:12 2024 -0600

    Add better typing in the db clean utils (#42341)
---
 airflow/utils/db_cleanup.py | 58 ++++++++++++++++++++++++---------------------
 1 file changed, 31 insertions(+), 27 deletions(-)

diff --git a/airflow/utils/db_cleanup.py b/airflow/utils/db_cleanup.py
index f053b32070..f98ef40c0a 100644
--- a/airflow/utils/db_cleanup.py
+++ b/airflow/utils/db_cleanup.py
@@ -129,7 +129,7 @@ if conf.get("webserver", "session_backend") == "database":
 config_dict: dict[str, _TableConfig] = {x.orm_model.name: x for x in 
sorted(config_list)}
 
 
-def _check_for_rows(*, query: Query, print_rows=False):
+def _check_for_rows(*, query: Query, print_rows: bool = False) -> int:
     num_entities = query.count()
     print(f"Found {num_entities} rows meeting deletion criteria.")
     if print_rows:
@@ -142,7 +142,7 @@ def _check_for_rows(*, query: Query, print_rows=False):
     return num_entities
 
 
-def _dump_table_to_file(*, target_table, file_path, export_format, session):
+def _dump_table_to_file(*, target_table: str, file_path: str, export_format: 
str, session: Session) -> None:
     if export_format == "csv":
         with open(file_path, "w") as f:
             csv_writer = csv.writer(f)
@@ -153,7 +153,7 @@ def _dump_table_to_file(*, target_table, file_path, 
export_format, session):
         raise AirflowException(f"Export format {export_format} is not 
supported.")
 
 
-def _do_delete(*, query, orm_model, skip_archive, session):
+def _do_delete(*, query: Query, orm_model: Base, skip_archive: bool, session: 
Session) -> None:
     import re2
 
     print("Performing Delete...")
@@ -203,7 +203,9 @@ def _do_delete(*, query, orm_model, skip_archive, session):
     print("Finished Performing Delete")
 
 
-def _subquery_keep_last(*, recency_column, keep_last_filters, 
group_by_columns, max_date_colname, session):
+def _subquery_keep_last(
+    *, recency_column, keep_last_filters, group_by_columns, max_date_colname, 
session: Session
+) -> Query:
     subquery = select(*group_by_columns, 
func.max(recency_column).label(max_date_colname))
 
     if keep_last_filters is not None:
@@ -238,10 +240,10 @@ def _build_query(
     keep_last,
     keep_last_filters,
     keep_last_group_by,
-    clean_before_timestamp,
-    session,
+    clean_before_timestamp: DateTime,
+    session: Session,
     **kwargs,
-):
+) -> Query:
     base_table_alias = "base"
     base_table = aliased(orm_model, name=base_table_alias)
     query = 
session.query(base_table).with_entities(text(f"{base_table_alias}.*"))
@@ -276,13 +278,13 @@ def _cleanup_table(
     keep_last,
     keep_last_filters,
     keep_last_group_by,
-    clean_before_timestamp,
-    dry_run=True,
-    verbose=False,
-    skip_archive=False,
-    session,
+    clean_before_timestamp: DateTime,
+    dry_run: bool = True,
+    verbose: bool = False,
+    skip_archive: bool = False,
+    session: Session,
     **kwargs,
-):
+) -> None:
     print()
     if dry_run:
         print(f"Performing dry run for table {orm_model.name}")
@@ -305,7 +307,7 @@ def _cleanup_table(
     session.commit()
 
 
-def _confirm_delete(*, date: DateTime, tables: list[str]):
+def _confirm_delete(*, date: DateTime, tables: list[str]) -> None:
     for_tables = f" for tables {tables!r}" if tables else ""
     question = (
         f"You have requested that we purge all data prior to 
{date}{for_tables}.\n"
@@ -319,7 +321,7 @@ def _confirm_delete(*, date: DateTime, tables: list[str]):
         raise SystemExit("User did not confirm; exiting.")
 
 
-def _confirm_drop_archives(*, tables: list[str]):
+def _confirm_drop_archives(*, tables: list[str]) -> None:
     # if length of tables is greater than 3, show the total count
     if len(tables) > 3:
         text_ = f"{len(tables)} archived tables prefixed with 
{ARCHIVE_TABLE_PREFIX}"
@@ -341,13 +343,13 @@ def _confirm_drop_archives(*, tables: list[str]):
         raise SystemExit("User did not confirm; exiting.")
 
 
-def _print_config(*, configs: dict[str, _TableConfig]):
+def _print_config(*, configs: dict[str, _TableConfig]) -> None:
     data = [x.readable_config for x in configs.values()]
     AirflowConsole().print_as_table(data=data)
 
 
 @contextmanager
-def _suppress_with_logging(table, session):
+def _suppress_with_logging(table: str, session: Session):
     """
     Suppresses errors but logs them.
 
@@ -363,7 +365,7 @@ def _suppress_with_logging(table, session):
             session.rollback()
 
 
-def _effective_table_names(*, table_names: list[str] | None):
+def _effective_table_names(*, table_names: list[str] | None) -> 
tuple[set[str], dict[str, _TableConfig]]:
     desired_table_names = set(table_names or config_dict)
     effective_config_dict = {k: v for k, v in config_dict.items() if k in 
desired_table_names}
     effective_table_names = set(effective_config_dict)
@@ -377,7 +379,7 @@ def _effective_table_names(*, table_names: list[str] | 
None):
     return effective_table_names, effective_config_dict
 
 
-def _get_archived_table_names(table_names, session):
+def _get_archived_table_names(table_names: list[str] | None, session: Session) 
-> list[str]:
     inspector = inspect(session.bind)
     db_table_names = [x for x in inspector.get_table_names() if 
x.startswith(ARCHIVE_TABLE_PREFIX)]
     effective_table_names, _ = _effective_table_names(table_names=table_names)
@@ -400,7 +402,7 @@ def run_cleanup(
     confirm: bool = True,
     skip_archive: bool = False,
     session: Session = NEW_SESSION,
-):
+) -> None:
     """
     Purges old records in airflow metadata database.
 
@@ -450,13 +452,13 @@ def run_cleanup(
 
 @provide_session
 def export_archived_records(
-    export_format,
-    output_path,
-    table_names=None,
-    drop_archives=False,
-    needs_confirm=True,
+    export_format: str,
+    output_path: str,
+    table_names: list[str] | None = None,
+    drop_archives: bool = False,
+    needs_confirm: bool = True,
     session: Session = NEW_SESSION,
-):
+) -> None:
     """Export archived data to the given output path in the given format."""
     archived_table_names = _get_archived_table_names(table_names, session)
     # If user chose to drop archives, check there are archive tables that 
exists
@@ -482,7 +484,9 @@ def export_archived_records(
 
 
 @provide_session
-def drop_archived_tables(table_names, needs_confirm, session):
+def drop_archived_tables(
+    table_names: list[str] | None, needs_confirm: bool, session: Session = 
NEW_SESSION
+) -> None:
     """Drop archived tables."""
     archived_table_names = _get_archived_table_names(table_names, session)
     if needs_confirm and archived_table_names:

Reply via email to