This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 153200fa052 Revert "Remove AIP-44 from airflow/sensors/base.py 
(#44518)" (#44527)
153200fa052 is described below

commit 153200fa05229546cb91cc341a6088a7d74f88ac
Author: Jens Scheffler <[email protected]>
AuthorDate: Sun Dec 1 09:47:41 2024 +0100

    Revert "Remove AIP-44 from airflow/sensors/base.py (#44518)" (#44527)
    
    This reverts commit de94c067486c5df68648a069796e06137608a73e.
---
 airflow/sensors/base.py | 50 ++++++++++++++++++++++++++++++++++---------------
 1 file changed, 35 insertions(+), 15 deletions(-)

diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py
index c0cf255cad8..1c56aa42005 100644
--- a/airflow/sensors/base.py
+++ b/airflow/sensors/base.py
@@ -29,6 +29,7 @@ from typing import TYPE_CHECKING, Any, Callable
 from sqlalchemy import select
 
 from airflow import settings
+from airflow.api_internal.internal_api_call import internal_api_call
 from airflow.configuration import conf
 from airflow.exceptions import (
     AirflowException,
@@ -48,6 +49,8 @@ from airflow.utils import timezone
 from airflow.utils.session import NEW_SESSION, provide_session
 
 if TYPE_CHECKING:
+    from sqlalchemy.orm.session import Session
+
     from airflow.utils.context import Context
 
 # As documented in https://dev.mysql.com/doc/refman/5.7/en/datetime.html.
@@ -80,6 +83,31 @@ class PokeReturnValue:
         return self.is_done
 
 
+@internal_api_call
+@provide_session
+def _orig_start_date(
+    dag_id: str, task_id: str, run_id: str, map_index: int, try_number: int, 
session: Session = NEW_SESSION
+):
+    """
+    Get the original start_date for a rescheduled task.
+
+    :meta private:
+    """
+    return session.scalar(
+        select(TaskReschedule)
+        .where(
+            TaskReschedule.dag_id == dag_id,
+            TaskReschedule.task_id == task_id,
+            TaskReschedule.run_id == run_id,
+            TaskReschedule.map_index == map_index,
+            TaskReschedule.try_number == try_number,
+        )
+        .order_by(TaskReschedule.id.asc())
+        .with_only_columns(TaskReschedule.start_date)
+        .limit(1)
+    )
+
+
 class BaseSensorOperator(BaseOperator, SkipMixin):
     """
     Sensor operators are derived from this class and inherit these attributes.
@@ -211,8 +239,7 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
         """Override when deriving this class."""
         raise AirflowException("Override me.")
 
-    @provide_session
-    def execute(self, context: Context, session=NEW_SESSION) -> Any:
+    def execute(self, context: Context) -> Any:
         started_at: datetime.datetime | float
 
         if self.reschedule:
@@ -222,19 +249,12 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
             # If reschedule, use the start date of the first try (first try 
can be either the very
             # first execution of the task, or the first execution after the 
task was cleared.)
             first_try_number = max_tries - retries + 1
-
-            start_date = session.scalar(
-                select(TaskReschedule)
-                .where(
-                    TaskReschedule.dag_id == ti.dag_id,
-                    TaskReschedule.task_id == ti.task_id,
-                    TaskReschedule.run_id == ti.run_id,
-                    TaskReschedule.map_index == ti.map_index,
-                    TaskReschedule.try_number == first_try_number,
-                )
-                .order_by(TaskReschedule.id.asc())
-                .with_only_columns(TaskReschedule.start_date)
-                .limit(1)
+            start_date = _orig_start_date(
+                dag_id=ti.dag_id,
+                task_id=ti.task_id,
+                run_id=ti.run_id,
+                map_index=ti.map_index,
+                try_number=first_try_number,
             )
             if not start_date:
                 start_date = timezone.utcnow()

Reply via email to