This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch v2-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v2-10-test by this push:
new bed1affbb66 Check pool_slots on partial task import instead of
execution (#39724) (#42693)
bed1affbb66 is described below
commit bed1affbb66345ada0de9435828a6519ff76716e
Author: Karen Braganza <[email protected]>
AuthorDate: Wed Nov 27 05:55:37 2024 -0500
Check pool_slots on partial task import instead of execution (#39724)
(#42693)
Co-authored-by: Ryan Hatter <[email protected]>
Co-authored-by: Utkarsh Sharma <[email protected]>
---
airflow/decorators/base.py | 6 ++++++
airflow/models/baseoperator.py | 5 +++++
tests/models/test_mappedoperator.py | 9 +++++++++
3 files changed, 20 insertions(+)
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index d743acbe50b..bcb64aaa6eb 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -457,6 +457,12 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams,
FReturn, OperatorSubcla
end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date",
None))
if partial_kwargs.get("pool") is None:
partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
+ if "pool_slots" in partial_kwargs:
+ if partial_kwargs["pool_slots"] < 1:
+ dag_str = ""
+ if dag:
+ dag_str = f" in dag {dag.dag_id}"
+ raise ValueError(f"pool slots for {task_id}{dag_str} cannot be
less than 1")
partial_kwargs["retries"] =
parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES))
partial_kwargs["retry_delay"] = coerce_timedelta(
partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY),
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 11522060fe0..773552184f1 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -365,6 +365,11 @@ def partial(
partial_kwargs["end_date"] =
timezone.convert_to_utc(partial_kwargs["end_date"])
if partial_kwargs["pool"] is None:
partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
+ if partial_kwargs["pool_slots"] < 1:
+ dag_str = ""
+ if dag:
+ dag_str = f" in dag {dag.dag_id}"
+ raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less
than 1")
partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"])
partial_kwargs["retry_delay"] =
coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay")
if partial_kwargs["max_retry_delay"] is not None:
diff --git a/tests/models/test_mappedoperator.py
b/tests/models/test_mappedoperator.py
index 01991c0bb45..cf547912fb9 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -221,6 +221,15 @@ def test_partial_on_class_invalid_ctor_args() -> None:
MockOperator.partial(task_id="a", foo="bar", bar=2)
+def test_partial_on_invalid_pool_slots_raises() -> None:
+ """Test that when we pass an invalid value to pool_slots in partial(),
+
+ i.e. if the value is not an integer, an error is raised at import time."""
+
+ with pytest.raises(TypeError, match="'<' not supported between instances
of 'str' and 'int'"):
+ MockOperator.partial(task_id="pool_slots_test", pool="test",
pool_slots="a").expand(arg1=[1, 2, 3])
+
+
@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation
mode
@pytest.mark.parametrize(
["num_existing_tis", "expected"],