This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 315518d  feat: add support for query offset (#10010)
315518d is described below

commit 315518d2d22a63de1b2e5f556e7fb5737669ebd6
Author: Ville Brofeldt <33317356+ville...@users.noreply.github.com>
AuthorDate: Tue Jun 9 11:46:28 2020 +0300

    feat: add support for query offset (#10010)
    
    * feat: add support for query offset
    
    * Address comments and add new tests
---
 superset/charts/schemas.py          | 12 ++++-
 superset/common/query_context.py    |  4 +-
 superset/common/query_object.py     | 15 ++++---
 superset/connectors/druid/models.py |  3 ++
 superset/connectors/sqla/models.py  |  4 ++
 tests/charts/api_tests.py           | 89 ++++++++++++++++++++++++++++++-------
 tests/charts/schema_tests.py        | 61 +++++++++++++++++++++++++
 7 files changed, 165 insertions(+), 23 deletions(-)

diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index 609a868..06dc111 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -16,8 +16,9 @@
 # under the License.
 from typing import Any, Dict, Union
 
+from flask_babel import gettext as _
 from marshmallow import fields, post_load, Schema, validate, ValidationError
-from marshmallow.validate import Length
+from marshmallow.validate import Length, Range
 
 from superset.common.query_context import QueryContext
 from superset.exceptions import SupersetException
@@ -663,6 +664,15 @@ class ChartDataQueryObjectSchema(Schema):
     )
     row_limit = fields.Integer(
         description='Maximum row count. Default: `config["ROW_LIMIT"]`',
+        validate=[
+            Range(min=1, error=_("`row_limit` must be greater than or equal to 
1"))
+        ],
+    )
+    row_offset = fields.Integer(
+        description="Number of rows to skip. Default: `0`",
+        validate=[
+            Range(min=0, error=_("`row_offset` must be greater than or equal 
to 0"))
+        ],
     )
     order_desc = fields.Boolean(
         description="Reverse order. Default: `false`", required=False
diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index 9d31f35..7cd00fc 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -25,14 +25,13 @@ import numpy as np
 import pandas as pd
 
 from superset import app, cache, db, security_manager
+from superset.common.query_object import QueryObject
 from superset.connectors.base.models import BaseDatasource
 from superset.connectors.connector_registry import ConnectorRegistry
 from superset.stats_logger import BaseStatsLogger
 from superset.utils import core as utils
 from superset.utils.core import DTTM_ALIAS
 
-from .query_object import QueryObject
-
 config = app.config
 stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
 logger = logging.getLogger(__name__)
@@ -156,6 +155,7 @@ class QueryContext:
             query_obj.metrics = []
             query_obj.post_processing = []
             query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"])
+            query_obj.row_offset = 0
             query_obj.columns = [o.column_name for o in 
self.datasource.columns]
         payload = self.get_df_payload(query_obj)
         df = payload["df"]
diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index 188d0b3..d09c75b 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -30,6 +30,7 @@ from superset.typing import Metric
 from superset.utils import core as utils, pandas_postprocessing
 from superset.views.utils import get_time_range_endpoints
 
+config = app.config
 logger = logging.getLogger(__name__)
 
 # TODO: Type Metrics dictionary with TypedDict when it becomes a vanilla 
python type
@@ -66,6 +67,7 @@ class QueryObject:
     groupby: List[str]
     metrics: List[Union[Dict[str, Any], str]]
     row_limit: int
+    row_offset: int
     filter: List[Dict[str, Any]]
     timeseries_limit: int
     timeseries_limit_metric: Optional[Metric]
@@ -85,7 +87,8 @@ class QueryObject:
         time_shift: Optional[str] = None,
         is_timeseries: bool = False,
         timeseries_limit: int = 0,
-        row_limit: int = app.config["ROW_LIMIT"],
+        row_limit: Optional[int] = None,
+        row_offset: Optional[int] = None,
         timeseries_limit_metric: Optional[Metric] = None,
         order_desc: bool = True,
         extras: Optional[Dict[str, Any]] = None,
@@ -100,10 +103,10 @@ class QueryObject:
         self.granularity = granularity
         self.from_dttm, self.to_dttm = utils.get_since_until(
             relative_start=extras.get(
-                "relative_start", app.config["DEFAULT_RELATIVE_START_TIME"]
+                "relative_start", config["DEFAULT_RELATIVE_START_TIME"]
             ),
             relative_end=extras.get(
-                "relative_end", app.config["DEFAULT_RELATIVE_END_TIME"]
+                "relative_end", config["DEFAULT_RELATIVE_END_TIME"]
             ),
             time_range=time_range,
             time_shift=time_shift,
@@ -123,14 +126,15 @@ class QueryObject:
             for metric in metrics
         ]
 
-        self.row_limit = row_limit
+        self.row_limit = row_limit or config["ROW_LIMIT"]
+        self.row_offset = row_offset or 0
         self.filter = filters or []
         self.timeseries_limit = timeseries_limit
         self.timeseries_limit_metric = timeseries_limit_metric
         self.order_desc = order_desc
         self.extras = extras
 
-        if app.config["SIP_15_ENABLED"] and "time_range_endpoints" not in 
self.extras:
+        if config["SIP_15_ENABLED"] and "time_range_endpoints" not in 
self.extras:
             self.extras["time_range_endpoints"] = 
get_time_range_endpoints(form_data={})
 
         self.columns = columns or []
@@ -184,6 +188,7 @@ class QueryObject:
             "is_timeseries": self.is_timeseries,
             "metrics": self.metrics,
             "row_limit": self.row_limit,
+            "row_offset": self.row_offset,
             "filter": self.filter,
             "timeseries_limit": self.timeseries_limit,
             "timeseries_limit_metric": self.timeseries_limit_metric,
diff --git a/superset/connectors/druid/models.py 
b/superset/connectors/druid/models.py
index 4de56c9..ccfff6e 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -1179,6 +1179,7 @@ class DruidDatasource(Model, BaseDatasource):
         timeseries_limit: Optional[int] = None,
         timeseries_limit_metric: Optional[Metric] = None,
         row_limit: Optional[int] = None,
+        row_offset: Optional[int] = None,
         inner_from_dttm: Optional[datetime] = None,
         inner_to_dttm: Optional[datetime] = None,
         orderby: Optional[Any] = None,
@@ -1192,6 +1193,8 @@ class DruidDatasource(Model, BaseDatasource):
         # TODO refactor into using a TBD Query object
         client = client or self.cluster.get_pydruid_client()
         row_limit = row_limit or conf.get("ROW_LIMIT")
+        if row_offset:
+            raise SupersetException("Offset not implemented for Druid 
connector")
 
         if not is_timeseries:
             granularity = "all"
diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 4e93d5f..9fb6e4f 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -741,6 +741,7 @@ class SqlaTable(Model, BaseDatasource):
         timeseries_limit: int = 15,
         timeseries_limit_metric: Optional[Metric] = None,
         row_limit: Optional[int] = None,
+        row_offset: Optional[int] = None,
         inner_from_dttm: Optional[datetime] = None,
         inner_to_dttm: Optional[datetime] = None,
         orderby: Optional[List[Tuple[ColumnElement, bool]]] = None,
@@ -753,6 +754,7 @@ class SqlaTable(Model, BaseDatasource):
             "groupby": groupby,
             "metrics": metrics,
             "row_limit": row_limit,
+            "row_offset": row_offset,
             "to_dttm": to_dttm,
             "filter": filter,
             "columns": {col.column_name: col for col in self.columns},
@@ -967,6 +969,8 @@ class SqlaTable(Model, BaseDatasource):
 
         if row_limit:
             qry = qry.limit(row_limit)
+        if row_offset:
+            qry = qry.offset(row_offset)
 
         if (
             is_timeseries
diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py
index 3e9769c..2723bcd 100644
--- a/tests/charts/api_tests.py
+++ b/tests/charts/api_tests.py
@@ -14,22 +14,27 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+# isort:skip_file
 """Unit tests for Superset"""
 import json
 from typing import List, Optional
+from unittest import mock
 
 import prison
 from sqlalchemy.sql import func
 
-import tests.test_app
+from tests.test_app import app
 from superset.connectors.connector_registry import ConnectorRegistry
 from superset.extensions import db, security_manager
 from superset.models.dashboard import Dashboard
 from superset.models.slice import Slice
+from superset.utils import core as utils
 from tests.base_api_tests import ApiOwnersTestCaseMixin
 from tests.base_tests import SupersetTestCase
 from tests.fixtures.query_context import get_query_context
 
+CHART_DATA_URI = "api/v1/chart/data"
+
 
 class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
     resource_name = "chart"
@@ -634,32 +639,88 @@ class ChartApiTests(SupersetTestCase, 
ApiOwnersTestCaseMixin):
         data = json.loads(rv.data.decode("utf-8"))
         self.assertEqual(data["count"], 0)
 
-    def test_chart_data(self):
+    def test_chart_data_simple(self):
         """
-        Query API: Test chart data query
+        Chart data API: Test chart data query
         """
         self.login(username="admin")
         table = self.get_table_by_name("birth_names")
-        payload = get_query_context(table.name, table.id, table.type)
-        uri = "api/v1/chart/data"
-        rv = self.post_assert_metric(uri, payload, "data")
+        request_payload = get_query_context(table.name, table.id, table.type)
+        rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
         self.assertEqual(rv.status_code, 200)
         data = json.loads(rv.data.decode("utf-8"))
         self.assertEqual(data["result"][0]["rowcount"], 100)
 
+    def test_chart_data_limit_offset(self):
+        """
+        Chart data API: Test chart data query with limit and offset
+        """
+        self.login(username="admin")
+        table = self.get_table_by_name("birth_names")
+        request_payload = get_query_context(table.name, table.id, table.type)
+        request_payload["queries"][0]["row_limit"] = 5
+        request_payload["queries"][0]["row_offset"] = 0
+        request_payload["queries"][0]["orderby"] = [["name", True]]
+        rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
+        response_payload = json.loads(rv.data.decode("utf-8"))
+        result = response_payload["result"][0]
+        self.assertEqual(result["rowcount"], 5)
+
+        # ensure that offset works properly
+        offset = 2
+        expected_name = result["data"][offset]["name"]
+        request_payload["queries"][0]["row_offset"] = offset
+        rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
+        response_payload = json.loads(rv.data.decode("utf-8"))
+        result = response_payload["result"][0]
+        self.assertEqual(result["rowcount"], 5)
+        self.assertEqual(result["data"][0]["name"], expected_name)
+
+    @mock.patch(
+        "superset.common.query_object.config", {**app.config, "ROW_LIMIT": 7},
+    )
+    def test_chart_data_default_row_limit(self):
+        """
+        Chart data API: Ensure row count doesn't exceed default limit
+        """
+        self.login(username="admin")
+        table = self.get_table_by_name("birth_names")
+        request_payload = get_query_context(table.name, table.id, table.type)
+        del request_payload["queries"][0]["row_limit"]
+        rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
+        response_payload = json.loads(rv.data.decode("utf-8"))
+        result = response_payload["result"][0]
+        self.assertEqual(result["rowcount"], 7)
+
+    @mock.patch(
+        "superset.common.query_context.config", {**app.config, 
"SAMPLES_ROW_LIMIT": 5},
+    )
+    def test_chart_data_default_sample_limit(self):
+        """
+        Chart data API: Ensure sample response row count doesn't exceed 
default limit
+        """
+        self.login(username="admin")
+        table = self.get_table_by_name("birth_names")
+        request_payload = get_query_context(table.name, table.id, table.type)
+        request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
+        request_payload["queries"][0]["row_limit"] = 10
+        rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
+        response_payload = json.loads(rv.data.decode("utf-8"))
+        result = response_payload["result"][0]
+        self.assertEqual(result["rowcount"], 5)
+
     def test_chart_data_with_invalid_datasource(self):
-        """Query API: Test chart data query with invalid schema
+        """Chart data API: Test chart data query with invalid schema
         """
         self.login(username="admin")
         table = self.get_table_by_name("birth_names")
         payload = get_query_context(table.name, table.id, table.type)
         payload["datasource"] = "abc"
-        uri = "api/v1/chart/data"
-        rv = self.post_assert_metric(uri, payload, "data")
+        rv = self.post_assert_metric(CHART_DATA_URI, payload, "data")
         self.assertEqual(rv.status_code, 400)
 
     def test_chart_data_with_invalid_enum_value(self):
-        """Query API: Test chart data query with invalid enum value
+        """Chart data API: Test chart data query with invalid enum value
         """
         self.login(username="admin")
         table = self.get_table_by_name("birth_names")
@@ -668,19 +729,17 @@ class ChartApiTests(SupersetTestCase, 
ApiOwnersTestCaseMixin):
             "abc",
             "EXCLUSIVE",
         ]
-        uri = "api/v1/chart/data"
-        rv = self.client.post(uri, json=payload)
+        rv = self.client.post(CHART_DATA_URI, json=payload)
         self.assertEqual(rv.status_code, 400)
 
     def test_query_exec_not_allowed(self):
         """
-        Query API: Test chart data query not allowed
+        Chart data API: Test chart data query not allowed
         """
         self.login(username="gamma")
         table = self.get_table_by_name("birth_names")
         payload = get_query_context(table.name, table.id, table.type)
-        uri = "api/v1/chart/data"
-        rv = self.post_assert_metric(uri, payload, "data")
+        rv = self.post_assert_metric(CHART_DATA_URI, payload, "data")
         self.assertEqual(rv.status_code, 401)
 
     def test_datasources(self):
diff --git a/tests/charts/schema_tests.py b/tests/charts/schema_tests.py
new file mode 100644
index 0000000..5f0ef16
--- /dev/null
+++ b/tests/charts/schema_tests.py
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for Superset"""
+from typing import Any, Dict, Tuple
+
+from superset.charts.schemas import ChartDataQueryContextSchema
+from superset.common.query_context import QueryContext
+from tests.base_tests import SupersetTestCase
+from tests.fixtures.query_context import get_query_context
+from tests.test_app import app
+
+
+def load_query_context(payload: Dict[str, Any]) -> Tuple[QueryContext, 
Dict[str, Any]]:
+    return ChartDataQueryContextSchema().load(payload)
+
+
+class SchemaTestCase(SupersetTestCase):
+    def test_query_context_limit_and_offset(self):
+        self.login(username="admin")
+        table_name = "birth_names"
+        table = self.get_table_by_name(table_name)
+        payload = get_query_context(table.name, table.id, table.type)
+
+        # Use defaults
+        payload["queries"][0].pop("row_limit", None)
+        payload["queries"][0].pop("row_offset", None)
+        query_context, errors = load_query_context(payload)
+        self.assertEqual(errors, {})
+        query_object = query_context.queries[0]
+        self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"])
+        self.assertEqual(query_object.row_offset, 0)
+
+        # Valid limit and offset
+        payload["queries"][0]["row_limit"] = 100
+        payload["queries"][0]["row_offset"] = 200
+        query_context, errors = ChartDataQueryContextSchema().load(payload)
+        self.assertEqual(errors, {})
+        query_object = query_context.queries[0]
+        self.assertEqual(query_object.row_limit, 100)
+        self.assertEqual(query_object.row_offset, 200)
+
+        # too low limit and offset
+        payload["queries"][0]["row_limit"] = 0
+        payload["queries"][0]["row_offset"] = -1
+        query_context, errors = ChartDataQueryContextSchema().load(payload)
+        self.assertIn("row_limit", errors["queries"][0])
+        self.assertIn("row_offset", errors["queries"][0])

Reply via email to