jedcunningham commented on code in PR #46032:
URL: https://github.com/apache/airflow/pull/46032#discussion_r1943685513
##########
airflow/models/baseoperator.py:
##########
@@ -848,11 +848,20 @@ def _(cls, task: TaskSDKAbstractOperator, run_id: str, *,
session: Session) -> i
@get_mapped_ti_count.register(MappedOperator)
@classmethod
def _(cls, task: MappedOperator, run_id: str, *, session: Session) ->
int:
- from airflow.serialization.serialized_objects import
_ExpandInputRef
+ from airflow.serialization.serialized_objects import
BaseSerialization, _ExpandInputRef
exp_input = task._get_specified_expand_input()
if isinstance(exp_input, _ExpandInputRef):
exp_input = exp_input.deref(task.dag)
+ # TODO: TaskSDK This is only needed to support `dag.test()` etc
until we port it over ot use the
Review Comment:
```suggestion
# TODO: TaskSDK This is only needed to support `dag.test()` etc
until we port it over to use the
```
##########
airflow/models/baseoperator.py:
##########
@@ -878,18 +887,24 @@ def _(cls, group: TaskGroup, run_id: str, *, session:
Session) -> int:
:raise NotFullyPopulated: If upstream tasks are not all complete
yet.
:return: Total number of mapped TIs this task should have.
"""
+ from airflow.serialization.serialized_objects import
BaseSerialization, _ExpandInputRef
- def iter_mapped_task_groups(group) -> Iterator[MappedTaskGroup]:
+ def iter_mapped_task_group_lengths(group) -> Iterator[int]:
while group is not None:
if isinstance(group, MappedTaskGroup):
- yield group
+ exp_input = group._expand_input
+ # TODO: TaskSDK This is only needed to support
`dag.test()` etc until we port it over ot use the
Review Comment:
```suggestion
# TODO: TaskSDK This is only needed to support
`dag.test()` etc until we port it over to use the
```
##########
task_sdk/src/airflow/sdk/execution_time/task_runner.py:
##########
@@ -204,7 +219,7 @@ def xcom_pull(
key: str = "return_value", # TODO: Make this a constant
(``XCOM_RETURN_KEY``)
include_prior_dates: bool = False, # TODO: Add support for this
*,
- map_indexes: int | Iterable[int] | None = None,
+ map_index: int | None | ArgNotSet = NOTSET,
Review Comment:
Does this impact users?
##########
task_sdk/tests/conftest.py:
##########
@@ -236,3 +239,135 @@ def mock_supervisor_comms():
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as supervisor_comms:
yield supervisor_comms
+
+
[email protected]
+def mocked_parse(spy_agency):
+ """
+ Fixture to set up an inline DAG and use it in a stubbed `parse` function.
Use this fixture if you
+ want to isolate and test `parse` or `run` logic without having to define a
DAG file.
+
+ This fixture returns a helper function `set_dag` that:
+ 1. Creates an in line DAG with the given `dag_id` and `task` (limited to
one task)
+ 2. Constructs a `RuntimeTaskInstance` based on the provided
`StartupDetails` and task.
+ 3. Stubs the `parse` function using `spy_agency`, to return the mocked
`RuntimeTaskInstance`.
+
+ After adding the fixture in your test function signature, you can use it
like this ::
+
+ mocked_parse(
+ StartupDetails(
+ ti=TaskInstance(id=uuid7(), task_id="hello",
dag_id="super_basic_run", run_id="c", try_number=1),
+ file="",
+ requests_fd=0,
+ ),
+ "example_dag_id",
+ CustomOperator(task_id="hello"),
+ )
+ """
+
+ def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) ->
RuntimeTaskInstance:
+ from airflow.sdk.definitions.dag import DAG
+ from airflow.sdk.execution_time.task_runner import
RuntimeTaskInstance, parse
+ from airflow.utils import timezone
+
+ if not task.has_dag():
+ dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3))
+ task.dag = dag
+ task = dag.task_dict[task.task_id]
+ else:
+ dag = task.dag
+ if what.ti_context.dag_run.conf:
+ dag.params = what.ti_context.dag_run.conf # type:
ignore[assignment]
+ ti = RuntimeTaskInstance.model_construct(
+ **what.ti.model_dump(exclude_unset=True),
+ task=task,
+ _ti_context_from_server=what.ti_context,
+ max_tries=what.ti_context.max_tries,
+ )
+ if hasattr(parse, "spy"):
+ spy_agency.unspy(parse)
+ spy_agency.spy_on(parse, call_fake=lambda _: ti)
+ return ti
+
+ return set_dag
+
+
[email protected]
+def create_runtime_ti(mocked_parse, make_ti_context):
+ """
+ Fixture to create a Runtime TaskInstance for testing purposes without
defining a dag file.
+
+ It mimics the behavior of the `parse` function by creating a
`RuntimeTaskInstance` based on the provided
+ `StartupDetails` (formed from arguments) and task. This allows you to test
the logic of a task without
+ having to define a DAG file, parse it, get context from the server, etc.
+
+ Example usage: ::
+
+ def test_custom_task_instance(create_runtime_ti):
+ class MyTaskOperator(BaseOperator):
+ def execute(self, context):
+ assert context["dag_run"].run_id == "test_run"
+
+ task = MyTaskOperator(task_id="test_task")
+ ti = create_runtime_ti(task,
context_from_server=make_ti_context(run_id="test_run"))
+ # Further test logic...
+ """
+ from uuid6 import uuid7
+
+ from airflow.sdk.api.datamodels._generated import TaskInstance
+ from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails
+
+ def _create_task_instance(
+ task: BaseOperator,
+ dag_id: str = "test_dag",
+ run_id: str = "test_run",
+ logical_date: str | datetime = "2024-12-01T01:00:00Z",
+ data_interval_start: str | datetime = "2024-12-01T00:00:00Z",
+ data_interval_end: str | datetime = "2024-12-01T01:00:00Z",
+ start_date: str | datetime = "2024-12-01T01:00:00Z",
+ run_type: str = "manual",
+ try_number: int = 1,
+ map_index: int | None = None,
Review Comment:
Should this default to -1? Seems more realistic.
##########
task_sdk/src/airflow/sdk/api/client.py:
##########
@@ -253,14 +253,25 @@ class XComOperations:
def __init__(self, client: Client):
self.client = client
+ def head(self, dag_id: str, run_id: str, task_id: str, key: str) -> int:
+ """Get the number of mapped XCom values."""
+ resp = self.client.head(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}")
+
+ # content_range: str | None
+ if not (content_range := resp.headers["Content-Range"]) or not
content_range.startswith(
+ "map_indexes "
+ ):
+ raise RuntimeError(f"Unable to parse Content-Range header from
HEAD {resp.request.url}")
+ return int(content_range[len("map_indexes ") :], base=10)
Review Comment:
`base=10` is the default, do we need it here?
##########
task_sdk/src/airflow/sdk/definitions/xcom_arg.py:
##########
@@ -354,17 +332,16 @@ def concat(self, *others: XComArg) -> ConcatXComArg:
raise ValueError("cannot concatenate non-return XCom")
return super().concat(*others)
- # TODO: Task-SDK: Remove session argument once everything is ported over
to Task SDK
- def resolve(
- self, context: Mapping[str, Any], session: Session | None = None, *,
include_xcom: bool = True
- ) -> Any:
+ def resolve(self, context: Mapping[str, Any]) -> Any:
ti = context["ti"]
task_id = self.operator.task_id
- map_indexes = context.get("_upstream_map_indexes", {}).get(task_id)
+
+ if self.operator.is_mapped:
+ return LazyXComSequence[Any](xcom_arg=self, ti=ti)
result = ti.xcom_pull(
task_ids=task_id,
- map_indexes=map_indexes,
+ # map_indexes=map_indexes,
Review Comment:
?
##########
providers/src/airflow/providers/microsoft/azure/operators/msgraph.py:
##########
@@ -241,7 +241,6 @@ def pull_xcom(self, context: Context) -> list:
key=self.key,
task_ids=self.task_id,
dag_id=self.dag_id,
- map_indexes=map_index,
Review Comment:
What about backcompat?
##########
providers/standard/tests/provider_tests/standard/operators/test_python.py:
##########
@@ -110,8 +109,12 @@ def base_tests_setup(self, request,
create_serialized_task_instance_of_operator,
self.run_id = f"run_{slugify(request.node.name, max_length=40)}"
self.ds_templated = self.default_date.date().isoformat()
self.ti_maker = create_serialized_task_instance_of_operator
+
self.dag_maker = dag_maker
self.dag_non_serialized = self.dag_maker(self.dag_id,
template_searchpath=TEMPLATE_SEARCHPATH).dag
+ # We need to entre the context in order to the factory to create things
Review Comment:
```suggestion
# We need to enter the context in order for the factory to create
things
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]