This is an automated email from the ASF dual-hosted git repository. jedcunningham pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 2500dcf20d2 Move FAB session table creation to FAB provider (#47969) 2500dcf20d2 is described below commit 2500dcf20d2782d16da53ee857c0aab21bfdfbf2 Author: Jed Cunningham <66968678+jedcunning...@users.noreply.github.com> AuthorDate: Wed Mar 19 15:41:37 2025 -0600 Move FAB session table creation to FAB provider (#47969) We need to create the `session` table in the provider db manager, not in the core db utils. Co-authored-by: vincbeck <vincb...@amazon.com> --- airflow/utils/db.py | 21 ----------------- .../providers/fab/auth_manager/models/db.py | 27 ++++++++++++++++++---- .../tests/unit/fab/auth_manager/models/test_db.py | 4 +++- 3 files changed, 25 insertions(+), 27 deletions(-) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 70e6955c39b..7055a1c7601 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -729,34 +729,15 @@ def create_default_connections(session: Session = NEW_SESSION): ) -def _get_flask_db(sql_database_uri): - from flask import Flask - from flask_sqlalchemy import SQLAlchemy - - from airflow.providers.fab.www.session import AirflowDatabaseSessionInterface - - flask_app = Flask(__name__) - flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri - flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False - db = SQLAlchemy(flask_app) - AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="") - return db - - def _create_db_from_orm(session): log.info("Creating Airflow database tables from the ORM") from alembic import command from airflow.models.base import Base - def _create_flask_session_tbl(sql_database_uri): - db = _get_flask_db(sql_database_uri) - db.create_all() - with create_global_lock(session=session, lock=DBLocks.MIGRATIONS): engine = session.get_bind().engine Base.metadata.create_all(engine) - _create_flask_session_tbl(engine.url) # stamp the migration head config = _get_alembic_config() command.stamp(config, "head") @@ -1254,8 +1235,6 @@ def drop_airflow_models(connection): from airflow.models.base import Base Base.metadata.drop_all(connection) - db = _get_flask_db(connection.engine.url) - db.drop_all() # alembic adds significant import time, so we import it lazily from alembic.migration import MigrationContext diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py b/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py index ce0efef55a1..5e2c5397745 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py @@ -31,6 +31,20 @@ _REVISION_HEADS_MAP: dict[str, str] = { } +def _get_flask_db(sql_database_uri): + from flask import Flask + from flask_sqlalchemy import SQLAlchemy + + from airflow.providers.fab.www.session import AirflowDatabaseSessionInterface + + flask_app = Flask(__name__) + flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri + flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False + db = SQLAlchemy(flask_app) + AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="") + return db + + class FABDBManager(BaseDBManager): """Manages FAB database.""" @@ -40,6 +54,10 @@ class FABDBManager(BaseDBManager): alembic_file = (PACKAGE_DIR / "alembic.ini").as_posix() supports_table_dropping = True + def _create_db_from_orm(self): + super()._create_db_from_orm() + _get_flask_db(settings.SQL_ALCHEMY_CONN).create_all() + def upgradedb(self, to_revision=None, from_revision=None, show_sql_only=False): """Upgrade the database.""" if from_revision and not show_sql_only: @@ -68,11 +86,6 @@ class FABDBManager(BaseDBManager): _offline_migration(command.upgrade, config, f"{from_revision}:{to_revision}") return # only running sql; our job is done - if not self.get_current_revision(): - # New DB; initialize and exit - self.initdb() - return - command.upgrade(config, revision=to_revision or "heads") def downgrade(self, to_revision, from_revision=None, show_sql_only=False): @@ -104,3 +117,7 @@ class FABDBManager(BaseDBManager): else: self.log.info("Applying FAB downgrade migrations.") command.downgrade(config, revision=to_revision, sql=show_sql_only) + + def drop_tables(self, connection): + super().drop_tables(connection) + _get_flask_db(settings.SQL_ALCHEMY_CONN).drop_all() diff --git a/providers/fab/tests/unit/fab/auth_manager/models/test_db.py b/providers/fab/tests/unit/fab/auth_manager/models/test_db.py index f0920ebb151..50eaf9450e9 100644 --- a/providers/fab/tests/unit/fab/auth_manager/models/test_db.py +++ b/providers/fab/tests/unit/fab/auth_manager/models/test_db.py @@ -110,10 +110,12 @@ try: @mock.patch("airflow.utils.db_manager.inspect") @mock.patch.object(FABDBManager, "metadata") - def test_drop_tables(self, mock_metadata, mock_inspect, session): + @mock.patch("airflow.providers.fab.auth_manager.models.db._get_flask_db") + def test_drop_tables(self, mock__get_flask_db, mock_metadata, mock_inspect, session): manager = FABDBManager(session) connection = mock.MagicMock() manager.drop_tables(connection) + mock__get_flask_db.return_value.drop_all.assert_called_once_with() mock_metadata.drop_all.assert_called_once_with(connection) @pytest.mark.parametrize("skip_init", [True, False])