mistercrunch closed pull request #3438: Feature: Implementing incremental 
search for a column values 
URL: https://github.com/apache/incubator-superset/pull/3438
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/superset/connectors/base/models.py 
b/superset/connectors/base/models.py
index 593c722d42..63ecaaff9e 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -179,7 +179,7 @@ def query(self, query_obj):
         """
         raise NotImplementedError()
 
-    def values_for_column(self, column_name, limit=10000):
+    def values_for_column(self, column_name, limit=10000, search_string=None):
         """Given a column, returns an iterable of distinct values
 
         This is used to populate the dropdown showing a list of
diff --git a/superset/connectors/druid/models.py 
b/superset/connectors/druid/models.py
index 6340331ba7..c945a8a7ec 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -6,6 +6,7 @@
 from datetime import datetime, timedelta
 from six import string_types
 
+import re
 import requests
 import sqlalchemy as sa
 from sqlalchemy import (
@@ -772,7 +773,8 @@ def recursive_get_fields(_conf):
 
     def values_for_column(self,
                           column_name,
-                          limit=10000):
+                          limit=10000,
+                          search_string=None):
         """Retrieve some values for the given column"""
         # TODO: Use Lexicographic TopNMetricSpec once supported by PyDruid
         if self.fetch_values_from:
@@ -790,10 +792,27 @@ def values_for_column(self,
             threshold=limit,
         )
 
+        if search_string:
+            # Druid can't make the regex case-insensitive :(
+            pattern = ''.join([
+                '[{0}{1}]'.format(c.upper(), c.lower())
+                if c.isalpha() else re.escape(c)
+                for c in search_string])
+
+            filter_params = {
+                'type': 'regex',
+                'dimension': column_name,
+                'pattern': ".*{}.*".format(pattern),
+            }
+            qry['filter'] = Filter(**filter_params)
+
         client = self.cluster.get_pydruid_client()
         client.topn(**qry)
         df = client.export_pandas()
-        return [row[column_name] for row in df.to_records(index=False)]
+        if (df.values.any()):
+            return [row[column_name] for row in df.to_records(index=False)]
+        else:
+            return []
 
     def get_query_str(self, query_obj, phase=1, client=None):
         return self.run_query(client=client, phase=phase, **query_obj)
@@ -833,7 +852,9 @@ def run_query(  # noqa / druid
 
         columns_dict = {c.column_name: c for c in self.columns}
 
-        all_metrics, post_aggs = self._metrics_and_post_aggs(metrics, 
metrics_dict)
+        all_metrics, post_aggs = self._metrics_and_post_aggs(
+            metrics,
+            metrics_dict)
 
         aggregations = OrderedDict()
         for m in self.metrics:
@@ -1110,5 +1131,6 @@ def query_datasources_by_name(
             .all()
         )
 
+
 sa.event.listen(DruidDatasource, 'after_insert', set_perm)
 sa.event.listen(DruidDatasource, 'after_update', set_perm)
diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 96ef575986..5a746986e7 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -10,8 +10,8 @@
     DateTime,
 )
 import sqlalchemy as sa
-from sqlalchemy import asc, and_, desc, select
-from sqlalchemy.sql.expression import TextAsFrom
+from sqlalchemy import asc, and_, desc, select, String as SqlString
+from sqlalchemy.sql.expression import cast, TextAsFrom
 from sqlalchemy.orm import backref, relationship
 from sqlalchemy.sql import table, literal_column, text, column
 
@@ -20,7 +20,9 @@
 from flask_babel import lazy_gettext as _
 
 from superset import db, utils, import_util, sm
-from superset.connectors.base.models import BaseDatasource, BaseColumn, 
BaseMetric
+from superset.connectors.base.models import (
+    BaseDatasource, BaseColumn, BaseMetric
+)
 from superset.utils import DTTM_ALIAS, QueryStatus
 from superset.models.helpers import QueryResult
 from superset.models.core import Database
@@ -106,7 +108,8 @@ def dttm_sql_literal(self, dttm):
 
         tf = self.python_date_format or '%Y-%m-%d %H:%M:%S.%f'
         if self.database_expression:
