shahar1 commented on code in PR #46584:
URL: https://github.com/apache/airflow/pull/46584#discussion_r1976684764
##########
providers/standard/tests/unit/standard/operators/test_latest_only_operator.py:
##########
@@ -98,81 +101,94 @@ def test_skipping_non_latest(self, dag_maker):
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if
AIRFLOW_V_3_0_PLUS else {}
- dag_maker.create_dagrun(
+ dr0 = dag_maker.create_dagrun(
run_type=DagRunType.SCHEDULED,
start_date=timezone.utcnow(),
logical_date=DEFAULT_DATE,
state=State.RUNNING,
- data_interval=(DEFAULT_DATE, DEFAULT_DATE),
+ data_interval=DataInterval(DEFAULT_DATE, DEFAULT_DATE),
**triggered_by_kwargs,
)
- dag_maker.create_dagrun(
+ dr1 = dag_maker.create_dagrun(
run_type=DagRunType.SCHEDULED,
start_date=timezone.utcnow(),
logical_date=timezone.datetime(2016, 1, 1, 12),
state=State.RUNNING,
- data_interval=(timezone.datetime(2016, 1, 1, 12),
timezone.datetime(2016, 1, 1, 12) + INTERVAL),
+ data_interval=DataInterval(
+ timezone.datetime(2016, 1, 1, 12), timezone.datetime(2016, 1,
1, 12) + INTERVAL
+ ),
**triggered_by_kwargs,
)
- dag_maker.create_dagrun(
+ dr2 = dag_maker.create_dagrun(
run_type=DagRunType.SCHEDULED,
start_date=timezone.utcnow(),
logical_date=END_DATE,
state=State.RUNNING,
- data_interval=(END_DATE, END_DATE + INTERVAL),
+ data_interval=DataInterval(END_DATE, END_DATE + INTERVAL),
**triggered_by_kwargs,
)
- latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
- downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
- downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE)
- downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE)
-
- latest_instances = get_task_instances("latest")
if AIRFLOW_V_3_0_PLUS:
- exec_date_to_latest_state = {ti.logical_date: ti.state for ti in
latest_instances}
+ for idx, (dr, exec_date) in enumerate(
+ [
+ (dr0, timezone.datetime(2016, 1, 1)),
+ (dr1, timezone.datetime(2016, 1, 1, 12)),
+ (dr2, timezone.datetime(2016, 1, 2)),
+ ]
+ ):
+ ## FIXME: dr1 and dr2 raise "AttributeError: 'NoneType' object
has no attribute 'queue'"
+ # if idx in [1, 2]:
+ # continue
+
+ latest_ti = dr.get_task_instance(task_id="latest")
+ with pytest.raises(DownstreamTasksSkipped) as exc_info:
+ latest_ti.run()
+ assert exc_info.value.tasks == [("downstream", -1)]
+ ## FIXME: Cannot assert execution date because we run the TI
directly and not via SDK
+ # (Fails due to the DownstreamTasksSkipped signal)
+ # assert latest_ti.execution_date == exec_date
+
+ ## TODO: assert the state of the downstream tasks after
resolving the above issues
Review Comment:
@ashb
tl;dr - how to run (mocked?) Task SDK in unit tests?
Since the ShortCircuit and Branch operators now rely on signaling
`DownstreamTasksSkipped` (which are eventually handled by the Task SDK), it
makes it problematic to adjust some existing tests that currently run TIS
directly.
To mock a simple TI's run of such operator, without the Task SDK, I have to
wrap it as follows:
```python
with pytest.raises(DownstreamTasksSkipped) as exc_info:
skippable_ti.run()
assert exc_info.value.tasks == [('task1'), -1] # that's how to assert that
specific tasks are skipped
skippable_ti.set_state(TaskInstanceState.SUCCESS) # we need to do if we
schedule downstream as part of the test, as it should be done by the Task SDK
```
For simple operators it works pretty much ok, but then we have more complex
operators like this one where we also need to access the execution dates of
each TI.
Is it somehow possible to integrate an ability to run the Task SDK for unit
tests?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]