uranusjr commented on code in PR #41850:
URL: https://github.com/apache/airflow/pull/41850#discussion_r1735849100


##########
airflow/utils/db.py:
##########
@@ -1079,442 +1020,10 @@ def reflect_tables(tables: list[MappedClassProtocol | 
str] | None, session):
     return metadata
 
 
-def check_table_for_duplicates(
-    *, session: Session, table_name: str, uniqueness: list[str], version: str
-) -> Iterable[str]:
-    """
-    Check table for duplicates, given a list of columns which define the 
uniqueness of the table.
-
-    Usage example:
-
-    .. code-block:: python
-
-        def check_task_fail_for_duplicates(session):
-            from airflow.models.taskfail import TaskFail
-
-            metadata = reflect_tables([TaskFail], session)
-            task_fail = metadata.tables.get(TaskFail.__tablename__)  # type: 
ignore
-            if task_fail is None:  # table not there
-                return
-            if "run_id" in task_fail.columns:  # upgrade already applied
-                return
-            yield from check_table_for_duplicates(
-                table_name=task_fail.name,
-                uniqueness=["dag_id", "task_id", "execution_date"],
-                session=session,
-                version="2.3",
-            )
-
-    :param table_name: table name to check
-    :param uniqueness: uniqueness constraint to evaluate against
-    :param session:  session of the sqlalchemy
-    """
-    minimal_table_obj = table(table_name, *(column(x) for x in uniqueness))
-    try:
-        subquery = session.execute(
-            select(minimal_table_obj, func.count().label("dupe_count"))
-            .group_by(*(text(x) for x in uniqueness))
-            .having(func.count() > text("1"))
-            .subquery()
-        )
-        dupe_count = session.scalar(select(func.sum(subquery.c.dupe_count)))
-        if not dupe_count:
-            # there are no duplicates; nothing to do.
-            return
-
-        log.warning("Found %s duplicates in table %s.  Will attempt to move 
them.", dupe_count, table_name)
-
-        metadata = reflect_tables(tables=[table_name], session=session)
-        if table_name not in metadata.tables:
-            yield f"Table {table_name} does not exist in the database."
-
-        # We can't use the model here since it may differ from the db state 
due to
-        # this function is run prior to migration. Use the reflected table 
instead.
-        table_obj = metadata.tables[table_name]
-
-        _move_duplicate_data_to_new_table(
-            session=session,
-            source_table=table_obj,
-            subquery=subquery,
-            uniqueness=uniqueness,
-            target_table_name=_format_airflow_moved_table_name(table_name, 
version, "duplicates"),
-        )
-    except (exc.OperationalError, exc.ProgrammingError):
-        # fallback if `table_name` hasn't been created yet
-        session.rollback()
-
-
-def check_conn_type_null(session: Session) -> Iterable[str]:
-    """
-    Check nullable conn_type column in Connection table.
-
-    :param session:  session of the sqlalchemy
-    """
-    from airflow.models.connection import Connection
-
-    try:
-        n_nulls = 
session.scalars(select(Connection.conn_id).where(Connection.conn_type.is_(None))).all()
-    except (exc.OperationalError, exc.ProgrammingError, exc.InternalError):
-        # fallback if tables hasn't been created yet
-        session.rollback()
-        return
-
-    if n_nulls:
-        yield (
-            "The conn_type column in the connection "
-            "table must contain content.\n"
-            "Make sure you don't have null "
-            "in the conn_type column.\n"
-            f"Null conn_type conn_id: {n_nulls}"
-        )
-
-
-def _format_dangling_error(source_table, target_table, invalid_count, reason):
-    noun = "row" if invalid_count == 1 else "rows"
-    return (
-        f"The {source_table} table has {invalid_count} {noun} {reason}, which "
-        f"is invalid. We could not move them out of the way because the "
-        f"{target_table} table already exists in your database. Please either "
-        f"drop the {target_table} table, or manually delete the invalid rows "
-        f"from the {source_table} table."
-    )
-
-
-def check_run_id_null(session: Session) -> Iterable[str]:
-    from airflow.models.dagrun import DagRun
-
-    metadata = reflect_tables([DagRun], session)
-
-    # We can't use the model here since it may differ from the db state due to
-    # this function is run prior to migration. Use the reflected table instead.
-    dagrun_table = metadata.tables.get(DagRun.__tablename__)
-    if dagrun_table is None:
-        return
-
-    invalid_dagrun_filter = or_(
-        dagrun_table.c.dag_id.is_(None),
-        dagrun_table.c.run_id.is_(None),
-        dagrun_table.c.execution_date.is_(None),
-    )
-    invalid_dagrun_count = 
session.scalar(select(func.count(dagrun_table.c.id)).where(invalid_dagrun_filter))
-    if invalid_dagrun_count > 0:
-        dagrun_dangling_table_name = 
_format_airflow_moved_table_name(dagrun_table.name, "2.2", "dangling")
-        if dagrun_dangling_table_name in 
inspect(session.get_bind()).get_table_names():
-            yield _format_dangling_error(
-                source_table=dagrun_table.name,
-                target_table=dagrun_dangling_table_name,
-                invalid_count=invalid_dagrun_count,
-                reason="with a NULL dag_id, run_id, or execution_date",
-            )
-            return
-
-        bind = session.get_bind()
-        dialect_name = bind.dialect.name
-        _create_table_as(
-            dialect_name=dialect_name,
-            source_query=dagrun_table.select(invalid_dagrun_filter),
-            target_table_name=dagrun_dangling_table_name,
-            source_table_name=dagrun_table.name,
-            session=session,
-        )
-        delete = dagrun_table.delete().where(invalid_dagrun_filter)
-        session.execute(delete)
-
-
-def _create_table_as(
-    *,
-    session,
-    dialect_name: str,
-    source_query: Query,
-    target_table_name: str,
-    source_table_name: str,
-):
-    """
-    Create a new table with rows from query.
-
-    We have to handle CTAS differently for different dialects.
-    """
-    if dialect_name == "mysql":
-        # MySQL with replication needs this split in to two queries, so just 
do it for all MySQL
-        # ERROR 1786 (HY000): Statement violates GTID consistency: CREATE 
TABLE ... SELECT.
-        session.execute(text(f"CREATE TABLE {target_table_name} LIKE 
{source_table_name}"))
-        session.execute(
-            text(
-                f"INSERT INTO {target_table_name} 
{source_query.selectable.compile(bind=session.get_bind())}"
-            )
-        )
-    else:
-        # Postgres and SQLite both support the same "CREATE TABLE a AS SELECT 
..." syntax
-        select_table = source_query.selectable.compile(bind=session.get_bind())
-        session.execute(text(f"CREATE TABLE {target_table_name} AS 
{select_table}"))
-
-
-def _move_dangling_data_to_new_table(
-    session, source_table: Table, source_query: Query, target_table_name: str
-):
-    bind = session.get_bind()
-    dialect_name = bind.dialect.name
-
-    # First: Create moved rows from new table
-    log.debug("running CTAS for table %s", target_table_name)
-    _create_table_as(
-        dialect_name=dialect_name,
-        source_query=source_query,
-        target_table_name=target_table_name,
-        source_table_name=source_table.name,
-        session=session,
-    )
-    session.commit()
-
-    target_table = source_table.to_metadata(source_table.metadata, 
name=target_table_name)
-    log.debug("checking whether rows were moved for table %s", 
target_table_name)
-    moved_rows_exist_query = select(1).select_from(target_table).limit(1)
-    first_moved_row = session.execute(moved_rows_exist_query).all()
-    session.commit()
-
-    if not first_moved_row:
-        log.debug("no rows moved; dropping %s", target_table_name)
-        # no bad rows were found; drop moved rows table.
-        target_table.drop(bind=session.get_bind(), checkfirst=True)
-    else:
-        log.debug("rows moved; purging from %s", source_table.name)
-        if dialect_name == "sqlite":
-            pk_cols = source_table.primary_key.columns
-
-            delete = source_table.delete().where(
-                
tuple_(*pk_cols).in_(session.select(*target_table.primary_key.columns).subquery())
-            )
-        else:
-            delete = source_table.delete().where(
-                and_(col == target_table.c[col.name] for col in 
source_table.primary_key.columns)
-            )
-        log.debug(delete.compile())
-        session.execute(delete)
-    session.commit()
-
-    log.debug("exiting move function")
-
-
-def _dangling_against_dag_run(session, source_table, dag_run):
-    """Given a source table, we generate a subquery that will return 1 for 
every row that has a dagrun."""
-    source_to_dag_run_join_cond = and_(
-        source_table.c.dag_id == dag_run.c.dag_id,
-        source_table.c.execution_date == dag_run.c.execution_date,
-    )
-
-    return (
-        select(*(c.label(c.name) for c in source_table.c))
-        .join(dag_run, source_to_dag_run_join_cond, isouter=True)
-        .where(dag_run.c.dag_id.is_(None))
-    )
-
-
-def _dangling_against_task_instance(session, source_table, dag_run, 
task_instance):
-    """
-    Given a source table, generate a subquery that will return 1 for every row 
that has a valid task instance.
-
-    This is used to identify rows that need to be removed from tables prior to 
adding a TI fk.
-
-    Since this check is applied prior to running the migrations, we have to 
use different
-    query logic depending on which revision the database is at.
-
-    """
-    if "run_id" not in task_instance.c:
-        # db is < 2.2.0
-        dr_join_cond = and_(
-            source_table.c.dag_id == dag_run.c.dag_id,
-            source_table.c.execution_date == dag_run.c.execution_date,
-        )
-        ti_join_cond = and_(
-            dag_run.c.dag_id == task_instance.c.dag_id,
-            dag_run.c.execution_date == task_instance.c.execution_date,
-            source_table.c.task_id == task_instance.c.task_id,
-        )
-    else:
-        # db is 2.2.0 <= version < 2.3.0
-        dr_join_cond = and_(
-            source_table.c.dag_id == dag_run.c.dag_id,
-            source_table.c.execution_date == dag_run.c.execution_date,
-        )
-        ti_join_cond = and_(
-            dag_run.c.dag_id == task_instance.c.dag_id,
-            dag_run.c.run_id == task_instance.c.run_id,
-            source_table.c.task_id == task_instance.c.task_id,
-        )
-
-    return (
-        select(*(c.label(c.name) for c in source_table.c))
-        .outerjoin(dag_run, dr_join_cond)
-        .outerjoin(task_instance, ti_join_cond)
-        .where(or_(task_instance.c.dag_id.is_(None), 
dag_run.c.dag_id.is_(None)))
-    )
-
-
-def _move_duplicate_data_to_new_table(
-    session, source_table: Table, subquery: Query, uniqueness: list[str], 
target_table_name: str
-):
-    """
-    When adding a uniqueness constraint we first should ensure that there are 
no duplicate rows.
-
-    This function accepts a subquery that should return one record for each 
row with duplicates (e.g.
-    a group by with having count(*) > 1).  We select from ``source_table`` 
getting all rows matching the
-    subquery result and store in ``target_table_name``.  Then to purge the 
duplicates from the source table,
-    we do a DELETE FROM with a join to the target table (which now contains 
the dupes).
-
-    :param session: sqlalchemy session for metadata db
-    :param source_table: table to purge dupes from
-    :param subquery: the subquery that returns the duplicate rows
-    :param uniqueness: the string list of columns used to define the 
uniqueness for the table. used in
-        building the DELETE FROM join condition.
-    :param target_table_name: name of the table in which to park the duplicate 
rows
-    """
-    bind = session.get_bind()
-    dialect_name = bind.dialect.name
-
-    query = (
-        select(*(source_table.c[x.name].label(str(x.name)) for x in 
source_table.columns))
-        .select_from(source_table)
-        .join(subquery, and_(*(source_table.c[x] == subquery.c[x] for x in 
uniqueness)))
-    )
-
-    _create_table_as(
-        session=session,
-        dialect_name=dialect_name,
-        source_query=query,
-        target_table_name=target_table_name,
-        source_table_name=source_table.name,
-    )
-
-    # we must ensure that the CTAS table is created prior to the DELETE step 
since we have to join to it
-    session.commit()
-
-    metadata = reflect_tables([target_table_name], session)
-    target_table = metadata.tables[target_table_name]
-    where_clause = and_(*(source_table.c[x] == target_table.c[x] for x in 
uniqueness))
-
-    if dialect_name == "sqlite":
-        subq = 
query.selectable.with_only_columns([text(f"{source_table}.ROWID")])
-        delete = source_table.delete().where(column("ROWID").in_(subq))
-    else:
-        delete = source_table.delete(where_clause)
-
-    session.execute(delete)
-
-
-def check_bad_references(session: Session) -> Iterable[str]:
-    """
-    Go through each table and look for records that can't be mapped to a dag 
run.
-
-    When we find such "dangling" rows we back them up in a special table and 
delete them
-    from the main table.
-
-    Starting in Airflow 2.2, we began a process of replacing `execution_date` 
with `run_id` in many tables.
-    """
-    from airflow.models.dagrun import DagRun
-    from airflow.models.renderedtifields import RenderedTaskInstanceFields
-    from airflow.models.taskfail import TaskFail
-    from airflow.models.taskinstance import TaskInstance
-    from airflow.models.taskreschedule import TaskReschedule
-    from airflow.models.xcom import XCom
-
-    @dataclass
-    class BadReferenceConfig:
-        """
-        Bad reference config class.
-
-        :param bad_rows_func: function that returns subquery which determines 
whether bad rows exist
-        :param join_tables: table objects referenced in subquery
-        :param ref_table: information-only identifier for categorizing the 
missing ref
-        """
-
-        bad_rows_func: Callable
-        join_tables: list[str]
-        ref_table: str
-
-    missing_dag_run_config = BadReferenceConfig(
-        bad_rows_func=_dangling_against_dag_run,
-        join_tables=["dag_run"],
-        ref_table="dag_run",
-    )
-
-    missing_ti_config = BadReferenceConfig(
-        bad_rows_func=_dangling_against_task_instance,
-        join_tables=["dag_run", "task_instance"],
-        ref_table="task_instance",
-    )
-
-    models_list: list[tuple[MappedClassProtocol, str, BadReferenceConfig]] = [
-        (TaskInstance, "2.2", missing_dag_run_config),
-        (TaskReschedule, "2.2", missing_ti_config),
-        (RenderedTaskInstanceFields, "2.3", missing_ti_config),
-        (TaskFail, "2.3", missing_ti_config),
-        (XCom, "2.3", missing_ti_config),
-    ]
-    metadata = reflect_tables([*(x[0] for x in models_list), DagRun, 
TaskInstance], session)
-
-    if (
-        not metadata.tables
-        or metadata.tables.get(DagRun.__tablename__) is None
-        or metadata.tables.get(TaskInstance.__tablename__) is None
-    ):
-        # Key table doesn't exist -- likely empty DB.
-        return
-
-    existing_table_names = set(inspect(session.get_bind()).get_table_names())
-    errored = False
-
-    for model, change_version, bad_ref_cfg in models_list:
-        log.debug("checking model %s", model.__tablename__)
-        # We can't use the model here since it may differ from the db state 
due to
-        # this function is run prior to migration. Use the reflected table 
instead.
-        source_table = metadata.tables.get(model.__tablename__)  # type: ignore
-        if source_table is None:
-            continue
-
-        # Migration already applied, don't check again.
-        if "run_id" in source_table.columns:
-            continue
-
-        func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables}
-        bad_rows_query = bad_ref_cfg.bad_rows_func(session, source_table, 
**func_kwargs)
-
-        dangling_table_name = 
_format_airflow_moved_table_name(source_table.name, change_version, "dangling")
-        if dangling_table_name in existing_table_names:
-            invalid_row_count = get_query_count(bad_rows_query, 
session=session)
-            if invalid_row_count:
-                yield _format_dangling_error(
-                    source_table=source_table.name,
-                    target_table=dangling_table_name,
-                    invalid_count=invalid_row_count,
-                    reason=f"without a corresponding {bad_ref_cfg.ref_table} 
row",
-                )
-                errored = True
-            continue
-
-        log.debug("moving data for table %s", source_table.name)
-        _move_dangling_data_to_new_table(
-            session,
-            source_table,
-            bad_rows_query,
-            dangling_table_name,
-        )
-
-    if errored:
-        session.rollback()
-    else:
-        session.commit()
-
-
 @provide_session
 def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]:
     """:session: session of the sqlalchemy."""
-    check_functions: tuple[Callable[..., Iterable[str]], ...] = (
-        check_conn_id_duplicates,
-        check_conn_type_null,
-        check_run_id_null,
-        check_bad_references,
-        check_username_duplicates,
-    )
+    check_functions: Iterable[Callable[..., Iterable[str]]] = ()

Review Comment:
   I was thinking it would be needed in the future anyway and it’s easier to 
just keep this here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to