This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch metricflow-db-engine-spec
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 05043fb961b09c81f2c516cf3f237b786bc85006
Author: Beto Dealmeida <robe...@dealmeida.net>
AuthorDate: Tue Jul 22 11:33:15 2025 -0400

    WIP
---
 pyproject.toml                         |   1 +
 superset/db_engine_specs/metricflow.py | 189 +++++++++++++++++++++++++++++++++
 2 files changed, 190 insertions(+)

diff --git a/pyproject.toml b/pyproject.toml
index b8c5ebb0b8..f513b9fb03 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -148,6 +148,7 @@ hive = [
 impala = ["impyla>0.16.2, <0.17"]
 kusto = ["sqlalchemy-kusto>=3.0.0, <4"]
 kylin = ["kylinpy>=2.8.1, <2.9"]
+metricflow = ["shillelagh[dbtmetricflowapi]>=1.4.2, <2"]
 motherduck = ["duckdb==0.10.2", "duckdb-engine>=0.12.1, <0.13"]
 mssql = ["pymssql>=2.2.8, <3"]
 mysql = ["mysqlclient>=2.1.0, <3"]
diff --git a/superset/db_engine_specs/metricflow.py 
b/superset/db_engine_specs/metricflow.py
new file mode 100644
index 0000000000..e1c9c2115b
--- /dev/null
+++ b/superset/db_engine_specs/metricflow.py
@@ -0,0 +1,189 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+An interface to dbt's semantic layer, Metric Flow.
+"""
+
+from __future__ import annotations
+
+from datetime import timedelta
+from typing import Any, TYPE_CHECKING
+
+from shillelagh.backends.apsw.dialects.base import get_adapter_for_table_name
+from shillelagh.backends.apsw.dialects.metricflow import TABLE_NAME
+from sqlalchemy import event
+from sqlalchemy.engine import Engine
+from sqlalchemy.pool import _ConnectionRecord
+from sqlalchemy.pool.base import _ConnectionFairy
+
+from superset.connectors.sqla.models import SqlaTable
+from superset.constants import TimeGrain
+from superset.db_engine_specs.base import MetricType, ValidColumnsType
+from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec
+from superset.extensions import cache_manager
+from superset.utils.cache import memoized_func
+
+if TYPE_CHECKING:
+    from shillelagh.backends.apsw.dialects.metricflow import TableMetricFlowAPI
+    from shillelagh.fields import Field
+    from sqlalchemy.engine.reflection import Inspector
+
+    from superset.models.core import Database
+    from superset.sql.parse import Table
+    from superset.superset_typing import ResultSetColumnType
+
+
+@event.listens_for(Engine, "connect")
+def receive_connect(
+    dbapi_connection: _ConnectionFairy,
+    connection_record: _ConnectionRecord,
+) -> None:
+    """
+    Called when a new DB connection is created.
+
+    This hook adds a cache to the `_build_column_from_dimension` method of the 
Metric
+    Flow adapter, since it's called frequently and can be expensive.
+    """
+    engine = connection_record.info.get("engine")
+    if (
+        not engine
+        or not engine.name == "metricflow"
+        or getattr(engine.dialect, "_patched", False)
+    ):
+        return
+
+    original_method = engine.dialect._build_column_from_dimension
+
+    @memoized_func(
+        key="metricflow:dimension:{name}",
+        cache=cache_manager.data_cache,
+    )
+    def cached_build_column_from_dimension(
+        self: TableMetricFlowAPI,
+        name: str,
+        *args: Any,
+        **kwargs: Any,
+    ) -> Field:
+        return original_method(
+            self,
+            name,
+            cache_timeout=int(timedelta(days=1).total_seconds()),
+        )
+
+    engine.dialect._build_column_from_dimension = 
cached_build_column_from_dimension
+    engine.dialect._patched = True
+
+
+SELECT_STAR_MESSAGE = (
+    'The dbt semantic layer does not support data preview, since the "metrics" 
table '
+    "is a virtual table that is not materialized. An administrator should 
configure "
+    'the database in Apache Superset so that the "Disable SQL Lab data preview 
'
+    'queries" option under "Advanced" → "SQL Lab" is enabled.'
+)
+
+
+class DbtMetricFlowEngineSpec(ShillelaghEngineSpec):
+    """
+    Engine for the the dbt semantic layer.
+    """
+
+    engine = "metricflow"
+    engine_name = "dbt Metric Flow"
+    sqlalchemy_uri_placeholder = (
+        "metricflow://[ab123.us1.dbt.com]/<environment_id>"
+        "?service_token=<service_token>"
+    )
+
+    supports_dynamic_columns = True
+
+    _time_grain_expressions = {
+        TimeGrain.DAY: "{col}__day",
+        TimeGrain.WEEK: "{col}__week",
+        TimeGrain.MONTH: "{col}__month",
+        TimeGrain.QUARTER: "{col}__quarter",
+        TimeGrain.YEAR: "{col}__year",
+    }
+
+    @classmethod
+    def select_star(cls, *args: Any, **kwargs: Any) -> str:
+        """
+        Return a ``SELECT *`` query.
+        """
+        message = SELECT_STAR_MESSAGE.replace("'", "''")
+        return f"SELECT '{message}' AS warning"
+
+    @classmethod
+    def get_columns(
+        cls,
+        inspector: Inspector,
+        table: Table,
+        options: dict[str, Any] | None = None,
+    ) -> list[ResultSetColumnType]:
+        columns: list[ResultSetColumnType] = []
+
+        for column in inspector.get_columns(table.table, table.schema):
+            # ignore metrics
+            if "computed" in column:
+                continue
+
+            column["column_name"] = column["name"]
+            columns.append(column)
+
+        return columns
+
+    @classmethod
+    def get_metrics(
+        cls,
+        database: Database,
+        inspector: Inspector,
+        table: Table,
+    ) -> list[MetricType]:
+        """
+        Get all metrics.
+        """
+        return [
+            {
+                "metric_name": column["name"],
+                "expression": column["computed"]["sqltext"],
+                "description": column["comment"],
+            }
+            for column in inspector.get_columns(table.table, table.schema)
+            if "computed" in column
+        ]
+
+    @classmethod
+    def get_valid_metrics_and_dimensions(
+        cls,
+        database: Database,
+        table: SqlaTable,
+        dimensions: set[str],
+        metrics: set[str],
+    ) -> ValidColumnsType:
+        """
+        Get valid metrics and dimensions.
+
+        Given a datasource, and sets of selected metrics and dimensions, 
return the
+        sets of valid metrics and dimensions that can further be selected.
+        """
+        with database.get_sqla_engine() as engine:
+            connection = engine.connect()
+            adapter = get_adapter_for_table_name(connection, TABLE_NAME)
+
+        return {
+            "metrics": adapter._get_metrics_for_dimensions(dimensions),
+            "dimensions": adapter._get_dimensions_for_metrics(metrics),
+        }

Reply via email to