This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-8-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit da72faefd6875dad7c467d2ffb53292fcfbdec5c
Author: Kalyan <[email protected]>
AuthorDate: Tue Feb 6 05:35:32 2024 +0530

    Type Check for retries: Add tests (#37183)
    
    * use type instead of isinstance
    
    * add tests
    
    (cherry picked from commit 7f44d9bc1dcaeda5c48e3d5afb395363ddba0ddb)
---
 airflow/models/baseoperator.py    |  2 +-
 tests/models/test_baseoperator.py | 18 +++++++++++++++---
 2 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index f7f1d6ccc6..37026564a6 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -120,7 +120,7 @@ logger = 
logging.getLogger("airflow.models.baseoperator.BaseOperator")
 
 
 def parse_retries(retries: Any) -> int | None:
-    if retries is None or isinstance(retries, int):
+    if retries is None or type(retries) == int:  # noqa: E721
         return retries
     try:
         parsed_retries = int(retries)
diff --git a/tests/models/test_baseoperator.py 
b/tests/models/test_baseoperator.py
index fb46fd39c7..4454ef137c 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -101,6 +101,11 @@ class MockNamedTuple(NamedTuple):
     var2: str
 
 
+class CustomInt(int):
+    def __int__(self):
+        raise ValueError("Cannot cast to int")
+
+
 class TestBaseOperator:
     def test_expand(self):
         dummy = DummyClass(test_param=True)
@@ -828,11 +833,18 @@ def test_init_subclass_args():
 
 
 @pytest.mark.db_test
-def test_operator_retries_invalid(dag_maker):
[email protected](
+    ("retries", "expected"),
+    [
+        pytest.param("foo", "'retries' type must be int, not str", 
id="string"),
+        pytest.param(CustomInt(10), "'retries' type must be int, not 
CustomInt", id="custom int"),
+    ],
+)
+def test_operator_retries_invalid(dag_maker, retries, expected):
     with pytest.raises(AirflowException) as ctx:
         with dag_maker():
-            BaseOperator(task_id="test_illegal_args", retries="foo")
-    assert str(ctx.value) == "'retries' type must be int, not str"
+            BaseOperator(task_id="test_illegal_args", retries=retries)
+    assert str(ctx.value) == expected
 
 
 @pytest.mark.db_test

Reply via email to