This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 08358d6  fix: handle query exceptions gracefully (#10548)
08358d6 is described below

commit 08358d623b4938956526df840ef9e466bf281b6a
Author: Ville Brofeldt <[email protected]>
AuthorDate: Fri Aug 7 17:37:40 2020 +0300

    fix: handle query exceptions gracefully (#10548)
    
    * fix: handle query exceptions gracefully
    
    * add more recasts
    
    * add test
    
    * disable test for presto
    
    * switch to SQLA error
---
 superset/common/query_context.py   |  6 ++-
 superset/connectors/sqla/models.py | 85 ++++++++++++++++++++++++++++----------
 superset/views/core.py             |  7 ++--
 superset/viz.py                    | 50 +++++++++++++++-------
 tests/sqla_models_tests.py         | 25 +++++++++++
 5 files changed, 132 insertions(+), 41 deletions(-)

diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index e602fbf..0d33f9c 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -27,6 +27,7 @@ from superset import app, cache, db, security_manager
 from superset.common.query_object import QueryObject
 from superset.connectors.base.models import BaseDatasource
 from superset.connectors.connector_registry import ConnectorRegistry
+from superset.exceptions import QueryObjectValidationError
 from superset.stats_logger import BaseStatsLogger
 from superset.utils import core as utils
 from superset.utils.core import DTTM_ALIAS
@@ -244,10 +245,13 @@ class QueryContext:
                     if not self.force:
                         stats_logger.incr("loaded_from_source_without_force")
                     is_loaded = True
+            except QueryObjectValidationError as ex:
+                error_message = str(ex)
+                status = utils.QueryStatus.FAILED
             except Exception as ex:  # pylint: disable=broad-except
                 logger.exception(ex)
                 if not error_message:
-                    error_message = "{}".format(ex)
+                    error_message = str(ex)
                 status = utils.QueryStatus.FAILED
                 stacktrace = utils.get_stacktrace()
 
diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 530a2e1..cfc807d 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -25,6 +25,7 @@ import sqlparse
 from flask import escape, Markup
 from flask_appbuilder import Model
 from flask_babel import lazy_gettext as _
+from jinja2.exceptions import TemplateError
 from sqlalchemy import (
     and_,
     asc,
@@ -40,7 +41,7 @@ from sqlalchemy import (
     Table,
     Text,
 )
-from sqlalchemy.exc import CompileError
+from sqlalchemy.exc import CompileError, SQLAlchemyError
 from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, 
Session
 from sqlalchemy.orm.exc import NoResultFound
 from sqlalchemy.schema import UniqueConstraint
@@ -51,7 +52,7 @@ from superset import app, db, is_feature_enabled, 
security_manager
 from superset.connectors.base.models import BaseColumn, BaseDatasource, 
BaseMetric
 from superset.constants import NULL_STRING
 from superset.db_engine_specs.base import TimestampExpression
-from superset.exceptions import DatabaseNotFound
+from superset.exceptions import DatabaseNotFound, QueryObjectValidationError
 from superset.jinja_context import (
     BaseTemplateProcessor,
     ExtraCache,
@@ -634,7 +635,15 @@ class SqlaTable(  # pylint: 
disable=too-many-public-methods,too-many-instance-at
 
         if self.fetch_values_predicate:
             tp = self.get_template_processor()
-            qry = 
qry.where(text(tp.process_template(self.fetch_values_predicate)))
+            try:
+                qry = 
qry.where(text(tp.process_template(self.fetch_values_predicate)))
+            except TemplateError as ex:
+                raise QueryObjectValidationError(
+                    _(
+                        "Error in jinja expression in fetch values predicate: 
%(msg)s",
+                        msg=ex.message,
+                    )
+                )
 
         engine = self.database.get_sqla_engine()
         sql = "{}".format(qry.compile(engine, compile_kwargs={"literal_binds": 
True}))
@@ -684,7 +693,16 @@ class SqlaTable(  # pylint: 
disable=too-many-public-methods,too-many-instance-at
         if self.sql:
             from_sql = self.sql
             if template_processor:
-                from_sql = template_processor.process_template(from_sql)
+                try:
+                    from_sql = template_processor.process_template(from_sql)
+                except TemplateError as ex:
+                    raise QueryObjectValidationError(
+                        _(
+                            "Error in jinja expression in FROM clause: 
%(msg)s",
+                            msg=ex.message,
+                        )
+                    )
+
             from_sql = sqlparse.format(from_sql, strip_comments=True)
             return TextAsFrom(sa.text(from_sql), []).alias("expr_qry")
         return self.get_sqla_table()
@@ -730,10 +748,15 @@ class SqlaTable(  # pylint: 
disable=too-many-public-methods,too-many-instance-at
         :returns: A list of SQL clauses to be ANDed together.
         :rtype: List[str]
         """
-        return [
-            text("({})".format(template_processor.process_template(f.clause)))
-            for f in security_manager.get_rls_filters(self)
-        ]
+        try:
+            return [
+                
text("({})".format(template_processor.process_template(f.clause)))
+                for f in security_manager.get_rls_filters(self)
+            ]
+        except TemplateError as ex:
+            raise QueryObjectValidationError(
+                _("Error in jinja expression in RLS filters: %(msg)s", 
msg=ex.message,)
+            )
 
     def get_sqla_query(  # pylint: 
disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements
         self,
@@ -791,7 +814,7 @@ class SqlaTable(  # pylint: 
disable=too-many-public-methods,too-many-instance-at
         metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in 
self.metrics}
 
         if not granularity and is_timeseries:
-            raise Exception(
+            raise QueryObjectValidationError(
                 _(
                     "Datetime column not provided as part table configuration "
                     "and is required by this type of chart"
@@ -802,7 +825,7 @@ class SqlaTable(  # pylint: 
disable=too-many-public-methods,too-many-instance-at
             and not columns
             and (is_sip_38 or (not is_sip_38 and not groupby))
         ):
-            raise Exception(_("Empty query?"))
+            raise QueryObjectValidationError(_("Empty query?"))
         metrics_exprs: List[ColumnElement] = []
         for metric in metrics:
             if utils.is_adhoc_metric(metric):
@@ -811,7 +834,9 @@ class SqlaTable(  # pylint: 
disable=too-many-public-methods,too-many-instance-at
             elif isinstance(metric, str) and metric in metrics_by_name:
                 metrics_exprs.append(metrics_by_name[metric].get_sqla_col())
             else:
-                raise Exception(_("Metric '%(metric)s' does not exist", 
metric=metric))
+                raise QueryObjectValidationError(
+                    _("Metric '%(metric)s' does not exist", metric=metric)
+                )
         if metrics_exprs:
             main_metric_expr = metrics_exprs[0]
         else:
@@ -958,7 +983,7 @@ class SqlaTable(  # pylint: 
disable=too-many-public-methods,too-many-instance-at
                             != None
                         )
                     else:
-                        raise Exception(
+                        raise QueryObjectValidationError(
                             _("Invalid filter operation type: %(op)s", op=op)
                         )
         if config["ENABLE_ROW_LEVEL_SECURITY"]:
@@ -966,11 +991,27 @@ class SqlaTable(  # pylint: 
disable=too-many-public-methods,too-many-instance-at
         if extras:
             where = extras.get("where")
             if where:
-                where = template_processor.process_template(where)
+                try:
+                    where = template_processor.process_template(where)
+                except TemplateError as ex:
+                    raise QueryObjectValidationError(
+                        _(
+                            "Error in jinja expression in WHERE clause: 
%(msg)s",
+                            msg=ex.message,
+                        )
+                    )
                 where_clause_and += [sa.text("({})".format(where))]
             having = extras.get("having")
             if having:
-                having = template_processor.process_template(having)
+                try:
+                    having = template_processor.process_template(having)
+                except TemplateError as ex:
+                    raise QueryObjectValidationError(
+                        _(
+                            "Error in jinja expression in HAVING clause: 
%(msg)s",
+                            msg=ex.message,
+                        )
+                    )
                 having_clause_and += [sa.text("({})".format(having))]
         if granularity:
             qry = qry.where(and_(*(time_filters + where_clause_and)))
@@ -1117,7 +1158,7 @@ class SqlaTable(  # pylint: 
disable=too-many-public-methods,too-many-instance-at
         ):
             ob = metrics_by_name[timeseries_limit_metric].get_sqla_col()
         else:
-            raise Exception(
+            raise QueryObjectValidationError(
                 _("Metric '%(metric)s' does not exist", 
metric=timeseries_limit_metric)
             )
 
@@ -1159,7 +1200,7 @@ class SqlaTable(  # pylint: 
disable=too-many-public-methods,too-many-instance-at
             labels_expected = query_str_ext.labels_expected
             if df is not None and not df.empty:
                 if len(df.columns) != len(labels_expected):
-                    raise Exception(
+                    raise QueryObjectValidationError(
                         f"For {sql}, df.columns: {df.columns}"
                         f" differs from {labels_expected}"
                     )
@@ -1193,13 +1234,13 @@ class SqlaTable(  # pylint: 
disable=too-many-public-methods,too-many-instance-at
         """Fetches the metadata for the table and merges it in"""
         try:
             table_ = self.get_sqla_table_object()
-        except Exception as ex:
-            logger.exception(ex)
-            raise Exception(
+        except SQLAlchemyError:
+            raise QueryObjectValidationError(
                 _(
-                    "Table [{}] doesn't seem to exist in the specified 
database, "
-                    "couldn't fetch column information"
-                ).format(self.table_name)
+                    "Table %(table)s doesn't seem to exist in the specified 
database, "
+                    "couldn't fetch column information",
+                    table=self.table_name,
+                )
             )
 
         metrics = []
diff --git a/superset/views/core.py b/superset/views/core.py
index f3a2264..48cd458 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -32,6 +32,7 @@ from flask_appbuilder.models.sqla.interface import 
SQLAInterface
 from flask_appbuilder.security.decorators import has_access, has_access_api
 from flask_appbuilder.security.sqla import models as ab_models
 from flask_babel import gettext as __, lazy_gettext as _
+from jinja2.exceptions import TemplateError
 from sqlalchemy import and_, or_, select
 from sqlalchemy.engine.url import make_url
 from sqlalchemy.exc import (
@@ -535,7 +536,7 @@ class Superset(BaseSupersetView):  # pylint: 
disable=too-many-public-methods
 
             return self.generate_json(viz_obj, response_type)
         except SupersetException as ex:
-            return json_error_response(utils.error_msg_from_exception(ex))
+            return json_error_response(utils.error_msg_from_exception(ex), 400)
 
     @event_logger.log_this
     @has_access
@@ -2300,10 +2301,10 @@ class Superset(BaseSupersetView):  # pylint: 
disable=too-many-public-methods
             rendered_query = template_processor.process_template(
                 query.sql, **template_params
             )
-        except Exception as ex:  # pylint: disable=broad-except
+        except TemplateError as ex:
             error_msg = utils.error_msg_from_exception(ex)
             return json_error_response(
-                f"Query {query_id}: Template rendering failed: {error_msg}"
+                f"Query {query_id}: Template syntax error: {error_msg}"
             )
 
         # Limit is not applied to the CTA queries if SQLLAB_CTAS_NO_LIMIT flag 
is set
diff --git a/superset/viz.py b/superset/viz.py
index 27b6ad9..88d67bb 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -329,13 +329,17 @@ class BaseViz:
         # default order direction
         order_desc = form_data.get("order_desc", True)
 
-        since, until = utils.get_since_until(
-            relative_start=relative_start,
-            relative_end=relative_end,
-            time_range=form_data.get("time_range"),
-            since=form_data.get("since"),
-            until=form_data.get("until"),
-        )
+        try:
+            since, until = utils.get_since_until(
+                relative_start=relative_start,
+                relative_end=relative_end,
+                time_range=form_data.get("time_range"),
+                since=form_data.get("since"),
+                until=form_data.get("until"),
+            )
+        except ValueError as ex:
+            raise QueryObjectValidationError(str(ex))
+
         time_shift = form_data.get("time_shift", "")
         self.time_shift = utils.parse_past_timedelta(time_shift)
         from_dttm = None if since is None else (since - self.time_shift)
@@ -475,6 +479,16 @@ class BaseViz:
                     if not self.force:
                         stats_logger.incr("loaded_from_source_without_force")
                     is_loaded = True
+            except QueryObjectValidationError as ex:
+                error = dataclasses.asdict(
+                    SupersetError(
+                        message=str(ex),
+                        level=ErrorLevel.ERROR,
+                        error_type=SupersetErrorType.VIZ_GET_DF_ERROR,
+                    )
+                )
+                self.errors.append(error)
+                self.status = utils.QueryStatus.FAILED
             except Exception as ex:
                 logger.exception(ex)
 
@@ -889,13 +903,16 @@ class CalHeatmapViz(BaseViz):
                 values[str(v / 10 ** 9)] = obj.get(metric)
             data[metric] = values
 
-        start, end = utils.get_since_until(
-            relative_start=relative_start,
-            relative_end=relative_end,
-            time_range=form_data.get("time_range"),
-            since=form_data.get("since"),
-            until=form_data.get("until"),
-        )
+        try:
+            start, end = utils.get_since_until(
+                relative_start=relative_start,
+                relative_end=relative_end,
+                time_range=form_data.get("time_range"),
+                since=form_data.get("since"),
+                until=form_data.get("until"),
+            )
+        except ValueError as ex:
+            raise QueryObjectValidationError(str(ex))
         if not start or not end:
             raise QueryObjectValidationError(
                 "Please provide both time bounds (Since and Until)"
@@ -1288,7 +1305,10 @@ class NVD3TimeSeriesViz(NVD3Viz):
 
         for option in time_compare:
             query_object = self.query_obj()
-            delta = utils.parse_past_timedelta(option)
+            try:
+                delta = utils.parse_past_timedelta(option)
+            except ValueError as ex:
+                raise QueryObjectValidationError(str(ex))
             query_object["inner_from_dttm"] = query_object["from_dttm"]
             query_object["inner_to_dttm"] = query_object["to_dttm"]
 
diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py
index 3b9a84f..4666fd7 100644
--- a/tests/sqla_models_tests.py
+++ b/tests/sqla_models_tests.py
@@ -17,10 +17,12 @@
 # isort:skip_file
 from typing import Any, Dict, NamedTuple, List, Tuple, Union
 from unittest.mock import patch
+import pytest
 
 import tests.test_app
 from superset.connectors.sqla.models import SqlaTable, TableColumn
 from superset.db_engine_specs.druid import DruidEngineSpec
+from superset.exceptions import QueryObjectValidationError
 from superset.models.core import Database
 from superset.utils.core import DbColumnType, get_example_database, 
FilterOperator
 
@@ -170,3 +172,26 @@ class TestDatabaseModel(SupersetTestCase):
             sqla_query = table.get_sqla_query(**query_obj)
             sql = table.database.compile_sqla_query(sqla_query.sqla_query)
             self.assertIn(filter_.expected, sql)
+
+    def test_incorrect_jinja_syntax_raises_correct_exception(self):
+        query_obj = {
+            "granularity": None,
+            "from_dttm": None,
+            "to_dttm": None,
+            "groupby": ["user"],
+            "metrics": [],
+            "is_timeseries": False,
+            "filter": [],
+            "extras": {},
+        }
+
+        # Table with Jinja callable.
+        table = SqlaTable(
+            table_name="test_table",
+            sql="SELECT '{{ abcd xyz + 1 ASDF }}' as user",
+            database=get_example_database(),
+        )
+        # TODO(villebro): make it work with presto
+        if get_example_database().backend != "presto":
+            with pytest.raises(QueryObjectValidationError):
+                table.get_sqla_query(**query_obj)

Reply via email to