This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 2497048 feat: Add `validate_sql_json` endpoint for checking that a
given sql query is valid for the chosen database (#7422) (#7462)
2497048 is described below
commit 24970485cf05fdc4afa815c708dd38d95ebe389e
Author: Alex Berghage <[email protected]>
AuthorDate: Mon May 6 11:21:02 2019 -0600
feat: Add `validate_sql_json` endpoint for checking that a given sql query
is valid for the chosen database (#7422) (#7462)
merge from lyft-release-sp8 to master
---
docs/installation.rst | 40 ++++---
superset/config.py | 7 ++
superset/sql_validators/__init__.py | 27 +++++
superset/sql_validators/base.py | 66 +++++++++++
superset/sql_validators/presto_db.py | 186 +++++++++++++++++++++++++++++++
superset/views/core.py | 69 +++++++++++-
tests/base_tests.py | 15 +++
tests/sql_validator_tests.py | 210 +++++++++++++++++++++++++++++++++++
8 files changed, 605 insertions(+), 15 deletions(-)
diff --git a/docs/installation.rst b/docs/installation.rst
index b7c83e5..c7c24fc 100644
--- a/docs/installation.rst
+++ b/docs/installation.rst
@@ -439,8 +439,8 @@ The connection string for Teradata looks like this ::
Required environment variables: ::
- export ODBCINI=/.../teradata/client/ODBC_64/odbc.ini
- export ODBCINST=/.../teradata/client/ODBC_64/odbcinst.ini
+ export ODBCINI=/.../teradata/client/ODBC_64/odbc.ini
+ export ODBCINST=/.../teradata/client/ODBC_64/odbcinst.ini
See `Teradata SQLAlchemy <https://github.com/Teradata/sqlalchemy-teradata>`_.
@@ -811,6 +811,19 @@ in this dictionary are made available for users to use in
their SQL.
'my_crazy_macro': lambda x: x*2,
}
+SQL Lab also includes a live query validation feature with pluggable backends.
+You can configure which validation implementation is used with which database
+engine by adding a block like the following to your config.py:
+
+.. code-block:: python
+ FEATURE_FLAGS = {
+ 'SQL_VALIDATORS_BY_ENGINE': {
+ 'presto': 'PrestoDBSQLValidator',
+ }
+ }
+
+The available validators and names can be found in `sql_validators/`.
+
**Scheduling queries**
You can optionally allow your users to schedule queries directly in SQL Lab.
@@ -967,7 +980,7 @@ Note that the above command will install Superset into
``default`` namespace of
Custom OAuth2 configuration
---------------------------
-Beyond FAB supported providers (github, twitter, linkedin, google, azure), its
easy to connect Superset with other OAuth2 Authorization Server implementations
that support "code" authorization.
+Beyond FAB supported providers (github, twitter, linkedin, google, azure), its
easy to connect Superset with other OAuth2 Authorization Server implementations
that support "code" authorization.
The first step: Configure authorization in Superset ``superset_config.py``.
@@ -986,10 +999,10 @@ The first step: Configure authorization in Superset
``superset_config.py``.
},
'access_token_method':'POST', # HTTP Method to call
access_token_url
'access_token_params':{ # Additional parameters for
calls to access_token_url
- 'client_id':'myClientId'
+ 'client_id':'myClientId'
},
- 'access_token_headers':{ # Additional headers for calls to
access_token_url
- 'Authorization': 'Basic Base64EncodedClientIdAndSecret'
+ 'access_token_headers':{ # Additional headers for calls to
access_token_url
+ 'Authorization': 'Basic Base64EncodedClientIdAndSecret'
},
'base_url':'https://myAuthorizationServer/oauth2AuthorizationServer/',
'access_token_url':'https://myAuthorizationServer/oauth2AuthorizationServer/token',
@@ -997,25 +1010,25 @@ The first step: Configure authorization in Superset
``superset_config.py``.
}
}
]
-
+
# Will allow user self registration, allowing to create Flask users from
Authorized User
AUTH_USER_REGISTRATION = True
-
+
# The default user self registration role
AUTH_USER_REGISTRATION_ROLE = "Public"
-
+
Second step: Create a `CustomSsoSecurityManager` that extends
`SupersetSecurityManager` and overrides `oauth_user_info`:
.. code-block:: python
-
+
from superset.security import SupersetSecurityManager
-
+
class CustomSsoSecurityManager(SupersetSecurityManager):
def oauth_user_info(self, provider, response=None):
logging.debug("Oauth2 provider: {0}.".format(provider))
if provider == 'egaSSO':
- # As example, this line request a GET to base_url + '/' +
userDetails with Bearer Authentication,
+ # As example, this line request a GET to base_url + '/' +
userDetails with Bearer Authentication,
# and expects that authorization server checks the token, and response
with user details
me =
self.appbuilder.sm.oauth_remotes[provider].get('userDetails').data
logging.debug("user_data: {0}".format(me))
@@ -1027,7 +1040,6 @@ This file must be located at the same directory than
``superset_config.py`` with
Then we can add this two lines to ``superset_config.py``:
.. code-block:: python
-
+
from custom_sso_security_manager import CustomSsoSecurityManager
CUSTOM_SECURITY_MANAGER = CustomSsoSecurityManager
-
diff --git a/superset/config.py b/superset/config.py
index df14b0f..5a35f0b 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -420,6 +420,9 @@ DEFAULT_DB_ID = None
# Timeout duration for SQL Lab synchronous queries
SQLLAB_TIMEOUT = 30
+# Timeout duration for SQL Lab query validation
+SQLLAB_VALIDATION_TIMEOUT = 10
+
# SQLLAB_DEFAULT_DBID
SQLLAB_DEFAULT_DBID = None
@@ -608,6 +611,10 @@ DEFAULT_RELATIVE_END_TIME = 'today'
# localtime (in the tz where the superset webserver is running)
IS_EPOCH_S_TRULY_UTC = False
+# Configure which SQL validator to use for each engine
+SQL_VALIDATORS_BY_ENGINE = {
+ 'presto': 'PrestoDBSQLValidator',
+}
try:
if CONFIG_PATH_ENV_VAR in os.environ:
diff --git a/superset/sql_validators/__init__.py
b/superset/sql_validators/__init__.py
new file mode 100644
index 0000000..367aab6
--- /dev/null
+++ b/superset/sql_validators/__init__.py
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Optional
+
+from . import base # noqa
+from . import presto_db # noqa
+from .base import SQLValidationAnnotation # noqa
+
+
+def get_validator_by_name(name: str) -> Optional[base.BaseSQLValidator]:
+ return {
+ 'PrestoDBSQLValidator': presto_db.PrestoDBSQLValidator,
+ }.get(name)
diff --git a/superset/sql_validators/base.py b/superset/sql_validators/base.py
new file mode 100644
index 0000000..437001b
--- /dev/null
+++ b/superset/sql_validators/base.py
@@ -0,0 +1,66 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=too-few-public-methods
+
+from typing import (
+ Any,
+ Dict,
+ List,
+ Optional,
+)
+
+
+class SQLValidationAnnotation:
+ """Represents a single annotation (error/warning) in an SQL querytext"""
+ def __init__(
+ self,
+ message: str,
+ line_number: Optional[int],
+ start_column: Optional[int],
+ end_column: Optional[int],
+ ):
+ self.message = message
+ self.line_number = line_number
+ self.start_column = start_column
+ self.end_column = end_column
+
+ def to_dict(self) -> Dict:
+ """Return a dictionary representation of this annotation"""
+ return {
+ 'line_number': self.line_number,
+ 'start_column': self.start_column,
+ 'end_column': self.end_column,
+ 'message': self.message,
+ }
+
+
+class BaseSQLValidator:
+ """BaseSQLValidator defines the interface for checking that a given sql
+ query is valid for a given database engine."""
+
+ name = 'BaseSQLValidator'
+
+ @classmethod
+ def validate(
+ cls,
+ sql: str,
+ schema: str,
+ database: Any,
+ ) -> List[SQLValidationAnnotation]:
+ """Check that the given SQL querystring is valid for the given
engine"""
+ raise NotImplementedError
diff --git a/superset/sql_validators/presto_db.py
b/superset/sql_validators/presto_db.py
new file mode 100644
index 0000000..87c2d8e
--- /dev/null
+++ b/superset/sql_validators/presto_db.py
@@ -0,0 +1,186 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from contextlib import closing
+import logging
+import time
+from typing import (
+ Any,
+ Dict,
+ List,
+ Optional,
+)
+
+from flask import g
+from pyhive.exc import DatabaseError
+
+from superset import app, security_manager
+from superset.sql_parse import ParsedQuery
+from superset.sql_validators.base import (
+ BaseSQLValidator,
+ SQLValidationAnnotation,
+)
+from superset.utils.core import sources
+
+MAX_ERROR_ROWS = 10
+
+config = app.config
+
+
+class PrestoSQLValidationError(Exception):
+ """Error in the process of asking Presto to validate SQL querytext"""
+
+
+class PrestoDBSQLValidator(BaseSQLValidator):
+ """Validate SQL queries using Presto's built-in EXPLAIN subtype"""
+
+ name = 'PrestoDBSQLValidator'
+
+ @classmethod
+ def validate_statement(
+ cls,
+ statement,
+ database,
+ cursor,
+ user_name,
+ ) -> Optional[SQLValidationAnnotation]:
+ # pylint: disable=too-many-locals
+ db_engine_spec = database.db_engine_spec
+ parsed_query = ParsedQuery(statement)
+ sql = parsed_query.stripped()
+
+ # Hook to allow environment-specific mutation (usually comments) to
the SQL
+ # pylint: disable=invalid-name
+ SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR')
+ if SQL_QUERY_MUTATOR:
+ sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)
+
+ # Transform the final statement to an explain call before sending it on
+ # to presto to validate
+ sql = f'EXPLAIN (TYPE VALIDATE) {sql}'
+
+ # Invoke the query against presto. NB this deliberately doesn't use the
+ # engine spec's handle_cursor implementation since we don't record
+ # these EXPLAIN queries done in validation as proper Query objects
+ # in the superset ORM.
+ try:
+ db_engine_spec.execute(cursor, sql)
+ polled = cursor.poll()
+ while polled:
+ logging.info('polling presto for validation progress')
+ stats = polled.get('stats', {})
+ if stats:
+ state = stats.get('state')
+ if state == 'FINISHED':
+ break
+ time.sleep(0.2)
+ polled = cursor.poll()
+ db_engine_spec.fetch_data(cursor, MAX_ERROR_ROWS)
+ return None
+ except DatabaseError as db_error:
+ # The pyhive presto client yields EXPLAIN (TYPE VALIDATE) responses
+ # as though they were normal queries. In other words, it doesn't
+ # know that errors here are not exceptional. To map this back to
+ # ordinary control flow, we have to trap the category of exception
+ # raised by the underlying client, match the exception arguments
+ # pyhive provides against the shape of dictionary for a presto
query
+ # invalid error, and restructure that error as an annotation we can
+ # return up.
+
+ # Confirm the first element in the DatabaseError constructor is a
+ # dictionary with error information. This is currently provided by
+ # the pyhive client, but may break if their interface changes when
+ # we update at some point in the future.
+ if not db_error.args or not isinstance(db_error.args[0], dict):
+ raise PrestoSQLValidationError(
+ 'The pyhive presto client returned an unhandled '
+ 'database error.',
+ ) from db_error
+ error_args: Dict[str, Any] = db_error.args[0]
+
+ # Confirm the two fields we need to be able to present an
annotation
+ # are present in the error response -- a message, and a location.
+ if 'message' not in error_args:
+ raise PrestoSQLValidationError(
+ 'The pyhive presto client did not report an error message',
+ ) from db_error
+ if 'errorLocation' not in error_args:
+ raise PrestoSQLValidationError(
+ 'The pyhive presto client did not report an error
location',
+ ) from db_error
+
+ # Pylint is confused about the type of error_args, despite the
hints
+ # and checks above.
+ # pylint: disable=invalid-sequence-index
+ message = error_args['message']
+ err_loc = error_args['errorLocation']
+ line_number = err_loc.get('lineNumber', None)
+ start_column = err_loc.get('columnNumber', None)
+ end_column = err_loc.get('columnNumber', None)
+
+ return SQLValidationAnnotation(
+ message=message,
+ line_number=line_number,
+ start_column=start_column,
+ end_column=end_column,
+ )
+ except Exception as e:
+ logging.exception(f'Unexpected error running validation query:
{e}')
+ raise e
+
+ @classmethod
+ def validate(
+ cls,
+ sql: str,
+ schema: str,
+ database: Any,
+ ) -> List[SQLValidationAnnotation]:
+ """
+ Presto supports query-validation queries by running them with a
+ prepended explain.
+
+ For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
+ VALIDATE) SELECT 1 FROM default.mytable.
+ """
+ user_name = g.user.username if g.user else None
+ parsed_query = ParsedQuery(sql)
+ statements = parsed_query.get_statements()
+
+ logging.info(f'Validating {len(statements)} statement(s)')
+ engine = database.get_sqla_engine(
+ schema=schema,
+ nullpool=True,
+ user_name=user_name,
+ source=sources.get('sql_lab', None),
+ )
+ # Sharing a single connection and cursor across the
+ # execution of all statements (if many)
+ annotations: List[SQLValidationAnnotation] = []
+ with closing(engine.raw_connection()) as conn:
+ with closing(conn.cursor()) as cursor:
+ for statement in parsed_query.get_statements():
+ annotation = cls.validate_statement(
+ statement,
+ database,
+ cursor,
+ user_name,
+ )
+ if annotation:
+ annotations.append(annotation)
+ logging.debug(f'Validation found {len(annotations)} error(s)')
+
+ return annotations
diff --git a/superset/views/core.py b/superset/views/core.py
index e22acb7..eb25cd0 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -44,7 +44,7 @@ from werkzeug.routing import BaseConverter
from werkzeug.utils import secure_filename
from superset import (
- app, appbuilder, cache, conf, db, results_backend,
+ app, appbuilder, cache, conf, db, get_feature_flags, results_backend,
security_manager, sql_lab, viz)
from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import AnnotationDatasource, SqlaTable
@@ -56,6 +56,7 @@ import superset.models.core as models
from superset.models.sql_lab import Query
from superset.models.user_attributes import UserAttribute
from superset.sql_parse import ParsedQuery
+from superset.sql_validators import get_validator_by_name
from superset.utils import core as utils
from superset.utils import dashboard_import_export
from superset.utils.dates import now_as_float
@@ -2517,6 +2518,72 @@ class Superset(BaseSupersetView):
return self.json_response('OK')
@has_access_api
+ @expose('/validate_sql_json/', methods=['POST', 'GET'])
+ @log_this
+ def validate_sql_json(self):
+ """Validates that arbitrary sql is acceptable for the given database.
+ Returns a list of error/warning annotations as json.
+ """
+ sql = request.form.get('sql')
+ database_id = request.form.get('database_id')
+ schema = request.form.get('schema') or None
+ template_params = json.loads(
+ request.form.get('templateParams') or '{}')
+
+ if len(template_params) > 0:
+ # TODO: factor the Database object out of template rendering
+ # or provide it as mydb so we can render template params
+ # without having to also persist a Query ORM object.
+ return json_error_response(
+ 'SQL validation does not support template parameters',
+ status=400)
+
+ session = db.session()
+ mydb = session.query(models.Database).filter_by(id=database_id).first()
+ if not mydb:
+ json_error_response(
+ 'Database with id {} is missing.'.format(database_id),
+ status=400,
+ )
+
+ spec = mydb.db_engine_spec
+ validators_by_engine = get_feature_flags().get(
+ 'SQL_VALIDATORS_BY_ENGINE')
+ if not validators_by_engine or spec.engine not in validators_by_engine:
+ return json_error_response(
+ 'no SQL validator is configured for {}'.format(spec.engine),
+ status=400)
+ validator_name = validators_by_engine[spec.engine]
+ validator = get_validator_by_name(validator_name)
+ if not validator:
+ return json_error_response(
+ 'No validator named {} found (configured for the {} engine)'
+ .format(validator_name, spec.engine))
+
+ try:
+ timeout = config.get('SQLLAB_VALIDATION_TIMEOUT')
+ timeout_msg = (
+ f'The query exceeded the {timeout} seconds timeout.')
+ with utils.timeout(seconds=timeout,
+ error_message=timeout_msg):
+ errors = validator.validate(sql, schema, mydb)
+ payload = json.dumps(
+ [err.to_dict() for err in errors],
+ default=utils.pessimistic_json_iso_dttm_ser,
+ ignore_nan=True,
+ encoding=None,
+ )
+ return json_success(payload)
+ except Exception as e:
+ logging.exception(e)
+ msg = _(
+ 'Failed to validate your SQL query text. Please check that '
+ f'you have configured the {validator.name} validator '
+ 'correctly and that any services it depends on are up. '
+ f'Exception: {e}')
+ return json_error_response(f'{msg}')
+
+ @has_access_api
@expose('/sql_json/', methods=['POST', 'GET'])
@log_this
def sql_json(self):
diff --git a/tests/base_tests.py b/tests/base_tests.py
index 8555915..6de082a 100644
--- a/tests/base_tests.py
+++ b/tests/base_tests.py
@@ -190,6 +190,21 @@ class SupersetTestCase(unittest.TestCase):
raise Exception('run_sql failed')
return resp
+ def validate_sql(self, sql, client_id=None, user_name=None,
+ raise_on_error=False):
+ if user_name:
+ self.logout()
+ self.login(username=(user_name if user_name else 'admin'))
+ dbid = get_main_database(db.session).id
+ resp = self.get_json_resp(
+ '/superset/validate_sql_json/',
+ raise_on_error=False,
+ data=dict(database_id=dbid, sql=sql, client_id=client_id),
+ )
+ if raise_on_error and 'error' in resp:
+ raise Exception('validate_sql failed')
+ return resp
+
@patch.dict('superset._feature_flags', {'FOO': True}, clear=True)
def test_existing_feature_flags(self):
self.assertTrue(is_feature_enabled('FOO'))
diff --git a/tests/sql_validator_tests.py b/tests/sql_validator_tests.py
new file mode 100644
index 0000000..0e1310c
--- /dev/null
+++ b/tests/sql_validator_tests.py
@@ -0,0 +1,210 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for Sql Lab"""
+import unittest
+from unittest.mock import (
+ MagicMock,
+ patch,
+)
+
+from pyhive.exc import DatabaseError
+
+from superset import app
+from superset.sql_validators import SQLValidationAnnotation
+from superset.sql_validators.base import BaseSQLValidator
+from superset.sql_validators.presto_db import (
+ PrestoDBSQLValidator,
+ PrestoSQLValidationError,
+)
+from .base_tests import SupersetTestCase
+
+PRESTO_TEST_FEATURE_FLAGS = {
+ 'SQL_VALIDATORS_BY_ENGINE': {
+ 'presto': 'PrestoDBSQLValidator',
+ 'sqlite': 'PrestoDBSQLValidator',
+ 'postgresql': 'PrestoDBSQLValidator',
+ 'mysql': 'PrestoDBSQLValidator',
+ },
+}
+
+
+class SqlValidatorEndpointTests(SupersetTestCase):
+ """Testing for Sql Lab querytext validation endpoint"""
+
+ def tearDown(self):
+ self.logout()
+
+ def test_validate_sql_endpoint_noconfig(self):
+ """Assert that validate_sql_json errors out when no validators are
+ configured for any db"""
+ self.login('admin')
+
+ app.config['SQL_VALIDATORS_BY_ENGINE'] = {}
+
+ resp = self.validate_sql(
+ 'SELECT * FROM ab_user',
+ client_id='1',
+ raise_on_error=False,
+ )
+ self.assertIn('error', resp)
+ self.assertIn('no SQL validator is configured', resp['error'])
+
+ @patch('superset.views.core.get_validator_by_name')
+ @patch.dict('superset._feature_flags',
+ PRESTO_TEST_FEATURE_FLAGS,
+ clear=True)
+ def test_validate_sql_endpoint_mocked(self, get_validator_by_name):
+ """Assert that, with a mocked validator, annotations make it back out
+ from the validate_sql_json endpoint as a list of json dictionaries"""
+ self.login('admin')
+
+ validator = MagicMock()
+ get_validator_by_name.return_value = validator
+ validator.validate.return_value = [
+ SQLValidationAnnotation(
+ message="I don't know what I expected, but it wasn't this",
+ line_number=4,
+ start_column=12,
+ end_column=42,
+ ),
+ ]
+
+ resp = self.validate_sql(
+ 'SELECT * FROM somewhere_over_the_rainbow',
+ client_id='1',
+ raise_on_error=False,
+ )
+
+ self.assertEqual(1, len(resp))
+ self.assertIn('expected,', resp[0]['message'])
+
+ @patch('superset.views.core.get_validator_by_name')
+ @patch.dict('superset._feature_flags',
+ PRESTO_TEST_FEATURE_FLAGS,
+ clear=True)
+ def test_validate_sql_endpoint_failure(self, get_validator_by_name):
+ """Assert that validate_sql_json errors out when the selected validator
+ raises an unexpected exception"""
+ self.login('admin')
+
+ validator = MagicMock()
+ get_validator_by_name.return_value = validator
+ validator.validate.side_effect = Exception('Kaboom!')
+
+ resp = self.validate_sql(
+ 'SELECT * FROM ab_user',
+ client_id='1',
+ raise_on_error=False,
+ )
+ self.assertIn('error', resp)
+ self.assertIn('Kaboom!', resp['error'])
+
+
+class BaseValidatorTests(SupersetTestCase):
+ """Testing for the base sql validator"""
+ def setUp(self):
+ self.validator = BaseSQLValidator
+
+ def test_validator_excepts(self):
+ with self.assertRaises(NotImplementedError):
+ self.validator.validate(None, None, None)
+
+
+class PrestoValidatorTests(SupersetTestCase):
+ """Testing for the prestodb sql validator"""
+ def setUp(self):
+ self.validator = PrestoDBSQLValidator
+ self.database = MagicMock() # noqa
+ self.database_engine = self.database.get_sqla_engine.return_value
+ self.database_conn = self.database_engine.raw_connection.return_value
+ self.database_cursor = self.database_conn.cursor.return_value
+ self.database_cursor.poll.return_value = None
+
+ def tearDown(self):
+ self.logout()
+
+ PRESTO_ERROR_TEMPLATE = {
+ 'errorLocation': {
+ 'lineNumber': 10,
+ 'columnNumber': 20,
+ },
+ 'message': "your query isn't how I like it",
+ }
+
+ @patch('superset.sql_validators.presto_db.g')
+ def test_validator_success(self, flask_g):
+ flask_g.user.username = 'nobody'
+ sql = 'SELECT 1 FROM default.notarealtable'
+ schema = 'default'
+
+ errors = self.validator.validate(sql, schema, self.database)
+
+ self.assertEqual([], errors)
+
+ @patch('superset.sql_validators.presto_db.g')
+ def test_validator_db_error(self, flask_g):
+ flask_g.user.username = 'nobody'
+ sql = 'SELECT 1 FROM default.notarealtable'
+ schema = 'default'
+
+ fetch_fn = self.database.db_engine_spec.fetch_data
+ fetch_fn.side_effect = DatabaseError('dummy db error')
+
+ with self.assertRaises(PrestoSQLValidationError):
+ self.validator.validate(sql, schema, self.database)
+
+ @patch('superset.sql_validators.presto_db.g')
+ def test_validator_unexpected_error(self, flask_g):
+ flask_g.user.username = 'nobody'
+ sql = 'SELECT 1 FROM default.notarealtable'
+ schema = 'default'
+
+ fetch_fn = self.database.db_engine_spec.fetch_data
+ fetch_fn.side_effect = Exception('a mysterious failure')
+
+ with self.assertRaises(Exception):
+ self.validator.validate(sql, schema, self.database)
+
+ @patch('superset.sql_validators.presto_db.g')
+ def test_validator_query_error(self, flask_g):
+ flask_g.user.username = 'nobody'
+ sql = 'SELECT 1 FROM default.notarealtable'
+ schema = 'default'
+
+ fetch_fn = self.database.db_engine_spec.fetch_data
+ fetch_fn.side_effect = DatabaseError(self.PRESTO_ERROR_TEMPLATE)
+
+ errors = self.validator.validate(sql, schema, self.database)
+
+ self.assertEqual(1, len(errors))
+
+ def test_validate_sql_endpoint(self):
+ self.login('admin')
+ # NB this is effectively an integration test -- when there's a default
+ # validator for sqlite, this test will fail because the validator
+ # will no longer error out.
+ resp = self.validate_sql(
+ 'SELECT * FROM ab_user',
+ client_id='1',
+ raise_on_error=False,
+ )
+ self.assertIn('error', resp)
+ self.assertIn('no SQL validator is configured', resp['error'])
+
+
+if __name__ == '__main__':
+ unittest.main()