This is an automated email from the ASF dual-hosted git repository.
willbarrett pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 8a3ac70 feat(databases): test connection api (#10723)
8a3ac70 is described below
commit 8a3ac70c0644d8d62aee7f75a54b223d33d751d7
Author: Lily Kuang <[email protected]>
AuthorDate: Wed Sep 9 13:37:48 2020 -0700
feat(databases): test connection api (#10723)
* test connection api on databases
* update test connection tests
* update database api test and open api description
* moved test connection to commands
* update error message
* fix isort
* fix mypy
* fix black
* fix mypy pre commit
---
superset/app.py | 2 +-
superset/databases/api.py | 99 ++++++++++++++++++++++++--
superset/databases/commands/exceptions.py | 5 ++
superset/databases/commands/test_connection.py | 67 +++++++++++++++++
superset/databases/dao.py | 21 +++++-
superset/databases/schemas.py | 23 +++++-
superset/tasks/schedules.py | 2 +-
superset/views/base_api.py | 25 +++----
superset/views/core.py | 2 +-
tests/databases/api_tests.py | 93 +++++++++++++++++++++++-
10 files changed, 316 insertions(+), 23 deletions(-)
diff --git a/superset/app.py b/superset/app.py
index 5280922..e073f45 100644
--- a/superset/app.py
+++ b/superset/app.py
@@ -149,8 +149,8 @@ class SupersetAppInitializer:
AlertLogModelView,
AlertModelView,
AlertObservationModelView,
- ValidatorInlineView,
SQLObserverInlineView,
+ ValidatorInlineView,
)
from superset.views.annotations import (
AnnotationLayerModelView,
diff --git a/superset/databases/api.py b/superset/databases/api.py
index c3355f5..5a65b84 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -20,8 +20,15 @@ from typing import Any, Optional
from flask import g, request, Response
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
+from flask_babel import gettext as _
from marshmallow import ValidationError
-from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError
+from sqlalchemy.engine.url import make_url
+from sqlalchemy.exc import (
+ NoSuchModuleError,
+ NoSuchTableError,
+ OperationalError,
+ SQLAlchemyError,
+)
from superset import event_logger
from superset.constants import RouteMethod
@@ -33,8 +40,10 @@ from superset.databases.commands.exceptions import (
DatabaseDeleteFailedError,
DatabaseInvalidError,
DatabaseNotFoundError,
+ DatabaseSecurityUnsafeError,
DatabaseUpdateFailedError,
)
+from superset.databases.commands.test_connection import
TestConnectionDatabaseCommand
from superset.databases.commands.update import UpdateDatabaseCommand
from superset.databases.dao import DatabaseDAO
from superset.databases.decorators import check_datasource_access
@@ -44,6 +53,7 @@ from superset.databases.schemas import (
DatabasePostSchema,
DatabasePutSchema,
DatabaseRelatedObjectsResponse,
+ DatabaseTestConnectionSchema,
SchemasResponseSchema,
SelectStarResponseSchema,
TableMetadataResponseSchema,
@@ -65,6 +75,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"table_metadata",
"select_star",
"schemas",
+ "test_connection",
"related_objects",
}
class_permission_name = "DatabaseView"
@@ -343,7 +354,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
@rison(database_schemas_query_schema)
@statsd_metrics
def schemas(self, pk: int, **kwargs: Any) -> FlaskResponse:
- """ Get all schemas from a database
+ """Get all schemas from a database
---
get:
description: Get all schemas from a database
@@ -400,7 +411,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
def table_metadata(
self, database: Database, table_name: str, schema_name: str
) -> FlaskResponse:
- """ Table schema info
+ """Table schema info
---
get:
description: Get database table metadata
@@ -457,7 +468,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
def select_star(
self, database: Database, table_name: str, schema_name: Optional[str]
= None
) -> FlaskResponse:
- """ Table schema info
+ """Table schema info
---
get:
description: Get database select star for table
@@ -506,6 +517,86 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
self.incr_stats("success", self.select_star.__name__)
return self.response(200, result=result)
+ @expose("/test_connection", methods=["POST"])
+ @protect()
+ @safe
+ @event_logger.log_this
+ @statsd_metrics
+ def test_connection( # pylint: disable=too-many-return-statements
+ self,
+ ) -> FlaskResponse:
+ """Tests a database connection
+ ---
+ post:
+ description: >-
+ Tests a database connection
+ requestBody:
+ description: Database schema
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ encrypted_extra:
+ type: object
+ extras:
+ type: object
+ name:
+ type: string
+ server_cert:
+ type: string
+ responses:
+ 200:
+ description: Database Test Connection
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ message:
+ type: string
+ 400:
+ $ref: '#/components/responses/400'
+ 422:
+ $ref: '#/components/responses/422'
+ 500:
+ $ref: '#/components/responses/500'
+ """
+ if not request.is_json:
+ return self.response_400(message="Request is not JSON")
+ try:
+ item = DatabaseTestConnectionSchema().load(request.json)
+ # This validates custom Schema with custom validations
+ except ValidationError as error:
+ return self.response_400(message=error.messages)
+ try:
+ TestConnectionDatabaseCommand(g.user, item).run()
+ return self.response(200, message="OK")
+ except (NoSuchModuleError, ModuleNotFoundError):
+ logger.info("Invalid driver")
+ driver_name = make_url(item.get("sqlalchemy_uri")).drivername
+ return self.response(
+ 400,
+ message=_(f"Could not load database driver: {driver_name}"),
+ driver_name=driver_name,
+ )
+ except DatabaseSecurityUnsafeError as ex:
+ return self.response_422(message=ex)
+ except OperationalError:
+ logger.warning("Connection failed")
+ return self.response(
+ 500,
+ message=_("Connection failed, please check your connection
settings"),
+ )
+ except Exception as ex: # pylint: disable=broad-except
+ logger.error("Unexpected error %s", type(ex).__name__)
+ return self.response_400(
+ message=_(
+ "Unexpected error occurred, please check your logs for
details"
+ )
+ )
+
@expose("/<int:pk>/related_objects/", methods=["GET"])
@protect()
@safe
diff --git a/superset/databases/commands/exceptions.py
b/superset/databases/commands/exceptions.py
index 66a3245..51d1660 100644
--- a/superset/databases/commands/exceptions.py
+++ b/superset/databases/commands/exceptions.py
@@ -24,6 +24,7 @@ from superset.commands.exceptions import (
DeleteFailedError,
UpdateFailedError,
)
+from superset.security.analytics_db_safety import DBSecurityException
class DatabaseInvalidError(CommandInvalidError):
@@ -109,3 +110,7 @@ class
DatabaseDeleteDatasetsExistFailedError(DeleteFailedError):
class DatabaseDeleteFailedError(DeleteFailedError):
message = _("Database could not be deleted.")
+
+
+class DatabaseSecurityUnsafeError(DBSecurityException):
+ message = _("Stopped an unsafe database connection")
diff --git a/superset/databases/commands/test_connection.py
b/superset/databases/commands/test_connection.py
new file mode 100644
index 0000000..3bcd5b0
--- /dev/null
+++ b/superset/databases/commands/test_connection.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import logging
+from contextlib import closing
+from typing import Any, Dict, Optional
+
+import simplejson as json
+from flask_appbuilder.security.sqla.models import User
+from sqlalchemy import select
+
+from superset.commands.base import BaseCommand
+from superset.databases.commands.exceptions import DatabaseSecurityUnsafeError
+from superset.databases.dao import DatabaseDAO
+from superset.models.core import Database
+from superset.security.analytics_db_safety import DBSecurityException
+
+logger = logging.getLogger(__name__)
+
+
+class TestConnectionDatabaseCommand(BaseCommand):
+ def __init__(self, user: User, data: Dict[str, Any]):
+ self._actor = user
+ self._properties = data.copy()
+ self._model: Optional[Database] = None
+
+ def run(self) -> None:
+ self.validate()
+ try:
+ uri = self._properties.get("sqlalchemy_uri", "")
+ if self._model and uri == self._model.safe_sqlalchemy_uri():
+ uri = self._model.sqlalchemy_uri_decrypted
+
+ database = DatabaseDAO.build_db_for_connection_test(
+ server_cert=self._properties.get("server_cert", ""),
+ extra=json.dumps(self._properties.get("extra", {})),
+ impersonate_user=self._properties.get("impersonate_user",
False),
+
encrypted_extra=json.dumps(self._properties.get("encrypted_extra", {})),
+ )
+ if database is not None:
+ database.set_sqlalchemy_uri(uri)
+ database.db_engine_spec.mutate_db_for_connection_test(database)
+ username = self._actor.username if self._actor is not None
else None
+ engine = database.get_sqla_engine(user_name=username)
+ with closing(engine.connect()) as conn:
+ conn.scalar(select([1]))
+ except DBSecurityException as ex:
+ logger.warning(ex)
+ raise DatabaseSecurityUnsafeError()
+
+ def validate(self) -> None:
+ database_name = self._properties.get("database_name")
+ if database_name is not None:
+ self._model = DatabaseDAO.get_database_by_name(database_name)
diff --git a/superset/databases/dao.py b/superset/databases/dao.py
index 804ac12..2e89ad0 100644
--- a/superset/databases/dao.py
+++ b/superset/databases/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any, Dict, Optional
from superset.dao.base import BaseDAO
from superset.databases.filters import DatabaseFilter
@@ -45,6 +45,25 @@ class DatabaseDAO(BaseDAO):
)
return not db.session.query(database_query.exists()).scalar()
+ @staticmethod
+ def get_database_by_name(database_name: str) -> Optional[Database]:
+ return (
+ db.session.query(Database)
+ .filter(Database.database_name == database_name)
+ .one_or_none()
+ )
+
+ @staticmethod
+ def build_db_for_connection_test(
+ server_cert: str, extra: str, impersonate_user: bool, encrypted_extra:
str
+ ) -> Optional[Database]:
+ return Database(
+ server_cert=server_cert,
+ extra=extra,
+ impersonate_user=impersonate_user,
+ encrypted_extra=encrypted_extra,
+ )
+
@classmethod
def get_related_objects(cls, database_id: int) -> Dict[str, Any]:
datasets = cls.find_by_id(database_id).tables
diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py
index 2d6779d..859eebb 100644
--- a/superset/databases/schemas.py
+++ b/superset/databases/schemas.py
@@ -17,6 +17,7 @@
import inspect
import json
+from flask import current_app
from flask_babel import lazy_gettext as _
from marshmallow import fields, Schema
from marshmallow.validate import Length, ValidationError
@@ -24,7 +25,6 @@ from sqlalchemy import MetaData
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import ArgumentError
-from superset import app
from superset.exceptions import CertificateException
from superset.utils.core import markdown, parse_ssl_cert
@@ -142,7 +142,7 @@ def sqlalchemy_uri_validator(value: str) -> str:
)
]
)
- if app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] and value:
+ if current_app.config.get("PREVENT_UNSAFE_DB_CONNECTIONS", True) and value:
if value.startswith("sqlite"):
raise ValidationError(
[
@@ -291,6 +291,25 @@ class DatabasePutSchema(Schema):
)
+class DatabaseTestConnectionSchema(Schema):
+ database_name = fields.String(
+ description=database_name_description, allow_none=True,
validate=Length(1, 250),
+ )
+ impersonate_user = fields.Boolean(description=impersonate_user_description)
+ extra = fields.String(description=extra_description,
validate=extra_validator)
+ encrypted_extra = fields.String(
+ description=encrypted_extra_description,
validate=encrypted_extra_validator
+ )
+ server_cert = fields.String(
+ description=server_cert_description, validate=server_cert_validator
+ )
+ sqlalchemy_uri = fields.String(
+ description=sqlalchemy_uri_description,
+ required=True,
+ validate=[Length(1, 1024), sqlalchemy_uri_validator],
+ )
+
+
class TableMetadataOptionsResponseSchema(Schema):
deferrable = fields.Bool()
initially = fields.Bool()
diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py
index d1fabc9..9643f09 100644
--- a/superset/tasks/schedules.py
+++ b/superset/tasks/schedules.py
@@ -70,8 +70,8 @@ from superset.utils.urls import get_url_path
if TYPE_CHECKING:
# pylint: disable=unused-import
- from werkzeug.datastructures import TypeConversionDict
from flask_appbuilder.security.sqla.models import User
+ from werkzeug.datastructures import TypeConversionDict
# Globals
config = app.config
diff --git a/superset/views/base_api.py b/superset/views/base_api.py
index 6ee2016..458fa83 100644
--- a/superset/views/base_api.py
+++ b/superset/views/base_api.py
@@ -91,24 +91,25 @@ class BaseSupersetModelRestApi(ModelRestApi):
csrf_exempt = False
method_permission_name = {
- "get_list": "list",
- "get": "show",
+ "bulk_delete": "delete",
+ "data": "list",
+ "delete": "delete",
+ "distinct": "list",
"export": "mulexport",
+ "get": "show",
+ "get_list": "list",
+ "info": "list",
"post": "add",
"put": "edit",
- "delete": "delete",
- "bulk_delete": "delete",
- "info": "list",
- "related": "list",
- "distinct": "list",
- "thumbnail": "list",
"refresh": "edit",
- "data": "list",
- "viz_types": "list",
+ "related": "list",
"related_objects": "list",
- "table_metadata": "list",
- "select_star": "list",
"schemas": "list",
+ "select_star": "list",
+ "table_metadata": "list",
+ "test_connection": "post",
+ "thumbnail": "list",
+ "viz_types": "list",
}
order_rel_fields: Dict[str, Tuple[str, str]] = {}
diff --git a/superset/views/core.py b/superset/views/core.py
index a96ce15..1ee501e 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -1162,7 +1162,7 @@ class Superset(BaseSupersetView): # pylint:
disable=too-many-public-methods
logger.warning("Stopped an unsafe database connection")
return json_error_response(_(str(ex)), 400)
except Exception as ex: # pylint: disable=broad-except
- logger.error("Unexpected error %s", type(ex).__name__)
+ logger.warning("Unexpected error %s", type(ex).__name__)
return json_error_response(
_("Unexpected error occurred, please check your logs for
details"), 400
)
diff --git a/tests/databases/api_tests.py b/tests/databases/api_tests.py
index 6d82202..e07b7ac 100644
--- a/tests/databases/api_tests.py
+++ b/tests/databases/api_tests.py
@@ -21,13 +21,13 @@ import json
import prison
from sqlalchemy.sql import func
-import tests.test_app
from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.utils.core import get_example_database, get_main_database
from tests.base_tests import SupersetTestCase
from tests.fixtures.certificates import ssl_certificate
+from tests.test_app import app
class TestDatabaseApi(SupersetTestCase):
@@ -652,6 +652,97 @@ class TestDatabaseApi(SupersetTestCase):
)
self.assertEqual(rv.status_code, 400)
+ def test_test_connection(self):
+ """
+ Database API: Test test connection
+ """
+ # need to temporarily allow sqlite dbs, teardown will undo this
+ app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False
+ self.login("admin")
+ example_db = get_example_database()
+ # validate that the endpoint works with the password-masked sqlalchemy
uri
+ data = {
+ "sqlalchemy_uri": example_db.safe_sqlalchemy_uri(),
+ "database_name": "examples",
+ "impersonate_user": False,
+ }
+ url = f"api/v1/database/test_connection"
+ rv = self.post_assert_metric(url, data, "test_connection")
+ self.assertEqual(rv.status_code, 200)
+ self.assertEqual(rv.headers["Content-Type"], "application/json;
charset=utf-8")
+
+ # validate that the endpoint works with the decrypted sqlalchemy uri
+ data = {
+ "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
+ "database_name": "examples",
+ "impersonate_user": False,
+ }
+ rv = self.post_assert_metric(url, data, "test_connection")
+ self.assertEqual(rv.status_code, 200)
+ self.assertEqual(rv.headers["Content-Type"], "application/json;
charset=utf-8")
+
+ def test_test_connection_failed(self):
+ """
+ Database API: Test test connection failed
+ """
+ self.login("admin")
+
+ data = {
+ "sqlalchemy_uri": "broken://url",
+ "database_name": "examples",
+ "impersonate_user": False,
+ }
+ url = f"api/v1/database/test_connection"
+ rv = self.post_assert_metric(url, data, "test_connection")
+ self.assertEqual(rv.status_code, 400)
+ self.assertEqual(rv.headers["Content-Type"], "application/json;
charset=utf-8")
+ response = json.loads(rv.data.decode("utf-8"))
+ expected_response = {
+ "driver_name": "broken",
+ "message": "Could not load database driver: broken",
+ }
+ self.assertEqual(response, expected_response)
+
+ data = {
+ "sqlalchemy_uri": "mssql+pymssql://url",
+ "database_name": "examples",
+ "impersonate_user": False,
+ }
+ rv = self.post_assert_metric(url, data, "test_connection")
+ self.assertEqual(rv.status_code, 400)
+ self.assertEqual(rv.headers["Content-Type"], "application/json;
charset=utf-8")
+ response = json.loads(rv.data.decode("utf-8"))
+ expected_response = {
+ "driver_name": "mssql+pymssql",
+ "message": "Could not load database driver: mssql+pymssql",
+ }
+ self.assertEqual(response, expected_response)
+
+ def test_test_connection_unsafe_uri(self):
+ """
+ Database API: Test test connection with unsafe uri
+ """
+ self.login("admin")
+
+ app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True
+ data = {
+ "sqlalchemy_uri": "sqlite:///home/superset/unsafe.db",
+ "database_name": "unsafe",
+ "impersonate_user": False,
+ }
+ url = f"api/v1/database/test_connection"
+ rv = self.post_assert_metric(url, data, "test_connection")
+ self.assertEqual(rv.status_code, 400)
+ response = json.loads(rv.data.decode("utf-8"))
+ expected_response = {
+ "message": {
+ "sqlalchemy_uri": [
+ "SQLite database cannot be used as a data source for
security reasons."
+ ]
+ }
+ }
+ self.assertEqual(response, expected_response)
+
def test_get_database_related_objects(self):
"""
Database API: Test get chart and dashboard count related to a database