This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch snowflake-semantic-layer in repository https://gitbox.apache.org/repos/asf/superset.git
commit 57ca29baf416945933b1e3ce0b52df86c650fd77 Author: Beto Dealmeida <[email protected]> AuthorDate: Wed Jul 30 16:08:32 2025 -0400 Working --- superset/connectors/sqla/models.py | 8 +- superset/daos/dataset.py | 11 ++- superset/db_engine_specs/base.py | 24 +++++- superset/db_engine_specs/snowflake.py | 141 +++++++++++++++++++++++++++------- superset/extensions/semantic_layer.py | 4 +- superset/models/core.py | 21 +++-- 6 files changed, 161 insertions(+), 48 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 5bbafda72d..4798c9970f 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -164,7 +164,9 @@ class DatasourceKind(StrEnum): PHYSICAL = "physical" -class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods +class BaseDatasource( + AuditMixinNullable, ImportExportMixin +): # pylint: disable=too-many-public-methods """A common interface to objects that are queryable (tables and datasources)""" @@ -1778,7 +1780,9 @@ class SqlaTable( def default_query(qry: Query) -> Query: return qry.filter_by(is_sqllab_view=False) - def has_extra_cache_key_calls(self, query_obj: QueryObjectDict) -> bool: # noqa: C901 + def has_extra_cache_key_calls( + self, query_obj: QueryObjectDict + ) -> bool: # noqa: C901 """ Detects the presence of calls to `ExtraCache` methods in items in query_obj that can be templated. If any are present, the query must be evaluated to extract diff --git a/superset/daos/dataset.py b/superset/daos/dataset.py index 5fbd012801..9221af7575 100644 --- a/superset/daos/dataset.py +++ b/superset/daos/dataset.py @@ -75,12 +75,11 @@ class DatasetDAO(BaseDAO[SqlaTable]): database: Database, table: Table, ) -> bool: - try: - database.get_table(table) - return True - except SQLAlchemyError as ex: # pragma: no cover - logger.warning("Got an error %s validating table: %s", str(ex), table) - return False + with database.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: + return database.db_engine_spec.has_table(database, inspector, table) @staticmethod def validate_uniqueness( diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index f7594f3844..71ea3377e4 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1491,6 +1491,21 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return tables + @classmethod + def has_table( + cls, + database: Database, + inspector: Inspector, + table: Table, + ) -> bool: + if cls.semantic_layer: + semantic_layer = cls.semantic_layer(inspector.engine) + semantic_views = semantic_layer.get_semantic_views() + if table.table in {semantic_view.name for semantic_view in semantic_views}: + return True + + return inspector.has_table(table.table, table.schema) + @classmethod def get_view_names( # pylint: disable=unused-argument cls, @@ -1564,6 +1579,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def get_columns( # pylint: disable=unused-argument cls, + database: Database, inspector: Inspector, table: Table, options: dict[str, Any] | None = None, @@ -1588,11 +1604,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods for semantic_view in semantic_layer.get_semantic_views() } if semantic_view := semantic_views.get(table.table): + dialect = database.get_dialect() return [ { "name": dimension.name, "column_name": dimension.name, - "type": get_sqla_type_from_dimension_type(dimension.type), + "type": cls.column_datatype_to_string( + get_sqla_type_from_dimension_type(dimension.type), + dialect, + ), } for dimension in semantic_layer.get_dimensions(semantic_view) ] @@ -1962,7 +1982,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :return: """ if cls.semantic_layer: - with database.get_engine() as engine: + with cls.get_engine(database, schema="tpcds_sf10tcl") as engine: semantic_layer = cls.semantic_layer(engine) query = semantic_layer.get_query_from_standard_sql(query).sql diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 27de00bb66..2056e7b4bc 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -33,9 +33,10 @@ from flask import current_app from flask_babel import gettext as __ from marshmallow import fields, Schema from sqlalchemy import text, types +from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL -from sqlglot import exp +from sqlglot import exp, parse_one from superset.constants import TimeGrain from superset.databases.utils import make_url_safe @@ -124,23 +125,23 @@ class SnowflakeSemanticLayer: SHOW SEMANTIC VIEWS ->> SELECT "name" FROM $1; """ - return {row["name"] for row in self.execute(sql)} + return {SemanticView(row["name"]) for row in self.execute(sql)} def get_type(self, snowflake_type: str | None) -> type[SemanticType]: if snowflake_type is None: return STRING type_map = { - STRING: {r"VARCHAR\(\d+\)", "STRING", "TEXT", r"CHAR\(\d+\)"}, - INTEGER: {r"NUMBER\(38,\s0\)", "INT", "INTEGER", "BIGINT"}, - DECIMAL: {r"NUMBER\(10,\s2\)"}, - NUMBER: {r"NUMBER\(\d+,\s\d+)", "FLOAT", "DOUBLE"}, - BOOLEAN: {"BOOLEAN"}, - DATE: {"DATE"}, - DATETIME: {"TIMESTAMP_TZ", "TIMESTAMP__NTZ"}, - TIME: {"TIME"}, - OBJECT: {"OBJECT"}, - BINARY: {r"BINARY\(\d+\)", r"VARBINARY\(\d+\)"}, + STRING: {r"VARCHAR\(\d+\)$", "STRING$", "TEXT$", r"CHAR\(\d+\)$"}, + INTEGER: {r"NUMBER\(38,\s?0\)$", "INT$", "INTEGER$", "BIGINT$"}, + DECIMAL: {r"NUMBER\(10,\s?2\)$"}, + NUMBER: {r"NUMBER\(\d+,\s?\d+\)$", "FLOAT$", "DOUBLE$"}, + BOOLEAN: {"BOOLEAN$"}, + DATE: {"DATE$"}, + DATETIME: {"TIMESTAMP_TZ$", "TIMESTAMP__NTZ$"}, + TIME: {"TIME$"}, + OBJECT: {"OBJECT$"}, + BINARY: {r"BINARY\(\d+\)$", r"VARBINARY\(\d+\)$"}, } for semantic_type, patterns in type_map.items(): if any( @@ -150,6 +151,23 @@ SHOW SEMANTIC VIEWS return STRING + @classmethod + def quote_table(cls, table: Table, dialect: Dialect) -> str: + """ + Fully quote a table name, including the schema and catalog. + """ + quoters = { + "catalog": dialect.identifier_preparer.quote_schema, + "schema": dialect.identifier_preparer.quote_schema, + "table": dialect.identifier_preparer.quote, + } + + return ".".join( + function(getattr(table, key)) + for key, function in quoters.items() + if getattr(table, key) + ) + def get_metrics(self, semantic_view: SemanticView) -> set[SemanticMetric]: quoted_semantic_view_name = self.quote_table( Table(semantic_view.name), @@ -161,7 +179,7 @@ DESC SEMANTIC VIEW {quoted_semantic_view_name} FROM $1 WHERE "object_kind" = 'METRIC' AND - "property" IN ('EXPRESSION', 'DATA_TYPE', 'TABLE'); + "property" IN ('DATA_TYPE', 'TABLE'); """ # noqa: S608 (semantic_view.name is quoted) rows = self.execute(sql) @@ -171,9 +189,10 @@ DESC SEMANTIC VIEW {quoted_semantic_view_name} for row in group: attributes[row["property"]].add(row["property_value"]) - metric_name = attributes["TABLE"] + "." + name + table = next(iter(attributes["TABLE"])) + metric_name = table + "." + name type_ = self.get_type(next(iter(attributes["DATA_TYPE"]))) - sql = next(iter(attributes["EXPRESSION"]), name) + sql = self.engine.dialect.identifier_preparer.quote(metric_name) tables = frozenset(attributes["TABLE"]) join_columns = frozenset() @@ -261,7 +280,7 @@ DESC SEMANTIC VIEW {quoted_semantic_view_name} offset: int | None = None, ) -> exp.Select: semantic_view = exp.SemanticView( - this=Table(this=exp.Identifier(this=semantic_view.name, quoted=True)), + this=exp.Table(this=exp.Identifier(this=semantic_view.name, quoted=True)), dimensions=[ exp.Column( this=exp.Identifier(this=dimension.column.name, quoted=True), @@ -274,8 +293,8 @@ DESC SEMANTIC VIEW {quoted_semantic_view_name} ], metrics=[ exp.Column( - this=exp.Identifier(this=table, quoted=True), - table=exp.Identifier(this=column, quoted=True), + this=exp.Identifier(this=column, quoted=True), + table=exp.Identifier(this=table, quoted=True), ) for table, column in ( metric.name.split(".", 1) @@ -283,14 +302,13 @@ DESC SEMANTIC VIEW {quoted_semantic_view_name} if "." in metric.name ) ], - # where= XXX push predicates ) query = exp.Select( expressions=[exp.Star()], - **{"from": exp.From(this=exp.Table(semantic_view))}, + **{"from": exp.From(this=exp.Table(this=semantic_view))}, ) - if sort: + if sort.items: order = [ exp.Ordered( this=exp.Column(this=exp.Identifier(this=item.field.name)), @@ -309,14 +327,81 @@ DESC SEMANTIC VIEW {quoted_semantic_view_name} return query - def get_query_from_standard_sql( - self, - semantic_view: SemanticView, - sql: str, - ) -> Query: - statement = SQLStatement(sql, "snowflake") - # check if any of the tables in the statement is a semantic view + def get_query_from_standard_sql(self, sql: str) -> SemanticQuery: + """ + Convert the Explore query into a proper query. + + Explore will produce a pseudo-SQL query that references metrics and dimensions + as if they were columns in a table. This method replaces the table name with a + call to `SEMANTIC_VIEW`, and removes the `GROUP BY` clause, since all the + aggregations happen inside the `SEMANTIC_VIEW` call. + """ + ast = parse_one(sql, "snowflake") + table = ast.find(exp.Table) + if not table: + return SemanticQuery(sql=sql) + + semantic_views = self.get_semantic_views() + if table.name not in {semantic_view.name for semantic_view in semantic_views}: + return SemanticQuery(sql=sql) + + # collect all metric and dimensions + semantic_view = SemanticView(table.name) + all_metrics = self.get_metrics(semantic_view) + all_dimensions = self.get_dimensions(semantic_view) + + # collect metrics and dimensions used in the query + columns = {column.name for column in ast.find_all(exp.Column)} + metrics = [metric for metric in all_metrics if metric.name in columns] + dimensions = [ + dimension for dimension in all_dimensions if dimension.name in columns + ] + + # now replace table with a call to `SEMANTIC_VIEW` + udtf = exp.Table( + this=exp.SemanticView( + this=exp.Table( + this=exp.Identifier(this=semantic_view.name, quoted=True) + ), + metrics=[ + exp.Column( + this=exp.Identifier(this=column, quoted=True), + table=exp.Identifier(this=table, quoted=True), + ) + for table, column in ( + metric.name.split(".", 1) + for metric in metrics + if "." in metric.name + ) + ], + dimensions=[ + exp.Column( + this=exp.Identifier(this=column, quoted=True), + table=exp.Identifier(this=table, quoted=True), + ) + for table, column in ( + dimension.name.split(".", 1) + for dimension in dimensions + if "." in dimension.name + ) + ], + ), + alias=exp.TableAlias( + this=exp.Identifier(this="table_alias", quoted=False), + columns=[ + exp.Identifier(this=column.name, quoted=True) + for column in metrics + dimensions + ], + ), + ) + table.replace(udtf) + # remove group by, since aggregations are done inside the `SEMANTIC_VIEW` call + del ast.args["group"] + + print("BETO") + print(ast.sql(dialect="snowflake", pretty=True)) + return SemanticQuery(sql=ast.sql(dialect="snowflake", pretty=True)) class SnowflakeEngineSpec(PostgresBaseEngineSpec): diff --git a/superset/extensions/semantic_layer.py b/superset/extensions/semantic_layer.py index ac278788f7..c597a70619 100644 --- a/superset/extensions/semantic_layer.py +++ b/superset/extensions/semantic_layer.py @@ -333,8 +333,8 @@ TYPE_MAPPING: dict[Type, type[sqltypes.TypeEngine]] = { def get_sqla_type_from_dimension_type( dimension_type: Type, -) -> type[sqltypes.TypeEngine]: +) -> sqltypes.TypeEngine: """ Get the SQLAlchemy type corresponding to the given dimension type. """ - return TYPE_MAPPING.get(dimension_type, sqltypes.String) + return TYPE_MAPPING.get(dimension_type, sqltypes.String)() diff --git a/superset/models/core.py b/superset/models/core.py index 787c380bea..7e9f085aec 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -126,7 +126,9 @@ class ConfigurationMethod(StrEnum): DYNAMIC_FORM = "dynamic_form" -class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods +class Database( + Model, AuditMixinNullable, ImportExportMixin +): # pylint: disable=too-many-public-methods """An ORM object that stores Database related information""" __tablename__ = "dbs" @@ -400,9 +402,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable return ( username if (username := get_username()) - else object_url.username - if self.impersonate_user - else None + else object_url.username if self.impersonate_user else None ) @contextmanager @@ -987,7 +987,10 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable schema=table.schema, ) as inspector: return self.db_engine_spec.get_columns( - inspector, table, self.schema_options + self, + inspector, + table, + self.schema_options, ) def get_metrics( @@ -1076,9 +1079,11 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable return self.perm def has_table(self, table: Table) -> bool: - with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine: - # do not pass "" as an empty schema; force null - return engine.has_table(table.table, table.schema or None) + with self.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: + return self.db_engine_spec.has_table(self, inspector, table) def has_view(self, table: Table) -> bool: with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine:
