This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c75774d Add `db clean` CLI command for purging old data (#20838)
c75774d is described below
commit c75774d3a31efe749f55ba16e782737df9f53af4
Author: Daniel Standish <[email protected]>
AuthorDate: Fri Feb 25 21:55:21 2022 -0800
Add `db clean` CLI command for purging old data (#20838)
CLI command to delete old rows from airflow metadata database.
Notes:
* Must supply "purge before date".
* Can optionally provide table list.
* Dry run will only print the number of rows meeting criteria.
* If not dry run, will require the user to confirm before deleting.
---
airflow/cli/cli_parser.py | 45 ++++-
airflow/cli/commands/db_command.py | 17 ++
airflow/utils/db_cleanup.py | 315 ++++++++++++++++++++++++++++++++++
docs/spelling_wordlist.txt | 2 +
tests/cli/commands/test_db_command.py | 148 ++++++++++++++++
tests/utils/test_db_cleanup.py | 265 ++++++++++++++++++++++++++++
6 files changed, 791 insertions(+), 1 deletion(-)
diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py
index c91ae7a..cb0bcd2 100644
--- a/airflow/cli/cli_parser.py
+++ b/airflow/cli/cli_parser.py
@@ -26,6 +26,8 @@ from argparse import Action, ArgumentError,
RawTextHelpFormatter
from functools import lru_cache
from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Union
+import lazy_object_proxy
+
from airflow import PY37, settings
from airflow.cli.commands.legacy_commands import check_legacy_command
from airflow.configuration import conf
@@ -162,6 +164,11 @@ def positive_int(*, allow_zero):
return _check
+def string_list_type(val):
+ """Parses comma-separated list and returns list of string (strips
whitespace)"""
+ return [x.strip() for x in val.split(',')]
+
+
# Shared
ARG_DAG_ID = Arg(("dag_id",), help="The id of the dag")
ARG_TASK_ID = Arg(("task_id",), help="The id of the task")
@@ -205,7 +212,7 @@ ARG_STDERR = Arg(("--stderr",), help="Redirect stderr to
this file")
ARG_STDOUT = Arg(("--stdout",), help="Redirect stdout to this file")
ARG_LOG_FILE = Arg(("-l", "--log-file"), help="Location of the log file")
ARG_YES = Arg(
- ("-y", "--yes"), help="Do not prompt to confirm reset. Use with care!",
action="store_true", default=False
+ ("-y", "--yes"), help="Do not prompt to confirm. Use with care!",
action="store_true", default=False
)
ARG_OUTPUT = Arg(
(
@@ -398,6 +405,30 @@ ARG_RUN_ID = Arg(("-r", "--run-id"), help="Helps to
identify this run")
ARG_CONF = Arg(('-c', '--conf'), help="JSON string that gets pickled into the
DagRun's conf attribute")
ARG_EXEC_DATE = Arg(("-e", "--exec-date"), help="The execution date of the
DAG", type=parsedate)
+# db
+ARG_DB_TABLES = Arg(
+ ("-t", "--tables"),
+ help=lazy_object_proxy.Proxy(
+ lambda: f"Table names to perform maintenance on (use comma-separated
list).\n"
+ f"Options:
{import_string('airflow.cli.commands.db_command.all_tables')}"
+ ),
+ type=string_list_type,
+)
+ARG_DB_CLEANUP_TIMESTAMP = Arg(
+ ("--clean-before-timestamp",),
+ help="The date or timestamp before which data should be purged.\n"
+ "If no timezone info is supplied then dates are assumed to be in airflow
default timezone.\n"
+ "Example: '2022-01-01 00:00:00+01:00'",
+ type=parsedate,
+ required=True,
+)
+ARG_DB_DRY_RUN = Arg(
+ ("--dry-run",),
+ help="Perform a dry run",
+ action="store_true",
+)
+
+
# pool
ARG_POOL_NAME = Arg(("pool",), metavar='NAME', help="Pool name")
ARG_POOL_SLOTS = Arg(("slots",), type=int, help="Pool slots")
@@ -1308,6 +1339,18 @@ DB_COMMANDS = (
func=lazy_load_command('airflow.cli.commands.db_command.check'),
args=(),
),
+ ActionCommand(
+ name='clean',
+ help="Purge old records in metastore tables",
+
func=lazy_load_command('airflow.cli.commands.db_command.cleanup_tables'),
+ args=(
+ ARG_DB_TABLES,
+ ARG_DB_DRY_RUN,
+ ARG_DB_CLEANUP_TIMESTAMP,
+ ARG_VERBOSE,
+ ARG_YES,
+ ),
+ ),
)
CONNECTIONS_COMMANDS = (
ActionCommand(
diff --git a/airflow/cli/commands/db_command.py
b/airflow/cli/commands/db_command.py
index 09fe221..02811d0 100644
--- a/airflow/cli/commands/db_command.py
+++ b/airflow/cli/commands/db_command.py
@@ -22,6 +22,7 @@ from tempfile import NamedTemporaryFile
from airflow import settings
from airflow.exceptions import AirflowException
from airflow.utils import cli as cli_utils, db
+from airflow.utils.db_cleanup import config_dict, run_cleanup
from airflow.utils.process_utils import execute_interactive
@@ -101,3 +102,19 @@ def shell(args):
def check(_):
"""Runs a check command that checks if db is available."""
db.check()
+
+
+# lazily imported by CLI parser for `help` command
+all_tables = sorted(config_dict)
+
+
+@cli_utils.action_cli(check_db=False)
+def cleanup_tables(args):
+ """Purges old records in metadata database"""
+ run_cleanup(
+ table_names=args.tables,
+ dry_run=args.dry_run,
+ clean_before_timestamp=args.clean_before_timestamp,
+ verbose=args.verbose,
+ confirm=not args.yes,
+ )
diff --git a/airflow/utils/db_cleanup.py b/airflow/utils/db_cleanup.py
new file mode 100644
index 0000000..d1449bd
--- /dev/null
+++ b/airflow/utils/db_cleanup.py
@@ -0,0 +1,315 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+This module took inspiration from the community maintenance dag
+(https://github.com/teamclairvoyant/airflow-maintenance-dags/blob/4e5c7682a808082561d60cbc9cafaa477b0d8c65/db-cleanup/airflow-db-cleanup.py).
+"""
+
+import logging
+from contextlib import AbstractContextManager
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+from pendulum import DateTime
+from sqlalchemy import and_, false, func
+from sqlalchemy.exc import OperationalError
+
+from airflow.cli.simple_table import AirflowConsole
+from airflow.jobs.base_job import BaseJob
+from airflow.models import (
+ Base,
+ DagModel,
+ DagRun,
+ ImportError,
+ Log,
+ RenderedTaskInstanceFields,
+ SensorInstance,
+ SlaMiss,
+ TaskFail,
+ TaskInstance,
+ TaskReschedule,
+ XCom,
+)
+from airflow.utils import timezone
+from airflow.utils.session import NEW_SESSION, provide_session
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Query, Session
+ from sqlalchemy.orm.attributes import InstrumentedAttribute
+ from sqlalchemy.sql.schema import Column
+
+
+@dataclass
+class _TableConfig:
+ """
+ Config class for performing cleanup on a table
+
+ :param orm_model: the table
+ :param recency_column: date column to filter by
+ :param keep_last: whether the last record should be kept even if it's
older than clean_before_timestamp
+ :param keep_last_filters: the "keep last" functionality will preserve the
most recent record
+ in the table. to ignore certain records even if they are the latest
in the table, you can
+ supply additional filters here (e.g. externally triggered dag runs)
+ :param keep_last_group_by: if keeping the last record, can keep the last
record for each group
+ :param warn_if_missing: If True, then we'll suppress "table missing"
exception and log a warning.
+ If False then the exception will go uncaught.
+ """
+
+ orm_model: Base
+ recency_column: Union["Column", "InstrumentedAttribute"]
+ keep_last: bool = False
+ keep_last_filters: Optional[Any] = None
+ keep_last_group_by: Optional[Any] = None
+ warn_if_missing: bool = False
+
+ def __lt__(self, other):
+ return self.orm_model.__tablename__ < other.orm_model.__tablename__
+
+ @property
+ def readable_config(self):
+ return dict(
+ table=self.orm_model.__tablename__,
+ recency_column=str(self.recency_column),
+ keep_last=self.keep_last,
+ keep_last_filters=[str(x) for x in self.keep_last_filters] if
self.keep_last_filters else None,
+ keep_last_group_by=str(self.keep_last_group_by),
+ warn_if_missing=str(self.warn_if_missing),
+ )
+
+
+config_list: List[_TableConfig] = [
+ _TableConfig(orm_model=BaseJob, recency_column=BaseJob.latest_heartbeat),
+ _TableConfig(orm_model=DagModel, recency_column=DagModel.last_parsed_time),
+ _TableConfig(
+ orm_model=DagRun,
+ recency_column=DagRun.start_date,
+ keep_last=True,
+ keep_last_filters=[DagRun.external_trigger == false()],
+ keep_last_group_by=DagRun.dag_id,
+ ),
+ _TableConfig(orm_model=ImportError, recency_column=ImportError.timestamp),
+ _TableConfig(orm_model=Log, recency_column=Log.dttm),
+ _TableConfig(
+ orm_model=RenderedTaskInstanceFields,
recency_column=RenderedTaskInstanceFields.execution_date
+ ),
+ _TableConfig(
+ orm_model=SensorInstance, recency_column=SensorInstance.updated_at
+ ), # TODO: add FK to task instance / dag so we can remove here
+ _TableConfig(orm_model=SlaMiss, recency_column=SlaMiss.timestamp),
+ _TableConfig(orm_model=TaskFail, recency_column=TaskFail.start_date),
+ _TableConfig(orm_model=TaskInstance,
recency_column=TaskInstance.start_date),
+ _TableConfig(orm_model=TaskReschedule,
recency_column=TaskReschedule.start_date),
+ _TableConfig(orm_model=XCom, recency_column=XCom.timestamp),
+]
+try:
+ from celery.backends.database.models import Task, TaskSet
+
+ config_list.extend(
+ [
+ _TableConfig(orm_model=Task, recency_column=Task.date_done,
warn_if_missing=True),
+ _TableConfig(orm_model=TaskSet, recency_column=TaskSet.date_done,
warn_if_missing=True),
+ ]
+ )
+except ImportError:
+ pass
+
+config_dict: Dict[str, _TableConfig] = {x.orm_model.__tablename__: x for x in
sorted(config_list)}
+
+
+def _print_entities(*, query: "Query", print_rows=False):
+ num_entities = query.count()
+ print(f"Found {num_entities} rows meeting deletion criteria.")
+ if not print_rows:
+ return
+ max_rows_to_print = 100
+ if num_entities > 0:
+ print(f"Printing first {max_rows_to_print} rows.")
+ logger.debug("print entities query: %s", query)
+ for entry in query.limit(max_rows_to_print):
+ print(entry.__dict__)
+
+
+def _do_delete(*, query, session):
+ print("Performing Delete...")
+ # using bulk delete
+ query.delete(synchronize_session=False)
+ session.commit()
+ print("Finished Performing Delete")
+
+
+def _subquery_keep_last(*, recency_column, keep_last_filters,
keep_last_group_by, session):
+ subquery = session.query(func.max(recency_column))
+
+ if keep_last_filters is not None:
+ for entry in keep_last_filters:
+ subquery = subquery.filter(entry)
+
+ if keep_last_group_by is not None:
+ subquery = subquery.group_by(keep_last_group_by)
+
+ # We nest this subquery to work around a MySQL "table specified twice"
issue
+ # See https://github.com/teamclairvoyant/airflow-maintenance-dags/issues/41
+ # and
https://github.com/teamclairvoyant/airflow-maintenance-dags/pull/57/files.
+ subquery = subquery.from_self()
+ return subquery
+
+
+def _build_query(
+ *,
+ orm_model,
+ recency_column,
+ keep_last,
+ keep_last_filters,
+ keep_last_group_by,
+ clean_before_timestamp,
+ session,
+ **kwargs,
+):
+ query = session.query(orm_model)
+ conditions = [recency_column < clean_before_timestamp]
+ if keep_last:
+ subquery = _subquery_keep_last(
+ recency_column=recency_column,
+ keep_last_filters=keep_last_filters,
+ keep_last_group_by=keep_last_group_by,
+ session=session,
+ )
+ conditions.append(recency_column.notin_(subquery))
+ query = query.filter(and_(*conditions))
+ return query
+
+
+logger = logging.getLogger(__file__)
+
+
+def _cleanup_table(
+ *,
+ orm_model,
+ recency_column,
+ keep_last,
+ keep_last_filters,
+ keep_last_group_by,
+ clean_before_timestamp,
+ dry_run=True,
+ verbose=False,
+ session=None,
+ **kwargs,
+):
+ print()
+ if dry_run:
+ print(f"Performing dry run for table {orm_model.__tablename__!r}")
+ query = _build_query(
+ orm_model=orm_model,
+ recency_column=recency_column,
+ keep_last=keep_last,
+ keep_last_filters=keep_last_filters,
+ keep_last_group_by=keep_last_group_by,
+ clean_before_timestamp=clean_before_timestamp,
+ session=session,
+ )
+
+ _print_entities(query=query, print_rows=False)
+
+ if not dry_run:
+ _do_delete(query=query, session=session)
+ session.commit()
+
+
+def _confirm_delete(*, date: DateTime, tables: List[str]):
+ for_tables = f" for tables {tables!r}" if tables else ''
+ question = (
+ f"You have requested that we purge all data prior to
{date}{for_tables}.\n"
+ f"This is irreversible. Consider backing up the tables first and / or
doing a dry run "
+ f"with option --dry-run.\n"
+ f"Enter 'delete rows' (without quotes) to proceed."
+ )
+ print(question)
+ answer = input().strip()
+ if not answer == 'delete rows':
+ raise SystemExit("User did not confirm; exiting.")
+
+
+def _print_config(*, configs: Dict[str, _TableConfig]):
+ data = [x.readable_config for x in configs.values()]
+ AirflowConsole().print_as_table(data=data)
+
+
+class _warn_if_missing(AbstractContextManager):
+ def __init__(self, table, suppress):
+ self.table = table
+ self.suppress = suppress
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exctype, excinst, exctb):
+ caught_error = exctype is not None and issubclass(exctype,
OperationalError)
+ if caught_error:
+ logger.warning("Table %r not found. Skipping.", self.table)
+ return caught_error
+
+
+@provide_session
+def run_cleanup(
+ *,
+ clean_before_timestamp: DateTime,
+ table_names: Optional[List[str]] = None,
+ dry_run: bool = False,
+ verbose: bool = False,
+ confirm: bool = True,
+ session: 'Session' = NEW_SESSION,
+):
+ """
+ Purges old records in airflow metadata database.
+
+ The last non-externally-triggered dag run will always be kept in order to
ensure
+ continuity of scheduled dag runs.
+
+ Where there are foreign key relationships, deletes will cascade, so that
for
+ example if you clean up old dag runs, the associated task instances will
+ be deleted.
+
+ :param clean_before_timestamp: The timestamp before which data should be
purged
+ :param table_names: Optional. List of table names to perform maintenance
on. If list not provided,
+ will perform maintenance on all tables.
+ :param dry_run: If true, print rows meeting deletion criteria
+ :param verbose: If true, may provide more detailed output.
+ :param confirm: Require user input to confirm before processing deletions.
+ :param session: Session representing connection to the metadata database.
+ """
+ clean_before_timestamp = timezone.coerce_datetime(clean_before_timestamp)
+ effective_table_names = table_names if table_names else
list(config_dict.keys())
+ effective_config_dict = {k: v for k, v in config_dict.items() if k in
effective_table_names}
+ if dry_run:
+ print('Performing dry run for db cleanup.')
+ print(
+ f"Data prior to {clean_before_timestamp} would be purged "
+ f"from tables {effective_table_names} with the following config:\n"
+ )
+ _print_config(configs=effective_config_dict)
+ if not dry_run and confirm:
+ _confirm_delete(date=clean_before_timestamp,
tables=list(effective_config_dict.keys()))
+ for table_name, table_config in effective_config_dict.items():
+ with _warn_if_missing(table_name, table_config.warn_if_missing):
+ _cleanup_table(
+ clean_before_timestamp=clean_before_timestamp,
+ dry_run=dry_run,
+ verbose=verbose,
+ **table_config.__dict__,
+ session=session,
+ )
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 68ecca0..efae25a 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1375,7 +1375,9 @@ tagValue
task_group
taskflow
taskinstance
+taskmeta
taskmixin
+tasksetmeta
tblproperties
tcp
teardown
diff --git a/tests/cli/commands/test_db_command.py
b/tests/cli/commands/test_db_command.py
index e94e63a..09cb14a 100644
--- a/tests/cli/commands/test_db_command.py
+++ b/tests/cli/commands/test_db_command.py
@@ -17,7 +17,9 @@
import unittest
from unittest import mock
+from unittest.mock import patch
+import pendulum
import pytest
from sqlalchemy.engine.url import make_url
@@ -134,3 +136,149 @@ class TestCliDb(unittest.TestCase):
def test_cli_shell_invalid(self):
with pytest.raises(AirflowException, match=r"Unknown driver:
invalid\+psycopg2"):
db_command.shell(self.parser.parse_args(['db', 'shell']))
+
+
+class TestCLIDBClean:
+ @classmethod
+ def setup_class(cls):
+ cls.parser = cli_parser.get_parser()
+
+ @pytest.mark.parametrize('timezone', ['UTC', 'Europe/Berlin',
'America/Los_Angeles'])
+ @patch('airflow.cli.commands.db_command.run_cleanup')
+ def test_date_timezone_omitted(self, run_cleanup_mock, timezone):
+ """
+ When timezone omitted we should always expect that the timestamp is
+ coerced to tz-aware with default timezone
+ """
+ timestamp = '2021-01-01 00:00:00'
+ with patch('airflow.utils.timezone.TIMEZONE',
pendulum.timezone(timezone)):
+ args = self.parser.parse_args(['db', 'clean',
'--clean-before-timestamp', f"{timestamp}", '-y'])
+ db_command.cleanup_tables(args)
+ run_cleanup_mock.assert_called_once_with(
+ table_names=None,
+ dry_run=False,
+ clean_before_timestamp=pendulum.parse(timestamp, tz=timezone),
+ verbose=False,
+ confirm=False,
+ )
+
+ @pytest.mark.parametrize('timezone', ['UTC', 'Europe/Berlin',
'America/Los_Angeles'])
+ @patch('airflow.cli.commands.db_command.run_cleanup')
+ def test_date_timezone_supplied(self, run_cleanup_mock, timezone):
+ """
+ When tz included in the string then default timezone should not be
used.
+ """
+ timestamp = '2021-01-01 00:00:00+03:00'
+ with patch('airflow.utils.timezone.TIMEZONE',
pendulum.timezone(timezone)):
+ args = self.parser.parse_args(['db', 'clean',
'--clean-before-timestamp', f"{timestamp}", '-y'])
+ db_command.cleanup_tables(args)
+
+ run_cleanup_mock.assert_called_once_with(
+ table_names=None,
+ dry_run=False,
+ clean_before_timestamp=pendulum.parse(timestamp),
+ verbose=False,
+ confirm=False,
+ )
+
+ @pytest.mark.parametrize('confirm_arg, expected', [(['-y'], False), ([],
True)])
+ @patch('airflow.cli.commands.db_command.run_cleanup')
+ def test_confirm(self, run_cleanup_mock, confirm_arg, expected):
+ """
+ When tz included in the string then default timezone should not be
used.
+ """
+ args = self.parser.parse_args(
+ [
+ 'db',
+ 'clean',
+ '--clean-before-timestamp',
+ '2021-01-01',
+ *confirm_arg,
+ ]
+ )
+ db_command.cleanup_tables(args)
+
+ run_cleanup_mock.assert_called_once_with(
+ table_names=None,
+ dry_run=False,
+ clean_before_timestamp=pendulum.parse('2021-01-01 00:00:00Z'),
+ verbose=False,
+ confirm=expected,
+ )
+
+ @pytest.mark.parametrize('dry_run_arg, expected', [(['--dry-run'], True),
([], False)])
+ @patch('airflow.cli.commands.db_command.run_cleanup')
+ def test_dry_run(self, run_cleanup_mock, dry_run_arg, expected):
+ """
+ When tz included in the string then default timezone should not be
used.
+ """
+ args = self.parser.parse_args(
+ [
+ 'db',
+ 'clean',
+ '--clean-before-timestamp',
+ '2021-01-01',
+ *dry_run_arg,
+ ]
+ )
+ db_command.cleanup_tables(args)
+
+ run_cleanup_mock.assert_called_once_with(
+ table_names=None,
+ dry_run=expected,
+ clean_before_timestamp=pendulum.parse('2021-01-01 00:00:00Z'),
+ verbose=False,
+ confirm=True,
+ )
+
+ @pytest.mark.parametrize(
+ 'extra_args, expected', [(['--tables', 'hello, goodbye'], ['hello',
'goodbye']), ([], None)]
+ )
+ @patch('airflow.cli.commands.db_command.run_cleanup')
+ def test_tables(self, run_cleanup_mock, extra_args, expected):
+ """
+ When tz included in the string then default timezone should not be
used.
+ """
+ args = self.parser.parse_args(
+ [
+ 'db',
+ 'clean',
+ '--clean-before-timestamp',
+ '2021-01-01',
+ *extra_args,
+ ]
+ )
+ db_command.cleanup_tables(args)
+
+ run_cleanup_mock.assert_called_once_with(
+ table_names=expected,
+ dry_run=False,
+ clean_before_timestamp=pendulum.parse('2021-01-01 00:00:00Z'),
+ verbose=False,
+ confirm=True,
+ )
+
+ @pytest.mark.parametrize('extra_args, expected', [(['--verbose'], True),
([], False)])
+ @patch('airflow.cli.commands.db_command.run_cleanup')
+ def test_verbose(self, run_cleanup_mock, extra_args, expected):
+ """
+ When tz included in the string then default timezone should not be
used.
+ """
+ args = self.parser.parse_args(
+ [
+ 'db',
+ 'clean',
+ '--clean-before-timestamp',
+ '2021-01-01',
+ *extra_args,
+ ]
+ )
+ db_command.cleanup_tables(args)
+
+ run_cleanup_mock.assert_called_once_with(
+ table_names=None,
+ dry_run=False,
+ clean_before_timestamp=pendulum.parse('2021-01-01 00:00:00Z'),
+ verbose=expected,
+ confirm=True,
+ )
diff --git a/tests/utils/test_db_cleanup.py b/tests/utils/test_db_cleanup.py
new file mode 100644
index 0000000..8ef80eb
--- /dev/null
+++ b/tests/utils/test_db_cleanup.py
@@ -0,0 +1,265 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from contextlib import suppress
+from importlib import import_module
+from pathlib import Path
+from unittest.mock import MagicMock, patch
+from uuid import uuid4
+
+import pendulum
+import pytest
+from pytest import param
+from sqlalchemy.ext.declarative import DeclarativeMeta
+
+from airflow.models import DagModel, DagRun, TaskInstance
+from airflow.operators.python import PythonOperator
+from airflow.utils.db_cleanup import _build_query, _cleanup_table,
config_dict, run_cleanup
+from airflow.utils.session import create_session
+from tests.test_utils.db import clear_db_dags, clear_db_runs
+
+
[email protected](autouse=True)
+def clean_database():
+ """Fixture that cleans the database before and after every test."""
+ clear_db_runs()
+ clear_db_dags()
+ yield # Test runs here
+ clear_db_dags()
+ clear_db_runs()
+
+
+class TestDBCleanup:
+ @pytest.mark.parametrize(
+ 'kwargs, called',
+ [
+ param(dict(confirm=True), True, id='true'),
+ param(dict(), True, id='not supplied'),
+ param(dict(confirm=False), False, id='false'),
+ ],
+ )
+ @patch('airflow.utils.db_cleanup._cleanup_table', new=MagicMock())
+ @patch('airflow.utils.db_cleanup._confirm_delete')
+ def test_run_cleanup_confirm(self, confirm_delete_mock, kwargs, called):
+ """test that delete confirmation input is called when appropriate"""
+ run_cleanup(
+ clean_before_timestamp=None,
+ table_names=None,
+ dry_run=None,
+ verbose=None,
+ **kwargs,
+ )
+ if called:
+ confirm_delete_mock.assert_called()
+ else:
+ confirm_delete_mock.assert_not_called()
+
+ @pytest.mark.parametrize(
+ 'table_names',
+ [
+ ['xcom', 'log'],
+ None,
+ ],
+ )
+ @patch('airflow.utils.db_cleanup._cleanup_table')
+ @patch('airflow.utils.db_cleanup._confirm_delete', new=MagicMock())
+ def test_run_cleanup_tables(self, clean_table_mock, table_names):
+ """
+ ``_cleanup_table`` should be called for each table in subset if one
+ is provided else should be called for all tables.
+ """
+ base_kwargs = dict(
+ clean_before_timestamp=None,
+ dry_run=None,
+ verbose=None,
+ )
+ run_cleanup(**base_kwargs, table_names=table_names)
+ assert clean_table_mock.call_count == len(table_names) if table_names
else len(config_dict)
+
+ @pytest.mark.parametrize(
+ 'dry_run',
+ [None, True, False],
+ )
+ @patch('airflow.utils.db_cleanup._build_query', MagicMock())
+ @patch('airflow.utils.db_cleanup._print_entities', MagicMock())
+ @patch('airflow.utils.db_cleanup._do_delete')
+ @patch('airflow.utils.db_cleanup._confirm_delete', MagicMock())
+ def test_run_cleanup_dry_run(self, do_delete, dry_run):
+ """Delete should only be called when not dry_run"""
+ base_kwargs = dict(
+ clean_before_timestamp=None,
+ dry_run=dry_run,
+ verbose=None,
+ )
+ run_cleanup(
+ **base_kwargs,
+ )
+ if dry_run:
+ do_delete.assert_not_called()
+ else:
+ do_delete.assert_called()
+
+ @pytest.mark.parametrize(
+ 'table_name, date_add_kwargs, expected_to_delete, external_trigger',
+ [
+ param('task_instance', dict(days=0), 0, False, id='beginning'),
+ param('task_instance', dict(days=4), 4, False, id='middle'),
+ param('task_instance', dict(days=9), 9, False, id='end_exactly'),
+ param('task_instance', dict(days=9, microseconds=1), 10, False,
id='beyond_end'),
+ param('dag_run', dict(days=9, microseconds=1), 9, False,
id='beyond_end_dr'),
+ param('dag_run', dict(days=9, microseconds=1), 10, True,
id='beyond_end_dr_external'),
+ ],
+ )
+ def test__build_query(self, table_name, date_add_kwargs,
expected_to_delete, external_trigger):
+ """
+ Verify that ``_build_query`` produces a query that would delete the
right
+ task instance records depending on the value of
``clean_before_timestamp``.
+
+ DagRun is a special case where we always keep the last dag run even if
+ the ``clean_before_timestamp`` is in the future, except for
+ externally-triggered dag runs. That is, only the last
non-externally-triggered
+ dag run is kept.
+
+ """
+ base_date = pendulum.DateTime(2022, 1, 1,
tzinfo=pendulum.timezone('America/Los_Angeles'))
+ create_tis(
+ base_date=base_date,
+ num_tis=10,
+ external_trigger=external_trigger,
+ )
+ with create_session() as session:
+ clean_before_date = base_date.add(**date_add_kwargs)
+ query = _build_query(
+ **config_dict[table_name].__dict__,
+ clean_before_timestamp=clean_before_date,
+ session=session,
+ )
+ assert len(query.all()) == expected_to_delete
+
+ @pytest.mark.parametrize(
+ 'table_name, date_add_kwargs, expected_to_delete, external_trigger',
+ [
+ param('task_instance', dict(days=0), 0, False, id='beginning'),
+ param('task_instance', dict(days=4), 4, False, id='middle'),
+ param('task_instance', dict(days=9), 9, False, id='end_exactly'),
+ param('task_instance', dict(days=9, microseconds=1), 10, False,
id='beyond_end'),
+ param('dag_run', dict(days=9, microseconds=1), 9, False,
id='beyond_end_dr'),
+ param('dag_run', dict(days=9, microseconds=1), 10, True,
id='beyond_end_dr_external'),
+ ],
+ )
+ def test__cleanup_table(self, table_name, date_add_kwargs,
expected_to_delete, external_trigger):
+ """
+ Verify that _cleanup_table actually deletes the rows it should.
+
+ TaskInstance represents the "normal" case. DagRun is the odd case
where we want
+ to keep the last non-externally-triggered DagRun record even if if it
should be
+ deleted according to the provided timestamp.
+
+ We also verify that the "on delete cascade" behavior is as expected.
Some tables
+ have foreign keys defined so for example if we delete a dag run, all
its associated
+ task instances should be purged as well. But if we delete task
instances the
+ associated dag runs should remain.
+
+ """
+ base_date = pendulum.DateTime(2022, 1, 1,
tzinfo=pendulum.timezone('America/Los_Angeles'))
+ num_tis = 10
+ create_tis(
+ base_date=base_date,
+ num_tis=num_tis,
+ external_trigger=external_trigger,
+ )
+ with create_session() as session:
+ clean_before_date = base_date.add(**date_add_kwargs)
+ _cleanup_table(
+ **config_dict[table_name].__dict__,
+ clean_before_timestamp=clean_before_date,
+ dry_run=False,
+ session=session,
+ )
+ model = config_dict[table_name].orm_model
+ expected_remaining = num_tis - expected_to_delete
+ assert len(session.query(model).all()) == expected_remaining
+ if model == TaskInstance:
+ assert len(session.query(DagRun).all()) == num_tis
+ elif model == DagRun:
+ assert len(session.query(TaskInstance).all()) ==
expected_remaining
+ else:
+ raise Exception("unexpected")
+
+ def test_no_models_missing(self):
+ """
+ 1. Verify that for all tables in `airflow.models`, we either have them
enabled in db cleanup,
+ or documented in the exclusion list in this test.
+ 2. Verify that no table is enabled for db cleanup and also in
exclusion list.
+ """
+ import pkgutil
+
+ proj_root = Path(__file__).parent.parent.parent
+ mods = list(
+ f"airflow.models.{name}" for _, name, _ in
pkgutil.iter_modules([proj_root / 'airflow/models'])
+ )
+
+ all_models = {}
+ for mod_name in mods:
+ mod = import_module(mod_name)
+
+ for table_name, class_ in mod.__dict__.items():
+ if isinstance(class_, DeclarativeMeta):
+ with suppress(AttributeError):
+ all_models.update({class_.__tablename__: class_})
+ exclusion_list = {
+ 'variable', # leave alone
+ 'trigger', # self-maintaining
+ 'task_map', # TODO: add datetime column to TaskMap so we can
include it here
+ 'serialized_dag', # handled through FK to Dag
+ 'log_template', # not a significant source of data; age not
indicative of staleness
+ 'dag_tag', # not a significant source of data; age not indicative
of staleness,
+ 'dag_pickle', # unsure of consequences
+ 'dag_code', # self-maintaining
+ 'connection', # leave alone
+ 'slot_pool', # leave alone
+ }
+
+ from airflow.utils.db_cleanup import config_dict
+
+ print(f"all_models={set(all_models)}")
+ print(f"excl+conf={exclusion_list.union(config_dict)}")
+ assert set(all_models) - exclusion_list.union(config_dict) == set()
+ assert exclusion_list.isdisjoint(config_dict)
+
+
+def create_tis(base_date, num_tis, external_trigger=False):
+ with create_session() as session:
+ dag = DagModel(dag_id=f'test-dag_{uuid4()}')
+ session.add(dag)
+ for num in range(num_tis):
+ start_date = base_date.add(days=num)
+ dag_run = DagRun(
+ dag.dag_id,
+ run_id=f'abc_{num}',
+ run_type='none',
+ start_date=start_date,
+ external_trigger=external_trigger,
+ )
+ ti = TaskInstance(
+ PythonOperator(task_id='dummy-task', python_callable=print),
run_id=dag_run.run_id
+ )
+ ti.dag_id = dag.dag_id
+ ti.start_date = start_date
+ session.add(dag_run)
+ session.add(ti)
+ session.commit()