This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch canva-demo
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 66cb9cb40c7c27091234d2bf1c9ce07cce419107
Author: Beto Dealmeida <[email protected]>
AuthorDate: Wed Jul 30 17:46:02 2025 -0400

    Snowflake
---
 superset/daos/dataset.py              |  11 +-
 superset/db_engine_specs/base.py      | 223 ++++++++++++++++++++--
 superset/db_engine_specs/snowflake.py | 336 ++++++++++++++++++++++++++++++++-
 superset/extensions/semantic_layer.py | 340 ++++++++++++++++++++++++++++++++++
 superset/models/core.py               |  21 ++-
 5 files changed, 895 insertions(+), 36 deletions(-)

diff --git a/superset/daos/dataset.py b/superset/daos/dataset.py
index 5fbd012801..9221af7575 100644
--- a/superset/daos/dataset.py
+++ b/superset/daos/dataset.py
@@ -75,12 +75,11 @@ class DatasetDAO(BaseDAO[SqlaTable]):
         database: Database,
         table: Table,
     ) -> bool:
-        try:
-            database.get_table(table)
-            return True
-        except SQLAlchemyError as ex:  # pragma: no cover
-            logger.warning("Got an error %s validating table: %s", str(ex), 
table)
-            return False
+        with database.get_inspector(
+            catalog=table.catalog,
+            schema=table.schema,
+        ) as inspector:
+            return database.db_engine_spec.has_table(database, inspector, 
table)
 
     @staticmethod
     def validate_uniqueness(
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index f8449f4030..47bd1afec7 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -30,6 +30,7 @@ from typing import (
     cast,
     ContextManager,
     NamedTuple,
+    Type,
     TYPE_CHECKING,
     TypedDict,
     Union,
@@ -54,7 +55,7 @@ from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.url import URL
 from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.sql import literal_column, quoted_name, text
-from sqlalchemy.sql.expression import ColumnClause, Select, TextClause
+from sqlalchemy.sql.expression import BinaryExpression, ColumnClause, Select, 
TextClause
 from sqlalchemy.types import TypeEngine
 
 from superset import db
@@ -62,6 +63,10 @@ from superset.constants import QUERY_CANCEL_KEY, TimeGrain 
as TimeGrainConstants
 from superset.databases.utils import get_table_metadata, make_url_safe
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.exceptions import OAuth2Error, OAuth2RedirectError
+from superset.extensions.semantic_layer import (
+    get_sqla_type_from_dimension_type,
+    SemanticLayer,
+)
 from superset.sql.parse import (
     BaseSQLStatement,
     LimitMethod,
@@ -106,6 +111,15 @@ logger = logging.getLogger()
 GenericDBException = Exception
 
 
+class ValidColumnsType(TypedDict):
+    """
+    Type for valid columns returned by `get_valid_metrics_and_dimensions`.
+    """
+
+    dimensions: set[str]
+    metrics: set[str]
+
+
 def convert_inspector_columns(cols: list[SQLAColumnType]) -> 
list[ResultSetColumnType]:
     result_set_columns: list[ResultSetColumnType] = []
     for col in cols:
@@ -188,15 +202,6 @@ class MetricType(TypedDict, total=False):
     extra: str | None
 
 
-class ValidColumnsType(TypedDict):
-    """
-    Type for valid columns returned by `get_valid_metrics_and_dimensions`.
-    """
-
-    dimensions: set[str]
-    metrics: set[str]
-
-
 class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     """Abstract class for database engine specific configurations
 
@@ -225,6 +230,9 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         
"engine+driver://user:password@host:port/dbname[?key=value&key=value...]"
     )
 
+    # databases can optionally specify a semantic layer
+    semantic_layer: Type[SemanticLayer] | None = None
+
     disable_ssh_tunneling = False
 
     _date_trunc_functions: dict[str, str] = {}
@@ -388,6 +396,10 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
     disallow_uri_query_params: dict[str, set[str]] = {}
     # A Dict of query parameters that will always be used on every connection
     # by driver name
+
+    # Whether to use equality operators (= true/false) instead of IS operators
+    # for boolean filters. Some databases like Snowflake don't support IS 
true/false
+    use_equality_for_boolean_filters = False
     enforce_uri_query_params: dict[str, dict[str, Any]] = {}
 
     force_column_alias_quotes = False
@@ -1218,6 +1230,78 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         """
         return None
 
+    @classmethod
+    def handle_boolean_filter(
+        cls, sqla_col: Any, op: str, value: bool
+    ) -> BinaryExpression:
+        """
+        Handle boolean filter operations with engine-specific logic.
+
+        By default, uses SQLAlchemy's IS operator (column IS true/false).
+        Engines that don't support IS for boolean values can override
+        use_equality_for_boolean_filters to use equality operators instead.
+
+        :param sqla_col: SQLAlchemy column element
+        :param op: Filter operator (IS_TRUE or IS_FALSE)
+        :param value: Boolean value (True or False)
+        :return: SQLAlchemy expression for the boolean filter
+        """
+        if cls.use_equality_for_boolean_filters:
+            return sqla_col == value
+        else:
+            return sqla_col.is_(value)
+
+    @classmethod
+    def handle_null_filter(
+        cls,
+        sqla_col: Any,
+        op: utils.FilterOperator,
+    ) -> BinaryExpression:
+        """
+        Handle null/not null filter operations.
+
+        :param sqla_col: SQLAlchemy column element
+        :param op: Filter operator (IS_NULL or IS_NOT_NULL)
+        :return: SQLAlchemy expression for the null filter
+        """
+        from superset.utils import core as utils
+
+        if op == utils.FilterOperator.IS_NULL:
+            return sqla_col.is_(None)
+        elif op == utils.FilterOperator.IS_NOT_NULL:
+            return sqla_col.isnot(None)
+        else:
+            raise ValueError(f"Invalid null filter operator: {op}")
+
+    @classmethod
+    def handle_comparison_filter(
+        cls, sqla_col: Any, op: utils.FilterOperator, value: Any
+    ) -> BinaryExpression:
+        """
+        Handle comparison filter operations (=, !=, >, <, >=, <=).
+
+        :param sqla_col: SQLAlchemy column element
+        :param op: Filter operator
+        :param value: Filter value
+        :return: SQLAlchemy expression for the comparison filter
+        """
+        from superset.utils import core as utils
+
+        if op == utils.FilterOperator.EQUALS:
+            return sqla_col == value
+        elif op == utils.FilterOperator.NOT_EQUALS:
+            return sqla_col != value
+        elif op == utils.FilterOperator.GREATER_THAN:
+            return sqla_col > value
+        elif op == utils.FilterOperator.LESS_THAN:
+            return sqla_col < value
+        elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS:
+            return sqla_col >= value
+        elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS:
+            return sqla_col <= value
+        else:
+            raise ValueError(f"Invalid comparison filter operator: {op}")
+
     @classmethod
     def handle_cursor(cls, cursor: Any, query: Query) -> None:
         """Handle a live cursor between the execute and fetchall calls
@@ -1401,8 +1485,32 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
 
         if schema and cls.try_remove_schema_from_table_name:
             tables = {re.sub(f"^{schema}\\.", "", table) for table in tables}
+
+        # add semantic views as tables too
+        if cls.semantic_layer:
+            semantic_layer = cls.semantic_layer(inspector.engine)
+            tables.update(
+                semantic_view.name
+                for semantic_view in semantic_layer.get_semantic_views()
+            )
+
         return tables
 
+    @classmethod
+    def has_table(
+        cls,
+        database: Database,
+        inspector: Inspector,
+        table: Table,
+    ) -> bool:
+        if cls.semantic_layer:
+            semantic_layer = cls.semantic_layer(inspector.engine)
+            semantic_views = semantic_layer.get_semantic_views()
+            if table.table in {semantic_view.name for semantic_view in 
semantic_views}:
+                return True
+
+        return inspector.has_table(table.table, table.schema)
+
     @classmethod
     def get_view_names(  # pylint: disable=unused-argument
         cls,
@@ -1476,6 +1584,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
     @classmethod
     def get_columns(  # pylint: disable=unused-argument
         cls,
+        database: Database,
         inspector: Inspector,
         table: Table,
         options: dict[str, Any] | None = None,
@@ -1483,7 +1592,9 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         """
         Get all columns from a given schema and table.
 
-        The inspector will be bound to a catalog, if one was specified.
+        The inspector will be bound to a catalog, if one was specified. If the 
database
+        supports semantic layers the method will check if the table is a 
semantic view,
+        and return columns (metrics and dimensions) from it instead.
 
         :param inspector: SqlAlchemy Inspector instance
         :param table: Table instance
@@ -1491,6 +1602,26 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
                         some databases
         :return: All columns in table
         """
+        if cls.semantic_layer:
+            semantic_layer = cls.semantic_layer(inspector.engine)
+            semantic_views = {
+                semantic_view.name: semantic_view
+                for semantic_view in semantic_layer.get_semantic_views()
+            }
+            if semantic_view := semantic_views.get(table.table):
+                dialect = database.get_dialect()
+                return [
+                    {
+                        "name": dimension.name,
+                        "column_name": dimension.name,
+                        "type": cls.column_datatype_to_string(
+                            get_sqla_type_from_dimension_type(dimension.type),
+                            dialect,
+                        ),
+                    }
+                    for dimension in 
semantic_layer.get_dimensions(semantic_view)
+                ]
+
         return convert_inspector_columns(
             cast(
                 list[SQLAColumnType],
@@ -1508,6 +1639,22 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         """
         Get all metrics from a given schema and table.
         """
+        if cls.semantic_layer:
+            semantic_layer = cls.semantic_layer(inspector.engine)
+            semantic_views = {
+                semantic_view.name: semantic_view
+                for semantic_view in semantic_layer.get_semantic_views()
+            }
+            if semantic_view := semantic_views.get(table.table):
+                return [
+                    {
+                        "metric_name": metric.name,
+                        "verbose_name": metric.name,
+                        "expression": metric.sql,
+                    }
+                    for metric in semantic_layer.get_metrics(semantic_view)
+                ]
+
         return [
             {
                 "metric_name": "count",
@@ -1526,17 +1673,48 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         metrics: set[str],
     ) -> ValidColumnsType:
         """
-        Given a selection of columns/metrics from a datasource, return related 
columns.
-
-        This is a method used for semantic layers, where tables can have 
columns and
-        metrics that cannot be computed together. When the user selects a 
given metric
-        it allows the UI to filter the remaining metrics and dimensions so 
that only
-        valid combinations are possible.
+        Get valid metrics and dimensions.
 
-        The method should only be called when ``supports_dynamic_columns`` is 
set to
-        true. The default method in the base class ignores the selected 
columns and
-        metrics, and simply returns everything, for reference.
+        Given a datasource, and sets of selected metrics and dimensions, 
return the
+        sets of valid metrics and dimensions that can further be selected.
         """
+        if cls.semantic_layer:
+            with database.get_sqla_engine() as engine:
+                semantic_layer = cls.semantic_layer(engine)
+                semantic_views = {
+                    semantic_view.name: semantic_view
+                    for semantic_view in semantic_layer.get_semantic_views()
+                }
+                if semantic_view := semantic_views.get(table.table):
+                    selected_metrics = {
+                        metric
+                        for metric in semantic_layer.get_metrics(semantic_view)
+                        if metric.name in metrics
+                    }
+                    selected_dimensions = {
+                        dimension
+                        for dimension in 
semantic_layer.get_dimensions(semantic_view)
+                        if dimension.name in dimensions
+                    }
+                    return {
+                        "metrics": {
+                            metric.name
+                            for metric in semantic_layer.get_valid_metrics(
+                                semantic_view,
+                                selected_metrics,
+                                selected_dimensions,
+                            )
+                        },
+                        "dimensions": {
+                            dimension.name
+                            for dimension in 
semantic_layer.get_valid_dimensions(
+                                semantic_view,
+                                selected_metrics,
+                                selected_dimensions,
+                            )
+                        },
+                    }
+
         return {
             "dimensions": {column.column_name for column in table.columns},
             "metrics": {metric.metric_name for metric in table.metrics},
@@ -1808,6 +1986,11 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         :param kwargs: kwargs to be passed to cursor.execute()
         :return:
         """
+        if cls.semantic_layer:
+            with cls.get_engine(database, schema="tpcds_sf10tcl") as engine:
+                semantic_layer = cls.semantic_layer(engine)
+                query = semantic_layer.get_query_from_standard_sql(query).sql
+
         if cls.arraysize:
             cursor.arraysize = cls.arraysize
         try:
diff --git a/superset/db_engine_specs/snowflake.py 
b/superset/db_engine_specs/snowflake.py
index 7e353f006d..2056e7b4bc 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -16,11 +16,13 @@
 # under the License.
 from __future__ import annotations
 
+import itertools
 import logging
 import re
+from collections import defaultdict
 from datetime import datetime
 from re import Pattern
-from typing import Any, Optional, TYPE_CHECKING, TypedDict
+from typing import Any, Iterator, Optional, TYPE_CHECKING, TypedDict
 from urllib import parse
 
 from apispec import APISpec
@@ -30,20 +32,48 @@ from cryptography.hazmat.primitives import serialization
 from flask import current_app
 from flask_babel import gettext as __
 from marshmallow import fields, Schema
-from sqlalchemy import types
+from sqlalchemy import text, types
+from sqlalchemy.engine.interfaces import Dialect
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.url import URL
+from sqlglot import exp, parse_one
 
 from superset.constants import TimeGrain
 from superset.databases.utils import make_url_safe
 from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
 from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
+from superset.extensions.semantic_layer import (
+    BINARY,
+    BOOLEAN,
+    Column as SemanticColumn,
+    DATE,
+    DATETIME,
+    DECIMAL,
+    Dimension as SemanticDimension,
+    Filter as SemanticFilter,
+    INTEGER,
+    Metric as SemanticMetric,
+    NoSort,
+    NUMBER,
+    OBJECT,
+    Query as SemanticQuery,
+    SemanticView,
+    Sort as SemanticSort,
+    SortDirectionEnum,
+    STRING,
+    Table as SemanticTable,
+    TIME,
+    Type as SemanticType,
+)
 from superset.models.sql_lab import Query
+from superset.sql.parse import Table
 from superset.utils import json
 from superset.utils.core import get_user_agent, QuerySource
 
 if TYPE_CHECKING:
+    from sqlalchemy.engine.base import Engine
+
     from superset.models.core import Database
 
 # Regular expressions to catch custom errors
@@ -77,16 +107,318 @@ class SnowflakeParametersType(TypedDict):
     warehouse: str
 
 
+class SnowflakeSemanticLayer:
+    def __init__(self, engine: Engine) -> None:
+        self.engine = engine
+
+    def execute(
+        self,
+        sql: str,
+        **kwargs: Any,
+    ) -> Iterator[dict[str, Any]]:
+        with self.engine.connect() as connection:
+            for row in connection.execute(text(sql), kwargs).mappings():
+                yield dict(row)
+
+    def get_semantic_views(self) -> set[SemanticView]:
+        sql = """
+SHOW SEMANTIC VIEWS
+    ->> SELECT "name" FROM $1;
+        """
+        return {SemanticView(row["name"]) for row in self.execute(sql)}
+
+    def get_type(self, snowflake_type: str | None) -> type[SemanticType]:
+        if snowflake_type is None:
+            return STRING
+
+        type_map = {
+            STRING: {r"VARCHAR\(\d+\)$", "STRING$", "TEXT$", r"CHAR\(\d+\)$"},
+            INTEGER: {r"NUMBER\(38,\s?0\)$", "INT$", "INTEGER$", "BIGINT$"},
+            DECIMAL: {r"NUMBER\(10,\s?2\)$"},
+            NUMBER: {r"NUMBER\(\d+,\s?\d+\)$", "FLOAT$", "DOUBLE$"},
+            BOOLEAN: {"BOOLEAN$"},
+            DATE: {"DATE$"},
+            DATETIME: {"TIMESTAMP_TZ$", "TIMESTAMP__NTZ$"},
+            TIME: {"TIME$"},
+            OBJECT: {"OBJECT$"},
+            BINARY: {r"BINARY\(\d+\)$", r"VARBINARY\(\d+\)$"},
+        }
+        for semantic_type, patterns in type_map.items():
+            if any(
+                re.match(pattern, snowflake_type, re.IGNORECASE) for pattern 
in patterns
+            ):
+                return semantic_type
+
+        return STRING
+
+    @classmethod
+    def quote_table(cls, table: Table, dialect: Dialect) -> str:
+        """
+        Fully quote a table name, including the schema and catalog.
+        """
+        quoters = {
+            "catalog": dialect.identifier_preparer.quote_schema,
+            "schema": dialect.identifier_preparer.quote_schema,
+            "table": dialect.identifier_preparer.quote,
+        }
+
+        return ".".join(
+            function(getattr(table, key))
+            for key, function in quoters.items()
+            if getattr(table, key)
+        )
+
+    def get_metrics(self, semantic_view: SemanticView) -> set[SemanticMetric]:
+        quoted_semantic_view_name = self.quote_table(
+            Table(semantic_view.name),
+            self.engine.dialect,
+        )
+        sql = f"""
+DESC SEMANTIC VIEW {quoted_semantic_view_name}
+    ->> SELECT "object_name", "property", "property_value"
+        FROM $1
+        WHERE
+            "object_kind" = 'METRIC' AND
+            "property" IN ('DATA_TYPE', 'TABLE');
+        """  # noqa: S608 (semantic_view.name is quoted)
+        rows = self.execute(sql)
+
+        metrics: set[SemanticMetric] = set()
+        for name, group in itertools.groupby(rows, key=lambda x: 
x["object_name"]):
+            attributes = defaultdict(set)
+            for row in group:
+                attributes[row["property"]].add(row["property_value"])
+
+            table = next(iter(attributes["TABLE"]))
+            metric_name = table + "." + name
+            type_ = self.get_type(next(iter(attributes["DATA_TYPE"])))
+            sql = self.engine.dialect.identifier_preparer.quote(metric_name)
+            tables = frozenset(attributes["TABLE"])
+            join_columns = frozenset()
+
+            metrics.add(SemanticMetric(metric_name, type_, sql, tables, 
join_columns))
+
+        return metrics
+
+    def get_dimensions(self, semantic_view: SemanticView) -> 
set[SemanticDimension]:
+        quoted_semantic_view_name = self.quote_table(
+            Table(semantic_view.name),
+            self.engine.dialect,
+        )
+        sql = f"""
+DESC SEMANTIC VIEW {quoted_semantic_view_name}
+    ->> SELECT "object_name", "property", "property_value"
+        FROM $1
+        WHERE
+            "object_kind" = 'DIMENSION' AND
+            "property" IN ('DATA_TYPE', 'TABLE');
+        """  # noqa: S608 (semantic_view.name is quoted)
+        rows = self.execute(sql)
+
+        dimensions: set[SemanticDimension] = set()
+        for name, group in itertools.groupby(rows, key=lambda x: 
x["object_name"]):
+            attributes = defaultdict(set)
+            for row in group:
+                attributes[row["property"]].add(row["property_value"])
+
+            table = next(iter(attributes["TABLE"]))
+            dimension_name = table + "." + name
+            column = SemanticColumn(SemanticTable(table), name)
+            type_ = self.get_type(next(iter(attributes["DATA_TYPE"])))
+
+            dimensions.add(SemanticDimension(column, dimension_name, type_))
+
+        return dimensions
+
+    def get_valid_metrics(
+        self,
+        semantic_view: SemanticView,
+        metrics: set[SemanticMetric],
+        dimensions: set[SemanticDimension],
+    ) -> set[SemanticMetric]:
+        # all metrics and dimensions are valid inside a given semantic view
+        return self.get_metrics(semantic_view)
+
+    def get_valid_dimensions(
+        self,
+        semantic_view: SemanticView,
+        metrics: set[SemanticMetric],
+        dimensions: set[SemanticDimension],
+    ) -> set[SemanticDimension]:
+        # all metrics and dimensions are valid inside a given semantic view
+        return self.get_dimensions(semantic_view)
+
+    def get_query(
+        self,
+        semantic_view: SemanticView,
+        metrics: set[SemanticMetric],
+        dimensions: set[SemanticDimension],
+        filters: set[SemanticFilter],
+        sort: SemanticSort = NoSort,
+        limit: int | None = None,
+        offset: int | None = None,
+    ) -> SemanticQuery:
+        ast = self.build_query(
+            semantic_view,
+            metrics,
+            dimensions,
+            filters,
+            sort,
+            limit,
+            offset,
+        )
+        return SemanticQuery(sql=ast.sql(dialect="snowflake", pretty=True))
+
+    def build_query(
+        self,
+        semantic_view: SemanticView,
+        metrics: set[SemanticMetric],
+        dimensions: set[SemanticDimension],
+        filters: set[SemanticFilter],
+        sort: SemanticSort = NoSort,
+        limit: int | None = None,
+        offset: int | None = None,
+    ) -> exp.Select:
+        semantic_view = exp.SemanticView(
+            this=exp.Table(this=exp.Identifier(this=semantic_view.name, 
quoted=True)),
+            dimensions=[
+                exp.Column(
+                    this=exp.Identifier(this=dimension.column.name, 
quoted=True),
+                    table=exp.Identifier(
+                        this=dimension.column.relation.name,
+                        quoted=True,
+                    ),
+                )
+                for dimension in dimensions
+            ],
+            metrics=[
+                exp.Column(
+                    this=exp.Identifier(this=column, quoted=True),
+                    table=exp.Identifier(this=table, quoted=True),
+                )
+                for table, column in (
+                    metric.name.split(".", 1)
+                    for metric in metrics
+                    if "." in metric.name
+                )
+            ],
+        )
+        query = exp.Select(
+            expressions=[exp.Star()],
+            **{"from": exp.From(this=exp.Table(this=semantic_view))},
+        )
+
+        if sort.items:
+            order = [
+                exp.Ordered(
+                    this=exp.Column(this=exp.Identifier(this=item.field.name)),
+                    desc=item.direction == SortDirectionEnum.DESC,
+                    nulls_first=item.nulls_first,
+                )
+                for item in sort.items
+            ]
+            query.args["order"] = exp.Order(expressions=order)
+
+        if offset:
+            query = query.offset(offset)
+
+        if limit:
+            query = query.limit(limit)
+
+        return query
+
+    def get_query_from_standard_sql(self, sql: str) -> SemanticQuery:
+        """
+        Convert the Explore query into a proper query.
+
+        Explore will produce a pseudo-SQL query that references metrics and 
dimensions
+        as if they were columns in a table. This method replaces the table 
name with a
+        call to `SEMANTIC_VIEW`, and removes the `GROUP BY` clause, since all 
the
+        aggregations happen inside the `SEMANTIC_VIEW` call.
+        """
+        ast = parse_one(sql, "snowflake")
+        table = ast.find(exp.Table)
+        if not table:
+            return SemanticQuery(sql=sql)
+
+        semantic_views = self.get_semantic_views()
+        if table.name not in {semantic_view.name for semantic_view in 
semantic_views}:
+            return SemanticQuery(sql=sql)
+
+        # collect all metric and dimensions
+        semantic_view = SemanticView(table.name)
+        all_metrics = self.get_metrics(semantic_view)
+        all_dimensions = self.get_dimensions(semantic_view)
+
+        # collect metrics and dimensions used in the query
+        columns = {column.name for column in ast.find_all(exp.Column)}
+        metrics = [metric for metric in all_metrics if metric.name in columns]
+        dimensions = [
+            dimension for dimension in all_dimensions if dimension.name in 
columns
+        ]
+
+        # now replace table with a call to `SEMANTIC_VIEW`
+        udtf = exp.Table(
+            this=exp.SemanticView(
+                this=exp.Table(
+                    this=exp.Identifier(this=semantic_view.name, quoted=True)
+                ),
+                metrics=[
+                    exp.Column(
+                        this=exp.Identifier(this=column, quoted=True),
+                        table=exp.Identifier(this=table, quoted=True),
+                    )
+                    for table, column in (
+                        metric.name.split(".", 1)
+                        for metric in metrics
+                        if "." in metric.name
+                    )
+                ],
+                dimensions=[
+                    exp.Column(
+                        this=exp.Identifier(this=column, quoted=True),
+                        table=exp.Identifier(this=table, quoted=True),
+                    )
+                    for table, column in (
+                        dimension.name.split(".", 1)
+                        for dimension in dimensions
+                        if "." in dimension.name
+                    )
+                ],
+            ),
+            alias=exp.TableAlias(
+                this=exp.Identifier(this="table_alias", quoted=False),
+                columns=[
+                    exp.Identifier(this=column.name, quoted=True)
+                    for column in metrics + dimensions
+                ],
+            ),
+        )
+        table.replace(udtf)
+
+        # remove group by, since aggregations are done inside the 
`SEMANTIC_VIEW` call
+        del ast.args["group"]
+
+        print("BETO")
+        print(ast.sql(dialect="snowflake", pretty=True))
+        return SemanticQuery(sql=ast.sql(dialect="snowflake", pretty=True))
+
+
 class SnowflakeEngineSpec(PostgresBaseEngineSpec):
     engine = "snowflake"
     engine_name = "Snowflake"
     force_column_alias_quotes = True
     max_column_name_length = 256
 
+    # Snowflake doesn't support IS true/false syntax, use = true/false instead
+    use_equality_for_boolean_filters = True
+
     parameters_schema = SnowflakeParametersSchema()
     default_driver = "snowflake"
     sqlalchemy_uri_placeholder = "snowflake://"
 
+    semantic_layer = SnowflakeSemanticLayer
+
     supports_dynamic_schema = True
     supports_catalog = supports_dynamic_catalog = 
supports_cross_catalog_queries = True
 
diff --git a/superset/extensions/semantic_layer.py 
b/superset/extensions/semantic_layer.py
new file mode 100644
index 0000000000..c597a70619
--- /dev/null
+++ b/superset/extensions/semantic_layer.py
@@ -0,0 +1,340 @@
+import enum
+from dataclasses import dataclass
+from datetime import timedelta
+from functools import total_ordering
+from typing import Protocol, runtime_checkable
+
+from sqlalchemy import types as sqltypes
+from sqlalchemy.engine.base import Engine
+
+
+class Type:
+    """
+    Base class for types.
+    """
+
+
+class INTEGER(Type):
+    """
+    Represents an integer type.
+    """
+
+
+class NUMBER(Type):
+    """
+    Represents a number type.
+    """
+
+
+class DECIMAL(Type):
+    """
+    Represents a decimal type.
+    """
+
+
+class STRING(Type):
+    """
+    Represents a string type.
+    """
+
+
+class BOOLEAN(Type):
+    """
+    Represents a boolean type.
+    """
+
+
+class DATE(Type):
+    """
+    Represents a date type.
+    """
+
+
+class TIME(Type):
+    """
+    Represents a time type.
+    """
+
+
+class DATETIME(DATE, TIME):
+    """
+    Represents a datetime type.
+    """
+
+
+class INTERVAL(Type):
+    """
+    Represents an interval type.
+    """
+
+
+class OBJECT(Type):
+    """
+    Represents an object type.
+    """
+
+
+class BINARY(Type):
+    """
+    Represents a binary type.
+    """
+
+
+@dataclass(frozen=True)
+class SemanticView:
+    name: str
+    description: str | None = None
+
+
+@dataclass(frozen=True)
+class Relation:
+    name: str
+    schema: str | None = None
+    catalog: str | None = None
+
+
+@dataclass(frozen=True)
+class Table:
+    name: str
+    schema: str | None = None
+    catalog: str | None = None
+
+
+@dataclass(frozen=True)
+class View:
+    name: str
+    sql: str
+    schema: str | None = None
+    catalog: str | None = None
+
+
+@dataclass(frozen=True)
+class Virtual:
+    name: str
+
+
+@dataclass(frozen=True)
+class Metric:
+    name: str
+    type: type[Type]
+    sql: str
+    tables: frozenset[Table]
+    join_columns: frozenset[str]
+
+
+@total_ordering
+class ComparableEnum(enum.Enum):
+    def __eq__(self, other: object) -> bool:
+        if isinstance(other, enum.Enum):
+            return self.value == other.value
+        return NotImplemented
+
+    def __lt__(self, other: object) -> bool:
+        if isinstance(other, enum.Enum):
+            return self.value < other.value
+        return NotImplemented
+
+    def __hash__(self):
+        return hash((self.__class__, self.name))
+
+
+class TimeGrain(ComparableEnum):
+    second = timedelta(seconds=1)
+    minute = timedelta(minutes=1)
+    hour = timedelta(hours=1)
+
+
+class DateGrain(ComparableEnum):
+    day = timedelta(days=1)
+    week = timedelta(weeks=1)
+    month = timedelta(days=30)
+    quarter = timedelta(days=90)
+    year = timedelta(days=365)
+
+
+@dataclass(frozen=True)
+class Column:
+    relation: Table | View | Virtual
+    name: str
+
+
+@dataclass(frozen=True)
+class Dimension:
+    column: Column
+    name: str
+    type: type[Type]
+    grain: TimeGrain | DateGrain | None = None
+
+    def __repr__(self) -> str:
+        metadata = f"[{self.grain.name}]" if self.grain else ""
+        return f"{self.type.__name__} {self.name} {metadata}".strip()
+
+
+class FilterTypeEnum(enum.Enum):
+    WHERE = enum.auto()
+    HAVING = enum.auto()
+
+
+@dataclass(frozen=True)
+class Filter:
+    type: FilterTypeEnum
+    expression: str
+
+
+class SortDirectionEnum(enum.Enum):
+    ASC = enum.auto()
+    DESC = enum.auto()
+
+
+@dataclass(frozen=True)
+class SortField:
+    field: Metric | Dimension
+    direction: SortDirectionEnum
+    nulls_first: bool = True
+
+
+@dataclass(frozen=True)
+class Sort:
+    items: list[SortField]
+
+
+@dataclass(frozen=True)
+class Query:
+    sql: str
+
+
+NoSort = Sort(items=[])
+
+
+@runtime_checkable
+class SemanticLayer(Protocol):
+    """
+    A generic protocol for semantic layers.
+    """
+
+    def __init__(self, engine: Engine) -> None: ...
+
+    def get_semantic_views(self) -> set[SemanticView]:
+        """
+        Return a set of the semantic views.
+
+        A semantic view is an organizational group of metrics and dimensions. 
It's not a
+        logical grouping, since metrics and dimensions from a given semantic 
view might
+        not be compatible. An implementation might expose a single semantic 
view for
+        exploration of available metric and dimesnions, and smaller curated 
semantic
+        views that are domain specific.
+        """
+        ...
+
+    def get_metrics(self, semantic_view: SemanticView) -> set[Metric]:
+        """
+        Return a set of metrics from a given semantic views.
+        """
+        ...
+
+    def get_dimensions(self, semantic_view: SemanticView) -> set[Dimension]:
+        """
+        Return a set of dimensions from a given semantic views.
+        """
+        ...
+
+    def get_valid_metrics(
+        self,
+        semantic_view: SemanticView,
+        metrics: set[Metric],
+        dimensions: set[Dimension],
+    ) -> set[Metric]:
+        """
+        Return compatible metrics for the given metrics and dimensions.
+
+        For metrics to be valid they must be compatible with all the provided
+        dimensions.
+        """
+        ...
+
+    def get_valid_dimensions(
+        self,
+        semantic_view: SemanticView,
+        metrics: set[Metric],
+        dimensions: set[Dimension],
+    ) -> set[Dimension]:
+        """
+        Return compatible dimensions for the given metrics.
+
+        For dimensions to be valid they must be compatible with all the 
provided
+        metrics.
+        """
+        ...
+
+    def get_query(
+        self,
+        semantic_view: SemanticView,
+        metrics: set[Metric],
+        dimensions: set[Dimension],
+        # populations: set[Population],
+        filters: set[Filter],
+        sort: Sort = NoSort,
+        limit: int | None = None,
+        offset: int | None = None,
+    ) -> Query:
+        """
+        Build a SQL query from the given metrics, dimensions, filters, and 
sort order.
+        """
+        ...
+
+    def get_query_from_standard_sql(
+        self,
+        semantic_view: SemanticView,
+        sql: str,
+    ) -> Query:
+        """
+        Build a SQL query from a pseudo-query referencing metrics and 
dimensions.
+
+        For example, given `metric1` having the expression `COUNT(*)`, this 
query:
+
+            SELECT metric1, dim1
+            FROM semantic_layer
+            GROUP BY dim1
+
+        Becomes:
+
+            SELECT metric1, dim1
+            FROM (
+              SELECT COUNT(*) AS metric1, dim1
+              FROM fact_table
+              JOIN dim_table
+                ON fact_table.dim_id = dim_table.id
+              GROUP BY dim1
+            ) AS semantic_view
+
+        """
+        ...
+
+
+TYPE_MAPPING: dict[Type, type[sqltypes.TypeEngine]] = {
+    # Numeric types
+    INTEGER: sqltypes.Integer,
+    NUMBER: sqltypes.Numeric,
+    DECIMAL: sqltypes.DECIMAL,
+    # String types
+    STRING: sqltypes.String,
+    # Boolean type
+    BOOLEAN: sqltypes.Boolean,
+    # Date/time types
+    DATE: sqltypes.Date,
+    TIME: sqltypes.Time,
+    DATETIME: sqltypes.DateTime,
+    INTERVAL: sqltypes.Interval,
+    # Complex types
+    OBJECT: sqltypes.JSON,
+    BINARY: sqltypes.LargeBinary,
+}
+
+
+def get_sqla_type_from_dimension_type(
+    dimension_type: Type,
+) -> sqltypes.TypeEngine:
+    """
+    Get the SQLAlchemy type corresponding to the given dimension type.
+    """
+    return TYPE_MAPPING.get(dimension_type, sqltypes.String)()
diff --git a/superset/models/core.py b/superset/models/core.py
index 787c380bea..7e9f085aec 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -126,7 +126,9 @@ class ConfigurationMethod(StrEnum):
     DYNAMIC_FORM = "dynamic_form"
 
 
-class Database(Model, AuditMixinNullable, ImportExportMixin):  # pylint: 
disable=too-many-public-methods
+class Database(
+    Model, AuditMixinNullable, ImportExportMixin
+):  # pylint: disable=too-many-public-methods
     """An ORM object that stores Database related information"""
 
     __tablename__ = "dbs"
@@ -400,9 +402,7 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
         return (
             username
             if (username := get_username())
-            else object_url.username
-            if self.impersonate_user
-            else None
+            else object_url.username if self.impersonate_user else None
         )
 
     @contextmanager
@@ -987,7 +987,10 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
             schema=table.schema,
         ) as inspector:
             return self.db_engine_spec.get_columns(
-                inspector, table, self.schema_options
+                self,
+                inspector,
+                table,
+                self.schema_options,
             )
 
     def get_metrics(
@@ -1076,9 +1079,11 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
         return self.perm
 
     def has_table(self, table: Table) -> bool:
-        with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) 
as engine:
-            # do not pass "" as an empty schema; force null
-            return engine.has_table(table.table, table.schema or None)
+        with self.get_inspector(
+            catalog=table.catalog,
+            schema=table.schema,
+        ) as inspector:
+            return self.db_engine_spec.has_table(self, inspector, table)
 
     def has_view(self, table: Table) -> bool:
         with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) 
as engine:


Reply via email to