This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch v2-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v2-10-test by this push:
     new f58fb2011c5 fix(XCom): /xcom/list got exception when applying filter 
on the value column (#46053)
f58fb2011c5 is described below

commit f58fb2011c535b3d741f3d1fdb49e90034e81766
Author: Castle Cheng <[email protected]>
AuthorDate: Fri Feb 7 10:28:28 2025 +0800

    fix(XCom): /xcom/list got exception when applying filter on the value 
column (#46053)
    
    * fix(XCom): /xcom/list got exception when applying filter on the value 
column
    
    #42720
    
    * fix(XCom): fix the query in XComFilterNotEndsWith
    
    reduce code duplication and add type annotation
    
    fix
    
    fix
    
    fix
    
    * Add unit test for for XCom filter
    
    use @pytest.mark.parametrize
    test all kind of XComFilter
    
    fix unit test parametrize
    
    ---------
    
    Co-authored-by: josix <[email protected]>
---
 airflow/www/utils.py    | 125 ++++++++++++++++++++++++++++++++++++++++++++++++
 tests/www/test_utils.py | 109 ++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 233 insertions(+), 1 deletion(-)

diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 70e942329ca..ac65ffb0038 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -59,8 +59,10 @@ if TYPE_CHECKING:
     from flask_appbuilder.models.sqla import Model
     from pendulum.datetime import DateTime
     from pygments.lexer import Lexer
+    from sqlalchemy.orm.query import Query
     from sqlalchemy.orm.session import Session
     from sqlalchemy.sql import Select
+    from sqlalchemy.sql.elements import ColumnElement
     from sqlalchemy.sql.operators import ColumnOperators
 
     from airflow.www.extensions.init_appbuilder import AirflowAppBuilder
@@ -682,6 +684,14 @@ def get_attr_renderer():
     }
 
 
+def generate_filter_value_query(
+    *, query: Query, model: Model, column_name: str, filter_cond: 
Callable[[ColumnElement], Query]
+) -> Query:
+    query, field = get_field_setup_query(query, model, column_name)
+    trimmed_value = func.btrim(func.convert_from(field, "UTF8"), '"')
+    return query.filter(filter_cond(trimmed_value))
+
+
 class UtcAwareFilterMixin:
     """Mixin for filter for UTC time."""
 
@@ -779,6 +789,104 @@ class 
UtcAwareFilterConverter(fab_sqlafilters.SQLAFilterConverter):
     """Retrieve conversion tables for UTC-Aware filters."""
 
 
+class XComFilterStartsWith(fab_sqlafilters.FilterStartsWith):
+    """Starts With filter for XCom values."""
+
+    def apply(self, query: Query, value: str) -> Query:
+        return generate_filter_value_query(
+            query=query,
+            model=self.model,
+            column_name=self.column_name,
+            filter_cond=lambda trimmed_value: trimmed_value.ilike(f"{value}%"),
+        )
+
+
+class XComFilterEndsWith(fab_sqlafilters.FilterEndsWith):
+    """Ends With filter for XCom values."""
+
+    def apply(self, query: Query, value: str) -> Query:
+        return generate_filter_value_query(
+            query=query,
+            model=self.model,
+            column_name=self.column_name,
+            filter_cond=lambda trimmed_value: trimmed_value.ilike(f"%{value}"),
+        )
+
+
+class XComFilterEqual(fab_sqlafilters.FilterEqual):
+    """Equality filter for XCom values."""
+
+    def apply(self, query: Query, value: str) -> Query:
+        value = set_value_to_type(self.datamodel, self.column_name, value)
+        return generate_filter_value_query(
+            query=query,
+            model=self.model,
+            column_name=self.column_name,
+            filter_cond=lambda trimmed_value: trimmed_value == value,
+        )
+
+
+class XComFilterContains(fab_sqlafilters.FilterContains):
+    """Not Equal To filter for XCom values."""
+
+    def apply(self, query: Query, value: str) -> Query:
+        return generate_filter_value_query(
+            query=query,
+            model=self.model,
+            column_name=self.column_name,
+            filter_cond=lambda trimmed_value: 
trimmed_value.ilike(f"%{value}%"),
+        )
+
+
+class XComFilterNotStartsWith(fab_sqlafilters.FilterNotStartsWith):
+    """Not Starts With filter for XCom values."""
+
+    def apply(self, query: Query, value: str) -> Query:
+        return generate_filter_value_query(
+            query=query,
+            model=self.model,
+            column_name=self.column_name,
+            filter_cond=lambda trimmed_value: 
~trimmed_value.ilike(f"{value}%"),
+        )
+
+
+class XComFilterNotEndsWith(fab_sqlafilters.FilterNotEndsWith):
+    """Not Starts With filter for XCom values."""
+
+    def apply(self, query: Query, value: str) -> Query:
+        return generate_filter_value_query(
+            query=query,
+            model=self.model,
+            column_name=self.column_name,
+            filter_cond=lambda trimmed_value: 
~trimmed_value.ilike(f"%{value}"),
+        )
+
+
+class XComFilterNotContains(fab_sqlafilters.FilterNotContains):
+    """Not Starts With filter for XCom values."""
+
+    def apply(self, query: Query, value: str) -> Query:
+        return generate_filter_value_query(
+            query=query,
+            model=self.model,
+            column_name=self.column_name,
+            filter_cond=lambda trimmed_value: 
~trimmed_value.ilike(f"%{value}%"),
+        )
+
+
+class XComFilterNotEqual(fab_sqlafilters.FilterNotEqual):
+    """Not Starts With filter for XCom values."""
+
+    def apply(self, query: Query, value: str) -> Query:
+        value = set_value_to_type(self.datamodel, self.column_name, value)
+        return generate_filter_value_query(
+            query=query,
+            model=self.model,
+            column_name=self.column_name,
+            filter_cond=lambda trimmed_value: trimmed_value != value,
+        )
+
+
 class AirflowFilterConverter(fab_sqlafilters.SQLAFilterConverter):
     """Retrieve conversion tables for Airflow-specific filters."""
 
@@ -800,6 +908,19 @@ class 
AirflowFilterConverter(fab_sqlafilters.SQLAFilterConverter):
             "is_extendedjson",
             [],
         ),
