This is an automated email from the ASF dual-hosted git repository. hugh pushed a commit to branch test-ssh-tunnel-1 in repository https://gitbox.apache.org/repos/asf/superset.git
commit b2cc56471a35383aa9d5e4247417bf67d1e0b412 Author: Antonio Rivero <[email protected]> AuthorDate: Fri Nov 25 08:32:21 2022 -0300 SSH Tunnel: - Uncomment ssh tunnel creation for TestConnection command - Update our tests - Add update changes in API --- superset/databases/commands/test_connection.py | 10 +- superset/databases/commands/update.py | 27 ++++- superset/databases/schemas.py | 1 + tests/integration_tests/databases/api_tests.py | 140 ++++++++++++++++++++++++- 4 files changed, 170 insertions(+), 8 deletions(-) diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 7b913ed202..098c24ff11 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -32,6 +32,7 @@ from superset.databases.commands.exceptions import ( DatabaseTestConnectionUnexpectedError, ) from superset.databases.dao import DatabaseDAO +from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetErrorType from superset.exceptions import ( @@ -90,10 +91,13 @@ class TestConnectionDatabaseCommand(BaseCommand): database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) - # TODO: (hughhh) uncomment in API enablement PR + # Generate tunnel if present in the properties ssh_tunnel = None - # if self._properties.get("ssh_tunnel"): - # ssh_tunnel = SSHTunnel(**self._properties["ssh_tunnel"]) + if ssh_tunnel := self._properties.get("ssh_tunnel"): + url = make_url_safe(database.sqlalchemy_uri_decrypted) + ssh_tunnel["bind_host"] = url.host + ssh_tunnel["bind_port"] = url.port + ssh_tunnel = SSHTunnel(**ssh_tunnel) event_logger.log_with_context( action="test_connection_attempt", diff --git a/superset/databases/commands/update.py b/superset/databases/commands/update.py index 80e3a9b54e..b662ab3546 100644 --- a/superset/databases/commands/update.py +++ b/superset/databases/commands/update.py @@ -21,7 +21,7 @@ from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError from superset.commands.base import BaseCommand -from superset.dao.exceptions import DAOUpdateFailedError +from superset.dao.exceptions import DAOCreateFailedError, DAOUpdateFailedError from superset.databases.commands.exceptions import ( DatabaseConnectionFailedError, DatabaseExistsValidationError, @@ -30,6 +30,7 @@ from superset.databases.commands.exceptions import ( DatabaseUpdateFailedError, ) from superset.databases.dao import DatabaseDAO +from superset.databases.ssh_tunnel.dao import SSHTunnelDAO from superset.extensions import db, security_manager from superset.models.core import Database from superset.utils.core import DatasourceType @@ -94,11 +95,35 @@ class UpdateDatabaseCommand(BaseCommand): security_manager.add_permission_view_menu( "schema_access", security_manager.get_schema_perm(database, schema) ) + + if self._properties.get("ssh_tunnel"): + existing_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) + if existing_ssh_tunnel is None: + # We couldn't found an existing tunnel so we need to create one + SSHTunnelDAO.create( + { + **self._properties.get("ssh_tunnel"), + "database_id": database.id, + }, + commit=False, + ) + else: + # We found an existing tunnel so we need to update it + ssh_tunnel_model = SSHTunnelDAO.find_by_id(existing_ssh_tunnel.id) + SSHTunnelDAO.update( + ssh_tunnel_model, + self._properties.get("ssh_tunnel"), + commit=False, + ) + db.session.commit() except DAOUpdateFailedError as ex: logger.exception(ex.exception) raise DatabaseUpdateFailedError() from ex + except DAOCreateFailedError as ex: + logger.exception(ex.exception) + raise DatabaseUpdateFailedError() from ex return database @staticmethod diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 9c2de0b18e..ff134eb0af 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -475,6 +475,7 @@ class DatabasePutSchema(Schema, DatabaseParametersSchemaMixin): ) is_managed_externally = fields.Boolean(allow_none=True, default=False) external_url = fields.String(allow_none=True) + ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True) class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin): diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 7821900c87..2dd5559219 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -35,6 +35,7 @@ from sqlalchemy.sql import func from superset import db, security_manager from superset.connectors.sqla.models import SqlaTable +from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.db_engine_specs.postgres import PostgresEngineSpec @@ -280,7 +281,12 @@ class TestDatabaseApi(SupersetTestCase): db.session.delete(model) db.session.commit() - def test_create_database_with_ssh_tunnel(self): + @mock.patch( + "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", + ) + def test_create_database_with_ssh_tunnel( + self, mock_test_connection_database_command_run + ): """ Database API: Test create with SSH Tunnel """ @@ -290,13 +296,12 @@ class TestDatabaseApi(SupersetTestCase): return ssh_tunnel_properties = { "server_address": "123.132.123.1", - "bind_host": "localhost", - "bind_port": "5432", + "server_port": 8080, "username": "foo", "password": "bar", } database_data = { - "database_name": "test-create-db-with-ssh-tunnel", + "database_name": "test-db-with-ssh-tunnel", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "ssh_tunnel": ssh_tunnel_properties, } @@ -305,11 +310,138 @@ class TestDatabaseApi(SupersetTestCase): rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 201) + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response.get("id")) + .one() + ) + self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) # Cleanup model = db.session.query(Database).get(response.get("id")) db.session.delete(model) db.session.commit() + @mock.patch( + "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch( + "superset.models.core.Database.get_all_schema_names", + ) + def test_update_database_with_ssh_tunnel( + self, mock_test_connection_database_command_run, mock_get_all_schema_names + ): + """ + Database API: Test update with SSH Tunnel + """ + self.login(username="admin") + example_db = get_example_database() + if example_db.backend == "sqlite": + return + ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + database_data = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + } + database_data_with_ssh_tunnel = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": ssh_tunnel_properties, + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + + uri = "api/v1/database/{}".format(response.get("id")) + rv = self.client.put(uri, json=database_data_with_ssh_tunnel) + response_update = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response_update.get("id")) + .one() + ) + self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id")) + # Cleanup + model = db.session.query(Database).get(response.get("id")) + db.session.delete(model) + db.session.commit() + + @mock.patch( + "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", + ) + def test_update_ssh_tunnel_via_database_api( + self, mock_test_connection_database_command_run + ): + """ + Database API: Test update with SSH Tunnel + """ + self.login(username="admin") + example_db = get_example_database() + + if example_db.backend == "sqlite": + return + initial_ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + updated_ssh_tunnel_properties = { + "server_address": "123.132.123.2", + "server_port": 8081, + "username": "Test", + "password": "bar", + } + database_data_with_ssh_tunnel = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": initial_ssh_tunnel_properties, + } + database_data_with_ssh_tunnel_update = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": updated_ssh_tunnel_properties, + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data_with_ssh_tunnel) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response.get("id")) + .one() + ) + self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) + self.assertEqual(model_ssh_tunnel.username, "foo") + with mock.patch( + "superset.models.core.Database.get_all_schema_names", + return_value=["information_schema", "public"], + ): + uri = "api/v1/database/{}".format(response.get("id")) + rv = self.client.put(uri, json=database_data_with_ssh_tunnel_update) + response_update = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response_update.get("id")) + .one() + ) + self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id")) + self.assertEqual(model_ssh_tunnel.username, "Test") + # Cleanup + model = db.session.query(Database).get(response.get("id")) + db.session.delete(model) + db.session.commit() + def test_create_database_invalid_configuration_method(self): """ Database API: Test create with an invalid configuration method.
