This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 56a3987c3f Remove old pre-migration data integrity checks (#41850)
56a3987c3f is described below
commit 56a3987c3fca3a574b8f1c7faa8f08de2becee60
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Aug 29 17:20:41 2024 +0800
Remove old pre-migration data integrity checks (#41850)
---
airflow/utils/db.py | 495 +------------------------------------------------
tests/utils/test_db.py | 84 +--------
2 files changed, 3 insertions(+), 576 deletions(-)
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 5e15e59d3a..0719fe3b9f 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -27,7 +27,6 @@ import os
import sys
import time
import warnings
-from dataclasses import dataclass
from tempfile import gettempdir
from typing import (
TYPE_CHECKING,
@@ -45,8 +44,6 @@ from typing import (
import attrs
from sqlalchemy import (
Table,
- and_,
- column,
delete,
exc,
func,
@@ -54,9 +51,7 @@ from sqlalchemy import (
literal,
or_,
select,
- table,
text,
- tuple_,
)
import airflow
@@ -75,7 +70,7 @@ if TYPE_CHECKING:
from alembic.runtime.environment import EnvironmentContext
from alembic.script import ScriptDirectory
from sqlalchemy.engine import Row
- from sqlalchemy.orm import Query, Session
+ from sqlalchemy.orm import Session
from sqlalchemy.sql.elements import ClauseElement, TextClause
from sqlalchemy.sql.selectable import Select
@@ -1001,60 +996,6 @@ def synchronize_log_template(*, session: Session =
NEW_SESSION) -> None:
session.add(LogTemplate(filename=filename,
elasticsearch_id=elasticsearch_id))
-def check_conn_id_duplicates(session: Session) -> Iterable[str]:
- """
- Check unique conn_id in connection table.
-
- :param session: session of the sqlalchemy
- """
- from airflow.models.connection import Connection
-
- try:
- dups = session.scalars(
-
select(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1)
- ).all()
- except (exc.OperationalError, exc.ProgrammingError):
- # fallback if tables hasn't been created yet
- session.rollback()
- return
- if dups:
- yield (
- "Seems you have non unique conn_id in connection table.\n"
- "You have to manage those duplicate connections "
- "before upgrading the database.\n"
- f"Duplicated conn_id: {dups}"
- )
-
-
-def check_username_duplicates(session: Session) -> Iterable[str]:
- """
- Check unique username in User & RegisterUser table.
-
- :param session: session of the sqlalchemy
- :rtype: str
- """
- from airflow.providers.fab.auth_manager.models import RegisterUser, User
-
- for model in [User, RegisterUser]:
- dups = []
- try:
- dups = session.execute(
- select(model.username) # type: ignore[attr-defined]
- .group_by(model.username) # type: ignore[attr-defined]
- .having(func.count() > 1)
- ).all()
- except (exc.OperationalError, exc.ProgrammingError):
- # fallback if tables hasn't been created yet
- session.rollback()
- if dups:
- yield (
- f"Seems you have mixed case usernames in
{model.__table__.name} table.\n" # type: ignore
- "You have to rename or delete those mixed case usernames "
- "before upgrading the database.\n"
- f"usernames with mixed cases: {[dup.username for dup in dups]}"
- )
-
-
def reflect_tables(tables: list[MappedClassProtocol | str] | None, session):
"""
When running checks prior to upgrades, we use reflection to determine
current state of the database.
@@ -1079,442 +1020,10 @@ def reflect_tables(tables: list[MappedClassProtocol |
str] | None, session):
return metadata
-def check_table_for_duplicates(
- *, session: Session, table_name: str, uniqueness: list[str], version: str
-) -> Iterable[str]:
- """
- Check table for duplicates, given a list of columns which define the
uniqueness of the table.
-
- Usage example:
-
- .. code-block:: python
-
- def check_task_fail_for_duplicates(session):
- from airflow.models.taskfail import TaskFail
-
- metadata = reflect_tables([TaskFail], session)
- task_fail = metadata.tables.get(TaskFail.__tablename__) # type:
ignore
- if task_fail is None: # table not there
- return
- if "run_id" in task_fail.columns: # upgrade already applied
- return
- yield from check_table_for_duplicates(
- table_name=task_fail.name,
- uniqueness=["dag_id", "task_id", "execution_date"],
- session=session,
- version="2.3",
- )
-
- :param table_name: table name to check
- :param uniqueness: uniqueness constraint to evaluate against
- :param session: session of the sqlalchemy
- """
- minimal_table_obj = table(table_name, *(column(x) for x in uniqueness))
- try:
- subquery = session.execute(
- select(minimal_table_obj, func.count().label("dupe_count"))
- .group_by(*(text(x) for x in uniqueness))
- .having(func.count() > text("1"))
- .subquery()
- )
- dupe_count = session.scalar(select(func.sum(subquery.c.dupe_count)))
- if not dupe_count:
- # there are no duplicates; nothing to do.
- return
-
- log.warning("Found %s duplicates in table %s. Will attempt to move
them.", dupe_count, table_name)
-
- metadata = reflect_tables(tables=[table_name], session=session)
- if table_name not in metadata.tables:
- yield f"Table {table_name} does not exist in the database."
-
- # We can't use the model here since it may differ from the db state
due to
- # this function is run prior to migration. Use the reflected table
instead.
- table_obj = metadata.tables[table_name]
-
- _move_duplicate_data_to_new_table(
- session=session,
- source_table=table_obj,
- subquery=subquery,
- uniqueness=uniqueness,
- target_table_name=_format_airflow_moved_table_name(table_name,
version, "duplicates"),
- )
- except (exc.OperationalError, exc.ProgrammingError):
- # fallback if `table_name` hasn't been created yet
- session.rollback()
-
-
-def check_conn_type_null(session: Session) -> Iterable[str]:
- """
- Check nullable conn_type column in Connection table.
-
- :param session: session of the sqlalchemy
- """
- from airflow.models.connection import Connection
-
- try:
- n_nulls =
session.scalars(select(Connection.conn_id).where(Connection.conn_type.is_(None))).all()
- except (exc.OperationalError, exc.ProgrammingError, exc.InternalError):
- # fallback if tables hasn't been created yet
- session.rollback()
- return
-
- if n_nulls:
- yield (
- "The conn_type column in the connection "
- "table must contain content.\n"
- "Make sure you don't have null "
- "in the conn_type column.\n"
- f"Null conn_type conn_id: {n_nulls}"
- )
-
-
-def _format_dangling_error(source_table, target_table, invalid_count, reason):
- noun = "row" if invalid_count == 1 else "rows"
- return (
- f"The {source_table} table has {invalid_count} {noun} {reason}, which "
- f"is invalid. We could not move them out of the way because the "
- f"{target_table} table already exists in your database. Please either "
- f"drop the {target_table} table, or manually delete the invalid rows "
- f"from the {source_table} table."
- )
-
-
-def check_run_id_null(session: Session) -> Iterable[str]:
- from airflow.models.dagrun import DagRun
-
- metadata = reflect_tables([DagRun], session)
-
- # We can't use the model here since it may differ from the db state due to
- # this function is run prior to migration. Use the reflected table instead.
- dagrun_table = metadata.tables.get(DagRun.__tablename__)
- if dagrun_table is None:
- return
-
- invalid_dagrun_filter = or_(
- dagrun_table.c.dag_id.is_(None),
- dagrun_table.c.run_id.is_(None),
- dagrun_table.c.execution_date.is_(None),
- )
- invalid_dagrun_count =
session.scalar(select(func.count(dagrun_table.c.id)).where(invalid_dagrun_filter))
- if invalid_dagrun_count > 0:
- dagrun_dangling_table_name =
_format_airflow_moved_table_name(dagrun_table.name, "2.2", "dangling")
- if dagrun_dangling_table_name in
inspect(session.get_bind()).get_table_names():
- yield _format_dangling_error(
- source_table=dagrun_table.name,
- target_table=dagrun_dangling_table_name,
- invalid_count=invalid_dagrun_count,
- reason="with a NULL dag_id, run_id, or execution_date",
- )
- return
-
- bind = session.get_bind()
- dialect_name = bind.dialect.name
- _create_table_as(
- dialect_name=dialect_name,
- source_query=dagrun_table.select(invalid_dagrun_filter),
- target_table_name=dagrun_dangling_table_name,
- source_table_name=dagrun_table.name,
- session=session,
- )
- delete = dagrun_table.delete().where(invalid_dagrun_filter)
- session.execute(delete)
-
-
-def _create_table_as(
- *,
- session,
- dialect_name: str,
- source_query: Query,
- target_table_name: str,
- source_table_name: str,
-):
- """
- Create a new table with rows from query.
-
- We have to handle CTAS differently for different dialects.
- """
- if dialect_name == "mysql":
- # MySQL with replication needs this split in to two queries, so just
do it for all MySQL
- # ERROR 1786 (HY000): Statement violates GTID consistency: CREATE
TABLE ... SELECT.
- session.execute(text(f"CREATE TABLE {target_table_name} LIKE
{source_table_name}"))
- session.execute(
- text(
- f"INSERT INTO {target_table_name}
{source_query.selectable.compile(bind=session.get_bind())}"
- )
- )
- else:
- # Postgres and SQLite both support the same "CREATE TABLE a AS SELECT
..." syntax
- select_table = source_query.selectable.compile(bind=session.get_bind())
- session.execute(text(f"CREATE TABLE {target_table_name} AS
{select_table}"))
-
-
-def _move_dangling_data_to_new_table(
- session, source_table: Table, source_query: Query, target_table_name: str
-):
- bind = session.get_bind()
- dialect_name = bind.dialect.name
-
- # First: Create moved rows from new table
- log.debug("running CTAS for table %s", target_table_name)
- _create_table_as(
- dialect_name=dialect_name,
- source_query=source_query,
- target_table_name=target_table_name,
- source_table_name=source_table.name,
- session=session,
- )
- session.commit()
-
- target_table = source_table.to_metadata(source_table.metadata,
name=target_table_name)
- log.debug("checking whether rows were moved for table %s",
target_table_name)
- moved_rows_exist_query = select(1).select_from(target_table).limit(1)
- first_moved_row = session.execute(moved_rows_exist_query).all()
- session.commit()
-
- if not first_moved_row:
- log.debug("no rows moved; dropping %s", target_table_name)
- # no bad rows were found; drop moved rows table.
- target_table.drop(bind=session.get_bind(), checkfirst=True)
- else:
- log.debug("rows moved; purging from %s", source_table.name)
- if dialect_name == "sqlite":
- pk_cols = source_table.primary_key.columns
-
- delete = source_table.delete().where(
-
tuple_(*pk_cols).in_(session.select(*target_table.primary_key.columns).subquery())
- )
- else:
- delete = source_table.delete().where(
- and_(col == target_table.c[col.name] for col in
source_table.primary_key.columns)
- )
- log.debug(delete.compile())
- session.execute(delete)
- session.commit()
-
- log.debug("exiting move function")
-
-
-def _dangling_against_dag_run(session, source_table, dag_run):
- """Given a source table, we generate a subquery that will return 1 for
every row that has a dagrun."""
- source_to_dag_run_join_cond = and_(
- source_table.c.dag_id == dag_run.c.dag_id,
- source_table.c.execution_date == dag_run.c.execution_date,
- )
-
- return (
- select(*(c.label(c.name) for c in source_table.c))
- .join(dag_run, source_to_dag_run_join_cond, isouter=True)
- .where(dag_run.c.dag_id.is_(None))
- )
-
-
-def _dangling_against_task_instance(session, source_table, dag_run,
task_instance):
- """
- Given a source table, generate a subquery that will return 1 for every row
that has a valid task instance.
-
- This is used to identify rows that need to be removed from tables prior to
adding a TI fk.
-
- Since this check is applied prior to running the migrations, we have to
use different
- query logic depending on which revision the database is at.
-
- """
- if "run_id" not in task_instance.c:
- # db is < 2.2.0
- dr_join_cond = and_(
- source_table.c.dag_id == dag_run.c.dag_id,
- source_table.c.execution_date == dag_run.c.execution_date,
- )
- ti_join_cond = and_(
- dag_run.c.dag_id == task_instance.c.dag_id,
- dag_run.c.execution_date == task_instance.c.execution_date,
- source_table.c.task_id == task_instance.c.task_id,
- )
- else:
- # db is 2.2.0 <= version < 2.3.0
- dr_join_cond = and_(
- source_table.c.dag_id == dag_run.c.dag_id,
- source_table.c.execution_date == dag_run.c.execution_date,
- )
- ti_join_cond = and_(
- dag_run.c.dag_id == task_instance.c.dag_id,
- dag_run.c.run_id == task_instance.c.run_id,
- source_table.c.task_id == task_instance.c.task_id,
- )
-
- return (
- select(*(c.label(c.name) for c in source_table.c))
- .outerjoin(dag_run, dr_join_cond)
- .outerjoin(task_instance, ti_join_cond)
- .where(or_(task_instance.c.dag_id.is_(None),
dag_run.c.dag_id.is_(None)))
- )
-
-
-def _move_duplicate_data_to_new_table(
- session, source_table: Table, subquery: Query, uniqueness: list[str],
target_table_name: str
-):
- """
- When adding a uniqueness constraint we first should ensure that there are
no duplicate rows.
-
- This function accepts a subquery that should return one record for each
row with duplicates (e.g.
- a group by with having count(*) > 1). We select from ``source_table``
getting all rows matching the
- subquery result and store in ``target_table_name``. Then to purge the
duplicates from the source table,
- we do a DELETE FROM with a join to the target table (which now contains
the dupes).
-
- :param session: sqlalchemy session for metadata db
- :param source_table: table to purge dupes from
- :param subquery: the subquery that returns the duplicate rows
- :param uniqueness: the string list of columns used to define the
uniqueness for the table. used in
- building the DELETE FROM join condition.
- :param target_table_name: name of the table in which to park the duplicate
rows
- """
- bind = session.get_bind()
- dialect_name = bind.dialect.name
-
- query = (
- select(*(source_table.c[x.name].label(str(x.name)) for x in
source_table.columns))
- .select_from(source_table)
- .join(subquery, and_(*(source_table.c[x] == subquery.c[x] for x in
uniqueness)))
- )
-
- _create_table_as(
- session=session,
- dialect_name=dialect_name,
- source_query=query,
- target_table_name=target_table_name,
- source_table_name=source_table.name,
- )
-
- # we must ensure that the CTAS table is created prior to the DELETE step
since we have to join to it
- session.commit()
-
- metadata = reflect_tables([target_table_name], session)
- target_table = metadata.tables[target_table_name]
- where_clause = and_(*(source_table.c[x] == target_table.c[x] for x in
uniqueness))
-
- if dialect_name == "sqlite":
- subq =
query.selectable.with_only_columns([text(f"{source_table}.ROWID")])
- delete = source_table.delete().where(column("ROWID").in_(subq))
- else:
- delete = source_table.delete(where_clause)
-
- session.execute(delete)
-
-
-def check_bad_references(session: Session) -> Iterable[str]:
- """
- Go through each table and look for records that can't be mapped to a dag
run.
-
- When we find such "dangling" rows we back them up in a special table and
delete them
- from the main table.
-
- Starting in Airflow 2.2, we began a process of replacing `execution_date`
with `run_id` in many tables.
- """
- from airflow.models.dagrun import DagRun
- from airflow.models.renderedtifields import RenderedTaskInstanceFields
- from airflow.models.taskfail import TaskFail
- from airflow.models.taskinstance import TaskInstance
- from airflow.models.taskreschedule import TaskReschedule
- from airflow.models.xcom import XCom
-
- @dataclass
- class BadReferenceConfig:
- """
- Bad reference config class.
-
- :param bad_rows_func: function that returns subquery which determines
whether bad rows exist
- :param join_tables: table objects referenced in subquery
- :param ref_table: information-only identifier for categorizing the
missing ref
- """
-
- bad_rows_func: Callable
- join_tables: list[str]
- ref_table: str
-
- missing_dag_run_config = BadReferenceConfig(
- bad_rows_func=_dangling_against_dag_run,
- join_tables=["dag_run"],
- ref_table="dag_run",
- )
-
- missing_ti_config = BadReferenceConfig(
- bad_rows_func=_dangling_against_task_instance,
- join_tables=["dag_run", "task_instance"],
- ref_table="task_instance",
- )
-
- models_list: list[tuple[MappedClassProtocol, str, BadReferenceConfig]] = [
- (TaskInstance, "2.2", missing_dag_run_config),
- (TaskReschedule, "2.2", missing_ti_config),
- (RenderedTaskInstanceFields, "2.3", missing_ti_config),
- (TaskFail, "2.3", missing_ti_config),
- (XCom, "2.3", missing_ti_config),
- ]
- metadata = reflect_tables([*(x[0] for x in models_list), DagRun,
TaskInstance], session)
-
- if (
- not metadata.tables
- or metadata.tables.get(DagRun.__tablename__) is None
- or metadata.tables.get(TaskInstance.__tablename__) is None
- ):
- # Key table doesn't exist -- likely empty DB.
- return
-
- existing_table_names = set(inspect(session.get_bind()).get_table_names())
- errored = False
-
- for model, change_version, bad_ref_cfg in models_list:
- log.debug("checking model %s", model.__tablename__)
- # We can't use the model here since it may differ from the db state
due to
- # this function is run prior to migration. Use the reflected table
instead.
- source_table = metadata.tables.get(model.__tablename__) # type: ignore
- if source_table is None:
- continue
-
- # Migration already applied, don't check again.
- if "run_id" in source_table.columns:
- continue
-
- func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables}
- bad_rows_query = bad_ref_cfg.bad_rows_func(session, source_table,
**func_kwargs)
-
- dangling_table_name =
_format_airflow_moved_table_name(source_table.name, change_version, "dangling")
- if dangling_table_name in existing_table_names:
- invalid_row_count = get_query_count(bad_rows_query,
session=session)
- if invalid_row_count:
- yield _format_dangling_error(
- source_table=source_table.name,
- target_table=dangling_table_name,
- invalid_count=invalid_row_count,
- reason=f"without a corresponding {bad_ref_cfg.ref_table}
row",
- )
- errored = True
- continue
-
- log.debug("moving data for table %s", source_table.name)
- _move_dangling_data_to_new_table(
- session,
- source_table,
- bad_rows_query,
- dangling_table_name,
- )
-
- if errored:
- session.rollback()
- else:
- session.commit()
-
-
@provide_session
def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]:
""":session: session of the sqlalchemy."""
- check_functions: tuple[Callable[..., Iterable[str]], ...] = (
- check_conn_id_duplicates,
- check_conn_type_null,
- check_run_id_null,
- check_bad_references,
- check_username_duplicates,
- )
+ check_functions: Iterable[Callable[..., Iterable[str]]] = ()
for check_fn in check_functions:
log.debug("running check function %s", check_fn.__name__)
yield from check_fn(session=session)
diff --git a/tests/utils/test_db.py b/tests/utils/test_db.py
index 4d32d132b2..4d36d98a13 100644
--- a/tests/utils/test_db.py
+++ b/tests/utils/test_db.py
@@ -31,14 +31,12 @@ from alembic.config import Config
from alembic.migration import MigrationContext
from alembic.runtime.environment import EnvironmentContext
from alembic.script import ScriptDirectory
-from sqlalchemy import MetaData, Table
-from sqlalchemy.sql import Select
+from sqlalchemy import MetaData
from airflow.models import Base as airflow_base
from airflow.settings import engine
from airflow.utils.db import (
_get_alembic_config,
- check_bad_references,
check_migrations,
compare_server_default,
compare_type,
@@ -51,7 +49,6 @@ from airflow.utils.db import (
upgradedb,
)
from airflow.utils.db_manager import RunDBManager
-from airflow.utils.session import NEW_SESSION
pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode]
@@ -238,82 +235,3 @@ class TestDb:
import airflow
assert config.config_file_name ==
os.path.join(os.path.dirname(airflow.__file__), "alembic.ini")
-
- @mock.patch("airflow.utils.db._move_dangling_data_to_new_table")
- @mock.patch("airflow.utils.db.get_query_count")
- @mock.patch("airflow.utils.db._dangling_against_task_instance")
- @mock.patch("airflow.utils.db._dangling_against_dag_run")
- @mock.patch("airflow.utils.db.reflect_tables")
- @mock.patch("airflow.utils.db.inspect")
- def test_check_bad_references(
- self,
- mock_inspect: MagicMock,
- mock_reflect_tables: MagicMock,
- mock_dangling_against_dag_run: MagicMock,
- mock_dangling_against_task_instance: MagicMock,
- mock_get_query_count: MagicMock,
- mock_move_dangling_data_to_new_table: MagicMock,
- ):
- from airflow.models.dagrun import DagRun
- from airflow.models.renderedtifields import RenderedTaskInstanceFields
- from airflow.models.taskfail import TaskFail
- from airflow.models.taskinstance import TaskInstance
- from airflow.models.taskreschedule import TaskReschedule
- from airflow.models.xcom import XCom
-
- mock_session = MagicMock(spec=NEW_SESSION)
- mock_bind = MagicMock()
- mock_session.get_bind.return_value = mock_bind
- task_instance_table = MagicMock(spec=Table)
- task_instance_table.name = TaskInstance.__tablename__
- dag_run_table = MagicMock(spec=Table)
- task_fail_table = MagicMock(spec=Table)
- task_fail_table.name = TaskFail.__tablename__
-
- mock_reflect_tables.return_value = MagicMock(
- tables={
- DagRun.__tablename__: dag_run_table,
- TaskInstance.__tablename__: task_instance_table,
- TaskFail.__tablename__: task_fail_table,
- }
- )
-
- # Simulate that there is a moved `task_instance` table from the
- # previous run, but no moved `task_fail` table
- dangling_task_instance_table_name =
f"_airflow_moved__2_2__dangling__{task_instance_table.name}"
- dangling_task_fail_table_name =
f"_airflow_moved__2_3__dangling__{task_fail_table.name}"
- mock_get_table_names = MagicMock(
- return_value=[
- TaskInstance.__tablename__,
- DagRun.__tablename__,
- TaskFail.__tablename__,
- dangling_task_instance_table_name,
- ]
- )
- mock_inspect.return_value = MagicMock(
- get_table_names=mock_get_table_names,
- )
- mock_select = MagicMock(spec=Select)
- mock_dangling_against_dag_run.return_value = mock_select
- mock_dangling_against_task_instance.return_value = mock_select
- mock_get_query_count.return_value = 1
-
- # Should return a single error related to the dangling `task_instance`
table
- errs = list(check_bad_references(session=mock_session))
- assert len(errs) == 1
- assert dangling_task_instance_table_name in errs[0]
-
- mock_reflect_tables.assert_called_once_with(
- [TaskInstance, TaskReschedule, RenderedTaskInstanceFields,
TaskFail, XCom, DagRun, TaskInstance],
- mock_session,
- )
- mock_inspect.assert_called_once_with(mock_bind)
- mock_get_table_names.assert_called_once()
- mock_dangling_against_dag_run.assert_called_once_with(
- mock_session, task_instance_table, dag_run=dag_run_table
- )
- mock_get_query_count.assert_called_once_with(mock_select,
session=mock_session)
- mock_move_dangling_data_to_new_table.assert_called_once_with(
- mock_session, task_fail_table, mock_select,
dangling_task_fail_table_name
- )
- mock_session.rollback.assert_called_once()