This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 125c7af1b8 Improve TaskInstance typing hints (#36487)
125c7af1b8 is described below

commit 125c7af1b8fa6c95e2093172d8d06655ee7f0c19
Author: Jens Scheffler <[email protected]>
AuthorDate: Wed Jan 3 23:45:20 2024 +0100

    Improve TaskInstance typing hints (#36487)
    
    * Improve TaskInstance typing hints
    
    * Fix mypy typing by introducing a base class for task instance dependencies
    
    * Add missing TaskMap to clean reference
---
 airflow/models/base.py             | 13 ++++++++++-
 airflow/models/renderedtifields.py |  4 ++--
 airflow/models/taskfail.py         |  4 ++--
 airflow/models/taskinstance.py     | 46 +++++++++++++++++++++++++-------------
 airflow/models/taskmap.py          |  4 ++--
 airflow/models/taskreschedule.py   |  4 ++--
 airflow/models/xcom.py             |  4 ++--
 7 files changed, 52 insertions(+), 27 deletions(-)

diff --git a/airflow/models/base.py b/airflow/models/base.py
index 934b9b1b74..31b6f9538f 100644
--- a/airflow/models/base.py
+++ b/airflow/models/base.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 
 from typing import Any
 
-from sqlalchemy import MetaData, String
+from sqlalchemy import Column, Integer, MetaData, String, text
 from sqlalchemy.orm import registry
 
 from airflow.configuration import conf
@@ -79,3 +79,14 @@ COLLATION_ARGS: dict[str, Any] = get_id_collation_args()
 
 def StringID(*, length=ID_LEN, **kwargs) -> String:
     return String(length=length, **kwargs, **COLLATION_ARGS)
+
+
+class TaskInstanceDependencies(Base):
+    """Base class for depending models linked to TaskInstance."""
+
+    __abstract__ = True
+
+    task_id = Column(StringID(), nullable=False)
+    dag_id = Column(StringID(), nullable=False)
+    run_id = Column(StringID(), nullable=False)
+    map_index = Column(Integer, nullable=False, server_default=text("-1"))
diff --git a/airflow/models/renderedtifields.py 
b/airflow/models/renderedtifields.py
index c54ee0e903..ff850e24ee 100644
--- a/airflow/models/renderedtifields.py
+++ b/airflow/models/renderedtifields.py
@@ -36,7 +36,7 @@ from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.orm import relationship
 
 from airflow.configuration import conf
-from airflow.models.base import Base, StringID
+from airflow.models.base import StringID, TaskInstanceDependencies
 from airflow.serialization.helpers import serialize_template_field
 from airflow.settings import json
 from airflow.utils.retries import retry_db_transaction
@@ -49,7 +49,7 @@ if TYPE_CHECKING:
     from airflow.models.taskinstance import TaskInstance, TaskInstancePydantic
 
 
-class RenderedTaskInstanceFields(Base):
+class RenderedTaskInstanceFields(TaskInstanceDependencies):
     """Save Rendered Template Fields."""
 
     __tablename__ = "rendered_task_instance_fields"
diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py
index 7ae459ab8c..c8f2ac51fc 100644
--- a/airflow/models/taskfail.py
+++ b/airflow/models/taskfail.py
@@ -21,11 +21,11 @@ from __future__ import annotations
 from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, text
 from sqlalchemy.orm import relationship
 
-from airflow.models.base import Base, StringID
+from airflow.models.base import StringID, TaskInstanceDependencies
 from airflow.utils.sqlalchemy import UtcDateTime
 
 
-class TaskFail(Base):
+class TaskFail(TaskInstanceDependencies):
     """TaskFail tracks the failed run durations of each task instance."""
 
     __tablename__ = "task_fail"
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index dfe93c1ae6..15761ea2e4 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -85,7 +85,7 @@ from airflow.exceptions import (
     XComForMappingNotPushed,
 )
 from airflow.listeners.listener import get_listener_manager
-from airflow.models.base import Base, StringID
+from airflow.models.base import Base, StringID, TaskInstanceDependencies
 from airflow.models.dagbag import DagBag
 from airflow.models.log import Log
 from airflow.models.mappedoperator import MappedOperator
@@ -147,6 +147,7 @@ if TYPE_CHECKING:
     from airflow.models.dataset import DatasetEvent
     from airflow.models.operator import Operator
     from airflow.serialization.pydantic.dag import DagModelPydantic
+    from airflow.serialization.pydantic.dataset import DatasetEventPydantic
     from airflow.serialization.pydantic.taskinstance import 
TaskInstancePydantic
     from airflow.timetables.base import DataInterval
     from airflow.typing_compat import Literal, TypeGuard
@@ -382,7 +383,7 @@ def _creator_note(val):
         return TaskInstanceNote(*val)
 
 
-def _execute_task(task_instance, context, task_orig):
+def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: 
Context, task_orig: Operator):
     """
     Execute Task (optionally with a Timeout) and push Xcom results.
 
@@ -400,7 +401,8 @@ def _execute_task(task_instance, context, task_orig):
     # If the task has been deferred and is being executed due to a trigger,
     # then we need to pick the right method to come back to, otherwise
     # we go for the default execute
-    execute_callable_kwargs = {}
+    execute_callable_kwargs: dict[str, Any] = {}
+    execute_callable: Callable
     if task_instance.next_method:
         if task_instance.next_method:
             execute_callable = task_to_execute.resume_execution
@@ -413,7 +415,7 @@ def _execute_task(task_instance, context, task_orig):
     if task_to_execute.execution_timeout:
         # If we are coming in with a next_method (i.e. from a deferral),
         # calculate the timeout from our start_date.
-        if task_instance.next_method:
+        if task_instance.next_method and task_instance.start_date:
             timeout_seconds = (
                 task_to_execute.execution_timeout - (timezone.utcnow() - 
task_instance.start_date)
             ).total_seconds()
@@ -543,7 +545,7 @@ def _clear_next_method_args(*, task_instance: TaskInstance 
| TaskInstancePydanti
 
 def _get_template_context(
     *,
-    task_instance,
+    task_instance: TaskInstance | TaskInstancePydantic,
     session: Session | None = None,
     ignore_param_exceptions: bool = True,
 ) -> Context:
@@ -575,7 +577,7 @@ def _get_template_context(
 
     validated_params = process_params(dag, task, dag_run, 
suppress_exception=ignore_param_exceptions)
 
-    logical_date = timezone.coerce_datetime(task_instance.execution_date)
+    logical_date: DateTime = 
timezone.coerce_datetime(task_instance.execution_date)
     ds = logical_date.strftime("%Y-%m-%d")
     ds_nodash = ds.replace("-", "")
     ts = logical_date.isoformat()
@@ -682,7 +684,7 @@ def _get_template_context(
             return None
         return prev_ds.replace("-", "")
 
-    def get_triggering_events() -> dict[str, list[DatasetEvent]]:
+    def get_triggering_events() -> dict[str, list[DatasetEvent | 
DatasetEventPydantic]]:
         if TYPE_CHECKING:
             assert session is not None
 
@@ -694,9 +696,10 @@ def _get_template_context(
             dag_run = session.merge(dag_run, load=False)
 
         dataset_events = dag_run.consumed_dataset_events
-        triggering_events: dict[str, list[DatasetEvent]] = defaultdict(list)
+        triggering_events: dict[str, list[DatasetEvent | 
DatasetEventPydantic]] = defaultdict(list)
         for event in dataset_events:
-            triggering_events[event.dataset.uri].append(event)
+            if event.dataset:
+                triggering_events[event.dataset.uri].append(event)
 
         return triggering_events
 
@@ -2430,7 +2433,7 @@ class TaskInstance(Base, LoggingMixin):
                     session=session,
                 )
 
-    def _execute_task_with_callbacks(self, context, test_mode: bool = False, 
*, session: Session):
+    def _execute_task_with_callbacks(self, context: Context, test_mode: bool = 
False, *, session: Session):
         """Prepare Task for Execution."""
         from airflow.models.renderedtifields import RenderedTaskInstanceFields
 
@@ -2607,7 +2610,11 @@ class TaskInstance(Base, LoggingMixin):
 
     @provide_session
     def _handle_reschedule(
-        self, actual_start_date, reschedule_exception, test_mode=False, 
session=NEW_SESSION
+        self,
+        actual_start_date: datetime,
+        reschedule_exception: AirflowRescheduleException,
+        test_mode: bool = False,
+        session: Session = NEW_SESSION,
     ):
         # Don't record reschedule request in test mode
         if test_mode:
@@ -2852,7 +2859,7 @@ class TaskInstance(Base, LoggingMixin):
                 "rendering of template_fields."
             ) from e
 
-    def overwrite_params_with_dag_run_conf(self, params, dag_run):
+    def overwrite_params_with_dag_run_conf(self, params: dict, dag_run: 
DagRun):
         """Overwrite Task Params with DagRun.conf."""
         if dag_run and dag_run.conf:
             self.log.debug("Updating task params (%s) with DagRun.conf (%s)", 
params, dag_run.conf)
@@ -3084,7 +3091,7 @@ class TaskInstance(Base, LoggingMixin):
         return LazyXComAccess.build_from_xcom_query(query)
 
     @provide_session
-    def get_num_running_task_instances(self, session: Session, 
same_dagrun=False) -> int:
+    def get_num_running_task_instances(self, session: Session, same_dagrun: 
bool = False) -> int:
         """Return Number of running TIs from the DB."""
         # .count() is inefficient
         num_running_task_instances_query = session.query(func.count()).filter(
@@ -3382,7 +3389,7 @@ class TaskInstance(Base, LoggingMixin):
         map_index_start = ancestor_map_index * further_count
         return range(map_index_start, map_index_start + further_count)
 
-    def clear_db_references(self, session):
+    def clear_db_references(self, session: Session):
         """
         Clear db tables that have a reference to this instance.
 
@@ -3392,7 +3399,14 @@ class TaskInstance(Base, LoggingMixin):
         """
         from airflow.models.renderedtifields import RenderedTaskInstanceFields
 
-        tables = [TaskFail, TaskInstanceNote, TaskReschedule, XCom, 
RenderedTaskInstanceFields]
+        tables: list[type[TaskInstanceDependencies]] = [
+            TaskFail,
+            TaskInstanceNote,
+            TaskReschedule,
+            XCom,
+            RenderedTaskInstanceFields,
+            TaskMap,
+        ]
         for table in tables:
             session.execute(
                 delete(table).where(
@@ -3527,7 +3541,7 @@ class SimpleTaskInstance:
         return cls(**obj_dict, start_date=start_date, end_date=end_date, 
key=ti_key)
 
 
-class TaskInstanceNote(Base):
+class TaskInstanceNote(TaskInstanceDependencies):
     """For storage of arbitrary notes concerning the task instance."""
 
     __tablename__ = "task_instance_note"
diff --git a/airflow/models/taskmap.py b/airflow/models/taskmap.py
index 12658f8b7e..7211203171 100644
--- a/airflow/models/taskmap.py
+++ b/airflow/models/taskmap.py
@@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Collection
 
 from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, 
String
 
-from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
+from airflow.models.base import COLLATION_ARGS, ID_LEN, 
TaskInstanceDependencies
 from airflow.utils.sqlalchemy import ExtendedJSON
 
 if TYPE_CHECKING:
@@ -43,7 +43,7 @@ class TaskMapVariant(enum.Enum):
     LIST = "list"
 
 
-class TaskMap(Base):
+class TaskMap(TaskInstanceDependencies):
     """Model to track dynamic task-mapping information.
 
     This is currently only populated by an upstream TaskInstance pushing an
diff --git a/airflow/models/taskreschedule.py b/airflow/models/taskreschedule.py
index 49107fda4d..a098001773 100644
--- a/airflow/models/taskreschedule.py
+++ b/airflow/models/taskreschedule.py
@@ -26,7 +26,7 @@ from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.orm import relationship
 
 from airflow.exceptions import RemovedInAirflow3Warning
-from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
+from airflow.models.base import COLLATION_ARGS, ID_LEN, 
TaskInstanceDependencies
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import UtcDateTime
 
@@ -40,7 +40,7 @@ if TYPE_CHECKING:
     from airflow.models.taskinstance import TaskInstance
 
 
-class TaskReschedule(Base):
+class TaskReschedule(TaskInstanceDependencies):
     """TaskReschedule tracks rescheduled task instances."""
 
     __tablename__ = "task_reschedule"
diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index f10e7b4637..ee74dd89f7 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -48,7 +48,7 @@ from airflow import settings
 from airflow.api_internal.internal_api_call import internal_api_call
 from airflow.configuration import conf
 from airflow.exceptions import RemovedInAirflow3Warning
-from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
+from airflow.models.base import COLLATION_ARGS, ID_LEN, 
TaskInstanceDependencies
 from airflow.utils import timezone
 from airflow.utils.helpers import exactly_one, is_container
 from airflow.utils.json import XComDecoder, XComEncoder
@@ -74,7 +74,7 @@ if TYPE_CHECKING:
     from airflow.models.taskinstancekey import TaskInstanceKey
 
 
-class BaseXCom(Base, LoggingMixin):
+class BaseXCom(TaskInstanceDependencies, LoggingMixin):
     """Base class for XCom objects."""
 
     __tablename__ = "xcom"

Reply via email to