This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch trino-wrappers
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 993bc79fb0e9ecef3ca640b842715a3c3dd7f8b0
Author: Beto Dealmeida <[email protected]>
AuthorDate: Thu Jul 11 17:01:37 2024 -0400

    fix: Trino get_columns
---
 superset/db_engine_specs/base.py               |  36 ++++--
 superset/db_engine_specs/presto.py             | 166 ++++++++++++-------------
 superset/db_engine_specs/trino.py              |  20 ++-
 tests/unit_tests/db_engine_specs/test_base.py  |  23 ++++
 tests/unit_tests/db_engine_specs/test_trino.py |  63 +++++++++-
 5 files changed, 207 insertions(+), 101 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 159c510fe9..abd4c8c895 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -130,7 +130,9 @@ builtin_time_grains: dict[str | None, str] = {
 }
 
 
-class TimestampExpression(ColumnClause):  # pylint: disable=abstract-method, 
too-many-ancestors
+class TimestampExpression(
+    ColumnClause
+):  # pylint: disable=abstract-method, too-many-ancestors
     def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
         """Sqlalchemy class that can be used to render native column elements 
respecting
         engine-specific quoting rules as part of a string-based expression.
@@ -388,9 +390,9 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
     max_column_name_length: int | None = None
     try_remove_schema_from_table_name = True  # pylint: disable=invalid-name
     run_multiple_statements_as_one = False
-    custom_errors: dict[
-        Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]
-    ] = {}
+    custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, 
Any]]] = (
+        {}
+    )
 
     # Whether the engine supports file uploads
     # if True, database will be listed as option in the upload file form
@@ -1653,14 +1655,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         if show_cols:
             fields = cls._get_fields(cols)
 
-        quote = engine.dialect.identifier_preparer.quote
-        quote_schema = engine.dialect.identifier_preparer.quote_schema
-        full_table_name = (
-            quote_schema(table.schema) + "." + quote(table.table)
-            if table.schema
-            else quote(table.table)
-        )
-
+        full_table_name = cls.quote_table(table, engine.dialect)
         qry = select(fields).select_from(text(full_table_name))
 
         if limit and cls.allow_limit_clause:
@@ -2224,6 +2219,23 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
 
         return name
 
+    @classmethod
+    def quote_table(cls, table: Table, dialect: Dialect) -> str:
+        """
+        Fully quote a table name, including the schema and catalog.
+        """
+        quoters = {
+            "catalog": dialect.identifier_preparer.quote_schema,
+            "schema": dialect.identifier_preparer.quote_schema,
+            "table": dialect.identifier_preparer.quote,
+        }
+
+        return ".".join(
+            function(getattr(table, key))
+            for key, function in quoters.items()
+            if getattr(table, key)
+        )
+
 
 # schema for adding a database by providing parameters instead of the
 # full SQLAlchemy URI
diff --git a/superset/db_engine_specs/presto.py 
b/superset/db_engine_specs/presto.py
index fbd0eff484..d5db0fdfaa 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -672,6 +672,88 @@ class PrestoBaseEngineSpec(BaseEngineSpec, 
metaclass=ABCMeta):
             return ""
         return df.to_dict()[field_to_return][0]
 
+    @classmethod
+    def _show_columns(
+        cls,
+        inspector: Inspector,
+        table: Table,
+    ) -> list[ResultRow]:
+        """
+        Show presto column names
+        :param inspector: object that performs database schema inspection
+        :param table: table instance
+        :return: list of column objects
+        """
+        full_table_name = cls.quote_table(table, inspector.engine.dialect)
+        return inspector.bind.execute(f"SHOW COLUMNS FROM 
{full_table_name}").fetchall()
+
+    @classmethod
+    def _create_column_info(
+        cls, name: str, data_type: types.TypeEngine
+    ) -> ResultSetColumnType:
+        """
+        Create column info object
+        :param name: column name
+        :param data_type: column data type
+        :return: column info object
+        """
+        return {
+            "column_name": name,
+            "name": name,
+            "type": f"{data_type}",
+            "is_dttm": None,
+            "type_generic": None,
+        }
+
+    @classmethod
+    def get_columns(
+        cls,
+        inspector: Inspector,
+        table: Table,
+        options: dict[str, Any] | None = None,
+    ) -> list[ResultSetColumnType]:
+        """
+        Get columns from a Presto data source. This includes handling row and
+        array data types
+        :param inspector: object that performs database schema inspection
+        :param table: table instance
+        :param options: Extra configuration options, not used by this backend
+        :return: a list of results that contain column info
+                (i.e. column name and data type)
+        """
+        columns = cls._show_columns(inspector, table)
+        result: list[ResultSetColumnType] = []
+        for column in columns:
+            # parse column if it is a row or array
+            if is_feature_enabled("PRESTO_EXPAND_DATA") and (
+                "array" in column.Type or "row" in column.Type
+            ):
+                structural_column_index = len(result)
+                cls._parse_structural_column(column.Column, column.Type, 
result)
+                result[structural_column_index]["nullable"] = getattr(
+                    column, "Null", True
+                )
+                result[structural_column_index]["default"] = None
+                continue
+
+            # otherwise column is a basic data type
+            column_spec = cls.get_column_spec(column.Type)
+            column_type = column_spec.sqla_type if column_spec else None
+            if column_type is None:
+                column_type = types.String()
+                logger.info(
+                    "Did not recognize type %s of column %s",
+                    str(column.Type),
+                    str(column.Column),
+                )
+            column_info = cls._create_column_info(column.Column, column_type)
+            column_info["nullable"] = getattr(column, "Null", True)
+            column_info["default"] = None
+            column_info["column_name"] = column.Column
+            result.append(column_info)
+
+        return result
+
 
 class PrestoEngineSpec(PrestoBaseEngineSpec):
     engine = "presto"
@@ -840,24 +922,6 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
             results = cursor.fetchall()
             return {row[0] for row in results}
 
-    @classmethod
-    def _create_column_info(
-        cls, name: str, data_type: types.TypeEngine
-    ) -> ResultSetColumnType:
-        """
-        Create column info object
-        :param name: column name
-        :param data_type: column data type
-        :return: column info object
-        """
-        return {
-            "column_name": name,
-            "name": name,
-            "type": f"{data_type}",
-            "is_dttm": None,
-            "type_generic": None,
-        }
-
     @classmethod
     def _get_full_name(cls, names: list[tuple[str, str]]) -> str:
         """
