This is an automated email from the ASF dual-hosted git repository. diegopucci pushed a commit to branch diego/ch78628/fix-disabled-ssh-toggle in repository https://gitbox.apache.org/repos/asf/superset.git
commit b94994f07fd3a121cd17536dc610a3d48bc325ba Author: geido <[email protected]> AuthorDate: Tue Feb 20 13:23:12 2024 +0200 Catch missing database port for SSH Tunnel --- superset/commands/database/create.py | 4 ++ superset/commands/database/ssh_tunnel/create.py | 8 ++++ .../commands/database/ssh_tunnel/exceptions.py | 4 ++ superset/commands/database/ssh_tunnel/update.py | 6 +++ superset/commands/database/test_connection.py | 52 ++++++++++++---------- superset/commands/database/update.py | 33 ++++++++------ superset/databases/api.py | 5 ++- 7 files changed, 73 insertions(+), 39 deletions(-) diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index cde9dd8e88..1ddc08e6a1 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -19,6 +19,7 @@ from typing import Any, Optional from flask import current_app from flask_appbuilder.models.sqla import Model +from flask_babel import gettext as _ from marshmallow import ValidationError from superset import is_feature_enabled @@ -33,6 +34,7 @@ from superset.commands.database.exceptions import ( from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand from superset.commands.database.ssh_tunnel.exceptions import ( SSHTunnelCreateFailedError, + SSHTunnelDatabasePortError, SSHTunnelingNotEnabledError, SSHTunnelInvalidError, ) @@ -103,6 +105,7 @@ class CreateDatabaseCommand(BaseCommand): SSHTunnelInvalidError, SSHTunnelCreateFailedError, SSHTunnelingNotEnabledError, + SSHTunnelDatabasePortError, ) as ex: db.session.rollback() event_logger.log_with_context( @@ -140,6 +143,7 @@ class CreateDatabaseCommand(BaseCommand): # Check database_name uniqueness if not DatabaseDAO.validate_uniqueness(database_name): exceptions.append(DatabaseExistsValidationError()) + if exceptions: exception = DatabaseInvalidError() exception.extend(exceptions) diff --git a/superset/commands/database/ssh_tunnel/create.py b/superset/commands/database/ssh_tunnel/create.py index 59e083d4d8..287accc5aa 100644 --- a/superset/commands/database/ssh_tunnel/create.py +++ b/superset/commands/database/ssh_tunnel/create.py @@ -23,11 +23,13 @@ from marshmallow import ValidationError from superset.commands.base import BaseCommand from superset.commands.database.ssh_tunnel.exceptions import ( SSHTunnelCreateFailedError, + SSHTunnelDatabasePortError, SSHTunnelInvalidError, SSHTunnelRequiredFieldValidationError, ) from superset.daos.database import SSHTunnelDAO from superset.daos.exceptions import DAOCreateFailedError +from superset.databases.utils import make_url_safe from superset.extensions import event_logger from superset.models.core import Database @@ -35,9 +37,12 @@ logger = logging.getLogger(__name__) class CreateSSHTunnelCommand(BaseCommand): + _database: Database + def __init__(self, database: Database, data: dict[str, Any]): self._properties = data.copy() self._properties["database"] = database + self._database = database def run(self) -> Model: try: @@ -62,6 +67,9 @@ class CreateSSHTunnelCommand(BaseCommand): private_key_password: Optional[str] = self._properties.get( "private_key_password" ) + url = make_url_safe(self._database.sqlalchemy_uri) + if not url.port: + raise SSHTunnelDatabasePortError() if not server_address: exceptions.append(SSHTunnelRequiredFieldValidationError("server_address")) if not server_port: diff --git a/superset/commands/database/ssh_tunnel/exceptions.py b/superset/commands/database/ssh_tunnel/exceptions.py index 0e3f91cae6..a0def8c087 100644 --- a/superset/commands/database/ssh_tunnel/exceptions.py +++ b/superset/commands/database/ssh_tunnel/exceptions.py @@ -38,6 +38,10 @@ class SSHTunnelInvalidError(CommandInvalidError): message = _("SSH Tunnel parameters are invalid.") +class SSHTunnelDatabasePortError(CommandInvalidError): + message = _("A database port is required when connecting via SSH Tunnel.") + + class SSHTunnelUpdateFailedError(UpdateFailedError): message = _("SSH Tunnel could not be updated.") diff --git a/superset/commands/database/ssh_tunnel/update.py b/superset/commands/database/ssh_tunnel/update.py index 47f7d4947a..077ed4c321 100644 --- a/superset/commands/database/ssh_tunnel/update.py +++ b/superset/commands/database/ssh_tunnel/update.py @@ -21,6 +21,7 @@ from flask_appbuilder.models.sqla import Model from superset.commands.base import BaseCommand from superset.commands.database.ssh_tunnel.exceptions import ( + SSHTunnelDatabasePortError, SSHTunnelInvalidError, SSHTunnelNotFoundError, SSHTunnelRequiredFieldValidationError, @@ -29,6 +30,7 @@ from superset.commands.database.ssh_tunnel.exceptions import ( from superset.daos.database import SSHTunnelDAO from superset.daos.exceptions import DAOUpdateFailedError from superset.databases.ssh_tunnel.models import SSHTunnel +from superset.databases.utils import make_url_safe logger = logging.getLogger(__name__) @@ -62,6 +64,8 @@ class UpdateSSHTunnelCommand(BaseCommand): self._model = SSHTunnelDAO.find_by_id(self._model_id) if not self._model: raise SSHTunnelNotFoundError() + + url = make_url_safe(self._model.database.sqlalchemy_uri) private_key: Optional[str] = self._properties.get("private_key") private_key_password: Optional[str] = self._properties.get( "private_key_password" @@ -70,3 +74,5 @@ class UpdateSSHTunnelCommand(BaseCommand): raise SSHTunnelInvalidError( exceptions=[SSHTunnelRequiredFieldValidationError("private_key")] ) + if not url.port: + raise SSHTunnelDatabasePortError() diff --git a/superset/commands/database/test_connection.py b/superset/commands/database/test_connection.py index 0ffdf3ddd9..e91eec3a89 100644 --- a/superset/commands/database/test_connection.py +++ b/superset/commands/database/test_connection.py @@ -32,8 +32,11 @@ from superset.commands.database.exceptions import ( DatabaseTestConnectionDriverError, DatabaseTestConnectionUnexpectedError, ) -from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelingNotEnabledError -from superset.daos.database import DatabaseDAO, SSHTunnelDAO +from superset.commands.database.ssh_tunnel.exceptions import ( + SSHTunnelDatabasePortError, + SSHTunnelingNotEnabledError, +) +from superset.daos.database import DatabaseDAO from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetErrorType @@ -44,7 +47,6 @@ from superset.exceptions import ( ) from superset.extensions import event_logger from superset.models.core import Database -from superset.utils.ssh_tunnel import unmask_password_info logger = logging.getLogger(__name__) @@ -61,20 +63,22 @@ def get_log_connection_action( class TestConnectionDatabaseCommand(BaseCommand): + _model: Optional[Database] = None + _context: dict[str, Any] + _uri: str + def __init__(self, data: dict[str, Any]): self._properties = data.copy() - self._model: Optional[Database] = None - def run(self) -> None: # pylint: disable=too-many-statements, too-many-branches - self.validate() - ex_str = "" + if (database_name := self._properties.get("database_name")) is not None: + self._model = DatabaseDAO.get_database_by_name(database_name) + uri = self._properties.get("sqlalchemy_uri", "") if self._model and uri == self._model.safe_sqlalchemy_uri(): uri = self._model.sqlalchemy_uri_decrypted - ssh_tunnel = self._properties.get("ssh_tunnel") - # context for error messages url = make_url_safe(uri) + context = { "hostname": url.host, "password": url.password, @@ -83,6 +87,14 @@ class TestConnectionDatabaseCommand(BaseCommand): "database": url.database, } + self._context = context + self._uri = uri + + def run(self) -> None: # pylint: disable=too-many-statements, too-many-branches + self.validate() + ex_str = "" + ssh_tunnel = self._properties.get("ssh_tunnel") + serialized_encrypted_extra = self._properties.get( "masked_encrypted_extra", "{}", @@ -103,20 +115,11 @@ class TestConnectionDatabaseCommand(BaseCommand): encrypted_extra=serialized_encrypted_extra, ) - database.set_sqlalchemy_uri(uri) + database.set_sqlalchemy_uri(self._uri) database.db_engine_spec.mutate_db_for_connection_test(database) # Generate tunnel if present in the properties if ssh_tunnel: - if not is_feature_enabled("SSH_TUNNELING"): - raise SSHTunnelingNotEnabledError() - # If there's an existing tunnel for that DB we need to use the stored - # password, private_key and private_key_password instead - if ssh_tunnel_id := ssh_tunnel.pop("id", None): - if existing_ssh_tunnel := SSHTunnelDAO.find_by_id(ssh_tunnel_id): - ssh_tunnel = unmask_password_info( - ssh_tunnel, existing_ssh_tunnel - ) ssh_tunnel = SSHTunnel(**ssh_tunnel) event_logger.log_with_context( @@ -186,7 +189,7 @@ class TestConnectionDatabaseCommand(BaseCommand): engine=database.db_engine_spec.__name__, ) # check for custom errors (wrong username, wrong password, etc) - errors = database.db_engine_spec.extract_errors(ex, context) + errors = database.db_engine_spec.extract_errors(ex, self._context) raise SupersetErrorsException(errors) from ex except SupersetSecurityException as ex: event_logger.log_with_context( @@ -221,9 +224,12 @@ class TestConnectionDatabaseCommand(BaseCommand): ), engine=database.db_engine_spec.__name__, ) - errors = database.db_engine_spec.extract_errors(ex, context) + errors = database.db_engine_spec.extract_errors(ex, self._context) raise DatabaseTestConnectionUnexpectedError(errors) from ex def validate(self) -> None: - if (database_name := self._properties.get("database_name")) is not None: - self._model = DatabaseDAO.get_database_by_name(database_name) + if self._properties.get("ssh_tunnel"): + if not is_feature_enabled("SSH_TUNNELING"): + raise SSHTunnelingNotEnabledError() + if not self._context.get("port"): + raise SSHTunnelDatabasePortError() diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index b891c8f157..88539a2c7b 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -18,6 +18,7 @@ import logging from typing import Any, Optional from flask_appbuilder.models.sqla import Model +from flask_babel import gettext as _ from marshmallow import ValidationError from superset import is_feature_enabled @@ -33,6 +34,7 @@ from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand from superset.commands.database.ssh_tunnel.exceptions import ( SSHTunnelCreateFailedError, + SSHTunnelDatabasePortError, SSHTunnelDeleteFailedError, SSHTunnelingNotEnabledError, SSHTunnelInvalidError, @@ -49,15 +51,19 @@ logger = logging.getLogger(__name__) class UpdateDatabaseCommand(BaseCommand): + _model: Optional[Database] + def __init__(self, model_id: int, data: dict[str, Any]): self._properties = data.copy() self._model_id = model_id - self._model: Optional[Database] = None + self._model = DatabaseDAO.find_by_id(self._model_id) def run(self) -> Model: - self.validate() if not self._model: raise DatabaseNotFoundError() + + self.validate() + old_database_name = self._model.database_name # unmask ``encrypted_extra`` @@ -72,32 +78,34 @@ class UpdateDatabaseCommand(BaseCommand): database = DatabaseDAO.update(self._model, self._properties, commit=False) database.set_sqlalchemy_uri(database.sqlalchemy_uri) - existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id) + ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) if "ssh_tunnel" in self._properties: if not is_feature_enabled("SSH_TUNNELING"): db.session.rollback() raise SSHTunnelingNotEnabledError() - if not self._properties.get("ssh_tunnel") and existing_ssh_tunnel_model: + if not self._properties.get("ssh_tunnel") and ssh_tunnel: # We need to remove the existing tunnel try: - DeleteSSHTunnelCommand(existing_ssh_tunnel_model.id).run() + DeleteSSHTunnelCommand(ssh_tunnel.id).run() + ssh_tunnel = None except SSHTunnelDeleteFailedError as ex: raise ex except Exception as ex: raise DatabaseUpdateFailedError() from ex if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): - if existing_ssh_tunnel_model is None: + if ssh_tunnel is None: # We couldn't found an existing tunnel so we need to create one try: - CreateSSHTunnelCommand( + ssh_tunnel = CreateSSHTunnelCommand( database, ssh_tunnel_properties ).run() except ( SSHTunnelInvalidError, SSHTunnelCreateFailedError, + SSHTunnelDatabasePortError, ) as ex: # So we can show the original message raise ex @@ -106,12 +114,14 @@ class UpdateDatabaseCommand(BaseCommand): else: # We found an existing tunnel so we need to update it try: - UpdateSSHTunnelCommand( - existing_ssh_tunnel_model.id, ssh_tunnel_properties + ssh_tunnel_id = ssh_tunnel.id + ssh_tunnel = UpdateSSHTunnelCommand( + ssh_tunnel_id, ssh_tunnel_properties ).run() except ( SSHTunnelInvalidError, SSHTunnelUpdateFailedError, + SSHTunnelDatabasePortError, ) as ex: # So we can show the original message raise ex @@ -121,7 +131,6 @@ class UpdateDatabaseCommand(BaseCommand): # adding a new database we always want to force refresh schema list # TODO Improve this simplistic implementation for catching DB conn fails try: - ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel) except Exception as ex: db.session.rollback() @@ -189,10 +198,6 @@ class UpdateDatabaseCommand(BaseCommand): def validate(self) -> None: exceptions: list[ValidationError] = [] - # Validate/populate model exists - self._model = DatabaseDAO.find_by_id(self._model_id) - if not self._model: - raise DatabaseNotFoundError() database_name: Optional[str] = self._properties.get("database_name") if database_name: # Check database_name uniqueness diff --git a/superset/databases/api.py b/superset/databases/api.py index 2f95bd0442..e6aca61a20 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -47,6 +47,7 @@ from superset.commands.database.export import ExportDatabasesCommand from superset.commands.database.importers.dispatcher import ImportDatabasesCommand from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand from superset.commands.database.ssh_tunnel.exceptions import ( + SSHTunnelDatabasePortError, SSHTunnelDeleteFailedError, SSHTunnelingNotEnabledError, ) @@ -415,7 +416,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): exc_info=True, ) return self.response_422(message=str(ex)) - except SSHTunnelingNotEnabledError as ex: + except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex: return self.response_400(message=str(ex)) except SupersetException as ex: return self.response(ex.status, message=ex.message) @@ -500,7 +501,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): exc_info=True, ) return self.response_422(message=str(ex)) - except SSHTunnelingNotEnabledError as ex: + except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex: return self.response_400(message=str(ex)) @expose("/<int:pk>", methods=("DELETE",))
