This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c75774d  Add `db clean` CLI command for purging old data (#20838)
c75774d is described below

commit c75774d3a31efe749f55ba16e782737df9f53af4
Author: Daniel Standish <[email protected]>
AuthorDate: Fri Feb 25 21:55:21 2022 -0800

    Add `db clean` CLI command for purging old data (#20838)
    
    CLI command to delete old rows from airflow metadata database.
    Notes:
    * Must supply "purge before date".
    * Can optionally provide table list.
    * Dry run will only print the number of rows meeting criteria.
    * If not dry run, will require the user to confirm before deleting.
---
 airflow/cli/cli_parser.py             |  45 ++++-
 airflow/cli/commands/db_command.py    |  17 ++
 airflow/utils/db_cleanup.py           | 315 ++++++++++++++++++++++++++++++++++
 docs/spelling_wordlist.txt            |   2 +
 tests/cli/commands/test_db_command.py | 148 ++++++++++++++++
 tests/utils/test_db_cleanup.py        | 265 ++++++++++++++++++++++++++++
 6 files changed, 791 insertions(+), 1 deletion(-)

diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py
index c91ae7a..cb0bcd2 100644
--- a/airflow/cli/cli_parser.py
+++ b/airflow/cli/cli_parser.py
@@ -26,6 +26,8 @@ from argparse import Action, ArgumentError, 
RawTextHelpFormatter
 from functools import lru_cache
 from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Union
 
+import lazy_object_proxy
+
 from airflow import PY37, settings
 from airflow.cli.commands.legacy_commands import check_legacy_command
 from airflow.configuration import conf
@@ -162,6 +164,11 @@ def positive_int(*, allow_zero):
     return _check
 
 
+def string_list_type(val):
+    """Parses comma-separated list and returns list of string (strips 
whitespace)"""
+    return [x.strip() for x in val.split(',')]
+
+
 # Shared
 ARG_DAG_ID = Arg(("dag_id",), help="The id of the dag")
 ARG_TASK_ID = Arg(("task_id",), help="The id of the task")
@@ -205,7 +212,7 @@ ARG_STDERR = Arg(("--stderr",), help="Redirect stderr to 
this file")
 ARG_STDOUT = Arg(("--stdout",), help="Redirect stdout to this file")
 ARG_LOG_FILE = Arg(("-l", "--log-file"), help="Location of the log file")
 ARG_YES = Arg(
-    ("-y", "--yes"), help="Do not prompt to confirm reset. Use with care!", 
action="store_true", default=False
+    ("-y", "--yes"), help="Do not prompt to confirm. Use with care!", 
action="store_true", default=False
 )
 ARG_OUTPUT = Arg(
     (
@@ -398,6 +405,30 @@ ARG_RUN_ID = Arg(("-r", "--run-id"), help="Helps to 
identify this run")
 ARG_CONF = Arg(('-c', '--conf'), help="JSON string that gets pickled into the 
DagRun's conf attribute")
 ARG_EXEC_DATE = Arg(("-e", "--exec-date"), help="The execution date of the 
DAG", type=parsedate)
 
+# db
+ARG_DB_TABLES = Arg(
+    ("-t", "--tables"),
+    help=lazy_object_proxy.Proxy(
+        lambda: f"Table names to perform maintenance on (use comma-separated 
list).\n"
+        f"Options: 
{import_string('airflow.cli.commands.db_command.all_tables')}"
+    ),
+    type=string_list_type,
+)
+ARG_DB_CLEANUP_TIMESTAMP = Arg(
+    ("--clean-before-timestamp",),
+    help="The date or timestamp before which data should be purged.\n"
+    "If no timezone info is supplied then dates are assumed to be in airflow 
default timezone.\n"
+    "Example: '2022-01-01 00:00:00+01:00'",
+    type=parsedate,
+    required=True,
+)
+ARG_DB_DRY_RUN = Arg(
+    ("--dry-run",),
+    help="Perform a dry run",
+    action="store_true",
+)
+
+
 # pool
 ARG_POOL_NAME = Arg(("pool",), metavar='NAME', help="Pool name")
 ARG_POOL_SLOTS = Arg(("slots",), type=int, help="Pool slots")
@@ -1308,6 +1339,18 @@ DB_COMMANDS = (
         func=lazy_load_command('airflow.cli.commands.db_command.check'),
         args=(),
     ),
+    ActionCommand(
+        name='clean',
+        help="Purge old records in metastore tables",
+        
func=lazy_load_command('airflow.cli.commands.db_command.cleanup_tables'),
+        args=(
+            ARG_DB_TABLES,
+            ARG_DB_DRY_RUN,
+            ARG_DB_CLEANUP_TIMESTAMP,
+            ARG_VERBOSE,
+            ARG_YES,
+        ),
+    ),
 )
 CONNECTIONS_COMMANDS = (
     ActionCommand(
diff --git a/airflow/cli/commands/db_command.py 
b/airflow/cli/commands/db_command.py
index 09fe221..02811d0 100644
--- a/airflow/cli/commands/db_command.py
+++ b/airflow/cli/commands/db_command.py
@@ -22,6 +22,7 @@ from tempfile import NamedTemporaryFile
 from airflow import settings
 from airflow.exceptions import AirflowException
 from airflow.utils import cli as cli_utils, db
+from airflow.utils.db_cleanup import config_dict, run_cleanup
 from airflow.utils.process_utils import execute_interactive
 
 
@@ -101,3 +102,19 @@ def shell(args):
 def check(_):
     """Runs a check command that checks if db is available."""
     db.check()
+
+
+# lazily imported by CLI parser for `help` command
+all_tables = sorted(config_dict)
+
+
+@cli_utils.action_cli(check_db=False)
+def cleanup_tables(args):
+    """Purges old records in metadata database"""
+    run_cleanup(
+        table_names=args.tables,
+        dry_run=args.dry_run,
+        clean_before_timestamp=args.clean_before_timestamp,
+        verbose=args.verbose,
+        confirm=not args.yes,
+    )
diff --git a/airflow/utils/db_cleanup.py b/airflow/utils/db_cleanup.py
new file mode 100644
index 0000000..d1449bd
--- /dev/null
+++ b/airflow/utils/db_cleanup.py
@@ -0,0 +1,315 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+This module took inspiration from the community maintenance dag
+(https://github.com/teamclairvoyant/airflow-maintenance-dags/blob/4e5c7682a808082561d60cbc9cafaa477b0d8c65/db-cleanup/airflow-db-cleanup.py).
+"""
+
+import logging
+from contextlib import AbstractContextManager
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+from pendulum import DateTime
+from sqlalchemy import and_, false, func
+from sqlalchemy.exc import OperationalError
+
+from airflow.cli.simple_table import AirflowConsole
+from airflow.jobs.base_job import BaseJob
+from airflow.models import (
+    Base,
+    DagModel,
+    DagRun,
+    ImportError,
+    Log,
+    RenderedTaskInstanceFields,
+    SensorInstance,
+    SlaMiss,
+    TaskFail,
+    TaskInstance,
+    TaskReschedule,
+    XCom,
+)
+from airflow.utils import timezone
+from airflow.utils.session import NEW_SESSION, provide_session
+
+if TYPE_CHECKING:
+    from sqlalchemy.orm import Query, Session
+    from sqlalchemy.orm.attributes import InstrumentedAttribute
+    from sqlalchemy.sql.schema import Column
+
+
+@dataclass
+class _TableConfig:
+    """
+    Config class for performing cleanup on a table
+
+    :param orm_model: the table
+    :param recency_column: date column to filter by
+    :param keep_last: whether the last record should be kept even if it's 
older than clean_before_timestamp
+    :param keep_last_filters: the "keep last" functionality will preserve the 
most recent record
+        in the table.  to ignore certain records even if they are the latest 
in the table, you can
+        supply additional filters here (e.g. externally triggered dag runs)
+    :param keep_last_group_by: if keeping the last record, can keep the last 
record for each group
+    :param warn_if_missing: If True, then we'll suppress "table missing" 
exception and log a warning.
+        If False then the exception will go uncaught.
+    """
+
+    orm_model: Base
+    recency_column: Union["Column", "InstrumentedAttribute"]
+    keep_last: bool = False
+    keep_last_filters: Optional[Any] = None
+    keep_last_group_by: Optional[Any] = None
+    warn_if_missing: bool = False
+
+    def __lt__(self, other):
+        return self.orm_model.__tablename__ < other.orm_model.__tablename__
+
+    @property
+    def readable_config(self):
+        return dict(
+            table=self.orm_model.__tablename__,
+            recency_column=str(self.recency_column),
+            keep_last=self.keep_last,
+            keep_last_filters=[str(x) for x in self.keep_last_filters] if 
self.keep_last_filters else None,
+            keep_last_group_by=str(self.keep_last_group_by),
+            warn_if_missing=str(self.warn_if_missing),
+        )
+
+
+config_list: List[_TableConfig] = [
+    _TableConfig(orm_model=BaseJob, recency_column=BaseJob.latest_heartbeat),
+    _TableConfig(orm_model=DagModel, recency_column=DagModel.last_parsed_time),
+    _TableConfig(
+        orm_model=DagRun,
+        recency_column=DagRun.start_date,
+        keep_last=True,
+        keep_last_filters=[DagRun.external_trigger == false()],
+        keep_last_group_by=DagRun.dag_id,
+    ),
+    _TableConfig(orm_model=ImportError, recency_column=ImportError.timestamp),
+    _TableConfig(orm_model=Log, recency_column=Log.dttm),
+    _TableConfig(
+        orm_model=RenderedTaskInstanceFields, 
recency_column=RenderedTaskInstanceFields.execution_date
+    ),
+    _TableConfig(
+        orm_model=SensorInstance, recency_column=SensorInstance.updated_at
+    ),  # TODO: add FK to task instance / dag so we can remove here
+    _TableConfig(orm_model=SlaMiss, recency_column=SlaMiss.timestamp),
+    _TableConfig(orm_model=TaskFail, recency_column=TaskFail.start_date),
+    _TableConfig(orm_model=TaskInstance, 
recency_column=TaskInstance.start_date),
+    _TableConfig(orm_model=TaskReschedule, 
recency_column=TaskReschedule.start_date),
+    _TableConfig(orm_model=XCom, recency_column=XCom.timestamp),
+]
+try:
+    from celery.backends.database.models import Task, TaskSet
+
+    config_list.extend(
+        [
+            _TableConfig(orm_model=Task, recency_column=Task.date_done, 
warn_if_missing=True),
+            _TableConfig(orm_model=TaskSet, recency_column=TaskSet.date_done, 
warn_if_missing=True),
+        ]
+    )
+except ImportError:
+    pass
+
+config_dict: Dict[str, _TableConfig] = {x.orm_model.__tablename__: x for x in 
sorted(config_list)}
+
+
+def _print_entities(*, query: "Query", print_rows=False):
+    num_entities = query.count()
+    print(f"Found {num_entities} rows meeting deletion criteria.")
+    if not print_rows:
+        return
+    max_rows_to_print = 100
+    if num_entities > 0:
+        print(f"Printing first {max_rows_to_print} rows.")
+    logger.debug("print entities query: %s", query)
+    for entry in query.limit(max_rows_to_print):
+        print(entry.__dict__)
+
+
+def _do_delete(*, query, session):
+    print("Performing Delete...")
+    # using bulk delete
+    query.delete(synchronize_session=False)
+    session.commit()
+    print("Finished Performing Delete")
+
+
+def _subquery_keep_last(*, recency_column, keep_last_filters, 
keep_last_group_by, session):
+    subquery = session.query(func.max(recency_column))
+
+    if keep_last_filters is not None:
+        for entry in keep_last_filters:
+            subquery = subquery.filter(entry)
+
+    if keep_last_group_by is not None:
+        subquery = subquery.group_by(keep_last_group_by)
+
+    # We nest this subquery to work around a MySQL "table specified twice" 
issue
+    # See https://github.com/teamclairvoyant/airflow-maintenance-dags/issues/41
+    # and 
https://github.com/teamclairvoyant/airflow-maintenance-dags/pull/57/files.
+    subquery = subquery.from_self()
+    return subquery
+
+
+def _build_query(
+    *,
+    orm_model,
+    recency_column,
+    keep_last,
+    keep_last_filters,
+    keep_last_group_by,
+    clean_before_timestamp,
+    session,
+    **kwargs,
+):
+    query = session.query(orm_model)
+    conditions = [recency_column < clean_before_timestamp]
+    if keep_last:
+        subquery = _subquery_keep_last(
+            recency_column=recency_column,
+            keep_last_filters=keep_last_filters,
+            keep_last_group_by=keep_last_group_by,
+            session=session,
+        )
+        conditions.append(recency_column.notin_(subquery))
+    query = query.filter(and_(*conditions))
+    return query
+
+
+logger = logging.getLogger(__file__)
+
+
+def _cleanup_table(
+    *,
+    orm_model,
+    recency_column,
+    keep_last,
+    keep_last_filters,
+    keep_last_group_by,
+    clean_before_timestamp,
+    dry_run=True,
+    verbose=False,
+    session=None,
+    **kwargs,
+):
+    print()
+    if dry_run:
+        print(f"Performing dry run for table {orm_model.__tablename__!r}")
+    query = _build_query(
+        orm_model=orm_model,
+        recency_column=recency_column,
+        keep_last=keep_last,
+        keep_last_filters=keep_last_filters,
+        keep_last_group_by=keep_last_group_by,
+        clean_before_timestamp=clean_before_timestamp,
+        session=session,
+    )
+
+    _print_entities(query=query, print_rows=False)
+
+    if not dry_run:
+        _do_delete(query=query, session=session)
+        session.commit()
+
+
+def _confirm_delete(*, date: DateTime, tables: List[str]):
+    for_tables = f" for tables {tables!r}" if tables else ''
+    question = (
+        f"You have requested that we purge all data prior to 
{date}{for_tables}.\n"
+        f"This is irreversible.  Consider backing up the tables first and / or 
doing a dry run "
+        f"with option --dry-run.\n"
+        f"Enter 'delete rows' (without quotes) to proceed."
+    )
+    print(question)
+    answer = input().strip()
+    if not answer == 'delete rows':
+        raise SystemExit("User did not confirm; exiting.")
+
+
+def _print_config(*, configs: Dict[str, _TableConfig]):
+    data = [x.readable_config for x in configs.values()]
+    AirflowConsole().print_as_table(data=data)
+
+
+class _warn_if_missing(AbstractContextManager):
+    def __init__(self, table, suppress):
+        self.table = table
+        self.suppress = suppress
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exctype, excinst, exctb):
+        caught_error = exctype is not None and issubclass(exctype, 
OperationalError)
+        if caught_error:
+            logger.warning("Table %r not found.  Skipping.", self.table)
+        return caught_error
+
+
+@provide_session
+def run_cleanup(
+    *,
+    clean_before_timestamp: DateTime,
+    table_names: Optional[List[str]] = None,
+    dry_run: bool = False,
+    verbose: bool = False,
+    confirm: bool = True,
+    session: 'Session' = NEW_SESSION,
+):
+    """
+    Purges old records in airflow metadata database.
+
+    The last non-externally-triggered dag run will always be kept in order to 
ensure
+    continuity of scheduled dag runs.
+
+    Where there are foreign key relationships, deletes will cascade, so that 
for
+    example if you clean up old dag runs, the associated task instances will
+    be deleted.
+
+    :param clean_before_timestamp: The timestamp before which data should be 
purged
+    :param table_names: Optional. List of table names to perform maintenance 
on.  If list not provided,
+        will perform maintenance on all tables.
+    :param dry_run: If true, print rows meeting deletion criteria
+    :param verbose: If true, may provide more detailed output.
+    :param confirm: Require user input to confirm before processing deletions.
+    :param session: Session representing connection to the metadata database.
+    """
+    clean_before_timestamp = timezone.coerce_datetime(clean_before_timestamp)
+    effective_table_names = table_names if table_names else 
list(config_dict.keys())
+    effective_config_dict = {k: v for k, v in config_dict.items() if k in 
effective_table_names}
+    if dry_run:
+        print('Performing dry run for db cleanup.')
+        print(
+            f"Data prior to {clean_before_timestamp} would be purged "
+            f"from tables {effective_table_names} with the following config:\n"
+        )
+        _print_config(configs=effective_config_dict)
+    if not dry_run and confirm:
+        _confirm_delete(date=clean_before_timestamp, 
tables=list(effective_config_dict.keys()))
+    for table_name, table_config in effective_config_dict.items():
+        with _warn_if_missing(table_name, table_config.warn_if_missing):
+            _cleanup_table(
+                clean_before_timestamp=clean_before_timestamp,
+                dry_run=dry_run,
+                verbose=verbose,
+                **table_config.__dict__,
+                session=session,
+            )
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 68ecca0..efae25a 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1375,7 +1375,9 @@ tagValue
 task_group
 taskflow
 taskinstance
+taskmeta
 taskmixin
+tasksetmeta
 tblproperties
 tcp
 teardown
diff --git a/tests/cli/commands/test_db_command.py 
b/tests/cli/commands/test_db_command.py
index e94e63a..09cb14a 100644
--- a/tests/cli/commands/test_db_command.py
+++ b/tests/cli/commands/test_db_command.py
@@ -17,7 +17,9 @@
 
 import unittest
 from unittest import mock
+from unittest.mock import patch
 
+import pendulum
 import pytest
 from sqlalchemy.engine.url import make_url
 
@@ -134,3 +136,149 @@ class TestCliDb(unittest.TestCase):
     def test_cli_shell_invalid(self):
         with pytest.raises(AirflowException, match=r"Unknown driver: 
invalid\+psycopg2"):
             db_command.shell(self.parser.parse_args(['db', 'shell']))
+
+
+class TestCLIDBClean:
+    @classmethod
+    def setup_class(cls):
+        cls.parser = cli_parser.get_parser()
+
+    @pytest.mark.parametrize('timezone', ['UTC', 'Europe/Berlin', 
'America/Los_Angeles'])
+    @patch('airflow.cli.commands.db_command.run_cleanup')
+    def test_date_timezone_omitted(self, run_cleanup_mock, timezone):
+        """
+        When timezone omitted we should always expect that the timestamp is
+        coerced to tz-aware with default timezone
+        """
+        timestamp = '2021-01-01 00:00:00'
+        with patch('airflow.utils.timezone.TIMEZONE', 
pendulum.timezone(timezone)):
+            args = self.parser.parse_args(['db', 'clean', 
'--clean-before-timestamp', f"{timestamp}", '-y'])
+            db_command.cleanup_tables(args)
+        run_cleanup_mock.assert_called_once_with(
+            table_names=None,
+            dry_run=False,
+            clean_before_timestamp=pendulum.parse(timestamp, tz=timezone),
+            verbose=False,
+            confirm=False,
+        )
+
+    @pytest.mark.parametrize('timezone', ['UTC', 'Europe/Berlin', 
'America/Los_Angeles'])
+    @patch('airflow.cli.commands.db_command.run_cleanup')
+    def test_date_timezone_supplied(self, run_cleanup_mock, timezone):
+        """
+        When tz included in the string then default timezone should not be 
used.
+        """
+        timestamp = '2021-01-01 00:00:00+03:00'
+        with patch('airflow.utils.timezone.TIMEZONE', 
pendulum.timezone(timezone)):
+            args = self.parser.parse_args(['db', 'clean', 
'--clean-before-timestamp', f"{timestamp}", '-y'])
+            db_command.cleanup_tables(args)
+
+        run_cleanup_mock.assert_called_once_with(
+            table_names=None,
+            dry_run=False,
+            clean_before_timestamp=pendulum.parse(timestamp),
+            verbose=False,
+            confirm=False,
+        )
+
+    @pytest.mark.parametrize('confirm_arg, expected', [(['-y'], False), ([], 
True)])
+    @patch('airflow.cli.commands.db_command.run_cleanup')
+    def test_confirm(self, run_cleanup_mock, confirm_arg, expected):
+        """
+        When tz included in the string then default timezone should not be 
used.
+        """
+        args = self.parser.parse_args(
+            [
+                'db',
+                'clean',
+                '--clean-before-timestamp',
+                '2021-01-01',
+                *confirm_arg,
+            ]
+        )
+        db_command.cleanup_tables(args)
+
+        run_cleanup_mock.assert_called_once_with(
+            table_names=None,
+            dry_run=False,
+            clean_before_timestamp=pendulum.parse('2021-01-01 00:00:00Z'),
+            verbose=False,
+            confirm=expected,
+        )
+
+    @pytest.mark.parametrize('dry_run_arg, expected', [(['--dry-run'], True), 
([], False)])
+    @patch('airflow.cli.commands.db_command.run_cleanup')
+    def test_dry_run(self, run_cleanup_mock, dry_run_arg, expected):
+        """
+        When tz included in the string then default timezone should not be 
used.
+        """
+        args = self.parser.parse_args(
+            [
+                'db',
+                'clean',
+                '--clean-before-timestamp',
+                '2021-01-01',
+                *dry_run_arg,
+            ]
+        )
+        db_command.cleanup_tables(args)
+
+        run_cleanup_mock.assert_called_once_with(
+            table_names=None,
+            dry_run=expected,
+            clean_before_timestamp=pendulum.parse('2021-01-01 00:00:00Z'),
+            verbose=False,
+            confirm=True,
+        )
+
+    @pytest.mark.parametrize(
+        'extra_args, expected', [(['--tables', 'hello, goodbye'], ['hello', 
'goodbye']), ([], None)]
+    )
+    @patch('airflow.cli.commands.db_command.run_cleanup')
+    def test_tables(self, run_cleanup_mock, extra_args, expected):
+        """
+        When tz included in the string then default timezone should not be 
used.
+        """
+        args = self.parser.parse_args(
+            [
+                'db',
+                'clean',
+                '--clean-before-timestamp',
+                '2021-01-01',
+                *extra_args,
+            ]
+        )
+        db_command.cleanup_tables(args)
+
+        run_cleanup_mock.assert_called_once_with(
+            table_names=expected,
+            dry_run=False,
+            clean_before_timestamp=pendulum.parse('2021-01-01 00:00:00Z'),
+            verbose=False,
+            confirm=True,
+        )
+
+    @pytest.mark.parametrize('extra_args, expected', [(['--verbose'], True), 
([], False)])
+    @patch('airflow.cli.commands.db_command.run_cleanup')
+    def test_verbose(self, run_cleanup_mock, extra_args, expected):
+        """
+        When tz included in the string then default timezone should not be 
used.
+        """
+        args = self.parser.parse_args(
+            [
+                'db',
+                'clean',
+                '--clean-before-timestamp',
+                '2021-01-01',
+                *extra_args,
+            ]
+        )
+        db_command.cleanup_tables(args)
+
+        run_cleanup_mock.assert_called_once_with(
+            table_names=None,
+            dry_run=False,
+            clean_before_timestamp=pendulum.parse('2021-01-01 00:00:00Z'),
+            verbose=expected,
+            confirm=True,
+        )
diff --git a/tests/utils/test_db_cleanup.py b/tests/utils/test_db_cleanup.py
new file mode 100644
index 0000000..8ef80eb
--- /dev/null
+++ b/tests/utils/test_db_cleanup.py
@@ -0,0 +1,265 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from contextlib import suppress
+from importlib import import_module
+from pathlib import Path
+from unittest.mock import MagicMock, patch
+from uuid import uuid4
+
+import pendulum
+import pytest
+from pytest import param
+from sqlalchemy.ext.declarative import DeclarativeMeta
+
+from airflow.models import DagModel, DagRun, TaskInstance
+from airflow.operators.python import PythonOperator
+from airflow.utils.db_cleanup import _build_query, _cleanup_table, 
config_dict, run_cleanup
+from airflow.utils.session import create_session
+from tests.test_utils.db import clear_db_dags, clear_db_runs
+
+
[email protected](autouse=True)
+def clean_database():
+    """Fixture that cleans the database before and after every test."""
+    clear_db_runs()
+    clear_db_dags()
+    yield  # Test runs here
+    clear_db_dags()
+    clear_db_runs()
+
+
+class TestDBCleanup:
+    @pytest.mark.parametrize(
+        'kwargs, called',
+        [
+            param(dict(confirm=True), True, id='true'),
+            param(dict(), True, id='not supplied'),
+            param(dict(confirm=False), False, id='false'),
+        ],
+    )
+    @patch('airflow.utils.db_cleanup._cleanup_table', new=MagicMock())
+    @patch('airflow.utils.db_cleanup._confirm_delete')
+    def test_run_cleanup_confirm(self, confirm_delete_mock, kwargs, called):
+        """test that delete confirmation input is called when appropriate"""
+        run_cleanup(
+            clean_before_timestamp=None,
+            table_names=None,
+            dry_run=None,
+            verbose=None,
+            **kwargs,
+        )
+        if called:
+            confirm_delete_mock.assert_called()
+        else:
+            confirm_delete_mock.assert_not_called()
+
+    @pytest.mark.parametrize(
+        'table_names',
+        [
+            ['xcom', 'log'],
+            None,
+        ],
+    )
+    @patch('airflow.utils.db_cleanup._cleanup_table')
+    @patch('airflow.utils.db_cleanup._confirm_delete', new=MagicMock())
+    def test_run_cleanup_tables(self, clean_table_mock, table_names):
+        """
+        ``_cleanup_table`` should be called for each table in subset if one
+        is provided else should be called for all tables.
+        """
+        base_kwargs = dict(
+            clean_before_timestamp=None,
+            dry_run=None,
+            verbose=None,
+        )
+        run_cleanup(**base_kwargs, table_names=table_names)
+        assert clean_table_mock.call_count == len(table_names) if table_names 
else len(config_dict)
+
+    @pytest.mark.parametrize(
+        'dry_run',
+        [None, True, False],
+    )
+    @patch('airflow.utils.db_cleanup._build_query', MagicMock())
+    @patch('airflow.utils.db_cleanup._print_entities', MagicMock())
+    @patch('airflow.utils.db_cleanup._do_delete')
+    @patch('airflow.utils.db_cleanup._confirm_delete', MagicMock())
+    def test_run_cleanup_dry_run(self, do_delete, dry_run):
+        """Delete should only be called when not dry_run"""
+        base_kwargs = dict(
+            clean_before_timestamp=None,
+            dry_run=dry_run,
+            verbose=None,
+        )
+        run_cleanup(
+            **base_kwargs,
+        )
+        if dry_run:
+            do_delete.assert_not_called()
+        else:
+            do_delete.assert_called()
+
+    @pytest.mark.parametrize(
+        'table_name, date_add_kwargs, expected_to_delete, external_trigger',
+        [
+            param('task_instance', dict(days=0), 0, False, id='beginning'),
+            param('task_instance', dict(days=4), 4, False, id='middle'),
+            param('task_instance', dict(days=9), 9, False, id='end_exactly'),
+            param('task_instance', dict(days=9, microseconds=1), 10, False, 
id='beyond_end'),
+            param('dag_run', dict(days=9, microseconds=1), 9, False, 
id='beyond_end_dr'),
+            param('dag_run', dict(days=9, microseconds=1), 10, True, 
id='beyond_end_dr_external'),
+        ],
+    )
+    def test__build_query(self, table_name, date_add_kwargs, 
expected_to_delete, external_trigger):
+        """
+        Verify that ``_build_query`` produces a query that would delete the 
right
+        task instance records depending on the value of 
``clean_before_timestamp``.
+
+        DagRun is a special case where we always keep the last dag run even if
+        the ``clean_before_timestamp`` is in the future, except for
+        externally-triggered dag runs. That is, only the last 
non-externally-triggered
+        dag run is kept.
+
+        """
+        base_date = pendulum.DateTime(2022, 1, 1, 
tzinfo=pendulum.timezone('America/Los_Angeles'))
+        create_tis(
+            base_date=base_date,
+            num_tis=10,
+            external_trigger=external_trigger,
+        )
+        with create_session() as session:
+            clean_before_date = base_date.add(**date_add_kwargs)
+            query = _build_query(
+                **config_dict[table_name].__dict__,
+                clean_before_timestamp=clean_before_date,
+                session=session,
+            )
+            assert len(query.all()) == expected_to_delete
+
+    @pytest.mark.parametrize(
+        'table_name, date_add_kwargs, expected_to_delete, external_trigger',
+        [
+            param('task_instance', dict(days=0), 0, False, id='beginning'),
+            param('task_instance', dict(days=4), 4, False, id='middle'),
+            param('task_instance', dict(days=9), 9, False, id='end_exactly'),
+            param('task_instance', dict(days=9, microseconds=1), 10, False, 
id='beyond_end'),
+            param('dag_run', dict(days=9, microseconds=1), 9, False, 
id='beyond_end_dr'),
+            param('dag_run', dict(days=9, microseconds=1), 10, True, 
id='beyond_end_dr_external'),
+        ],
+    )
+    def test__cleanup_table(self, table_name, date_add_kwargs, 
expected_to_delete, external_trigger):
+        """
+        Verify that _cleanup_table actually deletes the rows it should.
+
+        TaskInstance represents the "normal" case.  DagRun is the odd case 
where we want
+        to keep the last non-externally-triggered DagRun record even if if it 
should be
+        deleted according to the provided timestamp.
+
+        We also verify that the "on delete cascade" behavior is as expected.  
Some tables
+        have foreign keys defined so for example if we delete a dag run, all 
its associated
+        task instances should be purged as well.  But if we delete task 
instances the
+        associated dag runs should remain.
+
+        """
+        base_date = pendulum.DateTime(2022, 1, 1, 
tzinfo=pendulum.timezone('America/Los_Angeles'))
+        num_tis = 10
+        create_tis(
+            base_date=base_date,
+            num_tis=num_tis,
+            external_trigger=external_trigger,
+        )
+        with create_session() as session:
+            clean_before_date = base_date.add(**date_add_kwargs)
+            _cleanup_table(
+                **config_dict[table_name].__dict__,
+                clean_before_timestamp=clean_before_date,
+                dry_run=False,
+                session=session,
+            )
+            model = config_dict[table_name].orm_model
+            expected_remaining = num_tis - expected_to_delete
+            assert len(session.query(model).all()) == expected_remaining
+            if model == TaskInstance:
+                assert len(session.query(DagRun).all()) == num_tis
+            elif model == DagRun:
+                assert len(session.query(TaskInstance).all()) == 
expected_remaining
+            else:
+                raise Exception("unexpected")
+
+    def test_no_models_missing(self):
+        """
+        1. Verify that for all tables in `airflow.models`, we either have them 
enabled in db cleanup,
+        or documented in the exclusion list in this test.
+        2. Verify that no table is enabled for db cleanup and also in 
exclusion list.
+        """
+        import pkgutil
+
+        proj_root = Path(__file__).parent.parent.parent
+        mods = list(
+            f"airflow.models.{name}" for _, name, _ in 
pkgutil.iter_modules([proj_root / 'airflow/models'])
+        )
+
+        all_models = {}
+        for mod_name in mods:
+            mod = import_module(mod_name)
+
+            for table_name, class_ in mod.__dict__.items():
+                if isinstance(class_, DeclarativeMeta):
+                    with suppress(AttributeError):
+                        all_models.update({class_.__tablename__: class_})
+        exclusion_list = {
+            'variable',  # leave alone
+            'trigger',  # self-maintaining
+            'task_map',  # TODO: add datetime column to TaskMap so we can 
include it here
+            'serialized_dag',  # handled through FK to Dag
+            'log_template',  # not a significant source of data; age not 
indicative of staleness
+            'dag_tag',  # not a significant source of data; age not indicative 
of staleness,
+            'dag_pickle',  # unsure of consequences
+            'dag_code',  # self-maintaining
+            'connection',  # leave alone
+            'slot_pool',  # leave alone
+        }
+
+        from airflow.utils.db_cleanup import config_dict
+
+        print(f"all_models={set(all_models)}")
+        print(f"excl+conf={exclusion_list.union(config_dict)}")
+        assert set(all_models) - exclusion_list.union(config_dict) == set()
+        assert exclusion_list.isdisjoint(config_dict)
+
+
+def create_tis(base_date, num_tis, external_trigger=False):
+    with create_session() as session:
+        dag = DagModel(dag_id=f'test-dag_{uuid4()}')
+        session.add(dag)
+        for num in range(num_tis):
+            start_date = base_date.add(days=num)
+            dag_run = DagRun(
+                dag.dag_id,
+                run_id=f'abc_{num}',
+                run_type='none',
+                start_date=start_date,
+                external_trigger=external_trigger,
+            )
+            ti = TaskInstance(
+                PythonOperator(task_id='dummy-task', python_callable=print), 
run_id=dag_run.run_id
+            )
+            ti.dag_id = dag.dag_id
+            ti.start_date = start_date
+            session.add(dag_run)
+            session.add(ti)
+        session.commit()

Reply via email to