@@ -979,72 +1043,6 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
                     formatted_parent_column_name, parent_column_name
                 )
 
-    @classmethod
-    def _show_columns(
-        cls,
-        inspector: Inspector,
-        table: Table,
-    ) -> list[ResultRow]:
-        """
-        Show presto column names
-        :param inspector: object that performs database schema inspection
-        :param table: table instance
-        :return: list of column objects
-        """
-        quote = inspector.engine.dialect.identifier_preparer.quote_identifier
-        full_table = quote(table.table)
-        if table.schema:
-            full_table = f"{quote(table.schema)}.{full_table}"
-        return inspector.bind.execute(f"SHOW COLUMNS FROM 
{full_table}").fetchall()
-
-    @classmethod
-    def get_columns(
-        cls,
-        inspector: Inspector,
-        table: Table,
-        options: dict[str, Any] | None = None,
-    ) -> list[ResultSetColumnType]:
-        """
-        Get columns from a Presto data source. This includes handling row and
-        array data types
-        :param inspector: object that performs database schema inspection
-        :param table: table instance
-        :param options: Extra configuration options, not used by this backend
-        :return: a list of results that contain column info
-                (i.e. column name and data type)
-        """
-        columns = cls._show_columns(inspector, table)
-        result: list[ResultSetColumnType] = []
-        for column in columns:
-            # parse column if it is a row or array
-            if is_feature_enabled("PRESTO_EXPAND_DATA") and (
-                "array" in column.Type or "row" in column.Type
-            ):
-                structural_column_index = len(result)
-                cls._parse_structural_column(column.Column, column.Type, 
result)
-                result[structural_column_index]["nullable"] = getattr(
-                    column, "Null", True
-                )
-                result[structural_column_index]["default"] = None
-                continue
-
-            # otherwise column is a basic data type
-            column_spec = cls.get_column_spec(column.Type)
-            column_type = column_spec.sqla_type if column_spec else None
-            if column_type is None:
-                column_type = types.String()
-                logger.info(
-                    "Did not recognize type %s of column %s",
-                    str(column.Type),
-                    str(column.Column),
-                )
-            column_info = cls._create_column_info(column.Column, column_type)
-            column_info["nullable"] = getattr(column, "Null", True)
-            column_info["default"] = None
-            column_info["column_name"] = column.Column
-            result.append(column_info)
-        return result
-
     @classmethod
     def _is_column_name_quoted(cls, column_name: str) -> bool:
         """
diff --git a/superset/db_engine_specs/trino.py 
b/superset/db_engine_specs/trino.py
index 143276bdc3..dabb89ed8a 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -36,7 +36,7 @@ from sqlalchemy.exc import NoSuchTableError
 from superset import db
 from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, 
USER_AGENT
 from superset.databases.utils import make_url_safe
-from superset.db_engine_specs.base import BaseEngineSpec
+from superset.db_engine_specs.base import BaseEngineSpec, 
convert_inspector_columns
 from superset.db_engine_specs.exceptions import (
     SupersetDBAPIConnectionError,
     SupersetDBAPIDatabaseError,
@@ -241,7 +241,11 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
 
         execute_thread = threading.Thread(
             target=_execute,
-            args=(execute_result, execute_event, 
current_app._get_current_object()),  # pylint: disable=protected-access
+            args=(
+                execute_result,
+                execute_event,
+                current_app._get_current_object(),
+            ),  # pylint: disable=protected-access
         )
         execute_thread.start()
 
@@ -433,7 +437,17 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
         "schema_options", expand the schema definition out to show all
         subfields of nested ROWs as their appropriate dotted paths.
         """
-        base_cols = super().get_columns(inspector, table, options)
+        # The Trino dialect raises `NoSuchTableError` on the inspection 
methods when the
+        # table is empty. We can work around this by running a `SHOW COLUMNS 
FROM` query
+        # when that happens, using the method from the Presto base engine spec.
+        try:
+            # `SELECT * FROM information_schema.columns WHERE ...`
+            sqla_columns = inspector.get_columns(table.table, table.schema)
+            base_cols = convert_inspector_columns(sqla_columns)
+        except NoSuchTableError:
+            # `SHOW COLUMNS FROM ...`
+            base_cols = super().get_columns(inspector, table, options)
+
         if not (options or {}).get("expand_rows"):
             return base_cols
 
diff --git a/tests/unit_tests/db_engine_specs/test_base.py 
b/tests/unit_tests/db_engine_specs/test_base.py
index a3920c8916..9ec1ebaf00 100644
--- a/tests/unit_tests/db_engine_specs/test_base.py
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -311,3 +311,26 @@ def test_get_default_catalog(mocker: MockerFixture) -> 
None:
 
     database = mocker.MagicMock()
     assert BaseEngineSpec.get_default_catalog(database) is None
+
+
+def test_quote_table() -> None:
+    """
+    Test the `quote_table` function.
+    """
+    from superset.db_engine_specs.base import BaseEngineSpec
+
+    dialect = sqlite.dialect()
+
+    assert BaseEngineSpec.quote_table(Table("table"), dialect) == '"table"'
+    assert (
+        BaseEngineSpec.quote_table(Table("table", "schema"), dialect)
+        == 'schema."table"'
+    )
+    assert (
+        BaseEngineSpec.quote_table(Table("table", "schema", "catalog"), 
dialect)
+        == 'catalog.schema."table"'
+    )
+    assert (
+        BaseEngineSpec.quote_table(Table("ta ble", "sche.ma", 'cata"log'), 
dialect)
+        == '"cata""log"."sche.ma"."ta ble"'
+    )
diff --git a/tests/unit_tests/db_engine_specs/test_trino.py 
b/tests/unit_tests/db_engine_specs/test_trino.py
index 3a2ac91ad6..a0923e8111 100644
--- a/tests/unit_tests/db_engine_specs/test_trino.py
+++ b/tests/unit_tests/db_engine_specs/test_trino.py
@@ -16,6 +16,7 @@
 # under the License.
 # pylint: disable=unused-argument, import-outside-toplevel, protected-access
 import copy
+from collections import namedtuple
 from datetime import datetime
 from typing import Any, Optional
 from unittest.mock import MagicMock, Mock, patch
@@ -25,7 +26,9 @@ import pytest
 from pytest_mock import MockerFixture
 from requests.exceptions import ConnectionError as RequestsConnectionError
 from sqlalchemy import sql, text, types
+from sqlalchemy.dialects import sqlite
 from sqlalchemy.engine.url import make_url
+from sqlalchemy.exc import NoSuchTableError
 from trino.exceptions import TrinoExternalError, TrinoInternalError, 
TrinoUserError
 from trino.sqlalchemy import datatype
 from trino.sqlalchemy.dialect import TrinoDialect
@@ -464,6 +467,64 @@ def test_get_columns(mocker: MockerFixture):
     _assert_columns_equal(actual, expected)
 
 
+def test_get_columns_error(mocker: MockerFixture):
+    """
+    Test that we fallback to a `SHOW COLUMNS FROM ...` query.
+    """
+    from superset.db_engine_specs.trino import TrinoEngineSpec
+
+    field1_type = datatype.parse_sqltype("row(a varchar, b date)")
+    field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))")
+    field3_type = datatype.parse_sqltype("int")
+
+    mock_inspector = mocker.MagicMock()
+    mock_inspector.engine.dialect = sqlite.dialect()
+    mock_inspector.get_columns.side_effect = NoSuchTableError(
+        "The specified table does not exist."
+    )
+    Row = namedtuple("Row", ["Column", "Type"])
+    mock_inspector.bind.execute().fetchall.return_value = [
+        Row("field1", "row(a varchar, b date)"),
+        Row("field2", "row(r1 row(a varchar, b varchar))"),
+        Row("field3", "int"),
+    ]
+
+    actual = TrinoEngineSpec.get_columns(mock_inspector, Table("table", 
"schema"))
+    expected = [
+        ResultSetColumnType(
+            name="field1",
+            column_name="field1",
+            type=field1_type,
+            is_dttm=None,
+            type_generic=None,
+            default=None,
+            nullable=True,
+        ),
+        ResultSetColumnType(
+            name="field2",
+            column_name="field2",
+            type=field2_type,
+            is_dttm=None,
+            type_generic=None,
+            default=None,
+            nullable=True,
+        ),
+        ResultSetColumnType(
+            name="field3",
+            column_name="field3",
+            type=field3_type,
+            is_dttm=None,
+            type_generic=None,
+            default=None,
+            nullable=True,
+        ),
+    ]
+
+    _assert_columns_equal(actual, expected)
+
+    mock_inspector.bind.execute.assert_called_with('SHOW COLUMNS FROM 
schema."table"')
+
+
 def test_get_columns_expand_rows(mocker: MockerFixture):
     """Test that ROW columns are correctly expanded with expand_rows"""
     from superset.db_engine_specs.trino import TrinoEngineSpec
@@ -536,8 +597,6 @@ def test_get_columns_expand_rows(mocker: MockerFixture):
 
 
 def test_get_indexes_no_table():
-    from sqlalchemy.exc import NoSuchTableError
-
     from superset.db_engine_specs.trino import TrinoEngineSpec
 
     db_mock = Mock()

Reply via email to