This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new a6cedaa  chore: Improve chart data API + schemas + tests (#9599)
a6cedaa is described below

commit a6cedaaa879348aca49a520793bb20e63d152a1f
Author: Ville Brofeldt <33317356+ville...@users.noreply.github.com>
AuthorDate: Thu Apr 23 14:30:48 2020 +0300

    chore: Improve chart data API + schemas + tests (#9599)
    
    * Make all fields optional in QueryObject and fix having_druid schema
    
    * fix: datasource type sql to table
    
    * lint
    
    * Add missing fields
    
    * Refactor tests
    
    * Linting
    
    * Refactor query context fixtures
    
    * Add typing to test func
---
 superset/charts/schemas.py         | 208 ++++++++++++++++++++++---------------
 superset/common/query_object.py    |  61 ++++++++---
 superset/connectors/base/models.py |   4 +-
 tests/access_tests.py              |   4 +-
 tests/base_tests.py                |  21 ++--
 tests/charts/api_tests.py          |  63 +++++------
 tests/core_tests.py                | 111 --------------------
 tests/dict_import_export_tests.py  |  12 +--
 tests/fixtures/query_context.py    | 103 ++++++++++++++++++
 tests/import_export_tests.py       |  10 +-
 tests/query_context_tests.py       |  94 +++++++++++++++++
 11 files changed, 421 insertions(+), 270 deletions(-)

diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index 2743732..d7438aa 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -16,7 +16,7 @@
 # under the License.
 from typing import Any, Dict, Union
 
-from marshmallow import fields, post_load, Schema, ValidationError
+from marshmallow import fields, post_load, Schema, validate, ValidationError
 from marshmallow.validate import Length
 
 from superset.common.query_context import QueryContext
@@ -77,13 +77,15 @@ class ChartDataAdhocMetricSchema(Schema):
     expressionType = fields.String(
         description="Simple or SQL metric",
         required=True,
-        enum=["SIMPLE", "SQL"],
+        validate=validate.OneOf(choices=("SIMPLE", "SQL")),
         example="SQL",
     )
     aggregate = fields.String(
         description="Aggregation operator. Only required for simple expression 
types.",
         required=False,
-        enum=["AVG", "COUNT", "COUNT_DISTINCT", "MAX", "MIN", "SUM"],
+        validate=validate.OneOf(
+            choices=("AVG", "COUNT", "COUNT_DISTINCT", "MAX", "MIN", "SUM")
+        ),
     )
     column = fields.Nested(ChartDataColumnSchema)
     sqlExpression = fields.String(
@@ -178,28 +180,30 @@ class 
ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchem
     )
     rolling_type = fields.String(
         description="Type of rolling window. Any numpy function will work.",
-        enum=[
-            "average",
-            "argmin",
-            "argmax",
-            "cumsum",
-            "cumprod",
-            "max",
-            "mean",
-            "median",
-            "nansum",
-            "nanmin",
-            "nanmax",
-            "nanmean",
-            "nanmedian",
-            "min",
-            "percentile",
-            "prod",
-            "product",
-            "std",
-            "sum",
-            "var",
-        ],
+        validate=validate.OneOf(
+            choices=(
+                "average",
+                "argmin",
+                "argmax",
+                "cumsum",
+                "cumprod",
+                "max",
+                "mean",
+                "median",
+                "nansum",
+                "nanmin",
+                "nanmax",
+                "nanmean",
+                "nanmedian",
+                "min",
+                "percentile",
+                "prod",
+                "product",
+                "std",
+                "sum",
+                "var",
+            )
+        ),
         required=True,
         example="percentile",
     )
@@ -225,23 +229,25 @@ class 
ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchem
         "additional parameters to `rolling_type_options`. For instance, "
         "to use `gaussian`, the parameter `std` needs to be provided.",
         required=False,
-        enum=[
-            "boxcar",
-            "triang",
-            "blackman",
-            "hamming",
-            "bartlett",
-            "parzen",
-            "bohman",
-            "blackmanharris",
-            "nuttall",
-            "barthann",
-            "kaiser",
-            "gaussian",
-            "general_gaussian",
-            "slepian",
-            "exponential",
-        ],
+        validate=validate.OneOf(
+            choices=(
+                "boxcar",
+                "triang",
+                "blackman",
+                "hamming",
+                "bartlett",
+                "parzen",
+                "bohman",
+                "blackmanharris",
+                "nuttall",
+                "barthann",
+                "kaiser",
+                "gaussian",
+                "general_gaussian",
+                "slepian",
+                "exponential",
+            )
+        ),
     )
     min_periods = fields.Integer(
         description="The minimum amount of periods required for a row to be 
included "
@@ -333,7 +339,9 @@ class ChartDataPostProcessingOperationSchema(Schema):
     operation = fields.String(
         description="Post processing operation type",
         required=True,
-        enum=["aggregate", "pivot", "rolling", "select", "sort"],
+        validate=validate.OneOf(
+            choices=("aggregate", "pivot", "rolling", "select", "sort")
+        ),
         example="aggregate",
     )
     options = fields.Nested(
@@ -362,7 +370,9 @@ class ChartDataFilterSchema(Schema):
     )
     op = fields.String(  # pylint: disable=invalid-name
         description="The comparison operator.",
-        enum=[filter_op.value for filter_op in utils.FilterOperator],
+        validate=validate.OneOf(
+            choices=[filter_op.value for filter_op in utils.FilterOperator]
+        ),
         required=True,
         example="IN",
     )
@@ -376,21 +386,23 @@ class ChartDataFilterSchema(Schema):
 class ChartDataExtrasSchema(Schema):
 
     time_range_endpoints = fields.List(
-        fields.String(enum=["INCLUSIVE", "EXCLUSIVE"]),
-        description="A list with two values, stating if start/end should be "
-        "inclusive/exclusive.",
-        required=False,
+        fields.String(
+            validate=validate.OneOf(choices=("INCLUSIVE", "EXCLUSIVE")),
+            description="A list with two values, stating if start/end should 
be "
+            "inclusive/exclusive.",
+            required=False,
+        )
     )
     relative_start = fields.String(
         description="Start time for relative time deltas. "
         'Default: `config["DEFAULT_RELATIVE_START_TIME"]`',
-        enum=["today", "now"],
+        validate=validate.OneOf(choices=("today", "now")),
         required=False,
     )
     relative_end = fields.String(
         description="End time for relative time deltas. "
         'Default: `config["DEFAULT_RELATIVE_START_TIME"]`',
-        enum=["today", "now"],
+        validate=validate.OneOf(choices=("today", "now")),
         required=False,
     )
     where = fields.String(
@@ -402,35 +414,54 @@ class ChartDataExtrasSchema(Schema):
         "AND operator.",
         required=False,
     )
-    having_druid = fields.String(
+    having_druid = fields.List(
+        fields.Nested(ChartDataFilterSchema),
         description="HAVING filters to be added to legacy Druid datasource 
queries.",
         required=False,
     )
+    time_grain_sqla = fields.String(
+        description="To what level of granularity should the temporal column 
be "
+        "aggregated. Supports "
+        "[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) 
durations.",
+        validate=validate.OneOf(
+            choices=(
+                "PT1S",
+                "PT1M",
+                "PT5M",
+                "PT10M",
+                "PT15M",
+                "PT0.5H",
+                "PT1H",
+                "P1D",
+                "P1W",
+                "P1M",
+                "P0.25Y",
+                "P1Y",
+            ),
+        ),
+        required=False,
+        example="P1D",
+    )
+    druid_time_origin = fields.String(
+        description="Starting point for time grain counting on legacy Druid "
+        "datasources. Used to change e.g. Monday/Sunday first-day-of-week.",
+        required=False,
+    )
 
 
 class ChartDataQueryObjectSchema(Schema):
     filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False)
     granularity = fields.String(
-        description="To what level of granularity should the temporal column 
be "
-        "aggregated. Supports "
-        "[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) "
-        "durations.",
-        enum=[
-            "PT1S",
-            "PT1M",
-            "PT5M",
-            "PT10M",
-            "PT15M",
-            "PT0.5H",
-            "PT1H",
-            "P1D",
-            "P1W",
-            "P1M",
-            "P0.25Y",
-            "P1Y",
-        ],
+        description="Name of temporal column used for time filtering. For 
legacy Druid "
+        "datasources this defines the time grain.",
         required=False,
-        example="P1D",
+    )
+    granularity_sqla = fields.String(
+        description="Name of temporal column used for time filtering for SQL "
+        "datasources. This field is deprecated, use `granularity` "
+        "instead.",
+        required=False,
+        deprecated=True,
     )
     groupby = fields.List(
         fields.String(description="Columns by which to group the query.",),
@@ -441,6 +472,7 @@ class ChartDataQueryObjectSchema(Schema):
         "references to datasource metrics (strings), or ad-hoc metrics"
         "which are defined only within the query object. See "
         "`ChartDataAdhocMetricSchema` for the structure of ad-hoc metrics.",
+        required=False,
     )
     post_processing = fields.List(
         fields.Nested(ChartDataPostProcessingOperationSchema),
@@ -450,7 +482,8 @@ class ChartDataQueryObjectSchema(Schema):
     )
     time_range = fields.String(
         description="A time rage, either expressed as a colon separated string 
"
-        "`since : until`. Valid formats for `since` and `until` are: \n"
+        "`since : until` or human readable freeform. Valid formats for "
+        "`since` and `until` are: \n"
         "- ISO 8601\n"
         "- X days/years/hours/day/year/weeks\n"
         "- X days/years/hours/day/year/weeks ago\n"
@@ -488,7 +521,7 @@ class ChartDataQueryObjectSchema(Schema):
     order_desc = fields.Boolean(
         description="Reverse order. Default: `false`", required=False
     )
-    extras = fields.Dict(description=" Default: `{}`", required=False)
+    extras = fields.Nested(ChartDataExtrasSchema, required=False)
     columns = fields.List(fields.String(), description="", required=False,)
     orderby = fields.List(
         fields.List(fields.Raw()),
@@ -499,13 +532,13 @@ class ChartDataQueryObjectSchema(Schema):
     )
     where = fields.String(
         description="WHERE clause to be added to queries using AND operator."
-        "This field is deprecated, and should be passed to `extras`.",
+        "This field is deprecated and should be passed to `extras`.",
         required=False,
         deprecated=True,
     )
     having = fields.String(
         description="HAVING clause to be added to aggregate queries using "
-        "AND operator. This field is deprecated, and should be passed "
+        "AND operator. This field is deprecated and should be passed "
         "to `extras`.",
         required=False,
         deprecated=True,
@@ -513,7 +546,7 @@ class ChartDataQueryObjectSchema(Schema):
     having_filters = fields.List(
         fields.Dict(),
         description="HAVING filters to be added to legacy Druid datasource 
queries. "
-        "This field is deprecated, and should be passed to `extras` "
+        "This field is deprecated and should be passed to `extras` "
         "as `filters_druid`.",
         required=False,
         deprecated=True,
@@ -523,7 +556,10 @@ class ChartDataQueryObjectSchema(Schema):
 class ChartDataDatasourceSchema(Schema):
     description = "Chart datasource"
     id = fields.Integer(description="Datasource id", required=True,)
-    type = fields.String(description="Datasource type", enum=["druid", "sql"])
+    type = fields.String(
+        description="Datasource type",
+        validate=validate.OneOf(choices=("druid", "table")),
+    )
 
 
 class ChartDataQueryContextSchema(Schema):
@@ -561,15 +597,17 @@ class ChartDataResponseResult(Schema):
     )
     status = fields.String(
         description="Status of the query",
-        enum=[
-            "stopped",
-            "failed",
-            "pending",
-            "running",
-            "scheduled",
-            "success",
-            "timed_out",
-        ],
+        validate=validate.OneOf(
+            choices=(
+                "stopped",
+                "failed",
+                "pending",
+                "running",
+                "scheduled",
+                "success",
+                "timed_out",
+            )
+        ),
         allow_none=False,
     )
     stacktrace = fields.String(
diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index 0a83ef7..ea1f3f5 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -35,15 +35,19 @@ logger = logging.getLogger(__name__)
 #  https://github.com/python/mypy/issues/5288
 
 
-class DeprecatedExtrasField(NamedTuple):
-    name: str
-    extras_name: str
+class DeprecatedField(NamedTuple):
+    old_name: str
+    new_name: str
 
 
+DEPRECATED_FIELDS = (
+    DeprecatedField(old_name="granularity_sqla", new_name="granularity"),
+)
+
 DEPRECATED_EXTRAS_FIELDS = (
-    DeprecatedExtrasField(name="where", extras_name="where"),
-    DeprecatedExtrasField(name="having", extras_name="having"),
-    DeprecatedExtrasField(name="having_filters", extras_name="having_druid"),
+    DeprecatedField(old_name="where", new_name="where"),
+    DeprecatedField(old_name="having", new_name="having"),
+    DeprecatedField(old_name="having_filters", new_name="having_druid"),
 )
 
 
@@ -53,7 +57,7 @@ class QueryObject:
     and druid. The query objects are constructed on the client.
     """
 
-    granularity: str
+    granularity: Optional[str]
     from_dttm: datetime
     to_dttm: datetime
     is_timeseries: bool
@@ -72,8 +76,8 @@ class QueryObject:
 
     def __init__(
         self,
-        granularity: str,
-        metrics: List[Union[Dict[str, Any], str]],
+        granularity: Optional[str] = None,
+        metrics: Optional[List[Union[Dict[str, Any], str]]] = None,
         groupby: Optional[List[str]] = None,
         filters: Optional[List[Dict[str, Any]]] = None,
         time_range: Optional[str] = None,
@@ -89,6 +93,7 @@ class QueryObject:
         post_processing: Optional[List[Dict[str, Any]]] = None,
         **kwargs: Any,
     ):
+        metrics = metrics or []
         extras = extras or {}
         is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE")
         self.granularity = granularity
@@ -131,22 +136,44 @@ class QueryObject:
         if is_sip_38 and groupby:
             self.columns += groupby
             logger.warning(
-                f"The field groupby is deprecated. Viz plugins should "
-                f"pass all selectables via the columns field"
+                f"The field `groupby` is deprecated. Viz plugins should "
+                f"pass all selectables via the `columns` field"
             )
 
         self.orderby = orderby or []
 
-        # move deprecated fields to extras
+        # rename deprecated fields
+        for field in DEPRECATED_FIELDS:
+            if field.old_name in kwargs:
+                logger.warning(
+                    f"The field `{field.old_name}` is deprecated, please use "
+                    f"`{field.new_name}` instead."
+                )
+                value = kwargs[field.old_name]
+                if value:
+                    if hasattr(self, field.new_name):
+                        logger.warning(
+                            f"The field `{field.new_name}` is already 
populated, "
+                            f"replacing value with contents from 
`{field.old_name}`."
+                        )
+                    setattr(self, field.new_name, value)
+
+        # move deprecated extras fields to extras
         for field in DEPRECATED_EXTRAS_FIELDS:
-            if field.name in kwargs:
+            if field.old_name in kwargs:
                 logger.warning(
-                    f"The field `{field.name} is deprecated, and should be "
-                    f"passed to `extras` via the `{field.extras_name}` 
property"
+                    f"The field `{field.old_name}` is deprecated and should be 
"
+                    f"passed to `extras` via the `{field.new_name}` property."
                 )
-                value = kwargs[field.name]
+                value = kwargs[field.old_name]
                 if value:
-                    self.extras[field.extras_name] = value
+                    if hasattr(self.extras, field.new_name):
+                        logger.warning(
+                            f"The field `{field.new_name}` is already 
populated in "
+                            f"`extras`, replacing value with contents "
+                            f"from `{field.old_name}`."
+                        )
+                    self.extras[field.new_name] = value
 
     def to_dict(self) -> Dict[str, Any]:
         query_object_dict = {
diff --git a/superset/connectors/base/models.py 
b/superset/connectors/base/models.py
index 2b6e0d2..2b051d3 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -366,7 +366,9 @@ class BaseDatasource(
     def default_query(qry) -> Query:
         return qry
 
-    def get_column(self, column_name: str) -> Optional["BaseColumn"]:
+    def get_column(self, column_name: Optional[str]) -> Optional["BaseColumn"]:
+        if not column_name:
+            return None
         for col in self.columns:
             if col.column_name == column_name:
                 return col
diff --git a/tests/access_tests.py b/tests/access_tests.py
index 58affb6..f5ac8e0 100644
--- a/tests/access_tests.py
+++ b/tests/access_tests.py
@@ -385,7 +385,7 @@ class RequestAccessTests(SupersetTestCase):
             )
             self.assertEqual(
                 "[Superset] Access to the datasource {} was granted".format(
-                    self.get_table(ds_1_id).full_name
+                    self.get_table_by_id(ds_1_id).full_name
                 ),
                 call_args[2]["Subject"],
             )
@@ -426,7 +426,7 @@ class RequestAccessTests(SupersetTestCase):
             )
             self.assertEqual(
                 "[Superset] Access to the datasource {} was granted".format(
-                    self.get_table(ds_2_id).full_name
+                    self.get_table_by_id(ds_2_id).full_name
                 ),
                 call_args[2]["Subject"],
             )
diff --git a/tests/base_tests.py b/tests/base_tests.py
index 370adf3..ebe1d9a 100644
--- a/tests/base_tests.py
+++ b/tests/base_tests.py
@@ -18,16 +18,18 @@
 """Unit tests for Superset"""
 import imp
 import json
-from typing import Union, Dict
+from typing import Dict, Union
 from unittest.mock import Mock, patch
 
 import pandas as pd
 from flask import Response
 from flask_appbuilder.security.sqla import models as ab_models
 from flask_testing import TestCase
+from sqlalchemy.orm import Session
 
 from tests.test_app import app  # isort:skip
 from superset import db, security_manager
+from superset.connectors.base.models import BaseDatasource
 from superset.connectors.druid.models import DruidCluster, DruidDatasource
 from superset.connectors.sqla.models import SqlaTable
 from superset.models import core as models
@@ -103,7 +105,8 @@ class SupersetTestCase(TestCase):
                 session.add(druid_datasource2)
                 session.commit()
 
-    def get_table(self, table_id):
+    @staticmethod
+    def get_table_by_id(table_id: int) -> SqlaTable:
         return db.session.query(SqlaTable).filter_by(id=table_id).one()
 
     @staticmethod
@@ -127,21 +130,25 @@ class SupersetTestCase(TestCase):
         resp = self.get_resp("/login/", data=dict(username=username, 
password=password))
         self.assertNotIn("User confirmation needed", resp)
 
-    def get_slice(self, slice_name, session):
+    def get_slice(self, slice_name: str, session: Session) -> Slice:
         slc = session.query(Slice).filter_by(slice_name=slice_name).one()
         session.expunge_all()
         return slc
 
-    def get_table_by_name(self, name):
+    @staticmethod
+    def get_table_by_name(name: str) -> SqlaTable:
         return db.session.query(SqlaTable).filter_by(table_name=name).one()
 
-    def get_database_by_id(self, db_id):
+    @staticmethod
+    def get_database_by_id(db_id: int) -> Database:
         return db.session.query(Database).filter_by(id=db_id).one()
 
-    def get_druid_ds_by_name(self, name):
+    @staticmethod
+    def get_druid_ds_by_name(name: str) -> DruidDatasource:
         return 
db.session.query(DruidDatasource).filter_by(datasource_name=name).first()
 
-    def get_datasource_mock(self):
+    @staticmethod
+    def get_datasource_mock() -> BaseDatasource:
         datasource = Mock()
         results = Mock()
         results.query = Mock()
diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py
index 0fb8471..0a7ab58 100644
--- a/tests/charts/api_tests.py
+++ b/tests/charts/api_tests.py
@@ -16,7 +16,7 @@
 # under the License.
 """Unit tests for Superset"""
 import json
-from typing import Any, Dict, List, Optional
+from typing import List, Optional
 
 import prison
 from sqlalchemy.sql import func
@@ -28,6 +28,7 @@ from superset.models.dashboard import Dashboard
 from superset.models.slice import Slice
 from tests.base_api_tests import ApiOwnersTestCaseMixin
 from tests.base_tests import SupersetTestCase
+from tests.fixtures.query_context import get_query_context
 
 
 class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
@@ -69,32 +70,6 @@ class ChartApiTests(SupersetTestCase, 
ApiOwnersTestCaseMixin):
         db.session.commit()
         return slice
 
-    def _get_query_context(self) -> Dict[str, Any]:
-        self.login(username="admin")
-        slc = self.get_slice("Girl Name Cloud", db.session)
-        return {
-            "datasource": {"id": slc.datasource_id, "type": 
slc.datasource_type},
-            "queries": [
-                {
-                    "extras": {"where": ""},
-                    "granularity": "ds",
-                    "groupby": ["name"],
-                    "is_timeseries": False,
-                    "metrics": [{"label": "sum__num"}],
-                    "order_desc": True,
-                    "orderby": [],
-                    "row_limit": 100,
-                    "time_range": "100 years ago : now",
-                    "timeseries_limit": 0,
-                    "timeseries_limit_metric": None,
-                    "filters": [{"col": "gender", "op": "==", "val": "boy"}],
-                    "having": "",
-                    "having_filters": [],
-                    "where": "",
-                }
-            ],
-        }
-
     def test_delete_chart(self):
         """
         Chart API: Test delete
@@ -662,22 +637,37 @@ class ChartApiTests(SupersetTestCase, 
ApiOwnersTestCaseMixin):
         Query API: Test chart data query
         """
         self.login(username="admin")
-        query_context = self._get_query_context()
+        table = self.get_table_by_name("birth_names")
+        payload = get_query_context(table.name, table.id, table.type)
         uri = "api/v1/chart/data"
-        rv = self.post_assert_metric(uri, query_context, "data")
+        rv = self.post_assert_metric(uri, payload, "data")
         self.assertEqual(rv.status_code, 200)
         data = json.loads(rv.data.decode("utf-8"))
         self.assertEqual(data["result"][0]["rowcount"], 100)
 
-    def test_invalid_chart_data(self):
+    def test_chart_data_with_invalid_datasource(self):
+        """Query API: Test chart data query with invalid schema
         """
-        Query API: Test chart data query with invalid schema
+        self.login(username="admin")
+        table = self.get_table_by_name("birth_names")
+        payload = get_query_context(table.name, table.id, table.type)
+        payload["datasource"] = "abc"
+        uri = "api/v1/chart/data"
+        rv = self.post_assert_metric(uri, payload, "data")
+        self.assertEqual(rv.status_code, 400)
+
+    def test_chart_data_with_invalid_enum_value(self):
+        """Query API: Test chart data query with invalid enum value
         """
         self.login(username="admin")
-        query_context = self._get_query_context()
-        query_context["datasource"] = "abc"
+        table = self.get_table_by_name("birth_names")
+        payload = get_query_context(table.name, table.id, table.type)
+        payload["queries"][0]["extras"]["time_range_endpoints"] = [
+            "abc",
+            "EXCLUSIVE",
+        ]
         uri = "api/v1/chart/data"
-        rv = self.client.post(uri, json=query_context)
+        rv = self.client.post(uri, json=payload)
         self.assertEqual(rv.status_code, 400)
 
     def test_query_exec_not_allowed(self):
@@ -685,9 +675,10 @@ class ChartApiTests(SupersetTestCase, 
ApiOwnersTestCaseMixin):
         Query API: Test chart data query not allowed
         """
         self.login(username="gamma")
-        query_context = self._get_query_context()
+        table = self.get_table_by_name("birth_names")
+        payload = get_query_context(table.name, table.id, table.type)
         uri = "api/v1/chart/data"
-        rv = self.post_assert_metric(uri, query_context, "data")
+        rv = self.post_assert_metric(uri, payload, "data")
         self.assertEqual(rv.status_code, 401)
 
     def test_datasources(self):
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 04cc587..ecbfa05 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -28,7 +28,6 @@ import pytz
 import random
 import re
 import string
-from typing import Any, Dict
 import unittest
 from unittest import mock, skipUnless
 
@@ -44,8 +43,6 @@ from superset import (
     sql_lab,
     is_feature_enabled,
 )
-from superset.common.query_context import QueryContext
-from superset.connectors.connector_registry import ConnectorRegistry
 from superset.connectors.sqla.models import SqlaTable
 from superset.db_engine_specs.base import BaseEngineSpec
 from superset.db_engine_specs.mssql import MssqlEngineSpec
@@ -111,61 +108,6 @@ class CoreTests(SupersetTestCase):
         resp = self.client.get("/superset/slice/-1/")
         assert resp.status_code == 404
 
-    def _get_query_context(self) -> Dict[str, Any]:
-        self.login(username="admin")
-        slc = self.get_slice("Girl Name Cloud", db.session)
-        return {
-            "datasource": {"id": slc.datasource_id, "type": 
slc.datasource_type},
-            "queries": [
-                {
-                    "granularity": "ds",
-                    "groupby": ["name"],
-                    "metrics": [{"label": "sum__num"}],
-                    "filters": [],
-                    "row_limit": 100,
-                }
-            ],
-        }
-
-    def _get_query_context_with_post_processing(self) -> Dict[str, Any]:
-        self.login(username="admin")
-        slc = self.get_slice("Girl Name Cloud", db.session)
-        return {
-            "datasource": {"id": slc.datasource_id, "type": 
slc.datasource_type},
-            "queries": [
-                {
-                    "granularity": "ds",
-                    "groupby": ["name", "state"],
-                    "metrics": [{"label": "sum__num"}],
-                    "filters": [],
-                    "row_limit": 100,
-                    "post_processing": [
-                        {
-                            "operation": "aggregate",
-                            "options": {
-                                "groupby": ["state"],
-                                "aggregates": {
-                                    "q1": {
-                                        "operator": "percentile",
-                                        "column": "sum__num",
-                                        "options": {"q": 25},
-                                    },
-                                    "median": {
-                                        "operator": "median",
-                                        "column": "sum__num",
-                                    },
-                                },
-                            },
-                        },
-                        {
-                            "operation": "sort",
-                            "options": {"columns": {"q1": False, "state": 
True},},
-                        },
-                    ],
-                }
-            ],
-        }
-
     def test_viz_cache_key(self):
         self.login(username="admin")
         slc = self.get_slice("Girls", db.session)
@@ -178,45 +120,6 @@ class CoreTests(SupersetTestCase):
         qobj["groupby"] = []
         self.assertNotEqual(cache_key, viz.cache_key(qobj))
 
-    def test_cache_key_changes_when_datasource_is_updated(self):
-        qc_dict = self._get_query_context()
-
-        # construct baseline cache_key
-        query_context = QueryContext(**qc_dict)
-        query_object = query_context.queries[0]
-        cache_key_original = query_context.cache_key(query_object)
-
-        # make temporary change and revert it to refresh the changed_on 
property
-        datasource = ConnectorRegistry.get_datasource(
-            datasource_type=qc_dict["datasource"]["type"],
-            datasource_id=qc_dict["datasource"]["id"],
-            session=db.session,
-        )
-        description_original = datasource.description
-        datasource.description = "temporary description"
-        db.session.commit()
-        datasource.description = description_original
-        db.session.commit()
-
-        # create new QueryContext with unchanged attributes and extract new 
cache_key
-        query_context = QueryContext(**qc_dict)
-        query_object = query_context.queries[0]
-        cache_key_new = query_context.cache_key(query_object)
-
-        # the new cache_key should be different due to updated datasource
-        self.assertNotEqual(cache_key_original, cache_key_new)
-
-    def test_query_context_time_range_endpoints(self):
-        query_context = QueryContext(**self._get_query_context())
-        query_object = query_context.queries[0]
-        extras = query_object.to_dict()["extras"]
-        self.assertTrue("time_range_endpoints" in extras)
-
-        self.assertEquals(
-            extras["time_range_endpoints"],
-            (utils.TimeRangeEndpoint.INCLUSIVE, 
utils.TimeRangeEndpoint.EXCLUSIVE),
-        )
-
     def test_get_superset_tables_not_allowed(self):
         example_db = utils.get_example_database()
         schema_name = self.default_schema_backend_map[example_db.backend]
@@ -254,20 +157,6 @@ class CoreTests(SupersetTestCase):
         rv = self.client.get(uri)
         self.assertEqual(rv.status_code, 404)
 
-    def test_api_v1_query_endpoint(self):
-        self.login(username="admin")
-        qc_dict = self._get_query_context()
-        data = json.dumps(qc_dict)
-        resp = json.loads(self.get_resp("/api/v1/query/", {"query_context": 
data}))
-        self.assertEqual(resp[0]["rowcount"], 100)
-
-    def test_api_v1_query_endpoint_with_post_processing(self):
-        self.login(username="admin")
-        qc_dict = self._get_query_context_with_post_processing()
-        data = json.dumps(qc_dict)
-        resp = json.loads(self.get_resp("/api/v1/query/", {"query_context": 
data}))
-        self.assertEqual(resp[0]["rowcount"], 6)
-
     def test_old_slice_json_endpoint(self):
         self.login(username="admin")
         slc = self.get_slice("Girls", db.session)
diff --git a/tests/dict_import_export_tests.py 
b/tests/dict_import_export_tests.py
index 404709c..e857479 100644
--- a/tests/dict_import_export_tests.py
+++ b/tests/dict_import_export_tests.py
@@ -165,7 +165,7 @@ class DictImportExportTests(SupersetTestCase):
         new_table = SqlaTable.import_from_dict(db.session, dict_table)
         db.session.commit()
         imported_id = new_table.id
-        imported = self.get_table(imported_id)
+        imported = self.get_table_by_id(imported_id)
         self.assert_table_equals(table, imported)
         self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
 
@@ -178,7 +178,7 @@ class DictImportExportTests(SupersetTestCase):
         )
         imported_table = SqlaTable.import_from_dict(db.session, dict_table)
         db.session.commit()
-        imported = self.get_table(imported_table.id)
+        imported = self.get_table_by_id(imported_table.id)
         self.assert_table_equals(table, imported)
         self.assertEqual(
             {DBREF: ID_PREFIX + 2, "database_name": "main"}, 
json.loads(imported.params)
@@ -194,7 +194,7 @@ class DictImportExportTests(SupersetTestCase):
         )
         imported_table = SqlaTable.import_from_dict(db.session, dict_table)
         db.session.commit()
-        imported = self.get_table(imported_table.id)
+        imported = self.get_table_by_id(imported_table.id)
         self.assert_table_equals(table, imported)
         self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
 
@@ -213,7 +213,7 @@ class DictImportExportTests(SupersetTestCase):
         imported_over_table = SqlaTable.import_from_dict(db.session, 
dict_table_over)
         db.session.commit()
 
-        imported_over = self.get_table(imported_over_table.id)
+        imported_over = self.get_table_by_id(imported_over_table.id)
         self.assertEqual(imported_table.id, imported_over.id)
         expected_table, _ = self.create_table(
             "table_override",
@@ -243,7 +243,7 @@ class DictImportExportTests(SupersetTestCase):
         )
         db.session.commit()
 
-        imported_over = self.get_table(imported_over_table.id)
+        imported_over = self.get_table_by_id(imported_over_table.id)
         self.assertEqual(imported_table.id, imported_over.id)
         expected_table, _ = self.create_table(
             "table_override",
@@ -274,7 +274,7 @@ class DictImportExportTests(SupersetTestCase):
         imported_copy_table = SqlaTable.import_from_dict(db.session, 
dict_copy_table)
         db.session.commit()
         self.assertEqual(imported_table.id, imported_copy_table.id)
-        self.assert_table_equals(copy_table, self.get_table(imported_table.id))
+        self.assert_table_equals(copy_table, 
self.get_table_by_id(imported_table.id))
         self.yaml_compare(
             imported_copy_table.export_to_dict(), 
imported_table.export_to_dict()
         )
diff --git a/tests/fixtures/query_context.py b/tests/fixtures/query_context.py
new file mode 100644
index 0000000..e886fda
--- /dev/null
+++ b/tests/fixtures/query_context.py
@@ -0,0 +1,103 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import copy
+from typing import Any, Dict, List
+
+QUERY_OBJECTS = {
+    "birth_names": {
+        "extras": {"where": "", "time_range_endpoints": ["INCLUSIVE", 
"EXCLUSIVE"],},
+        "granularity": "ds",
+        "groupby": ["name"],
+        "is_timeseries": False,
+        "metrics": [{"label": "sum__num"}],
+        "order_desc": True,
+        "orderby": [],
+        "row_limit": 100,
+        "time_range": "100 years ago : now",
+        "timeseries_limit": 0,
+        "timeseries_limit_metric": None,
+        "filters": [{"col": "gender", "op": "==", "val": "boy"}],
+        "having": "",
+        "having_filters": [],
+        "where": "",
+    }
+}
+
+POSTPROCESSING_OPERATIONS = {
+    "birth_names": [
+        {
+            "operation": "aggregate",
+            "options": {
+                "groupby": ["gender"],
+                "aggregates": {
+                    "q1": {
+                        "operator": "percentile",
+                        "column": "sum__num",
+                        "options": {"q": 25},
+                    },
+                    "median": {"operator": "median", "column": "sum__num",},
+                },
+            },
+        },
+        {"operation": "sort", "options": {"columns": {"q1": False, "gender": 
True},},},
+    ]
+}
+
+
+def _get_query_object(
+    datasource_name: str, add_postprocessing_operations: bool
+) -> Dict[str, Any]:
+    if datasource_name not in QUERY_OBJECTS:
+        raise Exception(
+            f"QueryObject fixture not defined for datasource: 
{datasource_name}"
+        )
+    query_object = copy.deepcopy(QUERY_OBJECTS[datasource_name])
+    if add_postprocessing_operations:
+        query_object["post_processing"] = 
_get_postprocessing_operation(datasource_name)
+    return query_object
+
+
+def _get_postprocessing_operation(datasource_name: str) -> List[Dict[str, 
Any]]:
+    if datasource_name not in QUERY_OBJECTS:
+        raise Exception(
+            f"Post-processing fixture not defined for datasource: 
{datasource_name}"
+        )
+    return copy.deepcopy(POSTPROCESSING_OPERATIONS[datasource_name])
+
+
+def get_query_context(
+    datasource_name: str = "birth_names",
+    datasource_id: int = 0,
+    datasource_type: str = "table",
+    add_postprocessing_operations: bool = False,
+) -> Dict[str, Any]:
+    """
+    Create a request payload for retrieving a QueryContext object via the
+    `api/v1/chart/data` endpoint. By default returns a payload corresponding 
to one
+    generated by the "Boy Name Cloud" chart in the examples.
+
+    :param datasource_name: name of datasource to query. Different datasources 
require
+           different parameters in the QueryContext.
+    :param datasource_id: id of datasource to query.
+    :param datasource_type: type of datasource to query.
+    :param add_postprocessing_operations: Add post-processing operations to 
QueryObject
+    :return: Request payload
+    """
+    return {
+        "datasource": {"id": datasource_id, "type": datasource_type},
+        "queries": [_get_query_object(datasource_name, 
add_postprocessing_operations)],
+    }
diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py
index 458ab04..b47b43b 100644
--- a/tests/import_export_tests.py
+++ b/tests/import_export_tests.py
@@ -558,7 +558,7 @@ class ImportExportTests(SupersetTestCase):
     def test_import_table_no_metadata(self):
         table = self.create_table("pure_table", id=10001)
         imported_id = SqlaTable.import_obj(table, import_time=1989)
-        imported = self.get_table(imported_id)
+        imported = self.get_table_by_id(imported_id)
         self.assert_table_equals(table, imported)
 
     def test_import_table_1_col_1_met(self):
@@ -566,7 +566,7 @@ class ImportExportTests(SupersetTestCase):
             "table_1_col_1_met", id=10002, cols_names=["col1"], 
metric_names=["metric1"]
         )
         imported_id = SqlaTable.import_obj(table, import_time=1990)
-        imported = self.get_table(imported_id)
+        imported = self.get_table_by_id(imported_id)
         self.assert_table_equals(table, imported)
         self.assertEqual(
             {"remote_id": 10002, "import_time": 1990, "database_name": 
"examples"},
@@ -582,7 +582,7 @@ class ImportExportTests(SupersetTestCase):
         )
         imported_id = SqlaTable.import_obj(table, import_time=1991)
 
-        imported = self.get_table(imported_id)
+        imported = self.get_table_by_id(imported_id)
         self.assert_table_equals(table, imported)
 
     def test_import_table_override(self):
@@ -599,7 +599,7 @@ class ImportExportTests(SupersetTestCase):
         )
         imported_over_id = SqlaTable.import_obj(table_over, import_time=1992)
 
-        imported_over = self.get_table(imported_over_id)
+        imported_over = self.get_table_by_id(imported_over_id)
         self.assertEqual(imported_id, imported_over.id)
         expected_table = self.create_table(
             "table_override",
@@ -627,7 +627,7 @@ class ImportExportTests(SupersetTestCase):
         imported_id_copy = SqlaTable.import_obj(copy_table, import_time=1994)
 
         self.assertEqual(imported_id, imported_id_copy)
-        self.assert_table_equals(copy_table, self.get_table(imported_id))
+        self.assert_table_equals(copy_table, self.get_table_by_id(imported_id))
 
     def test_import_druid_no_metadata(self):
         datasource = self.create_druid_datasource("pure_druid", id=10001)
diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py
new file mode 100644
index 0000000..2d378ba
--- /dev/null
+++ b/tests/query_context_tests.py
@@ -0,0 +1,94 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Any, Dict, List, Optional
+
+from superset import db
+from superset.common.query_context import QueryContext
+from superset.connectors.connector_registry import ConnectorRegistry
+from superset.utils.core import TimeRangeEndpoint
+from tests.base_tests import SupersetTestCase
+from tests.fixtures.query_context import get_query_context
+from tests.test_app import app
+
+
+class QueryContextTests(SupersetTestCase):
+    def test_cache_key_changes_when_datasource_is_updated(self):
+        self.login(username="admin")
+        table_name = "birth_names"
+        table = self.get_table_by_name(table_name)
+        payload = get_query_context(table.name, table.id, table.type)
+
+        # construct baseline cache_key
+        query_context = QueryContext(**payload)
+        query_object = query_context.queries[0]
+        cache_key_original = query_context.cache_key(query_object)
+
+        # make temporary change and revert it to refresh the changed_on 
property
+        datasource = ConnectorRegistry.get_datasource(
+            datasource_type=payload["datasource"]["type"],
+            datasource_id=payload["datasource"]["id"],
+            session=db.session,
+        )
+        description_original = datasource.description
+        datasource.description = "temporary description"
+        db.session.commit()
+        datasource.description = description_original
+        db.session.commit()
+
+        # create new QueryContext with unchanged attributes and extract new 
cache_key
+        query_context = QueryContext(**payload)
+        query_object = query_context.queries[0]
+        cache_key_new = query_context.cache_key(query_object)
+
+        # the new cache_key should be different due to updated datasource
+        self.assertNotEqual(cache_key_original, cache_key_new)
+
+    def test_query_context_time_range_endpoints(self):
+        """
+        Ensure that time_range_endpoints are populated automatically when 
missing
+        from the payload
+        """
+        self.login(username="admin")
+        table_name = "birth_names"
+        table = self.get_table_by_name(table_name)
+        payload = get_query_context(table.name, table.id, table.type)
+        del payload["queries"][0]["extras"]["time_range_endpoints"]
+        query_context = QueryContext(**payload)
+        query_object = query_context.queries[0]
+        extras = query_object.to_dict()["extras"]
+        self.assertTrue("time_range_endpoints" in extras)
+
+        self.assertEquals(
+            extras["time_range_endpoints"],
+            (TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE),
+        )
+
+    def test_convert_deprecated_fields(self):
+        """
+        Ensure that deprecated fields are converted correctly
+        """
+        self.login(username="admin")
+        table_name = "birth_names"
+        table = self.get_table_by_name(table_name)
+        payload = get_query_context(table.name, table.id, table.type)
+        payload["queries"][0]["granularity_sqla"] = "timecol"
+        payload["queries"][0]["having_filters"] = {"col": "a", "op": "==", 
"val": "b"}
+        query_context = QueryContext(**payload)
+        self.assertEqual(len(query_context.queries), 1)
+        query_object = query_context.queries[0]
+        self.assertEqual(query_object.granularity, "timecol")
+        self.assertIn("having_druid", query_object.extras)

Reply via email to