This is an automated email from the ASF dual-hosted git repository.

tai pushed a commit to branch feat/starrocks-catalog
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 74996ff6a838a680f37661edbe4261e40141fc4d
Author: Tai Dupree <[email protected]>
AuthorDate: Thu Jan 8 11:55:06 2026 -0800

    feat(starrocks): add catalog support for StarRocks database connections
---
 superset/db_engine_specs/starrocks.py              | 152 ++++++++++++++++-----
 tests/unit_tests/db_engine_specs/test_starrocks.py | 114 +++++++++++++++-
 2 files changed, 229 insertions(+), 37 deletions(-)

diff --git a/superset/db_engine_specs/starrocks.py 
b/superset/db_engine_specs/starrocks.py
index d3e2172f2b..3777e67149 100644
--- a/superset/db_engine_specs/starrocks.py
+++ b/superset/db_engine_specs/starrocks.py
@@ -18,11 +18,12 @@
 import logging
 import re
 from re import Pattern
-from typing import Any, Optional, Union
+from typing import Any
 from urllib import parse
 
 from flask_babel import gettext as __
 from sqlalchemy import Float, Integer, Numeric, types
+from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.url import URL
 from sqlalchemy.sql.type_api import TypeEngine
 
@@ -31,6 +32,8 @@ from superset.errors import SupersetErrorType
 from superset.models.core import Database
 from superset.utils.core import GenericDataType
 
+DEFAULT_CATALOG = "default_catalog"
+
 # Regular expressions to catch custom errors
 CONNECTION_ACCESS_DENIED_REGEX = re.compile(
     "Access denied for user '(?P<username>.*?)'"
@@ -68,7 +71,7 @@ class ARRAY(TypeEngine):
     __visit_name__ = "ARRAY"
 
     @property
-    def python_type(self) -> Optional[type[list[Any]]]:
+    def python_type(self) -> type[list[Any]] | None:
         return list
 
 
@@ -76,7 +79,7 @@ class MAP(TypeEngine):
     __visit_name__ = "MAP"
 
     @property
-    def python_type(self) -> Optional[type[dict[Any, Any]]]:
+    def python_type(self) -> type[dict[Any, Any]] | None:
         return dict
 
 
@@ -84,7 +87,7 @@ class STRUCT(TypeEngine):
     __visit_name__ = "STRUCT"
 
     @property
-    def python_type(self) -> Optional[type[Any]]:
+    def python_type(self) -> type[Any] | None:
         return None
 
 
@@ -94,8 +97,11 @@ class StarRocksEngineSpec(MySQLEngineSpec):
 
     default_driver = "starrocks"
     sqlalchemy_uri_placeholder = (
-        
"starrocks://user:password@host:port/catalog.db[?key=value&key=value...]"
+        "starrocks://user:password@host:port[/catalog.db]"
     )
+    supports_dynamic_schema = True
+    supports_catalog = supports_dynamic_catalog = True
+    supports_cross_catalog_queries = True
 
     column_type_mappings = (  # type: ignore
         (
@@ -168,17 +174,39 @@ class StarRocksEngineSpec(MySQLEngineSpec):
         cls,
         uri: URL,
         connect_args: dict[str, Any],
-        catalog: Optional[str] = None,
-        schema: Optional[str] = None,
+        catalog: str | None = None,
+        schema: str | None = None,
     ) -> tuple[URL, dict[str, Any]]:
-        database = uri.database
-        if schema and database:
+        """
+        Adjust engine parameters for StarRocks catalog and schema support.
+
+        StarRocks uses a "catalog.schema" format in the database field:
+        - "catalog.schema" - both specified
+        - "catalog." - catalog only (for browsing schemas)
+        - None - neither specified
+        """
+        if uri.database and "." in uri.database:
+            current_catalog, current_schema = uri.database.split(".", 1)
+        elif uri.database:
+            current_catalog, current_schema = uri.database, None
+        else:
+            current_catalog, current_schema = None, None
+
+        if schema:
             schema = parse.quote(schema, safe="")
-            if "." in database:
-                database = database.split(".")[0] + "." + schema
-            else:
-                database = "default_catalog." + schema
-            uri = uri.set(database=database)
+
+        effective_catalog = catalog or current_catalog or DEFAULT_CATALOG
+        # only use the schema/db from uri if we're not overriding catalog
+        effective_schema = schema
+        if not effective_schema and (not catalog or catalog == 
current_catalog):
+            effective_schema = current_schema
+
+        if effective_schema:
+            adjusted_database = f"{effective_catalog}.{effective_schema}"
+        else:
+            adjusted_database = f"{effective_catalog}."
+
+        uri = uri.set(database=adjusted_database)
 
         return uri, connect_args
 
@@ -187,21 +215,85 @@ class StarRocksEngineSpec(MySQLEngineSpec):
         cls,
         sqlalchemy_uri: URL,
         connect_args: dict[str, Any],
-    ) -> Optional[str]:
+    ) -> str | None:
+        """
+        Extract schema from engine parameters.
+
+        Returns the schema portion from formats like:
+        - "catalog.schema" -> "schema"
+        - "schema" -> None (ambiguous - could be catalog or schema)
+        - "" or None -> None
         """
