This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2500dcf20d2 Move FAB session table creation to FAB provider (#47969)
2500dcf20d2 is described below

commit 2500dcf20d2782d16da53ee857c0aab21bfdfbf2
Author: Jed Cunningham <66968678+jedcunning...@users.noreply.github.com>
AuthorDate: Wed Mar 19 15:41:37 2025 -0600

    Move FAB session table creation to FAB provider (#47969)
    
    We need to create the `session` table in the provider db manager, not in
    the core db utils.
    
    Co-authored-by: vincbeck <vincb...@amazon.com>
---
 airflow/utils/db.py                                | 21 -----------------
 .../providers/fab/auth_manager/models/db.py        | 27 ++++++++++++++++++----
 .../tests/unit/fab/auth_manager/models/test_db.py  |  4 +++-
 3 files changed, 25 insertions(+), 27 deletions(-)

diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 70e6955c39b..7055a1c7601 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -729,34 +729,15 @@ def create_default_connections(session: Session = 
NEW_SESSION):
     )
 
 
-def _get_flask_db(sql_database_uri):
-    from flask import Flask
-    from flask_sqlalchemy import SQLAlchemy
-
-    from airflow.providers.fab.www.session import 
AirflowDatabaseSessionInterface
-
-    flask_app = Flask(__name__)
-    flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri
-    flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
-    db = SQLAlchemy(flask_app)
-    AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", 
key_prefix="")
-    return db
-
-
 def _create_db_from_orm(session):
     log.info("Creating Airflow database tables from the ORM")
     from alembic import command
 
     from airflow.models.base import Base
 
-    def _create_flask_session_tbl(sql_database_uri):
-        db = _get_flask_db(sql_database_uri)
-        db.create_all()
-
     with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
         engine = session.get_bind().engine
         Base.metadata.create_all(engine)
-        _create_flask_session_tbl(engine.url)
         # stamp the migration head
         config = _get_alembic_config()
         command.stamp(config, "head")
@@ -1254,8 +1235,6 @@ def drop_airflow_models(connection):
     from airflow.models.base import Base
 
     Base.metadata.drop_all(connection)
-    db = _get_flask_db(connection.engine.url)
-    db.drop_all()
     # alembic adds significant import time, so we import it lazily
     from alembic.migration import MigrationContext
 
diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py 
b/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py
index ce0efef55a1..5e2c5397745 100644
--- a/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py
+++ b/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py
@@ -31,6 +31,20 @@ _REVISION_HEADS_MAP: dict[str, str] = {
 }
 
 
+def _get_flask_db(sql_database_uri):
+    from flask import Flask
+    from flask_sqlalchemy import SQLAlchemy
+
+    from airflow.providers.fab.www.session import 
AirflowDatabaseSessionInterface
+
+    flask_app = Flask(__name__)
+    flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri
+    flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
+    db = SQLAlchemy(flask_app)
+    AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", 
key_prefix="")
+    return db
+
+
 class FABDBManager(BaseDBManager):
     """Manages FAB database."""
 
@@ -40,6 +54,10 @@ class FABDBManager(BaseDBManager):
     alembic_file = (PACKAGE_DIR / "alembic.ini").as_posix()
     supports_table_dropping = True
 
+    def _create_db_from_orm(self):
+        super()._create_db_from_orm()
+        _get_flask_db(settings.SQL_ALCHEMY_CONN).create_all()
+
     def upgradedb(self, to_revision=None, from_revision=None, 
show_sql_only=False):
         """Upgrade the database."""
         if from_revision and not show_sql_only:
@@ -68,11 +86,6 @@ class FABDBManager(BaseDBManager):
             _offline_migration(command.upgrade, config, 
f"{from_revision}:{to_revision}")
             return  # only running sql; our job is done
 
-        if not self.get_current_revision():
-            # New DB; initialize and exit
-            self.initdb()
-            return
-
         command.upgrade(config, revision=to_revision or "heads")
 
     def downgrade(self, to_revision, from_revision=None, show_sql_only=False):
@@ -104,3 +117,7 @@ class FABDBManager(BaseDBManager):
         else:
             self.log.info("Applying FAB downgrade migrations.")
             command.downgrade(config, revision=to_revision, sql=show_sql_only)
+
+    def drop_tables(self, connection):
+        super().drop_tables(connection)
+        _get_flask_db(settings.SQL_ALCHEMY_CONN).drop_all()
diff --git a/providers/fab/tests/unit/fab/auth_manager/models/test_db.py 
b/providers/fab/tests/unit/fab/auth_manager/models/test_db.py
index f0920ebb151..50eaf9450e9 100644
--- a/providers/fab/tests/unit/fab/auth_manager/models/test_db.py
+++ b/providers/fab/tests/unit/fab/auth_manager/models/test_db.py
@@ -110,10 +110,12 @@ try:
 
         @mock.patch("airflow.utils.db_manager.inspect")
         @mock.patch.object(FABDBManager, "metadata")
-        def test_drop_tables(self, mock_metadata, mock_inspect, session):
+        
@mock.patch("airflow.providers.fab.auth_manager.models.db._get_flask_db")
+        def test_drop_tables(self, mock__get_flask_db, mock_metadata, 
mock_inspect, session):
             manager = FABDBManager(session)
             connection = mock.MagicMock()
             manager.drop_tables(connection)
+            mock__get_flask_db.return_value.drop_all.assert_called_once_with()
             mock_metadata.drop_all.assert_called_once_with(connection)
 
         @pytest.mark.parametrize("skip_init", [True, False])

Reply via email to