This is an automated email from the ASF dual-hosted git repository.
maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 5a40f71 [sql lab] improve table name detection in free form SQL
(#6793)
5a40f71 is described below
commit 5a40f7171079280c1c7d452e6f1344156b24a409
Author: Maxime Beauchemin <[email protected]>
AuthorDate: Mon Feb 4 16:03:23 2019 -0800
[sql lab] improve table name detection in free form SQL (#6793)
* [sql lab] improve table name detection in free form SQL
* flake
* Addressing comments
---
superset/sql_parse.py | 67 +++++++++++++++++++++---------------------------
tests/sql_parse_tests.py | 42 +++++++++++++++++++++++++++++-
2 files changed, 70 insertions(+), 39 deletions(-)
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 241917a..d1ad23d 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -23,7 +23,10 @@ from sqlparse.tokens import Keyword, Name
RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'}
ON_KEYWORD = 'ON'
-PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'}
+PRECEDES_TABLE_NAME = {
+ 'FROM', 'JOIN', 'DESCRIBE', 'WITH', 'LEFT JOIN', 'RIGHT JOIN',
+}
+CTE_PREFIX = 'CTE__'
class ParsedQuery(object):
@@ -72,13 +75,6 @@ class ParsedQuery(object):
return statements
@staticmethod
- def __precedes_table_name(token_value):
- for keyword in PRECEDES_TABLE_NAME:
- if keyword in token_value:
- return True
- return False
-
- @staticmethod
def __get_full_name(identifier):
if len(identifier.tokens) > 1 and identifier.tokens[1].value == '.':
return '{}.{}'.format(identifier.tokens[0].value,
@@ -86,20 +82,15 @@ class ParsedQuery(object):
return identifier.get_real_name()
@staticmethod
- def __is_result_operation(keyword):
- for operation in RESULT_OPERATIONS:
- if operation in keyword.upper():
- return True
- return False
-
- @staticmethod
def __is_identifier(token):
return isinstance(token, (IdentifierList, Identifier))
def __process_identifier(self, identifier):
# exclude subselects
- if '(' not in '{}'.format(identifier):
- self._table_names.add(self.__get_full_name(identifier))
+ if '(' not in str(identifier):
+ table_name = self.__get_full_name(identifier)
+ if not table_name.startswith(CTE_PREFIX):
+ self._table_names.add(self.__get_full_name(identifier))
return
# store aliases
@@ -129,39 +120,39 @@ class ParsedQuery(object):
exec_sql += f'CREATE TABLE {table_name} AS \n{sql}'
return exec_sql
- def __extract_from_token(self, token):
+ def __extract_from_token(self, token, depth=0):
if not hasattr(token, 'tokens'):
return
table_name_preceding_token = False
for item in token.tokens:
+ logging.debug((' ' * depth) + str(item.ttype) + str(item.value))
if item.is_group and not self.__is_identifier(item):
- self.__extract_from_token(item)
+ self.__extract_from_token(item, depth=depth + 1)
+
+ if (
+ item.ttype in Keyword and (
+ item.normalized in PRECEDES_TABLE_NAME or
+ item.normalized.endswith(' JOIN')
+ )):
+ table_name_preceding_token = True
+ continue
if item.ttype in Keyword:
- if self.__precedes_table_name(item.value.upper()):
- table_name_preceding_token = True
- continue
-
- if not table_name_preceding_token:
+ table_name_preceding_token = False
continue
- if item.ttype in Keyword or item.value == ',':
- if (self.__is_result_operation(item.value) or
- item.value.upper() == ON_KEYWORD):
- table_name_preceding_token = False
- continue
- # FROM clause is over
- break
-
- if isinstance(item, Identifier):
- self.__process_identifier(item)
-
- if isinstance(item, IdentifierList):
- for token in item.tokens:
- if self.__is_identifier(token):
+ if table_name_preceding_token:
+ if isinstance(item, Identifier):
+ self.__process_identifier(item)
+ elif isinstance(item, IdentifierList):
+ for token in item.get_identifiers():
self.__process_identifier(token)
+ elif isinstance(item, IdentifierList):
+ for token in item.tokens:
+ if not self.__is_identifier(token):
+ self.__extract_from_token(item, depth=depth + 1)
def _get_limit_from_token(self, token):
if token.ttype == sqlparse.tokens.Literal.Number.Integer:
diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py
index 9247780..e821fce 100644
--- a/tests/sql_parse_tests.py
+++ b/tests/sql_parse_tests.py
@@ -167,7 +167,6 @@ class SupersetTestCase(unittest.TestCase):
# DESCRIBE | DESC qualifiedName
def test_describe(self):
self.assertEquals({'t1'}, self.extract_tables('DESCRIBE t1'))
- self.assertEquals({'t1'}, self.extract_tables('DESC t1'))
# SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)?
# (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))?
@@ -349,6 +348,32 @@ class SupersetTestCase(unittest.TestCase):
{'table_a', 'table_b', 'table_c'},
self.extract_tables(query))
+ def test_mixed_from_clause(self):
+ query = """SELECT *
+ FROM table_a AS a, (select * from table_b) AS b, table_c as c
+ WHERE a.id = b.id and b.id = c.id"""
+ self.assertEquals(
+ {'table_a', 'table_b', 'table_c'},
+ self.extract_tables(query))
+
+ def test_nested_selects(self):
+ query = """
+ select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
+ from INFORMATION_SCHEMA.COLUMNS
+ WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
+ """
+ self.assertEquals(
+ {'INFORMATION_SCHEMA.COLUMNS'},
+ self.extract_tables(query))
+ query = """
+ select (extractvalue(1,concat(0x7e,(select
GROUP_CONCAT(COLUMN_NAME)
+ from INFORMATION_SCHEMA.COLUMNS
+ WHERE TABLE_NAME="bi_achivement_daily"),0x7e)));
+ """
+ self.assertEquals(
+ {'INFORMATION_SCHEMA.COLUMNS'},
+ self.extract_tables(query))
+
def test_complex_extract_tables3(self):
query = """SELECT somecol AS somecol
FROM
@@ -386,6 +411,21 @@ class SupersetTestCase(unittest.TestCase):
{'a', 'b', 'c', 'd', 'e', 'f'},
self.extract_tables(query))
+ def test_complex_cte_with_prefix(self):
+ query = """
+ WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
+ AS (
+ SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
+ FROM SalesOrderHeader
+ WHERE SalesPersonID IS NOT NULL
+ )
+ SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear
+ FROM CTE__test
+ GROUP BY SalesYear, SalesPersonID
+ ORDER BY SalesPersonID, SalesYear;
+ """
+ self.assertEquals({'SalesOrderHeader'}, self.extract_tables(query))
+
def test_basic_breakdown_statements(self):
multi_sql = """
SELECT * FROM ab_user;