This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 25537acfa2 Have consistent types between the ORM and the migration
files (#24044)
25537acfa2 is described below
commit 25537acfa28eebc82a90274840e0e6fb5c91e271
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Fri Jun 24 16:51:01 2022 +0100
Have consistent types between the ORM and the migration files (#24044)
We currently don't compare column types between ORM and the migration
files. Some columns in the migration files have different types from the same
columns in the ORM.
Here, I made effort to match the types in migration files with the
types in ORM, using the migration files as the source of truth in most
cases.
I couldn't convert the MySQL VARCHAR collation in db(utf8_bin) to use the
one in ORM(utf8mb3_bin). It seems it's not possible to convert a collation of
an already existing column in MySQL.
---
airflow/migrations/env.py | 9 +-
.../versions/0112_2_4_0_add_dagwarning_model.py | 2 +-
.../0113_2_4_0_compare_types_between_orm_and_db.py | 253 +++++++++++++++++++++
airflow/models/connection.py | 2 +-
airflow/models/dag.py | 2 +-
airflow/models/dagcode.py | 3 +-
airflow/models/dagpickle.py | 4 +-
airflow/models/dagrun.py | 6 +-
airflow/models/dagwarning.py | 4 +-
airflow/models/log.py | 4 +-
airflow/models/renderedtifields.py | 4 +-
airflow/models/serialized_dag.py | 2 +-
airflow/models/taskfail.py | 4 +-
airflow/models/taskinstance.py | 23 +-
airflow/models/variable.py | 3 +-
airflow/models/xcom.py | 4 +-
airflow/utils/db.py | 55 +++++
airflow/utils/sqlalchemy.py | 20 +-
airflow/www/fab_security/sqla/models.py | 60 +----
docs/apache-airflow/migrations-ref.rst | 4 +-
tests/utils/test_db.py | 15 +-
21 files changed, 389 insertions(+), 94 deletions(-)
diff --git a/airflow/migrations/env.py b/airflow/migrations/env.py
index 58a8f7f4a0..039d7cf23f 100644
--- a/airflow/migrations/env.py
+++ b/airflow/migrations/env.py
@@ -21,6 +21,7 @@ from logging.config import fileConfig
from alembic import context
from airflow import models, settings
+from airflow.utils.db import compare_server_default, compare_type
def include_object(_, name, type_, *args):
@@ -51,8 +52,6 @@ target_metadata = models.base.Base.metadata
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
-COMPARE_TYPE = False
-
def run_migrations_offline():
"""Run migrations in 'offline' mode.
@@ -70,7 +69,8 @@ def run_migrations_offline():
url=settings.SQL_ALCHEMY_CONN,
target_metadata=target_metadata,
literal_binds=True,
- compare_type=COMPARE_TYPE,
+ compare_type=compare_type,
+ compare_server_default=compare_server_default,
render_as_batch=True,
)
@@ -92,7 +92,8 @@ def run_migrations_online():
connection=connection,
transaction_per_migration=True,
target_metadata=target_metadata,
- compare_type=COMPARE_TYPE,
+ compare_type=compare_type,
+ compare_server_default=compare_server_default,
include_object=include_object,
render_as_batch=True,
)
diff --git a/airflow/migrations/versions/0112_2_4_0_add_dagwarning_model.py
b/airflow/migrations/versions/0112_2_4_0_add_dagwarning_model.py
index cb7d089871..3d03210e53 100644
--- a/airflow/migrations/versions/0112_2_4_0_add_dagwarning_model.py
+++ b/airflow/migrations/versions/0112_2_4_0_add_dagwarning_model.py
@@ -44,7 +44,7 @@ def upgrade():
'dag_warning',
sa.Column('dag_id', StringID(), primary_key=True),
sa.Column('warning_type', sa.String(length=50), primary_key=True),
- sa.Column('message', sa.String(1000), nullable=False),
+ sa.Column('message', sa.Text(), nullable=False),
sa.Column('timestamp', TIMESTAMP, nullable=False),
sa.ForeignKeyConstraint(
('dag_id',),
diff --git
a/airflow/migrations/versions/0113_2_4_0_compare_types_between_orm_and_db.py
b/airflow/migrations/versions/0113_2_4_0_compare_types_between_orm_and_db.py
new file mode 100644
index 0000000000..68f02dd81e
--- /dev/null
+++ b/airflow/migrations/versions/0113_2_4_0_compare_types_between_orm_and_db.py
@@ -0,0 +1,253 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""compare types between ORM and DB
+
+Revision ID: 44b7034f6bdc
+Revises: 424117c37d18
+Create Date: 2022-05-31 09:16:44.558754
+
+"""
+
+import sqlalchemy as sa
+from alembic import op
+
+from airflow.migrations.db_types import TIMESTAMP
+
+# revision identifiers, used by Alembic.
+revision = '44b7034f6bdc'
+down_revision = '424117c37d18'
+branch_labels = None
+depends_on = None
+airflow_version = '2.4.0'
+
+
+def upgrade():
+ """Apply compare types between ORM and DB"""
+ conn = op.get_bind()
+ with op.batch_alter_table('connection', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'extra',
+ existing_type=sa.TEXT(),
+ type_=sa.Text(),
+ existing_nullable=True,
+ )
+ with op.batch_alter_table('log_template', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'created_at', existing_type=sa.DateTime(), type_=TIMESTAMP(),
existing_nullable=False
+ )
+
+ with op.batch_alter_table('serialized_dag', schema=None) as batch_op:
+ # drop server_default
+ batch_op.alter_column(
+ 'dag_hash',
+ existing_type=sa.String(32),
+ server_default=None,
+ type_=sa.String(32),
+ existing_nullable=False,
+ )
+ with op.batch_alter_table('trigger', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'created_date', existing_type=sa.DateTime(), type_=TIMESTAMP(),
existing_nullable=False
+ )
+
+ if conn.dialect.name != 'sqlite':
+ return
+ with op.batch_alter_table('serialized_dag', schema=None) as batch_op:
+ batch_op.alter_column('fileloc_hash', existing_type=sa.Integer,
type_=sa.BigInteger())
+ # Some sqlite date are not in db_types.TIMESTAMP. Convert these to
TIMESTAMP.
+ with op.batch_alter_table('dag', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'last_pickled', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=True
+ )
+ batch_op.alter_column(
+ 'last_expired', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=True
+ )
+
+ with op.batch_alter_table('dag_pickle', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'created_dttm', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=True
+ )
+
+ with op.batch_alter_table('dag_run', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'execution_date', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=False
+ )
+ batch_op.alter_column(
+ 'start_date', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=True
+ )
+ batch_op.alter_column(
+ 'end_date', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=True
+ )
+
+ with op.batch_alter_table('import_error', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'timestamp', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=True
+ )
+
+ with op.batch_alter_table('job', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'start_date', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=True
+ )
+ batch_op.alter_column(
+ 'end_date', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=True
+ )
+ batch_op.alter_column(
+ 'latest_heartbeat', existing_type=sa.DATETIME(),
type_=TIMESTAMP(), existing_nullable=True
+ )
+
+ with op.batch_alter_table('log', schema=None) as batch_op:
+ batch_op.alter_column('dttm', existing_type=sa.DATETIME(),
type_=TIMESTAMP(), existing_nullable=True)
+ batch_op.alter_column(
+ 'execution_date', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=True
+ )
+
+ with op.batch_alter_table('serialized_dag', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'last_updated', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=False
+ )
+
+ with op.batch_alter_table('sla_miss', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'execution_date', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=False
+ )
+ batch_op.alter_column(
+ 'timestamp', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=True
+ )
+
+ with op.batch_alter_table('task_fail', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'start_date', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=True
+ )
+ batch_op.alter_column(
+ 'end_date', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=True
+ )
+
+ with op.batch_alter_table('task_instance', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'start_date', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=True
+ )
+ batch_op.alter_column(
+ 'end_date', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=True
+ )
+ batch_op.alter_column(
+ 'queued_dttm', existing_type=sa.DATETIME(), type_=TIMESTAMP(),
existing_nullable=True
+ )
+
+
+def downgrade():
+ """Unapply compare types between ORM and DB"""
+ with op.batch_alter_table('connection', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'extra',
+ existing_type=sa.Text(),
+ type_=sa.TEXT(),
+ existing_nullable=True,
+ )
+ with op.batch_alter_table('log_template', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'created_at', existing_type=TIMESTAMP(), type_=sa.DateTime(),
existing_nullable=False
+ )
+ with op.batch_alter_table('trigger', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'created_date', existing_type=TIMESTAMP(), type_=sa.DateTime(),
existing_nullable=False
+ )
+ conn = op.get_bind()
+
+ if conn.dialect.name != 'sqlite':
+ return
+ with op.batch_alter_table('serialized_dag', schema=None) as batch_op:
+ batch_op.alter_column('fileloc_hash', existing_type=sa.BigInteger,
type_=sa.Integer())
+ # Change these column back to sa.DATETIME()
+ with op.batch_alter_table('task_instance', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'queued_dttm', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=True
+ )
+ batch_op.alter_column(
+ 'end_date', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=True
+ )
+ batch_op.alter_column(
+ 'start_date', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=True
+ )
+
+ with op.batch_alter_table('task_fail', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'end_date', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=True
+ )
+ batch_op.alter_column(
+ 'start_date', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=True
+ )
+
+ with op.batch_alter_table('sla_miss', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'timestamp', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=True
+ )
+ batch_op.alter_column(
+ 'execution_date', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=False
+ )
+
+ with op.batch_alter_table('serialized_dag', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'last_updated', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=False
+ )
+
+ with op.batch_alter_table('log', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'execution_date', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=True
+ )
+ batch_op.alter_column('dttm', existing_type=TIMESTAMP(),
type_=sa.DATETIME(), existing_nullable=True)
+
+ with op.batch_alter_table('job', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'latest_heartbeat', existing_type=TIMESTAMP(),
type_=sa.DATETIME(), existing_nullable=True
+ )
+ batch_op.alter_column(
+ 'end_date', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=True
+ )
+ batch_op.alter_column(
+ 'start_date', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=True
+ )
+
+ with op.batch_alter_table('import_error', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'timestamp', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=True
+ )
+
+ with op.batch_alter_table('dag_run', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'end_date', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=True
+ )
+ batch_op.alter_column(
+ 'start_date', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=True
+ )
+ batch_op.alter_column(
+ 'execution_date', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=False
+ )
+
+ with op.batch_alter_table('dag_pickle', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'created_dttm', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=True
+ )
+
+ with op.batch_alter_table('dag', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'last_expired', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=True
+ )
+ batch_op.alter_column(
+ 'last_pickled', existing_type=TIMESTAMP(), type_=sa.DATETIME(),
existing_nullable=True
+ )
diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index b1219907f1..20059907c6 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -90,7 +90,7 @@ class Connection(Base, LoggingMixin):
id = Column(Integer(), primary_key=True)
conn_id = Column(String(ID_LEN), unique=True, nullable=False)
conn_type = Column(String(500), nullable=False)
- description = Column(Text(5000))
+ description = Column(Text().with_variant(Text(5000),
'mysql').with_variant(String(5000), 'sqlite'))
host = Column(String(500))
schema = Column(String(500))
login = Column(String(500))
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index ebe7364d51..b16126e569 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2739,7 +2739,7 @@ class DagModel(Base):
max_active_runs = Column(Integer, nullable=True)
has_task_concurrency_limits = Column(Boolean, nullable=False)
- has_import_errors = Column(Boolean(), default=False)
+ has_import_errors = Column(Boolean(), default=False, server_default='0')
# The logical date of the next dag run.
next_dagrun = Column(UtcDateTime)
diff --git a/airflow/models/dagcode.py b/airflow/models/dagcode.py
index 7322ba92fb..14a24823f6 100644
--- a/airflow/models/dagcode.py
+++ b/airflow/models/dagcode.py
@@ -21,6 +21,7 @@ from datetime import datetime
from typing import Iterable, List, Optional
from sqlalchemy import BigInteger, Column, String, Text
+from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.sql.expression import literal
from airflow.exceptions import AirflowException, DagCodeNotFound
@@ -47,7 +48,7 @@ class DagCode(Base):
fileloc = Column(String(2000), nullable=False)
# The max length of fileloc exceeds the limit of indexing.
last_updated = Column(UtcDateTime, nullable=False)
- source_code = Column(Text, nullable=False)
+ source_code = Column(Text().with_variant(MEDIUMTEXT(), 'mysql'),
nullable=False)
def __init__(self, full_filepath: str, source_code: Optional[str] = None):
self.fileloc = full_filepath
diff --git a/airflow/models/dagpickle.py b/airflow/models/dagpickle.py
index aa56ce3e58..e5e1e08f85 100644
--- a/airflow/models/dagpickle.py
+++ b/airflow/models/dagpickle.py
@@ -17,7 +17,7 @@
# under the License.
import dill
-from sqlalchemy import Column, Integer, PickleType, Text
+from sqlalchemy import BigInteger, Column, Integer, PickleType
from airflow.models.base import Base
from airflow.utils import timezone
@@ -39,7 +39,7 @@ class DagPickle(Base):
id = Column(Integer, primary_key=True)
pickle = Column(PickleType(pickler=dill))
created_dttm = Column(UtcDateTime, default=timezone.utcnow)
- pickle_hash = Column(Text)
+ pickle_hash = Column(BigInteger)
__tablename__ = "dag_pickle"
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 05713b96c7..0255e00549 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -61,7 +61,7 @@ from airflow import settings
from airflow.callbacks.callback_requests import DagCallbackRequest
from airflow.configuration import conf as airflow_conf
from airflow.exceptions import AirflowException, TaskNotFound
-from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
+from airflow.models.base import Base, StringID
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance as TI
from airflow.models.tasklog import LogTemplate
@@ -100,13 +100,13 @@ class DagRun(Base, LoggingMixin):
__tablename__ = "dag_run"
id = Column(Integer, primary_key=True)
- dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
+ dag_id = Column(StringID(), nullable=False)
queued_at = Column(UtcDateTime)
execution_date = Column(UtcDateTime, default=timezone.utcnow,
nullable=False)
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
_state = Column('state', String(50), default=State.QUEUED)
- run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
+ run_id = Column(StringID(), nullable=False)
creating_job_id = Column(Integer)
external_trigger = Column(Boolean, default=True)
run_type = Column(String(50), nullable=False)
diff --git a/airflow/models/dagwarning.py b/airflow/models/dagwarning.py
index aa5dc5ef54..68745ed645 100644
--- a/airflow/models/dagwarning.py
+++ b/airflow/models/dagwarning.py
@@ -34,8 +34,8 @@ class DagWarning(Base):
when parsing DAG and displayed on the Webserver in a flash message.
"""
- dag_id = Column(StringID(), primary_key=True, nullable=False)
- warning_type = Column(String(50), primary_key=True, nullable=False)
+ dag_id = Column(StringID(), primary_key=True)
+ warning_type = Column(String(50), primary_key=True)
message = Column(Text, nullable=False)
timestamp = Column(UtcDateTime, nullable=False, default=timezone.utcnow)
diff --git a/airflow/models/log.py b/airflow/models/log.py
index b2a5639dcd..d3ba41a071 100644
--- a/airflow/models/log.py
+++ b/airflow/models/log.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
-from sqlalchemy import Column, Index, Integer, String, Text, text
+from sqlalchemy import Column, Index, Integer, String, Text
from airflow.models.base import Base, StringID
from airflow.utils import timezone
@@ -32,7 +32,7 @@ class Log(Base):
dttm = Column(UtcDateTime)
dag_id = Column(StringID())
task_id = Column(StringID())
- map_index = Column(Integer, server_default=text('NULL'))
+ map_index = Column(Integer)
event = Column(String(30))
execution_date = Column(UtcDateTime)
owner = Column(String(500))
diff --git a/airflow/models/renderedtifields.py
b/airflow/models/renderedtifields.py
index c7bad78b5f..f1b826c1bc 100644
--- a/airflow/models/renderedtifields.py
+++ b/airflow/models/renderedtifields.py
@@ -20,7 +20,7 @@ import os
from typing import Optional
import sqlalchemy_jsonfield
-from sqlalchemy import Column, ForeignKeyConstraint, Integer, and_, not_,
tuple_
+from sqlalchemy import Column, ForeignKeyConstraint, Integer, and_, not_,
text, tuple_
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Session, relationship
@@ -41,7 +41,7 @@ class RenderedTaskInstanceFields(Base):
dag_id = Column(StringID(), primary_key=True)
task_id = Column(StringID(), primary_key=True)
run_id = Column(StringID(), primary_key=True)
- map_index = Column(Integer, primary_key=True, server_default='-1')
+ map_index = Column(Integer, primary_key=True, server_default=text('-1'))
rendered_fields = Column(sqlalchemy_jsonfield.JSONField(json=json),
nullable=False)
k8s_pod_yaml = Column(sqlalchemy_jsonfield.JSONField(json=json),
nullable=True)
diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py
index 45424709a8..9114af987c 100644
--- a/airflow/models/serialized_dag.py
+++ b/airflow/models/serialized_dag.py
@@ -67,7 +67,7 @@ class SerializedDagModel(Base):
dag_id = Column(String(ID_LEN), primary_key=True)
fileloc = Column(String(2000), nullable=False)
# The max length of fileloc exceeds the limit of indexing.
- fileloc_hash = Column(BigInteger, nullable=False)
+ fileloc_hash = Column(BigInteger(), nullable=False)
_data = Column('data', sqlalchemy_jsonfield.JSONField(json=json),
nullable=True)
_data_compressed = Column('data_compressed', LargeBinary, nullable=True)
last_updated = Column(UtcDateTime, nullable=False)
diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py
index b5f23d8ec5..ee53bde2a4 100644
--- a/airflow/models/taskfail.py
+++ b/airflow/models/taskfail.py
@@ -17,7 +17,7 @@
# under the License.
"""Taskfail tracks the failed run durations of each task instance"""
-from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer
+from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, text
from sqlalchemy.orm import relationship
from airflow.models.base import Base, StringID
@@ -33,7 +33,7 @@ class TaskFail(Base):
task_id = Column(StringID(), nullable=False)
dag_id = Column(StringID(), nullable=False)
run_id = Column(StringID(), nullable=False)
- map_index = Column(Integer, nullable=False)
+ map_index = Column(Integer, nullable=False, server_default=text('-1'))
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
duration = Column(Integer)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 3fc05d58b8..cade4c7ad4 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -53,6 +53,7 @@ import pendulum
from jinja2 import TemplateAssertionError, UndefinedError
from sqlalchemy import (
Column,
+ DateTime,
Float,
ForeignKeyConstraint,
Index,
@@ -75,7 +76,6 @@ from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.elements import BooleanClauseList
from sqlalchemy.sql.expression import ColumnOperators
-from sqlalchemy.sql.sqltypes import BigInteger
from airflow import settings
from airflow.compat.functools import cache
@@ -95,7 +95,7 @@ from airflow.exceptions import (
UnmappableXComTypePushed,
XComForMappingNotPushed,
)
-from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
+from airflow.models.base import Base, StringID
from airflow.models.log import Log
from airflow.models.param import ParamsDict
from airflow.models.taskfail import TaskFail
@@ -434,9 +434,9 @@ class TaskInstance(Base, LoggingMixin):
__tablename__ = "task_instance"
- task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True,
nullable=False)
- dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True,
nullable=False)
- run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True,
nullable=False)
+ task_id = Column(StringID(), primary_key=True, nullable=False)
+ dag_id = Column(StringID(), primary_key=True, nullable=False)
+ run_id = Column(StringID(), primary_key=True, nullable=False)
map_index = Column(Integer, primary_key=True, nullable=False,
server_default=text("-1"))
start_date = Column(UtcDateTime)
@@ -444,12 +444,12 @@ class TaskInstance(Base, LoggingMixin):
duration = Column(Float)
state = Column(String(20))
_try_number = Column('try_number', Integer, default=0)
- max_tries = Column(Integer)
+ max_tries = Column(Integer, server_default=text("-1"))
hostname = Column(String(1000))
unixname = Column(String(1000))
job_id = Column(Integer)
pool = Column(String(256), nullable=False)
- pool_slots = Column(Integer, default=1, nullable=False,
server_default=text("1"))
+ pool_slots = Column(Integer, default=1, nullable=False)
queue = Column(String(256))
priority_weight = Column(Integer)
operator = Column(String(1000))
@@ -458,13 +458,16 @@ class TaskInstance(Base, LoggingMixin):
pid = Column(Integer)
executor_config = Column(PickleType(pickler=dill,
comparator=_executor_config_comparator))
- external_executor_id = Column(String(ID_LEN, **COLLATION_ARGS))
+ external_executor_id = Column(StringID())
# The trigger to resume on if we are in state DEFERRED
- trigger_id = Column(BigInteger)
+ trigger_id = Column(Integer)
# Optional timeout datetime for the trigger (past this, we'll fail)
- trigger_timeout = Column(UtcDateTime)
+ trigger_timeout = Column(DateTime)
+ # The trigger_timeout should be TIMESTAMP(using UtcDateTime) but for ease
of
+ # migration, we are keeping it as DateTime pending a change where expensive
+ # migration is inevitable.
# The method to call next, and any extra arguments to pass to it.
# Usually used when resuming from DEFERRED.
diff --git a/airflow/models/variable.py b/airflow/models/variable.py
index 0904ddb23e..82253eddbc 100644
--- a/airflow/models/variable.py
+++ b/airflow/models/variable.py
@@ -22,6 +22,7 @@ from typing import Any, Optional
from cryptography.fernet import InvalidToken as InvalidFernetToken
from sqlalchemy import Boolean, Column, Integer, String, Text
+from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import Session, reconstructor, synonym
@@ -47,7 +48,7 @@ class Variable(Base, LoggingMixin):
id = Column(Integer, primary_key=True)
key = Column(String(ID_LEN), unique=True)
- _val = Column('val', Text)
+ _val = Column('val', Text().with_variant(MEDIUMTEXT, 'mysql'))
description = Column(Text)
is_encrypted = Column(Boolean, unique=False, default=False)
diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index aad720bd8b..146e67269f 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -26,7 +26,7 @@ from functools import wraps
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type, Union,
cast, overload
import pendulum
-from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer,
LargeBinary, String
+from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer,
LargeBinary, String, text
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Query, Session, reconstructor, relationship
from sqlalchemy.orm.exc import NoResultFound
@@ -57,7 +57,7 @@ class BaseXCom(Base, LoggingMixin):
dag_run_id = Column(Integer(), nullable=False, primary_key=True)
task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False,
primary_key=True)
- map_index = Column(Integer, primary_key=True, nullable=False,
server_default="-1")
+ map_index = Column(Integer, primary_key=True, nullable=False,
server_default=text("-1"))
key = Column(String(512, **COLLATION_ARGS), nullable=False,
primary_key=True)
# Denormalized for easier lookup.
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index bb2c2a6bae..fecd7d7846 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -1675,6 +1675,61 @@ def create_global_lock(
pass
+def compare_type(context, inspected_column, metadata_column, inspected_type,
metadata_type):
+ """
+ Compare types between ORM and DB .
+
+ return False if the metadata_type is the same as the inspected_type
+ or None to allow the default implementation to compare these
+ types. a return value of True means the two types do not
+ match and should result in a type change operation.
+ """
+ if context.dialect.name == 'mysql':
+ from sqlalchemy import String
+ from sqlalchemy.dialects import mysql
+
+ if isinstance(inspected_type, mysql.VARCHAR) and
isinstance(metadata_type, String):
+ # This is a hack to get around MySQL VARCHAR collation
+ # not being possible to change from utf8_bin to utf8mb3_bin
+ return False
+ return None
+
+
+def compare_server_default(
+ context, inspected_column, metadata_column, inspected_default,
metadata_default, rendered_metadata_default
+):
+ """
+ Compare server defaults between ORM and DB .
+
+ return True if the defaults are different, False if not, or None to allow
the default implementation
+ to compare these defaults
+
+ Comparing server_default is not accurate in MSSQL because the
+ inspected_default above != metadata_default, while in Postgres/MySQL they
are equal.
+ This is an issue with alembic
+ In SQLite: task_instance.map_index & task_reschedule.map_index
+ are not comparing accurately. Sometimes they are equal, sometimes they are
not.
+ Alembic warned that this feature has varied accuracy depending on backends.
+ See:
(https://alembic.sqlalchemy.org/en/latest/api/runtime.html#alembic.runtime.
+ environment.EnvironmentContext.configure.params.compare_server_default)
+ """
+ dialect_name = context.connection.dialect.name
+ if dialect_name in ['mssql', 'sqlite']:
+ return False
+ if (
+ dialect_name == 'mysql'
+ and metadata_column.name == 'pool_slots'
+ and metadata_column.table.name == 'task_instance'
+ ):
+ # We removed server_default value in ORM to avoid expensive migration
+ # (it was removed in postgres DB in migration head 7b2661a43ba3 ).
+ # As a side note, server default value here was only actually needed
for the migration
+ # where we added the column in the first place -- now that it exists
and all
+ # existing rows are populated with a value this server default is
never used.
+ return False
+ return None
+
+
def get_sqla_model_classes():
"""
Get all SQLAlchemy class mappers.
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index 6d751d106a..5dd47d157b 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -23,12 +23,13 @@ from typing import Any, Dict, Iterable, Tuple
import pendulum
from dateutil import relativedelta
-from sqlalchemy import and_, event, false, nullsfirst, or_, tuple_
+from sqlalchemy import TIMESTAMP, and_, event, false, nullsfirst, or_, tuple_
+from sqlalchemy.dialects import mssql, mysql
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import ColumnElement
from sqlalchemy.sql.expression import ColumnOperators
-from sqlalchemy.types import JSON, DateTime, Text, TypeDecorator, TypeEngine,
UnicodeText
+from sqlalchemy.types import JSON, Text, TypeDecorator, TypeEngine, UnicodeText
from airflow import settings
from airflow.configuration import conf
@@ -42,21 +43,21 @@ using_mysql = conf.get_mandatory_value('database',
'sql_alchemy_conn').lower().s
class UtcDateTime(TypeDecorator):
"""
- Almost equivalent to :class:`~sqlalchemy.types.DateTime` with
+ Almost equivalent to :class:`~sqlalchemy.types.TIMESTAMP` with
``timezone=True`` option, but it differs from that by:
- Never silently take naive :class:`~datetime.datetime`, instead it
always raise :exc:`ValueError` unless time zone aware value.
- :class:`~datetime.datetime` value's :attr:`~datetime.datetime.tzinfo`
is always converted to UTC.
- - Unlike SQLAlchemy's built-in :class:`~sqlalchemy.types.DateTime`,
+ - Unlike SQLAlchemy's built-in :class:`~sqlalchemy.types.TIMESTAMP`,
it never return naive :class:`~datetime.datetime`, but time zone
aware value, even with SQLite or MySQL.
- - Always returns DateTime in UTC
+ - Always returns TIMESTAMP in UTC
"""
- impl = DateTime(timezone=True)
+ impl = TIMESTAMP(timezone=True)
def process_bind_param(self, value, dialect):
if value is not None:
@@ -92,6 +93,13 @@ class UtcDateTime(TypeDecorator):
return value
+ def load_dialect_impl(self, dialect):
+ if dialect.name == 'mssql':
+ return mssql.DATETIME2(precision=6)
+ elif dialect.name == 'mysql':
+ return mysql.TIMESTAMP(fsp=6)
+ return super().load_dialect_impl(dialect)
+
class ExtendedJSON(TypeDecorator):
"""
diff --git a/airflow/www/fab_security/sqla/models.py
b/airflow/www/fab_security/sqla/models.py
index c6eb65ead1..062cb74c4e 100644
--- a/airflow/www/fab_security/sqla/models.py
+++ b/airflow/www/fab_security/sqla/models.py
@@ -20,21 +20,11 @@ import datetime
# This product contains a modified portion of 'Flask App Builder' developed by
Daniel Vaz Gaspar.
# (https://github.com/dpgaspar/Flask-AppBuilder).
# Copyright 2013, Daniel Vaz Gaspar
-from typing import TYPE_CHECKING, Set, Tuple, Union
+from typing import TYPE_CHECKING, Set, Tuple
from flask import current_app, g
from flask_appbuilder.models.sqla import Model
-from sqlalchemy import (
- Boolean,
- Column,
- DateTime,
- ForeignKey,
- Integer,
- Sequence,
- String,
- Table,
- UniqueConstraint,
-)
+from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String,
Table, UniqueConstraint
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import backref, relationship
@@ -53,41 +43,11 @@ if TYPE_CHECKING:
Identity = None
-def get_sequence_or_identity(sequence_name: str) -> Union[Sequence,
'Identity']:
- """
- Depending on the engine it either returns Sequence, or Identity (in case
of MSSQL in SQLAlchemy 1.4).
- In SQLAlchemy 1.4 using sequence is not allowed for primary key columns in
MsSQL.
- Primary columns in MsSQL use IDENTITY keyword to auto increment.
- Using Sequence for those fields used to be allowed in SQLAlchemy 1.3 (and
essentially ignored
- if only name was specified).
-
- See https://docs.sqlalchemy.org/en/14/dialects/mssql.html
-
- Changed in version 1.4: Removed the ability to use a Sequence object
to modify IDENTITY
- characteristics. Sequence objects now only manipulate true T-SQL
SEQUENCE types.
-
- :param sequence_name: name of the sequence
- :return: Sequence or Identity
- """
- from airflow.settings import SQL_ALCHEMY_CONN
-
- if SQL_ALCHEMY_CONN is not None and SQL_ALCHEMY_CONN.startswith('mssql'):
- try:
- from sqlalchemy import Identity
-
- return Identity()
- except Exception:
- # Identity object is only available in SQLAlchemy 1.4.
- # For SQLAlchemy 1.3 compatibility we return original Sequence if
Identity is missing
- pass
- return Sequence(sequence_name)
-
-
class Action(Model):
"""Represents permission actions such as `can_read`."""
__tablename__ = "ab_permission"
- id = Column(Integer, get_sequence_or_identity("ab_permission_id_seq"),
primary_key=True)
+ id = Column(Integer, primary_key=True)
name = Column(String(100), unique=True, nullable=False)
def __repr__(self):
@@ -98,7 +58,7 @@ class Resource(Model):
"""Represents permission object such as `User` or `Dag`."""
__tablename__ = "ab_view_menu"
- id = Column(Integer, get_sequence_or_identity("ab_view_menu_id_seq"),
primary_key=True)
+ id = Column(Integer, primary_key=True)
name = Column(String(250), unique=True, nullable=False)
def __eq__(self, other):
@@ -114,7 +74,7 @@ class Resource(Model):
assoc_permission_role = Table(
"ab_permission_view_role",
Model.metadata,
- Column("id", Integer,
get_sequence_or_identity("ab_permission_view_role_id_seq"), primary_key=True),
+ Column("id", Integer, primary_key=True),
Column("permission_view_id", Integer, ForeignKey("ab_permission_view.id")),
Column("role_id", Integer, ForeignKey("ab_role.id")),
UniqueConstraint("permission_view_id", "role_id"),
@@ -126,7 +86,7 @@ class Role(Model):
__tablename__ = "ab_role"
- id = Column(Integer, get_sequence_or_identity("ab_role_id_seq"),
primary_key=True)
+ id = Column(Integer, primary_key=True)
name = Column(String(64), unique=True, nullable=False)
permissions = relationship("Permission", secondary=assoc_permission_role,
backref="role", lazy="joined")
@@ -139,7 +99,7 @@ class Permission(Model):
__tablename__ = "ab_permission_view"
__table_args__ = (UniqueConstraint("permission_id", "view_menu_id"),)
- id = Column(Integer,
get_sequence_or_identity("ab_permission_view_id_seq"), primary_key=True)
+ id = Column(Integer, primary_key=True)
action_id = Column("permission_id", Integer,
ForeignKey("ab_permission.id"))
action = relationship(
"Action",
@@ -160,7 +120,7 @@ class Permission(Model):
assoc_user_role = Table(
"ab_user_role",
Model.metadata,
- Column("id", Integer, get_sequence_or_identity("ab_user_role_id_seq"),
primary_key=True),
+ Column("id", Integer, primary_key=True),
Column("user_id", Integer, ForeignKey("ab_user.id")),
Column("role_id", Integer, ForeignKey("ab_role.id")),
UniqueConstraint("user_id", "role_id"),
@@ -171,7 +131,7 @@ class User(Model):
"""Represents an Airflow user which has roles assigned to it."""
__tablename__ = "ab_user"
- id = Column(Integer, get_sequence_or_identity("ab_user_id_seq"),
primary_key=True)
+ id = Column(Integer, primary_key=True)
first_name = Column(String(64), nullable=False)
last_name = Column(String(64), nullable=False)
username = Column(String(256), unique=True, nullable=False)
@@ -264,7 +224,7 @@ class RegisterUser(Model):
"""Represents a user registration."""
__tablename__ = "ab_register_user"
- id = Column(Integer, get_sequence_or_identity("ab_register_user_id_seq"),
primary_key=True)
+ id = Column(Integer, primary_key=True)
first_name = Column(String(64), nullable=False)
last_name = Column(String(64), nullable=False)
username = Column(String(256), unique=True, nullable=False)
diff --git a/docs/apache-airflow/migrations-ref.rst
b/docs/apache-airflow/migrations-ref.rst
index 21beddb33c..b45e6c2de7 100644
--- a/docs/apache-airflow/migrations-ref.rst
+++ b/docs/apache-airflow/migrations-ref.rst
@@ -27,7 +27,9 @@ Here's the list of all the Database Migrations that are
executed via when you ru
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
| Revision ID | Revises ID | Airflow Version |
Description |
+=================================+===================+===================+==============================================================+
-| ``424117c37d18`` (head) | ``f5fcbda3e651`` | ``2.4.0`` |
Add DagWarning model |
+| ``44b7034f6bdc`` (head) | ``424117c37d18`` | ``2.4.0`` |
compare types between ORM and DB |
++---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
+| ``424117c37d18`` | ``f5fcbda3e651`` | ``2.4.0`` |
Add DagWarning model |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
| ``f5fcbda3e651`` | ``3c94c427fdf6`` | ``2.3.3`` |
Add indexes for CASCADE deletes on task_instance |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
diff --git a/tests/utils/test_db.py b/tests/utils/test_db.py
index 752f9873aa..60511f2428 100644
--- a/tests/utils/test_db.py
+++ b/tests/utils/test_db.py
@@ -34,7 +34,15 @@ from sqlalchemy import MetaData
from airflow.exceptions import AirflowException
from airflow.models import Base as airflow_base
from airflow.settings import engine
-from airflow.utils.db import check_migrations, create_default_connections,
downgrade, resetdb, upgradedb
+from airflow.utils.db import (
+ check_migrations,
+ compare_server_default,
+ compare_type,
+ create_default_connections,
+ downgrade,
+ resetdb,
+ upgradedb,
+)
class TestDb:
@@ -44,7 +52,10 @@ class TestDb:
all_meta_data._add_table(table_name, table.schema, table)
# create diff between database schema and SQLAlchemy model
- mctx = MigrationContext.configure(engine.connect())
+ mctx = MigrationContext.configure(
+ engine.connect(),
+ opts={'compare_type': compare_type, 'compare_server_default':
compare_server_default},
+ )
diff = compare_metadata(mctx, all_meta_data)
# known diffs to ignore
ignores = [