This is an automated email from the ASF dual-hosted git repository. villebro pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push: new a6cedaa chore: Improve chart data API + schemas + tests (#9599) a6cedaa is described below commit a6cedaaa879348aca49a520793bb20e63d152a1f Author: Ville Brofeldt <33317356+ville...@users.noreply.github.com> AuthorDate: Thu Apr 23 14:30:48 2020 +0300 chore: Improve chart data API + schemas + tests (#9599) * Make all fields optional in QueryObject and fix having_druid schema * fix: datasource type sql to table * lint * Add missing fields * Refactor tests * Linting * Refactor query context fixtures * Add typing to test func --- superset/charts/schemas.py | 208 ++++++++++++++++++++++--------------- superset/common/query_object.py | 61 ++++++++--- superset/connectors/base/models.py | 4 +- tests/access_tests.py | 4 +- tests/base_tests.py | 21 ++-- tests/charts/api_tests.py | 63 +++++------ tests/core_tests.py | 111 -------------------- tests/dict_import_export_tests.py | 12 +-- tests/fixtures/query_context.py | 103 ++++++++++++++++++ tests/import_export_tests.py | 10 +- tests/query_context_tests.py | 94 +++++++++++++++++ 11 files changed, 421 insertions(+), 270 deletions(-) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 2743732..d7438aa 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -16,7 +16,7 @@ # under the License. from typing import Any, Dict, Union -from marshmallow import fields, post_load, Schema, ValidationError +from marshmallow import fields, post_load, Schema, validate, ValidationError from marshmallow.validate import Length from superset.common.query_context import QueryContext @@ -77,13 +77,15 @@ class ChartDataAdhocMetricSchema(Schema): expressionType = fields.String( description="Simple or SQL metric", required=True, - enum=["SIMPLE", "SQL"], + validate=validate.OneOf(choices=("SIMPLE", "SQL")), example="SQL", ) aggregate = fields.String( description="Aggregation operator. Only required for simple expression types.", required=False, - enum=["AVG", "COUNT", "COUNT_DISTINCT", "MAX", "MIN", "SUM"], + validate=validate.OneOf( + choices=("AVG", "COUNT", "COUNT_DISTINCT", "MAX", "MIN", "SUM") + ), ) column = fields.Nested(ChartDataColumnSchema) sqlExpression = fields.String( @@ -178,28 +180,30 @@ class ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchem ) rolling_type = fields.String( description="Type of rolling window. Any numpy function will work.", - enum=[ - "average", - "argmin", - "argmax", - "cumsum", - "cumprod", - "max", - "mean", - "median", - "nansum", - "nanmin", - "nanmax", - "nanmean", - "nanmedian", - "min", - "percentile", - "prod", - "product", - "std", - "sum", - "var", - ], + validate=validate.OneOf( + choices=( + "average", + "argmin", + "argmax", + "cumsum", + "cumprod", + "max", + "mean", + "median", + "nansum", + "nanmin", + "nanmax", + "nanmean", + "nanmedian", + "min", + "percentile", + "prod", + "product", + "std", + "sum", + "var", + ) + ), required=True, example="percentile", ) @@ -225,23 +229,25 @@ class ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchem "additional parameters to `rolling_type_options`. For instance, " "to use `gaussian`, the parameter `std` needs to be provided.", required=False, - enum=[ - "boxcar", - "triang", - "blackman", - "hamming", - "bartlett", - "parzen", - "bohman", - "blackmanharris", - "nuttall", - "barthann", - "kaiser", - "gaussian", - "general_gaussian", - "slepian", - "exponential", - ], + validate=validate.OneOf( + choices=( + "boxcar", + "triang", + "blackman", + "hamming", + "bartlett", + "parzen", + "bohman", + "blackmanharris", + "nuttall", + "barthann", + "kaiser", + "gaussian", + "general_gaussian", + "slepian", + "exponential", + ) + ), ) min_periods = fields.Integer( description="The minimum amount of periods required for a row to be included " @@ -333,7 +339,9 @@ class ChartDataPostProcessingOperationSchema(Schema): operation = fields.String( description="Post processing operation type", required=True, - enum=["aggregate", "pivot", "rolling", "select", "sort"], + validate=validate.OneOf( + choices=("aggregate", "pivot", "rolling", "select", "sort") + ), example="aggregate", ) options = fields.Nested( @@ -362,7 +370,9 @@ class ChartDataFilterSchema(Schema): ) op = fields.String( # pylint: disable=invalid-name description="The comparison operator.", - enum=[filter_op.value for filter_op in utils.FilterOperator], + validate=validate.OneOf( + choices=[filter_op.value for filter_op in utils.FilterOperator] + ), required=True, example="IN", ) @@ -376,21 +386,23 @@ class ChartDataFilterSchema(Schema): class ChartDataExtrasSchema(Schema): time_range_endpoints = fields.List( - fields.String(enum=["INCLUSIVE", "EXCLUSIVE"]), - description="A list with two values, stating if start/end should be " - "inclusive/exclusive.", - required=False, + fields.String( + validate=validate.OneOf(choices=("INCLUSIVE", "EXCLUSIVE")), + description="A list with two values, stating if start/end should be " + "inclusive/exclusive.", + required=False, + ) ) relative_start = fields.String( description="Start time for relative time deltas. " 'Default: `config["DEFAULT_RELATIVE_START_TIME"]`', - enum=["today", "now"], + validate=validate.OneOf(choices=("today", "now")), required=False, ) relative_end = fields.String( description="End time for relative time deltas. " 'Default: `config["DEFAULT_RELATIVE_START_TIME"]`', - enum=["today", "now"], + validate=validate.OneOf(choices=("today", "now")), required=False, ) where = fields.String( @@ -402,35 +414,54 @@ class ChartDataExtrasSchema(Schema): "AND operator.", required=False, ) - having_druid = fields.String( + having_druid = fields.List( + fields.Nested(ChartDataFilterSchema), description="HAVING filters to be added to legacy Druid datasource queries.", required=False, ) + time_grain_sqla = fields.String( + description="To what level of granularity should the temporal column be " + "aggregated. Supports " + "[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) durations.", + validate=validate.OneOf( + choices=( + "PT1S", + "PT1M", + "PT5M", + "PT10M", + "PT15M", + "PT0.5H", + "PT1H", + "P1D", + "P1W", + "P1M", + "P0.25Y", + "P1Y", + ), + ), + required=False, + example="P1D", + ) + druid_time_origin = fields.String( + description="Starting point for time grain counting on legacy Druid " + "datasources. Used to change e.g. Monday/Sunday first-day-of-week.", + required=False, + ) class ChartDataQueryObjectSchema(Schema): filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False) granularity = fields.String( - description="To what level of granularity should the temporal column be " - "aggregated. Supports " - "[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) " - "durations.", - enum=[ - "PT1S", - "PT1M", - "PT5M", - "PT10M", - "PT15M", - "PT0.5H", - "PT1H", - "P1D", - "P1W", - "P1M", - "P0.25Y", - "P1Y", - ], + description="Name of temporal column used for time filtering. For legacy Druid " + "datasources this defines the time grain.", required=False, - example="P1D", + ) + granularity_sqla = fields.String( + description="Name of temporal column used for time filtering for SQL " + "datasources. This field is deprecated, use `granularity` " + "instead.", + required=False, + deprecated=True, ) groupby = fields.List( fields.String(description="Columns by which to group the query.",), @@ -441,6 +472,7 @@ class ChartDataQueryObjectSchema(Schema): "references to datasource metrics (strings), or ad-hoc metrics" "which are defined only within the query object. See " "`ChartDataAdhocMetricSchema` for the structure of ad-hoc metrics.", + required=False, ) post_processing = fields.List( fields.Nested(ChartDataPostProcessingOperationSchema), @@ -450,7 +482,8 @@ class ChartDataQueryObjectSchema(Schema): ) time_range = fields.String( description="A time rage, either expressed as a colon separated string " - "`since : until`. Valid formats for `since` and `until` are: \n" + "`since : until` or human readable freeform. Valid formats for " + "`since` and `until` are: \n" "- ISO 8601\n" "- X days/years/hours/day/year/weeks\n" "- X days/years/hours/day/year/weeks ago\n" @@ -488,7 +521,7 @@ class ChartDataQueryObjectSchema(Schema): order_desc = fields.Boolean( description="Reverse order. Default: `false`", required=False ) - extras = fields.Dict(description=" Default: `{}`", required=False) + extras = fields.Nested(ChartDataExtrasSchema, required=False) columns = fields.List(fields.String(), description="", required=False,) orderby = fields.List( fields.List(fields.Raw()), @@ -499,13 +532,13 @@ class ChartDataQueryObjectSchema(Schema): ) where = fields.String( description="WHERE clause to be added to queries using AND operator." - "This field is deprecated, and should be passed to `extras`.", + "This field is deprecated and should be passed to `extras`.", required=False, deprecated=True, ) having = fields.String( description="HAVING clause to be added to aggregate queries using " - "AND operator. This field is deprecated, and should be passed " + "AND operator. This field is deprecated and should be passed " "to `extras`.", required=False, deprecated=True, @@ -513,7 +546,7 @@ class ChartDataQueryObjectSchema(Schema): having_filters = fields.List( fields.Dict(), description="HAVING filters to be added to legacy Druid datasource queries. " - "This field is deprecated, and should be passed to `extras` " + "This field is deprecated and should be passed to `extras` " "as `filters_druid`.", required=False, deprecated=True, @@ -523,7 +556,10 @@ class ChartDataQueryObjectSchema(Schema): class ChartDataDatasourceSchema(Schema): description = "Chart datasource" id = fields.Integer(description="Datasource id", required=True,) - type = fields.String(description="Datasource type", enum=["druid", "sql"]) + type = fields.String( + description="Datasource type", + validate=validate.OneOf(choices=("druid", "table")), + ) class ChartDataQueryContextSchema(Schema): @@ -561,15 +597,17 @@ class ChartDataResponseResult(Schema): ) status = fields.String( description="Status of the query", - enum=[ - "stopped", - "failed", - "pending", - "running", - "scheduled", - "success", - "timed_out", - ], + validate=validate.OneOf( + choices=( + "stopped", + "failed", + "pending", + "running", + "scheduled", + "success", + "timed_out", + ) + ), allow_none=False, ) stacktrace = fields.String( diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 0a83ef7..ea1f3f5 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -35,15 +35,19 @@ logger = logging.getLogger(__name__) # https://github.com/python/mypy/issues/5288 -class DeprecatedExtrasField(NamedTuple): - name: str - extras_name: str +class DeprecatedField(NamedTuple): + old_name: str + new_name: str +DEPRECATED_FIELDS = ( + DeprecatedField(old_name="granularity_sqla", new_name="granularity"), +) + DEPRECATED_EXTRAS_FIELDS = ( - DeprecatedExtrasField(name="where", extras_name="where"), - DeprecatedExtrasField(name="having", extras_name="having"), - DeprecatedExtrasField(name="having_filters", extras_name="having_druid"), + DeprecatedField(old_name="where", new_name="where"), + DeprecatedField(old_name="having", new_name="having"), + DeprecatedField(old_name="having_filters", new_name="having_druid"), ) @@ -53,7 +57,7 @@ class QueryObject: and druid. The query objects are constructed on the client. """ - granularity: str + granularity: Optional[str] from_dttm: datetime to_dttm: datetime is_timeseries: bool @@ -72,8 +76,8 @@ class QueryObject: def __init__( self, - granularity: str, - metrics: List[Union[Dict[str, Any], str]], + granularity: Optional[str] = None, + metrics: Optional[List[Union[Dict[str, Any], str]]] = None, groupby: Optional[List[str]] = None, filters: Optional[List[Dict[str, Any]]] = None, time_range: Optional[str] = None, @@ -89,6 +93,7 @@ class QueryObject: post_processing: Optional[List[Dict[str, Any]]] = None, **kwargs: Any, ): + metrics = metrics or [] extras = extras or {} is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE") self.granularity = granularity @@ -131,22 +136,44 @@ class QueryObject: if is_sip_38 and groupby: self.columns += groupby logger.warning( - f"The field groupby is deprecated. Viz plugins should " - f"pass all selectables via the columns field" + f"The field `groupby` is deprecated. Viz plugins should " + f"pass all selectables via the `columns` field" ) self.orderby = orderby or [] - # move deprecated fields to extras + # rename deprecated fields + for field in DEPRECATED_FIELDS: + if field.old_name in kwargs: + logger.warning( + f"The field `{field.old_name}` is deprecated, please use " + f"`{field.new_name}` instead." + ) + value = kwargs[field.old_name] + if value: + if hasattr(self, field.new_name): + logger.warning( + f"The field `{field.new_name}` is already populated, " + f"replacing value with contents from `{field.old_name}`." + ) + setattr(self, field.new_name, value) + + # move deprecated extras fields to extras for field in DEPRECATED_EXTRAS_FIELDS: - if field.name in kwargs: + if field.old_name in kwargs: logger.warning( - f"The field `{field.name} is deprecated, and should be " - f"passed to `extras` via the `{field.extras_name}` property" + f"The field `{field.old_name}` is deprecated and should be " + f"passed to `extras` via the `{field.new_name}` property." ) - value = kwargs[field.name] + value = kwargs[field.old_name] if value: - self.extras[field.extras_name] = value + if hasattr(self.extras, field.new_name): + logger.warning( + f"The field `{field.new_name}` is already populated in " + f"`extras`, replacing value with contents " + f"from `{field.old_name}`." + ) + self.extras[field.new_name] = value def to_dict(self) -> Dict[str, Any]: query_object_dict = { diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index 2b6e0d2..2b051d3 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -366,7 +366,9 @@ class BaseDatasource( def default_query(qry) -> Query: return qry - def get_column(self, column_name: str) -> Optional["BaseColumn"]: + def get_column(self, column_name: Optional[str]) -> Optional["BaseColumn"]: + if not column_name: + return None for col in self.columns: if col.column_name == column_name: return col diff --git a/tests/access_tests.py b/tests/access_tests.py index 58affb6..f5ac8e0 100644 --- a/tests/access_tests.py +++ b/tests/access_tests.py @@ -385,7 +385,7 @@ class RequestAccessTests(SupersetTestCase): ) self.assertEqual( "[Superset] Access to the datasource {} was granted".format( - self.get_table(ds_1_id).full_name + self.get_table_by_id(ds_1_id).full_name ), call_args[2]["Subject"], ) @@ -426,7 +426,7 @@ class RequestAccessTests(SupersetTestCase): ) self.assertEqual( "[Superset] Access to the datasource {} was granted".format( - self.get_table(ds_2_id).full_name + self.get_table_by_id(ds_2_id).full_name ), call_args[2]["Subject"], ) diff --git a/tests/base_tests.py b/tests/base_tests.py index 370adf3..ebe1d9a 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -18,16 +18,18 @@ """Unit tests for Superset""" import imp import json -from typing import Union, Dict +from typing import Dict, Union from unittest.mock import Mock, patch import pandas as pd from flask import Response from flask_appbuilder.security.sqla import models as ab_models from flask_testing import TestCase +from sqlalchemy.orm import Session from tests.test_app import app # isort:skip from superset import db, security_manager +from superset.connectors.base.models import BaseDatasource from superset.connectors.druid.models import DruidCluster, DruidDatasource from superset.connectors.sqla.models import SqlaTable from superset.models import core as models @@ -103,7 +105,8 @@ class SupersetTestCase(TestCase): session.add(druid_datasource2) session.commit() - def get_table(self, table_id): + @staticmethod + def get_table_by_id(table_id: int) -> SqlaTable: return db.session.query(SqlaTable).filter_by(id=table_id).one() @staticmethod @@ -127,21 +130,25 @@ class SupersetTestCase(TestCase): resp = self.get_resp("/login/", data=dict(username=username, password=password)) self.assertNotIn("User confirmation needed", resp) - def get_slice(self, slice_name, session): + def get_slice(self, slice_name: str, session: Session) -> Slice: slc = session.query(Slice).filter_by(slice_name=slice_name).one() session.expunge_all() return slc - def get_table_by_name(self, name): + @staticmethod + def get_table_by_name(name: str) -> SqlaTable: return db.session.query(SqlaTable).filter_by(table_name=name).one() - def get_database_by_id(self, db_id): + @staticmethod + def get_database_by_id(db_id: int) -> Database: return db.session.query(Database).filter_by(id=db_id).one() - def get_druid_ds_by_name(self, name): + @staticmethod + def get_druid_ds_by_name(name: str) -> DruidDatasource: return db.session.query(DruidDatasource).filter_by(datasource_name=name).first() - def get_datasource_mock(self): + @staticmethod + def get_datasource_mock() -> BaseDatasource: datasource = Mock() results = Mock() results.query = Mock() diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index 0fb8471..0a7ab58 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -16,7 +16,7 @@ # under the License. """Unit tests for Superset""" import json -from typing import Any, Dict, List, Optional +from typing import List, Optional import prison from sqlalchemy.sql import func @@ -28,6 +28,7 @@ from superset.models.dashboard import Dashboard from superset.models.slice import Slice from tests.base_api_tests import ApiOwnersTestCaseMixin from tests.base_tests import SupersetTestCase +from tests.fixtures.query_context import get_query_context class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin): @@ -69,32 +70,6 @@ class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin): db.session.commit() return slice - def _get_query_context(self) -> Dict[str, Any]: - self.login(username="admin") - slc = self.get_slice("Girl Name Cloud", db.session) - return { - "datasource": {"id": slc.datasource_id, "type": slc.datasource_type}, - "queries": [ - { - "extras": {"where": ""}, - "granularity": "ds", - "groupby": ["name"], - "is_timeseries": False, - "metrics": [{"label": "sum__num"}], - "order_desc": True, - "orderby": [], - "row_limit": 100, - "time_range": "100 years ago : now", - "timeseries_limit": 0, - "timeseries_limit_metric": None, - "filters": [{"col": "gender", "op": "==", "val": "boy"}], - "having": "", - "having_filters": [], - "where": "", - } - ], - } - def test_delete_chart(self): """ Chart API: Test delete @@ -662,22 +637,37 @@ class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin): Query API: Test chart data query """ self.login(username="admin") - query_context = self._get_query_context() + table = self.get_table_by_name("birth_names") + payload = get_query_context(table.name, table.id, table.type) uri = "api/v1/chart/data" - rv = self.post_assert_metric(uri, query_context, "data") + rv = self.post_assert_metric(uri, payload, "data") self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) self.assertEqual(data["result"][0]["rowcount"], 100) - def test_invalid_chart_data(self): + def test_chart_data_with_invalid_datasource(self): + """Query API: Test chart data query with invalid schema """ - Query API: Test chart data query with invalid schema + self.login(username="admin") + table = self.get_table_by_name("birth_names") + payload = get_query_context(table.name, table.id, table.type) + payload["datasource"] = "abc" + uri = "api/v1/chart/data" + rv = self.post_assert_metric(uri, payload, "data") + self.assertEqual(rv.status_code, 400) + + def test_chart_data_with_invalid_enum_value(self): + """Query API: Test chart data query with invalid enum value """ self.login(username="admin") - query_context = self._get_query_context() - query_context["datasource"] = "abc" + table = self.get_table_by_name("birth_names") + payload = get_query_context(table.name, table.id, table.type) + payload["queries"][0]["extras"]["time_range_endpoints"] = [ + "abc", + "EXCLUSIVE", + ] uri = "api/v1/chart/data" - rv = self.client.post(uri, json=query_context) + rv = self.client.post(uri, json=payload) self.assertEqual(rv.status_code, 400) def test_query_exec_not_allowed(self): @@ -685,9 +675,10 @@ class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin): Query API: Test chart data query not allowed """ self.login(username="gamma") - query_context = self._get_query_context() + table = self.get_table_by_name("birth_names") + payload = get_query_context(table.name, table.id, table.type) uri = "api/v1/chart/data" - rv = self.post_assert_metric(uri, query_context, "data") + rv = self.post_assert_metric(uri, payload, "data") self.assertEqual(rv.status_code, 401) def test_datasources(self): diff --git a/tests/core_tests.py b/tests/core_tests.py index 04cc587..ecbfa05 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -28,7 +28,6 @@ import pytz import random import re import string -from typing import Any, Dict import unittest from unittest import mock, skipUnless @@ -44,8 +43,6 @@ from superset import ( sql_lab, is_feature_enabled, ) -from superset.common.query_context import QueryContext -from superset.connectors.connector_registry import ConnectorRegistry from superset.connectors.sqla.models import SqlaTable from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.mssql import MssqlEngineSpec @@ -111,61 +108,6 @@ class CoreTests(SupersetTestCase): resp = self.client.get("/superset/slice/-1/") assert resp.status_code == 404 - def _get_query_context(self) -> Dict[str, Any]: - self.login(username="admin") - slc = self.get_slice("Girl Name Cloud", db.session) - return { - "datasource": {"id": slc.datasource_id, "type": slc.datasource_type}, - "queries": [ - { - "granularity": "ds", - "groupby": ["name"], - "metrics": [{"label": "sum__num"}], - "filters": [], - "row_limit": 100, - } - ], - } - - def _get_query_context_with_post_processing(self) -> Dict[str, Any]: - self.login(username="admin") - slc = self.get_slice("Girl Name Cloud", db.session) - return { - "datasource": {"id": slc.datasource_id, "type": slc.datasource_type}, - "queries": [ - { - "granularity": "ds", - "groupby": ["name", "state"], - "metrics": [{"label": "sum__num"}], - "filters": [], - "row_limit": 100, - "post_processing": [ - { - "operation": "aggregate", - "options": { - "groupby": ["state"], - "aggregates": { - "q1": { - "operator": "percentile", - "column": "sum__num", - "options": {"q": 25}, - }, - "median": { - "operator": "median", - "column": "sum__num", - }, - }, - }, - }, - { - "operation": "sort", - "options": {"columns": {"q1": False, "state": True},}, - }, - ], - } - ], - } - def test_viz_cache_key(self): self.login(username="admin") slc = self.get_slice("Girls", db.session) @@ -178,45 +120,6 @@ class CoreTests(SupersetTestCase): qobj["groupby"] = [] self.assertNotEqual(cache_key, viz.cache_key(qobj)) - def test_cache_key_changes_when_datasource_is_updated(self): - qc_dict = self._get_query_context() - - # construct baseline cache_key - query_context = QueryContext(**qc_dict) - query_object = query_context.queries[0] - cache_key_original = query_context.cache_key(query_object) - - # make temporary change and revert it to refresh the changed_on property - datasource = ConnectorRegistry.get_datasource( - datasource_type=qc_dict["datasource"]["type"], - datasource_id=qc_dict["datasource"]["id"], - session=db.session, - ) - description_original = datasource.description - datasource.description = "temporary description" - db.session.commit() - datasource.description = description_original - db.session.commit() - - # create new QueryContext with unchanged attributes and extract new cache_key - query_context = QueryContext(**qc_dict) - query_object = query_context.queries[0] - cache_key_new = query_context.cache_key(query_object) - - # the new cache_key should be different due to updated datasource - self.assertNotEqual(cache_key_original, cache_key_new) - - def test_query_context_time_range_endpoints(self): - query_context = QueryContext(**self._get_query_context()) - query_object = query_context.queries[0] - extras = query_object.to_dict()["extras"] - self.assertTrue("time_range_endpoints" in extras) - - self.assertEquals( - extras["time_range_endpoints"], - (utils.TimeRangeEndpoint.INCLUSIVE, utils.TimeRangeEndpoint.EXCLUSIVE), - ) - def test_get_superset_tables_not_allowed(self): example_db = utils.get_example_database() schema_name = self.default_schema_backend_map[example_db.backend] @@ -254,20 +157,6 @@ class CoreTests(SupersetTestCase): rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) - def test_api_v1_query_endpoint(self): - self.login(username="admin") - qc_dict = self._get_query_context() - data = json.dumps(qc_dict) - resp = json.loads(self.get_resp("/api/v1/query/", {"query_context": data})) - self.assertEqual(resp[0]["rowcount"], 100) - - def test_api_v1_query_endpoint_with_post_processing(self): - self.login(username="admin") - qc_dict = self._get_query_context_with_post_processing() - data = json.dumps(qc_dict) - resp = json.loads(self.get_resp("/api/v1/query/", {"query_context": data})) - self.assertEqual(resp[0]["rowcount"], 6) - def test_old_slice_json_endpoint(self): self.login(username="admin") slc = self.get_slice("Girls", db.session) diff --git a/tests/dict_import_export_tests.py b/tests/dict_import_export_tests.py index 404709c..e857479 100644 --- a/tests/dict_import_export_tests.py +++ b/tests/dict_import_export_tests.py @@ -165,7 +165,7 @@ class DictImportExportTests(SupersetTestCase): new_table = SqlaTable.import_from_dict(db.session, dict_table) db.session.commit() imported_id = new_table.id - imported = self.get_table(imported_id) + imported = self.get_table_by_id(imported_id) self.assert_table_equals(table, imported) self.yaml_compare(table.export_to_dict(), imported.export_to_dict()) @@ -178,7 +178,7 @@ class DictImportExportTests(SupersetTestCase): ) imported_table = SqlaTable.import_from_dict(db.session, dict_table) db.session.commit() - imported = self.get_table(imported_table.id) + imported = self.get_table_by_id(imported_table.id) self.assert_table_equals(table, imported) self.assertEqual( {DBREF: ID_PREFIX + 2, "database_name": "main"}, json.loads(imported.params) @@ -194,7 +194,7 @@ class DictImportExportTests(SupersetTestCase): ) imported_table = SqlaTable.import_from_dict(db.session, dict_table) db.session.commit() - imported = self.get_table(imported_table.id) + imported = self.get_table_by_id(imported_table.id) self.assert_table_equals(table, imported) self.yaml_compare(table.export_to_dict(), imported.export_to_dict()) @@ -213,7 +213,7 @@ class DictImportExportTests(SupersetTestCase): imported_over_table = SqlaTable.import_from_dict(db.session, dict_table_over) db.session.commit() - imported_over = self.get_table(imported_over_table.id) + imported_over = self.get_table_by_id(imported_over_table.id) self.assertEqual(imported_table.id, imported_over.id) expected_table, _ = self.create_table( "table_override", @@ -243,7 +243,7 @@ class DictImportExportTests(SupersetTestCase): ) db.session.commit() - imported_over = self.get_table(imported_over_table.id) + imported_over = self.get_table_by_id(imported_over_table.id) self.assertEqual(imported_table.id, imported_over.id) expected_table, _ = self.create_table( "table_override", @@ -274,7 +274,7 @@ class DictImportExportTests(SupersetTestCase): imported_copy_table = SqlaTable.import_from_dict(db.session, dict_copy_table) db.session.commit() self.assertEqual(imported_table.id, imported_copy_table.id) - self.assert_table_equals(copy_table, self.get_table(imported_table.id)) + self.assert_table_equals(copy_table, self.get_table_by_id(imported_table.id)) self.yaml_compare( imported_copy_table.export_to_dict(), imported_table.export_to_dict() ) diff --git a/tests/fixtures/query_context.py b/tests/fixtures/query_context.py new file mode 100644 index 0000000..e886fda --- /dev/null +++ b/tests/fixtures/query_context.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import copy +from typing import Any, Dict, List + +QUERY_OBJECTS = { + "birth_names": { + "extras": {"where": "", "time_range_endpoints": ["INCLUSIVE", "EXCLUSIVE"],}, + "granularity": "ds", + "groupby": ["name"], + "is_timeseries": False, + "metrics": [{"label": "sum__num"}], + "order_desc": True, + "orderby": [], + "row_limit": 100, + "time_range": "100 years ago : now", + "timeseries_limit": 0, + "timeseries_limit_metric": None, + "filters": [{"col": "gender", "op": "==", "val": "boy"}], + "having": "", + "having_filters": [], + "where": "", + } +} + +POSTPROCESSING_OPERATIONS = { + "birth_names": [ + { + "operation": "aggregate", + "options": { + "groupby": ["gender"], + "aggregates": { + "q1": { + "operator": "percentile", + "column": "sum__num", + "options": {"q": 25}, + }, + "median": {"operator": "median", "column": "sum__num",}, + }, + }, + }, + {"operation": "sort", "options": {"columns": {"q1": False, "gender": True},},}, + ] +} + + +def _get_query_object( + datasource_name: str, add_postprocessing_operations: bool +) -> Dict[str, Any]: + if datasource_name not in QUERY_OBJECTS: + raise Exception( + f"QueryObject fixture not defined for datasource: {datasource_name}" + ) + query_object = copy.deepcopy(QUERY_OBJECTS[datasource_name]) + if add_postprocessing_operations: + query_object["post_processing"] = _get_postprocessing_operation(datasource_name) + return query_object + + +def _get_postprocessing_operation(datasource_name: str) -> List[Dict[str, Any]]: + if datasource_name not in QUERY_OBJECTS: + raise Exception( + f"Post-processing fixture not defined for datasource: {datasource_name}" + ) + return copy.deepcopy(POSTPROCESSING_OPERATIONS[datasource_name]) + + +def get_query_context( + datasource_name: str = "birth_names", + datasource_id: int = 0, + datasource_type: str = "table", + add_postprocessing_operations: bool = False, +) -> Dict[str, Any]: + """ + Create a request payload for retrieving a QueryContext object via the + `api/v1/chart/data` endpoint. By default returns a payload corresponding to one + generated by the "Boy Name Cloud" chart in the examples. + + :param datasource_name: name of datasource to query. Different datasources require + different parameters in the QueryContext. + :param datasource_id: id of datasource to query. + :param datasource_type: type of datasource to query. + :param add_postprocessing_operations: Add post-processing operations to QueryObject + :return: Request payload + """ + return { + "datasource": {"id": datasource_id, "type": datasource_type}, + "queries": [_get_query_object(datasource_name, add_postprocessing_operations)], + } diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index 458ab04..b47b43b 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -558,7 +558,7 @@ class ImportExportTests(SupersetTestCase): def test_import_table_no_metadata(self): table = self.create_table("pure_table", id=10001) imported_id = SqlaTable.import_obj(table, import_time=1989) - imported = self.get_table(imported_id) + imported = self.get_table_by_id(imported_id) self.assert_table_equals(table, imported) def test_import_table_1_col_1_met(self): @@ -566,7 +566,7 @@ class ImportExportTests(SupersetTestCase): "table_1_col_1_met", id=10002, cols_names=["col1"], metric_names=["metric1"] ) imported_id = SqlaTable.import_obj(table, import_time=1990) - imported = self.get_table(imported_id) + imported = self.get_table_by_id(imported_id) self.assert_table_equals(table, imported) self.assertEqual( {"remote_id": 10002, "import_time": 1990, "database_name": "examples"}, @@ -582,7 +582,7 @@ class ImportExportTests(SupersetTestCase): ) imported_id = SqlaTable.import_obj(table, import_time=1991) - imported = self.get_table(imported_id) + imported = self.get_table_by_id(imported_id) self.assert_table_equals(table, imported) def test_import_table_override(self): @@ -599,7 +599,7 @@ class ImportExportTests(SupersetTestCase): ) imported_over_id = SqlaTable.import_obj(table_over, import_time=1992) - imported_over = self.get_table(imported_over_id) + imported_over = self.get_table_by_id(imported_over_id) self.assertEqual(imported_id, imported_over.id) expected_table = self.create_table( "table_override", @@ -627,7 +627,7 @@ class ImportExportTests(SupersetTestCase): imported_id_copy = SqlaTable.import_obj(copy_table, import_time=1994) self.assertEqual(imported_id, imported_id_copy) - self.assert_table_equals(copy_table, self.get_table(imported_id)) + self.assert_table_equals(copy_table, self.get_table_by_id(imported_id)) def test_import_druid_no_metadata(self): datasource = self.create_druid_datasource("pure_druid", id=10001) diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py new file mode 100644 index 0000000..2d378ba --- /dev/null +++ b/tests/query_context_tests.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, List, Optional + +from superset import db +from superset.common.query_context import QueryContext +from superset.connectors.connector_registry import ConnectorRegistry +from superset.utils.core import TimeRangeEndpoint +from tests.base_tests import SupersetTestCase +from tests.fixtures.query_context import get_query_context +from tests.test_app import app + + +class QueryContextTests(SupersetTestCase): + def test_cache_key_changes_when_datasource_is_updated(self): + self.login(username="admin") + table_name = "birth_names" + table = self.get_table_by_name(table_name) + payload = get_query_context(table.name, table.id, table.type) + + # construct baseline cache_key + query_context = QueryContext(**payload) + query_object = query_context.queries[0] + cache_key_original = query_context.cache_key(query_object) + + # make temporary change and revert it to refresh the changed_on property + datasource = ConnectorRegistry.get_datasource( + datasource_type=payload["datasource"]["type"], + datasource_id=payload["datasource"]["id"], + session=db.session, + ) + description_original = datasource.description + datasource.description = "temporary description" + db.session.commit() + datasource.description = description_original + db.session.commit() + + # create new QueryContext with unchanged attributes and extract new cache_key + query_context = QueryContext(**payload) + query_object = query_context.queries[0] + cache_key_new = query_context.cache_key(query_object) + + # the new cache_key should be different due to updated datasource + self.assertNotEqual(cache_key_original, cache_key_new) + + def test_query_context_time_range_endpoints(self): + """ + Ensure that time_range_endpoints are populated automatically when missing + from the payload + """ + self.login(username="admin") + table_name = "birth_names" + table = self.get_table_by_name(table_name) + payload = get_query_context(table.name, table.id, table.type) + del payload["queries"][0]["extras"]["time_range_endpoints"] + query_context = QueryContext(**payload) + query_object = query_context.queries[0] + extras = query_object.to_dict()["extras"] + self.assertTrue("time_range_endpoints" in extras) + + self.assertEquals( + extras["time_range_endpoints"], + (TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE), + ) + + def test_convert_deprecated_fields(self): + """ + Ensure that deprecated fields are converted correctly + """ + self.login(username="admin") + table_name = "birth_names" + table = self.get_table_by_name(table_name) + payload = get_query_context(table.name, table.id, table.type) + payload["queries"][0]["granularity_sqla"] = "timecol" + payload["queries"][0]["having_filters"] = {"col": "a", "op": "==", "val": "b"} + query_context = QueryContext(**payload) + self.assertEqual(len(query_context.queries), 1) + query_object = query_context.queries[0] + self.assertEqual(query_object.granularity, "timecol") + self.assertIn("having_druid", query_object.extras)