This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch dbt-metricflow
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 214f0fa5a5d0f1f9cadfa32d3020254d7eb5efdb
Author: Beto Dealmeida <robe...@dealmeida.net>
AuthorDate: Sat Apr 27 16:16:52 2024 -0400

    WIP
---
 setup.py                               |   8 +-
 superset/db_engine_specs/base.py       |  51 ++++++-
 superset/db_engine_specs/metricflow.py | 165 +++++++++++++++++++++
 superset/extensions/metadb.py          |   2 +-
 superset/extensions/metricflow.py      | 253 +++++++++++++++++++++++++++++++++
 superset/sql/parse.py                  |   1 +
 6 files changed, 473 insertions(+), 7 deletions(-)

diff --git a/setup.py b/setup.py
index 9b834cc241..c3822ccf1d 100644
--- a/setup.py
+++ b/setup.py
@@ -30,7 +30,9 @@ with open(PACKAGE_JSON) as package_file:
 
 def get_git_sha() -> str:
     try:
-        output = subprocess.check_output(["git", "rev-parse", "HEAD"])  # 
noqa: S603, S607
+        output = subprocess.check_output(
+            ["git", "rev-parse", "HEAD"]
+        )  # noqa: S603, S607
         return output.decode().strip()
     except Exception:  # pylint: disable=broad-except
         return ""
@@ -65,9 +67,11 @@ setup(
             "postgres.psycopg2 = sqlalchemy.dialects.postgresql:dialect",
             "postgres = sqlalchemy.dialects.postgresql:dialect",
             "superset = superset.extensions.metadb:SupersetAPSWDialect",
+            "metricflow = superset.extensions.metricflow:MetricFlowDialect",
         ],
         "shillelagh.adapter": [
-            "superset=superset.extensions.metadb:SupersetShillelaghAdapter"
+            "superset = superset.extensions.metadb:SupersetShillelaghAdapter",
+            "supersetmetricflowapi = 
superset.extensions.metricflow:SupersetMetricFlowAPI",
         ],
     },
     download_url="https://www.apache.org/dist/superset/"; + version_string,
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index e3ffdd335d..4014f1b8dc 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -88,6 +88,7 @@ if TYPE_CHECKING:
     from superset.connectors.sqla.models import TableColumn
     from superset.databases.schemas import TableMetadataResponse
     from superset.models.core import Database
+    from superset.models.helpers import ExploreMixin
     from superset.models.sql_lab import Query
 
 
@@ -143,7 +144,9 @@ builtin_time_grains: dict[str | None, str] = {
 }
 
 
-class TimestampExpression(ColumnClause):  # pylint: disable=abstract-method, 
too-many-ancestors
+class TimestampExpression(
+    ColumnClause
+):  # pylint: disable=abstract-method, too-many-ancestors
     def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
         """Sqlalchemy class that can be used to render native column elements 
respecting
         engine-specific quoting rules as part of a string-based expression.
@@ -186,6 +189,15 @@ class MetricType(TypedDict, total=False):
     extra: str | None
 
 
+class ValidColumnsType(TypedDict):
+    """
+    Type for valid columns returned by `get_valid_columns`.
+    """
+
+    columns: set[str]
+    metrics: set[str]
+
+
 class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     """Abstract class for database engine specific configurations
 
@@ -384,9 +396,9 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
     max_column_name_length: int | None = None
     try_remove_schema_from_table_name = True  # pylint: disable=invalid-name
     run_multiple_statements_as_one = False
-    custom_errors: dict[
-        Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]
-    ] = {}
+    custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, 
Any]]] = (
+        {}
+    )
 
     # List of JSON path to fields in `encrypted_extra` that should be masked 
when the
     # database is edited. By default everything is masked.
@@ -436,6 +448,11 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
     # the `cancel_query` value in the `extra` field of the `query` object
     has_query_id_before_execute = True
 
