This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new d535f3fe56 fix: make catalog migration lenient (#29549)
d535f3fe56 is described below
commit d535f3fe56bc9d3b8400ef806119121c7cc0af31
Author: Beto Dealmeida <[email protected]>
AuthorDate: Thu Jul 11 15:10:02 2024 -0400
fix: make catalog migration lenient (#29549)
---
superset/migrations/shared/catalogs.py | 117 ++++++++++++-------
.../unit_tests/migrations/shared/catalogs_test.py | 125 +++++++++++++++++++++
2 files changed, 204 insertions(+), 38 deletions(-)
diff --git a/superset/migrations/shared/catalogs.py
b/superset/migrations/shared/catalogs.py
index 6c03faec46..b4c8658cbe 100644
--- a/superset/migrations/shared/catalogs.py
+++ b/superset/migrations/shared/catalogs.py
@@ -23,6 +23,7 @@ from typing import Any, Type
import sqlalchemy as sa
from alembic import op
from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import Session
from superset import db, security_manager
from superset.daos.database import DatabaseDAO
@@ -86,6 +87,24 @@ class Slice(Base):
schema_perm = sa.Column(sa.String(1000))
+def get_schemas(database_name: str) -> list[str]:
+ """
+ Read all known schemas from the schema permissions.
+ """
+ query = f"""
+SELECT
+ avm.name
+FROM ab_view_menu avm
+JOIN ab_permission_view apv ON avm.id = apv.view_menu_id
+JOIN ab_permission ap ON apv.permission_id = ap.id
+WHERE
+ avm.name LIKE '[{database_name}]%' AND
+ ap.name = 'schema_access';
+ """
+ # [PostgreSQL].[postgres].[public] => public
+ return sorted({row[0].split(".")[-1][1:-1] for row in op.execute(query)})
+
+
def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
"""
Update models when catalogs are introduced in a DB engine spec.
@@ -116,25 +135,7 @@ def upgrade_catalog_perms(engines: set[str] | None = None)
-> None:
)
add_pvms(session, {perm: ("catalog_access",)})
- # update schema_perms
- ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
- for schema in database.get_all_schema_names(
- catalog=catalog,
- cache=False,
- ssh_tunnel=ssh_tunnel,
- ):
- perm = security_manager.get_schema_perm(
- database.database_name,
- None,
- schema,
- )
- existing_pvm =
session.query(ViewMenu).filter_by(name=perm).one_or_none()
- if existing_pvm:
- existing_pvm.name = security_manager.get_schema_perm(
- database.database_name,
- catalog,
- schema,
- )
+ upgrade_schema_perms(database, catalog, session)
# update existing models
models = [
@@ -166,6 +167,35 @@ def upgrade_catalog_perms(engines: set[str] | None = None)
-> None:
session.commit()
+def upgrade_schema_perms(database: Database, catalog: str, session: Session)
-> None:
+ """
+ Rename existing schema permissions to include the catalog.
+ """
+ ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
+ try:
+ schemas = database.get_all_schema_names(
+ catalog=catalog,
+ cache=False,
+ ssh_tunnel=ssh_tunnel,
+ )
+ except Exception: # pylint: disable=broad-except
+ schemas = get_schemas(database.database_name)
+
+ for schema in schemas:
+ perm = security_manager.get_schema_perm(
+ database.database_name,
+ None,
+ schema,
+ )
+ existing_pvm =
session.query(ViewMenu).filter_by(name=perm).one_or_none()
+ if existing_pvm:
+ existing_pvm.name = security_manager.get_schema_perm(
+ database.database_name,
+ catalog,
+ schema,
+ )
+
+
def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
"""
Reverse the process of `upgrade_catalog_perms`.
@@ -183,25 +213,7 @@ def downgrade_catalog_perms(engines: set[str] | None =
None) -> None:
if catalog is None:
continue
- # update schema_perms
- ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
- for schema in database.get_all_schema_names(
- catalog=catalog,
- cache=False,
- ssh_tunnel=ssh_tunnel,
- ):
- perm = security_manager.get_schema_perm(
- database.database_name,
- catalog,
- schema,
- )
- existing_pvm =
session.query(ViewMenu).filter_by(name=perm).one_or_none()
- if existing_pvm:
- existing_pvm.name = security_manager.get_schema_perm(
- database.database_name,
- None,
- schema,
- )
+ downgrade_schema_perms(database, catalog, session)
# update existing models
models = [
@@ -231,3 +243,32 @@ def downgrade_catalog_perms(engines: set[str] | None =
None) -> None:
chart.schema_perm = schema_perm
session.commit()
+
+
+def downgrade_schema_perms(database: Database, catalog: str, session: Session)
-> None:
+ """
+ Rename existing schema permissions to omit the catalog.
+ """
+ ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
+ try:
+ schemas = database.get_all_schema_names(
+ catalog=catalog,
+ cache=False,
+ ssh_tunnel=ssh_tunnel,
+ )
+ except Exception: # pylint: disable=broad-except
+ schemas = get_schemas(database.database_name)
+
+ for schema in schemas:
+ perm = security_manager.get_schema_perm(
+ database.database_name,
+ catalog,
+ schema,
+ )
+ existing_pvm =
session.query(ViewMenu).filter_by(name=perm).one_or_none()
+ if existing_pvm:
+ existing_pvm.name = security_manager.get_schema_perm(
+ database.database_name,
+ None,
+ schema,
+ )
diff --git a/tests/unit_tests/migrations/shared/catalogs_test.py
b/tests/unit_tests/migrations/shared/catalogs_test.py
index ca715bec94..78ef522217 100644
--- a/tests/unit_tests/migrations/shared/catalogs_test.py
+++ b/tests/unit_tests/migrations/shared/catalogs_test.py
@@ -143,3 +143,128 @@ def test_upgrade_catalog_perms(mocker: MockerFixture,
session: Session) -> None:
("[my_db].[public]",),
("[my_db].[db]",),
]
+
+
+def test_upgrade_catalog_perms_graceful(
+ mocker: MockerFixture,
+ session: Session,
+) -> None:
+ """
+ Test the `upgrade_catalog_perms` function when it fails to connect to the
DB.
+
+ During the migration we try to connect to the analytical database to get
the list of
+ schemas. This should fail gracefully and not raise an exception, since the
database
+ could be offline, and the permissions can be generated later then the
admin enables
+ catalog browsing on the database (permissions are always synced on a DB
update, see
+ `UpdateDatabaseCommand`).
+ """
+ from superset.connectors.sqla.models import SqlaTable
+ from superset.models.core import Database
+ from superset.models.slice import Slice
+ from superset.models.sql_lab import Query, SavedQuery, TableSchema,
TabState
+
+ engine = session.get_bind()
+ Database.metadata.create_all(engine)
+
+ mocker.patch("superset.migrations.shared.catalogs.op")
+ db = mocker.patch("superset.migrations.shared.catalogs.db")
+ db.Session.return_value = session
+
+ mocker.patch.object(
+ Database,
+ "get_all_schema_names",
+ side_effect=Exception("Failed to connect to the database"),
+ )
+ mocker.patch("superset.migrations.shared.catalogs.op", session)
+
+ database = Database(
+ database_name="my_db",
+ sqlalchemy_uri="postgresql://localhost/db",
+ )
+ dataset = SqlaTable(
+ table_name="my_table",
+ database=database,
+ catalog=None,
+ schema="public",
+ schema_perm="[my_db].[public]",
+ )
+ session.add(dataset)
+ session.commit()
+
+ chart = Slice(
+ slice_name="my_chart",
+ datasource_type="table",
+ datasource_id=dataset.id,
+ )
+ query = Query(
+ client_id="foo",
+ database=database,
+ catalog=None,
+ schema="public",
+ )
+ saved_query = SavedQuery(
+ database=database,
+ sql="SELECT * FROM public.t",
+ catalog=None,
+ schema="public",
+ )
+ tab_state = TabState(
+ database=database,
+ catalog=None,
+ schema="public",
+ )
+ table_schema = TableSchema(
+ database=database,
+ catalog=None,
+ schema="public",
+ )
+ session.add_all([chart, query, saved_query, tab_state, table_schema])
+ session.commit()
+
+ # before migration
+ assert dataset.catalog is None
+ assert query.catalog is None
+ assert saved_query.catalog is None
+ assert tab_state.catalog is None
+ assert table_schema.catalog is None
+ assert dataset.schema_perm == "[my_db].[public]"
+ assert chart.schema_perm == "[my_db].[public]"
+ assert session.query(ViewMenu.name).all() == [
+ ("[my_db].(id:1)",),
+ ("[my_db].[my_table](id:1)",),
+ ("[my_db].[public]",),
+ ]
+
+ upgrade_catalog_perms()
+
+ # after migration
+ assert dataset.catalog == "db"
+ assert query.catalog == "db"
+ assert saved_query.catalog == "db"
+ assert tab_state.catalog == "db"
+ assert table_schema.catalog == "db"
+ assert dataset.schema_perm == "[my_db].[db].[public]"
+ assert chart.schema_perm == "[my_db].[db].[public]"
+ assert session.query(ViewMenu.name).all() == [
+ ("[my_db].(id:1)",),
+ ("[my_db].[my_table](id:1)",),
+ ("[my_db].[db].[public]",),
+ ("[my_db].[db]",),
+ ]
+
+ downgrade_catalog_perms()
+
+ # revert
+ assert dataset.catalog is None
+ assert query.catalog is None
+ assert saved_query.catalog is None
+ assert tab_state.catalog is None
+ assert table_schema.catalog is None
+ assert dataset.schema_perm == "[my_db].[public]"
+ assert chart.schema_perm == "[my_db].[public]"
+ assert session.query(ViewMenu.name).all() == [
+ ("[my_db].(id:1)",),
+ ("[my_db].[my_table](id:1)",),
+ ("[my_db].[public]",),
+ ("[my_db].[db]",),
+ ]