This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 7ee8afb  Improve support for BigQuery, Redshift, Oracle, Db2, 
Snowflake (#5827)
7ee8afb is described below

commit 7ee8afb608e1251e3c2f7617fb0ae1a5340ae1a5
Author: Ville Brofeldt <[email protected]>
AuthorDate: Fri Jan 18 18:24:11 2019 +0200

    Improve support for BigQuery, Redshift, Oracle, Db2, Snowflake (#5827)
    
    * Conditionally mutate and quote sqla labels decouple sqla logic from viz.py
    
    * Prefix hashed label with underscore if bigquery label exceeds 128 chars
    
    * Add comments for label cache
    
    * Rename to mutated_labels and simply
    
    * Rename mutated_label to get_label and simplify make_label_compatible in 
db_engine_specs
    
    * Add note about deterministic and unique mutated labels
    
    * add hash to label that has been prefixed with underscore
    
    * Fix PEP8 escape warning
    
    * Fix DeckPathViz get_metric_label call
---
 superset/connectors/sqla/models.py | 56 +++++++++++++++-------
 superset/db_engine_specs.py        | 98 +++++++++++++++++++++++++++++++-------
 superset/viz.py                    | 51 ++++++++------------
 3 files changed, 138 insertions(+), 67 deletions(-)

diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 10591d3..dff4559 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -112,8 +112,8 @@ class TableColumn(Model, BaseColumn):
     export_parent = 'table'
 
     def get_sqla_col(self, label=None):
-        db_engine_spec = self.table.database.db_engine_spec
-        label = db_engine_spec.make_label_compatible(label if label else 
self.column_name)
+        label = label if label else self.column_name
+        label = self.table.get_label(label)
         if not self.expression:
             col = column(self.column_name).label(label)
         else:
@@ -135,10 +135,12 @@ class TableColumn(Model, BaseColumn):
 
     def get_timestamp_expression(self, time_grain):
         """Getting the time component of the query"""
+        label = self.table.get_label(utils.DTTM_ALIAS)
+
         pdf = self.python_date_format
         is_epoch = pdf in ('epoch_s', 'epoch_ms')
         if not self.expression and not time_grain and not is_epoch:
-            return column(self.column_name, 
type_=DateTime).label(utils.DTTM_ALIAS)
+            return column(self.column_name, type_=DateTime).label(label)
 
         expr = self.expression or self.column_name
         if is_epoch:
@@ -152,7 +154,7 @@ class TableColumn(Model, BaseColumn):
             grain = self.table.database.grains_dict().get(time_grain)
             if grain:
                 expr = grain.function.format(col=expr)
-        return literal_column(expr, type_=DateTime).label(utils.DTTM_ALIAS)
+        return literal_column(expr, type_=DateTime).label(label)
 
     @classmethod
     def import_obj(cls, i_column):
@@ -207,8 +209,8 @@ class SqlMetric(Model, BaseMetric):
     export_parent = 'table'
 
     def get_sqla_col(self, label=None):
-        db_engine_spec = self.table.database.db_engine_spec
-        label = db_engine_spec.make_label_compatible(label if label else 
self.metric_name)
+        label = label if label else self.metric_name
+        label = self.table.get_label(label)
         return literal_column(self.expression).label(label)
 
     @property
@@ -287,6 +289,21 @@ class SqlaTable(Model, BaseDatasource):
         'MAX': sa.func.MAX,
     }
 
+    def get_label(self, label):
+        """Conditionally mutate a label to conform to db engine requirements
+        and store mapping from mutated label to original label
+
+        :param label: original label
+        :return: Either a string or sqlalchemy.sql.elements.quoted_name if 
required
+        by db engine
+        """
+        db_engine_spec = self.database.db_engine_spec
+        sqla_label = db_engine_spec.make_label_compatible(label)
+        mutated_label = str(sqla_label)
+        if label != mutated_label:
+            self.mutated_labels[mutated_label] = label
+        return sqla_label
+
     def __repr__(self):
         return self.name
 
@@ -486,8 +503,8 @@ class SqlaTable(Model, BaseDatasource):
         :rtype: sqlalchemy.sql.column
         """
         expression_type = metric.get('expressionType')
-        db_engine_spec = self.database.db_engine_spec
-        label = db_engine_spec.make_label_compatible(metric.get('label'))
+        label = utils.get_metric_name(metric)
+        label = self.get_label(label)
 
         if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
             column_name = metric.get('column').get('column_name')
@@ -540,6 +557,9 @@ class SqlaTable(Model, BaseDatasource):
         template_processor = self.get_template_processor(**template_kwargs)
         db_engine_spec = self.database.db_engine_spec
 
+        # Initialize empty cache to store mutated labels
+        self.mutated_labels = {}
+
         orderby = orderby or []
 
         # For backward compatibility
@@ -569,8 +589,8 @@ class SqlaTable(Model, BaseDatasource):
         if metrics_exprs:
             main_metric_expr = metrics_exprs[0]
         else:
-            main_metric_expr = literal_column('COUNT(*)').label(
-                db_engine_spec.make_label_compatible('count'))
+            label = self.get_label('ccount')
+            main_metric_expr = literal_column('COUNT(*)').label(label)
 
         select_exprs = []
         groupby_exprs = []
@@ -695,7 +715,8 @@ class SqlaTable(Model, BaseDatasource):
                 # some sql dialects require for order by expressions
                 # to also be in the select clause -- others, e.g. vertica,
                 # require a unique inner alias
-                inner_main_metric_expr = main_metric_expr.label('mme_inner__')
+                label = self.get_label('mme_inner__')
+                inner_main_metric_expr = main_metric_expr.label(label)
                 inner_select_exprs += [inner_main_metric_expr]
                 subq = select(inner_select_exprs)
                 subq = subq.select_from(tbl)
@@ -723,8 +744,11 @@ class SqlaTable(Model, BaseDatasource):
 
                 on_clause = []
                 for i, gb in enumerate(groupby):
-                    on_clause.append(
-                        groupby_exprs[i] == column(gb + '__'))
+                    # in this case the column name, not the alias, needs to be
+                    # conditionally mutated, as it refers to the column alias 
in
+                    # the inner query
+                    col_name = self.get_label(gb + '__')
+                    on_clause.append(groupby_exprs[i] == column(col_name))
 
                 tbl = tbl.join(subq.alias(), and_(*on_clause))
             else:
@@ -776,6 +800,8 @@ class SqlaTable(Model, BaseDatasource):
         df = None
         try:
             df = self.database.get_df(sql, self.schema)
+            if self.mutated_labels:
+                df = df.rename(index=str, columns=self.mutated_labels)
         except Exception as e:
             status = utils.QueryStatus.FAILED
             logging.exception(e)
@@ -818,7 +844,6 @@ class SqlaTable(Model, BaseDatasource):
             .filter(or_(TableColumn.column_name == col.name
                         for col in table.columns)))
         dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
-        db_engine_spec = self.database.db_engine_spec
 
         for col in table.columns:
             try:
@@ -850,9 +875,6 @@ class SqlaTable(Model, BaseDatasource):
         ))
         if not self.main_dttm_col:
             self.main_dttm_col = any_date_col
-        for metric in metrics:
-            metric.metric_name = db_engine_spec.mutate_expression_label(
-                metric.metric_name)
         self.add_missing_metrics(metrics)
         db.session.merge(self)
         db.session.commit()
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 1a9bf81..379c0cf 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -29,6 +29,7 @@ at all. The classes here will use a common interface to 
specify all this.
 The general idea is to use static classes and an inheritance scheme.
 """
 from collections import namedtuple
+import hashlib
 import inspect
 import logging
 import os
@@ -392,16 +393,26 @@ class BaseEngineSpec(object):
     @classmethod
     def make_label_compatible(cls, label):
         """
-        Return a sqlalchemy.sql.elements.quoted_name if the engine requires
-        quoting of aliases to ensure that select query and query results
-        have same case.
+        Conditionally mutate and/or quote a sql column/expression label. If
+        force_column_alias_quotes is set to True, return the label as a
+        sqlalchemy.sql.elements.quoted_name object to ensure that the select 
query
+        and query results have same case. Otherwise return the mutated label 
as a
+        regular string.
         """
-        if cls.force_column_alias_quotes is True:
-            return quoted_name(label, True)
-        return label
+        label = cls.mutate_label(label)
+        return quoted_name(label, True) if cls.force_column_alias_quotes else 
label
 
     @staticmethod
-    def mutate_expression_label(label):
+    def mutate_label(label):
+        """
+        Most engines support mixed case aliases that can include numbers
+        and special characters, like commas, parentheses etc. For engines that
+        have restrictions on what types of aliases are supported, this method
+        can be overridden to ensure that labels conform to the engine's
+        limitations. Mutated labels should be deterministic (input label A 
always
+        yields output label X) and unique (input labels A and B don't yield 
the same
+        output label X).
+        """
         return label
 
 
@@ -490,7 +501,15 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):
 
 class RedshiftEngineSpec(PostgresBaseEngineSpec):
     engine = 'redshift'
