This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 327b83967881b5c5a8973b8d2dbb1e8c39c051d6
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Fri Oct 15 01:09:54 2021 +0800

    Try to move "dangling" rows in upgradedb (#18953)
    
    (cherry picked from commit f967ca91058b4296edb507c7826282050188b501)
---
 airflow/settings.py                     |   3 +
 airflow/utils/db.py                     | 164 ++++++++++++++++++++------------
 airflow/www/templates/airflow/dags.html |   8 ++
 airflow/www/views.py                    |  13 ++-
 tests/www/views/test_views_base.py      |   2 +-
 5 files changed, 128 insertions(+), 62 deletions(-)

diff --git a/airflow/settings.py b/airflow/settings.py
index fc2c6bb..9a456ba 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -563,3 +563,6 @@ MASK_SECRETS_IN_LOGS = False
 #
 # DASHBOARD_UIALERTS: List["UIAlert"]
 DASHBOARD_UIALERTS = []
+
+# Prefix used to identify tables holding data moved during migration.
+AIRFLOW_MOVED_TABLE_PREFIX = "_airflow_moved"
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 13dd401..17522d3 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -20,7 +20,7 @@ import os
 import time
 from typing import Iterable
 
-from sqlalchemy import Table, exc, func
+from sqlalchemy import Table, exc, func, inspect, or_, text
 
 from airflow import settings
 from airflow.configuration import conf
@@ -51,15 +51,15 @@ from airflow.models import (  # noqa: F401
 from airflow.models.serialized_dag import SerializedDagModel  # noqa: F401
 
 # TODO: remove create_session once we decide to break backward compatibility
-from airflow.utils.session import (  # noqa: F401 # pylint: 
disable=unused-import
-    create_global_lock,
-    create_session,
-    provide_session,
-)
+from airflow.utils.session import create_global_lock, create_session, 
provide_session  # noqa: F401
 
 log = logging.getLogger(__name__)
 
 
+def _format_airflow_moved_table_name(source_table, version):
+    return "__".join([settings.AIRFLOW_MOVED_TABLE_PREFIX, 
version.replace(".", "_"), source_table])
+
+
 @provide_session
 def merge_conn(conn, session=None):
     """Add new Connection."""
@@ -697,47 +697,77 @@ def check_conn_type_null(session=None) -> Iterable[str]:
         )
 
 
+def _format_dangling_error(source_table, target_table, invalid_count, reason):
+    noun = "row" if invalid_count == 1 else "rows"
+    return (
+        f"The {source_table} table has {invalid_count} {noun} {reason}, which "
+        f"is invalid. We could not move them out of the way because the "
+        f"{target_table} table already exists in your database. Please either "
+        f"drop the {target_table} table, or manually delete the invalid rows "
+        f"from the {source_table} table."
+    )
+
+
+def _move_dangling_run_data_to_new_table(session, source_table, target_table):
+    where_clause = "where dag_id is null or run_id is null or execution_date 
is null"
+    session.execute(text(f"create table {target_table} as select * from 
{source_table} {where_clause}"))
+    session.execute(text(f"delete from {source_table} {where_clause}"))
+
+
 def check_run_id_null(session) -> Iterable[str]:
     import sqlalchemy.schema
 
     metadata = sqlalchemy.schema.MetaData(session.bind)
     try:
-        metadata.reflect(only=["dag_run"])
+        metadata.reflect(only=[DagRun.__tablename__])
     except exc.InvalidRequestError:
         # Table doesn't exist -- empty db
         return
 
-    dag_run = metadata.tables["dag_run"]
-
-    for colname in ('run_id', 'dag_id', 'execution_date'):
-
-        col = dag_run.columns.get(colname)
-        if col is None:
-            continue
-
-        if not col.nullable:
-            continue
-
-        num = session.query(dag_run).filter(col.is_(None)).count()
-        if num > 0:
-            yield (
-                f'The {dag_run.name} table has {num} row{"s" if num != 1 else 
""} with a NULL value in '
-                f'{col.name!r}. You must manually correct this problem 
(possibly by deleting the problem '
-                'rows).'
+    # We can't use the model here since it may differ from the db state due to
+    # this function is run prior to migration. Use the reflected table instead.
+    dagrun_table = metadata.tables[DagRun.__tablename__]
+
+    invalid_dagrun_filter = or_(
+        dagrun_table.c.dag_id.is_(None),
+        dagrun_table.c.run_id.is_(None),
+        dagrun_table.c.execution_date.is_(None),
+    )
+    invalid_dagrun_count = 
session.query(dagrun_table.c.id).filter(invalid_dagrun_filter).count()
+    if invalid_dagrun_count > 0:
+        dagrun_dangling_table_name = 
_format_airflow_moved_table_name(dagrun_table.name, "2.2")
+        if dagrun_dangling_table_name in 
inspect(session.get_bind()).get_table_names():
+            yield _format_dangling_error(
+                source_table=dagrun_table.name,
+                target_table=dagrun_dangling_table_name,
+                invalid_count=invalid_dagrun_count,
+                reason="with a NULL dag_id, run_id, or execution_date",
             )
-    session.rollback()
+            return
+        _move_dangling_run_data_to_new_table(session, dagrun_table.name, 
dagrun_dangling_table_name)
+
+
+def _move_dangling_task_data_to_new_table(session, source_table, target_table):
+    where_clause = f"""
+        where (task_id, dag_id, execution_date) IN (
+            select source.task_id, source.dag_id, source.execution_date
+            from {source_table} as source
+            left join dag_run as dr
+            on (source.dag_id = dr.dag_id and source.execution_date = 
dr.execution_date)
+            where dr.id is null
+        )
+    """
+    session.execute(text(f"create table {target_table} as select * from 
{source_table} {where_clause}"))
+    session.execute(text(f"delete from {source_table} {where_clause}"))
 
 
 def check_task_tables_without_matching_dagruns(session) -> Iterable[str]:
-    from itertools import chain
-
     import sqlalchemy.schema
     from sqlalchemy import and_, outerjoin
 
     metadata = sqlalchemy.schema.MetaData(session.bind)
-    models_to_dagrun = [TaskInstance, TaskFail]
-    models_to_ti = []
-    for model in models_to_dagrun + models_to_ti + [DagRun]:
+    models_to_dagrun = [TaskInstance, TaskReschedule]
+    for model in models_to_dagrun + [DagRun]:
         try:
             metadata.reflect(only=[model.__tablename__])
         except exc.InvalidRequestError:
@@ -745,43 +775,57 @@ def check_task_tables_without_matching_dagruns(session) 
-> Iterable[str]:
             # version
             pass
 
+    # Key table doesn't exist -- likely empty DB.
     if DagRun.__tablename__ not in metadata or TaskInstance.__tablename__ not 
in metadata:
-        # Key table doesn't exist -- likely empty DB
-        session.rollback()
         return
 
-    for (model, target) in chain(
-        ((m, metadata.tables[DagRun.__tablename__]) for m in models_to_dagrun),
-        ((m, metadata.tables[TaskInstance.__tablename__]) for m in 
models_to_ti),
-    ):
-        table = metadata.tables.get(model.__tablename__)
-        if table is None:
+    # We can't use the model here since it may differ from the db state due to
+    # this function is run prior to migration. Use the reflected table instead.
+    dagrun_table = metadata.tables[DagRun.__tablename__]
+
+    existing_table_names = set(inspect(session.get_bind()).get_table_names())
+    errored = False
+
+    for model in models_to_dagrun:
+        # We can't use the model here since it may differ from the db state 
due to
+        # this function is run prior to migration. Use the reflected table 
instead.
+        source_table = metadata.tables.get(model.__tablename__)
+        if source_table is None:
             continue
-        if 'run_id' in table.columns:
-            # Migration already applied, don't check again
+
+        # Migration already applied, don't check again.
+        if "run_id" in source_table.columns:
             continue
 
-        # We can't use the model here (as that would have the 
associationproxy, we instead need to use the
-        # _reflected_ table)
-        join_cond = and_(table.c.dag_id == target.c.dag_id, 
table.c.execution_date == target.c.execution_date)
-        if "task_id" in target.columns:
-            join_cond = and_(join_cond, table.c.task_id == target.c.task_id)
-
-        query = (
-            session.query(table.c.dag_id, table.c.task_id, 
table.c.execution_date)
-            .select_from(outerjoin(table, target, join_cond))
-            .filter(target.c.dag_id.is_(None))
-        )  # type: ignore
-
-        num = query.count()
-
-        if num > 0:
-            yield (
-                f'The {table.name} table has {num} row{"s" if num != 1 else 
""} without a '
-                f'corresponding {target.name} row. You must manually correct 
this problem '
-                '(possibly by deleting the problem rows).'
+        source_to_dag_run_join_cond = and_(
+            source_table.c.dag_id == dagrun_table.c.dag_id,
+            source_table.c.execution_date == dagrun_table.c.execution_date,
+        )
+        invalid_row_count = (
+            session.query(source_table.c.dag_id, source_table.c.task_id, 
source_table.c.execution_date)
+            .select_from(outerjoin(source_table, dagrun_table, 
source_to_dag_run_join_cond))
+            .filter(dagrun_table.c.dag_id.is_(None))
+            .count()
+        )
+        if invalid_row_count <= 0:
+            continue
+
+        dangling_table_name = 
_format_airflow_moved_table_name(source_table.name, "2.2")
+        if dangling_table_name in existing_table_names:
+            yield _format_dangling_error(
+                source_table=source_table.name,
+                target_table=dangling_table_name,
+                invalid_count=invalid_row_count,
+                reason=f"without a corresponding {dagrun_table.name} row",
             )
-    session.rollback()
+            errored = True
+            continue
+        _move_dangling_task_data_to_new_table(session, source_table.name, 
dangling_table_name)
+
+    if errored:
+        session.rollback()
+    else:
+        session.commit()
 
 
 @provide_session
diff --git a/airflow/www/templates/airflow/dags.html 
b/airflow/www/templates/airflow/dags.html
index 1bc56e5..5d5d140 100644
--- a/airflow/www/templates/airflow/dags.html
+++ b/airflow/www/templates/airflow/dags.html
@@ -51,6 +51,14 @@
   {% for m in dashboard_alerts %}
     {{ message(m.message, m.category) }}
   {% endfor %}
+  {% for original_table_name, moved_table_name in migration_moved_data_alerts 
%}
+    {% call message(category='error', dismissable=false) %}
+      Airflow found incompatible data in the <code>{{ original_table_name 
}}</code> table in the
+      metadatabase, and has moved them to <code>{{ moved_table_name }}</code> 
during the database migration
+      to upgrade. Please inspect the moved data to decide whether you need to 
keep them, and manually drop
+      the <code>{{ moved_table_name }}</code> table to dismiss this warning.
+    {% endcall %}
+  {% endfor %}
   {{ super() }}
   {% if sqlite_warning | default(true) %}
     {% call message(category='warning', dismissable=false)  %}
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 2dca5fd..d3e77ff 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -82,7 +82,7 @@ from pendulum.datetime import DateTime
 from pendulum.parsing.exceptions import ParserError
 from pygments import highlight, lexers
 from pygments.formatters import HtmlFormatter
-from sqlalchemy import Date, and_, desc, func, union_all
+from sqlalchemy import Date, and_, desc, func, inspect, union_all
 from sqlalchemy.exc import IntegrityError
 from sqlalchemy.orm import joinedload
 from wtforms import SelectField, validators
@@ -692,10 +692,21 @@ class Airflow(AirflowBaseView):
             fm for fm in settings.DASHBOARD_UIALERTS if 
fm.should_show(current_app.appbuilder.sm)
         ]
 
+        def _iter_parsed_moved_data_table_names():
+            for table_name in inspect(session.get_bind()).get_table_names():
+                segments = table_name.split("__", 2)
+                if len(segments) < 3:
+                    continue
+                if segments[0] != settings.AIRFLOW_MOVED_TABLE_PREFIX:
+                    continue
+                # Second segment is a version marker that we don't need to 
show.
+                yield segments[2], table_name
+
         return self.render_template(
             'airflow/dags.html',
             dags=dags,
             dashboard_alerts=dashboard_alerts,
+            
migration_moved_data_alerts=sorted(set(_iter_parsed_moved_data_table_names())),
             current_page=current_page,
             search_query=arg_search_query if arg_search_query else '',
             page_title=page_title,
diff --git a/tests/www/views/test_views_base.py 
b/tests/www/views/test_views_base.py
index e3be4b2..5254be1 100644
--- a/tests/www/views/test_views_base.py
+++ b/tests/www/views/test_views_base.py
@@ -30,7 +30,7 @@ from tests.test_utils.www import check_content_in_response, 
check_content_not_in
 
 
 def test_index(admin_client):
-    with assert_queries_count(48):
+    with assert_queries_count(49):
         resp = admin_client.get('/', follow_redirects=True)
     check_content_in_response('DAGs', resp)
 

Reply via email to