This is an automated email from the ASF dual-hosted git repository.
maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new cb7c5aa Fixed finding postaggregations (#4017)
cb7c5aa is described below
commit cb7c5aa70c3729d8f1fe0c310d30e620f9e9a581
Author: Jeff Niu <[email protected]>
AuthorDate: Wed Dec 6 21:55:43 2017 -0800
Fixed finding postaggregations (#4017)
---
superset/connectors/druid/models.py | 174 ++++++++++++++--------
tests/druid_func_tests.py | 284 ++++++++++++++++++++++++++++++++++++
tests/druid_tests.py | 87 +----------
3 files changed, 397 insertions(+), 148 deletions(-)
diff --git a/superset/connectors/druid/models.py
b/superset/connectors/druid/models.py
index bf7e176..acb1951 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -786,73 +786,123 @@ class DruidDatasource(Model, BaseDatasource):
return granularity
@staticmethod
- def _metrics_and_post_aggs(metrics, metrics_dict):
- all_metrics = []
- post_aggs = {}
-
- def recursive_get_fields(_conf):
- _type = _conf.get('type')
- _field = _conf.get('field')
- _fields = _conf.get('fields')
-
- field_names = []
- if _type in ['fieldAccess', 'hyperUniqueCardinality',
- 'quantile', 'quantiles']:
- field_names.append(_conf.get('fieldName', ''))
+ def get_post_agg(mconf):
+ """
+ For a metric specified as `postagg` returns the
+ kind of post aggregation for pydruid.
+ """
+ if mconf.get('type') == 'javascript':
+ return JavascriptPostAggregator(
+ name=mconf.get('name', ''),
+ field_names=mconf.get('fieldNames', []),
+ function=mconf.get('function', ''))
+ elif mconf.get('type') == 'quantile':
+ return Quantile(
+ mconf.get('name', ''),
+ mconf.get('probability', ''),
+ )
+ elif mconf.get('type') == 'quantiles':
+ return Quantiles(
+ mconf.get('name', ''),
+ mconf.get('probabilities', ''),
+ )
+ elif mconf.get('type') == 'fieldAccess':
+ return Field(mconf.get('name'))
+ elif mconf.get('type') == 'constant':
+ return Const(
+ mconf.get('value'),
+ output_name=mconf.get('name', ''),
+ )
+ elif mconf.get('type') == 'hyperUniqueCardinality':
+ return HyperUniqueCardinality(
+ mconf.get('name'),
+ )
+ elif mconf.get('type') == 'arithmetic':
+ return Postaggregator(
+ mconf.get('fn', '/'),
+ mconf.get('fields', []),
+ mconf.get('name', ''))
+ else:
+ return CustomPostAggregator(
+ mconf.get('name', ''),
+ mconf)
- if _field:
- field_names += recursive_get_fields(_field)
+ @staticmethod
+ def find_postaggs_for(postagg_names, metrics_dict):
+ """Return a list of metrics that are post aggregations"""
+ postagg_metrics = [
+ metrics_dict[name] for name in postagg_names
+ if metrics_dict[name].metric_type == 'postagg'
+ ]
+ # Remove post aggregations that were found
+ for postagg in postagg_metrics:
+ postagg_names.remove(postagg.metric_name)
+ return postagg_metrics
- if _fields:
- for _f in _fields:
- field_names += recursive_get_fields(_f)
+ @staticmethod
+ def recursive_get_fields(_conf):
+ _type = _conf.get('type')
+ _field = _conf.get('field')
+ _fields = _conf.get('fields')
+ field_names = []
+ if _type in ['fieldAccess', 'hyperUniqueCardinality',
+ 'quantile', 'quantiles']:
+ field_names.append(_conf.get('fieldName', ''))
+ if _field:
+ field_names += DruidDatasource.recursive_get_fields(_field)
+ if _fields:
+ for _f in _fields:
+ field_names += DruidDatasource.recursive_get_fields(_f)
+ return list(set(field_names))
- return list(set(field_names))
+ @staticmethod
+ def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs,
metrics_dict):
+ mconf = postagg.json_obj
+ required_fields = set(
+ DruidDatasource.recursive_get_fields(mconf)
+ + mconf.get('fieldNames', []))
+ # Check if the fields are already in aggs
+ # or is a previous postagg
+ required_fields = set([
+ field for field in required_fields
+ if field not in visited_postaggs and field not in agg_names
+ ])
+ # First try to find postaggs that match
+ if len(required_fields) > 0:
+ missing_postaggs = DruidDatasource.find_postaggs_for(
+ required_fields, metrics_dict)
+ for missing_metric in required_fields:
+ agg_names.add(missing_metric)
+ for missing_postagg in missing_postaggs:
+ # Add to visited first to avoid infinite recursion
+ # if post aggregations are cyclicly dependent
+ visited_postaggs.add(missing_postagg.metric_name)
+ for missing_postagg in missing_postaggs:
+ DruidDatasource.resolve_postagg(
+ missing_postagg, post_aggs, agg_names, visited_postaggs,
metrics_dict)
+ post_aggs[postagg.metric_name] =
DruidDatasource.get_post_agg(postagg.json_obj)
+ @staticmethod
+ def metrics_and_post_aggs(metrics, metrics_dict):
+ # Separate metrics into those that are aggregations
+ # and those that are post aggregations
+ agg_names = set()
+ postagg_names = []
for metric_name in metrics:
- metric = metrics_dict[metric_name]
- if metric.metric_type != 'postagg':
- all_metrics.append(metric_name)
+ if metrics_dict[metric_name].metric_type != 'postagg':
+ agg_names.add(metric_name)
else:
- mconf = metric.json_obj
- all_metrics += recursive_get_fields(mconf)
- all_metrics += mconf.get('fieldNames', [])
- if mconf.get('type') == 'javascript':
- post_aggs[metric_name] = JavascriptPostAggregator(
- name=mconf.get('name', ''),
- field_names=mconf.get('fieldNames', []),
- function=mconf.get('function', ''))
- elif mconf.get('type') == 'quantile':
- post_aggs[metric_name] = Quantile(
- mconf.get('name', ''),
- mconf.get('probability', ''),
- )
- elif mconf.get('type') == 'quantiles':
- post_aggs[metric_name] = Quantiles(
- mconf.get('name', ''),
- mconf.get('probabilities', ''),
- )
- elif mconf.get('type') == 'fieldAccess':
- post_aggs[metric_name] = Field(mconf.get('name'))
- elif mconf.get('type') == 'constant':
- post_aggs[metric_name] = Const(
- mconf.get('value'),
- output_name=mconf.get('name', ''),
- )
- elif mconf.get('type') == 'hyperUniqueCardinality':
- post_aggs[metric_name] = HyperUniqueCardinality(
- mconf.get('name'),
- )
- elif mconf.get('type') == 'arithmetic':
- post_aggs[metric_name] = Postaggregator(
- mconf.get('fn', '/'),
- mconf.get('fields', []),
- mconf.get('name', ''))
- else:
- post_aggs[metric_name] = CustomPostAggregator(
- mconf.get('name', ''),
- mconf)
- return all_metrics, post_aggs
+ postagg_names.append(metric_name)
+ # Create the post aggregations, maintain order since postaggs
+ # may depend on previous ones
+ post_aggs = OrderedDict()
+ visited_postaggs = set()
+ for postagg_name in postagg_names:
+ postagg = metrics_dict[postagg_name]
+ visited_postaggs.add(postagg_name)
+ DruidDatasource.resolve_postagg(
+ postagg, post_aggs, agg_names, visited_postaggs, metrics_dict)
+ return list(agg_names), post_aggs
def values_for_column(self,
column_name,
@@ -940,7 +990,7 @@ class DruidDatasource(Model, BaseDatasource):
columns_dict = {c.column_name: c for c in self.columns}
- all_metrics, post_aggs = self._metrics_and_post_aggs(
+ all_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
metrics,
metrics_dict)
diff --git a/tests/druid_func_tests.py b/tests/druid_func_tests.py
index ba1f497..4c047df 100644
--- a/tests/druid_func_tests.py
+++ b/tests/druid_func_tests.py
@@ -2,12 +2,26 @@ import json
import unittest
from mock import Mock
+import pydruid.utils.postaggregator as postaggs
+import superset.connectors.druid.models as models
from superset.connectors.druid.models import (
DruidColumn, DruidDatasource, DruidMetric,
)
+def mock_metric(metric_name, is_postagg=False):
+ metric = Mock()
+ metric.metric_name = metric_name
+ metric.metric_type = 'postagg' if is_postagg else 'metric'
+ return metric
+
+
+def emplace(metrics_dict, metric_name, is_postagg=False):
+ metrics_dict[metric_name] = mock_metric(metric_name, is_postagg)
+
+
+# Unit tests that can be run without initializing base tests
class DruidFuncTestCase(unittest.TestCase):
def test_get_filters_ignores_invalid_filter_objects(self):
@@ -271,3 +285,273 @@ class DruidFuncTestCase(unittest.TestCase):
called_args = client.groupby.call_args_list[0][1]
self.assertIn('dimensions', called_args)
self.assertEqual(['col1', 'col2'], called_args['dimensions'])
+
+ def test_get_post_agg_returns_correct_agg_type(self):
+ get_post_agg = DruidDatasource.get_post_agg
+ # javascript PostAggregators
+ function = 'function(field1, field2) { return field1 + field2; }'
+ conf = {
+ 'type': 'javascript',
+ 'name': 'postagg_name',
+ 'fieldNames': ['field1', 'field2'],
+ 'function': function,
+ }
+ postagg = get_post_agg(conf)
+ self.assertTrue(isinstance(postagg, models.JavascriptPostAggregator))
+ self.assertEqual(postagg.name, 'postagg_name')
+ self.assertEqual(postagg.post_aggregator['type'], 'javascript')
+ self.assertEqual(postagg.post_aggregator['fieldNames'], ['field1',
'field2'])
+ self.assertEqual(postagg.post_aggregator['name'], 'postagg_name')
+ self.assertEqual(postagg.post_aggregator['function'], function)
+ # Quantile
+ conf = {
+ 'type': 'quantile',
+ 'name': 'postagg_name',
+ 'probability': '0.5',
+ }
+ postagg = get_post_agg(conf)
+ self.assertTrue(isinstance(postagg, postaggs.Quantile))
+ self.assertEqual(postagg.name, 'postagg_name')
+ self.assertEqual(postagg.post_aggregator['probability'], '0.5')
+ # Quantiles
+ conf = {
+ 'type': 'quantiles',
+ 'name': 'postagg_name',
+ 'probabilities': '0.4,0.5,0.6',
+ }
+ postagg = get_post_agg(conf)
+ self.assertTrue(isinstance(postagg, postaggs.Quantiles))
+ self.assertEqual(postagg.name, 'postagg_name')
+ self.assertEqual(postagg.post_aggregator['probabilities'],
'0.4,0.5,0.6')
+ # FieldAccess
+ conf = {
+ 'type': 'fieldAccess',
+ 'name': 'field_name',
+ }
+ postagg = get_post_agg(conf)
+ self.assertTrue(isinstance(postagg, postaggs.Field))
+ self.assertEqual(postagg.name, 'field_name')
+ # constant
+ conf = {
+ 'type': 'constant',
+ 'value': 1234,
+ 'name': 'postagg_name',
+ }
+ postagg = get_post_agg(conf)
+ self.assertTrue(isinstance(postagg, postaggs.Const))
+ self.assertEqual(postagg.name, 'postagg_name')
+ self.assertEqual(postagg.post_aggregator['value'], 1234)
+ # hyperUniqueCardinality
+ conf = {
+ 'type': 'hyperUniqueCardinality',
+ 'name': 'unique_name',
+ }
+ postagg = get_post_agg(conf)
+ self.assertTrue(isinstance(postagg, postaggs.HyperUniqueCardinality))
+ self.assertEqual(postagg.name, 'unique_name')
+ # arithmetic
+ conf = {
+ 'type': 'arithmetic',
+ 'fn': '+',
+ 'fields': ['field1', 'field2'],
+ 'name': 'postagg_name',
+ }
+ postagg = get_post_agg(conf)
+ self.assertTrue(isinstance(postagg, postaggs.Postaggregator))
+ self.assertEqual(postagg.name, 'postagg_name')
+ self.assertEqual(postagg.post_aggregator['fn'], '+')
+ self.assertEqual(postagg.post_aggregator['fields'], ['field1',
'field2'])
+ # custom post aggregator
+ conf = {
+ 'type': 'custom',
+ 'name': 'custom_name',
+ 'stuff': 'more_stuff',
+ }
+ postagg = get_post_agg(conf)
+ self.assertTrue(isinstance(postagg, models.CustomPostAggregator))
+ self.assertEqual(postagg.name, 'custom_name')
+ self.assertEqual(postagg.post_aggregator['stuff'], 'more_stuff')
+
+ def test_find_postaggs_for_returns_postaggs_and_removes(self):
+ find_postaggs_for = DruidDatasource.find_postaggs_for
+ postagg_names = set(['pa2', 'pa3', 'pa4', 'm1', 'm2', 'm3', 'm4'])
+
+ metrics = {}
+ for i in range(1, 6):
+ emplace(metrics, 'pa' + str(i), True)
+ emplace(metrics, 'm' + str(i), False)
+ postagg_list = find_postaggs_for(postagg_names, metrics)
+ self.assertEqual(3, len(postagg_list))
+ self.assertEqual(4, len(postagg_names))
+ expected_metrics = ['m1', 'm2', 'm3', 'm4']
+ expected_postaggs = set(['pa2', 'pa3', 'pa4'])
+ for postagg in postagg_list:
+ expected_postaggs.remove(postagg.metric_name)
+ for metric in expected_metrics:
+ postagg_names.remove(metric)
+ self.assertEqual(0, len(expected_postaggs))
+ self.assertEqual(0, len(postagg_names))
+
+ def test_recursive_get_fields(self):
+ conf = {
+ 'type': 'quantile',
+ 'fieldName': 'f1',
+ 'field': {
+ 'type': 'custom',
+ 'fields': [{
+ 'type': 'fieldAccess',
+ 'fieldName': 'f2',
+ }, {
+ 'type': 'fieldAccess',
+ 'fieldName': 'f3',
+ }, {
+ 'type': 'quantiles',
+ 'fieldName': 'f4',
+ 'field': {
+ 'type': 'custom',
+ },
+ }, {
+ 'type': 'custom',
+ 'fields': [{
+ 'type': 'fieldAccess',
+ 'fieldName': 'f5',
+ }, {
+ 'type': 'fieldAccess',
+ 'fieldName': 'f2',
+ 'fields': [{
+ 'type': 'fieldAccess',
+ 'fieldName': 'f3',
+ }, {
+ 'type': 'fieldIgnoreMe',
+ 'fieldName': 'f6',
+ }],
+ }],
+ }],
+ },
+ }
+ fields = DruidDatasource.recursive_get_fields(conf)
+ expected = set(['f1', 'f2', 'f3', 'f4', 'f5'])
+ self.assertEqual(5, len(fields))
+ for field in fields:
+ expected.remove(field)
+ self.assertEqual(0, len(expected))
+
+ def test_metrics_and_post_aggs_tree(self):
+ metrics = ['A', 'B', 'm1', 'm2']
+ metrics_dict = {}
+ for i in range(ord('A'), ord('K') + 1):
+ emplace(metrics_dict, chr(i), True)
+ for i in range(1, 10):
+ emplace(metrics_dict, 'm' + str(i), False)
+
+ def depends_on(index, fields):
+ dependents = fields if isinstance(fields, list) else [fields]
+ metrics_dict[index].json_obj = {'fieldNames': dependents}
+
+ depends_on('A', ['m1', 'D', 'C'])
+ depends_on('B', ['B', 'C', 'E', 'F', 'm3'])
+ depends_on('C', ['H', 'I'])
+ depends_on('D', ['m2', 'm5', 'G', 'C'])
+ depends_on('E', ['H', 'I', 'J'])
+ depends_on('F', ['J', 'm5'])
+ depends_on('G', ['m4', 'm7', 'm6', 'A'])
+ depends_on('H', ['A', 'm4', 'I'])
+ depends_on('I', ['H', 'K'])
+ depends_on('J', 'K')
+ depends_on('K', ['m8', 'm9'])
+ all_metrics, postaggs = DruidDatasource.metrics_and_post_aggs(
+ metrics, metrics_dict)
+ expected_metrics = set(all_metrics)
+ self.assertEqual(9, len(all_metrics))
+ for i in range(1, 10):
+ expected_metrics.remove('m' + str(i))
+ self.assertEqual(0, len(expected_metrics))
+ self.assertEqual(11, len(postaggs))
+ for i in range(ord('A'), ord('K') + 1):
+ del postaggs[chr(i)]
+ self.assertEqual(0, len(postaggs))
+
+ def test_metrics_and_post_aggs(self):
+ """
+ Test generation of metrics and post-aggregations from an initial list
+ of superset metrics (which may include the results of either). This
+ primarily tests that specifying a post-aggregator metric will also
+ require the raw aggregation of the associated druid metric column.
+ """
+ metrics_dict = {
+ 'unused_count': DruidMetric(
+ metric_name='unused_count',
+ verbose_name='COUNT(*)',
+ metric_type='count',
+ json=json.dumps({'type': 'count', 'name': 'unused_count'}),
+ ),
+ 'some_sum': DruidMetric(
+ metric_name='some_sum',
+ verbose_name='SUM(*)',
+ metric_type='sum',
+ json=json.dumps({'type': 'sum', 'name': 'sum'}),
+ ),
+ 'a_histogram': DruidMetric(
+ metric_name='a_histogram',
+ verbose_name='APPROXIMATE_HISTOGRAM(*)',
+ metric_type='approxHistogramFold',
+ json=json.dumps(
+ {'type': 'approxHistogramFold', 'name': 'a_histogram'},
+ ),
+ ),
+ 'aCustomMetric': DruidMetric(
+ metric_name='aCustomMetric',
+ verbose_name='MY_AWESOME_METRIC(*)',
+ metric_type='aCustomType',
+ json=json.dumps(
+ {'type': 'customMetric', 'name': 'aCustomMetric'},
+ ),
+ ),
+ 'quantile_p95': DruidMetric(
+ metric_name='quantile_p95',
+ verbose_name='P95(*)',
+ metric_type='postagg',
+ json=json.dumps({
+ 'type': 'quantile',
+ 'probability': 0.95,
+ 'name': 'p95',
+ 'fieldName': 'a_histogram',
+ }),
+ ),
+ 'aCustomPostAgg': DruidMetric(
+ metric_name='aCustomPostAgg',
+ verbose_name='CUSTOM_POST_AGG(*)',
+ metric_type='postagg',
+ json=json.dumps({
+ 'type': 'customPostAgg',
+ 'name': 'aCustomPostAgg',
+ 'field': {
+ 'type': 'fieldAccess',
+ 'fieldName': 'aCustomMetric',
+ },
+ }),
+ ),
+ }
+
+ metrics = ['some_sum']
+ all_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+ metrics, metrics_dict)
+
+ assert all_metrics == ['some_sum']
+ assert post_aggs == {}
+
+ metrics = ['quantile_p95']
+ all_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+ metrics, metrics_dict)
+
+ result_postaggs = set(['quantile_p95'])
+ assert all_metrics == ['a_histogram']
+ assert set(post_aggs.keys()) == result_postaggs
+
+ metrics = ['aCustomPostAgg']
+ all_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+ metrics, metrics_dict)
+
+ result_postaggs = set(['aCustomPostAgg'])
+ assert all_metrics == ['aCustomMetric']
+ assert set(post_aggs.keys()) == result_postaggs
diff --git a/tests/druid_tests.py b/tests/druid_tests.py
index c9dce33..c280da7 100644
--- a/tests/druid_tests.py
+++ b/tests/druid_tests.py
@@ -12,7 +12,7 @@ from mock import Mock, patch
from superset import db, security, sm
from superset.connectors.druid.models import (
- DruidCluster, DruidDatasource, DruidMetric,
+ DruidCluster, DruidDatasource,
)
from .base_tests import SupersetTestCase
@@ -328,91 +328,6 @@ class DruidTests(SupersetTestCase):
permission=permission, view_menu=view_menu).first()
assert pv is not None
- def test_metrics_and_post_aggs(self):
- """
- Test generation of metrics and post-aggregations from an initial list
- of superset metrics (which may include the results of either). This
- primarily tests that specifying a post-aggregator metric will also
- require the raw aggregation of the associated druid metric column.
- """
- metrics_dict = {
- 'unused_count': DruidMetric(
- metric_name='unused_count',
- verbose_name='COUNT(*)',
- metric_type='count',
- json=json.dumps({'type': 'count', 'name': 'unused_count'}),
- ),
- 'some_sum': DruidMetric(
- metric_name='some_sum',
- verbose_name='SUM(*)',
- metric_type='sum',
- json=json.dumps({'type': 'sum', 'name': 'sum'}),
- ),
- 'a_histogram': DruidMetric(
- metric_name='a_histogram',
- verbose_name='APPROXIMATE_HISTOGRAM(*)',
- metric_type='approxHistogramFold',
- json=json.dumps(
- {'type': 'approxHistogramFold', 'name': 'a_histogram'},
- ),
- ),
- 'aCustomMetric': DruidMetric(
- metric_name='aCustomMetric',
- verbose_name='MY_AWESOME_METRIC(*)',
- metric_type='aCustomType',
- json=json.dumps(
- {'type': 'customMetric', 'name': 'aCustomMetric'},
- ),
- ),
- 'quantile_p95': DruidMetric(
- metric_name='quantile_p95',
- verbose_name='P95(*)',
- metric_type='postagg',
- json=json.dumps({
- 'type': 'quantile',
- 'probability': 0.95,
- 'name': 'p95',
- 'fieldName': 'a_histogram',
- }),
- ),
- 'aCustomPostAgg': DruidMetric(
- metric_name='aCustomPostAgg',
- verbose_name='CUSTOM_POST_AGG(*)',
- metric_type='postagg',
- json=json.dumps({
- 'type': 'customPostAgg',
- 'name': 'aCustomPostAgg',
- 'field': {
- 'type': 'fieldAccess',
- 'fieldName': 'aCustomMetric',
- },
- }),
- ),
- }
-
- metrics = ['some_sum']
- all_metrics, post_aggs = DruidDatasource._metrics_and_post_aggs(
- metrics, metrics_dict)
-
- assert all_metrics == ['some_sum']
- assert post_aggs == {}
-
- metrics = ['quantile_p95']
- all_metrics, post_aggs = DruidDatasource._metrics_and_post_aggs(
- metrics, metrics_dict)
-
- result_postaggs = set(['quantile_p95'])
- assert all_metrics == ['a_histogram']
- assert set(post_aggs.keys()) == result_postaggs
-
- metrics = ['aCustomPostAgg']
- all_metrics, post_aggs = DruidDatasource._metrics_and_post_aggs(
- metrics, metrics_dict)
-
- result_postaggs = set(['aCustomPostAgg'])
- assert all_metrics == ['aCustomMetric']
- assert set(post_aggs.keys()) == result_postaggs
-
if __name__ == '__main__':
unittest.main()
--
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].