This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 125c7af1b8 Improve TaskInstance typing hints (#36487)
125c7af1b8 is described below
commit 125c7af1b8fa6c95e2093172d8d06655ee7f0c19
Author: Jens Scheffler <[email protected]>
AuthorDate: Wed Jan 3 23:45:20 2024 +0100
Improve TaskInstance typing hints (#36487)
* Improve TaskInstance typing hints
* Fix mypy typing by introducing a base class for task instance dependencies
* Add missing TaskMap to clean reference
---
airflow/models/base.py | 13 ++++++++++-
airflow/models/renderedtifields.py | 4 ++--
airflow/models/taskfail.py | 4 ++--
airflow/models/taskinstance.py | 46 +++++++++++++++++++++++++-------------
airflow/models/taskmap.py | 4 ++--
airflow/models/taskreschedule.py | 4 ++--
airflow/models/xcom.py | 4 ++--
7 files changed, 52 insertions(+), 27 deletions(-)
diff --git a/airflow/models/base.py b/airflow/models/base.py
index 934b9b1b74..31b6f9538f 100644
--- a/airflow/models/base.py
+++ b/airflow/models/base.py
@@ -19,7 +19,7 @@ from __future__ import annotations
from typing import Any
-from sqlalchemy import MetaData, String
+from sqlalchemy import Column, Integer, MetaData, String, text
from sqlalchemy.orm import registry
from airflow.configuration import conf
@@ -79,3 +79,14 @@ COLLATION_ARGS: dict[str, Any] = get_id_collation_args()
def StringID(*, length=ID_LEN, **kwargs) -> String:
return String(length=length, **kwargs, **COLLATION_ARGS)
+
+
+class TaskInstanceDependencies(Base):
+ """Base class for depending models linked to TaskInstance."""
+
+ __abstract__ = True
+
+ task_id = Column(StringID(), nullable=False)
+ dag_id = Column(StringID(), nullable=False)
+ run_id = Column(StringID(), nullable=False)
+ map_index = Column(Integer, nullable=False, server_default=text("-1"))
diff --git a/airflow/models/renderedtifields.py
b/airflow/models/renderedtifields.py
index c54ee0e903..ff850e24ee 100644
--- a/airflow/models/renderedtifields.py
+++ b/airflow/models/renderedtifields.py
@@ -36,7 +36,7 @@ from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import relationship
from airflow.configuration import conf
-from airflow.models.base import Base, StringID
+from airflow.models.base import StringID, TaskInstanceDependencies
from airflow.serialization.helpers import serialize_template_field
from airflow.settings import json
from airflow.utils.retries import retry_db_transaction
@@ -49,7 +49,7 @@ if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance, TaskInstancePydantic
-class RenderedTaskInstanceFields(Base):
+class RenderedTaskInstanceFields(TaskInstanceDependencies):
"""Save Rendered Template Fields."""
__tablename__ = "rendered_task_instance_fields"
diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py
index 7ae459ab8c..c8f2ac51fc 100644
--- a/airflow/models/taskfail.py
+++ b/airflow/models/taskfail.py
@@ -21,11 +21,11 @@ from __future__ import annotations
from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, text
from sqlalchemy.orm import relationship
-from airflow.models.base import Base, StringID
+from airflow.models.base import StringID, TaskInstanceDependencies
from airflow.utils.sqlalchemy import UtcDateTime
-class TaskFail(Base):
+class TaskFail(TaskInstanceDependencies):
"""TaskFail tracks the failed run durations of each task instance."""
__tablename__ = "task_fail"
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index dfe93c1ae6..15761ea2e4 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -85,7 +85,7 @@ from airflow.exceptions import (
XComForMappingNotPushed,
)
from airflow.listeners.listener import get_listener_manager
-from airflow.models.base import Base, StringID
+from airflow.models.base import Base, StringID, TaskInstanceDependencies
from airflow.models.dagbag import DagBag
from airflow.models.log import Log
from airflow.models.mappedoperator import MappedOperator
@@ -147,6 +147,7 @@ if TYPE_CHECKING:
from airflow.models.dataset import DatasetEvent
from airflow.models.operator import Operator
from airflow.serialization.pydantic.dag import DagModelPydantic
+ from airflow.serialization.pydantic.dataset import DatasetEventPydantic
from airflow.serialization.pydantic.taskinstance import
TaskInstancePydantic
from airflow.timetables.base import DataInterval
from airflow.typing_compat import Literal, TypeGuard
@@ -382,7 +383,7 @@ def _creator_note(val):
return TaskInstanceNote(*val)
-def _execute_task(task_instance, context, task_orig):
+def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context:
Context, task_orig: Operator):
"""
Execute Task (optionally with a Timeout) and push Xcom results.
@@ -400,7 +401,8 @@ def _execute_task(task_instance, context, task_orig):
# If the task has been deferred and is being executed due to a trigger,
# then we need to pick the right method to come back to, otherwise
# we go for the default execute
- execute_callable_kwargs = {}
+ execute_callable_kwargs: dict[str, Any] = {}
+ execute_callable: Callable
if task_instance.next_method:
if task_instance.next_method:
execute_callable = task_to_execute.resume_execution
@@ -413,7 +415,7 @@ def _execute_task(task_instance, context, task_orig):
if task_to_execute.execution_timeout:
# If we are coming in with a next_method (i.e. from a deferral),
# calculate the timeout from our start_date.
- if task_instance.next_method:
+ if task_instance.next_method and task_instance.start_date:
timeout_seconds = (
task_to_execute.execution_timeout - (timezone.utcnow() -
task_instance.start_date)
).total_seconds()
@@ -543,7 +545,7 @@ def _clear_next_method_args(*, task_instance: TaskInstance
| TaskInstancePydanti
def _get_template_context(
*,
- task_instance,
+ task_instance: TaskInstance | TaskInstancePydantic,
session: Session | None = None,
ignore_param_exceptions: bool = True,
) -> Context:
@@ -575,7 +577,7 @@ def _get_template_context(
validated_params = process_params(dag, task, dag_run,
suppress_exception=ignore_param_exceptions)
- logical_date = timezone.coerce_datetime(task_instance.execution_date)
+ logical_date: DateTime =
timezone.coerce_datetime(task_instance.execution_date)
ds = logical_date.strftime("%Y-%m-%d")
ds_nodash = ds.replace("-", "")
ts = logical_date.isoformat()
@@ -682,7 +684,7 @@ def _get_template_context(
return None
return prev_ds.replace("-", "")
- def get_triggering_events() -> dict[str, list[DatasetEvent]]:
+ def get_triggering_events() -> dict[str, list[DatasetEvent |
DatasetEventPydantic]]:
if TYPE_CHECKING:
assert session is not None
@@ -694,9 +696,10 @@ def _get_template_context(
dag_run = session.merge(dag_run, load=False)
dataset_events = dag_run.consumed_dataset_events
- triggering_events: dict[str, list[DatasetEvent]] = defaultdict(list)
+ triggering_events: dict[str, list[DatasetEvent |
DatasetEventPydantic]] = defaultdict(list)
for event in dataset_events:
- triggering_events[event.dataset.uri].append(event)
+ if event.dataset:
+ triggering_events[event.dataset.uri].append(event)
return triggering_events
@@ -2430,7 +2433,7 @@ class TaskInstance(Base, LoggingMixin):
session=session,
)
- def _execute_task_with_callbacks(self, context, test_mode: bool = False,
*, session: Session):
+ def _execute_task_with_callbacks(self, context: Context, test_mode: bool =
False, *, session: Session):
"""Prepare Task for Execution."""
from airflow.models.renderedtifields import RenderedTaskInstanceFields
@@ -2607,7 +2610,11 @@ class TaskInstance(Base, LoggingMixin):
@provide_session
def _handle_reschedule(
- self, actual_start_date, reschedule_exception, test_mode=False,
session=NEW_SESSION
+ self,
+ actual_start_date: datetime,
+ reschedule_exception: AirflowRescheduleException,
+ test_mode: bool = False,
+ session: Session = NEW_SESSION,
):
# Don't record reschedule request in test mode
if test_mode:
@@ -2852,7 +2859,7 @@ class TaskInstance(Base, LoggingMixin):
"rendering of template_fields."
) from e
- def overwrite_params_with_dag_run_conf(self, params, dag_run):
+ def overwrite_params_with_dag_run_conf(self, params: dict, dag_run:
DagRun):
"""Overwrite Task Params with DagRun.conf."""
if dag_run and dag_run.conf:
self.log.debug("Updating task params (%s) with DagRun.conf (%s)",
params, dag_run.conf)
@@ -3084,7 +3091,7 @@ class TaskInstance(Base, LoggingMixin):
return LazyXComAccess.build_from_xcom_query(query)
@provide_session
- def get_num_running_task_instances(self, session: Session,
same_dagrun=False) -> int:
+ def get_num_running_task_instances(self, session: Session, same_dagrun:
bool = False) -> int:
"""Return Number of running TIs from the DB."""
# .count() is inefficient
num_running_task_instances_query = session.query(func.count()).filter(
@@ -3382,7 +3389,7 @@ class TaskInstance(Base, LoggingMixin):
map_index_start = ancestor_map_index * further_count
return range(map_index_start, map_index_start + further_count)
- def clear_db_references(self, session):
+ def clear_db_references(self, session: Session):
"""
Clear db tables that have a reference to this instance.
@@ -3392,7 +3399,14 @@ class TaskInstance(Base, LoggingMixin):
"""
from airflow.models.renderedtifields import RenderedTaskInstanceFields
- tables = [TaskFail, TaskInstanceNote, TaskReschedule, XCom,
RenderedTaskInstanceFields]
+ tables: list[type[TaskInstanceDependencies]] = [
+ TaskFail,
+ TaskInstanceNote,
+ TaskReschedule,
+ XCom,
+ RenderedTaskInstanceFields,
+ TaskMap,
+ ]
for table in tables:
session.execute(
delete(table).where(
@@ -3527,7 +3541,7 @@ class SimpleTaskInstance:
return cls(**obj_dict, start_date=start_date, end_date=end_date,
key=ti_key)
-class TaskInstanceNote(Base):
+class TaskInstanceNote(TaskInstanceDependencies):
"""For storage of arbitrary notes concerning the task instance."""
__tablename__ = "task_instance_note"
diff --git a/airflow/models/taskmap.py b/airflow/models/taskmap.py
index 12658f8b7e..7211203171 100644
--- a/airflow/models/taskmap.py
+++ b/airflow/models/taskmap.py
@@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Collection
from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer,
String
-from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
+from airflow.models.base import COLLATION_ARGS, ID_LEN,
TaskInstanceDependencies
from airflow.utils.sqlalchemy import ExtendedJSON
if TYPE_CHECKING:
@@ -43,7 +43,7 @@ class TaskMapVariant(enum.Enum):
LIST = "list"
-class TaskMap(Base):
+class TaskMap(TaskInstanceDependencies):
"""Model to track dynamic task-mapping information.
This is currently only populated by an upstream TaskInstance pushing an
diff --git a/airflow/models/taskreschedule.py b/airflow/models/taskreschedule.py
index 49107fda4d..a098001773 100644
--- a/airflow/models/taskreschedule.py
+++ b/airflow/models/taskreschedule.py
@@ -26,7 +26,7 @@ from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import relationship
from airflow.exceptions import RemovedInAirflow3Warning
-from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
+from airflow.models.base import COLLATION_ARGS, ID_LEN,
TaskInstanceDependencies
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime
@@ -40,7 +40,7 @@ if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance
-class TaskReschedule(Base):
+class TaskReschedule(TaskInstanceDependencies):
"""TaskReschedule tracks rescheduled task instances."""
__tablename__ = "task_reschedule"
diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index f10e7b4637..ee74dd89f7 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -48,7 +48,7 @@ from airflow import settings
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.exceptions import RemovedInAirflow3Warning
-from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
+from airflow.models.base import COLLATION_ARGS, ID_LEN,
TaskInstanceDependencies
from airflow.utils import timezone
from airflow.utils.helpers import exactly_one, is_container
from airflow.utils.json import XComDecoder, XComEncoder
@@ -74,7 +74,7 @@ if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
-class BaseXCom(Base, LoggingMixin):
+class BaseXCom(TaskInstanceDependencies, LoggingMixin):
"""Base class for XCom objects."""
__tablename__ = "xcom"