This is an automated email from the ASF dual-hosted git repository.

hugh pushed a commit to branch test-ssh-tunnel-1
in repository https://gitbox.apache.org/repos/asf/superset.git

commit b2cc56471a35383aa9d5e4247417bf67d1e0b412
Author: Antonio Rivero <[email protected]>
AuthorDate: Fri Nov 25 08:32:21 2022 -0300

    SSH Tunnel:
    
    - Uncomment ssh tunnel creation for TestConnection command
    - Update our tests
    - Add update changes in API
---
 superset/databases/commands/test_connection.py |  10 +-
 superset/databases/commands/update.py          |  27 ++++-
 superset/databases/schemas.py                  |   1 +
 tests/integration_tests/databases/api_tests.py | 140 ++++++++++++++++++++++++-
 4 files changed, 170 insertions(+), 8 deletions(-)

diff --git a/superset/databases/commands/test_connection.py 
b/superset/databases/commands/test_connection.py
index 7b913ed202..098c24ff11 100644
--- a/superset/databases/commands/test_connection.py
+++ b/superset/databases/commands/test_connection.py
@@ -32,6 +32,7 @@ from superset.databases.commands.exceptions import (
     DatabaseTestConnectionUnexpectedError,
 )
 from superset.databases.dao import DatabaseDAO
