This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch dbt-sl
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 04a770e468faa5e9b9b4980f6178f5191568ccd9
Author: Beto Dealmeida <[email protected]>
AuthorDate: Sat Apr 27 16:16:52 2024 -0400

    WIP
---
 setup.py                        |   4 +-
 superset/db_engine_specs/dbt.py | 108 +++++++++++++++++++++++
 superset/extensions/dbt.py      | 189 ++++++++++++++++++++++++++++++++++++++++
 superset/extensions/metadb.py   |   2 +-
 4 files changed, 301 insertions(+), 2 deletions(-)

diff --git a/setup.py b/setup.py
index 00b8d22e2a..1e71d6e2f5 100644
--- a/setup.py
+++ b/setup.py
@@ -64,9 +64,11 @@ setup(
             "postgres.psycopg2 = sqlalchemy.dialects.postgresql:dialect",
             "postgres = sqlalchemy.dialects.postgresql:dialect",
             "superset = superset.extensions.metadb:SupersetAPSWDialect",
+            "dbt = superset.extensions.dbt:DbtMetricFlowDialect",
         ],
         "shillelagh.adapter": [
-            "superset=superset.extensions.metadb:SupersetShillelaghAdapter"
+            "superset = superset.extensions.metadb:SupersetShillelaghAdapter",
+            "presetdbtmetricflowapi = 
superset.extensions.dbt:PresetDbtMetricFlowAPI",
         ],
     },
     download_url="https://www.apache.org/dist/superset/"; + version_string,