+        (
+            "is_xcom_value",
+            [
+                XComFilterStartsWith,
+                XComFilterEndsWith,
+                XComFilterEqual,
+                XComFilterContains,
+                XComFilterNotStartsWith,
+                XComFilterNotEndsWith,
+                XComFilterNotContains,
+                XComFilterNotEqual,
+            ],
+        ),
         *fab_sqlafilters.SQLAFilterConverter.conversion_table,
     )
 
@@ -864,6 +985,10 @@ class CustomSQLAInterface(SQLAInterface):
             )
         return False
 
+    def is_xcom_value(self, col_name: str) -> bool:
+        """Check if it is col_name is value of xcom table."""
+        return col_name == "value" and self.obj.__tablename__ == "xcom"
+
     def get_col_default(self, col_name: str) -> Any:
         if col_name not in self.list_columns:
             # Handle AssociationProxy etc, or anything that isn't a "real" 
column
diff --git a/tests/www/test_utils.py b/tests/www/test_utils.py
index 669f08a2b61..8ad6ef3ca47 100644
--- a/tests/www/test_utils.py
+++ b/tests/www/test_utils.py
@@ -22,7 +22,8 @@ import os
 import re
 import time
 from datetime import datetime
-from unittest.mock import Mock
+from typing import TYPE_CHECKING, Callable
+from unittest.mock import MagicMock, Mock
 from urllib.parse import parse_qs
 
 import pendulum
@@ -31,6 +32,7 @@ from bs4 import BeautifulSoup
 from flask_appbuilder.models.sqla.filters import get_field_setup_query, 
set_value_to_type
 from flask_wtf import FlaskForm
 from markupsafe import Markup
+from sqlalchemy import func
 from sqlalchemy.orm import Query
 from wtforms.fields import StringField, TextAreaField
 
