This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch 1.4
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 92c245835f1a17d1991d1417f37ec9b6f08d4546
Author: Ville Brofeldt <[email protected]>
AuthorDate: Fri Dec 3 12:35:26 2021 +0200

    fix(sqla): make text clause escaping optional (#17641)
---
 superset/connectors/sqla/models.py                 | 31 ++++----
 superset/db_engine_specs/athena.py                 |  1 +
 superset/db_engine_specs/base.py                   | 15 +++-
 .../db_engine_specs/athena_tests.py                | 57 --------------
 tests/unit_tests/db_engine_specs/test_athena.py    | 87 ++++++++++++++++++++++
 tests/unit_tests/db_engine_specs/test_base.py      | 33 ++++++++
 tests/unit_tests/fixtures/common.py                | 25 +++++++
 7 files changed, 175 insertions(+), 74 deletions(-)

diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index fe79f81..36ffc74 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -105,16 +105,6 @@ logger = logging.getLogger(__name__)
 VIRTUAL_TABLE_ALIAS = "virtual_table"
 
 
-def text(clause: str) -> TextClause:
-    """
-    SQLALchemy wrapper to ensure text clauses are escaped properly
-
-    :param clause: clause potentially containing colons
-    :return: text clause with escaped colons
-    """
-    return sa.text(clause.replace(":", "\\:"))
-
-
 class SqlaQuery(NamedTuple):
     applied_template_filters: List[str]
     extra_cache_keys: List[Any]
@@ -299,7 +289,10 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
         l = []
         if start_dttm:
             l.append(
-                col >= text(self.dttm_sql_literal(start_dttm, 
time_range_endpoints))
+                col
+                >= self.table.text(
+                    self.dttm_sql_literal(start_dttm, time_range_endpoints)
+                )
             )
         if end_dttm:
             if (
@@ -307,10 +300,13 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
                 and time_range_endpoints[1] == 
utils.TimeRangeEndpoint.EXCLUSIVE
             ):
                 l.append(
-                    col < text(self.dttm_sql_literal(end_dttm, 
time_range_endpoints))
+                    col
+                    < self.table.text(
+                        self.dttm_sql_literal(end_dttm, time_range_endpoints)
+                    )
                 )
             else:
-                l.append(col <= text(self.dttm_sql_literal(end_dttm, None)))
+                l.append(col <= 
self.table.text(self.dttm_sql_literal(end_dttm, None)))
         return and_(*l)
 
     def get_timestamp_expression(
@@ -716,7 +712,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: 
disable=too-many-public-metho
     def get_fetch_values_predicate(self) -> TextClause:
         tp = self.get_template_processor()
         try:
-            return text(tp.process_template(self.fetch_values_predicate))
+            return self.text(tp.process_template(self.fetch_values_predicate))
         except TemplateError as ex:
             raise QueryObjectValidationError(
                 _(
@@ -806,7 +802,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: 
disable=too-many-public-metho
             raise QueryObjectValidationError(
                 _("Virtual dataset query must be read-only")
             )
-        return TextAsFrom(text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
+        return TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
 
     def get_rendered_sql(
         self, template_processor: Optional[BaseTemplateProcessor] = None
@@ -932,7 +928,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: 
disable=too-many-public-metho
         filters_grouped: Dict[Union[int, str], List[str]] = defaultdict(list)
         try:
             for filter_ in security_manager.get_rls_filters(self):
-                clause = text(
+                clause = self.text(
                     f"({template_processor.process_template(filter_.clause)})"
                 )
                 filters_grouped[filter_.group_key or filter_.id].append(clause)
@@ -942,6 +938,9 @@ class SqlaTable(Model, BaseDatasource):  # pylint: 
disable=too-many-public-metho
                 _("Error in jinja expression in RLS filters: %(msg)s", 
msg=ex.message,)
             ) from ex
 
+    def text(self, clause: str) -> TextClause:
+        return self.db_engine_spec.get_text_clause(clause)
+
     def get_sqla_query(  # pylint: 
disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements
         self,
         apply_fetch_values_predicate: bool = False,
diff --git a/superset/db_engine_specs/athena.py 
b/superset/db_engine_specs/athena.py
index 666049b..e7f67d7 100644
--- a/superset/db_engine_specs/athena.py
+++ b/superset/db_engine_specs/athena.py
@@ -32,6 +32,7 @@ SYNTAX_ERROR_REGEX = re.compile(
 class AthenaEngineSpec(BaseEngineSpec):
     engine = "awsathena"
     engine_name = "Amazon Athena"
+    allows_escaped_colons = False
 
     _time_grain_expressions = {
         None: "{col}",
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index ed3851e..53a3c62 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -52,7 +52,7 @@ from sqlalchemy.engine.url import make_url, URL
 from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.orm import Session
 from sqlalchemy.sql import quoted_name, text
-from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom
+from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, 
TextClause
 from sqlalchemy.types import String, TypeEngine, UnicodeText
 from typing_extensions import TypedDict
 
@@ -280,6 +280,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
     allows_alias_in_select = True
     allows_alias_in_orderby = True
     allows_sql_comments = True
+    allows_escaped_colons = True
 
     # Whether ORDER BY clause can use aliases created in SELECT
     # that are the same as a source column
@@ -334,6 +335,18 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         return False
 
     @classmethod
+    def get_text_clause(cls, clause: str) -> TextClause:
+        """
+        SQLALchemy wrapper to ensure text clauses are escaped properly
+
+        :param clause: string clause with potentially unescaped characters
+        :return: text clause with escaped characters
+        """
+        if cls.allows_escaped_colons:
+            clause = clause.replace(":", "\\:")
+        return text(clause)
+
+    @classmethod
     def get_engine(
         cls,
         database: "Database",
diff --git a/tests/integration_tests/db_engine_specs/athena_tests.py 
b/tests/integration_tests/db_engine_specs/athena_tests.py
deleted file mode 100644
index 484788f..0000000
--- a/tests/integration_tests/db_engine_specs/athena_tests.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from superset.db_engine_specs.athena import AthenaEngineSpec
-from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
-from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
-
-
-class TestAthenaDbEngineSpec(TestDbEngineSpec):
-    def test_convert_dttm(self):
-        dttm = self.get_dttm()
-
-        self.assertEqual(
-            AthenaEngineSpec.convert_dttm("DATE", dttm),
-            "from_iso8601_date('2019-01-02')",
-        )
-
-        self.assertEqual(
-            AthenaEngineSpec.convert_dttm("TIMESTAMP", dttm),
-            "from_iso8601_timestamp('2019-01-02T03:04:05.678900')",
-        )
-
-    def test_extract_errors(self):
-        """
-        Test that custom error messages are extracted correctly.
-        """
-        msg = ": mismatched input 'fromm'. Expecting: "
-        result = AthenaEngineSpec.extract_errors(Exception(msg))
-        assert result == [
-            SupersetError(
-                message='Please check your query for syntax errors at or near 
"fromm". Then, try running your query again.',
-                error_type=SupersetErrorType.SYNTAX_ERROR,
-                level=ErrorLevel.ERROR,
-                extra={
-                    "engine_name": "Amazon Athena",
-                    "issue_codes": [
-                        {
-                            "code": 1030,
-                            "message": "Issue 1030 - The query has a syntax 
error.",
-                        }
-                    ],
-                },
-            )
-        ]
diff --git a/tests/unit_tests/db_engine_specs/test_athena.py 
b/tests/unit_tests/db_engine_specs/test_athena.py
new file mode 100644
index 0000000..32a401f
--- /dev/null
+++ b/tests/unit_tests/db_engine_specs/test_athena.py
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument, import-outside-toplevel, protected-access
+import re
+from datetime import datetime
+
+from flask.ctx import AppContext
+
+from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
+from tests.unit_tests.fixtures.common import dttm
+
+SYNTAX_ERROR_REGEX = re.compile(
+    ": mismatched input '(?P<syntax_error>.*?)'. Expecting: "
+)
+
+
+def test_convert_dttm(app_context: AppContext, dttm: datetime) -> None:
+    """
+    Test that date objects are converted correctly.
+    """
+
+    from superset.db_engine_specs.athena import AthenaEngineSpec
+
+    assert (
+        AthenaEngineSpec.convert_dttm("DATE", dttm) == 
"from_iso8601_date('2019-01-02')"
+    )
+
+    assert (
+        AthenaEngineSpec.convert_dttm("TIMESTAMP", dttm)
+        == "from_iso8601_timestamp('2019-01-02T03:04:05.678900')"
+    )
+
+
+def test_extract_errors(app_context: AppContext) -> None:
+    """
+    Test that custom error messages are extracted correctly.
+    """
+
+    from superset.db_engine_specs.athena import AthenaEngineSpec
+
+    msg = ": mismatched input 'fromm'. Expecting: "
+    result = AthenaEngineSpec.extract_errors(Exception(msg))
+    assert result == [
+        SupersetError(
+            message='Please check your query for syntax errors at or near 
"fromm". Then, try running your query again.',
+            error_type=SupersetErrorType.SYNTAX_ERROR,
+            level=ErrorLevel.ERROR,
+            extra={
+                "engine_name": "Amazon Athena",
+                "issue_codes": [
+                    {
+                        "code": 1030,
+                        "message": "Issue 1030 - The query has a syntax 
error.",
+                    }
+                ],
+            },
+        )
+    ]
+
+
+def test_get_text_clause_with_colon(app_context: AppContext) -> None:
+    """
+    Make sure text clauses don't escape the colon character
+    """
+
+    from superset.db_engine_specs.athena import AthenaEngineSpec
+
+    query = (
+        "SELECT foo FROM tbl WHERE "
+        "abc >= from_iso8601_timestamp('2021-11-26T00\:00\:00.000000')"
+    )
+    text_clause = AthenaEngineSpec.get_text_clause(query)
+    assert text_clause.text == query
diff --git a/tests/unit_tests/db_engine_specs/test_base.py 
b/tests/unit_tests/db_engine_specs/test_base.py
new file mode 100644
index 0000000..e450aaf
--- /dev/null
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -0,0 +1,33 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument, import-outside-toplevel, protected-access
+import re
+
+from flask.ctx import AppContext
+
+
+def test_get_text_clause_with_colon(app_context: AppContext) -> None:
+    """
+    Make sure text clauses are correctly escaped
+    """
+
+    from superset.db_engine_specs.base import BaseEngineSpec
+
+    text_clause = BaseEngineSpec.get_text_clause(
+        "SELECT foo FROM tbl WHERE foo = '123:456')"
+    )
+    assert text_clause.text == "SELECT foo FROM tbl WHERE foo = '123\\:456')"
diff --git a/tests/unit_tests/fixtures/common.py 
b/tests/unit_tests/fixtures/common.py
new file mode 100644
index 0000000..6c2af8d
--- /dev/null
+++ b/tests/unit_tests/fixtures/common.py
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime
+
+import pytest
+
+
[email protected]
+def dttm() -> datetime:
+    return datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d 
%H:%M:%S.%f")

Reply via email to