This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch canva-demo in repository https://gitbox.apache.org/repos/asf/superset.git
commit 66cb9cb40c7c27091234d2bf1c9ce07cce419107 Author: Beto Dealmeida <[email protected]> AuthorDate: Wed Jul 30 17:46:02 2025 -0400 Snowflake --- superset/daos/dataset.py | 11 +- superset/db_engine_specs/base.py | 223 ++++++++++++++++++++-- superset/db_engine_specs/snowflake.py | 336 ++++++++++++++++++++++++++++++++- superset/extensions/semantic_layer.py | 340 ++++++++++++++++++++++++++++++++++ superset/models/core.py | 21 ++- 5 files changed, 895 insertions(+), 36 deletions(-) diff --git a/superset/daos/dataset.py b/superset/daos/dataset.py index 5fbd012801..9221af7575 100644 --- a/superset/daos/dataset.py +++ b/superset/daos/dataset.py @@ -75,12 +75,11 @@ class DatasetDAO(BaseDAO[SqlaTable]): database: Database, table: Table, ) -> bool: - try: - database.get_table(table) - return True - except SQLAlchemyError as ex: # pragma: no cover - logger.warning("Got an error %s validating table: %s", str(ex), table) - return False + with database.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: + return database.db_engine_spec.has_table(database, inspector, table) @staticmethod def validate_uniqueness( diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index f8449f4030..47bd1afec7 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -30,6 +30,7 @@ from typing import ( cast, ContextManager, NamedTuple, + Type, TYPE_CHECKING, TypedDict, Union, @@ -54,7 +55,7 @@ from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import literal_column, quoted_name, text -from sqlalchemy.sql.expression import ColumnClause, Select, TextClause +from sqlalchemy.sql.expression import BinaryExpression, ColumnClause, Select, TextClause from sqlalchemy.types import TypeEngine from superset import db @@ -62,6 +63,10 @@ from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants from superset.databases.utils import get_table_metadata, make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import OAuth2Error, OAuth2RedirectError +from superset.extensions.semantic_layer import ( + get_sqla_type_from_dimension_type, + SemanticLayer, +) from superset.sql.parse import ( BaseSQLStatement, LimitMethod, @@ -106,6 +111,15 @@ logger = logging.getLogger() GenericDBException = Exception +class ValidColumnsType(TypedDict): + """ + Type for valid columns returned by `get_valid_metrics_and_dimensions`. + """ + + dimensions: set[str] + metrics: set[str] + + def convert_inspector_columns(cols: list[SQLAColumnType]) -> list[ResultSetColumnType]: result_set_columns: list[ResultSetColumnType] = [] for col in cols: @@ -188,15 +202,6 @@ class MetricType(TypedDict, total=False): extra: str | None -class ValidColumnsType(TypedDict): - """ - Type for valid columns returned by `get_valid_metrics_and_dimensions`. - """ - - dimensions: set[str] - metrics: set[str] - - class BaseEngineSpec: # pylint: disable=too-many-public-methods """Abstract class for database engine specific configurations @@ -225,6 +230,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods "engine+driver://user:password@host:port/dbname[?key=value&key=value...]" ) + # databases can optionally specify a semantic layer + semantic_layer: Type[SemanticLayer] | None = None + disable_ssh_tunneling = False _date_trunc_functions: dict[str, str] = {} @@ -388,6 +396,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods disallow_uri_query_params: dict[str, set[str]] = {} # A Dict of query parameters that will always be used on every connection # by driver name + + # Whether to use equality operators (= true/false) instead of IS operators + # for boolean filters. Some databases like Snowflake don't support IS true/false + use_equality_for_boolean_filters = False enforce_uri_query_params: dict[str, dict[str, Any]] = {} force_column_alias_quotes = False @@ -1218,6 +1230,78 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ return None + @classmethod + def handle_boolean_filter( + cls, sqla_col: Any, op: str, value: bool + ) -> BinaryExpression: + """ + Handle boolean filter operations with engine-specific logic. + + By default, uses SQLAlchemy's IS operator (column IS true/false). + Engines that don't support IS for boolean values can override + use_equality_for_boolean_filters to use equality operators instead. + + :param sqla_col: SQLAlchemy column element + :param op: Filter operator (IS_TRUE or IS_FALSE) + :param value: Boolean value (True or False) + :return: SQLAlchemy expression for the boolean filter + """ + if cls.use_equality_for_boolean_filters: + return sqla_col == value + else: + return sqla_col.is_(value) + + @classmethod + def handle_null_filter( + cls, + sqla_col: Any, + op: utils.FilterOperator, + ) -> BinaryExpression: + """ + Handle null/not null filter operations. + + :param sqla_col: SQLAlchemy column element + :param op: Filter operator (IS_NULL or IS_NOT_NULL) + :return: SQLAlchemy expression for the null filter + """ + from superset.utils import core as utils + + if op == utils.FilterOperator.IS_NULL: + return sqla_col.is_(None) + elif op == utils.FilterOperator.IS_NOT_NULL: + return sqla_col.isnot(None) + else: + raise ValueError(f"Invalid null filter operator: {op}") + + @classmethod + def handle_comparison_filter( + cls, sqla_col: Any, op: utils.FilterOperator, value: Any + ) -> BinaryExpression: + """ + Handle comparison filter operations (=, !=, >, <, >=, <=). + + :param sqla_col: SQLAlchemy column element + :param op: Filter operator + :param value: Filter value + :return: SQLAlchemy expression for the comparison filter + """ + from superset.utils import core as utils + + if op == utils.FilterOperator.EQUALS: + return sqla_col == value + elif op == utils.FilterOperator.NOT_EQUALS: + return sqla_col != value + elif op == utils.FilterOperator.GREATER_THAN: + return sqla_col > value + elif op == utils.FilterOperator.LESS_THAN: + return sqla_col < value + elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS: + return sqla_col >= value + elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS: + return sqla_col <= value + else: + raise ValueError(f"Invalid comparison filter operator: {op}") + @classmethod def handle_cursor(cls, cursor: Any, query: Query) -> None: """Handle a live cursor between the execute and fetchall calls @@ -1401,8 +1485,32 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods if schema and cls.try_remove_schema_from_table_name: tables = {re.sub(f"^{schema}\\.", "", table) for table in tables} + + # add semantic views as tables too + if cls.semantic_layer: + semantic_layer = cls.semantic_layer(inspector.engine) + tables.update( + semantic_view.name + for semantic_view in semantic_layer.get_semantic_views() + ) + return tables + @classmethod + def has_table( + cls, + database: Database, + inspector: Inspector, + table: Table, + ) -> bool: + if cls.semantic_layer: + semantic_layer = cls.semantic_layer(inspector.engine) + semantic_views = semantic_layer.get_semantic_views() + if table.table in {semantic_view.name for semantic_view in semantic_views}: + return True + + return inspector.has_table(table.table, table.schema) + @classmethod def get_view_names( # pylint: disable=unused-argument cls, @@ -1476,6 +1584,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def get_columns( # pylint: disable=unused-argument cls, + database: Database, inspector: Inspector, table: Table, options: dict[str, Any] | None = None, @@ -1483,7 +1592,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ Get all columns from a given schema and table. - The inspector will be bound to a catalog, if one was specified. + The inspector will be bound to a catalog, if one was specified. If the database + supports semantic layers the method will check if the table is a semantic view, + and return columns (metrics and dimensions) from it instead. :param inspector: SqlAlchemy Inspector instance :param table: Table instance @@ -1491,6 +1602,26 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods some databases :return: All columns in table """ + if cls.semantic_layer: + semantic_layer = cls.semantic_layer(inspector.engine) + semantic_views = { + semantic_view.name: semantic_view + for semantic_view in semantic_layer.get_semantic_views() + } + if semantic_view := semantic_views.get(table.table): + dialect = database.get_dialect() + return [ + { + "name": dimension.name, + "column_name": dimension.name, + "type": cls.column_datatype_to_string( + get_sqla_type_from_dimension_type(dimension.type), + dialect, + ), + } + for dimension in semantic_layer.get_dimensions(semantic_view) + ] + return convert_inspector_columns( cast( list[SQLAColumnType], @@ -1508,6 +1639,22 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ Get all metrics from a given schema and table. """ + if cls.semantic_layer: + semantic_layer = cls.semantic_layer(inspector.engine) + semantic_views = { + semantic_view.name: semantic_view + for semantic_view in semantic_layer.get_semantic_views() + } + if semantic_view := semantic_views.get(table.table): + return [ + { + "metric_name": metric.name, + "verbose_name": metric.name, + "expression": metric.sql, + } + for metric in semantic_layer.get_metrics(semantic_view) + ] + return [ { "metric_name": "count", @@ -1526,17 +1673,48 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods metrics: set[str], ) -> ValidColumnsType: """ - Given a selection of columns/metrics from a datasource, return related columns. - - This is a method used for semantic layers, where tables can have columns and - metrics that cannot be computed together. When the user selects a given metric - it allows the UI to filter the remaining metrics and dimensions so that only - valid combinations are possible. + Get valid metrics and dimensions. - The method should only be called when ``supports_dynamic_columns`` is set to - true. The default method in the base class ignores the selected columns and - metrics, and simply returns everything, for reference. + Given a datasource, and sets of selected metrics and dimensions, return the + sets of valid metrics and dimensions that can further be selected. """ + if cls.semantic_layer: + with database.get_sqla_engine() as engine: + semantic_layer = cls.semantic_layer(engine) + semantic_views = { + semantic_view.name: semantic_view + for semantic_view in semantic_layer.get_semantic_views() + } + if semantic_view := semantic_views.get(table.table): + selected_metrics = { + metric + for metric in semantic_layer.get_metrics(semantic_view) + if metric.name in metrics + } + selected_dimensions = { + dimension + for dimension in semantic_layer.get_dimensions(semantic_view) + if dimension.name in dimensions + } + return { + "metrics": { + metric.name + for metric in semantic_layer.get_valid_metrics( + semantic_view, + selected_metrics, + selected_dimensions, + ) + }, + "dimensions": { + dimension.name + for dimension in semantic_layer.get_valid_dimensions( + semantic_view, + selected_metrics, + selected_dimensions, + ) + }, + } + return { "dimensions": {column.column_name for column in table.columns}, "metrics": {metric.metric_name for metric in table.metrics}, @@ -1808,6 +1986,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param kwargs: kwargs to be passed to cursor.execute() :return: """ + if cls.semantic_layer: + with cls.get_engine(database, schema="tpcds_sf10tcl") as engine: + semantic_layer = cls.semantic_layer(engine) + query = semantic_layer.get_query_from_standard_sql(query).sql + if cls.arraysize: cursor.arraysize = cls.arraysize try: diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 7e353f006d..2056e7b4bc 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -16,11 +16,13 @@ # under the License. from __future__ import annotations +import itertools import logging import re +from collections import defaultdict from datetime import datetime from re import Pattern -from typing import Any, Optional, TYPE_CHECKING, TypedDict +from typing import Any, Iterator, Optional, TYPE_CHECKING, TypedDict from urllib import parse from apispec import APISpec @@ -30,20 +32,48 @@ from cryptography.hazmat.primitives import serialization from flask import current_app from flask_babel import gettext as __ from marshmallow import fields, Schema -from sqlalchemy import types +from sqlalchemy import text, types +from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL +from sqlglot import exp, parse_one from superset.constants import TimeGrain from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType from superset.db_engine_specs.postgres import PostgresBaseEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.extensions.semantic_layer import ( + BINARY, + BOOLEAN, + Column as SemanticColumn, + DATE, + DATETIME, + DECIMAL, + Dimension as SemanticDimension, + Filter as SemanticFilter, + INTEGER, + Metric as SemanticMetric, + NoSort, + NUMBER, + OBJECT, + Query as SemanticQuery, + SemanticView, + Sort as SemanticSort, + SortDirectionEnum, + STRING, + Table as SemanticTable, + TIME, + Type as SemanticType, +) from superset.models.sql_lab import Query +from superset.sql.parse import Table from superset.utils import json from superset.utils.core import get_user_agent, QuerySource if TYPE_CHECKING: + from sqlalchemy.engine.base import Engine + from superset.models.core import Database # Regular expressions to catch custom errors @@ -77,16 +107,318 @@ class SnowflakeParametersType(TypedDict): warehouse: str +class SnowflakeSemanticLayer: + def __init__(self, engine: Engine) -> None: + self.engine = engine + + def execute( + self, + sql: str, + **kwargs: Any, + ) -> Iterator[dict[str, Any]]: + with self.engine.connect() as connection: + for row in connection.execute(text(sql), kwargs).mappings(): + yield dict(row) + + def get_semantic_views(self) -> set[SemanticView]: + sql = """ +SHOW SEMANTIC VIEWS + ->> SELECT "name" FROM $1; + """ + return {SemanticView(row["name"]) for row in self.execute(sql)} + + def get_type(self, snowflake_type: str | None) -> type[SemanticType]: + if snowflake_type is None: + return STRING + + type_map = { + STRING: {r"VARCHAR\(\d+\)$", "STRING$", "TEXT$", r"CHAR\(\d+\)$"}, + INTEGER: {r"NUMBER\(38,\s?0\)$", "INT$", "INTEGER$", "BIGINT$"}, + DECIMAL: {r"NUMBER\(10,\s?2\)$"}, + NUMBER: {r"NUMBER\(\d+,\s?\d+\)$", "FLOAT$", "DOUBLE$"}, + BOOLEAN: {"BOOLEAN$"}, + DATE: {"DATE$"}, + DATETIME: {"TIMESTAMP_TZ$", "TIMESTAMP__NTZ$"}, + TIME: {"TIME$"}, + OBJECT: {"OBJECT$"}, + BINARY: {r"BINARY\(\d+\)$", r"VARBINARY\(\d+\)$"}, + } + for semantic_type, patterns in type_map.items(): + if any( + re.match(pattern, snowflake_type, re.IGNORECASE) for pattern in patterns + ): + return semantic_type + + return STRING + + @classmethod + def quote_table(cls, table: Table, dialect: Dialect) -> str: + """ + Fully quote a table name, including the schema and catalog. + """ + quoters = { + "catalog": dialect.identifier_preparer.quote_schema, + "schema": dialect.identifier_preparer.quote_schema, + "table": dialect.identifier_preparer.quote, + } + + return ".".join( + function(getattr(table, key)) + for key, function in quoters.items() + if getattr(table, key) + ) + + def get_metrics(self, semantic_view: SemanticView) -> set[SemanticMetric]: + quoted_semantic_view_name = self.quote_table( + Table(semantic_view.name), + self.engine.dialect, + ) + sql = f""" +DESC SEMANTIC VIEW {quoted_semantic_view_name} + ->> SELECT "object_name", "property", "property_value" + FROM $1 + WHERE + "object_kind" = 'METRIC' AND + "property" IN ('DATA_TYPE', 'TABLE'); + """ # noqa: S608 (semantic_view.name is quoted) + rows = self.execute(sql) + + metrics: set[SemanticMetric] = set() + for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]): + attributes = defaultdict(set) + for row in group: + attributes[row["property"]].add(row["property_value"]) + + table = next(iter(attributes["TABLE"])) + metric_name = table + "." + name + type_ = self.get_type(next(iter(attributes["DATA_TYPE"]))) + sql = self.engine.dialect.identifier_preparer.quote(metric_name) + tables = frozenset(attributes["TABLE"]) + join_columns = frozenset() + + metrics.add(SemanticMetric(metric_name, type_, sql, tables, join_columns)) + + return metrics + + def get_dimensions(self, semantic_view: SemanticView) -> set[SemanticDimension]: + quoted_semantic_view_name = self.quote_table( + Table(semantic_view.name), + self.engine.dialect, + ) + sql = f""" +DESC SEMANTIC VIEW {quoted_semantic_view_name} + ->> SELECT "object_name", "property", "property_value" + FROM $1 + WHERE + "object_kind" = 'DIMENSION' AND + "property" IN ('DATA_TYPE', 'TABLE'); + """ # noqa: S608 (semantic_view.name is quoted) + rows = self.execute(sql) + + dimensions: set[SemanticDimension] = set() + for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]): + attributes = defaultdict(set) + for row in group: + attributes[row["property"]].add(row["property_value"]) + + table = next(iter(attributes["TABLE"])) + dimension_name = table + "." + name + column = SemanticColumn(SemanticTable(table), name) + type_ = self.get_type(next(iter(attributes["DATA_TYPE"]))) + + dimensions.add(SemanticDimension(column, dimension_name, type_)) + + return dimensions + + def get_valid_metrics( + self, + semantic_view: SemanticView, + metrics: set[SemanticMetric], + dimensions: set[SemanticDimension], + ) -> set[SemanticMetric]: + # all metrics and dimensions are valid inside a given semantic view + return self.get_metrics(semantic_view) + + def get_valid_dimensions( + self, + semantic_view: SemanticView, + metrics: set[SemanticMetric], + dimensions: set[SemanticDimension], + ) -> set[SemanticDimension]: + # all metrics and dimensions are valid inside a given semantic view + return self.get_dimensions(semantic_view) + + def get_query( + self, + semantic_view: SemanticView, + metrics: set[SemanticMetric], + dimensions: set[SemanticDimension], + filters: set[SemanticFilter], + sort: SemanticSort = NoSort, + limit: int | None = None, + offset: int | None = None, + ) -> SemanticQuery: + ast = self.build_query( + semantic_view, + metrics, + dimensions, + filters, + sort, + limit, + offset, + ) + return SemanticQuery(sql=ast.sql(dialect="snowflake", pretty=True)) + + def build_query( + self, + semantic_view: SemanticView, + metrics: set[SemanticMetric], + dimensions: set[SemanticDimension], + filters: set[SemanticFilter], + sort: SemanticSort = NoSort, + limit: int | None = None, + offset: int | None = None, + ) -> exp.Select: + semantic_view = exp.SemanticView( + this=exp.Table(this=exp.Identifier(this=semantic_view.name, quoted=True)), + dimensions=[ + exp.Column( + this=exp.Identifier(this=dimension.column.name, quoted=True), + table=exp.Identifier( + this=dimension.column.relation.name, + quoted=True, + ), + ) + for dimension in dimensions + ], + metrics=[ + exp.Column( + this=exp.Identifier(this=column, quoted=True), + table=exp.Identifier(this=table, quoted=True), + ) + for table, column in ( + metric.name.split(".", 1) + for metric in metrics + if "." in metric.name + ) + ], + ) + query = exp.Select( + expressions=[exp.Star()], + **{"from": exp.From(this=exp.Table(this=semantic_view))}, + ) + + if sort.items: + order = [ + exp.Ordered( + this=exp.Column(this=exp.Identifier(this=item.field.name)), + desc=item.direction == SortDirectionEnum.DESC, + nulls_first=item.nulls_first, + ) + for item in sort.items + ] + query.args["order"] = exp.Order(expressions=order) + + if offset: + query = query.offset(offset) + + if limit: + query = query.limit(limit) + + return query + + def get_query_from_standard_sql(self, sql: str) -> SemanticQuery: + """ + Convert the Explore query into a proper query. + + Explore will produce a pseudo-SQL query that references metrics and dimensions + as if they were columns in a table. This method replaces the table name with a + call to `SEMANTIC_VIEW`, and removes the `GROUP BY` clause, since all the + aggregations happen inside the `SEMANTIC_VIEW` call. + """ + ast = parse_one(sql, "snowflake") + table = ast.find(exp.Table) + if not table: + return SemanticQuery(sql=sql) + + semantic_views = self.get_semantic_views() + if table.name not in {semantic_view.name for semantic_view in semantic_views}: + return SemanticQuery(sql=sql) + + # collect all metric and dimensions + semantic_view = SemanticView(table.name) + all_metrics = self.get_metrics(semantic_view) + all_dimensions = self.get_dimensions(semantic_view) + + # collect metrics and dimensions used in the query + columns = {column.name for column in ast.find_all(exp.Column)} + metrics = [metric for metric in all_metrics if metric.name in columns] + dimensions = [ + dimension for dimension in all_dimensions if dimension.name in columns + ] + + # now replace table with a call to `SEMANTIC_VIEW` + udtf = exp.Table( + this=exp.SemanticView( + this=exp.Table( + this=exp.Identifier(this=semantic_view.name, quoted=True) + ), + metrics=[ + exp.Column( + this=exp.Identifier(this=column, quoted=True), + table=exp.Identifier(this=table, quoted=True), + ) + for table, column in ( + metric.name.split(".", 1) + for metric in metrics + if "." in metric.name + ) + ], + dimensions=[ + exp.Column( + this=exp.Identifier(this=column, quoted=True), + table=exp.Identifier(this=table, quoted=True), + ) + for table, column in ( + dimension.name.split(".", 1) + for dimension in dimensions + if "." in dimension.name + ) + ], + ), + alias=exp.TableAlias( + this=exp.Identifier(this="table_alias", quoted=False), + columns=[ + exp.Identifier(this=column.name, quoted=True) + for column in metrics + dimensions + ], + ), + ) + table.replace(udtf) + + # remove group by, since aggregations are done inside the `SEMANTIC_VIEW` call + del ast.args["group"] + + print("BETO") + print(ast.sql(dialect="snowflake", pretty=True)) + return SemanticQuery(sql=ast.sql(dialect="snowflake", pretty=True)) + + class SnowflakeEngineSpec(PostgresBaseEngineSpec): engine = "snowflake" engine_name = "Snowflake" force_column_alias_quotes = True max_column_name_length = 256 + # Snowflake doesn't support IS true/false syntax, use = true/false instead + use_equality_for_boolean_filters = True + parameters_schema = SnowflakeParametersSchema() default_driver = "snowflake" sqlalchemy_uri_placeholder = "snowflake://" + semantic_layer = SnowflakeSemanticLayer + supports_dynamic_schema = True supports_catalog = supports_dynamic_catalog = supports_cross_catalog_queries = True diff --git a/superset/extensions/semantic_layer.py b/superset/extensions/semantic_layer.py new file mode 100644 index 0000000000..c597a70619 --- /dev/null +++ b/superset/extensions/semantic_layer.py @@ -0,0 +1,340 @@ +import enum +from dataclasses import dataclass +from datetime import timedelta +from functools import total_ordering +from typing import Protocol, runtime_checkable + +from sqlalchemy import types as sqltypes +from sqlalchemy.engine.base import Engine + + +class Type: + """ + Base class for types. + """ + + +class INTEGER(Type): + """ + Represents an integer type. + """ + + +class NUMBER(Type): + """ + Represents a number type. + """ + + +class DECIMAL(Type): + """ + Represents a decimal type. + """ + + +class STRING(Type): + """ + Represents a string type. + """ + + +class BOOLEAN(Type): + """ + Represents a boolean type. + """ + + +class DATE(Type): + """ + Represents a date type. + """ + + +class TIME(Type): + """ + Represents a time type. + """ + + +class DATETIME(DATE, TIME): + """ + Represents a datetime type. + """ + + +class INTERVAL(Type): + """ + Represents an interval type. + """ + + +class OBJECT(Type): + """ + Represents an object type. + """ + + +class BINARY(Type): + """ + Represents a binary type. + """ + + +@dataclass(frozen=True) +class SemanticView: + name: str + description: str | None = None + + +@dataclass(frozen=True) +class Relation: + name: str + schema: str | None = None + catalog: str | None = None + + +@dataclass(frozen=True) +class Table: + name: str + schema: str | None = None + catalog: str | None = None + + +@dataclass(frozen=True) +class View: + name: str + sql: str + schema: str | None = None + catalog: str | None = None + + +@dataclass(frozen=True) +class Virtual: + name: str + + +@dataclass(frozen=True) +class Metric: + name: str + type: type[Type] + sql: str + tables: frozenset[Table] + join_columns: frozenset[str] + + +@total_ordering +class ComparableEnum(enum.Enum): + def __eq__(self, other: object) -> bool: + if isinstance(other, enum.Enum): + return self.value == other.value + return NotImplemented + + def __lt__(self, other: object) -> bool: + if isinstance(other, enum.Enum): + return self.value < other.value + return NotImplemented + + def __hash__(self): + return hash((self.__class__, self.name)) + + +class TimeGrain(ComparableEnum): + second = timedelta(seconds=1) + minute = timedelta(minutes=1) + hour = timedelta(hours=1) + + +class DateGrain(ComparableEnum): + day = timedelta(days=1) + week = timedelta(weeks=1) + month = timedelta(days=30) + quarter = timedelta(days=90) + year = timedelta(days=365) + + +@dataclass(frozen=True) +class Column: + relation: Table | View | Virtual + name: str + + +@dataclass(frozen=True) +class Dimension: + column: Column + name: str + type: type[Type] + grain: TimeGrain | DateGrain | None = None + + def __repr__(self) -> str: + metadata = f"[{self.grain.name}]" if self.grain else "" + return f"{self.type.__name__} {self.name} {metadata}".strip() + + +class FilterTypeEnum(enum.Enum): + WHERE = enum.auto() + HAVING = enum.auto() + + +@dataclass(frozen=True) +class Filter: + type: FilterTypeEnum + expression: str + + +class SortDirectionEnum(enum.Enum): + ASC = enum.auto() + DESC = enum.auto() + + +@dataclass(frozen=True) +class SortField: + field: Metric | Dimension + direction: SortDirectionEnum + nulls_first: bool = True + + +@dataclass(frozen=True) +class Sort: + items: list[SortField] + + +@dataclass(frozen=True) +class Query: + sql: str + + +NoSort = Sort(items=[]) + + +@runtime_checkable +class SemanticLayer(Protocol): + """ + A generic protocol for semantic layers. + """ + + def __init__(self, engine: Engine) -> None: ... + + def get_semantic_views(self) -> set[SemanticView]: + """ + Return a set of the semantic views. + + A semantic view is an organizational group of metrics and dimensions. It's not a + logical grouping, since metrics and dimensions from a given semantic view might + not be compatible. An implementation might expose a single semantic view for + exploration of available metric and dimesnions, and smaller curated semantic + views that are domain specific. + """ + ... + + def get_metrics(self, semantic_view: SemanticView) -> set[Metric]: + """ + Return a set of metrics from a given semantic views. + """ + ... + + def get_dimensions(self, semantic_view: SemanticView) -> set[Dimension]: + """ + Return a set of dimensions from a given semantic views. + """ + ... + + def get_valid_metrics( + self, + semantic_view: SemanticView, + metrics: set[Metric], + dimensions: set[Dimension], + ) -> set[Metric]: + """ + Return compatible metrics for the given metrics and dimensions. + + For metrics to be valid they must be compatible with all the provided + dimensions. + """ + ... + + def get_valid_dimensions( + self, + semantic_view: SemanticView, + metrics: set[Metric], + dimensions: set[Dimension], + ) -> set[Dimension]: + """ + Return compatible dimensions for the given metrics. + + For dimensions to be valid they must be compatible with all the provided + metrics. + """ + ... + + def get_query( + self, + semantic_view: SemanticView, + metrics: set[Metric], + dimensions: set[Dimension], + # populations: set[Population], + filters: set[Filter], + sort: Sort = NoSort, + limit: int | None = None, + offset: int | None = None, + ) -> Query: + """ + Build a SQL query from the given metrics, dimensions, filters, and sort order. + """ + ... + + def get_query_from_standard_sql( + self, + semantic_view: SemanticView, + sql: str, + ) -> Query: + """ + Build a SQL query from a pseudo-query referencing metrics and dimensions. + + For example, given `metric1` having the expression `COUNT(*)`, this query: + + SELECT metric1, dim1 + FROM semantic_layer + GROUP BY dim1 + + Becomes: + + SELECT metric1, dim1 + FROM ( + SELECT COUNT(*) AS metric1, dim1 + FROM fact_table + JOIN dim_table + ON fact_table.dim_id = dim_table.id + GROUP BY dim1 + ) AS semantic_view + + """ + ... + + +TYPE_MAPPING: dict[Type, type[sqltypes.TypeEngine]] = { + # Numeric types + INTEGER: sqltypes.Integer, + NUMBER: sqltypes.Numeric, + DECIMAL: sqltypes.DECIMAL, + # String types + STRING: sqltypes.String, + # Boolean type + BOOLEAN: sqltypes.Boolean, + # Date/time types + DATE: sqltypes.Date, + TIME: sqltypes.Time, + DATETIME: sqltypes.DateTime, + INTERVAL: sqltypes.Interval, + # Complex types + OBJECT: sqltypes.JSON, + BINARY: sqltypes.LargeBinary, +} + + +def get_sqla_type_from_dimension_type( + dimension_type: Type, +) -> sqltypes.TypeEngine: + """ + Get the SQLAlchemy type corresponding to the given dimension type. + """ + return TYPE_MAPPING.get(dimension_type, sqltypes.String)() diff --git a/superset/models/core.py b/superset/models/core.py index 787c380bea..7e9f085aec 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -126,7 +126,9 @@ class ConfigurationMethod(StrEnum): DYNAMIC_FORM = "dynamic_form" -class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods +class Database( + Model, AuditMixinNullable, ImportExportMixin +): # pylint: disable=too-many-public-methods """An ORM object that stores Database related information""" __tablename__ = "dbs" @@ -400,9 +402,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable return ( username if (username := get_username()) - else object_url.username - if self.impersonate_user - else None + else object_url.username if self.impersonate_user else None ) @contextmanager @@ -987,7 +987,10 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable schema=table.schema, ) as inspector: return self.db_engine_spec.get_columns( - inspector, table, self.schema_options + self, + inspector, + table, + self.schema_options, ) def get_metrics( @@ -1076,9 +1079,11 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable return self.perm def has_table(self, table: Table) -> bool: - with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine: - # do not pass "" as an empty schema; force null - return engine.has_table(table.table, table.schema or None) + with self.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: + return self.db_engine_spec.has_table(self, inspector, table) def has_view(self, table: Table) -> bool: with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine:
