amoghrajesh commented on code in PR #46032: URL: https://github.com/apache/airflow/pull/46032#discussion_r1930373706
########## task_sdk/src/airflow/sdk/definitions/_internal/expandinput.py: ########## @@ -0,0 +1,271 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import functools +import operator +from collections.abc import Iterable, Mapping, Sequence, Sized +from typing import TYPE_CHECKING, Any, ClassVar, Union + +import attrs + +from airflow.sdk.definitions._internal.mixins import ResolveMixin + +if TYPE_CHECKING: + from airflow.sdk.definitions.xcom_arg import XComArg + from airflow.sdk.types import Operator + from airflow.typing_compat import TypeGuard + +ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] + +# Each keyword argument to expand() can be an XComArg, sequence, or dict (not +# any mapping since we need the value to be ordered). +OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, dict[str, Any]] + +# The single argument of expand_kwargs() can be an XComArg, or a list with each +# element being either an XComArg or a dict. +OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]] + + +class NotFullyPopulated(RuntimeError): + """ + Raise when ``get_map_lengths`` cannot populate all mapping metadata. + + This is generally due to not all upstream tasks have finished when the + function is called. + """ + + def __init__(self, missing: set[str]) -> None: + self.missing = missing + + def __str__(self) -> str: + keys = ", ".join(repr(k) for k in sorted(self.missing)) + return f"Failed to populate all mapping metadata; missing: {keys}" + + +# To replace tedious isinstance() checks. +def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]: + from airflow.sdk.definitions.xcom_arg import XComArg + + return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str) + + +# To replace tedious isinstance() checks. +def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]: + from airflow.models.xcom_arg import XComArg + + return not isinstance(v, (MappedArgument, XComArg)) + + +# To replace tedious isinstance() checks. +def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]: + from airflow.models.xcom_arg import XComArg + + return isinstance(v, (MappedArgument, XComArg)) + + [email protected](kw_only=True) +class MappedArgument(ResolveMixin): + """ + Stand-in stub for task-group-mapping arguments. + + This is very similar to an XComArg, but resolved differently. Declared here + (instead of in the task group module) to avoid import cycles. + """ + + _input: ExpandInput + _key: str + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + yield from self._input.iter_references() + + def resolve(self, context: Mapping[str, Any], *, include_xcom: bool = True) -> Any: + data, _ = self._input.resolve(context, include_xcom=include_xcom) + return data[self._key] + + [email protected]() +class DictOfListsExpandInput(ResolveMixin): + """ + Storage type of a mapped operator's mapped kwargs. + + This is created from ``expand(**kwargs)``. + """ + + value: dict[str, OperatorExpandArgument] + + EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists" + + def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: + """Generate kwargs with values available on parse-time.""" + return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v)) + + def get_parse_time_mapped_ti_count(self) -> int: + if not self.value: + return 0 + literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()] + if len(literal_values) != len(self.value): + literal_keys = (k for k, _ in self._iter_parse_time_resolved_kwargs()) + raise NotFullyPopulated(set(self.value).difference(literal_keys)) + return functools.reduce(operator.mul, literal_values, 1) + + def _get_map_lengths(self, resolved_val: Sized, upstream_map_indexes: dict[str, int]) -> dict[str, int]: + """ + Return dict of argument name to map length. + + If any arguments are not known right now (upstream task not finished), + they will not be present in the dict. + """ + + # TODO: This initiates one API call for each XComArg. Would it be + # more efficient to do one single call and unpack the value here? Review Comment: Yeah, i think so ########## tests/models/test_dagrun.py: ########## @@ -1638,8 +1625,8 @@ def double(value): @dag.task def consumer(value): - nonlocal result - result = list(value) + ... + # result = list(value) Review Comment: Intentional? ########## task_sdk/src/airflow/sdk/definitions/_internal/expandinput.py: ########## @@ -0,0 +1,271 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import functools +import operator +from collections.abc import Iterable, Mapping, Sequence, Sized +from typing import TYPE_CHECKING, Any, ClassVar, Union + +import attrs + +from airflow.sdk.definitions._internal.mixins import ResolveMixin + +if TYPE_CHECKING: + from airflow.sdk.definitions.xcom_arg import XComArg + from airflow.sdk.types import Operator + from airflow.typing_compat import TypeGuard + +ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] + +# Each keyword argument to expand() can be an XComArg, sequence, or dict (not +# any mapping since we need the value to be ordered). +OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, dict[str, Any]] + +# The single argument of expand_kwargs() can be an XComArg, or a list with each +# element being either an XComArg or a dict. +OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]] + + +class NotFullyPopulated(RuntimeError): + """ + Raise when ``get_map_lengths`` cannot populate all mapping metadata. + + This is generally due to not all upstream tasks have finished when the + function is called. + """ + + def __init__(self, missing: set[str]) -> None: + self.missing = missing + + def __str__(self) -> str: + keys = ", ".join(repr(k) for k in sorted(self.missing)) + return f"Failed to populate all mapping metadata; missing: {keys}" + + +# To replace tedious isinstance() checks. +def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]: + from airflow.sdk.definitions.xcom_arg import XComArg + + return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str) Review Comment: ```suggestion def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]: if isinstance(v, str): return False return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) ``` Better? ########## task_sdk/src/airflow/sdk/definitions/_internal/expandinput.py: ########## @@ -0,0 +1,271 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import functools +import operator +from collections.abc import Iterable, Mapping, Sequence, Sized +from typing import TYPE_CHECKING, Any, ClassVar, Union + +import attrs + +from airflow.sdk.definitions._internal.mixins import ResolveMixin + +if TYPE_CHECKING: + from airflow.sdk.definitions.xcom_arg import XComArg + from airflow.sdk.types import Operator + from airflow.typing_compat import TypeGuard + +ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] + +# Each keyword argument to expand() can be an XComArg, sequence, or dict (not +# any mapping since we need the value to be ordered). +OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, dict[str, Any]] + +# The single argument of expand_kwargs() can be an XComArg, or a list with each +# element being either an XComArg or a dict. +OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]] + + +class NotFullyPopulated(RuntimeError): + """ + Raise when ``get_map_lengths`` cannot populate all mapping metadata. + + This is generally due to not all upstream tasks have finished when the + function is called. + """ + + def __init__(self, missing: set[str]) -> None: + self.missing = missing + + def __str__(self) -> str: + keys = ", ".join(repr(k) for k in sorted(self.missing)) + return f"Failed to populate all mapping metadata; missing: {keys}" + + +# To replace tedious isinstance() checks. +def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]: + from airflow.sdk.definitions.xcom_arg import XComArg + + return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str) + + +# To replace tedious isinstance() checks. +def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]: + from airflow.models.xcom_arg import XComArg + + return not isinstance(v, (MappedArgument, XComArg)) + + +# To replace tedious isinstance() checks. +def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]: + from airflow.models.xcom_arg import XComArg + + return isinstance(v, (MappedArgument, XComArg)) + + [email protected](kw_only=True) +class MappedArgument(ResolveMixin): + """ + Stand-in stub for task-group-mapping arguments. + + This is very similar to an XComArg, but resolved differently. Declared here + (instead of in the task group module) to avoid import cycles. + """ + + _input: ExpandInput + _key: str + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + yield from self._input.iter_references() + + def resolve(self, context: Mapping[str, Any], *, include_xcom: bool = True) -> Any: + data, _ = self._input.resolve(context, include_xcom=include_xcom) + return data[self._key] + + [email protected]() +class DictOfListsExpandInput(ResolveMixin): Review Comment: Wow this name is hard to understand, but its historical, so i will leave that decision to keep or not to you ########## task_sdk/src/airflow/sdk/execution_time/task_runner.py: ########## @@ -147,6 +147,11 @@ def get_template_context(self) -> Context: } context.update(context_from_server) + if from_server.upstream_map_indexes is not None: + # We stash this in here for later use, but we purposefully don't want to document it's + # existence. Should this be a private attribute on RuntimeTI instead perhaps? + context["_upstream_map_indexes"] = from_server.upstream_map_indexes # type: ignore [typeddict-unknown-key] Review Comment: Umm yeah, private and no documentation should do. ########## task_sdk/src/airflow/sdk/execution_time/context.py: ########## @@ -315,3 +316,22 @@ def set_current_context(context: Context) -> Generator[Context, None, None]: expected=context, got=expected_state, ) + + +def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: + """ + Update context after task unmapping. + + Since ``get_template_context()`` is called before unmapping, the context + contains information about the mapped task. We need to do some in-place + updates to ensure the template context reflects the unmapped task instead. + + :meta private: + """ + # TODO: Task-SDK this need to live in sdk too + from airflow.models.param import process_params Review Comment: Yeah we have a ticket for this. Can be tracked by https://github.com/apache/airflow/issues/44361 ########## airflow/models/xcom_arg.py: ########## @@ -114,3 +178,17 @@ def _(xcom_arg: ConcatXComArg, run_id: str, *, session: Session): if len(ready_lengths) != len(xcom_arg.args): return None # If any of the referenced XComs is not ready, we are not ready either. return sum(ready_lengths) + + +def deserialize_xcom_arg(data: dict[str, Any], dag: SchedulerDAG): + """DAG serialization interface.""" + klass = _XCOM_ARG_TYPES[data.get("type", "")] + return klass._deserialize(data, dag) + + +_XCOM_ARG_TYPES: dict[str, type[SchedulerXComArg]] = { + "": SchedulerPlainXComArg, + "concat": SchedulerConcatXComArg, + "map": SchedulerMapXComArg, + "zip": SchedulerZipXComArg, +} Review Comment: Nice! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