diff --git a/superset/db_engine_specs/dbt.py b/superset/db_engine_specs/dbt.py
new file mode 100644
index 0000000000..a2b27326a6
--- /dev/null
+++ b/superset/db_engine_specs/dbt.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+An interface to dbt's semantic layer.
+"""
+
+from __future__ import annotations
+
+from typing import Any, TypedDict, TYPE_CHECKING
+
+from shillelagh.backends.apsw.dialects.base import get_adapter_for_table_name
+
+from superset.constants import TimeGrain
+from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec
+
+if TYPE_CHECKING:
+    from superset.models.core import Database
+    from superset.sql_parse import Table
+    from sqlalchemy.engine.reflection import Inspector
+
+
+SELECT_STAR_MESSAGE = (
+    "The dbt semantic layer does not support data preview, since the `metrics` 
table is "
+    "a virtual table that is not materialized. An administrator should 
configure the "
+    "database in Apache Superset so that the `disable_data_preview` attribute 
is set to "
+    "`true` in the `extra` field of the database configuration."
+)
+
+
+class MetricType(TypedDict, total=False):
+    """
+    Type for metrics return by `get_metrics`.
+    """
+
+    metric_name: str
+    expression: str
+    verbose_name: str | None
+    metric_type: str | None
+    description: str | None
+    d3format: str | None
+    warning_text: str | None
+    extra: str | None
+
+
+class DbtMetricFlowEngineSpec(ShillelaghEngineSpec):
+    """
+    Engine for the the dbt semantic layer.
+    """
+
+    engine = "dbt"
+    engine_name = "dbt semantic layer interface"
+    sqlalchemy_uri_placeholder = 
"dbt:///<environment_id>?service_token=<service_token>"
+
+    _time_grain_expressions = {
+        TimeGrain.DAY: "{col}__day",
+        TimeGrain.WEEK: "{col}__week",
+        TimeGrain.MONTH: "{col}__month",
+        TimeGrain.QUARTER: "{col}__quarter",
+        TimeGrain.YEAR: "{col}__year",
+    }
+
+    @classmethod
+    def select_star(cls, *args: Any, **kwargs: Any) -> str:
+        """
+        Return a ``SELECT *`` query.
+        """
+        message = SELECT_STAR_MESSAGE.replace("'", "''")
+        return f"SELECT '{message}' AS warning"
+
+    @classmethod
+    def get_metrics(
+        cls,
+        database: Database,
+        inpector: Inspector,
+        table: Table,
+    ) -> list[MetricType]:
+        """
+        Get all metrics.
+        """
+        with database.get_sqla_engine(
+            catalog=table.catalog,
+            schema=table.schema,
+        ) as engine:
+            connection = engine.connect()
+            adapter = get_adapter_for_table_name(connection, table.table)
+
+        return [
+            {
+                "metric_name": metric,
+                "expression": metric,
+                "description": description,
+            }
+            for metric, description in adapter.metrics.items()
+        ]
diff --git a/superset/extensions/dbt.py b/superset/extensions/dbt.py
new file mode 100644
index 0000000000..8a68ed2285
--- /dev/null
+++ b/superset/extensions/dbt.py
@@ -0,0 +1,189 @@
+from __future__ import annotations
+
+from datetime import timedelta
+from typing import Any
+
+import sqlalchemy.types
+from shillelagh.adapters.api.dbt_metricflow import DbtMetricFlowAPI
+from shillelagh.backends.apsw.dialects.base import (
+    APSWDialect,
+    get_adapter_for_table_name,
+)
+from shillelagh.fields import Field
+from sqlalchemy.engine.base import Connection
+from sqlalchemy.engine.url import URL
+from sqlalchemy.sql.visitors import VisitableType
+
+from superset.extensions import cache_manager
+from superset.utils.cache import memoized_func
+
+
+TABLE_NAME = "metrics"
+
+
+def get_sqla_type(field: Field) -> VisitableType:
+    """
+    Convert from Shillelagh to SQLAlchemy types.
+    """
+    type_map = {
+        "BOOLEAN": sqlalchemy.types.BOOLEAN,
+        "INTEGER": sqlalchemy.types.INT,
+        "DECIMAL": sqlalchemy.types.DECIMAL,
+        "TIMESTAMP": sqlalchemy.types.TIMESTAMP,
+        "DATE": sqlalchemy.types.DATE,
+        "TIME": sqlalchemy.types.TIME,
+        "TEXT": sqlalchemy.types.TEXT,
+    }
+
+    return type_map.get(field.type, sqlalchemy.types.TEXT)
+
+
+class PresetDbtMetricFlowAPI(DbtMetricFlowAPI):
+    """
+    Custom API adapter for dbt Metric Flow API.
+
+    Instead of querying by URL we have a dummy table name.
+    """
+
+    @staticmethod
+    def supports(uri: str, fast: bool = True, **kwargs: Any) -> bool:
+        return uri == TABLE_NAME
+
+    @staticmethod
+    def parse_uri(uri: str) -> tuple[()]:
+        return ()
+
+    def __init__(self, table: str, service_token: str, environment_id: int):
+        super().__init__(table, service_token, environment_id)
+        self.table = TABLE_NAME
+
+    def _build_column_from_dimension(self, name: str) -> Field:
+        return self._cached_get_dimension(
+            name,
+            cache_timeout=int(timedelta(days=1).total_seconds()),
+        )
+
+    @memoized_func(key="dbt:dimension:{name}", cache=cache_manager.data_cache)
+    def _cached_get_dimension(self, name: str, *args: Any, **kwargs: Any) -> 
Field:
+        return super()._build_column_from_dimension(name)
+
+
+class DbtMetricFlowDialect(APSWDialect):
+    """
+    A dbt Metric Flow dialect.
+
+    URL should look like:
+
+        dbt:///<environment_id>?service_token=<service_token>
+
+    Or when using a custom URL:
+
+        dbt://ab123.us1.dbt.com/<environment_id>?service_token=<service_token>
+
+    """
+
+    name = "dbt"
+
+    supports_statement_cache = True
+
+    def create_connect_args(self, url: URL) -> tuple[tuple[()], dict[str, 
Any]]:
+        table = (
+            f"https://{url.host}/";
+            if url.host
+            else "https://semantic-layer.cloud.getdbt.com/";
+        )
+
+        return (
+            (),
+            {
+                "path": ":memory:",
+                "adapters": ["presetdbtmetricflowapi"],
+                "adapter_kwargs": {
+                    "presetdbtmetricflowapi": {
+                        "table": table,
+                        "service_token": url.query["service_token"],
+                        "environment_id": int(url.database),
+                    },
+                },
+                "safe": True,
+                "isolation_level": self.isolation_level,
+            },
+        )
+
+    def get_table_names(
+        self,
+        connection: Connection,
+        schema: str | None = None,
+        sqlite_include_internal: bool = False,
+        **kwargs: Any,
+    ) -> list[str]:
+        return [TABLE_NAME]
+
+    def has_table(
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: str | None = None,
+        **kwargs: Any,
+    ) -> bool:
+        return table_name == TABLE_NAME
+
+    def get_columns(
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: str | None = None,
+        **kwargs: Any,
+    ) -> list[tuple[str, str]]:
+        adapter = get_adapter_for_table_name(connection, table_name)
+
+        columns = {
+            adapter.grains[dimension][0]
+            if dimension in adapter.grains
+            else dimension: adapter.columns[dimension]
+            for dimension in adapter.dimensions
+        }
+
+        return [
+            {
+                "name": name,
+                "type": get_sqla_type(field),
+                "nullable": True,
+                "default": None,
+            }
+            for name, field in columns.items()
+        ]
+
+    def get_schema_names(
+        self,
+        connection: Connection,
+        **kwargs: Any,
+    ) -> list[str]:
+        return ["main"]
+
+    def get_pk_constraint(
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: str | None = None,
+        **kwargs: Any,
+    ) -> dict[str, Any]:
+        return {"constrained_columns": [], "name": None}
+
+    def get_foreign_keys(
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: str | None = None,
+        **kwargs: Any,
+    ) -> list[dict[str, Any]]:
+        return []
+
+    get_check_constraints = get_foreign_keys
+    get_indexes = get_foreign_keys
+    get_unique_constraints = get_foreign_keys
+
+    def get_table_comment(self, connection, table_name, schema=None, **kwargs):
+        return {
+            "text": "A virtual table that gives access to all dbt metrics and 
dimensions."
+        }
diff --git a/superset/extensions/metadb.py b/superset/extensions/metadb.py
index 2d8444cc99..353b8279d2 100644
--- a/superset/extensions/metadb.py
+++ b/superset/extensions/metadb.py
@@ -111,7 +111,7 @@ class SupersetAPSWDialect(APSWDialect):
                     "superset": {
                         "prefix": None,
                         "allowed_dbs": self.allowed_dbs,
-                    }
+                    },
                 },
                 "safe": True,
                 "isolation_level": self.isolation_level,

Reply via email to