This is an automated email from the ASF dual-hosted git repository.
maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 7ee8afb Improve support for BigQuery, Redshift, Oracle, Db2,
Snowflake (#5827)
7ee8afb is described below
commit 7ee8afb608e1251e3c2f7617fb0ae1a5340ae1a5
Author: Ville Brofeldt <[email protected]>
AuthorDate: Fri Jan 18 18:24:11 2019 +0200
Improve support for BigQuery, Redshift, Oracle, Db2, Snowflake (#5827)
* Conditionally mutate and quote sqla labels decouple sqla logic from viz.py
* Prefix hashed label with underscore if bigquery label exceeds 128 chars
* Add comments for label cache
* Rename to mutated_labels and simply
* Rename mutated_label to get_label and simplify make_label_compatible in
db_engine_specs
* Add note about deterministic and unique mutated labels
* add hash to label that has been prefixed with underscore
* Fix PEP8 escape warning
* Fix DeckPathViz get_metric_label call
---
superset/connectors/sqla/models.py | 56 +++++++++++++++-------
superset/db_engine_specs.py | 98 +++++++++++++++++++++++++++++++-------
superset/viz.py | 51 ++++++++------------
3 files changed, 138 insertions(+), 67 deletions(-)
diff --git a/superset/connectors/sqla/models.py
b/superset/connectors/sqla/models.py
index 10591d3..dff4559 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -112,8 +112,8 @@ class TableColumn(Model, BaseColumn):
export_parent = 'table'
def get_sqla_col(self, label=None):
- db_engine_spec = self.table.database.db_engine_spec
- label = db_engine_spec.make_label_compatible(label if label else
self.column_name)
+ label = label if label else self.column_name
+ label = self.table.get_label(label)
if not self.expression:
col = column(self.column_name).label(label)
else:
@@ -135,10 +135,12 @@ class TableColumn(Model, BaseColumn):
def get_timestamp_expression(self, time_grain):
"""Getting the time component of the query"""
+ label = self.table.get_label(utils.DTTM_ALIAS)
+
pdf = self.python_date_format
is_epoch = pdf in ('epoch_s', 'epoch_ms')
if not self.expression and not time_grain and not is_epoch:
- return column(self.column_name,
type_=DateTime).label(utils.DTTM_ALIAS)
+ return column(self.column_name, type_=DateTime).label(label)
expr = self.expression or self.column_name
if is_epoch:
@@ -152,7 +154,7 @@ class TableColumn(Model, BaseColumn):
grain = self.table.database.grains_dict().get(time_grain)
if grain:
expr = grain.function.format(col=expr)
- return literal_column(expr, type_=DateTime).label(utils.DTTM_ALIAS)
+ return literal_column(expr, type_=DateTime).label(label)
@classmethod
def import_obj(cls, i_column):
@@ -207,8 +209,8 @@ class SqlMetric(Model, BaseMetric):
export_parent = 'table'
def get_sqla_col(self, label=None):
- db_engine_spec = self.table.database.db_engine_spec
- label = db_engine_spec.make_label_compatible(label if label else
self.metric_name)
+ label = label if label else self.metric_name
+ label = self.table.get_label(label)
return literal_column(self.expression).label(label)
@property
@@ -287,6 +289,21 @@ class SqlaTable(Model, BaseDatasource):
'MAX': sa.func.MAX,
}
+ def get_label(self, label):
+ """Conditionally mutate a label to conform to db engine requirements
+ and store mapping from mutated label to original label
+
+ :param label: original label
+ :return: Either a string or sqlalchemy.sql.elements.quoted_name if
required
+ by db engine
+ """
+ db_engine_spec = self.database.db_engine_spec
+ sqla_label = db_engine_spec.make_label_compatible(label)
+ mutated_label = str(sqla_label)
+ if label != mutated_label:
+ self.mutated_labels[mutated_label] = label
+ return sqla_label
+
def __repr__(self):
return self.name
@@ -486,8 +503,8 @@ class SqlaTable(Model, BaseDatasource):
:rtype: sqlalchemy.sql.column
"""
expression_type = metric.get('expressionType')
- db_engine_spec = self.database.db_engine_spec
- label = db_engine_spec.make_label_compatible(metric.get('label'))
+ label = utils.get_metric_name(metric)
+ label = self.get_label(label)
if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
column_name = metric.get('column').get('column_name')
@@ -540,6 +557,9 @@ class SqlaTable(Model, BaseDatasource):
template_processor = self.get_template_processor(**template_kwargs)
db_engine_spec = self.database.db_engine_spec
+ # Initialize empty cache to store mutated labels
+ self.mutated_labels = {}
+
orderby = orderby or []
# For backward compatibility
@@ -569,8 +589,8 @@ class SqlaTable(Model, BaseDatasource):
if metrics_exprs:
main_metric_expr = metrics_exprs[0]
else:
- main_metric_expr = literal_column('COUNT(*)').label(
- db_engine_spec.make_label_compatible('count'))
+ label = self.get_label('ccount')
+ main_metric_expr = literal_column('COUNT(*)').label(label)
select_exprs = []
groupby_exprs = []
@@ -695,7 +715,8 @@ class SqlaTable(Model, BaseDatasource):
# some sql dialects require for order by expressions
# to also be in the select clause -- others, e.g. vertica,
# require a unique inner alias
- inner_main_metric_expr = main_metric_expr.label('mme_inner__')
+ label = self.get_label('mme_inner__')
+ inner_main_metric_expr = main_metric_expr.label(label)
inner_select_exprs += [inner_main_metric_expr]
subq = select(inner_select_exprs)
subq = subq.select_from(tbl)
@@ -723,8 +744,11 @@ class SqlaTable(Model, BaseDatasource):
on_clause = []
for i, gb in enumerate(groupby):
- on_clause.append(
- groupby_exprs[i] == column(gb + '__'))
+ # in this case the column name, not the alias, needs to be
+ # conditionally mutated, as it refers to the column alias
in
+ # the inner query
+ col_name = self.get_label(gb + '__')
+ on_clause.append(groupby_exprs[i] == column(col_name))
tbl = tbl.join(subq.alias(), and_(*on_clause))
else:
@@ -776,6 +800,8 @@ class SqlaTable(Model, BaseDatasource):
df = None
try:
df = self.database.get_df(sql, self.schema)
+ if self.mutated_labels:
+ df = df.rename(index=str, columns=self.mutated_labels)
except Exception as e:
status = utils.QueryStatus.FAILED
logging.exception(e)
@@ -818,7 +844,6 @@ class SqlaTable(Model, BaseDatasource):
.filter(or_(TableColumn.column_name == col.name
for col in table.columns)))
dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
- db_engine_spec = self.database.db_engine_spec
for col in table.columns:
try:
@@ -850,9 +875,6 @@ class SqlaTable(Model, BaseDatasource):
))
if not self.main_dttm_col:
self.main_dttm_col = any_date_col
- for metric in metrics:
- metric.metric_name = db_engine_spec.mutate_expression_label(
- metric.metric_name)
self.add_missing_metrics(metrics)
db.session.merge(self)
db.session.commit()
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 1a9bf81..379c0cf 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -29,6 +29,7 @@ at all. The classes here will use a common interface to
specify all this.
The general idea is to use static classes and an inheritance scheme.
"""
from collections import namedtuple
+import hashlib
import inspect
import logging
import os
@@ -392,16 +393,26 @@ class BaseEngineSpec(object):
@classmethod
def make_label_compatible(cls, label):
"""
- Return a sqlalchemy.sql.elements.quoted_name if the engine requires
- quoting of aliases to ensure that select query and query results
- have same case.
+ Conditionally mutate and/or quote a sql column/expression label. If
+ force_column_alias_quotes is set to True, return the label as a
+ sqlalchemy.sql.elements.quoted_name object to ensure that the select
query
+ and query results have same case. Otherwise return the mutated label
as a
+ regular string.
"""
- if cls.force_column_alias_quotes is True:
- return quoted_name(label, True)
- return label
+ label = cls.mutate_label(label)
+ return quoted_name(label, True) if cls.force_column_alias_quotes else
label
@staticmethod
- def mutate_expression_label(label):
+ def mutate_label(label):
+ """
+ Most engines support mixed case aliases that can include numbers
+ and special characters, like commas, parentheses etc. For engines that
+ have restrictions on what types of aliases are supported, this method
+ can be overridden to ensure that labels conform to the engine's
+ limitations. Mutated labels should be deterministic (input label A
always
+ yields output label X) and unique (input labels A and B don't yield
the same
+ output label X).
+ """
return label
@@ -490,7 +501,15 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):
class RedshiftEngineSpec(PostgresBaseEngineSpec):
engine = 'redshift'
- force_column_alias_quotes = True
+
+ @staticmethod
+ def mutate_label(label):
+ """
+ Redshift only supports lowercase column names and aliases.
+ :param str label: Original label which might include uppercase letters
+ :return: String that is supported by the database
+ """
+ return label.lower()
class OracleEngineSpec(PostgresBaseEngineSpec):
@@ -516,11 +535,26 @@ class OracleEngineSpec(PostgresBaseEngineSpec):
"""TO_TIMESTAMP('{}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')"""
).format(dttm.isoformat())
+ @staticmethod
+ def mutate_label(label):
+ """
+ Oracle 12.1 and earlier support a maximum of 30 byte length object
names, which
+ usually means 30 characters.
+ :param str label: Original label which might include unsupported
characters
+ :return: String that is supported by the database
+ """
+ if len(label) > 30:
+ hashed_label = hashlib.md5(label.encode('utf-8')).hexdigest()
+ # truncate the hash to first 30 characters
+ return hashed_label[:30]
+ return label
+
class Db2EngineSpec(BaseEngineSpec):
engine = 'ibm_db_sa'
limit_method = LimitMethod.WRAP_SQL
force_column_alias_quotes = True
+
time_grain_functions = {
None: '{col}',
'PT1S': 'CAST({col} as TIMESTAMP)'
@@ -554,6 +588,20 @@ class Db2EngineSpec(BaseEngineSpec):
def convert_dttm(cls, target_type, dttm):
return "'{}'".format(dttm.strftime('%Y-%m-%d-%H.%M.%S'))
+ @staticmethod
+ def mutate_label(label):
+ """
+ Db2 for z/OS supports a maximum of 30 byte length object names, which
usually
+ means 30 characters.
+ :param str label: Original label which might include unsupported
characters
+ :return: String that is supported by the database
+ """
+ if len(label) > 30:
+ hashed_label = hashlib.md5(label.encode('utf-8')).hexdigest()
+ # truncate the hash to first 30 characters
+ return hashed_label[:30]
+ return label
+
class SqliteEngineSpec(BaseEngineSpec):
engine = 'sqlite'
@@ -1424,16 +1472,30 @@ class BQEngineSpec(BaseEngineSpec):
return data
@staticmethod
- def mutate_expression_label(label):
- mutated_label = re.sub('[^\w]+', '_', label)
- if not re.match('^[a-zA-Z_]+.*', mutated_label):
- raise SupersetTemplateException('BigQuery field_name used is
invalid {}, '
- 'should start with a letter or '
- 'underscore'.format(mutated_label))
- if len(mutated_label) > 128:
- raise SupersetTemplateException('BigQuery field_name {}, should be
atmost '
- '128
characters'.format(mutated_label))
- return mutated_label
+ def mutate_label(label):
+ """
+ BigQuery field_name should start with a letter or underscore, contain
only
+ alphanumeric characters and be at most 128 characters long. Labels
that start
+ with a number are prefixed with an underscore. Any unsupported
characters are
+ replaced with underscores and an md5 hash is added to the end of the
label to
+ avoid possible collisions. If the resulting label exceeds 128
characters, only
+ the md5 sum is returned.
+ :param str label: the original label which might include unsupported
characters
+ :return: String that is supported by the database
+ """
+ hashed_label = '_' + hashlib.md5(label.encode('utf-8')).hexdigest()
+
+ # if label starts with number, add underscore as first character
+ mutated_label = '_' + label if re.match(r'^\d', label) else label
+
+ # replace non-alphanumeric characters with underscores
+ mutated_label = re.sub(r'[^\w]+', '_', mutated_label)
+ if mutated_label != label:
+ # add md5 hash to label to avoid possible collisions
+ mutated_label += hashed_label
+
+ # return only hash if length of final label exceeds 128 chars
+ return mutated_label if len(mutated_label) <= 128 else hashed_label
@classmethod
def extra_table_metadata(cls, database, table_name, schema_name):
diff --git a/superset/viz.py b/superset/viz.py
index 7cf6465..3bdeb79 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -121,27 +121,13 @@ class BaseViz(object):
if not isinstance(val, list):
val = [val]
for o in val:
- label = self.get_metric_label(o)
- if isinstance(o, dict):
- o['label'] = label
+ label = utils.get_metric_name(o)
self.metric_dict[label] = o
# Cast to list needed to return serializable object in py3
self.all_metrics = list(self.metric_dict.values())
self.metric_labels = list(self.metric_dict.keys())
- def get_metric_label(self, metric):
- if isinstance(metric, str):
- return metric
-
- if isinstance(metric, dict):
- metric = metric.get('label')
-
- if self.datasource.type == 'table':
- db_engine_spec = self.datasource.database.db_engine_spec
- metric = db_engine_spec.mutate_expression_label(metric)
- return metric
-
@staticmethod
def handle_js_int_overflow(data):
for d in data.get('records', dict()):
@@ -577,7 +563,7 @@ class TableViz(BaseViz):
# Sum up and compute percentages for all percent metrics
percent_metrics = fd.get('percent_metrics') or []
- percent_metrics = [self.get_metric_label(m) for m in percent_metrics]
+ percent_metrics = [utils.get_metric_name(m) for m in percent_metrics]
if len(percent_metrics):
percent_metrics = list(filter(lambda m: m in df, percent_metrics))
@@ -595,7 +581,7 @@ class TableViz(BaseViz):
df[m_name] = pd.Series(metric_percents[m], name=m_name)
# Remove metrics that are not in the main metrics list
metrics = fd.get('metrics') or []
- metrics = [self.get_metric_label(m) for m in metrics]
+ metrics = [utils.get_metric_name(m) for m in metrics]
for m in filter(
lambda m: m not in metrics and m in df.columns,
percent_metrics,
@@ -695,7 +681,7 @@ class PivotTableViz(BaseViz):
df = df.pivot_table(
index=self.form_data.get('groupby'),
columns=self.form_data.get('columns'),
- values=[self.get_metric_label(m) for m in
self.form_data.get('metrics')],
+ values=[utils.get_metric_name(m) for m in
self.form_data.get('metrics')],
aggfunc=self.form_data.get('pandas_aggfunc'),
margins=self.form_data.get('pivot_margins'),
)
@@ -1030,7 +1016,7 @@ class BulletViz(NVD3Viz):
def get_data(self, df):
df = df.fillna(0)
- df['metric'] = df[[self.get_metric_label(self.metric)]]
+ df['metric'] = df[[utils.get_metric_name(self.metric)]]
values = df['metric'].values
return {
'measures': values.tolist(),
@@ -1150,6 +1136,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
df = df.fillna(0)
if fd.get('granularity') == 'all':
raise Exception(_('Pick a time granularity for your time series'))
+
if not aggregate:
df = df.pivot_table(
index=DTTM_ALIAS,
@@ -1365,8 +1352,8 @@ class NVD3DualLineViz(NVD3Viz):
if self.form_data.get('granularity') == 'all':
raise Exception(_('Pick a time granularity for your time series'))
- metric = self.get_metric_label(fd.get('metric'))
- metric_2 = self.get_metric_label(fd.get('metric_2'))
+ metric = utils.get_metric_name(fd.get('metric'))
+ metric_2 = utils.get_metric_name(fd.get('metric_2'))
df = df.pivot_table(
index=DTTM_ALIAS,
values=[metric, metric_2])
@@ -1417,7 +1404,7 @@ class NVD3TimePivotViz(NVD3TimeSeriesViz):
df = df.pivot_table(
index=DTTM_ALIAS,
columns='series',
- values=self.get_metric_label(fd.get('metric')))
+ values=utils.get_metric_name(fd.get('metric')))
chart_data = self.to_series(df)
for serie in chart_data:
serie['rank'] = rank_lookup[serie['key']]
@@ -1589,8 +1576,8 @@ class SunburstViz(BaseViz):
def get_data(self, df):
fd = self.form_data
cols = fd.get('groupby')
- metric = self.get_metric_label(fd.get('metric'))
- secondary_metric = self.get_metric_label(fd.get('secondary_metric'))
+ metric = utils.get_metric_name(fd.get('metric'))
+ secondary_metric = utils.get_metric_name(fd.get('secondary_metric'))
if metric == secondary_metric or secondary_metric is None:
df.columns = cols + ['m1']
df['m2'] = df['m1']
@@ -1691,7 +1678,7 @@ class ChordViz(BaseViz):
qry = super(ChordViz, self).query_obj()
fd = self.form_data
qry['groupby'] = [fd.get('groupby'), fd.get('columns')]
- qry['metrics'] = [self.get_metric_label(fd.get('metric'))]
+ qry['metrics'] = [utils.get_metric_name(fd.get('metric'))]
return qry
def get_data(self, df):
@@ -1757,8 +1744,8 @@ class WorldMapViz(BaseViz):
from superset.data import countries
fd = self.form_data
cols = [fd.get('entity')]
- metric = self.get_metric_label(fd.get('metric'))
- secondary_metric = self.get_metric_label(fd.get('secondary_metric'))
+ metric = utils.get_metric_name(fd.get('metric'))
+ secondary_metric = utils.get_metric_name(fd.get('secondary_metric'))
columns = ['country', 'm1', 'm2']
if metric == secondary_metric:
ndf = df[cols]
@@ -2289,7 +2276,7 @@ class DeckScatterViz(BaseDeckGLViz):
def get_data(self, df):
fd = self.form_data
self.metric_label = \
- self.get_metric_label(self.metric) if self.metric else None
+ utils.get_metric_name(self.metric) if self.metric else None
self.point_radius_fixed = fd.get('point_radius_fixed')
self.fixed_value = None
self.dim = self.form_data.get('dimension')
@@ -2320,7 +2307,7 @@ class DeckScreengrid(BaseDeckGLViz):
}
def get_data(self, df):
- self.metric_label = self.get_metric_label(self.metric)
+ self.metric_label = utils.get_metric_name(self.metric)
return super(DeckScreengrid, self).get_data(df)
@@ -2339,7 +2326,7 @@ class DeckGrid(BaseDeckGLViz):
}
def get_data(self, df):
- self.metric_label = self.get_metric_label(self.metric)
+ self.metric_label = utils.get_metric_name(self.metric)
return super(DeckGrid, self).get_data(df)
@@ -2397,7 +2384,7 @@ class DeckPathViz(BaseDeckGLViz):
return d
def get_data(self, df):
- self.metric_label = self.get_metric_label(self.metric)
+ self.metric_label = utils.get_metric_name(self.metric)
return super(DeckPathViz, self).get_data(df)
@@ -2445,7 +2432,7 @@ class DeckHex(BaseDeckGLViz):
}
def get_data(self, df):
- self.metric_label = self.get_metric_label(self.metric)
+ self.metric_label = utils.get_metric_name(self.metric)
return super(DeckHex, self).get_data(df)