This is an automated email from the ASF dual-hosted git repository.
dpgaspar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 516bdf6 fix(mssql): apply limit and set alias for functions (#9644)
516bdf6 is described below
commit 516bdf6db1bbb9591be24ee8869c6466eb14e8c2
Author: Daniel Vaz Gaspar <[email protected]>
AuthorDate: Mon Apr 27 09:23:08 2020 +0100
fix(mssql): apply limit and set alias for functions (#9644)
---
superset/db_engine_specs/mssql.py | 14 +++++++-
superset/sql_parse.py | 45 ++++++++++++++++++++++++-
tests/db_engine_specs/mssql_tests.py | 64 +++++++++++++++++++++++++++++++++++-
3 files changed, 120 insertions(+), 3 deletions(-)
diff --git a/superset/db_engine_specs/mssql.py
b/superset/db_engine_specs/mssql.py
index 6a231b2..4fc6e6f 100644
--- a/superset/db_engine_specs/mssql.py
+++ b/superset/db_engine_specs/mssql.py
@@ -14,13 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import logging
import re
from datetime import datetime
-from typing import Any, List, Optional, Tuple
+from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from sqlalchemy.types import String, TypeEngine, UnicodeText
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
+from superset.sql_parse import ParsedQuery
+
+if TYPE_CHECKING:
+ from superset.models.core import Database # pylint: disable=unused-import
+
+logger = logging.getLogger(__name__)
class MssqlEngineSpec(BaseEngineSpec):
@@ -76,3 +83,8 @@ class MssqlEngineSpec(BaseEngineSpec):
if regex.match(type_):
return sqla_type
return None
+
+ @classmethod
+ def apply_limit_to_sql(cls, sql: str, limit: int, database: "Database") ->
str:
+ new_sql = ParsedQuery(sql).set_alias()
+ return super().apply_limit_to_sql(new_sql, limit, database)
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index b39fc44..8cac2ff 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -18,7 +18,14 @@ import logging
from typing import List, Optional, Set
import sqlparse
-from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token,
TokenList
+from sqlparse.sql import (
+ Function,
+ Identifier,
+ IdentifierList,
+ remove_quotes,
+ Token,
+ TokenList,
+)
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt
@@ -247,3 +254,39 @@ class ParsedQuery:
for i in statement.tokens:
str_res += str(i.value)
return str_res
+
+ def set_alias(self) -> str:
+ """
+ Returns a new query string where all functions have alias.
+ This is particularly necessary for MSSQL engines.
+
+ :return: String with new aliased SQL query
+ """
+ new_sql = ""
+ changed_counter = 1
+ for token in self._parsed[0].tokens:
+ # Identifier list (list of columns)
+ if isinstance(token, IdentifierList) and token.ttype is None:
+ for i, identifier in enumerate(token.get_identifiers()):
+ # Functions are anonymous on MSSQL
+ if isinstance(identifier, Function) and not
identifier.has_alias():
+ identifier.value = (
+ f"{identifier.value} AS"
+ f" {identifier.get_real_name()}_{changed_counter}"
+ )
+ changed_counter += 1
+ new_sql += str(identifier.value)
+ # If not last identifier
+ if i != len(list(token.get_identifiers())) - 1:
+ new_sql += ", "
+ # Just a lonely function?
+ elif isinstance(token, Function) and token.ttype is None:
+ if not token.has_alias():
+ token.value = (
+ f"{token.value} AS
{token.get_real_name()}_{changed_counter}"
+ )
+ new_sql += str(token.value)
+ # Nothing to change, assemble what we have
+ else:
+ new_sql += str(token.value)
+ return new_sql
diff --git a/tests/db_engine_specs/mssql_tests.py
b/tests/db_engine_specs/mssql_tests.py
index 238dd2a..9f5351c 100644
--- a/tests/db_engine_specs/mssql_tests.py
+++ b/tests/db_engine_specs/mssql_tests.py
@@ -15,15 +15,18 @@
# specific language governing permissions and limitations
# under the License.
import unittest.mock as mock
+from typing import Optional
from sqlalchemy import column, table
from sqlalchemy.dialects import mssql
from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR
-from sqlalchemy.sql import select
+from sqlalchemy.sql import select, Select
from sqlalchemy.types import String, UnicodeText
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
+from superset.extensions import db
+from superset.models.core import Database
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
@@ -94,6 +97,65 @@ class MssqlEngineSpecTest(DbEngineSpecTestCase):
for actual, expected in test_cases:
self.assertEqual(actual, expected)
+ def test_apply_limit(self):
+ def compile_sqla_query(qry: Select, schema: Optional[str] = None) ->
str:
+ return str(
+ qry.compile(
+ dialect=mssql.dialect(), compile_kwargs={"literal_binds":
True}
+ )
+ )
+
+ database = Database(
+ database_name="mssql_test",
+
sqlalchemy_uri="mssql+pymssql://sa:Password_123@localhost:1433/msdb",
+ )
+ db.session.add(database)
+ db.session.commit()
+
+ with mock.patch.object(database, "compile_sqla_query",
new=compile_sqla_query):
+ test_sql = "SELECT COUNT(*) FROM FOO_TABLE"
+
+ limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000,
database)
+
+ expected_sql = (
+ "SELECT TOP 1000 * \n"
+ "FROM (SELECT COUNT(*) AS COUNT_1 FROM FOO_TABLE) AS inner_qry"
+ )
+ self.assertEqual(expected_sql, limited_sql)
+
+ test_sql = "SELECT COUNT(*), SUM(id) FROM FOO_TABLE"
+ limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000,
database)
+
+ expected_sql = (
+ "SELECT TOP 1000 * \n"
+ "FROM (SELECT COUNT(*) AS COUNT_1, SUM(id) AS SUM_2 FROM
FOO_TABLE) "
+ "AS inner_qry"
+ )
+ self.assertEqual(expected_sql, limited_sql)
+
+ test_sql = "SELECT COUNT(*), FOO_COL1 FROM FOO_TABLE GROUP BY
FOO_COL1"
+ limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000,
database)
+
+ expected_sql = (
+ "SELECT TOP 1000 * \n"
+ "FROM (SELECT COUNT(*) AS COUNT_1, "
+ "FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1)"
+ " AS inner_qry"
+ )
+ self.assertEqual(expected_sql, limited_sql)
+
+ test_sql = "SELECT COUNT(*), COUNT(*) FROM FOO_TABLE"
+ limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000,
database)
+ expected_sql = (
+ "SELECT TOP 1000 * \n"
+ "FROM (SELECT COUNT(*) AS COUNT_1, COUNT(*) AS COUNT_2 FROM
FOO_TABLE)"
+ " AS inner_qry"
+ )
+ self.assertEqual(expected_sql, limited_sql)
+
+ db.session.delete(database)
+ db.session.commit()
+
@mock.patch.object(
MssqlEngineSpec, "pyodbc_rows_to_tuples", return_value="converted"
)