This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 85778f5 fix: broken IS NULL and IS NOT NULL operator (#9613)
85778f5 is described below
commit 85778f5bec3e4321715352a2c7715c0128e5598a
Author: Ville Brofeldt <[email protected]>
AuthorDate: Wed Apr 22 19:11:45 2020 +0300
fix: broken IS NULL and IS NOT NULL operator (#9613)
* fix: broken is null and is not null operator
* add unit tests
* Rename filter operator enum
---
superset/charts/schemas.py | 2 +-
superset/connectors/druid/models.py | 48 ++++++++++++++++++-------------------
superset/connectors/sqla/models.py | 32 ++++++++++++-------------
superset/utils/core.py | 4 ++--
tests/sqla_models_tests.py | 40 +++++++++++++++++++++++++++++--
5 files changed, 81 insertions(+), 45 deletions(-)
diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index 49d0480..2743732 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -362,7 +362,7 @@ class ChartDataFilterSchema(Schema):
)
op = fields.String( # pylint: disable=invalid-name
description="The comparison operator.",
- enum=[filter_op.value for filter_op in utils.FilterOperationType],
+ enum=[filter_op.value for filter_op in utils.FilterOperator],
required=True,
example="IN",
)
diff --git a/superset/connectors/druid/models.py
b/superset/connectors/druid/models.py
index 8d4aeb1..eef20e2 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -84,7 +84,7 @@ try:
from superset.utils.core import (
DimSelector,
DTTM_ALIAS,
- FilterOperationType,
+ FilterOperator,
flasher,
)
except ImportError:
@@ -1499,8 +1499,8 @@ class DruidDatasource(Model, BaseDatasource):
eq is None
and op
not in (
- FilterOperationType.IS_NULL.value,
- FilterOperationType.IS_NOT_NULL.value,
+ FilterOperator.IS_NULL.value,
+ FilterOperator.IS_NOT_NULL.value,
)
)
):
@@ -1517,8 +1517,8 @@ class DruidDatasource(Model, BaseDatasource):
cond = None
is_numeric_col = col in num_cols
is_list_target = op in (
- FilterOperationType.IN.value,
- FilterOperationType.NOT_IN.value,
+ FilterOperator.IN.value,
+ FilterOperator.NOT_IN.value,
)
eq = cls.filter_values_handler(
eq,
@@ -1528,11 +1528,11 @@ class DruidDatasource(Model, BaseDatasource):
# For these two ops, could have used Dimension,
# but it doesn't support extraction functions
- if op == FilterOperationType.EQUALS.value:
+ if op == FilterOperator.EQUALS.value:
cond = Filter(
dimension=col, value=eq, extraction_function=extraction_fn
)
- elif op == FilterOperationType.NOT_EQUALS.value:
+ elif op == FilterOperator.NOT_EQUALS.value:
cond = ~Filter(
dimension=col, value=eq, extraction_function=extraction_fn
)
@@ -1557,9 +1557,9 @@ class DruidDatasource(Model, BaseDatasource):
for s in eq:
fields.append(Dimension(col) == s)
cond = Filter(type="or", fields=fields)
- if op == FilterOperationType.NOT_IN.value:
+ if op == FilterOperator.NOT_IN.value:
cond = ~cond
- elif op == FilterOperationType.REGEX.value:
+ elif op == FilterOperator.REGEX.value:
cond = Filter(
extraction_function=extraction_fn,
type="regex",
@@ -1569,7 +1569,7 @@ class DruidDatasource(Model, BaseDatasource):
# For the ops below, could have used pydruid's Bound,
# but it doesn't support extraction functions
- elif op == FilterOperationType.GREATER_THAN_OR_EQUALS.value:
+ elif op == FilterOperator.GREATER_THAN_OR_EQUALS.value:
cond = Bound(
extraction_function=extraction_fn,
dimension=col,
@@ -1579,7 +1579,7 @@ class DruidDatasource(Model, BaseDatasource):
upper=None,
ordering=cls._get_ordering(is_numeric_col),
)
- elif op == FilterOperationType.LESS_THAN_OR_EQUALS.value:
+ elif op == FilterOperator.LESS_THAN_OR_EQUALS.value:
cond = Bound(
extraction_function=extraction_fn,
dimension=col,
@@ -1589,7 +1589,7 @@ class DruidDatasource(Model, BaseDatasource):
upper=eq,
ordering=cls._get_ordering(is_numeric_col),
)
- elif op == FilterOperationType.GREATER_THAN.value:
+ elif op == FilterOperator.GREATER_THAN.value:
cond = Bound(
extraction_function=extraction_fn,
lowerStrict=True,
@@ -1599,7 +1599,7 @@ class DruidDatasource(Model, BaseDatasource):
upper=None,
ordering=cls._get_ordering(is_numeric_col),
)
- elif op == FilterOperationType.LESS_THAN.value:
+ elif op == FilterOperator.LESS_THAN.value:
cond = Bound(
extraction_function=extraction_fn,
upperStrict=True,
@@ -1609,9 +1609,9 @@ class DruidDatasource(Model, BaseDatasource):
upper=eq,
ordering=cls._get_ordering(is_numeric_col),
)
- elif op == FilterOperationType.IS_NULL.value:
+ elif op == FilterOperator.IS_NULL.value:
cond = Filter(dimension=col, value="")
- elif op == FilterOperationType.IS_NOT_NULL.value:
+ elif op == FilterOperator.IS_NOT_NULL.value:
cond = ~Filter(dimension=col, value="")
if filters:
@@ -1627,14 +1627,14 @@ class DruidDatasource(Model, BaseDatasource):
def _get_having_obj(self, col: str, op: str, eq: str) -> "Having":
cond = None
- if op == FilterOperationType.EQUALS.value:
+ if op == FilterOperator.EQUALS.value:
if col in self.column_names:
cond = DimSelector(dimension=col, value=eq)
else:
cond = Aggregation(col) == eq
- elif op == FilterOperationType.GREATER_THAN.value:
+ elif op == FilterOperator.GREATER_THAN.value:
cond = Aggregation(col) > eq
- elif op == FilterOperationType.LESS_THAN.value:
+ elif op == FilterOperator.LESS_THAN.value:
cond = Aggregation(col) < eq
return cond
@@ -1642,9 +1642,9 @@ class DruidDatasource(Model, BaseDatasource):
def get_having_filters(self, raw_filters: List[Dict[str, Any]]) ->
"Having":
filters = None
reversed_op_map = {
- FilterOperationType.NOT_EQUALS.value:
FilterOperationType.EQUALS.value,
- FilterOperationType.GREATER_THAN_OR_EQUALS.value:
FilterOperationType.LESS_THAN.value,
- FilterOperationType.LESS_THAN_OR_EQUALS.value:
FilterOperationType.GREATER_THAN.value,
+ FilterOperator.NOT_EQUALS.value: FilterOperator.EQUALS.value,
+ FilterOperator.GREATER_THAN_OR_EQUALS.value:
FilterOperator.LESS_THAN.value,
+ FilterOperator.LESS_THAN_OR_EQUALS.value:
FilterOperator.GREATER_THAN.value,
}
for flt in raw_filters:
@@ -1655,9 +1655,9 @@ class DruidDatasource(Model, BaseDatasource):
eq = flt["val"]
cond = None
if op in [
- FilterOperationType.EQUALS.value,
- FilterOperationType.GREATER_THAN.value,
- FilterOperationType.LESS_THAN.value,
+ FilterOperator.EQUALS.value,
+ FilterOperator.GREATER_THAN.value,
+ FilterOperator.LESS_THAN.value,
]:
cond = self._get_having_obj(col, op, eq)
elif op in reversed_op_map:
diff --git a/superset/connectors/sqla/models.py
b/superset/connectors/sqla/models.py
index 8aac7e0..1eed4be 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -847,8 +847,8 @@ class SqlaTable(Model, BaseDatasource):
col_obj = cols.get(col)
if col_obj:
is_list_target = op in (
- utils.FilterOperationType.IN.value,
- utils.FilterOperationType.NOT_IN.value,
+ utils.FilterOperator.IN.value,
+ utils.FilterOperator.NOT_IN.value,
)
eq = self.filter_values_handler(
values=flt.get("val"),
@@ -856,36 +856,36 @@ class SqlaTable(Model, BaseDatasource):
is_list_target=is_list_target,
)
if op in (
- utils.FilterOperationType.IN.value,
- utils.FilterOperationType.NOT_IN.value,
+ utils.FilterOperator.IN.value,
+ utils.FilterOperator.NOT_IN.value,
):
cond = col_obj.get_sqla_col().in_(eq)
if isinstance(eq, str) and NULL_STRING in eq:
cond = or_(cond, col_obj.get_sqla_col() is None)
- if op == utils.FilterOperationType.NOT_IN.value:
+ if op == utils.FilterOperator.NOT_IN.value:
cond = ~cond
where_clause_and.append(cond)
else:
if col_obj.is_numeric:
eq = utils.cast_to_num(flt["val"])
- if op == utils.FilterOperationType.EQUALS.value:
+ if op == utils.FilterOperator.EQUALS.value:
where_clause_and.append(col_obj.get_sqla_col() == eq)
- elif op == utils.FilterOperationType.NOT_EQUALS.value:
+ elif op == utils.FilterOperator.NOT_EQUALS.value:
where_clause_and.append(col_obj.get_sqla_col() != eq)
- elif op == utils.FilterOperationType.GREATER_THAN.value:
+ elif op == utils.FilterOperator.GREATER_THAN.value:
where_clause_and.append(col_obj.get_sqla_col() > eq)
- elif op == utils.FilterOperationType.LESS_THAN.value:
+ elif op == utils.FilterOperator.LESS_THAN.value:
where_clause_and.append(col_obj.get_sqla_col() < eq)
- elif op ==
utils.FilterOperationType.GREATER_THAN_OR_EQUALS.value:
+ elif op ==
utils.FilterOperator.GREATER_THAN_OR_EQUALS.value:
where_clause_and.append(col_obj.get_sqla_col() >= eq)
- elif op ==
utils.FilterOperationType.LESS_THAN_OR_EQUALS.value:
+ elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value:
where_clause_and.append(col_obj.get_sqla_col() <= eq)
- elif op == utils.FilterOperationType.LIKE.value:
+ elif op == utils.FilterOperator.LIKE.value:
where_clause_and.append(col_obj.get_sqla_col().like(eq))
- elif op == utils.FilterOperationType.IS_NULL.value:
- where_clause_and.append(col_obj.get_sqla_col() is None)
- elif op == utils.FilterOperationType.IS_NOT_NULL.value:
- where_clause_and.append(col_obj.get_sqla_col() is None)
+ elif op == utils.FilterOperator.IS_NULL.value:
+ where_clause_and.append(col_obj.get_sqla_col() == None)
+ elif op == utils.FilterOperator.IS_NOT_NULL.value:
+ where_clause_and.append(col_obj.get_sqla_col() != None)
else:
raise Exception(
_("Invalid filter operation type: %(op)s", op=op)
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 0c838b7..ba715dd 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -1350,9 +1350,9 @@ class DbColumnType(Enum):
TEMPORAL = 2
-class FilterOperationType(str, Enum):
+class FilterOperator(str, Enum):
"""
- Filter operation type
+ Operators used filter controls
"""
EQUALS = "=="
diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py
index 700c2e2..af15b26 100644
--- a/tests/sqla_models_tests.py
+++ b/tests/sqla_models_tests.py
@@ -15,12 +15,12 @@
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
-from typing import Dict
+from typing import Any, Dict, NamedTuple, List, Tuple, Union
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.db_engine_specs.druid import DruidEngineSpec
from superset.models.core import Database
-from superset.utils.core import DbColumnType, get_example_database
+from superset.utils.core import DbColumnType, get_example_database,
FilterOperator
from .base_tests import SupersetTestCase
@@ -109,3 +109,39 @@ class DatabaseModelTestCase(SupersetTestCase):
extra_cache_keys = table.get_extra_cache_keys(query_obj)
self.assertFalse(table.has_calls_to_cache_key_wrapper(query_obj))
self.assertListEqual(extra_cache_keys, [])
+
+ def test_where_operators(self):
+ class FilterTestCase(NamedTuple):
+ operator: str
+ value: Union[float, int, List[Any], str]
+ expected: str
+
+ filters: Tuple[FilterTestCase, ...] = (
+ FilterTestCase(FilterOperator.IS_NULL, "", "IS NULL"),
+ FilterTestCase(FilterOperator.IS_NOT_NULL, "", "IS NOT NULL"),
+ FilterTestCase(FilterOperator.GREATER_THAN, 0, "> 0"),
+ FilterTestCase(FilterOperator.GREATER_THAN_OR_EQUALS, 0, ">= 0"),
+ FilterTestCase(FilterOperator.LESS_THAN, 0, "< 0"),
+ FilterTestCase(FilterOperator.LESS_THAN_OR_EQUALS, 0, "<= 0"),
+ FilterTestCase(FilterOperator.EQUALS, 0, "= 0"),
+ FilterTestCase(FilterOperator.NOT_EQUALS, 0, "!= 0"),
+ FilterTestCase(FilterOperator.IN, ["1", "2"], "IN (1, 2)"),
+ FilterTestCase(FilterOperator.NOT_IN, ["1", "2"], "NOT IN (1, 2)"),
+ )
+ table = self.get_table_by_name("birth_names")
+ for filter_ in filters:
+ query_obj = {
+ "granularity": None,
+ "from_dttm": None,
+ "to_dttm": None,
+ "groupby": ["gender"],
+ "metrics": ["count"],
+ "is_timeseries": False,
+ "filter": [
+ {"col": "num", "op": filter_.operator, "val":
filter_.value}
+ ],
+ "extras": {},
+ }
+ sqla_query = table.get_sqla_query(**query_obj)
+ sql = table.database.compile_sqla_query(sqla_query.sqla_query)
+ self.assertIn(filter_.expected, sql)