This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9f167bbc34 Add FAB migration commands (#41804)
9f167bbc34 is described below
commit 9f167bbc34ba4f0f64a6edab90d436275949fc56
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Thu Sep 19 08:21:32 2024 +0100
Add FAB migration commands (#41804)
* Add FAB migration commands
This PR adds `migrate`, `upgrade` and `reset` db commands to facilitate
migrating FAB DBs.
FAB upgrade is also integrated into Airflow upgrade such that if airflow
db is being upgraded to the heads, FAB migration will also upgrade to the
heads.
Migration checks to determine if migration has finished now includes
checking that
FAB migration is also done.
Note that downgrading Airflow does not trigger FAB downgrade. FAB downgrade
has to
be done with FAB downgrade command.
* Fix failing tests
* Add tests
* add more tests
* fixup! add more tests
* fixup! fixup! add more tests
* Refactor code
* fixup! Refactor code
* set airflow_db as false
* Add placeholder migration for correct stamping
* fixup! Add placeholder migration for correct stamping
* fixup! fixup! Add placeholder migration for correct stamping
* fix test
* fixup! fix test
* fix rebase mistake
* use reserialize_dags instead of airflow_db
* add messages around placeholder migration
---
airflow/cli/commands/db_command.py | 97 ++++++++++-----
.../fab/auth_manager/cli_commands/db_command.py | 52 ++++++++
.../fab/auth_manager/cli_commands/definition.py | 61 ++++++++++
.../providers/fab/auth_manager/fab_auth_manager.py | 11 +-
airflow/providers/fab/auth_manager/models/db.py | 74 +++++++++++-
airflow/providers/fab/migrations/script.py.mako | 9 +-
.../versions/0001_1_3_0_placeholder_migration.py} | 36 ++++--
airflow/utils/db.py | 21 +++-
airflow/utils/db_manager.py | 68 +++++++++--
scripts/ci/pre_commit/version_heads_map.py | 81 +++++++------
scripts/in_container/run_migration_reference.py | 7 +-
tests/always/test_project_structure.py | 2 +
.../auth_manager/cli_commands/test_db_command.py | 134 +++++++++++++++++++++
tests/providers/fab/auth_manager/models/test_db.py | 87 +++++++++++--
tests/utils/test_db.py | 11 +-
tests/utils/test_db_manager.py | 31 ++---
16 files changed, 650 insertions(+), 132 deletions(-)
diff --git a/airflow/cli/commands/db_command.py
b/airflow/cli/commands/db_command.py
index 776dabcdea..cd7493212b 100644
--- a/airflow/cli/commands/db_command.py
+++ b/airflow/cli/commands/db_command.py
@@ -72,15 +72,19 @@ def upgradedb(args):
migratedb(args)
-def get_version_revision(version: str, recursion_limit=10) -> str | None:
+def _get_version_revision(
+ version: str, recursion_limit: int = 10, revision_heads_map: dict[str,
str] | None = None
+) -> str | None:
"""
- Recursively search for the revision of the given version.
+ Recursively search for the revision of the given version in
revision_heads_map.
- This searches REVISION_HEADS_MAP for the revision of the given version,
recursively
+ This searches given revision_heads_map for the revision of the given
version, recursively
searching for the previous version if the given version is not found.
"""
- if version in _REVISION_HEADS_MAP:
- return _REVISION_HEADS_MAP[version]
+ if revision_heads_map is None:
+ revision_heads_map = _REVISION_HEADS_MAP
+ if version in revision_heads_map:
+ return revision_heads_map[version]
try:
major, minor, patch = map(int, version.split("."))
except ValueError:
@@ -90,13 +94,19 @@ def get_version_revision(version: str, recursion_limit=10)
-> str | None:
if recursion_limit <= 0:
# Prevent infinite recursion as I can't imagine 10 successive versions
without migration
return None
- return get_version_revision(new_version, recursion_limit)
+ return _get_version_revision(new_version, recursion_limit)
-@cli_utils.action_cli(check_db=False)
-@providers_configuration_loaded
-def migratedb(args):
- """Migrates the metadata database."""
+def run_db_migrate_command(args, command, revision_heads_map: dict[str, str],
reserialize_dags: bool = True):
+ """
+ Run the db migrate command.
+
+ param args: The parsed arguments.
+ param command: The command to run.
+ param airflow_db: Whether the command is for the airflow database.
+
+ :meta private:
+ """
print(f"DB: {settings.engine.url!r}")
if args.to_revision and args.to_version:
raise SystemExit("Cannot supply both `--to-revision` and
`--to-version`.")
@@ -112,12 +122,10 @@ def migratedb(args):
from_revision = args.from_revision
elif args.from_version:
try:
- parsed_version = parse_version(args.from_version)
+ parse_version(args.from_version)
except InvalidVersion:
raise SystemExit(f"Invalid version {args.from_version!r} supplied
as `--from-version`.")
- if parsed_version < parse_version("2.0.0"):
- raise SystemExit("--from-version must be greater or equal to than
2.0.0")
- from_revision = get_version_revision(args.from_version)
+ from_revision = _get_version_revision(args.from_version,
revision_heads_map=revision_heads_map)
if not from_revision:
raise SystemExit(f"Unknown version {args.from_version!r} supplied
as `--from-version`.")
@@ -126,7 +134,7 @@ def migratedb(args):
parse_version(args.to_version)
except InvalidVersion:
raise SystemExit(f"Invalid version {args.to_version!r} supplied as
`--to-version`.")
- to_revision = get_version_revision(args.to_version)
+ to_revision = _get_version_revision(args.to_version,
revision_heads_map=revision_heads_map)
if not to_revision:
raise SystemExit(f"Unknown version {args.to_version!r} supplied as
`--to-version`.")
elif args.to_revision:
@@ -136,21 +144,30 @@ def migratedb(args):
print(f"Performing upgrade to the metadata database
{settings.engine.url!r}")
else:
print("Generating sql for upgrade -- upgrade commands will *not* be
submitted.")
-
- db.upgradedb(
- to_revision=to_revision,
- from_revision=from_revision,
- show_sql_only=args.show_sql_only,
- reserialize_dags=args.reserialize_dags,
- )
+ if reserialize_dags:
+ command(
+ to_revision=to_revision,
+ from_revision=from_revision,
+ show_sql_only=args.show_sql_only,
+ reserialize_dags=True,
+ )
+ else:
+ command(
+ to_revision=to_revision,
+ from_revision=from_revision,
+ show_sql_only=args.show_sql_only,
+ )
if not args.show_sql_only:
print("Database migrating done!")
-@cli_utils.action_cli(check_db=False)
-@providers_configuration_loaded
-def downgrade(args):
- """Downgrades the metadata database."""
+def run_db_downgrade_command(args, command, revision_heads_map: dict[str,
str]):
+ """
+ Run the db downgrade command.
+
+ param args: The parsed arguments.
+ param command: The command to run.
+ """
if args.to_revision and args.to_version:
raise SystemExit("Cannot supply both `--to-revision` and
`--to-version`.")
if args.from_version and args.from_revision:
@@ -162,14 +179,15 @@ def downgrade(args):
if not (args.to_version or args.to_revision):
raise SystemExit("Must provide either --to-revision or --to-version.")
from_revision = None
+ to_revision = None
if args.from_revision:
from_revision = args.from_revision
elif args.from_version:
- from_revision = get_version_revision(args.from_version)
+ from_revision = _get_version_revision(args.from_version,
revision_heads_map=revision_heads_map)
if not from_revision:
raise SystemExit(f"Unknown version {args.from_version!r} supplied
as `--from-version`.")
if args.to_version:
- to_revision = get_version_revision(args.to_version)
+ to_revision = _get_version_revision(args.to_version,
revision_heads_map=revision_heads_map)
if not to_revision:
raise SystemExit(f"Downgrading to version {args.to_version} is not
supported.")
elif args.to_revision:
@@ -188,13 +206,34 @@ def downgrade(args):
).upper()
== "Y"
):
- db.downgrade(to_revision=to_revision, from_revision=from_revision,
show_sql_only=args.show_sql_only)
+ command(to_revision=to_revision, from_revision=from_revision,
show_sql_only=args.show_sql_only)
if not args.show_sql_only:
print("Downgrade complete")
else:
raise SystemExit("Cancelled")
+@cli_utils.action_cli(check_db=False)
+@providers_configuration_loaded
+def migratedb(args):
+ """Migrates the metadata database."""
+ if args.from_version:
+ try:
+ parsed_version = parse_version(args.from_version)
+ except InvalidVersion:
+ raise SystemExit(f"Invalid version {args.from_version!r} supplied
as `--from-version`.")
+ if parsed_version < parse_version("2.0.0"):
+ raise SystemExit("--from-version must be greater or equal to
2.0.0")
+ run_db_migrate_command(args, db.upgradedb, _REVISION_HEADS_MAP,
reserialize_dags=True)
+
+
+@cli_utils.action_cli(check_db=False)
+@providers_configuration_loaded
+def downgrade(args):
+ """Downgrades the metadata database."""
+ run_db_downgrade_command(args, db.downgrade, _REVISION_HEADS_MAP)
+
+
@providers_configuration_loaded
def check_migrations(args):
"""Wait for all airflow migrations to complete. Used for launching airflow
in k8s."""
diff --git a/airflow/providers/fab/auth_manager/cli_commands/db_command.py
b/airflow/providers/fab/auth_manager/cli_commands/db_command.py
new file mode 100644
index 0000000000..8b41cf4216
--- /dev/null
+++ b/airflow/providers/fab/auth_manager/cli_commands/db_command.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow import settings
+from airflow.cli.commands.db_command import run_db_downgrade_command,
run_db_migrate_command
+from airflow.providers.fab.auth_manager.models.db import _REVISION_HEADS_MAP,
FABDBManager
+from airflow.utils import cli as cli_utils
+from airflow.utils.providers_configuration_loader import
providers_configuration_loaded
+
+
+@providers_configuration_loaded
+def resetdb(args):
+ """Reset the metadata database."""
+ print(f"DB: {settings.engine.url!r}")
+ if not (args.yes or input("This will drop existing tables if they exist.
Proceed? (y/n)").upper() == "Y"):
+ raise SystemExit("Cancelled")
+ FABDBManager(settings.Session()).resetdb(skip_init=args.skip_init)
+
+
+@cli_utils.action_cli(check_db=False)
+@providers_configuration_loaded
+def migratedb(args):
+ """Migrates the metadata database."""
+ session = settings.Session()
+ upgrade_command = FABDBManager(session).upgradedb
+ run_db_migrate_command(
+ args, upgrade_command, revision_heads_map=_REVISION_HEADS_MAP,
reserialize_dags=False
+ )
+
+
+@cli_utils.action_cli(check_db=False)
+@providers_configuration_loaded
+def downgrade(args):
+ """Downgrades the metadata database."""
+ session = settings.Session()
+ dwongrade_command = FABDBManager(session).downgrade
+ run_db_downgrade_command(args, dwongrade_command,
revision_heads_map=_REVISION_HEADS_MAP)
diff --git a/airflow/providers/fab/auth_manager/cli_commands/definition.py
b/airflow/providers/fab/auth_manager/cli_commands/definition.py
index c7be5270d5..7f8f1e84e2 100644
--- a/airflow/providers/fab/auth_manager/cli_commands/definition.py
+++ b/airflow/providers/fab/auth_manager/cli_commands/definition.py
@@ -19,8 +19,17 @@ from __future__ import annotations
import textwrap
from airflow.cli.cli_config import (
+ ARG_DB_FROM_REVISION,
+ ARG_DB_FROM_VERSION,
+ ARG_DB_REVISION__DOWNGRADE,
+ ARG_DB_REVISION__UPGRADE,
+ ARG_DB_SKIP_INIT,
+ ARG_DB_SQL_ONLY,
+ ARG_DB_VERSION__DOWNGRADE,
+ ARG_DB_VERSION__UPGRADE,
ARG_OUTPUT,
ARG_VERBOSE,
+ ARG_YES,
ActionCommand,
Arg,
lazy_load_command,
@@ -243,3 +252,55 @@ SYNC_PERM_COMMAND = ActionCommand(
func=lazy_load_command("airflow.providers.fab.auth_manager.cli_commands.sync_perm_command.sync_perm"),
args=(ARG_INCLUDE_DAGS, ARG_VERBOSE),
)
+
+DB_COMMANDS = (
+ ActionCommand(
+ name="migrate",
+ help="Migrates the FAB metadata database to the latest version",
+ description=(
+ "Migrate the schema of the FAB metadata database. "
+ "Create the database if it does not exist "
+ "To print but not execute commands, use option
``--show-sql-only``. "
+ "If using options ``--from-revision`` or ``--from-version``, you
must also use "
+ "``--show-sql-only``, because if actually *running* migrations, we
should only "
+ "migrate from the *current* Alembic revision."
+ ),
+
func=lazy_load_command("airflow.providers.fab.auth_manager.cli_commands.db_command.migratedb"),
+ args=(
+ ARG_DB_REVISION__UPGRADE,
+ ARG_DB_VERSION__UPGRADE,
+ ARG_DB_SQL_ONLY,
+ ARG_DB_FROM_REVISION,
+ ARG_DB_FROM_VERSION,
+ ARG_VERBOSE,
+ ),
+ ),
+ ActionCommand(
+ name="downgrade",
+ help="Downgrade the schema of the FAB metadata database.",
+ description=(
+ "Downgrade the schema of the FAB metadata database. "
+ "You must provide either `--to-revision` or `--to-version`. "
+ "To print but not execute commands, use option `--show-sql-only`. "
+ "If using options `--from-revision` or `--from-version`, you must
also use `--show-sql-only`, "
+ "because if actually *running* migrations, we should only migrate
from the *current* Alembic "
+ "revision."
+ ),
+
func=lazy_load_command("airflow.providers.fab.auth_manager.cli_commands.db_command.downgrade"),
+ args=(
+ ARG_DB_REVISION__DOWNGRADE,
+ ARG_DB_VERSION__DOWNGRADE,
+ ARG_DB_SQL_ONLY,
+ ARG_YES,
+ ARG_DB_FROM_REVISION,
+ ARG_DB_FROM_VERSION,
+ ARG_VERBOSE,
+ ),
+ ),
+ ActionCommand(
+ name="reset",
+ help="Burn down and rebuild the FAB metadata database",
+
func=lazy_load_command("airflow.providers.fab.auth_manager.cli_commands.db_command.resetdb"),
+ args=(ARG_YES, ARG_DB_SKIP_INIT, ARG_VERBOSE),
+ ),
+)
diff --git a/airflow/providers/fab/auth_manager/fab_auth_manager.py
b/airflow/providers/fab/auth_manager/fab_auth_manager.py
index a0ec27bc67..336437061c 100644
--- a/airflow/providers/fab/auth_manager/fab_auth_manager.py
+++ b/airflow/providers/fab/auth_manager/fab_auth_manager.py
@@ -22,11 +22,13 @@ from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Container
+import packaging.version
from connexion import FlaskApi
from flask import Blueprint, url_for
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
+from airflow import __version__ as airflow_version
from airflow.auth.managers.base_auth_manager import BaseAuthManager,
ResourceMethod
from airflow.auth.managers.models.resource_details import (
AccessView,
@@ -47,6 +49,7 @@ from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException, AirflowException
from airflow.models import DagModel
from airflow.providers.fab.auth_manager.cli_commands.definition import (
+ DB_COMMANDS,
ROLES_COMMANDS,
SYNC_PERM_COMMAND,
USERS_COMMANDS,
@@ -132,7 +135,7 @@ class FabAuthManager(BaseAuthManager):
@staticmethod
def get_cli_commands() -> list[CLICommand]:
"""Vends CLI commands to be included in Airflow CLI."""
- return [
+ commands: list[CLICommand] = [
GroupCommand(
name="users",
help="Manage users",
@@ -145,6 +148,12 @@ class FabAuthManager(BaseAuthManager):
),
SYNC_PERM_COMMAND, # not in a command group
]
+ # If Airflow version is 3.0.0 or higher, add the fab-db command group
+ if packaging.version.parse(
+ packaging.version.parse(airflow_version).base_version
+ ) >= packaging.version.parse("3.0.0"):
+ commands.append(GroupCommand(name="fab-db", help="Manage FAB",
subcommands=DB_COMMANDS))
+ return commands
def get_api_endpoints(self) -> None | Blueprint:
folder = Path(__file__).parents[0].resolve() # this is
airflow/auth/managers/fab/
diff --git a/airflow/providers/fab/auth_manager/models/db.py
b/airflow/providers/fab/auth_manager/models/db.py
index a971ea29a3..f72e1fcc65 100644
--- a/airflow/providers/fab/auth_manager/models/db.py
+++ b/airflow/providers/fab/auth_manager/models/db.py
@@ -19,17 +19,89 @@ from __future__ import annotations
import os
import airflow
+from airflow import settings
+from airflow.exceptions import AirflowException
from airflow.providers.fab.auth_manager.models import metadata
+from airflow.utils.db import _offline_migration, print_happy_cat
from airflow.utils.db_manager import BaseDBManager
PACKAGE_DIR = os.path.dirname(airflow.__file__)
+_REVISION_HEADS_MAP: dict[str, str] = {
+ "1.3.0": "6709f7a774b9",
+}
+
class FABDBManager(BaseDBManager):
"""Manages FAB database."""
metadata = metadata
- version_table_name = "fab_alembic_version"
+ version_table_name = "alembic_version_fab"
migration_dir = os.path.join(PACKAGE_DIR, "providers/fab/migrations")
alembic_file = os.path.join(PACKAGE_DIR, "providers/fab/alembic.ini")
supports_table_dropping = True
+
+ def upgradedb(self, to_revision=None, from_revision=None,
show_sql_only=False):
+ """Upgrade the database."""
+ if from_revision and not show_sql_only:
+ raise AirflowException("`from_revision` only supported with
`sql_only=True`.")
+
+ # alembic adds significant import time, so we import it lazily
+ if not settings.SQL_ALCHEMY_CONN:
+ raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set. This is
a critical assertion.")
+ from alembic import command
+
+ config = self.get_alembic_config()
+
+ if show_sql_only:
+ if settings.engine.dialect.name == "sqlite":
+ raise SystemExit("Offline migration not supported for SQLite.")
+ if not from_revision:
+ from_revision = self.get_current_revision()
+
+ if not to_revision:
+ script = self.get_script_object(config)
+ to_revision = script.get_current_head()
+
+ if to_revision == from_revision:
+ print_happy_cat("No migrations to apply; nothing to do.")
+ return
+ _offline_migration(command.upgrade, config,
f"{from_revision}:{to_revision}")
+ return # only running sql; our job is done
+
+ if not self.get_current_revision():
+ # New DB; initialize and exit
+ self.initdb()
+ return
+
+ command.upgrade(config, revision=to_revision or "heads")
+
+ def downgrade(self, to_revision, from_revision=None, show_sql_only=False):
+ if from_revision and not show_sql_only:
+ raise ValueError(
+ "`from_revision` can't be combined with `show_sql_only=False`.
When actually "
+ "applying a downgrade (instead of just generating sql), we
always "
+ "downgrade from current revision."
+ )
+
+ if not settings.SQL_ALCHEMY_CONN:
+ raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set.")
+
+ # alembic adds significant import time, so we import it lazily
+ from alembic import command
+
+ self.log.info("Attempting downgrade of FAB migration to revision %s",
to_revision)
+ config = self.get_alembic_config()
+
+ if show_sql_only:
+ self.log.warning("Generating sql scripts for manual migration.")
+ if not from_revision:
+ from_revision = self.get_current_revision()
+ if from_revision is None:
+ self.log.info("No revision found")
+ return
+ revision_range = f"{from_revision}:{to_revision}"
+ _offline_migration(command.downgrade, config=config,
revision=revision_range)
+ else:
+ self.log.info("Applying FAB downgrade migrations.")
+ command.downgrade(config, revision=to_revision, sql=show_sql_only)
diff --git a/airflow/providers/fab/migrations/script.py.mako
b/airflow/providers/fab/migrations/script.py.mako
index 4d0928fcc0..c0193ce2b0 100644
--- a/airflow/providers/fab/migrations/script.py.mako
+++ b/airflow/providers/fab/migrations/script.py.mako
@@ -30,10 +30,11 @@ import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
-revision: str = ${repr(up_revision)}
-down_revision: Union[str, None] = ${repr(down_revision)}
-branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
-depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
+revision = ${repr(up_revision)}
+down_revision = ${repr(down_revision)}
+branch_labels = ${repr(branch_labels)}
+depends_on = ${repr(depends_on)}
+fab_version = None
def upgrade() -> None:
diff --git a/airflow/providers/fab/auth_manager/models/db.py
b/airflow/providers/fab/migrations/versions/0001_1_3_0_placeholder_migration.py
similarity index 60%
copy from airflow/providers/fab/auth_manager/models/db.py
copy to
airflow/providers/fab/migrations/versions/0001_1_3_0_placeholder_migration.py
index a971ea29a3..685216779d 100644
--- a/airflow/providers/fab/auth_manager/models/db.py
+++
b/airflow/providers/fab/migrations/versions/0001_1_3_0_placeholder_migration.py
@@ -1,3 +1,4 @@
+#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -14,22 +15,31 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-import os
+"""
+placeholder migration.
+
+Revision ID: 6709f7a774b9
+Revises:
+Create Date: 2024-09-03 17:06:38.040510
+
+Note: This is a placeholder migration used to stamp the migration
+when we create the migration from the ORM. Otherwise, it will run
+without stamping the migration, leading to subsequent changes to
+the tables not being migrated.
+"""
+
+from __future__ import annotations
-import airflow
-from airflow.providers.fab.auth_manager.models import metadata
-from airflow.utils.db_manager import BaseDBManager
+# revision identifiers, used by Alembic.
+revision = "6709f7a774b9"
+down_revision = None
+branch_labels = None
+depends_on = None
+fab_version = "1.3.0"
-PACKAGE_DIR = os.path.dirname(airflow.__file__)
+def upgrade() -> None: ...
-class FABDBManager(BaseDBManager):
- """Manages FAB database."""
- metadata = metadata
- version_table_name = "fab_alembic_version"
- migration_dir = os.path.join(PACKAGE_DIR, "providers/fab/migrations")
- alembic_file = os.path.join(PACKAGE_DIR, "providers/fab/alembic.ini")
- supports_table_dropping = True
+def downgrade() -> None: ...
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index bc4c60697c..1dc7275610 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -89,7 +89,7 @@ T = TypeVar("T")
log = logging.getLogger(__name__)
-_REVISION_HEADS_MAP = {
+_REVISION_HEADS_MAP: dict[str, str] = {
"2.7.0": "405de8318b3a",
"2.8.0": "10b52ebd31f7",
"2.8.1": "88344c1d9134",
@@ -819,6 +819,7 @@ def check_migrations(timeout):
:return: None
"""
timeout = timeout or 1 # run the loop at least 1
+ external_db_manager = RunDBManager()
with _configured_alembic_environment() as env:
context = env.get_context()
source_heads = None
@@ -826,7 +827,7 @@ def check_migrations(timeout):
for ticker in range(timeout):
source_heads = set(env.script.get_heads())
db_heads = set(context.get_current_heads())
- if source_heads == db_heads:
+ if source_heads == db_heads and
external_db_manager.check_migration(settings.Session()):
return
time.sleep(1)
log.info("Waiting for migrations... %s second(s)", ticker)
@@ -1027,6 +1028,11 @@ def _check_migration_errors(session: Session =
NEW_SESSION) -> Iterable[str]:
def _offline_migration(migration_func: Callable, config, revision):
+ """
+ Run offline migration.
+
+ :meta private:
+ """
with warnings.catch_warnings():
warnings.simplefilter("ignore")
logging.disable(logging.CRITICAL)
@@ -1160,6 +1166,13 @@ def upgradedb(
os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE"] = "1"
settings.reconfigure_orm(pool_class=sqlalchemy.pool.SingletonThreadPool)
command.upgrade(config, revision=to_revision or "heads")
+ current_revision = _get_current_revision(session=session)
+ with _configured_alembic_environment() as env:
+ source_heads = env.script.get_heads()
+ if current_revision == source_heads[0]:
+ # Only run external DB upgrade migration if user upgraded to
heads
+ external_db_manager = RunDBManager()
+ external_db_manager.upgradedb(session)
finally:
if val is None:
@@ -1168,8 +1181,6 @@ def upgradedb(
os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE"] = val
settings.reconfigure_orm()
- current_revision = _get_current_revision(session=session)
-
if reserialize_dags and current_revision != previous_revision:
_reserialize_dags(session=session)
add_default_pool_if_not_exists(session=session)
@@ -1191,7 +1202,7 @@ def resetdb(session: Session = NEW_SESSION, skip_init:
bool = False):
drop_airflow_models(connection)
drop_airflow_moved_tables(connection)
external_db_manager = RunDBManager()
- external_db_manager.drop_tables(connection)
+ external_db_manager.drop_tables(session, connection)
if not skip_init:
initdb(session=session)
diff --git a/airflow/utils/db_manager.py b/airflow/utils/db_manager.py
index 78d3244cd2..2241a646ef 100644
--- a/airflow/utils/db_manager.py
+++ b/airflow/utils/db_manager.py
@@ -20,13 +20,16 @@ import os
from typing import TYPE_CHECKING
from alembic import command
+from sqlalchemy import inspect
+from airflow import settings
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import import_string
if TYPE_CHECKING:
+ from alembic.script import ScriptDirectory
from sqlalchemy import MetaData
@@ -54,14 +57,34 @@ class BaseDBManager(LoggingMixin):
config.set_main_option("sqlalchemy.url",
settings.SQL_ALCHEMY_CONN.replace("%", "%%"))
return config
- def get_current_revision(self):
+ def get_script_object(self, config=None) -> ScriptDirectory:
+ from alembic.script import ScriptDirectory
+
+ if not config:
+ config = self.get_alembic_config()
+ return ScriptDirectory.from_config(config)
+
+ def _get_migration_ctx(self):
from alembic.migration import MigrationContext
conn = self.session.connection()
- migration_ctx = MigrationContext.configure(conn,
opts={"version_table": self.version_table_name})
+ return MigrationContext.configure(conn, opts={"version_table":
self.version_table_name})
- return migration_ctx.get_current_revision()
+ def get_current_revision(self):
+ return self._get_migration_ctx().get_current_revision()
+
+ def check_migration(self):
+ """Check migration done."""
+ script_heads = self.get_script_object().get_heads()
+ db_heads = self.get_current_revision()
+ if db_heads:
+ db_heads = {db_heads}
+ if not db_heads and not script_heads:
+ return True
+ if set(script_heads) == db_heads:
+ return True
+ return False
def _create_db_from_orm(self):
"""Create database from ORM."""
@@ -70,6 +93,22 @@ class BaseDBManager(LoggingMixin):
config = self.get_alembic_config()
command.stamp(config, "head")
+ def drop_tables(self, connection):
+ self.metadata.drop_all(connection)
+ version = self._get_migration_ctx()._version
+ if inspect(connection).has_table(version.name):
+ version.drop(connection)
+
+ def resetdb(self, skip_init=False):
+ from airflow.utils.db import DBLocks, create_global_lock
+
+ connection = settings.engine.connect()
+
+ with create_global_lock(self.session, lock=DBLocks.MIGRATIONS),
connection.begin():
+ self.drop_tables(connection)
+ if not skip_init:
+ self.initdb()
+
def initdb(self):
"""Initialize the database."""
db_exists = self.get_current_revision()
@@ -78,14 +117,14 @@ class BaseDBManager(LoggingMixin):
else:
self._create_db_from_orm()
- def upgradedb(self, to_version=None, from_version=None,
show_sql_only=False):
+ def upgradedb(self, to_revision=None, from_revision=None,
show_sql_only=False):
"""Upgrade the database."""
self.log.info("Upgrading the %s database", self.__class__.__name__)
config = self.get_alembic_config()
- command.upgrade(config, revision=to_version or "heads",
sql=show_sql_only)
+ command.upgrade(config, revision=to_revision or "heads",
sql=show_sql_only)
- def downgradedb(self, to_version, from_version=None, show_sql_only=False):
+ def downgrade(self, to_version, from_version=None, show_sql_only=False):
"""Downgrade the database."""
raise NotImplementedError
@@ -141,6 +180,14 @@ class RunDBManager(LoggingMixin):
if manager.version_table_name == "alembic_version":
raise AirflowException(f"{manager}.version_table_name cannot be
'alembic_version'")
+ def check_migration(self, session):
+ """Check the external database migration."""
+ return_value = []
+ for manager in self._managers:
+ m = manager(session)
+ return_value.append(m.check_migration)
+ return all([x() for x in return_value])
+
def initdb(self, session):
"""Initialize the external database managers."""
for manager in self._managers:
@@ -153,14 +200,15 @@ class RunDBManager(LoggingMixin):
m = manager(session)
m.upgradedb()
- def downgradedb(self, session):
+ def downgrade(self, session):
"""Downgrade the external database managers."""
for manager in self._managers:
m = manager(session)
- m.downgradedb()
+ m.downgrade()
- def drop_tables(self, connection):
+ def drop_tables(self, session, connection):
"""Drop the external database managers."""
for manager in self._managers:
if manager.supports_table_dropping:
- manager.metadata.drop_all(connection)
+ m = manager(session)
+ m.drop_tables(connection)
diff --git a/scripts/ci/pre_commit/version_heads_map.py
b/scripts/ci/pre_commit/version_heads_map.py
index 4277c46564..10a6dee2ea 100755
--- a/scripts/ci/pre_commit/version_heads_map.py
+++ b/scripts/ci/pre_commit/version_heads_map.py
@@ -23,21 +23,27 @@ import sys
from pathlib import Path
import re2
-from packaging.version import parse as parse_version
PROJECT_SOURCE_ROOT_DIR = Path(__file__).resolve().parent.parent.parent.parent
DB_FILE = PROJECT_SOURCE_ROOT_DIR / "airflow" / "utils" / "db.py"
MIGRATION_PATH = PROJECT_SOURCE_ROOT_DIR / "airflow" / "migrations" /
"versions"
+FAB_DB_FILE = PROJECT_SOURCE_ROOT_DIR / "airflow" / "providers" / "fab" /
"auth_manager" / "models" / "db.py"
+FAB_MIGRATION_PATH = PROJECT_SOURCE_ROOT_DIR / "airflow" / "providers" / "fab"
/ "migrations" / "versions"
+
sys.path.insert(0, str(Path(__file__).parent.resolve())) # make sure
common_precommit_utils is importable
-def revision_heads_map():
+def revision_heads_map(migration_path):
rh_map = {}
pattern = r'revision = "[a-fA-F0-9]+"'
- airflow_version_pattern = r'airflow_version = "\d+\.\d+\.\d+"'
- filenames = os.listdir(MIGRATION_PATH)
+ version_pattern = None
+ if migration_path == MIGRATION_PATH:
+ version_pattern = r'airflow_version = "\d+\.\d+\.\d+"'
+ elif migration_path == FAB_MIGRATION_PATH:
+ version_pattern = r'fab_version = "\d+\.\d+\.\d+"'
+ filenames = os.listdir(migration_path)
def sorting_key(filen):
prefix = filen.split("_")[0]
@@ -46,43 +52,46 @@ def revision_heads_map():
sorted_filenames = sorted(filenames, key=sorting_key)
for filename in sorted_filenames:
- if not filename.endswith(".py"):
+ if not filename.endswith(".py") or filename == "__init__.py":
continue
- with open(os.path.join(MIGRATION_PATH, filename)) as file:
+ with open(os.path.join(migration_path, filename)) as file:
content = file.read()
revision_match = re2.search(pattern, content)
- airflow_version_match = re2.search(airflow_version_pattern,
content)
- if revision_match and airflow_version_match:
+ _version_match = re2.search(version_pattern, content)
+ if revision_match and _version_match:
revision = revision_match.group(0).split('"')[1]
- version = airflow_version_match.group(0).split('"')[1]
- if parse_version(version) >= parse_version("2.0.0"):
- rh_map[version] = revision
+ version = _version_match.group(0).split('"')[1]
+ rh_map[version] = revision
return rh_map
if __name__ == "__main__":
- with open(DB_FILE) as file:
- content = file.read()
-
- pattern = r"_REVISION_HEADS_MAP = {[^}]+\}"
- match = re2.search(pattern, content)
- if not match:
- print(
- f"_REVISION_HEADS_MAP not found in {DB_FILE}. If this has been
removed intentionally, "
- "please update scripts/ci/pre_commit/version_heads_map.py"
- )
- sys.exit(1)
-
- existing_revision_heads_map = match.group(0)
- rh_map = revision_heads_map()
- updated_revision_heads_map = "_REVISION_HEADS_MAP = {\n"
- for k, v in rh_map.items():
- updated_revision_heads_map += f' "{k}": "{v}",\n'
- updated_revision_heads_map += "}"
- if existing_revision_heads_map != updated_revision_heads_map:
- new_content = content.replace(existing_revision_heads_map,
updated_revision_heads_map)
-
- with open(DB_FILE, "w") as file:
- file.write(new_content)
- print("_REVISION_HEADS_MAP updated in db.py. Please commit the
changes.")
- sys.exit(1)
+ paths = [(DB_FILE, MIGRATION_PATH), (FAB_DB_FILE, FAB_MIGRATION_PATH)]
+ for dbfile, mpath in paths:
+ with open(dbfile) as file:
+ content = file.read()
+
+ pattern =
r"_REVISION_HEADS_MAP:\s*dict\[\s*str\s*,\s*str\s*\]\s*=\s*\{[^}]*\}"
+ match = re2.search(pattern, content)
+ if not match:
+ print(
+ f"_REVISION_HEADS_MAP not found in {dbfile}. If this has been
removed intentionally, "
+ "please update scripts/ci/pre_commit/version_heads_map.py"
+ )
+ sys.exit(1)
+
+ existing_revision_heads_map = match.group(0)
+ rh_map = revision_heads_map(mpath)
+ updated_revision_heads_map = "_REVISION_HEADS_MAP: dict[str, str] =
{\n"
+ for k, v in rh_map.items():
+ updated_revision_heads_map += f' "{k}": "{v}",\n'
+ updated_revision_heads_map += "}"
+ if updated_revision_heads_map == "_REVISION_HEADS_MAP: dict[str, str]
= {\n}":
+ updated_revision_heads_map = "_REVISION_HEADS_MAP: dict[str, str]
= {}"
+ if existing_revision_heads_map != updated_revision_heads_map:
+ new_content = content.replace(existing_revision_heads_map,
updated_revision_heads_map)
+
+ with open(dbfile, "w") as file:
+ file.write(new_content)
+ print(f"_REVISION_HEADS_MAP updated in {dbfile}. Please commit the
changes.")
+ sys.exit(1)
diff --git a/scripts/in_container/run_migration_reference.py
b/scripts/in_container/run_migration_reference.py
index 204db7a5ea..fc7d3bd084 100755
--- a/scripts/in_container/run_migration_reference.py
+++ b/scripts/in_container/run_migration_reference.py
@@ -138,8 +138,7 @@ def get_revisions(app="airflow") -> Iterable[Script]:
else:
from airflow.providers.fab.auth_manager.models.db import FABDBManager
- config = FABDBManager(session="").get_alembic_config()
- script = ScriptDirectory.from_config(config)
+ script = FABDBManager(session="").get_script_object()
yield from script.walk_revisions()
@@ -234,9 +233,9 @@ def correct_mismatching_revision_nums(revisions:
Iterable[Script]):
if __name__ == "__main__":
apps = ["airflow", "fab"]
for app in apps:
- console.print("[bright_blue]Updating migration reference")
+ console.print(f"[bright_blue]Updating migration reference for {app}")
revisions = list(reversed(list(get_revisions(app))))
- console.print("[bright_blue]Making sure airflow version updated")
+ console.print(f"[bright_blue]Making sure {app} version updated")
ensure_version(revisions=revisions, app=app)
console.print("[bright_blue]Making sure there's no mismatching
revision numbers")
correct_mismatching_revision_nums(revisions=revisions)
diff --git a/tests/always/test_project_structure.py
b/tests/always/test_project_structure.py
index 1e684888dc..a6d174e639 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -169,6 +169,8 @@ class TestProjectStructure:
modules_files = list(f for f in modules_files if "/_vendor/" not in f)
# Exclude __init__.py
modules_files = list(f for f in modules_files if not
f.endswith("__init__.py"))
+ # Exclude versions file
+ modules_files = list(f for f in modules_files if "/versions/" not in f)
# Change airflow/ to tests/
expected_test_files = list(
f'tests/{f.partition("/")[2]}' for f in modules_files if not
f.endswith("__init__.py")
diff --git a/tests/providers/fab/auth_manager/cli_commands/test_db_command.py
b/tests/providers/fab/auth_manager/cli_commands/test_db_command.py
new file mode 100644
index 0000000000..030b251a55
--- /dev/null
+++ b/tests/providers/fab/auth_manager/cli_commands/test_db_command.py
@@ -0,0 +1,134 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.cli import cli_parser
+
+pytestmark = [pytest.mark.db_test]
+try:
+ from airflow.providers.fab.auth_manager.cli_commands import db_command
+ from airflow.providers.fab.auth_manager.models.db import FABDBManager
+
+ class TestFABCLiDB:
+ @classmethod
+ def setup_class(cls):
+ cls.parser = cli_parser.get_parser()
+
+ @mock.patch.object(FABDBManager, "resetdb")
+ def test_cli_resetdb(self, mock_resetdb):
+ db_command.resetdb(self.parser.parse_args(["fab-db", "reset",
"--yes"]))
+
+ mock_resetdb.assert_called_once_with(skip_init=False)
+
+ @mock.patch.object(FABDBManager, "resetdb")
+ def test_cli_resetdb_skip_init(self, mock_resetdb):
+ db_command.resetdb(self.parser.parse_args(["fab-db", "reset",
"--yes", "--skip-init"]))
+ mock_resetdb.assert_called_once_with(skip_init=True)
+
+ @pytest.mark.parametrize(
+ "args, called_with",
+ [
+ (
+ [],
+ dict(
+ to_revision=None,
+ from_revision=None,
+ show_sql_only=False,
+ ),
+ ),
+ (
+ ["--show-sql-only"],
+ dict(
+ to_revision=None,
+ from_revision=None,
+ show_sql_only=True,
+ ),
+ ),
+ (
+ ["--to-revision", "abc"],
+ dict(
+ to_revision="abc",
+ from_revision=None,
+ show_sql_only=False,
+ ),
+ ),
+ (
+ ["--to-revision", "abc", "--show-sql-only"],
+ dict(to_revision="abc", from_revision=None,
show_sql_only=True),
+ ),
+ (
+ ["--to-revision", "abc", "--from-revision", "abc123",
"--show-sql-only"],
+ dict(
+ to_revision="abc",
+ from_revision="abc123",
+ show_sql_only=True,
+ ),
+ ),
+ ],
+ )
+ @mock.patch.object(FABDBManager, "upgradedb")
+ def test_cli_upgrade_success(self, mock_upgradedb, args, called_with):
+ db_command.migratedb(self.parser.parse_args(["fab-db", "migrate",
*args]))
+ mock_upgradedb.assert_called_once_with(**called_with)
+
+ @pytest.mark.parametrize(
+ "args, pattern",
+ [
+ pytest.param(
+ ["--to-revision", "abc", "--to-version", "1.3.0"],
+ "Cannot supply both",
+ id="to both version and revision",
+ ),
+ pytest.param(
+ ["--from-revision", "abc", "--from-version", "1.3.0"],
+ "Cannot supply both",
+ id="from both version and revision",
+ ),
+ pytest.param(["--to-version", "1.2.0"], "Unknown version
'1.2.0'", id="unknown to version"),
+ pytest.param(["--to-version", "abc"], "Invalid version 'abc'",
id="invalid to version"),
+ pytest.param(
+ ["--to-revision", "abc", "--from-revision", "abc123"],
+ "used with `--show-sql-only`",
+ id="requires offline",
+ ),
+ pytest.param(
+ ["--to-revision", "abc", "--from-version", "1.3.0"],
+ "used with `--show-sql-only`",
+ id="requires offline",
+ ),
+ pytest.param(
+ ["--to-revision", "abc", "--from-version", "1.1.25",
"--show-sql-only"],
+ "Unknown version '1.1.25'",
+ id="unknown from version",
+ ),
+ pytest.param(
+ ["--to-revision", "adaf", "--from-version", "abc",
"--show-sql-only"],
+ "Invalid version 'abc'",
+ id="invalid from version",
+ ),
+ ],
+ )
+ @mock.patch.object(FABDBManager, "upgradedb")
+ def test_cli_migratedb_failure(self, mock_upgradedb, args, pattern):
+ with pytest.raises(SystemExit, match=pattern):
+ db_command.migratedb(self.parser.parse_args(["fab-db",
"migrate", *args]))
+except (ModuleNotFoundError, ImportError):
+ pass
diff --git a/tests/providers/fab/auth_manager/models/test_db.py
b/tests/providers/fab/auth_manager/models/test_db.py
index 7b0dc345c7..528e1cbf09 100644
--- a/tests/providers/fab/auth_manager/models/test_db.py
+++ b/tests/providers/fab/auth_manager/models/test_db.py
@@ -17,6 +17,8 @@
from __future__ import annotations
import os
+import re
+from unittest import mock
import pytest
from alembic.autogenerate import compare_metadata
@@ -35,30 +37,33 @@ try:
from airflow.providers.fab.auth_manager.models.db import FABDBManager
class TestFABDBManager:
- def setup_method(self, session):
+ def setup_method(self):
self.airflow_dir = os.path.dirname(airflow.__file__)
- self.db_manager = FABDBManager(session=session)
- def test_version_table_name_set(self):
- assert self.db_manager.version_table_name == "fab_alembic_version"
+ def test_version_table_name_set(self, session):
+ assert FABDBManager(session=session).version_table_name ==
"alembic_version_fab"
- def test_migration_dir_set(self):
- assert self.db_manager.migration_dir ==
f"{self.airflow_dir}/providers/fab/migrations"
+ def test_migration_dir_set(self, session):
+ assert (
+ FABDBManager(session=session).migration_dir ==
f"{self.airflow_dir}/providers/fab/migrations"
+ )
- def test_alembic_file_set(self):
- assert self.db_manager.alembic_file ==
f"{self.airflow_dir}/providers/fab/alembic.ini"
+ def test_alembic_file_set(self, session):
+ assert (
+ FABDBManager(session=session).alembic_file ==
f"{self.airflow_dir}/providers/fab/alembic.ini"
+ )
- def test_supports_table_dropping_set(self):
- assert self.db_manager.supports_table_dropping is True
+ def test_supports_table_dropping_set(self, session):
+ assert FABDBManager(session=session).supports_table_dropping is
True
- def test_database_schema_and_sqlalchemy_model_are_in_sync(self):
+ def test_database_schema_and_sqlalchemy_model_are_in_sync(self,
session):
def include_object(_, name, type_, *args):
- if type_ == "table" and name not in
self.db_manager.metadata.tables:
+ if type_ == "table" and name not in
FABDBManager(session=session).metadata.tables:
return False
return True
all_meta_data = MetaData()
- for table_name, table in self.db_manager.metadata.tables.items():
+ for table_name, table in
FABDBManager(session=session).metadata.tables.items():
all_meta_data._add_table(table_name, table.schema, table)
# create diff between database schema and SQLAlchemy model
mctx = MigrationContext.configure(
@@ -72,5 +77,61 @@ try:
diff = compare_metadata(mctx, all_meta_data)
assert not diff, "Database schema and SQLAlchemy model are not in
sync: " + str(diff)
+
+
@mock.patch("airflow.providers.fab.auth_manager.models.db._offline_migration")
+ def test_downgrade_sql_no_from(self, mock_om, session, caplog):
+ FABDBManager(session=session).downgrade(to_revision="abc",
show_sql_only=True, from_revision=None)
+ actual = mock_om.call_args.kwargs["revision"]
+ assert re.match(r"[a-z0-9]+:abc", actual) is not None
+
+
@mock.patch("airflow.providers.fab.auth_manager.models.db._offline_migration")
+ def test_downgrade_sql_with_from(self, mock_om, session):
+ FABDBManager(session=session).downgrade(
+ to_revision="abc", show_sql_only=True, from_revision="123"
+ )
+ actual = mock_om.call_args.kwargs["revision"]
+ assert actual == "123:abc"
+
+ @mock.patch("alembic.command.downgrade")
+ def test_downgrade_invalid_combo(self, mock_om, session):
+ """can't combine `sql=False` and `from_revision`"""
+ with pytest.raises(ValueError, match="can't be combined"):
+ FABDBManager(session=session).downgrade(to_revision="abc",
from_revision="123")
+
+ @mock.patch("alembic.command.downgrade")
+ def test_downgrade_with_from(self, mock_om, session):
+ FABDBManager(session=session).downgrade(to_revision="abc")
+ actual = mock_om.call_args.kwargs["revision"]
+ assert actual == "abc"
+
+ @mock.patch.object(FABDBManager, "get_current_revision")
+ def test_sqlite_offline_upgrade_raises_with_revision(self, mock_gcr,
session):
+ with mock.patch(
+
"airflow.providers.fab.auth_manager.models.db.settings.engine.dialect"
+ ) as dialect:
+ dialect.name = "sqlite"
+ with pytest.raises(SystemExit, match="Offline migration not
supported for SQLite"):
+ FABDBManager(session).upgradedb(from_revision=None,
to_revision=None, show_sql_only=True)
+
+ @mock.patch("airflow.utils.db_manager.inspect")
+ @mock.patch.object(FABDBManager, "metadata")
+ def test_drop_tables(self, mock_metadata, mock_inspect, session):
+ manager = FABDBManager(session)
+ connection = mock.MagicMock()
+ manager.drop_tables(connection)
+ mock_metadata.drop_all.assert_called_once_with(connection)
+
+ @pytest.mark.parametrize("skip_init", [True, False])
+ @mock.patch.object(FABDBManager, "drop_tables")
+ @mock.patch.object(FABDBManager, "initdb")
+ @mock.patch("airflow.utils.db.create_global_lock", new=mock.MagicMock)
+ def test_resetdb(self, mock_initdb, mock_drop_tables, session,
skip_init):
+ manager = FABDBManager(session)
+ manager.resetdb(skip_init=skip_init)
+ mock_drop_tables.assert_called_once()
+ if skip_init:
+ mock_initdb.assert_not_called()
+ else:
+ mock_initdb.assert_called_once()
except ModuleNotFoundError:
pass
diff --git a/tests/utils/test_db.py b/tests/utils/test_db.py
index 2308ed3ca8..2a197c2e6c 100644
--- a/tests/utils/test_db.py
+++ b/tests/utils/test_db.py
@@ -49,6 +49,7 @@ from airflow.utils.db import (
upgradedb,
)
from airflow.utils.db_manager import RunDBManager
+from tests.test_utils.config import conf_vars
pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode]
@@ -92,7 +93,7 @@ class TestDb:
# sqlite sequence is used for autoincrementing columns created
with `sqlite_autoincrement` option
lambda t: (t[0] == "remove_table" and t[1].name ==
"sqlite_sequence"),
# fab version table
- lambda t: (t[0] == "remove_table" and t[1].name ==
"fab_alembic_version"),
+ lambda t: (t[0] == "remove_table" and t[1].name ==
"alembic_version_fab"),
]
for ignore in ignores:
@@ -131,7 +132,8 @@ class TestDb:
@mock.patch("alembic.command")
def test_upgradedb(self, mock_alembic_command):
upgradedb()
- mock_alembic_command.upgrade.assert_called_once_with(mock.ANY,
revision="heads")
+ mock_alembic_command.upgrade.assert_called_with(mock.ANY,
revision="heads")
+ assert mock_alembic_command.upgrade.call_count == 2
@pytest.mark.parametrize(
"from_revision, to_revision",
@@ -200,6 +202,10 @@ class TestDb:
assert actual == "abc"
@pytest.mark.parametrize("skip_init", [False, True])
+ @conf_vars(
+ {("database", "external_db_managers"):
"airflow.providers.fab.auth_manager.models.db.FABDBManager"}
+ )
+ @mock.patch("airflow.providers.fab.auth_manager.models.db.FABDBManager")
@mock.patch("airflow.utils.db.create_global_lock", new=MagicMock)
@mock.patch("airflow.utils.db.drop_airflow_models")
@mock.patch("airflow.utils.db.drop_airflow_moved_tables")
@@ -211,6 +217,7 @@ class TestDb:
mock_init,
mock_drop_moved,
mock_drop_airflow,
+ mock_fabdb_manager,
skip_init,
):
session_mock = MagicMock()
diff --git a/tests/utils/test_db_manager.py b/tests/utils/test_db_manager.py
index d6fc91f939..1c8a6c6c7d 100644
--- a/tests/utils/test_db_manager.py
+++ b/tests/utils/test_db_manager.py
@@ -23,7 +23,7 @@ from sqlalchemy import Table
from airflow.exceptions import AirflowException
from airflow.models import Base
-from airflow.utils.db import downgrade, initdb, upgradedb
+from airflow.utils.db import downgrade, initdb
from airflow.utils.db_manager import BaseDBManager, RunDBManager
from tests.test_utils.config import conf_vars
@@ -57,29 +57,26 @@ class TestRunDBManager:
run_db_manager.validate()
metadata._remove_table("dag_run", None)
- @mock.patch.object(RunDBManager, "downgradedb")
+ @mock.patch.object(RunDBManager, "downgrade")
@mock.patch.object(RunDBManager, "upgradedb")
@mock.patch.object(RunDBManager, "initdb")
def test_init_db_calls_rundbmanager(self, mock_initdb, mock_upgrade_db,
mock_downgrade_db, session):
initdb(session=session)
mock_initdb.assert_called()
mock_initdb.assert_called_once_with(session)
- mock_upgrade_db.assert_not_called()
mock_downgrade_db.assert_not_called()
- @mock.patch.object(RunDBManager, "downgradedb")
+ @mock.patch.object(RunDBManager, "downgrade")
@mock.patch.object(RunDBManager, "upgradedb")
@mock.patch.object(RunDBManager, "initdb")
@mock.patch("alembic.command")
- def test_upgradedb_or_downgrade_dont_call_rundbmanager(
+ def test_downgrade_dont_call_rundbmanager(
self, mock_alembic_command, mock_initdb, mock_upgrade_db,
mock_downgrade_db, session
):
- upgradedb(session=session)
- mock_alembic_command.upgrade.assert_called_once_with(mock.ANY,
revision="heads")
downgrade(to_revision="base")
mock_alembic_command.downgrade.assert_called_once_with(mock.ANY,
revision="base", sql=False)
- mock_initdb.assert_not_called()
mock_upgrade_db.assert_not_called()
+ mock_initdb.assert_not_called()
mock_downgrade_db.assert_not_called()
@conf_vars(
@@ -96,12 +93,12 @@ class TestRunDBManager:
# upgradedb
ext_db.upgradedb(session=session)
fabdb_manager.upgradedb.assert_called_once()
- # downgradedb
- ext_db.downgradedb(session=session)
- mock_fabdb_manager.return_value.downgradedb.assert_called_once()
+ # downgrade
+ ext_db.downgrade(session=session)
+ mock_fabdb_manager.return_value.downgrade.assert_called_once()
connection = mock.MagicMock()
- ext_db.drop_tables(connection)
-
mock_fabdb_manager.metadata.drop_all.assert_called_once_with(connection)
+ ext_db.drop_tables(session, connection)
+
mock_fabdb_manager.return_value.drop_tables.assert_called_once_with(connection)
class MockDBManager(BaseDBManager):
@@ -127,8 +124,14 @@ class TestBaseDBManager:
@mock.patch.object(BaseDBManager, "get_alembic_config")
@mock.patch("alembic.command.upgrade")
- def test_upgradedb(self, mock_alembic_cmd, mock_alembic_config, session,
caplog):
+ def test_upgrade(self, mock_alembic_cmd, mock_alembic_config, session,
caplog):
manager = MockDBManager(session)
manager.upgradedb()
mock_alembic_cmd.assert_called_once()
assert "Upgrading the MockDBManager database" in caplog.text
+
+ @mock.patch.object(BaseDBManager, "get_script_object")
+ @mock.patch.object(BaseDBManager, "get_current_revision")
+ def test_check_migration(self, mock_script_obj, mock_current_revision,
session):
+ manager = MockDBManager(session)
+ manager.check_migration() # just ensure this can be called