This is an automated email from the ASF dual-hosted git repository.

yongjiezhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 05648eb489 feat: support None operand in EQUAL operator (#21713)
05648eb489 is described below

commit 05648eb489a7b5eec1c452cf1f037566dd942505
Author: Yongjie Zhao <[email protected]>
AuthorDate: Thu Oct 6 16:45:59 2022 +0800

    feat: support None operand in EQUAL operator (#21713)
---
 superset/charts/schemas.py                   |   3 +-
 superset/connectors/sqla/models.py           |   9 +-
 tests/integration_tests/sqla_models_tests.py | 120 ++++++++++++++++++---------
 3 files changed, 90 insertions(+), 42 deletions(-)

diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index ffd64e8bb0..34dd44b38c 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -819,7 +819,8 @@ class ChartDataFilterSchema(Schema):
     )
     val = fields.Raw(
         description="The value or values to compare against. Can be a string, "
-        "integer, decimal or list, depending on the operator.",
+        "integer, decimal, None or list, depending on the operator.",
+        allow_none=True,
         example=["China", "France", "Japan"],
     )
     grain = fields.String(
diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index c09b687023..3b7a786441 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -1615,7 +1615,14 @@ class SqlaTable(Model, BaseDatasource):  # pylint: 
disable=too-many-public-metho
                 elif op == utils.FilterOperator.IS_FALSE.value:
                     where_clause_and.append(sqla_col.is_(False))
                 else:
-                    if eq is None:
+                    if (
+                        op
+                        not in {
+                            utils.FilterOperator.EQUALS.value,
+                            utils.FilterOperator.NOT_EQUALS.value,
+                        }
+                        and eq is None
+                    ):
                         raise QueryObjectValidationError(
                             _(
                                 "Must specify a value for filters "
diff --git a/tests/integration_tests/sqla_models_tests.py 
b/tests/integration_tests/sqla_models_tests.py
index 047228355d..fc37e17b57 100644
--- a/tests/integration_tests/sqla_models_tests.py
+++ b/tests/integration_tests/sqla_models_tests.py
@@ -51,6 +51,7 @@ from tests.integration_tests.fixtures.birth_names_dashboard 
import (
 from tests.integration_tests.test_app import app
 
 from .base_tests import SupersetTestCase
+from .conftest import only_postgresql
 
 VIRTUAL_TABLE_INT_TYPES: Dict[str, Pattern[str]] = {
     "hive": re.compile(r"^INT_TYPE$"),
@@ -659,51 +660,90 @@ def test_filter_on_text_column(text_column_table):
     assert result_object.df["count"][0] == 1
 
 
-def test_should_generate_closed_and_open_time_filter_range():
-    with app.app_context():
-        if backend() != "postgresql":
-            pytest.skip(f"{backend()} has different dialect for datetime 
column")
-
-        table = SqlaTable(
-            table_name="temporal_column_table",
-            sql=(
-                "SELECT '2021-12-31'::timestamp as datetime_col "
-                "UNION SELECT '2022-01-01'::timestamp "
-                "UNION SELECT '2022-03-10'::timestamp "
-                "UNION SELECT '2023-01-01'::timestamp "
-                "UNION SELECT '2023-03-10'::timestamp "
-            ),
-            database=get_example_database(),
-        )
-        TableColumn(
-            column_name="datetime_col",
-            type="TIMESTAMP",
-            table=table,
-            is_dttm=True,
-        )
-        SqlMetric(metric_name="count", expression="count(*)", table=table)
-        result_object = table.query(
+@only_postgresql
+def test_should_generate_closed_and_open_time_filter_range(login_as_admin):
+    table = SqlaTable(
+        table_name="temporal_column_table",
+        sql=(
+            "SELECT '2021-12-31'::timestamp as datetime_col "
+            "UNION SELECT '2022-01-01'::timestamp "
+            "UNION SELECT '2022-03-10'::timestamp "
+            "UNION SELECT '2023-01-01'::timestamp "
+            "UNION SELECT '2023-03-10'::timestamp "
+        ),
+        database=get_example_database(),
+    )
+    TableColumn(
+        column_name="datetime_col",
+        type="TIMESTAMP",
+        table=table,
+        is_dttm=True,
+    )
+    SqlMetric(metric_name="count", expression="count(*)", table=table)
+    result_object = table.query(
+        {
+            "metrics": ["count"],
+            "is_timeseries": False,
+            "filter": [],
+            "from_dttm": datetime(2022, 1, 1),
+            "to_dttm": datetime(2023, 1, 1),
+            "granularity": "datetime_col",
+        }
+    )
+    """ >>> result_object.query
+            SELECT count(*) AS count
+            FROM
+              (SELECT '2021-12-31'::timestamp as datetime_col
+               UNION SELECT '2022-01-01'::timestamp
+               UNION SELECT '2022-03-10'::timestamp
+               UNION SELECT '2023-01-01'::timestamp
+               UNION SELECT '2023-03-10'::timestamp) AS virtual_table
+            WHERE datetime_col >= TO_TIMESTAMP('2022-01-01 00:00:00.000000', 
'YYYY-MM-DD HH24:MI:SS.US')
+              AND datetime_col < TO_TIMESTAMP('2023-01-01 00:00:00.000000', 
'YYYY-MM-DD HH24:MI:SS.US')
+    """
+    assert result_object.df.iloc[0]["count"] == 2
+
+
+def test_none_operand_in_filter(login_as_admin, physical_dataset):
+    expected_results = [
+        {
+            "operator": FilterOperator.EQUALS.value,
+            "count": 10,
+            "sql_should_contain": "COL4 IS NULL",
+        },
+        {
+            "operator": FilterOperator.NOT_EQUALS.value,
+            "count": 0,
+            "sql_should_contain": "COL4 IS NOT NULL",
+        },
+    ]
+    for expected in expected_results:
+        result = physical_dataset.query(
             {
                 "metrics": ["count"],
+                "filter": [{"col": "col4", "val": None, "op": 
expected["operator"]}],
                 "is_timeseries": False,
-                "filter": [],
-                "from_dttm": datetime(2022, 1, 1),
-                "to_dttm": datetime(2023, 1, 1),
-                "granularity": "datetime_col",
             }
         )
-        """ >>> result_object.query
-                SELECT count(*) AS count
-                FROM
-                  (SELECT '2021-12-31'::timestamp as datetime_col
-                   UNION SELECT '2022-01-01'::timestamp
-                   UNION SELECT '2022-03-10'::timestamp
-                   UNION SELECT '2023-01-01'::timestamp
-                   UNION SELECT '2023-03-10'::timestamp) AS virtual_table
-                WHERE datetime_col >= TO_TIMESTAMP('2022-01-01 
00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
-                  AND datetime_col < TO_TIMESTAMP('2023-01-01 
00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
-        """
-        assert result_object.df.iloc[0]["count"] == 2
+        assert result.df["count"][0] == expected["count"]
+        assert expected["sql_should_contain"] in result.query.upper()
+
+    with pytest.raises(QueryObjectValidationError):
+        for flt in [
+            FilterOperator.GREATER_THAN,
+            FilterOperator.LESS_THAN,
+            FilterOperator.GREATER_THAN_OR_EQUALS,
+            FilterOperator.LESS_THAN_OR_EQUALS,
+            FilterOperator.LIKE,
+            FilterOperator.ILIKE,
+        ]:
+            physical_dataset.query(
+                {
+                    "metrics": ["count"],
+                    "filter": [{"col": "col4", "val": None, "op": flt.value}],
+                    "is_timeseries": False,
+                }
+            )
 
 
 @pytest.mark.parametrize(

Reply via email to