This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 66407e8d226 Update backfill `list` endpoint to be async (#44208)
66407e8d226 is described below
commit 66407e8d2260ae8bf96f75048f8d1b2d1db3766c
Author: Daniel Standish <[email protected]>
AuthorDate: Fri Nov 22 12:37:09 2024 -0800
Update backfill `list` endpoint to be async (#44208)
This is a sort of hello world / proof of concept for having an route
implemented using asyncio. Gotta start somewhere.
---
airflow/api_fastapi/common/db/common.py | 84 +++++++++++++++++++++-
.../core_api/routes/public/backfills.py | 15 ++--
airflow/settings.py | 10 +--
airflow/utils/db.py | 16 +++++
airflow/utils/session.py | 18 +++++
tests/utils/test_session.py | 4 +-
6 files changed, 129 insertions(+), 18 deletions(-)
diff --git a/airflow/api_fastapi/common/db/common.py
b/airflow/api_fastapi/common/db/common.py
index 17da1eafacc..2d7da4bff73 100644
--- a/airflow/api_fastapi/common/db/common.py
+++ b/airflow/api_fastapi/common/db/common.py
@@ -24,8 +24,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Literal, Sequence, overload
-from airflow.utils.db import get_query_count
-from airflow.utils.session import NEW_SESSION, create_session, provide_session
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from airflow.utils.db import get_query_count, get_query_count_async
+from airflow.utils.session import NEW_SESSION, create_session,
create_session_async, provide_session
if TYPE_CHECKING:
from sqlalchemy.orm import Session
@@ -53,7 +55,9 @@ def get_session() -> Session:
def apply_filters_to_select(
- *, base_select: Select, filters: Sequence[BaseParam | None] | None = None
+ *,
+ base_select: Select,
+ filters: Sequence[BaseParam | None] | None = None,
) -> Select:
if filters is None:
return base_select
@@ -65,6 +69,80 @@ def apply_filters_to_select(
return base_select
+async def get_async_session() -> AsyncSession:
+ """
+ Dependency for providing a session.
+
+ Example usage:
+
+ .. code:: python
+
+ @router.get("/your_path")
+ def your_route(session: Annotated[AsyncSession,
Depends(get_async_session)]):
+ pass
+ """
+ async with create_session_async() as session:
+ yield session
+
+
+@overload
+async def paginated_select_async(
+ *,
+ query: Select,
+ filters: Sequence[BaseParam] | None = None,
+ order_by: BaseParam | None = None,
+ offset: BaseParam | None = None,
+ limit: BaseParam | None = None,
+ session: AsyncSession,
+ return_total_entries: Literal[True] = True,
+) -> tuple[Select, int]: ...
+
+
+@overload
+async def paginated_select_async(
+ *,
+ query: Select,
+ filters: Sequence[BaseParam] | None = None,
+ order_by: BaseParam | None = None,
+ offset: BaseParam | None = None,
+ limit: BaseParam | None = None,
+ session: AsyncSession,
+ return_total_entries: Literal[False],
+) -> tuple[Select, None]: ...
+
+
+async def paginated_select_async(
+ *,
+ query: Select,
+ filters: Sequence[BaseParam | None] | None = None,
+ order_by: BaseParam | None = None,
+ offset: BaseParam | None = None,
+ limit: BaseParam | None = None,
+ session: AsyncSession,
+ return_total_entries: bool = True,
+) -> tuple[Select, int | None]:
+ query = apply_filters_to_select(
+ base_select=query,
+ filters=filters,
+ )
+
+ total_entries = None
+ if return_total_entries:
+ total_entries = await get_query_count_async(query, session=session)
+
+ # TODO: Re-enable when permissions are handled. Readable / writable
entities,
+ # for instance:
+ # readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user)
+ # dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags))
+
+ query = apply_filters_to_select(
+ base_select=query,
+ filters=[order_by, offset, limit],
+ )
+
+ return query, total_entries
+
+
@overload
def paginated_select(
*,
diff --git a/airflow/api_fastapi/core_api/routes/public/backfills.py
b/airflow/api_fastapi/core_api/routes/public/backfills.py
index aa6f540d327..78b2beb5588 100644
--- a/airflow/api_fastapi/core_api/routes/public/backfills.py
+++ b/airflow/api_fastapi/core_api/routes/public/backfills.py
@@ -20,9 +20,10 @@ from typing import Annotated
from fastapi import Depends, HTTPException, status
from sqlalchemy import select, update
+from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
-from airflow.api_fastapi.common.db.common import get_session, paginated_select
+from airflow.api_fastapi.common.db.common import get_async_session,
get_session, paginated_select_async
from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset,
SortParam
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.backfills import (
@@ -49,7 +50,7 @@ backfills_router = AirflowRouter(tags=["Backfill"],
prefix="/backfills")
@backfills_router.get(
path="",
)
-def list_backfills(
+async def list_backfills(
dag_id: str,
limit: QueryLimit,
offset: QueryOffset,
@@ -57,18 +58,16 @@ def list_backfills(
SortParam,
Depends(SortParam(["id"], Backfill).dynamic_depends()),
],
- session: Annotated[Session, Depends(get_session)],
+ session: Annotated[AsyncSession, Depends(get_async_session)],
) -> BackfillCollectionResponse:
- select_stmt, total_entries = paginated_select(
- select=select(Backfill).where(Backfill.dag_id == dag_id),
+ select_stmt, total_entries = await paginated_select_async(
+ query=select(Backfill).where(Backfill.dag_id == dag_id),
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)
-
- backfills = session.scalars(select_stmt)
-
+ backfills = await session.scalars(select_stmt)
return BackfillCollectionResponse(
backfills=backfills,
total_entries=total_entries,
diff --git a/airflow/settings.py b/airflow/settings.py
index 5b458efcba4..76b3e948964 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -31,7 +31,7 @@ from typing import TYPE_CHECKING, Any, Callable
import pluggy
from packaging.version import Version
from sqlalchemy import create_engine, exc, text
-from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession,
create_async_engine
+from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession as
SAAsyncSession, create_async_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.pool import NullPool
@@ -111,7 +111,7 @@ Session: Callable[..., SASession]
# this is achieved by the Session factory above.
NonScopedSession: Callable[..., SASession]
async_engine: AsyncEngine
-create_async_session: Callable[..., AsyncSession]
+AsyncSession: Callable[..., SAAsyncSession]
# The JSON library to use for DAG Serialization and De-Serialization
json = json
@@ -469,7 +469,7 @@ def configure_orm(disable_connection_pool=False,
pool_class=None):
global Session
global engine
global async_engine
- global create_async_session
+ global AsyncSession
global NonScopedSession
if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
@@ -498,11 +498,11 @@ def configure_orm(disable_connection_pool=False,
pool_class=None):
engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args,
**engine_args, future=True)
async_engine = create_async_engine(SQL_ALCHEMY_CONN_ASYNC, future=True)
- create_async_session = sessionmaker(
+ AsyncSession = sessionmaker(
bind=async_engine,
autocommit=False,
autoflush=False,
- class_=AsyncSession,
+ class_=SAAsyncSession,
expire_on_commit=False,
)
mask_secret(engine.url.password)
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index d8939a11731..c899ebf615d 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -70,6 +70,7 @@ if TYPE_CHECKING:
from alembic.runtime.environment import EnvironmentContext
from alembic.script import ScriptDirectory
from sqlalchemy.engine import Row
+ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from sqlalchemy.sql.elements import ClauseElement, TextClause
from sqlalchemy.sql.selectable import Select
@@ -1447,6 +1448,21 @@ def get_query_count(query_stmt: Select, *, session:
Session) -> int:
return session.scalar(count_stmt)
+async def get_query_count_async(query: Select, *, session: AsyncSession) ->
int:
+ """
+ Get count of a query.
+
+ A SELECT COUNT() FROM is issued against the subquery built from the
+ given statement. The ORDER BY clause is stripped from the statement
+ since it's unnecessary for COUNT, and can impact query planning and
+ degrade performance.
+
+ :meta private:
+ """
+ count_stmt =
select(func.count()).select_from(query.order_by(None).subquery())
+ return await session.scalar(count_stmt)
+
+
def check_query_exists(query_stmt: Select, *, session: Session) -> bool:
"""
Check whether there is at least one row matching a query.
diff --git a/airflow/utils/session.py b/airflow/utils/session.py
index a63d3f3f937..49383cdf4a8 100644
--- a/airflow/utils/session.py
+++ b/airflow/utils/session.py
@@ -65,6 +65,24 @@ def create_session(scoped: bool = True) ->
Generator[SASession, None, None]:
session.close()
[email protected]
+async def create_session_async():
+ """
+ Context manager to create async session.
+
+ :meta private:
+ """
+ from airflow.settings import AsyncSession
+
+ async with AsyncSession() as session:
+ try:
+ yield session
+ await session.commit()
+ except Exception:
+ await session.rollback()
+ raise
+
+
PS = ParamSpec("PS")
RT = TypeVar("RT")
diff --git a/tests/utils/test_session.py b/tests/utils/test_session.py
index 02cba9e070d..8d26a25c626 100644
--- a/tests/utils/test_session.py
+++ b/tests/utils/test_session.py
@@ -58,9 +58,9 @@ class TestSession:
@pytest.mark.asyncio
async def test_async_session(self):
- from airflow.settings import create_async_session
+ from airflow.settings import AsyncSession
- session = create_async_session()
+ session = AsyncSession()
session.add(Log(event="hihi1234"))
await session.commit()
my_special_log_event = await
session.scalar(select(Log).where(Log.event == "hihi1234").limit(1))