This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch trino-wrappers in repository https://gitbox.apache.org/repos/asf/superset.git
commit 993bc79fb0e9ecef3ca640b842715a3c3dd7f8b0 Author: Beto Dealmeida <[email protected]> AuthorDate: Thu Jul 11 17:01:37 2024 -0400 fix: Trino get_columns --- superset/db_engine_specs/base.py | 36 ++++-- superset/db_engine_specs/presto.py | 166 ++++++++++++------------- superset/db_engine_specs/trino.py | 20 ++- tests/unit_tests/db_engine_specs/test_base.py | 23 ++++ tests/unit_tests/db_engine_specs/test_trino.py | 63 +++++++++- 5 files changed, 207 insertions(+), 101 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 159c510fe9..abd4c8c895 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -130,7 +130,9 @@ builtin_time_grains: dict[str | None, str] = { } -class TimestampExpression(ColumnClause): # pylint: disable=abstract-method, too-many-ancestors +class TimestampExpression( + ColumnClause +): # pylint: disable=abstract-method, too-many-ancestors def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None: """Sqlalchemy class that can be used to render native column elements respecting engine-specific quoting rules as part of a string-based expression. @@ -388,9 +390,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods max_column_name_length: int | None = None try_remove_schema_from_table_name = True # pylint: disable=invalid-name run_multiple_statements_as_one = False - custom_errors: dict[ - Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]] - ] = {} + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = ( + {} + ) # Whether the engine supports file uploads # if True, database will be listed as option in the upload file form @@ -1653,14 +1655,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods if show_cols: fields = cls._get_fields(cols) - quote = engine.dialect.identifier_preparer.quote - quote_schema = engine.dialect.identifier_preparer.quote_schema - full_table_name = ( - quote_schema(table.schema) + "." + quote(table.table) - if table.schema - else quote(table.table) - ) - + full_table_name = cls.quote_table(table, engine.dialect) qry = select(fields).select_from(text(full_table_name)) if limit and cls.allow_limit_clause: @@ -2224,6 +2219,23 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return name + @classmethod + def quote_table(cls, table: Table, dialect: Dialect) -> str: + """ + Fully quote a table name, including the schema and catalog. + """ + quoters = { + "catalog": dialect.identifier_preparer.quote_schema, + "schema": dialect.identifier_preparer.quote_schema, + "table": dialect.identifier_preparer.quote, + } + + return ".".join( + function(getattr(table, key)) + for key, function in quoters.items() + if getattr(table, key) + ) + # schema for adding a database by providing parameters instead of the # full SQLAlchemy URI diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index fbd0eff484..d5db0fdfaa 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -672,6 +672,88 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): return "" return df.to_dict()[field_to_return][0] + @classmethod + def _show_columns( + cls, + inspector: Inspector, + table: Table, + ) -> list[ResultRow]: + """ + Show presto column names + :param inspector: object that performs database schema inspection + :param table: table instance + :return: list of column objects + """ + full_table_name = cls.quote_table(table, inspector.engine.dialect) + return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table_name}").fetchall() + + @classmethod + def _create_column_info( + cls, name: str, data_type: types.TypeEngine + ) -> ResultSetColumnType: + """ + Create column info object + :param name: column name + :param data_type: column data type + :return: column info object + """ + return { + "column_name": name, + "name": name, + "type": f"{data_type}", + "is_dttm": None, + "type_generic": None, + } + + @classmethod + def get_columns( + cls, + inspector: Inspector, + table: Table, + options: dict[str, Any] | None = None, + ) -> list[ResultSetColumnType]: + """ + Get columns from a Presto data source. This includes handling row and + array data types + :param inspector: object that performs database schema inspection + :param table: table instance + :param options: Extra configuration options, not used by this backend + :return: a list of results that contain column info + (i.e. column name and data type) + """ + columns = cls._show_columns(inspector, table) + result: list[ResultSetColumnType] = [] + for column in columns: + # parse column if it is a row or array + if is_feature_enabled("PRESTO_EXPAND_DATA") and ( + "array" in column.Type or "row" in column.Type + ): + structural_column_index = len(result) + cls._parse_structural_column(column.Column, column.Type, result) + result[structural_column_index]["nullable"] = getattr( + column, "Null", True + ) + result[structural_column_index]["default"] = None + continue + + # otherwise column is a basic data type + column_spec = cls.get_column_spec(column.Type) + column_type = column_spec.sqla_type if column_spec else None + if column_type is None: + column_type = types.String() + logger.info( + "Did not recognize type %s of column %s", + str(column.Type), + str(column.Column), + ) + column_info = cls._create_column_info(column.Column, column_type) + column_info["nullable"] = getattr(column, "Null", True) + column_info["default"] = None + column_info["column_name"] = column.Column + result.append(column_info) + + return result + class PrestoEngineSpec(PrestoBaseEngineSpec): engine = "presto" @@ -840,24 +922,6 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): results = cursor.fetchall() return {row[0] for row in results} - @classmethod - def _create_column_info( - cls, name: str, data_type: types.TypeEngine - ) -> ResultSetColumnType: - """ - Create column info object - :param name: column name - :param data_type: column data type - :return: column info object - """ - return { - "column_name": name, - "name": name, - "type": f"{data_type}", - "is_dttm": None, - "type_generic": None, - } - @classmethod def _get_full_name(cls, names: list[tuple[str, str]]) -> str: """ @@ -979,72 +1043,6 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): formatted_parent_column_name, parent_column_name ) - @classmethod - def _show_columns( - cls, - inspector: Inspector, - table: Table, - ) -> list[ResultRow]: - """ - Show presto column names - :param inspector: object that performs database schema inspection - :param table: table instance - :return: list of column objects - """ - quote = inspector.engine.dialect.identifier_preparer.quote_identifier - full_table = quote(table.table) - if table.schema: - full_table = f"{quote(table.schema)}.{full_table}" - return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall() - - @classmethod - def get_columns( - cls, - inspector: Inspector, - table: Table, - options: dict[str, Any] | None = None, - ) -> list[ResultSetColumnType]: - """ - Get columns from a Presto data source. This includes handling row and - array data types - :param inspector: object that performs database schema inspection - :param table: table instance - :param options: Extra configuration options, not used by this backend - :return: a list of results that contain column info - (i.e. column name and data type) - """ - columns = cls._show_columns(inspector, table) - result: list[ResultSetColumnType] = [] - for column in columns: - # parse column if it is a row or array - if is_feature_enabled("PRESTO_EXPAND_DATA") and ( - "array" in column.Type or "row" in column.Type - ): - structural_column_index = len(result) - cls._parse_structural_column(column.Column, column.Type, result) - result[structural_column_index]["nullable"] = getattr( - column, "Null", True - ) - result[structural_column_index]["default"] = None - continue - - # otherwise column is a basic data type - column_spec = cls.get_column_spec(column.Type) - column_type = column_spec.sqla_type if column_spec else None - if column_type is None: - column_type = types.String() - logger.info( - "Did not recognize type %s of column %s", - str(column.Type), - str(column.Column), - ) - column_info = cls._create_column_info(column.Column, column_type) - column_info["nullable"] = getattr(column, "Null", True) - column_info["default"] = None - column_info["column_name"] = column.Column - result.append(column_info) - return result - @classmethod def _is_column_name_quoted(cls, column_name: str) -> bool: """ diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 143276bdc3..dabb89ed8a 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -36,7 +36,7 @@ from sqlalchemy.exc import NoSuchTableError from superset import db from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT from superset.databases.utils import make_url_safe -from superset.db_engine_specs.base import BaseEngineSpec +from superset.db_engine_specs.base import BaseEngineSpec, convert_inspector_columns from superset.db_engine_specs.exceptions import ( SupersetDBAPIConnectionError, SupersetDBAPIDatabaseError, @@ -241,7 +241,11 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): execute_thread = threading.Thread( target=_execute, - args=(execute_result, execute_event, current_app._get_current_object()), # pylint: disable=protected-access + args=( + execute_result, + execute_event, + current_app._get_current_object(), + ), # pylint: disable=protected-access ) execute_thread.start() @@ -433,7 +437,17 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): "schema_options", expand the schema definition out to show all subfields of nested ROWs as their appropriate dotted paths. """ - base_cols = super().get_columns(inspector, table, options) + # The Trino dialect raises `NoSuchTableError` on the inspection methods when the + # table is empty. We can work around this by running a `SHOW COLUMNS FROM` query + # when that happens, using the method from the Presto base engine spec. + try: + # `SELECT * FROM information_schema.columns WHERE ...` + sqla_columns = inspector.get_columns(table.table, table.schema) + base_cols = convert_inspector_columns(sqla_columns) + except NoSuchTableError: + # `SHOW COLUMNS FROM ...` + base_cols = super().get_columns(inspector, table, options) + if not (options or {}).get("expand_rows"): return base_cols diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index a3920c8916..9ec1ebaf00 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -311,3 +311,26 @@ def test_get_default_catalog(mocker: MockerFixture) -> None: database = mocker.MagicMock() assert BaseEngineSpec.get_default_catalog(database) is None + + +def test_quote_table() -> None: + """ + Test the `quote_table` function. + """ + from superset.db_engine_specs.base import BaseEngineSpec + + dialect = sqlite.dialect() + + assert BaseEngineSpec.quote_table(Table("table"), dialect) == '"table"' + assert ( + BaseEngineSpec.quote_table(Table("table", "schema"), dialect) + == 'schema."table"' + ) + assert ( + BaseEngineSpec.quote_table(Table("table", "schema", "catalog"), dialect) + == 'catalog.schema."table"' + ) + assert ( + BaseEngineSpec.quote_table(Table("ta ble", "sche.ma", 'cata"log'), dialect) + == '"cata""log"."sche.ma"."ta ble"' + ) diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 3a2ac91ad6..a0923e8111 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access import copy +from collections import namedtuple from datetime import datetime from typing import Any, Optional from unittest.mock import MagicMock, Mock, patch @@ -25,7 +26,9 @@ import pytest from pytest_mock import MockerFixture from requests.exceptions import ConnectionError as RequestsConnectionError from sqlalchemy import sql, text, types +from sqlalchemy.dialects import sqlite from sqlalchemy.engine.url import make_url +from sqlalchemy.exc import NoSuchTableError from trino.exceptions import TrinoExternalError, TrinoInternalError, TrinoUserError from trino.sqlalchemy import datatype from trino.sqlalchemy.dialect import TrinoDialect @@ -464,6 +467,64 @@ def test_get_columns(mocker: MockerFixture): _assert_columns_equal(actual, expected) +def test_get_columns_error(mocker: MockerFixture): + """ + Test that we fallback to a `SHOW COLUMNS FROM ...` query. + """ + from superset.db_engine_specs.trino import TrinoEngineSpec + + field1_type = datatype.parse_sqltype("row(a varchar, b date)") + field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))") + field3_type = datatype.parse_sqltype("int") + + mock_inspector = mocker.MagicMock() + mock_inspector.engine.dialect = sqlite.dialect() + mock_inspector.get_columns.side_effect = NoSuchTableError( + "The specified table does not exist." + ) + Row = namedtuple("Row", ["Column", "Type"]) + mock_inspector.bind.execute().fetchall.return_value = [ + Row("field1", "row(a varchar, b date)"), + Row("field2", "row(r1 row(a varchar, b varchar))"), + Row("field3", "int"), + ] + + actual = TrinoEngineSpec.get_columns(mock_inspector, Table("table", "schema")) + expected = [ + ResultSetColumnType( + name="field1", + column_name="field1", + type=field1_type, + is_dttm=None, + type_generic=None, + default=None, + nullable=True, + ), + ResultSetColumnType( + name="field2", + column_name="field2", + type=field2_type, + is_dttm=None, + type_generic=None, + default=None, + nullable=True, + ), + ResultSetColumnType( + name="field3", + column_name="field3", + type=field3_type, + is_dttm=None, + type_generic=None, + default=None, + nullable=True, + ), + ] + + _assert_columns_equal(actual, expected) + + mock_inspector.bind.execute.assert_called_with('SHOW COLUMNS FROM schema."table"') + + def test_get_columns_expand_rows(mocker: MockerFixture): """Test that ROW columns are correctly expanded with expand_rows""" from superset.db_engine_specs.trino import TrinoEngineSpec @@ -536,8 +597,6 @@ def test_get_columns_expand_rows(mocker: MockerFixture): def test_get_indexes_no_table(): - from sqlalchemy.exc import NoSuchTableError - from superset.db_engine_specs.trino import TrinoEngineSpec db_mock = Mock()
