This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 7eab59af51 fix(mysql): handle string typed decimal results (#24241)
7eab59af51 is described below
commit 7eab59af513ccccb3b1fed7aca5798c98c35fdb8
Author: Ville Brofeldt <[email protected]>
AuthorDate: Fri Sep 29 10:48:08 2023 -0700
fix(mysql): handle string typed decimal results (#24241)
---
superset/db_engine_specs/base.py | 29 ++++++++++++++++++-
superset/db_engine_specs/mysql.py | 6 +++-
tests/unit_tests/db_engine_specs/test_mysql.py | 40 ++++++++++++++++++++++++++
3 files changed, 73 insertions(+), 2 deletions(-)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index b9c083e42b..f355e4ef8c 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -323,6 +323,10 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
# engine-specific type mappings to check prior to the defaults
column_type_mappings: tuple[ColumnTypeMapping, ...] = ()
+ # type-specific functions to mutate values received from the database.
+ # Needed on certain databases that return values in an unexpected format
+ column_type_mutators: dict[TypeEngine, Callable[[Any], Any]] = {}
+
# Does database support join-free timeslot grouping
time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT
@@ -743,7 +747,30 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
try:
if cls.limit_method == LimitMethod.FETCH_MANY and limit:
return cursor.fetchmany(limit)
- return cursor.fetchall()
+ data = cursor.fetchall()
+ description = cursor.description or []
+ # Create a mapping between column name and a mutator function to
normalize
+ # values with. The first two items in the description row are
+ # the column name and type.
+ column_mutators = {
+ row[0]: func
+ for row in description
+ if (
+ func := cls.column_type_mutators.get(
+
type(cls.get_sqla_column_type(cls.get_datatype(row[1])))
+ )
+ )
+ }
+ if column_mutators:
+ indexes = {row[0]: idx for idx, row in enumerate(description)}
+ for row_idx, row in enumerate(data):
+ new_row = list(row)
+ for col, func in column_mutators.items():
+ col_idx = indexes[col]
+ new_row[col_idx] = func(row[col_idx])
+ data[row_idx] = tuple(new_row)
+
+ return data
except Exception as ex:
raise cls.get_dbapi_mapped_exception(ex) from ex
diff --git a/superset/db_engine_specs/mysql.py
b/superset/db_engine_specs/mysql.py
index 4d5604222d..687ffee7d5 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -17,8 +17,9 @@
import contextlib
import re
from datetime import datetime
+from decimal import Decimal
from re import Pattern
-from typing import Any, Optional
+from typing import Any, Callable, Optional
from urllib import parse
from flask_babel import gettext as __
@@ -126,6 +127,9 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
GenericDataType.STRING,
),
)
+ column_type_mutators: dict[types.TypeEngine, Callable[[Any], Any]] = {
+ DECIMAL: lambda val: Decimal(val) if isinstance(val, str) else val
+ }
_time_grain_expressions = {
None: "{col}",
diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py
b/tests/unit_tests/db_engine_specs/test_mysql.py
index 89abf2321d..ed64347017 100644
--- a/tests/unit_tests/db_engine_specs/test_mysql.py
+++ b/tests/unit_tests/db_engine_specs/test_mysql.py
@@ -16,6 +16,7 @@
# under the License.
from datetime import datetime
+from decimal import Decimal
from typing import Any, Optional
from unittest.mock import Mock, patch
@@ -220,3 +221,42 @@ def test_get_schema_from_engine_params() -> None:
)
== "db1"
)
+
+
[email protected](
+ "data,description,expected_result",
+ [
+ (
+ [("1.23456", "abc")],
+ [("dec", "decimal(12,6)"), ("str", "varchar(3)")],
+ [(Decimal("1.23456"), "abc")],
+ ),
+ (
+ [(Decimal("1.23456"), "abc")],
+ [("dec", "decimal(12,6)"), ("str", "varchar(3)")],
+ [(Decimal("1.23456"), "abc")],
+ ),
+ (
+ [(None, "abc")],
+ [("dec", "decimal(12,6)"), ("str", "varchar(3)")],
+ [(None, "abc")],
+ ),
+ (
+ [("1.23456", "abc")],
+ [("dec", "varchar(255)"), ("str", "varchar(3)")],
+ [("1.23456", "abc")],
+ ),
+ ],
+)
+def test_column_type_mutator(
+ data: list[tuple[Any, ...]],
+ description: list[Any],
+ expected_result: list[tuple[Any, ...]],
+):
+ from superset.db_engine_specs.mysql import MySQLEngineSpec as spec
+
+ mock_cursor = Mock()
+ mock_cursor.fetchall.return_value = data
+ mock_cursor.description = description
+
+ assert spec.fetch_data(mock_cursor) == expected_result