This is an automated email from the ASF dual-hosted git repository. hugh pushed a commit to branch ssh-create-command in repository https://gitbox.apache.org/repos/asf/superset.git
commit 1fb44ddf78a30eb8dca8397036f4da6a98f2c332 Author: hughhhh <[email protected]> AuthorDate: Wed Nov 16 14:15:18 2022 -0500 save --- superset/databases/commands/create.py | 4 ++ superset/databases/commands/exceptions.py | 5 ++ superset/databases/ssh_tunnel/commands/create.py | 50 +++++++++++++++ tests/unit_tests/databases/dao/dao_tests.py | 7 ++- .../databases/ssh_tunnel/commands/create_test.py | 73 ++++++++++++++++++++++ .../databases/ssh_tunnel/commands/delete_test.py | 2 +- tests/unit_tests/databases/ssh_tunnel/dao_test.py | 39 ++++++++++++ 7 files changed, 177 insertions(+), 3 deletions(-) diff --git a/superset/databases/commands/create.py b/superset/databases/commands/create.py index 4dc8e8eda4..169db649af 100644 --- a/superset/databases/commands/create.py +++ b/superset/databases/commands/create.py @@ -71,6 +71,10 @@ class CreateDatabaseCommand(BaseCommand): database = DatabaseDAO.create(self._properties, commit=False) database.set_sqlalchemy_uri(database.sqlalchemy_uri) + # create ssh tunnel + # if self._properties.get("ssh_tunnel"): + # ssh_tunnel = SSHTunnelDAO.create(self._properties["ssh_tunnel"], commit=False) + # adding a new database we always want to force refresh schema list schemas = database.get_all_schema_names(cache=False) for schema in schemas: diff --git a/superset/databases/commands/exceptions.py b/superset/databases/commands/exceptions.py index a49abd3449..ce695dae35 100644 --- a/superset/databases/commands/exceptions.py +++ b/superset/databases/commands/exceptions.py @@ -176,3 +176,8 @@ class DatabaseOfflineError(SupersetErrorException): class InvalidParametersError(SupersetErrorsException): status = 422 + + +class SSHTunnelCreateFailedError(CommandException): + status = 500 + message = _("Creating SSH Tunnel failed for an unknown reason") diff --git a/superset/databases/ssh_tunnel/commands/create.py b/superset/databases/ssh_tunnel/commands/create.py new file mode 100644 index 0000000000..56eb575fd3 --- /dev/null +++ b/superset/databases/ssh_tunnel/commands/create.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Any, Dict, List, Optional + +from flask_appbuilder.models.sqla import Model +from marshmallow import ValidationError + +from superset.commands.base import BaseCommand +from superset.dao.exceptions import DAOCreateFailedError +from superset.databases.commands.exceptions import SSHTunnelCreateFailedError +from superset.databases.commands.test_connection import TestConnectionDatabaseCommand +from superset.databases.ssh_tunnel_dao import SSHTunnelDAO +from superset.exceptions import SupersetErrorsException +from superset.extensions import db, event_logger, security_manager + +logger = logging.getLogger(__name__) + + +class CreateSSHTunnelCommand(BaseCommand): + def __init__(self, database_id: int, data: Dict[str, Any]): + self._properties = data.copy() + self._properties["database_id"] = database_id + + def run(self) -> Model: + self.validate() + + try: + tunnel = SSHTunnelDAO.create(self._properties, commit=False) + except DAOCreateFailedError as ex: + raise SSHTunnelCreateFailedError() from ex + + return tunnel + + def validate(self) -> None: + pass diff --git a/tests/unit_tests/databases/dao/dao_tests.py b/tests/unit_tests/databases/dao/dao_tests.py index a5a828d79d..69a299d5d7 100644 --- a/tests/unit_tests/databases/dao/dao_tests.py +++ b/tests/unit_tests/databases/dao/dao_tests.py @@ -20,6 +20,9 @@ from typing import Iterator import pytest from sqlalchemy.orm.session import Session +from superset.databases.dao import DatabaseDAO +from superset.databases.ssh_tunnel.models import SSHTunnel + @pytest.fixture def session_with_data(session: Session) -> Iterator[Session]: @@ -50,7 +53,7 @@ def session_with_data(session: Session) -> Iterator[Session]: session.rollback() -def test_database_get_shh_tunnel(session_with_data: Session) -> None: +def test_database_get_ssh_tunnel(session_with_data: Session) -> None: from superset.databases.dao import DatabaseDAO from superset.databases.ssh_tunnel.models import SSHTunnel @@ -61,7 +64,7 @@ def test_database_get_shh_tunnel(session_with_data: Session) -> None: assert 1 == result["ssh_tunnel"].database_id -def test_database_get_shh_tunnel_not_found(session_with_data: Session) -> None: +def test_database_get_ssh_tunnel_not_found(session_with_data: Session) -> None: from superset.databases.dao import DatabaseDAO result = DatabaseDAO.get_ssh_tunnel(2) diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py new file mode 100644 index 0000000000..f472d71689 --- /dev/null +++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterator + +import pytest +from sqlalchemy.orm.session import Session + +# @pytest.fixture +# def session_with_data(session: Session) -> Iterator[Session]: +# from superset.connectors.sqla.models import SqlaTable +# from superset.databases.ssh_tunnel.models import SSHTunnel +# from superset.models.core import Database + +# engine = session.get_bind() +# SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + +# db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") +# sqla_table = SqlaTable( +# table_name="my_sqla_table", +# columns=[], +# metrics=[], +# database=db, +# ) +# ssh_tunnel = SSHTunnel( +# database_id=db.id, +# database=db, +# ) + +# session.add(db) +# session.add(sqla_table) +# session.add(ssh_tunnel) +# session.flush() +# yield session +# session.rollback() + + +# def test_create_ssh_tunnel_command(session_with_data: Session) -> None: +# from superset.connectors.sqla.models import SqlaTable +# from superset.databases.dao import DatabaseDAO +# from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand +# from superset.databases.ssh_tunnel.dao import SSHTunnelDAO +# from superset.databases.ssh_tunnel.models import SSHTunnel +# from superset.models.core import Database + +# db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + +# result = SSHTunnelDAO.create() + +# assert result +# assert isinstance(result["ssh_tunnel"], SSHTunnel) +# assert 1 == result["ssh_tunnel"].database_id + +# DeleteSSHTunnelCommand(1).run() + +# result = DatabaseDAO.get_ssh_tunnel(1) + +# assert result +# assert result["ssh_tunnel"] is None diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py index b0967228e2..c12b5d7216 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py @@ -50,7 +50,7 @@ def session_with_data(session: Session) -> Iterator[Session]: session.rollback() -def test_delete_shh_tunnel_command(session_with_data: Session) -> None: +def test_delete_ssh_tunnel_command(session_with_data: Session) -> None: from superset.databases.dao import DatabaseDAO from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand from superset.databases.ssh_tunnel.models import SSHTunnel diff --git a/tests/unit_tests/databases/ssh_tunnel/dao_test.py b/tests/unit_tests/databases/ssh_tunnel/dao_test.py new file mode 100644 index 0000000000..94e09ea07b --- /dev/null +++ b/tests/unit_tests/databases/ssh_tunnel/dao_test.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterator + +import pytest + + +def test_create_ssh_tunnel(): + from superset.databases.ssh_tunnel.dao import SSHTunnelDAO + from superset.models.core import Database + + db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + + properties = { + "database_id": db.id, + "server_address": "123.132.123.1", + "server_port": "3005", + "username": "foo", + "password": "bar", + } + + result = SSHTunnelDAO.create(properties, commit=True) + + assert result is not None
