This is an automated email from the ASF dual-hosted git repository.
maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new f5277fe Refactor dataframe and column name mutation logic (#6847)
f5277fe is described below
commit f5277fe6843be27e639b954e0a37a680faf22374
Author: Ville Brofeldt <[email protected]>
AuthorDate: Thu Feb 21 09:05:35 2019 +0200
Refactor dataframe and column name mutation logic (#6847)
* Merge dataframe and column name mutation logic, add flag for disabling
column aliases and add column name length checking
* Remove custome mutate_label from oracle spec
* Move hashing from mutate_label() to make_label_compatible()
* Remove empty line
* Make label mutating and truncating more robust
* Rename variables and make proposed changes from review
* Always execute labels_expected codepath
* Fix linting error
* Add comments and fix subquery errors
* Refine column compatibility
* Simplify label assignment
* Add unit tests for BQ and Oracle
* Linting
---
superset/connectors/sqla/models.py | 87 +++++++++++++-------------
superset/db_engine_specs.py | 125 ++++++++++++++++++++-----------------
tests/db_engine_specs_test.py | 29 ++++++++-
3 files changed, 137 insertions(+), 104 deletions(-)
diff --git a/superset/connectors/sqla/models.py
b/superset/connectors/sqla/models.py
index ff2821b..8183d7c 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -116,14 +116,14 @@ class TableColumn(Model, BaseColumn):
export_parent = 'table'
def get_sqla_col(self, label=None):
- label = label if label else self.column_name
- label = self.table.get_label(label)
+ label = label or self.column_name
if not self.expression:
db_engine_spec = self.table.database.db_engine_spec
type_ = db_engine_spec.get_sqla_column_type(self.type)
- col = column(self.column_name, type_=type_).label(label)
+ col = column(self.column_name, type_=type_)
else:
- col = literal_column(self.expression).label(label)
+ col = literal_column(self.expression)
+ col = self.table.make_sqla_column_compatible(col, label)
return col
@property
@@ -142,13 +142,14 @@ class TableColumn(Model, BaseColumn):
def get_timestamp_expression(self, time_grain):
"""Getting the time component of the query"""
- label = self.table.get_label(utils.DTTM_ALIAS)
+ label = utils.DTTM_ALIAS
db = self.table.database
pdf = self.python_date_format
is_epoch = pdf in ('epoch_s', 'epoch_ms')
if not self.expression and not time_grain and not is_epoch:
- return column(self.column_name, type_=DateTime).label(label)
+ sqla_col = column(self.column_name, type_=DateTime)
+ return self.table.make_sqla_column_compatible(sqla_col, label)
grain = None
if time_grain:
grain = db.grains_dict().get(time_grain)
@@ -158,7 +159,8 @@ class TableColumn(Model, BaseColumn):
expr = db.db_engine_spec.get_time_expr(
self.expression or self.column_name,
pdf, time_grain, grain)
- return literal_column(expr, type_=DateTime).label(label)
+ sqla_col = literal_column(expr, type_=DateTime)
+ return self.table.make_sqla_column_compatible(sqla_col, label)
@classmethod
def import_obj(cls, i_column):
@@ -218,9 +220,9 @@ class SqlMetric(Model, BaseMetric):
export_parent = 'table'
def get_sqla_col(self, label=None):
- label = label if label else self.metric_name
- label = self.table.get_label(label)
- return literal_column(self.expression).label(label)
+ label = label or self.metric_name
+ sqla_col = literal_column(self.expression)
+ return self.table.make_sqla_column_compatible(sqla_col, label)
@property
def perm(self):
@@ -298,20 +300,19 @@ class SqlaTable(Model, BaseDatasource):
'MAX': sa.func.MAX,
}
- def get_label(self, label):
- """Conditionally mutate a label to conform to db engine requirements
- and store mapping from mutated label to original label
-
- :param label: original label
- :return: Either a string or sqlalchemy.sql.elements.quoted_name if
required
- by db engine
+ def make_sqla_column_compatible(self, sqla_col, label=None):
+ """Takes a sql alchemy column object and adds label info if supported
by engine.
+ :param sqla_col: sql alchemy column instance
+ :param label: alias/label that column is expected to have
+ :return: either a sql alchemy column or label instance if supported by
engine
"""
+ label_expected = label or sqla_col.name
db_engine_spec = self.database.db_engine_spec
- sqla_label = db_engine_spec.make_label_compatible(label)
- mutated_label = str(sqla_label)
- if label != mutated_label:
- self.mutated_labels[mutated_label] = label
- return sqla_label
+ if db_engine_spec.supports_column_aliases:
+ label = db_engine_spec.make_label_compatible(label_expected)
+ sqla_col = sqla_col.label(label)
+ sqla_col._df_label_expected = label_expected
+ return sqla_col
def __repr__(self):
return self.name
@@ -517,7 +518,6 @@ class SqlaTable(Model, BaseDatasource):
"""
expression_type = metric.get('expressionType')
label = utils.get_metric_name(metric)
- label = self.get_label(label)
if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
column_name = metric.get('column').get('column_name')
@@ -527,15 +527,13 @@ class SqlaTable(Model, BaseDatasource):
else:
sqla_column = column(column_name)
sqla_metric =
self.sqla_aggregations[metric.get('aggregate')](sqla_column)
- sqla_metric = sqla_metric.label(label)
- return sqla_metric
elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
sqla_metric = literal_column(metric.get('sqlExpression'))
- sqla_metric = sqla_metric.label(label)
- return sqla_metric
else:
return None
+ return self.make_sqla_column_compatible(sqla_metric, label)
+
def get_sqla_query( # sqla
self,
groupby, metrics,
@@ -569,9 +567,6 @@ class SqlaTable(Model, BaseDatasource):
template_processor = self.get_template_processor(**template_kwargs)
db_engine_spec = self.database.db_engine_spec
- # Initialize empty cache to store mutated labels
- self.mutated_labels = {}
-
orderby = orderby or []
# For backward compatibility
@@ -601,8 +596,8 @@ class SqlaTable(Model, BaseDatasource):
if metrics_exprs:
main_metric_expr = metrics_exprs[0]
else:
- label = self.get_label('ccount')
- main_metric_expr = literal_column('COUNT(*)').label(label)
+ main_metric_expr, label = literal_column('COUNT(*)'), 'ccount'
+ main_metric_expr =
self.make_sqla_column_compatible(main_metric_expr, label)
select_exprs = []
groupby_exprs_sans_timestamp = OrderedDict()
@@ -613,14 +608,16 @@ class SqlaTable(Model, BaseDatasource):
if s in cols:
outer = cols[s].get_sqla_col()
else:
- outer = literal_column(f'({s})').label(self.get_label(s))
+ outer = literal_column(f'({s})')
+ outer = self.make_sqla_column_compatible(outer, s)
groupby_exprs_sans_timestamp[outer.name] = outer
select_exprs.append(outer)
elif columns:
for s in columns:
select_exprs.append(
- cols[s].get_sqla_col() if s in cols else literal_column(s))
+ cols[s].get_sqla_col() if s in cols else
+ self.make_sqla_column_compatible(literal_column(s)))
metrics_exprs = []
groupby_exprs_with_timestamp =
OrderedDict(groupby_exprs_sans_timestamp.items())
@@ -644,7 +641,7 @@ class SqlaTable(Model, BaseDatasource):
select_exprs += metrics_exprs
- labels_expected = [str(c.name) for c in select_exprs]
+ labels_expected = [c._df_label_expected for c in select_exprs]
select_exprs = db_engine_spec.make_select_compatible(
groupby_exprs_with_timestamp.values(),
@@ -732,12 +729,12 @@ class SqlaTable(Model, BaseDatasource):
# some sql dialects require for order by expressions
# to also be in the select clause -- others, e.g. vertica,
# require a unique inner alias
- label = self.get_label('mme_inner__')
- inner_main_metric_expr = main_metric_expr.label(label)
+ inner_main_metric_expr = self.make_sqla_column_compatible(
+ main_metric_expr, 'mme_inner__')
inner_groupby_exprs = []
inner_select_exprs = []
for gby_name, gby_obj in groupby_exprs_sans_timestamp.items():
- inner = gby_obj.label(gby_name + '__')
+ inner = self.make_sqla_column_compatible(gby_obj, gby_name
+ '__')
inner_groupby_exprs.append(inner)
inner_select_exprs.append(inner)
@@ -766,7 +763,7 @@ class SqlaTable(Model, BaseDatasource):
# in this case the column name, not the alias, needs to be
# conditionally mutated, as it refers to the column alias
in
# the inner query
- col_name = self.get_label(gby_name + '__')
+ col_name = db_engine_spec.make_label_compatible(gby_name +
'__')
on_clause.append(gby_obj == column(col_name))
tbl = tbl.join(subq.alias(), and_(*on_clause))
@@ -841,15 +838,19 @@ class SqlaTable(Model, BaseDatasource):
status = utils.QueryStatus.SUCCESS
error_message = None
df = None
- db_engine_spec = self.database.db_engine_spec
try:
df = self.database.get_df(sql, self.schema)
- if self.mutated_labels:
- df = df.rename(index=str, columns=self.mutated_labels)
- db_engine_spec.mutate_df_columns(df, sql,
query_str_ext.labels_expected)
+ labels_expected = query_str_ext.labels_expected
+ if df is not None and not df.empty:
+ if len(df.columns) != len(labels_expected):
+ raise Exception(f'For {sql}, df.columns: {df.columns}'
+ f' differs from {labels_expected}')
+ else:
+ df.columns = labels_expected
except Exception as e:
status = utils.QueryStatus.FAILED
logging.exception(f'Query {sql} on schema {self.schema} failed')
+ db_engine_spec = self.database.db_engine_spec
error_message = db_engine_spec.extract_error_message(e)
# if this is a main query with prequeries, combine them together
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 40d1ea5..0785082 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -111,8 +111,10 @@ class BaseEngineSpec(object):
time_secondary_columns = False
inner_joins = True
allows_subquery = True
+ supports_column_aliases = True
force_column_alias_quotes = False
arraysize = None
+ max_column_name_length = None
@classmethod
def get_time_expr(cls, expr, pdf, time_grain, grain):
@@ -143,10 +145,6 @@ class BaseEngineSpec(object):
return select_exprs
@classmethod
- def mutate_df_columns(cls, df, sql, labels_expected):
- pass
-
- @classmethod
def fetch_data(cls, cursor, limit):
if cls.arraysize:
cursor.arraysize = cls.arraysize
@@ -287,6 +285,8 @@ class BaseEngineSpec(object):
schema=schema, force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
+ else:
+ raise Exception(f'Unsupported datasource_type:
{datasource_type}')
all_result_sets += [
'{}.{}'.format(schema, t) for t in all_datasource_names]
return all_result_sets
@@ -418,10 +418,15 @@ class BaseEngineSpec(object):
force_column_alias_quotes is set to True, return the label as a
sqlalchemy.sql.elements.quoted_name object to ensure that the select
query
and query results have same case. Otherwise return the mutated label
as a
- regular string.
+ regular string. If maxmimum supported column name length is exceeded,
+ generate a truncated label by calling truncate_label().
"""
- label = cls.mutate_label(label)
- return quoted_name(label, True) if cls.force_column_alias_quotes else
label
+ label_mutated = cls.mutate_label(label)
+ if cls.max_column_name_length and len(label_mutated) >
cls.max_column_name_length:
+ label_mutated = cls.truncate_label(label)
+ if cls.force_column_alias_quotes:
+ label_mutated = quoted_name(label_mutated, True)
+ return label_mutated
@classmethod
def get_sqla_column_type(cls, type_):
@@ -445,6 +450,19 @@ class BaseEngineSpec(object):
"""
return label
+ @classmethod
+ def truncate_label(cls, label):
+ """
+ In the case that a label exceeds the max length supported by the
engine,
+ this method is used to construct a deterministic and unique label
based on
+ an md5 hash.
+ """
+ label = hashlib.md5(label.encode('utf-8')).hexdigest()
+ # truncate hash if it exceeds max length
+ if cls.max_column_name_length and len(label) >
cls.max_column_name_length:
+ label = label[:cls.max_column_name_length]
+ return label
+
class PostgresBaseEngineSpec(BaseEngineSpec):
""" Abstract class for Postgres 'like' databases """
@@ -482,6 +500,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
class PostgresEngineSpec(PostgresBaseEngineSpec):
engine = 'postgresql'
+ max_column_name_length = 63
@classmethod
def get_table_names(cls, inspector, schema):
@@ -494,6 +513,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = 'snowflake'
force_column_alias_quotes = True
+ max_column_name_length = 256
time_grain_functions = {
None: '{col}',
@@ -531,6 +551,7 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):
class RedshiftEngineSpec(PostgresBaseEngineSpec):
engine = 'redshift'
+ max_column_name_length = 127
@staticmethod
def mutate_label(label):
@@ -546,6 +567,7 @@ class OracleEngineSpec(PostgresBaseEngineSpec):
engine = 'oracle'
limit_method = LimitMethod.WRAP_SQL
force_column_alias_quotes = True
+ max_column_name_length = 30
time_grain_functions = {
None: '{col}',
@@ -565,25 +587,12 @@ class OracleEngineSpec(PostgresBaseEngineSpec):
"""TO_TIMESTAMP('{}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')"""
).format(dttm.isoformat())
- @staticmethod
- def mutate_label(label):
- """
- Oracle 12.1 and earlier support a maximum of 30 byte length object
names, which
- usually means 30 characters.
- :param str label: Original label which might include unsupported
characters
- :return: String that is supported by the database
- """
- if len(label) > 30:
- hashed_label = hashlib.md5(label.encode('utf-8')).hexdigest()
- # truncate the hash to first 30 characters
- return hashed_label[:30]
- return label
-
class Db2EngineSpec(BaseEngineSpec):
engine = 'ibm_db_sa'
limit_method = LimitMethod.WRAP_SQL
force_column_alias_quotes = True
+ max_column_name_length = 30
time_grain_functions = {
None: '{col}',
@@ -618,20 +627,6 @@ class Db2EngineSpec(BaseEngineSpec):
def convert_dttm(cls, target_type, dttm):
return "'{}'".format(dttm.strftime('%Y-%m-%d-%H.%M.%S'))
- @staticmethod
- def mutate_label(label):
- """
- Db2 for z/OS supports a maximum of 30 byte length object names, which
usually
- means 30 characters.
- :param str label: Original label which might include unsupported
characters
- :return: String that is supported by the database
- """
- if len(label) > 30:
- hashed_label = hashlib.md5(label.encode('utf-8')).hexdigest()
- # truncate the hash to first 30 characters
- return hashed_label[:30]
- return label
-
class SqliteEngineSpec(BaseEngineSpec):
engine = 'sqlite'
@@ -668,6 +663,9 @@ class SqliteEngineSpec(BaseEngineSpec):
schema=schema, force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
+ else:
+ raise Exception(f'Unsupported datasource_type: {datasource_type}')
+
all_result_sets += [
'{}.{}'.format(schema, t) for t in all_datasource_names]
return all_result_sets
@@ -687,6 +685,7 @@ class SqliteEngineSpec(BaseEngineSpec):
class MySQLEngineSpec(BaseEngineSpec):
engine = 'mysql'
+ max_column_name_length = 64
time_grain_functions = {
None: '{col}',
@@ -1060,6 +1059,7 @@ class HiveEngineSpec(PrestoEngineSpec):
"""Reuses PrestoEngineSpec functionality."""
engine = 'hive'
+ max_column_name_length = 767
# Scoping regex at class level to avoid recompiling
# 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5
@@ -1366,6 +1366,7 @@ class MssqlEngineSpec(BaseEngineSpec):
engine = 'mssql'
epoch_to_dttm = "dateadd(S, {col}, '1970-01-01')"
limit_method = LimitMethod.WRAP_SQL
+ max_column_name_length = 128
time_grain_functions = {
None: '{col}',
@@ -1434,11 +1435,21 @@ class AthenaEngineSpec(BaseEngineSpec):
def epoch_to_dttm(cls):
return 'from_unixtime({col})'
+ @staticmethod
+ def mutate_label(label):
+ """
+ Athena only supports lowercase column names and aliases.
+ :param str label: Original label which might include uppercase letters
+ :return: String that is supported by the database
+ """
+ return label.lower()
+
class PinotEngineSpec(BaseEngineSpec):
engine = 'pinot'
allows_subquery = False
inner_joins = False
+ supports_column_aliases = False
_time_grain_to_datetimeconvert = {
'PT1S': '1:SECONDS',
@@ -1481,17 +1492,6 @@ class PinotEngineSpec(BaseEngineSpec):
select_sans_groupby.append(s)
return select_sans_groupby
- @classmethod
- def mutate_df_columns(cls, df, sql, labels_expected):
- if df is not None and \
- not df.empty and \
- labels_expected is not None:
- if len(df.columns) != len(labels_expected):
- raise Exception(f'For {sql}, df.columns: {df.columns}'
- f' differs from {labels_expected}')
- else:
- df.columns = labels_expected
-
class ClickHouseEngineSpec(BaseEngineSpec):
"""Dialect for ClickHouse analytical DB."""
@@ -1532,6 +1532,7 @@ class BQEngineSpec(BaseEngineSpec):
As contributed by @mxmzdlv on issue #945"""
engine = 'bigquery'
+ max_column_name_length = 128
"""
https://www.python.org/dev/peps/pep-0249/#arraysize
@@ -1574,28 +1575,33 @@ class BQEngineSpec(BaseEngineSpec):
@staticmethod
def mutate_label(label):
"""
- BigQuery field_name should start with a letter or underscore, contain
only
- alphanumeric characters and be at most 128 characters long. Labels
that start
- with a number are prefixed with an underscore. Any unsupported
characters are
- replaced with underscores and an md5 hash is added to the end of the
label to
- avoid possible collisions. If the resulting label exceeds 128
characters, only
- the md5 sum is returned.
+ BigQuery field_name should start with a letter or underscore and
contain only
+ alphanumeric characters. Labels that start with a number are prefixed
with an
+ underscore. Any unsupported characters are replaced with underscores
and an
+ md5 hash is added to the end of the label to avoid possible collisions.
:param str label: the original label which might include unsupported
characters
:return: String that is supported by the database
"""
- hashed_label = '_' + hashlib.md5(label.encode('utf-8')).hexdigest()
+ label_hashed = '_' + hashlib.md5(label.encode('utf-8')).hexdigest()
# if label starts with number, add underscore as first character
- mutated_label = '_' + label if re.match(r'^\d', label) else label
+ label_mutated = '_' + label if re.match(r'^\d', label) else label
# replace non-alphanumeric characters with underscores
- mutated_label = re.sub(r'[^\w]+', '_', mutated_label)
- if mutated_label != label:
+ label_mutated = re.sub(r'[^\w]+', '_', label_mutated)
+ if label_mutated != label:
# add md5 hash to label to avoid possible collisions
- mutated_label += hashed_label
+ label_mutated += label_hashed
+
+ return label_mutated
- # return only hash if length of final label exceeds 128 chars
- return mutated_label if len(mutated_label) <= 128 else hashed_label
+ @classmethod
+ def truncate_label(cls, label):
+ """BigQuery requires column names start with either a letter or
+ underscore. To make sure this is always the case, an underscore is
prefixed
+ to the truncated label.
+ """
+ return '_' + hashlib.md5(label.encode('utf-8')).hexdigest()
@classmethod
def extra_table_metadata(cls, database, table_name, schema_name):
@@ -1727,6 +1733,7 @@ class TeradataEngineSpec(BaseEngineSpec):
"""Dialect for Teradata DB."""
engine = 'teradata'
limit_method = LimitMethod.WRAP_SQL
+ max_column_name_length = 30 # since 14.10 this is 128
time_grain_functions = {
None: '{col}',
diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py
index c2f1713b..a48012d 100644
--- a/tests/db_engine_specs_test.py
+++ b/tests/db_engine_specs_test.py
@@ -17,11 +17,12 @@
import inspect
import mock
+from sqlalchemy import column
from superset import db_engine_specs
from superset.db_engine_specs import (
- BaseEngineSpec, HiveEngineSpec, MssqlEngineSpec,
- MySQLEngineSpec, PrestoEngineSpec,
+ BaseEngineSpec, BQEngineSpec, HiveEngineSpec, MssqlEngineSpec,
+ MySQLEngineSpec, OracleEngineSpec, PrestoEngineSpec,
)
from superset.models.core import Database
from .base_tests import SupersetTestCase
@@ -307,3 +308,27 @@ class DbEngineSpecsTestCase(SupersetTestCase):
def test_hive_get_view_names_return_empty_list(self):
self.assertEquals([], HiveEngineSpec.get_view_names(mock.ANY,
mock.ANY))
+
+ def test_bigquery_sqla_column_label(self):
+ label = BQEngineSpec.make_label_compatible(column('Col').name)
+ label_expected = 'Col'
+ self.assertEqual(label, label_expected)
+
+ label = BQEngineSpec.make_label_compatible(column('SUM(x)').name)
+ label_expected = 'SUM_x__5f110b965a993675bc4953bb3e03c4a5'
+ self.assertEqual(label, label_expected)
+
+ label = BQEngineSpec.make_label_compatible(column('SUM[x]').name)
+ label_expected = 'SUM_x__7ebe14a3f9534aeee125449b0bc083a8'
+ self.assertEqual(label, label_expected)
+
+ label = BQEngineSpec.make_label_compatible(column('12345_col').name)
+ label_expected = '_12345_col_8d3906e2ea99332eb185f7f8ecb2ffd6'
+ self.assertEqual(label, label_expected)
+
+ def test_oracle_sqla_column_name_length_exceeded(self):
+ col = column('This_Is_32_Character_Column_Name')
+ label = OracleEngineSpec.make_label_compatible(col.name)
+ self.assertEqual(label.quote, True)
+ label_expected = '3b26974078683be078219674eeb8f5'
+ self.assertEqual(label, label_expected)