This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-8-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit da72faefd6875dad7c467d2ffb53292fcfbdec5c Author: Kalyan <[email protected]> AuthorDate: Tue Feb 6 05:35:32 2024 +0530 Type Check for retries: Add tests (#37183) * use type instead of isinstance * add tests (cherry picked from commit 7f44d9bc1dcaeda5c48e3d5afb395363ddba0ddb) --- airflow/models/baseoperator.py | 2 +- tests/models/test_baseoperator.py | 18 +++++++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index f7f1d6ccc6..37026564a6 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -120,7 +120,7 @@ logger = logging.getLogger("airflow.models.baseoperator.BaseOperator") def parse_retries(retries: Any) -> int | None: - if retries is None or isinstance(retries, int): + if retries is None or type(retries) == int: # noqa: E721 return retries try: parsed_retries = int(retries) diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index fb46fd39c7..4454ef137c 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -101,6 +101,11 @@ class MockNamedTuple(NamedTuple): var2: str +class CustomInt(int): + def __int__(self): + raise ValueError("Cannot cast to int") + + class TestBaseOperator: def test_expand(self): dummy = DummyClass(test_param=True) @@ -828,11 +833,18 @@ def test_init_subclass_args(): @pytest.mark.db_test -def test_operator_retries_invalid(dag_maker): [email protected]( + ("retries", "expected"), + [ + pytest.param("foo", "'retries' type must be int, not str", id="string"), + pytest.param(CustomInt(10), "'retries' type must be int, not CustomInt", id="custom int"), + ], +) +def test_operator_retries_invalid(dag_maker, retries, expected): with pytest.raises(AirflowException) as ctx: with dag_maker(): - BaseOperator(task_id="test_illegal_args", retries="foo") - assert str(ctx.value) == "'retries' type must be int, not str" + BaseOperator(task_id="test_illegal_args", retries=retries) + assert str(ctx.value) == expected @pytest.mark.db_test
