This is an automated email from the ASF dual-hosted git repository.
bolke pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 20cb70b316 Fix broken regex for allowed_deserialization_classes
(#36147)
20cb70b316 is described below
commit 20cb70b316e4efcd75cc68e98c1ae28e14ade573
Author: tobiaszorzetto <[email protected]>
AuthorDate: Mon Jan 22 07:24:28 2024 -0300
Fix broken regex for allowed_deserialization_classes (#36147)
---------
Co-authored-by: Victor Dominguite <[email protected]>
Co-authored-by: Elad Kalif <[email protected]>
---
airflow/config_templates/config.yml | 15 +++++++--
airflow/config_templates/unit_tests.cfg | 2 +-
airflow/serialization/serde.py | 27 ++++++++++++---
newsfragments/36147.significant.rst | 11 +++++++
tests/serialization/test_serde.py | 58 ++++++++++++++++++++++++++++++---
5 files changed, 101 insertions(+), 12 deletions(-)
diff --git a/airflow/config_templates/config.yml
b/airflow/config_templates/config.yml
index f6c0c09fc0..8785cb08b6 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -241,11 +241,20 @@ core:
allowed_deserialization_classes:
description: |
What classes can be imported during deserialization. This is a multi
line value.
- The individual items will be parsed as regexp. Python built-in classes
(like dict)
- are always allowed. Bare "." will be replaced so you can set airflow.*
.
+ The individual items will be parsed as a pattern to a glob function.
+ Python built-in classes (like dict) are always allowed.
version_added: 2.5.0
type: string
- default: 'airflow\..*'
+ default: 'airflow.*'
+ example: ~
+ allowed_deserialization_classes_regexp:
+ description: |
+ What classes can be imported during deserialization. This is a multi
line value.
+ The individual items will be parsed as regexp patterns.
+ This is a secondary option to ``allowed_deserialization_classes``.
+ version_added: 2.8.1
+ type: string
+ default: ''
example: ~
killed_task_cleanup_time:
description: |
diff --git a/airflow/config_templates/unit_tests.cfg
b/airflow/config_templates/unit_tests.cfg
index 69c2d65bba..42055b9d9c 100644
--- a/airflow/config_templates/unit_tests.cfg
+++ b/airflow/config_templates/unit_tests.cfg
@@ -58,7 +58,7 @@ unit_test_mode = True
# We want to use a shorter timeout for task cleanup
killed_task_cleanup_time = 5
# We only allow our own classes to be deserialized in tests
-allowed_deserialization_classes = airflow\..* tests\..*
+allowed_deserialization_classes = airflow.* tests.*
[database]
diff --git a/airflow/serialization/serde.py b/airflow/serialization/serde.py
index 23d67e6162..a214acc9a6 100644
--- a/airflow/serialization/serde.py
+++ b/airflow/serialization/serde.py
@@ -22,6 +22,7 @@ import enum
import functools
import logging
import sys
+from fnmatch import fnmatch
from importlib import import_module
from typing import TYPE_CHECKING, Any, Pattern, TypeVar, Union, cast
@@ -241,7 +242,6 @@ def deserialize(o: T | None, full=True, type_hint: Any =
None) -> object:
# only return string representation
if not full:
return _stringify(classname, version, value)
-
if not _match(classname) and classname not in _extra_allowed:
raise ImportError(
f"{classname} was not found in allow list for deserialization
imports. "
@@ -288,7 +288,22 @@ def _convert(old: dict) -> dict:
def _match(classname: str) -> bool:
- return any(p.match(classname) is not None for p in _get_patterns())
+ """Checks if the given classname matches a path pattern either using glob
format or regexp format."""
+ return _match_glob(classname) or _match_regexp(classname)
+
+
[email protected]_cache(maxsize=None)
+def _match_glob(classname: str):
+ """Checks if the given classname matches a pattern from
allowed_deserialization_classes using glob syntax."""
+ patterns = _get_patterns()
+ return any(fnmatch(classname, p.pattern) for p in patterns)
+
+
[email protected]_cache(maxsize=None)
+def _match_regexp(classname: str):
+ """Checks if the given classname matches a pattern from
allowed_deserialization_classes_regexp using regexp."""
+ patterns = _get_regexp_patterns()
+ return any(p.match(classname) is not None for p in patterns)
def _stringify(classname: str, version: int, value: T | None) -> str:
@@ -359,8 +374,12 @@ def _register():
@functools.lru_cache(maxsize=None)
def _get_patterns() -> list[Pattern]:
- patterns = conf.get("core", "allowed_deserialization_classes").split()
- return [re2.compile(re2.sub(r"(\w)\.", r"\1\..", p)) for p in patterns]
+ return [re2.compile(p) for p in conf.get("core",
"allowed_deserialization_classes").split()]
+
+
[email protected]_cache(maxsize=None)
+def _get_regexp_patterns() -> list[Pattern]:
+ return [re2.compile(p) for p in conf.get("core",
"allowed_deserialization_classes_regexp").split()]
_register()
diff --git a/newsfragments/36147.significant.rst
b/newsfragments/36147.significant.rst
new file mode 100644
index 0000000000..105e1b74d4
--- /dev/null
+++ b/newsfragments/36147.significant.rst
@@ -0,0 +1,11 @@
+The ``allowed_deserialization_classes`` flag now follows a glob pattern.
+
+For example if one wants to add the class ``airflow.tests.custom_class`` to the
+``allowed_deserialization_classes`` list, it can be done by writing the full
class
+name (``airflow.tests.custom_class``) or a pattern such as the ones used in
glob
+search (e.g., ``airflow.*``, ``airflow.tests.*``).
+
+If you currently use a custom regexp path make sure to rewrite it as a glob
pattern.
+
+Alternatively, if you still wish to match it as a regexp pattern, add it under
the new
+list ``allowed_deserialization_classes_regexp`` instead.
diff --git a/tests/serialization/test_serde.py
b/tests/serialization/test_serde.py
index dc3d3faf1e..bd39ac7166 100644
--- a/tests/serialization/test_serde.py
+++ b/tests/serialization/test_serde.py
@@ -33,7 +33,10 @@ from airflow.serialization.serde import (
SCHEMA_ID,
VERSION,
_get_patterns,
+ _get_regexp_patterns,
_match,
+ _match_glob,
+ _match_regexp,
deserialize,
serialize,
)
@@ -44,10 +47,16 @@ from tests.test_utils.config import conf_vars
@pytest.fixture()
def recalculate_patterns():
_get_patterns.cache_clear()
+ _get_regexp_patterns.cache_clear()
+ _match_glob.cache_clear()
+ _match_regexp.cache_clear()
try:
yield
finally:
_get_patterns.cache_clear()
+ _get_regexp_patterns.cache_clear()
+ _match_glob.cache_clear()
+ _match_regexp.cache_clear()
class Z:
@@ -218,7 +227,7 @@ class TestSerDe:
@conf_vars(
{
- ("core", "allowed_deserialization_classes"): "airflow[.].*",
+ ("core", "allowed_deserialization_classes"): "airflow.*",
}
)
@pytest.mark.usefixtures("recalculate_patterns")
@@ -232,13 +241,54 @@ class TestSerDe:
@conf_vars(
{
- ("core", "allowed_deserialization_classes"): "tests.*",
+ ("core", "allowed_deserialization_classes"): "tests.airflow.*",
}
)
@pytest.mark.usefixtures("recalculate_patterns")
- def test_allow_list_replace(self):
+ def test_allow_list_match(self):
assert _match("tests.airflow.deep")
- assert _match("testsfault") is False
+ assert _match("tests.wrongpath") is False
+
+ @conf_vars(
+ {
+ ("core", "allowed_deserialization_classes"): "tests.airflow.deep",
+ }
+ )
+ @pytest.mark.usefixtures("recalculate_patterns")
+ def test_allow_list_match_class(self):
+ """Test the match function when passing a full classname as
+ allowed_deserialization_classes
+ """
+ assert _match("tests.airflow.deep")
+ assert _match("tests.airflow.FALSE") is False
+
+ @conf_vars(
+ {
+ ("core", "allowed_deserialization_classes"): "",
+ ("core", "allowed_deserialization_classes_regexp"):
"tests\.airflow\..",
+ }
+ )
+ @pytest.mark.usefixtures("recalculate_patterns")
+ def test_allow_list_match_regexp(self):
+ """Test the match function when passing a path as
+ allowed_deserialization_classes_regexp with no glob pattern defined
+ """
+ assert _match("tests.airflow.deep")
+ assert _match("tests.wrongpath") is False
+
+ @conf_vars(
+ {
+ ("core", "allowed_deserialization_classes"): "",
+ ("core", "allowed_deserialization_classes_regexp"):
"tests\.airflow\.deep",
+ }
+ )
+ @pytest.mark.usefixtures("recalculate_patterns")
+ def test_allow_list_match_class_regexp(self):
+ """Test the match function when passing a full classname as
+ allowed_deserialization_classes_regexp with no glob pattern defined
+ """
+ assert _match("tests.airflow.deep")
+ assert _match("tests.airflow.FALSE") is False
def test_incompatible_version(self):
data = dict(