-            return self.database_expression.format(dttm.strftime('%Y-%m-%d 
%H:%M:%S'))
+            return self.database_expression.format(
+                dttm.strftime('%Y-%m-%d %H:%M:%S'))
         elif tf == 'epoch_s':
             return str((dttm - datetime(1970, 1, 1)).total_seconds())
         elif tf == 'epoch_ms':
@@ -285,7 +288,7 @@ def data(self):
             d['time_grain_sqla'] = grains
         return d
 
-    def values_for_column(self, column_name, limit=10000):
+    def values_for_column(self, column_name, limit=10000, search_string=None):
         """Runs query against sqla to retrieve some
         sample values for the given column.
         """
@@ -299,6 +302,14 @@ def values_for_column(self, column_name, limit=10000):
             .select_from(self.get_from_clause(tp, db_engine_spec))
             .distinct(column_name)
         )
+
+        if search_string:
+            # cast to String in case we want to search for numeric values
+            qry = qry.where(
+                cast(target_col.sqla_col, SqlString(length=100)).ilike(
+                    '%%{}%%'.format(search_string))).order_by(
+                        target_col.sqla_col)
+
         if limit:
             qry = qry.limit(limit)
 
@@ -712,5 +723,6 @@ def query_datasources_by_name(
             query = query.filter_by(schema=schema)
         return query.all()
 
+
 sa.event.listen(SqlaTable, 'after_insert', set_perm)
 sa.event.listen(SqlaTable, 'after_update', set_perm)
diff --git a/superset/views/core.py b/superset/views/core.py
index 68e82027d0..b27861b5b8 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -1106,14 +1106,17 @@ def explore(self, datasource_type, datasource_id):
     @api
     @has_access_api
     @expose("/filter/<datasource_type>/<datasource_id>/<column>/")
-    def filter(self, datasource_type, datasource_id, column):
+    @expose("/filter/<datasource_type>/<datasource_id>/<column>/<limit>/")
+    
@expose("/filter/<datasource_type>/<datasource_id>/<column>/<limit>/<search_string>")
+    def filter(self, datasource_type, datasource_id, column, limit=10000, 
search_string=None):
         """
         Endpoint to retrieve values for specified column.
 
         :param datasource_type: Type of datasource e.g. table
         :param datasource_id: Datasource id
         :param column: Column name to retrieve values for
-        :return:
+        :param limit: Return at most these entries (default: 10000)
+        :return: search_string: Only return columns containing the 
search_string
         """
         # TODO: Cache endpoint by user, datasource and column
         datasource = ConnectorRegistry.get_datasource(
@@ -1124,7 +1127,9 @@ def filter(self, datasource_type, datasource_id, column):
             return json_error_response(DATASOURCE_ACCESS_ERR)
 
         payload = json.dumps(
-            datasource.values_for_column(column),
+            datasource.values_for_column(column_name=column,
+            limit=limit,
+            search_string=search_string),
             default=utils.json_int_dttm_ser)
         return json_success(payload)
 
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 34f30a14f5..4399891729 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -183,7 +183,6 @@ def test_save_slice(self):
         assert slc.slice_name == new_slice_name
         db.session.delete(slc)
 
-
     def test_filter_endpoint(self):
         self.login(username='admin')
         slice_name = "Energy Sankey"
@@ -199,8 +198,28 @@ def test_filter_endpoint(self):
             "datasource_id=1&datasource_type=table")
 
         # Changing name
-        resp = self.get_resp(url.format(tbl_id, slice_id))
-        assert len(resp) > 0
+        resp = json.loads(self.get_resp(url.format(tbl_id, slice_id)))
+        assert len(resp) > 1
+        assert 'Carbon Dioxide' in resp
+
+        # Limit to 3
+        url = (
+            "/superset/filter/table/{}/target/3?viz_type=sankey&groupby=source"
+            "&metric=sum__value&flt_col_0=source&flt_op_0=in&flt_eq_0=&"
+            "slice_id={}&datasource_name=energy_usage&"
+            "datasource_id=1&datasource_type=table")
+        resp = json.loads(self.get_resp(url.format(tbl_id, slice_id)))
+        assert len(resp) == 3
+
+        # With search_string = "carbon"
+        url = (
+            "/superset/filter/table/{}/target/100/carbon?"
+            "viz_type=sankey&groupby=source&"
+            "metric=sum__value&flt_col_0=source&flt_op_0=in&flt_eq_0=&"
+            "slice_id={}&datasource_name=energy_usage&"
+            "datasource_id=1&datasource_type=table")
+        resp = json.loads(self.get_resp(url.format(tbl_id, slice_id)))
+        assert len(resp) == 1
         assert 'Carbon Dioxide' in resp
 
     def test_slices(self):
diff --git a/tests/druid_tests.py b/tests/druid_tests.py
index 637afe984c..540baf9800 100644
--- a/tests/druid_tests.py
+++ b/tests/druid_tests.py
@@ -390,6 +390,38 @@ def test_metrics_and_post_aggs(self):
         assert all_metrics == ['aCustomMetric']
         assert set(post_aggs.keys()) == result_postaggs
 
+    @patch('superset.connectors.druid.models.PyDruid')
+    def test_values_for_column(self, py_druid):
+        ds = 'test_datasource'
+        column = 'test_column'
+        search_string = "$t1"  # difficult test string
+
+        datasource = self.get_or_create(
+            DruidDatasource, {'datasource_name': ds},
+            db.session)
+        druid = py_druid()
+        datasource.cluster.get_pydruid_client = Mock(return_value=druid)
+
+        # search_string
+        datasource.values_for_column(column_name=column, limit=5,
+                                     search_string=search_string)
+
+        assert druid.topn.call_args[1]['datasource'] == ds
+        assert druid.topn.call_args[1]['granularity'] == 'all'
+        assert druid.topn.call_args[1]['metric'] == 'count'
+        assert druid.topn.call_args[1]['dimension'] == column
+        assert druid.topn.call_args[1]['threshold'] == 5
+
+        # test filter
+        assert(druid.topn.call_args[1]['filter']
+               .filter['filter']['dimension'] == column)
+        assert(druid.topn.call_args[1]['filter']
+               .filter['filter']['pattern'] == '.*\\$[Tt]1.*')
+
+        # no search_string
+        datasource.values_for_column(column_name=column)
+        assert not druid.topn.call_args[1].get('filter')
+
 
 if __name__ == '__main__':
     unittest.main()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to