This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new fa095a98ed fix: Trino `get_columns` (#29566)
fa095a98ed is described below
commit fa095a98ed833e028cf051a8cb6854f1fab7c801
Author: Beto Dealmeida <[email protected]>
AuthorDate: Fri Jul 12 16:37:49 2024 -0400
fix: Trino `get_columns` (#29566)
---
superset/db_engine_specs/base.py | 28 +-
superset/db_engine_specs/couchbasedb.py | 1 -
superset/db_engine_specs/presto.py | 408 ++++++++++-----------
superset/db_engine_specs/trino.py | 23 +-
.../db_engine_specs/presto_tests.py | 11 +-
tests/unit_tests/db_engine_specs/test_base.py | 23 ++
tests/unit_tests/db_engine_specs/test_trino.py | 63 +++-
7 files changed, 331 insertions(+), 226 deletions(-)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 159c510fe9..1329597f02 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1618,7 +1618,7 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
]
@classmethod
- def select_star( # pylint: disable=too-many-arguments,too-many-locals
+ def select_star( # pylint: disable=too-many-arguments
cls,
database: Database,
table: Table,
@@ -1653,14 +1653,7 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
if show_cols:
fields = cls._get_fields(cols)
- quote = engine.dialect.identifier_preparer.quote
- quote_schema = engine.dialect.identifier_preparer.quote_schema
- full_table_name = (
- quote_schema(table.schema) + "." + quote(table.table)
- if table.schema
- else quote(table.table)
- )
-
+ full_table_name = cls.quote_table(table, engine.dialect)
qry = select(fields).select_from(text(full_table_name))
if limit and cls.allow_limit_clause:
@@ -2224,6 +2217,23 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
return name
+ @classmethod
+ def quote_table(cls, table: Table, dialect: Dialect) -> str:
+ """
+ Fully quote a table name, including the schema and catalog.
+ """
+ quoters = {
+ "catalog": dialect.identifier_preparer.quote_schema,
+ "schema": dialect.identifier_preparer.quote_schema,
+ "table": dialect.identifier_preparer.quote,
+ }
+
+ return ".".join(
+ function(getattr(table, key))
+ for key, function in quoters.items()
+ if getattr(table, key)
+ )
+
# schema for adding a database by providing parameters instead of the
# full SQLAlchemy URI
diff --git a/superset/db_engine_specs/couchbasedb.py
b/superset/db_engine_specs/couchbasedb.py
index b9cebdba32..71dc727679 100644
--- a/superset/db_engine_specs/couchbasedb.py
+++ b/superset/db_engine_specs/couchbasedb.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=too-many-lines
from __future__ import annotations
diff --git a/superset/db_engine_specs/presto.py
b/superset/db_engine_specs/presto.py
index fbd0eff484..5a375896c1 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -672,6 +672,209 @@ class PrestoBaseEngineSpec(BaseEngineSpec,
metaclass=ABCMeta):
return ""
return df.to_dict()[field_to_return][0]
+ @classmethod
+ def _show_columns(
+ cls,
+ inspector: Inspector,
+ table: Table,
+ ) -> list[ResultRow]:
+ """
+ Show presto column names
+ :param inspector: object that performs database schema inspection
+ :param table: table instance
+ :return: list of column objects
+ """
+ full_table_name = cls.quote_table(table, inspector.engine.dialect)
+ return inspector.bind.execute(f"SHOW COLUMNS FROM
{full_table_name}").fetchall()
+
+ @classmethod
+ def _create_column_info(
+ cls, name: str, data_type: types.TypeEngine
+ ) -> ResultSetColumnType:
+ """
+ Create column info object
+ :param name: column name
+ :param data_type: column data type
+ :return: column info object
+ """
+ return {
+ "column_name": name,
+ "name": name,
+ "type": f"{data_type}",
+ "is_dttm": None,
+ "type_generic": None,
+ }
+
+ @classmethod
+ def get_columns(
+ cls,
+ inspector: Inspector,
+ table: Table,
+ options: dict[str, Any] | None = None,
+ ) -> list[ResultSetColumnType]:
+ """
+ Get columns from a Presto data source. This includes handling row and
+ array data types
+ :param inspector: object that performs database schema inspection
+ :param table: table instance
+ :param options: Extra configuration options, not used by this backend
+ :return: a list of results that contain column info
+ (i.e. column name and data type)
+ """
+ columns = cls._show_columns(inspector, table)
+ result: list[ResultSetColumnType] = []
+ for column in columns:
+ # parse column if it is a row or array
+ if is_feature_enabled("PRESTO_EXPAND_DATA") and (
+ "array" in column.Type or "row" in column.Type
+ ):
+ structural_column_index = len(result)
+ cls._parse_structural_column(column.Column, column.Type,
result)
+ result[structural_column_index]["nullable"] = getattr(
+ column, "Null", True
+ )
+ result[structural_column_index]["default"] = None
+ continue
+
+ # otherwise column is a basic data type
+ column_spec = cls.get_column_spec(column.Type)
+ column_type = column_spec.sqla_type if column_spec else None
+ if column_type is None:
+ column_type = types.String()
+ logger.info(
+ "Did not recognize type %s of column %s",
+ str(column.Type),
+ str(column.Column),
+ )
+ column_info = cls._create_column_info(column.Column, column_type)
+ column_info["nullable"] = getattr(column, "Null", True)
+ column_info["default"] = None
+ column_info["column_name"] = column.Column
+ result.append(column_info)
+
+ return result
+
+ @classmethod
+ def _parse_structural_column( # pylint: disable=too-many-locals
+ cls,
+ parent_column_name: str,
+ parent_data_type: str,
+ result: list[ResultSetColumnType],
+ ) -> None:
+ """
+ Parse a row or array column
+ :param result: list tracking the results
+ """
+ formatted_parent_column_name = parent_column_name
+ # Quote the column name if there is a space
+ if " " in parent_column_name:
+ formatted_parent_column_name = f'"{parent_column_name}"'
+ full_data_type = f"{formatted_parent_column_name} {parent_data_type}"
+ original_result_len = len(result)
+ # split on open parenthesis ( to get the structural
+ # data type and its component types
+ data_types = cls._split_data_type(full_data_type, r"\(")
+ stack: list[tuple[str, str]] = []
+ for data_type in data_types:
+ # split on closed parenthesis ) to track which component
+ # types belong to what structural data type
+ inner_types = cls._split_data_type(data_type, r"\)")
+ for inner_type in inner_types:
+ # We have finished parsing multiple structural data types
+ if not inner_type and stack:
+ stack.pop()
+ elif cls._has_nested_data_types(inner_type):
+ # split on comma , to get individual data types
+ single_fields = cls._split_data_type(inner_type, ",")
+ for single_field in single_fields:
+ single_field = single_field.strip()
+ # If component type starts with a comma, the first
single field
+ # will be an empty string. Disregard this empty string.
+ if not single_field:
+ continue
+ # split on whitespace to get field name and data type
+ field_info = cls._split_data_type(single_field, r"\s")
+ # check if there is a structural data type within
+ # overall structural data type
+ column_spec = cls.get_column_spec(field_info[1])
+ column_type = column_spec.sqla_type if column_spec
else None
+ if column_type is None:
+ column_type = types.String()
+ logger.info(
+ "Did not recognize type %s of column %s",
+ field_info[1],
+ field_info[0],
+ )
+ if field_info[1] == "array" or field_info[1] == "row":
+ stack.append((field_info[0], field_info[1]))
+ full_parent_path = cls._get_full_name(stack)
+ result.append(
+ cls._create_column_info(full_parent_path,
column_type)
+ )
+ else: # otherwise this field is a basic data type
+ full_parent_path = cls._get_full_name(stack)
+ column_name = f"{full_parent_path}.{field_info[0]}"
+ result.append(
+ cls._create_column_info(column_name,
column_type)
+ )
+ # If the component type ends with a structural data type,
do not pop
+ # the stack. We have run across a structural data type
within the
+ # overall structural data type. Otherwise, we have
completely parsed
+ # through the entire structural data type and can move on.
+ if not (inner_type.endswith("array") or
inner_type.endswith("row")):
+ stack.pop()
+ # We have an array of row objects (i.e. array(row(...)))
+ elif inner_type in ("array", "row"):
+ # Push a dummy object to represent the structural data type
+ stack.append(("", inner_type))
+ # We have an array of a basic data types(i.e. array(varchar)).
+ elif stack:
+ # Because it is an array of a basic data type. We have
finished
+ # parsing the structural data type and can move on.
+ stack.pop()
+ # Unquote the column name if necessary
+ if formatted_parent_column_name != parent_column_name:
+ for index in range(original_result_len, len(result)):
+ result[index]["column_name"] =
result[index]["column_name"].replace(
+ formatted_parent_column_name, parent_column_name
+ )
+
+ @classmethod
+ def _split_data_type(cls, data_type: str, delimiter: str) -> list[str]:
+ """
+ Split data type based on given delimiter. Do not split the string if
the
+ delimiter is enclosed in quotes
+ :param data_type: data type
+ :param delimiter: string separator (i.e. open parenthesis, closed
parenthesis,
+ comma, whitespace)
+ :return: list of strings after breaking it by the delimiter
+ """
+ return re.split(rf"{delimiter}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)",
data_type)
+
+ @classmethod
+ def _has_nested_data_types(cls, component_type: str) -> bool:
+ """
+ Check if string contains a data type. We determine if there is a data
type by
+ whitespace or multiple data types by commas
+ :param component_type: data type
+ :return: boolean
+ """
+ comma_regex = r",(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
+ white_space_regex = r"\s(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
+ return (
+ re.search(comma_regex, component_type) is not None
+ or re.search(white_space_regex, component_type) is not None
+ )
+
+ @classmethod
+ def _get_full_name(cls, names: list[tuple[str, str]]) -> str:
+ """
+ Get the full column name
+ :param names: list of all individual column names
+ :return: full column name
+ """
+ return ".".join(column[0] for column in names if column[0])
+
class PrestoEngineSpec(PrestoBaseEngineSpec):
engine = "presto"
@@ -840,211 +1043,6 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
results = cursor.fetchall()
return {row[0] for row in results}
- @classmethod
- def _create_column_info(
- cls, name: str, data_type: types.TypeEngine
- ) -> ResultSetColumnType:
- """
- Create column info object
- :param name: column name
- :param data_type: column data type
- :return: column info object
- """
- return {
- "column_name": name,
- "name": name,
- "type": f"{data_type}",
- "is_dttm": None,
- "type_generic": None,
- }
-
- @classmethod
- def _get_full_name(cls, names: list[tuple[str, str]]) -> str:
- """
- Get the full column name
- :param names: list of all individual column names
- :return: full column name
- """
- return ".".join(column[0] for column in names if column[0])
-
- @classmethod
- def _has_nested_data_types(cls, component_type: str) -> bool:
- """
- Check if string contains a data type. We determine if there is a data
type by
- whitespace or multiple data types by commas
- :param component_type: data type
- :return: boolean
- """
- comma_regex = r",(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
- white_space_regex = r"\s(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
- return (
- re.search(comma_regex, component_type) is not None
- or re.search(white_space_regex, component_type) is not None
- )
-
- @classmethod
- def _split_data_type(cls, data_type: str, delimiter: str) -> list[str]:
- """
- Split data type based on given delimiter. Do not split the string if
the
- delimiter is enclosed in quotes
- :param data_type: data type
- :param delimiter: string separator (i.e. open parenthesis, closed
parenthesis,
- comma, whitespace)
- :return: list of strings after breaking it by the delimiter
- """
- return re.split(rf"{delimiter}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)",
data_type)
-
- @classmethod
- def _parse_structural_column( # pylint: disable=too-many-locals
- cls,
- parent_column_name: str,
- parent_data_type: str,
- result: list[ResultSetColumnType],
- ) -> None:
- """
- Parse a row or array column
- :param result: list tracking the results
- """
- formatted_parent_column_name = parent_column_name
- # Quote the column name if there is a space
- if " " in parent_column_name:
- formatted_parent_column_name = f'"{parent_column_name}"'
- full_data_type = f"{formatted_parent_column_name} {parent_data_type}"
- original_result_len = len(result)
- # split on open parenthesis ( to get the structural
- # data type and its component types
- data_types = cls._split_data_type(full_data_type, r"\(")
- stack: list[tuple[str, str]] = []
- for data_type in data_types:
- # split on closed parenthesis ) to track which component
- # types belong to what structural data type
- inner_types = cls._split_data_type(data_type, r"\)")
- for inner_type in inner_types:
- # We have finished parsing multiple structural data types
- if not inner_type and stack:
- stack.pop()
- elif cls._has_nested_data_types(inner_type):
- # split on comma , to get individual data types
- single_fields = cls._split_data_type(inner_type, ",")
- for single_field in single_fields:
- single_field = single_field.strip()
- # If component type starts with a comma, the first
single field
- # will be an empty string. Disregard this empty string.
- if not single_field:
- continue
- # split on whitespace to get field name and data type
- field_info = cls._split_data_type(single_field, r"\s")
- # check if there is a structural data type within
- # overall structural data type
- column_spec = cls.get_column_spec(field_info[1])
- column_type = column_spec.sqla_type if column_spec
else None
- if column_type is None:
- column_type = types.String()
- logger.info(
- "Did not recognize type %s of column %s",
- field_info[1],
- field_info[0],
- )
- if field_info[1] == "array" or field_info[1] == "row":
- stack.append((field_info[0], field_info[1]))
- full_parent_path = cls._get_full_name(stack)
- result.append(
- cls._create_column_info(full_parent_path,
column_type)
- )
- else: # otherwise this field is a basic data type
- full_parent_path = cls._get_full_name(stack)
- column_name = f"{full_parent_path}.{field_info[0]}"
- result.append(
- cls._create_column_info(column_name,
column_type)
- )
- # If the component type ends with a structural data type,
do not pop
- # the stack. We have run across a structural data type
within the
- # overall structural data type. Otherwise, we have
completely parsed
- # through the entire structural data type and can move on.
- if not (inner_type.endswith("array") or
inner_type.endswith("row")):
- stack.pop()
- # We have an array of row objects (i.e. array(row(...)))
- elif inner_type in ("array", "row"):
- # Push a dummy object to represent the structural data type
- stack.append(("", inner_type))
- # We have an array of a basic data types(i.e. array(varchar)).
- elif stack:
- # Because it is an array of a basic data type. We have
finished
- # parsing the structural data type and can move on.
- stack.pop()
- # Unquote the column name if necessary
- if formatted_parent_column_name != parent_column_name:
- for index in range(original_result_len, len(result)):
- result[index]["column_name"] =
result[index]["column_name"].replace(
- formatted_parent_column_name, parent_column_name
- )
-
- @classmethod
- def _show_columns(
- cls,
- inspector: Inspector,
- table: Table,
- ) -> list[ResultRow]:
- """
- Show presto column names
- :param inspector: object that performs database schema inspection
- :param table: table instance
- :return: list of column objects
- """
- quote = inspector.engine.dialect.identifier_preparer.quote_identifier
- full_table = quote(table.table)
- if table.schema:
- full_table = f"{quote(table.schema)}.{full_table}"
- return inspector.bind.execute(f"SHOW COLUMNS FROM
{full_table}").fetchall()
-
- @classmethod
- def get_columns(
- cls,
- inspector: Inspector,
- table: Table,
- options: dict[str, Any] | None = None,
- ) -> list[ResultSetColumnType]:
- """
- Get columns from a Presto data source. This includes handling row and
- array data types
- :param inspector: object that performs database schema inspection
- :param table: table instance
- :param options: Extra configuration options, not used by this backend
- :return: a list of results that contain column info
- (i.e. column name and data type)
- """
- columns = cls._show_columns(inspector, table)
- result: list[ResultSetColumnType] = []
- for column in columns:
- # parse column if it is a row or array
- if is_feature_enabled("PRESTO_EXPAND_DATA") and (
- "array" in column.Type or "row" in column.Type
- ):
- structural_column_index = len(result)
- cls._parse_structural_column(column.Column, column.Type,
result)
- result[structural_column_index]["nullable"] = getattr(
- column, "Null", True
- )
- result[structural_column_index]["default"] = None
- continue
-
- # otherwise column is a basic data type
- column_spec = cls.get_column_spec(column.Type)
- column_type = column_spec.sqla_type if column_spec else None
- if column_type is None:
- column_type = types.String()
- logger.info(
- "Did not recognize type %s of column %s",
- str(column.Type),
- str(column.Column),
- )
- column_info = cls._create_column_info(column.Column, column_type)
- column_info["nullable"] = getattr(column, "Null", True)
- column_info["default"] = None
- column_info["column_name"] = column.Column
- result.append(column_info)
- return result
-
@classmethod
def _is_column_name_quoted(cls, column_name: str) -> bool:
"""
diff --git a/superset/db_engine_specs/trino.py
b/superset/db_engine_specs/trino.py
index 143276bdc3..1eb4b30787 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -36,7 +36,7 @@ from sqlalchemy.exc import NoSuchTableError
from superset import db
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY,
USER_AGENT
from superset.databases.utils import make_url_safe
-from superset.db_engine_specs.base import BaseEngineSpec
+from superset.db_engine_specs.base import BaseEngineSpec,
convert_inspector_columns
from superset.db_engine_specs.exceptions import (
SupersetDBAPIConnectionError,
SupersetDBAPIDatabaseError,
@@ -241,7 +241,11 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
execute_thread = threading.Thread(
target=_execute,
- args=(execute_result, execute_event,
current_app._get_current_object()), # pylint: disable=protected-access
+ args=(
+ execute_result,
+ execute_event,
+ current_app._get_current_object(), # pylint:
disable=protected-access
+ ),
)
execute_thread.start()
@@ -433,7 +437,17 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
"schema_options", expand the schema definition out to show all
subfields of nested ROWs as their appropriate dotted paths.
"""
- base_cols = super().get_columns(inspector, table, options)
+ # The Trino dialect raises `NoSuchTableError` on the inspection
methods when the
+ # table is empty. We can work around this by running a `SHOW COLUMNS
FROM` query
+ # when that happens, using the method from the Presto base engine spec.
+ try:
+ # `SELECT * FROM information_schema.columns WHERE ...`
+ sqla_columns = inspector.get_columns(table.table, table.schema)
+ base_cols = convert_inspector_columns(sqla_columns)
+ except NoSuchTableError:
+ # `SHOW COLUMNS FROM ...`
+ base_cols = super().get_columns(inspector, table, options)
+
if not (options or {}).get("expand_rows"):
return base_cols
@@ -483,9 +497,6 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
:param to_sql_kwargs: The `pandas.DataFrame.to_sql` keyword arguments
:see: superset.db_engine_specs.HiveEngineSpec.df_to_sql
"""
-
- # pylint: disable=import-outside-toplevel
-
if to_sql_kwargs["if_exists"] == "append":
raise SupersetException("Append operation not currently supported")
diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py
b/tests/integration_tests/db_engine_specs/presto_tests.py
index 607afa6953..69e5273f2b 100644
--- a/tests/integration_tests/db_engine_specs/presto_tests.py
+++ b/tests/integration_tests/db_engine_specs/presto_tests.py
@@ -78,7 +78,10 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
def verify_presto_column(self, column, expected_results):
inspector = mock.Mock()
- inspector.engine.dialect.identifier_preparer.quote_identifier =
mock.Mock()
+ preparer = inspector.engine.dialect.identifier_preparer
+ preparer.quote_identifier = preparer.quote = preparer.quote_schema = (
+ lambda x: f'"{x}"'
+ )
row = mock.Mock()
row.Column, row.Type, row.Null = column
inspector.bind.execute.return_value.fetchall =
mock.Mock(return_value=[row])
@@ -798,7 +801,8 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
def test_show_columns(self):
inspector = mock.MagicMock()
- inspector.engine.dialect.identifier_preparer.quote_identifier = (
+ preparer = inspector.engine.dialect.identifier_preparer
+ preparer.quote_identifier = preparer.quote = preparer.quote_schema = (
lambda x: f'"{x}"'
)
inspector.bind.execute.return_value.fetchall = mock.MagicMock(
@@ -813,7 +817,8 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
def test_show_columns_with_schema(self):
inspector = mock.MagicMock()
- inspector.engine.dialect.identifier_preparer.quote_identifier = (
+ preparer = inspector.engine.dialect.identifier_preparer
+ preparer.quote_identifier = preparer.quote = preparer.quote_schema = (
lambda x: f'"{x}"'
)
inspector.bind.execute.return_value.fetchall = mock.MagicMock(
diff --git a/tests/unit_tests/db_engine_specs/test_base.py
b/tests/unit_tests/db_engine_specs/test_base.py
index a3920c8916..9ec1ebaf00 100644
--- a/tests/unit_tests/db_engine_specs/test_base.py
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -311,3 +311,26 @@ def test_get_default_catalog(mocker: MockerFixture) ->
None:
database = mocker.MagicMock()
assert BaseEngineSpec.get_default_catalog(database) is None
+
+
+def test_quote_table() -> None:
+ """
+ Test the `quote_table` function.
+ """
+ from superset.db_engine_specs.base import BaseEngineSpec
+
+ dialect = sqlite.dialect()
+
+ assert BaseEngineSpec.quote_table(Table("table"), dialect) == '"table"'
+ assert (
+ BaseEngineSpec.quote_table(Table("table", "schema"), dialect)
+ == 'schema."table"'
+ )
+ assert (
+ BaseEngineSpec.quote_table(Table("table", "schema", "catalog"),
dialect)
+ == 'catalog.schema."table"'
+ )
+ assert (
+ BaseEngineSpec.quote_table(Table("ta ble", "sche.ma", 'cata"log'),
dialect)
+ == '"cata""log"."sche.ma"."ta ble"'
+ )
diff --git a/tests/unit_tests/db_engine_specs/test_trino.py
b/tests/unit_tests/db_engine_specs/test_trino.py
index 3a2ac91ad6..a0923e8111 100644
--- a/tests/unit_tests/db_engine_specs/test_trino.py
+++ b/tests/unit_tests/db_engine_specs/test_trino.py
@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
import copy
+from collections import namedtuple
from datetime import datetime
from typing import Any, Optional
from unittest.mock import MagicMock, Mock, patch
@@ -25,7 +26,9 @@ import pytest
from pytest_mock import MockerFixture
from requests.exceptions import ConnectionError as RequestsConnectionError
from sqlalchemy import sql, text, types
+from sqlalchemy.dialects import sqlite
from sqlalchemy.engine.url import make_url
+from sqlalchemy.exc import NoSuchTableError
from trino.exceptions import TrinoExternalError, TrinoInternalError,
TrinoUserError
from trino.sqlalchemy import datatype
from trino.sqlalchemy.dialect import TrinoDialect
@@ -464,6 +467,64 @@ def test_get_columns(mocker: MockerFixture):
_assert_columns_equal(actual, expected)
+def test_get_columns_error(mocker: MockerFixture):
+ """
+ Test that we fallback to a `SHOW COLUMNS FROM ...` query.
+ """
+ from superset.db_engine_specs.trino import TrinoEngineSpec
+
+ field1_type = datatype.parse_sqltype("row(a varchar, b date)")
+ field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))")
+ field3_type = datatype.parse_sqltype("int")
+
+ mock_inspector = mocker.MagicMock()
+ mock_inspector.engine.dialect = sqlite.dialect()
+ mock_inspector.get_columns.side_effect = NoSuchTableError(
+ "The specified table does not exist."
+ )
+ Row = namedtuple("Row", ["Column", "Type"])
+ mock_inspector.bind.execute().fetchall.return_value = [
+ Row("field1", "row(a varchar, b date)"),
+ Row("field2", "row(r1 row(a varchar, b varchar))"),
+ Row("field3", "int"),
+ ]
+
+ actual = TrinoEngineSpec.get_columns(mock_inspector, Table("table",
"schema"))
+ expected = [
+ ResultSetColumnType(
+ name="field1",
+ column_name="field1",
+ type=field1_type,
+ is_dttm=None,
+ type_generic=None,
+ default=None,
+ nullable=True,
+ ),
+ ResultSetColumnType(
+ name="field2",
+ column_name="field2",
+ type=field2_type,
+ is_dttm=None,
+ type_generic=None,
+ default=None,
+ nullable=True,
+ ),
+ ResultSetColumnType(
+ name="field3",
+ column_name="field3",
+ type=field3_type,
+ is_dttm=None,
+ type_generic=None,
+ default=None,
+ nullable=True,
+ ),
+ ]
+
+ _assert_columns_equal(actual, expected)
+
+ mock_inspector.bind.execute.assert_called_with('SHOW COLUMNS FROM
schema."table"')
+
+
def test_get_columns_expand_rows(mocker: MockerFixture):
"""Test that ROW columns are correctly expanded with expand_rows"""
from superset.db_engine_specs.trino import TrinoEngineSpec
@@ -536,8 +597,6 @@ def test_get_columns_expand_rows(mocker: MockerFixture):
def test_get_indexes_no_table():
- from sqlalchemy.exc import NoSuchTableError
-
from superset.db_engine_specs.trino import TrinoEngineSpec
db_mock = Mock()