This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch snowflake-semantic-layer
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 57ca29baf416945933b1e3ce0b52df86c650fd77
Author: Beto Dealmeida <[email protected]>
AuthorDate: Wed Jul 30 16:08:32 2025 -0400

    Working
---
 superset/connectors/sqla/models.py    |   8 +-
 superset/daos/dataset.py              |  11 ++-
 superset/db_engine_specs/base.py      |  24 +++++-
 superset/db_engine_specs/snowflake.py | 141 +++++++++++++++++++++++++++-------
 superset/extensions/semantic_layer.py |   4 +-
 superset/models/core.py               |  21 +++--
 6 files changed, 161 insertions(+), 48 deletions(-)

diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 5bbafda72d..4798c9970f 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -164,7 +164,9 @@ class DatasourceKind(StrEnum):
     PHYSICAL = "physical"
 
 
-class BaseDatasource(AuditMixinNullable, ImportExportMixin):  # pylint: 
disable=too-many-public-methods
+class BaseDatasource(
+    AuditMixinNullable, ImportExportMixin
+):  # pylint: disable=too-many-public-methods
     """A common interface to objects that are queryable
     (tables and datasources)"""
 
@@ -1778,7 +1780,9 @@ class SqlaTable(
     def default_query(qry: Query) -> Query:
         return qry.filter_by(is_sqllab_view=False)
 
-    def has_extra_cache_key_calls(self, query_obj: QueryObjectDict) -> bool:  
# noqa: C901
+    def has_extra_cache_key_calls(
+        self, query_obj: QueryObjectDict
+    ) -> bool:  # noqa: C901
         """
         Detects the presence of calls to `ExtraCache` methods in items in 
query_obj that
         can be templated. If any are present, the query must be evaluated to 
extract
diff --git a/superset/daos/dataset.py b/superset/daos/dataset.py
index 5fbd012801..9221af7575 100644
--- a/superset/daos/dataset.py
+++ b/superset/daos/dataset.py
@@ -75,12 +75,11 @@ class DatasetDAO(BaseDAO[SqlaTable]):
         database: Database,
         table: Table,
     ) -> bool:
-        try:
-            database.get_table(table)
-            return True
-        except SQLAlchemyError as ex:  # pragma: no cover
-            logger.warning("Got an error %s validating table: %s", str(ex), 
table)
-            return False
+        with database.get_inspector(
+            catalog=table.catalog,
+            schema=table.schema,
+        ) as inspector:
+            return database.db_engine_spec.has_table(database, inspector, 
table)
 
     @staticmethod
     def validate_uniqueness(
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index f7594f3844..71ea3377e4 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1491,6 +1491,21 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
 
         return tables
 
+    @classmethod
+    def has_table(
+        cls,
+        database: Database,
+        inspector: Inspector,
+        table: Table,
+    ) -> bool:
+        if cls.semantic_layer:
+            semantic_layer = cls.semantic_layer(inspector.engine)
+            semantic_views = semantic_layer.get_semantic_views()
+            if table.table in {semantic_view.name for semantic_view in 
semantic_views}:
+                return True
+
+        return inspector.has_table(table.table, table.schema)
+
     @classmethod
     def get_view_names(  # pylint: disable=unused-argument
         cls,
@@ -1564,6 +1579,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
     @classmethod
     def get_columns(  # pylint: disable=unused-argument
         cls,
+        database: Database,
         inspector: Inspector,
         table: Table,
         options: dict[str, Any] | None = None,
@@ -1588,11 +1604,15 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
                 for semantic_view in semantic_layer.get_semantic_views()
             }
             if semantic_view := semantic_views.get(table.table):
+                dialect = database.get_dialect()
                 return [
                     {
                         "name": dimension.name,
                         "column_name": dimension.name,
-                        "type": 
get_sqla_type_from_dimension_type(dimension.type),
+                        "type": cls.column_datatype_to_string(
+                            get_sqla_type_from_dimension_type(dimension.type),
+                            dialect,
+                        ),
                     }
                     for dimension in 
semantic_layer.get_dimensions(semantic_view)
                 ]
@@ -1962,7 +1982,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         :return:
         """
         if cls.semantic_layer:
-            with database.get_engine() as engine:
+            with cls.get_engine(database, schema="tpcds_sf10tcl") as engine:
                 semantic_layer = cls.semantic_layer(engine)
                 query = semantic_layer.get_query_from_standard_sql(query).sql
 
diff --git a/superset/db_engine_specs/snowflake.py 
b/superset/db_engine_specs/snowflake.py
index 27de00bb66..2056e7b4bc 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -33,9 +33,10 @@ from flask import current_app
 from flask_babel import gettext as __
 from marshmallow import fields, Schema
 from sqlalchemy import text, types
+from sqlalchemy.engine.interfaces import Dialect
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.url import URL
-from sqlglot import exp
+from sqlglot import exp, parse_one
 
 from superset.constants import TimeGrain
 from superset.databases.utils import make_url_safe
@@ -124,23 +125,23 @@ class SnowflakeSemanticLayer:
 SHOW SEMANTIC VIEWS
     ->> SELECT "name" FROM $1;
         """
-        return {row["name"] for row in self.execute(sql)}
+        return {SemanticView(row["name"]) for row in self.execute(sql)}
 
     def get_type(self, snowflake_type: str | None) -> type[SemanticType]:
         if snowflake_type is None:
             return STRING
 
         type_map = {
-            STRING: {r"VARCHAR\(\d+\)", "STRING", "TEXT", r"CHAR\(\d+\)"},
-            INTEGER: {r"NUMBER\(38,\s0\)", "INT", "INTEGER", "BIGINT"},
-            DECIMAL: {r"NUMBER\(10,\s2\)"},
-            NUMBER: {r"NUMBER\(\d+,\s\d+)", "FLOAT", "DOUBLE"},
-            BOOLEAN: {"BOOLEAN"},
-            DATE: {"DATE"},
-            DATETIME: {"TIMESTAMP_TZ", "TIMESTAMP__NTZ"},
-            TIME: {"TIME"},
-            OBJECT: {"OBJECT"},
-            BINARY: {r"BINARY\(\d+\)", r"VARBINARY\(\d+\)"},
+            STRING: {r"VARCHAR\(\d+\)$", "STRING$", "TEXT$", r"CHAR\(\d+\)$"},
+            INTEGER: {r"NUMBER\(38,\s?0\)$", "INT$", "INTEGER$", "BIGINT$"},
+            DECIMAL: {r"NUMBER\(10,\s?2\)$"},
+            NUMBER: {r"NUMBER\(\d+,\s?\d+\)$", "FLOAT$", "DOUBLE$"},
+            BOOLEAN: {"BOOLEAN$"},
+            DATE: {"DATE$"},
+            DATETIME: {"TIMESTAMP_TZ$", "TIMESTAMP__NTZ$"},
+            TIME: {"TIME$"},
+            OBJECT: {"OBJECT$"},
+            BINARY: {r"BINARY\(\d+\)$", r"VARBINARY\(\d+\)$"},
         }
         for semantic_type, patterns in type_map.items():
             if any(
@@ -150,6 +151,23 @@ SHOW SEMANTIC VIEWS
 
         return STRING
 
+    @classmethod
+    def quote_table(cls, table: Table, dialect: Dialect) -> str:
+        """
+        Fully quote a table name, including the schema and catalog.
+        """
+        quoters = {
+            "catalog": dialect.identifier_preparer.quote_schema,
+            "schema": dialect.identifier_preparer.quote_schema,
+            "table": dialect.identifier_preparer.quote,
+        }
+
+        return ".".join(
+            function(getattr(table, key))
+            for key, function in quoters.items()
+            if getattr(table, key)
+        )
+
     def get_metrics(self, semantic_view: SemanticView) -> set[SemanticMetric]:
         quoted_semantic_view_name = self.quote_table(
             Table(semantic_view.name),
@@ -161,7 +179,7 @@ DESC SEMANTIC VIEW {quoted_semantic_view_name}
         FROM $1
         WHERE
             "object_kind" = 'METRIC' AND
-            "property" IN ('EXPRESSION', 'DATA_TYPE', 'TABLE');
+            "property" IN ('DATA_TYPE', 'TABLE');
         """  # noqa: S608 (semantic_view.name is quoted)
         rows = self.execute(sql)
 
@@ -171,9 +189,10 @@ DESC SEMANTIC VIEW {quoted_semantic_view_name}
             for row in group:
                 attributes[row["property"]].add(row["property_value"])
 
-            metric_name = attributes["TABLE"] + "." + name
+            table = next(iter(attributes["TABLE"]))
+            metric_name = table + "." + name
             type_ = self.get_type(next(iter(attributes["DATA_TYPE"])))
-            sql = next(iter(attributes["EXPRESSION"]), name)
+            sql = self.engine.dialect.identifier_preparer.quote(metric_name)
             tables = frozenset(attributes["TABLE"])
             join_columns = frozenset()
 
@@ -261,7 +280,7 @@ DESC SEMANTIC VIEW {quoted_semantic_view_name}
         offset: int | None = None,
     ) -> exp.Select:
         semantic_view = exp.SemanticView(
-            this=Table(this=exp.Identifier(this=semantic_view.name, 
quoted=True)),
+            this=exp.Table(this=exp.Identifier(this=semantic_view.name, 
quoted=True)),
             dimensions=[
                 exp.Column(
                     this=exp.Identifier(this=dimension.column.name, 
quoted=True),
@@ -274,8 +293,8 @@ DESC SEMANTIC VIEW {quoted_semantic_view_name}
             ],
             metrics=[
                 exp.Column(
-                    this=exp.Identifier(this=table, quoted=True),
-                    table=exp.Identifier(this=column, quoted=True),
+                    this=exp.Identifier(this=column, quoted=True),
+                    table=exp.Identifier(this=table, quoted=True),
                 )
                 for table, column in (
                     metric.name.split(".", 1)
@@ -283,14 +302,13 @@ DESC SEMANTIC VIEW {quoted_semantic_view_name}
                     if "." in metric.name
                 )
             ],
-            # where=  XXX push predicates
         )
         query = exp.Select(
             expressions=[exp.Star()],
-            **{"from": exp.From(this=exp.Table(semantic_view))},
+            **{"from": exp.From(this=exp.Table(this=semantic_view))},
         )
 
-        if sort:
+        if sort.items:
             order = [
                 exp.Ordered(
                     this=exp.Column(this=exp.Identifier(this=item.field.name)),
@@ -309,14 +327,81 @@ DESC SEMANTIC VIEW {quoted_semantic_view_name}
 
         return query
 
-    def get_query_from_standard_sql(
-        self,
-        semantic_view: SemanticView,
-        sql: str,
-    ) -> Query:
-        statement = SQLStatement(sql, "snowflake")
-        # check if any of the tables in the statement is a semantic view
+    def get_query_from_standard_sql(self, sql: str) -> SemanticQuery:
+        """
+        Convert the Explore query into a proper query.
+
+        Explore will produce a pseudo-SQL query that references metrics and 
dimensions
+        as if they were columns in a table. This method replaces the table 
name with a
+        call to `SEMANTIC_VIEW`, and removes the `GROUP BY` clause, since all 
the
+        aggregations happen inside the `SEMANTIC_VIEW` call.
+        """
+        ast = parse_one(sql, "snowflake")
+        table = ast.find(exp.Table)
+        if not table:
+            return SemanticQuery(sql=sql)
+
+        semantic_views = self.get_semantic_views()
+        if table.name not in {semantic_view.name for semantic_view in 
semantic_views}:
+            return SemanticQuery(sql=sql)
+
+        # collect all metric and dimensions
+        semantic_view = SemanticView(table.name)
+        all_metrics = self.get_metrics(semantic_view)
+        all_dimensions = self.get_dimensions(semantic_view)
+
+        # collect metrics and dimensions used in the query
+        columns = {column.name for column in ast.find_all(exp.Column)}
+        metrics = [metric for metric in all_metrics if metric.name in columns]
+        dimensions = [
+            dimension for dimension in all_dimensions if dimension.name in 
columns
+        ]
+
+        # now replace table with a call to `SEMANTIC_VIEW`
+        udtf = exp.Table(
+            this=exp.SemanticView(
+                this=exp.Table(
+                    this=exp.Identifier(this=semantic_view.name, quoted=True)
+                ),
+                metrics=[
+                    exp.Column(
+                        this=exp.Identifier(this=column, quoted=True),
+                        table=exp.Identifier(this=table, quoted=True),
+                    )
+                    for table, column in (
+                        metric.name.split(".", 1)
+                        for metric in metrics
+                        if "." in metric.name
+                    )
+                ],
+                dimensions=[
+                    exp.Column(
+                        this=exp.Identifier(this=column, quoted=True),
+                        table=exp.Identifier(this=table, quoted=True),
+                    )
+                    for table, column in (
+                        dimension.name.split(".", 1)
+                        for dimension in dimensions
+                        if "." in dimension.name
+                    )
+                ],
+            ),
+            alias=exp.TableAlias(
+                this=exp.Identifier(this="table_alias", quoted=False),
+                columns=[
+                    exp.Identifier(this=column.name, quoted=True)
+                    for column in metrics + dimensions
+                ],
+            ),
+        )
+        table.replace(udtf)
 
+        # remove group by, since aggregations are done inside the 
`SEMANTIC_VIEW` call
+        del ast.args["group"]
+
+        print("BETO")
+        print(ast.sql(dialect="snowflake", pretty=True))
+        return SemanticQuery(sql=ast.sql(dialect="snowflake", pretty=True))
 
 
 class SnowflakeEngineSpec(PostgresBaseEngineSpec):
diff --git a/superset/extensions/semantic_layer.py 
b/superset/extensions/semantic_layer.py
index ac278788f7..c597a70619 100644
--- a/superset/extensions/semantic_layer.py
+++ b/superset/extensions/semantic_layer.py
@@ -333,8 +333,8 @@ TYPE_MAPPING: dict[Type, type[sqltypes.TypeEngine]] = {
 
 def get_sqla_type_from_dimension_type(
     dimension_type: Type,
-) -> type[sqltypes.TypeEngine]:
+) -> sqltypes.TypeEngine:
     """
     Get the SQLAlchemy type corresponding to the given dimension type.
     """
-    return TYPE_MAPPING.get(dimension_type, sqltypes.String)
+    return TYPE_MAPPING.get(dimension_type, sqltypes.String)()
diff --git a/superset/models/core.py b/superset/models/core.py
index 787c380bea..7e9f085aec 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -126,7 +126,9 @@ class ConfigurationMethod(StrEnum):
     DYNAMIC_FORM = "dynamic_form"
 
 
-class Database(Model, AuditMixinNullable, ImportExportMixin):  # pylint: 
disable=too-many-public-methods
+class Database(
+    Model, AuditMixinNullable, ImportExportMixin
+):  # pylint: disable=too-many-public-methods
     """An ORM object that stores Database related information"""
 
     __tablename__ = "dbs"
@@ -400,9 +402,7 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
         return (
             username
             if (username := get_username())
-            else object_url.username
-            if self.impersonate_user
-            else None
+            else object_url.username if self.impersonate_user else None
         )
 
     @contextmanager
@@ -987,7 +987,10 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
             schema=table.schema,
         ) as inspector:
             return self.db_engine_spec.get_columns(
-                inspector, table, self.schema_options
+                self,
+                inspector,
+                table,
+                self.schema_options,
             )
 
     def get_metrics(
@@ -1076,9 +1079,11 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
         return self.perm
 
     def has_table(self, table: Table) -> bool:
-        with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) 
as engine:
-            # do not pass "" as an empty schema; force null
-            return engine.has_table(table.table, table.schema or None)
+        with self.get_inspector(
+            catalog=table.catalog,
+            schema=table.schema,
+        ) as inspector:
+            return self.db_engine_spec.has_table(self, inspector, table)
 
     def has_view(self, table: Table) -> bool:
         with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) 
as engine:

Reply via email to