This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b3536217cd Fix typing in external task triggers (#31490)
b3536217cd is described below

commit b3536217cd80bda8b56068a2efb3fa6979d17b3f
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed May 24 14:55:50 2023 +0800

    Fix typing in external task triggers (#31490)
---
 airflow/triggers/external_task.py | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/airflow/triggers/external_task.py 
b/airflow/triggers/external_task.py
index 6099dc0a37..fc70a63a45 100644
--- a/airflow/triggers/external_task.py
+++ b/airflow/triggers/external_task.py
@@ -19,7 +19,6 @@ from __future__ import annotations
 import asyncio
 import datetime
 import typing
-from typing import Any
 
 from asgiref.sync import sync_to_async
 from sqlalchemy import func
@@ -27,7 +26,7 @@ from sqlalchemy.orm import Session
 
 from airflow.models import DagRun, TaskInstance
 from airflow.triggers.base import BaseTrigger, TriggerEvent
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
 
 
 class TaskStateTrigger(BaseTrigger):
@@ -59,7 +58,7 @@ class TaskStateTrigger(BaseTrigger):
         self.execution_dates = execution_dates
         self.poll_interval = poll_interval
 
-    def serialize(self) -> tuple[str, dict[str, Any]]:
+    def serialize(self) -> tuple[str, dict[str, typing.Any]]:
         """Serializes TaskStateTrigger arguments and classpath."""
         return (
             "airflow.triggers.external_task.TaskStateTrigger",
@@ -85,7 +84,7 @@ class TaskStateTrigger(BaseTrigger):
 
     @sync_to_async
     @provide_session
-    def count_tasks(self, session: Session) -> int | None:
+    def count_tasks(self, *, session: Session = NEW_SESSION) -> int | None:
         """Count how many task instances in the database match our criteria."""
         count = (
             session.query(func.count("*"))  # .count() is inefficient
@@ -124,7 +123,7 @@ class DagStateTrigger(BaseTrigger):
         self.execution_dates = execution_dates
         self.poll_interval = poll_interval
 
-    def serialize(self) -> tuple[str, dict[str, Any]]:
+    def serialize(self) -> tuple[str, dict[str, typing.Any]]:
         """Serializes DagStateTrigger arguments and classpath."""
         return (
             "airflow.triggers.external_task.DagStateTrigger",
@@ -149,7 +148,7 @@ class DagStateTrigger(BaseTrigger):
 
     @sync_to_async
     @provide_session
-    def count_dags(self, session: Session) -> int | None:
+    def count_dags(self, *, session: Session = NEW_SESSION) -> int | None:
         """Count how many dag runs in the database match our criteria."""
         count = (
             session.query(func.count("*"))  # .count() is inefficient

Reply via email to