This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 8e439b1 chore: Add OpenAPI docs to /api/v1/chart/data EP (#9556)
8e439b1 is described below
commit 8e439b1115481b6df7f8af616ac683f399d52893
Author: Ville Brofeldt <[email protected]>
AuthorDate: Fri Apr 17 16:44:16 2020 +0300
chore: Add OpenAPI docs to /api/v1/chart/data EP (#9556)
* Add OpenAPI docs to /api/v1/chart/data EP
* Add missing fields to QueryObject, fix linting and unit test errors
* Fix unit test errors
* abc
* Fix errors uncovered by schema validation and add unit test for invalid
payload
* Add schema for response
* Remove unnecessary pylint disable
---
setup.cfg | 2 +-
superset/charts/api.py | 84 ++----
superset/charts/schemas.py | 512 +++++++++++++++++++++++++++++++-
superset/common/query_object.py | 21 +-
superset/connectors/base/models.py | 28 +-
superset/connectors/druid/models.py | 71 +++--
superset/connectors/sqla/models.py | 48 +--
superset/examples/birth_names.py | 2 +-
superset/examples/world_bank.py | 2 +-
superset/typing.py | 2 +
superset/utils/core.py | 49 ++-
superset/utils/pandas_postprocessing.py | 9 +-
tests/charts/api_tests.py | 13 +-
13 files changed, 695 insertions(+), 148 deletions(-)
diff --git a/setup.cfg b/setup.cfg
index b535c63..566633c 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -45,7 +45,7 @@ combine_as_imports = true
include_trailing_comma = true
line_length = 88
known_first_party = superset
-known_third_party
=alembic,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlpars
[...]
+known_third_party
=alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils
[...]
multi_line_output = 3
order_by_type = false
diff --git a/superset/charts/api.py b/superset/charts/api.py
index be0df11..be4f407 100644
--- a/superset/charts/api.py
+++ b/superset/charts/api.py
@@ -18,10 +18,11 @@ import logging
from typing import Any, Dict
import simplejson
+from apispec import APISpec
from flask import g, make_response, redirect, request, Response, url_for
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
-from flask_babel import ngettext
+from flask_babel import gettext as _, ngettext
from werkzeug.wrappers import Response as WerkzeugResponse
from werkzeug.wsgi import FileWrapper
@@ -41,12 +42,13 @@ from superset.charts.commands.exceptions import (
from superset.charts.commands.update import UpdateChartCommand
from superset.charts.filters import ChartFilter, ChartNameOrDescriptionFilter
from superset.charts.schemas import (
+ CHART_DATA_SCHEMAS,
+ ChartDataQueryContextSchema,
ChartPostSchema,
ChartPutSchema,
get_delete_ids_schema,
thumbnail_query_schema,
)
-from superset.common.query_context import QueryContext
from superset.constants import RouteMethod
from superset.exceptions import SupersetSecurityException
from superset.extensions import event_logger, security_manager
@@ -381,74 +383,21 @@ class ChartRestApi(BaseSupersetModelRestApi):
Takes a query context constructed in the client and returns
payload data
response for the given query.
requestBody:
- description: Query context schema
+ description: >-
+ A query context consists of a datasource from which to fetch data
+ and one or many query objects.
required: true
content:
application/json:
schema:
- type: object
- properties:
- datasource:
- type: object
- description: The datasource where the query will run
- properties:
- id:
- type: integer
- type:
- type: string
- queries:
- type: array
- items:
- type: object
- properties:
- granularity:
- type: string
- groupby:
- type: array
- items:
- type: string
- metrics:
- type: array
- items:
- type: object
- filters:
- type: array
- items:
- type: string
- row_limit:
- type: integer
+ $ref: "#/components/schemas/ChartDataQueryContextSchema"
responses:
200:
description: Query result
content:
application/json:
schema:
- type: array
- items:
- type: object
- properties:
- cache_key:
- type: string
- cached_dttm:
- type: string
- cache_timeout:
- type: integer
- error:
- type: string
- is_cached:
- type: boolean
- query:
- type: string
- status:
- type: string
- stacktrace:
- type: string
- rowcount:
- type: integer
- data:
- type: array
- items:
- type: object
+ $ref: "#/components/schemas/ChartDataResponseSchema"
400:
$ref: '#/components/responses/400'
500:
@@ -457,7 +406,11 @@ class ChartRestApi(BaseSupersetModelRestApi):
if not request.is_json:
return self.response_400(message="Request is not JSON")
try:
- query_context = QueryContext(**request.json)
+ query_context, errors =
ChartDataQueryContextSchema().load(request.json)
+ if errors:
+ return self.response_400(
+ message=_("Request is incorrect: %(error)s", error=errors)
+ )
except KeyError:
return self.response_400(message="Request is incorrect")
try:
@@ -466,7 +419,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
return self.response_401()
payload_json = query_context.get_payload()
response_data = simplejson.dumps(
- payload_json, default=json_int_dttm_ser, ignore_nan=True
+ {"result": payload_json}, default=json_int_dttm_ser,
ignore_nan=True
)
resp = make_response(response_data, 200)
resp.headers["Content-Type"] = "application/json; charset=utf-8"
@@ -533,3 +486,10 @@ class ChartRestApi(BaseSupersetModelRestApi):
return Response(
FileWrapper(screenshot), mimetype="image/png",
direct_passthrough=True
)
+
+ def add_apispec_components(self, api_spec: APISpec) -> None:
+ for chart_type in CHART_DATA_SCHEMAS:
+ api_spec.components.schema(
+ chart_type.__name__, schema=chart_type,
+ )
+ super().add_apispec_components(api_spec)
diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index bf1b57b..0a7035c 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -14,11 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Union
+from typing import Any, Dict, Union
-from marshmallow import fields, Schema, ValidationError
+from marshmallow import fields, post_load, Schema, ValidationError
from marshmallow.validate import Length
+from superset.common.query_context import QueryContext
from superset.exceptions import SupersetException
from superset.utils import core as utils
@@ -59,3 +60,510 @@ class ChartPutSchema(Schema):
datasource_id = fields.Integer(allow_none=True)
datasource_type = fields.String(allow_none=True)
dashboards = fields.List(fields.Integer())
+
+
+class ChartDataColumnSchema(Schema):
+ column_name = fields.String(
+ description="The name of the target column", example="mycol",
+ )
+ type = fields.String(description="Type of target column",
example="BIGINT",)
+
+
+class ChartDataAdhocMetricSchema(Schema):
+ """
+ Ad-hoc metrics are used to define metrics outside the datasource.
+ """
+
+ expressionType = fields.String(
+ description="Simple or SQL metric",
+ required=True,
+ enum=["SIMPLE", "SQL"],
+ example="SQL",
+ )
+ aggregate = fields.String(
+ description="Aggregation operator. Only required for simple expression
types.",
+ required=False,
+ enum=["AVG", "COUNT", "COUNT_DISTINCT", "MAX", "MIN", "SUM"],
+ )
+ column = fields.Nested(ChartDataColumnSchema)
+ sqlExpression = fields.String(
+ description="The metric as defined by a SQL aggregate expression. "
+ "Only required for SQL expression type.",
+ required=False,
+ example="SUM(weight * observations) / SUM(weight)",
+ )
+ label = fields.String(
+ description="Label for the metric. Is automatically generated unless "
+ "hasCustomLabel is true, in which case label must be defined.",
+ required=False,
+ example="Weighted observations",
+ )
+ hasCustomLabel = fields.Boolean(
+ description="When false, the label will be automatically generated
based on "
+ "the aggregate expression. When true, a custom label has to be "
+ "specified.",
+ required=False,
+ example=True,
+ )
+ optionName = fields.String(
+ description="Unique identifier. Can be any string value, as long as
all "
+ "metrics have a unique identifier. If undefined, a random name "
+ "will be generated.",
+ required=False,
+ example="metric_aec60732-fac0-4b17-b736-93f1a5c93e30",
+ )
+
+
+class ChartDataAggregateConfigField(fields.Dict):
+ def __init__(self) -> None:
+ super().__init__(
+ description="The keys are the name of the aggregate column to be
created, "
+ "and the values specify the details of how to apply the "
+ "aggregation. If an operator requires additional options, "
+ "these can be passed here to be unpacked in the operator call. The
"
+ "following numpy operators are supported: average, argmin, argmax,
cumsum, "
+ "cumprod, max, mean, median, nansum, nanmin, nanmax, nanmean,
nanmedian, "
+ "min, percentile, prod, product, std, sum, var. Any options
required by "
+ "the operator can be passed to the `options` object.\n"
+ "\n"
+ "In the example, a new column `first_quantile` is created based on
values "
+ "in the column `my_col` using the `percentile` operator with "
+ "the `q=0.25` parameter.",
+ example={
+ "first_quantile": {
+ "operator": "percentile",
+ "column": "my_col",
+ "options": {"q": 0.25},
+ }
+ },
+ )
+
+
+class ChartDataPostProcessingOperationOptionsSchema(Schema):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ super().__init__(*args, **kwargs)
+
+
+class
ChartDataAggregateOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
+ """
+ Aggregate operation config.
+ """
+
+ groupby = (
+ fields.List(
+ fields.String(
+ allow_none=False, description="Columns by which to group by",
+ ),
+ minLength=1,
+ required=True,
+ ),
+ )
+ aggregates = ChartDataAggregateConfigField()
+
+
+class
ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
+ """
+ Rolling operation config.
+ """
+
+ columns = (
+ fields.Dict(
+ description="columns on which to perform rolling, mapping source
column to "
+ "target column. For instance, `{'y': 'y'}` will replace the "
+ "column `y` with the rolling value in `y`, while `{'y': 'y2'}` "
+ "will add a column `y2` based on rolling values calculated "
+ "from `y`, leaving the original column `y` unchanged.",
+ example={"weekly_rolling_sales": "sales"},
+ ),
+ )
+ rolling_type = fields.String(
+ description="Type of rolling window. Any numpy function will work.",
+ enum=[
+ "average",
+ "argmin",
+ "argmax",
+ "cumsum",
+ "cumprod",
+ "max",
+ "mean",
+ "median",
+ "nansum",
+ "nanmin",
+ "nanmax",
+ "nanmean",
+ "nanmedian",
+ "min",
+ "percentile",
+ "prod",
+ "product",
+ "std",
+ "sum",
+ "var",
+ ],
+ required=True,
+ example="percentile",
+ )
+ window = fields.Integer(
+ description="Size of the rolling window in days.", required=True,
example=7,
+ )
+ rolling_type_options = fields.Dict(
+ desctiption="Optional options to pass to rolling method. Needed for "
+ "e.g. quantile operation.",
+ required=False,
+ example={},
+ )
+ center = fields.Boolean(
+ description="Should the label be at the center of the window. Default:
`false`",
+ required=False,
+ example=False,
+ )
+ win_type = fields.String(
+ description="Type of window function. See "
+ "[SciPy window functions](https://docs.scipy.org/doc/scipy/reference"
+ "/signal.windows.html#module-scipy.signal.windows) "
+ "for more details. Some window functions require passing "
+ "additional parameters to `rolling_type_options`. For instance, "
+ "to use `gaussian`, the parameter `std` needs to be provided.",
+ required=False,
+ enum=[
+ "boxcar",
+ "triang",
+ "blackman",
+ "hamming",
+ "bartlett",
+ "parzen",
+ "bohman",
+ "blackmanharris",
+ "nuttall",
+ "barthann",
+ "kaiser",
+ "gaussian",
+ "general_gaussian",
+ "slepian",
+ "exponential",
+ ],
+ )
+ min_periods = fields.Integer(
+ description="The minimum amount of periods required for a row to be
included "
+ "in the result set.",
+ required=False,
+ example=7,
+ )
+
+
+class
ChartDataSelectOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
+ """
+ Sort operation config.
+ """
+
+ columns = fields.List(
+ fields.String(),
+ description="Columns which to select from the input data, in the
desired "
+ "order. If columns are renamed, the old column name should be "
+ "referenced here.",
+ example=["country", "gender", "age"],
+ )
+ rename = fields.List(
+ fields.Dict(),
+ description="columns which to rename, mapping source column to target
column. "
+ "For instance, `{'y': 'y2'}` will rename the column `y` to `y2`.",
+ example=[{"age": "average_age"}],
+ )
+
+
+class
ChartDataSortOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
+ """
+ Sort operation config.
+ """
+
+ columns = fields.Dict(
+ description="columns by by which to sort. The key specifies the column
name, "
+ "value specifies if sorting in ascending order.",
+ example={"country": True, "gender": False},
+ required=True,
+ )
+ aggregates = ChartDataAggregateConfigField()
+
+
+class
ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
+ """
+ Pivot operation config.
+ """
+
+ index = (
+ fields.List(
+ fields.String(
+ allow_none=False,
+ description="Columns to group by on the table index (=rows)",
+ ),
+ minLength=1,
+ required=True,
+ ),
+ )
+ columns = fields.List(
+ fields.String(
+ allow_none=False, description="Columns to group by on the table
columns",
+ ),
+ minLength=1,
+ required=True,
+ )
+ metric_fill_value = fields.Number(
+ required=False,
+ description="Value to replace missing values with in aggregate
calculations.",
+ )
+ column_fill_value = fields.String(
+ required=False, description="Value to replace missing pivot columns
names with."
+ )
+ drop_missing_columns = fields.Boolean(
+ description="Do not include columns whose entries are all missing "
+ "(default: `true`).",
+ required=False,
+ )
+ marginal_distributions = fields.Boolean(
+ description="Add totals for row/column. (default: `false`)",
required=False,
+ )
+ marginal_distribution_name = fields.String(
+ description="Name of marginal distribution row/column. (default:
`All`)",
+ required=False,
+ )
+ aggregates = ChartDataAggregateConfigField()
+
+
+class ChartDataPostProcessingOperationSchema(Schema):
+ operation = fields.String(
+ description="Post processing operation type",
+ required=True,
+ enum=["aggregate", "pivot", "rolling", "select", "sort"],
+ example="aggregate",
+ )
+ options = fields.Nested(
+ ChartDataPostProcessingOperationOptionsSchema,
+ description="Options specifying how to perform the operation. Please
refer "
+ "to the respective post processing operation option schemas. "
+ "For example, `ChartDataPostProcessingOperationOptions` specifies "
+ "the required options for the pivot operation.",
+ example={
+ "groupby": ["country", "gender"],
+ "aggregates": {
+ "age_q1": {
+ "operator": "percentile",
+ "column": "age",
+ "options": {"q": 0.25},
+ },
+ "age_mean": {"operator": "mean", "column": "age",},
+ },
+ },
+ )
+
+
+class ChartDataFilterSchema(Schema):
+ col = fields.String(
+ description="The column to filter.", required=True, example="country"
+ )
+ op = fields.String( # pylint: disable=invalid-name
+ description="The comparison operator.",
+ enum=[filter_op.value for filter_op in utils.FilterOperationType],
+ required=True,
+ example="IN",
+ )
+ val = fields.Raw(
+ description="The value or values to compare against. Can be a string, "
+ "integer, decimal or list, depending on the operator.",
+ example=["China", "France", "Japan"],
+ )
+
+
+class ChartDataExtrasSchema(Schema):
+
+ time_range_endpoints = fields.List(
+ fields.String(enum=["INCLUSIVE", "EXCLUSIVE"]),
+ description="A list with two values, stating if start/end should be "
+ "inclusive/exclusive.",
+ required=False,
+ )
+ relative_start = fields.String(
+ description="Start time for relative time deltas. "
+ 'Default: `config["DEFAULT_RELATIVE_START_TIME"]`',
+ enum=["today", "now"],
+ required=False,
+ )
+ relative_end = fields.String(
+ description="End time for relative time deltas. "
+ 'Default: `config["DEFAULT_RELATIVE_START_TIME"]`',
+ enum=["today", "now"],
+ required=False,
+ )
+
+
+class ChartDataQueryObjectSchema(Schema):
+ filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False)
+ granularity = fields.String(
+ description="To what level of granularity should the temporal column
be "
+ "aggregated. Supports "
+ "[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) "
+ "durations.",
+ enum=[
+ "PT1S",
+ "PT1M",
+ "PT5M",
+ "PT10M",
+ "PT15M",
+ "PT0.5H",
+ "PT1H",
+ "P1D",
+ "P1W",
+ "P1M",
+ "P0.25Y",
+ "P1Y",
+ ],
+ required=False,
+ example="P1D",
+ )
+ groupby = fields.List(
+ fields.String(description="Columns by which to group the query.",),
+ )
+ metrics = fields.List(
+ fields.Raw(),
+ description="Aggregate expressions. Metrics can be passed as both "
+ "references to datasource metrics (strings), or ad-hoc metrics"
+ "which are defined only within the query object. See "
+ "`ChartDataAdhocMetricSchema` for the structure of ad-hoc metrics.",
+ )
+ post_processing = fields.List(
+ fields.Nested(ChartDataPostProcessingOperationSchema),
+ description="Post processing operations to be applied to the result
set. "
+ "Operations are applied to the result set in sequential order.",
+ required=False,
+ )
+ time_range = fields.String(
+ description="A time rage, either expressed as a colon separated string
"
+ "`since : until`. Valid formats for `since` and `until` are: \n"
+ "- ISO 8601\n"
+ "- X days/years/hours/day/year/weeks\n"
+ "- X days/years/hours/day/year/weeks ago\n"
+ "- X days/years/hours/day/year/weeks from now\n"
+ "\n"
+ "Additionally, the following freeform can be used:\n"
+ "\n"
+ "- Last day\n"
+ "- Last week\n"
+ "- Last month\n"
+ "- Last quarter\n"
+ "- Last year\n"
+ "- No filter\n"
+ "- Last X seconds/minutes/hours/days/weeks/months/years\n"
+ "- Next X seconds/minutes/hours/days/weeks/months/years\n",
+ required=False,
+ example="Last week",
+ )
+ time_shift = fields.String(
+ description="A human-readable date/time string. "
+ "Please refer to [parsdatetime](https://github.com/bear/parsedatetime)
"
+ "documentation for details on valid values.",
+ required=False,
+ )
+ is_timeseries = fields.Boolean(
+ description="Is the `query_object` a timeseries.", required=False
+ )
+ timeseries_limit = fields.Integer(
+ description="Maximum row count for timeseries queries. Default: `0`",
+ required=False,
+ )
+ row_limit = fields.Integer(
+ description='Maximum row count. Default: `config["ROW_LIMIT"]`',
required=False,
+ )
+ order_desc = fields.Boolean(
+ description="Reverse order. Default: `false`", required=False
+ )
+ extras = fields.Dict(description=" Default: `{}`", required=False)
+ columns = fields.List(fields.String(), description="", required=False,)
+ orderby = fields.List(
+ fields.List(fields.Raw()),
+ description="Expects a list of lists where the first element is the
column "
+ "name which to sort by, and the second element is a boolean ",
+ required=False,
+ example=[["my_col_1", False], ["my_col_2", True]],
+ )
+
+
+class ChartDataDatasourceSchema(Schema):
+ description = "Chart datasource"
+ id = fields.Integer(description="Datasource id", required=True,)
+ type = fields.String(description="Datasource type", enum=["druid", "sql"])
+
+
+class ChartDataQueryContextSchema(Schema):
+ datasource = fields.Nested(ChartDataDatasourceSchema)
+ queries = fields.List(fields.Nested(ChartDataQueryObjectSchema))
+
+ # pylint: disable=no-self-use
+ @post_load
+ def make_query_context(self, data: Dict[str, Any]) -> QueryContext:
+ query_context = QueryContext(**data)
+ return query_context
+
+ # pylint: enable=no-self-use
+
+
+class ChartDataResponseResult(Schema):
+ cache_key = fields.String(
+ description="Unique cache key for query object", required=True,
allow_none=True,
+ )
+ cached_dttm = fields.String(
+ description="Cache timestamp", required=True, allow_none=True,
+ )
+ cache_timeout = fields.Integer(
+ description="Cache timeout in following order: custom timeout,
datasource "
+ "timeout, default config timeout.",
+ required=True,
+ allow_none=True,
+ )
+ error = fields.String(description="Error", allow_none=True,)
+ is_cached = fields.Boolean(
+ description="Is the result cached", required=True, allow_none=None,
+ )
+ query = fields.String(
+ description="The executed query statement", required=True,
allow_none=False,
+ )
+ status = fields.String(
+ description="Status of the query",
+ enum=[
+ "stopped",
+ "failed",
+ "pending",
+ "running",
+ "scheduled",
+ "success",
+ "timed_out",
+ ],
+ allow_none=False,
+ )
+ stacktrace = fields.String(
+ desciption="Stacktrace if there was an error", allow_none=True,
+ )
+ rowcount = fields.Integer(
+ description="Amount of rows in result set", allow_none=False,
+ )
+ data = fields.List(fields.Dict(), description="A list with results")
+
+
+class ChartDataResponseSchema(Schema):
+ result = fields.List(
+ fields.Nested(ChartDataResponseResult),
+ description="A list of results for each corresponding query in the
request.",
+ )
+
+
+CHART_DATA_SCHEMAS = (
+ ChartDataQueryContextSchema,
+ ChartDataResponseSchema,
+ # TODO: These should optimally be included in the QueryContext schema as
an `anyOf`
+ # in ChartDataPostPricessingOperation.options, but since `anyOf` is not
+ # by Marshmallow<3, this is not currently possible.
+ ChartDataAdhocMetricSchema,
+ ChartDataAggregateOptionsSchema,
+ ChartDataPivotOptionsSchema,
+ ChartDataRollingOptionsSchema,
+ ChartDataSelectOptionsSchema,
+ ChartDataSortOptionsSchema,
+)
diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index 63158c6..31a6241 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -47,9 +47,9 @@ class QueryObject:
is_timeseries: bool
time_shift: Optional[timedelta]
groupby: List[str]
- metrics: List[Union[Dict, str]]
+ metrics: List[Union[Dict[str, Any], str]]
row_limit: int
- filter: List[str]
+ filter: List[Dict[str, Any]]
timeseries_limit: int
timeseries_limit_metric: Optional[Dict]
order_desc: bool
@@ -61,9 +61,9 @@ class QueryObject:
def __init__(
self,
granularity: str,
- metrics: List[Union[Dict, str]],
+ metrics: List[Union[Dict[str, Any], str]],
groupby: Optional[List[str]] = None,
- filters: Optional[List[str]] = None,
+ filters: Optional[List[Dict[str, Any]]] = None,
time_range: Optional[str] = None,
time_shift: Optional[str] = None,
is_timeseries: bool = False,
@@ -75,14 +75,17 @@ class QueryObject:
columns: Optional[List[str]] = None,
orderby: Optional[List[List]] = None,
post_processing: Optional[List[Dict[str, Any]]] = None,
- relative_start: str = app.config["DEFAULT_RELATIVE_START_TIME"],
- relative_end: str = app.config["DEFAULT_RELATIVE_END_TIME"],
):
+ extras = extras or {}
is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE")
self.granularity = granularity
self.from_dttm, self.to_dttm = utils.get_since_until(
- relative_start=relative_start,
- relative_end=relative_end,
+ relative_start=extras.get(
+ "relative_start", app.config["DEFAULT_RELATIVE_START_TIME"]
+ ),
+ relative_end=extras.get(
+ "relative_end", app.config["DEFAULT_RELATIVE_END_TIME"]
+ ),
time_range=time_range,
time_shift=time_shift,
)
@@ -106,7 +109,7 @@ class QueryObject:
self.timeseries_limit = timeseries_limit
self.timeseries_limit_metric = timeseries_limit_metric
self.order_desc = order_desc
- self.extras = extras or {}
+ self.extras = extras
if app.config["SIP_15_ENABLED"] and "time_range_endpoints" not in
self.extras:
self.extras["time_range_endpoints"] =
get_time_range_endpoints(form_data={})
diff --git a/superset/connectors/base/models.py
b/superset/connectors/base/models.py
index 8e1acc7..2b6e0d2 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -25,6 +25,7 @@ from sqlalchemy.orm import foreign, Query, relationship
from superset.constants import NULL_STRING
from superset.models.helpers import AuditMixinNullable, ImportMixin,
QueryResult
from superset.models.slice import Slice
+from superset.typing import FilterValue, FilterValues
from superset.utils import core as utils
METRIC_FORM_DATA_PARAMS = [
@@ -301,28 +302,33 @@ class BaseDatasource(
@staticmethod
def filter_values_handler(
- values, target_column_is_numeric=False, is_list_target=False
- ):
- def handle_single_value(v):
+ values: Optional[FilterValues],
+ target_column_is_numeric: bool = False,
+ is_list_target: bool = False,
+ ) -> Optional[FilterValues]:
+ if values is None:
+ return None
+
+ def handle_single_value(value: Optional[FilterValue]) ->
Optional[FilterValue]:
# backward compatibility with previous <select> components
- if isinstance(v, str):
- v = v.strip("\t\n'\"")
+ if isinstance(value, str):
+ value = value.strip("\t\n'\"")
if target_column_is_numeric:
# For backwards compatibility and edge cases
# where a column data type might have changed
- v = utils.string_to_num(v)
- if v == NULL_STRING:
+ value = utils.cast_to_num(value)
+ if value == NULL_STRING:
return None
- elif v == "<empty string>":
+ elif value == "<empty string>":
return ""
- return v
+ return value
if isinstance(values, (list, tuple)):
- values = [handle_single_value(v) for v in values]
+ values = [handle_single_value(v) for v in values] # type: ignore
else:
values = handle_single_value(values)
if is_list_target and not isinstance(values, (tuple, list)):
- values = [values]
+ values = [values] # type: ignore
elif not is_list_target and isinstance(values, (tuple, list)):
if values:
values = values[0]
diff --git a/superset/connectors/druid/models.py
b/superset/connectors/druid/models.py
index a4cc527..20dd732 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -24,7 +24,7 @@ from copy import deepcopy
from datetime import datetime, timedelta
from distutils.version import LooseVersion
from multiprocessing.pool import ThreadPool
-from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
+from typing import cast, Dict, Iterable, List, Optional, Set, Tuple, Union
import pandas as pd
import sqlalchemy as sa
@@ -54,6 +54,7 @@ from superset.constants import NULL_STRING
from superset.exceptions import SupersetException
from superset.models.core import Database
from superset.models.helpers import AuditMixinNullable, ImportMixin,
QueryResult
+from superset.typing import FilterValues
from superset.utils import core as utils, import_datasource
try:
@@ -80,7 +81,12 @@ except ImportError:
pass
try:
- from superset.utils.core import DimSelector, DTTM_ALIAS, flasher
+ from superset.utils.core import (
+ DimSelector,
+ DTTM_ALIAS,
+ FilterOperationType,
+ flasher,
+ )
except ImportError:
pass
@@ -1483,13 +1489,20 @@ class DruidDatasource(Model, BaseDatasource):
"""Given Superset filter data structure, returns pydruid Filter(s)"""
filters = None
for flt in raw_filters:
- col = flt.get("col")
- op = flt.get("op")
- eq = flt.get("val")
+ col: Optional[str] = flt.get("col")
+ op: Optional[str] = flt["op"].upper() if "op" in flt else None
+ eq: Optional[FilterValues] = flt.get("val")
if (
not col
or not op
- or (eq is None and op not in ("IS NULL", "IS NOT NULL"))
+ or (
+ eq is None
+ and op
+ not in (
+ FilterOperationType.IS_NULL.value,
+ FilterOperationType.IS_NOT_NULL.value,
+ )
+ )
):
continue
@@ -1503,7 +1516,10 @@ class DruidDatasource(Model, BaseDatasource):
cond = None
is_numeric_col = col in num_cols
- is_list_target = op in ("in", "not in")
+ is_list_target = op in (
+ FilterOperationType.IN.value,
+ FilterOperationType.NOT_IN.value,
+ )
eq = cls.filter_values_handler(
eq,
is_list_target=is_list_target,
@@ -1512,15 +1528,16 @@ class DruidDatasource(Model, BaseDatasource):
# For these two ops, could have used Dimension,
# but it doesn't support extraction functions
- if op == "==":
+ if op == FilterOperationType.EQUALS.value:
cond = Filter(
dimension=col, value=eq, extraction_function=extraction_fn
)
- elif op == "!=":
+ elif op == FilterOperationType.NOT_EQUALS.value:
cond = ~Filter(
dimension=col, value=eq, extraction_function=extraction_fn
)
- elif op in ("in", "not in"):
+ elif is_list_target:
+ eq = cast(list, eq)
fields = []
# ignore the filter if it has no value
if not len(eq):
@@ -1540,9 +1557,9 @@ class DruidDatasource(Model, BaseDatasource):
for s in eq:
fields.append(Dimension(col) == s)
cond = Filter(type="or", fields=fields)
- if op == "not in":
+ if op == FilterOperationType.NOT_IN.value:
cond = ~cond
- elif op == "regex":
+ elif op == FilterOperationType.REGEX.value:
cond = Filter(
extraction_function=extraction_fn,
type="regex",
@@ -1552,7 +1569,7 @@ class DruidDatasource(Model, BaseDatasource):
# For the ops below, could have used pydruid's Bound,
# but it doesn't support extraction functions
- elif op == ">=":
+ elif op == FilterOperationType.GREATER_THAN_OR_EQUALS.value:
cond = Bound(
extraction_function=extraction_fn,
dimension=col,
@@ -1562,7 +1579,7 @@ class DruidDatasource(Model, BaseDatasource):
upper=None,
ordering=cls._get_ordering(is_numeric_col),
)
- elif op == "<=":
+ elif op == FilterOperationType.LESS_THAN_OR_EQUALS.value:
cond = Bound(
extraction_function=extraction_fn,
dimension=col,
@@ -1572,7 +1589,7 @@ class DruidDatasource(Model, BaseDatasource):
upper=eq,
ordering=cls._get_ordering(is_numeric_col),
)
- elif op == ">":
+ elif op == FilterOperationType.GREATER_THAN.value:
cond = Bound(
extraction_function=extraction_fn,
lowerStrict=True,
@@ -1582,7 +1599,7 @@ class DruidDatasource(Model, BaseDatasource):
upper=None,
ordering=cls._get_ordering(is_numeric_col),
)
- elif op == "<":
+ elif op == FilterOperationType.LESS_THAN.value:
cond = Bound(
extraction_function=extraction_fn,
upperStrict=True,
@@ -1592,9 +1609,9 @@ class DruidDatasource(Model, BaseDatasource):
upper=eq,
ordering=cls._get_ordering(is_numeric_col),
)
- elif op == "IS NULL":
+ elif op == FilterOperationType.IS_NULL.value:
cond = Filter(dimension=col, value="")
- elif op == "IS NOT NULL":
+ elif op == FilterOperationType.IS_NOT_NULL.value:
cond = ~Filter(dimension=col, value="")
if filters:
@@ -1610,21 +1627,25 @@ class DruidDatasource(Model, BaseDatasource):
def _get_having_obj(self, col: str, op: str, eq: str) -> "Having":
cond = None
- if op == "==":
+ if op == FilterOperationType.EQUALS.value:
if col in self.column_names:
cond = DimSelector(dimension=col, value=eq)
else:
cond = Aggregation(col) == eq
- elif op == ">":
+ elif op == FilterOperationType.GREATER_THAN.value:
cond = Aggregation(col) > eq
- elif op == "<":
+ elif op == FilterOperationType.LESS_THAN.value:
cond = Aggregation(col) < eq
return cond
def get_having_filters(self, raw_filters: List[Dict]) -> "Having":
filters = None
- reversed_op_map = {"!=": "==", ">=": "<", "<=": ">"}
+ reversed_op_map = {
+ FilterOperationType.NOT_EQUALS.value:
FilterOperationType.EQUALS.value,
+ FilterOperationType.GREATER_THAN_OR_EQUALS.value:
FilterOperationType.LESS_THAN.value,
+ FilterOperationType.LESS_THAN_OR_EQUALS.value:
FilterOperationType.GREATER_THAN.value,
+ }
for flt in raw_filters:
if not all(f in flt for f in ["col", "op", "val"]):
@@ -1633,7 +1654,11 @@ class DruidDatasource(Model, BaseDatasource):
op = flt["op"]
eq = flt["val"]
cond = None
- if op in ["==", ">", "<"]:
+ if op in [
+ FilterOperationType.EQUALS.value,
+ FilterOperationType.GREATER_THAN.value,
+ FilterOperationType.LESS_THAN.value,
+ ]:
cond = self._get_having_obj(col, op, eq)
elif op in reversed_op_map:
cond = ~self._get_having_obj(col, reversed_op_map[op], eq)
diff --git a/superset/connectors/sqla/models.py
b/superset/connectors/sqla/models.py
index c290363..8aac7e0 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -843,43 +843,53 @@ class SqlaTable(Model, BaseDatasource):
if not all([flt.get(s) for s in ["col", "op"]]):
continue
col = flt["col"]
- op = flt["op"]
+ op = flt["op"].upper()
col_obj = cols.get(col)
if col_obj:
- is_list_target = op in ("in", "not in")
+ is_list_target = op in (
+ utils.FilterOperationType.IN.value,
+ utils.FilterOperationType.NOT_IN.value,
+ )
eq = self.filter_values_handler(
- flt.get("val"),
+ values=flt.get("val"),
target_column_is_numeric=col_obj.is_numeric,
is_list_target=is_list_target,
)
- if op in ("in", "not in"):
+ if op in (
+ utils.FilterOperationType.IN.value,
+ utils.FilterOperationType.NOT_IN.value,
+ ):
cond = col_obj.get_sqla_col().in_(eq)
- if NULL_STRING in eq:
- cond = or_(cond, col_obj.get_sqla_col() == None)
- if op == "not in":
+ if isinstance(eq, str) and NULL_STRING in eq:
+ cond = or_(cond, col_obj.get_sqla_col() is None)
+ if op == utils.FilterOperationType.NOT_IN.value:
cond = ~cond
where_clause_and.append(cond)
else:
if col_obj.is_numeric:
- eq = utils.string_to_num(flt["val"])
- if op == "==":
+ eq = utils.cast_to_num(flt["val"])
+ if op == utils.FilterOperationType.EQUALS.value:
where_clause_and.append(col_obj.get_sqla_col() == eq)
- elif op == "!=":
+ elif op == utils.FilterOperationType.NOT_EQUALS.value:
where_clause_and.append(col_obj.get_sqla_col() != eq)
- elif op == ">":
+ elif op == utils.FilterOperationType.GREATER_THAN.value:
where_clause_and.append(col_obj.get_sqla_col() > eq)
- elif op == "<":
+ elif op == utils.FilterOperationType.LESS_THAN.value:
where_clause_and.append(col_obj.get_sqla_col() < eq)
- elif op == ">=":
+ elif op ==
utils.FilterOperationType.GREATER_THAN_OR_EQUALS.value:
where_clause_and.append(col_obj.get_sqla_col() >= eq)
- elif op == "<=":
+ elif op ==
utils.FilterOperationType.LESS_THAN_OR_EQUALS.value:
where_clause_and.append(col_obj.get_sqla_col() <= eq)
- elif op == "LIKE":
+ elif op == utils.FilterOperationType.LIKE.value:
where_clause_and.append(col_obj.get_sqla_col().like(eq))
- elif op == "IS NULL":
- where_clause_and.append(col_obj.get_sqla_col() == None)
- elif op == "IS NOT NULL":
- where_clause_and.append(col_obj.get_sqla_col() != None)
+ elif op == utils.FilterOperationType.IS_NULL.value:
+ where_clause_and.append(col_obj.get_sqla_col() is None)
+ elif op == utils.FilterOperationType.IS_NOT_NULL.value:
+ where_clause_and.append(col_obj.get_sqla_col() is None)
+ else:
+ raise Exception(
+ _("Invalid filter operation type: %(op)s", op=op)
+ )
where_clause_and +=
self._get_sqla_row_level_filters(template_processor)
if extras:
diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py
index ade420b..58657c6 100644
--- a/superset/examples/birth_names.py
+++ b/superset/examples/birth_names.py
@@ -189,7 +189,7 @@ def load_birth_names(only_metadata: bool = False, force:
bool = False) -> None:
"expressionType": "SIMPLE",
"filterOptionName": "2745eae5",
"comparator": ["other"],
- "operator": "not in",
+ "operator": "NOT IN",
"subject": "state",
}
],
diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py
index 0136ed9..1764e11 100644
--- a/superset/examples/world_bank.py
+++ b/superset/examples/world_bank.py
@@ -249,7 +249,7 @@ def load_world_bank_health_n_pop( # pylint:
disable=too-many-locals
"AMA",
"PLW",
],
- "operator": "not in",
+ "operator": "NOT IN",
"subject": "country_code",
}
],
diff --git a/superset/typing.py b/superset/typing.py
index c84e6d3..b6686ec 100644
--- a/superset/typing.py
+++ b/superset/typing.py
@@ -25,4 +25,6 @@ DbapiDescriptionRow = Tuple[
]
DbapiDescription = Union[List[DbapiDescriptionRow], Tuple[DbapiDescriptionRow,
...]]
DbapiResult = List[Union[List[Any], Tuple[Any, ...]]]
+FilterValue = Union[float, int, str]
+FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]]
VizData = Optional[Union[List[Any], Dict[Any, Any]]]
diff --git a/superset/utils/core.py b/superset/utils/core.py
index ac2e92e..5749930 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -198,28 +198,30 @@ def parse_js_uri_path_item(
return unquote_plus(item) if unquote and item else item
-def string_to_num(s: str):
- """Converts a string to an int/float
+def cast_to_num(value: Union[float, int, str]) -> Optional[Union[float, int]]:
+ """Casts a value to an int/float
- Returns ``None`` if it can't be converted
-
- >>> string_to_num('5')
+ >>> cast_to_num('5')
5
- >>> string_to_num('5.2')
+ >>> cast_to_num('5.2')
5.2
- >>> string_to_num(10)
+ >>> cast_to_num(10)
10
- >>> string_to_num(10.1)
+ >>> cast_to_num(10.1)
10.1
- >>> string_to_num('this is not a string') is None
+ >>> cast_to_num('this is not a string') is None
True
+
+ :param value: value to be converted to numeric representation
+ :returns: value cast to `int` if value is all digits, `float` if `value` is
+ decimal value and `None`` if it can't be converted
"""
- if isinstance(s, (int, float)):
- return s
- if s.isdigit():
- return int(s)
+ if isinstance(value, (int, float)):
+ return value
+ if value.isdigit():
+ return int(value)
try:
- return float(s)
+ return float(value)
except ValueError:
return None
@@ -1346,3 +1348,22 @@ class DbColumnType(Enum):
NUMERIC = 0
STRING = 1
TEMPORAL = 2
+
+
+class FilterOperationType(str, Enum):
+ """
+ Filter operation type
+ """
+
+ EQUALS = "=="
+ NOT_EQUALS = "!="
+ GREATER_THAN = ">"
+ LESS_THAN = "<"
+ GREATER_THAN_OR_EQUALS = ">="
+ LESS_THAN_OR_EQUALS = "<="
+ LIKE = "LIKE"
+ IS_NULL = "IS NULL"
+ IS_NOT_NULL = "IS NOT NULL"
+ IN = "IN"
+ NOT_IN = "NOT IN"
+ REGEX = "REGEX"
diff --git a/superset/utils/pandas_postprocessing.py
b/superset/utils/pandas_postprocessing.py
index 2800ee1..f2a688c 100644
--- a/superset/utils/pandas_postprocessing.py
+++ b/superset/utils/pandas_postprocessing.py
@@ -96,7 +96,7 @@ def _get_aggregate_funcs(
aggregators. Currently only numpy aggregators are supported.
:param df: DataFrame on which to perform aggregate operation.
- :param aggregates: Mapping from column name to aggregat config.
+ :param aggregates: Mapping from column name to aggregate config.
:return: Mapping from metric name to function that takes a single input
argument.
"""
agg_funcs: Dict[str, NamedAgg] = {}
@@ -276,12 +276,13 @@ def rolling( # pylint: disable=too-many-arguments
on rolling values calculated from `y`, leaving the original column
`y`
unchanged.
:param rolling_type: Type of rolling window. Any numpy function will work.
+ :param window: Size of the window.
:param rolling_type_options: Optional options to pass to rolling method.
Needed
for e.g. quantile operation.
:param center: Should the label be at the center of the window.
:param win_type: Type of window function.
- :param window: Size of the window.
- :param min_periods:
+ :param min_periods: The minimum amount of periods required for a row to be
included
+ in the result set.
:return: DataFrame with the rolling columns
:raises ChartDataValidationError: If the request in incorrect
"""
@@ -332,7 +333,7 @@ def select(
:param df: DataFrame on which the rolling period will be based.
:param columns: Columns which to select from the DataFrame, in the desired
order.
- If columns are renamed, the new column name should be
referenced
+ If columns are renamed, the old column name should be
referenced
here.
:param rename: columns which to rename, mapping source column to target
column.
For instance, `{'y': 'y2'}` will rename the column `y` to
diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py
index 257b89b..1f64bae 100644
--- a/tests/charts/api_tests.py
+++ b/tests/charts/api_tests.py
@@ -657,7 +657,18 @@ class ChartApiTests(SupersetTestCase,
ApiOwnersTestCaseMixin):
rv = self.client.post(uri, json=query_context)
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
- self.assertEqual(data[0]["rowcount"], 100)
+ self.assertEqual(data["result"][0]["rowcount"], 100)
+
+ def test_invalid_chart_data(self):
+ """
+ Query API: Test chart data query
+ """
+ self.login(username="admin")
+ query_context = self._get_query_context()
+ query_context["datasource"] = "abc"
+ uri = "api/v1/chart/data"
+ rv = self.client.post(uri, json=query_context)
+ self.assertEqual(rv.status_code, 400)
def test_query_exec_not_allowed(self):
"""