This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new d535f3fe56 fix: make catalog migration lenient (#29549)
d535f3fe56 is described below

commit d535f3fe56bc9d3b8400ef806119121c7cc0af31
Author: Beto Dealmeida <[email protected]>
AuthorDate: Thu Jul 11 15:10:02 2024 -0400

    fix: make catalog migration lenient (#29549)
---
 superset/migrations/shared/catalogs.py             | 117 ++++++++++++-------
 .../unit_tests/migrations/shared/catalogs_test.py  | 125 +++++++++++++++++++++
 2 files changed, 204 insertions(+), 38 deletions(-)

diff --git a/superset/migrations/shared/catalogs.py 
b/superset/migrations/shared/catalogs.py
index 6c03faec46..b4c8658cbe 100644
--- a/superset/migrations/shared/catalogs.py
+++ b/superset/migrations/shared/catalogs.py
@@ -23,6 +23,7 @@ from typing import Any, Type
 import sqlalchemy as sa
 from alembic import op
 from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import Session
 
 from superset import db, security_manager
 from superset.daos.database import DatabaseDAO
@@ -86,6 +87,24 @@ class Slice(Base):
     schema_perm = sa.Column(sa.String(1000))
 
 
+def get_schemas(database_name: str) -> list[str]:
+    """
+    Read all known schemas from the schema permissions.
+    """
+    query = f"""
+SELECT
+    avm.name
+FROM ab_view_menu avm
+JOIN ab_permission_view apv ON avm.id = apv.view_menu_id
+JOIN ab_permission ap ON apv.permission_id = ap.id
+WHERE
+    avm.name LIKE '[{database_name}]%' AND
+    ap.name = 'schema_access';
+    """
+    # [PostgreSQL].[postgres].[public] => public
+    return sorted({row[0].split(".")[-1][1:-1] for row in op.execute(query)})
+
+
 def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
     """
     Update models when catalogs are introduced in a DB engine spec.
@@ -116,25 +135,7 @@ def upgrade_catalog_perms(engines: set[str] | None = None) 
-> None:
         )
         add_pvms(session, {perm: ("catalog_access",)})
 
-        # update schema_perms
-        ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
-        for schema in database.get_all_schema_names(
-            catalog=catalog,
-            cache=False,
-            ssh_tunnel=ssh_tunnel,
-        ):
-            perm = security_manager.get_schema_perm(
-                database.database_name,
-                None,
-                schema,
-            )
-            existing_pvm = 
session.query(ViewMenu).filter_by(name=perm).one_or_none()
-            if existing_pvm:
-                existing_pvm.name = security_manager.get_schema_perm(
-                    database.database_name,
-                    catalog,
-                    schema,
-                )
+        upgrade_schema_perms(database, catalog, session)
 
         # update existing models
         models = [
@@ -166,6 +167,35 @@ def upgrade_catalog_perms(engines: set[str] | None = None) 
-> None:
     session.commit()
 
 
+def upgrade_schema_perms(database: Database, catalog: str, session: Session) 
-> None:
+    """
+    Rename existing schema permissions to include the catalog.
+    """
+    ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
+    try:
+        schemas = database.get_all_schema_names(
+            catalog=catalog,
+            cache=False,
+            ssh_tunnel=ssh_tunnel,
+        )
+    except Exception:  # pylint: disable=broad-except
+        schemas = get_schemas(database.database_name)
+
+    for schema in schemas:
+        perm = security_manager.get_schema_perm(
+            database.database_name,
+            None,
+            schema,
+        )
+        existing_pvm = 
session.query(ViewMenu).filter_by(name=perm).one_or_none()
+        if existing_pvm:
+            existing_pvm.name = security_manager.get_schema_perm(
+                database.database_name,
+                catalog,
+                schema,
+            )
+
+
 def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
     """
     Reverse the process of `upgrade_catalog_perms`.
@@ -183,25 +213,7 @@ def downgrade_catalog_perms(engines: set[str] | None = 
None) -> None:
         if catalog is None:
             continue
 
-        # update schema_perms
-        ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
-        for schema in database.get_all_schema_names(
-            catalog=catalog,
-            cache=False,
-            ssh_tunnel=ssh_tunnel,
-        ):
-            perm = security_manager.get_schema_perm(
-                database.database_name,
-                catalog,
-                schema,
-            )
-            existing_pvm = 
session.query(ViewMenu).filter_by(name=perm).one_or_none()
-            if existing_pvm:
-                existing_pvm.name = security_manager.get_schema_perm(
-                    database.database_name,
-                    None,
-                    schema,
-                )
+        downgrade_schema_perms(database, catalog, session)
 
         # update existing models
         models = [
@@ -231,3 +243,32 @@ def downgrade_catalog_perms(engines: set[str] | None = 
None) -> None:
                 chart.schema_perm = schema_perm
 
     session.commit()
+
+
+def downgrade_schema_perms(database: Database, catalog: str, session: Session) 
-> None:
+    """
+    Rename existing schema permissions to omit the catalog.
+    """
+    ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
+    try:
+        schemas = database.get_all_schema_names(
+            catalog=catalog,
+            cache=False,
+            ssh_tunnel=ssh_tunnel,
+        )
+    except Exception:  # pylint: disable=broad-except
+        schemas = get_schemas(database.database_name)
+
+    for schema in schemas:
+        perm = security_manager.get_schema_perm(
+            database.database_name,
+            catalog,
+            schema,
+        )
+        existing_pvm = 
session.query(ViewMenu).filter_by(name=perm).one_or_none()
+        if existing_pvm:
+            existing_pvm.name = security_manager.get_schema_perm(
+                database.database_name,
+                None,
+                schema,
+            )
diff --git a/tests/unit_tests/migrations/shared/catalogs_test.py 
b/tests/unit_tests/migrations/shared/catalogs_test.py
index ca715bec94..78ef522217 100644
--- a/tests/unit_tests/migrations/shared/catalogs_test.py
+++ b/tests/unit_tests/migrations/shared/catalogs_test.py
@@ -143,3 +143,128 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, 
session: Session) -> None:
         ("[my_db].[public]",),
         ("[my_db].[db]",),
     ]
+
+
+def test_upgrade_catalog_perms_graceful(
+    mocker: MockerFixture,
+    session: Session,
+) -> None:
+    """
+    Test the `upgrade_catalog_perms` function when it fails to connect to the 
DB.
+
+    During the migration we try to connect to the analytical database to get 
the list of
+    schemas. This should fail gracefully and not raise an exception, since the 
database
+    could be offline, and the permissions can be generated later then the 
admin enables
+    catalog browsing on the database (permissions are always synced on a DB 
update, see
+    `UpdateDatabaseCommand`).
+    """
+    from superset.connectors.sqla.models import SqlaTable
+    from superset.models.core import Database
+    from superset.models.slice import Slice
+    from superset.models.sql_lab import Query, SavedQuery, TableSchema, 
TabState
+
+    engine = session.get_bind()
+    Database.metadata.create_all(engine)
+
+    mocker.patch("superset.migrations.shared.catalogs.op")
+    db = mocker.patch("superset.migrations.shared.catalogs.db")
+    db.Session.return_value = session
+
+    mocker.patch.object(
+        Database,
+        "get_all_schema_names",
+        side_effect=Exception("Failed to connect to the database"),
+    )
+    mocker.patch("superset.migrations.shared.catalogs.op", session)
+
+    database = Database(
+        database_name="my_db",
+        sqlalchemy_uri="postgresql://localhost/db",
+    )
+    dataset = SqlaTable(
+        table_name="my_table",
+        database=database,
+        catalog=None,
+        schema="public",
+        schema_perm="[my_db].[public]",
+    )
+    session.add(dataset)
+    session.commit()
+
+    chart = Slice(
+        slice_name="my_chart",
+        datasource_type="table",
+        datasource_id=dataset.id,
+    )
+    query = Query(
+        client_id="foo",
+        database=database,
+        catalog=None,
+        schema="public",
+    )
+    saved_query = SavedQuery(
+        database=database,
+        sql="SELECT * FROM public.t",
+        catalog=None,
+        schema="public",
+    )
+    tab_state = TabState(
+        database=database,
+        catalog=None,
+        schema="public",
+    )
+    table_schema = TableSchema(
+        database=database,
+        catalog=None,
+        schema="public",
+    )
+    session.add_all([chart, query, saved_query, tab_state, table_schema])
+    session.commit()
+
+    # before migration
+    assert dataset.catalog is None
+    assert query.catalog is None
+    assert saved_query.catalog is None
+    assert tab_state.catalog is None
+    assert table_schema.catalog is None
+    assert dataset.schema_perm == "[my_db].[public]"
+    assert chart.schema_perm == "[my_db].[public]"
+    assert session.query(ViewMenu.name).all() == [
+        ("[my_db].(id:1)",),
+        ("[my_db].[my_table](id:1)",),
+        ("[my_db].[public]",),
+    ]
+
+    upgrade_catalog_perms()
+
+    # after migration
+    assert dataset.catalog == "db"
+    assert query.catalog == "db"
+    assert saved_query.catalog == "db"
+    assert tab_state.catalog == "db"
+    assert table_schema.catalog == "db"
+    assert dataset.schema_perm == "[my_db].[db].[public]"
+    assert chart.schema_perm == "[my_db].[db].[public]"
+    assert session.query(ViewMenu.name).all() == [
+        ("[my_db].(id:1)",),
+        ("[my_db].[my_table](id:1)",),
+        ("[my_db].[db].[public]",),
+        ("[my_db].[db]",),
+    ]
+
+    downgrade_catalog_perms()
+
+    # revert
+    assert dataset.catalog is None
+    assert query.catalog is None
+    assert saved_query.catalog is None
+    assert tab_state.catalog is None
+    assert table_schema.catalog is None
+    assert dataset.schema_perm == "[my_db].[public]"
+    assert chart.schema_perm == "[my_db].[public]"
+    assert session.query(ViewMenu.name).all() == [
+        ("[my_db].(id:1)",),
+        ("[my_db].[my_table](id:1)",),
+        ("[my_db].[public]",),
+        ("[my_db].[db]",),
+    ]

Reply via email to