+    # This attribute is used for semantic layers, where only certain 
combinations of
+    # metrics and dimensions are valid for given datasource. For traditional 
databases
+    # this should be set to false.
+    supports_dynamic_columns = False
+
     @classmethod
     def get_rls_method(cls) -> RLSMethod:
         """
@@ -1501,6 +1518,31 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
             }
         ]
 
+    @classmethod
+    def get_valid_columns(
+        cls,
+        database: Database,
+        datasource: ExploreMixin,
+        columns: set[str],
+        metrics: set[str],
+    ) -> ValidColumnsType:
+        """
+        Given a selection of columns/metrics from a datasource, return related 
columns.
+
+        This is a method used for semantic layers, where tables can have 
columns and
+        metrics that cannot be computed together. When the user selects a 
given metric
+        it allows the UI to filter the remaining metrics and dimensions so 
that only
+        valid combinations are possible.
+
+        The method should only be called when ``supports_dynamic_columns`` is 
set to
+        true. The default method in the base class ignores the selected 
columns and
+        metrics, and simply returns everything, for reference.
+        """
+        return {
+            "columns": {column.column_name for column in datasource.columns},
+            "metrics": {metric.metric_name for metric in datasource.metrics},
+        }
+
     @classmethod
     def where_latest_partition(  # pylint: disable=unused-argument
         cls,
@@ -2148,6 +2190,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
             "supports_file_upload": cls.supports_file_upload,
             "disable_ssh_tunneling": cls.disable_ssh_tunneling,
             "supports_dynamic_catalog": cls.supports_dynamic_catalog,
+            "supports_dynamic_columns": cls.supports_dynamic_columns,
             "supports_oauth2": cls.supports_oauth2,
         }
 
diff --git a/superset/db_engine_specs/metricflow.py 
b/superset/db_engine_specs/metricflow.py
new file mode 100644
index 0000000000..77ea1552bf
--- /dev/null
+++ b/superset/db_engine_specs/metricflow.py
@@ -0,0 +1,165 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+An interface to dbt's semantic layer, Metric Flow.
+"""
+
+from __future__ import annotations
+
+from typing import Any, TYPE_CHECKING, TypedDict
+
+from shillelagh.backends.apsw.dialects.base import get_adapter_for_table_name
+
+from superset.constants import TimeGrain
+from superset.db_engine_specs.base import ValidColumnsType
+from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec
+from superset.extensions.metricflow import TABLE_NAME
+from superset.models.helpers import ExploreMixin
+from superset.superset_typing import ResultSetColumnType
+
+if TYPE_CHECKING:
+    from sqlalchemy.engine.reflection import Inspector
+
+    from superset.models.core import Database
+    from superset.sql_parse import Table
+
+
+SELECT_STAR_MESSAGE = (
+    'The dbt semantic layer does not support data preview, since the "metrics" 
table '
+    "is a virtual table that is not materialized. An administrator should 
configure "
+    'the database in Apache Superset so that the "Disable SQL Lab data preview 
'
+    'queries" option under "Advanced" → "SQL Lab" is enabled.'
+)
+
+
+class MetricType(TypedDict, total=False):
+    """
+    Type for metrics returned by `get_metrics`.
+    """
+
+    metric_name: str
+    expression: str
+    verbose_name: str | None
+    metric_type: str | None
+    description: str | None
+    d3format: str | None
+    warning_text: str | None
+    extra: str | None
+
+
+class DbtMetricFlowEngineSpec(ShillelaghEngineSpec):
+    """
+    Engine for the the dbt semantic layer.
+    """
+
+    engine = "metricflow"
+    engine_name = "dbt Metric Flow"
+    sqlalchemy_uri_placeholder = (
+        "metricflow://[ab123.us1.dbt.com]/<environment_id>"
+        "?service_token=<service_token>"
+    )
+
+    supports_dynamic_columns = True
+
+    _time_grain_expressions = {
+        TimeGrain.DAY: "{col}__day",
+        TimeGrain.WEEK: "{col}__week",
+        TimeGrain.MONTH: "{col}__month",
+        TimeGrain.QUARTER: "{col}__quarter",
+        TimeGrain.YEAR: "{col}__year",
+    }
+
+    @classmethod
+    def select_star(cls, *args: Any, **kwargs: Any) -> str:
+        """
+        Return a ``SELECT *`` query.
+        """
+        message = SELECT_STAR_MESSAGE.replace("'", "''")
+        return f"SELECT '{message}' AS warning"
+
+    @classmethod
+    def get_columns(
+        cls,
+        inspector: Inspector,
+        table: Table,
+        options: dict[str, Any] | None = None,
+    ) -> list[ResultSetColumnType]:
+        """
+        Get columns.
+
+        This method enriches the method from the SQLAlchemy dialect to include 
the
+        dimension descriptions.
+        """
+        connection = inspector.engine.connect()
+        adapter = get_adapter_for_table_name(connection, table.table)
+
+        return [
+            {
+                "name": column["name"],
+                "column_name": column["name"],
+                "type": column["type"],
+                "nullable": column["nullable"],
+                "default": column["default"],
+                "comment": adapter.dimensions.get(column["name"], ""),
+            }
+            for column in inspector.get_columns(table.table, table.schema)
+        ]
+
+    @classmethod
+    def get_metrics(
+        cls,
+        database: Database,
+        inspector: Inspector,
+        table: Table,
+    ) -> list[MetricType]:
+        """
+        Get all metrics.
+        """
+        connection = inspector.engine.connect()
+        adapter = get_adapter_for_table_name(connection, table.table)
+
+        return [
+            {
+                "metric_name": metric,
+                "expression": metric,
+                "description": description,
+            }
+            for metric, description in adapter.metrics.items()
+        ]
+
+    @classmethod
+    def get_valid_columns(
+        cls,
+        database: Database,
+        datasource: ExploreMixin,
+        columns: set[str],
+        metrics: set[str],
+    ) -> ValidColumnsType:
+        """
+        Get valid columns.
+
+        Given a datasource, and sets of selected metrics and dimensions, 
return the
+        sets of valid metrics and dimensions that can further be selected.
+        """
+        with database.get_sqla_engine() as engine:
+            connection = engine.connect()
+            adapter = get_adapter_for_table_name(connection, TABLE_NAME)
+
+        return {
+            "metrics": adapter._get_metrics_for_dimensions(columns),
+            "dimensions": adapter._get_dimensions_for_metrics(metrics),
+        }
diff --git a/superset/extensions/metadb.py b/superset/extensions/metadb.py
index 375b4b87ce..55bbf449e9 100644
--- a/superset/extensions/metadb.py
+++ b/superset/extensions/metadb.py
@@ -112,7 +112,7 @@ class SupersetAPSWDialect(APSWDialect):
                     "superset": {
                         "prefix": None,
                         "allowed_dbs": self.allowed_dbs,
-                    }
+                    },
                 },
                 "safe": True,
                 "isolation_level": self.isolation_level,
