This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch snowflake-semantic-layer
in repository https://gitbox.apache.org/repos/asf/superset.git

commit cd019bab3e9a89c313d8fa593e1e2991af8ecf7b
Author: Beto Dealmeida <[email protected]>
AuthorDate: Wed Jul 23 15:51:07 2025 -0400

    WIP
---
 superset/db_engine_specs/base.py      |  30 ++++--
 superset/db_engine_specs/snowflake.py | 172 +++++++++++++++++++++++++++++++++-
 superset/extensions/semantic_layer.py |  30 ++++++
 3 files changed, 224 insertions(+), 8 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 25e5f1b834..f7594f3844 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -30,6 +30,7 @@ from typing import (
     cast,
     ContextManager,
     NamedTuple,
+    Type,
     TYPE_CHECKING,
     TypedDict,
     Union,
@@ -62,7 +63,10 @@ from superset.constants import QUERY_CANCEL_KEY, TimeGrain 
as TimeGrainConstants
 from superset.databases.utils import get_table_metadata, make_url_safe
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.exceptions import OAuth2Error, OAuth2RedirectError
-from superset.extensions.semantic_layer import SemanticLayer
+from superset.extensions.semantic_layer import (
+    get_sqla_type_from_dimension_type,
+    SemanticLayer,
+)
 from superset.sql.parse import (
     BaseSQLStatement,
     LimitMethod,
@@ -107,6 +111,15 @@ logger = logging.getLogger()
 GenericDBException = Exception
 
 
+class ValidColumnsType(TypedDict):
+    """
+    Type for valid columns returned by `get_valid_metrics_and_dimensions`.
+    """
+
+    dimensions: set[str]
+    metrics: set[str]
+
+
 def convert_inspector_columns(cols: list[SQLAColumnType]) -> 
list[ResultSetColumnType]:
     result_set_columns: list[ResultSetColumnType] = []
     for col in cols:
@@ -218,7 +231,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
     )
 
     # databases can optionally specify a semantic layer
-    semantic_layer: SemanticLayer | None = None
+    semantic_layer: Type[SemanticLayer] | None = None
 
     disable_ssh_tunneling = False
 
@@ -1468,12 +1481,12 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         if schema and cls.try_remove_schema_from_table_name:
             tables = {re.sub(f"^{schema}\\.", "", table) for table in tables}
 
+        # add semantic views as tables too
         if cls.semantic_layer:
+            semantic_layer = cls.semantic_layer(inspector.engine)
             tables.update(
                 semantic_view.name
-                for semantic_view in cls.semantic_layer(
-                    inspector.engine
-                ).get_semantic_views()
+                for semantic_view in semantic_layer.get_semantic_views()
             )
 
         return tables
@@ -1579,7 +1592,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
                     {
                         "name": dimension.name,
                         "column_name": dimension.name,
-                        "type": XXX,
+                        "type": 
get_sqla_type_from_dimension_type(dimension.type),
                     }
                     for dimension in 
semantic_layer.get_dimensions(semantic_view)
                 ]
