This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 68b84acd93 feat: improve Doris catalog support (#34140)
68b84acd93 is described below

commit 68b84acd93a51cad9bbee69aa78584bf0a1d1f33
Author: Beto Dealmeida <robe...@dealmeida.net>
AuthorDate: Mon Jul 14 12:01:08 2025 -0400

    feat: improve Doris catalog support (#34140)
---
 superset/db_engine_specs/doris.py              | 46 ++++++++------
 tests/unit_tests/db_engine_specs/test_doris.py | 83 ++++++++++++++++++++++----
 2 files changed, 101 insertions(+), 28 deletions(-)

diff --git a/superset/db_engine_specs/doris.py 
b/superset/db_engine_specs/doris.py
index c00dd8040d..9c054ce2a7 100644
--- a/superset/db_engine_specs/doris.py
+++ b/superset/db_engine_specs/doris.py
@@ -31,6 +31,9 @@ from superset.errors import SupersetErrorType
 from superset.models.core import Database
 from superset.utils.core import GenericDataType
 
+DEFAULT_CATALOG = "internal"
+DEFAULT_SCHEMA = "information_schema"
+
 # Regular expressions to catch custom errors
 CONNECTION_ACCESS_DENIED_REGEX = re.compile(
     "Access denied for user '(?P<username>.*?)'"
@@ -248,29 +251,39 @@ class DorisEngineSpec(MySQLEngineSpec):
         catalog: Optional[str] = None,
         schema: Optional[str] = None,
     ) -> tuple[URL, dict[str, Any]]:
-        if catalog:
-            pass
-        elif uri.database and "." in uri.database:
-            catalog, _ = uri.database.split(".", 1)
+        if not uri.database:
+            raise ValueError("Doris requires a database to be specified in the 
URI.")
+        elif "." not in uri.database:
+            current_catalog, current_schema = None, uri.database
         else:
-            catalog = "internal"
+            current_catalog, current_schema = uri.database.split(".", 1)
+
+        # and possibly override them
+        catalog = catalog or current_catalog
+        schema = schema or current_schema
 
-        # In Apache Doris, each catalog has an information_schema for BI tool
-        # compatibility. See: https://github.com/apache/doris/pull/28919
-        schema = schema or "information_schema"
-        database = ".".join([catalog or "", schema])
+        database = ".".join(part for part in (catalog, schema) if part)
         uri = uri.set(database=database)
+
         return uri, connect_args
 
     @classmethod
-    def get_default_catalog(cls, database: Database) -> Optional[str]:
+    def get_default_catalog(cls, database: Database) -> str:
         """
         Return the default catalog.
         """
-        if database.url_object.database is None:
-            return None
+        # first check the URI to see if a default catalog is set
+        if database.url_object.database and "." in 
database.url_object.database:
+            return database.url_object.database.split(".")[0]
+
+        # if not, iterate over existing catalogs and find the current one
+        with database.get_sqla_engine() as engine:
+            for catalog in engine.execute("SHOW CATALOGS"):
+                if catalog.IsCurrent:
+                    return catalog.CatalogName
 
-        return database.url_object.database.split(".")[0]
+        # fallback to "internal"
+        return DEFAULT_CATALOG
 
     @classmethod
     def get_catalog_names(
@@ -301,9 +314,8 @@ class DorisEngineSpec(MySQLEngineSpec):
             doris://localhost:9030/catalog.database
 
         """
-        database = sqlalchemy_uri.database.strip("/")
-
-        if "." not in database:
+        if not sqlalchemy_uri.database:
             return None
 
-        return parse.unquote(database.split(".")[1])
+        schema = sqlalchemy_uri.database.split(".")[-1].strip("/")
+        return parse.unquote(schema)
diff --git a/tests/unit_tests/db_engine_specs/test_doris.py 
b/tests/unit_tests/db_engine_specs/test_doris.py
index d79bc7dcbb..ed68b4e8a3 100644
--- a/tests/unit_tests/db_engine_specs/test_doris.py
+++ b/tests/unit_tests/db_engine_specs/test_doris.py
@@ -19,6 +19,7 @@ from typing import Any, Optional
 from unittest.mock import Mock
 
 import pytest
+from pytest_mock import MockerFixture
 from sqlalchemy import JSON, types
 from sqlalchemy.engine.url import make_url
 
@@ -81,30 +82,62 @@ def test_get_column_spec(
 
 
 @pytest.mark.parametrize(
-    "sqlalchemy_uri,connect_args,return_schema,return_connect_args",
+    "sqlalchemy_uri, connect_args, catalog, schema, 
return_schema,return_connect_args",
     [
         (
             "doris://user:password@host/db1",
             {"param1": "some_value"},
-            "internal.information_schema",
+            None,
+            None,
+            "db1",
             {"param1": "some_value"},
         ),
         (
             "pydoris://user:password@host/db1",
             {"param1": "some_value"},
-            "internal.information_schema",
+            None,
+            None,
+            "db1",
             {"param1": "some_value"},
         ),
         (
             "doris://user:password@host/catalog1.db1",
             {"param1": "some_value"},
-            "catalog1.information_schema",
+            None,
+            None,
+            "catalog1.db1",
             {"param1": "some_value"},
         ),
         (
             "pydoris://user:password@host/catalog1.db1",
             {"param1": "some_value"},
-            "catalog1.information_schema",
+            None,
+            None,
+            "catalog1.db1",
+            {"param1": "some_value"},
+        ),
+        (
+            "pydoris://user:password@host/catalog1.db1",
+            {"param1": "some_value"},
+            "catalog2",
+            None,
+            "catalog2.db1",
+            {"param1": "some_value"},
+        ),
+        (
+            "pydoris://user:password@host/catalog1.db1",
+            {"param1": "some_value"},
+            None,
+            "db2",
+            "catalog1.db2",
+            {"param1": "some_value"},
+        ),
+        (
+            "pydoris://user:password@host/catalog1.db1",
+            {"param1": "some_value"},
+            "catalog2",
+            "db2",
+            "catalog2.db2",
             {"param1": "some_value"},
         ),
     ],
@@ -112,6 +145,8 @@ def test_get_column_spec(
 def test_adjust_engine_params(
     sqlalchemy_uri: str,
     connect_args: dict[str, Any],
+    catalog: str | None,
+    schema: str | None,
     return_schema: str,
     return_connect_args: dict[str, Any],
 ) -> None:
@@ -119,18 +154,36 @@ def test_adjust_engine_params(
 
     url = make_url(sqlalchemy_uri)
     returned_url, returned_connect_args = DorisEngineSpec.adjust_engine_params(
-        url, connect_args
+        url,
+        connect_args,
+        catalog,
+        schema,
     )
 
     assert returned_url.database == return_schema
     assert returned_connect_args == return_connect_args
 
 
+def test_adjust_engine_params_no_database() -> None:
+    """
+    Test that we raise an exception when the database is not specified.
+    """
+    from superset.db_engine_specs.doris import DorisEngineSpec
+
+    url = make_url("doris://user:password@host")
+    with pytest.raises(
+        ValueError,
+        match="Doris requires a database to be specified in the URI.",
+    ):
+        DorisEngineSpec.adjust_engine_params(url, {})
+
+
 @pytest.mark.parametrize(
     "url,expected_schema",
     [
         ("doris://localhost:9030/hive.test", "test"),
-        ("doris://localhost:9030/hive", None),
+        ("doris://localhost:9030/test", "test"),
+        ("doris://localhost:9030/", None),
     ],
 )
 def test_get_schema_from_engine_params(
@@ -154,12 +207,14 @@ def test_get_schema_from_engine_params(
     "database_value,expected_catalog",
     [
         ("catalog1.schema1", "catalog1"),
-        ("catalog1", "catalog1"),
-        (None, None),
+        ("schema1", "catalog2"),
+        ("", "catalog2"),
     ],
 )
 def test_get_default_catalog(
-    database_value: Optional[str], expected_catalog: Optional[str]
+    mocker: MockerFixture,
+    database_value: Optional[str],
+    expected_catalog: Optional[str],
 ) -> None:
     """
     Test the ``get_default_catalog`` method.
@@ -167,8 +222,14 @@ def test_get_default_catalog(
     from superset.db_engine_specs.doris import DorisEngineSpec
     from superset.models.core import Database
 
-    database = Mock(spec=Database)
+    database = mocker.MagicMock(spec=Database)
     database.url_object.database = database_value
+    rows = [
+        mocker.MagicMock(IsCurrent=False, CatalogName="catalog1"),
+        mocker.MagicMock(IsCurrent=True, CatalogName="catalog2"),
+    ]
+    with database.get_sqla_engine() as engine:
+        engine.execute.return_value = rows
 
     assert DorisEngineSpec.get_default_catalog(database) == expected_catalog
 

Reply via email to