This is an automated email from the ASF dual-hosted git repository.

hugh pushed a commit to branch ssh-create-command
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 1fb44ddf78a30eb8dca8397036f4da6a98f2c332
Author: hughhhh <[email protected]>
AuthorDate: Wed Nov 16 14:15:18 2022 -0500

    save
---
 superset/databases/commands/create.py              |  4 ++
 superset/databases/commands/exceptions.py          |  5 ++
 superset/databases/ssh_tunnel/commands/create.py   | 50 +++++++++++++++
 tests/unit_tests/databases/dao/dao_tests.py        |  7 ++-
 .../databases/ssh_tunnel/commands/create_test.py   | 73 ++++++++++++++++++++++
 .../databases/ssh_tunnel/commands/delete_test.py   |  2 +-
 tests/unit_tests/databases/ssh_tunnel/dao_test.py  | 39 ++++++++++++
 7 files changed, 177 insertions(+), 3 deletions(-)

diff --git a/superset/databases/commands/create.py 
b/superset/databases/commands/create.py
index 4dc8e8eda4..169db649af 100644
--- a/superset/databases/commands/create.py
+++ b/superset/databases/commands/create.py
@@ -71,6 +71,10 @@ class CreateDatabaseCommand(BaseCommand):
             database = DatabaseDAO.create(self._properties, commit=False)
             database.set_sqlalchemy_uri(database.sqlalchemy_uri)
 
+            # create ssh tunnel
+            # if self._properties.get("ssh_tunnel"):
+            #     ssh_tunnel = 
SSHTunnelDAO.create(self._properties["ssh_tunnel"], commit=False)
+
             # adding a new database we always want to force refresh schema list
             schemas = database.get_all_schema_names(cache=False)
             for schema in schemas:
diff --git a/superset/databases/commands/exceptions.py 
b/superset/databases/commands/exceptions.py
index a49abd3449..ce695dae35 100644
--- a/superset/databases/commands/exceptions.py
+++ b/superset/databases/commands/exceptions.py
@@ -176,3 +176,8 @@ class DatabaseOfflineError(SupersetErrorException):
 
 class InvalidParametersError(SupersetErrorsException):
     status = 422
+
+
+class SSHTunnelCreateFailedError(CommandException):
+    status = 500
+    message = _("Creating SSH Tunnel failed for an unknown reason")
diff --git a/superset/databases/ssh_tunnel/commands/create.py 
b/superset/databases/ssh_tunnel/commands/create.py
new file mode 100644
index 0000000000..56eb575fd3
--- /dev/null
+++ b/superset/databases/ssh_tunnel/commands/create.py
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import logging
+from typing import Any, Dict, List, Optional
+
+from flask_appbuilder.models.sqla import Model
+from marshmallow import ValidationError
+
+from superset.commands.base import BaseCommand
+from superset.dao.exceptions import DAOCreateFailedError
+from superset.databases.commands.exceptions import SSHTunnelCreateFailedError
+from superset.databases.commands.test_connection import 
TestConnectionDatabaseCommand
+from superset.databases.ssh_tunnel_dao import SSHTunnelDAO
+from superset.exceptions import SupersetErrorsException
+from superset.extensions import db, event_logger, security_manager
+
+logger = logging.getLogger(__name__)
+
+
+class CreateSSHTunnelCommand(BaseCommand):
+    def __init__(self, database_id: int, data: Dict[str, Any]):
+        self._properties = data.copy()
+        self._properties["database_id"] = database_id
+
+    def run(self) -> Model:
+        self.validate()
+
+        try:
+            tunnel = SSHTunnelDAO.create(self._properties, commit=False)
+        except DAOCreateFailedError as ex:
+            raise SSHTunnelCreateFailedError() from ex
+
+        return tunnel
+
+    def validate(self) -> None:
+        pass
diff --git a/tests/unit_tests/databases/dao/dao_tests.py 
b/tests/unit_tests/databases/dao/dao_tests.py
index a5a828d79d..69a299d5d7 100644
--- a/tests/unit_tests/databases/dao/dao_tests.py
+++ b/tests/unit_tests/databases/dao/dao_tests.py
@@ -20,6 +20,9 @@ from typing import Iterator
 import pytest
 from sqlalchemy.orm.session import Session
 
