This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 8e439b1  chore: Add OpenAPI docs to /api/v1/chart/data EP (#9556)
8e439b1 is described below

commit 8e439b1115481b6df7f8af616ac683f399d52893
Author: Ville Brofeldt <[email protected]>
AuthorDate: Fri Apr 17 16:44:16 2020 +0300

    chore: Add OpenAPI docs to /api/v1/chart/data EP (#9556)
    
    * Add OpenAPI docs to /api/v1/chart/data EP
    
    * Add missing fields to QueryObject, fix linting and unit test errors
    
    * Fix unit test errors
    
    * abc
    
    * Fix errors uncovered by schema validation and add unit test for invalid 
payload
    
    * Add schema for response
    
    * Remove unnecessary pylint disable
---
 setup.cfg                               |   2 +-
 superset/charts/api.py                  |  84 ++----
 superset/charts/schemas.py              | 512 +++++++++++++++++++++++++++++++-
 superset/common/query_object.py         |  21 +-
 superset/connectors/base/models.py      |  28 +-
 superset/connectors/druid/models.py     |  71 +++--
 superset/connectors/sqla/models.py      |  48 +--
 superset/examples/birth_names.py        |   2 +-
 superset/examples/world_bank.py         |   2 +-
 superset/typing.py                      |   2 +
 superset/utils/core.py                  |  49 ++-
 superset/utils/pandas_postprocessing.py |   9 +-
 tests/charts/api_tests.py               |  13 +-
 13 files changed, 695 insertions(+), 148 deletions(-)

diff --git a/setup.cfg b/setup.cfg
index b535c63..566633c 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -45,7 +45,7 @@ combine_as_imports = true
 include_trailing_comma = true
 line_length = 88
 known_first_party = superset
-known_third_party 
=alembic,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlpars
 [...]
+known_third_party 
=alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils
 [...]
 multi_line_output = 3
 order_by_type = false
 
diff --git a/superset/charts/api.py b/superset/charts/api.py
index be0df11..be4f407 100644
--- a/superset/charts/api.py
+++ b/superset/charts/api.py
@@ -18,10 +18,11 @@ import logging
 from typing import Any, Dict
 
 import simplejson
+from apispec import APISpec
 from flask import g, make_response, redirect, request, Response, url_for
 from flask_appbuilder.api import expose, protect, rison, safe
 from flask_appbuilder.models.sqla.interface import SQLAInterface
-from flask_babel import ngettext
+from flask_babel import gettext as _, ngettext
 from werkzeug.wrappers import Response as WerkzeugResponse
 from werkzeug.wsgi import FileWrapper
 
@@ -41,12 +42,13 @@ from superset.charts.commands.exceptions import (
 from superset.charts.commands.update import UpdateChartCommand
 from superset.charts.filters import ChartFilter, ChartNameOrDescriptionFilter
 from superset.charts.schemas import (
+    CHART_DATA_SCHEMAS,
+    ChartDataQueryContextSchema,
     ChartPostSchema,
     ChartPutSchema,
     get_delete_ids_schema,
     thumbnail_query_schema,
 )
-from superset.common.query_context import QueryContext
 from superset.constants import RouteMethod
 from superset.exceptions import SupersetSecurityException
 from superset.extensions import event_logger, security_manager
@@ -381,74 +383,21 @@ class ChartRestApi(BaseSupersetModelRestApi):
             Takes a query context constructed in the client and returns 
payload data
             response for the given query.
           requestBody:
-            description: Query context schema
+            description: >-
+              A query context consists of a datasource from which to fetch data
+              and one or many query objects.
             required: true
             content:
               application/json:
                 schema:
-                  type: object
-                  properties:
-                    datasource:
-                      type: object
-                      description: The datasource where the query will run
-                      properties:
-                        id:
-                          type: integer
-                        type:
-                          type: string
-                    queries:
-                      type: array
-                      items:
-                        type: object
-                        properties:
-                          granularity:
-                            type: string
-                          groupby:
-                            type: array
-                            items:
-                              type: string
-                          metrics:
-                            type: array
-                            items:
-                              type: object
-                          filters:
-                            type: array
-                            items:
-                              type: string
-                          row_limit:
-                            type: integer
+                  $ref: "#/components/schemas/ChartDataQueryContextSchema"
           responses:
             200:
               description: Query result
               content:
                 application/json:
                   schema:
-                    type: array
-                    items:
-                      type: object
-                      properties:
-                        cache_key:
-                          type: string
-                        cached_dttm:
-                          type: string
-                        cache_timeout:
-                          type: integer
-                        error:
-                          type: string
-                        is_cached:
-                          type: boolean
-                        query:
-                          type: string
-                        status:
-                          type: string
-                        stacktrace:
-                          type: string
-                        rowcount:
-                          type: integer
-                        data:
-                          type: array
-                          items:
-                            type: object
+                    $ref: "#/components/schemas/ChartDataResponseSchema"
             400:
               $ref: '#/components/responses/400'
             500:
@@ -457,7 +406,11 @@ class ChartRestApi(BaseSupersetModelRestApi):
         if not request.is_json:
             return self.response_400(message="Request is not JSON")
         try:
-            query_context = QueryContext(**request.json)
+            query_context, errors = 
ChartDataQueryContextSchema().load(request.json)
+            if errors:
+                return self.response_400(
+                    message=_("Request is incorrect: %(error)s", error=errors)
+                )
         except KeyError:
             return self.response_400(message="Request is incorrect")
         try:
@@ -466,7 +419,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
             return self.response_401()
         payload_json = query_context.get_payload()
         response_data = simplejson.dumps(
-            payload_json, default=json_int_dttm_ser, ignore_nan=True
+            {"result": payload_json}, default=json_int_dttm_ser, 
ignore_nan=True
         )
         resp = make_response(response_data, 200)
         resp.headers["Content-Type"] = "application/json; charset=utf-8"
@@ -533,3 +486,10 @@ class ChartRestApi(BaseSupersetModelRestApi):
         return Response(
             FileWrapper(screenshot), mimetype="image/png", 
direct_passthrough=True
         )
+
+    def add_apispec_components(self, api_spec: APISpec) -> None:
+        for chart_type in CHART_DATA_SCHEMAS:
+            api_spec.components.schema(
+                chart_type.__name__, schema=chart_type,
+            )
+        super().add_apispec_components(api_spec)
diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index bf1b57b..0a7035c 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -14,11 +14,12 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from typing import Union
+from typing import Any, Dict, Union
 
-from marshmallow import fields, Schema, ValidationError
+from marshmallow import fields, post_load, Schema, ValidationError
 from marshmallow.validate import Length
 
+from superset.common.query_context import QueryContext
 from superset.exceptions import SupersetException
 from superset.utils import core as utils
 
@@ -59,3 +60,510 @@ class ChartPutSchema(Schema):
     datasource_id = fields.Integer(allow_none=True)
     datasource_type = fields.String(allow_none=True)
     dashboards = fields.List(fields.Integer())
+
+
+class ChartDataColumnSchema(Schema):
+    column_name = fields.String(
+        description="The name of the target column", example="mycol",
+    )
+    type = fields.String(description="Type of target column", 
example="BIGINT",)
+
+
+class ChartDataAdhocMetricSchema(Schema):
+    """
+    Ad-hoc metrics are used to define metrics outside the datasource.
+    """
+
+    expressionType = fields.String(
+        description="Simple or SQL metric",
+        required=True,
+        enum=["SIMPLE", "SQL"],
+        example="SQL",
+    )
+    aggregate = fields.String(
+        description="Aggregation operator. Only required for simple expression 
types.",
+        required=False,
+        enum=["AVG", "COUNT", "COUNT_DISTINCT", "MAX", "MIN", "SUM"],
+    )
+    column = fields.Nested(ChartDataColumnSchema)
+    sqlExpression = fields.String(
+        description="The metric as defined by a SQL aggregate expression. "
+        "Only required for SQL expression type.",
+        required=False,
+        example="SUM(weight * observations) / SUM(weight)",
+    )
+    label = fields.String(
+        description="Label for the metric. Is automatically generated unless "
+        "hasCustomLabel is true, in which case label must be defined.",
+        required=False,
+        example="Weighted observations",
+    )
+    hasCustomLabel = fields.Boolean(
+        description="When false, the label will be automatically generated 
based on "
+        "the aggregate expression. When true, a custom label has to be "
+        "specified.",
+        required=False,
+        example=True,
+    )
+    optionName = fields.String(
+        description="Unique identifier. Can be any string value, as long as 
all "
+        "metrics have a unique identifier. If undefined, a random name "
+        "will be generated.",
+        required=False,
+        example="metric_aec60732-fac0-4b17-b736-93f1a5c93e30",
+    )
+
+
+class ChartDataAggregateConfigField(fields.Dict):
+    def __init__(self) -> None:
+        super().__init__(
+            description="The keys are the name of the aggregate column to be 
created, "
+            "and the values specify the details of how to apply the "
+            "aggregation. If an operator requires additional options, "
+            "these can be passed here to be unpacked in the operator call. The 
"
+            "following numpy operators are supported: average, argmin, argmax, 
cumsum, "
+            "cumprod, max, mean, median, nansum, nanmin, nanmax, nanmean, 
nanmedian, "
+            "min, percentile, prod, product, std, sum, var. Any options 
required by "
+            "the operator can be passed to the `options` object.\n"
+            "\n"
+            "In the example, a new column `first_quantile` is created based on 
values "
+            "in the column `my_col` using the `percentile` operator with "
+            "the `q=0.25` parameter.",
+            example={
+                "first_quantile": {
+                    "operator": "percentile",
+                    "column": "my_col",
+                    "options": {"q": 0.25},
+                }
+            },
+        )
+
+
+class ChartDataPostProcessingOperationOptionsSchema(Schema):
+    def __init__(self, *args: Any, **kwargs: Any) -> None:
+        super().__init__(*args, **kwargs)
+
+
+class 
ChartDataAggregateOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
+    """
+    Aggregate operation config.
+    """
+
+    groupby = (
+        fields.List(
+            fields.String(
+                allow_none=False, description="Columns by which to group by",
+            ),
+            minLength=1,
+            required=True,
+        ),
+    )
+    aggregates = ChartDataAggregateConfigField()
+
+
+class 
ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
+    """
+    Rolling operation config.
+    """
+
+    columns = (
+        fields.Dict(
+            description="columns on which to perform rolling, mapping source 
column to "
+            "target column. For instance, `{'y': 'y'}` will replace the "
+            "column `y` with the rolling value in `y`, while `{'y': 'y2'}` "
+            "will add a column `y2` based on rolling values calculated "
+            "from `y`, leaving the original column `y` unchanged.",
+            example={"weekly_rolling_sales": "sales"},
+        ),
+    )
+    rolling_type = fields.String(
+        description="Type of rolling window. Any numpy function will work.",
+        enum=[
+            "average",
+            "argmin",
+            "argmax",
+            "cumsum",
+            "cumprod",
+            "max",
+            "mean",
+            "median",
+            "nansum",
+            "nanmin",
+            "nanmax",
+            "nanmean",
+            "nanmedian",
+            "min",
+            "percentile",
+            "prod",
+            "product",
+            "std",
+            "sum",
+            "var",
+        ],
+        required=True,
+        example="percentile",
+    )
+    window = fields.Integer(
+        description="Size of the rolling window in days.", required=True, 
example=7,
+    )
+    rolling_type_options = fields.Dict(
+        desctiption="Optional options to pass to rolling method. Needed for "
+        "e.g. quantile operation.",
+        required=False,
+        example={},
+    )
+    center = fields.Boolean(
+        description="Should the label be at the center of the window. Default: 
`false`",
+        required=False,
+        example=False,
+    )
+    win_type = fields.String(
+        description="Type of window function. See "
+        "[SciPy window functions](https://docs.scipy.org/doc/scipy/reference";
+        "/signal.windows.html#module-scipy.signal.windows) "
+        "for more details. Some window functions require passing "
+        "additional parameters to `rolling_type_options`. For instance, "
+        "to use `gaussian`, the parameter `std` needs to be provided.",
+        required=False,
+        enum=[
+            "boxcar",
+            "triang",
+            "blackman",
+            "hamming",
+            "bartlett",
+            "parzen",
+            "bohman",
+            "blackmanharris",
+            "nuttall",
+            "barthann",
+            "kaiser",
+            "gaussian",
+            "general_gaussian",
+            "slepian",
+            "exponential",
+        ],
+    )
+    min_periods = fields.Integer(
+        description="The minimum amount of periods required for a row to be 
included "
+        "in the result set.",
+        required=False,
+        example=7,
+    )
+
+
+class 
ChartDataSelectOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
+    """
+    Sort operation config.
+    """
+
+    columns = fields.List(
+        fields.String(),
+        description="Columns which to select from the input data, in the 
desired "
+        "order. If columns are renamed, the old column name should be "
+        "referenced here.",
+        example=["country", "gender", "age"],
+    )
+    rename = fields.List(
+        fields.Dict(),
+        description="columns which to rename, mapping source column to target 
column. "
+        "For instance, `{'y': 'y2'}` will rename the column `y` to `y2`.",
+        example=[{"age": "average_age"}],
+    )
+
+
+class 
ChartDataSortOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
+    """
+    Sort operation config.
+    """
+
+    columns = fields.Dict(
+        description="columns by by which to sort. The key specifies the column 
name, "
+        "value specifies if sorting in ascending order.",
+        example={"country": True, "gender": False},
+        required=True,
+    )
+    aggregates = ChartDataAggregateConfigField()
+
+
+class 
ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
+    """
+    Pivot operation config.
+    """
+
+    index = (
+        fields.List(
+            fields.String(
+                allow_none=False,
+                description="Columns to group by on the table index (=rows)",
+            ),
+            minLength=1,
+            required=True,
+        ),
+    )
+    columns = fields.List(
+        fields.String(
+            allow_none=False, description="Columns to group by on the table 
columns",
+        ),
+        minLength=1,
+        required=True,
+    )
+    metric_fill_value = fields.Number(
+        required=False,
+        description="Value to replace missing values with in aggregate 
calculations.",
+    )
+    column_fill_value = fields.String(
+        required=False, description="Value to replace missing pivot columns 
names with."
+    )
+    drop_missing_columns = fields.Boolean(
+        description="Do not include columns whose entries are all missing "
+        "(default: `true`).",
+        required=False,
+    )
+    marginal_distributions = fields.Boolean(
+        description="Add totals for row/column. (default: `false`)", 
required=False,
+    )
+    marginal_distribution_name = fields.String(
+        description="Name of marginal distribution row/column. (default: 
`All`)",
+        required=False,
+    )
+    aggregates = ChartDataAggregateConfigField()
+
+
+class ChartDataPostProcessingOperationSchema(Schema):
+    operation = fields.String(
+        description="Post processing operation type",
+        required=True,
+        enum=["aggregate", "pivot", "rolling", "select", "sort"],
+        example="aggregate",
+    )
+    options = fields.Nested(
+        ChartDataPostProcessingOperationOptionsSchema,
+        description="Options specifying how to perform the operation. Please 
refer "
+        "to the respective post processing operation option schemas. "
+        "For example, `ChartDataPostProcessingOperationOptions` specifies "
+        "the required options for the pivot operation.",
+        example={
+            "groupby": ["country", "gender"],
+            "aggregates": {
+                "age_q1": {
+                    "operator": "percentile",
+                    "column": "age",
+                    "options": {"q": 0.25},
+                },
+                "age_mean": {"operator": "mean", "column": "age",},
+            },
+        },
+    )
+
+
+class ChartDataFilterSchema(Schema):
+    col = fields.String(
+        description="The column to filter.", required=True, example="country"
+    )
+    op = fields.String(  # pylint: disable=invalid-name
+        description="The comparison operator.",
+        enum=[filter_op.value for filter_op in utils.FilterOperationType],
+        required=True,
+        example="IN",
+    )
+    val = fields.Raw(
+        description="The value or values to compare against. Can be a string, "
+        "integer, decimal or list, depending on the operator.",
+        example=["China", "France", "Japan"],
+    )
+
+
+class ChartDataExtrasSchema(Schema):
+
+    time_range_endpoints = fields.List(
+        fields.String(enum=["INCLUSIVE", "EXCLUSIVE"]),
+        description="A list with two values, stating if start/end should be "
+        "inclusive/exclusive.",
+        required=False,
+    )
+    relative_start = fields.String(
+        description="Start time for relative time deltas. "
+        'Default: `config["DEFAULT_RELATIVE_START_TIME"]`',
+        enum=["today", "now"],
+        required=False,
+    )
+    relative_end = fields.String(
+        description="End time for relative time deltas. "
+        'Default: `config["DEFAULT_RELATIVE_START_TIME"]`',
+        enum=["today", "now"],
+        required=False,
+    )
+
+
+class ChartDataQueryObjectSchema(Schema):
+    filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False)
+    granularity = fields.String(
+        description="To what level of granularity should the temporal column 
be "
+        "aggregated. Supports "
+        "[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) "
+        "durations.",
+        enum=[
+            "PT1S",
+            "PT1M",
+            "PT5M",
+            "PT10M",
+            "PT15M",
+            "PT0.5H",
+            "PT1H",
+            "P1D",
+            "P1W",
+            "P1M",
+            "P0.25Y",
+            "P1Y",
+        ],
+        required=False,
+        example="P1D",
+    )
+    groupby = fields.List(
+        fields.String(description="Columns by which to group the query.",),
+    )
+    metrics = fields.List(
+        fields.Raw(),
+        description="Aggregate expressions. Metrics can be passed as both "
+        "references to datasource metrics (strings), or ad-hoc metrics"
+        "which are defined only within the query object. See "
+        "`ChartDataAdhocMetricSchema` for the structure of ad-hoc metrics.",
+    )
+    post_processing = fields.List(
+        fields.Nested(ChartDataPostProcessingOperationSchema),
+        description="Post processing operations to be applied to the result 
set. "
+        "Operations are applied to the result set in sequential order.",
+        required=False,
+    )
+    time_range = fields.String(
+        description="A time rage, either expressed as a colon separated string 
"
+        "`since : until`. Valid formats for `since` and `until` are: \n"
+        "- ISO 8601\n"
+        "- X days/years/hours/day/year/weeks\n"
+        "- X days/years/hours/day/year/weeks ago\n"
+        "- X days/years/hours/day/year/weeks from now\n"
+        "\n"
+        "Additionally, the following freeform can be used:\n"
+        "\n"
+        "- Last day\n"
+        "- Last week\n"
+        "- Last month\n"
+        "- Last quarter\n"
+        "- Last year\n"
+        "- No filter\n"
+        "- Last X seconds/minutes/hours/days/weeks/months/years\n"
+        "- Next X seconds/minutes/hours/days/weeks/months/years\n",
+        required=False,
+        example="Last week",
+    )
+    time_shift = fields.String(
+        description="A human-readable date/time string. "
+        "Please refer to [parsdatetime](https://github.com/bear/parsedatetime) 
"
+        "documentation for details on valid values.",
+        required=False,
+    )
+    is_timeseries = fields.Boolean(
+        description="Is the `query_object` a timeseries.", required=False
+    )
+    timeseries_limit = fields.Integer(
+        description="Maximum row count for timeseries queries. Default: `0`",
+        required=False,
+    )
+    row_limit = fields.Integer(
+        description='Maximum row count. Default: `config["ROW_LIMIT"]`', 
required=False,
+    )
+    order_desc = fields.Boolean(
+        description="Reverse order. Default: `false`", required=False
+    )
+    extras = fields.Dict(description=" Default: `{}`", required=False)
+    columns = fields.List(fields.String(), description="", required=False,)
+    orderby = fields.List(
+        fields.List(fields.Raw()),
+        description="Expects a list of lists where the first element is the 
column "
+        "name which to sort by, and the second element is a boolean ",
+        required=False,
+        example=[["my_col_1", False], ["my_col_2", True]],
+    )
+
+
+class ChartDataDatasourceSchema(Schema):
+    description = "Chart datasource"
+    id = fields.Integer(description="Datasource id", required=True,)
+    type = fields.String(description="Datasource type", enum=["druid", "sql"])
+
+
+class ChartDataQueryContextSchema(Schema):
+    datasource = fields.Nested(ChartDataDatasourceSchema)
+    queries = fields.List(fields.Nested(ChartDataQueryObjectSchema))
+
+    # pylint: disable=no-self-use
+    @post_load
+    def make_query_context(self, data: Dict[str, Any]) -> QueryContext:
+        query_context = QueryContext(**data)
+        return query_context
+
+    # pylint: enable=no-self-use
+
+
+class ChartDataResponseResult(Schema):
+    cache_key = fields.String(
+        description="Unique cache key for query object", required=True, 
allow_none=True,
+    )
+    cached_dttm = fields.String(
+        description="Cache timestamp", required=True, allow_none=True,
+    )
+    cache_timeout = fields.Integer(
+        description="Cache timeout in following order: custom timeout, 
datasource "
+        "timeout, default config timeout.",
+        required=True,
+        allow_none=True,
+    )
+    error = fields.String(description="Error", allow_none=True,)
+    is_cached = fields.Boolean(
+        description="Is the result cached", required=True, allow_none=None,
+    )
+    query = fields.String(
+        description="The executed query statement", required=True, 
allow_none=False,
+    )
+    status = fields.String(
+        description="Status of the query",
+        enum=[
+            "stopped",
+            "failed",
+            "pending",
+            "running",
+            "scheduled",
+            "success",
+            "timed_out",
+        ],
+        allow_none=False,
+    )
+    stacktrace = fields.String(
+        desciption="Stacktrace if there was an error", allow_none=True,
+    )
+    rowcount = fields.Integer(
+        description="Amount of rows in result set", allow_none=False,
+    )
+    data = fields.List(fields.Dict(), description="A list with results")
+
+
+class ChartDataResponseSchema(Schema):
+    result = fields.List(
+        fields.Nested(ChartDataResponseResult),
+        description="A list of results for each corresponding query in the 
request.",
+    )
+
+
+CHART_DATA_SCHEMAS = (
+    ChartDataQueryContextSchema,
+    ChartDataResponseSchema,
+    # TODO: These should optimally be included in the QueryContext schema as 
an `anyOf`
+    #  in ChartDataPostPricessingOperation.options, but since `anyOf` is not
+    #  by Marshmallow<3, this is not currently possible.
+    ChartDataAdhocMetricSchema,
+    ChartDataAggregateOptionsSchema,
+    ChartDataPivotOptionsSchema,
+    ChartDataRollingOptionsSchema,
+    ChartDataSelectOptionsSchema,
+    ChartDataSortOptionsSchema,
+)
diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index 63158c6..31a6241 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -47,9 +47,9 @@ class QueryObject:
     is_timeseries: bool
     time_shift: Optional[timedelta]
     groupby: List[str]
-    metrics: List[Union[Dict, str]]
+    metrics: List[Union[Dict[str, Any], str]]
     row_limit: int
-    filter: List[str]
+    filter: List[Dict[str, Any]]
     timeseries_limit: int
     timeseries_limit_metric: Optional[Dict]
     order_desc: bool
@@ -61,9 +61,9 @@ class QueryObject:
     def __init__(
         self,
         granularity: str,
-        metrics: List[Union[Dict, str]],
+        metrics: List[Union[Dict[str, Any], str]],
         groupby: Optional[List[str]] = None,
-        filters: Optional[List[str]] = None,
+        filters: Optional[List[Dict[str, Any]]] = None,
         time_range: Optional[str] = None,
         time_shift: Optional[str] = None,
         is_timeseries: bool = False,
@@ -75,14 +75,17 @@ class QueryObject:
         columns: Optional[List[str]] = None,
         orderby: Optional[List[List]] = None,
         post_processing: Optional[List[Dict[str, Any]]] = None,
-        relative_start: str = app.config["DEFAULT_RELATIVE_START_TIME"],
-        relative_end: str = app.config["DEFAULT_RELATIVE_END_TIME"],
     ):
+        extras = extras or {}
         is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE")
         self.granularity = granularity
         self.from_dttm, self.to_dttm = utils.get_since_until(
-            relative_start=relative_start,
-            relative_end=relative_end,
+            relative_start=extras.get(
+                "relative_start", app.config["DEFAULT_RELATIVE_START_TIME"]
+            ),
+            relative_end=extras.get(
+                "relative_end", app.config["DEFAULT_RELATIVE_END_TIME"]
+            ),
             time_range=time_range,
             time_shift=time_shift,
         )
@@ -106,7 +109,7 @@ class QueryObject:
         self.timeseries_limit = timeseries_limit
         self.timeseries_limit_metric = timeseries_limit_metric
         self.order_desc = order_desc
-        self.extras = extras or {}
+        self.extras = extras
 
         if app.config["SIP_15_ENABLED"] and "time_range_endpoints" not in 
self.extras:
             self.extras["time_range_endpoints"] = 
get_time_range_endpoints(form_data={})
diff --git a/superset/connectors/base/models.py 
b/superset/connectors/base/models.py
index 8e1acc7..2b6e0d2 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -25,6 +25,7 @@ from sqlalchemy.orm import foreign, Query, relationship
 from superset.constants import NULL_STRING
 from superset.models.helpers import AuditMixinNullable, ImportMixin, 
QueryResult
 from superset.models.slice import Slice
+from superset.typing import FilterValue, FilterValues
 from superset.utils import core as utils
 
 METRIC_FORM_DATA_PARAMS = [
@@ -301,28 +302,33 @@ class BaseDatasource(
 
     @staticmethod
     def filter_values_handler(
-        values, target_column_is_numeric=False, is_list_target=False
-    ):
-        def handle_single_value(v):
+        values: Optional[FilterValues],
+        target_column_is_numeric: bool = False,
+        is_list_target: bool = False,
+    ) -> Optional[FilterValues]:
+        if values is None:
+            return None
+
+        def handle_single_value(value: Optional[FilterValue]) -> 
Optional[FilterValue]:
             # backward compatibility with previous <select> components
-            if isinstance(v, str):
-                v = v.strip("\t\n'\"")
+            if isinstance(value, str):
+                value = value.strip("\t\n'\"")
                 if target_column_is_numeric:
                     # For backwards compatibility and edge cases
                     # where a column data type might have changed
-                    v = utils.string_to_num(v)
-                if v == NULL_STRING:
+                    value = utils.cast_to_num(value)
+                if value == NULL_STRING:
                     return None
-                elif v == "<empty string>":
+                elif value == "<empty string>":
                     return ""
-            return v
+            return value
 
         if isinstance(values, (list, tuple)):
-            values = [handle_single_value(v) for v in values]
+            values = [handle_single_value(v) for v in values]  # type: ignore
         else:
             values = handle_single_value(values)
         if is_list_target and not isinstance(values, (tuple, list)):
-            values = [values]
+            values = [values]  # type: ignore
         elif not is_list_target and isinstance(values, (tuple, list)):
             if values:
                 values = values[0]
diff --git a/superset/connectors/druid/models.py 
b/superset/connectors/druid/models.py
index a4cc527..20dd732 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -24,7 +24,7 @@ from copy import deepcopy
 from datetime import datetime, timedelta
 from distutils.version import LooseVersion
 from multiprocessing.pool import ThreadPool
-from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
+from typing import cast, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import pandas as pd
 import sqlalchemy as sa
@@ -54,6 +54,7 @@ from superset.constants import NULL_STRING
 from superset.exceptions import SupersetException
 from superset.models.core import Database
 from superset.models.helpers import AuditMixinNullable, ImportMixin, 
QueryResult
+from superset.typing import FilterValues
 from superset.utils import core as utils, import_datasource
 
 try:
@@ -80,7 +81,12 @@ except ImportError:
     pass
 
 try:
-    from superset.utils.core import DimSelector, DTTM_ALIAS, flasher
+    from superset.utils.core import (
+        DimSelector,
+        DTTM_ALIAS,
+        FilterOperationType,
+        flasher,
+    )
 except ImportError:
     pass
 
@@ -1483,13 +1489,20 @@ class DruidDatasource(Model, BaseDatasource):
         """Given Superset filter data structure, returns pydruid Filter(s)"""
         filters = None
         for flt in raw_filters:
-            col = flt.get("col")
-            op = flt.get("op")
-            eq = flt.get("val")
+            col: Optional[str] = flt.get("col")
+            op: Optional[str] = flt["op"].upper() if "op" in flt else None
+            eq: Optional[FilterValues] = flt.get("val")
             if (
                 not col
                 or not op
-                or (eq is None and op not in ("IS NULL", "IS NOT NULL"))
+                or (
+                    eq is None
+                    and op
+                    not in (
+                        FilterOperationType.IS_NULL.value,
+                        FilterOperationType.IS_NOT_NULL.value,
+                    )
+                )
             ):
                 continue
 
@@ -1503,7 +1516,10 @@ class DruidDatasource(Model, BaseDatasource):
 
             cond = None
             is_numeric_col = col in num_cols
-            is_list_target = op in ("in", "not in")
+            is_list_target = op in (
+                FilterOperationType.IN.value,
+                FilterOperationType.NOT_IN.value,
+            )
             eq = cls.filter_values_handler(
                 eq,
                 is_list_target=is_list_target,
@@ -1512,15 +1528,16 @@ class DruidDatasource(Model, BaseDatasource):
 
             # For these two ops, could have used Dimension,
             # but it doesn't support extraction functions
-            if op == "==":
+            if op == FilterOperationType.EQUALS.value:
                 cond = Filter(
                     dimension=col, value=eq, extraction_function=extraction_fn
                 )
-            elif op == "!=":
+            elif op == FilterOperationType.NOT_EQUALS.value:
                 cond = ~Filter(
                     dimension=col, value=eq, extraction_function=extraction_fn
                 )
-            elif op in ("in", "not in"):
+            elif is_list_target:
+                eq = cast(list, eq)
                 fields = []
                 # ignore the filter if it has no value
                 if not len(eq):
@@ -1540,9 +1557,9 @@ class DruidDatasource(Model, BaseDatasource):
                     for s in eq:
                         fields.append(Dimension(col) == s)
                     cond = Filter(type="or", fields=fields)
-                if op == "not in":
+                if op == FilterOperationType.NOT_IN.value:
                     cond = ~cond
-            elif op == "regex":
+            elif op == FilterOperationType.REGEX.value:
                 cond = Filter(
                     extraction_function=extraction_fn,
                     type="regex",
@@ -1552,7 +1569,7 @@ class DruidDatasource(Model, BaseDatasource):
 
             # For the ops below, could have used pydruid's Bound,
             # but it doesn't support extraction functions
-            elif op == ">=":
+            elif op == FilterOperationType.GREATER_THAN_OR_EQUALS.value:
                 cond = Bound(
                     extraction_function=extraction_fn,
                     dimension=col,
@@ -1562,7 +1579,7 @@ class DruidDatasource(Model, BaseDatasource):
                     upper=None,
                     ordering=cls._get_ordering(is_numeric_col),
                 )
-            elif op == "<=":
+            elif op == FilterOperationType.LESS_THAN_OR_EQUALS.value:
                 cond = Bound(
                     extraction_function=extraction_fn,
                     dimension=col,
@@ -1572,7 +1589,7 @@ class DruidDatasource(Model, BaseDatasource):
                     upper=eq,
                     ordering=cls._get_ordering(is_numeric_col),
                 )
-            elif op == ">":
+            elif op == FilterOperationType.GREATER_THAN.value:
                 cond = Bound(
                     extraction_function=extraction_fn,
                     lowerStrict=True,
@@ -1582,7 +1599,7 @@ class DruidDatasource(Model, BaseDatasource):
                     upper=None,
                     ordering=cls._get_ordering(is_numeric_col),
                 )
-            elif op == "<":
+            elif op == FilterOperationType.LESS_THAN.value:
                 cond = Bound(
                     extraction_function=extraction_fn,
                     upperStrict=True,
@@ -1592,9 +1609,9 @@ class DruidDatasource(Model, BaseDatasource):
                     upper=eq,
                     ordering=cls._get_ordering(is_numeric_col),
                 )
-            elif op == "IS NULL":
+            elif op == FilterOperationType.IS_NULL.value:
                 cond = Filter(dimension=col, value="")
-            elif op == "IS NOT NULL":
+            elif op == FilterOperationType.IS_NOT_NULL.value:
                 cond = ~Filter(dimension=col, value="")
 
             if filters:
@@ -1610,21 +1627,25 @@ class DruidDatasource(Model, BaseDatasource):
 
     def _get_having_obj(self, col: str, op: str, eq: str) -> "Having":
         cond = None
-        if op == "==":
+        if op == FilterOperationType.EQUALS.value:
             if col in self.column_names:
                 cond = DimSelector(dimension=col, value=eq)
             else:
                 cond = Aggregation(col) == eq
-        elif op == ">":
+        elif op == FilterOperationType.GREATER_THAN.value:
             cond = Aggregation(col) > eq
-        elif op == "<":
+        elif op == FilterOperationType.LESS_THAN.value:
             cond = Aggregation(col) < eq
 
         return cond
 
     def get_having_filters(self, raw_filters: List[Dict]) -> "Having":
         filters = None
-        reversed_op_map = {"!=": "==", ">=": "<", "<=": ">"}
+        reversed_op_map = {
+            FilterOperationType.NOT_EQUALS.value: 
FilterOperationType.EQUALS.value,
+            FilterOperationType.GREATER_THAN_OR_EQUALS.value: 
FilterOperationType.LESS_THAN.value,
+            FilterOperationType.LESS_THAN_OR_EQUALS.value: 
FilterOperationType.GREATER_THAN.value,
+        }
 
         for flt in raw_filters:
             if not all(f in flt for f in ["col", "op", "val"]):
@@ -1633,7 +1654,11 @@ class DruidDatasource(Model, BaseDatasource):
             op = flt["op"]
             eq = flt["val"]
             cond = None
-            if op in ["==", ">", "<"]:
+            if op in [
+                FilterOperationType.EQUALS.value,
+                FilterOperationType.GREATER_THAN.value,
+                FilterOperationType.LESS_THAN.value,
+            ]:
                 cond = self._get_having_obj(col, op, eq)
             elif op in reversed_op_map:
                 cond = ~self._get_having_obj(col, reversed_op_map[op], eq)
diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index c290363..8aac7e0 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -843,43 +843,53 @@ class SqlaTable(Model, BaseDatasource):
             if not all([flt.get(s) for s in ["col", "op"]]):
                 continue
             col = flt["col"]
-            op = flt["op"]
+            op = flt["op"].upper()
             col_obj = cols.get(col)
             if col_obj:
-                is_list_target = op in ("in", "not in")
+                is_list_target = op in (
+                    utils.FilterOperationType.IN.value,
+                    utils.FilterOperationType.NOT_IN.value,
+                )
                 eq = self.filter_values_handler(
-                    flt.get("val"),
+                    values=flt.get("val"),
                     target_column_is_numeric=col_obj.is_numeric,
                     is_list_target=is_list_target,
                 )
-                if op in ("in", "not in"):
+                if op in (
+                    utils.FilterOperationType.IN.value,
+                    utils.FilterOperationType.NOT_IN.value,
+                ):
                     cond = col_obj.get_sqla_col().in_(eq)
-                    if NULL_STRING in eq:
-                        cond = or_(cond, col_obj.get_sqla_col() == None)
-                    if op == "not in":
+                    if isinstance(eq, str) and NULL_STRING in eq:
+                        cond = or_(cond, col_obj.get_sqla_col() is None)
+                    if op == utils.FilterOperationType.NOT_IN.value:
                         cond = ~cond
                     where_clause_and.append(cond)
                 else:
                     if col_obj.is_numeric:
-                        eq = utils.string_to_num(flt["val"])
-                    if op == "==":
+                        eq = utils.cast_to_num(flt["val"])
+                    if op == utils.FilterOperationType.EQUALS.value:
                         where_clause_and.append(col_obj.get_sqla_col() == eq)
-                    elif op == "!=":
+                    elif op == utils.FilterOperationType.NOT_EQUALS.value:
                         where_clause_and.append(col_obj.get_sqla_col() != eq)
-                    elif op == ">":
+                    elif op == utils.FilterOperationType.GREATER_THAN.value:
                         where_clause_and.append(col_obj.get_sqla_col() > eq)
-                    elif op == "<":
+                    elif op == utils.FilterOperationType.LESS_THAN.value:
                         where_clause_and.append(col_obj.get_sqla_col() < eq)
-                    elif op == ">=":
+                    elif op == 
utils.FilterOperationType.GREATER_THAN_OR_EQUALS.value:
                         where_clause_and.append(col_obj.get_sqla_col() >= eq)
-                    elif op == "<=":
+                    elif op == 
utils.FilterOperationType.LESS_THAN_OR_EQUALS.value:
                         where_clause_and.append(col_obj.get_sqla_col() <= eq)
-                    elif op == "LIKE":
+                    elif op == utils.FilterOperationType.LIKE.value:
                         
where_clause_and.append(col_obj.get_sqla_col().like(eq))
-                    elif op == "IS NULL":
-                        where_clause_and.append(col_obj.get_sqla_col() == None)
-                    elif op == "IS NOT NULL":
-                        where_clause_and.append(col_obj.get_sqla_col() != None)
+                    elif op == utils.FilterOperationType.IS_NULL.value:
+                        where_clause_and.append(col_obj.get_sqla_col() is None)
+                    elif op == utils.FilterOperationType.IS_NOT_NULL.value:
+                        where_clause_and.append(col_obj.get_sqla_col() is None)
+                    else:
+                        raise Exception(
+                            _("Invalid filter operation type: %(op)s", op=op)
+                        )
 
         where_clause_and += 
self._get_sqla_row_level_filters(template_processor)
         if extras:
diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py
index ade420b..58657c6 100644
--- a/superset/examples/birth_names.py
+++ b/superset/examples/birth_names.py
@@ -189,7 +189,7 @@ def load_birth_names(only_metadata: bool = False, force: 
bool = False) -> None:
                         "expressionType": "SIMPLE",
                         "filterOptionName": "2745eae5",
                         "comparator": ["other"],
-                        "operator": "not in",
+                        "operator": "NOT IN",
                         "subject": "state",
                     }
                 ],
diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py
index 0136ed9..1764e11 100644
--- a/superset/examples/world_bank.py
+++ b/superset/examples/world_bank.py
@@ -249,7 +249,7 @@ def load_world_bank_health_n_pop(  # pylint: 
disable=too-many-locals
                             "AMA",
                             "PLW",
                         ],
-                        "operator": "not in",
+                        "operator": "NOT IN",
                         "subject": "country_code",
                     }
                 ],
