kaxil commented on code in PR #62174:
URL: https://github.com/apache/airflow/pull/62174#discussion_r2927160308
##########
task-sdk/tests/task_sdk/bases/test_decorator.py:
##########
@@ -69,6 +70,208 @@ def
test_get_python_source_strips_decorator_and_comment(self, tmp_path: Path):
assert cleaned.lstrip().splitlines()[0].startswith("def a_task")
+class DummyDecoratedOperator(DecoratedOperator):
+ custom_operator_name = "@task.dummy"
+
+ def execute(self, context):
+ return self.python_callable(*self.op_args, **self.op_kwargs)
+
+
+class TestDefaultFillingLogic:
+ @pytest.mark.parametrize(
+ ("func", "kwargs", "args"),
+ [
+ pytest.param(
+ lambda: 42,
+ {},
+ [],
+ id="no_params",
+ ),
+ pytest.param(
+ lambda a, b=5, c=None: (a, b, c),
+ {"a": 1},
+ [],
+ id="param_after_first_default_without_default",
+ ),
+ pytest.param(
+ lambda x, y=99: (x, y),
+ {"x": 1},
+ [],
+ id="param_after_first_default_is_given_none",
+ ),
+ pytest.param(
+ lambda a, b, c=99: (a, b, c),
+ {},
+ [1, 2],
+ id="single_trailing_optional",
+ ),
+ ],
+ )
+ def test_construction_succeeds(self, func, kwargs, args):
+ op = make_op(func, op_kwargs=kwargs, op_args=args)
+ assert op is not None
+
+ @pytest.mark.parametrize(
+ ("func", "op_kwargs", "op_args", "expected_defaults"),
+ [
+ pytest.param(
+ lambda a, b, c: a + b + c,
+ {},
+ [1, 2, 3],
+ [inspect.Parameter.empty, inspect.Parameter.empty,
inspect.Parameter.empty],
+ id="all_required_no_defaults_injected",
+ ),
+ pytest.param(
+ lambda required, optional=10: required + optional,
+ {"required": 5},
+ [],
+ [inspect.Parameter.empty, 10],
+ id="params_before_first_default_stay_required",
+ ),
+ pytest.param(
+ lambda a, b=1, c=2, d=3: a + b + c + d,
+ {"a": 10},
+ [],
+ [inspect.Parameter.empty, 1, 2, 3],
+ id="explicit_defaults_after_first_default_preserved",
+ ),
+ pytest.param(
+ lambda no_default_1, no_default_2, first_default=42,
after=None: None,
+ {},
+ [1, 2],
+ [inspect.Parameter.empty, inspect.Parameter.empty, 42, None],
+ id="first_default_defines_boundary",
+ ),
+ pytest.param(
+ lambda a=1, b=2, c=3: a + b + c,
+ {},
+ [],
+ [1, 2, 3],
+ id="all_params_have_defaults_none_overwritten",
+ ),
+ ],
+ )
+ def test_param_defaults(self, func, op_kwargs, op_args, expected_defaults):
+ op = make_op(func, op_kwargs=op_kwargs, op_args=op_args)
+ sig = inspect.signature(op.python_callable)
Review Comment:
`inspect.signature(op.python_callable)` returns the original function's
signature. The fix modifies a local `parameters` list used only for `bind()`
validation; it never touches the callable itself. So this test is verifying
that the original callable's defaults are preserved (which was always true),
not that the fix's default-injection logic produces the right values.
Consider inspecting the modified signature instead (e.g., by capturing it
inside `__init__` or testing that `signature.bind()` succeeds/fails as
expected).
##########
task-sdk/tests/task_sdk/bases/test_decorator.py:
##########
@@ -69,6 +70,208 @@ def
test_get_python_source_strips_decorator_and_comment(self, tmp_path: Path):
assert cleaned.lstrip().splitlines()[0].startswith("def a_task")
+class DummyDecoratedOperator(DecoratedOperator):
+ custom_operator_name = "@task.dummy"
+
+ def execute(self, context):
+ return self.python_callable(*self.op_args, **self.op_kwargs)
+
+
+class TestDefaultFillingLogic:
+ @pytest.mark.parametrize(
+ ("func", "kwargs", "args"),
+ [
+ pytest.param(
+ lambda: 42,
+ {},
+ [],
+ id="no_params",
+ ),
+ pytest.param(
+ lambda a, b=5, c=None: (a, b, c),
+ {"a": 1},
+ [],
+ id="param_after_first_default_without_default",
+ ),
+ pytest.param(
+ lambda x, y=99: (x, y),
+ {"x": 1},
+ [],
+ id="param_after_first_default_is_given_none",
+ ),
+ pytest.param(
+ lambda a, b, c=99: (a, b, c),
+ {},
+ [1, 2],
+ id="single_trailing_optional",
+ ),
+ ],
+ )
+ def test_construction_succeeds(self, func, kwargs, args):
Review Comment:
None of the `test_construction_succeeds` cases reproduce the exact scenario
from issue #56128: `def foo(start_date, end_date): ...; foo(None, None)`.
Consider adding a case using actual context key names to serve as a regression
test for that specific bug.
##########
task-sdk/tests/task_sdk/bases/test_decorator.py:
##########
@@ -69,6 +70,208 @@ def
test_get_python_source_strips_decorator_and_comment(self, tmp_path: Path):
assert cleaned.lstrip().splitlines()[0].startswith("def a_task")
+class DummyDecoratedOperator(DecoratedOperator):
+ custom_operator_name = "@task.dummy"
+
+ def execute(self, context):
+ return self.python_callable(*self.op_args, **self.op_kwargs)
+
+
+class TestDefaultFillingLogic:
+ @pytest.mark.parametrize(
+ ("func", "kwargs", "args"),
+ [
+ pytest.param(
+ lambda: 42,
+ {},
+ [],
+ id="no_params",
+ ),
+ pytest.param(
+ lambda a, b=5, c=None: (a, b, c),
+ {"a": 1},
+ [],
+ id="param_after_first_default_without_default",
+ ),
+ pytest.param(
+ lambda x, y=99: (x, y),
+ {"x": 1},
+ [],
+ id="param_after_first_default_is_given_none",
+ ),
+ pytest.param(
+ lambda a, b, c=99: (a, b, c),
+ {},
+ [1, 2],
+ id="single_trailing_optional",
+ ),
+ ],
+ )
+ def test_construction_succeeds(self, func, kwargs, args):
+ op = make_op(func, op_kwargs=kwargs, op_args=args)
+ assert op is not None
+
+ @pytest.mark.parametrize(
+ ("func", "op_kwargs", "op_args", "expected_defaults"),
+ [
+ pytest.param(
+ lambda a, b, c: a + b + c,
+ {},
+ [1, 2, 3],
+ [inspect.Parameter.empty, inspect.Parameter.empty,
inspect.Parameter.empty],
+ id="all_required_no_defaults_injected",
+ ),
+ pytest.param(
+ lambda required, optional=10: required + optional,
+ {"required": 5},
+ [],
+ [inspect.Parameter.empty, 10],
+ id="params_before_first_default_stay_required",
+ ),
+ pytest.param(
+ lambda a, b=1, c=2, d=3: a + b + c + d,
+ {"a": 10},
+ [],
+ [inspect.Parameter.empty, 1, 2, 3],
+ id="explicit_defaults_after_first_default_preserved",
+ ),
+ pytest.param(
+ lambda no_default_1, no_default_2, first_default=42,
after=None: None,
+ {},
+ [1, 2],
+ [inspect.Parameter.empty, inspect.Parameter.empty, 42, None],
+ id="first_default_defines_boundary",
+ ),
+ pytest.param(
+ lambda a=1, b=2, c=3: a + b + c,
+ {},
+ [],
+ [1, 2, 3],
+ id="all_params_have_defaults_none_overwritten",
+ ),
+ ],
+ )
+ def test_param_defaults(self, func, op_kwargs, op_args, expected_defaults):
+ op = make_op(func, op_kwargs=op_kwargs, op_args=op_args)
+ sig = inspect.signature(op.python_callable)
+ actual = [p.default for p in sig.parameters.values()]
+ assert actual == expected_defaults
+
+ def test_context_key_default_none_does_not_raise(self):
+ from airflow.sdk.bases.decorator import KNOWN_CONTEXT_KEYS
+
+ ctx_key = next(iter(KNOWN_CONTEXT_KEYS))
+ f = _make_func(f"def dummy_task(x, {ctx_key}=None): return x")
+ assert make_op(f, op_kwargs={"x": 1}) is not None
+
+ def test_context_key_with_non_none_default_raises(self):
+ from airflow.sdk.bases.decorator import KNOWN_CONTEXT_KEYS
+
+ ctx_key = next(iter(KNOWN_CONTEXT_KEYS))
+ f = _make_func(f"def dummy_task(x, {ctx_key}='bad_default'): return x")
+ with pytest.raises(ValueError, match="can't have a default other than
None"):
+ make_op(f, op_kwargs={"x": 1})
+
+ @pytest.mark.parametrize(
+ ("func_src", "op_kwargs"),
+ [
+ pytest.param(
+ "def dummy_task({ctx0}, x, y=10): return (x, y)",
+ {"x": 1},
+ id="context_key_before_first_default_shifts_boundary",
+ ),
+ pytest.param(
+ "def dummy_task(x, y=5, {ctx0}=None): return (x, y)",
+ {"x": 1},
+ id="context_key_after_regular_default",
+ ),
+ pytest.param(
+ "def dummy_task(a, {ctx0}=None, b=7, {ctx1}=None): return (a,
b)",
+ {"a": 1},
+ id="multiple_context_keys_mixed_with_regular_defaults",
+ ),
+ pytest.param(
+ "def dummy_task({ctx0}, x, y=10): return (x, y)",
+ {"x": 42},
+
id="required_param_between_context_key_and_regular_default_gets_none",
+ ),
+ pytest.param(
+ "def dummy_task({ctx0}=None, {ctx1}=None, {ctx2}=None): return
True",
+ {},
+ id="context_key_only_signature",
+ ),
+ ],
+ )
+ def test_context_key_construction_succeeds(self, func_src, op_kwargs):
+ """All context-key signature shapes must construct without raising."""
+ from airflow.sdk.bases.decorator import KNOWN_CONTEXT_KEYS
+
+ ctx_keys = list(KNOWN_CONTEXT_KEYS)
+ src = func_src.format(
+ ctx0=ctx_keys[0],
+ ctx1=ctx_keys[1] if len(ctx_keys) > 1 else ctx_keys[0],
+ ctx2=ctx_keys[2] if len(ctx_keys) > 2 else ctx_keys[0],
+ )
+ op = make_op(_make_func(src), op_kwargs=op_kwargs)
+ assert op is not None
+
+ def
test_context_key_after_regular_default_preserves_original_default(self):
+ from airflow.sdk.bases.decorator import KNOWN_CONTEXT_KEYS
+
+ ctx_key = next(iter(KNOWN_CONTEXT_KEYS))
+ f = _make_func(f"def dummy_task(x, y=5, {ctx_key}=None): return (x,
y)")
+ op = make_op(f, op_kwargs={"x": 1})
+ sig = inspect.signature(op.python_callable)
+ y_param = next(p for p in sig.parameters.values() if p.name == "y")
+ assert y_param.default == 5
+
+ def test_non_context_param_after_context_key_gets_none_injected(self):
+ from airflow.sdk.bases.decorator import KNOWN_CONTEXT_KEYS
+
+ ctx_key = next(iter(KNOWN_CONTEXT_KEYS))
+ f = _make_func(f"def dummy_task({ctx_key}, a): ...")
+ assert make_op(f, op_kwargs={"a": "2024-01-01"}) is not None
+ assert make_op(f) is not None
Review Comment:
This asserts `make_op(f)` with zero kwargs succeeds, meaning a required
non-context-key param `a` is silently given `default=None`. Is that the
intended behavior? A user who writes `@task def foo(start_date, my_data): ...;
foo()` (forgetting `my_data`) won't get any parse-time error.
##########
task-sdk/src/airflow/sdk/bases/decorator.py:
##########
@@ -313,6 +313,23 @@ def __init__(
param.replace(default=None) if param.name in KNOWN_CONTEXT_KEYS
else param
for param in signature.parameters.values()
]
+
+ # Python requires that positional parameters with defaults don't
precede those without.
+ # This only applies to POSITIONAL_ONLY and POSITIONAL_OR_KEYWORD
parameters — *args,
+ # **kwargs, and keyword-only parameters follow different rules.
+ positional_kinds = (inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD)
+ positional = [(i, p) for i, p in enumerate(parameters) if p.kind in
positional_kinds]
+ first_default_idx = next((i for i, p in positional if p.default !=
inspect.Parameter.empty), None)
+ if first_default_idx is not None:
+ parameters = [
+ param.replace(default=None)
Review Comment:
This injects `default=None` for all positional params after the first
defaulted one, not just context-key params. If a user writes `def
foo(start_date, my_data): ...; foo()` and forgets `my_data`, the parse-time
`signature.bind()` check won't catch it because `my_data` now has a default
too. The error moves to runtime.
Before this PR the error message was confusing, but at least it happened
early. Worth considering whether only params between context-key-defaulted
params and the next user-provided default should get `None`, rather than all
trailing positional params.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]