This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 9461f9c fix(db_engine_specs): improve Presto column type matching
(#10658)
9461f9c is described below
commit 9461f9c1e0d4a4d07ae0ff93acba8f5cf4ccbe97
Author: Ville Brofeldt <[email protected]>
AuthorDate: Mon Aug 24 22:42:07 2020 +0300
fix(db_engine_specs): improve Presto column type matching (#10658)
* fix: improve Presto column type matching
* add optional callback to type map and add tests
* lint
* change private to public
---
superset/db_engine_specs/base.py | 13 +++-
superset/db_engine_specs/mssql.py | 15 ++---
superset/db_engine_specs/presto.py | 89 ++++++++++++++++++---------
superset/models/sql_types/presto_sql_types.py | 24 --------
tests/db_engine_specs/presto_tests.py | 21 +++++++
5 files changed, 97 insertions(+), 65 deletions(-)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 331961c..d3d9dab 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -24,8 +24,10 @@ from contextlib import closing
from datetime import datetime
from typing import (
Any,
+ Callable,
Dict,
List,
+ Match,
NamedTuple,
Optional,
Pattern,
@@ -142,6 +144,9 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
] = None # used for user messages, overridden in child classes
_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
+ column_type_mappings: Tuple[
+ Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]],
TypeEngine]]], ...,
+ ] = ()
time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT
time_secondary_columns = False
@@ -886,12 +891,18 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
"""
Return a sqlalchemy native column type that corresponds to the column
type
defined in the data source (return None to use default type inferred by
- SQLAlchemy). Needs to be overridden if column requires special handling
+ SQLAlchemy). Override `_column_type_mappings` for specific needs
(see MSSQL for example of NCHAR/NVARCHAR handling).
:param type_: Column type returned by inspector
:return: SqlAlchemy column type
"""
+ for regex, sqla_type in cls.column_type_mappings:
+ match = regex.match(type_)
+ if match:
+ if callable(sqla_type):
+ return sqla_type(match)
+ return sqla_type
return None
@staticmethod
diff --git a/superset/db_engine_specs/mssql.py
b/superset/db_engine_specs/mssql.py
index abe1f6c..70bd9b5 100644
--- a/superset/db_engine_specs/mssql.py
+++ b/superset/db_engine_specs/mssql.py
@@ -19,7 +19,7 @@ import re
from datetime import datetime
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
-from sqlalchemy.types import String, TypeEngine, UnicodeText
+from sqlalchemy.types import String, UnicodeText
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.utils import core as utils
@@ -73,19 +73,12 @@ class MssqlEngineSpec(BaseEngineSpec):
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)
- column_types = (
- (String(), re.compile(r"^(?<!N)((VAR){0,1}CHAR|TEXT|STRING)",
re.IGNORECASE)),
- (UnicodeText(), re.compile(r"^N((VAR){0,1}CHAR|TEXT)", re.IGNORECASE)),
+ column_type_mappings = (
+ (re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE), UnicodeText()),
+ (re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE), String()),
)
@classmethod
- def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]:
- for sqla_type, regex in cls.column_types:
- if regex.match(type_):
- return sqla_type
- return None
-
- @classmethod
def extract_error_message(cls, ex: Exception) -> str:
if str(ex).startswith("(8155,"):
return (
diff --git a/superset/db_engine_specs/presto.py
b/superset/db_engine_specs/presto.py
index 16e6a4c..9a53d5d 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -28,7 +28,7 @@ from urllib import parse
import pandas as pd
import simplejson as json
from flask_babel import lazy_gettext as _
-from sqlalchemy import Column, literal_column
+from sqlalchemy import Column, literal_column, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.result import RowProxy
@@ -40,7 +40,13 @@ from superset import app, cache, is_feature_enabled,
security_manager
from superset.db_engine_specs.base import BaseEngineSpec
from superset.exceptions import SupersetTemplateException
from superset.models.sql_lab import Query
-from superset.models.sql_types.presto_sql_types import type_map as
presto_type_map
+from superset.models.sql_types.presto_sql_types import (
+ Array,
+ Interval,
+ Map,
+ Row,
+ TinyInteger,
+)
from superset.result_set import destringify
from superset.sql_parse import ParsedQuery
from superset.utils import core as utils
@@ -260,13 +266,16 @@ class PrestoEngineSpec(BaseEngineSpec):
field_info = cls._split_data_type(single_field, r"\s")
# check if there is a structural data type within
# overall structural data type
+ column_type = cls.get_sqla_column_type(field_info[1])
+ if column_type is None:
+ raise NotImplementedError(
+ _("Unknown column type: %(col)s",
col=field_info[1])
+ )
if field_info[1] == "array" or field_info[1] == "row":
stack.append((field_info[0], field_info[1]))
full_parent_path = cls._get_full_name(stack)
result.append(
- cls._create_column_info(
- full_parent_path,
presto_type_map[field_info[1]]()
- )
+ cls._create_column_info(full_parent_path,
column_type)
)
else: # otherwise this field is a basic data type
full_parent_path = cls._get_full_name(stack)
@@ -274,9 +283,7 @@ class PrestoEngineSpec(BaseEngineSpec):
full_parent_path, field_info[0]
)
result.append(
- cls._create_column_info(
- column_name,
presto_type_map[field_info[1]]()
- )
+ cls._create_column_info(column_name,
column_type)
)
# If the component type ends with a structural data type,
do not pop
# the stack. We have run across a structural data type
within the
@@ -318,6 +325,34 @@ class PrestoEngineSpec(BaseEngineSpec):
columns = inspector.bind.execute("SHOW COLUMNS FROM
{}".format(full_table))
return columns
+ column_type_mappings = (
+ (re.compile(r"^boolean.*", re.IGNORECASE), types.Boolean()),
+ (re.compile(r"^tinyint.*", re.IGNORECASE), TinyInteger()),
+ (re.compile(r"^smallint.*", re.IGNORECASE), types.SmallInteger()),
+ (re.compile(r"^integer.*", re.IGNORECASE), types.Integer()),
+ (re.compile(r"^bigint.*", re.IGNORECASE), types.BigInteger()),
+ (re.compile(r"^real.*", re.IGNORECASE), types.Float()),
+ (re.compile(r"^double.*", re.IGNORECASE), types.Float()),
+ (re.compile(r"^decimal.*", re.IGNORECASE), types.DECIMAL()),
+ (
+ re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE),
+ lambda match: types.VARCHAR(int(match[2])) if match[2] else
types.String(),
+ ),
+ (
+ re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE),
+ lambda match: types.CHAR(int(match[2])) if match[2] else
types.CHAR(),
+ ),
+ (re.compile(r"^varbinary.*", re.IGNORECASE), types.VARBINARY()),
+ (re.compile(r"^json.*", re.IGNORECASE), types.JSON()),
+ (re.compile(r"^date.*", re.IGNORECASE), types.DATE()),
+ (re.compile(r"^time.*", re.IGNORECASE), types.Time()),
+ (re.compile(r"^timestamp.*", re.IGNORECASE), types.TIMESTAMP()),
+ (re.compile(r"^interval.*", re.IGNORECASE), Interval()),
+ (re.compile(r"^array.*", re.IGNORECASE), Array()),
+ (re.compile(r"^map.*", re.IGNORECASE), Map()),
+ (re.compile(r"^row.*", re.IGNORECASE), Row()),
+ )
+
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
@@ -334,28 +369,24 @@ class PrestoEngineSpec(BaseEngineSpec):
columns = cls._show_columns(inspector, table_name, schema)
result: List[Dict[str, Any]] = []
for column in columns:
- try:
- # parse column if it is a row or array
- if is_feature_enabled("PRESTO_EXPAND_DATA") and (
- "array" in column.Type or "row" in column.Type
- ):
- structural_column_index = len(result)
- cls._parse_structural_column(column.Column, column.Type,
result)
- result[structural_column_index]["nullable"] = getattr(
- column, "Null", True
- )
- result[structural_column_index]["default"] = None
- continue
-
- # otherwise column is a basic data type
- column_type = presto_type_map[column.Type]()
- except KeyError:
- logger.info(
- "Did not recognize type {} of column {}".format( #
pylint: disable=logging-format-interpolation
- column.Type, column.Column
- )
+ # parse column if it is a row or array
+ if is_feature_enabled("PRESTO_EXPAND_DATA") and (
+ "array" in column.Type or "row" in column.Type
+ ):
+ structural_column_index = len(result)
+ cls._parse_structural_column(column.Column, column.Type,
result)
+ result[structural_column_index]["nullable"] = getattr(
+ column, "Null", True
+ )
+ result[structural_column_index]["default"] = None
+ continue
+
+ # otherwise column is a basic data type
+ column_type = cls.get_sqla_column_type(column.Type)
+ if column_type is None:
+ raise NotImplementedError(
+ _("Unknown column type: %(col)s", col=column_type)
)
- column_type = "OTHER"
column_info = cls._create_column_info(column.Column, column_type)
column_info["nullable"] = getattr(column, "Null", True)
column_info["default"] = None
diff --git a/superset/models/sql_types/presto_sql_types.py
b/superset/models/sql_types/presto_sql_types.py
index d6f6d39..a314639 100644
--- a/superset/models/sql_types/presto_sql_types.py
+++ b/superset/models/sql_types/presto_sql_types.py
@@ -16,7 +16,6 @@
# under the License.
from typing import Any, Dict, List, Optional, Type
-from sqlalchemy import types
from sqlalchemy.sql.sqltypes import Integer
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.sql.visitors import Visitable
@@ -92,26 +91,3 @@ class Row(TypeEngine):
@classmethod
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
return "ROW"
-
-
-type_map = {
- "boolean": types.Boolean,
- "tinyint": TinyInteger,
- "smallint": types.SmallInteger,
- "integer": types.Integer,
- "bigint": types.BigInteger,
- "real": types.Float,
- "double": types.Float,
- "decimal": types.DECIMAL,
- "varchar": types.String,
- "char": types.CHAR,
- "varbinary": types.VARBINARY,
- "JSON": types.JSON,
- "date": types.DATE,
- "time": types.Time,
- "timestamp": types.TIMESTAMP,
- "interval": Interval,
- "array": Array,
- "map": Map,
- "row": Row,
-}
diff --git a/tests/db_engine_specs/presto_tests.py
b/tests/db_engine_specs/presto_tests.py
index 9d1d384..3a0346b 100644
--- a/tests/db_engine_specs/presto_tests.py
+++ b/tests/db_engine_specs/presto_tests.py
@@ -17,6 +17,7 @@
from unittest import mock, skipUnless
import pandas as pd
+from sqlalchemy import types
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql import select
@@ -490,3 +491,23 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
+
+ def test_get_sqla_column_type(self):
+ sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar(255)")
+ assert isinstance(sqla_type, types.VARCHAR)
+ assert sqla_type.length == 255
+
+ sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar")
+ assert isinstance(sqla_type, types.String)
+ assert sqla_type.length is None
+
+ sqla_type = PrestoEngineSpec.get_sqla_column_type("char(10)")
+ assert isinstance(sqla_type, types.CHAR)
+ assert sqla_type.length == 10
+
+ sqla_type = PrestoEngineSpec.get_sqla_column_type("char")
+ assert isinstance(sqla_type, types.CHAR)
+ assert sqla_type.length is None
+
+ sqla_type = PrestoEngineSpec.get_sqla_column_type("integer")
+ assert isinstance(sqla_type, types.Integer)