This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch cleanup-db-update-command
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/cleanup-db-update-command by 
this push:
     new 15d748ae7a chore: clean up DB create command
15d748ae7a is described below

commit 15d748ae7a12e444a9d605312bf54f83092c5110
Author: Beto Dealmeida <[email protected]>
AuthorDate: Sat Apr 27 19:41:31 2024 -0400

    chore: clean up DB create command
---
 .../commands/database/ssh_tunnel/exceptions.py     |  28 ++--
 superset/commands/database/update.py               | 162 +++++----------------
 2 files changed, 57 insertions(+), 133 deletions(-)

diff --git a/superset/commands/database/ssh_tunnel/exceptions.py 
b/superset/commands/database/ssh_tunnel/exceptions.py
index a0def8c087..f74e8f397a 100644
--- a/superset/commands/database/ssh_tunnel/exceptions.py
+++ b/superset/commands/database/ssh_tunnel/exceptions.py
@@ -25,47 +25,53 @@ from superset.commands.exceptions import (
 )
 
 
-class SSHTunnelDeleteFailedError(DeleteFailedError):
+class SSHTunnelError(Exception):
+    """
+    Base class.
+    """
+
+
+class SSHTunnelDeleteFailedError(DeleteFailedError, SSHTunnelError):
     message = _("SSH Tunnel could not be deleted.")
 
 
-class SSHTunnelNotFoundError(CommandException):
+class SSHTunnelNotFoundError(CommandException, SSHTunnelError):
     status = 404
     message = _("SSH Tunnel not found.")
 
 
-class SSHTunnelInvalidError(CommandInvalidError):
+class SSHTunnelInvalidError(CommandInvalidError, SSHTunnelError):
     message = _("SSH Tunnel parameters are invalid.")
 
 
-class SSHTunnelDatabasePortError(CommandInvalidError):
+class SSHTunnelDatabasePortError(CommandInvalidError, SSHTunnelError):
     message = _("A database port is required when connecting via SSH Tunnel.")
 
 
-class SSHTunnelUpdateFailedError(UpdateFailedError):
+class SSHTunnelUpdateFailedError(UpdateFailedError, SSHTunnelError):
     message = _("SSH Tunnel could not be updated.")
 
 
-class SSHTunnelCreateFailedError(CommandException):
+class SSHTunnelCreateFailedError(CommandException, SSHTunnelError):
     message = _("Creating SSH Tunnel failed for an unknown reason")
 
 
-class SSHTunnelingNotEnabledError(CommandException):
+class SSHTunnelingNotEnabledError(CommandException, SSHTunnelError):
     status = 400
     message = _("SSH Tunneling is not enabled")
 
 
-class SSHTunnelRequiredFieldValidationError(ValidationError):
+class SSHTunnelRequiredFieldValidationError(ValidationError, SSHTunnelError):
     def __init__(self, field_name: str) -> None:
         super().__init__(
-            [_("Field is required")],
+            [_("Field is required")],  # type: ignore
             field_name=field_name,
         )
 
 
-class SSHTunnelMissingCredentials(CommandInvalidError):
+class SSHTunnelMissingCredentials(CommandInvalidError, SSHTunnelError):
     message = _("Must provide credentials for the SSH Tunnel")
 
 
-class SSHTunnelInvalidCredentials(CommandInvalidError):
+class SSHTunnelInvalidCredentials(CommandInvalidError, SSHTunnelError):
     message = _("Cannot have multiple credentials for the SSH Tunnel")
