This is an automated email from the ASF dual-hosted git repository.

dpgaspar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 516bdf6  fix(mssql): apply limit and set alias for functions (#9644)
516bdf6 is described below

commit 516bdf6db1bbb9591be24ee8869c6466eb14e8c2
Author: Daniel Vaz Gaspar <[email protected]>
AuthorDate: Mon Apr 27 09:23:08 2020 +0100

    fix(mssql): apply limit and set alias for functions (#9644)
---
 superset/db_engine_specs/mssql.py    | 14 +++++++-
 superset/sql_parse.py                | 45 ++++++++++++++++++++++++-
 tests/db_engine_specs/mssql_tests.py | 64 +++++++++++++++++++++++++++++++++++-
 3 files changed, 120 insertions(+), 3 deletions(-)

diff --git a/superset/db_engine_specs/mssql.py 
b/superset/db_engine_specs/mssql.py
index 6a231b2..4fc6e6f 100644
--- a/superset/db_engine_specs/mssql.py
+++ b/superset/db_engine_specs/mssql.py
@@ -14,13 +14,20 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import logging
 import re
 from datetime import datetime
-from typing import Any, List, Optional, Tuple
+from typing import Any, List, Optional, Tuple, TYPE_CHECKING
 
 from sqlalchemy.types import String, TypeEngine, UnicodeText
 
 from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
+from superset.sql_parse import ParsedQuery
+
+if TYPE_CHECKING:
+    from superset.models.core import Database  # pylint: disable=unused-import
+
+logger = logging.getLogger(__name__)
 
 
 class MssqlEngineSpec(BaseEngineSpec):
@@ -76,3 +83,8 @@ class MssqlEngineSpec(BaseEngineSpec):
             if regex.match(type_):
                 return sqla_type
         return None
+
+    @classmethod
+    def apply_limit_to_sql(cls, sql: str, limit: int, database: "Database") -> 
str:
+        new_sql = ParsedQuery(sql).set_alias()
+        return super().apply_limit_to_sql(new_sql, limit, database)
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index b39fc44..8cac2ff 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -18,7 +18,14 @@ import logging
 from typing import List, Optional, Set
 
 import sqlparse
-from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, 
TokenList
+from sqlparse.sql import (
+    Function,
+    Identifier,
+    IdentifierList,
+    remove_quotes,
+    Token,
+    TokenList,
+)
 from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
 from sqlparse.utils import imt
 
@@ -247,3 +254,39 @@ class ParsedQuery:
         for i in statement.tokens:
             str_res += str(i.value)
         return str_res
+
+    def set_alias(self) -> str:
+        """
+        Returns a new query string where all functions have alias.
+        This is particularly necessary for MSSQL engines.
+
+        :return: String with new aliased SQL query
+        """
+        new_sql = ""
+        changed_counter = 1
+        for token in self._parsed[0].tokens:
+            # Identifier list (list of columns)
+            if isinstance(token, IdentifierList) and token.ttype is None:
+                for i, identifier in enumerate(token.get_identifiers()):
+                    # Functions are anonymous on MSSQL
+                    if isinstance(identifier, Function) and not 
identifier.has_alias():
+                        identifier.value = (
+                            f"{identifier.value} AS"
+                            f" {identifier.get_real_name()}_{changed_counter}"
+                        )
+                        changed_counter += 1
+                    new_sql += str(identifier.value)
+                    # If not last identifier
+                    if i != len(list(token.get_identifiers())) - 1:
+                        new_sql += ", "
+            # Just a lonely function?
+            elif isinstance(token, Function) and token.ttype is None:
+                if not token.has_alias():
+                    token.value = (
+                        f"{token.value} AS 
{token.get_real_name()}_{changed_counter}"
+                    )
+                new_sql += str(token.value)
+            # Nothing to change, assemble what we have
+            else:
+                new_sql += str(token.value)
+        return new_sql
diff --git a/tests/db_engine_specs/mssql_tests.py 
b/tests/db_engine_specs/mssql_tests.py
index 238dd2a..9f5351c 100644
--- a/tests/db_engine_specs/mssql_tests.py
+++ b/tests/db_engine_specs/mssql_tests.py
@@ -15,15 +15,18 @@
 # specific language governing permissions and limitations
 # under the License.
 import unittest.mock as mock
+from typing import Optional
 
 from sqlalchemy import column, table
 from sqlalchemy.dialects import mssql
 from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR
-from sqlalchemy.sql import select
+from sqlalchemy.sql import select, Select
 from sqlalchemy.types import String, UnicodeText
 
 from superset.db_engine_specs.base import BaseEngineSpec
 from superset.db_engine_specs.mssql import MssqlEngineSpec
+from superset.extensions import db
+from superset.models.core import Database
 from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
 
 
@@ -94,6 +97,65 @@ class MssqlEngineSpecTest(DbEngineSpecTestCase):
         for actual, expected in test_cases:
             self.assertEqual(actual, expected)
 
+    def test_apply_limit(self):
+        def compile_sqla_query(qry: Select, schema: Optional[str] = None) -> 
str:
+            return str(
+                qry.compile(
+                    dialect=mssql.dialect(), compile_kwargs={"literal_binds": 
True}
+                )
+            )
+
+        database = Database(
+            database_name="mssql_test",
+            
sqlalchemy_uri="mssql+pymssql://sa:Password_123@localhost:1433/msdb",
+        )
+        db.session.add(database)
+        db.session.commit()
+
+        with mock.patch.object(database, "compile_sqla_query", 
new=compile_sqla_query):
+            test_sql = "SELECT COUNT(*) FROM FOO_TABLE"
+
+            limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, 
database)
+
+            expected_sql = (
+                "SELECT TOP 1000 * \n"
+                "FROM (SELECT COUNT(*) AS COUNT_1 FROM FOO_TABLE) AS inner_qry"
+            )
+            self.assertEqual(expected_sql, limited_sql)
+
+            test_sql = "SELECT COUNT(*), SUM(id) FROM FOO_TABLE"
+            limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, 
database)
+
+            expected_sql = (
+                "SELECT TOP 1000 * \n"
+                "FROM (SELECT COUNT(*) AS COUNT_1, SUM(id) AS SUM_2 FROM 
FOO_TABLE) "
+                "AS inner_qry"
+            )
+            self.assertEqual(expected_sql, limited_sql)
+
+            test_sql = "SELECT COUNT(*), FOO_COL1 FROM FOO_TABLE GROUP BY 
FOO_COL1"
+            limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, 
database)
+
+            expected_sql = (
+                "SELECT TOP 1000 * \n"
+                "FROM (SELECT COUNT(*) AS COUNT_1, "
+                "FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1)"
+                " AS inner_qry"
+            )
+            self.assertEqual(expected_sql, limited_sql)
+
+            test_sql = "SELECT COUNT(*), COUNT(*) FROM FOO_TABLE"
+            limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, 
database)
+            expected_sql = (
+                "SELECT TOP 1000 * \n"
+                "FROM (SELECT COUNT(*) AS COUNT_1, COUNT(*) AS COUNT_2 FROM 
FOO_TABLE)"
+                " AS inner_qry"
+            )
+            self.assertEqual(expected_sql, limited_sql)
+
+        db.session.delete(database)
+        db.session.commit()
+
     @mock.patch.object(
         MssqlEngineSpec, "pyodbc_rows_to_tuples", return_value="converted"
     )

Reply via email to