This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-7-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 7ecc84125dcf317e106cd2f3d1ed0d44b77b04d7 Author: Hussein Awala <[email protected]> AuthorDate: Thu Aug 3 23:21:02 2023 +0200 Fix BaseOperator get_task_instances query (#33054) * Fix BaseOperator get_task_instances query * add unit test (cherry picked from commit f5a83bc90b237228db1434662f9dba5ebb719d47) --- airflow/models/baseoperator.py | 13 +++++++------ tests/models/test_baseoperator.py | 38 +++++++++++++++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 99911f322e..7e861e20a6 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1285,16 +1285,17 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): """Get task instances related to this task for a specific date range.""" from airflow.models import DagRun - end_date = end_date or timezone.utcnow() - return session.scalars( + query = ( select(TaskInstance) .join(TaskInstance.dag_run) .where(TaskInstance.dag_id == self.dag_id) .where(TaskInstance.task_id == self.task_id) - .where(DagRun.execution_date >= start_date) - .where(DagRun.execution_date <= end_date) - .order_by(DagRun.execution_date) - ).all() + ) + if start_date: + query = query.where(DagRun.execution_date >= start_date) + if end_date: + query = query.where(DagRun.execution_date <= end_date) + return session.scalars(query.order_by(DagRun.execution_date)).all() @provide_session def run( diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index dbb842c314..b5735c650c 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -31,7 +31,7 @@ import pytest from airflow.decorators import task as task_decorator from airflow.exceptions import AirflowException, FailStopDagInvalidTriggerRule, RemovedInAirflow3Warning from airflow.lineage.entities import File -from airflow.models import DAG +from airflow.models import DAG, DagRun, TaskInstance from airflow.models.baseoperator import ( BaseOperator, BaseOperatorMeta, @@ -43,6 +43,7 @@ from airflow.utils.context import Context from airflow.utils.edgemodifier import Label from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule +from airflow.utils.types import DagRunType from airflow.utils.weight_rule import WeightRule from tests.models import DEFAULT_DATE from tests.test_utils.config import conf_vars @@ -1023,3 +1024,38 @@ def test_teardown_and_fail_stop(dag_maker): "tg_2.my_work": "skipped", } assert states == expected + + +def test_get_task_instances(session): + import pendulum + + first_execution_date = pendulum.datetime(2023, 1, 1) + second_execution_date = pendulum.datetime(2023, 1, 2) + third_execution_date = pendulum.datetime(2023, 1, 3) + + test_dag = DAG(dag_id="test_dag", start_date=first_execution_date) + task = BaseOperator(task_id="test_task", dag=test_dag) + + common_dr_kwargs = { + "dag_id": test_dag.dag_id, + "run_type": DagRunType.MANUAL, + } + dr1 = DagRun(execution_date=first_execution_date, run_id="test_run_id_1", **common_dr_kwargs) + ti_1 = TaskInstance(run_id=dr1.run_id, task=task, execution_date=first_execution_date) + dr2 = DagRun(execution_date=second_execution_date, run_id="test_run_id_2", **common_dr_kwargs) + ti_2 = TaskInstance(run_id=dr2.run_id, task=task, execution_date=second_execution_date) + dr3 = DagRun(execution_date=third_execution_date, run_id="test_run_id_3", **common_dr_kwargs) + ti_3 = TaskInstance(run_id=dr3.run_id, task=task, execution_date=third_execution_date) + session.add_all([dr1, dr2, dr3, ti_1, ti_2, ti_3]) + session.commit() + + # get all task instances + assert task.get_task_instances(session=session) == [ti_1, ti_2, ti_3] + # get task instances with start_date + assert task.get_task_instances(session=session, start_date=second_execution_date) == [ti_2, ti_3] + # get task instances with end_date + assert task.get_task_instances(session=session, end_date=second_execution_date) == [ti_1, ti_2] + # get task instances with start_date and end_date + assert task.get_task_instances( + session=session, start_date=second_execution_date, end_date=second_execution_date + ) == [ti_2]
