This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch improve-doris-catalog in repository https://gitbox.apache.org/repos/asf/superset.git
commit 841693a55fbddb74d8d8b6b7ede9d8b1bab65600 Author: Beto Dealmeida <robe...@dealmeida.net> AuthorDate: Fri Jul 11 11:53:38 2025 -0400 feat: improve Doris catalog support --- superset/db_engine_specs/doris.py | 41 ++++++++++------ tests/unit_tests/db_engine_specs/test_doris.py | 66 ++++++++++++++++++++++---- 2 files changed, 83 insertions(+), 24 deletions(-) diff --git a/superset/db_engine_specs/doris.py b/superset/db_engine_specs/doris.py index c00dd8040d..5bebbb20ca 100644 --- a/superset/db_engine_specs/doris.py +++ b/superset/db_engine_specs/doris.py @@ -31,6 +31,9 @@ from superset.errors import SupersetErrorType from superset.models.core import Database from superset.utils.core import GenericDataType +DEFAULT_CATALOG = "internal" +DEFAULT_SCHEMA = "information_schema" + # Regular expressions to catch custom errors CONNECTION_ACCESS_DENIED_REGEX = re.compile( "Access denied for user '(?P<username>.*?)'" @@ -248,29 +251,39 @@ class DorisEngineSpec(MySQLEngineSpec): catalog: Optional[str] = None, schema: Optional[str] = None, ) -> tuple[URL, dict[str, Any]]: - if catalog: - pass - elif uri.database and "." in uri.database: - catalog, _ = uri.database.split(".", 1) + if not uri.database: + current_catalog, current_schema = None, None + elif "." not in uri.database: + current_catalog, current_schema = None, uri.database else: - catalog = "internal" + current_catalog, current_schema = uri.database.split(".", 1) + + # and possibly override them + catalog = catalog or current_catalog + schema = schema or current_schema - # In Apache Doris, each catalog has an information_schema for BI tool - # compatibility. See: https://github.com/apache/doris/pull/28919 - schema = schema or "information_schema" - database = ".".join([catalog or "", schema]) + database = ".".join(part for part in (catalog, schema) if part) uri = uri.set(database=database) + return uri, connect_args @classmethod - def get_default_catalog(cls, database: Database) -> Optional[str]: + def get_default_catalog(cls, database: Database) -> str: """ Return the default catalog. """ - if database.url_object.database is None: - return None - - return database.url_object.database.split(".")[0] + # first check the URI to see if a default catalog is set + if database.url_object.database and "." in database.url_object.database: + return database.url_object.database.split(".")[0] + + # if not, iterate over existing catalogs and find the current one + with database.get_sqla_engine() as engine: + for catalog in engine.execute("SHOW CATALOGS"): + if catalog.IsCurrent: + return catalog.CatalogName + + # fallback to "internal" + return DEFAULT_CATALOG @classmethod def get_catalog_names( diff --git a/tests/unit_tests/db_engine_specs/test_doris.py b/tests/unit_tests/db_engine_specs/test_doris.py index d79bc7dcbb..30dfbb3c37 100644 --- a/tests/unit_tests/db_engine_specs/test_doris.py +++ b/tests/unit_tests/db_engine_specs/test_doris.py @@ -19,6 +19,7 @@ from typing import Any, Optional from unittest.mock import Mock import pytest +from pytest_mock import MockerFixture from sqlalchemy import JSON, types from sqlalchemy.engine.url import make_url @@ -81,30 +82,62 @@ def test_get_column_spec( @pytest.mark.parametrize( - "sqlalchemy_uri,connect_args,return_schema,return_connect_args", + "sqlalchemy_uri, connect_args, catalog, schema, return_schema,return_connect_args", [ ( "doris://user:password@host/db1", {"param1": "some_value"}, - "internal.information_schema", + None, + None, + "db1", {"param1": "some_value"}, ), ( "pydoris://user:password@host/db1", {"param1": "some_value"}, - "internal.information_schema", + None, + None, + "db1", {"param1": "some_value"}, ), ( "doris://user:password@host/catalog1.db1", {"param1": "some_value"}, - "catalog1.information_schema", + None, + None, + "catalog1.db1", {"param1": "some_value"}, ), ( "pydoris://user:password@host/catalog1.db1", {"param1": "some_value"}, - "catalog1.information_schema", + None, + None, + "catalog1.db1", + {"param1": "some_value"}, + ), + ( + "pydoris://user:password@host/catalog1.db1", + {"param1": "some_value"}, + "catalog2", + None, + "catalog2.db1", + {"param1": "some_value"}, + ), + ( + "pydoris://user:password@host/catalog1.db1", + {"param1": "some_value"}, + None, + "db2", + "catalog1.db2", + {"param1": "some_value"}, + ), + ( + "pydoris://user:password@host/catalog1.db1", + {"param1": "some_value"}, + "catalog2", + "db2", + "catalog2.db2", {"param1": "some_value"}, ), ], @@ -112,6 +145,8 @@ def test_get_column_spec( def test_adjust_engine_params( sqlalchemy_uri: str, connect_args: dict[str, Any], + catalog: str | None, + schema: str | None, return_schema: str, return_connect_args: dict[str, Any], ) -> None: @@ -119,7 +154,10 @@ def test_adjust_engine_params( url = make_url(sqlalchemy_uri) returned_url, returned_connect_args = DorisEngineSpec.adjust_engine_params( - url, connect_args + url, + connect_args, + catalog, + schema, ) assert returned_url.database == return_schema @@ -154,12 +192,14 @@ def test_get_schema_from_engine_params( "database_value,expected_catalog", [ ("catalog1.schema1", "catalog1"), - ("catalog1", "catalog1"), - (None, None), + ("schema1", "catalog2"), + ("", "catalog2"), ], ) def test_get_default_catalog( - database_value: Optional[str], expected_catalog: Optional[str] + mocker: MockerFixture, + database_value: Optional[str], + expected_catalog: Optional[str], ) -> None: """ Test the ``get_default_catalog`` method. @@ -167,8 +207,14 @@ def test_get_default_catalog( from superset.db_engine_specs.doris import DorisEngineSpec from superset.models.core import Database - database = Mock(spec=Database) + database = mocker.MagicMock(spec=Database) database.url_object.database = database_value + rows = [ + mocker.MagicMock(IsCurrent=False, CatalogName="catalog1"), + mocker.MagicMock(IsCurrent=True, CatalogName="catalog2"), + ] + with database.get_sqla_engine() as engine: + engine.execute.return_value = rows assert DorisEngineSpec.get_default_catalog(database) == expected_catalog