This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch v2-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v2-10-test by this push:
new f58fb2011c5 fix(XCom): /xcom/list got exception when applying filter
on the value column (#46053)
f58fb2011c5 is described below
commit f58fb2011c535b3d741f3d1fdb49e90034e81766
Author: Castle Cheng <[email protected]>
AuthorDate: Fri Feb 7 10:28:28 2025 +0800
fix(XCom): /xcom/list got exception when applying filter on the value
column (#46053)
* fix(XCom): /xcom/list got exception when applying filter on the value
column
#42720
* fix(XCom): fix the query in XComFilterNotEndsWith
reduce code duplication and add type annotation
fix
fix
fix
* Add unit test for for XCom filter
use @pytest.mark.parametrize
test all kind of XComFilter
fix unit test parametrize
---------
Co-authored-by: josix <[email protected]>
---
airflow/www/utils.py | 125 ++++++++++++++++++++++++++++++++++++++++++++++++
tests/www/test_utils.py | 109 ++++++++++++++++++++++++++++++++++++++++-
2 files changed, 233 insertions(+), 1 deletion(-)
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 70e942329ca..ac65ffb0038 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -59,8 +59,10 @@ if TYPE_CHECKING:
from flask_appbuilder.models.sqla import Model
from pendulum.datetime import DateTime
from pygments.lexer import Lexer
+ from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import Select
+ from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.operators import ColumnOperators
from airflow.www.extensions.init_appbuilder import AirflowAppBuilder
@@ -682,6 +684,14 @@ def get_attr_renderer():
}
+def generate_filter_value_query(
+ *, query: Query, model: Model, column_name: str, filter_cond:
Callable[[ColumnElement], Query]
+) -> Query:
+ query, field = get_field_setup_query(query, model, column_name)
+ trimmed_value = func.btrim(func.convert_from(field, "UTF8"), '"')
+ return query.filter(filter_cond(trimmed_value))
+
+
class UtcAwareFilterMixin:
"""Mixin for filter for UTC time."""
@@ -779,6 +789,104 @@ class
UtcAwareFilterConverter(fab_sqlafilters.SQLAFilterConverter):
"""Retrieve conversion tables for UTC-Aware filters."""
+class XComFilterStartsWith(fab_sqlafilters.FilterStartsWith):
+ """Starts With filter for XCom values."""
+
+ def apply(self, query: Query, value: str) -> Query:
+ return generate_filter_value_query(
+ query=query,
+ model=self.model,
+ column_name=self.column_name,
+ filter_cond=lambda trimmed_value: trimmed_value.ilike(f"{value}%"),
+ )
+
+
+class XComFilterEndsWith(fab_sqlafilters.FilterEndsWith):
+ """Ends With filter for XCom values."""
+
+ def apply(self, query: Query, value: str) -> Query:
+ return generate_filter_value_query(
+ query=query,
+ model=self.model,
+ column_name=self.column_name,
+ filter_cond=lambda trimmed_value: trimmed_value.ilike(f"%{value}"),
+ )
+
+
+class XComFilterEqual(fab_sqlafilters.FilterEqual):
+ """Equality filter for XCom values."""
+
+ def apply(self, query: Query, value: str) -> Query:
+ value = set_value_to_type(self.datamodel, self.column_name, value)
+ return generate_filter_value_query(
+ query=query,
+ model=self.model,
+ column_name=self.column_name,
+ filter_cond=lambda trimmed_value: trimmed_value == value,
+ )
+
+
+class XComFilterContains(fab_sqlafilters.FilterContains):
+ """Not Equal To filter for XCom values."""
+
+ def apply(self, query: Query, value: str) -> Query:
+ return generate_filter_value_query(
+ query=query,
+ model=self.model,
+ column_name=self.column_name,
+ filter_cond=lambda trimmed_value:
trimmed_value.ilike(f"%{value}%"),
+ )
+
+
+class XComFilterNotStartsWith(fab_sqlafilters.FilterNotStartsWith):
+ """Not Starts With filter for XCom values."""
+
+ def apply(self, query: Query, value: str) -> Query:
+ return generate_filter_value_query(
+ query=query,
+ model=self.model,
+ column_name=self.column_name,
+ filter_cond=lambda trimmed_value:
~trimmed_value.ilike(f"{value}%"),
+ )
+
+
+class XComFilterNotEndsWith(fab_sqlafilters.FilterNotEndsWith):
+ """Not Starts With filter for XCom values."""
+
+ def apply(self, query: Query, value: str) -> Query:
+ return generate_filter_value_query(
+ query=query,
+ model=self.model,
+ column_name=self.column_name,
+ filter_cond=lambda trimmed_value:
~trimmed_value.ilike(f"%{value}"),
+ )
+
+
+class XComFilterNotContains(fab_sqlafilters.FilterNotContains):
+ """Not Starts With filter for XCom values."""
+
+ def apply(self, query: Query, value: str) -> Query:
+ return generate_filter_value_query(
+ query=query,
+ model=self.model,
+ column_name=self.column_name,
+ filter_cond=lambda trimmed_value:
~trimmed_value.ilike(f"%{value}%"),
+ )
+
+
+class XComFilterNotEqual(fab_sqlafilters.FilterNotEqual):
+ """Not Starts With filter for XCom values."""
+
+ def apply(self, query: Query, value: str) -> Query:
+ value = set_value_to_type(self.datamodel, self.column_name, value)
+ return generate_filter_value_query(
+ query=query,
+ model=self.model,
+ column_name=self.column_name,
+ filter_cond=lambda trimmed_value: trimmed_value != value,
+ )
+
+
class AirflowFilterConverter(fab_sqlafilters.SQLAFilterConverter):
"""Retrieve conversion tables for Airflow-specific filters."""
@@ -800,6 +908,19 @@ class
AirflowFilterConverter(fab_sqlafilters.SQLAFilterConverter):
"is_extendedjson",
[],
),
+ (
+ "is_xcom_value",
+ [
+ XComFilterStartsWith,
+ XComFilterEndsWith,
+ XComFilterEqual,
+ XComFilterContains,
+ XComFilterNotStartsWith,
+ XComFilterNotEndsWith,
+ XComFilterNotContains,
+ XComFilterNotEqual,
+ ],
+ ),
*fab_sqlafilters.SQLAFilterConverter.conversion_table,
)
@@ -864,6 +985,10 @@ class CustomSQLAInterface(SQLAInterface):
)
return False
+ def is_xcom_value(self, col_name: str) -> bool:
+ """Check if it is col_name is value of xcom table."""
+ return col_name == "value" and self.obj.__tablename__ == "xcom"
+
def get_col_default(self, col_name: str) -> Any:
if col_name not in self.list_columns:
# Handle AssociationProxy etc, or anything that isn't a "real"
column
diff --git a/tests/www/test_utils.py b/tests/www/test_utils.py
index 669f08a2b61..8ad6ef3ca47 100644
--- a/tests/www/test_utils.py
+++ b/tests/www/test_utils.py
@@ -22,7 +22,8 @@ import os
import re
import time
from datetime import datetime
-from unittest.mock import Mock
+from typing import TYPE_CHECKING, Callable
+from unittest.mock import MagicMock, Mock
from urllib.parse import parse_qs
import pendulum
@@ -31,6 +32,7 @@ from bs4 import BeautifulSoup
from flask_appbuilder.models.sqla.filters import get_field_setup_query,
set_value_to_type
from flask_wtf import FlaskForm
from markupsafe import Markup
+from sqlalchemy import func
from sqlalchemy.orm import Query
from wtforms.fields import StringField, TextAreaField
@@ -47,6 +49,9 @@ from airflow.www.utils import (
from airflow.www.widgets import AirflowDateTimePickerROWidget,
BS3TextAreaROWidget, BS3TextFieldROWidget
from tests.test_utils.config import conf_vars
+if TYPE_CHECKING:
+ from sqlalchemy.sql.elements import ColumnElement
+
class TestUtils:
def check_generate_pages_html(
@@ -663,6 +668,108 @@ class TestFilter:
assert result_query_filter == self.mock_query
+class TestXComFilter:
+ def setup_method(self):
+ self.mock_datamodel = MagicMock()
+ self.mock_query = MagicMock(spec=Query)
+ self.mock_column_name = "test_column"
+
+ def _assert_filter_query(
+ self,
+ xcom_filter,
+ raw_value: str,
+ expected_expr_builder: Callable[[ColumnElement, str], ColumnElement],
+ convert_value: bool = False,
+ ) -> None:
+ """
+ A helper function to assert the filter query.
+
+ :param xcom_filter: The XCom filter instance (e.g.,
XComFilterStartsWith).
+ :param raw_value: The raw string value we want to filter on.
+ :param expected_expr_builder: A function that takes in
`returned_field` and returns the expected SQL expression.
+ :param convert_value: Whether to run `set_value_to_type(...)` on the
raw_value.
+ """
+ returned_query, returned_field = get_field_setup_query(
+ self.mock_query, self.mock_datamodel, self.mock_column_name
+ )
+
+ if convert_value:
+ value = set_value_to_type(self.mock_datamodel,
self.mock_column_name, raw_value)
+ else:
+ value = raw_value
+ xcom_filter.apply(self.mock_query, value)
+ self.mock_query.filter.assert_called_once()
+ actual_filter_arg = self.mock_query.filter.call_args[0][0]
+ expected_filter_arg = expected_expr_builder(returned_field, value)
+ assert str(actual_filter_arg) == str(expected_filter_arg)
+
+ @pytest.mark.parametrize(
+ "filter_class, convert_value, expected_expr_builder",
+ [
+ (
+ utils.XComFilterStartsWith,
+ False,
+ lambda field, v: func.btrim(func.convert_from(field, "UTF8"),
'"').ilike(f"{v}%"),
+ ),
+ (
+ utils.XComFilterEndsWith,
+ False,
+ lambda field, v: func.btrim(func.convert_from(field, "UTF8"),
'"').ilike(f"%{v}"),
+ ),
+ (
+ utils.XComFilterEqual,
+ True,
+ lambda field, v: func.btrim(func.convert_from(field, "UTF8"),
'"') == v,
+ ),
+ (
+ utils.XComFilterContains,
+ False,
+ lambda field, v: func.btrim(func.convert_from(field, "UTF8"),
'"').ilike(f"%{v}%"),
+ ),
+ (
+ utils.XComFilterNotStartsWith,
+ False,
+ lambda field, v: ~func.btrim(func.convert_from(field, "UTF8"),
'"').ilike(f"{v}%"),
+ ),
+ (
+ utils.XComFilterNotEndsWith,
+ False,
+ lambda field, v: ~func.btrim(func.convert_from(field, "UTF8"),
'"').ilike(f"%{v}"),
+ ),
+ (
+ utils.XComFilterNotContains,
+ False,
+ lambda field, v: ~func.btrim(func.convert_from(field, "UTF8"),
'"').ilike(f"%{v}%"),
+ ),
+ (
+ utils.XComFilterNotEqual,
+ True,
+ lambda field, v: func.btrim(func.convert_from(field, "UTF8"),
'"') != v,
+ ),
+ ],
+ ids=[
+ "StartsWith",
+ "EndsWith",
+ "Equal",
+ "Contains",
+ "NotStartsWith",
+ "NotEndsWith",
+ "NotContains",
+ "NotEqual",
+ ],
+ )
+ def test_xcom_filters(self, filter_class, convert_value,
expected_expr_builder):
+ xcom_filter_query = filter_class(datamodel=self.mock_datamodel,
column_name=self.mock_column_name)
+ raw_value = "test_value"
+
+ self._assert_filter_query(
+ xcom_filter_query,
+ raw_value=raw_value,
+ expected_expr_builder=expected_expr_builder,
+ convert_value=convert_value,
+ )
+
+
@pytest.mark.db_test
def test_get_col_default_not_existing(session):
interface = CustomSQLAInterface(obj=DagRun, session=session)