diff --git a/superset/extensions/metricflow.py 
b/superset/extensions/metricflow.py
new file mode 100644
index 0000000000..451feae503
--- /dev/null
+++ b/superset/extensions/metricflow.py
@@ -0,0 +1,253 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+A SQLAlchemy dialect for dbt Metric Flow.
+"""
+
+from __future__ import annotations
+
+from datetime import timedelta
+from typing import Any
+
+import sqlalchemy.types
+from shillelagh.adapters.api.dbt_metricflow import DbtMetricFlowAPI
+from shillelagh.backends.apsw.dialects.base import (
+    APSWDialect,
+    get_adapter_for_table_name,
+)
+from shillelagh.fields import Field
+from sqlalchemy.engine.base import Connection
+from sqlalchemy.engine.url import URL
+from sqlalchemy.sql.visitors import VisitableType
+
+from superset.extensions import cache_manager
+from superset.utils.cache import memoized_func
+
+TABLE_NAME = "metrics"
+
+
+def get_sqla_type(field: Field) -> VisitableType:
+    """
+    Convert from Shillelagh to SQLAlchemy types.
+    """
+    type_map = {
+        "BOOLEAN": sqlalchemy.types.BOOLEAN,
+        "INTEGER": sqlalchemy.types.INT,
+        "DECIMAL": sqlalchemy.types.DECIMAL,
+        "TIMESTAMP": sqlalchemy.types.TIMESTAMP,
+        "DATE": sqlalchemy.types.DATE,
+        "TIME": sqlalchemy.types.TIME,
+        "TEXT": sqlalchemy.types.TEXT,
+    }
+
+    return type_map.get(field.type, sqlalchemy.types.TEXT)
+
+
+class SupersetMetricFlowAPI(DbtMetricFlowAPI):
+    """
+    Custom API adapter for dbt Metric Flow API.
+
+    In the original adapter, the SQL queries a base dbt API URL, eg:
+
+        SELECT * FROM "https://semantic-layer.cloud.getdbt.com/";;
+        SELECT * FROM "https://ab123.us1.dbt.com/";;  -- custom user URL
+
+    For this adapter, we want a leaner URI, mimicking a table:
+
+        SELECT * FROM metrics;
+
+    In order to do this, we override the ``supports`` method to only accept
+    ``$TABLE_NAME`` instead of the URL, which is then passed to the adapter 
when it is
+    instantiated.
+
+    One problem with this change is that the adapter needs the base URL in 
order to
+    determine the GraphQL endpoint. To solve this we pass the original URL via 
a new
+    argument ``url``, and override the ``_get_endpoint`` method to use it 
instead of the
+    table name.
+    """
+
+    @staticmethod
+    def supports(uri: str, fast: bool = True, **kwargs: Any) -> bool:
+        return uri == TABLE_NAME
+
+    def __init__(
+        self,
+        table: str,
+        service_token: str,
+        environment_id: int,
+        url: str,
+    ) -> None:
+        self.url = url
+        super().__init__(table, service_token, environment_id)
+
+    def _get_endpoint(self, url: str) -> str:
+        """
+        Compute the GraphQL endpoint.
+
+        Instead of using ``url`` (which points to ``TABLE_NAME`` in this 
adapter), we
+        should call the method using the actual dbt API base URL.
+        """
+        return super()._get_endpoint(self.url)
+
+    def _build_column_from_dimension(self, name: str) -> Field:
+        """
+        Build a Shillelagh column from a dbt dimension.
+
+        This method is terribly slow, since it needs to do a full data request 
for each
+        dimension in order to determine their types. To improve UX we cache 
the results
+        for one day.
+        """
+        return self._cached_build_column_from_dimension(
+            name,
+            cache_timeout=int(timedelta(days=1).total_seconds()),
+        )
+
+    @memoized_func(key="metricflow:dimension:{name}", 
cache=cache_manager.data_cache)
+    def _cached_build_column_from_dimension(
+        self,
+        name: str,
+        *args: Any,
+        **kwargs: Any,
+    ) -> Field:
+        """
+        Cached version of ``_build_column_from_dimension``.
+        """
+        return super()._build_column_from_dimension(name)
+
+
+class MetricFlowDialect(APSWDialect):
+    """
+    A dbt Metric Flow dialect.
+
+    URL should look like:
+
+        metricflow:///<environment_id>?service_token=<service_token>
+
+    Or when using a custom URL:
+
+        
metricflow://ab123.us1.dbt.com/<environment_id>?service_token=<service_token>
+
+    """
+
+    name = "metricflow"
+
+    supports_statement_cache = True
+
+    def create_connect_args(self, url: URL) -> tuple[tuple[()], dict[str, 
Any]]:
+        baseurl = (
+            f"https://{url.host}/";
+            if url.host
+            else "https://semantic-layer.cloud.getdbt.com/";
+        )
+
+        return (
+            (),
+            {
+                "path": ":memory:",
+                "adapters": ["supersetmetricflowapi"],
+                "adapter_kwargs": {
+                    "supersetmetricflowapi": {
+                        "service_token": url.query["service_token"],
+                        "environment_id": int(url.database),
+                        "url": baseurl,
+                    },
+                },
+                "safe": True,
+                "isolation_level": self.isolation_level,
+            },
+        )
+
+    def get_table_names(
+        self,
+        connection: Connection,
+        schema: str | None = None,
+        sqlite_include_internal: bool = False,
+        **kwargs: Any,
+    ) -> list[str]:
+        return [TABLE_NAME]
+
+    def has_table(
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: str | None = None,
+        **kwargs: Any,
+    ) -> bool:
+        return table_name == TABLE_NAME
+
+    def get_columns(
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: str | None = None,
+        **kwargs: Any,
+    ) -> list[tuple[str, str]]:
+        adapter = get_adapter_for_table_name(connection, table_name)
+
+        columns = {
+            (
+                adapter.grains[dimension][0]
+                if dimension in adapter.grains
+                else dimension
+            ): adapter.columns[dimension]
+            for dimension in adapter.dimensions
+        }
+
+        return [
+            {
+                "name": name,
+                "type": get_sqla_type(field),
+                "nullable": True,
+                "default": None,
+            }
+            for name, field in columns.items()
+        ]
+
+    def get_schema_names(
+        self,
+        connection: Connection,
+        **kwargs: Any,
+    ) -> list[str]:
+        return ["main"]
+
+    def get_pk_constraint(
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: str | None = None,
+        **kwargs: Any,
+    ) -> dict[str, Any]:
+        return {"constrained_columns": [], "name": None}
+
+    def get_foreign_keys(
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: str | None = None,
+        **kwargs: Any,
+    ) -> list[dict[str, Any]]:
+        return []
+
+    get_check_constraints = get_foreign_keys
+    get_indexes = get_foreign_keys
+    get_unique_constraints = get_foreign_keys
+
+    def get_table_comment(self, connection, table_name, schema=None, **kwargs):
+        return {
+            "text": "A virtual table that gives access to all dbt metrics & 
dimensions."
+        }
diff --git a/superset/sql/parse.py b/superset/sql/parse.py
index 54d1006b93..fce01a053e 100644
--- a/superset/sql/parse.py
+++ b/superset/sql/parse.py
@@ -85,6 +85,7 @@ SQLGLOT_DIALECTS = {
     # "kustosql": ???
     # "kylin": ???
     "mariadb": Dialects.MYSQL,
+    "metricflow": Dialects.SQLITE,
     "motherduck": Dialects.DUCKDB,
     "mssql": Dialects.TSQL,
     "mysql": Dialects.MYSQL,

Reply via email to