This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 66407e8d226 Update backfill `list` endpoint to be async (#44208)
66407e8d226 is described below

commit 66407e8d2260ae8bf96f75048f8d1b2d1db3766c
Author: Daniel Standish <[email protected]>
AuthorDate: Fri Nov 22 12:37:09 2024 -0800

    Update backfill `list` endpoint to be async (#44208)
    
    This is a sort of hello world / proof of concept for having an route 
implemented using asyncio.  Gotta start somewhere.
---
 airflow/api_fastapi/common/db/common.py            | 84 +++++++++++++++++++++-
 .../core_api/routes/public/backfills.py            | 15 ++--
 airflow/settings.py                                | 10 +--
 airflow/utils/db.py                                | 16 +++++
 airflow/utils/session.py                           | 18 +++++
 tests/utils/test_session.py                        |  4 +-
 6 files changed, 129 insertions(+), 18 deletions(-)

diff --git a/airflow/api_fastapi/common/db/common.py 
b/airflow/api_fastapi/common/db/common.py
index 17da1eafacc..2d7da4bff73 100644
--- a/airflow/api_fastapi/common/db/common.py
+++ b/airflow/api_fastapi/common/db/common.py
@@ -24,8 +24,10 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, Literal, Sequence, overload
 
-from airflow.utils.db import get_query_count
-from airflow.utils.session import NEW_SESSION, create_session, provide_session
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from airflow.utils.db import get_query_count, get_query_count_async
+from airflow.utils.session import NEW_SESSION, create_session, 
create_session_async, provide_session
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
@@ -53,7 +55,9 @@ def get_session() -> Session:
 
 
 def apply_filters_to_select(
-    *, base_select: Select, filters: Sequence[BaseParam | None] | None = None
+    *,
+    base_select: Select,
+    filters: Sequence[BaseParam | None] | None = None,
 ) -> Select:
     if filters is None:
         return base_select
@@ -65,6 +69,80 @@ def apply_filters_to_select(
     return base_select
 
 
+async def get_async_session() -> AsyncSession:
+    """
+    Dependency for providing a session.
+
+    Example usage:
+
+    .. code:: python
+
+        @router.get("/your_path")
+        def your_route(session: Annotated[AsyncSession, 
Depends(get_async_session)]):
+            pass
+    """
+    async with create_session_async() as session:
+        yield session
+
+
+@overload
+async def paginated_select_async(
+    *,
+    query: Select,
+    filters: Sequence[BaseParam] | None = None,
+    order_by: BaseParam | None = None,
+    offset: BaseParam | None = None,
+    limit: BaseParam | None = None,
+    session: AsyncSession,
+    return_total_entries: Literal[True] = True,
+) -> tuple[Select, int]: ...
+
+
+@overload
+async def paginated_select_async(
+    *,
+    query: Select,
+    filters: Sequence[BaseParam] | None = None,
+    order_by: BaseParam | None = None,
+    offset: BaseParam | None = None,
+    limit: BaseParam | None = None,
+    session: AsyncSession,
+    return_total_entries: Literal[False],
+) -> tuple[Select, None]: ...
+
+
+async def paginated_select_async(
+    *,
+    query: Select,
+    filters: Sequence[BaseParam | None] | None = None,
+    order_by: BaseParam | None = None,
+    offset: BaseParam | None = None,
+    limit: BaseParam | None = None,
+    session: AsyncSession,
+    return_total_entries: bool = True,
+) -> tuple[Select, int | None]:
+    query = apply_filters_to_select(
+        base_select=query,
+        filters=filters,
+    )
+
+    total_entries = None
+    if return_total_entries:
+        total_entries = await get_query_count_async(query, session=session)
+
+    # TODO: Re-enable when permissions are handled. Readable / writable 
entities,
+    # for instance:
+    # readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user)
+    # dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags))
+
+    query = apply_filters_to_select(
+        base_select=query,
+        filters=[order_by, offset, limit],
+    )
+
+    return query, total_entries
+
+
 @overload
 def paginated_select(
     *,
diff --git a/airflow/api_fastapi/core_api/routes/public/backfills.py 
b/airflow/api_fastapi/core_api/routes/public/backfills.py
index aa6f540d327..78b2beb5588 100644
--- a/airflow/api_fastapi/core_api/routes/public/backfills.py
+++ b/airflow/api_fastapi/core_api/routes/public/backfills.py
@@ -20,9 +20,10 @@ from typing import Annotated
 
 from fastapi import Depends, HTTPException, status
 from sqlalchemy import select, update
+from sqlalchemy.ext.asyncio import AsyncSession
 from sqlalchemy.orm import Session
 
-from airflow.api_fastapi.common.db.common import get_session, paginated_select
+from airflow.api_fastapi.common.db.common import get_async_session, 
get_session, paginated_select_async
 from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset, 
SortParam
 from airflow.api_fastapi.common.router import AirflowRouter
 from airflow.api_fastapi.core_api.datamodels.backfills import (
@@ -49,7 +50,7 @@ backfills_router = AirflowRouter(tags=["Backfill"], 
prefix="/backfills")
 @backfills_router.get(
     path="",
 )
-def list_backfills(
+async def list_backfills(
     dag_id: str,
     limit: QueryLimit,
     offset: QueryOffset,
@@ -57,18 +58,16 @@ def list_backfills(
         SortParam,
         Depends(SortParam(["id"], Backfill).dynamic_depends()),
     ],
-    session: Annotated[Session, Depends(get_session)],
+    session: Annotated[AsyncSession, Depends(get_async_session)],
 ) -> BackfillCollectionResponse:
-    select_stmt, total_entries = paginated_select(
-        select=select(Backfill).where(Backfill.dag_id == dag_id),
+    select_stmt, total_entries = await paginated_select_async(
+        query=select(Backfill).where(Backfill.dag_id == dag_id),
         order_by=order_by,
         offset=offset,
         limit=limit,
         session=session,
     )
-
-    backfills = session.scalars(select_stmt)
-
+    backfills = await session.scalars(select_stmt)
     return BackfillCollectionResponse(
         backfills=backfills,
         total_entries=total_entries,
diff --git a/airflow/settings.py b/airflow/settings.py
index 5b458efcba4..76b3e948964 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -31,7 +31,7 @@ from typing import TYPE_CHECKING, Any, Callable
 import pluggy
 from packaging.version import Version
 from sqlalchemy import create_engine, exc, text
-from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, 
create_async_engine
+from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession as 
SAAsyncSession, create_async_engine
 from sqlalchemy.orm import scoped_session, sessionmaker
 from sqlalchemy.pool import NullPool
 
@@ -111,7 +111,7 @@ Session: Callable[..., SASession]
 # this is achieved by the Session factory above.
 NonScopedSession: Callable[..., SASession]
 async_engine: AsyncEngine
-create_async_session: Callable[..., AsyncSession]
+AsyncSession: Callable[..., SAAsyncSession]
 
 # The JSON library to use for DAG Serialization and De-Serialization
 json = json
@@ -469,7 +469,7 @@ def configure_orm(disable_connection_pool=False, 
pool_class=None):
     global Session
     global engine
     global async_engine
-    global create_async_session
+    global AsyncSession
     global NonScopedSession
 
     if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
@@ -498,11 +498,11 @@ def configure_orm(disable_connection_pool=False, 
pool_class=None):
 
     engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, 
**engine_args, future=True)
     async_engine = create_async_engine(SQL_ALCHEMY_CONN_ASYNC, future=True)
-    create_async_session = sessionmaker(
+    AsyncSession = sessionmaker(
         bind=async_engine,
         autocommit=False,
         autoflush=False,
-        class_=AsyncSession,
+        class_=SAAsyncSession,
         expire_on_commit=False,
     )
     mask_secret(engine.url.password)
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index d8939a11731..c899ebf615d 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -70,6 +70,7 @@ if TYPE_CHECKING:
     from alembic.runtime.environment import EnvironmentContext
     from alembic.script import ScriptDirectory
     from sqlalchemy.engine import Row
+    from sqlalchemy.ext.asyncio import AsyncSession
     from sqlalchemy.orm import Session
     from sqlalchemy.sql.elements import ClauseElement, TextClause
     from sqlalchemy.sql.selectable import Select
@@ -1447,6 +1448,21 @@ def get_query_count(query_stmt: Select, *, session: 
Session) -> int:
     return session.scalar(count_stmt)
 
 
+async def get_query_count_async(query: Select, *, session: AsyncSession) -> 
int:
+    """
+    Get count of a query.
+
+    A SELECT COUNT() FROM is issued against the subquery built from the
+    given statement. The ORDER BY clause is stripped from the statement
+    since it's unnecessary for COUNT, and can impact query planning and
+    degrade performance.
+
+    :meta private:
+    """
+    count_stmt = 
select(func.count()).select_from(query.order_by(None).subquery())
+    return await session.scalar(count_stmt)
+
+
 def check_query_exists(query_stmt: Select, *, session: Session) -> bool:
     """
     Check whether there is at least one row matching a query.
diff --git a/airflow/utils/session.py b/airflow/utils/session.py
index a63d3f3f937..49383cdf4a8 100644
--- a/airflow/utils/session.py
+++ b/airflow/utils/session.py
@@ -65,6 +65,24 @@ def create_session(scoped: bool = True) -> 
Generator[SASession, None, None]:
         session.close()
 
 
[email protected]
+async def create_session_async():
+    """
+    Context manager to create async session.
+
+    :meta private:
+    """
+    from airflow.settings import AsyncSession
+
+    async with AsyncSession() as session:
+        try:
+            yield session
+            await session.commit()
+        except Exception:
+            await session.rollback()
+            raise
+
+
 PS = ParamSpec("PS")
 RT = TypeVar("RT")
 
diff --git a/tests/utils/test_session.py b/tests/utils/test_session.py
index 02cba9e070d..8d26a25c626 100644
--- a/tests/utils/test_session.py
+++ b/tests/utils/test_session.py
@@ -58,9 +58,9 @@ class TestSession:
 
     @pytest.mark.asyncio
     async def test_async_session(self):
-        from airflow.settings import create_async_session
+        from airflow.settings import AsyncSession
 
-        session = create_async_session()
+        session = AsyncSession()
         session.add(Log(event="hihi1234"))
         await session.commit()
         my_special_log_event = await 
session.scalar(select(Log).where(Log.event == "hihi1234").limit(1))

Reply via email to