This is an automated email from the ASF dual-hosted git repository.
yongjiezhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 05648eb489 feat: support None operand in EQUAL operator (#21713)
05648eb489 is described below
commit 05648eb489a7b5eec1c452cf1f037566dd942505
Author: Yongjie Zhao <[email protected]>
AuthorDate: Thu Oct 6 16:45:59 2022 +0800
feat: support None operand in EQUAL operator (#21713)
---
superset/charts/schemas.py | 3 +-
superset/connectors/sqla/models.py | 9 +-
tests/integration_tests/sqla_models_tests.py | 120 ++++++++++++++++++---------
3 files changed, 90 insertions(+), 42 deletions(-)
diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index ffd64e8bb0..34dd44b38c 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -819,7 +819,8 @@ class ChartDataFilterSchema(Schema):
)
val = fields.Raw(
description="The value or values to compare against. Can be a string, "
- "integer, decimal or list, depending on the operator.",
+ "integer, decimal, None or list, depending on the operator.",
+ allow_none=True,
example=["China", "France", "Japan"],
)
grain = fields.String(
diff --git a/superset/connectors/sqla/models.py
b/superset/connectors/sqla/models.py
index c09b687023..3b7a786441 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -1615,7 +1615,14 @@ class SqlaTable(Model, BaseDatasource): # pylint:
disable=too-many-public-metho
elif op == utils.FilterOperator.IS_FALSE.value:
where_clause_and.append(sqla_col.is_(False))
else:
- if eq is None:
+ if (
+ op
+ not in {
+ utils.FilterOperator.EQUALS.value,
+ utils.FilterOperator.NOT_EQUALS.value,
+ }
+ and eq is None
+ ):
raise QueryObjectValidationError(
_(
"Must specify a value for filters "
diff --git a/tests/integration_tests/sqla_models_tests.py
b/tests/integration_tests/sqla_models_tests.py
index 047228355d..fc37e17b57 100644
--- a/tests/integration_tests/sqla_models_tests.py
+++ b/tests/integration_tests/sqla_models_tests.py
@@ -51,6 +51,7 @@ from tests.integration_tests.fixtures.birth_names_dashboard
import (
from tests.integration_tests.test_app import app
from .base_tests import SupersetTestCase
+from .conftest import only_postgresql
VIRTUAL_TABLE_INT_TYPES: Dict[str, Pattern[str]] = {
"hive": re.compile(r"^INT_TYPE$"),
@@ -659,51 +660,90 @@ def test_filter_on_text_column(text_column_table):
assert result_object.df["count"][0] == 1
-def test_should_generate_closed_and_open_time_filter_range():
- with app.app_context():
- if backend() != "postgresql":
- pytest.skip(f"{backend()} has different dialect for datetime
column")
-
- table = SqlaTable(
- table_name="temporal_column_table",
- sql=(
- "SELECT '2021-12-31'::timestamp as datetime_col "
- "UNION SELECT '2022-01-01'::timestamp "
- "UNION SELECT '2022-03-10'::timestamp "
- "UNION SELECT '2023-01-01'::timestamp "
- "UNION SELECT '2023-03-10'::timestamp "
- ),
- database=get_example_database(),
- )
- TableColumn(
- column_name="datetime_col",
- type="TIMESTAMP",
- table=table,
- is_dttm=True,
- )
- SqlMetric(metric_name="count", expression="count(*)", table=table)
- result_object = table.query(
+@only_postgresql
+def test_should_generate_closed_and_open_time_filter_range(login_as_admin):
+ table = SqlaTable(
+ table_name="temporal_column_table",
+ sql=(
+ "SELECT '2021-12-31'::timestamp as datetime_col "
+ "UNION SELECT '2022-01-01'::timestamp "
+ "UNION SELECT '2022-03-10'::timestamp "
+ "UNION SELECT '2023-01-01'::timestamp "
+ "UNION SELECT '2023-03-10'::timestamp "
+ ),
+ database=get_example_database(),
+ )
+ TableColumn(
+ column_name="datetime_col",
+ type="TIMESTAMP",
+ table=table,
+ is_dttm=True,
+ )
+ SqlMetric(metric_name="count", expression="count(*)", table=table)
+ result_object = table.query(
+ {
+ "metrics": ["count"],
+ "is_timeseries": False,
+ "filter": [],
+ "from_dttm": datetime(2022, 1, 1),
+ "to_dttm": datetime(2023, 1, 1),
+ "granularity": "datetime_col",
+ }
+ )
+ """ >>> result_object.query
+ SELECT count(*) AS count
+ FROM
+ (SELECT '2021-12-31'::timestamp as datetime_col
+ UNION SELECT '2022-01-01'::timestamp
+ UNION SELECT '2022-03-10'::timestamp
+ UNION SELECT '2023-01-01'::timestamp
+ UNION SELECT '2023-03-10'::timestamp) AS virtual_table
+ WHERE datetime_col >= TO_TIMESTAMP('2022-01-01 00:00:00.000000',
'YYYY-MM-DD HH24:MI:SS.US')
+ AND datetime_col < TO_TIMESTAMP('2023-01-01 00:00:00.000000',
'YYYY-MM-DD HH24:MI:SS.US')
+ """
+ assert result_object.df.iloc[0]["count"] == 2
+
+
+def test_none_operand_in_filter(login_as_admin, physical_dataset):
+ expected_results = [
+ {
+ "operator": FilterOperator.EQUALS.value,
+ "count": 10,
+ "sql_should_contain": "COL4 IS NULL",
+ },
+ {
+ "operator": FilterOperator.NOT_EQUALS.value,
+ "count": 0,
+ "sql_should_contain": "COL4 IS NOT NULL",
+ },
+ ]
+ for expected in expected_results:
+ result = physical_dataset.query(
{
"metrics": ["count"],
+ "filter": [{"col": "col4", "val": None, "op":
expected["operator"]}],
"is_timeseries": False,
- "filter": [],
- "from_dttm": datetime(2022, 1, 1),
- "to_dttm": datetime(2023, 1, 1),
- "granularity": "datetime_col",
}
)
- """ >>> result_object.query
- SELECT count(*) AS count
- FROM
- (SELECT '2021-12-31'::timestamp as datetime_col
- UNION SELECT '2022-01-01'::timestamp
- UNION SELECT '2022-03-10'::timestamp
- UNION SELECT '2023-01-01'::timestamp
- UNION SELECT '2023-03-10'::timestamp) AS virtual_table
- WHERE datetime_col >= TO_TIMESTAMP('2022-01-01
00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
- AND datetime_col < TO_TIMESTAMP('2023-01-01
00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
- """
- assert result_object.df.iloc[0]["count"] == 2
+ assert result.df["count"][0] == expected["count"]
+ assert expected["sql_should_contain"] in result.query.upper()
+
+ with pytest.raises(QueryObjectValidationError):
+ for flt in [
+ FilterOperator.GREATER_THAN,
+ FilterOperator.LESS_THAN,
+ FilterOperator.GREATER_THAN_OR_EQUALS,
+ FilterOperator.LESS_THAN_OR_EQUALS,
+ FilterOperator.LIKE,
+ FilterOperator.ILIKE,
+ ]:
+ physical_dataset.query(
+ {
+ "metrics": ["count"],
+ "filter": [{"col": "col4", "val": None, "op": flt.value}],
+ "is_timeseries": False,
+ }
+ )
@pytest.mark.parametrize(