This is an automated email from the ASF dual-hosted git repository.
jasonliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1f6202df9e9 feat(param): add source to Param (#58615)
1f6202df9e9 is described below
commit 1f6202df9e95fe544256a759d4ec843d05b1073e
Author: Wei Lee <[email protected]>
AuthorDate: Sat Nov 29 00:03:36 2025 +0800
feat(param): add source to Param (#58615)
---
.../core_api/services/ui/connections.py | 4 +-
.../src/airflow/serialization/definitions/param.py | 12 ++++-
.../airflow/serialization/serialized_objects.py | 4 +-
.../core_api/routes/public/test_dags.py | 18 ++++++-
.../core_api/routes/public/test_tasks.py | 38 +++++++++++++--
.../unit/serialization/test_dag_serialization.py | 1 +
.../airflow/providers/standard/operators/hitl.py | 7 +++
.../tests/unit/standard/operators/test_hitl.py | 36 +++++++-------
.../tests/unit/standard/triggers/test_hitl.py | 20 ++++++--
task-sdk/src/airflow/sdk/bases/operator.py | 10 +++-
task-sdk/src/airflow/sdk/definitions/param.py | 46 ++++++++++++++++--
task-sdk/tests/task_sdk/bases/test_operator.py | 31 ++++++++++++
task-sdk/tests/task_sdk/definitions/test_param.py | 56 ++++++++++++++++++++--
13 files changed, 242 insertions(+), 41 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/connections.py
b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/connections.py
index ef045813e87..787f43c9a73 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/connections.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/connections.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import logging
from collections.abc import MutableMapping
from functools import cache
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Literal
from airflow.api_fastapi.core_api.datamodels.connections import (
ConnectionHookFieldBehavior,
@@ -68,6 +68,7 @@ class HookMetaService:
description: str = "",
default: str | None = None,
widget=None,
+ source: Literal["dag", "task"] | None = None,
):
type: str | list[str] = [self.param_type, "null"]
enum = {}
@@ -82,6 +83,7 @@ class HookMetaService:
default=default,
title=label,
description=description or None,
+ source=source or None,
type=type,
**format,
**enum,
diff --git a/airflow-core/src/airflow/serialization/definitions/param.py
b/airflow-core/src/airflow/serialization/definitions/param.py
index 733131f3eab..12470c5fdd2 100644
--- a/airflow-core/src/airflow/serialization/definitions/param.py
+++ b/airflow-core/src/airflow/serialization/definitions/param.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import collections.abc
import copy
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, Literal
from airflow.serialization.definitions.notset import NOTSET, is_arg_set
@@ -31,11 +31,18 @@ if TYPE_CHECKING:
class SerializedParam:
"""Server-side param class for deserialization."""
- def __init__(self, default: Any = NOTSET, description: str | None = None,
**schema):
+ def __init__(
+ self,
+ default: Any = NOTSET,
+ description: str | None = None,
+ source: Literal["dag", "task"] | None = None,
+ **schema,
+ ):
# No validation needed - the SDK already validated the default.
self.value = default
self.description = description
self.schema = schema
+ self.source = source
def resolve(self, *, raises: bool = False) -> Any:
"""
@@ -66,6 +73,7 @@ class SerializedParam:
"value": self.resolve(),
"schema": self.schema,
"description": self.description,
+ "source": self.source,
}
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py
b/airflow-core/src/airflow/serialization/serialized_objects.py
index c0f2b488322..d9c19382077 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -1037,6 +1037,7 @@ class BaseSerialization:
"default": cls.serialize(param.value),
"description": cls.serialize(param.description),
"schema": cls.serialize(param.schema),
+ "source": cls.serialize(getattr(param, "source", None)),
}
@classmethod
@@ -1048,7 +1049,7 @@ class BaseSerialization:
this class's ``serialize`` method. So before running through
``deserialize``,
we first verify that it's necessary to do.
"""
- attrs = ("default", "description", "schema")
+ attrs = ("default", "description", "schema", "source")
kwargs = {}
def is_serialized(val):
@@ -1068,6 +1069,7 @@ class BaseSerialization:
return SerializedParam(
default=kwargs.get("default"),
description=kwargs.get("description"),
+ source=kwargs.get("source", None),
**(kwargs.get("schema") or {}),
)
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py
index 891d173aeba..15a73951815 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py
@@ -944,7 +944,14 @@ class TestDagDetails(TestDagEndpoint):
"next_dagrun_run_after": None,
"owners": ["airflow"],
"owner_links": {},
- "params": {"foo": {"value": 1, "schema": {}, "description": None}},
+ "params": {
+ "foo": {
+ "value": 1,
+ "schema": {},
+ "description": None,
+ "source": None,
+ }
+ },
"relative_fileloc": "test_dags.py",
"render_template_as_native_obj": False,
"timetable_summary": None,
@@ -1034,7 +1041,14 @@ class TestDagDetails(TestDagEndpoint):
"next_dagrun_run_after": None,
"owners": ["airflow"],
"owner_links": {},
- "params": {"foo": {"value": 1, "schema": {}, "description": None}},
+ "params": {
+ "foo": {
+ "value": 1,
+ "schema": {},
+ "description": None,
+ "source": None,
+ }
+ },
"relative_fileloc": "test_dags.py",
"render_template_as_native_obj": False,
"timetable_summary": None,
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py
index 1bd788d1436..dd85ad1c325 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py
@@ -102,7 +102,7 @@ class TestGetTask(TestTaskEndpoint):
"extra_links": [],
"operator_name": "EmptyOperator",
"owner": "airflow",
- "params": {"foo": {"value": "bar", "schema": {}, "description":
None}},
+ "params": {"foo": {"value": "bar", "schema": {}, "description":
None, "source": "task"}},
"pool": "default_pool",
"pool_slots": 1.0,
"priority_weight": 1.0,
@@ -180,7 +180,14 @@ class TestGetTask(TestTaskEndpoint):
"extra_links": [],
"operator_name": "EmptyOperator",
"owner": "airflow",
- "params": {"is_unscheduled": {"value": True, "schema": {},
"description": None}},
+ "params": {
+ "is_unscheduled": {
+ "value": True,
+ "schema": {},
+ "description": None,
+ "source": "task",
+ }
+ },
"pool": "default_pool",
"pool_slots": 1.0,
"priority_weight": 1.0,
@@ -239,7 +246,14 @@ class TestGetTask(TestTaskEndpoint):
"extra_links": [],
"operator_name": "EmptyOperator",
"owner": "airflow",
- "params": {"foo": {"value": "bar", "schema": {}, "description":
None}},
+ "params": {
+ "foo": {
+ "value": "bar",
+ "schema": {},
+ "description": None,
+ "source": "task",
+ }
+ },
"pool": "default_pool",
"pool_slots": 1.0,
"priority_weight": 1.0,
@@ -304,7 +318,14 @@ class TestGetTasks(TestTaskEndpoint):
"extra_links": [],
"operator_name": "EmptyOperator",
"owner": "airflow",
- "params": {"foo": {"value": "bar", "schema": {},
"description": None}},
+ "params": {
+ "foo": {
+ "value": "bar",
+ "schema": {},
+ "description": None,
+ "source": "task",
+ }
+ },
"pool": "default_pool",
"pool_slots": 1.0,
"priority_weight": 1.0,
@@ -459,7 +480,14 @@ class TestGetTasks(TestTaskEndpoint):
"extra_links": [],
"operator_name": "EmptyOperator",
"owner": "airflow",
- "params": {"is_unscheduled": {"value": True, "schema": {},
"description": None}},
+ "params": {
+ "is_unscheduled": {
+ "value": True,
+ "schema": {},
+ "description": None,
+ "source": "task",
+ }
+ },
"pool": "default_pool",
"pool_slots": 1.0,
"priority_weight": 1.0,
diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py
b/airflow-core/tests/unit/serialization/test_dag_serialization.py
index c6fe7f5c0ae..fb261921aef 100644
--- a/airflow-core/tests/unit/serialization/test_dag_serialization.py
+++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py
@@ -1193,6 +1193,7 @@ class TestStringifiedDAGs:
"value": None if param.value is NOTSET else param.value,
"schema": param.schema,
"description": param.description,
+ "source": None,
}
@pytest.mark.parametrize(
diff --git
a/providers/standard/src/airflow/providers/standard/operators/hitl.py
b/providers/standard/src/airflow/providers/standard/operators/hitl.py
index 8b5c0cdd5b6..de208174827 100644
--- a/providers/standard/src/airflow/providers/standard/operators/hitl.py
+++ b/providers/standard/src/airflow/providers/standard/operators/hitl.py
@@ -84,6 +84,13 @@ class HITLOperator(BaseOperator):
self.multiple = multiple
self.params: ParamsDict = params if isinstance(params, ParamsDict)
else ParamsDict(params or {})
+ if hasattr(ParamsDict, "filter_params_by_source"):
+ # Params that exist only in Dag level does not make sense to
appear in HITLOperator
+ self.params = ParamsDict.filter_params_by_source(self.params,
source="task")
+ elif self.params:
+ self.log.debug(
+ "ParamsDict.filter_params_by_source not available;
HITLOperator will also include Dag level params."
+ )
self.notifiers: Sequence[BaseNotifier] = (
[notifiers] if isinstance(notifiers, BaseNotifier) else notifiers
or []
diff --git a/providers/standard/tests/unit/standard/operators/test_hitl.py
b/providers/standard/tests/unit/standard/operators/test_hitl.py
index 18984cb8912..4e59f2ae503 100644
--- a/providers/standard/tests/unit/standard/operators/test_hitl.py
+++ b/providers/standard/tests/unit/standard/operators/test_hitl.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import pytest
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS,
AIRFLOW_V_3_2_PLUS
if not AIRFLOW_V_3_1_PLUS:
pytest.skip("Human in the loop is only compatible with Airflow >= 3.1.0",
allow_module_level=True)
@@ -240,19 +240,23 @@ class TestHITLOperator:
assert hitl_detail_model.responded_by is None
assert hitl_detail_model.chosen_options is None
assert hitl_detail_model.params_input == {}
- if AIRFLOW_V_3_1_3_PLUS:
- assert hitl_detail_model.params == {
- "input_1": {
- "value": 1,
- "description": None,
- "schema": {},
- }
- }
+ expected_params: dict[str, Any]
+ if AIRFLOW_V_3_2_PLUS:
+ expected_params = {"input_1": {"value": 1, "description": None,
"schema": {}, "source": "task"}}
+ elif AIRFLOW_V_3_1_3_PLUS:
+ expected_params = {"input_1": {"value": 1, "description": None,
"schema": {}}}
else:
- assert hitl_detail_model.params == {"input_1": 1}
+ expected_params = {"input_1": 1}
+ assert hitl_detail_model.params == expected_params
assert notifier.called is True
+ expected_params_in_trigger_kwargs: dict[str, dict[str, Any]]
+ if AIRFLOW_V_3_2_PLUS:
+ expected_params_in_trigger_kwargs = expected_params
+ else:
+ expected_params_in_trigger_kwargs = {"input_1": {"value": 1,
"description": None, "schema": {}}}
+
registered_trigger = session.scalar(
select(Trigger).where(Trigger.classpath ==
"airflow.providers.standard.triggers.hitl.HITLTrigger")
)
@@ -261,13 +265,7 @@ class TestHITLOperator:
"ti_id": ti.id,
"options": ["1", "2", "3", "4", "5"],
"defaults": ["1"],
- "params": {
- "input_1": {
- "value": 1,
- "description": None,
- "schema": {},
- }
- },
+ "params": expected_params_in_trigger_kwargs,
"multiple": False,
"timeout_datetime": None,
"poke_interval": 5.0,
@@ -323,6 +321,10 @@ class TestHITLOperator:
options=["1", "2", "3", "4", "5"],
params=input_params,
)
+ if AIRFLOW_V_3_2_PLUS:
+ for key in expected_params:
+ expected_params[key]["source"] = "task"
+
assert hitl_op.serialized_params == expected_params
@pytest.mark.skipif(
diff --git a/providers/standard/tests/unit/standard/triggers/test_hitl.py
b/providers/standard/tests/unit/standard/triggers/test_hitl.py
index adb82ff00c0..7166d4428bb 100644
--- a/providers/standard/tests/unit/standard/triggers/test_hitl.py
+++ b/providers/standard/tests/unit/standard/triggers/test_hitl.py
@@ -21,7 +21,7 @@ from typing import Any
import pytest
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS,
AIRFLOW_V_3_2_PLUS
if not AIRFLOW_V_3_1_PLUS:
pytest.skip("Human in the loop public API compatible with Airflow >=
3.1.0", allow_module_level=True)
@@ -50,7 +50,12 @@ def default_trigger_args() -> dict[str, Any]:
"ti_id": TI_ID,
"options": ["1", "2", "3", "4", "5"],
"params": {
- "input": {"value": 1, "schema": {}, "description": None},
+ "input": {
+ "value": 1,
+ "schema": {},
+ "description": None,
+ "source": "task",
+ },
},
"multiple": False,
}
@@ -65,11 +70,20 @@ class TestHITLTrigger:
**default_trigger_args,
)
classpath, kwargs = trigger.serialize()
+
+ expected_params_in_trigger_kwargs: dict[str, dict[str, Any]]
+ if AIRFLOW_V_3_2_PLUS:
+ expected_params_in_trigger_kwargs = {
+ "input": {"value": 1, "description": None, "schema": {},
"source": "task"}
+ }
+ else:
+ expected_params_in_trigger_kwargs = {"input": {"value": 1,
"description": None, "schema": {}}}
+
assert classpath ==
"airflow.providers.standard.triggers.hitl.HITLTrigger"
assert kwargs == {
"ti_id": TI_ID,
"options": ["1", "2", "3", "4", "5"],
- "params": {"input": {"value": 1, "description": None, "schema":
{}}},
+ "params": expected_params_in_trigger_kwargs,
"defaults": ["1"],
"multiple": False,
"timeout_datetime": None,
diff --git a/task-sdk/src/airflow/sdk/bases/operator.py
b/task-sdk/src/airflow/sdk/bases/operator.py
index f040d99bc2f..27c3e702d53 100644
--- a/task-sdk/src/airflow/sdk/bases/operator.py
+++ b/task-sdk/src/airflow/sdk/bases/operator.py
@@ -138,6 +138,7 @@ def _get_parent_defaults(dag: DAG | None, task_group:
TaskGroup | None) -> tuple
return {}, ParamsDict()
dag_args = copy.copy(dag.default_args)
dag_params = copy.deepcopy(dag.params)
+ dag_params._fill_missing_param_source("dag")
if task_group:
if task_group.default_args and not isinstance(task_group.default_args,
collections.abc.Mapping):
raise TypeError("default_args must be a mapping")
@@ -155,13 +156,20 @@ def get_merged_defaults(
if task_params:
if not isinstance(task_params, collections.abc.Mapping):
raise TypeError(f"params must be a mapping, got
{type(task_params)}")
+
+ task_params = ParamsDict(task_params)
+ task_params._fill_missing_param_source("task")
params.update(task_params)
+
if task_default_args:
if not isinstance(task_default_args, collections.abc.Mapping):
raise TypeError(f"default_args must be a mapping, got
{type(task_params)}")
args.update(task_default_args)
with contextlib.suppress(KeyError):
- params.update(task_default_args["params"] or {})
+ if params_from_default_args :=
ParamsDict(task_default_args["params"] or {}):
+ params_from_default_args._fill_missing_param_source("task")
+ params.update(params_from_default_args)
+
return args, params
diff --git a/task-sdk/src/airflow/sdk/definitions/param.py
b/task-sdk/src/airflow/sdk/definitions/param.py
index 410da71fde7..d1174ec561b 100644
--- a/task-sdk/src/airflow/sdk/definitions/param.py
+++ b/task-sdk/src/airflow/sdk/definitions/param.py
@@ -21,7 +21,7 @@ import copy
import json
import logging
from collections.abc import ItemsView, Iterable, Mapping, MutableMapping,
ValuesView
-from typing import TYPE_CHECKING, Any, ClassVar
+from typing import TYPE_CHECKING, Any, ClassVar, Literal
from airflow.sdk.definitions._internal.mixins import ResolveMixin
from airflow.sdk.definitions._internal.types import NOTSET, is_arg_set
@@ -51,15 +51,27 @@ class Param:
CLASS_IDENTIFIER = "__class"
- def __init__(self, default: Any = NOTSET, description: str | None = None,
**kwargs):
+ def __init__(
+ self,
+ default: Any = NOTSET,
+ description: str | None = None,
+ source: Literal["dag", "task"] | None = None,
+ **kwargs,
+ ):
if default is not NOTSET:
self._check_json(default)
self.value = default
self.description = description
self.schema = kwargs.pop("schema") if "schema" in kwargs else kwargs
+ self.source = source
def __copy__(self) -> Param:
- return Param(self.value, self.description, schema=self.schema)
+ return Param(
+ self.value,
+ self.description,
+ schema=self.schema,
+ source=self.source,
+ )
@staticmethod
def _check_json(value):
@@ -119,14 +131,24 @@ class Param:
return self.value is not NOTSET and self.value is not None
def serialize(self) -> dict:
- return {"value": self.value, "description": self.description,
"schema": self.schema}
+ return {
+ "value": self.value,
+ "description": self.description,
+ "schema": self.schema,
+ "source": self.source,
+ }
@staticmethod
def deserialize(data: dict[str, Any], version: int) -> Param:
if version > Param.__version__:
raise TypeError("serialized version > class version")
- return Param(default=data["value"], description=data["description"],
schema=data["schema"])
+ return Param(
+ default=data["value"],
+ description=data["description"],
+ schema=data["schema"],
+ source=data.get("source", None),
+ )
class ParamsDict(MutableMapping[str, Any]):
@@ -253,6 +275,20 @@ class ParamsDict(MutableMapping[str, Any]):
return ParamsDict(data)
+ def _fill_missing_param_source(
+ self,
+ source: Literal["dag", "task"] | None = None,
+ ) -> None:
+ for key in self.__dict:
+ if self.__dict[key].source is None:
+ self.__dict[key].source = source
+
+ @staticmethod
+ def filter_params_by_source(params: ParamsDict, source: Literal["dag",
"task"]) -> ParamsDict:
+ return ParamsDict(
+ {key: param for key, param in params.__dict.items() if
param.source == source},
+ )
+
class DagParam(ResolveMixin):
"""
diff --git a/task-sdk/tests/task_sdk/bases/test_operator.py
b/task-sdk/tests/task_sdk/bases/test_operator.py
index 535ba511d09..5db27774e23 100644
--- a/task-sdk/tests/task_sdk/bases/test_operator.py
+++ b/task-sdk/tests/task_sdk/bases/test_operator.py
@@ -41,6 +41,7 @@ from airflow.sdk.bases.operator import (
)
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.edges import Label
+from airflow.sdk.definitions.param import ParamsDict
from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.sdk.definitions.template import literal
from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy,
_UpstreamPriorityWeightStrategy
@@ -785,6 +786,36 @@ class TestBaseOperator:
task.render_template_fields(context={"foo": "whatever", "bar":
"whatever"})
assert mock_jinja_env.call_count == 1
+ def test_params_source(self):
+ # Test bug when copying an operator attached to a Dag
+ with DAG(
+ "dag0",
+ params=ParamsDict(
+ {
+ "param from Dag": "value1",
+ "overwritten by task": "value 2",
+ }
+ ),
+ schedule=None,
+ start_date=DEFAULT_DATE,
+ ):
+ op1 = MockOperator(
+ task_id="task1",
+ params=ParamsDict(
+ {
+ "overwritten by task": "value 3",
+ "param from task": "value 4",
+ }
+ ),
+ )
+
+ for key, expected_source in (
+ ("param from Dag", "dag"),
+ ("overwritten by task", "task"),
+ ("param from task", "task"),
+ ):
+ assert op1.params.get_param(key).source == expected_source
+
def test_deepcopy(self):
# Test bug when copying an operator attached to a Dag
with DAG("dag0", schedule=None, start_date=DEFAULT_DATE) as dag:
diff --git a/task-sdk/tests/task_sdk/definitions/test_param.py
b/task-sdk/tests/task_sdk/definitions/test_param.py
index aa2aa71c51a..887003d13dc 100644
--- a/task-sdk/tests/task_sdk/definitions/test_param.py
+++ b/task-sdk/tests/task_sdk/definitions/test_param.py
@@ -17,6 +17,7 @@
from __future__ import annotations
from contextlib import nullcontext
+from typing import Literal
import pytest
@@ -207,10 +208,13 @@ class TestParam:
def test_dump(self):
p = Param("hello", description="world", type="string", minLength=2)
dump = p.dump()
- assert dump["__class"] == "airflow.sdk.definitions.param.Param"
- assert dump["value"] == "hello"
- assert dump["description"] == "world"
- assert dump["schema"] == {"type": "string", "minLength": 2}
+ assert dump == {
+ "__class": "airflow.sdk.definitions.param.Param",
+ "value": "hello",
+ "description": "world",
+ "schema": {"type": "string", "minLength": 2},
+ "source": None,
+ }
@pytest.mark.parametrize(
"param",
@@ -307,3 +311,47 @@ class TestParamsDict:
def test_repr(self):
pd = ParamsDict({"key": Param("value", type="string")})
assert repr(pd) == "{'key': 'value'}"
+
+ @pytest.mark.parametrize("source", ("dag", "task"))
+ def test_fill_missing_param_source(self, source: Literal["dag", "task"]):
+ pd = ParamsDict(
+ {
+ "key": Param("value", type="string"),
+ "key2": "value2",
+ }
+ )
+ pd._fill_missing_param_source(source)
+ for param in pd.values():
+ assert param.source == source
+
+ def test_fill_missing_param_source_not_overwrite_existing(self):
+ pd = ParamsDict(
+ {
+ "key": Param("value", type="string", source="dag"),
+ "key2": "value2",
+ "key3": "value3",
+ }
+ )
+ pd._fill_missing_param_source("task")
+ for key, expected_source in (
+ ("key", "dag"),
+ ("key2", "task"),
+ ("key3", "task"),
+ ):
+ assert pd.get_param(key).source == expected_source
+
+ def test_filter_params_by_source(self):
+ pd = ParamsDict(
+ {
+ "key": Param("value", type="string", source="dag"),
+ "key2": Param("value", source="task"),
+ }
+ )
+ assert ParamsDict.filter_params_by_source(pd, "dag") == ParamsDict(
+ {"key": Param("value", type="string", source="dag")},
+ )
+ assert ParamsDict.filter_params_by_source(pd, "task") == ParamsDict(
+ {
+ "key2": Param("value", source="task"),
+ }
+ )