This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new f29e1e4c29 feat: catalog support for Databricks native (#28394)
f29e1e4c29 is described below

commit f29e1e4c29a46f7d607cfa59adb8bb21d107091c
Author: Beto Dealmeida <[email protected]>
AuthorDate: Thu May 9 17:41:15 2024 -0400

    feat: catalog support for Databricks native (#28394)
---
 .../src/components/DatabaseSelector/index.tsx      |  18 ++-
 superset/db_engine_specs/databricks.py             |  65 +++++++-
 superset/migrations/shared/catalogs.py             | 163 ++++++++++++++++++---
 ...0-52_58d051681a3b_add_catalog_perm_to_tables.py |   8 +-
 ...3_4081be5b6b74_enable_catalog_in_databricks.py} |  33 ++---
 .../utils/pandas_postprocessing/contribution.py    |   1 +
 tests/integration_tests/datasets/api_tests.py      |   2 +-
 .../unit_tests/db_engine_specs/test_databricks.py  |  19 +++
 tests/unit_tests/migrations/shared/__init__.py     |  16 ++
 .../unit_tests/migrations/shared/catalogs_test.py  | 145 ++++++++++++++++++
 10 files changed, 412 insertions(+), 58 deletions(-)

diff --git a/superset-frontend/src/components/DatabaseSelector/index.tsx 
b/superset-frontend/src/components/DatabaseSelector/index.tsx
index 6eb1340d5b..23767ba9f7 100644
--- a/superset-frontend/src/components/DatabaseSelector/index.tsx
+++ b/superset-frontend/src/components/DatabaseSelector/index.tsx
@@ -143,7 +143,7 @@ export default function DatabaseSelector({
   const showCatalogSelector = !!db?.allow_multi_catalog;
   const [currentDb, setCurrentDb] = useState<DatabaseValue | undefined>();
   const [currentCatalog, setCurrentCatalog] = useState<
-    CatalogOption | undefined
+    CatalogOption | null | undefined
   >(catalog ? { label: catalog, value: catalog, title: catalog } : undefined);
   const catalogRef = useRef(catalog);
   catalogRef.current = catalog;
@@ -265,7 +265,7 @@ export default function DatabaseSelector({
 
   const schemaOptions = schemaData || EMPTY_SCHEMA_OPTIONS;
 
-  function changeCatalog(catalog: CatalogOption | undefined) {
+  function changeCatalog(catalog: CatalogOption | null | undefined) {
     setCurrentCatalog(catalog);
     setCurrentSchema(undefined);
     if (onCatalogChange && catalog?.value !== catalogRef.current) {
@@ -280,7 +280,9 @@ export default function DatabaseSelector({
   } = useCatalogs({
     dbId: currentDb?.value,
     onSuccess: (catalogs, isFetched) => {
-      if (catalogs.length === 1) {
+      if (!showCatalogSelector) {
+        changeCatalog(null);
+      } else if (catalogs.length === 1) {
         changeCatalog(catalogs[0]);
       } else if (
         !catalogs.find(
@@ -290,11 +292,15 @@ export default function DatabaseSelector({
         changeCatalog(undefined);
       }
 
-      if (isFetched) {
+      if (showCatalogSelector && isFetched) {
         addSuccessToast('List refreshed');
       }
     },
-    onError: () => handleError(t('There was an error loading the catalogs')),
+    onError: () => {
+      if (showCatalogSelector) {
+        handleError(t('There was an error loading the catalogs'));
+      }
+    },
   });
 
   const catalogOptions = catalogData || EMPTY_CATALOG_OPTIONS;
@@ -365,7 +371,7 @@ export default function DatabaseSelector({
         onChange={item => changeCatalog(item as CatalogOption)}
         options={catalogOptions}
         showSearch
-        value={currentCatalog}
+        value={currentCatalog || undefined}
       />,
       refreshIcon,
     );
diff --git a/superset/db_engine_specs/databricks.py 
b/superset/db_engine_specs/databricks.py
index 6fc753c00e..3f72931626 100644
--- a/superset/db_engine_specs/databricks.py
+++ b/superset/db_engine_specs/databricks.py
@@ -39,7 +39,6 @@ if TYPE_CHECKING:
     from superset.models.core import Database
 
 
-#
 class DatabricksBaseSchema(Schema):
     """
     Fields that are required for both Databricks drivers that uses a
@@ -371,6 +370,8 @@ class 
DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
         "extra",
     }
 
+    supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = 
True
+
     @classmethod
     def build_sqlalchemy_uri(  # type: ignore
         cls, parameters: DatabricksNativeParametersType, *_
@@ -428,6 +429,35 @@ class 
DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
         spec.components.schema(cls.__name__, schema=cls.properties_schema)
         return spec.to_dict()["components"]["schemas"][cls.__name__]
 
+    @classmethod
+    def get_default_catalog(
+        cls,
+        database: Database,
+    ) -> str | None:
+        with database.get_inspector() as inspector:
+            return inspector.bind.execute("SELECT current_catalog()").scalar()
+
+    @classmethod
+    def get_prequeries(
+        cls,
+        catalog: str | None = None,
+        schema: str | None = None,
+    ) -> list[str]:
+        prequeries = []
+        if catalog:
+            prequeries.append(f"USE CATALOG {catalog}")
+        if schema:
+            prequeries.append(f"USE SCHEMA {schema}")
+        return prequeries
+
+    @classmethod
+    def get_catalog_names(
+        cls,
+        database: Database,
+        inspector: Inspector,
+    ) -> set[str]:
+        return {catalog for (catalog,) in inspector.bind.execute("SHOW 
CATALOGS")}
+
 
 class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
     engine = "databricks"
@@ -455,6 +485,8 @@ class 
DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
         "http_path_field",
     }
 
+    supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = 
True
+
     @classmethod
     def build_sqlalchemy_uri(  # type: ignore
         cls, parameters: DatabricksPythonConnectorParametersType, *_
@@ -502,3 +534,34 @@ class 
DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
             "default_schema": query["schema"],
             "encryption": encryption,
         }
+
+    @classmethod
+    def get_default_catalog(
+        cls,
+        database: Database,
+    ) -> str | None:
+        return database.url_object.query.get("catalog")
+
+    @classmethod
+    def get_catalog_names(
+        cls,
+        database: Database,
+        inspector: Inspector,
+    ) -> set[str]:
+        return {catalog for (catalog,) in inspector.bind.execute("SHOW 
CATALOGS")}
+
+    @classmethod
+    def adjust_engine_params(
+        cls,
+        uri: URL,
+        connect_args: dict[str, Any],
+        catalog: str | None = None,
+        schema: str | None = None,
+    ) -> tuple[URL, dict[str, Any]]:
+        if catalog:
+            uri = uri.update_query_dict({"catalog": catalog})
+
+        if schema:
+            uri = uri.update_query_dict({"schema": schema})
+
+        return uri, connect_args
diff --git a/superset/migrations/shared/catalogs.py 
b/superset/migrations/shared/catalogs.py
index 5d01ecfbfb..4b13d60435 100644
--- a/superset/migrations/shared/catalogs.py
+++ b/superset/migrations/shared/catalogs.py
@@ -18,23 +18,84 @@
 from __future__ import annotations
 
 import logging
+from typing import Any, Type
 
+import sqlalchemy as sa
 from alembic import op
+from sqlalchemy.ext.declarative import declarative_base
 
 from superset import db, security_manager
 from superset.daos.database import DatabaseDAO
+from superset.migrations.shared.security_converge import add_pvms, ViewMenu
 from superset.models.core import Database
 
 logger = logging.getLogger(__name__)
 
 
-def upgrade_schema_perms(engine: str | None = None) -> None:
+Base: Type[Any] = declarative_base()
+
+
+class SqlaTable(Base):
+    __tablename__ = "tables"
+
+    id = sa.Column(sa.Integer, primary_key=True)
+    database_id = sa.Column(sa.Integer, nullable=False)
+    schema_perm = sa.Column(sa.String(1000))
+    schema = sa.Column(sa.String(255))
+    catalog = sa.Column(sa.String(256), nullable=True, default=None)
+
+
+class Query(Base):
+    __tablename__ = "query"
+
+    id = sa.Column(sa.Integer, primary_key=True)
+    database_id = sa.Column(sa.Integer, nullable=False)
+    catalog = sa.Column(sa.String(256), nullable=True, default=None)
+
+
+class SavedQuery(Base):
+    __tablename__ = "saved_query"
+
+    id = sa.Column(sa.Integer, primary_key=True)
+    db_id = sa.Column(sa.Integer, nullable=False)
+    catalog = sa.Column(sa.String(256), nullable=True, default=None)
+
+
+class TabState(Base):
+    __tablename__ = "tab_state"
+
+    id = sa.Column(sa.Integer, primary_key=True)
+    database_id = sa.Column(sa.Integer, nullable=False)
+    catalog = sa.Column(sa.String(256), nullable=True, default=None)
+
+
+class TableSchema(Base):
+    __tablename__ = "table_schema"
+
+    id = sa.Column(sa.Integer, primary_key=True)
+    database_id = sa.Column(sa.Integer, nullable=False)
+    catalog = sa.Column(sa.String(256), nullable=True, default=None)
+
+
+class Slice(Base):
+    __tablename__ = "slices"
+
+    id = sa.Column(sa.Integer, primary_key=True)
+    datasource_id = sa.Column(sa.Integer)
+    datasource_type = sa.Column(sa.String(200))
+    schema_perm = sa.Column(sa.String(1000))
+
+
+def upgrade_catalog_perms(engine: str | None = None) -> None:
     """
-    Update schema permissions to include the catalog part.
+    Update models when catalogs are introduced in a DB engine spec.
+
+    When an existing DB engine spec starts to support catalogs we need to:
+
+        - Add a `catalog_access` permission for each catalog.
+        - Populate the `catalog` field with the default catalog for each 
related model.
+        - Update `schema_perm` to include the default catalog.
 
-    Before SIP-95 schema permissions were stored in the format 
`[db].[schema]`. With the
-    introduction of catalogs, any existing permissions need to be renamed to 
include the
-    catalog: `[db].[catalog].[schema]`.
     """
     bind = op.get_bind()
     session = db.Session(bind=bind)
@@ -46,6 +107,16 @@ def upgrade_schema_perms(engine: str | None = None) -> None:
             continue
 
         catalog = database.get_default_catalog()
+        if catalog is None:
+            continue
+
+        perm = security_manager.get_catalog_perm(
+            database.database_name,
+            catalog,
+        )
+        add_pvms(session, {perm: ("catalog_access",)})
+
+        # update schema_perms
         ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
         for schema in database.get_all_schema_names(
             catalog=catalog,
@@ -57,29 +128,47 @@ def upgrade_schema_perms(engine: str | None = None) -> 
None:
                 None,
                 schema,
             )
-            existing_pvm = security_manager.find_permission_view_menu(
-                "schema_access",
-                perm,
-            )
+            existing_pvm = 
session.query(ViewMenu).filter_by(name=perm).one_or_none()
             if existing_pvm:
-                existing_pvm.view_menu.name = security_manager.get_schema_perm(
+                existing_pvm.name = security_manager.get_schema_perm(
                     database.database_name,
                     catalog,
                     schema,
                 )
 
+        # update existing models
+        models = [
+            (Query, "database_id"),
+            (SavedQuery, "db_id"),
+            (TabState, "database_id"),
+            (TableSchema, "database_id"),
+            (SqlaTable, "database_id"),
+        ]
+        for model, column in models:
+            for instance in session.query(model).filter(
+                getattr(model, column) == database.id
+            ):
+                instance.catalog = catalog
+
+        for table in 
session.query(SqlaTable).filter_by(database_id=database.id):
+            schema_perm = security_manager.get_schema_perm(
+                database.database_name,
+                catalog,
+                table.schema,
+            )
+            table.schema_perm = schema_perm
+            for chart in session.query(Slice).filter_by(
+                datasource_id=table.id,
+                datasource_type="table",
+            ):
+                chart.schema_perm = schema_perm
+
     session.commit()
 
 
-def downgrade_schema_perms(engine: str | None = None) -> None:
+def downgrade_catalog_perms(engine: str | None = None) -> None:
     """
-    Update schema permissions to not have the catalog part.
-
-    Before SIP-95 schema permissions were stored in the format 
`[db].[schema]`. With the
-    introduction of catalogs, any existing permissions need to be renamed to 
include the
-    catalog: `[db].[catalog].[schema]`.
-
-    This helped function reverts the process.
+    Reverse the process of `upgrade_catalog_perms`.
     """
     bind = op.get_bind()
     session = db.Session(bind=bind)
@@ -91,6 +180,10 @@ def downgrade_schema_perms(engine: str | None = None) -> 
None:
             continue
 
         catalog = database.get_default_catalog()
+        if catalog is None:
+            continue
+
+        # update schema_perms
         ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
         for schema in database.get_all_schema_names(
             catalog=catalog,
@@ -102,15 +195,39 @@ def downgrade_schema_perms(engine: str | None = None) -> 
None:
                 catalog,
                 schema,
             )
-            existing_pvm = security_manager.find_permission_view_menu(
-                "schema_access",
-                perm,
-            )
+            existing_pvm = 
session.query(ViewMenu).filter_by(name=perm).one_or_none()
             if existing_pvm:
-                existing_pvm.view_menu.name = security_manager.get_schema_perm(
+                existing_pvm.name = security_manager.get_schema_perm(
                     database.database_name,
                     None,
                     schema,
                 )
 
+        # update existing models
+        models = [
+            (Query, "database_id"),
+            (SavedQuery, "db_id"),
+            (TabState, "database_id"),
+            (TableSchema, "database_id"),
+            (SqlaTable, "database_id"),
+        ]
+        for model, column in models:
+            for instance in session.query(model).filter(
+                getattr(model, column) == database.id
+            ):
+                instance.catalog = None
+
+        for table in 
session.query(SqlaTable).filter_by(database_id=database.id):
+            schema_perm = security_manager.get_schema_perm(
+                database.database_name,
+                None,
+                table.schema,
+            )
+            table.schema_perm = schema_perm
+            for chart in session.query(Slice).filter_by(
+                datasource_id=table.id,
+                datasource_type="table",
+            ):
+                chart.schema_perm = schema_perm
+
     session.commit()
diff --git 
a/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py
 
b/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py
index 17b33e1d0a..f8f7824744 100644
--- 
a/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py
+++ 
b/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py
@@ -26,8 +26,8 @@ import sqlalchemy as sa
 from alembic import op
 
 from superset.migrations.shared.catalogs import (
-    downgrade_schema_perms,
-    upgrade_schema_perms,
+    downgrade_catalog_perms,
+    upgrade_catalog_perms,
 )
 
 # revision identifiers, used by Alembic.
@@ -44,10 +44,10 @@ def upgrade():
         "slices",
         sa.Column("catalog_perm", sa.String(length=1000), nullable=True),
     )
-    upgrade_schema_perms(engine="postgresql")
+    upgrade_catalog_perms(engine="postgresql")
 
 
 def downgrade():
     op.drop_column("slices", "catalog_perm")
     op.drop_column("tables", "catalog_perm")
-    downgrade_schema_perms(engine="postgresql")
+    downgrade_catalog_perms(engine="postgresql")
diff --git 
a/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py
 
b/superset/migrations/versions/2024-05-08_19-33_4081be5b6b74_enable_catalog_in_databricks.py
similarity index 57%
copy from 
superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py
copy to 
superset/migrations/versions/2024-05-08_19-33_4081be5b6b74_enable_catalog_in_databricks.py
index 17b33e1d0a..f39d6fa0d6 100644
--- 
a/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py
+++ 
b/superset/migrations/versions/2024-05-08_19-33_4081be5b6b74_enable_catalog_in_databricks.py
@@ -14,40 +14,27 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Add catalog_perm to tables
+"""Enable catalog in Databricks
 
-Revision ID: 58d051681a3b
-Revises: 4a33124c18ad
-Create Date: 2024-05-01 10:52:31.458433
+Revision ID: 4081be5b6b74
+Revises: 645bb206f96c
+Create Date: 2024-05-08 19:33:18.311411
 
 """
 
-import sqlalchemy as sa
-from alembic import op
-
 from superset.migrations.shared.catalogs import (
-    downgrade_schema_perms,
-    upgrade_schema_perms,
+    downgrade_catalog_perms,
+    upgrade_catalog_perms,
 )
 
 # revision identifiers, used by Alembic.
-revision = "58d051681a3b"
-down_revision = "4a33124c18ad"
+revision = "4081be5b6b74"
+down_revision = "645bb206f96c"
 
 
 def upgrade():
-    op.add_column(
-        "tables",
-        sa.Column("catalog_perm", sa.String(length=1000), nullable=True),
-    )
-    op.add_column(
-        "slices",
-        sa.Column("catalog_perm", sa.String(length=1000), nullable=True),
-    )
-    upgrade_schema_perms(engine="postgresql")
+    upgrade_catalog_perms(engine="databricks")
 
 
 def downgrade():
-    op.drop_column("slices", "catalog_perm")
-    op.drop_column("tables", "catalog_perm")
-    downgrade_schema_perms(engine="postgresql")
+    downgrade_catalog_perms(engine="databricks")
diff --git a/superset/utils/pandas_postprocessing/contribution.py 
b/superset/utils/pandas_postprocessing/contribution.py
index 46144ec019..ad8b070869 100644
--- a/superset/utils/pandas_postprocessing/contribution.py
+++ b/superset/utils/pandas_postprocessing/contribution.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 from __future__ import annotations
 
 from decimal import Decimal
diff --git a/tests/integration_tests/datasets/api_tests.py 
b/tests/integration_tests/datasets/api_tests.py
index 543a834793..6cc3cc8828 100644
--- a/tests/integration_tests/datasets/api_tests.py
+++ b/tests/integration_tests/datasets/api_tests.py
@@ -368,8 +368,8 @@ class TestDatasetApi(SupersetTestCase):
         expected_result = {
             "cache_timeout": None,
             "database": {
-                "backend": main_db.backend,
                 "allow_multi_catalog": False,
+                "backend": main_db.backend,
                 "database_name": "examples",
                 "id": 1,
             },
diff --git a/tests/unit_tests/db_engine_specs/test_databricks.py 
b/tests/unit_tests/db_engine_specs/test_databricks.py
index 8709833d3f..204faed445 100644
--- a/tests/unit_tests/db_engine_specs/test_databricks.py
+++ b/tests/unit_tests/db_engine_specs/test_databricks.py
@@ -245,3 +245,22 @@ def test_convert_dttm(
     from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec 
as spec
 
     assert_convert_dttm(spec, target_type, expected_result, dttm)
+
+
+def test_get_prequeries() -> None:
+    """
+    Test the ``get_prequeries`` method.
+    """
+    from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
+
+    assert DatabricksNativeEngineSpec.get_prequeries() == []
+    assert DatabricksNativeEngineSpec.get_prequeries(schema="test") == [
+        "USE SCHEMA test",
+    ]
+    assert DatabricksNativeEngineSpec.get_prequeries(catalog="test") == [
+        "USE CATALOG test",
+    ]
+    assert DatabricksNativeEngineSpec.get_prequeries(catalog="foo", 
schema="bar") == [
+        "USE CATALOG foo",
+        "USE SCHEMA bar",
+    ]
diff --git a/tests/unit_tests/migrations/shared/__init__.py 
b/tests/unit_tests/migrations/shared/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/unit_tests/migrations/shared/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/unit_tests/migrations/shared/catalogs_test.py 
b/tests/unit_tests/migrations/shared/catalogs_test.py
new file mode 100644
index 0000000000..ca715bec94
--- /dev/null
+++ b/tests/unit_tests/migrations/shared/catalogs_test.py
@@ -0,0 +1,145 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pytest_mock import MockerFixture
+from sqlalchemy.orm.session import Session
+
+from superset.migrations.shared.catalogs import (
+    downgrade_catalog_perms,
+    upgrade_catalog_perms,
+)
+from superset.migrations.shared.security_converge import ViewMenu
+
+
+def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> 
None:
+    """
+    Test the `upgrade_catalog_perms` function.
+
+    The function is called when catalogs are introduced into a new DB engine 
spec. When
+    that happens, we need to update the `catalog` attribute so it points to 
the default
+    catalog, instead of being `NULL`. We also need to update `schema_perms` to 
include
+    the default catalog.
+    """
+    from superset.connectors.sqla.models import SqlaTable
+    from superset.models.core import Database
+    from superset.models.slice import Slice
+    from superset.models.sql_lab import Query, SavedQuery, TableSchema, 
TabState
+
+    engine = session.get_bind()
+    Database.metadata.create_all(engine)
+
+    mocker.patch("superset.migrations.shared.catalogs.op")
+    db = mocker.patch("superset.migrations.shared.catalogs.db")
+    db.Session.return_value = session
+
+    mocker.patch.object(
+        Database,
+        "get_all_schema_names",
+        return_value=["public", "information_schema"],
+    )
+
+    database = Database(
+        database_name="my_db",
+        sqlalchemy_uri="postgresql://localhost/db",
+    )
+    dataset = SqlaTable(
+        table_name="my_table",
+        database=database,
+        catalog=None,
+        schema="public",
+        schema_perm="[my_db].[public]",
+    )
+    session.add(dataset)
+    session.commit()
+
+    chart = Slice(
+        slice_name="my_chart",
+        datasource_type="table",
+        datasource_id=dataset.id,
+    )
+    query = Query(
+        client_id="foo",
+        database=database,
+        catalog=None,
+        schema="public",
+    )
+    saved_query = SavedQuery(
+        database=database,
+        sql="SELECT * FROM public.t",
+        catalog=None,
+        schema="public",
+    )
+    tab_state = TabState(
+        database=database,
+        catalog=None,
+        schema="public",
+    )
+    table_schema = TableSchema(
+        database=database,
+        catalog=None,
+        schema="public",
+    )
+    session.add_all([chart, query, saved_query, tab_state, table_schema])
+    session.commit()
+
+    # before migration
+    assert dataset.catalog is None
+    assert query.catalog is None
+    assert saved_query.catalog is None
+    assert tab_state.catalog is None
+    assert table_schema.catalog is None
+    assert dataset.schema_perm == "[my_db].[public]"
+    assert chart.schema_perm == "[my_db].[public]"
+    assert session.query(ViewMenu.name).all() == [
+        ("[my_db].(id:1)",),
+        ("[my_db].[my_table](id:1)",),
+        ("[my_db].[public]",),
+    ]
+
+    upgrade_catalog_perms()
+
+    # after migration
+    assert dataset.catalog == "db"
+    assert query.catalog == "db"
+    assert saved_query.catalog == "db"
+    assert tab_state.catalog == "db"
+    assert table_schema.catalog == "db"
+    assert dataset.schema_perm == "[my_db].[db].[public]"
+    assert chart.schema_perm == "[my_db].[db].[public]"
+    assert session.query(ViewMenu.name).all() == [
+        ("[my_db].(id:1)",),
+        ("[my_db].[my_table](id:1)",),
+        ("[my_db].[db].[public]",),
+        ("[my_db].[db]",),
+    ]
+
+    downgrade_catalog_perms()
+
+    # revert
+    assert dataset.catalog is None
+    assert query.catalog is None
+    assert saved_query.catalog is None
+    assert tab_state.catalog is None
+    assert table_schema.catalog is None
+    assert dataset.schema_perm == "[my_db].[public]"
+    assert chart.schema_perm == "[my_db].[public]"
+    assert session.query(ViewMenu.name).all() == [
+        ("[my_db].(id:1)",),
+        ("[my_db].[my_table](id:1)",),
+        ("[my_db].[public]",),
+        ("[my_db].[db]",),
+    ]

Reply via email to