-    force_column_alias_quotes = True
+
+    @staticmethod
+    def mutate_label(label):
+        """
+        Redshift only supports lowercase column names and aliases.
+        :param str label: Original label which might include uppercase letters
+        :return: String that is supported by the database
+        """
+        return label.lower()
 
 
 class OracleEngineSpec(PostgresBaseEngineSpec):
@@ -516,11 +535,26 @@ class OracleEngineSpec(PostgresBaseEngineSpec):
             """TO_TIMESTAMP('{}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')"""
         ).format(dttm.isoformat())
 
+    @staticmethod
+    def mutate_label(label):
+        """
+        Oracle 12.1 and earlier support a maximum of 30 byte length object 
names, which
+        usually means 30 characters.
+        :param str label: Original label which might include unsupported 
characters
+        :return: String that is supported by the database
+        """
+        if len(label) > 30:
+            hashed_label = hashlib.md5(label.encode('utf-8')).hexdigest()
+            # truncate the hash to first 30 characters
+            return hashed_label[:30]
+        return label
+
 
 class Db2EngineSpec(BaseEngineSpec):
     engine = 'ibm_db_sa'
     limit_method = LimitMethod.WRAP_SQL
     force_column_alias_quotes = True
+
     time_grain_functions = {
         None: '{col}',
         'PT1S': 'CAST({col} as TIMESTAMP)'
@@ -554,6 +588,20 @@ class Db2EngineSpec(BaseEngineSpec):
     def convert_dttm(cls, target_type, dttm):
         return "'{}'".format(dttm.strftime('%Y-%m-%d-%H.%M.%S'))
 
+    @staticmethod
+    def mutate_label(label):
+        """
+        Db2 for z/OS supports a maximum of 30 byte length object names, which 
usually
+        means 30 characters.
+        :param str label: Original label which might include unsupported 
characters
+        :return: String that is supported by the database
+        """
+        if len(label) > 30:
+            hashed_label = hashlib.md5(label.encode('utf-8')).hexdigest()
+            # truncate the hash to first 30 characters
+            return hashed_label[:30]
+        return label
+
 
 class SqliteEngineSpec(BaseEngineSpec):
     engine = 'sqlite'
@@ -1424,16 +1472,30 @@ class BQEngineSpec(BaseEngineSpec):
         return data
 
     @staticmethod
-    def mutate_expression_label(label):
-        mutated_label = re.sub('[^\w]+', '_', label)
-        if not re.match('^[a-zA-Z_]+.*', mutated_label):
-            raise SupersetTemplateException('BigQuery field_name used is 
invalid {}, '
-                                            'should start with a letter or '
-                                            'underscore'.format(mutated_label))
-        if len(mutated_label) > 128:
-            raise SupersetTemplateException('BigQuery field_name {}, should be 
atmost '
-                                            '128 
characters'.format(mutated_label))
-        return mutated_label
+    def mutate_label(label):
+        """
+        BigQuery field_name should start with a letter or underscore, contain 
only
+        alphanumeric characters and be at most 128 characters long. Labels 
that start
+        with a number are prefixed with an underscore. Any unsupported 
characters are
+        replaced with underscores and an md5 hash is added to the end of the 
label to
+        avoid possible collisions. If the resulting label exceeds 128 
characters, only
+        the md5 sum is returned.
+        :param str label: the original label which might include unsupported 
characters
+        :return: String that is supported by the database
+        """
+        hashed_label = '_' + hashlib.md5(label.encode('utf-8')).hexdigest()
+
+        # if label starts with number, add underscore as first character
+        mutated_label = '_' + label if re.match(r'^\d', label) else label
+
+        # replace non-alphanumeric characters with underscores
+        mutated_label = re.sub(r'[^\w]+', '_', mutated_label)
+        if mutated_label != label:
+            # add md5 hash to label to avoid possible collisions
+            mutated_label += hashed_label
+
+        # return only hash if length of final label exceeds 128 chars
+        return mutated_label if len(mutated_label) <= 128 else hashed_label
 
     @classmethod
     def extra_table_metadata(cls, database, table_name, schema_name):
diff --git a/superset/viz.py b/superset/viz.py
index 7cf6465..3bdeb79 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -121,27 +121,13 @@ class BaseViz(object):
                 if not isinstance(val, list):
                     val = [val]
                 for o in val:
-                    label = self.get_metric_label(o)
-                    if isinstance(o, dict):
-                        o['label'] = label
+                    label = utils.get_metric_name(o)
                     self.metric_dict[label] = o
 
         # Cast to list needed to return serializable object in py3
         self.all_metrics = list(self.metric_dict.values())
         self.metric_labels = list(self.metric_dict.keys())
 
-    def get_metric_label(self, metric):
-        if isinstance(metric, str):
-            return metric
-
-        if isinstance(metric, dict):
-            metric = metric.get('label')
-
-        if self.datasource.type == 'table':
-            db_engine_spec = self.datasource.database.db_engine_spec
-            metric = db_engine_spec.mutate_expression_label(metric)
-        return metric
-
     @staticmethod
     def handle_js_int_overflow(data):
         for d in data.get('records', dict()):
@@ -577,7 +563,7 @@ class TableViz(BaseViz):
 
         # Sum up and compute percentages for all percent metrics
         percent_metrics = fd.get('percent_metrics') or []
-        percent_metrics = [self.get_metric_label(m) for m in percent_metrics]
+        percent_metrics = [utils.get_metric_name(m) for m in percent_metrics]
 
         if len(percent_metrics):
             percent_metrics = list(filter(lambda m: m in df, percent_metrics))
@@ -595,7 +581,7 @@ class TableViz(BaseViz):
                 df[m_name] = pd.Series(metric_percents[m], name=m_name)
             # Remove metrics that are not in the main metrics list
             metrics = fd.get('metrics') or []
-            metrics = [self.get_metric_label(m) for m in metrics]
+            metrics = [utils.get_metric_name(m) for m in metrics]
             for m in filter(
                 lambda m: m not in metrics and m in df.columns,
                 percent_metrics,
@@ -695,7 +681,7 @@ class PivotTableViz(BaseViz):
         df = df.pivot_table(
             index=self.form_data.get('groupby'),
             columns=self.form_data.get('columns'),
-            values=[self.get_metric_label(m) for m in 
self.form_data.get('metrics')],
+            values=[utils.get_metric_name(m) for m in 
self.form_data.get('metrics')],
             aggfunc=self.form_data.get('pandas_aggfunc'),
             margins=self.form_data.get('pivot_margins'),
         )
@@ -1030,7 +1016,7 @@ class BulletViz(NVD3Viz):
 
     def get_data(self, df):
         df = df.fillna(0)
-        df['metric'] = df[[self.get_metric_label(self.metric)]]
+        df['metric'] = df[[utils.get_metric_name(self.metric)]]
         values = df['metric'].values
         return {
             'measures': values.tolist(),
@@ -1150,6 +1136,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
         df = df.fillna(0)
         if fd.get('granularity') == 'all':
             raise Exception(_('Pick a time granularity for your time series'))
+
         if not aggregate:
             df = df.pivot_table(
                 index=DTTM_ALIAS,
@@ -1365,8 +1352,8 @@ class NVD3DualLineViz(NVD3Viz):
         if self.form_data.get('granularity') == 'all':
             raise Exception(_('Pick a time granularity for your time series'))
 
-        metric = self.get_metric_label(fd.get('metric'))
-        metric_2 = self.get_metric_label(fd.get('metric_2'))
+        metric = utils.get_metric_name(fd.get('metric'))
+        metric_2 = utils.get_metric_name(fd.get('metric_2'))
         df = df.pivot_table(
             index=DTTM_ALIAS,
             values=[metric, metric_2])
@@ -1417,7 +1404,7 @@ class NVD3TimePivotViz(NVD3TimeSeriesViz):
         df = df.pivot_table(
             index=DTTM_ALIAS,
             columns='series',
-            values=self.get_metric_label(fd.get('metric')))
+            values=utils.get_metric_name(fd.get('metric')))
         chart_data = self.to_series(df)
         for serie in chart_data:
             serie['rank'] = rank_lookup[serie['key']]
@@ -1589,8 +1576,8 @@ class SunburstViz(BaseViz):
     def get_data(self, df):
         fd = self.form_data
         cols = fd.get('groupby')
-        metric = self.get_metric_label(fd.get('metric'))
-        secondary_metric = self.get_metric_label(fd.get('secondary_metric'))
+        metric = utils.get_metric_name(fd.get('metric'))
+        secondary_metric = utils.get_metric_name(fd.get('secondary_metric'))
         if metric == secondary_metric or secondary_metric is None:
             df.columns = cols + ['m1']
             df['m2'] = df['m1']
@@ -1691,7 +1678,7 @@ class ChordViz(BaseViz):
         qry = super(ChordViz, self).query_obj()
         fd = self.form_data
         qry['groupby'] = [fd.get('groupby'), fd.get('columns')]
-        qry['metrics'] = [self.get_metric_label(fd.get('metric'))]
+        qry['metrics'] = [utils.get_metric_name(fd.get('metric'))]
         return qry
 
     def get_data(self, df):
@@ -1757,8 +1744,8 @@ class WorldMapViz(BaseViz):
         from superset.data import countries
         fd = self.form_data
         cols = [fd.get('entity')]
-        metric = self.get_metric_label(fd.get('metric'))
-        secondary_metric = self.get_metric_label(fd.get('secondary_metric'))
+        metric = utils.get_metric_name(fd.get('metric'))
+        secondary_metric = utils.get_metric_name(fd.get('secondary_metric'))
         columns = ['country', 'm1', 'm2']
         if metric == secondary_metric:
             ndf = df[cols]
@@ -2289,7 +2276,7 @@ class DeckScatterViz(BaseDeckGLViz):
     def get_data(self, df):
         fd = self.form_data
         self.metric_label = \
-            self.get_metric_label(self.metric) if self.metric else None
+            utils.get_metric_name(self.metric) if self.metric else None
         self.point_radius_fixed = fd.get('point_radius_fixed')
         self.fixed_value = None
         self.dim = self.form_data.get('dimension')
@@ -2320,7 +2307,7 @@ class DeckScreengrid(BaseDeckGLViz):
         }
 
     def get_data(self, df):
-        self.metric_label = self.get_metric_label(self.metric)
+        self.metric_label = utils.get_metric_name(self.metric)
         return super(DeckScreengrid, self).get_data(df)
 
 
@@ -2339,7 +2326,7 @@ class DeckGrid(BaseDeckGLViz):
         }
 
     def get_data(self, df):
-        self.metric_label = self.get_metric_label(self.metric)
+        self.metric_label = utils.get_metric_name(self.metric)
         return super(DeckGrid, self).get_data(df)
 
 
@@ -2397,7 +2384,7 @@ class DeckPathViz(BaseDeckGLViz):
         return d
 
     def get_data(self, df):
-        self.metric_label = self.get_metric_label(self.metric)
+        self.metric_label = utils.get_metric_name(self.metric)
         return super(DeckPathViz, self).get_data(df)
 
 
@@ -2445,7 +2432,7 @@ class DeckHex(BaseDeckGLViz):
         }
 
     def get_data(self, df):
-        self.metric_label = self.get_metric_label(self.metric)
+        self.metric_label = utils.get_metric_name(self.metric)
         return super(DeckHex, self).get_data(df)
 
 

Reply via email to