This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e6c56c4 Ensure that dag_id, run_id and execution_date are non-null on
DagRun (#18804)
e6c56c4 is described below
commit e6c56c4ae475605636f4a1b5ab3884383884a8cf
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Fri Oct 8 01:32:20 2021 +0100
Ensure that dag_id, run_id and execution_date are non-null on DagRun
(#18804)
These _should_ be non-nullable, and are always created as such. Without
this it was possible that someone had manually edited it which caused
problems with the TaskInstance FK migration not applying correctly.
Co-authored-by: Jed Cunningham
<[email protected]>
---
.../7b2661a43ba3_taskinstance_keyed_to_dagrun.py | 147 +++++++++++++++++----
airflow/models/dagrun.py | 6 +-
airflow/models/taskinstance.py | 6 +-
airflow/utils/db.py | 32 +++++
tests/api_connexion/schemas/test_dag_run_schema.py | 9 +-
tests/models/test_dagrun.py | 3 +
6 files changed, 168 insertions(+), 35 deletions(-)
diff --git
a/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py
b/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py
index 8c62101..059144e 100644
--- a/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py
+++ b/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py
@@ -41,10 +41,17 @@ branch_labels = None
depends_on = None
-def _mssql_datetime():
- from sqlalchemy.dialects import mssql
+def _datetime_type(dialect_name):
+ if dialect_name == "mssql":
+ from sqlalchemy.dialects import mssql
+
+ return mssql.DATETIME2(precision=6)
+ elif dialect_name == "mysql":
+ from sqlalchemy.dialects import mysql
- return mssql.DATETIME2(precision=6)
+ return mysql.DATETIME(fsp=6)
+
+ return sa.TIMESTAMP(timezone=True)
# Just Enough Table to run the conditions for update.
@@ -101,21 +108,30 @@ def upgrade():
"""Apply TaskInstance keyed to DagRun"""
conn = op.get_bind()
dialect_name = conn.dialect.name
+ dt_type = _datetime_type(dialect_name)
- run_id_col_type = sa.String(length=ID_LEN, **COLLATION_ARGS)
+ string_id_col_type = sa.String(length=ID_LEN, **COLLATION_ARGS)
if dialect_name == 'sqlite':
naming_convention = {
"uq": "%(table_name)s_%(column_0_N_name)s_key",
}
- with op.batch_alter_table('dag_run',
naming_convention=naming_convention, recreate="always"):
- # The naming_convention force the previously un-named UNIQUE
constraints to have the right name --
- # but we still need to enter the context manager to trigger it
- pass
+ # The naming_convention force the previously un-named UNIQUE
constraints to have the right name
+ with op.batch_alter_table(
+ 'dag_run', naming_convention=naming_convention, recreate="always"
+ ) as batch_op:
+ batch_op.alter_column('dag_id', existing_type=string_id_col_type,
nullable=False)
+ batch_op.alter_column('run_id', existing_type=string_id_col_type,
nullable=False)
+ batch_op.alter_column('execution_date', existing_type=dt_type,
nullable=False)
elif dialect_name == 'mysql':
with op.batch_alter_table('dag_run') as batch_op:
- batch_op.alter_column('dag_id',
existing_type=sa.String(length=ID_LEN), type_=run_id_col_type)
- batch_op.alter_column('run_id',
existing_type=sa.String(length=ID_LEN), type_=run_id_col_type)
+ batch_op.alter_column(
+ 'dag_id', existing_type=sa.String(length=ID_LEN),
type_=string_id_col_type, nullable=False
+ )
+ batch_op.alter_column(
+ 'run_id', existing_type=sa.String(length=ID_LEN),
type_=string_id_col_type, nullable=False
+ )
+ batch_op.alter_column('execution_date', existing_type=dt_type,
nullable=False)
batch_op.drop_constraint('dag_id', 'unique')
batch_op.drop_constraint('dag_id_2', 'unique')
batch_op.create_unique_constraint(
@@ -124,16 +140,47 @@ def upgrade():
batch_op.create_unique_constraint('dag_run_dag_id_run_id_key',
['dag_id', 'run_id'])
elif dialect_name == 'mssql':
- # _Somehow_ mssql was missing these constraints entirely!
with op.batch_alter_table('dag_run') as batch_op:
+ batch_op.drop_index('idx_not_null_dag_id_execution_date')
+ batch_op.drop_index('idx_not_null_dag_id_run_id')
+
+ batch_op.drop_index('dag_id_state')
+ batch_op.drop_index('idx_dag_run_dag_id')
+ batch_op.drop_index('idx_dag_run_running_dags')
+ batch_op.drop_index('idx_dag_run_queued_dags')
+
+ batch_op.alter_column('dag_id', existing_type=string_id_col_type,
nullable=False)
+ batch_op.alter_column('execution_date', existing_type=dt_type,
nullable=False)
+ batch_op.alter_column('run_id', existing_type=string_id_col_type,
nullable=False)
+
+ # _Somehow_ mssql was missing these constraints entirely
batch_op.create_unique_constraint(
'dag_run_dag_id_execution_date_key', ['dag_id',
'execution_date']
)
batch_op.create_unique_constraint('dag_run_dag_id_run_id_key',
['dag_id', 'run_id'])
+ batch_op.create_index('dag_id_state', ['dag_id', 'state'],
unique=False)
+ batch_op.create_index('idx_dag_run_dag_id', ['dag_id'])
+ batch_op.create_index(
+ 'idx_dag_run_running_dags',
+ ["state", "dag_id"],
+ mssql_where=sa.text("state='running'"),
+ )
+ batch_op.create_index(
+ 'idx_dag_run_queued_dags',
+ ["state", "dag_id"],
+ mssql_where=sa.text("state='queued'"),
+ )
+ else:
+ # Make sure DagRun id columns are non-nullable
+ with op.batch_alter_table('dag_run', schema=None) as batch_op:
+ batch_op.alter_column('dag_id', existing_type=string_id_col_type,
nullable=False)
+ batch_op.alter_column('execution_date', existing_type=dt_type,
nullable=False)
+ batch_op.alter_column('run_id', existing_type=string_id_col_type,
nullable=False)
+
# First create column nullable
- op.add_column('task_instance', sa.Column('run_id', type_=run_id_col_type,
nullable=True))
- op.add_column('task_reschedule', sa.Column('run_id',
type_=run_id_col_type, nullable=True))
+ op.add_column('task_instance', sa.Column('run_id',
type_=string_id_col_type, nullable=True))
+ op.add_column('task_reschedule', sa.Column('run_id',
type_=string_id_col_type, nullable=True))
# Then update the new column by selecting the right value from DagRun
update_query = _multi_table_update(dialect_name, task_instance,
task_instance.c.run_id)
@@ -147,7 +194,9 @@ def upgrade():
op.execute(update_query)
with op.batch_alter_table('task_reschedule', schema=None) as batch_op:
- batch_op.alter_column('run_id', existing_type=run_id_col_type,
existing_nullable=True, nullable=False)
+ batch_op.alter_column(
+ 'run_id', existing_type=string_id_col_type,
existing_nullable=True, nullable=False
+ )
batch_op.drop_constraint('task_reschedule_dag_task_date_fkey',
'foreignkey')
if dialect_name == "mysql":
@@ -157,7 +206,14 @@ def upgrade():
with op.batch_alter_table('task_instance', schema=None) as batch_op:
# Then make it non-nullable
- batch_op.alter_column('run_id', existing_type=run_id_col_type,
existing_nullable=True, nullable=False)
+ batch_op.alter_column(
+ 'run_id', existing_type=string_id_col_type,
existing_nullable=True, nullable=False
+ )
+
+ batch_op.alter_column(
+ 'dag_id', existing_type=string_id_col_type,
existing_nullable=True, nullable=False
+ )
+ batch_op.alter_column('execution_date', existing_type=dt_type,
existing_nullable=True, nullable=False)
# TODO: Is this right for non-postgres?
if dialect_name == 'mssql':
@@ -212,14 +268,11 @@ def upgrade():
def downgrade():
"""Unapply TaskInstance keyed to DagRun"""
dialect_name = op.get_bind().dialect.name
+ dt_type = _datetime_type(dialect_name)
+ string_id_col_type = sa.String(length=ID_LEN, **COLLATION_ARGS)
- if dialect_name == "mssql":
- col_type = _mssql_datetime()
- else:
- col_type = sa.TIMESTAMP(timezone=True)
-
- op.add_column('task_instance', sa.Column('execution_date', col_type,
nullable=True))
- op.add_column('task_reschedule', sa.Column('execution_date', col_type,
nullable=True))
+ op.add_column('task_instance', sa.Column('execution_date', dt_type,
nullable=True))
+ op.add_column('task_reschedule', sa.Column('execution_date', dt_type,
nullable=True))
update_query = _multi_table_update(dialect_name, task_instance,
task_instance.c.execution_date)
op.execute(update_query)
@@ -228,9 +281,7 @@ def downgrade():
op.execute(update_query)
with op.batch_alter_table('task_reschedule', schema=None) as batch_op:
- batch_op.alter_column(
- 'execution_date', existing_type=col_type, existing_nullable=True,
nullable=False
- )
+ batch_op.alter_column('execution_date', existing_type=dt_type,
existing_nullable=True, nullable=False)
# Can't drop PK index while there is a FK referencing it
batch_op.drop_constraint('task_reschedule_ti_fkey')
@@ -238,8 +289,9 @@ def downgrade():
batch_op.drop_index('idx_task_reschedule_dag_task_run')
with op.batch_alter_table('task_instance', schema=None) as batch_op:
+ batch_op.alter_column('execution_date', existing_type=dt_type,
existing_nullable=True, nullable=False)
batch_op.alter_column(
- 'execution_date', existing_type=col_type, existing_nullable=True,
nullable=False
+ 'dag_id', existing_type=string_id_col_type,
existing_nullable=True, nullable=True
)
batch_op.drop_constraint('task_instance_pkey', type_='primary')
@@ -269,6 +321,49 @@ def downgrade():
ondelete='CASCADE',
)
+ if dialect_name == "mssql":
+
+ with op.batch_alter_table('dag_run', schema=None) as batch_op:
+ batch_op.drop_constraint('dag_run_dag_id_execution_date_key',
'unique')
+ batch_op.drop_constraint('dag_run_dag_id_run_id_key', 'unique')
+ batch_op.drop_index('dag_id_state')
+ batch_op.drop_index('idx_dag_run_running_dags')
+ batch_op.drop_index('idx_dag_run_queued_dags')
+
+ batch_op.alter_column('dag_id', existing_type=string_id_col_type,
nullable=True)
+ batch_op.alter_column('execution_date', existing_type=dt_type,
nullable=True)
+ batch_op.alter_column('run_id', existing_type=string_id_col_type,
nullable=True)
+
+ batch_op.create_index('dag_id_state', ['dag_id', 'state'],
unique=False)
+ batch_op.create_index('idx_dag_run_dag_id', ['dag_id'])
+ batch_op.create_index(
+ 'idx_dag_run_running_dags',
+ ["state", "dag_id"],
+ mssql_where=sa.text("state='running'"),
+ )
+ batch_op.create_index(
+ 'idx_dag_run_queued_dags',
+ ["state", "dag_id"],
+ mssql_where=sa.text("state='queued'"),
+ )
+ op.execute(
+ """CREATE UNIQUE NONCLUSTERED INDEX
idx_not_null_dag_id_execution_date
+ ON dag_run(dag_id,execution_date)
+ WHERE dag_id IS NOT NULL and execution_date is not null"""
+ )
+ op.execute(
+ """CREATE UNIQUE NONCLUSTERED INDEX idx_not_null_dag_id_run_id
+ ON dag_run(dag_id,run_id)
+ WHERE dag_id IS NOT NULL and run_id is not null"""
+ )
+ else:
+ with op.batch_alter_table('dag_run', schema=None) as batch_op:
+ batch_op.drop_index('dag_id_state', table_name='dag_run')
+ batch_op.alter_column('run_id',
existing_type=sa.VARCHAR(length=250), nullable=True)
+ batch_op.alter_column('execution_date', existing_type=dt_type,
nullable=True)
+ batch_op.alter_column('dag_id',
existing_type=sa.VARCHAR(length=250), nullable=True)
+ batch_op.create_index('dag_id_state', 'dag_run', ['dag_id',
'state'], unique=False)
+
def _multi_table_update(dialect_name, target, column):
condition = dag_run.c.dag_id == target.c.dag_id
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 3d50b05..2b651c5 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -78,13 +78,13 @@ class DagRun(Base, LoggingMixin):
__NO_VALUE = object()
id = Column(Integer, primary_key=True)
- dag_id = Column(String(ID_LEN, **COLLATION_ARGS))
+ dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
queued_at = Column(UtcDateTime)
- execution_date = Column(UtcDateTime, default=timezone.utcnow)
+ execution_date = Column(UtcDateTime, default=timezone.utcnow,
nullable=False)
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
_state = Column('state', String(50), default=State.QUEUED)
- run_id = Column(String(ID_LEN, **COLLATION_ARGS))
+ run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
creating_job_id = Column(Integer)
external_trigger = Column(Boolean, default=True)
run_type = Column(String(50), nullable=False)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 8453a1e..b178ff0 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -324,9 +324,9 @@ class TaskInstance(Base, LoggingMixin):
__tablename__ = "task_instance"
- task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
- dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
- run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
+ task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True,
nullable=False)
+ dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True,
nullable=False)
+ run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True,
nullable=False)
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
duration = Column(Float)
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index db249fe..13dd401 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -697,6 +697,37 @@ def check_conn_type_null(session=None) -> Iterable[str]:
)
+def check_run_id_null(session) -> Iterable[str]:
+ import sqlalchemy.schema
+
+ metadata = sqlalchemy.schema.MetaData(session.bind)
+ try:
+ metadata.reflect(only=["dag_run"])
+ except exc.InvalidRequestError:
+ # Table doesn't exist -- empty db
+ return
+
+ dag_run = metadata.tables["dag_run"]
+
+ for colname in ('run_id', 'dag_id', 'execution_date'):
+
+ col = dag_run.columns.get(colname)
+ if col is None:
+ continue
+
+ if not col.nullable:
+ continue
+
+ num = session.query(dag_run).filter(col.is_(None)).count()
+ if num > 0:
+ yield (
+ f'The {dag_run.name} table has {num} row{"s" if num != 1 else
""} with a NULL value in '
+ f'{col.name!r}. You must manually correct this problem
(possibly by deleting the problem '
+ 'rows).'
+ )
+ session.rollback()
+
+
def check_task_tables_without_matching_dagruns(session) -> Iterable[str]:
from itertools import chain
@@ -762,6 +793,7 @@ def _check_migration_errors(session=None) -> Iterable[str]:
for check_fn in (
check_conn_id_duplicates,
check_conn_type_null,
+ check_run_id_null,
check_task_tables_without_matching_dagruns,
):
yield from check_fn(session)
diff --git a/tests/api_connexion/schemas/test_dag_run_schema.py
b/tests/api_connexion/schemas/test_dag_run_schema.py
index ba5acae..6f42ec0 100644
--- a/tests/api_connexion/schemas/test_dag_run_schema.py
+++ b/tests/api_connexion/schemas/test_dag_run_schema.py
@@ -51,6 +51,7 @@ class TestDAGRunSchema(TestDAGRunBase):
@provide_session
def test_serialize(self, session):
dagrun_model = DagRun(
+ dag_id="my-dag-run",
run_id="my-dag-run",
state='running',
run_type=DagRunType.MANUAL.value,
@@ -64,7 +65,7 @@ class TestDAGRunSchema(TestDAGRunBase):
deserialized_dagrun = dagrun_schema.dump(dagrun_model)
assert deserialized_dagrun == {
- "dag_id": None,
+ "dag_id": "my-dag-run",
"dag_run_id": "my-dag-run",
"end_date": None,
"state": "running",
@@ -128,6 +129,7 @@ class TestDagRunCollection(TestDAGRunBase):
@provide_session
def test_serialize(self, session):
dagrun_model_1 = DagRun(
+ dag_id="my-dag-run",
run_id="my-dag-run",
state='running',
execution_date=timezone.parse(self.default_time),
@@ -136,6 +138,7 @@ class TestDagRunCollection(TestDAGRunBase):
conf='{"start": "stop"}',
)
dagrun_model_2 = DagRun(
+ dag_id="my-dag-run",
run_id="my-dag-run-2",
state='running',
execution_date=timezone.parse(self.second_time),
@@ -150,7 +153,7 @@ class TestDagRunCollection(TestDAGRunBase):
assert deserialized_dagruns == {
"dag_runs": [
{
- "dag_id": None,
+ "dag_id": "my-dag-run",
"dag_run_id": "my-dag-run",
"end_date": None,
"execution_date": self.default_time,
@@ -161,7 +164,7 @@ class TestDagRunCollection(TestDAGRunBase):
"conf": {"start": "stop"},
},
{
- "dag_id": None,
+ "dag_id": "my-dag-run",
"dag_run_id": "my-dag-run-2",
"end_date": None,
"state": "running",
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 3622603..c4ef287 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -114,6 +114,7 @@ class TestDagRun(unittest.TestCase):
dag_id1 = "test_dagrun_find_externally_triggered"
dag_run = models.DagRun(
dag_id=dag_id1,
+ run_id=dag_id1,
run_type=DagRunType.MANUAL,
execution_date=now,
start_date=now,
@@ -125,6 +126,7 @@ class TestDagRun(unittest.TestCase):
dag_id2 = "test_dagrun_find_not_externally_triggered"
dag_run = models.DagRun(
dag_id=dag_id2,
+ run_id=dag_id2,
run_type=DagRunType.MANUAL,
execution_date=now,
start_date=now,
@@ -532,6 +534,7 @@ class TestDagRun(unittest.TestCase):
# don't want
dag_run = models.DagRun(
dag_id=dag.dag_id,
+ run_id="test_get_task_instance_on_empty_dagrun",
run_type=DagRunType.MANUAL,
execution_date=now,
start_date=now,