This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new fa095a98ed fix: Trino `get_columns` (#29566)
fa095a98ed is described below

commit fa095a98ed833e028cf051a8cb6854f1fab7c801
Author: Beto Dealmeida <[email protected]>
AuthorDate: Fri Jul 12 16:37:49 2024 -0400

    fix: Trino `get_columns` (#29566)
---
 superset/db_engine_specs/base.py                   |  28 +-
 superset/db_engine_specs/couchbasedb.py            |   1 -
 superset/db_engine_specs/presto.py                 | 408 ++++++++++-----------
 superset/db_engine_specs/trino.py                  |  23 +-
 .../db_engine_specs/presto_tests.py                |  11 +-
 tests/unit_tests/db_engine_specs/test_base.py      |  23 ++
 tests/unit_tests/db_engine_specs/test_trino.py     |  63 +++-
 7 files changed, 331 insertions(+), 226 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 159c510fe9..1329597f02 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1618,7 +1618,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         ]
 
     @classmethod
-    def select_star(  # pylint: disable=too-many-arguments,too-many-locals
+    def select_star(  # pylint: disable=too-many-arguments
         cls,
         database: Database,
         table: Table,
@@ -1653,14 +1653,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         if show_cols:
             fields = cls._get_fields(cols)
 
-        quote = engine.dialect.identifier_preparer.quote
-        quote_schema = engine.dialect.identifier_preparer.quote_schema
-        full_table_name = (
-            quote_schema(table.schema) + "." + quote(table.table)
-            if table.schema
-            else quote(table.table)
-        )
-
+        full_table_name = cls.quote_table(table, engine.dialect)
         qry = select(fields).select_from(text(full_table_name))
 
         if limit and cls.allow_limit_clause:
@@ -2224,6 +2217,23 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
 
         return name
 
+    @classmethod
+    def quote_table(cls, table: Table, dialect: Dialect) -> str:
+        """
+        Fully quote a table name, including the schema and catalog.
+        """
+        quoters = {
+            "catalog": dialect.identifier_preparer.quote_schema,
+            "schema": dialect.identifier_preparer.quote_schema,
+            "table": dialect.identifier_preparer.quote,
+        }
+
+        return ".".join(
+            function(getattr(table, key))
+            for key, function in quoters.items()
+            if getattr(table, key)
+        )
+
 
 # schema for adding a database by providing parameters instead of the
 # full SQLAlchemy URI
diff --git a/superset/db_engine_specs/couchbasedb.py 
b/superset/db_engine_specs/couchbasedb.py
index b9cebdba32..71dc727679 100644
--- a/superset/db_engine_specs/couchbasedb.py
+++ b/superset/db_engine_specs/couchbasedb.py
@@ -14,7 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=too-many-lines
 
 from __future__ import annotations
 
diff --git a/superset/db_engine_specs/presto.py 
b/superset/db_engine_specs/presto.py
index fbd0eff484..5a375896c1 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -672,6 +672,209 @@ class PrestoBaseEngineSpec(BaseEngineSpec, 
metaclass=ABCMeta):
             return ""
         return df.to_dict()[field_to_return][0]
 
+    @classmethod
+    def _show_columns(
+        cls,
+        inspector: Inspector,
+        table: Table,
+    ) -> list[ResultRow]:
+        """
+        Show presto column names
+        :param inspector: object that performs database schema inspection
+        :param table: table instance
+        :return: list of column objects
+        """
+        full_table_name = cls.quote_table(table, inspector.engine.dialect)
+        return inspector.bind.execute(f"SHOW COLUMNS FROM 
{full_table_name}").fetchall()
+
+    @classmethod
+    def _create_column_info(
+        cls, name: str, data_type: types.TypeEngine
+    ) -> ResultSetColumnType:
+        """
+        Create column info object
+        :param name: column name
+        :param data_type: column data type
+        :return: column info object
+        """
+        return {
+            "column_name": name,
+            "name": name,
+            "type": f"{data_type}",
+            "is_dttm": None,
+            "type_generic": None,
+        }
+
+    @classmethod
+    def get_columns(
+        cls,
+        inspector: Inspector,
+        table: Table,
+        options: dict[str, Any] | None = None,
+    ) -> list[ResultSetColumnType]:
+        """
+        Get columns from a Presto data source. This includes handling row and
+        array data types
+        :param inspector: object that performs database schema inspection
+        :param table: table instance
+        :param options: Extra configuration options, not used by this backend
+        :return: a list of results that contain column info
+                (i.e. column name and data type)
+        """
+        columns = cls._show_columns(inspector, table)
+        result: list[ResultSetColumnType] = []
+        for column in columns:
+            # parse column if it is a row or array
+            if is_feature_enabled("PRESTO_EXPAND_DATA") and (
+                "array" in column.Type or "row" in column.Type
+            ):
+                structural_column_index = len(result)
+                cls._parse_structural_column(column.Column, column.Type, 
result)
+                result[structural_column_index]["nullable"] = getattr(
+                    column, "Null", True
+                )
+                result[structural_column_index]["default"] = None
+                continue
+
+            # otherwise column is a basic data type
+            column_spec = cls.get_column_spec(column.Type)
+            column_type = column_spec.sqla_type if column_spec else None
+            if column_type is None:
+                column_type = types.String()
+                logger.info(
+                    "Did not recognize type %s of column %s",
+                    str(column.Type),
+                    str(column.Column),
+                )
+            column_info = cls._create_column_info(column.Column, column_type)
+            column_info["nullable"] = getattr(column, "Null", True)
+            column_info["default"] = None
+            column_info["column_name"] = column.Column
+            result.append(column_info)
+
+        return result
+
+    @classmethod
+    def _parse_structural_column(  # pylint: disable=too-many-locals
+        cls,
+        parent_column_name: str,
+        parent_data_type: str,
+        result: list[ResultSetColumnType],
+    ) -> None:
+        """
+        Parse a row or array column
+        :param result: list tracking the results
+        """
+        formatted_parent_column_name = parent_column_name
+        # Quote the column name if there is a space
+        if " " in parent_column_name:
+            formatted_parent_column_name = f'"{parent_column_name}"'
+        full_data_type = f"{formatted_parent_column_name} {parent_data_type}"
+        original_result_len = len(result)
+        # split on open parenthesis ( to get the structural
+        # data type and its component types
+        data_types = cls._split_data_type(full_data_type, r"\(")
+        stack: list[tuple[str, str]] = []
+        for data_type in data_types:
+            # split on closed parenthesis ) to track which component
+            # types belong to what structural data type
+            inner_types = cls._split_data_type(data_type, r"\)")
+            for inner_type in inner_types:
+                # We have finished parsing multiple structural data types
+                if not inner_type and stack:
+                    stack.pop()
+                elif cls._has_nested_data_types(inner_type):
+                    # split on comma , to get individual data types
+                    single_fields = cls._split_data_type(inner_type, ",")
+                    for single_field in single_fields:
+                        single_field = single_field.strip()
+                        # If component type starts with a comma, the first 
single field
+                        # will be an empty string. Disregard this empty string.
+                        if not single_field:
+                            continue
+                        # split on whitespace to get field name and data type
+                        field_info = cls._split_data_type(single_field, r"\s")
+                        # check if there is a structural data type within
+                        # overall structural data type
+                        column_spec = cls.get_column_spec(field_info[1])
+                        column_type = column_spec.sqla_type if column_spec 
else None
+                        if column_type is None:
+                            column_type = types.String()
+                            logger.info(
+                                "Did not recognize type %s of column %s",
+                                field_info[1],
+                                field_info[0],
+                            )
+                        if field_info[1] == "array" or field_info[1] == "row":
+                            stack.append((field_info[0], field_info[1]))
+                            full_parent_path = cls._get_full_name(stack)
+                            result.append(
+                                cls._create_column_info(full_parent_path, 
column_type)
+                            )
+                        else:  # otherwise this field is a basic data type
+                            full_parent_path = cls._get_full_name(stack)
+                            column_name = f"{full_parent_path}.{field_info[0]}"
+                            result.append(
+                                cls._create_column_info(column_name, 
column_type)
+                            )
+                    # If the component type ends with a structural data type, 
do not pop
+                    # the stack. We have run across a structural data type 
within the
+                    # overall structural data type. Otherwise, we have 
completely parsed
+                    # through the entire structural data type and can move on.
+                    if not (inner_type.endswith("array") or 
inner_type.endswith("row")):
+                        stack.pop()
+                # We have an array of row objects (i.e. array(row(...)))
+                elif inner_type in ("array", "row"):
+                    # Push a dummy object to represent the structural data type
+                    stack.append(("", inner_type))
+                # We have an array of a basic data types(i.e. array(varchar)).
+                elif stack:
+                    # Because it is an array of a basic data type. We have 
finished
+                    # parsing the structural data type and can move on.
+                    stack.pop()
+        # Unquote the column name if necessary
+        if formatted_parent_column_name != parent_column_name:
+            for index in range(original_result_len, len(result)):
+                result[index]["column_name"] = 
result[index]["column_name"].replace(
+                    formatted_parent_column_name, parent_column_name
+                )
+
+    @classmethod
+    def _split_data_type(cls, data_type: str, delimiter: str) -> list[str]:
+        """
+        Split data type based on given delimiter. Do not split the string if 
the
+        delimiter is enclosed in quotes
+        :param data_type: data type
+        :param delimiter: string separator (i.e. open parenthesis, closed 
parenthesis,
+               comma, whitespace)
+        :return: list of strings after breaking it by the delimiter
+        """
+        return re.split(rf"{delimiter}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)", 
data_type)
+
+    @classmethod
+    def _has_nested_data_types(cls, component_type: str) -> bool:
+        """
+        Check if string contains a data type. We determine if there is a data 
type by
+        whitespace or multiple data types by commas
+        :param component_type: data type
+        :return: boolean
+        """
+        comma_regex = r",(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
+        white_space_regex = r"\s(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
+        return (
+            re.search(comma_regex, component_type) is not None
+            or re.search(white_space_regex, component_type) is not None
+        )
+
+    @classmethod
+    def _get_full_name(cls, names: list[tuple[str, str]]) -> str:
+        """
+        Get the full column name
+        :param names: list of all individual column names
+        :return: full column name
+        """
+        return ".".join(column[0] for column in names if column[0])
+
 
 class PrestoEngineSpec(PrestoBaseEngineSpec):
     engine = "presto"
@@ -840,211 +1043,6 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
             results = cursor.fetchall()
             return {row[0] for row in results}
 
-    @classmethod
-    def _create_column_info(
-        cls, name: str, data_type: types.TypeEngine
-    ) -> ResultSetColumnType:
-        """
-        Create column info object
-        :param name: column name
-        :param data_type: column data type
-        :return: column info object
-        """
-        return {
-            "column_name": name,
-            "name": name,
-            "type": f"{data_type}",
-            "is_dttm": None,
-            "type_generic": None,
-        }
-
-    @classmethod
-    def _get_full_name(cls, names: list[tuple[str, str]]) -> str:
-        """
-        Get the full column name
-        :param names: list of all individual column names
-        :return: full column name
-        """
-        return ".".join(column[0] for column in names if column[0])
-
-    @classmethod
-    def _has_nested_data_types(cls, component_type: str) -> bool:
-        """
-        Check if string contains a data type. We determine if there is a data 
type by
-        whitespace or multiple data types by commas
-        :param component_type: data type
-        :return: boolean
-        """
-        comma_regex = r",(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
-        white_space_regex = r"\s(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
-        return (
-            re.search(comma_regex, component_type) is not None
-            or re.search(white_space_regex, component_type) is not None
-        )
-
-    @classmethod
-    def _split_data_type(cls, data_type: str, delimiter: str) -> list[str]:
-        """
-        Split data type based on given delimiter. Do not split the string if 
the
-        delimiter is enclosed in quotes
-        :param data_type: data type
-        :param delimiter: string separator (i.e. open parenthesis, closed 
parenthesis,
-               comma, whitespace)
-        :return: list of strings after breaking it by the delimiter
-        """
-        return re.split(rf"{delimiter}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)", 
data_type)
-
-    @classmethod
-    def _parse_structural_column(  # pylint: disable=too-many-locals
-        cls,
-        parent_column_name: str,
-        parent_data_type: str,
-        result: list[ResultSetColumnType],
-    ) -> None:
-        """
-        Parse a row or array column
-        :param result: list tracking the results
-        """
-        formatted_parent_column_name = parent_column_name
-        # Quote the column name if there is a space
-        if " " in parent_column_name:
-            formatted_parent_column_name = f'"{parent_column_name}"'
-        full_data_type = f"{formatted_parent_column_name} {parent_data_type}"
-        original_result_len = len(result)
-        # split on open parenthesis ( to get the structural
-        # data type and its component types
-        data_types = cls._split_data_type(full_data_type, r"\(")
-        stack: list[tuple[str, str]] = []
-        for data_type in data_types:
-            # split on closed parenthesis ) to track which component
-            # types belong to what structural data type
-            inner_types = cls._split_data_type(data_type, r"\)")
-            for inner_type in inner_types:
-                # We have finished parsing multiple structural data types
-                if not inner_type and stack:
-                    stack.pop()
-                elif cls._has_nested_data_types(inner_type):
-                    # split on comma , to get individual data types
-                    single_fields = cls._split_data_type(inner_type, ",")
-                    for single_field in single_fields:
-                        single_field = single_field.strip()
-                        # If component type starts with a comma, the first 
single field
-                        # will be an empty string. Disregard this empty string.
-                        if not single_field:
-                            continue
-                        # split on whitespace to get field name and data type
-                        field_info = cls._split_data_type(single_field, r"\s")
-                        # check if there is a structural data type within
-                        # overall structural data type
-                        column_spec = cls.get_column_spec(field_info[1])
-                        column_type = column_spec.sqla_type if column_spec 
else None
-                        if column_type is None:
-                            column_type = types.String()
-                            logger.info(
-                                "Did not recognize type %s of column %s",
-                                field_info[1],
-                                field_info[0],
-                            )
-                        if field_info[1] == "array" or field_info[1] == "row":
-                            stack.append((field_info[0], field_info[1]))
-                            full_parent_path = cls._get_full_name(stack)
-                            result.append(
-                                cls._create_column_info(full_parent_path, 
column_type)
-                            )
-                        else:  # otherwise this field is a basic data type
-                            full_parent_path = cls._get_full_name(stack)
-                            column_name = f"{full_parent_path}.{field_info[0]}"
-                            result.append(
-                                cls._create_column_info(column_name, 
column_type)
-                            )
-                    # If the component type ends with a structural data type, 
do not pop
-                    # the stack. We have run across a structural data type 
within the
-                    # overall structural data type. Otherwise, we have 
completely parsed
-                    # through the entire structural data type and can move on.
-                    if not (inner_type.endswith("array") or 
inner_type.endswith("row")):
-                        stack.pop()
-                # We have an array of row objects (i.e. array(row(...)))
-                elif inner_type in ("array", "row"):
-                    # Push a dummy object to represent the structural data type
-                    stack.append(("", inner_type))
-                # We have an array of a basic data types(i.e. array(varchar)).
-                elif stack:
-                    # Because it is an array of a basic data type. We have 
finished
-                    # parsing the structural data type and can move on.
-                    stack.pop()
-        # Unquote the column name if necessary
-        if formatted_parent_column_name != parent_column_name:
-            for index in range(original_result_len, len(result)):
-                result[index]["column_name"] = 
result[index]["column_name"].replace(
-                    formatted_parent_column_name, parent_column_name
-                )
-
-    @classmethod
-    def _show_columns(
-        cls,
-        inspector: Inspector,
-        table: Table,
-    ) -> list[ResultRow]:
-        """
-        Show presto column names
-        :param inspector: object that performs database schema inspection
-        :param table: table instance
-        :return: list of column objects
-        """
-        quote = inspector.engine.dialect.identifier_preparer.quote_identifier
-        full_table = quote(table.table)
-        if table.schema:
-            full_table = f"{quote(table.schema)}.{full_table}"
-        return inspector.bind.execute(f"SHOW COLUMNS FROM 
{full_table}").fetchall()
-
-    @classmethod
-    def get_columns(
-        cls,
-        inspector: Inspector,
-        table: Table,
-        options: dict[str, Any] | None = None,
-    ) -> list[ResultSetColumnType]:
-        """
-        Get columns from a Presto data source. This includes handling row and
-        array data types
-        :param inspector: object that performs database schema inspection
-        :param table: table instance
-        :param options: Extra configuration options, not used by this backend
-        :return: a list of results that contain column info
-                (i.e. column name and data type)
-        """
-        columns = cls._show_columns(inspector, table)
-        result: list[ResultSetColumnType] = []
-        for column in columns:
-            # parse column if it is a row or array
-            if is_feature_enabled("PRESTO_EXPAND_DATA") and (
-                "array" in column.Type or "row" in column.Type
-            ):
-                structural_column_index = len(result)
-                cls._parse_structural_column(column.Column, column.Type, 
result)
-                result[structural_column_index]["nullable"] = getattr(
-                    column, "Null", True
-                )
-                result[structural_column_index]["default"] = None
-                continue
-
-            # otherwise column is a basic data type
-            column_spec = cls.get_column_spec(column.Type)
-            column_type = column_spec.sqla_type if column_spec else None
-            if column_type is None:
-                column_type = types.String()
-                logger.info(
-                    "Did not recognize type %s of column %s",
-                    str(column.Type),
-                    str(column.Column),
-                )
-            column_info = cls._create_column_info(column.Column, column_type)
-            column_info["nullable"] = getattr(column, "Null", True)
-            column_info["default"] = None
-            column_info["column_name"] = column.Column
-            result.append(column_info)
-        return result
-
     @classmethod
     def _is_column_name_quoted(cls, column_name: str) -> bool:
         """
diff --git a/superset/db_engine_specs/trino.py 
b/superset/db_engine_specs/trino.py
index 143276bdc3..1eb4b30787 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -36,7 +36,7 @@ from sqlalchemy.exc import NoSuchTableError
 from superset import db
 from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, 
USER_AGENT
 from superset.databases.utils import make_url_safe
-from superset.db_engine_specs.base import BaseEngineSpec
+from superset.db_engine_specs.base import BaseEngineSpec, 
convert_inspector_columns
 from superset.db_engine_specs.exceptions import (
     SupersetDBAPIConnectionError,
     SupersetDBAPIDatabaseError,
@@ -241,7 +241,11 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
 
         execute_thread = threading.Thread(
             target=_execute,
-            args=(execute_result, execute_event, 
current_app._get_current_object()),  # pylint: disable=protected-access
+            args=(
+                execute_result,
+                execute_event,
+                current_app._get_current_object(),  # pylint: 
disable=protected-access
+            ),
         )
         execute_thread.start()
 
@@ -433,7 +437,17 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
         "schema_options", expand the schema definition out to show all
         subfields of nested ROWs as their appropriate dotted paths.
         """
-        base_cols = super().get_columns(inspector, table, options)
+        # The Trino dialect raises `NoSuchTableError` on the inspection 
methods when the
+        # table is empty. We can work around this by running a `SHOW COLUMNS 
FROM` query
+        # when that happens, using the method from the Presto base engine spec.
+        try:
+            # `SELECT * FROM information_schema.columns WHERE ...`
+            sqla_columns = inspector.get_columns(table.table, table.schema)
+            base_cols = convert_inspector_columns(sqla_columns)
+        except NoSuchTableError:
+            # `SHOW COLUMNS FROM ...`
+            base_cols = super().get_columns(inspector, table, options)
+
         if not (options or {}).get("expand_rows"):
             return base_cols
 
@@ -483,9 +497,6 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
         :param to_sql_kwargs: The `pandas.DataFrame.to_sql` keyword arguments
         :see: superset.db_engine_specs.HiveEngineSpec.df_to_sql
         """
-
-        # pylint: disable=import-outside-toplevel
-
         if to_sql_kwargs["if_exists"] == "append":
             raise SupersetException("Append operation not currently supported")
 
diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py 
b/tests/integration_tests/db_engine_specs/presto_tests.py
index 607afa6953..69e5273f2b 100644
--- a/tests/integration_tests/db_engine_specs/presto_tests.py
+++ b/tests/integration_tests/db_engine_specs/presto_tests.py
@@ -78,7 +78,10 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
 
     def verify_presto_column(self, column, expected_results):
         inspector = mock.Mock()
-        inspector.engine.dialect.identifier_preparer.quote_identifier = 
mock.Mock()
+        preparer = inspector.engine.dialect.identifier_preparer
+        preparer.quote_identifier = preparer.quote = preparer.quote_schema = (
+            lambda x: f'"{x}"'
+        )
         row = mock.Mock()
         row.Column, row.Type, row.Null = column
         inspector.bind.execute.return_value.fetchall = 
mock.Mock(return_value=[row])
@@ -798,7 +801,8 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
 
     def test_show_columns(self):
         inspector = mock.MagicMock()
-        inspector.engine.dialect.identifier_preparer.quote_identifier = (
+        preparer = inspector.engine.dialect.identifier_preparer
+        preparer.quote_identifier = preparer.quote = preparer.quote_schema = (
             lambda x: f'"{x}"'
         )
         inspector.bind.execute.return_value.fetchall = mock.MagicMock(
@@ -813,7 +817,8 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
 
     def test_show_columns_with_schema(self):
         inspector = mock.MagicMock()
-        inspector.engine.dialect.identifier_preparer.quote_identifier = (
+        preparer = inspector.engine.dialect.identifier_preparer
+        preparer.quote_identifier = preparer.quote = preparer.quote_schema = (
             lambda x: f'"{x}"'
         )
         inspector.bind.execute.return_value.fetchall = mock.MagicMock(
diff --git a/tests/unit_tests/db_engine_specs/test_base.py 
b/tests/unit_tests/db_engine_specs/test_base.py
index a3920c8916..9ec1ebaf00 100644
--- a/tests/unit_tests/db_engine_specs/test_base.py
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -311,3 +311,26 @@ def test_get_default_catalog(mocker: MockerFixture) -> 
None:
 
     database = mocker.MagicMock()
     assert BaseEngineSpec.get_default_catalog(database) is None
+
+
+def test_quote_table() -> None:
+    """
+    Test the `quote_table` function.
+    """
+    from superset.db_engine_specs.base import BaseEngineSpec
+
+    dialect = sqlite.dialect()
+
+    assert BaseEngineSpec.quote_table(Table("table"), dialect) == '"table"'
+    assert (
+        BaseEngineSpec.quote_table(Table("table", "schema"), dialect)
+        == 'schema."table"'
+    )
+    assert (
+        BaseEngineSpec.quote_table(Table("table", "schema", "catalog"), 
dialect)
+        == 'catalog.schema."table"'
+    )
+    assert (
+        BaseEngineSpec.quote_table(Table("ta ble", "sche.ma", 'cata"log'), 
dialect)
+        == '"cata""log"."sche.ma"."ta ble"'
+    )
diff --git a/tests/unit_tests/db_engine_specs/test_trino.py 
b/tests/unit_tests/db_engine_specs/test_trino.py
index 3a2ac91ad6..a0923e8111 100644
--- a/tests/unit_tests/db_engine_specs/test_trino.py
+++ b/tests/unit_tests/db_engine_specs/test_trino.py
@@ -16,6 +16,7 @@
 # under the License.
 # pylint: disable=unused-argument, import-outside-toplevel, protected-access
 import copy
+from collections import namedtuple
 from datetime import datetime
 from typing import Any, Optional
 from unittest.mock import MagicMock, Mock, patch
@@ -25,7 +26,9 @@ import pytest
 from pytest_mock import MockerFixture
 from requests.exceptions import ConnectionError as RequestsConnectionError
 from sqlalchemy import sql, text, types
+from sqlalchemy.dialects import sqlite
 from sqlalchemy.engine.url import make_url
+from sqlalchemy.exc import NoSuchTableError
 from trino.exceptions import TrinoExternalError, TrinoInternalError, 
TrinoUserError
 from trino.sqlalchemy import datatype
 from trino.sqlalchemy.dialect import TrinoDialect
@@ -464,6 +467,64 @@ def test_get_columns(mocker: MockerFixture):
     _assert_columns_equal(actual, expected)
 
 
+def test_get_columns_error(mocker: MockerFixture):
+    """
+    Test that we fallback to a `SHOW COLUMNS FROM ...` query.
+    """
+    from superset.db_engine_specs.trino import TrinoEngineSpec
+
+    field1_type = datatype.parse_sqltype("row(a varchar, b date)")
+    field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))")
+    field3_type = datatype.parse_sqltype("int")
+
+    mock_inspector = mocker.MagicMock()
+    mock_inspector.engine.dialect = sqlite.dialect()
+    mock_inspector.get_columns.side_effect = NoSuchTableError(
+        "The specified table does not exist."
+    )
+    Row = namedtuple("Row", ["Column", "Type"])
+    mock_inspector.bind.execute().fetchall.return_value = [
+        Row("field1", "row(a varchar, b date)"),
+        Row("field2", "row(r1 row(a varchar, b varchar))"),
+        Row("field3", "int"),
+    ]
+
+    actual = TrinoEngineSpec.get_columns(mock_inspector, Table("table", 
"schema"))
+    expected = [
+        ResultSetColumnType(
+            name="field1",
+            column_name="field1",
+            type=field1_type,
+            is_dttm=None,
+            type_generic=None,
+            default=None,
+            nullable=True,
+        ),
+        ResultSetColumnType(
+            name="field2",
+            column_name="field2",
+            type=field2_type,
+            is_dttm=None,
+            type_generic=None,
+            default=None,
+            nullable=True,
+        ),
+        ResultSetColumnType(
+            name="field3",
+            column_name="field3",
+            type=field3_type,
+            is_dttm=None,
+            type_generic=None,
+            default=None,
+            nullable=True,
+        ),
+    ]
+
+    _assert_columns_equal(actual, expected)
+
+    mock_inspector.bind.execute.assert_called_with('SHOW COLUMNS FROM 
schema."table"')
+
+
 def test_get_columns_expand_rows(mocker: MockerFixture):
     """Test that ROW columns are correctly expanded with expand_rows"""
     from superset.db_engine_specs.trino import TrinoEngineSpec
@@ -536,8 +597,6 @@ def test_get_columns_expand_rows(mocker: MockerFixture):
 
 
 def test_get_indexes_no_table():
-    from sqlalchemy.exc import NoSuchTableError
-
     from superset.db_engine_specs.trino import TrinoEngineSpec
 
     db_mock = Mock()

Reply via email to