This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch improve-function-detection in repository https://gitbox.apache.org/repos/asf/superset.git
commit 7a64a82cd9dbd9a47ded59f14d990d26313efb1f Author: Beto Dealmeida <[email protected]> AuthorDate: Wed Apr 30 16:58:11 2025 -0400 fix: improve function detection --- superset/sql/parse.py | 40 +++++++++++++++++++++++++++++++++++++ superset/sql_parse.py | 2 +- tests/unit_tests/sql_parse_tests.py | 29 +++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 1 deletion(-) diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 9e006c2809..ea293b858f 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -32,6 +32,7 @@ from deprecation import deprecated from sqlglot import exp from sqlglot.dialects.dialect import Dialect, Dialects from sqlglot.errors import ParseError +from sqlglot.expressions import Func from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope @@ -453,6 +454,23 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): return SQLStatement(sql, self.engine, optimized) + def check_functions_present(self, functions: set[str]) -> bool: + """ + Check if any of the given functions are present in the script. + + :param function_list: List of functions to check for + :return: True if any of the functions are present + """ + present = { + ( + function.sql_name() + if function.sql_name() != "ANONYMOUS" + else function.name.upper() + ) + for function in self._parsed.find_all(Func) + } + return any(function.upper() in present for function in functions) + class KQLSplitState(enum.Enum): """ @@ -619,6 +637,16 @@ class KustoKQLStatement(BaseSQLStatement[str]): """ return KustoKQLStatement(self._sql, self.engine, self._parsed) + def check_functions_present(self, functions: set[str]) -> bool: + """ + Check if any of the given functions are present in the script. + + :param function_list: List of functions to check for + :return: True if any of the functions are present + """ + logger.warning("Kusto KQL doesn't support checking for functions present.") + return True + class SQLScript: """ @@ -684,6 +712,18 @@ class SQLScript: return script + def check_functions_present(self, functions: set[str]) -> bool: + """ + Check if any of the given functions are present in the script. + + :param function_list: List of functions to check for + :return: True if any of the functions are present + """ + return any( + statement.check_functions_present(functions) + for statement in self.statements + ) + def extract_tables_from_statement( statement: exp.Expression, diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 2cd8d0ac9d..d51cc8ac30 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -181,7 +181,7 @@ def check_sql_functions_exist( :param function_list: The list of functions to search for :param engine: The engine to use for parsing the SQL statement """ - return ParsedQuery(sql, engine=engine).check_functions_exist(function_list) + return SQLScript(sql, engine=engine).check_functions_present(function_list) def strip_comments_from_sql(statement: str, engine: str = "base") -> str: diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 9c814a0f42..23aa6b0b12 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -1237,6 +1237,35 @@ def test_check_sql_functions_exist() -> None: ) +def test_check_sql_functions_exist_with_comments() -> None: + """ + Test sql functions are detected correctly with comments + """ + assert not ( + check_sql_functions_exist( + "select a, b from version/**/", {"version"}, "postgresql" + ) + ) + + assert check_sql_functions_exist("select version/**/()", {"version"}, "postgresql") + + assert check_sql_functions_exist( + "select version from version/**/()", {"version"}, "postgresql" + ) + + assert check_sql_functions_exist( + "select 1, a.version from (select version from version/**/()) as a", + {"version"}, + "postgresql", + ) + + assert check_sql_functions_exist( + "select 1, a.version from (select version/**/()) as a", + {"version"}, + "postgresql", + ) + + def test_sanitize_clause_valid(): # regular clauses assert sanitize_clause("col = 1") == "col = 1"
