This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9f167bbc34 Add FAB migration commands (#41804)
9f167bbc34 is described below

commit 9f167bbc34ba4f0f64a6edab90d436275949fc56
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Thu Sep 19 08:21:32 2024 +0100

    Add FAB migration commands (#41804)
    
    * Add FAB migration commands
    
    This PR adds `migrate`, `upgrade` and `reset` db commands to facilitate
    migrating FAB DBs.
    
    FAB upgrade is also integrated into Airflow upgrade such that if airflow
    db is being upgraded to the heads, FAB migration will also upgrade to the 
heads.
    
    Migration checks to determine if migration has finished now includes 
checking that
    FAB migration is also done.
    
    Note that downgrading Airflow does not trigger FAB downgrade. FAB downgrade 
has to
    be done with FAB downgrade command.
    
    * Fix failing tests
    
    * Add tests
    
    * add more tests
    
    * fixup! add more tests
    
    * fixup! fixup! add more tests
    
    * Refactor code
    
    * fixup! Refactor code
    
    * set airflow_db as false
    
    * Add placeholder migration for correct stamping
    
    * fixup! Add placeholder migration for correct stamping
    
    * fixup! fixup! Add placeholder migration for correct stamping
    
    * fix test
    
    * fixup! fix test
    
    * fix rebase mistake
    
    * use reserialize_dags instead of airflow_db
    
    * add messages around placeholder migration
---
 airflow/cli/commands/db_command.py                 |  97 ++++++++++-----
 .../fab/auth_manager/cli_commands/db_command.py    |  52 ++++++++
 .../fab/auth_manager/cli_commands/definition.py    |  61 ++++++++++
 .../providers/fab/auth_manager/fab_auth_manager.py |  11 +-
 airflow/providers/fab/auth_manager/models/db.py    |  74 +++++++++++-
 airflow/providers/fab/migrations/script.py.mako    |   9 +-
 .../versions/0001_1_3_0_placeholder_migration.py}  |  36 ++++--
 airflow/utils/db.py                                |  21 +++-
 airflow/utils/db_manager.py                        |  68 +++++++++--
 scripts/ci/pre_commit/version_heads_map.py         |  81 +++++++------
 scripts/in_container/run_migration_reference.py    |   7 +-
 tests/always/test_project_structure.py             |   2 +
 .../auth_manager/cli_commands/test_db_command.py   | 134 +++++++++++++++++++++
 tests/providers/fab/auth_manager/models/test_db.py |  87 +++++++++++--
 tests/utils/test_db.py                             |  11 +-
 tests/utils/test_db_manager.py                     |  31 ++---
 16 files changed, 650 insertions(+), 132 deletions(-)

diff --git a/airflow/cli/commands/db_command.py 
b/airflow/cli/commands/db_command.py
index 776dabcdea..cd7493212b 100644
--- a/airflow/cli/commands/db_command.py
+++ b/airflow/cli/commands/db_command.py
@@ -72,15 +72,19 @@ def upgradedb(args):
     migratedb(args)
 
 
-def get_version_revision(version: str, recursion_limit=10) -> str | None:
+def _get_version_revision(
+    version: str, recursion_limit: int = 10, revision_heads_map: dict[str, 
str] | None = None
+) -> str | None:
     """
-    Recursively search for the revision of the given version.
+    Recursively search for the revision of the given version in 
revision_heads_map.
 
-    This searches REVISION_HEADS_MAP for the revision of the given version, 
recursively
+    This searches given revision_heads_map for the revision of the given 
version, recursively
     searching for the previous version if the given version is not found.
     """
-    if version in _REVISION_HEADS_MAP:
-        return _REVISION_HEADS_MAP[version]
+    if revision_heads_map is None:
+        revision_heads_map = _REVISION_HEADS_MAP
+    if version in revision_heads_map:
+        return revision_heads_map[version]
     try:
         major, minor, patch = map(int, version.split("."))
     except ValueError:
@@ -90,13 +94,19 @@ def get_version_revision(version: str, recursion_limit=10) 
-> str | None:
     if recursion_limit <= 0:
         # Prevent infinite recursion as I can't imagine 10 successive versions 
without migration
         return None
-    return get_version_revision(new_version, recursion_limit)
+    return _get_version_revision(new_version, recursion_limit)
 
 
-@cli_utils.action_cli(check_db=False)
-@providers_configuration_loaded
-def migratedb(args):
-    """Migrates the metadata database."""
+def run_db_migrate_command(args, command, revision_heads_map: dict[str, str], 
reserialize_dags: bool = True):
+    """
+    Run the db migrate command.
+
+    param args: The parsed arguments.
+    param command: The command to run.
+    param airflow_db: Whether the command is for the airflow database.
+
+    :meta private:
+    """
     print(f"DB: {settings.engine.url!r}")
     if args.to_revision and args.to_version:
         raise SystemExit("Cannot supply both `--to-revision` and 
`--to-version`.")
@@ -112,12 +122,10 @@ def migratedb(args):
         from_revision = args.from_revision
     elif args.from_version:
         try:
-            parsed_version = parse_version(args.from_version)
+            parse_version(args.from_version)
         except InvalidVersion:
             raise SystemExit(f"Invalid version {args.from_version!r} supplied 
as `--from-version`.")
-        if parsed_version < parse_version("2.0.0"):
-            raise SystemExit("--from-version must be greater or equal to than 
2.0.0")
-        from_revision = get_version_revision(args.from_version)
+        from_revision = _get_version_revision(args.from_version, 
revision_heads_map=revision_heads_map)
         if not from_revision:
             raise SystemExit(f"Unknown version {args.from_version!r} supplied 
as `--from-version`.")
 
@@ -126,7 +134,7 @@ def migratedb(args):
             parse_version(args.to_version)
         except InvalidVersion:
             raise SystemExit(f"Invalid version {args.to_version!r} supplied as 
`--to-version`.")
-        to_revision = get_version_revision(args.to_version)
+        to_revision = _get_version_revision(args.to_version, 
revision_heads_map=revision_heads_map)
         if not to_revision:
             raise SystemExit(f"Unknown version {args.to_version!r} supplied as 
`--to-version`.")
     elif args.to_revision:
@@ -136,21 +144,30 @@ def migratedb(args):
         print(f"Performing upgrade to the metadata database 
{settings.engine.url!r}")
     else:
         print("Generating sql for upgrade -- upgrade commands will *not* be 
submitted.")
-
-    db.upgradedb(
-        to_revision=to_revision,
-        from_revision=from_revision,
-        show_sql_only=args.show_sql_only,
-        reserialize_dags=args.reserialize_dags,
-    )
+    if reserialize_dags:
+        command(
+            to_revision=to_revision,
+            from_revision=from_revision,
+            show_sql_only=args.show_sql_only,
+            reserialize_dags=True,
+        )
+    else:
+        command(
+            to_revision=to_revision,
+            from_revision=from_revision,
+            show_sql_only=args.show_sql_only,
+        )
     if not args.show_sql_only:
         print("Database migrating done!")
 
 
-@cli_utils.action_cli(check_db=False)
-@providers_configuration_loaded
-def downgrade(args):
-    """Downgrades the metadata database."""
+def run_db_downgrade_command(args, command, revision_heads_map: dict[str, 
str]):
+    """
+    Run the db downgrade command.
+
+    param args: The parsed arguments.
+    param command: The command to run.
+    """
     if args.to_revision and args.to_version:
         raise SystemExit("Cannot supply both `--to-revision` and 
`--to-version`.")
     if args.from_version and args.from_revision:
@@ -162,14 +179,15 @@ def downgrade(args):
     if not (args.to_version or args.to_revision):
         raise SystemExit("Must provide either --to-revision or --to-version.")
     from_revision = None
+    to_revision = None
     if args.from_revision:
         from_revision = args.from_revision
     elif args.from_version:
-        from_revision = get_version_revision(args.from_version)
+        from_revision = _get_version_revision(args.from_version, 
revision_heads_map=revision_heads_map)
         if not from_revision:
             raise SystemExit(f"Unknown version {args.from_version!r} supplied 
as `--from-version`.")
     if args.to_version:
-        to_revision = get_version_revision(args.to_version)
+        to_revision = _get_version_revision(args.to_version, 
revision_heads_map=revision_heads_map)
         if not to_revision:
             raise SystemExit(f"Downgrading to version {args.to_version} is not 
supported.")
     elif args.to_revision:
@@ -188,13 +206,34 @@ def downgrade(args):
         ).upper()
         == "Y"
     ):
-        db.downgrade(to_revision=to_revision, from_revision=from_revision, 
show_sql_only=args.show_sql_only)
+        command(to_revision=to_revision, from_revision=from_revision, 
show_sql_only=args.show_sql_only)
         if not args.show_sql_only:
             print("Downgrade complete")
     else:
         raise SystemExit("Cancelled")
 
 
+@cli_utils.action_cli(check_db=False)
+@providers_configuration_loaded
+def migratedb(args):
+    """Migrates the metadata database."""
+    if args.from_version:
+        try:
+            parsed_version = parse_version(args.from_version)
+        except InvalidVersion:
+            raise SystemExit(f"Invalid version {args.from_version!r} supplied 
as `--from-version`.")
+        if parsed_version < parse_version("2.0.0"):
+            raise SystemExit("--from-version must be greater or equal to 
2.0.0")
+    run_db_migrate_command(args, db.upgradedb, _REVISION_HEADS_MAP, 
reserialize_dags=True)
+
+
+@cli_utils.action_cli(check_db=False)
+@providers_configuration_loaded
+def downgrade(args):
+    """Downgrades the metadata database."""
+    run_db_downgrade_command(args, db.downgrade, _REVISION_HEADS_MAP)
+
+
 @providers_configuration_loaded
 def check_migrations(args):
     """Wait for all airflow migrations to complete. Used for launching airflow 
in k8s."""
diff --git a/airflow/providers/fab/auth_manager/cli_commands/db_command.py 
b/airflow/providers/fab/auth_manager/cli_commands/db_command.py
new file mode 100644
index 0000000000..8b41cf4216
--- /dev/null
+++ b/airflow/providers/fab/auth_manager/cli_commands/db_command.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow import settings
+from airflow.cli.commands.db_command import run_db_downgrade_command, 
run_db_migrate_command
+from airflow.providers.fab.auth_manager.models.db import _REVISION_HEADS_MAP, 
FABDBManager
+from airflow.utils import cli as cli_utils
+from airflow.utils.providers_configuration_loader import 
providers_configuration_loaded
+
+
+@providers_configuration_loaded
+def resetdb(args):
+    """Reset the metadata database."""
+    print(f"DB: {settings.engine.url!r}")
+    if not (args.yes or input("This will drop existing tables if they exist. 
Proceed? (y/n)").upper() == "Y"):
+        raise SystemExit("Cancelled")
+    FABDBManager(settings.Session()).resetdb(skip_init=args.skip_init)
+
+
+@cli_utils.action_cli(check_db=False)
+@providers_configuration_loaded
+def migratedb(args):
+    """Migrates the metadata database."""
+    session = settings.Session()
+    upgrade_command = FABDBManager(session).upgradedb
+    run_db_migrate_command(
+        args, upgrade_command, revision_heads_map=_REVISION_HEADS_MAP, 
reserialize_dags=False
+    )
+
+
+@cli_utils.action_cli(check_db=False)
+@providers_configuration_loaded
+def downgrade(args):
+    """Downgrades the metadata database."""
+    session = settings.Session()
+    dwongrade_command = FABDBManager(session).downgrade
+    run_db_downgrade_command(args, dwongrade_command, 
revision_heads_map=_REVISION_HEADS_MAP)
diff --git a/airflow/providers/fab/auth_manager/cli_commands/definition.py 
b/airflow/providers/fab/auth_manager/cli_commands/definition.py
index c7be5270d5..7f8f1e84e2 100644
--- a/airflow/providers/fab/auth_manager/cli_commands/definition.py
+++ b/airflow/providers/fab/auth_manager/cli_commands/definition.py
@@ -19,8 +19,17 @@ from __future__ import annotations
 import textwrap
 
 from airflow.cli.cli_config import (
+    ARG_DB_FROM_REVISION,
+    ARG_DB_FROM_VERSION,
+    ARG_DB_REVISION__DOWNGRADE,
+    ARG_DB_REVISION__UPGRADE,
+    ARG_DB_SKIP_INIT,
+    ARG_DB_SQL_ONLY,
+    ARG_DB_VERSION__DOWNGRADE,
+    ARG_DB_VERSION__UPGRADE,
     ARG_OUTPUT,
     ARG_VERBOSE,
+    ARG_YES,
     ActionCommand,
     Arg,
     lazy_load_command,
@@ -243,3 +252,55 @@ SYNC_PERM_COMMAND = ActionCommand(
     
func=lazy_load_command("airflow.providers.fab.auth_manager.cli_commands.sync_perm_command.sync_perm"),
     args=(ARG_INCLUDE_DAGS, ARG_VERBOSE),
 )
+
+DB_COMMANDS = (
+    ActionCommand(
+        name="migrate",
+        help="Migrates the FAB metadata database to the latest version",
+        description=(
+            "Migrate the schema of the FAB metadata database. "
+            "Create the database if it does not exist "
+            "To print but not execute commands, use option 
``--show-sql-only``. "
+            "If using options ``--from-revision`` or ``--from-version``, you 
must also use "
+            "``--show-sql-only``, because if actually *running* migrations, we 
should only "
+            "migrate from the *current* Alembic revision."
+        ),
+        
func=lazy_load_command("airflow.providers.fab.auth_manager.cli_commands.db_command.migratedb"),
+        args=(
+            ARG_DB_REVISION__UPGRADE,
+            ARG_DB_VERSION__UPGRADE,
+            ARG_DB_SQL_ONLY,
+            ARG_DB_FROM_REVISION,
+            ARG_DB_FROM_VERSION,
+            ARG_VERBOSE,
+        ),
+    ),
+    ActionCommand(
+        name="downgrade",
+        help="Downgrade the schema of the FAB metadata database.",
+        description=(
+            "Downgrade the schema of the FAB metadata database. "
+            "You must provide either `--to-revision` or `--to-version`. "
+            "To print but not execute commands, use option `--show-sql-only`. "
+            "If using options `--from-revision` or `--from-version`, you must 
also use `--show-sql-only`, "
+            "because if actually *running* migrations, we should only migrate 
from the *current* Alembic "
+            "revision."
+        ),
+        
func=lazy_load_command("airflow.providers.fab.auth_manager.cli_commands.db_command.downgrade"),
+        args=(
+            ARG_DB_REVISION__DOWNGRADE,
+            ARG_DB_VERSION__DOWNGRADE,
+            ARG_DB_SQL_ONLY,
+            ARG_YES,
+            ARG_DB_FROM_REVISION,
+            ARG_DB_FROM_VERSION,
+            ARG_VERBOSE,
+        ),
+    ),
+    ActionCommand(
+        name="reset",
+        help="Burn down and rebuild the FAB metadata database",
+        
func=lazy_load_command("airflow.providers.fab.auth_manager.cli_commands.db_command.resetdb"),
+        args=(ARG_YES, ARG_DB_SKIP_INIT, ARG_VERBOSE),
+    ),
+)
diff --git a/airflow/providers/fab/auth_manager/fab_auth_manager.py 
b/airflow/providers/fab/auth_manager/fab_auth_manager.py
index a0ec27bc67..336437061c 100644
--- a/airflow/providers/fab/auth_manager/fab_auth_manager.py
+++ b/airflow/providers/fab/auth_manager/fab_auth_manager.py
@@ -22,11 +22,13 @@ from functools import cached_property
 from pathlib import Path
 from typing import TYPE_CHECKING, Container
 
+import packaging.version
 from connexion import FlaskApi
 from flask import Blueprint, url_for
 from sqlalchemy import select
 from sqlalchemy.orm import Session, joinedload
 
+from airflow import __version__ as airflow_version
 from airflow.auth.managers.base_auth_manager import BaseAuthManager, 
ResourceMethod
 from airflow.auth.managers.models.resource_details import (
     AccessView,
@@ -47,6 +49,7 @@ from airflow.configuration import conf
 from airflow.exceptions import AirflowConfigException, AirflowException
 from airflow.models import DagModel
 from airflow.providers.fab.auth_manager.cli_commands.definition import (
+    DB_COMMANDS,
     ROLES_COMMANDS,
     SYNC_PERM_COMMAND,
     USERS_COMMANDS,
@@ -132,7 +135,7 @@ class FabAuthManager(BaseAuthManager):
     @staticmethod
     def get_cli_commands() -> list[CLICommand]:
         """Vends CLI commands to be included in Airflow CLI."""
-        return [
+        commands: list[CLICommand] = [
             GroupCommand(
                 name="users",
                 help="Manage users",
@@ -145,6 +148,12 @@ class FabAuthManager(BaseAuthManager):
             ),
             SYNC_PERM_COMMAND,  # not in a command group
         ]
+        # If Airflow version is 3.0.0 or higher, add the fab-db command group
+        if packaging.version.parse(
+            packaging.version.parse(airflow_version).base_version
+        ) >= packaging.version.parse("3.0.0"):
+            commands.append(GroupCommand(name="fab-db", help="Manage FAB", 
subcommands=DB_COMMANDS))
+        return commands
 
     def get_api_endpoints(self) -> None | Blueprint:
         folder = Path(__file__).parents[0].resolve()  # this is 
airflow/auth/managers/fab/
diff --git a/airflow/providers/fab/auth_manager/models/db.py 
b/airflow/providers/fab/auth_manager/models/db.py
index a971ea29a3..f72e1fcc65 100644
--- a/airflow/providers/fab/auth_manager/models/db.py
+++ b/airflow/providers/fab/auth_manager/models/db.py
@@ -19,17 +19,89 @@ from __future__ import annotations
 import os
 
 import airflow
+from airflow import settings
+from airflow.exceptions import AirflowException
 from airflow.providers.fab.auth_manager.models import metadata
+from airflow.utils.db import _offline_migration, print_happy_cat
 from airflow.utils.db_manager import BaseDBManager
 
 PACKAGE_DIR = os.path.dirname(airflow.__file__)
 
+_REVISION_HEADS_MAP: dict[str, str] = {
+    "1.3.0": "6709f7a774b9",
+}
+
 
 class FABDBManager(BaseDBManager):
     """Manages FAB database."""
 
     metadata = metadata
-    version_table_name = "fab_alembic_version"
+    version_table_name = "alembic_version_fab"
     migration_dir = os.path.join(PACKAGE_DIR, "providers/fab/migrations")
     alembic_file = os.path.join(PACKAGE_DIR, "providers/fab/alembic.ini")
     supports_table_dropping = True
+
+    def upgradedb(self, to_revision=None, from_revision=None, 
show_sql_only=False):
+        """Upgrade the database."""
+        if from_revision and not show_sql_only:
+            raise AirflowException("`from_revision` only supported with 
`sql_only=True`.")
+
+        # alembic adds significant import time, so we import it lazily
+        if not settings.SQL_ALCHEMY_CONN:
+            raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set. This is 
a critical assertion.")
+        from alembic import command
+
+        config = self.get_alembic_config()
+
+        if show_sql_only:
+            if settings.engine.dialect.name == "sqlite":
+                raise SystemExit("Offline migration not supported for SQLite.")
+            if not from_revision:
+                from_revision = self.get_current_revision()
+
+            if not to_revision:
+                script = self.get_script_object(config)
+                to_revision = script.get_current_head()
+
+            if to_revision == from_revision:
+                print_happy_cat("No migrations to apply; nothing to do.")
+                return
+            _offline_migration(command.upgrade, config, 
f"{from_revision}:{to_revision}")
+            return  # only running sql; our job is done
+
+        if not self.get_current_revision():
+            # New DB; initialize and exit
+            self.initdb()
+            return
+
+        command.upgrade(config, revision=to_revision or "heads")
+
+    def downgrade(self, to_revision, from_revision=None, show_sql_only=False):
+        if from_revision and not show_sql_only:
+            raise ValueError(
+                "`from_revision` can't be combined with `show_sql_only=False`. 
When actually "
+                "applying a downgrade (instead of just generating sql), we 
always "
+                "downgrade from current revision."
+            )
+
+        if not settings.SQL_ALCHEMY_CONN:
+            raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set.")
+
+        # alembic adds significant import time, so we import it lazily
+        from alembic import command
+
+        self.log.info("Attempting downgrade of FAB migration to revision %s", 
to_revision)
+        config = self.get_alembic_config()
+
+        if show_sql_only:
+            self.log.warning("Generating sql scripts for manual migration.")
+            if not from_revision:
+                from_revision = self.get_current_revision()
+            if from_revision is None:
+                self.log.info("No revision found")
+                return
+            revision_range = f"{from_revision}:{to_revision}"
+            _offline_migration(command.downgrade, config=config, 
revision=revision_range)
+        else:
+            self.log.info("Applying FAB downgrade migrations.")
+            command.downgrade(config, revision=to_revision, sql=show_sql_only)
diff --git a/airflow/providers/fab/migrations/script.py.mako 
b/airflow/providers/fab/migrations/script.py.mako
index 4d0928fcc0..c0193ce2b0 100644
--- a/airflow/providers/fab/migrations/script.py.mako
+++ b/airflow/providers/fab/migrations/script.py.mako
@@ -30,10 +30,11 @@ import sqlalchemy as sa
 ${imports if imports else ""}
 
 # revision identifiers, used by Alembic.
-revision: str = ${repr(up_revision)}
-down_revision: Union[str, None] = ${repr(down_revision)}
-branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
-depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
+revision = ${repr(up_revision)}
+down_revision = ${repr(down_revision)}
+branch_labels = ${repr(branch_labels)}
+depends_on = ${repr(depends_on)}
+fab_version = None
 
 
 def upgrade() -> None:
diff --git a/airflow/providers/fab/auth_manager/models/db.py 
b/airflow/providers/fab/migrations/versions/0001_1_3_0_placeholder_migration.py
similarity index 60%
copy from airflow/providers/fab/auth_manager/models/db.py
copy to 
airflow/providers/fab/migrations/versions/0001_1_3_0_placeholder_migration.py
index a971ea29a3..685216779d 100644
--- a/airflow/providers/fab/auth_manager/models/db.py
+++ 
b/airflow/providers/fab/migrations/versions/0001_1_3_0_placeholder_migration.py
@@ -1,3 +1,4 @@
+#
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -14,22 +15,31 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from __future__ import annotations
 
-import os
+"""
+placeholder migration.
+
+Revision ID: 6709f7a774b9
+Revises:
+Create Date: 2024-09-03 17:06:38.040510
+
+Note: This is a placeholder migration used to stamp the migration
+when we create the migration from the ORM. Otherwise, it will run
+without stamping the migration, leading to subsequent changes to
+the tables not being migrated.
+"""
+
+from __future__ import annotations
 
-import airflow
-from airflow.providers.fab.auth_manager.models import metadata
-from airflow.utils.db_manager import BaseDBManager
+# revision identifiers, used by Alembic.
+revision = "6709f7a774b9"
+down_revision = None
+branch_labels = None
+depends_on = None
+fab_version = "1.3.0"
 
-PACKAGE_DIR = os.path.dirname(airflow.__file__)
 
+def upgrade() -> None: ...
 
-class FABDBManager(BaseDBManager):
-    """Manages FAB database."""
 
-    metadata = metadata
-    version_table_name = "fab_alembic_version"
-    migration_dir = os.path.join(PACKAGE_DIR, "providers/fab/migrations")
-    alembic_file = os.path.join(PACKAGE_DIR, "providers/fab/alembic.ini")
-    supports_table_dropping = True
+def downgrade() -> None: ...
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index bc4c60697c..1dc7275610 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -89,7 +89,7 @@ T = TypeVar("T")
 
 log = logging.getLogger(__name__)
 
-_REVISION_HEADS_MAP = {
+_REVISION_HEADS_MAP: dict[str, str] = {
     "2.7.0": "405de8318b3a",
     "2.8.0": "10b52ebd31f7",
     "2.8.1": "88344c1d9134",
@@ -819,6 +819,7 @@ def check_migrations(timeout):
     :return: None
     """
     timeout = timeout or 1  # run the loop at least 1
+    external_db_manager = RunDBManager()
     with _configured_alembic_environment() as env:
         context = env.get_context()
         source_heads = None
@@ -826,7 +827,7 @@ def check_migrations(timeout):
         for ticker in range(timeout):
             source_heads = set(env.script.get_heads())
             db_heads = set(context.get_current_heads())
-            if source_heads == db_heads:
+            if source_heads == db_heads and 
external_db_manager.check_migration(settings.Session()):
                 return
             time.sleep(1)
             log.info("Waiting for migrations... %s second(s)", ticker)
@@ -1027,6 +1028,11 @@ def _check_migration_errors(session: Session = 
NEW_SESSION) -> Iterable[str]:
 
 
 def _offline_migration(migration_func: Callable, config, revision):
+    """
+    Run offline migration.
+
+    :meta private:
+    """
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
         logging.disable(logging.CRITICAL)
@@ -1160,6 +1166,13 @@ def upgradedb(
             os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE"] = "1"
             
settings.reconfigure_orm(pool_class=sqlalchemy.pool.SingletonThreadPool)
             command.upgrade(config, revision=to_revision or "heads")
+            current_revision = _get_current_revision(session=session)
+            with _configured_alembic_environment() as env:
+                source_heads = env.script.get_heads()
+            if current_revision == source_heads[0]:
+                # Only run external DB upgrade migration if user upgraded to 
heads
+                external_db_manager = RunDBManager()
+                external_db_manager.upgradedb(session)
 
         finally:
             if val is None:
@@ -1168,8 +1181,6 @@ def upgradedb(
                 os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE"] = val
             settings.reconfigure_orm()
 
-        current_revision = _get_current_revision(session=session)
-
         if reserialize_dags and current_revision != previous_revision:
             _reserialize_dags(session=session)
         add_default_pool_if_not_exists(session=session)
@@ -1191,7 +1202,7 @@ def resetdb(session: Session = NEW_SESSION, skip_init: 
bool = False):
         drop_airflow_models(connection)
         drop_airflow_moved_tables(connection)
         external_db_manager = RunDBManager()
-        external_db_manager.drop_tables(connection)
+        external_db_manager.drop_tables(session, connection)
 
     if not skip_init:
         initdb(session=session)
diff --git a/airflow/utils/db_manager.py b/airflow/utils/db_manager.py
index 78d3244cd2..2241a646ef 100644
--- a/airflow/utils/db_manager.py
+++ b/airflow/utils/db_manager.py
@@ -20,13 +20,16 @@ import os
 from typing import TYPE_CHECKING
 
 from alembic import command
+from sqlalchemy import inspect
 
+from airflow import settings
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.module_loading import import_string
 
 if TYPE_CHECKING:
+    from alembic.script import ScriptDirectory
     from sqlalchemy import MetaData
 
 
@@ -54,14 +57,34 @@ class BaseDBManager(LoggingMixin):
         config.set_main_option("sqlalchemy.url", 
settings.SQL_ALCHEMY_CONN.replace("%", "%%"))
         return config
 
-    def get_current_revision(self):
+    def get_script_object(self, config=None) -> ScriptDirectory:
+        from alembic.script import ScriptDirectory
+
+        if not config:
+            config = self.get_alembic_config()
+        return ScriptDirectory.from_config(config)
+
+    def _get_migration_ctx(self):
         from alembic.migration import MigrationContext
 
         conn = self.session.connection()
 
-        migration_ctx = MigrationContext.configure(conn, 
opts={"version_table": self.version_table_name})
+        return MigrationContext.configure(conn, opts={"version_table": 
self.version_table_name})
 
-        return migration_ctx.get_current_revision()
+    def get_current_revision(self):
+        return self._get_migration_ctx().get_current_revision()
+
+    def check_migration(self):
+        """Check migration done."""
+        script_heads = self.get_script_object().get_heads()
+        db_heads = self.get_current_revision()
+        if db_heads:
+            db_heads = {db_heads}
+        if not db_heads and not script_heads:
+            return True
+        if set(script_heads) == db_heads:
+            return True
+        return False
 
     def _create_db_from_orm(self):
         """Create database from ORM."""
@@ -70,6 +93,22 @@ class BaseDBManager(LoggingMixin):
         config = self.get_alembic_config()
         command.stamp(config, "head")
 
+    def drop_tables(self, connection):
+        self.metadata.drop_all(connection)
+        version = self._get_migration_ctx()._version
+        if inspect(connection).has_table(version.name):
+            version.drop(connection)
+
+    def resetdb(self, skip_init=False):
+        from airflow.utils.db import DBLocks, create_global_lock
+
+        connection = settings.engine.connect()
+
+        with create_global_lock(self.session, lock=DBLocks.MIGRATIONS), 
connection.begin():
+            self.drop_tables(connection)
+        if not skip_init:
+            self.initdb()
+
     def initdb(self):
         """Initialize the database."""
         db_exists = self.get_current_revision()
@@ -78,14 +117,14 @@ class BaseDBManager(LoggingMixin):
         else:
             self._create_db_from_orm()
 
-    def upgradedb(self, to_version=None, from_version=None, 
show_sql_only=False):
+    def upgradedb(self, to_revision=None, from_revision=None, 
show_sql_only=False):
         """Upgrade the database."""
         self.log.info("Upgrading the %s database", self.__class__.__name__)
 
         config = self.get_alembic_config()
-        command.upgrade(config, revision=to_version or "heads", 
sql=show_sql_only)
+        command.upgrade(config, revision=to_revision or "heads", 
sql=show_sql_only)
 
-    def downgradedb(self, to_version, from_version=None, show_sql_only=False):
+    def downgrade(self, to_version, from_version=None, show_sql_only=False):
         """Downgrade the database."""
         raise NotImplementedError
 
@@ -141,6 +180,14 @@ class RunDBManager(LoggingMixin):
         if manager.version_table_name == "alembic_version":
             raise AirflowException(f"{manager}.version_table_name cannot be 
'alembic_version'")
 
+    def check_migration(self, session):
+        """Check the external database migration."""
+        return_value = []
+        for manager in self._managers:
+            m = manager(session)
+            return_value.append(m.check_migration)
+        return all([x() for x in return_value])
+
     def initdb(self, session):
         """Initialize the external database managers."""
         for manager in self._managers:
@@ -153,14 +200,15 @@ class RunDBManager(LoggingMixin):
             m = manager(session)
             m.upgradedb()
 
-    def downgradedb(self, session):
+    def downgrade(self, session):
         """Downgrade the external database managers."""
         for manager in self._managers:
             m = manager(session)
-            m.downgradedb()
+            m.downgrade()
 
-    def drop_tables(self, connection):
+    def drop_tables(self, session, connection):
         """Drop the external database managers."""
         for manager in self._managers:
             if manager.supports_table_dropping:
-                manager.metadata.drop_all(connection)
+                m = manager(session)
+                m.drop_tables(connection)
diff --git a/scripts/ci/pre_commit/version_heads_map.py 
b/scripts/ci/pre_commit/version_heads_map.py
index 4277c46564..10a6dee2ea 100755
--- a/scripts/ci/pre_commit/version_heads_map.py
+++ b/scripts/ci/pre_commit/version_heads_map.py
@@ -23,21 +23,27 @@ import sys
 from pathlib import Path
 
 import re2
-from packaging.version import parse as parse_version
 
 PROJECT_SOURCE_ROOT_DIR = Path(__file__).resolve().parent.parent.parent.parent
 
 DB_FILE = PROJECT_SOURCE_ROOT_DIR / "airflow" / "utils" / "db.py"
 MIGRATION_PATH = PROJECT_SOURCE_ROOT_DIR / "airflow" / "migrations" / 
"versions"
 
+FAB_DB_FILE = PROJECT_SOURCE_ROOT_DIR / "airflow" / "providers" / "fab" / 
"auth_manager" / "models" / "db.py"
+FAB_MIGRATION_PATH = PROJECT_SOURCE_ROOT_DIR / "airflow" / "providers" / "fab" 
/ "migrations" / "versions"
+
 sys.path.insert(0, str(Path(__file__).parent.resolve()))  # make sure 
common_precommit_utils is importable
 
 
-def revision_heads_map():
+def revision_heads_map(migration_path):
     rh_map = {}
     pattern = r'revision = "[a-fA-F0-9]+"'
-    airflow_version_pattern = r'airflow_version = "\d+\.\d+\.\d+"'
-    filenames = os.listdir(MIGRATION_PATH)
+    version_pattern = None
+    if migration_path == MIGRATION_PATH:
+        version_pattern = r'airflow_version = "\d+\.\d+\.\d+"'
+    elif migration_path == FAB_MIGRATION_PATH:
+        version_pattern = r'fab_version = "\d+\.\d+\.\d+"'
+    filenames = os.listdir(migration_path)
 
     def sorting_key(filen):
         prefix = filen.split("_")[0]
@@ -46,43 +52,46 @@ def revision_heads_map():
     sorted_filenames = sorted(filenames, key=sorting_key)
 
     for filename in sorted_filenames:
-        if not filename.endswith(".py"):
+        if not filename.endswith(".py") or filename == "__init__.py":
             continue
-        with open(os.path.join(MIGRATION_PATH, filename)) as file:
+        with open(os.path.join(migration_path, filename)) as file:
             content = file.read()
             revision_match = re2.search(pattern, content)
-            airflow_version_match = re2.search(airflow_version_pattern, 
content)
-            if revision_match and airflow_version_match:
+            _version_match = re2.search(version_pattern, content)
+            if revision_match and _version_match:
                 revision = revision_match.group(0).split('"')[1]
-                version = airflow_version_match.group(0).split('"')[1]
-                if parse_version(version) >= parse_version("2.0.0"):
-                    rh_map[version] = revision
+                version = _version_match.group(0).split('"')[1]
+                rh_map[version] = revision
     return rh_map
 
 
 if __name__ == "__main__":
-    with open(DB_FILE) as file:
-        content = file.read()
-
-    pattern = r"_REVISION_HEADS_MAP = {[^}]+\}"
-    match = re2.search(pattern, content)
-    if not match:
-        print(
-            f"_REVISION_HEADS_MAP not found in {DB_FILE}. If this has been 
removed intentionally, "
-            "please update scripts/ci/pre_commit/version_heads_map.py"
-        )
-        sys.exit(1)
-
-    existing_revision_heads_map = match.group(0)
-    rh_map = revision_heads_map()
-    updated_revision_heads_map = "_REVISION_HEADS_MAP = {\n"
-    for k, v in rh_map.items():
-        updated_revision_heads_map += f'    "{k}": "{v}",\n'
-    updated_revision_heads_map += "}"
-    if existing_revision_heads_map != updated_revision_heads_map:
-        new_content = content.replace(existing_revision_heads_map, 
updated_revision_heads_map)
-
-        with open(DB_FILE, "w") as file:
-            file.write(new_content)
-        print("_REVISION_HEADS_MAP updated in db.py. Please commit the 
changes.")
-        sys.exit(1)
+    paths = [(DB_FILE, MIGRATION_PATH), (FAB_DB_FILE, FAB_MIGRATION_PATH)]
+    for dbfile, mpath in paths:
+        with open(dbfile) as file:
+            content = file.read()
+
+        pattern = 
r"_REVISION_HEADS_MAP:\s*dict\[\s*str\s*,\s*str\s*\]\s*=\s*\{[^}]*\}"
+        match = re2.search(pattern, content)
+        if not match:
+            print(
+                f"_REVISION_HEADS_MAP not found in {dbfile}. If this has been 
removed intentionally, "
+                "please update scripts/ci/pre_commit/version_heads_map.py"
+            )
+            sys.exit(1)
+
+        existing_revision_heads_map = match.group(0)
+        rh_map = revision_heads_map(mpath)
+        updated_revision_heads_map = "_REVISION_HEADS_MAP: dict[str, str] = 
{\n"
+        for k, v in rh_map.items():
+            updated_revision_heads_map += f'    "{k}": "{v}",\n'
+        updated_revision_heads_map += "}"
+        if updated_revision_heads_map == "_REVISION_HEADS_MAP: dict[str, str] 
= {\n}":
+            updated_revision_heads_map = "_REVISION_HEADS_MAP: dict[str, str] 
= {}"
+        if existing_revision_heads_map != updated_revision_heads_map:
+            new_content = content.replace(existing_revision_heads_map, 
updated_revision_heads_map)
+
+            with open(dbfile, "w") as file:
+                file.write(new_content)
+            print(f"_REVISION_HEADS_MAP updated in {dbfile}. Please commit the 
changes.")
+            sys.exit(1)
diff --git a/scripts/in_container/run_migration_reference.py 
b/scripts/in_container/run_migration_reference.py
index 204db7a5ea..fc7d3bd084 100755
--- a/scripts/in_container/run_migration_reference.py
+++ b/scripts/in_container/run_migration_reference.py
@@ -138,8 +138,7 @@ def get_revisions(app="airflow") -> Iterable[Script]:
     else:
         from airflow.providers.fab.auth_manager.models.db import FABDBManager
 
-        config = FABDBManager(session="").get_alembic_config()
-        script = ScriptDirectory.from_config(config)
+        script = FABDBManager(session="").get_script_object()
         yield from script.walk_revisions()
 
 
@@ -234,9 +233,9 @@ def correct_mismatching_revision_nums(revisions: 
Iterable[Script]):
 if __name__ == "__main__":
     apps = ["airflow", "fab"]
     for app in apps:
-        console.print("[bright_blue]Updating migration reference")
+        console.print(f"[bright_blue]Updating migration reference for {app}")
         revisions = list(reversed(list(get_revisions(app))))
-        console.print("[bright_blue]Making sure airflow version updated")
+        console.print(f"[bright_blue]Making sure {app} version updated")
         ensure_version(revisions=revisions, app=app)
         console.print("[bright_blue]Making sure there's no mismatching 
revision numbers")
         correct_mismatching_revision_nums(revisions=revisions)
diff --git a/tests/always/test_project_structure.py 
b/tests/always/test_project_structure.py
index 1e684888dc..a6d174e639 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -169,6 +169,8 @@ class TestProjectStructure:
         modules_files = list(f for f in modules_files if "/_vendor/" not in f)
         # Exclude __init__.py
         modules_files = list(f for f in modules_files if not 
f.endswith("__init__.py"))
+        # Exclude versions file
+        modules_files = list(f for f in modules_files if "/versions/" not in f)
         # Change airflow/ to tests/
         expected_test_files = list(
             f'tests/{f.partition("/")[2]}' for f in modules_files if not 
f.endswith("__init__.py")
diff --git a/tests/providers/fab/auth_manager/cli_commands/test_db_command.py 
b/tests/providers/fab/auth_manager/cli_commands/test_db_command.py
new file mode 100644
index 0000000000..030b251a55
--- /dev/null
+++ b/tests/providers/fab/auth_manager/cli_commands/test_db_command.py
@@ -0,0 +1,134 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.cli import cli_parser
+
+pytestmark = [pytest.mark.db_test]
+try:
+    from airflow.providers.fab.auth_manager.cli_commands import db_command
+    from airflow.providers.fab.auth_manager.models.db import FABDBManager
+
+    class TestFABCLiDB:
+        @classmethod
+        def setup_class(cls):
+            cls.parser = cli_parser.get_parser()
+
+        @mock.patch.object(FABDBManager, "resetdb")
+        def test_cli_resetdb(self, mock_resetdb):
+            db_command.resetdb(self.parser.parse_args(["fab-db", "reset", 
"--yes"]))
+
+            mock_resetdb.assert_called_once_with(skip_init=False)
+
+        @mock.patch.object(FABDBManager, "resetdb")
+        def test_cli_resetdb_skip_init(self, mock_resetdb):
+            db_command.resetdb(self.parser.parse_args(["fab-db", "reset", 
"--yes", "--skip-init"]))
+            mock_resetdb.assert_called_once_with(skip_init=True)
+
+        @pytest.mark.parametrize(
+            "args, called_with",
+            [
+                (
+                    [],
+                    dict(
+                        to_revision=None,
+                        from_revision=None,
+                        show_sql_only=False,
+                    ),
+                ),
+                (
+                    ["--show-sql-only"],
+                    dict(
+                        to_revision=None,
+                        from_revision=None,
+                        show_sql_only=True,
+                    ),
+                ),
+                (
+                    ["--to-revision", "abc"],
+                    dict(
+                        to_revision="abc",
+                        from_revision=None,
+                        show_sql_only=False,
+                    ),
+                ),
+                (
+                    ["--to-revision", "abc", "--show-sql-only"],
+                    dict(to_revision="abc", from_revision=None, 
show_sql_only=True),
+                ),
+                (
+                    ["--to-revision", "abc", "--from-revision", "abc123", 
"--show-sql-only"],
+                    dict(
+                        to_revision="abc",
+                        from_revision="abc123",
+                        show_sql_only=True,
+                    ),
+                ),
+            ],
+        )
+        @mock.patch.object(FABDBManager, "upgradedb")
+        def test_cli_upgrade_success(self, mock_upgradedb, args, called_with):
+            db_command.migratedb(self.parser.parse_args(["fab-db", "migrate", 
*args]))
+            mock_upgradedb.assert_called_once_with(**called_with)
+
+        @pytest.mark.parametrize(
+            "args, pattern",
+            [
+                pytest.param(
+                    ["--to-revision", "abc", "--to-version", "1.3.0"],
+                    "Cannot supply both",
+                    id="to both version and revision",
+                ),
+                pytest.param(
+                    ["--from-revision", "abc", "--from-version", "1.3.0"],
+                    "Cannot supply both",
+                    id="from both version and revision",
+                ),
+                pytest.param(["--to-version", "1.2.0"], "Unknown version 
'1.2.0'", id="unknown to version"),
+                pytest.param(["--to-version", "abc"], "Invalid version 'abc'", 
id="invalid to version"),
+                pytest.param(
+                    ["--to-revision", "abc", "--from-revision", "abc123"],
+                    "used with `--show-sql-only`",
+                    id="requires offline",
+                ),
+                pytest.param(
+                    ["--to-revision", "abc", "--from-version", "1.3.0"],
+                    "used with `--show-sql-only`",
+                    id="requires offline",
+                ),
+                pytest.param(
+                    ["--to-revision", "abc", "--from-version", "1.1.25", 
"--show-sql-only"],
+                    "Unknown version '1.1.25'",
+                    id="unknown from version",
+                ),
+                pytest.param(
+                    ["--to-revision", "adaf", "--from-version", "abc", 
"--show-sql-only"],
+                    "Invalid version 'abc'",
+                    id="invalid from version",
+                ),
+            ],
+        )
+        @mock.patch.object(FABDBManager, "upgradedb")
+        def test_cli_migratedb_failure(self, mock_upgradedb, args, pattern):
+            with pytest.raises(SystemExit, match=pattern):
+                db_command.migratedb(self.parser.parse_args(["fab-db", 
"migrate", *args]))
+except (ModuleNotFoundError, ImportError):
+    pass
diff --git a/tests/providers/fab/auth_manager/models/test_db.py 
b/tests/providers/fab/auth_manager/models/test_db.py
index 7b0dc345c7..528e1cbf09 100644
--- a/tests/providers/fab/auth_manager/models/test_db.py
+++ b/tests/providers/fab/auth_manager/models/test_db.py
@@ -17,6 +17,8 @@
 from __future__ import annotations
 
 import os
+import re
+from unittest import mock
 
 import pytest
 from alembic.autogenerate import compare_metadata
@@ -35,30 +37,33 @@ try:
     from airflow.providers.fab.auth_manager.models.db import FABDBManager
 
     class TestFABDBManager:
-        def setup_method(self, session):
+        def setup_method(self):
             self.airflow_dir = os.path.dirname(airflow.__file__)
-            self.db_manager = FABDBManager(session=session)
 
-        def test_version_table_name_set(self):
-            assert self.db_manager.version_table_name == "fab_alembic_version"
+        def test_version_table_name_set(self, session):
+            assert FABDBManager(session=session).version_table_name == 
"alembic_version_fab"
 
-        def test_migration_dir_set(self):
-            assert self.db_manager.migration_dir == 
f"{self.airflow_dir}/providers/fab/migrations"
+        def test_migration_dir_set(self, session):
+            assert (
+                FABDBManager(session=session).migration_dir == 
f"{self.airflow_dir}/providers/fab/migrations"
+            )
 
-        def test_alembic_file_set(self):
-            assert self.db_manager.alembic_file == 
f"{self.airflow_dir}/providers/fab/alembic.ini"
+        def test_alembic_file_set(self, session):
+            assert (
+                FABDBManager(session=session).alembic_file == 
f"{self.airflow_dir}/providers/fab/alembic.ini"
+            )
 
-        def test_supports_table_dropping_set(self):
-            assert self.db_manager.supports_table_dropping is True
+        def test_supports_table_dropping_set(self, session):
+            assert FABDBManager(session=session).supports_table_dropping is 
True
 
-        def test_database_schema_and_sqlalchemy_model_are_in_sync(self):
+        def test_database_schema_and_sqlalchemy_model_are_in_sync(self, 
session):
             def include_object(_, name, type_, *args):
-                if type_ == "table" and name not in 
self.db_manager.metadata.tables:
+                if type_ == "table" and name not in 
FABDBManager(session=session).metadata.tables:
                     return False
                 return True
 
             all_meta_data = MetaData()
-            for table_name, table in self.db_manager.metadata.tables.items():
+            for table_name, table in 
FABDBManager(session=session).metadata.tables.items():
                 all_meta_data._add_table(table_name, table.schema, table)
             # create diff between database schema and SQLAlchemy model
             mctx = MigrationContext.configure(
@@ -72,5 +77,61 @@ try:
             diff = compare_metadata(mctx, all_meta_data)
 
             assert not diff, "Database schema and SQLAlchemy model are not in 
sync: " + str(diff)
+
+        
@mock.patch("airflow.providers.fab.auth_manager.models.db._offline_migration")
+        def test_downgrade_sql_no_from(self, mock_om, session, caplog):
+            FABDBManager(session=session).downgrade(to_revision="abc", 
show_sql_only=True, from_revision=None)
+            actual = mock_om.call_args.kwargs["revision"]
+            assert re.match(r"[a-z0-9]+:abc", actual) is not None
+
+        
@mock.patch("airflow.providers.fab.auth_manager.models.db._offline_migration")
+        def test_downgrade_sql_with_from(self, mock_om, session):
+            FABDBManager(session=session).downgrade(
+                to_revision="abc", show_sql_only=True, from_revision="123"
+            )
+            actual = mock_om.call_args.kwargs["revision"]
+            assert actual == "123:abc"
+
+        @mock.patch("alembic.command.downgrade")
+        def test_downgrade_invalid_combo(self, mock_om, session):
+            """can't combine `sql=False` and `from_revision`"""
+            with pytest.raises(ValueError, match="can't be combined"):
+                FABDBManager(session=session).downgrade(to_revision="abc", 
from_revision="123")
+
+        @mock.patch("alembic.command.downgrade")
+        def test_downgrade_with_from(self, mock_om, session):
+            FABDBManager(session=session).downgrade(to_revision="abc")
+            actual = mock_om.call_args.kwargs["revision"]
+            assert actual == "abc"
+
+        @mock.patch.object(FABDBManager, "get_current_revision")
+        def test_sqlite_offline_upgrade_raises_with_revision(self, mock_gcr, 
session):
+            with mock.patch(
+                
"airflow.providers.fab.auth_manager.models.db.settings.engine.dialect"
+            ) as dialect:
+                dialect.name = "sqlite"
+                with pytest.raises(SystemExit, match="Offline migration not 
supported for SQLite"):
+                    FABDBManager(session).upgradedb(from_revision=None, 
to_revision=None, show_sql_only=True)
+
+        @mock.patch("airflow.utils.db_manager.inspect")
+        @mock.patch.object(FABDBManager, "metadata")
+        def test_drop_tables(self, mock_metadata, mock_inspect, session):
+            manager = FABDBManager(session)
+            connection = mock.MagicMock()
+            manager.drop_tables(connection)
+            mock_metadata.drop_all.assert_called_once_with(connection)
+
+        @pytest.mark.parametrize("skip_init", [True, False])
+        @mock.patch.object(FABDBManager, "drop_tables")
+        @mock.patch.object(FABDBManager, "initdb")
+        @mock.patch("airflow.utils.db.create_global_lock", new=mock.MagicMock)
+        def test_resetdb(self, mock_initdb, mock_drop_tables, session, 
skip_init):
+            manager = FABDBManager(session)
+            manager.resetdb(skip_init=skip_init)
+            mock_drop_tables.assert_called_once()
+            if skip_init:
+                mock_initdb.assert_not_called()
+            else:
+                mock_initdb.assert_called_once()
 except ModuleNotFoundError:
     pass
diff --git a/tests/utils/test_db.py b/tests/utils/test_db.py
index 2308ed3ca8..2a197c2e6c 100644
--- a/tests/utils/test_db.py
+++ b/tests/utils/test_db.py
@@ -49,6 +49,7 @@ from airflow.utils.db import (
     upgradedb,
 )
 from airflow.utils.db_manager import RunDBManager
+from tests.test_utils.config import conf_vars
 
 pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode]
 
@@ -92,7 +93,7 @@ class TestDb:
             # sqlite sequence is used for autoincrementing columns created 
with `sqlite_autoincrement` option
             lambda t: (t[0] == "remove_table" and t[1].name == 
"sqlite_sequence"),
             # fab version table
-            lambda t: (t[0] == "remove_table" and t[1].name == 
"fab_alembic_version"),
+            lambda t: (t[0] == "remove_table" and t[1].name == 
"alembic_version_fab"),
         ]
 
         for ignore in ignores:
@@ -131,7 +132,8 @@ class TestDb:
     @mock.patch("alembic.command")
     def test_upgradedb(self, mock_alembic_command):
         upgradedb()
-        mock_alembic_command.upgrade.assert_called_once_with(mock.ANY, 
revision="heads")
+        mock_alembic_command.upgrade.assert_called_with(mock.ANY, 
revision="heads")
+        assert mock_alembic_command.upgrade.call_count == 2
 
     @pytest.mark.parametrize(
         "from_revision, to_revision",
@@ -200,6 +202,10 @@ class TestDb:
         assert actual == "abc"
 
     @pytest.mark.parametrize("skip_init", [False, True])
+    @conf_vars(
+        {("database", "external_db_managers"): 
"airflow.providers.fab.auth_manager.models.db.FABDBManager"}
+    )
+    @mock.patch("airflow.providers.fab.auth_manager.models.db.FABDBManager")
     @mock.patch("airflow.utils.db.create_global_lock", new=MagicMock)
     @mock.patch("airflow.utils.db.drop_airflow_models")
     @mock.patch("airflow.utils.db.drop_airflow_moved_tables")
@@ -211,6 +217,7 @@ class TestDb:
         mock_init,
         mock_drop_moved,
         mock_drop_airflow,
+        mock_fabdb_manager,
         skip_init,
     ):
         session_mock = MagicMock()
diff --git a/tests/utils/test_db_manager.py b/tests/utils/test_db_manager.py
index d6fc91f939..1c8a6c6c7d 100644
--- a/tests/utils/test_db_manager.py
+++ b/tests/utils/test_db_manager.py
@@ -23,7 +23,7 @@ from sqlalchemy import Table
 
 from airflow.exceptions import AirflowException
 from airflow.models import Base
-from airflow.utils.db import downgrade, initdb, upgradedb
+from airflow.utils.db import downgrade, initdb
 from airflow.utils.db_manager import BaseDBManager, RunDBManager
 from tests.test_utils.config import conf_vars
 
@@ -57,29 +57,26 @@ class TestRunDBManager:
             run_db_manager.validate()
         metadata._remove_table("dag_run", None)
 
-    @mock.patch.object(RunDBManager, "downgradedb")
+    @mock.patch.object(RunDBManager, "downgrade")
     @mock.patch.object(RunDBManager, "upgradedb")
     @mock.patch.object(RunDBManager, "initdb")
     def test_init_db_calls_rundbmanager(self, mock_initdb, mock_upgrade_db, 
mock_downgrade_db, session):
         initdb(session=session)
         mock_initdb.assert_called()
         mock_initdb.assert_called_once_with(session)
-        mock_upgrade_db.assert_not_called()
         mock_downgrade_db.assert_not_called()
 
-    @mock.patch.object(RunDBManager, "downgradedb")
+    @mock.patch.object(RunDBManager, "downgrade")
     @mock.patch.object(RunDBManager, "upgradedb")
     @mock.patch.object(RunDBManager, "initdb")
     @mock.patch("alembic.command")
-    def test_upgradedb_or_downgrade_dont_call_rundbmanager(
+    def test_downgrade_dont_call_rundbmanager(
         self, mock_alembic_command, mock_initdb, mock_upgrade_db, 
mock_downgrade_db, session
     ):
-        upgradedb(session=session)
-        mock_alembic_command.upgrade.assert_called_once_with(mock.ANY, 
revision="heads")
         downgrade(to_revision="base")
         mock_alembic_command.downgrade.assert_called_once_with(mock.ANY, 
revision="base", sql=False)
-        mock_initdb.assert_not_called()
         mock_upgrade_db.assert_not_called()
+        mock_initdb.assert_not_called()
         mock_downgrade_db.assert_not_called()
 
     @conf_vars(
@@ -96,12 +93,12 @@ class TestRunDBManager:
         # upgradedb
         ext_db.upgradedb(session=session)
         fabdb_manager.upgradedb.assert_called_once()
-        # downgradedb
-        ext_db.downgradedb(session=session)
-        mock_fabdb_manager.return_value.downgradedb.assert_called_once()
+        # downgrade
+        ext_db.downgrade(session=session)
+        mock_fabdb_manager.return_value.downgrade.assert_called_once()
         connection = mock.MagicMock()
-        ext_db.drop_tables(connection)
-        
mock_fabdb_manager.metadata.drop_all.assert_called_once_with(connection)
+        ext_db.drop_tables(session, connection)
+        
mock_fabdb_manager.return_value.drop_tables.assert_called_once_with(connection)
 
 
 class MockDBManager(BaseDBManager):
@@ -127,8 +124,14 @@ class TestBaseDBManager:
 
     @mock.patch.object(BaseDBManager, "get_alembic_config")
     @mock.patch("alembic.command.upgrade")
-    def test_upgradedb(self, mock_alembic_cmd, mock_alembic_config, session, 
caplog):
+    def test_upgrade(self, mock_alembic_cmd, mock_alembic_config, session, 
caplog):
         manager = MockDBManager(session)
         manager.upgradedb()
         mock_alembic_cmd.assert_called_once()
         assert "Upgrading the MockDBManager database" in caplog.text
+
+    @mock.patch.object(BaseDBManager, "get_script_object")
+    @mock.patch.object(BaseDBManager, "get_current_revision")
+    def test_check_migration(self, mock_script_obj, mock_current_revision, 
session):
+        manager = MockDBManager(session)
+        manager.check_migration()  # just ensure this can be called

Reply via email to