This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch dbt-sl in repository https://gitbox.apache.org/repos/asf/superset.git
commit 5c51fe72922fa1802fd24d4747ff62415fe353a9 Author: Beto Dealmeida <[email protected]> AuthorDate: Sat Apr 27 16:16:52 2024 -0400 WIP --- setup.py | 4 +- superset/db_engine_specs/base.py | 45 +++++++- superset/db_engine_specs/dbt.py | 129 ++++++++++++++++++++++ superset/extensions/dbt.py | 231 +++++++++++++++++++++++++++++++++++++++ superset/extensions/metadb.py | 2 +- 5 files changed, 408 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 00b8d22e2a..1e71d6e2f5 100644 --- a/setup.py +++ b/setup.py @@ -64,9 +64,11 @@ setup( "postgres.psycopg2 = sqlalchemy.dialects.postgresql:dialect", "postgres = sqlalchemy.dialects.postgresql:dialect", "superset = superset.extensions.metadb:SupersetAPSWDialect", + "dbt = superset.extensions.dbt:DbtMetricFlowDialect", ], "shillelagh.adapter": [ - "superset=superset.extensions.metadb:SupersetShillelaghAdapter" + "superset = superset.extensions.metadb:SupersetShillelaghAdapter", + "presetdbtmetricflowapi = superset.extensions.dbt:PresetDbtMetricFlowAPI", ], }, download_url="https://www.apache.org/dist/superset/" + version_string, diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 3cc1315129..316bde3b63 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -82,6 +82,7 @@ if TYPE_CHECKING: from superset.connectors.sqla.models import TableColumn from superset.databases.schemas import TableMetadataResponse from superset.models.core import Database + from superset.models.helpers import ExploreMixin from superset.models.sql_lab import Query @@ -131,7 +132,9 @@ builtin_time_grains: dict[str | None, str] = { } -class TimestampExpression(ColumnClause): # pylint: disable=abstract-method, too-many-ancestors +class TimestampExpression( + ColumnClause +): # pylint: disable=abstract-method, too-many-ancestors def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None: """Sqlalchemy class that can be used to render native column elements respecting engine-specific quoting rules as part of a string-based expression. @@ -182,6 +185,15 @@ class MetricType(TypedDict, total=False): extra: str | None +class ValidColumnsType(TypedDict): + """ + Type for valid columns returned by `get_valid_columns`. + """ + + columns: set[str] + metrics: set[str] + + class BaseEngineSpec: # pylint: disable=too-many-public-methods """Abstract class for database engine specific configurations @@ -419,6 +431,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # Driver-specific exception that should be mapped to OAuth2RedirectError oauth2_exception = OAuth2RedirectError + # This attribute is used for semantic layers, where only certain combinations of + # metrics and dimensions are valid for given datasource. For traditional databases + # this should be set to false. + supports_dynamic_columns = False + @classmethod def is_oauth2_enabled(cls) -> bool: return ( @@ -1573,6 +1590,32 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods } ] + @classmethod + def get_valid_columns( + cls, + database: Database, + inspector: Inspector, + datasource: ExploreMixin, + columns: set[str], + metrics: set[str], + ) -> ValidColumnsType: + """ + Given a selection of columns/metrics from a datasource, return related columns. + + This is a method used for semantic layers, where tables can have columns and + metrics that cannot be computed together. When the user selects a given metric + it allows the UI to filter the remaining metrics and dimensions so that only + valid combinations are possible. + + The method should only be called when ``supports_dynamic_columns`` is set to + true. The default method in the base class ignores the selected columns and + metrics, and simply returns everything, for reference. + """ + return { + "columns": {column.column_name for column in datasource.columns}, + "metrics": {metric.metric_name for metric in datasource.metrics}, + } + @classmethod def where_latest_partition( # pylint: disable=unused-argument cls, diff --git a/superset/db_engine_specs/dbt.py b/superset/db_engine_specs/dbt.py new file mode 100644 index 0000000000..932978f63e --- /dev/null +++ b/superset/db_engine_specs/dbt.py @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +An interface to dbt's semantic layer. +""" + +from __future__ import annotations + +from typing import Any, TypedDict, TYPE_CHECKING + +from shillelagh.backends.apsw.dialects.base import get_adapter_for_table_name + +from superset.constants import TimeGrain +from superset.db_engine_specs.base import ValidColumnsType +from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec +from superset.extensions.dbt import TABLE_NAME +from superset.models.helpers import ExploreMixin + +if TYPE_CHECKING: + from superset.models.core import Database + from superset.sql_parse import Table + from sqlalchemy.engine.reflection import Inspector + + +SELECT_STAR_MESSAGE = ( + 'The dbt semantic layer does not support data preview, since the "metrics" table is ' + "a virtual table that is not materialized. An administrator should configure the " + 'database in Apache Superset so that the "Disable SQL Lab data preview queries" ' + 'option under "Advanced" → "SQL Lab" is enabled.' +) + + +class MetricType(TypedDict, total=False): + """ + Type for metrics return by `get_metrics`. + """ + + metric_name: str + expression: str + verbose_name: str | None + metric_type: str | None + description: str | None + d3format: str | None + warning_text: str | None + extra: str | None + + +class DbtMetricFlowEngineSpec(ShillelaghEngineSpec): + """ + Engine for the the dbt semantic layer. + """ + + engine = "dbt" + engine_name = "dbt Semantic Layer" + sqlalchemy_uri_placeholder = "dbt:///<environment_id>?service_token=<service_token>" + + supports_dynamic_columns = True + + _time_grain_expressions = { + TimeGrain.DAY: "{col}__day", + TimeGrain.WEEK: "{col}__week", + TimeGrain.MONTH: "{col}__month", + TimeGrain.QUARTER: "{col}__quarter", + TimeGrain.YEAR: "{col}__year", + } + + @classmethod + def select_star(cls, *args: Any, **kwargs: Any) -> str: + """ + Return a ``SELECT *`` query. + """ + message = SELECT_STAR_MESSAGE.replace("'", "''") + return f"SELECT '{message}' AS warning" + + @classmethod + def get_metrics( + cls, + database: Database, + inspector: Inspector, + table: Table, + ) -> list[MetricType]: + """ + Get all metrics. + """ + connection = inspector.engine.connect() + adapter = get_adapter_for_table_name(connection, table.table) + + return [ + { + "metric_name": metric, + "expression": metric, + "description": description, + } + for metric, description in adapter.metrics.items() + ] + + @classmethod + def get_valid_columns( + cls, + database: Database, + inspector: Inspector, + datasource: ExploreMixin, + columns: set[str], + metrics: set[str], + ) -> ValidColumnsType: + """ + Get valid columns. + """ + connection = inspector.engine.connect() + adapter = get_adapter_for_table_name(connection, TABLE_NAME) + + return { + "metrics": adapter._get_metrics_for_dimensions(columns), + "dimensions": adapter._get_dimensions_for_metrics(metrics), + } diff --git a/superset/extensions/dbt.py b/superset/extensions/dbt.py new file mode 100644 index 0000000000..6ec31c0cd5 --- /dev/null +++ b/superset/extensions/dbt.py @@ -0,0 +1,231 @@ +from __future__ import annotations + +from datetime import timedelta +from typing import Any + +import sqlalchemy.types +from shillelagh.adapters.api.dbt_metricflow import DbtMetricFlowAPI +from shillelagh.backends.apsw.dialects.base import ( + APSWDialect, + get_adapter_for_table_name, +) +from shillelagh.fields import Field +from sqlalchemy.engine.base import Connection +from sqlalchemy.engine.url import URL +from sqlalchemy.sql.visitors import VisitableType + +from superset.extensions import cache_manager +from superset.utils.cache import memoized_func + + +TABLE_NAME = "metrics" + + +def get_sqla_type(field: Field) -> VisitableType: + """ + Convert from Shillelagh to SQLAlchemy types. + """ + type_map = { + "BOOLEAN": sqlalchemy.types.BOOLEAN, + "INTEGER": sqlalchemy.types.INT, + "DECIMAL": sqlalchemy.types.DECIMAL, + "TIMESTAMP": sqlalchemy.types.TIMESTAMP, + "DATE": sqlalchemy.types.DATE, + "TIME": sqlalchemy.types.TIME, + "TEXT": sqlalchemy.types.TEXT, + } + + return type_map.get(field.type, sqlalchemy.types.TEXT) + + +class PresetDbtMetricFlowAPI(DbtMetricFlowAPI): + """ + Custom API adapter for dbt Metric Flow API. + + In the original adapter, the SQL queries a base dbt API URL, eg: + + SELECT * FROM "https://semantic-layer.cloud.getdbt.com/"; + SELECT * FROM "https://ab123.us1.dbt.com/"; -- custom user URL + + For this adapter, we want a leaner URI, mimicking a table: + + SELECT * FROM metrics; + + In order to do this, we override the ``supports`` method to only accept + ``$TABLE_NAME`` instead of the URL, which is then passed to the adapter when it is + instantiated. + + One problem with this change is that the adapter needs the base URL in order to + determine the GraphQL endpoint. To solve this we pass the original URL via a new + argument ``url``, and override the ``_get_endpoint`` method to use it instead of the + table name. + """ + + @staticmethod + def supports(uri: str, fast: bool = True, **kwargs: Any) -> bool: + return uri == TABLE_NAME + + def __init__( + self, + table: str, + service_token: str, + environment_id: int, + url: str, + ) -> None: + self.url = url + super().__init__(table, service_token, environment_id) + + def _get_endpoint(self, url: str) -> str: + """ + Compute the GraphQL endpoint. + + Instead of using ``url`` (which points to ``TABLE_NAME`` in this adapter), we + should call the method using the actual dbt API base URL. + """ + return super()._get_endpoint(self.url) + + def _build_column_from_dimension(self, name: str) -> Field: + """ + Build a Shillelagh column from a dbt dimension. + + This method is terrible slow, since it needs to do a full data request for each + dimension in order to determine their types. To improve UX we cache the results + for one day. + """ + return self._cached_build_column_from_dimension( + name, + cache_timeout=int(timedelta(days=1).total_seconds()), + ) + + @memoized_func(key="dbt:dimension:{name}", cache=cache_manager.data_cache) + def _cached_build_column_from_dimension( + self, + name: str, + *args: Any, + **kwargs: Any, + ) -> Field: + """ + Cached version of ``_build_column_from_dimension``. + """ + return super()._build_column_from_dimension(name) + + +class DbtMetricFlowDialect(APSWDialect): + """ + A dbt Metric Flow dialect. + + URL should look like: + + dbt:///<environment_id>?service_token=<service_token> + + Or when using a custom URL: + + dbt://ab123.us1.dbt.com/<environment_id>?service_token=<service_token> + + """ + + name = "dbt" + + supports_statement_cache = True + + def create_connect_args(self, url: URL) -> tuple[tuple[()], dict[str, Any]]: + baseurl = ( + f"https://{url.host}/" + if url.host + else "https://semantic-layer.cloud.getdbt.com/" + ) + + return ( + (), + { + "path": ":memory:", + "adapters": ["presetdbtmetricflowapi"], + "adapter_kwargs": { + "presetdbtmetricflowapi": { + "service_token": url.query["service_token"], + "environment_id": int(url.database), + "url": baseurl, + }, + }, + "safe": True, + "isolation_level": self.isolation_level, + }, + ) + + def get_table_names( + self, + connection: Connection, + schema: str | None = None, + sqlite_include_internal: bool = False, + **kwargs: Any, + ) -> list[str]: + return [TABLE_NAME] + + def has_table( + self, + connection: Connection, + table_name: str, + schema: str | None = None, + **kwargs: Any, + ) -> bool: + return table_name == TABLE_NAME + + def get_columns( + self, + connection: Connection, + table_name: str, + schema: str | None = None, + **kwargs: Any, + ) -> list[tuple[str, str]]: + adapter = get_adapter_for_table_name(connection, table_name) + + columns = { + adapter.grains[dimension][0] + if dimension in adapter.grains + else dimension: adapter.columns[dimension] + for dimension in adapter.dimensions + } + + return [ + { + "name": name, + "type": get_sqla_type(field), + "nullable": True, + "default": None, + } + for name, field in columns.items() + ] + + def get_schema_names( + self, + connection: Connection, + **kwargs: Any, + ) -> list[str]: + return ["main"] + + def get_pk_constraint( + self, + connection: Connection, + table_name: str, + schema: str | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + return {"constrained_columns": [], "name": None} + + def get_foreign_keys( + self, + connection: Connection, + table_name: str, + schema: str | None = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + return [] + + get_check_constraints = get_foreign_keys + get_indexes = get_foreign_keys + get_unique_constraints = get_foreign_keys + + def get_table_comment(self, connection, table_name, schema=None, **kwargs): + return { + "text": "A virtual table that gives access to all dbt metrics and dimensions." + } diff --git a/superset/extensions/metadb.py b/superset/extensions/metadb.py index 2d8444cc99..353b8279d2 100644 --- a/superset/extensions/metadb.py +++ b/superset/extensions/metadb.py @@ -111,7 +111,7 @@ class SupersetAPSWDialect(APSWDialect): "superset": { "prefix": None, "allowed_dbs": self.allowed_dbs, - } + }, }, "safe": True, "isolation_level": self.isolation_level,
