This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch snowflake-semantic-layer in repository https://gitbox.apache.org/repos/asf/superset.git
commit cd019bab3e9a89c313d8fa593e1e2991af8ecf7b Author: Beto Dealmeida <[email protected]> AuthorDate: Wed Jul 23 15:51:07 2025 -0400 WIP --- superset/db_engine_specs/base.py | 30 ++++-- superset/db_engine_specs/snowflake.py | 172 +++++++++++++++++++++++++++++++++- superset/extensions/semantic_layer.py | 30 ++++++ 3 files changed, 224 insertions(+), 8 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 25e5f1b834..f7594f3844 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -30,6 +30,7 @@ from typing import ( cast, ContextManager, NamedTuple, + Type, TYPE_CHECKING, TypedDict, Union, @@ -62,7 +63,10 @@ from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants from superset.databases.utils import get_table_metadata, make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import OAuth2Error, OAuth2RedirectError -from superset.extensions.semantic_layer import SemanticLayer +from superset.extensions.semantic_layer import ( + get_sqla_type_from_dimension_type, + SemanticLayer, +) from superset.sql.parse import ( BaseSQLStatement, LimitMethod, @@ -107,6 +111,15 @@ logger = logging.getLogger() GenericDBException = Exception +class ValidColumnsType(TypedDict): + """ + Type for valid columns returned by `get_valid_metrics_and_dimensions`. + """ + + dimensions: set[str] + metrics: set[str] + + def convert_inspector_columns(cols: list[SQLAColumnType]) -> list[ResultSetColumnType]: result_set_columns: list[ResultSetColumnType] = [] for col in cols: @@ -218,7 +231,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ) # databases can optionally specify a semantic layer - semantic_layer: SemanticLayer | None = None + semantic_layer: Type[SemanticLayer] | None = None disable_ssh_tunneling = False @@ -1468,12 +1481,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods if schema and cls.try_remove_schema_from_table_name: tables = {re.sub(f"^{schema}\\.", "", table) for table in tables} + # add semantic views as tables too if cls.semantic_layer: + semantic_layer = cls.semantic_layer(inspector.engine) tables.update( semantic_view.name - for semantic_view in cls.semantic_layer( - inspector.engine - ).get_semantic_views() + for semantic_view in semantic_layer.get_semantic_views() ) return tables @@ -1579,7 +1592,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods { "name": dimension.name, "column_name": dimension.name, - "type": XXX, + "type": get_sqla_type_from_dimension_type(dimension.type), } for dimension in semantic_layer.get_dimensions(semantic_view) ] @@ -1948,6 +1961,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param kwargs: kwargs to be passed to cursor.execute() :return: """ + if cls.semantic_layer: + with database.get_engine() as engine: + semantic_layer = cls.semantic_layer(engine) + query = semantic_layer.get_query_from_standard_sql(query).sql + if cls.arraysize: cursor.arraysize = cls.arraysize try: diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 5b064c22c7..be13acf839 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -16,11 +16,13 @@ # under the License. from __future__ import annotations +import itertools import logging import re +from collections import defaultdict from datetime import datetime from re import Pattern -from typing import Any, Optional, TYPE_CHECKING, TypedDict +from typing import Any, Iterator, Optional, TYPE_CHECKING, TypedDict from urllib import parse from apispec import APISpec @@ -30,7 +32,7 @@ from cryptography.hazmat.primitives import serialization from flask import current_app from flask_babel import gettext as __ from marshmallow import fields, Schema -from sqlalchemy import types +from sqlalchemy import text, types from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL @@ -39,11 +41,36 @@ from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType from superset.db_engine_specs.postgres import PostgresBaseEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.extensions.semantic_layer import ( + BINARY, + BOOLEAN, + Column as SemanticColumn, + DATE, + DATETIME, + DECIMAL, + Dimension as SemanticDimension, + Filter as SemanticFilter, + INTEGER, + Metric as SemanticMetric, + NoSort, + NUMBER, + OBJECT, + Query as SemanticQuery, + SemanticView, + Sort as SemanticSort, + STRING, + Table as SemanticTable, + TIME, + Type as SemanticType, +) from superset.models.sql_lab import Query +from superset.sql.parse import Table from superset.utils import json from superset.utils.core import get_user_agent, QuerySource if TYPE_CHECKING: + from sqlalchemy.engine.base import Engine + from superset.models.core import Database # Regular expressions to catch custom errors @@ -77,6 +104,145 @@ class SnowflakeParametersType(TypedDict): warehouse: str +class SnowflakeSemanticLayer: + def __init__(self, engine: Engine) -> None: + self.engine = engine + + def execute( + self, + sql: str, + **kwargs: Any, + ) -> Iterator[dict[str, Any]]: + with self.engine.connect() as connection: + for row in connection.execute(text(sql), kwargs).mappings(): + yield dict(row) + + def get_semantic_views(self) -> set[SemanticView]: + sql = """ +SHOW SEMANTIC VIEWS + ->> SELECT "name" FROM $1; + """ + return {row["name"] for row in self.execute(sql)} + + def get_type(self, snowflake_type: str | None) -> type[SemanticType]: + if snowflake_type is None: + return STRING + + type_map = { + STRING: {r"VARCHAR\(\d+\)", "STRING", "TEXT", r"CHAR\(\d+\)"}, + INTEGER: {r"NUMBER\(38,\s0\)", "INT", "INTEGER", "BIGINT"}, + DECIMAL: {r"NUMBER\(10,\s2\)"}, + NUMBER: {r"NUMBER\(\d+,\s\d+)", "FLOAT", "DOUBLE"}, + BOOLEAN: {"BOOLEAN"}, + DATE: {"DATE"}, + DATETIME: {"TIMESTAMP_TZ", "TIMESTAMP__NTZ"}, + TIME: {"TIME"}, + OBJECT: {"OBJECT"}, + BINARY: {r"BINARY\(\d+\)", r"VARBINARY\(\d+\)"}, + } + for semantic_type, patterns in type_map.items(): + if any( + re.match(pattern, snowflake_type, re.IGNORECASE) for pattern in patterns + ): + return semantic_type + + return STRING + + def get_metrics(self, semantic_view: SemanticView) -> set[SemanticMetric]: + quoted_semantic_view_name = self.quote_table( + Table(semantic_view.name), + self.engine.dialect, + ) + sql = f""" +DESC SEMANTIC VIEW {quoted_semantic_view_name} + ->> SELECT "object_name", "property", "property_value" + FROM $1 + WHERE + "object_kind" = 'METRIC' AND + "property" IN ('EXPRESSION', 'DATA_TYPE', 'TABLE'); + """ # noqa: S608 (semantic_view.name is quoted) + rows = self.execute(sql) + + metrics: set[SemanticMetric] = set() + for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]): + attributes = defaultdict(set) + for row in group: + attributes[row["property"]].add(row["property_value"]) + + type_ = self.get_type(next(iter(attributes["DATA_TYPE"]), None)) + sql = next(iter(attributes["EXPRESSION"]), name) + tables = frozenset(attributes["TABLE"]) + join_columns = frozenset() + + metrics.add(SemanticMetric(name, type_, sql, tables, join_columns)) + + return metrics + + def get_dimensions(self, semantic_view: SemanticView) -> set[SemanticDimension]: + quoted_semantic_view_name = self.quote_table( + Table(semantic_view.name), + self.engine.dialect, + ) + sql = f""" +DESC SEMANTIC VIEW {quoted_semantic_view_name} + ->> SELECT "object_name", "property", "property_value" + FROM $1 + WHERE + "object_kind" = 'DIMENSION' AND + "property" IN ('DATA_TYPE', 'TABLE', 'EXPRESSION'); + """ # noqa: S608 (semantic_view.name is quoted) + rows = self.execute(sql) + + dimensions: set[SemanticDimension] = set() + for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]): + attributes = defaultdict(set) + for row in group: + attributes[row["property"]].add(row["property_value"]) + + table = next(iter(attributes["TABLE"]), None) + expression = next(iter(attributes["EXPRESSION"]), None) + column = ( + SemanticColumn(SemanticTable(table), expression) + if table and expression + else None + ) + type_ = self.get_type(next(iter(attributes["DATA_TYPE"]), None)) + + dimensions.add(SemanticDimension(column, name, type_)) + + return dimensions + + def get_valid_metrics( + self, + semantic_view: SemanticView, + metrics: set[SemanticMetric], + dimensions: set[SemanticDimension], + ) -> set[SemanticMetric]: + # all metrics and dimensions are valid inside a given semantic view + return self.get_metrics(semantic_view) + + def get_valid_dimensions( + self, + semantic_view: SemanticView, + metrics: set[SemanticMetric], + dimensions: set[SemanticDimension], + ) -> set[SemanticDimension]: + # all metrics and dimensions are valid inside a given semantic view + return self.get_dimensions(semantic_view) + + def get_query( + self, + semantic_view: SemanticView, + metrics: set[SemanticMetric], + dimensions: set[SemanticDimension], + filters: set[SemanticFilter], + sort: SemanticSort = NoSort, + limit: int | None = None, + offset: int | None = None, + ) -> SemanticQuery: + pass + + class SnowflakeEngineSpec(PostgresBaseEngineSpec): engine = "snowflake" engine_name = "Snowflake" @@ -90,6 +256,8 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): default_driver = "snowflake" sqlalchemy_uri_placeholder = "snowflake://" + semantic_layer = SnowflakeSemanticLayer + supports_dynamic_schema = True supports_catalog = supports_dynamic_catalog = supports_cross_catalog_queries = True diff --git a/superset/extensions/semantic_layer.py b/superset/extensions/semantic_layer.py index 85b39f1867..ac278788f7 100644 --- a/superset/extensions/semantic_layer.py +++ b/superset/extensions/semantic_layer.py @@ -4,6 +4,7 @@ from datetime import timedelta from functools import total_ordering from typing import Protocol, runtime_checkable +from sqlalchemy import types as sqltypes from sqlalchemy.engine.base import Engine @@ -308,3 +309,32 @@ class SemanticLayer(Protocol): """ ... + + +TYPE_MAPPING: dict[Type, type[sqltypes.TypeEngine]] = { + # Numeric types + INTEGER: sqltypes.Integer, + NUMBER: sqltypes.Numeric, + DECIMAL: sqltypes.DECIMAL, + # String types + STRING: sqltypes.String, + # Boolean type + BOOLEAN: sqltypes.Boolean, + # Date/time types + DATE: sqltypes.Date, + TIME: sqltypes.Time, + DATETIME: sqltypes.DateTime, + INTERVAL: sqltypes.Interval, + # Complex types + OBJECT: sqltypes.JSON, + BINARY: sqltypes.LargeBinary, +} + + +def get_sqla_type_from_dimension_type( + dimension_type: Type, +) -> type[sqltypes.TypeEngine]: + """ + Get the SQLAlchemy type corresponding to the given dimension type. + """ + return TYPE_MAPPING.get(dimension_type, sqltypes.String)
