This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 0db59b45b8 fix: adhoc metrics (#30202)
0db59b45b8 is described below

commit 0db59b45b8ef7af003c7ab4518e4eb63c08f1ff5
Author: Beto Dealmeida <[email protected]>
AuthorDate: Thu Oct 10 16:46:17 2024 -0400

    fix: adhoc metrics (#30202)
---
 superset/connectors/sqla/models.py             |  2 ++
 superset/models/helpers.py                     | 21 +++++++++--
 superset/models/sql_lab.py                     |  1 +
 superset/sql_parse.py                          | 40 +++++++--------------
 tests/integration_tests/datasource_tests.py    |  4 +++
 tests/integration_tests/query_context_tests.py |  7 +++-
 tests/unit_tests/sql_parse_tests.py            | 50 ++++++++++++++++++--------
 7 files changed, 80 insertions(+), 45 deletions(-)

diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 9215e1545b..82980cef16 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -1533,6 +1533,7 @@ class SqlaTable(
                 expression = self._process_sql_expression(
                     expression=metric["sqlExpression"],
                     database_id=self.database_id,
+                    engine=self.database.backend,
                     schema=self.schema,
                     template_processor=template_processor,
                 )
@@ -1566,6 +1567,7 @@ class SqlaTable(
             expression = self._process_sql_expression(
                 expression=col["sqlExpression"],
                 database_id=self.database_id,
+                engine=self.database.backend,
                 schema=self.schema,
                 template_processor=template_processor,
             )
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 4085d3a0aa..51808f9a46 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -63,6 +63,7 @@ from superset.exceptions import (
     ColumnNotFoundException,
     QueryClauseValidationException,
     QueryObjectValidationError,
+    SupersetParseError,
     SupersetSecurityException,
 )
 from superset.extensions import feature_flag_manager
@@ -112,6 +113,7 @@ ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"]
 def validate_adhoc_subquery(
     sql: str,
     database_id: int,
+    engine: str,
     default_schema: str,
 ) -> str:
     """
@@ -126,7 +128,12 @@ def validate_adhoc_subquery(
     """
     statements = []
     for statement in sqlparse.parse(sql):
-        if has_table_query(statement):
+        try:
+            has_table = has_table_query(str(statement), engine)
+        except SupersetParseError:
+            has_table = True
+
+        if has_table:
             if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
                 raise SupersetSecurityException(
                     SupersetError(
@@ -135,7 +142,9 @@ def validate_adhoc_subquery(
                         level=ErrorLevel.ERROR,
                     )
                 )
+            # TODO (betodealmeida): reimplement with sqlglot
             statement = insert_rls_in_predicate(statement, database_id, 
default_schema)
+
         statements.append(statement)
 
     return ";\n".join(str(statement) for statement in statements)
@@ -810,10 +819,11 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
         # for datasources of type query
         return []
 
-    def _process_sql_expression(
+    def _process_sql_expression(  # pylint: disable=too-many-arguments
         self,
         expression: Optional[str],
         database_id: int,
+        engine: str,
         schema: str,
         template_processor: Optional[BaseTemplateProcessor],
     ) -> Optional[str]:
@@ -823,6 +833,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
             expression = validate_adhoc_subquery(
                 expression,
                 database_id,
+                engine,
                 schema,
             )
             try:
@@ -1108,6 +1119,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
             expression = self._process_sql_expression(
                 expression=metric["sqlExpression"],
                 database_id=self.database_id,
+                engine=self.database.backend,
                 schema=self.schema,
                 template_processor=template_processor,
             )
@@ -1551,6 +1563,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                     col["sqlExpression"] = self._process_sql_expression(
                         expression=col["sqlExpression"],
                         database_id=self.database_id,
+                        engine=self.database.backend,
                         schema=self.schema,
                         template_processor=template_processor,
                     )
@@ -1613,6 +1626,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                         selected = validate_adhoc_subquery(
                             selected,
                             self.database_id,
+                            self.database.backend,
                             self.schema,
                         )
                         outer = literal_column(f"({selected})")
@@ -1639,6 +1653,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                 selected = validate_adhoc_subquery(
                     _sql,
                     self.database_id,
+                    self.database.backend,
                     self.schema,
                 )
 
@@ -1915,6 +1930,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                 where = self._process_sql_expression(
                     expression=where,
                     database_id=self.database_id,
+                    engine=self.database.backend,
                     schema=self.schema,
                     template_processor=template_processor,
                 )
@@ -1933,6 +1949,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                 having = self._process_sql_expression(
                     expression=having,
                     database_id=self.database_id,
+                    engine=self.database.backend,
                     schema=self.schema,
                     template_processor=template_processor,
                 )
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index 6f25a5a660..1702601d0f 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -374,6 +374,7 @@ class Query(
         expression = self._process_sql_expression(
             expression=col["sqlExpression"],
             database_id=self.database_id,
+            engine=self.database.backend,
             schema=self.schema,
             template_processor=template_processor,
         )
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 1581b0c6e7..cb457cd4f5 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -64,6 +64,7 @@ from superset.sql.parse import (
     extract_tables_from_statement,
     SQLGLOT_DIALECTS,
     SQLScript,
+    SQLStatement,
     Table,
 )
 from superset.utils.backports import StrEnum
@@ -570,46 +571,31 @@ class InsertRLSState(StrEnum):
     FOUND_TABLE = "FOUND_TABLE"
 
 
-def has_table_query(token_list: TokenList) -> bool:
+def has_table_query(expression: str, engine: str) -> bool:
     """
     Return if a statement has a query reading from a table.
 
-        >>> has_table_query(sqlparse.parse("COUNT(*)")[0])
+        >>> has_table_query("COUNT(*)", "postgresql")
         False
-        >>> has_table_query(sqlparse.parse("SELECT * FROM table")[0])
+        >>> has_table_query("SELECT * FROM table", "postgresql")
         True
 
     Note that queries reading from constant values return false:
 
-        >>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0])
+        >>> has_table_query("SELECT * FROM (SELECT 1)", "postgresql")
         False
 
     """
-    state = InsertRLSState.SCANNING
-    for token in token_list.tokens:
-        # Ignore comments
-        if isinstance(token, sqlparse.sql.Comment):
-            continue
-
-        # Recurse into child token list
-        if isinstance(token, TokenList) and has_table_query(token):
-            return True
-
-        # Found a source keyword (FROM/JOIN)
-        if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
-            state = InsertRLSState.SEEN_SOURCE
-
-        # Found identifier/keyword after FROM/JOIN
-        elif state == InsertRLSState.SEEN_SOURCE and (
-            isinstance(token, sqlparse.sql.Identifier) or token.ttype == 
Keyword
-        ):
-            return True
+    # Remove trailing semicolon.
+    expression = expression.strip().rstrip(";")
 
-        # Found nothing, leaving source
-        elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
-            state = InsertRLSState.SCANNING
+    # Wrap the expression in parentheses if it's not already.
+    if not expression.startswith("("):
+        expression = f"({expression})"
 
-    return False
+    sql = f"SELECT {expression}"
+    statement = SQLStatement(sql, engine)
+    return any(statement.tables)
 
 
 def add_table_name(rls: TokenList, table: str) -> None:
diff --git a/tests/integration_tests/datasource_tests.py 
b/tests/integration_tests/datasource_tests.py
index ec45c8c57e..ab13fc4daf 100644
--- a/tests/integration_tests/datasource_tests.py
+++ b/tests/integration_tests/datasource_tests.py
@@ -42,6 +42,7 @@ from superset.utils.database import (  # noqa: F401
     get_main_database,
 )
 from tests.integration_tests.base_tests import db_insert_temp_object, 
SupersetTestCase
+from tests.integration_tests.conftest import with_feature_flags
 from tests.integration_tests.constants import ADMIN_USERNAME
 from tests.integration_tests.fixtures.birth_names_dashboard import (
     load_birth_names_dashboard_with_slices,  # noqa: F401
@@ -585,6 +586,7 @@ def test_get_samples_with_incorrect_cc(test_client, 
login_as_admin, virtual_data
         assert "INCORRECT SQL" in rv.json.get("error")
 
 
+@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
 def test_get_samples_on_physical_dataset(test_client, login_as_admin, 
physical_dataset):
     uri = (
         
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
@@ -649,6 +651,7 @@ def test_get_samples_with_filters(test_client, 
login_as_admin, virtual_dataset):
     assert rv.json["result"]["rowcount"] == 0
 
 
+@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
 def test_get_samples_with_time_filter(test_client, login_as_admin, 
physical_dataset):
     uri = (
         
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
@@ -669,6 +672,7 @@ def test_get_samples_with_time_filter(test_client, 
login_as_admin, physical_data
     assert rv.json["result"]["total_count"] == 2
 
 
+@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
 def test_get_samples_with_multiple_filters(
     test_client, login_as_admin, physical_dataset
 ):
diff --git a/tests/integration_tests/query_context_tests.py 
b/tests/integration_tests/query_context_tests.py
index 4822b690ed..d77523c71c 100644
--- a/tests/integration_tests/query_context_tests.py
+++ b/tests/integration_tests/query_context_tests.py
@@ -42,7 +42,11 @@ from superset.utils.core import (
 )
 from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
 from tests.integration_tests.base_tests import SupersetTestCase
-from tests.integration_tests.conftest import only_postgresql, only_sqlite
+from tests.integration_tests.conftest import (
+    only_postgresql,
+    only_sqlite,
+    with_feature_flags,
+)
 from tests.integration_tests.fixtures.birth_names_dashboard import (
     load_birth_names_dashboard_with_slices,  # noqa: F401
     load_birth_names_data,  # noqa: F401
@@ -858,6 +862,7 @@ def test_non_time_column_with_time_grain(app_context, 
physical_dataset):
     assert df["COL2 ALIAS"][0] == "a"
 
 
+@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
 def test_special_chars_in_column_name(app_context, physical_dataset):
     qc = QueryContextFactory().create(
         datasource={
diff --git a/tests/unit_tests/sql_parse_tests.py 
b/tests/unit_tests/sql_parse_tests.py
index 23d51de64c..44d52c7f6e 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -1286,46 +1286,66 @@ def test_sqlparse_issue_652():
 
 
 @pytest.mark.parametrize(
-    "sql,expected",
+    ("engine", "sql", "expected"),
     [
-        ("SELECT * FROM table", True),
-        ("SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)", True),
-        ("(SELECT COUNT(DISTINCT name) AS foo FROM    birth_names)", True),
-        ("COUNT(*)", False),
-        ("SELECT a FROM (SELECT 1 AS a)", False),
-        ("SELECT a FROM (SELECT 1 AS a) JOIN table", True),
-        ("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", 
False),
-        ("SELECT * FROM other_table", True),
-        ("extract(HOUR from from_unixtime(hour_ts)", False),
-        ("(SELECT * FROM table)", True),
-        ("(SELECT COUNT(DISTINCT name) from birth_names)", True),
+        ("postgresql", "extract(HOUR from from_unixtime(hour_ts))", False),
+        ("postgresql", "SELECT * FROM table", True),
+        ("postgresql", "(SELECT * FROM table)", True),
         (
+            "postgresql",
+            "SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)",
+            True,
+        ),
+        (
+            "postgresql",
+            "(SELECT COUNT(DISTINCT name) AS foo FROM    birth_names)",
+            True,
+        ),
+        ("postgresql", "COUNT(*)", False),
+        ("postgresql", "SELECT a FROM (SELECT 1 AS a)", False),
+        ("postgresql", "SELECT a FROM (SELECT 1 AS a) JOIN table", True),
+        (
+            "postgresql",
+            "SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar",
+            False,
+        ),
+        ("postgresql", "SELECT * FROM other_table", True),
+        ("postgresql", "(SELECT COUNT(DISTINCT name) from birth_names)", True),
+        (
+            "postgresql",
             "(SELECT table_name FROM information_schema.tables WHERE 
table_name LIKE '%user%' LIMIT 1)",
             True,
         ),
         (
+            "postgresql",
             "(SELECT table_name FROM /**/ information_schema.tables WHERE 
table_name LIKE '%user%' LIMIT 1)",
             True,
         ),
         (
+            "postgresql",
             "SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;",
             True,
         ),
         (
+            "postgresql",
             "SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table",
             True,
         ),
+        (
+            "postgresql",
+            "((select users.id from (select 'majorie' as a) b, users where b.a 
= users.name and users.name in ('majorie') limit 1) like 'U%')",
+            True,
+        ),
     ],
 )
-def test_has_table_query(sql: str, expected: bool) -> None:
+def test_has_table_query(engine: str, sql: str, expected: bool) -> None:
     """
     Test if a given statement queries a table.
 
     This is used to prevent ad-hoc metrics from querying unauthorized tables, 
bypassing
     row-level security.
     """
-    statement = sqlparse.parse(sql)[0]
-    assert has_table_query(statement) == expected
+    assert has_table_query(sql, engine) == expected
 
 
 @pytest.mark.parametrize(

Reply via email to