This is an automated email from the ASF dual-hosted git repository.

jasonliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1f6202df9e9 feat(param): add source to Param (#58615)
1f6202df9e9 is described below

commit 1f6202df9e95fe544256a759d4ec843d05b1073e
Author: Wei Lee <[email protected]>
AuthorDate: Sat Nov 29 00:03:36 2025 +0800

    feat(param): add source to Param (#58615)
---
 .../core_api/services/ui/connections.py            |  4 +-
 .../src/airflow/serialization/definitions/param.py | 12 ++++-
 .../airflow/serialization/serialized_objects.py    |  4 +-
 .../core_api/routes/public/test_dags.py            | 18 ++++++-
 .../core_api/routes/public/test_tasks.py           | 38 +++++++++++++--
 .../unit/serialization/test_dag_serialization.py   |  1 +
 .../airflow/providers/standard/operators/hitl.py   |  7 +++
 .../tests/unit/standard/operators/test_hitl.py     | 36 +++++++-------
 .../tests/unit/standard/triggers/test_hitl.py      | 20 ++++++--
 task-sdk/src/airflow/sdk/bases/operator.py         | 10 +++-
 task-sdk/src/airflow/sdk/definitions/param.py      | 46 ++++++++++++++++--
 task-sdk/tests/task_sdk/bases/test_operator.py     | 31 ++++++++++++
 task-sdk/tests/task_sdk/definitions/test_param.py  | 56 ++++++++++++++++++++--
 13 files changed, 242 insertions(+), 41 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/connections.py 
b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/connections.py
index ef045813e87..787f43c9a73 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/connections.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/connections.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 import logging
 from collections.abc import MutableMapping
 from functools import cache
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Literal
 
 from airflow.api_fastapi.core_api.datamodels.connections import (
     ConnectionHookFieldBehavior,
@@ -68,6 +68,7 @@ class HookMetaService:
             description: str = "",
             default: str | None = None,
             widget=None,
+            source: Literal["dag", "task"] | None = None,
         ):
             type: str | list[str] = [self.param_type, "null"]
             enum = {}
@@ -82,6 +83,7 @@ class HookMetaService:
                 default=default,
                 title=label,
                 description=description or None,
+                source=source or None,
                 type=type,
                 **format,
                 **enum,
diff --git a/airflow-core/src/airflow/serialization/definitions/param.py 
b/airflow-core/src/airflow/serialization/definitions/param.py
index 733131f3eab..12470c5fdd2 100644
--- a/airflow-core/src/airflow/serialization/definitions/param.py
+++ b/airflow-core/src/airflow/serialization/definitions/param.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 
 import collections.abc
 import copy
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, Literal
 
 from airflow.serialization.definitions.notset import NOTSET, is_arg_set
 
@@ -31,11 +31,18 @@ if TYPE_CHECKING:
 class SerializedParam:
     """Server-side param class for deserialization."""
 
-    def __init__(self, default: Any = NOTSET, description: str | None = None, 
**schema):
+    def __init__(
+        self,
+        default: Any = NOTSET,
+        description: str | None = None,
+        source: Literal["dag", "task"] | None = None,
+        **schema,
+    ):
         # No validation needed - the SDK already validated the default.
         self.value = default
         self.description = description
         self.schema = schema
+        self.source = source
 
     def resolve(self, *, raises: bool = False) -> Any:
         """
@@ -66,6 +73,7 @@ class SerializedParam:
             "value": self.resolve(),
             "schema": self.schema,
             "description": self.description,
+            "source": self.source,
         }
 
 
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py 
b/airflow-core/src/airflow/serialization/serialized_objects.py
index c0f2b488322..d9c19382077 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -1037,6 +1037,7 @@ class BaseSerialization:
             "default": cls.serialize(param.value),
             "description": cls.serialize(param.description),
             "schema": cls.serialize(param.schema),
+            "source": cls.serialize(getattr(param, "source", None)),
         }
 
     @classmethod
@@ -1048,7 +1049,7 @@ class BaseSerialization:
         this class's ``serialize`` method.  So before running through 
``deserialize``,
         we first verify that it's necessary to do.
         """
-        attrs = ("default", "description", "schema")
+        attrs = ("default", "description", "schema", "source")
         kwargs = {}
 
         def is_serialized(val):
@@ -1068,6 +1069,7 @@ class BaseSerialization:
         return SerializedParam(
             default=kwargs.get("default"),
             description=kwargs.get("description"),
+            source=kwargs.get("source", None),
             **(kwargs.get("schema") or {}),
         )
 
diff --git 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py
index 891d173aeba..15a73951815 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py
@@ -944,7 +944,14 @@ class TestDagDetails(TestDagEndpoint):
             "next_dagrun_run_after": None,
             "owners": ["airflow"],
             "owner_links": {},
-            "params": {"foo": {"value": 1, "schema": {}, "description": None}},
+            "params": {
+                "foo": {
+                    "value": 1,
+                    "schema": {},
+                    "description": None,
+                    "source": None,
+                }
+            },
             "relative_fileloc": "test_dags.py",
             "render_template_as_native_obj": False,
             "timetable_summary": None,
@@ -1034,7 +1041,14 @@ class TestDagDetails(TestDagEndpoint):
             "next_dagrun_run_after": None,
             "owners": ["airflow"],
             "owner_links": {},
-            "params": {"foo": {"value": 1, "schema": {}, "description": None}},
+            "params": {
+                "foo": {
+                    "value": 1,
+                    "schema": {},
+                    "description": None,
+                    "source": None,
+                }
+            },
             "relative_fileloc": "test_dags.py",
             "render_template_as_native_obj": False,
             "timetable_summary": None,
diff --git 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py
index 1bd788d1436..dd85ad1c325 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py
@@ -102,7 +102,7 @@ class TestGetTask(TestTaskEndpoint):
             "extra_links": [],
             "operator_name": "EmptyOperator",
             "owner": "airflow",
-            "params": {"foo": {"value": "bar", "schema": {}, "description": 
None}},
+            "params": {"foo": {"value": "bar", "schema": {}, "description": 
None, "source": "task"}},
             "pool": "default_pool",
             "pool_slots": 1.0,
             "priority_weight": 1.0,
@@ -180,7 +180,14 @@ class TestGetTask(TestTaskEndpoint):
             "extra_links": [],
             "operator_name": "EmptyOperator",
             "owner": "airflow",
-            "params": {"is_unscheduled": {"value": True, "schema": {}, 
"description": None}},
+            "params": {
+                "is_unscheduled": {
+                    "value": True,
+                    "schema": {},
+                    "description": None,
+                    "source": "task",
+                }
+            },
             "pool": "default_pool",
             "pool_slots": 1.0,
             "priority_weight": 1.0,
@@ -239,7 +246,14 @@ class TestGetTask(TestTaskEndpoint):
             "extra_links": [],
             "operator_name": "EmptyOperator",
             "owner": "airflow",
-            "params": {"foo": {"value": "bar", "schema": {}, "description": 
None}},
+            "params": {
+                "foo": {
+                    "value": "bar",
+                    "schema": {},
+                    "description": None,
+                    "source": "task",
+                }
+            },
             "pool": "default_pool",
             "pool_slots": 1.0,
             "priority_weight": 1.0,
@@ -304,7 +318,14 @@ class TestGetTasks(TestTaskEndpoint):
                     "extra_links": [],
                     "operator_name": "EmptyOperator",
                     "owner": "airflow",
-                    "params": {"foo": {"value": "bar", "schema": {}, 
"description": None}},
+                    "params": {
+                        "foo": {
+                            "value": "bar",
+                            "schema": {},
+                            "description": None,
+                            "source": "task",
+                        }
+                    },
                     "pool": "default_pool",
                     "pool_slots": 1.0,
                     "priority_weight": 1.0,
@@ -459,7 +480,14 @@ class TestGetTasks(TestTaskEndpoint):
                     "extra_links": [],
                     "operator_name": "EmptyOperator",
                     "owner": "airflow",
-                    "params": {"is_unscheduled": {"value": True, "schema": {}, 
"description": None}},
+                    "params": {
+                        "is_unscheduled": {
+                            "value": True,
+                            "schema": {},
+                            "description": None,
+                            "source": "task",
+                        }
+                    },
                     "pool": "default_pool",
                     "pool_slots": 1.0,
                     "priority_weight": 1.0,
diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py 
b/airflow-core/tests/unit/serialization/test_dag_serialization.py
index c6fe7f5c0ae..fb261921aef 100644
--- a/airflow-core/tests/unit/serialization/test_dag_serialization.py
+++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py
@@ -1193,6 +1193,7 @@ class TestStringifiedDAGs:
             "value": None if param.value is NOTSET else param.value,
             "schema": param.schema,
             "description": param.description,
+            "source": None,
         }
 
     @pytest.mark.parametrize(
diff --git 
a/providers/standard/src/airflow/providers/standard/operators/hitl.py 
b/providers/standard/src/airflow/providers/standard/operators/hitl.py
index 8b5c0cdd5b6..de208174827 100644
--- a/providers/standard/src/airflow/providers/standard/operators/hitl.py
+++ b/providers/standard/src/airflow/providers/standard/operators/hitl.py
@@ -84,6 +84,13 @@ class HITLOperator(BaseOperator):
         self.multiple = multiple
 
         self.params: ParamsDict = params if isinstance(params, ParamsDict) 
else ParamsDict(params or {})
+        if hasattr(ParamsDict, "filter_params_by_source"):
+            # Params that exist only in Dag level does not make sense to 
appear in HITLOperator
+            self.params = ParamsDict.filter_params_by_source(self.params, 
source="task")
+        elif self.params:
+            self.log.debug(
+                "ParamsDict.filter_params_by_source not available; 
HITLOperator will also include Dag level params."
+            )
 
         self.notifiers: Sequence[BaseNotifier] = (
             [notifiers] if isinstance(notifiers, BaseNotifier) else notifiers 
or []
diff --git a/providers/standard/tests/unit/standard/operators/test_hitl.py 
b/providers/standard/tests/unit/standard/operators/test_hitl.py
index 18984cb8912..4e59f2ae503 100644
--- a/providers/standard/tests/unit/standard/operators/test_hitl.py
+++ b/providers/standard/tests/unit/standard/operators/test_hitl.py
@@ -18,7 +18,7 @@ from __future__ import annotations
 
 import pytest
 
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, 
AIRFLOW_V_3_2_PLUS
 
 if not AIRFLOW_V_3_1_PLUS:
     pytest.skip("Human in the loop is only compatible with Airflow >= 3.1.0", 
allow_module_level=True)
@@ -240,19 +240,23 @@ class TestHITLOperator:
         assert hitl_detail_model.responded_by is None
         assert hitl_detail_model.chosen_options is None
         assert hitl_detail_model.params_input == {}
-        if AIRFLOW_V_3_1_3_PLUS:
-            assert hitl_detail_model.params == {
-                "input_1": {
-                    "value": 1,
-                    "description": None,
-                    "schema": {},
-                }
-            }
+        expected_params: dict[str, Any]
+        if AIRFLOW_V_3_2_PLUS:
+            expected_params = {"input_1": {"value": 1, "description": None, 
"schema": {}, "source": "task"}}
+        elif AIRFLOW_V_3_1_3_PLUS:
+            expected_params = {"input_1": {"value": 1, "description": None, 
"schema": {}}}
         else:
-            assert hitl_detail_model.params == {"input_1": 1}
+            expected_params = {"input_1": 1}
+        assert hitl_detail_model.params == expected_params
 
         assert notifier.called is True
 
+        expected_params_in_trigger_kwargs: dict[str, dict[str, Any]]
+        if AIRFLOW_V_3_2_PLUS:
+            expected_params_in_trigger_kwargs = expected_params
+        else:
+            expected_params_in_trigger_kwargs = {"input_1": {"value": 1, 
"description": None, "schema": {}}}
+
         registered_trigger = session.scalar(
             select(Trigger).where(Trigger.classpath == 
"airflow.providers.standard.triggers.hitl.HITLTrigger")
         )
@@ -261,13 +265,7 @@ class TestHITLOperator:
             "ti_id": ti.id,
             "options": ["1", "2", "3", "4", "5"],
             "defaults": ["1"],
-            "params": {
-                "input_1": {
-                    "value": 1,
-                    "description": None,
-                    "schema": {},
-                }
-            },
+            "params": expected_params_in_trigger_kwargs,
             "multiple": False,
             "timeout_datetime": None,
             "poke_interval": 5.0,
@@ -323,6 +321,10 @@ class TestHITLOperator:
             options=["1", "2", "3", "4", "5"],
             params=input_params,
         )
+        if AIRFLOW_V_3_2_PLUS:
+            for key in expected_params:
+                expected_params[key]["source"] = "task"
+
         assert hitl_op.serialized_params == expected_params
 
     @pytest.mark.skipif(
diff --git a/providers/standard/tests/unit/standard/triggers/test_hitl.py 
b/providers/standard/tests/unit/standard/triggers/test_hitl.py
index adb82ff00c0..7166d4428bb 100644
--- a/providers/standard/tests/unit/standard/triggers/test_hitl.py
+++ b/providers/standard/tests/unit/standard/triggers/test_hitl.py
@@ -21,7 +21,7 @@ from typing import Any
 
 import pytest
 
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, 
AIRFLOW_V_3_2_PLUS
 
 if not AIRFLOW_V_3_1_PLUS:
     pytest.skip("Human in the loop public API compatible with Airflow >= 
3.1.0", allow_module_level=True)
@@ -50,7 +50,12 @@ def default_trigger_args() -> dict[str, Any]:
         "ti_id": TI_ID,
         "options": ["1", "2", "3", "4", "5"],
         "params": {
-            "input": {"value": 1, "schema": {}, "description": None},
+            "input": {
+                "value": 1,
+                "schema": {},
+                "description": None,
+                "source": "task",
+            },
         },
         "multiple": False,
     }
@@ -65,11 +70,20 @@ class TestHITLTrigger:
             **default_trigger_args,
         )
         classpath, kwargs = trigger.serialize()
+
+        expected_params_in_trigger_kwargs: dict[str, dict[str, Any]]
+        if AIRFLOW_V_3_2_PLUS:
+            expected_params_in_trigger_kwargs = {
+                "input": {"value": 1, "description": None, "schema": {}, 
"source": "task"}
+            }
+        else:
+            expected_params_in_trigger_kwargs = {"input": {"value": 1, 
"description": None, "schema": {}}}
+
         assert classpath == 
"airflow.providers.standard.triggers.hitl.HITLTrigger"
         assert kwargs == {
             "ti_id": TI_ID,
             "options": ["1", "2", "3", "4", "5"],
-            "params": {"input": {"value": 1, "description": None, "schema": 
{}}},
+            "params": expected_params_in_trigger_kwargs,
             "defaults": ["1"],
             "multiple": False,
             "timeout_datetime": None,
diff --git a/task-sdk/src/airflow/sdk/bases/operator.py 
b/task-sdk/src/airflow/sdk/bases/operator.py
index f040d99bc2f..27c3e702d53 100644
--- a/task-sdk/src/airflow/sdk/bases/operator.py
+++ b/task-sdk/src/airflow/sdk/bases/operator.py
@@ -138,6 +138,7 @@ def _get_parent_defaults(dag: DAG | None, task_group: 
TaskGroup | None) -> tuple
         return {}, ParamsDict()
     dag_args = copy.copy(dag.default_args)
     dag_params = copy.deepcopy(dag.params)
+    dag_params._fill_missing_param_source("dag")
     if task_group:
         if task_group.default_args and not isinstance(task_group.default_args, 
collections.abc.Mapping):
             raise TypeError("default_args must be a mapping")
@@ -155,13 +156,20 @@ def get_merged_defaults(
     if task_params:
         if not isinstance(task_params, collections.abc.Mapping):
             raise TypeError(f"params must be a mapping, got 
{type(task_params)}")
+
+        task_params = ParamsDict(task_params)
+        task_params._fill_missing_param_source("task")
         params.update(task_params)
+
     if task_default_args:
         if not isinstance(task_default_args, collections.abc.Mapping):
             raise TypeError(f"default_args must be a mapping, got 
{type(task_params)}")
         args.update(task_default_args)
         with contextlib.suppress(KeyError):
-            params.update(task_default_args["params"] or {})
+            if params_from_default_args := 
ParamsDict(task_default_args["params"] or {}):
+                params_from_default_args._fill_missing_param_source("task")
+                params.update(params_from_default_args)
+
     return args, params
 
 
diff --git a/task-sdk/src/airflow/sdk/definitions/param.py 
b/task-sdk/src/airflow/sdk/definitions/param.py
index 410da71fde7..d1174ec561b 100644
--- a/task-sdk/src/airflow/sdk/definitions/param.py
+++ b/task-sdk/src/airflow/sdk/definitions/param.py
@@ -21,7 +21,7 @@ import copy
 import json
 import logging
 from collections.abc import ItemsView, Iterable, Mapping, MutableMapping, 
ValuesView
-from typing import TYPE_CHECKING, Any, ClassVar
+from typing import TYPE_CHECKING, Any, ClassVar, Literal
 
 from airflow.sdk.definitions._internal.mixins import ResolveMixin
 from airflow.sdk.definitions._internal.types import NOTSET, is_arg_set
@@ -51,15 +51,27 @@ class Param:
 
     CLASS_IDENTIFIER = "__class"
 
-    def __init__(self, default: Any = NOTSET, description: str | None = None, 
**kwargs):
+    def __init__(
+        self,
+        default: Any = NOTSET,
+        description: str | None = None,
+        source: Literal["dag", "task"] | None = None,
+        **kwargs,
+    ):
         if default is not NOTSET:
             self._check_json(default)
         self.value = default
         self.description = description
         self.schema = kwargs.pop("schema") if "schema" in kwargs else kwargs
+        self.source = source
 
     def __copy__(self) -> Param:
-        return Param(self.value, self.description, schema=self.schema)
+        return Param(
+            self.value,
+            self.description,
+            schema=self.schema,
+            source=self.source,
+        )
 
     @staticmethod
     def _check_json(value):
@@ -119,14 +131,24 @@ class Param:
         return self.value is not NOTSET and self.value is not None
 
     def serialize(self) -> dict:
-        return {"value": self.value, "description": self.description, 
"schema": self.schema}
+        return {
+            "value": self.value,
+            "description": self.description,
+            "schema": self.schema,
+            "source": self.source,
+        }
 
     @staticmethod
     def deserialize(data: dict[str, Any], version: int) -> Param:
         if version > Param.__version__:
             raise TypeError("serialized version > class version")
 
-        return Param(default=data["value"], description=data["description"], 
schema=data["schema"])
+        return Param(
+            default=data["value"],
+            description=data["description"],
+            schema=data["schema"],
+            source=data.get("source", None),
+        )
 
 
 class ParamsDict(MutableMapping[str, Any]):
@@ -253,6 +275,20 @@ class ParamsDict(MutableMapping[str, Any]):
 
         return ParamsDict(data)
 
+    def _fill_missing_param_source(
+        self,
+        source: Literal["dag", "task"] | None = None,
+    ) -> None:
+        for key in self.__dict:
+            if self.__dict[key].source is None:
+                self.__dict[key].source = source
+
+    @staticmethod
+    def filter_params_by_source(params: ParamsDict, source: Literal["dag", 
"task"]) -> ParamsDict:
+        return ParamsDict(
+            {key: param for key, param in params.__dict.items() if 
param.source == source},
+        )
+
 
 class DagParam(ResolveMixin):
     """
diff --git a/task-sdk/tests/task_sdk/bases/test_operator.py 
b/task-sdk/tests/task_sdk/bases/test_operator.py
index 535ba511d09..5db27774e23 100644
--- a/task-sdk/tests/task_sdk/bases/test_operator.py
+++ b/task-sdk/tests/task_sdk/bases/test_operator.py
@@ -41,6 +41,7 @@ from airflow.sdk.bases.operator import (
 )
 from airflow.sdk.definitions.dag import DAG
 from airflow.sdk.definitions.edges import Label
+from airflow.sdk.definitions.param import ParamsDict
 from airflow.sdk.definitions.taskgroup import TaskGroup
 from airflow.sdk.definitions.template import literal
 from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, 
_UpstreamPriorityWeightStrategy
@@ -785,6 +786,36 @@ class TestBaseOperator:
         task.render_template_fields(context={"foo": "whatever", "bar": 
"whatever"})
         assert mock_jinja_env.call_count == 1
 
+    def test_params_source(self):
+        # Test bug when copying an operator attached to a Dag
+        with DAG(
+            "dag0",
+            params=ParamsDict(
+                {
+                    "param from Dag": "value1",
+                    "overwritten by task": "value 2",
+                }
+            ),
+            schedule=None,
+            start_date=DEFAULT_DATE,
+        ):
+            op1 = MockOperator(
+                task_id="task1",
+                params=ParamsDict(
+                    {
+                        "overwritten by task": "value 3",
+                        "param from task": "value 4",
+                    }
+                ),
+            )
+
+        for key, expected_source in (
+            ("param from Dag", "dag"),
+            ("overwritten by task", "task"),
+            ("param from task", "task"),
+        ):
+            assert op1.params.get_param(key).source == expected_source
+
     def test_deepcopy(self):
         # Test bug when copying an operator attached to a Dag
         with DAG("dag0", schedule=None, start_date=DEFAULT_DATE) as dag:
diff --git a/task-sdk/tests/task_sdk/definitions/test_param.py 
b/task-sdk/tests/task_sdk/definitions/test_param.py
index aa2aa71c51a..887003d13dc 100644
--- a/task-sdk/tests/task_sdk/definitions/test_param.py
+++ b/task-sdk/tests/task_sdk/definitions/test_param.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 from contextlib import nullcontext
+from typing import Literal
 
 import pytest
 
@@ -207,10 +208,13 @@ class TestParam:
     def test_dump(self):
         p = Param("hello", description="world", type="string", minLength=2)
         dump = p.dump()
-        assert dump["__class"] == "airflow.sdk.definitions.param.Param"
-        assert dump["value"] == "hello"
-        assert dump["description"] == "world"
-        assert dump["schema"] == {"type": "string", "minLength": 2}
+        assert dump == {
+            "__class": "airflow.sdk.definitions.param.Param",
+            "value": "hello",
+            "description": "world",
+            "schema": {"type": "string", "minLength": 2},
+            "source": None,
+        }
 
     @pytest.mark.parametrize(
         "param",
@@ -307,3 +311,47 @@ class TestParamsDict:
     def test_repr(self):
         pd = ParamsDict({"key": Param("value", type="string")})
         assert repr(pd) == "{'key': 'value'}"
+
+    @pytest.mark.parametrize("source", ("dag", "task"))
+    def test_fill_missing_param_source(self, source: Literal["dag", "task"]):
+        pd = ParamsDict(
+            {
+                "key": Param("value", type="string"),
+                "key2": "value2",
+            }
+        )
+        pd._fill_missing_param_source(source)
+        for param in pd.values():
+            assert param.source == source
+
+    def test_fill_missing_param_source_not_overwrite_existing(self):
+        pd = ParamsDict(
+            {
+                "key": Param("value", type="string", source="dag"),
+                "key2": "value2",
+                "key3": "value3",
+            }
+        )
+        pd._fill_missing_param_source("task")
+        for key, expected_source in (
+            ("key", "dag"),
+            ("key2", "task"),
+            ("key3", "task"),
+        ):
+            assert pd.get_param(key).source == expected_source
+
+    def test_filter_params_by_source(self):
+        pd = ParamsDict(
+            {
+                "key": Param("value", type="string", source="dag"),
+                "key2": Param("value", source="task"),
+            }
+        )
+        assert ParamsDict.filter_params_by_source(pd, "dag") == ParamsDict(
+            {"key": Param("value", type="string", source="dag")},
+        )
+        assert ParamsDict.filter_params_by_source(pd, "task") == ParamsDict(
+            {
+                "key2": Param("value", source="task"),
+            }
+        )

Reply via email to