This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-8-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 9d9d62e5cab46320f5c0a5423fa652269568f90a
Author: tobiaszorzetto <[email protected]>
AuthorDate: Mon Jan 22 07:24:28 2024 -0300

    Fix broken regex for allowed_deserialization_classes (#36147)
    
    ---------
    
    Co-authored-by: Victor Dominguite <[email protected]>
    Co-authored-by: Elad Kalif <[email protected]>
    (cherry picked from commit 20cb70b316e4efcd75cc68e98c1ae28e14ade573)
---
 airflow/config_templates/config.yml     | 15 +++++++--
 airflow/config_templates/unit_tests.cfg |  2 +-
 airflow/serialization/serde.py          | 27 ++++++++++++---
 newsfragments/36147.significant.rst     | 11 +++++++
 tests/serialization/test_serde.py       | 58 ++++++++++++++++++++++++++++++---
 5 files changed, 101 insertions(+), 12 deletions(-)

diff --git a/airflow/config_templates/config.yml 
b/airflow/config_templates/config.yml
index 4107a49895..cae508e773 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -241,11 +241,20 @@ core:
     allowed_deserialization_classes:
       description: |
         What classes can be imported during deserialization. This is a multi 
line value.
-        The individual items will be parsed as regexp. Python built-in classes 
(like dict)
-        are always allowed. Bare "." will be replaced so you can set airflow.* 
.
+        The individual items will be parsed as a pattern to a glob function.
+        Python built-in classes (like dict) are always allowed.
       version_added: 2.5.0
       type: string
-      default: 'airflow\..*'
+      default: 'airflow.*'
+      example: ~
+    allowed_deserialization_classes_regexp:
+      description: |
+        What classes can be imported during deserialization. This is a multi 
line value.
+        The individual items will be parsed as regexp patterns.
+        This is a secondary option to ``allowed_deserialization_classes``.
+      version_added: 2.8.1
+      type: string
+      default: ''
       example: ~
     killed_task_cleanup_time:
       description: |
diff --git a/airflow/config_templates/unit_tests.cfg 
b/airflow/config_templates/unit_tests.cfg
index 69c2d65bba..42055b9d9c 100644
--- a/airflow/config_templates/unit_tests.cfg
+++ b/airflow/config_templates/unit_tests.cfg
@@ -58,7 +58,7 @@ unit_test_mode = True
 # We want to use a shorter timeout for task cleanup
 killed_task_cleanup_time = 5
 # We only allow our own classes to be deserialized in tests
-allowed_deserialization_classes = airflow\..* tests\..*
+allowed_deserialization_classes = airflow.* tests.*
 
 [database]
 
diff --git a/airflow/serialization/serde.py b/airflow/serialization/serde.py
index 23d67e6162..a214acc9a6 100644
--- a/airflow/serialization/serde.py
+++ b/airflow/serialization/serde.py
@@ -22,6 +22,7 @@ import enum
 import functools
 import logging
 import sys
+from fnmatch import fnmatch
 from importlib import import_module
 from typing import TYPE_CHECKING, Any, Pattern, TypeVar, Union, cast
 
@@ -241,7 +242,6 @@ def deserialize(o: T | None, full=True, type_hint: Any = 
None) -> object:
     # only return string representation
     if not full:
         return _stringify(classname, version, value)
-
     if not _match(classname) and classname not in _extra_allowed:
         raise ImportError(
             f"{classname} was not found in allow list for deserialization 
imports. "
@@ -288,7 +288,22 @@ def _convert(old: dict) -> dict:
 
 
 def _match(classname: str) -> bool:
-    return any(p.match(classname) is not None for p in _get_patterns())
+    """Checks if the given classname matches a path pattern either using glob 
format or regexp format."""
+    return _match_glob(classname) or _match_regexp(classname)
+
+
[email protected]_cache(maxsize=None)
+def _match_glob(classname: str):
+    """Checks if the given classname matches a pattern from 
allowed_deserialization_classes using glob syntax."""
+    patterns = _get_patterns()
+    return any(fnmatch(classname, p.pattern) for p in patterns)
+
+
[email protected]_cache(maxsize=None)
+def _match_regexp(classname: str):
+    """Checks if the given classname matches a pattern from 
allowed_deserialization_classes_regexp using regexp."""
+    patterns = _get_regexp_patterns()
+    return any(p.match(classname) is not None for p in patterns)
 
 
 def _stringify(classname: str, version: int, value: T | None) -> str:
@@ -359,8 +374,12 @@ def _register():
 
 @functools.lru_cache(maxsize=None)
 def _get_patterns() -> list[Pattern]:
-    patterns = conf.get("core", "allowed_deserialization_classes").split()
-    return [re2.compile(re2.sub(r"(\w)\.", r"\1\..", p)) for p in patterns]
+    return [re2.compile(p) for p in conf.get("core", 
"allowed_deserialization_classes").split()]
+
+
[email protected]_cache(maxsize=None)
+def _get_regexp_patterns() -> list[Pattern]:
+    return [re2.compile(p) for p in conf.get("core", 
"allowed_deserialization_classes_regexp").split()]
 
 
 _register()
diff --git a/newsfragments/36147.significant.rst 
b/newsfragments/36147.significant.rst
new file mode 100644
index 0000000000..105e1b74d4
--- /dev/null
+++ b/newsfragments/36147.significant.rst
@@ -0,0 +1,11 @@
+The ``allowed_deserialization_classes`` flag now follows a glob pattern.
+
+For example if one wants to add the class ``airflow.tests.custom_class`` to the
+``allowed_deserialization_classes`` list, it can be done by writing the full 
class
+name (``airflow.tests.custom_class``) or a pattern such as the ones used in 
glob
+search (e.g., ``airflow.*``, ``airflow.tests.*``).
+
+If you currently use a custom regexp path make sure to rewrite it as a glob 
pattern.
+
+Alternatively, if you still wish to match it as a regexp pattern, add it under 
the new
+list ``allowed_deserialization_classes_regexp`` instead.
diff --git a/tests/serialization/test_serde.py 
b/tests/serialization/test_serde.py
index dc3d3faf1e..bd39ac7166 100644
--- a/tests/serialization/test_serde.py
+++ b/tests/serialization/test_serde.py
@@ -33,7 +33,10 @@ from airflow.serialization.serde import (
     SCHEMA_ID,
     VERSION,
     _get_patterns,
+    _get_regexp_patterns,
     _match,
+    _match_glob,
+    _match_regexp,
     deserialize,
     serialize,
 )
@@ -44,10 +47,16 @@ from tests.test_utils.config import conf_vars
 @pytest.fixture()
 def recalculate_patterns():
     _get_patterns.cache_clear()
+    _get_regexp_patterns.cache_clear()
+    _match_glob.cache_clear()
+    _match_regexp.cache_clear()
     try:
         yield
     finally:
         _get_patterns.cache_clear()
+        _get_regexp_patterns.cache_clear()
+        _match_glob.cache_clear()
+        _match_regexp.cache_clear()
 
 
 class Z:
@@ -218,7 +227,7 @@ class TestSerDe:
 
     @conf_vars(
         {
-            ("core", "allowed_deserialization_classes"): "airflow[.].*",
+            ("core", "allowed_deserialization_classes"): "airflow.*",
         }
     )
     @pytest.mark.usefixtures("recalculate_patterns")
@@ -232,13 +241,54 @@ class TestSerDe:
 
     @conf_vars(
         {
-            ("core", "allowed_deserialization_classes"): "tests.*",
+            ("core", "allowed_deserialization_classes"): "tests.airflow.*",
         }
     )
     @pytest.mark.usefixtures("recalculate_patterns")
-    def test_allow_list_replace(self):
+    def test_allow_list_match(self):
         assert _match("tests.airflow.deep")
-        assert _match("testsfault") is False
+        assert _match("tests.wrongpath") is False
+
+    @conf_vars(
+        {
+            ("core", "allowed_deserialization_classes"): "tests.airflow.deep",
+        }
+    )
+    @pytest.mark.usefixtures("recalculate_patterns")
+    def test_allow_list_match_class(self):
+        """Test the match function when passing a full classname as
+        allowed_deserialization_classes
+        """
+        assert _match("tests.airflow.deep")
+        assert _match("tests.airflow.FALSE") is False
+
+    @conf_vars(
+        {
+            ("core", "allowed_deserialization_classes"): "",
+            ("core", "allowed_deserialization_classes_regexp"): 
"tests\.airflow\..",
+        }
+    )
+    @pytest.mark.usefixtures("recalculate_patterns")
+    def test_allow_list_match_regexp(self):
+        """Test the match function when passing a path as
+        allowed_deserialization_classes_regexp with no glob pattern defined
+        """
+        assert _match("tests.airflow.deep")
+        assert _match("tests.wrongpath") is False
+
+    @conf_vars(
+        {
+            ("core", "allowed_deserialization_classes"): "",
+            ("core", "allowed_deserialization_classes_regexp"): 
"tests\.airflow\.deep",
+        }
+    )
+    @pytest.mark.usefixtures("recalculate_patterns")
+    def test_allow_list_match_class_regexp(self):
+        """Test the match function when passing a full classname as
+        allowed_deserialization_classes_regexp with no glob pattern defined
+        """
+        assert _match("tests.airflow.deep")
+        assert _match("tests.airflow.FALSE") is False
 
     def test_incompatible_version(self):
         data = dict(

Reply via email to