@@ -1948,6 +1961,11 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         :param kwargs: kwargs to be passed to cursor.execute()
         :return:
         """
+        if cls.semantic_layer:
+            with database.get_engine() as engine:
+                semantic_layer = cls.semantic_layer(engine)
+                query = semantic_layer.get_query_from_standard_sql(query).sql
+
         if cls.arraysize:
             cursor.arraysize = cls.arraysize
         try:
diff --git a/superset/db_engine_specs/snowflake.py 
b/superset/db_engine_specs/snowflake.py
index 5b064c22c7..be13acf839 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -16,11 +16,13 @@
 # under the License.
 from __future__ import annotations
 
+import itertools
 import logging
 import re
+from collections import defaultdict
 from datetime import datetime
 from re import Pattern
-from typing import Any, Optional, TYPE_CHECKING, TypedDict
+from typing import Any, Iterator, Optional, TYPE_CHECKING, TypedDict
 from urllib import parse
 
 from apispec import APISpec
@@ -30,7 +32,7 @@ from cryptography.hazmat.primitives import serialization
 from flask import current_app
 from flask_babel import gettext as __
 from marshmallow import fields, Schema
-from sqlalchemy import types
+from sqlalchemy import text, types
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.url import URL
 
@@ -39,11 +41,36 @@ from superset.databases.utils import make_url_safe
 from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
 from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
+from superset.extensions.semantic_layer import (
+    BINARY,
+    BOOLEAN,
+    Column as SemanticColumn,
+    DATE,
+    DATETIME,
+    DECIMAL,
+    Dimension as SemanticDimension,
+    Filter as SemanticFilter,
+    INTEGER,
+    Metric as SemanticMetric,
+    NoSort,
+    NUMBER,
+    OBJECT,
+    Query as SemanticQuery,
+    SemanticView,
+    Sort as SemanticSort,
+    STRING,
+    Table as SemanticTable,
+    TIME,
+    Type as SemanticType,
+)
 from superset.models.sql_lab import Query
+from superset.sql.parse import Table
 from superset.utils import json
 from superset.utils.core import get_user_agent, QuerySource
 
 if TYPE_CHECKING:
+    from sqlalchemy.engine.base import Engine
+
     from superset.models.core import Database
 
 # Regular expressions to catch custom errors
@@ -77,6 +104,145 @@ class SnowflakeParametersType(TypedDict):
     warehouse: str
 
 
+class SnowflakeSemanticLayer:
+    def __init__(self, engine: Engine) -> None:
+        self.engine = engine
+
+    def execute(
+        self,
+        sql: str,
+        **kwargs: Any,
+    ) -> Iterator[dict[str, Any]]:
+        with self.engine.connect() as connection:
+            for row in connection.execute(text(sql), kwargs).mappings():
+                yield dict(row)
+
+    def get_semantic_views(self) -> set[SemanticView]:
+        sql = """
+SHOW SEMANTIC VIEWS
+    ->> SELECT "name" FROM $1;
+        """
+        return {row["name"] for row in self.execute(sql)}
+
+    def get_type(self, snowflake_type: str | None) -> type[SemanticType]:
+        if snowflake_type is None:
+            return STRING
+
+        type_map = {
+            STRING: {r"VARCHAR\(\d+\)", "STRING", "TEXT", r"CHAR\(\d+\)"},
+            INTEGER: {r"NUMBER\(38,\s0\)", "INT", "INTEGER", "BIGINT"},
+            DECIMAL: {r"NUMBER\(10,\s2\)"},
+            NUMBER: {r"NUMBER\(\d+,\s\d+)", "FLOAT", "DOUBLE"},
+            BOOLEAN: {"BOOLEAN"},
+            DATE: {"DATE"},
+            DATETIME: {"TIMESTAMP_TZ", "TIMESTAMP__NTZ"},
+            TIME: {"TIME"},
+            OBJECT: {"OBJECT"},
+            BINARY: {r"BINARY\(\d+\)", r"VARBINARY\(\d+\)"},
+        }
+        for semantic_type, patterns in type_map.items():
+            if any(
+                re.match(pattern, snowflake_type, re.IGNORECASE) for pattern 
in patterns
+            ):
+                return semantic_type
+
+        return STRING
+
+    def get_metrics(self, semantic_view: SemanticView) -> set[SemanticMetric]:
+        quoted_semantic_view_name = self.quote_table(
+            Table(semantic_view.name),
+            self.engine.dialect,
+        )
+        sql = f"""
+DESC SEMANTIC VIEW {quoted_semantic_view_name}
+    ->> SELECT "object_name", "property", "property_value"
+        FROM $1
+        WHERE
+            "object_kind" = 'METRIC' AND
+            "property" IN ('EXPRESSION', 'DATA_TYPE', 'TABLE');
+        """  # noqa: S608 (semantic_view.name is quoted)
+        rows = self.execute(sql)
+
+        metrics: set[SemanticMetric] = set()
+        for name, group in itertools.groupby(rows, key=lambda x: 
x["object_name"]):
+            attributes = defaultdict(set)
+            for row in group:
+                attributes[row["property"]].add(row["property_value"])
+
+            type_ = self.get_type(next(iter(attributes["DATA_TYPE"]), None))
+            sql = next(iter(attributes["EXPRESSION"]), name)
+            tables = frozenset(attributes["TABLE"])
+            join_columns = frozenset()
+
+            metrics.add(SemanticMetric(name, type_, sql, tables, join_columns))
+
+        return metrics
+
+    def get_dimensions(self, semantic_view: SemanticView) -> 
set[SemanticDimension]:
+        quoted_semantic_view_name = self.quote_table(
+            Table(semantic_view.name),
+            self.engine.dialect,
+        )
+        sql = f"""
+DESC SEMANTIC VIEW {quoted_semantic_view_name}
+    ->> SELECT "object_name", "property", "property_value"
+        FROM $1
+        WHERE
+            "object_kind" = 'DIMENSION' AND
+            "property" IN ('DATA_TYPE', 'TABLE', 'EXPRESSION');
+        """  # noqa: S608 (semantic_view.name is quoted)
+        rows = self.execute(sql)
+
+        dimensions: set[SemanticDimension] = set()
+        for name, group in itertools.groupby(rows, key=lambda x: 
x["object_name"]):
+            attributes = defaultdict(set)
+            for row in group:
+                attributes[row["property"]].add(row["property_value"])
+
+            table = next(iter(attributes["TABLE"]), None)
+            expression = next(iter(attributes["EXPRESSION"]), None)
+            column = (
+                SemanticColumn(SemanticTable(table), expression)
+                if table and expression
+                else None
+            )
+            type_ = self.get_type(next(iter(attributes["DATA_TYPE"]), None))
+
+            dimensions.add(SemanticDimension(column, name, type_))
+
+        return dimensions
+
+    def get_valid_metrics(
+        self,
+        semantic_view: SemanticView,
+        metrics: set[SemanticMetric],
+        dimensions: set[SemanticDimension],
+    ) -> set[SemanticMetric]:
+        # all metrics and dimensions are valid inside a given semantic view
+        return self.get_metrics(semantic_view)
+
+    def get_valid_dimensions(
+        self,
+        semantic_view: SemanticView,
+        metrics: set[SemanticMetric],
+        dimensions: set[SemanticDimension],
+    ) -> set[SemanticDimension]:
+        # all metrics and dimensions are valid inside a given semantic view
+        return self.get_dimensions(semantic_view)
+
+    def get_query(
+        self,
+        semantic_view: SemanticView,
+        metrics: set[SemanticMetric],
+        dimensions: set[SemanticDimension],
+        filters: set[SemanticFilter],
+        sort: SemanticSort = NoSort,
+        limit: int | None = None,
+        offset: int | None = None,
+    ) -> SemanticQuery:
+        pass
+
+
 class SnowflakeEngineSpec(PostgresBaseEngineSpec):
     engine = "snowflake"
     engine_name = "Snowflake"
@@ -90,6 +256,8 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
     default_driver = "snowflake"
     sqlalchemy_uri_placeholder = "snowflake://"
 
+    semantic_layer = SnowflakeSemanticLayer
+
     supports_dynamic_schema = True
     supports_catalog = supports_dynamic_catalog = 
supports_cross_catalog_queries = True
 
diff --git a/superset/extensions/semantic_layer.py 
b/superset/extensions/semantic_layer.py
index 85b39f1867..ac278788f7 100644
--- a/superset/extensions/semantic_layer.py
+++ b/superset/extensions/semantic_layer.py
@@ -4,6 +4,7 @@ from datetime import timedelta
 from functools import total_ordering
 from typing import Protocol, runtime_checkable
 
+from sqlalchemy import types as sqltypes
 from sqlalchemy.engine.base import Engine
 
 
@@ -308,3 +309,32 @@ class SemanticLayer(Protocol):
 
         """
         ...
+
+
+TYPE_MAPPING: dict[Type, type[sqltypes.TypeEngine]] = {
+    # Numeric types
+    INTEGER: sqltypes.Integer,
+    NUMBER: sqltypes.Numeric,
+    DECIMAL: sqltypes.DECIMAL,
+    # String types
+    STRING: sqltypes.String,
+    # Boolean type
+    BOOLEAN: sqltypes.Boolean,
+    # Date/time types
+    DATE: sqltypes.Date,
+    TIME: sqltypes.Time,
+    DATETIME: sqltypes.DateTime,
+    INTERVAL: sqltypes.Interval,
+    # Complex types
+    OBJECT: sqltypes.JSON,
+    BINARY: sqltypes.LargeBinary,
+}
+
+
+def get_sqla_type_from_dimension_type(
+    dimension_type: Type,
+) -> type[sqltypes.TypeEngine]:
+    """
+    Get the SQLAlchemy type corresponding to the given dimension type.
+    """
+    return TYPE_MAPPING.get(dimension_type, sqltypes.String)

Reply via email to