+from superset.databases.ssh_tunnel.models import SSHTunnel
 from superset.databases.utils import make_url_safe
 from superset.errors import ErrorLevel, SupersetErrorType
 from superset.exceptions import (
@@ -90,10 +91,13 @@ class TestConnectionDatabaseCommand(BaseCommand):
             database.set_sqlalchemy_uri(uri)
             database.db_engine_spec.mutate_db_for_connection_test(database)
 
-            # TODO: (hughhh) uncomment in API enablement PR
+            # Generate tunnel if present in the properties
             ssh_tunnel = None
-            # if self._properties.get("ssh_tunnel"):
-            #     ssh_tunnel = SSHTunnel(**self._properties["ssh_tunnel"])
+            if ssh_tunnel := self._properties.get("ssh_tunnel"):
+                url = make_url_safe(database.sqlalchemy_uri_decrypted)
+                ssh_tunnel["bind_host"] = url.host
+                ssh_tunnel["bind_port"] = url.port
+                ssh_tunnel = SSHTunnel(**ssh_tunnel)
 
             event_logger.log_with_context(
                 action="test_connection_attempt",
diff --git a/superset/databases/commands/update.py 
b/superset/databases/commands/update.py
index 80e3a9b54e..b662ab3546 100644
--- a/superset/databases/commands/update.py
+++ b/superset/databases/commands/update.py
@@ -21,7 +21,7 @@ from flask_appbuilder.models.sqla import Model
 from marshmallow import ValidationError
 
 from superset.commands.base import BaseCommand
-from superset.dao.exceptions import DAOUpdateFailedError
+from superset.dao.exceptions import DAOCreateFailedError, DAOUpdateFailedError
 from superset.databases.commands.exceptions import (
     DatabaseConnectionFailedError,
     DatabaseExistsValidationError,
@@ -30,6 +30,7 @@ from superset.databases.commands.exceptions import (
     DatabaseUpdateFailedError,
 )
 from superset.databases.dao import DatabaseDAO
+from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
 from superset.extensions import db, security_manager
 from superset.models.core import Database
 from superset.utils.core import DatasourceType
@@ -94,11 +95,35 @@ class UpdateDatabaseCommand(BaseCommand):
                 security_manager.add_permission_view_menu(
                     "schema_access", 
security_manager.get_schema_perm(database, schema)
                 )
+
+            if self._properties.get("ssh_tunnel"):
+                existing_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
+                if existing_ssh_tunnel is None:
+                    # We couldn't found an existing tunnel so we need to 
create one
+                    SSHTunnelDAO.create(
+                        {
+                            **self._properties.get("ssh_tunnel"),
+                            "database_id": database.id,
+                        },
+                        commit=False,
+                    )
+                else:
+                    # We found an existing tunnel so we need to update it
+                    ssh_tunnel_model = 
SSHTunnelDAO.find_by_id(existing_ssh_tunnel.id)
+                    SSHTunnelDAO.update(
+                        ssh_tunnel_model,
+                        self._properties.get("ssh_tunnel"),
+                        commit=False,
+                    )
+
             db.session.commit()
 
         except DAOUpdateFailedError as ex:
             logger.exception(ex.exception)
             raise DatabaseUpdateFailedError() from ex
+        except DAOCreateFailedError as ex:
+            logger.exception(ex.exception)
+            raise DatabaseUpdateFailedError() from ex
         return database
 
     @staticmethod
diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py
index 9c2de0b18e..ff134eb0af 100644
--- a/superset/databases/schemas.py
+++ b/superset/databases/schemas.py
@@ -475,6 +475,7 @@ class DatabasePutSchema(Schema, 
DatabaseParametersSchemaMixin):
     )
     is_managed_externally = fields.Boolean(allow_none=True, default=False)
     external_url = fields.String(allow_none=True)
+    ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
 
 
 class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin):
diff --git a/tests/integration_tests/databases/api_tests.py 
b/tests/integration_tests/databases/api_tests.py
index 7821900c87..2dd5559219 100644
--- a/tests/integration_tests/databases/api_tests.py
+++ b/tests/integration_tests/databases/api_tests.py
@@ -35,6 +35,7 @@ from sqlalchemy.sql import func
 
 from superset import db, security_manager
 from superset.connectors.sqla.models import SqlaTable
+from superset.databases.ssh_tunnel.models import SSHTunnel
 from superset.databases.utils import make_url_safe
 from superset.db_engine_specs.mysql import MySQLEngineSpec
 from superset.db_engine_specs.postgres import PostgresEngineSpec
@@ -280,7 +281,12 @@ class TestDatabaseApi(SupersetTestCase):
         db.session.delete(model)
         db.session.commit()
 
-    def test_create_database_with_ssh_tunnel(self):
+    @mock.patch(
+        
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
+    )
+    def test_create_database_with_ssh_tunnel(
+        self, mock_test_connection_database_command_run
+    ):
         """
         Database API: Test create with SSH Tunnel
         """
@@ -290,13 +296,12 @@ class TestDatabaseApi(SupersetTestCase):
             return
         ssh_tunnel_properties = {
             "server_address": "123.132.123.1",
-            "bind_host": "localhost",
-            "bind_port": "5432",
+            "server_port": 8080,
             "username": "foo",
             "password": "bar",
         }
         database_data = {
-            "database_name": "test-create-db-with-ssh-tunnel",
+            "database_name": "test-db-with-ssh-tunnel",
             "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
             "ssh_tunnel": ssh_tunnel_properties,
         }
@@ -305,11 +310,138 @@ class TestDatabaseApi(SupersetTestCase):
         rv = self.client.post(uri, json=database_data)
         response = json.loads(rv.data.decode("utf-8"))
         self.assertEqual(rv.status_code, 201)
+        model_ssh_tunnel = (
+            db.session.query(SSHTunnel)
+            .filter(SSHTunnel.database_id == response.get("id"))
+            .one()
+        )
+        self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
         # Cleanup
         model = db.session.query(Database).get(response.get("id"))
         db.session.delete(model)
         db.session.commit()
 
+    @mock.patch(
+        
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
+    )
+    @mock.patch(
+        "superset.models.core.Database.get_all_schema_names",
+    )
+    def test_update_database_with_ssh_tunnel(
+        self, mock_test_connection_database_command_run, 
mock_get_all_schema_names
+    ):
+        """
+        Database API: Test update with SSH Tunnel
+        """
+        self.login(username="admin")
+        example_db = get_example_database()
+        if example_db.backend == "sqlite":
+            return
+        ssh_tunnel_properties = {
+            "server_address": "123.132.123.1",
+            "server_port": 8080,
+            "username": "foo",
+            "password": "bar",
+        }
+        database_data = {
+            "database_name": "test-db-with-ssh-tunnel",
+            "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
+        }
+        database_data_with_ssh_tunnel = {
+            "database_name": "test-db-with-ssh-tunnel",
+            "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
+            "ssh_tunnel": ssh_tunnel_properties,
+        }
+
+        uri = "api/v1/database/"
+        rv = self.client.post(uri, json=database_data)
+        response = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(rv.status_code, 201)
+
+        uri = "api/v1/database/{}".format(response.get("id"))
+        rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
+        response_update = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(rv.status_code, 200)
+
+        model_ssh_tunnel = (
+            db.session.query(SSHTunnel)
+            .filter(SSHTunnel.database_id == response_update.get("id"))
+            .one()
+        )
+        self.assertEqual(model_ssh_tunnel.database_id, 
response_update.get("id"))
+        # Cleanup
+        model = db.session.query(Database).get(response.get("id"))
+        db.session.delete(model)
+        db.session.commit()
+
+    @mock.patch(
+        
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
+    )
+    def test_update_ssh_tunnel_via_database_api(
+        self, mock_test_connection_database_command_run
+    ):
+        """
+        Database API: Test update with SSH Tunnel
+        """
+        self.login(username="admin")
+        example_db = get_example_database()
+
+        if example_db.backend == "sqlite":
+            return
+        initial_ssh_tunnel_properties = {
+            "server_address": "123.132.123.1",
+            "server_port": 8080,
+            "username": "foo",
+            "password": "bar",
+        }
+        updated_ssh_tunnel_properties = {
+            "server_address": "123.132.123.2",
+            "server_port": 8081,
+            "username": "Test",
+            "password": "bar",
+        }
+        database_data_with_ssh_tunnel = {
+            "database_name": "test-db-with-ssh-tunnel",
+            "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
+            "ssh_tunnel": initial_ssh_tunnel_properties,
+        }
+        database_data_with_ssh_tunnel_update = {
+            "database_name": "test-db-with-ssh-tunnel",
+            "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
+            "ssh_tunnel": updated_ssh_tunnel_properties,
+        }
+
+        uri = "api/v1/database/"
+        rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
+        response = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(rv.status_code, 201)
+        model_ssh_tunnel = (
+            db.session.query(SSHTunnel)
+            .filter(SSHTunnel.database_id == response.get("id"))
+            .one()
+        )
+        self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
+        self.assertEqual(model_ssh_tunnel.username, "foo")
+        with mock.patch(
+            "superset.models.core.Database.get_all_schema_names",
+            return_value=["information_schema", "public"],
+        ):
+            uri = "api/v1/database/{}".format(response.get("id"))
+            rv = self.client.put(uri, 
json=database_data_with_ssh_tunnel_update)
+            response_update = json.loads(rv.data.decode("utf-8"))
+            self.assertEqual(rv.status_code, 200)
+            model_ssh_tunnel = (
+                db.session.query(SSHTunnel)
+                .filter(SSHTunnel.database_id == response_update.get("id"))
+                .one()
+            )
+            self.assertEqual(model_ssh_tunnel.database_id, 
response_update.get("id"))
+            self.assertEqual(model_ssh_tunnel.username, "Test")
+            # Cleanup
+            model = db.session.query(Database).get(response.get("id"))
+            db.session.delete(model)
+            db.session.commit()
+
     def test_create_database_invalid_configuration_method(self):
         """
         Database API: Test create with an invalid configuration method.

Reply via email to