ashb commented on code in PR #46176:
URL: https://github.com/apache/airflow/pull/46176#discussion_r1935358299
##########
providers/edge/src/airflow/providers/edge/example_dags/win_test.py:
##########
@@ -37,9 +37,9 @@
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator
from airflow.models.dag import DAG
-from airflow.models.param import Param
from airflow.models.variable import Variable
from airflow.operators.empty import EmptyOperator
+from airflow.sdk.definitions.param import Param
Review Comment:
```suggestion
from airflow.sdk import Param
```
##########
providers/edge/src/airflow/providers/edge/example_dags/win_notepad.py:
##########
@@ -34,7 +34,7 @@
from airflow.models import BaseOperator
from airflow.models.dag import DAG
-from airflow.models.param import Param
+from airflow.sdk.definitions.param import Param
Review Comment:
```suggestion
from airflow.sdk import Param
```
##########
tests/models/test_param.py:
##########
@@ -22,278 +22,13 @@
from airflow.decorators import task
from airflow.exceptions import ParamValidationError
-from airflow.models.param import Param, ParamsDict
-from airflow.serialization.serialized_objects import BaseSerialization
+from airflow.sdk.definitions.param import Param
from airflow.utils import timezone
from airflow.utils.types import DagRunType
from tests_common.test_utils.db import clear_db_dags, clear_db_runs,
clear_db_xcom
-class TestParam:
- def test_param_without_schema(self):
- p = Param("test")
- assert p.resolve() == "test"
-
- p.value = 10
- assert p.resolve() == 10
-
- def test_null_param(self):
- p = Param()
- with pytest.raises(ParamValidationError, match="No value passed and
Param has no default value"):
- p.resolve()
- assert p.resolve(None) is None
- assert p.dump()["value"] is None
- assert not p.has_value
-
- p = Param(None)
- assert p.resolve() is None
- assert p.resolve(None) is None
- assert p.dump()["value"] is None
- assert not p.has_value
-
- p = Param(None, type="null")
- assert p.resolve() is None
- assert p.resolve(None) is None
- assert p.dump()["value"] is None
- assert not p.has_value
- with pytest.raises(ParamValidationError):
- p.resolve("test")
-
- def test_string_param(self):
- p = Param("test", type="string")
- assert p.resolve() == "test"
-
- p = Param("test")
- assert p.resolve() == "test"
-
- p = Param("10.0.0.0", type="string", format="ipv4")
- assert p.resolve() == "10.0.0.0"
-
- p = Param(type="string")
- with pytest.raises(ParamValidationError):
- p.resolve(None)
- with pytest.raises(ParamValidationError, match="No value passed and
Param has no default value"):
- p.resolve()
-
- @pytest.mark.parametrize(
- "dt",
- [
- pytest.param("2022-01-02T03:04:05.678901Z",
id="microseconds-zed-timezone"),
- pytest.param("2022-01-02T03:04:05.678Z",
id="milliseconds-zed-timezone"),
- pytest.param("2022-01-02T03:04:05+00:00",
id="seconds-00-00-timezone"),
- pytest.param("2022-01-02T03:04:05+04:00",
id="seconds-custom-timezone"),
- ],
- )
- def test_string_rfc3339_datetime_format(self, dt):
- """Test valid rfc3339 datetime."""
- assert Param(dt, type="string", format="date-time").resolve() == dt
-
- @pytest.mark.parametrize(
- "dt",
- [
- pytest.param("2022-01-02", id="date"),
- pytest.param("03:04:05", id="time"),
- pytest.param("Thu, 04 Mar 2021 05:06:07 GMT",
id="rfc2822-datetime"),
- ],
- )
- def test_string_datetime_invalid_format(self, dt):
- """Test invalid iso8601 and rfc3339 datetime format."""
- with pytest.raises(ParamValidationError, match="is not a 'date-time'"):
- Param(dt, type="string", format="date-time").resolve()
-
- def test_string_time_format(self):
- """Test string time format."""
- assert Param("03:04:05", type="string", format="time").resolve() ==
"03:04:05"
-
- error_pattern = "is not a 'time'"
- with pytest.raises(ParamValidationError, match=error_pattern):
- Param("03:04:05.06", type="string", format="time").resolve()
-
- with pytest.raises(ParamValidationError, match=error_pattern):
- Param("03:04", type="string", format="time").resolve()
-
- with pytest.raises(ParamValidationError, match=error_pattern):
- Param("24:00:00", type="string", format="time").resolve()
-
- @pytest.mark.parametrize(
- "date_string",
- [
- "2021-01-01",
- ],
- )
- def test_string_date_format(self, date_string):
- """Test string date format."""
- assert Param(date_string, type="string", format="date").resolve() ==
date_string
-
- # Note that 20120503 behaved differently in 3.11.3 Official python image.
It was validated as a date
- # there but it started to fail again in 3.11.4 released on 2023-07-05.
- @pytest.mark.parametrize(
- "date_string",
- [
- "01/01/2021",
- "21 May 1975",
- "20120503",
- ],
- )
- def test_string_date_format_error(self, date_string):
- """Test string date format failures."""
- with pytest.raises(ParamValidationError, match="is not a 'date'"):
- Param(date_string, type="string", format="date").resolve()
-
- def test_int_param(self):
- p = Param(5)
- assert p.resolve() == 5
-
- p = Param(type="integer", minimum=0, maximum=10)
- assert p.resolve(value=5) == 5
-
- with pytest.raises(ParamValidationError):
- p.resolve(value=20)
-
- def test_number_param(self):
- p = Param(42, type="number")
- assert p.resolve() == 42
-
- p = Param(1.2, type="number")
- assert p.resolve() == 1.2
-
- p = Param("42", type="number")
- with pytest.raises(ParamValidationError):
- p.resolve()
-
- def test_list_param(self):
- p = Param([1, 2], type="array")
- assert p.resolve() == [1, 2]
-
- def test_dict_param(self):
- p = Param({"a": 1, "b": 2}, type="object")
- assert p.resolve() == {"a": 1, "b": 2}
-
- def test_composite_param(self):
- p = Param(type=["string", "number"])
- assert p.resolve(value="abc") == "abc"
- assert p.resolve(value=5.0) == 5.0
-
- def test_param_with_description(self):
- p = Param(10, description="Sample description")
- assert p.description == "Sample description"
-
- def test_suppress_exception(self):
- p = Param("abc", type="string", minLength=2, maxLength=4)
- assert p.resolve() == "abc"
-
- p.value = "long_string"
- assert p.resolve(suppress_exception=True) is None
-
- def test_explicit_schema(self):
- p = Param("abc", schema={type: "string"})
- assert p.resolve() == "abc"
-
- def test_custom_param(self):
- class S3Param(Param):
- def __init__(self, path: str):
- schema = {"type": "string", "pattern": r"s3:\/\/(.+?)\/(.+)"}
- super().__init__(default=path, schema=schema)
-
- p = S3Param("s3://my_bucket/my_path")
- assert p.resolve() == "s3://my_bucket/my_path"
-
- p = S3Param("file://not_valid/s3_path")
- with pytest.raises(ParamValidationError):
- p.resolve()
-
- def test_value_saved(self):
- p = Param("hello", type="string")
- assert p.resolve("world") == "world"
- assert p.resolve() == "world"
-
- def test_dump(self):
- p = Param("hello", description="world", type="string", minLength=2)
- dump = p.dump()
- assert dump["__class"] == "airflow.models.param.Param"
- assert dump["value"] == "hello"
- assert dump["description"] == "world"
- assert dump["schema"] == {"type": "string", "minLength": 2}
-
- @pytest.mark.parametrize(
- "param",
- [
- Param("my value", description="hello", schema={"type": "string"}),
- Param("my value", description="hello"),
- Param(None, description=None),
- Param([True], type="array", items={"type": "boolean"}),
- Param(),
- ],
- )
- def test_param_serialization(self, param: Param):
- """
- Test to make sure that native Param objects can be correctly serialized
- """
-
- serializer = BaseSerialization()
- serialized_param = serializer.serialize(param)
- restored_param: Param = serializer.deserialize(serialized_param)
-
- assert restored_param.value == param.value
- assert isinstance(restored_param, Param)
- assert restored_param.description == param.description
- assert restored_param.schema == param.schema
-
-
-class TestParamsDict:
- def test_params_dict(self):
- # Init with a simple dictionary
- pd = ParamsDict(dict_obj={"key": "value"})
- assert isinstance(pd.get_param("key"), Param)
- assert pd["key"] == "value"
- assert pd.suppress_exception is False
-
- # Init with a dict which contains Param objects
- pd2 = ParamsDict({"key": Param("value", type="string")},
suppress_exception=True)
- assert isinstance(pd2.get_param("key"), Param)
- assert pd2["key"] == "value"
- assert pd2.suppress_exception is True
-
- # Init with another object of another ParamsDict
- pd3 = ParamsDict(pd2)
- assert isinstance(pd3.get_param("key"), Param)
- assert pd3["key"] == "value"
- assert pd3.suppress_exception is False # as it's not a deepcopy of pd2
-
- # Dump the ParamsDict
- assert pd.dump() == {"key": "value"}
- assert pd2.dump() == {"key": "value"}
- assert pd3.dump() == {"key": "value"}
-
- # Validate the ParamsDict
- plain_dict = pd.validate()
- assert isinstance(plain_dict, dict)
- pd2.validate()
- pd3.validate()
-
- # Update the ParamsDict
- with pytest.raises(ParamValidationError, match=r"Invalid input for
param key: 1 is not"):
- pd3["key"] = 1
-
- # Should not raise an error as suppress_exception is True
- pd2["key"] = 1
- pd2.validate()
-
- def test_update(self):
- pd = ParamsDict({"key": Param("value", type="string")})
-
- pd.update({"key": "a"})
- internal_value = pd.get_param("key")
- assert isinstance(internal_value, Param)
- with pytest.raises(ParamValidationError, match=r"Invalid input for
param key: 1 is not"):
- pd.update({"key": 1})
-
- def test_repr(self):
- pd = ParamsDict({"key": Param("value", type="string")})
- assert repr(pd) == "{'key': 'value'}"
-
-
class TestDagParamRuntime:
Review Comment:
I think we might need to move these too?
This is testing the runtime resolution behavoiur (i.e. that the task is
called with the value, not the Param object, and that dagrun.conf is used in
preference over the dag params.
##########
airflow/models/param.py:
##########
@@ -14,340 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-
-import contextlib
-import copy
-import json
-import logging
-from collections.abc import ItemsView, Iterable, MutableMapping, ValuesView
-from typing import TYPE_CHECKING, Any, ClassVar
-
-from airflow.exceptions import AirflowException, ParamValidationError
-from airflow.sdk.definitions._internal.mixins import ResolveMixin
-from airflow.utils.types import NOTSET, ArgNotSet
-
-if TYPE_CHECKING:
- from airflow.sdk.definitions.context import Context
- from airflow.sdk.definitions.dag import DAG
- from airflow.sdk.types import Operator
-
-logger = logging.getLogger(__name__)
-
-
-class Param:
- """
- Class to hold the default value of a Param and rule set to do the
validations.
-
- Without the rule set it always validates and returns the default value.
-
- :param default: The value this Param object holds
- :param description: Optional help text for the Param
- :param schema: The validation schema of the Param, if not given then all
kwargs except
- default & description will form the schema
- """
-
- __version__: ClassVar[int] = 1
-
- CLASS_IDENTIFIER = "__class"
-
- def __init__(self, default: Any = NOTSET, description: str | None = None,
**kwargs):
- if default is not NOTSET:
- self._check_json(default)
- self.value = default
- self.description = description
- self.schema = kwargs.pop("schema") if "schema" in kwargs else kwargs
-
- def __copy__(self) -> Param:
- return Param(self.value, self.description, schema=self.schema)
-
- @staticmethod
- def _check_json(value):
- try:
- json.dumps(value)
- except Exception:
- raise ParamValidationError(
- "All provided parameters must be json-serializable. "
- f"The value '{value}' is not serializable."
- )
-
- def resolve(self, value: Any = NOTSET, suppress_exception: bool = False)
-> Any:
- """
- Run the validations and returns the Param's final value.
-
- May raise ValueError on failed validations, or TypeError
- if no value is passed and no value already exists.
- We first check that value is json-serializable; if not, warn.
- In future release we will require the value to be json-serializable.
-
- :param value: The value to be updated for the Param
- :param suppress_exception: To raise an exception or not when the
validations fails.
- If true and validations fails, the return value would be None.
- """
- import jsonschema
- from jsonschema import FormatChecker
- from jsonschema.exceptions import ValidationError
-
- if value is not NOTSET:
- self._check_json(value)
- final_val = self.value if value is NOTSET else value
- if isinstance(final_val, ArgNotSet):
- if suppress_exception:
- return None
- raise ParamValidationError("No value passed and Param has no
default value")
- try:
- jsonschema.validate(final_val, self.schema,
format_checker=FormatChecker())
- except ValidationError as err:
- if suppress_exception:
- return None
- raise ParamValidationError(err) from None
- self.value = final_val
- return final_val
-
- def dump(self) -> dict:
- """Dump the Param as a dictionary."""
- out_dict: dict[str, str | None] = {
- self.CLASS_IDENTIFIER:
f"{self.__module__}.{self.__class__.__name__}"
- }
- out_dict.update(self.__dict__)
- # Ensure that not set is translated to None
- if self.value is NOTSET:
- out_dict["value"] = None
- return out_dict
-
- @property
- def has_value(self) -> bool:
- return self.value is not NOTSET and self.value is not None
-
- def serialize(self) -> dict:
- return {"value": self.value, "description": self.description,
"schema": self.schema}
-
- @staticmethod
- def deserialize(data: dict[str, Any], version: int) -> Param:
- if version > Param.__version__:
- raise TypeError("serialized version > class version")
-
- return Param(default=data["value"], description=data["description"],
schema=data["schema"])
-
-
-class ParamsDict(MutableMapping[str, Any]):
- """
- Class to hold all params for dags or tasks.
-
- All the keys are strictly string and values are converted into Param's
object
- if they are not already. This class is to replace param's dictionary
implicitly
- and ideally not needed to be used directly.
-
-
- :param dict_obj: A dict or dict like object to init ParamsDict
- :param suppress_exception: Flag to suppress value exceptions while
initializing the ParamsDict
- """
-
- __version__: ClassVar[int] = 1
- __slots__ = ["__dict", "suppress_exception"]
-
- def __init__(self, dict_obj: MutableMapping | None = None,
suppress_exception: bool = False):
- params_dict: dict[str, Param] = {}
- dict_obj = dict_obj or {}
- for k, v in dict_obj.items():
- if not isinstance(v, Param):
- params_dict[k] = Param(v)
- else:
- params_dict[k] = v
- self.__dict = params_dict
- self.suppress_exception = suppress_exception
-
- def __bool__(self) -> bool:
- return bool(self.__dict)
-
- def __eq__(self, other: Any) -> bool:
- if isinstance(other, ParamsDict):
- return self.dump() == other.dump()
- if isinstance(other, dict):
- return self.dump() == other
- return NotImplemented
-
- def __copy__(self) -> ParamsDict:
- return ParamsDict(self.__dict, self.suppress_exception)
-
- def __deepcopy__(self, memo: dict[int, Any] | None) -> ParamsDict:
- return ParamsDict(copy.deepcopy(self.__dict, memo),
self.suppress_exception)
-
- def __contains__(self, o: object) -> bool:
- return o in self.__dict
-
- def __len__(self) -> int:
- return len(self.__dict)
- def __delitem__(self, v: str) -> None:
- del self.__dict[v]
+"""Re exporting the new param module from Task SDK for backward
compatibility."""
- def __iter__(self):
- return iter(self.__dict)
-
- def __repr__(self):
- return repr(self.dump())
-
- def __setitem__(self, key: str, value: Any) -> None:
- """
- Override for dictionary's ``setitem`` method to ensure all values are
of Param's type only.
-
- :param key: A key which needs to be inserted or updated in the dict
- :param value: A value which needs to be set against the key. It could
be of any
- type but will be converted and stored as a Param object eventually.
- """
- if isinstance(value, Param):
- param = value
- elif key in self.__dict:
- param = self.__dict[key]
- try:
- param.resolve(value=value,
suppress_exception=self.suppress_exception)
- except ParamValidationError as ve:
- raise ParamValidationError(f"Invalid input for param {key}:
{ve}") from None
- else:
- # if the key isn't there already and if the value isn't of Param
type create a new Param object
- param = Param(value)
-
- self.__dict[key] = param
-
- def __getitem__(self, key: str) -> Any:
- """
- Override for dictionary's ``getitem`` method to call the resolve
method after fetching the key.
-
- :param key: The key to fetch
- """
- param = self.__dict[key]
- return param.resolve(suppress_exception=self.suppress_exception)
-
- def get_param(self, key: str) -> Param:
- """Get the internal :class:`.Param` object for this key."""
- return self.__dict[key]
-
- def items(self):
- return ItemsView(self.__dict)
-
- def values(self):
- return ValuesView(self.__dict)
-
- def update(self, *args, **kwargs) -> None:
- if len(args) == 1 and not kwargs and isinstance(args[0], ParamsDict):
- return super().update(args[0].__dict)
- super().update(*args, **kwargs)
-
- def dump(self) -> dict[str, Any]:
- """Dump the ParamsDict object as a dictionary, while suppressing
exceptions."""
- return {k: v.resolve(suppress_exception=True) for k, v in self.items()}
-
- def validate(self) -> dict[str, Any]:
- """Validate & returns all the Params object stored in the
dictionary."""
- resolved_dict = {}
- try:
- for k, v in self.items():
- resolved_dict[k] =
v.resolve(suppress_exception=self.suppress_exception)
- except ParamValidationError as ve:
- raise ParamValidationError(f"Invalid input for param {k}: {ve}")
from None
-
- return resolved_dict
-
- def serialize(self) -> dict[str, Any]:
- return self.dump()
-
- @staticmethod
- def deserialize(data: dict, version: int) -> ParamsDict:
- if version > ParamsDict.__version__:
- raise TypeError("serialized version > class version")
-
- return ParamsDict(data)
-
-
-class DagParam(ResolveMixin):
- """
- DAG run parameter reference.
-
- This binds a simple Param object to a name within a DAG instance, so that
it
- can be resolved during the runtime via the ``{{ context }}`` dictionary.
The
- ideal use case of this class is to implicitly convert args passed to a
- method decorated by ``@dag``.
-
- It can be used to parameterize a DAG. You can overwrite its value by
setting
- it on conf when you trigger your DagRun.
-
- This can also be used in templates by accessing ``{{ context.params }}``.
-
- **Example**:
-
- with DAG(...) as dag:
- EmailOperator(subject=dag.param('subject', 'Hi from Airflow!'))
-
- :param current_dag: Dag being used for parameter.
- :param name: key value which is used to set the parameter
- :param default: Default value used if no parameter was set.
- """
-
- def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET):
- if default is not NOTSET:
- current_dag.params[name] = default
- self._name = name
- self._default = default
- self.current_dag = current_dag
-
- def iter_references(self) -> Iterable[tuple[Operator, str]]:
- return ()
-
- def resolve(self, context: Context, *, include_xcom: bool = True) -> Any:
- """Pull DagParam value from DagRun context. This method is run during
``op.execute()``."""
- with contextlib.suppress(KeyError):
- if context["dag_run"].conf:
- return context["dag_run"].conf[self._name]
- if self._default is not NOTSET:
- return self._default
- with contextlib.suppress(KeyError):
- return context["params"][self._name]
- raise AirflowException(f"No value could be resolved for parameter
{self._name}")
-
- def serialize(self) -> dict:
- """Serialize the DagParam object into a dictionary."""
- return {
- "dag_id": self.current_dag.dag_id,
- "name": self._name,
- "default": self._default,
- }
-
- @classmethod
- def deserialize(cls, data: dict, dags: dict) -> DagParam:
- """
- Deserializes the dictionary back into a DagParam object.
-
- :param data: The serialized representation of the DagParam.
- :param dags: A dictionary of available DAGs to look up the DAG.
- """
- dag_id = data["dag_id"]
- # Retrieve the current DAG from the provided DAGs dictionary
- current_dag = dags.get(dag_id)
- if not current_dag:
- raise ValueError(f"DAG with id {dag_id} not found.")
-
- return cls(current_dag=current_dag, name=data["name"],
default=data["default"])
-
-
-def process_params(
- dag: DAG,
- task: Operator,
- dagrun_conf: dict[str, Any] | None,
- *,
- suppress_exception: bool,
-) -> dict[str, Any]:
- """Merge, validate params, and convert them into a simple dict."""
- from airflow.configuration import conf
-
- dagrun_conf = dagrun_conf or {}
+from __future__ import annotations
- params = ParamsDict(suppress_exception=suppress_exception)
- with contextlib.suppress(AttributeError):
- params.update(dag.params)
- if task.params:
- params.update(task.params)
- if conf.getboolean("core", "dag_run_conf_overrides_params") and
dagrun_conf:
- logger.debug("Updating task params (%s) with DagRun.conf (%s)",
params, dagrun_conf)
- params.update(dagrun_conf)
- return params.validate()
+from airflow.sdk.definitions.param import * # noqa: F403
Review Comment:
Do we need `*`, or could we do
```suggestion
from airflow.sdk.definitions.param import Param, ParamsDict
__all__ = ["Param", "ParamsDict"]
```
##########
providers/edge/src/airflow/providers/edge/example_dags/win_notepad.py:
##########
@@ -34,7 +34,7 @@
from airflow.models import BaseOperator
from airflow.models.dag import DAG
-from airflow.models.param import Param
+from airflow.sdk.definitions.param import Param
Review Comment:
Just for consistency
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]