pierrejeambrun commented on code in PR #56214:
URL: https://github.com/apache/airflow/pull/56214#discussion_r2388614495
##########
airflow-core/src/airflow/api_fastapi/common/router.py:
##########
@@ -43,3 +52,52 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable:
return func
return decorator
+
+
+def _route_uses_dep(route: APIRoute, *, module: str, name: str) -> bool:
+ stack = list(route.dependant.dependencies)
+ while stack:
+ dep = stack.pop()
+ call = getattr(dep, "call", None)
+ if call is not None:
+ mod = getattr(call, "__module__", "")
+ func_name = getattr(call, "__name__", "")
+ if mod == module and func_name == name:
+ return True
+ stack.extend(getattr(dep, "dependencies", []) or [])
+ return False
+
+
+class _AirflowRoute(APIRoute):
+ def get_route_handler(self) -> Callable[[Request], Coroutine[None, None,
Response]]:
+ default_handler = super().get_route_handler()
+ uses_sync = _route_uses_dep(self,
module="airflow.api_fastapi.common.db.common", name="_get_session")
+ uses_async = _route_uses_dep(
+ self, module="airflow.api_fastapi.common.db.common",
name="_get_async_session"
+ )
+
+ async def handler(request: Request) -> Response:
+ if not (uses_sync or uses_async):
+ return await default_handler(request)
+ if uses_async:
+ async with create_session_async() as async_session:
+ setattr(request.state, "__airflow_async_db_session",
async_session)
+ response = await default_handler(request)
+ await async_session.commit()
+ try:
+ delattr(request.state, "__airflow_async_db_session")
+ except Exception:
+ pass
+ return response
Review Comment:
Done
##########
airflow-core/src/airflow/api_fastapi/common/router.py:
##########
@@ -43,3 +52,52 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable:
return func
return decorator
+
+
+def _route_uses_dep(route: APIRoute, *, module: str, name: str) -> bool:
+ stack = list(route.dependant.dependencies)
+ while stack:
+ dep = stack.pop()
+ call = getattr(dep, "call", None)
+ if call is not None:
+ mod = getattr(call, "__module__", "")
+ func_name = getattr(call, "__name__", "")
+ if mod == module and func_name == name:
+ return True
+ stack.extend(getattr(dep, "dependencies", []) or [])
+ return False
+
+
+class _AirflowRoute(APIRoute):
+ def get_route_handler(self) -> Callable[[Request], Coroutine[None, None,
Response]]:
+ default_handler = super().get_route_handler()
+ uses_sync = _route_uses_dep(self,
module="airflow.api_fastapi.common.db.common", name="_get_session")
+ uses_async = _route_uses_dep(
+ self, module="airflow.api_fastapi.common.db.common",
name="_get_async_session"
+ )
+
+ async def handler(request: Request) -> Response:
+ if not (uses_sync or uses_async):
+ return await default_handler(request)
+ if uses_async:
+ async with create_session_async() as async_session:
+ setattr(request.state, "__airflow_async_db_session",
async_session)
+ response = await default_handler(request)
+ await async_session.commit()
+ try:
+ delattr(request.state, "__airflow_async_db_session")
+ except Exception:
+ pass
+ return response
+ else:
+ with create_session(scoped=False) as session:
+ setattr(request.state, "__airflow_db_session", session)
+ response = await default_handler(request)
+ session.commit()
+ try:
+ delattr(request.state, "__airflow_db_session")
+ except Exception:
+ pass
+ return response
Review Comment:
Done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]