mistercrunch closed pull request #3438: Feature: Implementing incremental
search for a column values
URL: https://github.com/apache/incubator-superset/pull/3438
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/superset/connectors/base/models.py
b/superset/connectors/base/models.py
index 593c722d42..63ecaaff9e 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -179,7 +179,7 @@ def query(self, query_obj):
"""
raise NotImplementedError()
- def values_for_column(self, column_name, limit=10000):
+ def values_for_column(self, column_name, limit=10000, search_string=None):
"""Given a column, returns an iterable of distinct values
This is used to populate the dropdown showing a list of
diff --git a/superset/connectors/druid/models.py
b/superset/connectors/druid/models.py
index 6340331ba7..c945a8a7ec 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -6,6 +6,7 @@
from datetime import datetime, timedelta
from six import string_types
+import re
import requests
import sqlalchemy as sa
from sqlalchemy import (
@@ -772,7 +773,8 @@ def recursive_get_fields(_conf):
def values_for_column(self,
column_name,
- limit=10000):
+ limit=10000,
+ search_string=None):
"""Retrieve some values for the given column"""
# TODO: Use Lexicographic TopNMetricSpec once supported by PyDruid
if self.fetch_values_from:
@@ -790,10 +792,27 @@ def values_for_column(self,
threshold=limit,
)
+ if search_string:
+ # Druid can't make the regex case-insensitive :(
+ pattern = ''.join([
+ '[{0}{1}]'.format(c.upper(), c.lower())
+ if c.isalpha() else re.escape(c)
+ for c in search_string])
+
+ filter_params = {
+ 'type': 'regex',
+ 'dimension': column_name,
+ 'pattern': ".*{}.*".format(pattern),
+ }
+ qry['filter'] = Filter(**filter_params)
+
client = self.cluster.get_pydruid_client()
client.topn(**qry)
df = client.export_pandas()
- return [row[column_name] for row in df.to_records(index=False)]
+ if (df.values.any()):
+ return [row[column_name] for row in df.to_records(index=False)]
+ else:
+ return []
def get_query_str(self, query_obj, phase=1, client=None):
return self.run_query(client=client, phase=phase, **query_obj)
@@ -833,7 +852,9 @@ def run_query( # noqa / druid
columns_dict = {c.column_name: c for c in self.columns}
- all_metrics, post_aggs = self._metrics_and_post_aggs(metrics,
metrics_dict)
+ all_metrics, post_aggs = self._metrics_and_post_aggs(
+ metrics,
+ metrics_dict)
aggregations = OrderedDict()
for m in self.metrics:
@@ -1110,5 +1131,6 @@ def query_datasources_by_name(
.all()
)
+
sa.event.listen(DruidDatasource, 'after_insert', set_perm)
sa.event.listen(DruidDatasource, 'after_update', set_perm)
diff --git a/superset/connectors/sqla/models.py
b/superset/connectors/sqla/models.py
index 96ef575986..5a746986e7 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -10,8 +10,8 @@
DateTime,
)
import sqlalchemy as sa
-from sqlalchemy import asc, and_, desc, select
-from sqlalchemy.sql.expression import TextAsFrom
+from sqlalchemy import asc, and_, desc, select, String as SqlString
+from sqlalchemy.sql.expression import cast, TextAsFrom
from sqlalchemy.orm import backref, relationship
from sqlalchemy.sql import table, literal_column, text, column
@@ -20,7 +20,9 @@
from flask_babel import lazy_gettext as _
from superset import db, utils, import_util, sm
-from superset.connectors.base.models import BaseDatasource, BaseColumn,
BaseMetric
+from superset.connectors.base.models import (
+ BaseDatasource, BaseColumn, BaseMetric
+)
from superset.utils import DTTM_ALIAS, QueryStatus
from superset.models.helpers import QueryResult
from superset.models.core import Database
@@ -106,7 +108,8 @@ def dttm_sql_literal(self, dttm):
tf = self.python_date_format or '%Y-%m-%d %H:%M:%S.%f'
if self.database_expression:
- return self.database_expression.format(dttm.strftime('%Y-%m-%d
%H:%M:%S'))
+ return self.database_expression.format(
+ dttm.strftime('%Y-%m-%d %H:%M:%S'))
elif tf == 'epoch_s':
return str((dttm - datetime(1970, 1, 1)).total_seconds())
elif tf == 'epoch_ms':
@@ -285,7 +288,7 @@ def data(self):
d['time_grain_sqla'] = grains
return d
- def values_for_column(self, column_name, limit=10000):
+ def values_for_column(self, column_name, limit=10000, search_string=None):
"""Runs query against sqla to retrieve some
sample values for the given column.
"""
@@ -299,6 +302,14 @@ def values_for_column(self, column_name, limit=10000):
.select_from(self.get_from_clause(tp, db_engine_spec))
.distinct(column_name)
)
+
+ if search_string:
+ # cast to String in case we want to search for numeric values
+ qry = qry.where(
+ cast(target_col.sqla_col, SqlString(length=100)).ilike(
+ '%%{}%%'.format(search_string))).order_by(
+ target_col.sqla_col)
+
if limit:
qry = qry.limit(limit)
@@ -712,5 +723,6 @@ def query_datasources_by_name(
query = query.filter_by(schema=schema)
return query.all()
+
sa.event.listen(SqlaTable, 'after_insert', set_perm)
sa.event.listen(SqlaTable, 'after_update', set_perm)
diff --git a/superset/views/core.py b/superset/views/core.py
index 68e82027d0..b27861b5b8 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -1106,14 +1106,17 @@ def explore(self, datasource_type, datasource_id):
@api
@has_access_api
@expose("/filter/<datasource_type>/<datasource_id>/<column>/")
- def filter(self, datasource_type, datasource_id, column):
+ @expose("/filter/<datasource_type>/<datasource_id>/<column>/<limit>/")
+
@expose("/filter/<datasource_type>/<datasource_id>/<column>/<limit>/<search_string>")
+ def filter(self, datasource_type, datasource_id, column, limit=10000,
search_string=None):
"""
Endpoint to retrieve values for specified column.
:param datasource_type: Type of datasource e.g. table
:param datasource_id: Datasource id
:param column: Column name to retrieve values for
- :return:
+ :param limit: Return at most these entries (default: 10000)
+ :return: search_string: Only return columns containing the
search_string
"""
# TODO: Cache endpoint by user, datasource and column
datasource = ConnectorRegistry.get_datasource(
@@ -1124,7 +1127,9 @@ def filter(self, datasource_type, datasource_id, column):
return json_error_response(DATASOURCE_ACCESS_ERR)
payload = json.dumps(
- datasource.values_for_column(column),
+ datasource.values_for_column(column_name=column,
+ limit=limit,
+ search_string=search_string),
default=utils.json_int_dttm_ser)
return json_success(payload)
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 34f30a14f5..4399891729 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -183,7 +183,6 @@ def test_save_slice(self):
assert slc.slice_name == new_slice_name
db.session.delete(slc)
-
def test_filter_endpoint(self):
self.login(username='admin')
slice_name = "Energy Sankey"
@@ -199,8 +198,28 @@ def test_filter_endpoint(self):
"datasource_id=1&datasource_type=table")
# Changing name
- resp = self.get_resp(url.format(tbl_id, slice_id))
- assert len(resp) > 0
+ resp = json.loads(self.get_resp(url.format(tbl_id, slice_id)))
+ assert len(resp) > 1
+ assert 'Carbon Dioxide' in resp
+
+ # Limit to 3
+ url = (
+ "/superset/filter/table/{}/target/3?viz_type=sankey&groupby=source"
+ "&metric=sum__value&flt_col_0=source&flt_op_0=in&flt_eq_0=&"
+ "slice_id={}&datasource_name=energy_usage&"
+ "datasource_id=1&datasource_type=table")
+ resp = json.loads(self.get_resp(url.format(tbl_id, slice_id)))
+ assert len(resp) == 3
+
+ # With search_string = "carbon"
+ url = (
+ "/superset/filter/table/{}/target/100/carbon?"
+ "viz_type=sankey&groupby=source&"
+ "metric=sum__value&flt_col_0=source&flt_op_0=in&flt_eq_0=&"
+ "slice_id={}&datasource_name=energy_usage&"
+ "datasource_id=1&datasource_type=table")
+ resp = json.loads(self.get_resp(url.format(tbl_id, slice_id)))
+ assert len(resp) == 1
assert 'Carbon Dioxide' in resp
def test_slices(self):
diff --git a/tests/druid_tests.py b/tests/druid_tests.py
index 637afe984c..540baf9800 100644
--- a/tests/druid_tests.py
+++ b/tests/druid_tests.py
@@ -390,6 +390,38 @@ def test_metrics_and_post_aggs(self):
assert all_metrics == ['aCustomMetric']
assert set(post_aggs.keys()) == result_postaggs
+ @patch('superset.connectors.druid.models.PyDruid')
+ def test_values_for_column(self, py_druid):
+ ds = 'test_datasource'
+ column = 'test_column'
+ search_string = "$t1" # difficult test string
+
+ datasource = self.get_or_create(
+ DruidDatasource, {'datasource_name': ds},
+ db.session)
+ druid = py_druid()
+ datasource.cluster.get_pydruid_client = Mock(return_value=druid)
+
+ # search_string
+ datasource.values_for_column(column_name=column, limit=5,
+ search_string=search_string)
+
+ assert druid.topn.call_args[1]['datasource'] == ds
+ assert druid.topn.call_args[1]['granularity'] == 'all'
+ assert druid.topn.call_args[1]['metric'] == 'count'
+ assert druid.topn.call_args[1]['dimension'] == column
+ assert druid.topn.call_args[1]['threshold'] == 5
+
+ # test filter
+ assert(druid.topn.call_args[1]['filter']
+ .filter['filter']['dimension'] == column)
+ assert(druid.topn.call_args[1]['filter']
+ .filter['filter']['pattern'] == '.*\\$[Tt]1.*')
+
+ # no search_string
+ datasource.values_for_column(column_name=column)
+ assert not druid.topn.call_args[1].get('filter')
+
if __name__ == '__main__':
unittest.main()
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services