This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch cleanup-db-update-command
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/cleanup-db-update-command by
this push:
new 15d748ae7a chore: clean up DB create command
15d748ae7a is described below
commit 15d748ae7a12e444a9d605312bf54f83092c5110
Author: Beto Dealmeida <[email protected]>
AuthorDate: Sat Apr 27 19:41:31 2024 -0400
chore: clean up DB create command
---
.../commands/database/ssh_tunnel/exceptions.py | 28 ++--
superset/commands/database/update.py | 162 +++++----------------
2 files changed, 57 insertions(+), 133 deletions(-)
diff --git a/superset/commands/database/ssh_tunnel/exceptions.py
b/superset/commands/database/ssh_tunnel/exceptions.py
index a0def8c087..f74e8f397a 100644
--- a/superset/commands/database/ssh_tunnel/exceptions.py
+++ b/superset/commands/database/ssh_tunnel/exceptions.py
@@ -25,47 +25,53 @@ from superset.commands.exceptions import (
)
-class SSHTunnelDeleteFailedError(DeleteFailedError):
+class SSHTunnelError(Exception):
+ """
+ Base class.
+ """
+
+
+class SSHTunnelDeleteFailedError(DeleteFailedError, SSHTunnelError):
message = _("SSH Tunnel could not be deleted.")
-class SSHTunnelNotFoundError(CommandException):
+class SSHTunnelNotFoundError(CommandException, SSHTunnelError):
status = 404
message = _("SSH Tunnel not found.")
-class SSHTunnelInvalidError(CommandInvalidError):
+class SSHTunnelInvalidError(CommandInvalidError, SSHTunnelError):
message = _("SSH Tunnel parameters are invalid.")
-class SSHTunnelDatabasePortError(CommandInvalidError):
+class SSHTunnelDatabasePortError(CommandInvalidError, SSHTunnelError):
message = _("A database port is required when connecting via SSH Tunnel.")
-class SSHTunnelUpdateFailedError(UpdateFailedError):
+class SSHTunnelUpdateFailedError(UpdateFailedError, SSHTunnelError):
message = _("SSH Tunnel could not be updated.")
-class SSHTunnelCreateFailedError(CommandException):
+class SSHTunnelCreateFailedError(CommandException, SSHTunnelError):
message = _("Creating SSH Tunnel failed for an unknown reason")
-class SSHTunnelingNotEnabledError(CommandException):
+class SSHTunnelingNotEnabledError(CommandException, SSHTunnelError):
status = 400
message = _("SSH Tunneling is not enabled")
-class SSHTunnelRequiredFieldValidationError(ValidationError):
+class SSHTunnelRequiredFieldValidationError(ValidationError, SSHTunnelError):
def __init__(self, field_name: str) -> None:
super().__init__(
- [_("Field is required")],
+ [_("Field is required")], # type: ignore
field_name=field_name,
)
-class SSHTunnelMissingCredentials(CommandInvalidError):
+class SSHTunnelMissingCredentials(CommandInvalidError, SSHTunnelError):
message = _("Must provide credentials for the SSH Tunnel")
-class SSHTunnelInvalidCredentials(CommandInvalidError):
+class SSHTunnelInvalidCredentials(CommandInvalidError, SSHTunnelError):
message = _("Cannot have multiple credentials for the SSH Tunnel")
diff --git a/superset/commands/database/update.py
b/superset/commands/database/update.py
index b057cb300e..ab63574147 100644
--- a/superset/commands/database/update.py
+++ b/superset/commands/database/update.py
@@ -23,7 +23,6 @@ from marshmallow import ValidationError
from superset import is_feature_enabled
from superset.commands.base import BaseCommand
from superset.commands.database.exceptions import (
- DatabaseConnectionFailedError,
DatabaseExistsValidationError,
DatabaseInvalidError,
DatabaseNotFoundError,
@@ -32,19 +31,14 @@ from superset.commands.database.exceptions import (
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
from superset.commands.database.ssh_tunnel.exceptions import (
- SSHTunnelCreateFailedError,
- SSHTunnelDatabasePortError,
- SSHTunnelDeleteFailedError,
+ SSHTunnelError,
SSHTunnelingNotEnabledError,
- SSHTunnelInvalidError,
- SSHTunnelUpdateFailedError,
)
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.daos.database import DatabaseDAO
from superset.daos.exceptions import DAOCreateFailedError, DAOUpdateFailedError
-from superset.extensions import db, security_manager
+from superset.extensions import db
from superset.models.core import Database
-from superset.utils.core import DatasourceType
logger = logging.getLogger(__name__)
@@ -57,7 +51,7 @@ class UpdateDatabaseCommand(BaseCommand):
self._model_id = model_id
self._model: Optional[Database] = None
- def run(self) -> Model: # pylint: disable=too-many-statements,
too-many-branches
+ def run(self) -> Model:
self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model:
@@ -65,8 +59,6 @@ class UpdateDatabaseCommand(BaseCommand):
self.validate()
- old_database_name = self._model.database_name
-
# unmask ``encrypted_extra``
self._properties["encrypted_extra"] = (
self._model.db_engine_spec.unmask_encrypted_extra(
@@ -76,126 +68,52 @@ class UpdateDatabaseCommand(BaseCommand):
)
try:
- database = DatabaseDAO.update(self._model, self._properties,
commit=False)
+ database = DatabaseDAO.update(
+ self._model,
+ self._properties,
+ commit=False,
+ )
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
- ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
-
- if "ssh_tunnel" in self._properties:
- if not is_feature_enabled("SSH_TUNNELING"):
- db.session.rollback()
- raise SSHTunnelingNotEnabledError()
-
- if self._properties.get("ssh_tunnel") is None and ssh_tunnel:
- # We need to remove the existing tunnel
- try:
- DeleteSSHTunnelCommand(ssh_tunnel.id).run()
- ssh_tunnel = None
- except SSHTunnelDeleteFailedError as ex:
- raise ex
- except Exception as ex:
- raise DatabaseUpdateFailedError() from ex
-
- if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
- if ssh_tunnel is None:
- # We couldn't found an existing tunnel so we need to
create one
- try:
- ssh_tunnel = CreateSSHTunnelCommand(
- database, ssh_tunnel_properties
- ).run()
- except (
- SSHTunnelInvalidError,
- SSHTunnelCreateFailedError,
- SSHTunnelDatabasePortError,
- ) as ex:
- # So we can show the original message
- raise ex
- except Exception as ex:
- raise DatabaseUpdateFailedError() from ex
- else:
- # We found an existing tunnel so we need to update it
- try:
- ssh_tunnel_id = ssh_tunnel.id
- ssh_tunnel = UpdateSSHTunnelCommand(
- ssh_tunnel_id, ssh_tunnel_properties
- ).run()
- except (
- SSHTunnelInvalidError,
- SSHTunnelUpdateFailedError,
- SSHTunnelDatabasePortError,
- ) as ex:
- # So we can show the original message
- raise ex
- except Exception as ex:
- raise DatabaseUpdateFailedError() from ex
-
- # adding a new database we always want to force refresh schema list
- # TODO Improve this simplistic implementation for catching DB conn
fails
try:
- schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel)
+ self._handle_ssh_tunnel(database)
+ except SSHTunnelError:
+ raise
except Exception as ex:
- db.session.rollback()
- raise DatabaseConnectionFailedError() from ex
-
- # Update database schema permissions
- new_schemas: list[str] = []
-
- for schema in schemas:
- old_view_menu_name = security_manager.get_schema_perm(
- old_database_name, schema
- )
- new_view_menu_name = security_manager.get_schema_perm(
- database.database_name, schema
- )
- schema_pvm = security_manager.find_permission_view_menu(
- "schema_access", old_view_menu_name
- )
- # Update the schema permission if the database name changed
- if schema_pvm and old_database_name != database.database_name:
- schema_pvm.view_menu.name = new_view_menu_name
-
- self._propagate_schema_permissions(
- old_view_menu_name, new_view_menu_name
- )
- else:
- new_schemas.append(schema)
- for schema in new_schemas:
- security_manager.add_permission_view_menu(
- "schema_access",
security_manager.get_schema_perm(database, schema)
- )
-
- db.session.commit()
+ raise DatabaseUpdateFailedError() from ex
except (DAOUpdateFailedError, DAOCreateFailedError) as ex:
raise DatabaseUpdateFailedError() from ex
- return database
- @staticmethod
- def _propagate_schema_permissions(
- old_view_menu_name: str, new_view_menu_name: str
- ) -> None:
- from superset.connectors.sqla.models import ( # pylint:
disable=import-outside-toplevel
- SqlaTable,
- )
- from superset.models.slice import ( # pylint:
disable=import-outside-toplevel
- Slice,
- )
+ return database
- # Update schema_perm on all datasets
- datasets = (
- db.session.query(SqlaTable)
- .filter(SqlaTable.schema_perm == old_view_menu_name)
- .all()
- )
- for dataset in datasets:
- dataset.schema_perm = new_view_menu_name
- charts = db.session.query(Slice).filter(
- Slice.datasource_type == DatasourceType.TABLE,
- Slice.datasource_id == dataset.id,
- )
- # Update schema_perm on all charts
- for chart in charts:
- chart.schema_perm = new_view_menu_name
+ def _handle_ssh_tunnel(self, database: Database) -> None:
+ """
+ Delete, create, or update an SSH tunnel.
+ """
+ if not is_feature_enabled("SSH_TUNNELING"):
+ db.session.rollback()
+ raise SSHTunnelingNotEnabledError()
+
+ if "ssh_tunnel" not in self._properties:
+ return
+
+ current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
+ ssh_tunnel_properties = self._properties["ssh_tunnel"]
+
+ if ssh_tunnel_properties is None:
+ if current_ssh_tunnel:
+ DeleteSSHTunnelCommand(current_ssh_tunnel.id).run()
+ return
+
+ if current_ssh_tunnel is None:
+ CreateSSHTunnelCommand(database, ssh_tunnel_properties).run()
+ return
+
+ UpdateSSHTunnelCommand(
+ current_ssh_tunnel.id,
+ ssh_tunnel_properties,
+ ).run()
def validate(self) -> None:
exceptions: list[ValidationError] = []