This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch databricks-catalogs
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 493d00af59aaf27447e01bd9ec3de2f3a9868995
Author: Beto Dealmeida <[email protected]>
AuthorDate: Wed May 8 17:06:20 2024 -0400

    feat: catalog support for Databricks native
---
 .../src/components/DatabaseSelector/index.tsx      |  16 ++-
 superset/db_engine_specs/databricks.py             |  32 ++++-
 superset/migrations/shared/catalogs.py             | 143 +++++++++++++++++++--
 ...33_4081be5b6b74_enable_catalog_in_databricks.py |  40 ++++++
 .../utils/pandas_postprocessing/contribution.py    |   3 +
 .../unit_tests/db_engine_specs/test_databricks.py  |  19 +++
 6 files changed, 237 insertions(+), 16 deletions(-)

diff --git a/superset-frontend/src/components/DatabaseSelector/index.tsx 
b/superset-frontend/src/components/DatabaseSelector/index.tsx
index 6eb1340d5b..77878ce070 100644
--- a/superset-frontend/src/components/DatabaseSelector/index.tsx
+++ b/superset-frontend/src/components/DatabaseSelector/index.tsx
@@ -143,7 +143,7 @@ export default function DatabaseSelector({
   const showCatalogSelector = !!db?.allow_multi_catalog;
   const [currentDb, setCurrentDb] = useState<DatabaseValue | undefined>();
   const [currentCatalog, setCurrentCatalog] = useState<
-    CatalogOption | undefined
+    CatalogOption | null | undefined
   >(catalog ? { label: catalog, value: catalog, title: catalog } : undefined);
   const catalogRef = useRef(catalog);
   catalogRef.current = catalog;
@@ -265,7 +265,7 @@ export default function DatabaseSelector({
 
   const schemaOptions = schemaData || EMPTY_SCHEMA_OPTIONS;
 
-  function changeCatalog(catalog: CatalogOption | undefined) {
+  function changeCatalog(catalog: CatalogOption | null | undefined) {
     setCurrentCatalog(catalog);
     setCurrentSchema(undefined);
     if (onCatalogChange && catalog?.value !== catalogRef.current) {
@@ -280,7 +280,9 @@ export default function DatabaseSelector({
   } = useCatalogs({
     dbId: currentDb?.value,
     onSuccess: (catalogs, isFetched) => {
-      if (catalogs.length === 1) {
+      if (!showCatalogSelector) {
+        changeCatalog(null);
+      } else if (catalogs.length === 1) {
         changeCatalog(catalogs[0]);
       } else if (
         !catalogs.find(
@@ -290,11 +292,15 @@ export default function DatabaseSelector({
         changeCatalog(undefined);
       }
 
-      if (isFetched) {
+      if (showCatalogSelector && isFetched) {
         addSuccessToast('List refreshed');
       }
     },
-    onError: () => handleError(t('There was an error loading the catalogs')),
+    onError: () => {
+      if (showCatalogSelector) {
+        handleError(t('There was an error loading the catalogs'));
+      }
+    },
   });
 
   const catalogOptions = catalogData || EMPTY_CATALOG_OPTIONS;
diff --git a/superset/db_engine_specs/databricks.py 
b/superset/db_engine_specs/databricks.py
index 4b2f93ca5d..f2f5b81dcd 100644
--- a/superset/db_engine_specs/databricks.py
+++ b/superset/db_engine_specs/databricks.py
@@ -39,7 +39,6 @@ if TYPE_CHECKING:
     from superset.models.core import Database
 
 
-#
 class DatabricksParametersSchema(Schema):
     """
     This is the list of fields that are expected
@@ -160,6 +159,8 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, 
DatabricksODBCEngineSpec)
     )
     encryption_parameters = {"ssl": "1"}
 
+    supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = 
True
+
     @staticmethod
     def get_extra_params(database: Database) -> dict[str, Any]:
         """
@@ -367,3 +368,32 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, 
DatabricksODBCEngineSpec)
         )
         spec.components.schema(cls.__name__, schema=cls.properties_schema)
         return spec.to_dict()["components"]["schemas"][cls.__name__]
+
+    @classmethod
+    def get_default_catalog(
+        cls,
+        database: Database,
+    ) -> str | None:
+        with database.get_inspector() as inspector:
+            return inspector.bind.execute("SELECT current_catalog()").scalar()
+
+    @classmethod
+    def get_prequeries(
+        cls,
+        catalog: str | None = None,
+        schema: str | None = None,
+    ) -> list[str]:
+        prequeries = []
+        if catalog:
+            prequeries.append(f"USE CATALOG {catalog}")
+        if schema:
+            prequeries.append(f"USE SCHEMA {schema}")
+        return prequeries
+
+    @classmethod
+    def get_catalog_names(
+        cls,
+        database: Database,
+        inspector: Inspector,
+    ) -> set[str]:
+        return {catalog for (catalog,) in inspector.bind.execute("SHOW 
CATALOGS")}
diff --git a/superset/migrations/shared/catalogs.py 
b/superset/migrations/shared/catalogs.py
index 5d01ecfbfb..93d352b30e 100644
--- a/superset/migrations/shared/catalogs.py
+++ b/superset/migrations/shared/catalogs.py
@@ -19,13 +19,72 @@ from __future__ import annotations
 
 import logging
 
+import sqlalchemy as sa
 from alembic import op
+from flask import current_app
+from sqlalchemy.ext.declarative import declarative_base
 
 from superset import db, security_manager
 from superset.daos.database import DatabaseDAO
+from superset.migrations.shared.security_converge import add_pvms, ViewMenu
 from superset.models.core import Database
 
 logger = logging.getLogger(__name__)
+custom_password_store = current_app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
+
+
+Base = declarative_base()
+
+
+class SqlaTable(Base):
+    __tablename__ = "tables"
+
+    id = sa.Column(sa.Integer, primary_key=True)
+    database_id = sa.Column(sa.Integer, nullable=False)
+    schema_perm = sa.Column(sa.String(1000))
+    schema = sa.Column(sa.String(255))
+    catalog = sa.Column(sa.String(256), nullable=True, default=None)
+
+
+class Query(Base):
+    __tablename__ = "query"
+
+    id = sa.Column(sa.Integer, primary_key=True)
+    database_id = sa.Column(sa.Integer, nullable=False)
+    catalog = sa.Column(sa.String(256), nullable=True, default=None)
+
+
+class SavedQuery(Base):
+    __tablename__ = "saved_query"
+
+    id = sa.Column(sa.Integer, primary_key=True)
+    db_id = sa.Column(sa.Integer, nullable=False)
+    catalog = sa.Column(sa.String(256), nullable=True, default=None)
+
+
+class TabState(Base):
+    __tablename__ = "tab_state"
+
+    id = sa.Column(sa.Integer, primary_key=True)
+    database_id = sa.Column(sa.Integer, nullable=False)
+    catalog = sa.Column(sa.String(256), nullable=True, default=None)
+
+
+class TableSchema(Base):
+    __tablename__ = "table_schema"
+
+    id = sa.Column(sa.Integer, primary_key=True)
+    database_id = sa.Column(sa.Integer, nullable=False)
+    catalog = sa.Column(sa.String(256), nullable=True, default=None)
+
+
+class Slice(Base):
+    __tablename__ = "slices"
+
+    id = sa.Column(sa.Integer, primary_key=True)
+    datasource_id = sa.Column(sa.Integer)
+    datasource_type = sa.Column(sa.String(200))
+    schema_perm = sa.Column(sa.String(1000))
 
 
 def upgrade_schema_perms(engine: str | None = None) -> None:
@@ -35,6 +94,8 @@ def upgrade_schema_perms(engine: str | None = None) -> None:
     Before SIP-95 schema permissions were stored in the format 
`[db].[schema]`. With the
     introduction of catalogs, any existing permissions need to be renamed to 
include the
     catalog: `[db].[catalog].[schema]`.
+
+    This also points existing models to the correct catalog.
     """
     bind = op.get_bind()
     session = db.Session(bind=bind)
@@ -46,6 +107,16 @@ def upgrade_schema_perms(engine: str | None = None) -> None:
             continue
 
         catalog = database.get_default_catalog()
+        if catalog is None:
+            continue
+
+        perm = security_manager.get_catalog_perm(
+            database.database_name,
+            catalog,
+        )
+        add_pvms(session, {perm: ("catalog_access",)})
+
+        # update schema_perms
         ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
         for schema in database.get_all_schema_names(
             catalog=catalog,
@@ -57,17 +128,41 @@ def upgrade_schema_perms(engine: str | None = None) -> 
None:
                 None,
                 schema,
             )
-            existing_pvm = security_manager.find_permission_view_menu(
-                "schema_access",
-                perm,
-            )
+            existing_pvm = 
session.query(ViewMenu).filter_by(name=perm).one_or_none()
             if existing_pvm:
-                existing_pvm.view_menu.name = security_manager.get_schema_perm(
+                existing_pvm.name = security_manager.get_schema_perm(
                     database.database_name,
                     catalog,
                     schema,
                 )
 
+        # update existing models
+        models = [
+            (Query, "database_id"),
+            (SavedQuery, "db_id"),
+            (TabState, "database_id"),
+            (TableSchema, "database_id"),
+            (SqlaTable, "database_id"),
+        ]
+        for model, column in models:
+            for instance in session.query(model).filter(
+                getattr(model, column) == database.id
+            ):
+                instance.catalog = catalog
+
+        for table in 
session.query(SqlaTable).filter_by(database_id=database.id):
+            schema_perm = security_manager.get_schema_perm(
+                database.database_name,
+                catalog,
+                table.schema,
+            )
+            table.schema_perm = schema_perm
+            for chart in session.query(Slice).filter_by(
+                datasource_id=table.id,
+                datasource_type="table",
+            ):
+                chart.schema_perm = schema_perm
+
     session.commit()
 
 
@@ -91,6 +186,10 @@ def downgrade_schema_perms(engine: str | None = None) -> 
None:
             continue
 
         catalog = database.get_default_catalog()
+        if catalog is None:
+            continue
+
+        # update schema_perms
         ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
         for schema in database.get_all_schema_names(
             catalog=catalog,
@@ -102,15 +201,39 @@ def downgrade_schema_perms(engine: str | None = None) -> 
None:
                 catalog,
                 schema,
             )
-            existing_pvm = security_manager.find_permission_view_menu(
-                "schema_access",
-                perm,
-            )
+            existing_pvm = 
session.query(ViewMenu).filter_by(name=perm).one_or_none()
             if existing_pvm:
-                existing_pvm.view_menu.name = security_manager.get_schema_perm(
+                existing_pvm.name = security_manager.get_schema_perm(
                     database.database_name,
                     None,
                     schema,
                 )
 
+        # update existing models
+        models = [
+            (Query, "database_id"),
+            (SavedQuery, "db_id"),
+            (TabState, "database_id"),
+            (TableSchema, "database_id"),
+            (SqlaTable, "database_id"),
+        ]
+        for model, column in models:
+            for instance in session.query(model).filter(
+                getattr(model, column) == database.id
+            ):
+                instance.catalog = None
+
+        for table in 
session.query(SqlaTable).filter_by(database_id=database.id):
+            schema_perm = security_manager.get_schema_perm(
+                database.database_name,
+                None,
+                table.schema,
+            )
+            table.schema_perm = schema_perm
+            for chart in session.query(Slice).filter_by(
+                datasource_id=table.id,
+                datasource_type="table",
+            ):
+                chart.schema_perm = schema_perm
+
     session.commit()
diff --git 
a/superset/migrations/versions/2024-05-08_19-33_4081be5b6b74_enable_catalog_in_databricks.py
 
b/superset/migrations/versions/2024-05-08_19-33_4081be5b6b74_enable_catalog_in_databricks.py
new file mode 100644
index 0000000000..36e202b71e
--- /dev/null
+++ 
b/superset/migrations/versions/2024-05-08_19-33_4081be5b6b74_enable_catalog_in_databricks.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Enable catalog in Databricks
+
+Revision ID: 4081be5b6b74
+Revises: 645bb206f96c
+Create Date: 2024-05-08 19:33:18.311411
+
+"""
+
+from superset.migrations.shared.catalogs import (
+    downgrade_schema_perms,
+    upgrade_schema_perms,
+)
+
+# revision identifiers, used by Alembic.
+revision = "4081be5b6b74"
+down_revision = "645bb206f96c"
+
+
+def upgrade():
+    upgrade_schema_perms(engine="databricks")
+
+
+def downgrade():
+    downgrade_schema_perms(engine="databricks")
diff --git a/superset/utils/pandas_postprocessing/contribution.py 
b/superset/utils/pandas_postprocessing/contribution.py
index 89a1413b74..ad8b070869 100644
--- a/superset/utils/pandas_postprocessing/contribution.py
+++ b/superset/utils/pandas_postprocessing/contribution.py
@@ -14,6 +14,9 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+from __future__ import annotations
+
 from decimal import Decimal
 from typing import Any
 
diff --git a/tests/unit_tests/db_engine_specs/test_databricks.py 
b/tests/unit_tests/db_engine_specs/test_databricks.py
index de06f919be..03379f5c2e 100644
--- a/tests/unit_tests/db_engine_specs/test_databricks.py
+++ b/tests/unit_tests/db_engine_specs/test_databricks.py
@@ -245,3 +245,22 @@ def test_convert_dttm(
     from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec 
as spec
 
     assert_convert_dttm(spec, target_type, expected_result, dttm)
+
+
+def test_get_prequeries() -> None:
+    """
+    Test the ``get_prequeries`` method.
+    """
+    from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
+
+    assert DatabricksNativeEngineSpec.get_prequeries() == []
+    assert DatabricksNativeEngineSpec.get_prequeries(schema="test") == [
+        "USE SCHEMA test",
+    ]
+    assert DatabricksNativeEngineSpec.get_prequeries(catalog="test") == [
+        "USE CATALOG test",
+    ]
+    assert DatabricksNativeEngineSpec.get_prequeries(catalog="foo", 
schema="bar") == [
+        "USE CATALOG foo",
+        "USE SCHEMA bar",
+    ]

Reply via email to