kaxil commented on code in PR #62174:
URL: https://github.com/apache/airflow/pull/62174#discussion_r2927160308


##########
task-sdk/tests/task_sdk/bases/test_decorator.py:
##########
@@ -69,6 +70,208 @@ def 
test_get_python_source_strips_decorator_and_comment(self, tmp_path: Path):
         assert cleaned.lstrip().splitlines()[0].startswith("def a_task")
 
 
+class DummyDecoratedOperator(DecoratedOperator):
+    custom_operator_name = "@task.dummy"
+
+    def execute(self, context):
+        return self.python_callable(*self.op_args, **self.op_kwargs)
+
+
+class TestDefaultFillingLogic:
+    @pytest.mark.parametrize(
+        ("func", "kwargs", "args"),
+        [
+            pytest.param(
+                lambda: 42,
+                {},
+                [],
+                id="no_params",
+            ),
+            pytest.param(
+                lambda a, b=5, c=None: (a, b, c),
+                {"a": 1},
+                [],
+                id="param_after_first_default_without_default",
+            ),
+            pytest.param(
+                lambda x, y=99: (x, y),
+                {"x": 1},
+                [],
+                id="param_after_first_default_is_given_none",
+            ),
+            pytest.param(
+                lambda a, b, c=99: (a, b, c),
+                {},
+                [1, 2],
+                id="single_trailing_optional",
+            ),
+        ],
+    )
+    def test_construction_succeeds(self, func, kwargs, args):
+        op = make_op(func, op_kwargs=kwargs, op_args=args)
+        assert op is not None
+
+    @pytest.mark.parametrize(
+        ("func", "op_kwargs", "op_args", "expected_defaults"),
+        [
+            pytest.param(
+                lambda a, b, c: a + b + c,
+                {},
+                [1, 2, 3],
+                [inspect.Parameter.empty, inspect.Parameter.empty, 
inspect.Parameter.empty],
+                id="all_required_no_defaults_injected",
+            ),
+            pytest.param(
+                lambda required, optional=10: required + optional,
+                {"required": 5},
+                [],
+                [inspect.Parameter.empty, 10],
+                id="params_before_first_default_stay_required",
+            ),
+            pytest.param(
+                lambda a, b=1, c=2, d=3: a + b + c + d,
+                {"a": 10},
+                [],
+                [inspect.Parameter.empty, 1, 2, 3],
+                id="explicit_defaults_after_first_default_preserved",
+            ),
+            pytest.param(
+                lambda no_default_1, no_default_2, first_default=42, 
after=None: None,
+                {},
+                [1, 2],
+                [inspect.Parameter.empty, inspect.Parameter.empty, 42, None],
+                id="first_default_defines_boundary",
+            ),
+            pytest.param(
+                lambda a=1, b=2, c=3: a + b + c,
+                {},
+                [],
+                [1, 2, 3],
+                id="all_params_have_defaults_none_overwritten",
+            ),
+        ],
+    )
+    def test_param_defaults(self, func, op_kwargs, op_args, expected_defaults):
+        op = make_op(func, op_kwargs=op_kwargs, op_args=op_args)
+        sig = inspect.signature(op.python_callable)

Review Comment:
   `inspect.signature(op.python_callable)` returns the original function's 
signature. The fix modifies a local `parameters` list used only for `bind()` 
validation; it never touches the callable itself. So this test is verifying 
that the original callable's defaults are preserved (which was always true), 
not that the fix's default-injection logic produces the right values.
   
   Consider inspecting the modified signature instead (e.g., by capturing it 
inside `__init__` or testing that `signature.bind()` succeeds/fails as 
expected).



##########
task-sdk/tests/task_sdk/bases/test_decorator.py:
##########
@@ -69,6 +70,208 @@ def 
test_get_python_source_strips_decorator_and_comment(self, tmp_path: Path):
         assert cleaned.lstrip().splitlines()[0].startswith("def a_task")
 
 
+class DummyDecoratedOperator(DecoratedOperator):
+    custom_operator_name = "@task.dummy"
+
+    def execute(self, context):
+        return self.python_callable(*self.op_args, **self.op_kwargs)
+
+
+class TestDefaultFillingLogic:
+    @pytest.mark.parametrize(
+        ("func", "kwargs", "args"),
+        [
+            pytest.param(
+                lambda: 42,
+                {},
+                [],
+                id="no_params",
+            ),
+            pytest.param(
+                lambda a, b=5, c=None: (a, b, c),
+                {"a": 1},
+                [],
+                id="param_after_first_default_without_default",
+            ),
+            pytest.param(
+                lambda x, y=99: (x, y),
+                {"x": 1},
+                [],
+                id="param_after_first_default_is_given_none",
+            ),
+            pytest.param(
+                lambda a, b, c=99: (a, b, c),
+                {},
+                [1, 2],
+                id="single_trailing_optional",
+            ),
+        ],
+    )
+    def test_construction_succeeds(self, func, kwargs, args):

