This is an automated email from the ASF dual-hosted git repository.

willbarrett pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 8a3ac70  feat(databases): test connection api (#10723)
8a3ac70 is described below

commit 8a3ac70c0644d8d62aee7f75a54b223d33d751d7
Author: Lily Kuang <[email protected]>
AuthorDate: Wed Sep 9 13:37:48 2020 -0700

    feat(databases): test connection api (#10723)
    
    * test connection api on databases
    
    * update test connection tests
    
    * update database api test and open api description
    
    * moved test connection to commands
    
    * update error message
    
    * fix isort
    
    * fix mypy
    
    * fix black
    
    * fix mypy pre commit
---
 superset/app.py                                |  2 +-
 superset/databases/api.py                      | 99 ++++++++++++++++++++++++--
 superset/databases/commands/exceptions.py      |  5 ++
 superset/databases/commands/test_connection.py | 67 +++++++++++++++++
 superset/databases/dao.py                      | 21 +++++-
 superset/databases/schemas.py                  | 23 +++++-
 superset/tasks/schedules.py                    |  2 +-
 superset/views/base_api.py                     | 25 +++----
 superset/views/core.py                         |  2 +-
 tests/databases/api_tests.py                   | 93 +++++++++++++++++++++++-
 10 files changed, 316 insertions(+), 23 deletions(-)

diff --git a/superset/app.py b/superset/app.py
index 5280922..e073f45 100644
--- a/superset/app.py
+++ b/superset/app.py
@@ -149,8 +149,8 @@ class SupersetAppInitializer:
             AlertLogModelView,
             AlertModelView,
             AlertObservationModelView,
-            ValidatorInlineView,
             SQLObserverInlineView,
+            ValidatorInlineView,
         )
         from superset.views.annotations import (
             AnnotationLayerModelView,
diff --git a/superset/databases/api.py b/superset/databases/api.py
index c3355f5..5a65b84 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -20,8 +20,15 @@ from typing import Any, Optional
 from flask import g, request, Response
 from flask_appbuilder.api import expose, protect, rison, safe
 from flask_appbuilder.models.sqla.interface import SQLAInterface
+from flask_babel import gettext as _
 from marshmallow import ValidationError
-from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError
+from sqlalchemy.engine.url import make_url
+from sqlalchemy.exc import (
+    NoSuchModuleError,
+    NoSuchTableError,
+    OperationalError,
+    SQLAlchemyError,
+)
 
 from superset import event_logger
 from superset.constants import RouteMethod
@@ -33,8 +40,10 @@ from superset.databases.commands.exceptions import (
     DatabaseDeleteFailedError,
     DatabaseInvalidError,
     DatabaseNotFoundError,
+    DatabaseSecurityUnsafeError,
     DatabaseUpdateFailedError,
 )
+from superset.databases.commands.test_connection import 
TestConnectionDatabaseCommand
 from superset.databases.commands.update import UpdateDatabaseCommand
 from superset.databases.dao import DatabaseDAO
 from superset.databases.decorators import check_datasource_access
@@ -44,6 +53,7 @@ from superset.databases.schemas import (
     DatabasePostSchema,
     DatabasePutSchema,
     DatabaseRelatedObjectsResponse,
+    DatabaseTestConnectionSchema,
     SchemasResponseSchema,
     SelectStarResponseSchema,
     TableMetadataResponseSchema,
@@ -65,6 +75,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
         "table_metadata",
         "select_star",
         "schemas",
+        "test_connection",
         "related_objects",
     }
     class_permission_name = "DatabaseView"
@@ -343,7 +354,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
     @rison(database_schemas_query_schema)
     @statsd_metrics
     def schemas(self, pk: int, **kwargs: Any) -> FlaskResponse:
-        """ Get all schemas from a database
+        """Get all schemas from a database
         ---
         get:
           description: Get all schemas from a database
@@ -400,7 +411,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
     def table_metadata(
         self, database: Database, table_name: str, schema_name: str
     ) -> FlaskResponse:
-        """ Table schema info
+        """Table schema info
         ---
         get:
           description: Get database table metadata
@@ -457,7 +468,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
     def select_star(
         self, database: Database, table_name: str, schema_name: Optional[str] 
= None
     ) -> FlaskResponse:
-        """ Table schema info
+        """Table schema info
         ---
         get:
           description: Get database select star for table
@@ -506,6 +517,86 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
         self.incr_stats("success", self.select_star.__name__)
         return self.response(200, result=result)
 
+    @expose("/test_connection", methods=["POST"])
+    @protect()
+    @safe
+    @event_logger.log_this
+    @statsd_metrics
+    def test_connection(  # pylint: disable=too-many-return-statements
+        self,
+    ) -> FlaskResponse:
+        """Tests a database connection
+        ---
+        post:
+          description: >-
+            Tests a database connection
+          requestBody:
+            description: Database schema
+            required: true
+            content:
+              application/json:
+                schema:
+                  type: object
+                  properties:
+                    encrypted_extra:
+                      type: object
+                    extras:
+                      type: object
+                    name:
+                      type: string
+                    server_cert:
+                      type: string
+          responses:
+            200:
+              description: Database Test Connection
+              content:
+                application/json:
+                  schema:
+                    type: object
+                    properties:
+                      message:
+                        type: string
+            400:
+              $ref: '#/components/responses/400'
+            422:
+              $ref: '#/components/responses/422'
+            500:
+              $ref: '#/components/responses/500'
+        """
+        if not request.is_json:
+            return self.response_400(message="Request is not JSON")
+        try:
+            item = DatabaseTestConnectionSchema().load(request.json)
+        # This validates custom Schema with custom validations
+        except ValidationError as error:
+            return self.response_400(message=error.messages)
+        try:
+            TestConnectionDatabaseCommand(g.user, item).run()
+            return self.response(200, message="OK")
+        except (NoSuchModuleError, ModuleNotFoundError):
+            logger.info("Invalid driver")
+            driver_name = make_url(item.get("sqlalchemy_uri")).drivername
+            return self.response(
+                400,
+                message=_(f"Could not load database driver: {driver_name}"),
+                driver_name=driver_name,
+            )
+        except DatabaseSecurityUnsafeError as ex:
+            return self.response_422(message=ex)
+        except OperationalError:
+            logger.warning("Connection failed")
+            return self.response(
+                500,
+                message=_("Connection failed, please check your connection 
settings"),
+            )
+        except Exception as ex:  # pylint: disable=broad-except
+            logger.error("Unexpected error %s", type(ex).__name__)
+            return self.response_400(
+                message=_(
+                    "Unexpected error occurred, please check your logs for 
details"
+                )
+            )
+
     @expose("/<int:pk>/related_objects/", methods=["GET"])
     @protect()
     @safe
diff --git a/superset/databases/commands/exceptions.py 
b/superset/databases/commands/exceptions.py
index 66a3245..51d1660 100644
--- a/superset/databases/commands/exceptions.py
+++ b/superset/databases/commands/exceptions.py
@@ -24,6 +24,7 @@ from superset.commands.exceptions import (
     DeleteFailedError,
     UpdateFailedError,
 )
+from superset.security.analytics_db_safety import DBSecurityException
 
 
 class DatabaseInvalidError(CommandInvalidError):
@@ -109,3 +110,7 @@ class 
DatabaseDeleteDatasetsExistFailedError(DeleteFailedError):
 
 class DatabaseDeleteFailedError(DeleteFailedError):
     message = _("Database could not be deleted.")
+
+
+class DatabaseSecurityUnsafeError(DBSecurityException):
+    message = _("Stopped an unsafe database connection")
diff --git a/superset/databases/commands/test_connection.py 
b/superset/databases/commands/test_connection.py
new file mode 100644
index 0000000..3bcd5b0
--- /dev/null
+++ b/superset/databases/commands/test_connection.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import logging
+from contextlib import closing
+from typing import Any, Dict, Optional
+
+import simplejson as json
+from flask_appbuilder.security.sqla.models import User
+from sqlalchemy import select
+
+from superset.commands.base import BaseCommand
+from superset.databases.commands.exceptions import DatabaseSecurityUnsafeError
+from superset.databases.dao import DatabaseDAO
+from superset.models.core import Database
+from superset.security.analytics_db_safety import DBSecurityException
+
+logger = logging.getLogger(__name__)
+
+
+class TestConnectionDatabaseCommand(BaseCommand):
+    def __init__(self, user: User, data: Dict[str, Any]):
+        self._actor = user
+        self._properties = data.copy()
+        self._model: Optional[Database] = None
+
+    def run(self) -> None:
+        self.validate()
+        try:
+            uri = self._properties.get("sqlalchemy_uri", "")
+            if self._model and uri == self._model.safe_sqlalchemy_uri():
+                uri = self._model.sqlalchemy_uri_decrypted
+
+            database = DatabaseDAO.build_db_for_connection_test(
+                server_cert=self._properties.get("server_cert", ""),
+                extra=json.dumps(self._properties.get("extra", {})),
+                impersonate_user=self._properties.get("impersonate_user", 
False),
+                
encrypted_extra=json.dumps(self._properties.get("encrypted_extra", {})),
+            )
+            if database is not None:
+                database.set_sqlalchemy_uri(uri)
+                database.db_engine_spec.mutate_db_for_connection_test(database)
+                username = self._actor.username if self._actor is not None 
else None
+                engine = database.get_sqla_engine(user_name=username)
+            with closing(engine.connect()) as conn:
+                conn.scalar(select([1]))
+        except DBSecurityException as ex:
+            logger.warning(ex)
+            raise DatabaseSecurityUnsafeError()
+
+    def validate(self) -> None:
+        database_name = self._properties.get("database_name")
+        if database_name is not None:
+            self._model = DatabaseDAO.get_database_by_name(database_name)
diff --git a/superset/databases/dao.py b/superset/databases/dao.py
index 804ac12..2e89ad0 100644
--- a/superset/databases/dao.py
+++ b/superset/databases/dao.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import logging
-from typing import Any, Dict
+from typing import Any, Dict, Optional
 
 from superset.dao.base import BaseDAO
 from superset.databases.filters import DatabaseFilter
@@ -45,6 +45,25 @@ class DatabaseDAO(BaseDAO):
         )
         return not db.session.query(database_query.exists()).scalar()
 
+    @staticmethod
+    def get_database_by_name(database_name: str) -> Optional[Database]:
+        return (
+            db.session.query(Database)
+            .filter(Database.database_name == database_name)
+            .one_or_none()
+        )
+
+    @staticmethod
+    def build_db_for_connection_test(
+        server_cert: str, extra: str, impersonate_user: bool, encrypted_extra: 
str
+    ) -> Optional[Database]:
+        return Database(
+            server_cert=server_cert,
+            extra=extra,
+            impersonate_user=impersonate_user,
+            encrypted_extra=encrypted_extra,
+        )
+
     @classmethod
     def get_related_objects(cls, database_id: int) -> Dict[str, Any]:
         datasets = cls.find_by_id(database_id).tables
diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py
index 2d6779d..859eebb 100644
--- a/superset/databases/schemas.py
+++ b/superset/databases/schemas.py
@@ -17,6 +17,7 @@
 import inspect
 import json
 
+from flask import current_app
 from flask_babel import lazy_gettext as _
 from marshmallow import fields, Schema
 from marshmallow.validate import Length, ValidationError
@@ -24,7 +25,6 @@ from sqlalchemy import MetaData
 from sqlalchemy.engine.url import make_url
 from sqlalchemy.exc import ArgumentError
 
-from superset import app
 from superset.exceptions import CertificateException
 from superset.utils.core import markdown, parse_ssl_cert
 
@@ -142,7 +142,7 @@ def sqlalchemy_uri_validator(value: str) -> str:
                 )
             ]
         )