diff --git a/superset/typing.py b/superset/typing.py
index c84e6d3..b6686ec 100644
--- a/superset/typing.py
+++ b/superset/typing.py
@@ -25,4 +25,6 @@ DbapiDescriptionRow = Tuple[
 ]
 DbapiDescription = Union[List[DbapiDescriptionRow], Tuple[DbapiDescriptionRow, 
...]]
 DbapiResult = List[Union[List[Any], Tuple[Any, ...]]]
+FilterValue = Union[float, int, str]
+FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]]
 VizData = Optional[Union[List[Any], Dict[Any, Any]]]
diff --git a/superset/utils/core.py b/superset/utils/core.py
index ac2e92e..5749930 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -198,28 +198,30 @@ def parse_js_uri_path_item(
     return unquote_plus(item) if unquote and item else item
 
 
-def string_to_num(s: str):
-    """Converts a string to an int/float
+def cast_to_num(value: Union[float, int, str]) -> Optional[Union[float, int]]:
+    """Casts a value to an int/float
 
-    Returns ``None`` if it can't be converted
-
-    >>> string_to_num('5')
+    >>> cast_to_num('5')
     5
-    >>> string_to_num('5.2')
+    >>> cast_to_num('5.2')
     5.2
-    >>> string_to_num(10)
+    >>> cast_to_num(10)
     10
-    >>> string_to_num(10.1)
+    >>> cast_to_num(10.1)
     10.1
-    >>> string_to_num('this is not a string') is None
+    >>> cast_to_num('this is not a string') is None
     True
+
+    :param value: value to be converted to numeric representation
+    :returns: value cast to `int` if value is all digits, `float` if `value` is
+              decimal value and `None`` if it can't be converted
     """
-    if isinstance(s, (int, float)):
-        return s
-    if s.isdigit():
-        return int(s)
+    if isinstance(value, (int, float)):
+        return value
+    if value.isdigit():
+        return int(value)
     try:
-        return float(s)
+        return float(value)
     except ValueError:
         return None
 
@@ -1346,3 +1348,22 @@ class DbColumnType(Enum):
     NUMERIC = 0
     STRING = 1
     TEMPORAL = 2
+
+
+class FilterOperationType(str, Enum):
+    """
+    Filter operation type
+    """
+
+    EQUALS = "=="
+    NOT_EQUALS = "!="
+    GREATER_THAN = ">"
+    LESS_THAN = "<"
+    GREATER_THAN_OR_EQUALS = ">="
+    LESS_THAN_OR_EQUALS = "<="
+    LIKE = "LIKE"
+    IS_NULL = "IS NULL"
+    IS_NOT_NULL = "IS NOT NULL"
+    IN = "IN"
+    NOT_IN = "NOT IN"
+    REGEX = "REGEX"
diff --git a/superset/utils/pandas_postprocessing.py 
b/superset/utils/pandas_postprocessing.py
index 2800ee1..f2a688c 100644
--- a/superset/utils/pandas_postprocessing.py
+++ b/superset/utils/pandas_postprocessing.py
@@ -96,7 +96,7 @@ def _get_aggregate_funcs(
     aggregators. Currently only numpy aggregators are supported.
 
     :param df: DataFrame on which to perform aggregate operation.
-    :param aggregates: Mapping from column name to aggregat config.
+    :param aggregates: Mapping from column name to aggregate config.
     :return: Mapping from metric name to function that takes a single input 
argument.
     """
     agg_funcs: Dict[str, NamedAgg] = {}
@@ -276,12 +276,13 @@ def rolling(  # pylint: disable=too-many-arguments
            on rolling values calculated from `y`, leaving the original column 
`y`
            unchanged.
     :param rolling_type: Type of rolling window. Any numpy function will work.
+    :param window: Size of the window.
     :param rolling_type_options: Optional options to pass to rolling method. 
Needed
            for e.g. quantile operation.
     :param center: Should the label be at the center of the window.
     :param win_type: Type of window function.
-    :param window: Size of the window.
-    :param min_periods:
+    :param min_periods: The minimum amount of periods required for a row to be 
included
+                        in the result set.
     :return: DataFrame with the rolling columns
     :raises ChartDataValidationError: If the request in incorrect
     """
@@ -332,7 +333,7 @@ def select(
 
     :param df: DataFrame on which the rolling period will be based.
     :param columns: Columns which to select from the DataFrame, in the desired 
order.
-                    If columns are renamed, the new column name should be 
referenced
+                    If columns are renamed, the old column name should be 
referenced
                     here.
     :param rename: columns which to rename, mapping source column to target 
column.
                    For instance, `{'y': 'y2'}` will rename the column `y` to
diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py
index 257b89b..1f64bae 100644
--- a/tests/charts/api_tests.py
+++ b/tests/charts/api_tests.py
@@ -657,7 +657,18 @@ class ChartApiTests(SupersetTestCase, 
ApiOwnersTestCaseMixin):
         rv = self.client.post(uri, json=query_context)
         self.assertEqual(rv.status_code, 200)
         data = json.loads(rv.data.decode("utf-8"))
-        self.assertEqual(data[0]["rowcount"], 100)
+        self.assertEqual(data["result"][0]["rowcount"], 100)
+
+    def test_invalid_chart_data(self):
+        """
+            Query API: Test chart data query
+        """
+        self.login(username="admin")
+        query_context = self._get_query_context()
+        query_context["datasource"] = "abc"
+        uri = "api/v1/chart/data"
+        rv = self.client.post(uri, json=query_context)
+        self.assertEqual(rv.status_code, 400)
 
     def test_query_exec_not_allowed(self):
         """

Reply via email to