Review Comment:
   None of the `test_construction_succeeds` cases reproduce the exact scenario 
from issue #56128: `def foo(start_date, end_date): ...; foo(None, None)`. 
Consider adding a case using actual context key names to serve as a regression 
test for that specific bug.



##########
task-sdk/tests/task_sdk/bases/test_decorator.py:
##########
@@ -69,6 +70,208 @@ def 
test_get_python_source_strips_decorator_and_comment(self, tmp_path: Path):
         assert cleaned.lstrip().splitlines()[0].startswith("def a_task")
 
 
+class DummyDecoratedOperator(DecoratedOperator):
+    custom_operator_name = "@task.dummy"
+
+    def execute(self, context):
+        return self.python_callable(*self.op_args, **self.op_kwargs)
+
+
+class TestDefaultFillingLogic:
+    @pytest.mark.parametrize(
+        ("func", "kwargs", "args"),
+        [
+            pytest.param(
+                lambda: 42,
+                {},
+                [],
+                id="no_params",
+            ),
+            pytest.param(
+                lambda a, b=5, c=None: (a, b, c),
+                {"a": 1},
+                [],
+                id="param_after_first_default_without_default",
+            ),
+            pytest.param(
+                lambda x, y=99: (x, y),
+                {"x": 1},
+                [],
+                id="param_after_first_default_is_given_none",
+            ),
+            pytest.param(
+                lambda a, b, c=99: (a, b, c),
+                {},
+                [1, 2],
+                id="single_trailing_optional",
+            ),
+        ],
+    )
+    def test_construction_succeeds(self, func, kwargs, args):
+        op = make_op(func, op_kwargs=kwargs, op_args=args)
+        assert op is not None
+
+    @pytest.mark.parametrize(
+        ("func", "op_kwargs", "op_args", "expected_defaults"),
+        [
+            pytest.param(
+                lambda a, b, c: a + b + c,
+                {},
+                [1, 2, 3],
+                [inspect.Parameter.empty, inspect.Parameter.empty, 
inspect.Parameter.empty],
+                id="all_required_no_defaults_injected",
+            ),
+            pytest.param(
+                lambda required, optional=10: required + optional,
+                {"required": 5},
+                [],
+                [inspect.Parameter.empty, 10],
+                id="params_before_first_default_stay_required",
+            ),
+            pytest.param(
+                lambda a, b=1, c=2, d=3: a + b + c + d,
+                {"a": 10},
+                [],
+                [inspect.Parameter.empty, 1, 2, 3],
+                id="explicit_defaults_after_first_default_preserved",
+            ),
+            pytest.param(
+                lambda no_default_1, no_default_2, first_default=42, 
after=None: None,
+                {},
+                [1, 2],
+                [inspect.Parameter.empty, inspect.Parameter.empty, 42, None],
+                id="first_default_defines_boundary",
+            ),
+            pytest.param(
+                lambda a=1, b=2, c=3: a + b + c,
+                {},
+                [],
+                [1, 2, 3],
+                id="all_params_have_defaults_none_overwritten",
+            ),
+        ],
+    )
+    def test_param_defaults(self, func, op_kwargs, op_args, expected_defaults):
+        op = make_op(func, op_kwargs=op_kwargs, op_args=op_args)
+        sig = inspect.signature(op.python_callable)
+        actual = [p.default for p in sig.parameters.values()]
+        assert actual == expected_defaults
+
+    def test_context_key_default_none_does_not_raise(self):
+        from airflow.sdk.bases.decorator import KNOWN_CONTEXT_KEYS
+
+        ctx_key = next(iter(KNOWN_CONTEXT_KEYS))
+        f = _make_func(f"def dummy_task(x, {ctx_key}=None): return x")
+        assert make_op(f, op_kwargs={"x": 1}) is not None
+
+    def test_context_key_with_non_none_default_raises(self):
+        from airflow.sdk.bases.decorator import KNOWN_CONTEXT_KEYS
+
+        ctx_key = next(iter(KNOWN_CONTEXT_KEYS))
+        f = _make_func(f"def dummy_task(x, {ctx_key}='bad_default'): return x")
+        with pytest.raises(ValueError, match="can't have a default other than 
None"):
+            make_op(f, op_kwargs={"x": 1})
+
+    @pytest.mark.parametrize(
+        ("func_src", "op_kwargs"),
+        [
+            pytest.param(
+                "def dummy_task({ctx0}, x, y=10): return (x, y)",
+                {"x": 1},
+                id="context_key_before_first_default_shifts_boundary",
+            ),
+            pytest.param(
+                "def dummy_task(x, y=5, {ctx0}=None): return (x, y)",
+                {"x": 1},
+                id="context_key_after_regular_default",
+            ),
+            pytest.param(
+                "def dummy_task(a, {ctx0}=None, b=7, {ctx1}=None): return (a, 
b)",
+                {"a": 1},
+                id="multiple_context_keys_mixed_with_regular_defaults",
+            ),
+            pytest.param(
+                "def dummy_task({ctx0}, x, y=10): return (x, y)",
+                {"x": 42},
+                
id="required_param_between_context_key_and_regular_default_gets_none",
+            ),
+            pytest.param(
+                "def dummy_task({ctx0}=None, {ctx1}=None, {ctx2}=None): return 
True",
+                {},
+                id="context_key_only_signature",
+            ),
+        ],
+    )
+    def test_context_key_construction_succeeds(self, func_src, op_kwargs):
+        """All context-key signature shapes must construct without raising."""
+        from airflow.sdk.bases.decorator import KNOWN_CONTEXT_KEYS
+
+        ctx_keys = list(KNOWN_CONTEXT_KEYS)
+        src = func_src.format(
+            ctx0=ctx_keys[0],
+            ctx1=ctx_keys[1] if len(ctx_keys) > 1 else ctx_keys[0],
+            ctx2=ctx_keys[2] if len(ctx_keys) > 2 else ctx_keys[0],
+        )
+        op = make_op(_make_func(src), op_kwargs=op_kwargs)
+        assert op is not None
+
+    def 
test_context_key_after_regular_default_preserves_original_default(self):
+        from airflow.sdk.bases.decorator import KNOWN_CONTEXT_KEYS
+
+        ctx_key = next(iter(KNOWN_CONTEXT_KEYS))
+        f = _make_func(f"def dummy_task(x, y=5, {ctx_key}=None): return (x, 
y)")
+        op = make_op(f, op_kwargs={"x": 1})
+        sig = inspect.signature(op.python_callable)
+        y_param = next(p for p in sig.parameters.values() if p.name == "y")
+        assert y_param.default == 5
+
+    def test_non_context_param_after_context_key_gets_none_injected(self):
+        from airflow.sdk.bases.decorator import KNOWN_CONTEXT_KEYS
+
+        ctx_key = next(iter(KNOWN_CONTEXT_KEYS))
+        f = _make_func(f"def dummy_task({ctx_key}, a): ...")
+        assert make_op(f, op_kwargs={"a": "2024-01-01"}) is not None
+        assert make_op(f) is not None