@@ -47,6 +49,9 @@ from airflow.www.utils import (
 from airflow.www.widgets import AirflowDateTimePickerROWidget, 
BS3TextAreaROWidget, BS3TextFieldROWidget
 from tests.test_utils.config import conf_vars
 
+if TYPE_CHECKING:
+    from sqlalchemy.sql.elements import ColumnElement
+
 
 class TestUtils:
     def check_generate_pages_html(
@@ -663,6 +668,108 @@ class TestFilter:
         assert result_query_filter == self.mock_query
 
 
+class TestXComFilter:
+    def setup_method(self):
+        self.mock_datamodel = MagicMock()
+        self.mock_query = MagicMock(spec=Query)
+        self.mock_column_name = "test_column"
+
+    def _assert_filter_query(
+        self,
+        xcom_filter,
+        raw_value: str,
+        expected_expr_builder: Callable[[ColumnElement, str], ColumnElement],
+        convert_value: bool = False,
+    ) -> None:
+        """
+        A helper function to assert the filter query.
+
+        :param xcom_filter: The XCom filter instance (e.g., 
XComFilterStartsWith).
+        :param raw_value: The raw string value we want to filter on.
+        :param expected_expr_builder: A function that takes in 
`returned_field` and returns the expected SQL expression.
+        :param convert_value: Whether to run `set_value_to_type(...)` on the 
raw_value.
+        """
+        returned_query, returned_field = get_field_setup_query(
+            self.mock_query, self.mock_datamodel, self.mock_column_name
+        )
+
+        if convert_value:
+            value = set_value_to_type(self.mock_datamodel, 
self.mock_column_name, raw_value)
+        else:
+            value = raw_value
+        xcom_filter.apply(self.mock_query, value)
+        self.mock_query.filter.assert_called_once()
+        actual_filter_arg = self.mock_query.filter.call_args[0][0]
+        expected_filter_arg = expected_expr_builder(returned_field, value)
+        assert str(actual_filter_arg) == str(expected_filter_arg)
+
+    @pytest.mark.parametrize(
+        "filter_class, convert_value, expected_expr_builder",
+        [
+            (
+                utils.XComFilterStartsWith,
+                False,
+                lambda field, v: func.btrim(func.convert_from(field, "UTF8"), 
'"').ilike(f"{v}%"),
+            ),
+            (
+                utils.XComFilterEndsWith,
+                False,
+                lambda field, v: func.btrim(func.convert_from(field, "UTF8"), 
'"').ilike(f"%{v}"),
+            ),
+            (
+                utils.XComFilterEqual,
+                True,
+                lambda field, v: func.btrim(func.convert_from(field, "UTF8"), 
'"') == v,
+            ),
+            (
+                utils.XComFilterContains,
+                False,
+                lambda field, v: func.btrim(func.convert_from(field, "UTF8"), 
'"').ilike(f"%{v}%"),
+            ),
+            (
+                utils.XComFilterNotStartsWith,
+                False,
+                lambda field, v: ~func.btrim(func.convert_from(field, "UTF8"), 
'"').ilike(f"{v}%"),
+            ),
+            (
+                utils.XComFilterNotEndsWith,
+                False,
+                lambda field, v: ~func.btrim(func.convert_from(field, "UTF8"), 
'"').ilike(f"%{v}"),
+            ),
+            (
+                utils.XComFilterNotContains,
+                False,
+                lambda field, v: ~func.btrim(func.convert_from(field, "UTF8"), 
'"').ilike(f"%{v}%"),
+            ),
+            (
+                utils.XComFilterNotEqual,
+                True,
+                lambda field, v: func.btrim(func.convert_from(field, "UTF8"), 
'"') != v,
+            ),
+        ],
+        ids=[
+            "StartsWith",
+            "EndsWith",
+            "Equal",
+            "Contains",
+            "NotStartsWith",
+            "NotEndsWith",
+            "NotContains",
+            "NotEqual",
+        ],
+    )
+    def test_xcom_filters(self, filter_class, convert_value, 
expected_expr_builder):
+        xcom_filter_query = filter_class(datamodel=self.mock_datamodel, 
column_name=self.mock_column_name)
+        raw_value = "test_value"
+
+        self._assert_filter_query(
+            xcom_filter_query,
+            raw_value=raw_value,
+            expected_expr_builder=expected_expr_builder,
+            convert_value=convert_value,
+        )
+
+
 @pytest.mark.db_test
 def test_get_col_default_not_existing(session):
     interface = CustomSQLAInterface(obj=DagRun, session=session)

Reply via email to