This is an automated email from the ASF dual-hosted git repository. villebro pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push: new 315518d feat: add support for query offset (#10010) 315518d is described below commit 315518d2d22a63de1b2e5f556e7fb5737669ebd6 Author: Ville Brofeldt <33317356+ville...@users.noreply.github.com> AuthorDate: Tue Jun 9 11:46:28 2020 +0300 feat: add support for query offset (#10010) * feat: add support for query offset * Address comments and add new tests --- superset/charts/schemas.py | 12 ++++- superset/common/query_context.py | 4 +- superset/common/query_object.py | 15 ++++--- superset/connectors/druid/models.py | 3 ++ superset/connectors/sqla/models.py | 4 ++ tests/charts/api_tests.py | 89 ++++++++++++++++++++++++++++++------- tests/charts/schema_tests.py | 61 +++++++++++++++++++++++++ 7 files changed, 165 insertions(+), 23 deletions(-) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 609a868..06dc111 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -16,8 +16,9 @@ # under the License. from typing import Any, Dict, Union +from flask_babel import gettext as _ from marshmallow import fields, post_load, Schema, validate, ValidationError -from marshmallow.validate import Length +from marshmallow.validate import Length, Range from superset.common.query_context import QueryContext from superset.exceptions import SupersetException @@ -663,6 +664,15 @@ class ChartDataQueryObjectSchema(Schema): ) row_limit = fields.Integer( description='Maximum row count. Default: `config["ROW_LIMIT"]`', + validate=[ + Range(min=1, error=_("`row_limit` must be greater than or equal to 1")) + ], + ) + row_offset = fields.Integer( + description="Number of rows to skip. Default: `0`", + validate=[ + Range(min=0, error=_("`row_offset` must be greater than or equal to 0")) + ], ) order_desc = fields.Boolean( description="Reverse order. Default: `false`", required=False diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 9d31f35..7cd00fc 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -25,14 +25,13 @@ import numpy as np import pandas as pd from superset import app, cache, db, security_manager +from superset.common.query_object import QueryObject from superset.connectors.base.models import BaseDatasource from superset.connectors.connector_registry import ConnectorRegistry from superset.stats_logger import BaseStatsLogger from superset.utils import core as utils from superset.utils.core import DTTM_ALIAS -from .query_object import QueryObject - config = app.config stats_logger: BaseStatsLogger = config["STATS_LOGGER"] logger = logging.getLogger(__name__) @@ -156,6 +155,7 @@ class QueryContext: query_obj.metrics = [] query_obj.post_processing = [] query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"]) + query_obj.row_offset = 0 query_obj.columns = [o.column_name for o in self.datasource.columns] payload = self.get_df_payload(query_obj) df = payload["df"] diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 188d0b3..d09c75b 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -30,6 +30,7 @@ from superset.typing import Metric from superset.utils import core as utils, pandas_postprocessing from superset.views.utils import get_time_range_endpoints +config = app.config logger = logging.getLogger(__name__) # TODO: Type Metrics dictionary with TypedDict when it becomes a vanilla python type @@ -66,6 +67,7 @@ class QueryObject: groupby: List[str] metrics: List[Union[Dict[str, Any], str]] row_limit: int + row_offset: int filter: List[Dict[str, Any]] timeseries_limit: int timeseries_limit_metric: Optional[Metric] @@ -85,7 +87,8 @@ class QueryObject: time_shift: Optional[str] = None, is_timeseries: bool = False, timeseries_limit: int = 0, - row_limit: int = app.config["ROW_LIMIT"], + row_limit: Optional[int] = None, + row_offset: Optional[int] = None, timeseries_limit_metric: Optional[Metric] = None, order_desc: bool = True, extras: Optional[Dict[str, Any]] = None, @@ -100,10 +103,10 @@ class QueryObject: self.granularity = granularity self.from_dttm, self.to_dttm = utils.get_since_until( relative_start=extras.get( - "relative_start", app.config["DEFAULT_RELATIVE_START_TIME"] + "relative_start", config["DEFAULT_RELATIVE_START_TIME"] ), relative_end=extras.get( - "relative_end", app.config["DEFAULT_RELATIVE_END_TIME"] + "relative_end", config["DEFAULT_RELATIVE_END_TIME"] ), time_range=time_range, time_shift=time_shift, @@ -123,14 +126,15 @@ class QueryObject: for metric in metrics ] - self.row_limit = row_limit + self.row_limit = row_limit or config["ROW_LIMIT"] + self.row_offset = row_offset or 0 self.filter = filters or [] self.timeseries_limit = timeseries_limit self.timeseries_limit_metric = timeseries_limit_metric self.order_desc = order_desc self.extras = extras - if app.config["SIP_15_ENABLED"] and "time_range_endpoints" not in self.extras: + if config["SIP_15_ENABLED"] and "time_range_endpoints" not in self.extras: self.extras["time_range_endpoints"] = get_time_range_endpoints(form_data={}) self.columns = columns or [] @@ -184,6 +188,7 @@ class QueryObject: "is_timeseries": self.is_timeseries, "metrics": self.metrics, "row_limit": self.row_limit, + "row_offset": self.row_offset, "filter": self.filter, "timeseries_limit": self.timeseries_limit, "timeseries_limit_metric": self.timeseries_limit_metric, diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 4de56c9..ccfff6e 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -1179,6 +1179,7 @@ class DruidDatasource(Model, BaseDatasource): timeseries_limit: Optional[int] = None, timeseries_limit_metric: Optional[Metric] = None, row_limit: Optional[int] = None, + row_offset: Optional[int] = None, inner_from_dttm: Optional[datetime] = None, inner_to_dttm: Optional[datetime] = None, orderby: Optional[Any] = None, @@ -1192,6 +1193,8 @@ class DruidDatasource(Model, BaseDatasource): # TODO refactor into using a TBD Query object client = client or self.cluster.get_pydruid_client() row_limit = row_limit or conf.get("ROW_LIMIT") + if row_offset: + raise SupersetException("Offset not implemented for Druid connector") if not is_timeseries: granularity = "all" diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 4e93d5f..9fb6e4f 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -741,6 +741,7 @@ class SqlaTable(Model, BaseDatasource): timeseries_limit: int = 15, timeseries_limit_metric: Optional[Metric] = None, row_limit: Optional[int] = None, + row_offset: Optional[int] = None, inner_from_dttm: Optional[datetime] = None, inner_to_dttm: Optional[datetime] = None, orderby: Optional[List[Tuple[ColumnElement, bool]]] = None, @@ -753,6 +754,7 @@ class SqlaTable(Model, BaseDatasource): "groupby": groupby, "metrics": metrics, "row_limit": row_limit, + "row_offset": row_offset, "to_dttm": to_dttm, "filter": filter, "columns": {col.column_name: col for col in self.columns}, @@ -967,6 +969,8 @@ class SqlaTable(Model, BaseDatasource): if row_limit: qry = qry.limit(row_limit) + if row_offset: + qry = qry.offset(row_offset) if ( is_timeseries diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index 3e9769c..2723bcd 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -14,22 +14,27 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# isort:skip_file """Unit tests for Superset""" import json from typing import List, Optional +from unittest import mock import prison from sqlalchemy.sql import func -import tests.test_app +from tests.test_app import app from superset.connectors.connector_registry import ConnectorRegistry from superset.extensions import db, security_manager from superset.models.dashboard import Dashboard from superset.models.slice import Slice +from superset.utils import core as utils from tests.base_api_tests import ApiOwnersTestCaseMixin from tests.base_tests import SupersetTestCase from tests.fixtures.query_context import get_query_context +CHART_DATA_URI = "api/v1/chart/data" + class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin): resource_name = "chart" @@ -634,32 +639,88 @@ class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin): data = json.loads(rv.data.decode("utf-8")) self.assertEqual(data["count"], 0) - def test_chart_data(self): + def test_chart_data_simple(self): """ - Query API: Test chart data query + Chart data API: Test chart data query """ self.login(username="admin") table = self.get_table_by_name("birth_names") - payload = get_query_context(table.name, table.id, table.type) - uri = "api/v1/chart/data" - rv = self.post_assert_metric(uri, payload, "data") + request_payload = get_query_context(table.name, table.id, table.type) + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) self.assertEqual(data["result"][0]["rowcount"], 100) + def test_chart_data_limit_offset(self): + """ + Chart data API: Test chart data query with limit and offset + """ + self.login(username="admin") + table = self.get_table_by_name("birth_names") + request_payload = get_query_context(table.name, table.id, table.type) + request_payload["queries"][0]["row_limit"] = 5 + request_payload["queries"][0]["row_offset"] = 0 + request_payload["queries"][0]["orderby"] = [["name", True]] + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + response_payload = json.loads(rv.data.decode("utf-8")) + result = response_payload["result"][0] + self.assertEqual(result["rowcount"], 5) + + # ensure that offset works properly + offset = 2 + expected_name = result["data"][offset]["name"] + request_payload["queries"][0]["row_offset"] = offset + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + response_payload = json.loads(rv.data.decode("utf-8")) + result = response_payload["result"][0] + self.assertEqual(result["rowcount"], 5) + self.assertEqual(result["data"][0]["name"], expected_name) + + @mock.patch( + "superset.common.query_object.config", {**app.config, "ROW_LIMIT": 7}, + ) + def test_chart_data_default_row_limit(self): + """ + Chart data API: Ensure row count doesn't exceed default limit + """ + self.login(username="admin") + table = self.get_table_by_name("birth_names") + request_payload = get_query_context(table.name, table.id, table.type) + del request_payload["queries"][0]["row_limit"] + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + response_payload = json.loads(rv.data.decode("utf-8")) + result = response_payload["result"][0] + self.assertEqual(result["rowcount"], 7) + + @mock.patch( + "superset.common.query_context.config", {**app.config, "SAMPLES_ROW_LIMIT": 5}, + ) + def test_chart_data_default_sample_limit(self): + """ + Chart data API: Ensure sample response row count doesn't exceed default limit + """ + self.login(username="admin") + table = self.get_table_by_name("birth_names") + request_payload = get_query_context(table.name, table.id, table.type) + request_payload["result_type"] = utils.ChartDataResultType.SAMPLES + request_payload["queries"][0]["row_limit"] = 10 + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + response_payload = json.loads(rv.data.decode("utf-8")) + result = response_payload["result"][0] + self.assertEqual(result["rowcount"], 5) + def test_chart_data_with_invalid_datasource(self): - """Query API: Test chart data query with invalid schema + """Chart data API: Test chart data query with invalid schema """ self.login(username="admin") table = self.get_table_by_name("birth_names") payload = get_query_context(table.name, table.id, table.type) payload["datasource"] = "abc" - uri = "api/v1/chart/data" - rv = self.post_assert_metric(uri, payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, payload, "data") self.assertEqual(rv.status_code, 400) def test_chart_data_with_invalid_enum_value(self): - """Query API: Test chart data query with invalid enum value + """Chart data API: Test chart data query with invalid enum value """ self.login(username="admin") table = self.get_table_by_name("birth_names") @@ -668,19 +729,17 @@ class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin): "abc", "EXCLUSIVE", ] - uri = "api/v1/chart/data" - rv = self.client.post(uri, json=payload) + rv = self.client.post(CHART_DATA_URI, json=payload) self.assertEqual(rv.status_code, 400) def test_query_exec_not_allowed(self): """ - Query API: Test chart data query not allowed + Chart data API: Test chart data query not allowed """ self.login(username="gamma") table = self.get_table_by_name("birth_names") payload = get_query_context(table.name, table.id, table.type) - uri = "api/v1/chart/data" - rv = self.post_assert_metric(uri, payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, payload, "data") self.assertEqual(rv.status_code, 401) def test_datasources(self): diff --git a/tests/charts/schema_tests.py b/tests/charts/schema_tests.py new file mode 100644 index 0000000..5f0ef16 --- /dev/null +++ b/tests/charts/schema_tests.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for Superset""" +from typing import Any, Dict, Tuple + +from superset.charts.schemas import ChartDataQueryContextSchema +from superset.common.query_context import QueryContext +from tests.base_tests import SupersetTestCase +from tests.fixtures.query_context import get_query_context +from tests.test_app import app + + +def load_query_context(payload: Dict[str, Any]) -> Tuple[QueryContext, Dict[str, Any]]: + return ChartDataQueryContextSchema().load(payload) + + +class SchemaTestCase(SupersetTestCase): + def test_query_context_limit_and_offset(self): + self.login(username="admin") + table_name = "birth_names" + table = self.get_table_by_name(table_name) + payload = get_query_context(table.name, table.id, table.type) + + # Use defaults + payload["queries"][0].pop("row_limit", None) + payload["queries"][0].pop("row_offset", None) + query_context, errors = load_query_context(payload) + self.assertEqual(errors, {}) + query_object = query_context.queries[0] + self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"]) + self.assertEqual(query_object.row_offset, 0) + + # Valid limit and offset + payload["queries"][0]["row_limit"] = 100 + payload["queries"][0]["row_offset"] = 200 + query_context, errors = ChartDataQueryContextSchema().load(payload) + self.assertEqual(errors, {}) + query_object = query_context.queries[0] + self.assertEqual(query_object.row_limit, 100) + self.assertEqual(query_object.row_offset, 200) + + # too low limit and offset + payload["queries"][0]["row_limit"] = 0 + payload["queries"][0]["row_offset"] = -1 + query_context, errors = ChartDataQueryContextSchema().load(payload) + self.assertIn("row_limit", errors["queries"][0]) + self.assertIn("row_offset", errors["queries"][0])