mistercrunch closed pull request #4724: Improve database type inference
URL: https://github.com/apache/incubator-superset/pull/4724
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/superset/dataframe.py b/superset/dataframe.py
index 79a2c3d564..5fba4ffed6 100644
--- a/superset/dataframe.py
+++ b/superset/dataframe.py
@@ -13,6 +13,7 @@
from __future__ import unicode_literals
from datetime import date, datetime
+import logging
import numpy as np
import pandas as pd
@@ -26,6 +27,27 @@
INFER_COL_TYPES_SAMPLE_SIZE = 100
+def dedup(l, suffix='__'):
+ """De-duplicates a list of string by suffixing a counter
+
+ Always returns the same number of entries as provided, and always returns
+ unique values.
+
+ >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
+ foo,bar,bar__1,bar__2
+ """
+ new_l = []
+ seen = {}
+ for s in l:
+ if s in seen:
+ seen[s] += 1
+ s += suffix + str(seen[s])
+ else:
+ seen[s] = 0
+ new_l.append(s)
+ return new_l
+
+
class SupersetDataFrame(object):
# Mapping numpy dtype.char to generic database types
type_map = {
@@ -43,19 +65,39 @@ class SupersetDataFrame(object):
'V': None, # raw data (void)
}
- def __init__(self, df):
- self.__df = df.where((pd.notnull(df)), None)
+ def __init__(self, data, cursor_description, db_engine_spec):
+ column_names = []
+ if cursor_description:
+ column_names = [col[0] for col in cursor_description]
+
+ self.column_names = dedup(
+ db_engine_spec.get_normalized_column_names(cursor_description))
+
+ data = data or []
+ self.df = (
+ pd.DataFrame(list(data), columns=column_names).infer_objects())
+
+ self._type_dict = {}
+ try:
+ # The driver may not be passing a cursor.description
+ self._type_dict = {
+ col: db_engine_spec.get_datatype(cursor_description[i][1])
+ for i, col in enumerate(self.column_names)
+ if cursor_description
+ }
+ except Exception as e:
+ logging.exception(e)
@property
def size(self):
- return len(self.__df.index)
+ return len(self.df.index)
@property
def data(self):
# work around for https://github.com/pandas-dev/pandas/issues/18372
data = [dict((k, _maybe_box_datetimelike(v))
- for k, v in zip(self.__df.columns, np.atleast_1d(row)))
- for row in self.__df.values]
+ for k, v in zip(self.df.columns, np.atleast_1d(row)))
+ for row in self.df.values]
for d in data:
for k, v in list(d.items()):
# if an int is too big for Java Script to handle
@@ -70,7 +112,8 @@ def db_type(cls, dtype):
"""Given a numpy dtype, Returns a generic database type"""
if isinstance(dtype, ExtensionDtype):
return cls.type_map.get(dtype.kind)
- return cls.type_map.get(dtype.char)
+ elif hasattr(dtype, 'char'):
+ return cls.type_map.get(dtype.char)
@classmethod
def datetime_conversion_rate(cls, data_series):
@@ -105,7 +148,7 @@ def agg_func(cls, dtype, column_name):
# consider checking for key substring too.
if cls.is_id(column_name):
return 'count_distinct'
- if (issubclass(dtype.type, np.generic) and
+ if (hasattr(dtype, 'type') and issubclass(dtype.type, np.generic) and
np.issubdtype(dtype, np.number)):
return 'sum'
return None
@@ -116,22 +159,25 @@ def columns(self):
:return: dict, with the fields name, type, is_date, is_dim and agg.
"""
- if self.__df.empty:
+ if self.df.empty:
return None
columns = []
- sample_size = min(INFER_COL_TYPES_SAMPLE_SIZE, len(self.__df.index))
- sample = self.__df
+ sample_size = min(INFER_COL_TYPES_SAMPLE_SIZE, len(self.df.index))
+ sample = self.df
if sample_size:
- sample = self.__df.sample(sample_size)
- for col in self.__df.dtypes.keys():
- col_db_type = self.db_type(self.__df.dtypes[col])
+ sample = self.df.sample(sample_size)
+ for col in self.df.dtypes.keys():
+ col_db_type = (
+ self._type_dict.get(col) or
+ self.db_type(self.df.dtypes[col])
+ )
column = {
'name': col,
- 'agg': self.agg_func(self.__df.dtypes[col], col),
+ 'agg': self.agg_func(self.df.dtypes[col], col),
'type': col_db_type,
- 'is_date': self.is_date(self.__df.dtypes[col]),
- 'is_dim': self.is_dimension(self.__df.dtypes[col], col),
+ 'is_date': self.is_date(self.df.dtypes[col]),
+ 'is_dim': self.is_dimension(self.df.dtypes[col], col),
}
if column['type'] in ('OBJECT', None):
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 4f6b22e305..4181c49d67 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -30,6 +30,7 @@
from flask import g
from flask_babel import lazy_gettext as _
import pandas
+from past.builtins import basestring
import sqlalchemy as sqla
from sqlalchemy import select
from sqlalchemy.engine import create_engine
@@ -85,6 +86,11 @@ def epoch_to_dttm(cls):
def epoch_ms_to_dttm(cls):
return cls.epoch_to_dttm().replace('{col}', '({col}/1000.0)')
+ @classmethod
+ def get_datatype(cls, type_code):
+ if isinstance(type_code, basestring) and len(type_code):
+ return type_code.upper()
+
@classmethod
def extra_table_metadata(cls, database, table_name, schema_name):
"""Returns engine-specific table metadata"""
@@ -592,6 +598,7 @@ class MySQLEngineSpec(BaseEngineSpec):
'INTERVAL DAYOFWEEK(DATE_SUB({col}, INTERVAL 1 DAY)) - 1 DAY))',
'P1W'),
)
+ type_code_map = {} # loaded from get_datatype only if needed
@classmethod
def convert_dttm(cls, target_type, dttm):
@@ -606,6 +613,23 @@ def adjust_database_uri(cls, uri, selected_schema=None):
uri.database = selected_schema
return uri
+ @classmethod
+ def get_datatype(cls, type_code):
+ if not cls.type_code_map:
+ # only import and store if needed at least once
+ import MySQLdb
+ ft = MySQLdb.constants.FIELD_TYPE
+ cls.type_code_map = {
+ getattr(ft, k): k
+ for k in dir(ft)
+ if not k.startswith('_')
+ }
+ datatype = type_code
+ if isinstance(type_code, int):
+ datatype = cls.type_code_map.get(type_code)
+ if datatype and isinstance(datatype, basestring) and len(datatype):
+ return datatype
+
@classmethod
def epoch_to_dttm(cls):
return 'from_unixtime({col})'
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index df00a2b6b1..34a9eeb9e3 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -10,8 +10,6 @@
from celery.exceptions import SoftTimeLimitExceeded
from contextlib2 import contextmanager
-import numpy as np
-import pandas as pd
import sqlalchemy
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
@@ -31,27 +29,6 @@ class SqlLabException(Exception):
pass
-def dedup(l, suffix='__'):
- """De-duplicates a list of string by suffixing a counter
-
- Always returns the same number of entries as provided, and always returns
- unique values.
-
- >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
- foo,bar,bar__1,bar__2
- """
- new_l = []
- seen = {}
- for s in l:
- if s in seen:
- seen[s] += 1
- s += suffix + str(seen[s])
- else:
- seen[s] = 0
- new_l.append(s)
- return new_l
-
-
def get_query(query_id, session, retry_count=5):
"""attemps to get the query and retry if it cannot"""
query = None
@@ -96,24 +73,6 @@ def session_scope(nullpool):
session.close()
-def convert_results_to_df(column_names, data):
- """Convert raw query results to a DataFrame."""
- column_names = dedup(column_names)
-
- # check whether the result set has any nested dict columns
- if data:
- first_row = data[0]
- has_dict_col = any([isinstance(c, dict) for c in first_row])
- df_data = list(data) if has_dict_col else np.array(data, dtype=object)
- else:
- df_data = []
-
- cdf = dataframe.SupersetDataFrame(
- pd.DataFrame(df_data, columns=column_names))
-
- return cdf
-
-
@celery_app.task(bind=True, soft_time_limit=SQLLAB_TIMEOUT)
def get_sql_results(
ctask, query_id, rendered_query, return_results=True, store_results=False,
@@ -233,7 +192,6 @@ def handle_error(msg):
return handle_error(db_engine_spec.extract_error_message(e))
logging.info('Fetching cursor description')
- column_names =
db_engine_spec.get_normalized_column_names(cursor.description)
if conn is not None:
conn.commit()
@@ -242,7 +200,7 @@ def handle_error(msg):
if query.status == utils.QueryStatus.STOPPED:
return handle_error('The query has been stopped')
- cdf = convert_results_to_df(column_names, data)
+ cdf = dataframe.SupersetDataFrame(data, cursor.description, db_engine_spec)
query.rows = cdf.size
query.progress = 100
diff --git a/tests/celery_tests.py b/tests/celery_tests.py
index 39b7749ae8..afaeea9dfb 100644
--- a/tests/celery_tests.py
+++ b/tests/celery_tests.py
@@ -14,7 +14,7 @@
import pandas as pd
from past.builtins import basestring
-from superset import app, cli, dataframe, db, security_manager
+from superset import app, cli, db, security_manager
from superset.models.helpers import QueryStatus
from superset.models.sql_lab import Query
from superset.sql_parse import SupersetQuery
@@ -245,55 +245,6 @@ def str_if_basestring(o):
def dictify_list_of_dicts(cls, l, k):
return {str(o[k]): cls.de_unicode_dict(o) for o in l}
- def test_get_columns(self):
- main_db = self.get_main_database(db.session)
- df = main_db.get_df('SELECT * FROM multiformat_time_series', None)
- cdf = dataframe.SupersetDataFrame(df)
-
- # Making ordering non-deterministic
- cols = self.dictify_list_of_dicts(cdf.columns, 'name')
-
- if main_db.sqlalchemy_uri.startswith('sqlite'):
- self.assertEqual(self.dictify_list_of_dicts([
- {'is_date': True, 'type': 'STRING', 'name': 'ds',
- 'is_dim': False},
- {'is_date': True, 'type': 'STRING', 'name': 'ds2',
- 'is_dim': False},
- {'agg': 'sum', 'is_date': False, 'type': 'INT',
- 'name': 'epoch_ms', 'is_dim': False},
- {'agg': 'sum', 'is_date': False, 'type': 'INT',
- 'name': 'epoch_s', 'is_dim': False},
- {'is_date': True, 'type': 'STRING', 'name': 'string0',
- 'is_dim': False},
- {'is_date': False, 'type': 'STRING',
- 'name': 'string1', 'is_dim': True},
- {'is_date': True, 'type': 'STRING', 'name': 'string2',
- 'is_dim': False},
- {'is_date': False, 'type': 'STRING',
- 'name': 'string3', 'is_dim': True}], 'name'),
- cols,
- )
- else:
- self.assertEqual(self.dictify_list_of_dicts([
- {'is_date': True, 'type': 'DATETIME', 'name': 'ds',
- 'is_dim': False},
- {'is_date': True, 'type': 'DATETIME',
- 'name': 'ds2', 'is_dim': False},
- {'agg': 'sum', 'is_date': False, 'type': 'INT',
- 'name': 'epoch_ms', 'is_dim': False},
- {'agg': 'sum', 'is_date': False, 'type': 'INT',
- 'name': 'epoch_s', 'is_dim': False},
- {'is_date': True, 'type': 'STRING', 'name': 'string0',
- 'is_dim': False},
- {'is_date': False, 'type': 'STRING',
- 'name': 'string1', 'is_dim': True},
- {'is_date': True, 'type': 'STRING', 'name': 'string2',
- 'is_dim': False},
- {'is_date': False, 'type': 'STRING',
- 'name': 'string3', 'is_dim': True}], 'name'),
- cols,
- )
-
if __name__ == '__main__':
unittest.main()
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 6a4f153eb8..f1a01796b7 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -24,6 +24,7 @@
from superset import dataframe, db, jinja_context, security_manager, sql_lab,
utils
from superset.connectors.sqla.models import SqlaTable
+from superset.db_engine_specs import BaseEngineSpec
from superset.models import core as models
from superset.models.sql_lab import Query
from superset.views.core import DatabaseView
@@ -626,8 +627,7 @@ def test_dataframe_timezone(self):
(datetime.datetime(2017, 11, 18, 21, 53, 0, 219225, tzinfo=tz),),
(datetime.datetime(2017, 11, 18, 22, 6, 30, 61810, tzinfo=tz),),
]
- df = dataframe.SupersetDataFrame(pd.DataFrame(data=list(data),
- columns=['data']))
+ df = dataframe.SupersetDataFrame(list(data), [['data']],
BaseEngineSpec)
data = df.data
self.assertDictEqual(
data[0],
diff --git a/tests/dataframe_test.py b/tests/dataframe_test.py
new file mode 100644
index 0000000000..b56770240b
--- /dev/null
+++ b/tests/dataframe_test.py
@@ -0,0 +1,115 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from superset.dataframe import dedup, SupersetDataFrame
+from superset.db_engine_specs import BaseEngineSpec
+from .base_tests import SupersetTestCase
+
+
+class SupersetDataFrameTestCase(SupersetTestCase):
+ def test_dedup(self):
+ self.assertEquals(
+ dedup(['foo', 'bar']),
+ ['foo', 'bar'],
+ )
+ self.assertEquals(
+ dedup(['foo', 'bar', 'foo', 'bar']),
+ ['foo', 'bar', 'foo__1', 'bar__1'],
+ )
+ self.assertEquals(
+ dedup(['foo', 'bar', 'bar', 'bar']),
+ ['foo', 'bar', 'bar__1', 'bar__2'],
+ )
+
+ def test_get_columns_basic(self):
+ data = [
+ ('a1', 'b1', 'c1'),
+ ('a2', 'b2', 'c2'),
+ ]
+ cursor_descr = (
+ ('a', 'string'),
+ ('b', 'string'),
+ ('c', 'string'),
+ )
+ cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec)
+ self.assertEqual(
+ cdf.columns,
+ [
+ {
+ 'is_date': False,
+ 'type': 'STRING',
+ 'name': 'a',
+ 'is_dim': True,
+ }, {
+ 'is_date': False,
+ 'type': 'STRING',
+ 'name': 'b',
+ 'is_dim': True,
+ }, {
+ 'is_date': False,
+ 'type': 'STRING',
+ 'name': 'c',
+ 'is_dim': True,
+ },
+ ],
+ )
+
+ def test_get_columns_with_int(self):
+ data = [
+ ('a1', 1),
+ ('a2', 2),
+ ]
+ cursor_descr = (
+ ('a', 'string'),
+ ('b', 'int'),
+ )
+ cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec)
+ self.assertEqual(
+ cdf.columns,
+ [
+ {
+ 'is_date': False,
+ 'type': 'STRING',
+ 'name': 'a',
+ 'is_dim': True,
+ }, {
+ 'is_date': False,
+ 'type': 'INT',
+ 'name': 'b',
+ 'is_dim': False,
+ 'agg': 'sum',
+ },
+ ],
+ )
+
+ def test_get_columns_type_inference(self):
+ data = [
+ (1.2, 1),
+ (3.14, 2),
+ ]
+ cursor_descr = (
+ ('a', None),
+ ('b', None),
+ )
+ cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec)
+ self.assertEqual(
+ cdf.columns,
+ [
+ {
+ 'is_date': False,
+ 'type': 'FLOAT',
+ 'name': 'a',
+ 'is_dim': False,
+ 'agg': 'sum',
+ }, {
+ 'is_date': False,
+ 'type': 'INT',
+ 'name': 'b',
+ 'is_dim': False,
+ 'agg': 'sum',
+ },
+ ],
+ )
diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py
index bdce0b060d..447914ed5f 100644
--- a/tests/db_engine_specs_test.py
+++ b/tests/db_engine_specs_test.py
@@ -7,7 +7,9 @@
import textwrap
from superset.db_engine_specs import (
- HiveEngineSpec, MssqlEngineSpec, MySQLEngineSpec)
+ BaseEngineSpec, HiveEngineSpec, MssqlEngineSpec,
+ MySQLEngineSpec, PrestoEngineSpec,
+)
from superset.models.core import Database
from .base_tests import SupersetTestCase
@@ -193,3 +195,9 @@ def test_limit_expr_and_semicolon(self):
FROM
table LIMIT 1000"""),
)
+
+ def test_get_datatype(self):
+ self.assertEquals('STRING', PrestoEngineSpec.get_datatype('string'))
+ self.assertEquals('TINY', MySQLEngineSpec.get_datatype(1))
+ self.assertEquals('VARCHAR', MySQLEngineSpec.get_datatype(15))
+ self.assertEquals('VARCHAR', BaseEngineSpec.get_datatype('VARCHAR'))
diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py
index 49926f80de..a3bb564dd8 100644
--- a/tests/sqllab_tests.py
+++ b/tests/sqllab_tests.py
@@ -12,8 +12,9 @@
from flask_appbuilder.security.sqla import models as ab_models
from superset import db, security_manager, utils
+from superset.dataframe import SupersetDataFrame
+from superset.db_engine_specs import BaseEngineSpec
from superset.models.sql_lab import Query
-from superset.sql_lab import convert_results_to_df
from .base_tests import SupersetTestCase
@@ -203,9 +204,13 @@ def test_alias_duplicate(self):
raise_on_error=True)
def test_df_conversion_no_dict(self):
- cols = ['string_col', 'int_col', 'float_col']
+ cols = [
+ ['string_col', 'string'],
+ ['int_col', 'int'],
+ ['float_col', 'float'],
+ ]
data = [['a', 4, 4.0]]
- cdf = convert_results_to_df(cols, data)
+ cdf = SupersetDataFrame(data, cols, BaseEngineSpec)
self.assertEquals(len(data), cdf.size)
self.assertEquals(len(cols), len(cdf.columns))
@@ -213,7 +218,7 @@ def test_df_conversion_no_dict(self):
def test_df_conversion_tuple(self):
cols = ['string_col', 'int_col', 'list_col', 'float_col']
data = [(u'Text', 111, [123], 1.0)]
- cdf = convert_results_to_df(cols, data)
+ cdf = SupersetDataFrame(data, cols, BaseEngineSpec)
self.assertEquals(len(data), cdf.size)
self.assertEquals(len(cols), len(cdf.columns))
@@ -221,7 +226,7 @@ def test_df_conversion_tuple(self):
def test_df_conversion_dict(self):
cols = ['string_col', 'dict_col', 'int_col']
data = [['a', {'c1': 1, 'c2': 2, 'c3': 3}, 4]]
- cdf = convert_results_to_df(cols, data)
+ cdf = SupersetDataFrame(data, cols, BaseEngineSpec)
self.assertEquals(len(data), cdf.size)
self.assertEquals(len(cols), len(cdf.columns))
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]