This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 85778f5  fix: broken IS NULL and IS NOT NULL operator (#9613)
85778f5 is described below

commit 85778f5bec3e4321715352a2c7715c0128e5598a
Author: Ville Brofeldt <[email protected]>
AuthorDate: Wed Apr 22 19:11:45 2020 +0300

    fix: broken IS NULL and IS NOT NULL operator (#9613)
    
    * fix: broken is null and is not null operator
    
    * add unit tests
    
    * Rename filter operator enum
---
 superset/charts/schemas.py          |  2 +-
 superset/connectors/druid/models.py | 48 ++++++++++++++++++-------------------
 superset/connectors/sqla/models.py  | 32 ++++++++++++-------------
 superset/utils/core.py              |  4 ++--
 tests/sqla_models_tests.py          | 40 +++++++++++++++++++++++++++++--
 5 files changed, 81 insertions(+), 45 deletions(-)

diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index 49d0480..2743732 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -362,7 +362,7 @@ class ChartDataFilterSchema(Schema):
     )
     op = fields.String(  # pylint: disable=invalid-name
         description="The comparison operator.",
-        enum=[filter_op.value for filter_op in utils.FilterOperationType],
+        enum=[filter_op.value for filter_op in utils.FilterOperator],
         required=True,
         example="IN",
     )
diff --git a/superset/connectors/druid/models.py 
b/superset/connectors/druid/models.py
index 8d4aeb1..eef20e2 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -84,7 +84,7 @@ try:
     from superset.utils.core import (
         DimSelector,
         DTTM_ALIAS,
-        FilterOperationType,
+        FilterOperator,
         flasher,
     )
 except ImportError:
@@ -1499,8 +1499,8 @@ class DruidDatasource(Model, BaseDatasource):
                     eq is None
                     and op
                     not in (
-                        FilterOperationType.IS_NULL.value,
-                        FilterOperationType.IS_NOT_NULL.value,
+                        FilterOperator.IS_NULL.value,
+                        FilterOperator.IS_NOT_NULL.value,
                     )
                 )
             ):