diff --git a/superset/commands/database/update.py 
b/superset/commands/database/update.py
index b057cb300e..ab63574147 100644
--- a/superset/commands/database/update.py
+++ b/superset/commands/database/update.py
@@ -23,7 +23,6 @@ from marshmallow import ValidationError
 from superset import is_feature_enabled
 from superset.commands.base import BaseCommand
 from superset.commands.database.exceptions import (
-    DatabaseConnectionFailedError,
     DatabaseExistsValidationError,
     DatabaseInvalidError,
     DatabaseNotFoundError,
@@ -32,19 +31,14 @@ from superset.commands.database.exceptions import (
 from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
 from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
 from superset.commands.database.ssh_tunnel.exceptions import (
-    SSHTunnelCreateFailedError,
-    SSHTunnelDatabasePortError,
-    SSHTunnelDeleteFailedError,
+    SSHTunnelError,
     SSHTunnelingNotEnabledError,
-    SSHTunnelInvalidError,
-    SSHTunnelUpdateFailedError,
 )
 from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
 from superset.daos.database import DatabaseDAO
 from superset.daos.exceptions import DAOCreateFailedError, DAOUpdateFailedError
-from superset.extensions import db, security_manager
+from superset.extensions import db
 from superset.models.core import Database
-from superset.utils.core import DatasourceType
 
 logger = logging.getLogger(__name__)
 
@@ -57,7 +51,7 @@ class UpdateDatabaseCommand(BaseCommand):
         self._model_id = model_id
         self._model: Optional[Database] = None
 
-    def run(self) -> Model:  # pylint: disable=too-many-statements, 
too-many-branches
+    def run(self) -> Model:
         self._model = DatabaseDAO.find_by_id(self._model_id)
 
         if not self._model:
@@ -65,8 +59,6 @@ class UpdateDatabaseCommand(BaseCommand):
 
         self.validate()
 
-        old_database_name = self._model.database_name
-
         # unmask ``encrypted_extra``
         self._properties["encrypted_extra"] = (
             self._model.db_engine_spec.unmask_encrypted_extra(
@@ -76,126 +68,52 @@ class UpdateDatabaseCommand(BaseCommand):
         )
 
         try:
-            database = DatabaseDAO.update(self._model, self._properties, 
commit=False)
+            database = DatabaseDAO.update(
+                self._model,
+                self._properties,
+                commit=False,
+            )
             database.set_sqlalchemy_uri(database.sqlalchemy_uri)
 
-            ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
-
-            if "ssh_tunnel" in self._properties:
-                if not is_feature_enabled("SSH_TUNNELING"):
-                    db.session.rollback()
-                    raise SSHTunnelingNotEnabledError()
-
-                if self._properties.get("ssh_tunnel") is None and ssh_tunnel:
-                    # We need to remove the existing tunnel
-                    try:
-                        DeleteSSHTunnelCommand(ssh_tunnel.id).run()
-                        ssh_tunnel = None
-                    except SSHTunnelDeleteFailedError as ex:
-                        raise ex
-                    except Exception as ex:
-                        raise DatabaseUpdateFailedError() from ex
-
-                if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
-                    if ssh_tunnel is None:
-                        # We couldn't found an existing tunnel so we need to 
create one
-                        try:
-                            ssh_tunnel = CreateSSHTunnelCommand(
-                                database, ssh_tunnel_properties
-                            ).run()
-                        except (
-                            SSHTunnelInvalidError,
-                            SSHTunnelCreateFailedError,
-                            SSHTunnelDatabasePortError,
-                        ) as ex:
-                            # So we can show the original message
-                            raise ex
-                        except Exception as ex:
-                            raise DatabaseUpdateFailedError() from ex
-                    else:
-                        # We found an existing tunnel so we need to update it
-                        try:
-                            ssh_tunnel_id = ssh_tunnel.id
-                            ssh_tunnel = UpdateSSHTunnelCommand(
-                                ssh_tunnel_id, ssh_tunnel_properties
-                            ).run()
-                        except (
-                            SSHTunnelInvalidError,
-                            SSHTunnelUpdateFailedError,
-                            SSHTunnelDatabasePortError,
-                        ) as ex:
-                            # So we can show the original message
-                            raise ex
-                        except Exception as ex:
-                            raise DatabaseUpdateFailedError() from ex
-
-            # adding a new database we always want to force refresh schema list
-            # TODO Improve this simplistic implementation for catching DB conn 
fails
             try:
-                schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel)
+                self._handle_ssh_tunnel(database)
+            except SSHTunnelError:
+                raise
             except Exception as ex:
-                db.session.rollback()
-                raise DatabaseConnectionFailedError() from ex
-
-            # Update database schema permissions
-            new_schemas: list[str] = []
-
-            for schema in schemas:
-                old_view_menu_name = security_manager.get_schema_perm(
-                    old_database_name, schema
-                )
-                new_view_menu_name = security_manager.get_schema_perm(
-                    database.database_name, schema
-                )
-                schema_pvm = security_manager.find_permission_view_menu(
-                    "schema_access", old_view_menu_name
-                )
-                # Update the schema permission if the database name changed
-                if schema_pvm and old_database_name != database.database_name:
-                    schema_pvm.view_menu.name = new_view_menu_name
-
-                    self._propagate_schema_permissions(
-                        old_view_menu_name, new_view_menu_name
-                    )
-                else:
-                    new_schemas.append(schema)
-            for schema in new_schemas:
-                security_manager.add_permission_view_menu(
-                    "schema_access", 
security_manager.get_schema_perm(database, schema)
-                )
-
-            db.session.commit()
+                raise DatabaseUpdateFailedError() from ex
 
         except (DAOUpdateFailedError, DAOCreateFailedError) as ex:
             raise DatabaseUpdateFailedError() from ex
-        return database
 
-    @staticmethod
-    def _propagate_schema_permissions(
-        old_view_menu_name: str, new_view_menu_name: str
-    ) -> None:
-        from superset.connectors.sqla.models import (  # pylint: 
disable=import-outside-toplevel
-            SqlaTable,
-        )
-        from superset.models.slice import (  # pylint: 
disable=import-outside-toplevel
-            Slice,
-        )
+        return database
 
-        # Update schema_perm on all datasets
-        datasets = (
-            db.session.query(SqlaTable)
-            .filter(SqlaTable.schema_perm == old_view_menu_name)
-            .all()
-        )
-        for dataset in datasets:
-            dataset.schema_perm = new_view_menu_name
-            charts = db.session.query(Slice).filter(
-                Slice.datasource_type == DatasourceType.TABLE,
-                Slice.datasource_id == dataset.id,
-            )
-            # Update schema_perm on all charts
-            for chart in charts:
-                chart.schema_perm = new_view_menu_name
+    def _handle_ssh_tunnel(self, database: Database) -> None:
+        """
+        Delete, create, or update an SSH tunnel.
+        """
+        if not is_feature_enabled("SSH_TUNNELING"):
+            db.session.rollback()
+            raise SSHTunnelingNotEnabledError()
+
+        if "ssh_tunnel" not in self._properties:
+            return
+
+        current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
+        ssh_tunnel_properties = self._properties["ssh_tunnel"]
+
+        if ssh_tunnel_properties is None:
+            if current_ssh_tunnel:
+                DeleteSSHTunnelCommand(current_ssh_tunnel.id).run()
+            return
+
+        if current_ssh_tunnel is None:
+            CreateSSHTunnelCommand(database, ssh_tunnel_properties).run()
+            return
+
+        UpdateSSHTunnelCommand(
+            current_ssh_tunnel.id,
+            ssh_tunnel_properties,
+        ).run()
 
     def validate(self) -> None:
         exceptions: list[ValidationError] = []

Reply via email to