Review Comment:
   This asserts `make_op(f)` with zero kwargs succeeds, meaning a required 
non-context-key param `a` is silently given `default=None`. Is that the 
intended behavior? A user who writes `@task def foo(start_date, my_data): ...; 
foo()` (forgetting `my_data`) won't get any parse-time error.



##########
task-sdk/src/airflow/sdk/bases/decorator.py:
##########
@@ -313,6 +313,23 @@ def __init__(
             param.replace(default=None) if param.name in KNOWN_CONTEXT_KEYS 
else param
             for param in signature.parameters.values()
         ]
+
+        # Python requires that positional parameters with defaults don't 
precede those without.
+        # This only applies to POSITIONAL_ONLY and POSITIONAL_OR_KEYWORD 
parameters — *args,
+        # **kwargs, and keyword-only parameters follow different rules.
+        positional_kinds = (inspect.Parameter.POSITIONAL_ONLY, 
inspect.Parameter.POSITIONAL_OR_KEYWORD)
+        positional = [(i, p) for i, p in enumerate(parameters) if p.kind in 
positional_kinds]
+        first_default_idx = next((i for i, p in positional if p.default != 
inspect.Parameter.empty), None)
+        if first_default_idx is not None:
+            parameters = [
+                param.replace(default=None)

Review Comment:
   This injects `default=None` for all positional params after the first 
defaulted one, not just context-key params. If a user writes `def 
foo(start_date, my_data): ...; foo()` and forgets `my_data`, the parse-time 
`signature.bind()` check won't catch it because `my_data` now has a default 
too. The error moves to runtime.
   
   Before this PR the error message was confusing, but at least it happened 
early. Worth considering whether only params between context-key-defaulted 
params and the next user-provided default should get `None`, rather than all 
trailing positional params.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to