This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new a6cedaa chore: Improve chart data API + schemas + tests (#9599)
a6cedaa is described below
commit a6cedaaa879348aca49a520793bb20e63d152a1f
Author: Ville Brofeldt <[email protected]>
AuthorDate: Thu Apr 23 14:30:48 2020 +0300
chore: Improve chart data API + schemas + tests (#9599)
* Make all fields optional in QueryObject and fix having_druid schema
* fix: datasource type sql to table
* lint
* Add missing fields
* Refactor tests
* Linting
* Refactor query context fixtures
* Add typing to test func
---
superset/charts/schemas.py | 208 ++++++++++++++++++++++---------------
superset/common/query_object.py | 61 ++++++++---
superset/connectors/base/models.py | 4 +-
tests/access_tests.py | 4 +-
tests/base_tests.py | 21 ++--
tests/charts/api_tests.py | 63 +++++------
tests/core_tests.py | 111 --------------------
tests/dict_import_export_tests.py | 12 +--
tests/fixtures/query_context.py | 103 ++++++++++++++++++
tests/import_export_tests.py | 10 +-
tests/query_context_tests.py | 94 +++++++++++++++++
11 files changed, 421 insertions(+), 270 deletions(-)
diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index 2743732..d7438aa 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -16,7 +16,7 @@
# under the License.
from typing import Any, Dict, Union
-from marshmallow import fields, post_load, Schema, ValidationError
+from marshmallow import fields, post_load, Schema, validate, ValidationError
from marshmallow.validate import Length
from superset.common.query_context import QueryContext
@@ -77,13 +77,15 @@ class ChartDataAdhocMetricSchema(Schema):
expressionType = fields.String(
description="Simple or SQL metric",
required=True,
- enum=["SIMPLE", "SQL"],
+ validate=validate.OneOf(choices=("SIMPLE", "SQL")),
example="SQL",
)
aggregate = fields.String(
description="Aggregation operator. Only required for simple expression
types.",
required=False,
- enum=["AVG", "COUNT", "COUNT_DISTINCT", "MAX", "MIN", "SUM"],
+ validate=validate.OneOf(
+ choices=("AVG", "COUNT", "COUNT_DISTINCT", "MAX", "MIN", "SUM")
+ ),
)
column = fields.Nested(ChartDataColumnSchema)
sqlExpression = fields.String(
@@ -178,28 +180,30 @@ class
ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchem
)
rolling_type = fields.String(
description="Type of rolling window. Any numpy function will work.",
- enum=[
- "average",
- "argmin",
- "argmax",
- "cumsum",
- "cumprod",
- "max",
- "mean",
- "median",
- "nansum",
- "nanmin",
- "nanmax",
- "nanmean",
- "nanmedian",
- "min",
- "percentile",
- "prod",
- "product",
- "std",
- "sum",
- "var",
- ],
+ validate=validate.OneOf(
+ choices=(
+ "average",
+ "argmin",
+ "argmax",
+ "cumsum",
+ "cumprod",
+ "max",
+ "mean",
+ "median",
+ "nansum",
+ "nanmin",
+ "nanmax",
+ "nanmean",
+ "nanmedian",
+ "min",
+ "percentile",
+ "prod",
+ "product",
+ "std",
+ "sum",
+ "var",
+ )
+ ),
required=True,
example="percentile",
)
@@ -225,23 +229,25 @@ class
ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchem
"additional parameters to `rolling_type_options`. For instance, "
"to use `gaussian`, the parameter `std` needs to be provided.",
required=False,
- enum=[
- "boxcar",
- "triang",
- "blackman",
- "hamming",
- "bartlett",
- "parzen",
- "bohman",
- "blackmanharris",
- "nuttall",
- "barthann",
- "kaiser",
- "gaussian",
- "general_gaussian",
- "slepian",
- "exponential",
- ],
+ validate=validate.OneOf(
+ choices=(
+ "boxcar",
+ "triang",
+ "blackman",
+ "hamming",
+ "bartlett",
+ "parzen",
+ "bohman",
+ "blackmanharris",
+ "nuttall",
+ "barthann",
+ "kaiser",
+ "gaussian",
+ "general_gaussian",
+ "slepian",
+ "exponential",
+ )
+ ),
)
min_periods = fields.Integer(
description="The minimum amount of periods required for a row to be
included "
@@ -333,7 +339,9 @@ class ChartDataPostProcessingOperationSchema(Schema):
operation = fields.String(
description="Post processing operation type",
required=True,
- enum=["aggregate", "pivot", "rolling", "select", "sort"],
+ validate=validate.OneOf(
+ choices=("aggregate", "pivot", "rolling", "select", "sort")
+ ),
example="aggregate",
)
options = fields.Nested(
@@ -362,7 +370,9 @@ class ChartDataFilterSchema(Schema):
)
op = fields.String( # pylint: disable=invalid-name
description="The comparison operator.",
- enum=[filter_op.value for filter_op in utils.FilterOperator],
+ validate=validate.OneOf(
+ choices=[filter_op.value for filter_op in utils.FilterOperator]
+ ),
required=True,
example="IN",
)
@@ -376,21 +386,23 @@ class ChartDataFilterSchema(Schema):
class ChartDataExtrasSchema(Schema):
time_range_endpoints = fields.List(
- fields.String(enum=["INCLUSIVE", "EXCLUSIVE"]),
- description="A list with two values, stating if start/end should be "
- "inclusive/exclusive.",
- required=False,
+ fields.String(
+ validate=validate.OneOf(choices=("INCLUSIVE", "EXCLUSIVE")),
+ description="A list with two values, stating if start/end should
be "
+ "inclusive/exclusive.",
+ required=False,
+ )
)
relative_start = fields.String(
description="Start time for relative time deltas. "
'Default: `config["DEFAULT_RELATIVE_START_TIME"]`',
- enum=["today", "now"],
+ validate=validate.OneOf(choices=("today", "now")),
required=False,
)
relative_end = fields.String(
description="End time for relative time deltas. "
'Default: `config["DEFAULT_RELATIVE_START_TIME"]`',
- enum=["today", "now"],
+ validate=validate.OneOf(choices=("today", "now")),
required=False,
)
where = fields.String(
@@ -402,35 +414,54 @@ class ChartDataExtrasSchema(Schema):
"AND operator.",
required=False,
)
- having_druid = fields.String(
+ having_druid = fields.List(
+ fields.Nested(ChartDataFilterSchema),
description="HAVING filters to be added to legacy Druid datasource
queries.",
required=False,
)
+ time_grain_sqla = fields.String(
+ description="To what level of granularity should the temporal column
be "
+ "aggregated. Supports "
+ "[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations)
durations.",
+ validate=validate.OneOf(
+ choices=(
+ "PT1S",
+ "PT1M",
+ "PT5M",
+ "PT10M",
+ "PT15M",
+ "PT0.5H",
+ "PT1H",
+ "P1D",
+ "P1W",
+ "P1M",
+ "P0.25Y",
+ "P1Y",
+ ),
+ ),
+ required=False,
+ example="P1D",
+ )
+ druid_time_origin = fields.String(
+ description="Starting point for time grain counting on legacy Druid "
+ "datasources. Used to change e.g. Monday/Sunday first-day-of-week.",
+ required=False,
+ )
class ChartDataQueryObjectSchema(Schema):
filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False)
granularity = fields.String(
- description="To what level of granularity should the temporal column
be "
- "aggregated. Supports "
- "[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) "
- "durations.",
- enum=[
- "PT1S",
- "PT1M",
- "PT5M",
- "PT10M",
- "PT15M",
- "PT0.5H",
- "PT1H",
- "P1D",
- "P1W",
- "P1M",
- "P0.25Y",
- "P1Y",
- ],
+ description="Name of temporal column used for time filtering. For
legacy Druid "
+ "datasources this defines the time grain.",
required=False,
- example="P1D",
+ )
+ granularity_sqla = fields.String(
+ description="Name of temporal column used for time filtering for SQL "
+ "datasources. This field is deprecated, use `granularity` "
+ "instead.",
+ required=False,
+ deprecated=True,
)
groupby = fields.List(
fields.String(description="Columns by which to group the query.",),
@@ -441,6 +472,7 @@ class ChartDataQueryObjectSchema(Schema):
"references to datasource metrics (strings), or ad-hoc metrics"
"which are defined only within the query object. See "
"`ChartDataAdhocMetricSchema` for the structure of ad-hoc metrics.",
+ required=False,
)
post_processing = fields.List(
fields.Nested(ChartDataPostProcessingOperationSchema),
@@ -450,7 +482,8 @@ class ChartDataQueryObjectSchema(Schema):
)
time_range = fields.String(
description="A time rage, either expressed as a colon separated string
"
- "`since : until`. Valid formats for `since` and `until` are: \n"
+ "`since : until` or human readable freeform. Valid formats for "
+ "`since` and `until` are: \n"
"- ISO 8601\n"
"- X days/years/hours/day/year/weeks\n"
"- X days/years/hours/day/year/weeks ago\n"
@@ -488,7 +521,7 @@ class ChartDataQueryObjectSchema(Schema):
order_desc = fields.Boolean(
description="Reverse order. Default: `false`", required=False
)
- extras = fields.Dict(description=" Default: `{}`", required=False)
+ extras = fields.Nested(ChartDataExtrasSchema, required=False)
columns = fields.List(fields.String(), description="", required=False,)
orderby = fields.List(
fields.List(fields.Raw()),
@@ -499,13 +532,13 @@ class ChartDataQueryObjectSchema(Schema):
)
where = fields.String(
description="WHERE clause to be added to queries using AND operator."
- "This field is deprecated, and should be passed to `extras`.",
+ "This field is deprecated and should be passed to `extras`.",
required=False,
deprecated=True,
)
having = fields.String(
description="HAVING clause to be added to aggregate queries using "
- "AND operator. This field is deprecated, and should be passed "
+ "AND operator. This field is deprecated and should be passed "
"to `extras`.",
required=False,
deprecated=True,
@@ -513,7 +546,7 @@ class ChartDataQueryObjectSchema(Schema):
having_filters = fields.List(
fields.Dict(),
description="HAVING filters to be added to legacy Druid datasource
queries. "
- "This field is deprecated, and should be passed to `extras` "
+ "This field is deprecated and should be passed to `extras` "
"as `filters_druid`.",
required=False,
deprecated=True,
@@ -523,7 +556,10 @@ class ChartDataQueryObjectSchema(Schema):
class ChartDataDatasourceSchema(Schema):
description = "Chart datasource"
id = fields.Integer(description="Datasource id", required=True,)
- type = fields.String(description="Datasource type", enum=["druid", "sql"])
+ type = fields.String(
+ description="Datasource type",
+ validate=validate.OneOf(choices=("druid", "table")),
+ )
class ChartDataQueryContextSchema(Schema):
@@ -561,15 +597,17 @@ class ChartDataResponseResult(Schema):
)
status = fields.String(
description="Status of the query",
- enum=[
- "stopped",
- "failed",
- "pending",
- "running",
- "scheduled",
- "success",
- "timed_out",
- ],
+ validate=validate.OneOf(
+ choices=(
+ "stopped",
+ "failed",
+ "pending",
+ "running",
+ "scheduled",
+ "success",
+ "timed_out",
+ )
+ ),
allow_none=False,
)
stacktrace = fields.String(
diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index 0a83ef7..ea1f3f5 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -35,15 +35,19 @@ logger = logging.getLogger(__name__)
# https://github.com/python/mypy/issues/5288
-class DeprecatedExtrasField(NamedTuple):
- name: str
- extras_name: str
+class DeprecatedField(NamedTuple):
+ old_name: str
+ new_name: str
+DEPRECATED_FIELDS = (
+ DeprecatedField(old_name="granularity_sqla", new_name="granularity"),
+)
+
DEPRECATED_EXTRAS_FIELDS = (
- DeprecatedExtrasField(name="where", extras_name="where"),
- DeprecatedExtrasField(name="having", extras_name="having"),
- DeprecatedExtrasField(name="having_filters", extras_name="having_druid"),
+ DeprecatedField(old_name="where", new_name="where"),
+ DeprecatedField(old_name="having", new_name="having"),
+ DeprecatedField(old_name="having_filters", new_name="having_druid"),
)
@@ -53,7 +57,7 @@ class QueryObject:
and druid. The query objects are constructed on the client.
"""
- granularity: str
+ granularity: Optional[str]
from_dttm: datetime
to_dttm: datetime
is_timeseries: bool
@@ -72,8 +76,8 @@ class QueryObject:
def __init__(
self,
- granularity: str,
- metrics: List[Union[Dict[str, Any], str]],
+ granularity: Optional[str] = None,
+ metrics: Optional[List[Union[Dict[str, Any], str]]] = None,
groupby: Optional[List[str]] = None,
filters: Optional[List[Dict[str, Any]]] = None,
time_range: Optional[str] = None,
@@ -89,6 +93,7 @@ class QueryObject:
post_processing: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any,
):
+ metrics = metrics or []
extras = extras or {}
is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE")
self.granularity = granularity
@@ -131,22 +136,44 @@ class QueryObject:
if is_sip_38 and groupby:
self.columns += groupby
logger.warning(
- f"The field groupby is deprecated. Viz plugins should "
- f"pass all selectables via the columns field"
+ f"The field `groupby` is deprecated. Viz plugins should "
+ f"pass all selectables via the `columns` field"
)
self.orderby = orderby or []
- # move deprecated fields to extras
+ # rename deprecated fields
+ for field in DEPRECATED_FIELDS:
+ if field.old_name in kwargs:
+ logger.warning(
+ f"The field `{field.old_name}` is deprecated, please use "
+ f"`{field.new_name}` instead."
+ )
+ value = kwargs[field.old_name]
+ if value:
+ if hasattr(self, field.new_name):
+ logger.warning(
+ f"The field `{field.new_name}` is already
populated, "
+ f"replacing value with contents from
`{field.old_name}`."
+ )
+ setattr(self, field.new_name, value)
+
+ # move deprecated extras fields to extras
for field in DEPRECATED_EXTRAS_FIELDS:
- if field.name in kwargs:
+ if field.old_name in kwargs:
logger.warning(
- f"The field `{field.name} is deprecated, and should be "
- f"passed to `extras` via the `{field.extras_name}`
property"
+ f"The field `{field.old_name}` is deprecated and should be
"
+ f"passed to `extras` via the `{field.new_name}` property."
)
- value = kwargs[field.name]
+ value = kwargs[field.old_name]
if value:
- self.extras[field.extras_name] = value
+ if hasattr(self.extras, field.new_name):
+ logger.warning(
+ f"The field `{field.new_name}` is already
populated in "
+ f"`extras`, replacing value with contents "
+ f"from `{field.old_name}`."
+ )
+ self.extras[field.new_name] = value
def to_dict(self) -> Dict[str, Any]:
query_object_dict = {
diff --git a/superset/connectors/base/models.py
b/superset/connectors/base/models.py
index 2b6e0d2..2b051d3 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -366,7 +366,9 @@ class BaseDatasource(
def default_query(qry) -> Query:
return qry
- def get_column(self, column_name: str) -> Optional["BaseColumn"]:
+ def get_column(self, column_name: Optional[str]) -> Optional["BaseColumn"]:
+ if not column_name:
+ return None
for col in self.columns:
if col.column_name == column_name:
return col
diff --git a/tests/access_tests.py b/tests/access_tests.py
index 58affb6..f5ac8e0 100644
--- a/tests/access_tests.py
+++ b/tests/access_tests.py
@@ -385,7 +385,7 @@ class RequestAccessTests(SupersetTestCase):
)
self.assertEqual(
"[Superset] Access to the datasource {} was granted".format(
- self.get_table(ds_1_id).full_name
+ self.get_table_by_id(ds_1_id).full_name
),
call_args[2]["Subject"],
)
@@ -426,7 +426,7 @@ class RequestAccessTests(SupersetTestCase):
)
self.assertEqual(
"[Superset] Access to the datasource {} was granted".format(
- self.get_table(ds_2_id).full_name
+ self.get_table_by_id(ds_2_id).full_name
),
call_args[2]["Subject"],
)
diff --git a/tests/base_tests.py b/tests/base_tests.py
index 370adf3..ebe1d9a 100644
--- a/tests/base_tests.py
+++ b/tests/base_tests.py
@@ -18,16 +18,18 @@
"""Unit tests for Superset"""
import imp
import json
-from typing import Union, Dict
+from typing import Dict, Union
from unittest.mock import Mock, patch
import pandas as pd
from flask import Response
from flask_appbuilder.security.sqla import models as ab_models
from flask_testing import TestCase
+from sqlalchemy.orm import Session
from tests.test_app import app # isort:skip
from superset import db, security_manager
+from superset.connectors.base.models import BaseDatasource
from superset.connectors.druid.models import DruidCluster, DruidDatasource
from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models
@@ -103,7 +105,8 @@ class SupersetTestCase(TestCase):
session.add(druid_datasource2)
session.commit()
- def get_table(self, table_id):
+ @staticmethod
+ def get_table_by_id(table_id: int) -> SqlaTable:
return db.session.query(SqlaTable).filter_by(id=table_id).one()
@staticmethod
@@ -127,21 +130,25 @@ class SupersetTestCase(TestCase):
resp = self.get_resp("/login/", data=dict(username=username,
password=password))
self.assertNotIn("User confirmation needed", resp)
- def get_slice(self, slice_name, session):
+ def get_slice(self, slice_name: str, session: Session) -> Slice:
slc = session.query(Slice).filter_by(slice_name=slice_name).one()
session.expunge_all()
return slc
- def get_table_by_name(self, name):
+ @staticmethod
+ def get_table_by_name(name: str) -> SqlaTable:
return db.session.query(SqlaTable).filter_by(table_name=name).one()
- def get_database_by_id(self, db_id):
+ @staticmethod
+ def get_database_by_id(db_id: int) -> Database:
return db.session.query(Database).filter_by(id=db_id).one()
- def get_druid_ds_by_name(self, name):
+ @staticmethod
+ def get_druid_ds_by_name(name: str) -> DruidDatasource:
return
db.session.query(DruidDatasource).filter_by(datasource_name=name).first()
- def get_datasource_mock(self):
+ @staticmethod
+ def get_datasource_mock() -> BaseDatasource:
datasource = Mock()
results = Mock()
results.query = Mock()
diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py
index 0fb8471..0a7ab58 100644
--- a/tests/charts/api_tests.py
+++ b/tests/charts/api_tests.py
@@ -16,7 +16,7 @@
# under the License.
"""Unit tests for Superset"""
import json
-from typing import Any, Dict, List, Optional
+from typing import List, Optional
import prison
from sqlalchemy.sql import func
@@ -28,6 +28,7 @@ from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from tests.base_api_tests import ApiOwnersTestCaseMixin
from tests.base_tests import SupersetTestCase
+from tests.fixtures.query_context import get_query_context
class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
@@ -69,32 +70,6 @@ class ChartApiTests(SupersetTestCase,
ApiOwnersTestCaseMixin):
db.session.commit()
return slice
- def _get_query_context(self) -> Dict[str, Any]:
- self.login(username="admin")
- slc = self.get_slice("Girl Name Cloud", db.session)
- return {
- "datasource": {"id": slc.datasource_id, "type":
slc.datasource_type},
- "queries": [
- {
- "extras": {"where": ""},
- "granularity": "ds",
- "groupby": ["name"],
- "is_timeseries": False,
- "metrics": [{"label": "sum__num"}],
- "order_desc": True,
- "orderby": [],
- "row_limit": 100,
- "time_range": "100 years ago : now",
- "timeseries_limit": 0,
- "timeseries_limit_metric": None,
- "filters": [{"col": "gender", "op": "==", "val": "boy"}],
- "having": "",
- "having_filters": [],
- "where": "",
- }
- ],
- }
-
def test_delete_chart(self):
"""
Chart API: Test delete
@@ -662,22 +637,37 @@ class ChartApiTests(SupersetTestCase,
ApiOwnersTestCaseMixin):
Query API: Test chart data query
"""
self.login(username="admin")
- query_context = self._get_query_context()
+ table = self.get_table_by_name("birth_names")
+ payload = get_query_context(table.name, table.id, table.type)
uri = "api/v1/chart/data"
- rv = self.post_assert_metric(uri, query_context, "data")
+ rv = self.post_assert_metric(uri, payload, "data")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["result"][0]["rowcount"], 100)
- def test_invalid_chart_data(self):
+ def test_chart_data_with_invalid_datasource(self):
+ """Query API: Test chart data query with invalid schema
"""
- Query API: Test chart data query with invalid schema
+ self.login(username="admin")
+ table = self.get_table_by_name("birth_names")
+ payload = get_query_context(table.name, table.id, table.type)
+ payload["datasource"] = "abc"
+ uri = "api/v1/chart/data"
+ rv = self.post_assert_metric(uri, payload, "data")
+ self.assertEqual(rv.status_code, 400)
+
+ def test_chart_data_with_invalid_enum_value(self):
+ """Query API: Test chart data query with invalid enum value
"""
self.login(username="admin")
- query_context = self._get_query_context()
- query_context["datasource"] = "abc"
+ table = self.get_table_by_name("birth_names")
+ payload = get_query_context(table.name, table.id, table.type)
+ payload["queries"][0]["extras"]["time_range_endpoints"] = [
+ "abc",
+ "EXCLUSIVE",
+ ]
uri = "api/v1/chart/data"
- rv = self.client.post(uri, json=query_context)
+ rv = self.client.post(uri, json=payload)
self.assertEqual(rv.status_code, 400)
def test_query_exec_not_allowed(self):
@@ -685,9 +675,10 @@ class ChartApiTests(SupersetTestCase,
ApiOwnersTestCaseMixin):
Query API: Test chart data query not allowed
"""
self.login(username="gamma")
- query_context = self._get_query_context()
+ table = self.get_table_by_name("birth_names")
+ payload = get_query_context(table.name, table.id, table.type)
uri = "api/v1/chart/data"
- rv = self.post_assert_metric(uri, query_context, "data")
+ rv = self.post_assert_metric(uri, payload, "data")
self.assertEqual(rv.status_code, 401)
def test_datasources(self):
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 04cc587..ecbfa05 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -28,7 +28,6 @@ import pytz
import random
import re
import string
-from typing import Any, Dict
import unittest
from unittest import mock, skipUnless
@@ -44,8 +43,6 @@ from superset import (
sql_lab,
is_feature_enabled,
)
-from superset.common.query_context import QueryContext
-from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import SqlaTable
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
@@ -111,61 +108,6 @@ class CoreTests(SupersetTestCase):
resp = self.client.get("/superset/slice/-1/")
assert resp.status_code == 404
- def _get_query_context(self) -> Dict[str, Any]:
- self.login(username="admin")
- slc = self.get_slice("Girl Name Cloud", db.session)
- return {
- "datasource": {"id": slc.datasource_id, "type":
slc.datasource_type},
- "queries": [
- {
- "granularity": "ds",
- "groupby": ["name"],
- "metrics": [{"label": "sum__num"}],
- "filters": [],
- "row_limit": 100,
- }
- ],
- }
-
- def _get_query_context_with_post_processing(self) -> Dict[str, Any]:
- self.login(username="admin")
- slc = self.get_slice("Girl Name Cloud", db.session)
- return {
- "datasource": {"id": slc.datasource_id, "type":
slc.datasource_type},
- "queries": [
- {
- "granularity": "ds",
- "groupby": ["name", "state"],
- "metrics": [{"label": "sum__num"}],
- "filters": [],
- "row_limit": 100,
- "post_processing": [
- {
- "operation": "aggregate",
- "options": {
- "groupby": ["state"],
- "aggregates": {
- "q1": {
- "operator": "percentile",
- "column": "sum__num",
- "options": {"q": 25},
- },
- "median": {
- "operator": "median",
- "column": "sum__num",
- },
- },
- },
- },
- {
- "operation": "sort",
- "options": {"columns": {"q1": False, "state":
True},},
- },
- ],
- }
- ],
- }
-
def test_viz_cache_key(self):
self.login(username="admin")
slc = self.get_slice("Girls", db.session)
@@ -178,45 +120,6 @@ class CoreTests(SupersetTestCase):
qobj["groupby"] = []
self.assertNotEqual(cache_key, viz.cache_key(qobj))
- def test_cache_key_changes_when_datasource_is_updated(self):
- qc_dict = self._get_query_context()
-
- # construct baseline cache_key
- query_context = QueryContext(**qc_dict)
- query_object = query_context.queries[0]
- cache_key_original = query_context.cache_key(query_object)
-
- # make temporary change and revert it to refresh the changed_on
property
- datasource = ConnectorRegistry.get_datasource(
- datasource_type=qc_dict["datasource"]["type"],
- datasource_id=qc_dict["datasource"]["id"],
- session=db.session,
- )
- description_original = datasource.description
- datasource.description = "temporary description"
- db.session.commit()
- datasource.description = description_original
- db.session.commit()
-
- # create new QueryContext with unchanged attributes and extract new
cache_key
- query_context = QueryContext(**qc_dict)
- query_object = query_context.queries[0]
- cache_key_new = query_context.cache_key(query_object)
-
- # the new cache_key should be different due to updated datasource
- self.assertNotEqual(cache_key_original, cache_key_new)
-
- def test_query_context_time_range_endpoints(self):
- query_context = QueryContext(**self._get_query_context())
- query_object = query_context.queries[0]
- extras = query_object.to_dict()["extras"]
- self.assertTrue("time_range_endpoints" in extras)
-
- self.assertEquals(
- extras["time_range_endpoints"],
- (utils.TimeRangeEndpoint.INCLUSIVE,
utils.TimeRangeEndpoint.EXCLUSIVE),
- )
-
def test_get_superset_tables_not_allowed(self):
example_db = utils.get_example_database()
schema_name = self.default_schema_backend_map[example_db.backend]
@@ -254,20 +157,6 @@ class CoreTests(SupersetTestCase):
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
- def test_api_v1_query_endpoint(self):
- self.login(username="admin")
- qc_dict = self._get_query_context()
- data = json.dumps(qc_dict)
- resp = json.loads(self.get_resp("/api/v1/query/", {"query_context":
data}))
- self.assertEqual(resp[0]["rowcount"], 100)
-
- def test_api_v1_query_endpoint_with_post_processing(self):
- self.login(username="admin")
- qc_dict = self._get_query_context_with_post_processing()
- data = json.dumps(qc_dict)
- resp = json.loads(self.get_resp("/api/v1/query/", {"query_context":
data}))
- self.assertEqual(resp[0]["rowcount"], 6)
-
def test_old_slice_json_endpoint(self):
self.login(username="admin")
slc = self.get_slice("Girls", db.session)
diff --git a/tests/dict_import_export_tests.py
b/tests/dict_import_export_tests.py
index 404709c..e857479 100644
--- a/tests/dict_import_export_tests.py
+++ b/tests/dict_import_export_tests.py
@@ -165,7 +165,7 @@ class DictImportExportTests(SupersetTestCase):
new_table = SqlaTable.import_from_dict(db.session, dict_table)
db.session.commit()
imported_id = new_table.id
- imported = self.get_table(imported_id)
+ imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported)
self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
@@ -178,7 +178,7 @@ class DictImportExportTests(SupersetTestCase):
)
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
db.session.commit()
- imported = self.get_table(imported_table.id)
+ imported = self.get_table_by_id(imported_table.id)
self.assert_table_equals(table, imported)
self.assertEqual(
{DBREF: ID_PREFIX + 2, "database_name": "main"},
json.loads(imported.params)
@@ -194,7 +194,7 @@ class DictImportExportTests(SupersetTestCase):
)
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
db.session.commit()
- imported = self.get_table(imported_table.id)
+ imported = self.get_table_by_id(imported_table.id)
self.assert_table_equals(table, imported)
self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
@@ -213,7 +213,7 @@ class DictImportExportTests(SupersetTestCase):
imported_over_table = SqlaTable.import_from_dict(db.session,
dict_table_over)
db.session.commit()
- imported_over = self.get_table(imported_over_table.id)
+ imported_over = self.get_table_by_id(imported_over_table.id)
self.assertEqual(imported_table.id, imported_over.id)
expected_table, _ = self.create_table(
"table_override",
@@ -243,7 +243,7 @@ class DictImportExportTests(SupersetTestCase):
)
db.session.commit()
- imported_over = self.get_table(imported_over_table.id)
+ imported_over = self.get_table_by_id(imported_over_table.id)
self.assertEqual(imported_table.id, imported_over.id)
expected_table, _ = self.create_table(
"table_override",
@@ -274,7 +274,7 @@ class DictImportExportTests(SupersetTestCase):
imported_copy_table = SqlaTable.import_from_dict(db.session,
dict_copy_table)
db.session.commit()
self.assertEqual(imported_table.id, imported_copy_table.id)
- self.assert_table_equals(copy_table, self.get_table(imported_table.id))
+ self.assert_table_equals(copy_table,
self.get_table_by_id(imported_table.id))
self.yaml_compare(
imported_copy_table.export_to_dict(),
imported_table.export_to_dict()
)
diff --git a/tests/fixtures/query_context.py b/tests/fixtures/query_context.py
new file mode 100644
index 0000000..e886fda
--- /dev/null
+++ b/tests/fixtures/query_context.py
@@ -0,0 +1,103 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import copy
+from typing import Any, Dict, List
+
+QUERY_OBJECTS = {
+ "birth_names": {
+ "extras": {"where": "", "time_range_endpoints": ["INCLUSIVE",
"EXCLUSIVE"],},
+ "granularity": "ds",
+ "groupby": ["name"],
+ "is_timeseries": False,
+ "metrics": [{"label": "sum__num"}],
+ "order_desc": True,
+ "orderby": [],
+ "row_limit": 100,
+ "time_range": "100 years ago : now",
+ "timeseries_limit": 0,
+ "timeseries_limit_metric": None,
+ "filters": [{"col": "gender", "op": "==", "val": "boy"}],
+ "having": "",
+ "having_filters": [],
+ "where": "",
+ }
+}
+
+POSTPROCESSING_OPERATIONS = {
+ "birth_names": [
+ {
+ "operation": "aggregate",
+ "options": {
+ "groupby": ["gender"],
+ "aggregates": {
+ "q1": {
+ "operator": "percentile",
+ "column": "sum__num",
+ "options": {"q": 25},
+ },
+ "median": {"operator": "median", "column": "sum__num",},
+ },
+ },
+ },
+ {"operation": "sort", "options": {"columns": {"q1": False, "gender":
True},},},
+ ]
+}
+
+
+def _get_query_object(
+ datasource_name: str, add_postprocessing_operations: bool
+) -> Dict[str, Any]:
+ if datasource_name not in QUERY_OBJECTS:
+ raise Exception(
+ f"QueryObject fixture not defined for datasource:
{datasource_name}"
+ )
+ query_object = copy.deepcopy(QUERY_OBJECTS[datasource_name])
+ if add_postprocessing_operations:
+ query_object["post_processing"] =
_get_postprocessing_operation(datasource_name)
+ return query_object
+
+
+def _get_postprocessing_operation(datasource_name: str) -> List[Dict[str,
Any]]:
+ if datasource_name not in QUERY_OBJECTS:
+ raise Exception(
+ f"Post-processing fixture not defined for datasource:
{datasource_name}"
+ )
+ return copy.deepcopy(POSTPROCESSING_OPERATIONS[datasource_name])
+
+
+def get_query_context(
+ datasource_name: str = "birth_names",
+ datasource_id: int = 0,
+ datasource_type: str = "table",
+ add_postprocessing_operations: bool = False,
+) -> Dict[str, Any]:
+ """
+ Create a request payload for retrieving a QueryContext object via the
+ `api/v1/chart/data` endpoint. By default returns a payload corresponding
to one
+ generated by the "Boy Name Cloud" chart in the examples.
+
+ :param datasource_name: name of datasource to query. Different datasources
require
+ different parameters in the QueryContext.
+ :param datasource_id: id of datasource to query.
+ :param datasource_type: type of datasource to query.
+ :param add_postprocessing_operations: Add post-processing operations to
QueryObject
+ :return: Request payload
+ """
+ return {
+ "datasource": {"id": datasource_id, "type": datasource_type},
+ "queries": [_get_query_object(datasource_name,
add_postprocessing_operations)],
+ }
diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py
index 458ab04..b47b43b 100644
--- a/tests/import_export_tests.py
+++ b/tests/import_export_tests.py
@@ -558,7 +558,7 @@ class ImportExportTests(SupersetTestCase):
def test_import_table_no_metadata(self):
table = self.create_table("pure_table", id=10001)
imported_id = SqlaTable.import_obj(table, import_time=1989)
- imported = self.get_table(imported_id)
+ imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported)
def test_import_table_1_col_1_met(self):
@@ -566,7 +566,7 @@ class ImportExportTests(SupersetTestCase):
"table_1_col_1_met", id=10002, cols_names=["col1"],
metric_names=["metric1"]
)
imported_id = SqlaTable.import_obj(table, import_time=1990)
- imported = self.get_table(imported_id)
+ imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported)
self.assertEqual(
{"remote_id": 10002, "import_time": 1990, "database_name":
"examples"},
@@ -582,7 +582,7 @@ class ImportExportTests(SupersetTestCase):
)
imported_id = SqlaTable.import_obj(table, import_time=1991)
- imported = self.get_table(imported_id)
+ imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported)
def test_import_table_override(self):
@@ -599,7 +599,7 @@ class ImportExportTests(SupersetTestCase):
)
imported_over_id = SqlaTable.import_obj(table_over, import_time=1992)
- imported_over = self.get_table(imported_over_id)
+ imported_over = self.get_table_by_id(imported_over_id)
self.assertEqual(imported_id, imported_over.id)
expected_table = self.create_table(
"table_override",
@@ -627,7 +627,7 @@ class ImportExportTests(SupersetTestCase):
imported_id_copy = SqlaTable.import_obj(copy_table, import_time=1994)
self.assertEqual(imported_id, imported_id_copy)
- self.assert_table_equals(copy_table, self.get_table(imported_id))
+ self.assert_table_equals(copy_table, self.get_table_by_id(imported_id))
def test_import_druid_no_metadata(self):
datasource = self.create_druid_datasource("pure_druid", id=10001)
diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py
new file mode 100644
index 0000000..2d378ba
--- /dev/null
+++ b/tests/query_context_tests.py
@@ -0,0 +1,94 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Any, Dict, List, Optional
+
+from superset import db
+from superset.common.query_context import QueryContext
+from superset.connectors.connector_registry import ConnectorRegistry
+from superset.utils.core import TimeRangeEndpoint
+from tests.base_tests import SupersetTestCase
+from tests.fixtures.query_context import get_query_context
+from tests.test_app import app
+
+
+class QueryContextTests(SupersetTestCase):
+ def test_cache_key_changes_when_datasource_is_updated(self):
+ self.login(username="admin")
+ table_name = "birth_names"
+ table = self.get_table_by_name(table_name)
+ payload = get_query_context(table.name, table.id, table.type)
+
+ # construct baseline cache_key
+ query_context = QueryContext(**payload)
+ query_object = query_context.queries[0]
+ cache_key_original = query_context.cache_key(query_object)
+
+ # make temporary change and revert it to refresh the changed_on
property
+ datasource = ConnectorRegistry.get_datasource(
+ datasource_type=payload["datasource"]["type"],
+ datasource_id=payload["datasource"]["id"],
+ session=db.session,
+ )
+ description_original = datasource.description
+ datasource.description = "temporary description"
+ db.session.commit()
+ datasource.description = description_original
+ db.session.commit()
+
+ # create new QueryContext with unchanged attributes and extract new
cache_key
+ query_context = QueryContext(**payload)
+ query_object = query_context.queries[0]
+ cache_key_new = query_context.cache_key(query_object)
+
+ # the new cache_key should be different due to updated datasource
+ self.assertNotEqual(cache_key_original, cache_key_new)
+
+ def test_query_context_time_range_endpoints(self):
+ """
+ Ensure that time_range_endpoints are populated automatically when
missing
+ from the payload
+ """
+ self.login(username="admin")
+ table_name = "birth_names"
+ table = self.get_table_by_name(table_name)
+ payload = get_query_context(table.name, table.id, table.type)
+ del payload["queries"][0]["extras"]["time_range_endpoints"]
+ query_context = QueryContext(**payload)
+ query_object = query_context.queries[0]
+ extras = query_object.to_dict()["extras"]
+ self.assertTrue("time_range_endpoints" in extras)
+
+ self.assertEquals(
+ extras["time_range_endpoints"],
+ (TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE),
+ )
+
+ def test_convert_deprecated_fields(self):
+ """
+ Ensure that deprecated fields are converted correctly
+ """
+ self.login(username="admin")
+ table_name = "birth_names"
+ table = self.get_table_by_name(table_name)
+ payload = get_query_context(table.name, table.id, table.type)
+ payload["queries"][0]["granularity_sqla"] = "timecol"
+ payload["queries"][0]["having_filters"] = {"col": "a", "op": "==",
"val": "b"}
+ query_context = QueryContext(**payload)
+ self.assertEqual(len(query_context.queries), 1)
+ query_object = query_context.queries[0]
+ self.assertEqual(query_object.granularity, "timecol")
+ self.assertIn("having_druid", query_object.extras)