@@ -1517,8 +1517,8 @@ class DruidDatasource(Model, BaseDatasource):
             cond = None
             is_numeric_col = col in num_cols
             is_list_target = op in (
-                FilterOperationType.IN.value,
-                FilterOperationType.NOT_IN.value,
+                FilterOperator.IN.value,
+                FilterOperator.NOT_IN.value,
             )
             eq = cls.filter_values_handler(
                 eq,
@@ -1528,11 +1528,11 @@ class DruidDatasource(Model, BaseDatasource):
 
             # For these two ops, could have used Dimension,
             # but it doesn't support extraction functions
-            if op == FilterOperationType.EQUALS.value:
+            if op == FilterOperator.EQUALS.value:
                 cond = Filter(
                     dimension=col, value=eq, extraction_function=extraction_fn
                 )
-            elif op == FilterOperationType.NOT_EQUALS.value:
+            elif op == FilterOperator.NOT_EQUALS.value:
                 cond = ~Filter(
                     dimension=col, value=eq, extraction_function=extraction_fn
                 )
@@ -1557,9 +1557,9 @@ class DruidDatasource(Model, BaseDatasource):
                     for s in eq:
                         fields.append(Dimension(col) == s)
                     cond = Filter(type="or", fields=fields)
-                if op == FilterOperationType.NOT_IN.value:
+                if op == FilterOperator.NOT_IN.value:
                     cond = ~cond
-            elif op == FilterOperationType.REGEX.value:
+            elif op == FilterOperator.REGEX.value:
                 cond = Filter(
                     extraction_function=extraction_fn,
                     type="regex",
@@ -1569,7 +1569,7 @@ class DruidDatasource(Model, BaseDatasource):
 
             # For the ops below, could have used pydruid's Bound,
             # but it doesn't support extraction functions
-            elif op == FilterOperationType.GREATER_THAN_OR_EQUALS.value:
+            elif op == FilterOperator.GREATER_THAN_OR_EQUALS.value:
                 cond = Bound(
                     extraction_function=extraction_fn,
                     dimension=col,
@@ -1579,7 +1579,7 @@ class DruidDatasource(Model, BaseDatasource):
                     upper=None,
                     ordering=cls._get_ordering(is_numeric_col),
                 )
-            elif op == FilterOperationType.LESS_THAN_OR_EQUALS.value:
+            elif op == FilterOperator.LESS_THAN_OR_EQUALS.value:
                 cond = Bound(
                     extraction_function=extraction_fn,
                     dimension=col,
@@ -1589,7 +1589,7 @@ class DruidDatasource(Model, BaseDatasource):
                     upper=eq,
                     ordering=cls._get_ordering(is_numeric_col),
                 )
-            elif op == FilterOperationType.GREATER_THAN.value:
+            elif op == FilterOperator.GREATER_THAN.value:
                 cond = Bound(
                     extraction_function=extraction_fn,
                     lowerStrict=True,
@@ -1599,7 +1599,7 @@ class DruidDatasource(Model, BaseDatasource):
                     upper=None,
                     ordering=cls._get_ordering(is_numeric_col),
                 )
-            elif op == FilterOperationType.LESS_THAN.value:
+            elif op == FilterOperator.LESS_THAN.value:
                 cond = Bound(
                     extraction_function=extraction_fn,
                     upperStrict=True,
@@ -1609,9 +1609,9 @@ class DruidDatasource(Model, BaseDatasource):
                     upper=eq,
                     ordering=cls._get_ordering(is_numeric_col),
                 )
-            elif op == FilterOperationType.IS_NULL.value:
+            elif op == FilterOperator.IS_NULL.value:
                 cond = Filter(dimension=col, value="")
-            elif op == FilterOperationType.IS_NOT_NULL.value:
+            elif op == FilterOperator.IS_NOT_NULL.value:
                 cond = ~Filter(dimension=col, value="")
 
             if filters:
@@ -1627,14 +1627,14 @@ class DruidDatasource(Model, BaseDatasource):
 
     def _get_having_obj(self, col: str, op: str, eq: str) -> "Having":
         cond = None
-        if op == FilterOperationType.EQUALS.value:
+        if op == FilterOperator.EQUALS.value:
             if col in self.column_names:
                 cond = DimSelector(dimension=col, value=eq)
             else:
                 cond = Aggregation(col) == eq
-        elif op == FilterOperationType.GREATER_THAN.value:
+        elif op == FilterOperator.GREATER_THAN.value:
             cond = Aggregation(col) > eq
-        elif op == FilterOperationType.LESS_THAN.value:
+        elif op == FilterOperator.LESS_THAN.value:
             cond = Aggregation(col) < eq
 
         return cond
@@ -1642,9 +1642,9 @@ class DruidDatasource(Model, BaseDatasource):
     def get_having_filters(self, raw_filters: List[Dict[str, Any]]) -> 
"Having":
         filters = None
         reversed_op_map = {
-            FilterOperationType.NOT_EQUALS.value: 
FilterOperationType.EQUALS.value,
-            FilterOperationType.GREATER_THAN_OR_EQUALS.value: 
FilterOperationType.LESS_THAN.value,
-            FilterOperationType.LESS_THAN_OR_EQUALS.value: 
FilterOperationType.GREATER_THAN.value,
+            FilterOperator.NOT_EQUALS.value: FilterOperator.EQUALS.value,
+            FilterOperator.GREATER_THAN_OR_EQUALS.value: 
FilterOperator.LESS_THAN.value,
+            FilterOperator.LESS_THAN_OR_EQUALS.value: 
FilterOperator.GREATER_THAN.value,
         }
 
         for flt in raw_filters:
@@ -1655,9 +1655,9 @@ class DruidDatasource(Model, BaseDatasource):
             eq = flt["val"]
             cond = None
             if op in [
-                FilterOperationType.EQUALS.value,
-                FilterOperationType.GREATER_THAN.value,
-                FilterOperationType.LESS_THAN.value,
+                FilterOperator.EQUALS.value,
+                FilterOperator.GREATER_THAN.value,
+                FilterOperator.LESS_THAN.value,
             ]:
                 cond = self._get_having_obj(col, op, eq)
             elif op in reversed_op_map:
diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 8aac7e0..1eed4be 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -847,8 +847,8 @@ class SqlaTable(Model, BaseDatasource):
             col_obj = cols.get(col)
             if col_obj:
                 is_list_target = op in (
-                    utils.FilterOperationType.IN.value,
-                    utils.FilterOperationType.NOT_IN.value,
+                    utils.FilterOperator.IN.value,
+                    utils.FilterOperator.NOT_IN.value,
                 )
                 eq = self.filter_values_handler(
                     values=flt.get("val"),
@@ -856,36 +856,36 @@ class SqlaTable(Model, BaseDatasource):
                     is_list_target=is_list_target,
                 )
                 if op in (
-                    utils.FilterOperationType.IN.value,
-                    utils.FilterOperationType.NOT_IN.value,
+                    utils.FilterOperator.IN.value,
+                    utils.FilterOperator.NOT_IN.value,
                 ):
                     cond = col_obj.get_sqla_col().in_(eq)
                     if isinstance(eq, str) and NULL_STRING in eq:
                         cond = or_(cond, col_obj.get_sqla_col() is None)
-                    if op == utils.FilterOperationType.NOT_IN.value:
+                    if op == utils.FilterOperator.NOT_IN.value:
                         cond = ~cond
                     where_clause_and.append(cond)
                 else:
                     if col_obj.is_numeric:
                         eq = utils.cast_to_num(flt["val"])
-                    if op == utils.FilterOperationType.EQUALS.value:
+                    if op == utils.FilterOperator.EQUALS.value:
                         where_clause_and.append(col_obj.get_sqla_col() == eq)
-                    elif op == utils.FilterOperationType.NOT_EQUALS.value:
+                    elif op == utils.FilterOperator.NOT_EQUALS.value:
                         where_clause_and.append(col_obj.get_sqla_col() != eq)
-                    elif op == utils.FilterOperationType.GREATER_THAN.value:
+                    elif op == utils.FilterOperator.GREATER_THAN.value:
                         where_clause_and.append(col_obj.get_sqla_col() > eq)
-                    elif op == utils.FilterOperationType.LESS_THAN.value:
+                    elif op == utils.FilterOperator.LESS_THAN.value:
                         where_clause_and.append(col_obj.get_sqla_col() < eq)
-                    elif op == 
utils.FilterOperationType.GREATER_THAN_OR_EQUALS.value:
+                    elif op == 
utils.FilterOperator.GREATER_THAN_OR_EQUALS.value:
                         where_clause_and.append(col_obj.get_sqla_col() >= eq)
-                    elif op == 
utils.FilterOperationType.LESS_THAN_OR_EQUALS.value:
+                    elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value:
                         where_clause_and.append(col_obj.get_sqla_col() <= eq)
-                    elif op == utils.FilterOperationType.LIKE.value:
+                    elif op == utils.FilterOperator.LIKE.value:
                         
where_clause_and.append(col_obj.get_sqla_col().like(eq))
-                    elif op == utils.FilterOperationType.IS_NULL.value:
-                        where_clause_and.append(col_obj.get_sqla_col() is None)
-                    elif op == utils.FilterOperationType.IS_NOT_NULL.value:
-                        where_clause_and.append(col_obj.get_sqla_col() is None)
+                    elif op == utils.FilterOperator.IS_NULL.value:
+                        where_clause_and.append(col_obj.get_sqla_col() == None)
+                    elif op == utils.FilterOperator.IS_NOT_NULL.value:
+                        where_clause_and.append(col_obj.get_sqla_col() != None)
                     else:
                         raise Exception(
                             _("Invalid filter operation type: %(op)s", op=op)
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 0c838b7..ba715dd 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -1350,9 +1350,9 @@ class DbColumnType(Enum):
     TEMPORAL = 2
 
 
-class FilterOperationType(str, Enum):
+class FilterOperator(str, Enum):
     """
-    Filter operation type
+    Operators used filter controls
     """
 
     EQUALS = "=="
diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py
index 700c2e2..af15b26 100644
--- a/tests/sqla_models_tests.py
+++ b/tests/sqla_models_tests.py
@@ -15,12 +15,12 @@
 # specific language governing permissions and limitations
 # under the License.
 # isort:skip_file
-from typing import Dict
+from typing import Any, Dict, NamedTuple, List, Tuple, Union
 
 from superset.connectors.sqla.models import SqlaTable, TableColumn
 from superset.db_engine_specs.druid import DruidEngineSpec
 from superset.models.core import Database
-from superset.utils.core import DbColumnType, get_example_database
+from superset.utils.core import DbColumnType, get_example_database, 
FilterOperator
 
 from .base_tests import SupersetTestCase
 
@@ -109,3 +109,39 @@ class DatabaseModelTestCase(SupersetTestCase):
         extra_cache_keys = table.get_extra_cache_keys(query_obj)
         self.assertFalse(table.has_calls_to_cache_key_wrapper(query_obj))
         self.assertListEqual(extra_cache_keys, [])
+
+    def test_where_operators(self):
+        class FilterTestCase(NamedTuple):
+            operator: str
+            value: Union[float, int, List[Any], str]
+            expected: str
+
+        filters: Tuple[FilterTestCase, ...] = (
+            FilterTestCase(FilterOperator.IS_NULL, "", "IS NULL"),
+            FilterTestCase(FilterOperator.IS_NOT_NULL, "", "IS NOT NULL"),
+            FilterTestCase(FilterOperator.GREATER_THAN, 0, "> 0"),
+            FilterTestCase(FilterOperator.GREATER_THAN_OR_EQUALS, 0, ">= 0"),
+            FilterTestCase(FilterOperator.LESS_THAN, 0, "< 0"),
+            FilterTestCase(FilterOperator.LESS_THAN_OR_EQUALS, 0, "<= 0"),
+            FilterTestCase(FilterOperator.EQUALS, 0, "= 0"),
+            FilterTestCase(FilterOperator.NOT_EQUALS, 0, "!= 0"),
+            FilterTestCase(FilterOperator.IN, ["1", "2"], "IN (1, 2)"),
+            FilterTestCase(FilterOperator.NOT_IN, ["1", "2"], "NOT IN (1, 2)"),
+        )
+        table = self.get_table_by_name("birth_names")
+        for filter_ in filters:
+            query_obj = {
+                "granularity": None,
+                "from_dttm": None,
+                "to_dttm": None,
+                "groupby": ["gender"],
+                "metrics": ["count"],
+                "is_timeseries": False,
+                "filter": [
+                    {"col": "num", "op": filter_.operator, "val": 
filter_.value}
+                ],
+                "extras": {},
+            }
+            sqla_query = table.get_sqla_query(**query_obj)
+            sql = table.database.compile_sqla_query(sqla_query.sqla_query)
+            self.assertIn(filter_.expected, sql)

Reply via email to