This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch postgres-catalog in repository https://gitbox.apache.org/repos/asf/superset.git
commit 175ca5ca4db186a67e63ef2acebb3652f82cda31 Author: Beto Dealmeida <[email protected]> AuthorDate: Sat Apr 27 13:17:44 2024 -0400 WIP --- superset/commands/database/create.py | 33 +++- .../commands/database/ssh_tunnel/exceptions.py | 26 +-- superset/commands/database/update.py | 156 +++++----------- superset/config.py | 22 +-- superset/connectors/sqla/models.py | 10 +- superset/constants.py | 1 + superset/databases/api.py | 83 ++++++++- superset/databases/schemas.py | 14 ++ superset/db_engine_specs/base.py | 21 ++- superset/db_engine_specs/bigquery.py | 4 +- superset/db_engine_specs/impala.py | 7 +- superset/db_engine_specs/postgres.py | 34 +++- superset/db_engine_specs/presto.py | 6 +- superset/db_engine_specs/snowflake.py | 6 +- superset/extensions/metadb.py | 5 - superset/models/core.py | 64 +++++-- superset/security/manager.py | 197 ++++++++++++++++++--- superset/utils/cache.py | 10 +- superset/views/database/mixins.py | 28 ++- 19 files changed, 508 insertions(+), 219 deletions(-) diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index 4903938eb9..13d44b04e7 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -97,12 +97,35 @@ class CreateDatabaseCommand(BaseCommand): db.session.commit() - # adding a new database we always want to force refresh schema list - schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel) - for schema in schemas: - security_manager.add_permission_view_menu( - "schema_access", security_manager.get_schema_perm(database, schema) + # add catalog/schema permissions + if database.db_engine_spec.supports_catalog: + catalogs = database.get_all_catalog_names( + cache=False, + ssh_tunnel=ssh_tunnel, ) + for catalog in catalogs: + security_manager.add_permission_view_menu( + "catalog_access", + security_manager.get_catalog_perm(database, catalog), + ) + else: + # add a dummy catalog for DBs that don't support them + catalogs = [None] + + for catalog in catalogs: + for schema in database.get_all_schema_names( + catalog=catalog, + cache=False, + ssh_tunnel=ssh_tunnel, + ): + security_manager.add_permission_view_menu( + "schema_access", + security_manager.get_schema_perm( + database.database_name, + catalog, + schema, + ), + ) except ( SSHTunnelInvalidError, diff --git a/superset/commands/database/ssh_tunnel/exceptions.py b/superset/commands/database/ssh_tunnel/exceptions.py index a0def8c087..9dc4adbdb0 100644 --- a/superset/commands/database/ssh_tunnel/exceptions.py +++ b/superset/commands/database/ssh_tunnel/exceptions.py @@ -25,37 +25,43 @@ from superset.commands.exceptions import ( ) -class SSHTunnelDeleteFailedError(DeleteFailedError): +class SSHTunnelError(Exception): + """ + Base class. + """ + + +class SSHTunnelDeleteFailedError(SSHTunnelError, DeleteFailedError): message = _("SSH Tunnel could not be deleted.") -class SSHTunnelNotFoundError(CommandException): +class SSHTunnelNotFoundError(SSHTunnelError, CommandException): status = 404 message = _("SSH Tunnel not found.") -class SSHTunnelInvalidError(CommandInvalidError): +class SSHTunnelInvalidError(SSHTunnelError, CommandInvalidError): message = _("SSH Tunnel parameters are invalid.") -class SSHTunnelDatabasePortError(CommandInvalidError): +class SSHTunnelDatabasePortError(SSHTunnelError, CommandInvalidError): message = _("A database port is required when connecting via SSH Tunnel.") -class SSHTunnelUpdateFailedError(UpdateFailedError): +class SSHTunnelUpdateFailedError(SSHTunnelError, UpdateFailedError): message = _("SSH Tunnel could not be updated.") -class SSHTunnelCreateFailedError(CommandException): +class SSHTunnelCreateFailedError(SSHTunnelError, CommandException): message = _("Creating SSH Tunnel failed for an unknown reason") -class SSHTunnelingNotEnabledError(CommandException): +class SSHTunnelingNotEnabledError(SSHTunnelError, CommandException): status = 400 message = _("SSH Tunneling is not enabled") -class SSHTunnelRequiredFieldValidationError(ValidationError): +class SSHTunnelRequiredFieldValidationError(SSHTunnelError, ValidationError): def __init__(self, field_name: str) -> None: super().__init__( [_("Field is required")], @@ -63,9 +69,9 @@ class SSHTunnelRequiredFieldValidationError(ValidationError): ) -class SSHTunnelMissingCredentials(CommandInvalidError): +class SSHTunnelMissingCredentials(SSHTunnelError, CommandInvalidError): message = _("Must provide credentials for the SSH Tunnel") -class SSHTunnelInvalidCredentials(CommandInvalidError): +class SSHTunnelInvalidCredentials(SSHTunnelError, CommandInvalidError): message = _("Cannot have multiple credentials for the SSH Tunnel") diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index b057cb300e..80e180e099 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -35,6 +35,7 @@ from superset.commands.database.ssh_tunnel.exceptions import ( SSHTunnelCreateFailedError, SSHTunnelDatabasePortError, SSHTunnelDeleteFailedError, + SSHTunnelError, SSHTunnelingNotEnabledError, SSHTunnelInvalidError, SSHTunnelUpdateFailedError, @@ -42,6 +43,7 @@ from superset.commands.database.ssh_tunnel.exceptions import ( from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand from superset.daos.database import DatabaseDAO from superset.daos.exceptions import DAOCreateFailedError, DAOUpdateFailedError +from superset.databases.ssh_tunnel.models import SSHTunnel from superset.extensions import db, security_manager from superset.models.core import Database from superset.utils.core import DatasourceType @@ -57,7 +59,7 @@ class UpdateDatabaseCommand(BaseCommand): self._model_id = model_id self._model: Optional[Database] = None - def run(self) -> Model: # pylint: disable=too-many-statements, too-many-branches + def run(self) -> Model: # pylint: disable=too-many-branches self._model = DatabaseDAO.find_by_id(self._model_id) if not self._model: @@ -68,134 +70,56 @@ class UpdateDatabaseCommand(BaseCommand): old_database_name = self._model.database_name # unmask ``encrypted_extra`` - self._properties["encrypted_extra"] = ( - self._model.db_engine_spec.unmask_encrypted_extra( - self._model.encrypted_extra, - self._properties.pop("masked_encrypted_extra", "{}"), - ) + self._properties[ + "encrypted_extra" + ] = self._model.db_engine_spec.unmask_encrypted_extra( + self._model.encrypted_extra, + self._properties.pop("masked_encrypted_extra", "{}"), ) try: database = DatabaseDAO.update(self._model, self._properties, commit=False) database.set_sqlalchemy_uri(database.sqlalchemy_uri) - ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) - - if "ssh_tunnel" in self._properties: - if not is_feature_enabled("SSH_TUNNELING"): - db.session.rollback() - raise SSHTunnelingNotEnabledError() - - if self._properties.get("ssh_tunnel") is None and ssh_tunnel: - # We need to remove the existing tunnel - try: - DeleteSSHTunnelCommand(ssh_tunnel.id).run() - ssh_tunnel = None - except SSHTunnelDeleteFailedError as ex: - raise ex - except Exception as ex: - raise DatabaseUpdateFailedError() from ex - - if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): - if ssh_tunnel is None: - # We couldn't found an existing tunnel so we need to create one - try: - ssh_tunnel = CreateSSHTunnelCommand( - database, ssh_tunnel_properties - ).run() - except ( - SSHTunnelInvalidError, - SSHTunnelCreateFailedError, - SSHTunnelDatabasePortError, - ) as ex: - # So we can show the original message - raise ex - except Exception as ex: - raise DatabaseUpdateFailedError() from ex - else: - # We found an existing tunnel so we need to update it - try: - ssh_tunnel_id = ssh_tunnel.id - ssh_tunnel = UpdateSSHTunnelCommand( - ssh_tunnel_id, ssh_tunnel_properties - ).run() - except ( - SSHTunnelInvalidError, - SSHTunnelUpdateFailedError, - SSHTunnelDatabasePortError, - ) as ex: - # So we can show the original message - raise ex - except Exception as ex: - raise DatabaseUpdateFailedError() from ex - - # adding a new database we always want to force refresh schema list - # TODO Improve this simplistic implementation for catching DB conn fails try: - schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel) + self._handle_ssh_tunnel(database) + except SSHTunnelError: + raise except Exception as ex: - db.session.rollback() - raise DatabaseConnectionFailedError() from ex - - # Update database schema permissions - new_schemas: list[str] = [] - - for schema in schemas: - old_view_menu_name = security_manager.get_schema_perm( - old_database_name, schema - ) - new_view_menu_name = security_manager.get_schema_perm( - database.database_name, schema - ) - schema_pvm = security_manager.find_permission_view_menu( - "schema_access", old_view_menu_name - ) - # Update the schema permission if the database name changed - if schema_pvm and old_database_name != database.database_name: - schema_pvm.view_menu.name = new_view_menu_name - - self._propagate_schema_permissions( - old_view_menu_name, new_view_menu_name - ) - else: - new_schemas.append(schema) - for schema in new_schemas: - security_manager.add_permission_view_menu( - "schema_access", security_manager.get_schema_perm(database, schema) - ) - - db.session.commit() + raise DatabaseUpdateFailedError() from ex except (DAOUpdateFailedError, DAOCreateFailedError) as ex: raise DatabaseUpdateFailedError() from ex - return database - @staticmethod - def _propagate_schema_permissions( - old_view_menu_name: str, new_view_menu_name: str - ) -> None: - from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel - SqlaTable, - ) - from superset.models.slice import ( # pylint: disable=import-outside-toplevel - Slice, - ) + return database - # Update schema_perm on all datasets - datasets = ( - db.session.query(SqlaTable) - .filter(SqlaTable.schema_perm == old_view_menu_name) - .all() - ) - for dataset in datasets: - dataset.schema_perm = new_view_menu_name - charts = db.session.query(Slice).filter( - Slice.datasource_type == DatasourceType.TABLE, - Slice.datasource_id == dataset.id, - ) - # Update schema_perm on all charts - for chart in charts: - chart.schema_perm = new_view_menu_name + def _handle_ssh_tunnel(self, database: Database) -> None: + """ + Delete, create, or update an SSH tunnel. + """ + if not is_feature_enabled("SSH_TUNNELING"): + db.session.rollback() + raise SSHTunnelingNotEnabledError() + + if "ssh_tunnel" not in self._properties: + return + + current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) + ssh_tunnel_properties = self._properties["ssh_tunnel"] + + if ssh_tunnel_properties is None: + if current_ssh_tunnel: + DeleteSSHTunnelCommand(current_ssh_tunnel.id).run() + return + + if current_ssh_tunnel is None: + CreateSSHTunnelCommand(database, ssh_tunnel_properties).run() + return + + UpdateSSHTunnelCommand( + current_ssh_tunnel.id, + ssh_tunnel_properties, + ).run() def validate(self) -> None: exceptions: list[ValidationError] = [] diff --git a/superset/config.py b/superset/config.py index 9388edbe84..d492b0e15a 100644 --- a/superset/config.py +++ b/superset/config.py @@ -564,9 +564,9 @@ IS_FEATURE_ENABLED_FUNC: Callable[[str, bool | None], bool] | None = None # # Takes as a parameter the common bootstrap payload before transformations. # Returns a dict containing data that should be added or overridden to the payload. -COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[[dict[str, Any]], dict[str, Any]] = ( # noqa: E731 - lambda data: {} -) # default: empty dict +COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[ + [dict[str, Any]], dict[str, Any] +] = lambda data: {} # noqa: E731 # default: empty dict # EXTRA_CATEGORICAL_COLOR_SCHEMES is used for adding custom categorical color schemes # example code for "My custom warm to hot" color scheme @@ -1081,7 +1081,9 @@ UPLOADED_CSV_HIVE_NAMESPACE: str | None = None # db configuration and a result of this function. # mypy doesn't catch that if case ensures list content being always str -ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], list[str]] = ( # noqa: E731 +ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[ + [Database, models.User], list[str] +] = ( # noqa: E731 lambda database, user: [UPLOADED_CSV_HIVE_NAMESPACE] if UPLOADED_CSV_HIVE_NAMESPACE else [] @@ -1594,9 +1596,9 @@ GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000 GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS = True GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token" GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False -GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: None | (Literal["None", "Lax", "Strict"]) = ( - None -) +GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: None | ( + Literal["None", "Lax", "Strict"] +) = None GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN = None GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me" GLOBAL_ASYNC_QUERIES_TRANSPORT: Literal["polling", "ws"] = "polling" @@ -1658,9 +1660,9 @@ ADVANCED_DATA_TYPES: dict[str, AdvancedDataType] = { # "Xyz", # [{"col": 'created_by', "opr": 'rel_o_m', "value": 10}], # ) -WELCOME_PAGE_LAST_TAB: Literal["examples", "all"] | tuple[str, list[dict[str, Any]]] = ( - "all" -) +WELCOME_PAGE_LAST_TAB: Literal["examples", "all"] | tuple[ + str, list[dict[str, Any]] +] = "all" # Max allowed size for a zipped file ZIPPED_FILE_MAX_SIZE = 100 * 1024 * 1024 # 100MB diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 719d5af588..d0c7513fbf 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -175,7 +175,9 @@ class DatasourceKind(StrEnum): PHYSICAL = "physical" -class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods +class BaseDatasource( + AuditMixinNullable, ImportExportMixin +): # pylint: disable=too-many-public-methods """A common interface to objects that are queryable (tables and datasources)""" @@ -1263,7 +1265,11 @@ class SqlaTable( def get_schema_perm(self) -> str | None: """Returns schema permission if present, database one otherwise.""" - return security_manager.get_schema_perm(self.database, self.schema or None) + return security_manager.get_schema_perm( + self.database.database_name, + self.catalog, + self.schema or None, + ) def get_perm(self) -> str: """ diff --git a/superset/constants.py b/superset/constants.py index 28902ded6c..8e1563c9d3 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -132,6 +132,7 @@ MODEL_API_RW_METHOD_PERMISSION_MAP = { "related_objects": "read", "tables": "read", "schemas": "read", + "catalogs": "read", "select_star": "read", "table_metadata": "read", "table_metadata_deprecated": "read", diff --git a/superset/databases/api.py b/superset/databases/api.py index a77019123b..6d0af140f0 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -72,7 +72,9 @@ from superset.daos.database import DatabaseDAO, DatabaseUserOAuth2TokensDAO from superset.databases.decorators import check_table_access from superset.databases.filters import DatabaseFilter, DatabaseUploadEnabledFilter from superset.databases.schemas import ( + CatalogsResponseSchema, CSVUploadPostSchema, + database_catalogs_query_schema, database_schemas_query_schema, database_tables_query_schema, DatabaseConnectionSchema, @@ -140,6 +142,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "table_extra_metadata", "table_extra_metadata_deprecated", "select_star", + "catalogs", "schemas", "test_connection", "related_objects", @@ -256,6 +259,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): edit_model_schema = DatabasePutSchema() apispec_parameter_schemas = { + "database_catalogs_query_schema": database_catalogs_query_schema, "database_schemas_query_schema": database_schemas_query_schema, "database_tables_query_schema": database_tables_query_schema, "get_export_ids_schema": get_export_ids_schema, @@ -263,6 +267,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): openapi_spec_tag = "Database" openapi_spec_component_schemas = ( + CatalogsResponseSchema, CSVUploadPostSchema, DatabaseConnectionSchema, DatabaseFunctionNamesResponse, @@ -589,6 +594,69 @@ class DatabaseRestApi(BaseSupersetModelRestApi): ) return self.response_422(message=str(ex)) + @expose("/<int:pk>/catalogs/") + @protect() + @safe + @rison(database_catalogs_query_schema) + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" f".catalogs", + log_to_statsd=False, + ) + def catalogs(self, pk: int, **kwargs: Any) -> FlaskResponse: + """Get all catalogs from a database. + --- + get: + summary: Get all catalogs from a database + parameters: + - in: path + schema: + type: integer + name: pk + description: The database id + - in: query + name: q + content: + application/json: + schema: + $ref: '#/components/schemas/database_catalogs_query_schema' + responses: + 200: + description: A List of all catalogs from the database + content: + application/json: + schema: + $ref: "#/components/schemas/CatalogsResponseSchema" + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + database = self.datamodel.get(pk, self._base_filters) + if not database: + return self.response_404() + try: + catalogs = database.get_all_catalog_names( + cache=database.catalog_cache_enabled, + cache_timeout=database.catalog_cache_timeout or None, + force=kwargs["rison"].get("force", False), + ) + catalogs = security_manager.get_catalogs_accessible_by_user( + database, + catalogs, + ) + return self.response(200, result=list(catalogs)) + except OperationalError: + return self.response( + 500, message="There was an error connecting to the database" + ) + except SupersetException as ex: + return self.response(ex.status, message=ex.message) + @expose("/<int:pk>/schemas/") @protect() @safe @@ -640,8 +708,13 @@ class DatabaseRestApi(BaseSupersetModelRestApi): cache_timeout=database.schema_cache_timeout or None, force=kwargs["rison"].get("force", False), ) - schemas = security_manager.get_schemas_accessible_by_user(database, schemas) - return self.response(200, result=schemas) + catalog = kwargs["rison"].get("catalog") + schemas = security_manager.get_schemas_accessible_by_user( + database, + catalog, + schemas, + ) + return self.response(200, result=list(schemas)) except OperationalError: return self.response( 500, message="There was an error connecting to the database" @@ -1773,9 +1846,9 @@ class DatabaseRestApi(BaseSupersetModelRestApi): and getattr(engine_spec, "default_driver") in drivers ): payload["parameters"] = engine_spec.parameters_json_schema() - payload["sqlalchemy_uri_placeholder"] = ( - engine_spec.sqlalchemy_uri_placeholder - ) + payload[ + "sqlalchemy_uri_placeholder" + ] = engine_spec.sqlalchemy_uri_placeholder available_databases.append(payload) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 1bc0af7472..287f2a39d3 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -56,6 +56,14 @@ from superset.security.analytics_db_safety import check_sqlalchemy_uri from superset.utils.core import markdown, parse_ssl_cert database_schemas_query_schema = { + "type": "object", + "properties": { + "force": {"type": "boolean"}, + "catalog": {"type": "string"}, + }, +} + +database_catalogs_query_schema = { "type": "object", "properties": {"force": {"type": "boolean"}}, } @@ -712,6 +720,12 @@ class SchemasResponseSchema(Schema): ) +class CatalogsResponseSchema(Schema): + result = fields.List( + fields.String(metadata={"description": "A database catalog name"}) + ) + + class DatabaseTablesResponse(Schema): extra = fields.Dict( metadata={"description": "Extra data used to specify column metadata"} diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 3cc1315129..328d41719c 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -131,7 +131,9 @@ builtin_time_grains: dict[str | None, str] = { } -class TimestampExpression(ColumnClause): # pylint: disable=abstract-method, too-many-ancestors +class TimestampExpression( + ColumnClause +): # pylint: disable=abstract-method, too-many-ancestors def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None: """Sqlalchemy class that can be used to render native column elements respecting engine-specific quoting rules as part of a string-based expression. @@ -638,10 +640,17 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return driver in cls.drivers + @classmethod + def get_default_catalog(cls, database: Database) -> str | None: + """ + Return the default catalog for a given database. + """ + return None + @classmethod def get_default_schema(cls, database: Database, catalog: str | None) -> str | None: """ - Return the default schema in a given database. + Return the default schema for a catalog in a given database. """ with database.get_inspector(catalog=catalog) as inspector: return inspector.default_schema_name @@ -1412,24 +1421,24 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cls, database: Database, inspector: Inspector, - ) -> list[str]: + ) -> set[str]: """ Get all catalogs from database. This needs to be implemented per database, since SQLAlchemy doesn't offer an abstraction. """ - return [] + return set() @classmethod - def get_schema_names(cls, inspector: Inspector) -> list[str]: + def get_schema_names(cls, inspector: Inspector) -> set[str]: """ Get all schemas from database :param inspector: SqlAlchemy inspector :return: All schemas in the database """ - return sorted(inspector.get_schema_names()) + return set(inspector.get_schema_names()) @classmethod def get_table_names( # pylint: disable=unused-argument diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 8a2612f5b0..8e508b0e0f 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -464,7 +464,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met cls, database: Database, inspector: Inspector, - ) -> list[str]: + ) -> set[str]: """ Get all catalogs. @@ -475,7 +475,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met client = cls._get_client(engine) projects = client.list_projects() - return sorted(project.project_id for project in projects) + return {project.project_id for project in projects} @classmethod def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: diff --git a/superset/db_engine_specs/impala.py b/superset/db_engine_specs/impala.py index 1d3ec4e9e5..d7d1862aaf 100644 --- a/superset/db_engine_specs/impala.py +++ b/superset/db_engine_specs/impala.py @@ -74,13 +74,12 @@ class ImpalaEngineSpec(BaseEngineSpec): return None @classmethod - def get_schema_names(cls, inspector: Inspector) -> list[str]: - schemas = [ + def get_schema_names(cls, inspector: Inspector) -> set[str]: + return { row[0] for row in inspector.engine.execute("SHOW SCHEMAS") if not row[0].startswith("_") - ] - return schemas + } @classmethod def has_implicit_cancel(cls) -> bool: diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index ce87aa1f9b..bba2157e0a 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -101,8 +101,6 @@ class PostgresBaseEngineSpec(BaseEngineSpec): engine = "" engine_name = "PostgreSQL" - supports_catalog = True - _time_grain_expressions = { None: "{col}", TimeGrain.SECOND: "DATE_TRUNC('second', {col})", @@ -199,7 +197,10 @@ class PostgresBaseEngineSpec(BaseEngineSpec): class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): engine = "postgresql" engine_aliases = {"postgres"} + supports_dynamic_schema = True + supports_catalog = True + supports_dynamic_catalog = True default_driver = "psycopg2" sqlalchemy_uri_placeholder = ( @@ -296,6 +297,29 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): return super().get_default_schema_for_query(database, query) + @classmethod + def adjust_engine_params( + cls, + uri: URL, + connect_args: dict[str, Any], + catalog: str | None = None, + schema: str | None = None, + ) -> tuple[URL, dict[str, Any]]: + """ + Set the catalog (database). + """ + if catalog: + uri = uri.set(database=catalog) + + return uri, connect_args + + @classmethod + def get_default_catalog(cls, database: Database) -> str | None: + """ + Return the default catalog for a given database. + """ + return database.url_object.database + @classmethod def get_prequeries( cls, @@ -346,13 +370,13 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): cls, database: Database, inspector: Inspector, - ) -> list[str]: + ) -> set[str]: """ Return all catalogs. In Postgres, a catalog is called a "database". """ - return sorted( + return { catalog for (catalog,) in inspector.bind.execute( """ @@ -360,7 +384,7 @@ SELECT datname FROM pg_database WHERE datistemplate = false; """ ) - ) + } @classmethod def get_table_names( diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 34c47eb522..5a2c3afa5a 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -648,6 +648,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): engine_name = "Presto" allows_alias_to_source_column = False + supports_catalog = True + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { COLUMN_DOES_NOT_EXIST_REGEX: ( __( @@ -815,11 +817,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): cls, database: Database, inspector: Inspector, - ) -> list[str]: + ) -> set[str]: """ Get all catalogs. """ - return [catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")] + return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")} @classmethod def _create_column_info( diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 83d382cda1..0f03de2188 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -174,18 +174,18 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): cls, database: "Database", inspector: Inspector, - ) -> list[str]: + ) -> set[str]: """ Return all catalogs. In Snowflake, a catalog is called a "database". """ - return sorted( + return { catalog for (catalog,) in inspector.bind.execute( "SELECT DATABASE_NAME from information_schema.databases" ) - ) + } @classmethod def epoch_to_dttm(cls) -> str: diff --git a/superset/extensions/metadb.py b/superset/extensions/metadb.py index 2d8444cc99..fd697aea82 100644 --- a/superset/extensions/metadb.py +++ b/superset/extensions/metadb.py @@ -270,11 +270,6 @@ class SupersetShillelaghAdapter(Adapter): self.schema = parts.pop(-1) if parts else None self.catalog = parts.pop(-1) if parts else None - if self.catalog: - # TODO (betodealmeida): when SIP-95 is implemented we should check to see if - # the database has multi-catalog enabled, and if so, give access. - raise NotImplementedError("Catalogs are not currently supported") - # If the table has a single integer primary key we use that as the row ID in order # to perform updates and deletes. Otherwise we can only do inserts and selects. self._rowid: str | None = None diff --git a/superset/models/core.py b/superset/models/core.py index 9a4a1de403..65b5cd7452 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -117,7 +117,9 @@ class ConfigurationMethod(StrEnum): DYNAMIC_FORM = "dynamic_form" -class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods +class Database( + Model, AuditMixinNullable, ImportExportMixin +): # pylint: disable=too-many-public-methods """An ORM object that stores Database related information""" __tablename__ = "dbs" @@ -313,6 +315,14 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable def metadata_cache_timeout(self) -> dict[str, Any]: return self.get_extra().get("metadata_cache_timeout", {}) + @property + def catalog_cache_enabled(self) -> bool: + return "catalog_cache_timeout" in self.metadata_cache_timeout + + @property + def catalog_cache_timeout(self) -> int | None: + return self.metadata_cache_timeout.get("catalog_cache_timeout") + @property def schema_cache_enabled(self) -> bool: return "schema_cache_timeout" in self.metadata_cache_timeout @@ -549,6 +559,18 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable yield conn + def get_default_catalog(self) -> str | None: + """ + Return the default configured catalog for the database. + """ + return self.db_engine_spec.get_default_catalog(self) + + def get_default_schema(self, catalog: str | None) -> str | None: + """ + Return the default schema for the database. + """ + return self.db_engine_spec.get_default_catalog(self, catalog) + def get_default_schema_for_query(self, query: Query) -> str | None: """ Return the default schema for a given query. @@ -792,22 +814,17 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable key="db:{self.id}:schema_list", cache=cache_manager.cache, ) - def get_all_schema_names( # pylint: disable=unused-argument + def get_all_schema_names( self, + *, catalog: str | None = None, - cache: bool = False, - cache_timeout: int | None = None, - force: bool = False, ssh_tunnel: SSHTunnel | None = None, - ) -> list[str]: - """Parameters need to be passed as keyword arguments. - - For unused parameters, they are referenced in - cache_util.memoized_func decorator. + ) -> set[str]: + """ + Return the schemas in a given database - :param cache: whether cache is enabled for the function - :param cache_timeout: timeout in seconds for the cache - :param force: whether to force refresh the cache + :param catalog: override default catalog + :param ssh_tunnel: SSH tunnel information needed to establish a connection :return: schema list """ try: @@ -819,6 +836,27 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex + @cache_util.memoized_func( + key="db:{self.id}:catalog_list", + cache=cache_manager.cache, + ) + def get_all_catalog_names( + self, + *, + ssh_tunnel: SSHTunnel | None = None, + ) -> list[str]: + """ + Return the catalogs in a given database + + :param ssh_tunnel: SSH tunnel information needed to establish a connection + :return: catalog list + """ + try: + with self.get_inspector(ssh_tunnel=ssh_tunnel) as inspector: + return self.db_engine_spec.get_catalog_names(self, inspector) + except Exception as ex: + raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex + @property def db_engine_spec(self) -> builtins.type[db_engine_specs.BaseEngineSpec]: url = make_url_safe(self.sqlalchemy_uri_decrypted) diff --git a/superset/security/manager.py b/superset/security/manager.py index a84c0cec0d..4e8fc9635f 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -346,17 +346,55 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods return self.get_guest_user_from_request(request) return None + def get_catalog_perm( + self, + database: str, + catalog: Optional[str] = None, + ) -> Optional[str]: + """ + Return the database specific catalog permission. + + :param database: The Superset database or database name + :param catalog: The database catalog name + :return: The database specific schema permission + """ + if catalog is None: + return None + + return f"[{database}].[{catalog}]" + def get_schema_perm( - self, database: Union["Database", str], schema: Optional[str] = None + self, + database: str, + catalog: Optional[str] = None, + schema: Optional[str] = None, ) -> Optional[str]: """ Return the database specific schema permission. - :param database: The Superset database or database name - :param schema: The Superset schema name + Catalogs were added in SIP-95, and not all databases support them. Because of + this, the format used for permissions is different depending on whether a + catalog is passed or not: + + [database].[schema] + [database].[catalog].[schema] + + For backwards compatibility, when processing the first format Superset should + use the default catalog when the database supports them. This way, migrating + existing permissions is not necessary. + + :param database: The database name + :param catalog: The database catalog name + :param schema: The database schema name :return: The database specific schema permission """ - return f"[{database}].[{schema}]" if schema else None + if schema is None: + return None + + if catalog: + return f"[{database}].[{catalog}].[{schema}]" + + return f"[{database}].[{schema}]" @staticmethod def get_database_perm(database_id: int, database_name: str) -> Optional[str]: @@ -435,6 +473,16 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods or self.can_access("database_access", database.perm) # type: ignore ) + def can_access_catalog(self, database: "Database", catalog: str) -> bool: + """ + Return if the user can access the specified catalog. + """ + return ( + self.can_access_all_datasources() + or self.can_access_database(database) + or self.can_access("catalog_access", f"[{database}].[{catalog}]") + ) + def can_access_schema(self, datasource: "BaseDatasource") -> bool: """ Return True if the user can access the schema associated with specified @@ -447,6 +495,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods return ( self.can_access_all_datasources() or self.can_access_database(datasource.database) + or self.can_access_catalog(datasource.database, datasource.catalog) or self.can_access("schema_access", datasource.schema_perm or "") ) @@ -705,43 +754,133 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ] def get_schemas_accessible_by_user( - self, database: "Database", schemas: list[str], hierarchical: bool = True - ) -> list[str]: + self, + database: "Database", + catalog: Optional[str], + schemas: set[str], + hierarchical: bool = True, + ) -> set[str]: """ - Return the list of SQL schemas accessible by the user. + Returned a filtered list of the schemas accessible by the user. + + If not catalog is specified, the default catalog is used. :param database: The SQL database - :param schemas: The list of eligible SQL schemas + :param catalog: An optional database catalog + :param schemas: A set of candidate schemas :param hierarchical: Whether to check using the hierarchical permission logic - :returns: The list of accessible SQL schemas + :returns: The set of accessible database schemas """ # pylint: disable=import-outside-toplevel from superset.connectors.sqla.models import SqlaTable - if hierarchical and self.can_access_database(database): + if hierarchical and ( + self.can_access_database(database) + or self.can_access_catalog(database, catalog) + ): return schemas # schema_access - accessible_schemas = { - self.unpack_database_and_schema(s).schema - for s in self.user_view_menu_names("schema_access") - if s.startswith(f"[{database}].") - } + accessible_schemas: set[str] = set() + schema_access = self.user_view_menu_names("schema_access") + default_catalog = database.get_default_catalog() + default_schema = database.get_default_schema() + + for perm in schema_access: + parts = [part[1:-1] for part in perm.split(".")] + + if parts[0] != database.database_name: + continue + + # [database].[schema] matches when no catalog is specified, or when the user + # specifies the default catalog + if len(parts) == 2 and (catalog is None or catalog == default_catalog): + accessible_schemas.add(parts[1]) + + # [database].[catalog].[schema] matches when the catalog is equal to the + # requested catalog or, when no catalog specified, it's equal to the default + # catalog. + elif len(parts) == 3 and parts[1] == (catalog or default_catalog): + accessible_schemas.add(parts[2]) + + # datasource_access + if perms := self.user_view_menu_names("datasource_access"): + tables = ( + self.get_session.query(SqlaTable.schema) + .filter(SqlaTable.database_id == database.id) + .filter(or_(SqlaTable.perm.in_(perms))) + .distinct() + ) + accessible_schemas.update( + { + table.schema or default_schema + for table in tables + if (table.schema or default_schema) + } + ) + + return schemas & accessible_schemas + + def get_catalogs_accessible_by_user( + self, + database: "Database", + catalogs: set[str], + hierarchical: bool = True, + ) -> set[str]: + """ + Returned a filtered list of the catalogs accessible by the user. + + :param database: The SQL database + :param catalogs: A set of candidate catalogs + :param hierarchical: Whether to check using the hierarchical permission logic + :returns: The set of accessible database catalogs + """ + # pylint: disable=import-outside-toplevel + from superset.connectors.sqla.models import SqlaTable + + if hierarchical and self.can_access_database(database): + return catalogs + + # catalog access + accessible_catalogs: set[str] = set() + catalog_access = self.user_view_menu_names("catalog_access") + default_catalog = database.get_default_catalog() + + for perm in catalog_access: + parts = [part[1:-1] for part in perm.split(".")] + if parts[0] == database.database_name: + accessible_catalogs.add(parts[1]) + + # schema access + schema_access = self.user_view_menu_names("schema_access") + for perm in schema_access: + parts = [part[1:-1] for part in perm.split(".")] + + if parts[0] != database.database_name: + continue + if len(parts) == 2 and default_catalog: + accessible_catalogs.add(default_catalog) + elif len(parts) == 3: + accessible_catalogs.add(parts[2]) # datasource_access if perms := self.user_view_menu_names("datasource_access"): tables = ( self.get_session.query(SqlaTable.schema) .filter(SqlaTable.database_id == database.id) - .filter(SqlaTable.schema.isnot(None)) - .filter(SqlaTable.schema != "") .filter(or_(SqlaTable.perm.in_(perms))) .distinct() ) - accessible_schemas.update([table.schema for table in tables]) + accessible_catalogs.update( + { + table.catalog or default_catalog + for table in tables + if (table.catalog or default_catalog) + } + ) - return [s for s in schemas if s in accessible_schemas] + return catalogs & accessible_catalogs def get_datasources_accessible_by_user( # pylint: disable=invalid-name self, @@ -763,6 +902,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods if self.can_access_database(database): return datasource_names + # XXX if schema: schema_perm = self.get_schema_perm(database, schema) if schema_perm and self.can_access("schema_access", schema_perm): @@ -1234,6 +1374,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods :param target: The database object :return: A list of changed view menus (permission resource names) """ + # XXX view_menu_table = self.viewmenu_model.__table__ # pylint: disable=no-member new_database_name = target.database_name old_view_menu_name = self.get_database_perm(target.id, old_database_name) @@ -1400,7 +1541,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods if target.schema: dataset_schema_perm = self.get_schema_perm( - database.database_name, target.schema + database.database_name, + target.catalog, + target.schema, ) self._insert_pvm_on_sqla_event( mapper, connection, "schema_access", dataset_schema_perm @@ -1480,7 +1623,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods # Updates schema permissions new_dataset_schema_name = self.get_schema_perm( - target.database.database_name, target.schema + target.database.database_name, + target.catalog, + target.schema, ) self._update_dataset_schema_perm( mapper, @@ -1504,7 +1649,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods # When schema changes if current_schema != target.schema: new_dataset_schema_name = self.get_schema_perm( - target.database.database_name, target.schema + target.database.database_name, + target.catalog, + target.schema, ) self._update_dataset_schema_perm( mapper, @@ -1980,7 +2127,11 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods denied = set() for table_ in tables: - schema_perm = self.get_schema_perm(database, schema=table_.schema) + schema_perm = self.get_schema_perm( + database, + table.catalog, + table_.schema, + ) if not (schema_perm and self.can_access("schema_access", schema_perm)): datasources = SqlaTable.query_datasources_by_name( diff --git a/superset/utils/cache.py b/superset/utils/cache.py index 48e283e7c1..00216fc4b1 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -119,7 +119,11 @@ def memoized_func(key: str, cache: Cache = cache_manager.cache) -> Callable[..., def wrap(f: Callable[..., Any]) -> Callable[..., Any]: def wrapped_f(*args: Any, **kwargs: Any) -> Any: - if not kwargs.get("cache", True): + should_cache = kwargs.pop("cache", True) + force = kwargs.pop("force", False) + cache_timeout = kwargs.pop("cache_timeout", 0) + + if not should_cache: return f(*args, **kwargs) # format the key using args/kwargs passed to the decorated function @@ -129,10 +133,10 @@ def memoized_func(key: str, cache: Cache = cache_manager.cache) -> Callable[..., cache_key = key.format(**bound_args.arguments) obj = cache.get(cache_key) - if not kwargs.get("force") and obj is not None: + if not force and obj is not None: return obj obj = f(*args, **kwargs) - cache.set(cache_key, obj, timeout=kwargs.get("cache_timeout", 0)) + cache.set(cache_key, obj, timeout=cache_timeout) return obj return wrapped_f diff --git a/superset/views/database/mixins.py b/superset/views/database/mixins.py index c6e799e6d4..0d104aad5f 100644 --- a/superset/views/database/mixins.py +++ b/superset/views/database/mixins.py @@ -211,11 +211,29 @@ class DatabaseMixin: utils.parse_ssl_cert(database.server_cert) database.set_sqlalchemy_uri(database.sqlalchemy_uri) security_manager.add_permission_view_menu("database_access", database.perm) - # adding a new database we always want to force refresh schema list - for schema in database.get_all_schema_names(): - security_manager.add_permission_view_menu( - "schema_access", security_manager.get_schema_perm(database, schema) - ) + + # add catalog/schema permissions + if database.db_engine_spec.supports_catalog: + catalogs = database.get_all_catalog_names() + for catalog in catalogs: + security_manager.add_permission_view_menu( + "catalog_access", + security_manager.get_catalog_perm(database.database_name, catalog), + ) + else: + # add a dummy catalog for DBs that don't support them + catalogs = [None] + + for catalog in catalogs: + for schema in database.get_all_schema_names(catalog=catalog): + security_manager.add_permission_view_menu( + "schema_access", + security_manager.get_schema_perm( + database.database_name, + catalog, + schema, + ), + ) def pre_add(self, database: Database) -> None: self._pre_add_update(database)
