This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 0db59b45b8 fix: adhoc metrics (#30202)
0db59b45b8 is described below
commit 0db59b45b8ef7af003c7ab4518e4eb63c08f1ff5
Author: Beto Dealmeida <[email protected]>
AuthorDate: Thu Oct 10 16:46:17 2024 -0400
fix: adhoc metrics (#30202)
---
superset/connectors/sqla/models.py | 2 ++
superset/models/helpers.py | 21 +++++++++--
superset/models/sql_lab.py | 1 +
superset/sql_parse.py | 40 +++++++--------------
tests/integration_tests/datasource_tests.py | 4 +++
tests/integration_tests/query_context_tests.py | 7 +++-
tests/unit_tests/sql_parse_tests.py | 50 ++++++++++++++++++--------
7 files changed, 80 insertions(+), 45 deletions(-)
diff --git a/superset/connectors/sqla/models.py
b/superset/connectors/sqla/models.py
index 9215e1545b..82980cef16 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -1533,6 +1533,7 @@ class SqlaTable(
expression = self._process_sql_expression(
expression=metric["sqlExpression"],
database_id=self.database_id,
+ engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
@@ -1566,6 +1567,7 @@ class SqlaTable(
expression = self._process_sql_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
+ engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 4085d3a0aa..51808f9a46 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -63,6 +63,7 @@ from superset.exceptions import (
ColumnNotFoundException,
QueryClauseValidationException,
QueryObjectValidationError,
+ SupersetParseError,
SupersetSecurityException,
)
from superset.extensions import feature_flag_manager
@@ -112,6 +113,7 @@ ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"]
def validate_adhoc_subquery(
sql: str,
database_id: int,
+ engine: str,
default_schema: str,
) -> str:
"""
@@ -126,7 +128,12 @@ def validate_adhoc_subquery(
"""
statements = []
for statement in sqlparse.parse(sql):
- if has_table_query(statement):
+ try:
+ has_table = has_table_query(str(statement), engine)
+ except SupersetParseError:
+ has_table = True
+
+ if has_table:
if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
raise SupersetSecurityException(
SupersetError(
@@ -135,7 +142,9 @@ def validate_adhoc_subquery(
level=ErrorLevel.ERROR,
)
)
+ # TODO (betodealmeida): reimplement with sqlglot
statement = insert_rls_in_predicate(statement, database_id,
default_schema)
+
statements.append(statement)
return ";\n".join(str(statement) for statement in statements)
@@ -810,10 +819,11 @@ class ExploreMixin: # pylint:
disable=too-many-public-methods
# for datasources of type query
return []
- def _process_sql_expression(
+ def _process_sql_expression( # pylint: disable=too-many-arguments
self,
expression: Optional[str],
database_id: int,
+ engine: str,
schema: str,
template_processor: Optional[BaseTemplateProcessor],
) -> Optional[str]:
@@ -823,6 +833,7 @@ class ExploreMixin: # pylint:
disable=too-many-public-methods
expression = validate_adhoc_subquery(
expression,
database_id,
+ engine,
schema,
)
try:
@@ -1108,6 +1119,7 @@ class ExploreMixin: # pylint:
disable=too-many-public-methods
expression = self._process_sql_expression(
expression=metric["sqlExpression"],
database_id=self.database_id,
+ engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
@@ -1551,6 +1563,7 @@ class ExploreMixin: # pylint:
disable=too-many-public-methods
col["sqlExpression"] = self._process_sql_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
+ engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
@@ -1613,6 +1626,7 @@ class ExploreMixin: # pylint:
disable=too-many-public-methods
selected = validate_adhoc_subquery(
selected,
self.database_id,
+ self.database.backend,
self.schema,
)
outer = literal_column(f"({selected})")
@@ -1639,6 +1653,7 @@ class ExploreMixin: # pylint:
disable=too-many-public-methods
selected = validate_adhoc_subquery(
_sql,
self.database_id,
+ self.database.backend,
self.schema,
)
@@ -1915,6 +1930,7 @@ class ExploreMixin: # pylint:
disable=too-many-public-methods
where = self._process_sql_expression(
expression=where,
database_id=self.database_id,
+ engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
@@ -1933,6 +1949,7 @@ class ExploreMixin: # pylint:
disable=too-many-public-methods
having = self._process_sql_expression(
expression=having,
database_id=self.database_id,
+ engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index 6f25a5a660..1702601d0f 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -374,6 +374,7 @@ class Query(
expression = self._process_sql_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
+ engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 1581b0c6e7..cb457cd4f5 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -64,6 +64,7 @@ from superset.sql.parse import (
extract_tables_from_statement,
SQLGLOT_DIALECTS,
SQLScript,
+ SQLStatement,
Table,
)
from superset.utils.backports import StrEnum
@@ -570,46 +571,31 @@ class InsertRLSState(StrEnum):
FOUND_TABLE = "FOUND_TABLE"
-def has_table_query(token_list: TokenList) -> bool:
+def has_table_query(expression: str, engine: str) -> bool:
"""
Return if a statement has a query reading from a table.
- >>> has_table_query(sqlparse.parse("COUNT(*)")[0])
+ >>> has_table_query("COUNT(*)", "postgresql")
False
- >>> has_table_query(sqlparse.parse("SELECT * FROM table")[0])
+ >>> has_table_query("SELECT * FROM table", "postgresql")
True
Note that queries reading from constant values return false:
- >>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0])
+ >>> has_table_query("SELECT * FROM (SELECT 1)", "postgresql")
False
"""
- state = InsertRLSState.SCANNING
- for token in token_list.tokens:
- # Ignore comments
- if isinstance(token, sqlparse.sql.Comment):
- continue
-
- # Recurse into child token list
- if isinstance(token, TokenList) and has_table_query(token):
- return True
-
- # Found a source keyword (FROM/JOIN)
- if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
- state = InsertRLSState.SEEN_SOURCE
-
- # Found identifier/keyword after FROM/JOIN
- elif state == InsertRLSState.SEEN_SOURCE and (
- isinstance(token, sqlparse.sql.Identifier) or token.ttype ==
Keyword
- ):
- return True
+ # Remove trailing semicolon.
+ expression = expression.strip().rstrip(";")
- # Found nothing, leaving source
- elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
- state = InsertRLSState.SCANNING
+ # Wrap the expression in parentheses if it's not already.
+ if not expression.startswith("("):
+ expression = f"({expression})"
- return False
+ sql = f"SELECT {expression}"
+ statement = SQLStatement(sql, engine)
+ return any(statement.tables)
def add_table_name(rls: TokenList, table: str) -> None:
diff --git a/tests/integration_tests/datasource_tests.py
b/tests/integration_tests/datasource_tests.py
index ec45c8c57e..ab13fc4daf 100644
--- a/tests/integration_tests/datasource_tests.py
+++ b/tests/integration_tests/datasource_tests.py
@@ -42,6 +42,7 @@ from superset.utils.database import ( # noqa: F401
get_main_database,
)
from tests.integration_tests.base_tests import db_insert_temp_object,
SupersetTestCase
+from tests.integration_tests.conftest import with_feature_flags
from tests.integration_tests.constants import ADMIN_USERNAME
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices, # noqa: F401
@@ -585,6 +586,7 @@ def test_get_samples_with_incorrect_cc(test_client,
login_as_admin, virtual_data
assert "INCORRECT SQL" in rv.json.get("error")
+@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
def test_get_samples_on_physical_dataset(test_client, login_as_admin,
physical_dataset):
uri = (
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
@@ -649,6 +651,7 @@ def test_get_samples_with_filters(test_client,
login_as_admin, virtual_dataset):
assert rv.json["result"]["rowcount"] == 0
+@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
def test_get_samples_with_time_filter(test_client, login_as_admin,
physical_dataset):
uri = (
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
@@ -669,6 +672,7 @@ def test_get_samples_with_time_filter(test_client,
login_as_admin, physical_data
assert rv.json["result"]["total_count"] == 2
+@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
def test_get_samples_with_multiple_filters(
test_client, login_as_admin, physical_dataset
):
diff --git a/tests/integration_tests/query_context_tests.py
b/tests/integration_tests/query_context_tests.py
index 4822b690ed..d77523c71c 100644
--- a/tests/integration_tests/query_context_tests.py
+++ b/tests/integration_tests/query_context_tests.py
@@ -42,7 +42,11 @@ from superset.utils.core import (
)
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
from tests.integration_tests.base_tests import SupersetTestCase
-from tests.integration_tests.conftest import only_postgresql, only_sqlite
+from tests.integration_tests.conftest import (
+ only_postgresql,
+ only_sqlite,
+ with_feature_flags,
+)
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices, # noqa: F401
load_birth_names_data, # noqa: F401
@@ -858,6 +862,7 @@ def test_non_time_column_with_time_grain(app_context,
physical_dataset):
assert df["COL2 ALIAS"][0] == "a"
+@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
def test_special_chars_in_column_name(app_context, physical_dataset):
qc = QueryContextFactory().create(
datasource={
diff --git a/tests/unit_tests/sql_parse_tests.py
b/tests/unit_tests/sql_parse_tests.py
index 23d51de64c..44d52c7f6e 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -1286,46 +1286,66 @@ def test_sqlparse_issue_652():
@pytest.mark.parametrize(
- "sql,expected",
+ ("engine", "sql", "expected"),
[
- ("SELECT * FROM table", True),
- ("SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)", True),
- ("(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)", True),
- ("COUNT(*)", False),
- ("SELECT a FROM (SELECT 1 AS a)", False),
- ("SELECT a FROM (SELECT 1 AS a) JOIN table", True),
- ("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar",
False),
- ("SELECT * FROM other_table", True),
- ("extract(HOUR from from_unixtime(hour_ts)", False),
- ("(SELECT * FROM table)", True),
- ("(SELECT COUNT(DISTINCT name) from birth_names)", True),
+ ("postgresql", "extract(HOUR from from_unixtime(hour_ts))", False),
+ ("postgresql", "SELECT * FROM table", True),
+ ("postgresql", "(SELECT * FROM table)", True),
(
+ "postgresql",
+ "SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)",
+ True,
+ ),
+ (
+ "postgresql",
+ "(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)",
+ True,
+ ),
+ ("postgresql", "COUNT(*)", False),
+ ("postgresql", "SELECT a FROM (SELECT 1 AS a)", False),
+ ("postgresql", "SELECT a FROM (SELECT 1 AS a) JOIN table", True),
+ (
+ "postgresql",
+ "SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar",
+ False,
+ ),
+ ("postgresql", "SELECT * FROM other_table", True),
+ ("postgresql", "(SELECT COUNT(DISTINCT name) from birth_names)", True),
+ (
+ "postgresql",
"(SELECT table_name FROM information_schema.tables WHERE
table_name LIKE '%user%' LIMIT 1)",
True,
),
(
+ "postgresql",
"(SELECT table_name FROM /**/ information_schema.tables WHERE
table_name LIKE '%user%' LIMIT 1)",
True,
),
(
+ "postgresql",
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;",
True,
),
(
+ "postgresql",
"SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table",
True,
),
+ (
+ "postgresql",
+ "((select users.id from (select 'majorie' as a) b, users where b.a
= users.name and users.name in ('majorie') limit 1) like 'U%')",
+ True,
+ ),
],
)
-def test_has_table_query(sql: str, expected: bool) -> None:
+def test_has_table_query(engine: str, sql: str, expected: bool) -> None:
"""
Test if a given statement queries a table.
This is used to prevent ad-hoc metrics from querying unauthorized tables,
bypassing
row-level security.
"""
- statement = sqlparse.parse(sql)[0]
- assert has_table_query(statement) == expected
+ assert has_table_query(sql, engine) == expected
@pytest.mark.parametrize(