This is an automated email from the ASF dual-hosted git repository. utkarsharma pushed a commit to branch sync_2-10-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 3d34df4a2826d3409836abee6fc878d1ada5cf66 Author: Karen Braganza <[email protected]> AuthorDate: Wed Nov 27 05:55:37 2024 -0500 Check pool_slots on partial task import instead of execution (#39724) (#42693) Co-authored-by: Ryan Hatter <[email protected]> Co-authored-by: Utkarsh Sharma <[email protected]> --- airflow/decorators/base.py | 6 ++++++ airflow/models/baseoperator.py | 5 +++++ tests/models/test_mappedoperator.py | 9 +++++++++ 3 files changed, 20 insertions(+) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index d743acbe50b..bcb64aaa6eb 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -457,6 +457,12 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, FReturn, OperatorSubcla end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None)) if partial_kwargs.get("pool") is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME + if "pool_slots" in partial_kwargs: + if partial_kwargs["pool_slots"] < 1: + dag_str = "" + if dag: + dag_str = f" in dag {dag.dag_id}" + raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") partial_kwargs["retries"] = parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES)) partial_kwargs["retry_delay"] = coerce_timedelta( partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 11522060fe0..773552184f1 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -365,6 +365,11 @@ def partial( partial_kwargs["end_date"] = timezone.convert_to_utc(partial_kwargs["end_date"]) if partial_kwargs["pool"] is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME + if partial_kwargs["pool_slots"] < 1: + dag_str = "" + if dag: + dag_str = f" in dag {dag.dag_id}" + raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"]) partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay") if partial_kwargs["max_retry_delay"] is not None: diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 01991c0bb45..cf547912fb9 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -221,6 +221,15 @@ def test_partial_on_class_invalid_ctor_args() -> None: MockOperator.partial(task_id="a", foo="bar", bar=2) +def test_partial_on_invalid_pool_slots_raises() -> None: + """Test that when we pass an invalid value to pool_slots in partial(), + + i.e. if the value is not an integer, an error is raised at import time.""" + + with pytest.raises(TypeError, match="'<' not supported between instances of 'str' and 'int'"): + MockOperator.partial(task_id="pool_slots_test", pool="test", pool_slots="a").expand(arg1=[1, 2, 3]) + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode @pytest.mark.parametrize( ["num_existing_tis", "expected"],
