This is an automated email from the ASF dual-hosted git repository.

utkarsharma pushed a commit to branch sync_2-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 3d34df4a2826d3409836abee6fc878d1ada5cf66
Author: Karen Braganza <[email protected]>
AuthorDate: Wed Nov 27 05:55:37 2024 -0500

    Check pool_slots on partial task import instead of execution (#39724) 
(#42693)
    
    Co-authored-by: Ryan Hatter <[email protected]>
    Co-authored-by: Utkarsh Sharma <[email protected]>
---
 airflow/decorators/base.py          | 6 ++++++
 airflow/models/baseoperator.py      | 5 +++++
 tests/models/test_mappedoperator.py | 9 +++++++++
 3 files changed, 20 insertions(+)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index d743acbe50b..bcb64aaa6eb 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -457,6 +457,12 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, 
FReturn, OperatorSubcla
         end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", 
None))
         if partial_kwargs.get("pool") is None:
             partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
+        if "pool_slots" in partial_kwargs:
+            if partial_kwargs["pool_slots"] < 1:
+                dag_str = ""
+                if dag:
+                    dag_str = f" in dag {dag.dag_id}"
+                raise ValueError(f"pool slots for {task_id}{dag_str} cannot be 
less than 1")
         partial_kwargs["retries"] = 
parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES))
         partial_kwargs["retry_delay"] = coerce_timedelta(
             partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY),
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 11522060fe0..773552184f1 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -365,6 +365,11 @@ def partial(
     partial_kwargs["end_date"] = 
timezone.convert_to_utc(partial_kwargs["end_date"])
     if partial_kwargs["pool"] is None:
         partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
+    if partial_kwargs["pool_slots"] < 1:
+        dag_str = ""
+        if dag:
+            dag_str = f" in dag {dag.dag_id}"
+        raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less 
than 1")
     partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"])
     partial_kwargs["retry_delay"] = 
coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay")
     if partial_kwargs["max_retry_delay"] is not None:
diff --git a/tests/models/test_mappedoperator.py 
b/tests/models/test_mappedoperator.py
index 01991c0bb45..cf547912fb9 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -221,6 +221,15 @@ def test_partial_on_class_invalid_ctor_args() -> None:
         MockOperator.partial(task_id="a", foo="bar", bar=2)
 
 
+def test_partial_on_invalid_pool_slots_raises() -> None:
+    """Test that when we pass an invalid value to pool_slots in partial(),
+
+    i.e. if the value is not an integer, an error is raised at import time."""
+
+    with pytest.raises(TypeError, match="'<' not supported between instances 
of 'str' and 'int'"):
+        MockOperator.partial(task_id="pool_slots_test", pool="test", 
pool_slots="a").expand(arg1=[1, 2, 3])
+
+
 @pytest.mark.skip_if_database_isolation_mode  # Does not work in db isolation 
mode
 @pytest.mark.parametrize(
     ["num_existing_tis", "expected"],

Reply via email to