This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f5a83bc90b Fix BaseOperator get_task_instances query (#33054)
f5a83bc90b is described below
commit f5a83bc90b237228db1434662f9dba5ebb719d47
Author: Hussein Awala <[email protected]>
AuthorDate: Thu Aug 3 23:21:02 2023 +0200
Fix BaseOperator get_task_instances query (#33054)
* Fix BaseOperator get_task_instances query
* add unit test
---
airflow/models/baseoperator.py | 13 +++++++------
tests/models/test_baseoperator.py | 38 +++++++++++++++++++++++++++++++++++++-
2 files changed, 44 insertions(+), 7 deletions(-)
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 99911f322e..7e861e20a6 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -1285,16 +1285,17 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
"""Get task instances related to this task for a specific date
range."""
from airflow.models import DagRun
- end_date = end_date or timezone.utcnow()
- return session.scalars(
+ query = (
select(TaskInstance)
.join(TaskInstance.dag_run)
.where(TaskInstance.dag_id == self.dag_id)
.where(TaskInstance.task_id == self.task_id)
- .where(DagRun.execution_date >= start_date)
- .where(DagRun.execution_date <= end_date)
- .order_by(DagRun.execution_date)
- ).all()
+ )
+ if start_date:
+ query = query.where(DagRun.execution_date >= start_date)
+ if end_date:
+ query = query.where(DagRun.execution_date <= end_date)
+ return session.scalars(query.order_by(DagRun.execution_date)).all()
@provide_session
def run(
diff --git a/tests/models/test_baseoperator.py
b/tests/models/test_baseoperator.py
index dbb842c314..b5735c650c 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -31,7 +31,7 @@ import pytest
from airflow.decorators import task as task_decorator
from airflow.exceptions import AirflowException,
FailStopDagInvalidTriggerRule, RemovedInAirflow3Warning
from airflow.lineage.entities import File
-from airflow.models import DAG
+from airflow.models import DAG, DagRun, TaskInstance
from airflow.models.baseoperator import (
BaseOperator,
BaseOperatorMeta,
@@ -43,6 +43,7 @@ from airflow.utils.context import Context
from airflow.utils.edgemodifier import Label
from airflow.utils.task_group import TaskGroup
from airflow.utils.trigger_rule import TriggerRule
+from airflow.utils.types import DagRunType
from airflow.utils.weight_rule import WeightRule
from tests.models import DEFAULT_DATE
from tests.test_utils.config import conf_vars
@@ -1023,3 +1024,38 @@ def test_teardown_and_fail_stop(dag_maker):
"tg_2.my_work": "skipped",
}
assert states == expected
+
+
+def test_get_task_instances(session):
+ import pendulum
+
+ first_execution_date = pendulum.datetime(2023, 1, 1)
+ second_execution_date = pendulum.datetime(2023, 1, 2)
+ third_execution_date = pendulum.datetime(2023, 1, 3)
+
+ test_dag = DAG(dag_id="test_dag", start_date=first_execution_date)
+ task = BaseOperator(task_id="test_task", dag=test_dag)
+
+ common_dr_kwargs = {
+ "dag_id": test_dag.dag_id,
+ "run_type": DagRunType.MANUAL,
+ }
+ dr1 = DagRun(execution_date=first_execution_date, run_id="test_run_id_1",
**common_dr_kwargs)
+ ti_1 = TaskInstance(run_id=dr1.run_id, task=task,
execution_date=first_execution_date)
+ dr2 = DagRun(execution_date=second_execution_date, run_id="test_run_id_2",
**common_dr_kwargs)
+ ti_2 = TaskInstance(run_id=dr2.run_id, task=task,
execution_date=second_execution_date)
+ dr3 = DagRun(execution_date=third_execution_date, run_id="test_run_id_3",
**common_dr_kwargs)
+ ti_3 = TaskInstance(run_id=dr3.run_id, task=task,
execution_date=third_execution_date)
+ session.add_all([dr1, dr2, dr3, ti_1, ti_2, ti_3])
+ session.commit()
+
+ # get all task instances
+ assert task.get_task_instances(session=session) == [ti_1, ti_2, ti_3]
+ # get task instances with start_date
+ assert task.get_task_instances(session=session,
start_date=second_execution_date) == [ti_2, ti_3]
+ # get task instances with end_date
+ assert task.get_task_instances(session=session,
end_date=second_execution_date) == [ti_1, ti_2]
+ # get task instances with start_date and end_date
+ assert task.get_task_instances(
+ session=session, start_date=second_execution_date,
end_date=second_execution_date
+ ) == [ti_2]