-    if app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] and value:
+    if current_app.config.get("PREVENT_UNSAFE_DB_CONNECTIONS", True) and value:
         if value.startswith("sqlite"):
             raise ValidationError(
                 [
@@ -291,6 +291,25 @@ class DatabasePutSchema(Schema):
     )
 
 
+class DatabaseTestConnectionSchema(Schema):
+    database_name = fields.String(
+        description=database_name_description, allow_none=True, 
validate=Length(1, 250),
+    )
+    impersonate_user = fields.Boolean(description=impersonate_user_description)
+    extra = fields.String(description=extra_description, 
validate=extra_validator)
+    encrypted_extra = fields.String(
+        description=encrypted_extra_description, 
validate=encrypted_extra_validator
+    )
+    server_cert = fields.String(
+        description=server_cert_description, validate=server_cert_validator
+    )
+    sqlalchemy_uri = fields.String(
+        description=sqlalchemy_uri_description,
+        required=True,
+        validate=[Length(1, 1024), sqlalchemy_uri_validator],
+    )
+
+
 class TableMetadataOptionsResponseSchema(Schema):
     deferrable = fields.Bool()
     initially = fields.Bool()
diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py
index d1fabc9..9643f09 100644
--- a/superset/tasks/schedules.py
+++ b/superset/tasks/schedules.py
@@ -70,8 +70,8 @@ from superset.utils.urls import get_url_path
 
 if TYPE_CHECKING:
     # pylint: disable=unused-import
-    from werkzeug.datastructures import TypeConversionDict
     from flask_appbuilder.security.sqla.models import User
+    from werkzeug.datastructures import TypeConversionDict
 
 # Globals
 config = app.config
diff --git a/superset/views/base_api.py b/superset/views/base_api.py
index 6ee2016..458fa83 100644
--- a/superset/views/base_api.py
+++ b/superset/views/base_api.py
@@ -91,24 +91,25 @@ class BaseSupersetModelRestApi(ModelRestApi):
 
     csrf_exempt = False
     method_permission_name = {
-        "get_list": "list",
-        "get": "show",
+        "bulk_delete": "delete",
+        "data": "list",
+        "delete": "delete",
+        "distinct": "list",
         "export": "mulexport",
+        "get": "show",
+        "get_list": "list",
+        "info": "list",
         "post": "add",
         "put": "edit",
-        "delete": "delete",
-        "bulk_delete": "delete",
-        "info": "list",
-        "related": "list",
-        "distinct": "list",
-        "thumbnail": "list",
         "refresh": "edit",
-        "data": "list",
-        "viz_types": "list",
+        "related": "list",
         "related_objects": "list",
-        "table_metadata": "list",
-        "select_star": "list",
         "schemas": "list",
+        "select_star": "list",
+        "table_metadata": "list",
+        "test_connection": "post",
+        "thumbnail": "list",
+        "viz_types": "list",
     }
 
     order_rel_fields: Dict[str, Tuple[str, str]] = {}
diff --git a/superset/views/core.py b/superset/views/core.py
index a96ce15..1ee501e 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -1162,7 +1162,7 @@ class Superset(BaseSupersetView):  # pylint: 
disable=too-many-public-methods
             logger.warning("Stopped an unsafe database connection")
             return json_error_response(_(str(ex)), 400)
         except Exception as ex:  # pylint: disable=broad-except
-            logger.error("Unexpected error %s", type(ex).__name__)
+            logger.warning("Unexpected error %s", type(ex).__name__)
             return json_error_response(
                 _("Unexpected error occurred, please check your logs for 
details"), 400
             )
diff --git a/tests/databases/api_tests.py b/tests/databases/api_tests.py
index 6d82202..e07b7ac 100644
--- a/tests/databases/api_tests.py
+++ b/tests/databases/api_tests.py
@@ -21,13 +21,13 @@ import json
 import prison
 from sqlalchemy.sql import func
 
-import tests.test_app
 from superset import db, security_manager
 from superset.connectors.sqla.models import SqlaTable
 from superset.models.core import Database
 from superset.utils.core import get_example_database, get_main_database
 from tests.base_tests import SupersetTestCase
 from tests.fixtures.certificates import ssl_certificate
+from tests.test_app import app
 
 
 class TestDatabaseApi(SupersetTestCase):
@@ -652,6 +652,97 @@ class TestDatabaseApi(SupersetTestCase):
         )
         self.assertEqual(rv.status_code, 400)
 
+    def test_test_connection(self):
+        """
+        Database API: Test test connection
+        """
+        # need to temporarily allow sqlite dbs, teardown will undo this
+        app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False
+        self.login("admin")
+        example_db = get_example_database()
+        # validate that the endpoint works with the password-masked sqlalchemy 
uri
+        data = {
+            "sqlalchemy_uri": example_db.safe_sqlalchemy_uri(),
+            "database_name": "examples",
+            "impersonate_user": False,
+        }
+        url = f"api/v1/database/test_connection"
+        rv = self.post_assert_metric(url, data, "test_connection")
+        self.assertEqual(rv.status_code, 200)
+        self.assertEqual(rv.headers["Content-Type"], "application/json; 
charset=utf-8")
+
+        # validate that the endpoint works with the decrypted sqlalchemy uri
+        data = {
+            "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
+            "database_name": "examples",
+            "impersonate_user": False,
+        }
+        rv = self.post_assert_metric(url, data, "test_connection")
+        self.assertEqual(rv.status_code, 200)
+        self.assertEqual(rv.headers["Content-Type"], "application/json; 
charset=utf-8")
+
+    def test_test_connection_failed(self):
+        """
+        Database API: Test test connection failed
+        """
+        self.login("admin")
+
+        data = {
+            "sqlalchemy_uri": "broken://url",
+            "database_name": "examples",
+            "impersonate_user": False,
+        }
+        url = f"api/v1/database/test_connection"
+        rv = self.post_assert_metric(url, data, "test_connection")
+        self.assertEqual(rv.status_code, 400)
+        self.assertEqual(rv.headers["Content-Type"], "application/json; 
charset=utf-8")
+        response = json.loads(rv.data.decode("utf-8"))
+        expected_response = {
+            "driver_name": "broken",
+            "message": "Could not load database driver: broken",
+        }
+        self.assertEqual(response, expected_response)
+
+        data = {
+            "sqlalchemy_uri": "mssql+pymssql://url",
+            "database_name": "examples",
+            "impersonate_user": False,
+        }
+        rv = self.post_assert_metric(url, data, "test_connection")
+        self.assertEqual(rv.status_code, 400)
+        self.assertEqual(rv.headers["Content-Type"], "application/json; 
charset=utf-8")
+        response = json.loads(rv.data.decode("utf-8"))
+        expected_response = {
+            "driver_name": "mssql+pymssql",
+            "message": "Could not load database driver: mssql+pymssql",
+        }
+        self.assertEqual(response, expected_response)
+
+    def test_test_connection_unsafe_uri(self):
+        """
+        Database API: Test test connection with unsafe uri
+        """
+        self.login("admin")
+
+        app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True
+        data = {
+            "sqlalchemy_uri": "sqlite:///home/superset/unsafe.db",
+            "database_name": "unsafe",
+            "impersonate_user": False,
+        }
+        url = f"api/v1/database/test_connection"
+        rv = self.post_assert_metric(url, data, "test_connection")
+        self.assertEqual(rv.status_code, 400)
+        response = json.loads(rv.data.decode("utf-8"))
+        expected_response = {
+            "message": {
+                "sqlalchemy_uri": [
+                    "SQLite database cannot be used as a data source for 
security reasons."
+                ]
+            }
+        }
+        self.assertEqual(response, expected_response)
+
     def test_get_database_related_objects(self):
         """
         Database API: Test get chart and dashboard count related to a database

Reply via email to