jason810496 commented on code in PR #58615:
URL: https://github.com/apache/airflow/pull/58615#discussion_r2567505210
##########
task-sdk/tests/task_sdk/definitions/test_param.py:
##########
@@ -307,3 +311,47 @@ def test_update(self):
def test_repr(self):
pd = ParamsDict({"key": Param("value", type="string")})
assert repr(pd) == "{'key': 'value'}"
+
+ @pytest.mark.parametrize("source", ("dag", "task"))
+ def test_fill_missing_param_source(self, source: Literal["dag", "task"]):
+ pd = ParamsDict(
+ {
+ "key": Param("value", type="string"),
+ "key2": "value2",
+ }
+ )
+ pd._fill_missing_param_source(source)
+ for param in pd.values():
+ assert param.source == source
+
+ def test_fill_missing_param_source_not_overwrite_existing(self):
+ pd = ParamsDict(
+ {
+ "key": Param("value", type="string", source="dag"),
+ "key2": "value2",
+ "key3": "value3",
+ }
+ )
+ pd._fill_missing_param_source("task")
+ for key, expected_source in (
+ ("key", "Dag"),
Review Comment:
```suggestion
("key", "dag"),
```
##########
task-sdk/tests/task_sdk/definitions/test_param.py:
##########
@@ -307,3 +311,47 @@ def test_update(self):
def test_repr(self):
pd = ParamsDict({"key": Param("value", type="string")})
assert repr(pd) == "{'key': 'value'}"
+
+ @pytest.mark.parametrize("source", ("dag", "task"))
+ def test_fill_missing_param_source(self, source: Literal["dag", "task"]):
+ pd = ParamsDict(
+ {
+ "key": Param("value", type="string"),
+ "key2": "value2",
+ }
+ )
+ pd._fill_missing_param_source(source)
+ for param in pd.values():
+ assert param.source == source
+
+ def test_fill_missing_param_source_not_overwrite_existing(self):
+ pd = ParamsDict(
+ {
+ "key": Param("value", type="string", source="dag"),
+ "key2": "value2",
+ "key3": "value3",
+ }
+ )
+ pd._fill_missing_param_source("task")
+ for key, expected_source in (
+ ("key", "Dag"),
+ ("key2", "task"),
+ ("key3", "task"),
+ ):
+ assert pd.get_param(key).source == expected_source
+
+ def test_filter_params_by_source(self):
+ pd = ParamsDict(
+ {
+ "key": Param("value", type="string", source="dag"),
+ "key2": Param("value", source="task"),
+ }
+ )
+ assert ParamsDict.filter_params_by_source(pd, "Dag") == ParamsDict(
+ {"key": Param("value", type="string", source="Dag")},
Review Comment:
```suggestion
assert ParamsDict.filter_params_by_source(pd, "dag") == ParamsDict(
{"key": Param("value", type="string", source="dag")},
```
##########
task-sdk/src/airflow/sdk/definitions/param.py:
##########
@@ -253,6 +275,20 @@ def deserialize(data: dict, version: int) -> ParamsDict:
return ParamsDict(data)
+ def _fill_missing_param_source(
+ self,
+ source: Literal["dag", "task"] | None = None,
+ ) -> None:
+ for key in self.__dict:
+ if self.__dict[key].source is None:
+ self.__dict[key].source = source
+
+ @staticmethod
+ def filter_params_by_source(params: ParamsDict, source: Literal["Dag",
"task"]) -> ParamsDict:
Review Comment:
```suggestion
def filter_params_by_source(params: ParamsDict, source: Literal["dag",
"task"]) -> ParamsDict:
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]