This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-7-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 7ecc84125dcf317e106cd2f3d1ed0d44b77b04d7
Author: Hussein Awala <[email protected]>
AuthorDate: Thu Aug 3 23:21:02 2023 +0200

    Fix BaseOperator get_task_instances query (#33054)
    
    * Fix BaseOperator get_task_instances query
    
    * add unit test
    
    (cherry picked from commit f5a83bc90b237228db1434662f9dba5ebb719d47)
---
 airflow/models/baseoperator.py    | 13 +++++++------
 tests/models/test_baseoperator.py | 38 +++++++++++++++++++++++++++++++++++++-
 2 files changed, 44 insertions(+), 7 deletions(-)

diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 99911f322e..7e861e20a6 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -1285,16 +1285,17 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         """Get task instances related to this task for a specific date 
range."""
         from airflow.models import DagRun
 
-        end_date = end_date or timezone.utcnow()
-        return session.scalars(
+        query = (
             select(TaskInstance)
             .join(TaskInstance.dag_run)
             .where(TaskInstance.dag_id == self.dag_id)
             .where(TaskInstance.task_id == self.task_id)
-            .where(DagRun.execution_date >= start_date)
-            .where(DagRun.execution_date <= end_date)
-            .order_by(DagRun.execution_date)
-        ).all()
+        )
+        if start_date:
+            query = query.where(DagRun.execution_date >= start_date)
+        if end_date:
+            query = query.where(DagRun.execution_date <= end_date)
+        return session.scalars(query.order_by(DagRun.execution_date)).all()
 
     @provide_session
     def run(
diff --git a/tests/models/test_baseoperator.py 
b/tests/models/test_baseoperator.py
index dbb842c314..b5735c650c 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -31,7 +31,7 @@ import pytest
 from airflow.decorators import task as task_decorator
 from airflow.exceptions import AirflowException, 
FailStopDagInvalidTriggerRule, RemovedInAirflow3Warning
 from airflow.lineage.entities import File
-from airflow.models import DAG
+from airflow.models import DAG, DagRun, TaskInstance
 from airflow.models.baseoperator import (
     BaseOperator,
     BaseOperatorMeta,
@@ -43,6 +43,7 @@ from airflow.utils.context import Context
 from airflow.utils.edgemodifier import Label
 from airflow.utils.task_group import TaskGroup
 from airflow.utils.trigger_rule import TriggerRule
+from airflow.utils.types import DagRunType
 from airflow.utils.weight_rule import WeightRule
 from tests.models import DEFAULT_DATE
 from tests.test_utils.config import conf_vars
@@ -1023,3 +1024,38 @@ def test_teardown_and_fail_stop(dag_maker):
         "tg_2.my_work": "skipped",
     }
     assert states == expected
+
+
+def test_get_task_instances(session):
+    import pendulum
+
+    first_execution_date = pendulum.datetime(2023, 1, 1)
+    second_execution_date = pendulum.datetime(2023, 1, 2)
+    third_execution_date = pendulum.datetime(2023, 1, 3)
+
+    test_dag = DAG(dag_id="test_dag", start_date=first_execution_date)
+    task = BaseOperator(task_id="test_task", dag=test_dag)
+
+    common_dr_kwargs = {
+        "dag_id": test_dag.dag_id,
+        "run_type": DagRunType.MANUAL,
+    }
+    dr1 = DagRun(execution_date=first_execution_date, run_id="test_run_id_1", 
**common_dr_kwargs)
+    ti_1 = TaskInstance(run_id=dr1.run_id, task=task, 
execution_date=first_execution_date)
+    dr2 = DagRun(execution_date=second_execution_date, run_id="test_run_id_2", 
**common_dr_kwargs)
+    ti_2 = TaskInstance(run_id=dr2.run_id, task=task, 
execution_date=second_execution_date)
+    dr3 = DagRun(execution_date=third_execution_date, run_id="test_run_id_3", 
**common_dr_kwargs)
+    ti_3 = TaskInstance(run_id=dr3.run_id, task=task, 
execution_date=third_execution_date)
+    session.add_all([dr1, dr2, dr3, ti_1, ti_2, ti_3])
+    session.commit()
+
+    # get all task instances
+    assert task.get_task_instances(session=session) == [ti_1, ti_2, ti_3]
+    # get task instances with start_date
+    assert task.get_task_instances(session=session, 
start_date=second_execution_date) == [ti_2, ti_3]
+    # get task instances with end_date
+    assert task.get_task_instances(session=session, 
end_date=second_execution_date) == [ti_1, ti_2]
+    # get task instances with start_date and end_date
+    assert task.get_task_instances(
+        session=session, start_date=second_execution_date, 
end_date=second_execution_date
+    ) == [ti_2]

Reply via email to