This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch dbt-metricflow in repository https://gitbox.apache.org/repos/asf/superset.git
commit 214f0fa5a5d0f1f9cadfa32d3020254d7eb5efdb Author: Beto Dealmeida <robe...@dealmeida.net> AuthorDate: Sat Apr 27 16:16:52 2024 -0400 WIP --- setup.py | 8 +- superset/db_engine_specs/base.py | 51 ++++++- superset/db_engine_specs/metricflow.py | 165 +++++++++++++++++++++ superset/extensions/metadb.py | 2 +- superset/extensions/metricflow.py | 253 +++++++++++++++++++++++++++++++++ superset/sql/parse.py | 1 + 6 files changed, 473 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index 9b834cc241..c3822ccf1d 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,9 @@ with open(PACKAGE_JSON) as package_file: def get_git_sha() -> str: try: - output = subprocess.check_output(["git", "rev-parse", "HEAD"]) # noqa: S603, S607 + output = subprocess.check_output( + ["git", "rev-parse", "HEAD"] + ) # noqa: S603, S607 return output.decode().strip() except Exception: # pylint: disable=broad-except return "" @@ -65,9 +67,11 @@ setup( "postgres.psycopg2 = sqlalchemy.dialects.postgresql:dialect", "postgres = sqlalchemy.dialects.postgresql:dialect", "superset = superset.extensions.metadb:SupersetAPSWDialect", + "metricflow = superset.extensions.metricflow:MetricFlowDialect", ], "shillelagh.adapter": [ - "superset=superset.extensions.metadb:SupersetShillelaghAdapter" + "superset = superset.extensions.metadb:SupersetShillelaghAdapter", + "supersetmetricflowapi = superset.extensions.metricflow:SupersetMetricFlowAPI", ], }, download_url="https://www.apache.org/dist/superset/" + version_string, diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index e3ffdd335d..4014f1b8dc 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -88,6 +88,7 @@ if TYPE_CHECKING: from superset.connectors.sqla.models import TableColumn from superset.databases.schemas import TableMetadataResponse from superset.models.core import Database + from superset.models.helpers import ExploreMixin from superset.models.sql_lab import Query @@ -143,7 +144,9 @@ builtin_time_grains: dict[str | None, str] = { } -class TimestampExpression(ColumnClause): # pylint: disable=abstract-method, too-many-ancestors +class TimestampExpression( + ColumnClause +): # pylint: disable=abstract-method, too-many-ancestors def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None: """Sqlalchemy class that can be used to render native column elements respecting engine-specific quoting rules as part of a string-based expression. @@ -186,6 +189,15 @@ class MetricType(TypedDict, total=False): extra: str | None +class ValidColumnsType(TypedDict): + """ + Type for valid columns returned by `get_valid_columns`. + """ + + columns: set[str] + metrics: set[str] + + class BaseEngineSpec: # pylint: disable=too-many-public-methods """Abstract class for database engine specific configurations @@ -384,9 +396,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods max_column_name_length: int | None = None try_remove_schema_from_table_name = True # pylint: disable=invalid-name run_multiple_statements_as_one = False - custom_errors: dict[ - Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]] - ] = {} + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = ( + {} + ) # List of JSON path to fields in `encrypted_extra` that should be masked when the # database is edited. By default everything is masked. @@ -436,6 +448,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # the `cancel_query` value in the `extra` field of the `query` object has_query_id_before_execute = True + # This attribute is used for semantic layers, where only certain combinations of + # metrics and dimensions are valid for given datasource. For traditional databases + # this should be set to false. + supports_dynamic_columns = False + @classmethod def get_rls_method(cls) -> RLSMethod: """ @@ -1501,6 +1518,31 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods } ] + @classmethod + def get_valid_columns( + cls, + database: Database, + datasource: ExploreMixin, + columns: set[str], + metrics: set[str], + ) -> ValidColumnsType: + """ + Given a selection of columns/metrics from a datasource, return related columns. + + This is a method used for semantic layers, where tables can have columns and + metrics that cannot be computed together. When the user selects a given metric + it allows the UI to filter the remaining metrics and dimensions so that only + valid combinations are possible. + + The method should only be called when ``supports_dynamic_columns`` is set to + true. The default method in the base class ignores the selected columns and + metrics, and simply returns everything, for reference. + """ + return { + "columns": {column.column_name for column in datasource.columns}, + "metrics": {metric.metric_name for metric in datasource.metrics}, + } + @classmethod def where_latest_partition( # pylint: disable=unused-argument cls, @@ -2148,6 +2190,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods "supports_file_upload": cls.supports_file_upload, "disable_ssh_tunneling": cls.disable_ssh_tunneling, "supports_dynamic_catalog": cls.supports_dynamic_catalog, + "supports_dynamic_columns": cls.supports_dynamic_columns, "supports_oauth2": cls.supports_oauth2, } diff --git a/superset/db_engine_specs/metricflow.py b/superset/db_engine_specs/metricflow.py new file mode 100644 index 0000000000..77ea1552bf --- /dev/null +++ b/superset/db_engine_specs/metricflow.py @@ -0,0 +1,165 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +An interface to dbt's semantic layer, Metric Flow. +""" + +from __future__ import annotations + +from typing import Any, TYPE_CHECKING, TypedDict + +from shillelagh.backends.apsw.dialects.base import get_adapter_for_table_name + +from superset.constants import TimeGrain +from superset.db_engine_specs.base import ValidColumnsType +from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec +from superset.extensions.metricflow import TABLE_NAME +from superset.models.helpers import ExploreMixin +from superset.superset_typing import ResultSetColumnType + +if TYPE_CHECKING: + from sqlalchemy.engine.reflection import Inspector + + from superset.models.core import Database + from superset.sql_parse import Table + + +SELECT_STAR_MESSAGE = ( + 'The dbt semantic layer does not support data preview, since the "metrics" table ' + "is a virtual table that is not materialized. An administrator should configure " + 'the database in Apache Superset so that the "Disable SQL Lab data preview ' + 'queries" option under "Advanced" → "SQL Lab" is enabled.' +) + + +class MetricType(TypedDict, total=False): + """ + Type for metrics returned by `get_metrics`. + """ + + metric_name: str + expression: str + verbose_name: str | None + metric_type: str | None + description: str | None + d3format: str | None + warning_text: str | None + extra: str | None + + +class DbtMetricFlowEngineSpec(ShillelaghEngineSpec): + """ + Engine for the the dbt semantic layer. + """ + + engine = "metricflow" + engine_name = "dbt Metric Flow" + sqlalchemy_uri_placeholder = ( + "metricflow://[ab123.us1.dbt.com]/<environment_id>" + "?service_token=<service_token>" + ) + + supports_dynamic_columns = True + + _time_grain_expressions = { + TimeGrain.DAY: "{col}__day", + TimeGrain.WEEK: "{col}__week", + TimeGrain.MONTH: "{col}__month", + TimeGrain.QUARTER: "{col}__quarter", + TimeGrain.YEAR: "{col}__year", + } + + @classmethod + def select_star(cls, *args: Any, **kwargs: Any) -> str: + """ + Return a ``SELECT *`` query. + """ + message = SELECT_STAR_MESSAGE.replace("'", "''") + return f"SELECT '{message}' AS warning" + + @classmethod + def get_columns( + cls, + inspector: Inspector, + table: Table, + options: dict[str, Any] | None = None, + ) -> list[ResultSetColumnType]: + """ + Get columns. + + This method enriches the method from the SQLAlchemy dialect to include the + dimension descriptions. + """ + connection = inspector.engine.connect() + adapter = get_adapter_for_table_name(connection, table.table) + + return [ + { + "name": column["name"], + "column_name": column["name"], + "type": column["type"], + "nullable": column["nullable"], + "default": column["default"], + "comment": adapter.dimensions.get(column["name"], ""), + } + for column in inspector.get_columns(table.table, table.schema) + ] + + @classmethod + def get_metrics( + cls, + database: Database, + inspector: Inspector, + table: Table, + ) -> list[MetricType]: + """ + Get all metrics. + """ + connection = inspector.engine.connect() + adapter = get_adapter_for_table_name(connection, table.table) + + return [ + { + "metric_name": metric, + "expression": metric, + "description": description, + } + for metric, description in adapter.metrics.items() + ] + + @classmethod + def get_valid_columns( + cls, + database: Database, + datasource: ExploreMixin, + columns: set[str], + metrics: set[str], + ) -> ValidColumnsType: + """ + Get valid columns. + + Given a datasource, and sets of selected metrics and dimensions, return the + sets of valid metrics and dimensions that can further be selected. + """ + with database.get_sqla_engine() as engine: + connection = engine.connect() + adapter = get_adapter_for_table_name(connection, TABLE_NAME) + + return { + "metrics": adapter._get_metrics_for_dimensions(columns), + "dimensions": adapter._get_dimensions_for_metrics(metrics), + } diff --git a/superset/extensions/metadb.py b/superset/extensions/metadb.py index 375b4b87ce..55bbf449e9 100644 --- a/superset/extensions/metadb.py +++ b/superset/extensions/metadb.py @@ -112,7 +112,7 @@ class SupersetAPSWDialect(APSWDialect): "superset": { "prefix": None, "allowed_dbs": self.allowed_dbs, - } + }, }, "safe": True, "isolation_level": self.isolation_level, diff --git a/superset/extensions/metricflow.py b/superset/extensions/metricflow.py new file mode 100644 index 0000000000..451feae503 --- /dev/null +++ b/superset/extensions/metricflow.py @@ -0,0 +1,253 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +A SQLAlchemy dialect for dbt Metric Flow. +""" + +from __future__ import annotations + +from datetime import timedelta +from typing import Any + +import sqlalchemy.types +from shillelagh.adapters.api.dbt_metricflow import DbtMetricFlowAPI +from shillelagh.backends.apsw.dialects.base import ( + APSWDialect, + get_adapter_for_table_name, +) +from shillelagh.fields import Field +from sqlalchemy.engine.base import Connection +from sqlalchemy.engine.url import URL +from sqlalchemy.sql.visitors import VisitableType + +from superset.extensions import cache_manager +from superset.utils.cache import memoized_func + +TABLE_NAME = "metrics" + + +def get_sqla_type(field: Field) -> VisitableType: + """ + Convert from Shillelagh to SQLAlchemy types. + """ + type_map = { + "BOOLEAN": sqlalchemy.types.BOOLEAN, + "INTEGER": sqlalchemy.types.INT, + "DECIMAL": sqlalchemy.types.DECIMAL, + "TIMESTAMP": sqlalchemy.types.TIMESTAMP, + "DATE": sqlalchemy.types.DATE, + "TIME": sqlalchemy.types.TIME, + "TEXT": sqlalchemy.types.TEXT, + } + + return type_map.get(field.type, sqlalchemy.types.TEXT) + + +class SupersetMetricFlowAPI(DbtMetricFlowAPI): + """ + Custom API adapter for dbt Metric Flow API. + + In the original adapter, the SQL queries a base dbt API URL, eg: + + SELECT * FROM "https://semantic-layer.cloud.getdbt.com/"; + SELECT * FROM "https://ab123.us1.dbt.com/"; -- custom user URL + + For this adapter, we want a leaner URI, mimicking a table: + + SELECT * FROM metrics; + + In order to do this, we override the ``supports`` method to only accept + ``$TABLE_NAME`` instead of the URL, which is then passed to the adapter when it is + instantiated. + + One problem with this change is that the adapter needs the base URL in order to + determine the GraphQL endpoint. To solve this we pass the original URL via a new + argument ``url``, and override the ``_get_endpoint`` method to use it instead of the + table name. + """ + + @staticmethod + def supports(uri: str, fast: bool = True, **kwargs: Any) -> bool: + return uri == TABLE_NAME + + def __init__( + self, + table: str, + service_token: str, + environment_id: int, + url: str, + ) -> None: + self.url = url + super().__init__(table, service_token, environment_id) + + def _get_endpoint(self, url: str) -> str: + """ + Compute the GraphQL endpoint. + + Instead of using ``url`` (which points to ``TABLE_NAME`` in this adapter), we + should call the method using the actual dbt API base URL. + """ + return super()._get_endpoint(self.url) + + def _build_column_from_dimension(self, name: str) -> Field: + """ + Build a Shillelagh column from a dbt dimension. + + This method is terribly slow, since it needs to do a full data request for each + dimension in order to determine their types. To improve UX we cache the results + for one day. + """ + return self._cached_build_column_from_dimension( + name, + cache_timeout=int(timedelta(days=1).total_seconds()), + ) + + @memoized_func(key="metricflow:dimension:{name}", cache=cache_manager.data_cache) + def _cached_build_column_from_dimension( + self, + name: str, + *args: Any, + **kwargs: Any, + ) -> Field: + """ + Cached version of ``_build_column_from_dimension``. + """ + return super()._build_column_from_dimension(name) + + +class MetricFlowDialect(APSWDialect): + """ + A dbt Metric Flow dialect. + + URL should look like: + + metricflow:///<environment_id>?service_token=<service_token> + + Or when using a custom URL: + + metricflow://ab123.us1.dbt.com/<environment_id>?service_token=<service_token> + + """ + + name = "metricflow" + + supports_statement_cache = True + + def create_connect_args(self, url: URL) -> tuple[tuple[()], dict[str, Any]]: + baseurl = ( + f"https://{url.host}/" + if url.host + else "https://semantic-layer.cloud.getdbt.com/" + ) + + return ( + (), + { + "path": ":memory:", + "adapters": ["supersetmetricflowapi"], + "adapter_kwargs": { + "supersetmetricflowapi": { + "service_token": url.query["service_token"], + "environment_id": int(url.database), + "url": baseurl, + }, + }, + "safe": True, + "isolation_level": self.isolation_level, + }, + ) + + def get_table_names( + self, + connection: Connection, + schema: str | None = None, + sqlite_include_internal: bool = False, + **kwargs: Any, + ) -> list[str]: + return [TABLE_NAME] + + def has_table( + self, + connection: Connection, + table_name: str, + schema: str | None = None, + **kwargs: Any, + ) -> bool: + return table_name == TABLE_NAME + + def get_columns( + self, + connection: Connection, + table_name: str, + schema: str | None = None, + **kwargs: Any, + ) -> list[tuple[str, str]]: + adapter = get_adapter_for_table_name(connection, table_name) + + columns = { + ( + adapter.grains[dimension][0] + if dimension in adapter.grains + else dimension + ): adapter.columns[dimension] + for dimension in adapter.dimensions + } + + return [ + { + "name": name, + "type": get_sqla_type(field), + "nullable": True, + "default": None, + } + for name, field in columns.items() + ] + + def get_schema_names( + self, + connection: Connection, + **kwargs: Any, + ) -> list[str]: + return ["main"] + + def get_pk_constraint( + self, + connection: Connection, + table_name: str, + schema: str | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + return {"constrained_columns": [], "name": None} + + def get_foreign_keys( + self, + connection: Connection, + table_name: str, + schema: str | None = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + return [] + + get_check_constraints = get_foreign_keys + get_indexes = get_foreign_keys + get_unique_constraints = get_foreign_keys + + def get_table_comment(self, connection, table_name, schema=None, **kwargs): + return { + "text": "A virtual table that gives access to all dbt metrics & dimensions." + } diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 54d1006b93..fce01a053e 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -85,6 +85,7 @@ SQLGLOT_DIALECTS = { # "kustosql": ??? # "kylin": ??? "mariadb": Dialects.MYSQL, + "metricflow": Dialects.SQLITE, "motherduck": Dialects.DUCKDB, "mssql": Dialects.TSQL, "mysql": Dialects.MYSQL,