+from superset.databases.dao import DatabaseDAO
+from superset.databases.ssh_tunnel.models import SSHTunnel
+
 
 @pytest.fixture
 def session_with_data(session: Session) -> Iterator[Session]:
@@ -50,7 +53,7 @@ def session_with_data(session: Session) -> Iterator[Session]:
     session.rollback()
 
 
-def test_database_get_shh_tunnel(session_with_data: Session) -> None:
+def test_database_get_ssh_tunnel(session_with_data: Session) -> None:
     from superset.databases.dao import DatabaseDAO
     from superset.databases.ssh_tunnel.models import SSHTunnel
 
@@ -61,7 +64,7 @@ def test_database_get_shh_tunnel(session_with_data: Session) 
-> None:
     assert 1 == result["ssh_tunnel"].database_id
 
 
-def test_database_get_shh_tunnel_not_found(session_with_data: Session) -> None:
+def test_database_get_ssh_tunnel_not_found(session_with_data: Session) -> None:
     from superset.databases.dao import DatabaseDAO
 
     result = DatabaseDAO.get_ssh_tunnel(2)
diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py 
b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
new file mode 100644
index 0000000000..f472d71689
--- /dev/null
+++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Iterator
+
+import pytest
+from sqlalchemy.orm.session import Session
+
+# @pytest.fixture
+# def session_with_data(session: Session) -> Iterator[Session]:
+#     from superset.connectors.sqla.models import SqlaTable
+#     from superset.databases.ssh_tunnel.models import SSHTunnel
+#     from superset.models.core import Database
+
+#     engine = session.get_bind()
+#     SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member
+
+#     db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+#     sqla_table = SqlaTable(
+#         table_name="my_sqla_table",
+#         columns=[],
+#         metrics=[],
+#         database=db,
+#     )
+#     ssh_tunnel = SSHTunnel(
+#         database_id=db.id,
+#         database=db,
+#     )
+
+#     session.add(db)
+#     session.add(sqla_table)
+#     session.add(ssh_tunnel)
+#     session.flush()
+#     yield session
+#     session.rollback()
+
+
+# def test_create_ssh_tunnel_command(session_with_data: Session) -> None:
+#     from superset.connectors.sqla.models import SqlaTable
+#     from superset.databases.dao import DatabaseDAO
+#     from superset.databases.ssh_tunnel.commands.delete import 
DeleteSSHTunnelCommand
+#     from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
+#     from superset.databases.ssh_tunnel.models import SSHTunnel
+#     from superset.models.core import Database
+
+#     db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+
+#     result = SSHTunnelDAO.create()
+
+#     assert result
+#     assert isinstance(result["ssh_tunnel"], SSHTunnel)
+#     assert 1 == result["ssh_tunnel"].database_id
+
+#     DeleteSSHTunnelCommand(1).run()
+
+#     result = DatabaseDAO.get_ssh_tunnel(1)
+
+#     assert result
+#     assert result["ssh_tunnel"] is None
diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py 
b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py
index b0967228e2..c12b5d7216 100644
--- a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py
+++ b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py
@@ -50,7 +50,7 @@ def session_with_data(session: Session) -> Iterator[Session]:
     session.rollback()
 
 
-def test_delete_shh_tunnel_command(session_with_data: Session) -> None:
+def test_delete_ssh_tunnel_command(session_with_data: Session) -> None:
     from superset.databases.dao import DatabaseDAO
     from superset.databases.ssh_tunnel.commands.delete import 
DeleteSSHTunnelCommand
     from superset.databases.ssh_tunnel.models import SSHTunnel
diff --git a/tests/unit_tests/databases/ssh_tunnel/dao_test.py 
b/tests/unit_tests/databases/ssh_tunnel/dao_test.py
new file mode 100644
index 0000000000..94e09ea07b
--- /dev/null
+++ b/tests/unit_tests/databases/ssh_tunnel/dao_test.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Iterator
+
+import pytest
+
+
+def test_create_ssh_tunnel():
+    from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
+    from superset.models.core import Database
+
+    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+
+    properties = {
+        "database_id": db.id,
+        "server_address": "123.132.123.1",
+        "server_port": "3005",
+        "username": "foo",
+        "password": "bar",
+    }
+
+    result = SSHTunnelDAO.create(properties, commit=True)
+
+    assert result is not None

Reply via email to