-        Return the configured schema.
+        if not sqlalchemy_uri.database:
+            return None
 
-        For StarRocks the SQLAlchemy URI looks like this:
+        database = sqlalchemy_uri.database.strip("/")
+        if not database or "." not in database:
+            return None
 
-            starrocks://localhost:9030/catalog.schema
+        schema = database.split(".")[-1]
+        return parse.unquote(schema)
 
+    @classmethod
+    def get_default_catalog(cls, database: Database) -> str:
         """
-        database = sqlalchemy_uri.database.strip("/")
+        Return the default catalog.
 
-        if "." not in database:
-            return None
+        Extracts catalog from URI (e.g., "iceberg" from "iceberg.schema"),
+        otherwise returns DEFAULT_CATALOG.
+        """
+        if database.url_object.database and "." in 
database.url_object.database:
+            return database.url_object.database.split(".")[0]
 
-        return parse.unquote(database.split(".")[1])
+        return DEFAULT_CATALOG
+
+    @classmethod
+    def get_catalog_names(
+        cls,
+        database: Database,
+        inspector: Inspector,
+    ) -> set[str]:
+        """
+        Get all available catalogs.
+
+        Executes SHOW CATALOGS and extracts catalog names from the result.
+        The command returns columns: Catalog, Type, Comment
+        """
+        try:
+            result = inspector.bind.execute("SHOW CATALOGS")
+            catalogs = set()
+
+            for row in result:
+                try:
+                    if hasattr(row, "keys") and "Catalog" in row.keys():
+                        catalogs.add(row["Catalog"])
+                    elif hasattr(row, "Catalog"):
+                        catalogs.add(row.Catalog)
+                    else:
+                        catalogs.add(row[0])
+                except (AttributeError, TypeError, IndexError, KeyError) as ex:
+                    logger.warning("Unable to extract catalog name from row: 
%s (%s)", row, ex)
+                    continue
+
+            return catalogs
+        except Exception as ex:  # pylint: disable=broad-except
+            logger.exception("Error fetching catalog names from SHOW CATALOGS: 
%s", ex)
+            return set()
+
+    @classmethod
+    def get_schema_names(cls, inspector: Inspector) -> set[str]:
+        """
+        Get all schemas/databases using SHOW DATABASES.
+
+        The catalog context is set via the database field in the connection URL
+        (e.g., "catalog." sets the context to that catalog).
+        """
+        try:
+            result = inspector.bind.execute("SHOW DATABASES")
+            return {row[0] for row in result}
+        except Exception as ex:  # pylint: disable=broad-except
+            logger.exception("Error fetching schema names from SHOW DATABASES: 
%s", ex)
+            return set()
 
     @classmethod
     def impersonate_user(
@@ -225,21 +317,13 @@ class StarRocksEngineSpec(MySQLEngineSpec):
     def get_prequeries(
         cls,
         database: Database,
-        catalog: Union[str, None] = None,
-        schema: Union[str, None] = None,
+        catalog: str | None = None,
+        schema: str | None = None,
     ) -> list[str]:
         """
-        Return pre-session queries.
-
-        These are currently used as an alternative to ``adjust_engine_params`` 
for
-        databases where the selected schema cannot be specified in the 
SQLAlchemy URI or
-        connection arguments.
-
-        For example, in order to specify a default schema in RDS we need to 
run a query
-        at the beginning of the session:
-
-            sql> set search_path = my_schema;
+        Get pre-session queries.
 
+        For StarRocks with user impersonation enabled, returns an EXECUTE AS 
statement.
         """
         if database.impersonate_user:
             username = database.get_effective_user(database.url_object)
diff --git a/tests/unit_tests/db_engine_specs/test_starrocks.py 
b/tests/unit_tests/db_engine_specs/test_starrocks.py
index 67016a0801..e37aeab901 100644
--- a/tests/unit_tests/db_engine_specs/test_starrocks.py
+++ b/tests/unit_tests/db_engine_specs/test_starrocks.py
@@ -79,7 +79,7 @@ def test_get_column_spec(
         (
             "starrocks://user:password@host/db1",
             {"param1": "some_value"},
-            "db1",
+            "db1.",  # Single value is treated as schema (in default catalog)
             {"param1": "some_value"},
         ),
         (
@@ -88,12 +88,18 @@ def test_get_column_spec(
             "catalog1.db1",
             {"param1": "some_value"},
         ),
+        (
+            "starrocks://user:password@host",
+            {"param1": "some_value"},
+            "default_catalog.",
+            {"param1": "some_value"},
+        ),
     ],
 )
 def test_adjust_engine_params(
     sqlalchemy_uri: str,
     connect_args: dict[str, Any],
-    return_schema: str,
+    return_schema: Optional[str],
     return_connect_args: dict[str, Any],
 ) -> None:
     from superset.db_engine_specs.starrocks import StarRocksEngineSpec
@@ -112,6 +118,7 @@ def test_get_schema_from_engine_params() -> None:
     """
     from superset.db_engine_specs.starrocks import StarRocksEngineSpec
 
+    # With catalog.schema format
     assert (
         StarRocksEngineSpec.get_schema_from_engine_params(
             make_url("starrocks://localhost:9030/hive.default"),
@@ -120,9 +127,19 @@ def test_get_schema_from_engine_params() -> None:
         == "default"
     )
 
+    # With only catalog (no schema) - should return None
+    assert (
+        StarRocksEngineSpec.get_schema_from_engine_params(
+            make_url("starrocks://localhost:9030/sales"),
+            {},
+        )
+        is None
+    )
+
+    # With no database - should return None
     assert (
         StarRocksEngineSpec.get_schema_from_engine_params(
-            make_url("starrocks://localhost:9030/hive"),
+            make_url("starrocks://localhost:9030"),
             {},
         )
         is None
@@ -173,3 +190,94 @@ def test_impersonation_disabled(mocker: MockerFixture) -> 
None:
     ) == (make_url("starrocks://service_user@localhost:9030/hive.default"), {})
 
     assert StarRocksEngineSpec.get_prequeries(database) == []
+
+
+def test_get_default_catalog(mocker: MockerFixture) -> None:
+    """
+    Test the ``get_default_catalog`` method.
+    """
+    from superset.db_engine_specs.starrocks import StarRocksEngineSpec
+
+    # Test case 1: Catalog is in the URI
+    database = mocker.MagicMock()
+    database.url_object.database = "hive.default"
+
+    assert StarRocksEngineSpec.get_default_catalog(database) == "hive"
+
+    # Test case 2: Catalog is not in the URI, returns default
+    database = mocker.MagicMock()
+    database.url_object.database = "default"
+
+    assert StarRocksEngineSpec.get_default_catalog(database) == 
"default_catalog"
+
+
+def test_get_catalog_names(mocker: MockerFixture) -> None:
+    """
+    Test the ``get_catalog_names`` method.
+    """
+    from superset.db_engine_specs.starrocks import StarRocksEngineSpec
+
+    database = mocker.MagicMock()
+    inspector = mocker.MagicMock()
+
+    # Mock the actual StarRocks SHOW CATALOGS format
+    # StarRocks returns rows with keys: ['Catalog', 'Type', 'Comment']
+    mock_row_1 = mocker.MagicMock()
+    mock_row_1.keys.return_value = ["Catalog", "Type", "Comment"]
+    mock_row_1.__getitem__ = lambda self, key: "default_catalog" if key == 
"Catalog" else None
+
+    mock_row_2 = mocker.MagicMock()
+    mock_row_2.keys.return_value = ["Catalog", "Type", "Comment"]
+    mock_row_2.__getitem__ = lambda self, key: "hive" if key == "Catalog" else 
None
+
+    mock_row_3 = mocker.MagicMock()
+    mock_row_3.keys.return_value = ["Catalog", "Type", "Comment"]
+    mock_row_3.__getitem__ = lambda self, key: "iceberg" if key == "Catalog" 
else None
+
+    inspector.bind.execute.return_value = [mock_row_1, mock_row_2, mock_row_3]
+
+    catalogs = StarRocksEngineSpec.get_catalog_names(database, inspector)
+    assert catalogs == {"default_catalog", "hive", "iceberg"}
+
+
[email protected](
+    "uri,catalog,schema,expected_database",
+    [
+        # Test with catalog and schema/db in URI
+        ("starrocks://host/hive.sales", None, None, "hive.sales"),
+        # Test overriding catalog
+        ("starrocks://host/hive.sales", "iceberg", None, "iceberg."),
+        # Test overriding schema/db
+        ("starrocks://host/hive.sales", None, "marketing", "hive.marketing"),
+        # Test overriding both
+        ("starrocks://host/hive.sales", "iceberg", "marketing", 
"iceberg.marketing"),
+        # Test with only catalog in URI (no schema/db), add new schema
+        ("starrocks://host/hive", None, "marketing", "hive.marketing"),
+        # Test with catalog in URI, override catalog
+        ("starrocks://host/hive", "iceberg", None, "iceberg."),
+        # Test with no catalog/database in URI, overriding catalog"
+        ("starrocks://host", "iceberg", None, "iceberg."),
+        # Test with no catalog/database in URI, catalog and schema/db
+        ("starrocks://host", "iceberg", "sales", "iceberg.sales"),
+        # Test with empty database and empty overrides, uses default catalog
+        ("starrocks://host", None, None, 'default_catalog.'),
+        # Test schema only (no catalog) when URI has no database, uses 
default_catalog
+        ("starrocks://host", None, "sales", "default_catalog.sales"),
+    ],
+)
+def test_adjust_engine_params_with_catalog(
+    uri: str,
+    catalog: Optional[str],
+    schema: Optional[str],
+    expected_database: Optional[str],
+) -> None:
+    """
+    Test the ``adjust_engine_params`` method with catalog parameter.
+    """
+    from superset.db_engine_specs.starrocks import StarRocksEngineSpec
+
+    url = make_url(uri)
+    returned_url, _ = StarRocksEngineSpec.adjust_engine_params(
+        url, {}, catalog=catalog, schema=schema
+    )
+    assert returned_url.database == expected_database

Reply via email to