This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch databricks-catalogs in repository https://gitbox.apache.org/repos/asf/superset.git
commit 493d00af59aaf27447e01bd9ec3de2f3a9868995 Author: Beto Dealmeida <[email protected]> AuthorDate: Wed May 8 17:06:20 2024 -0400 feat: catalog support for Databricks native --- .../src/components/DatabaseSelector/index.tsx | 16 ++- superset/db_engine_specs/databricks.py | 32 ++++- superset/migrations/shared/catalogs.py | 143 +++++++++++++++++++-- ...33_4081be5b6b74_enable_catalog_in_databricks.py | 40 ++++++ .../utils/pandas_postprocessing/contribution.py | 3 + .../unit_tests/db_engine_specs/test_databricks.py | 19 +++ 6 files changed, 237 insertions(+), 16 deletions(-) diff --git a/superset-frontend/src/components/DatabaseSelector/index.tsx b/superset-frontend/src/components/DatabaseSelector/index.tsx index 6eb1340d5b..77878ce070 100644 --- a/superset-frontend/src/components/DatabaseSelector/index.tsx +++ b/superset-frontend/src/components/DatabaseSelector/index.tsx @@ -143,7 +143,7 @@ export default function DatabaseSelector({ const showCatalogSelector = !!db?.allow_multi_catalog; const [currentDb, setCurrentDb] = useState<DatabaseValue | undefined>(); const [currentCatalog, setCurrentCatalog] = useState< - CatalogOption | undefined + CatalogOption | null | undefined >(catalog ? { label: catalog, value: catalog, title: catalog } : undefined); const catalogRef = useRef(catalog); catalogRef.current = catalog; @@ -265,7 +265,7 @@ export default function DatabaseSelector({ const schemaOptions = schemaData || EMPTY_SCHEMA_OPTIONS; - function changeCatalog(catalog: CatalogOption | undefined) { + function changeCatalog(catalog: CatalogOption | null | undefined) { setCurrentCatalog(catalog); setCurrentSchema(undefined); if (onCatalogChange && catalog?.value !== catalogRef.current) { @@ -280,7 +280,9 @@ export default function DatabaseSelector({ } = useCatalogs({ dbId: currentDb?.value, onSuccess: (catalogs, isFetched) => { - if (catalogs.length === 1) { + if (!showCatalogSelector) { + changeCatalog(null); + } else if (catalogs.length === 1) { changeCatalog(catalogs[0]); } else if ( !catalogs.find( @@ -290,11 +292,15 @@ export default function DatabaseSelector({ changeCatalog(undefined); } - if (isFetched) { + if (showCatalogSelector && isFetched) { addSuccessToast('List refreshed'); } }, - onError: () => handleError(t('There was an error loading the catalogs')), + onError: () => { + if (showCatalogSelector) { + handleError(t('There was an error loading the catalogs')); + } + }, }); const catalogOptions = catalogData || EMPTY_CATALOG_OPTIONS; diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index 4b2f93ca5d..f2f5b81dcd 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -39,7 +39,6 @@ if TYPE_CHECKING: from superset.models.core import Database -# class DatabricksParametersSchema(Schema): """ This is the list of fields that are expected @@ -160,6 +159,8 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec) ) encryption_parameters = {"ssl": "1"} + supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True + @staticmethod def get_extra_params(database: Database) -> dict[str, Any]: """ @@ -367,3 +368,32 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec) ) spec.components.schema(cls.__name__, schema=cls.properties_schema) return spec.to_dict()["components"]["schemas"][cls.__name__] + + @classmethod + def get_default_catalog( + cls, + database: Database, + ) -> str | None: + with database.get_inspector() as inspector: + return inspector.bind.execute("SELECT current_catalog()").scalar() + + @classmethod + def get_prequeries( + cls, + catalog: str | None = None, + schema: str | None = None, + ) -> list[str]: + prequeries = [] + if catalog: + prequeries.append(f"USE CATALOG {catalog}") + if schema: + prequeries.append(f"USE SCHEMA {schema}") + return prequeries + + @classmethod + def get_catalog_names( + cls, + database: Database, + inspector: Inspector, + ) -> set[str]: + return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")} diff --git a/superset/migrations/shared/catalogs.py b/superset/migrations/shared/catalogs.py index 5d01ecfbfb..93d352b30e 100644 --- a/superset/migrations/shared/catalogs.py +++ b/superset/migrations/shared/catalogs.py @@ -19,13 +19,72 @@ from __future__ import annotations import logging +import sqlalchemy as sa from alembic import op +from flask import current_app +from sqlalchemy.ext.declarative import declarative_base from superset import db, security_manager from superset.daos.database import DatabaseDAO +from superset.migrations.shared.security_converge import add_pvms, ViewMenu from superset.models.core import Database logger = logging.getLogger(__name__) +custom_password_store = current_app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] + + +Base = declarative_base() + + +class SqlaTable(Base): + __tablename__ = "tables" + + id = sa.Column(sa.Integer, primary_key=True) + database_id = sa.Column(sa.Integer, nullable=False) + schema_perm = sa.Column(sa.String(1000)) + schema = sa.Column(sa.String(255)) + catalog = sa.Column(sa.String(256), nullable=True, default=None) + + +class Query(Base): + __tablename__ = "query" + + id = sa.Column(sa.Integer, primary_key=True) + database_id = sa.Column(sa.Integer, nullable=False) + catalog = sa.Column(sa.String(256), nullable=True, default=None) + + +class SavedQuery(Base): + __tablename__ = "saved_query" + + id = sa.Column(sa.Integer, primary_key=True) + db_id = sa.Column(sa.Integer, nullable=False) + catalog = sa.Column(sa.String(256), nullable=True, default=None) + + +class TabState(Base): + __tablename__ = "tab_state" + + id = sa.Column(sa.Integer, primary_key=True) + database_id = sa.Column(sa.Integer, nullable=False) + catalog = sa.Column(sa.String(256), nullable=True, default=None) + + +class TableSchema(Base): + __tablename__ = "table_schema" + + id = sa.Column(sa.Integer, primary_key=True) + database_id = sa.Column(sa.Integer, nullable=False) + catalog = sa.Column(sa.String(256), nullable=True, default=None) + + +class Slice(Base): + __tablename__ = "slices" + + id = sa.Column(sa.Integer, primary_key=True) + datasource_id = sa.Column(sa.Integer) + datasource_type = sa.Column(sa.String(200)) + schema_perm = sa.Column(sa.String(1000)) def upgrade_schema_perms(engine: str | None = None) -> None: @@ -35,6 +94,8 @@ def upgrade_schema_perms(engine: str | None = None) -> None: Before SIP-95 schema permissions were stored in the format `[db].[schema]`. With the introduction of catalogs, any existing permissions need to be renamed to include the catalog: `[db].[catalog].[schema]`. + + This also points existing models to the correct catalog. """ bind = op.get_bind() session = db.Session(bind=bind) @@ -46,6 +107,16 @@ def upgrade_schema_perms(engine: str | None = None) -> None: continue catalog = database.get_default_catalog() + if catalog is None: + continue + + perm = security_manager.get_catalog_perm( + database.database_name, + catalog, + ) + add_pvms(session, {perm: ("catalog_access",)}) + + # update schema_perms ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) for schema in database.get_all_schema_names( catalog=catalog, @@ -57,17 +128,41 @@ def upgrade_schema_perms(engine: str | None = None) -> None: None, schema, ) - existing_pvm = security_manager.find_permission_view_menu( - "schema_access", - perm, - ) + existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none() if existing_pvm: - existing_pvm.view_menu.name = security_manager.get_schema_perm( + existing_pvm.name = security_manager.get_schema_perm( database.database_name, catalog, schema, ) + # update existing models + models = [ + (Query, "database_id"), + (SavedQuery, "db_id"), + (TabState, "database_id"), + (TableSchema, "database_id"), + (SqlaTable, "database_id"), + ] + for model, column in models: + for instance in session.query(model).filter( + getattr(model, column) == database.id + ): + instance.catalog = catalog + + for table in session.query(SqlaTable).filter_by(database_id=database.id): + schema_perm = security_manager.get_schema_perm( + database.database_name, + catalog, + table.schema, + ) + table.schema_perm = schema_perm + for chart in session.query(Slice).filter_by( + datasource_id=table.id, + datasource_type="table", + ): + chart.schema_perm = schema_perm + session.commit() @@ -91,6 +186,10 @@ def downgrade_schema_perms(engine: str | None = None) -> None: continue catalog = database.get_default_catalog() + if catalog is None: + continue + + # update schema_perms ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) for schema in database.get_all_schema_names( catalog=catalog, @@ -102,15 +201,39 @@ def downgrade_schema_perms(engine: str | None = None) -> None: catalog, schema, ) - existing_pvm = security_manager.find_permission_view_menu( - "schema_access", - perm, - ) + existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none() if existing_pvm: - existing_pvm.view_menu.name = security_manager.get_schema_perm( + existing_pvm.name = security_manager.get_schema_perm( database.database_name, None, schema, ) + # update existing models + models = [ + (Query, "database_id"), + (SavedQuery, "db_id"), + (TabState, "database_id"), + (TableSchema, "database_id"), + (SqlaTable, "database_id"), + ] + for model, column in models: + for instance in session.query(model).filter( + getattr(model, column) == database.id + ): + instance.catalog = None + + for table in session.query(SqlaTable).filter_by(database_id=database.id): + schema_perm = security_manager.get_schema_perm( + database.database_name, + None, + table.schema, + ) + table.schema_perm = schema_perm + for chart in session.query(Slice).filter_by( + datasource_id=table.id, + datasource_type="table", + ): + chart.schema_perm = schema_perm + session.commit() diff --git a/superset/migrations/versions/2024-05-08_19-33_4081be5b6b74_enable_catalog_in_databricks.py b/superset/migrations/versions/2024-05-08_19-33_4081be5b6b74_enable_catalog_in_databricks.py new file mode 100644 index 0000000000..36e202b71e --- /dev/null +++ b/superset/migrations/versions/2024-05-08_19-33_4081be5b6b74_enable_catalog_in_databricks.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Enable catalog in Databricks + +Revision ID: 4081be5b6b74 +Revises: 645bb206f96c +Create Date: 2024-05-08 19:33:18.311411 + +""" + +from superset.migrations.shared.catalogs import ( + downgrade_schema_perms, + upgrade_schema_perms, +) + +# revision identifiers, used by Alembic. +revision = "4081be5b6b74" +down_revision = "645bb206f96c" + + +def upgrade(): + upgrade_schema_perms(engine="databricks") + + +def downgrade(): + downgrade_schema_perms(engine="databricks") diff --git a/superset/utils/pandas_postprocessing/contribution.py b/superset/utils/pandas_postprocessing/contribution.py index 89a1413b74..ad8b070869 100644 --- a/superset/utils/pandas_postprocessing/contribution.py +++ b/superset/utils/pandas_postprocessing/contribution.py @@ -14,6 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from __future__ import annotations + from decimal import Decimal from typing import Any diff --git a/tests/unit_tests/db_engine_specs/test_databricks.py b/tests/unit_tests/db_engine_specs/test_databricks.py index de06f919be..03379f5c2e 100644 --- a/tests/unit_tests/db_engine_specs/test_databricks.py +++ b/tests/unit_tests/db_engine_specs/test_databricks.py @@ -245,3 +245,22 @@ def test_convert_dttm( from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec as spec assert_convert_dttm(spec, target_type, expected_result, dttm) + + +def test_get_prequeries() -> None: + """ + Test the ``get_prequeries`` method. + """ + from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec + + assert DatabricksNativeEngineSpec.get_prequeries() == [] + assert DatabricksNativeEngineSpec.get_prequeries(schema="test") == [ + "USE SCHEMA test", + ] + assert DatabricksNativeEngineSpec.get_prequeries(catalog="test") == [ + "USE CATALOG test", + ] + assert DatabricksNativeEngineSpec.get_prequeries(catalog="foo", schema="bar") == [ + "USE CATALOG foo", + "USE SCHEMA bar", + ]
