This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 41a8a9a4c69 AIP-84 | Add Auth for Dags (#47433)
41a8a9a4c69 is described below

commit 41a8a9a4c69adadf1602028240c8a1b1969413fd
Author: LIU ZHE YOU <[email protected]>
AuthorDate: Tue Mar 11 22:37:31 2025 +0800

    AIP-84 | Add Auth for Dags (#47433)
    
    * AIP-84 | Add Auth for Dag
    
    Refactor conftest for api_fastapi and test_dags
    
    Add unauthorized 403 test cases
    
    Remove PATCH in requires_access
    
    Fix unauthorized_test_client, requires_access_dag
    
    Add EditableDagsFilterDep, ReadableDagsFilterDep
    
    Add permitted_dag_filter for dags API
    
    Fix test_security
    
    Add OrmFilterClause
    
    Fix mypy error
    
    * fix(api_fastapi): rename methods argument to method
    
    * Fix kubernetes_tests
    
    * Fix api_fastapi/test_dags
    
    * Add dags_reserialize for k8s tests
    
    Refactor _get_jwt_token
    
    * Increase threshold of test_integration_run_dag_with_scheduler_failure
    
    * test: raise if we cannot get jwt_token not due to connection error
    
    * Fix _get_jwt_token after dynamic patching k8s configMap
    
    * Remove dags_reserialize setup in BaseK8STest
    
    * Fix test_docker_compose_quick_start
    
    * Ensure scheduler health in test_integration_run_dag_with_scheduler_failure
    
    * Increase timeout threshold
    
    * Add HTTP retry for _get_jwt_token
    
    * Add JWTRefreshAdapter and restart api-server if needed
    
    ---------
    
    Co-authored-by: Wei Lee <[email protected]>
---
 airflow/api_fastapi/common/db/common.py            |  52 +++++-----
 airflow/api_fastapi/common/parameters.py           |  11 +--
 airflow/api_fastapi/core_api/base.py               |  23 +++++
 .../api_fastapi/core_api/openapi/v1-generated.yaml |  12 +++
 airflow/api_fastapi/core_api/routes/public/dags.py |  20 +++-
 airflow/api_fastapi/core_api/security.py           |  51 ++++++++--
 docker_tests/test_docker_compose_quick_start.py    |  57 ++++++++++-
 kubernetes_tests/test_base.py                      | 107 ++++++++++++++++++---
 kubernetes_tests/test_kubernetes_executor.py       |   2 +-
 kubernetes_tests/test_other_executors.py           |   5 +-
 .../core_api/routes/public/test_dags.py            |  73 +++++++++++++-
 tests/api_fastapi/core_api/test_security.py        |  11 ++-
 12 files changed, 355 insertions(+), 69 deletions(-)

diff --git a/airflow/api_fastapi/common/db/common.py 
b/airflow/api_fastapi/common/db/common.py
index 84297724ce0..363a7fbdc4b 100644
--- a/airflow/api_fastapi/common/db/common.py
+++ b/airflow/api_fastapi/common/db/common.py
@@ -35,7 +35,7 @@ from airflow.utils.session import NEW_SESSION, 
create_session, create_session_as
 if TYPE_CHECKING:
     from sqlalchemy.sql import Select
 
-    from airflow.api_fastapi.common.parameters import BaseParam
+    from airflow.api_fastapi.core_api.base import OrmClause
 
 
 def _get_session() -> Session:
@@ -47,7 +47,7 @@ SessionDep = Annotated[Session, Depends(_get_session)]
 
 
 def apply_filters_to_select(
-    *, statement: Select, filters: Sequence[BaseParam | None] | None = None
+    *, statement: Select, filters: Sequence[OrmClause | None] | None = None
 ) -> Select:
     if filters is None:
         return statement
@@ -71,10 +71,10 @@ AsyncSessionDep = Annotated[AsyncSession, 
Depends(_get_async_session)]
 async def paginated_select_async(
     *,
     statement: Select,
-    filters: Sequence[BaseParam] | None = None,
-    order_by: BaseParam | None = None,
-    offset: BaseParam | None = None,
-    limit: BaseParam | None = None,
+    filters: Sequence[OrmClause] | None = None,
+    order_by: OrmClause | None = None,
+    offset: OrmClause | None = None,
+    limit: OrmClause | None = None,
     session: AsyncSession,
     return_total_entries: Literal[True] = True,
 ) -> tuple[Select, int]: ...
@@ -84,10 +84,10 @@ async def paginated_select_async(
 async def paginated_select_async(
     *,
     statement: Select,
-    filters: Sequence[BaseParam] | None = None,
-    order_by: BaseParam | None = None,
-    offset: BaseParam | None = None,
-    limit: BaseParam | None = None,
+    filters: Sequence[OrmClause] | None = None,
+    order_by: OrmClause | None = None,
+    offset: OrmClause | None = None,
+    limit: OrmClause | None = None,
     session: AsyncSession,
     return_total_entries: Literal[False],
 ) -> tuple[Select, None]: ...
@@ -96,10 +96,10 @@ async def paginated_select_async(
 async def paginated_select_async(
     *,
     statement: Select,
-    filters: Sequence[BaseParam | None] | None = None,
-    order_by: BaseParam | None = None,
-    offset: BaseParam | None = None,
-    limit: BaseParam | None = None,
+    filters: Sequence[OrmClause | None] | None = None,
+    order_by: OrmClause | None = None,
+    offset: OrmClause | None = None,
+    limit: OrmClause | None = None,
     session: AsyncSession,
     return_total_entries: bool = True,
 ) -> tuple[Select, int | None]:
@@ -129,10 +129,10 @@ async def paginated_select_async(
 def paginated_select(
     *,
     statement: Select,
-    filters: Sequence[BaseParam] | None = None,
-    order_by: BaseParam | None = None,
-    offset: BaseParam | None = None,
-    limit: BaseParam | None = None,
+    filters: Sequence[OrmClause] | None = None,
+    order_by: OrmClause | None = None,
+    offset: OrmClause | None = None,
+    limit: OrmClause | None = None,
     session: Session = NEW_SESSION,
     return_total_entries: Literal[True] = True,
 ) -> tuple[Select, int]: ...
@@ -142,10 +142,10 @@ def paginated_select(
 def paginated_select(
     *,
     statement: Select,
-    filters: Sequence[BaseParam] | None = None,
-    order_by: BaseParam | None = None,
-    offset: BaseParam | None = None,
-    limit: BaseParam | None = None,
+    filters: Sequence[OrmClause] | None = None,
+    order_by: OrmClause | None = None,
+    offset: OrmClause | None = None,
+    limit: OrmClause | None = None,
     session: Session = NEW_SESSION,
     return_total_entries: Literal[False],
 ) -> tuple[Select, None]: ...
@@ -155,10 +155,10 @@ def paginated_select(
 def paginated_select(
     *,
     statement: Select,
-    filters: Sequence[BaseParam] | None = None,
-    order_by: BaseParam | None = None,
-    offset: BaseParam | None = None,
-    limit: BaseParam | None = None,
+    filters: Sequence[OrmClause] | None = None,
+    order_by: OrmClause | None = None,
+    offset: OrmClause | None = None,
+    limit: OrmClause | None = None,
     session: Session = NEW_SESSION,
     return_total_entries: bool = True,
 ) -> tuple[Select, int | None]:
diff --git a/airflow/api_fastapi/common/parameters.py 
b/airflow/api_fastapi/common/parameters.py
index f69d64bd5d0..d7ce038d104 100644
--- a/airflow/api_fastapi/common/parameters.py
+++ b/airflow/api_fastapi/common/parameters.py
@@ -40,6 +40,7 @@ from pydantic import AfterValidator, BaseModel, NonNegativeInt
 from sqlalchemy import Column, and_, case, or_
 from sqlalchemy.inspection import inspect
 
+from airflow.api_fastapi.core_api.base import OrmClause
 from airflow.models import Base
 from airflow.models.asset import (
     AssetAliasModel,
@@ -65,18 +66,14 @@ if TYPE_CHECKING:
 T = TypeVar("T")
 
 
-class BaseParam(Generic[T], ABC):
-    """Base class for filters."""
+class BaseParam(OrmClause[T], ABC):
+    """Base class for path or query parameters with ORM transformation."""
 
     def __init__(self, value: T | None = None, skip_none: bool = True) -> None:
-        self.value = value
+        super().__init__(value)
         self.attribute: ColumnElement | None = None
         self.skip_none = skip_none
 
-    @abstractmethod
-    def to_orm(self, select: Select) -> Select:
-        pass
-
     def set_value(self, value: T | None) -> Self:
         self.value = value
         return self
diff --git a/airflow/api_fastapi/core_api/base.py 
b/airflow/api_fastapi/core_api/base.py
index d88ec1757eb..887f528f197 100644
--- a/airflow/api_fastapi/core_api/base.py
+++ b/airflow/api_fastapi/core_api/base.py
@@ -16,8 +16,16 @@
 # under the License.
 from __future__ import annotations
 
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Generic, TypeVar
+
 from pydantic import BaseModel as PydanticBaseModel, ConfigDict
 
+if TYPE_CHECKING:
+    from sqlalchemy.sql import Select
+
+T = TypeVar("T")
+
 
 class BaseModel(PydanticBaseModel):
     """
@@ -39,3 +47,18 @@ class StrictBaseModel(BaseModel):
     """
 
     model_config = ConfigDict(from_attributes=True, populate_by_name=True, 
extra="forbid")
+
+
+class OrmClause(Generic[T], ABC):
+    """
+    Base class for filtering clauses with paginated_select.
+
+    The subclasses should implement the `to_orm` method and set the `value` 
attribute.
+    """
+
+    def __init__(self, value: T | None = None):
+        self.value = value
+
+    @abstractmethod
+    def to_orm(self, select: Select) -> Select:
+        pass
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml 
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index e136ee3df85..a4e7f5e2565 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -3058,6 +3058,8 @@ paths:
       summary: Get Dags
       description: Get all DAGs.
       operationId: get_dags
+      security:
+      - OAuth2PasswordBearer: []
       parameters:
       - name: limit
         in: query
@@ -3223,6 +3225,8 @@ paths:
       summary: Patch Dags
       description: Patch multiple DAGs.
       operationId: patch_dags
+      security:
+      - OAuth2PasswordBearer: []
       parameters:
       - name: update_mask
         in: query
@@ -3358,6 +3362,8 @@ paths:
       summary: Get Dag
       description: Get basic information about a DAG.
       operationId: get_dag
+      security:
+      - OAuth2PasswordBearer: []
       parameters:
       - name: dag_id
         in: path
@@ -3408,6 +3414,8 @@ paths:
       summary: Patch Dag
       description: Patch the specific DAG.
       operationId: patch_dag
+      security:
+      - OAuth2PasswordBearer: []
       parameters:
       - name: dag_id
         in: path
@@ -3474,6 +3482,8 @@ paths:
       summary: Delete Dag
       description: Delete the specific DAG.
       operationId: delete_dag
+      security:
+      - OAuth2PasswordBearer: []
       parameters:
       - name: dag_id
         in: path
@@ -3524,6 +3534,8 @@ paths:
       summary: Get Dag Details
       description: Get details of DAG.
       operationId: get_dag_details
+      security:
+      - OAuth2PasswordBearer: []
       parameters:
       - name: dag_id
         in: path
diff --git a/airflow/api_fastapi/core_api/routes/public/dags.py 
b/airflow/api_fastapi/core_api/routes/public/dags.py
index 01c44ed29a0..c5d95f78210 100644
--- a/airflow/api_fastapi/core_api/routes/public/dags.py
+++ b/airflow/api_fastapi/core_api/routes/public/dags.py
@@ -57,6 +57,11 @@ from airflow.api_fastapi.core_api.datamodels.dags import (
     DAGResponse,
 )
 from airflow.api_fastapi.core_api.openapi.exceptions import 
create_openapi_http_exception_doc
+from airflow.api_fastapi.core_api.security import (
+    EditableDagsFilterDep,
+    ReadableDagsFilterDep,
+    requires_access_dag,
+)
 from airflow.api_fastapi.logging.decorators import action_logging
 from airflow.exceptions import AirflowException, DagNotFound
 from airflow.models import DAG, DagModel
@@ -65,7 +70,7 @@ from airflow.models.dagrun import DagRun
 dags_router = AirflowRouter(tags=["DAG"], prefix="/dags")
 
 
-@dags_router.get("")
+@dags_router.get("", dependencies=[Depends(requires_access_dag(method="GET"))])
 def get_dags(
     limit: QueryLimit,
     offset: QueryOffset,
@@ -105,6 +110,7 @@ def get_dags(
             ).dynamic_depends()
         ),
     ],
+    readable_dags_filter: ReadableDagsFilterDep,
     session: SessionDep,
 ) -> DAGCollectionResponse:
     """Get all DAGs."""
@@ -132,6 +138,7 @@ def get_dags(
             tags,
             owners,
             last_dag_run_state,
+            readable_dags_filter,
         ],
         order_by=order_by,
         offset=offset,
@@ -156,6 +163,7 @@ def get_dags(
             status.HTTP_422_UNPROCESSABLE_ENTITY,
         ]
     ),
+    dependencies=[Depends(requires_access_dag(method="GET"))],
 )
 def get_dag(dag_id: str, session: SessionDep, request: Request) -> DAGResponse:
     """Get basic information about a DAG."""
@@ -182,6 +190,7 @@ def get_dag(dag_id: str, session: SessionDep, request: 
Request) -> DAGResponse:
             status.HTTP_404_NOT_FOUND,
         ]
     ),
+    dependencies=[Depends(requires_access_dag(method="GET"))],
 )
 def get_dag_details(dag_id: str, session: SessionDep, request: Request) -> 
DAGDetailsResponse:
     """Get details of DAG."""
@@ -208,7 +217,7 @@ def get_dag_details(dag_id: str, session: SessionDep, 
request: Request) -> DAGDe
             status.HTTP_404_NOT_FOUND,
         ]
     ),
-    dependencies=[Depends(action_logging())],
+    dependencies=[Depends(requires_access_dag(method="PUT")), 
Depends(action_logging())],
 )
 def patch_dag(
     dag_id: str,
@@ -251,7 +260,7 @@ def patch_dag(
             status.HTTP_404_NOT_FOUND,
         ]
     ),
-    dependencies=[Depends(action_logging())],
+    dependencies=[Depends(requires_access_dag(method="PUT")), 
Depends(action_logging())],
 )
 def patch_dags(
     patch_body: DAGPatchBody,
@@ -263,6 +272,7 @@ def patch_dags(
     only_active: QueryOnlyActiveFilter,
     paused: QueryPausedFilter,
     last_dag_run_state: QueryLastDagRunStateFilter,
+    editable_dags_filter: EditableDagsFilterDep,
     session: SessionDep,
     update_mask: list[str] | None = Query(None),
 ) -> DAGCollectionResponse:
@@ -283,7 +293,7 @@ def patch_dags(
 
     dags_select, total_entries = paginated_select(
         statement=generate_dag_with_latest_run_query(),
-        filters=[only_active, paused, dag_id_pattern, tags, owners, 
last_dag_run_state],
+        filters=[only_active, paused, dag_id_pattern, tags, owners, 
last_dag_run_state, editable_dags_filter],
         order_by=None,
         offset=offset,
         limit=limit,
@@ -313,7 +323,7 @@ def patch_dags(
             status.HTTP_422_UNPROCESSABLE_ENTITY,
         ]
     ),
-    dependencies=[Depends(action_logging())],
+    dependencies=[Depends(requires_access_dag(method="DELETE")), 
Depends(action_logging())],
 )
 def delete_dag(
     dag_id: str,
diff --git a/airflow/api_fastapi/core_api/security.py 
b/airflow/api_fastapi/core_api/security.py
index 2e5745ffed2..7ef10be39e3 100644
--- a/airflow/api_fastapi/core_api/security.py
+++ b/airflow/api_fastapi/core_api/security.py
@@ -36,11 +36,15 @@ from 
airflow.api_fastapi.auth.managers.models.resource_details import (
     PoolDetails,
     VariableDetails,
 )
+from airflow.api_fastapi.core_api.base import OrmClause
 from airflow.configuration import conf
+from airflow.models.dag import DagModel
 from airflow.utils.jwt_signer import JWTSigner, get_signing_key
 
 if TYPE_CHECKING:
-    from airflow.api_fastapi.auth.managers.base_auth_manager import 
ResourceMethod
+    from sqlalchemy.sql import Select
+
+    from airflow.api_fastapi.auth.managers.base_auth_manager import 
BaseAuthManager, ResourceMethod
 
 oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
 
@@ -63,6 +67,9 @@ def get_user(token_str: Annotated[str, 
Depends(oauth2_scheme)]) -> BaseUser:
         raise HTTPException(status.HTTP_403_FORBIDDEN, "Forbidden")
 
 
+GetUserDep = Annotated[BaseUser, Depends(get_user)]
+
+
 async def get_user_with_exception_handling(request: Request) -> BaseUser | 
None:
     # Currently the UI does not support JWT authentication, this method 
defines a fallback if no token is provided by the UI.
     # We can remove this method when issue 
https://github.com/apache/airflow/issues/44884 is done.
@@ -80,12 +87,14 @@ async def get_user_with_exception_handling(request: 
Request) -> BaseUser | None:
     return get_user(token_str)
 
 
-def requires_access_dag(method: ResourceMethod, access_entity: DagAccessEntity 
| None = None) -> Callable:
+def requires_access_dag(
+    method: ResourceMethod, access_entity: DagAccessEntity | None = None
+) -> Callable[[Request, BaseUser], None]:
     def inner(
         request: Request,
-        user: Annotated[BaseUser, Depends(get_user)],
+        user: GetUserDep,
     ) -> None:
-        dag_id = request.path_params.get("dag_id") or 
request.query_params.get("dag_id")
+        dag_id: str | None = request.path_params.get("dag_id")
 
         _requires_access(
             is_authorized_callback=lambda: 
get_auth_manager().is_authorized_dag(
@@ -96,10 +105,40 @@ def requires_access_dag(method: ResourceMethod, 
access_entity: DagAccessEntity |
     return inner
 
 
+class PermittedDagFilter(OrmClause[set[str]]):
+    """A parameter that filters the permitted dags for the user."""
+
+    def to_orm(self, select: Select) -> Select:
+        return select.where(DagModel.dag_id.in_(self.value))
+
+
+def permitted_dag_filter_factory(method: ResourceMethod) -> Callable[[Request, 
BaseUser], PermittedDagFilter]:
+    """
+    Create a callable for Depends in FastAPI that returns a filter of the 
permitted dags for the user.
+
+    :param method: whether filter readable or writable.
+    :return: The callable that can be used as Depends in FastAPI.
+    """
+
+    def depends_permitted_dags_filter(
+        request: Request,
+        user: GetUserDep,
+    ) -> PermittedDagFilter:
+        auth_manager: BaseAuthManager = request.app.state.auth_manager
+        permitted_dags: set[str] = 
auth_manager.get_permitted_dag_ids(user=user, method=method)
+        return PermittedDagFilter(permitted_dags)
+
+    return depends_permitted_dags_filter
+
+
+EditableDagsFilterDep = Annotated[PermittedDagFilter, 
Depends(permitted_dag_filter_factory("PUT"))]
+ReadableDagsFilterDep = Annotated[PermittedDagFilter, 
Depends(permitted_dag_filter_factory("GET"))]
+
+
 def requires_access_pool(method: ResourceMethod) -> Callable[[Request, 
BaseUser], None]:
     def inner(
         request: Request,
-        user: Annotated[BaseUser, Depends(get_user)],
+        user: GetUserDep,
     ) -> None:
         pool_name = request.path_params.get("pool_name")
 
@@ -115,7 +154,7 @@ def requires_access_pool(method: ResourceMethod) -> 
Callable[[Request, BaseUser]
 def requires_access_connection(method: ResourceMethod) -> Callable[[Request, 
BaseUser], None]:
     def inner(
         request: Request,
-        user: Annotated[BaseUser, Depends(get_user)],
+        user: GetUserDep,
     ) -> None:
         connection_id = request.path_params.get("connection_id")
 
diff --git a/docker_tests/test_docker_compose_quick_start.py 
b/docker_tests/test_docker_compose_quick_start.py
index c91dbb36dba..ccf4ef18fcd 100644
--- a/docker_tests/test_docker_compose_quick_start.py
+++ b/docker_tests/test_docker_compose_quick_start.py
@@ -18,10 +18,12 @@ from __future__ import annotations
 
 import json
 import os
+import re
 import shlex
 from pprint import pprint
 from shutil import copyfile
 from time import sleep
+from urllib.parse import parse_qs, urlparse
 
 import pytest
 import requests
@@ -34,18 +36,67 @@ from docker_tests.constants import SOURCE_ROOT
 
 # isort:on (needed to workaround isort bug)
 
+DOCKER_COMPOSE_HOST_PORT = os.environ.get("HOST_PORT", "localhost:8080")
 AIRFLOW_WWW_USER_USERNAME = os.environ.get("_AIRFLOW_WWW_USER_USERNAME", 
"airflow")
 AIRFLOW_WWW_USER_PASSWORD = os.environ.get("_AIRFLOW_WWW_USER_PASSWORD", 
"airflow")
 DAG_ID = "example_bash_operator"
 DAG_RUN_ID = "test_dag_run_id"
 
 
-def api_request(method: str, path: str, base_url: str = 
"http://localhost:8080/public";, **kwargs) -> dict:
+def get_jwt_token() -> str:
+    """Get the JWT token.
+
+    Note: API server is still using FAB Auth Manager.
+
+    Steps:
+    1. Get the login page to get the csrf token
+        - The csrf token is in the hidden input field with id "csrf_token"
+    2. Login with the username and password
+        - Must use the same session to keep the csrf token session
+    3. Extract the JWT token from the redirect url
+        - Expected to have a connection error
+        - The redirect url should have the JWT token as a query parameter
+
+    :return: The JWT token
+    """
+    # get csrf token from login page
+    session = requests.Session()
+    get_login_form_response = 
session.get(f"http://{DOCKER_COMPOSE_HOST_PORT}/auth/login";)
+    csrf_token = re.search(
+        r'<input id="csrf_token" name="csrf_token" type="hidden" 
value="(.+?)">',
+        get_login_form_response.text,
+    )
+    assert csrf_token, "Failed to get csrf token from login page"
+    csrf_token_str = csrf_token.group(1)
+    assert csrf_token_str, "Failed to get csrf token from login page"
+    # login with form data
+    login_response = session.post(
+        f"http://{DOCKER_COMPOSE_HOST_PORT}/auth/login";,
+        data={
+            "username": AIRFLOW_WWW_USER_USERNAME,
+            "password": AIRFLOW_WWW_USER_PASSWORD,
+            "csrf_token": csrf_token_str,
+        },
+    )
+    redirect_url = login_response.url
+    # ensure redirect_url is a string
+    redirect_url_str = str(redirect_url) if redirect_url is not None else ""
+    assert "/?token" in redirect_url_str, f"Login failed with redirect url 
{redirect_url_str}"
+    parsed_url = urlparse(redirect_url_str)
+    query_params = parse_qs(str(parsed_url.query))
+    jwt_token_list = query_params.get("token")
+    jwt_token = jwt_token_list[0] if jwt_token_list else None
+    assert jwt_token, f"Failed to get JWT token from redirect url 
{redirect_url_str}"
+    return jwt_token
+
+
+def api_request(
+    method: str, path: str, base_url: str = 
f"http://{DOCKER_COMPOSE_HOST_PORT}/public";, **kwargs
+) -> dict:
     response = requests.request(
         method=method,
         url=f"{base_url}/{path}",
-        auth=(AIRFLOW_WWW_USER_USERNAME, AIRFLOW_WWW_USER_PASSWORD),
-        headers={"Content-Type": "application/json"},
+        headers={"Authorization": f"Bearer {get_jwt_token()}", "Content-Type": 
"application/json"},
         **kwargs,
     )
     response.raise_for_status()
diff --git a/kubernetes_tests/test_base.py b/kubernetes_tests/test_base.py
index 31e1924c18a..31248b02ac4 100644
--- a/kubernetes_tests/test_base.py
+++ b/kubernetes_tests/test_base.py
@@ -25,6 +25,7 @@ import time
 from datetime import datetime, timezone
 from pathlib import Path
 from subprocess import check_call, check_output
+from urllib.parse import parse_qs, urlparse
 
 import pytest
 import requests
@@ -59,9 +60,12 @@ class BaseK8STest:
 
     @pytest.fixture(autouse=True)
     def base_tests_setup(self, request):
-        self.set_api_server_base_url_config()
-        self.rollout_restart_deployment("airflow-api-server")
-        self.ensure_deployment_health("airflow-api-server")
+        if self.set_api_server_base_url_config():
+            # only restart the deployment if the configmap was updated
+            # speed up the test and make the airflow-api-server deployment 
more stable
+            self.rollout_restart_deployment("airflow-api-server")
+            self.ensure_deployment_health("airflow-api-server")
+
         # Replacement for unittests.TestCase.id()
         self.test_id = f"{request.node.cls.__name__}_{request.node.name}"
         self.session = self._get_session_with_retries()
@@ -126,17 +130,92 @@ class BaseK8STest:
         if names:
             check_call(["kubectl", "delete", "pod", names[0]])
 
+    @staticmethod
+    def _get_jwt_token(username: str, password: str) -> str:
+        """Get the JWT token for the given username and password.
+
+        Note: API server is still using FAB Auth Manager.
+
+        Steps:
+        1. Get the login page to get the csrf token
+            - The csrf token is in the hidden input field with id "csrf_token"
+        2. Login with the username and password
+            - Must use the same session to keep the csrf token session
+        3. Extract the JWT token from the redirect url
+            - Expected to have a connection error
+            - The redirect url should have the JWT token as a query parameter
+
+        :param session: The session to use for the request
+        :param username: The username to use for the login
+        :param password: The password to use for the login
+        :return: The JWT token
+        """
+        # get csrf token from login page
+        retry = Retry(total=5, backoff_factor=10)
+        session = requests.Session()
+        session.mount("http://";, HTTPAdapter(max_retries=retry))
+        session.mount("https://";, HTTPAdapter(max_retries=retry))
+        get_login_form_response = 
session.get(f"http://{KUBERNETES_HOST_PORT}/auth/login";)
+        csrf_token = re.search(
+            r'<input id="csrf_token" name="csrf_token" type="hidden" 
value="(.+?)">',
+            get_login_form_response.text,
+        )
+        assert csrf_token, "Failed to get csrf token from login page"
+        csrf_token_str = csrf_token.group(1)
+        assert csrf_token_str, "Failed to get csrf token from login page"
+        # login with form data
+        login_response = session.post(
+            f"http://{KUBERNETES_HOST_PORT}/auth/login";,
+            data={"username": username, "password": password, "csrf_token": 
csrf_token_str},
+        )
+        redirect_url = login_response.url
+        # ensure redirect_url is a string
+        redirect_url_str = str(redirect_url) if redirect_url is not None else 
""
+        assert "/?token" in redirect_url_str, f"Login failed with redirect url 
{redirect_url_str}"
+        parsed_url = urlparse(redirect_url_str)
+        query_params = parse_qs(str(parsed_url.query))
+        jwt_token_list = query_params.get("token")
+        jwt_token = jwt_token_list[0] if jwt_token_list else None
+        assert jwt_token, f"Failed to get JWT token from redirect url 
{redirect_url_str}"
+        return jwt_token
+
     def _get_session_with_retries(self):
+        class JWTRefreshAdapter(HTTPAdapter):
+            def __init__(self, base_instance, **kwargs):
+                self.base_instance = base_instance
+                super().__init__(**kwargs)
+
+            def send(self, request, **kwargs):
+                response = super().send(request, **kwargs)
+                if response.status_code in (401, 403):
+                    # Refresh token and update the Authorization header with 
retry logic.
+                    attempts = 0
+                    jwt_token = None
+                    while attempts < 5:
+                        try:
+                            jwt_token = 
self.base_instance._get_jwt_token("admin", "admin")
+                            break
+                        except Exception:
+                            attempts += 1
+                            time.sleep(1)
+                    if jwt_token is None:
+                        raise Exception("Failed to refresh JWT token after 5 
attempts")
+                    request.headers["Authorization"] = f"Bearer {jwt_token}"
+                    response = super().send(request, **kwargs)
+                return response
+
+        jwt_token = self._get_jwt_token("admin", "admin")
         session = requests.Session()
-        session.auth = ("admin", "admin")
+        session.headers.update({"Authorization": f"Bearer {jwt_token}"})
         retries = Retry(
-            total=3,
+            total=5,
             backoff_factor=10,
             status_forcelist=[404],
             allowed_methods=Retry.DEFAULT_ALLOWED_METHODS | 
frozenset(["PATCH", "POST"]),
         )
-        session.mount("http://";, HTTPAdapter(max_retries=retries))
-        session.mount("https://";, HTTPAdapter(max_retries=retries))
+        adapter = JWTRefreshAdapter(self, max_retries=retries)
+        session.mount("http://";, adapter)
+        session.mount("https://";, adapter)
         return session
 
     def _ensure_airflow_api_server_is_healthy(self):
@@ -236,8 +315,11 @@ class BaseK8STest:
         # escape newlines and double quotes
         return airflow_cfg_str.replace("\n", "\\n").replace('"', '\\"')
 
-    def set_api_server_base_url_config(self):
-        """Set [api/base_url] with `f"http://{KUBERNETES_HOST_PORT}"` as env 
in k8s configmap."""
+    def set_api_server_base_url_config(self) -> bool:
+        """Set [api/base_url] with `f"http://{KUBERNETES_HOST_PORT}"` as env 
in k8s configmap.
+
+        :return: True if the configmap was updated successfully, False 
otherwise
+        """
         configmap_name = "airflow-config"
         configmap_key = "airflow.cfg"
         original_configmap_json_str = check_output(
@@ -250,7 +332,7 @@ class BaseK8STest:
         airflow_cfg_dict = 
self._parse_airflow_cfg_as_dict(original_airflow_cfg)
         airflow_cfg_dict["api"]["base_url"] = f"http://{KUBERNETES_HOST_PORT}";
         # update the configmap with the new airflow.cfg
-        check_call(
+        patch_configmap_result = check_output(
             [
                 "kubectl",
                 "patch",
@@ -263,7 +345,10 @@ class BaseK8STest:
                 "-p",
                 f'{{"data": {{"{configmap_key}": 
"{self._parse_airflow_cfg_dict_as_escaped_toml(airflow_cfg_dict)}"}}}}',
             ]
-        )
+        ).decode()
+        if "(no change)" in patch_configmap_result:
+            return False
+        return True
 
     def ensure_dag_expected_state(self, host, logical_date, dag_id, 
expected_final_state, timeout):
         tries = 0
diff --git a/kubernetes_tests/test_kubernetes_executor.py 
b/kubernetes_tests/test_kubernetes_executor.py
index 63d1389b171..8a7596f3cda 100644
--- a/kubernetes_tests/test_kubernetes_executor.py
+++ b/kubernetes_tests/test_kubernetes_executor.py
@@ -50,7 +50,7 @@ class TestKubernetesExecutor(BaseK8STest):
             timeout=300,
         )
 
-    @pytest.mark.execution_timeout(300)
+    @pytest.mark.execution_timeout(500)
     def test_integration_run_dag_with_scheduler_failure(self):
         dag_id = "example_kubernetes_executor"
 
diff --git a/kubernetes_tests/test_other_executors.py 
b/kubernetes_tests/test_other_executors.py
index f8203b069b4..327e252825a 100644
--- a/kubernetes_tests/test_other_executors.py
+++ b/kubernetes_tests/test_other_executors.py
@@ -16,8 +16,6 @@
 # under the License.
 from __future__ import annotations
 
-import time
-
 import pytest
 
 from kubernetes_tests.test_base import (
@@ -68,8 +66,7 @@ class TestCeleryAndLocalExecutor(BaseK8STest):
         dag_run_id, logical_date = self.start_job_in_kubernetes(dag_id, 
self.host)
 
         self._delete_airflow_pod("scheduler")
-
-        time.sleep(10)  # give time for pod to restart
+        self.ensure_deployment_health("airflow-scheduler")
 
         # Wait some time for the operator to complete
         self.monitor_task(
diff --git a/tests/api_fastapi/core_api/routes/public/test_dags.py 
b/tests/api_fastapi/core_api/routes/public/test_dags.py
index 21f1bd0235c..93c7c20ea3a 100644
--- a/tests/api_fastapi/core_api/routes/public/test_dags.py
+++ b/tests/api_fastapi/core_api/routes/public/test_dags.py
@@ -235,13 +235,31 @@ class TestGetDags(TestDagEndpoint):
     )
     def test_get_dags(self, test_client, query_params, expected_total_entries, 
expected_ids):
         response = test_client.get("/public/dags", params=query_params)
-
         assert response.status_code == 200
         body = response.json()
 
         assert body["total_entries"] == expected_total_entries
         assert [dag["dag_id"] for dag in body["dags"]] == expected_ids
 
+    
@mock.patch("airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_permitted_dag_ids")
+    def test_get_dags_should_call_permitted_dag_ids(self, 
mock_get_permitted_dag_ids, test_client):
+        mock_get_permitted_dag_ids.return_value = {DAG1_ID, DAG2_ID}
+        response = test_client.get("/public/dags")
+        mock_get_permitted_dag_ids.assert_called_once_with(user=mock.ANY, 
method="GET")
+        assert response.status_code == 200
+        body = response.json()
+
+        assert body["total_entries"] == 2
+        assert [dag["dag_id"] for dag in body["dags"]] == [DAG1_ID, DAG2_ID]
+
+    def test_get_dags_should_response_401(self, unauthenticated_test_client):
+        response = unauthenticated_test_client.get("/public/dags")
+        assert response.status_code == 401
+
+    def test_get_dags_should_response_403(self, unauthorized_test_client):
+        response = unauthorized_test_client.get("/public/dags")
+        assert response.status_code == 403
+
 
 class TestPatchDag(TestDagEndpoint):
     """Unit tests for Patch DAG."""
@@ -268,6 +286,14 @@ class TestPatchDag(TestDagEndpoint):
             assert body["is_paused"] == expected_is_paused
             check_last_log(session, dag_id=dag_id, event="patch_dag", 
logical_date=None)
 
+    def test_patch_dag_should_response_401(self, unauthenticated_test_client):
+        response = 
unauthenticated_test_client.patch(f"/public/dags/{DAG1_ID}", json={"is_paused": 
True})
+        assert response.status_code == 401
+
+    def test_patch_dag_should_response_403(self, unauthorized_test_client):
+        response = unauthorized_test_client.patch(f"/public/dags/{DAG1_ID}", 
json={"is_paused": True})
+        assert response.status_code == 403
+
 
 class TestPatchDags(TestDagEndpoint):
     """Unit tests for Patch DAGs."""
@@ -333,6 +359,26 @@ class TestPatchDags(TestDagEndpoint):
             assert paused_dag_ids == expected_paused_ids
             check_last_log(session, dag_id=DAG1_ID, event="patch_dag", 
logical_date=None)
 
+    
@mock.patch("airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_permitted_dag_ids")
+    def test_patch_dags_should_call_permitted_dag_ids(self, 
mock_get_permitted_dag_ids, test_client):
+        mock_get_permitted_dag_ids.return_value = {DAG1_ID, DAG2_ID}
+        response = test_client.patch(
+            "/public/dags", json={"is_paused": False}, params={"only_active": 
False, "dag_id_pattern": "~"}
+        )
+        mock_get_permitted_dag_ids.assert_called_once_with(user=mock.ANY, 
method="PUT")
+        assert response.status_code == 200
+        body = response.json()
+
+        assert [dag["dag_id"] for dag in body["dags"]] == [DAG1_ID, DAG2_ID]
+
+    def test_patch_dags_should_response_401(self, unauthenticated_test_client):
+        response = unauthenticated_test_client.patch("/public/dags", 
json={"is_paused": True})
+        assert response.status_code == 401
+
+    def test_patch_dags_should_response_403(self, unauthorized_test_client):
+        response = unauthorized_test_client.patch("/public/dags", 
json={"is_paused": True})
+        assert response.status_code == 403
+
 
 class TestDagDetails(TestDagEndpoint):
     """Unit tests for DAG Details."""
@@ -414,6 +460,14 @@ class TestDagDetails(TestDagEndpoint):
         }
         assert res_json == expected
 
+    def test_dag_details_should_response_401(self, 
unauthenticated_test_client):
+        response = 
unauthenticated_test_client.get(f"/public/dags/{DAG1_ID}/details")
+        assert response.status_code == 401
+
+    def test_dag_details_should_response_403(self, unauthorized_test_client):
+        response = 
unauthorized_test_client.get(f"/public/dags/{DAG1_ID}/details")
+        assert response.status_code == 403
+
 
 class TestGetDag(TestDagEndpoint):
     """Unit tests for Get DAG."""
@@ -462,6 +516,14 @@ class TestGetDag(TestDagEndpoint):
         }
         assert res_json == expected
 
+    def test_get_dag_should_response_401(self, unauthenticated_test_client):
+        response = unauthenticated_test_client.get(f"/public/dags/{DAG1_ID}")
+        assert response.status_code == 401
+
+    def test_get_dag_should_response_403(self, unauthorized_test_client):
+        response = unauthorized_test_client.get(f"/public/dags/{DAG1_ID}")
+        assert response.status_code == 403
+
 
 class TestDeleteDAG(TestDagEndpoint):
     """Unit tests for Delete DAG."""
@@ -521,5 +583,14 @@ class TestDeleteDAG(TestDagEndpoint):
 
         details_response = test_client.get(f"{API_PREFIX}/{dag_id}/details")
         assert details_response.status_code == status_code_details
+
         if details_response.status_code == 204:
             check_last_log(session, dag_id=dag_id, event="delete_dag", 
logical_date=None)
+
+    def test_delete_dag_should_response_401(self, unauthenticated_test_client):
+        response = 
unauthenticated_test_client.delete(f"{API_PREFIX}/{DAG1_ID}")
+        assert response.status_code == 401
+
+    def test_delete_dag_should_response_403(self, unauthorized_test_client):
+        response = unauthorized_test_client.delete(f"{API_PREFIX}/{DAG1_ID}")
+        assert response.status_code == 403
diff --git a/tests/api_fastapi/core_api/test_security.py 
b/tests/api_fastapi/core_api/test_security.py
index f777cda5e83..193cc67798f 100644
--- a/tests/api_fastapi/core_api/test_security.py
+++ b/tests/api_fastapi/core_api/test_security.py
@@ -88,11 +88,10 @@ class TestFastApiSecurity:
         auth_manager = Mock()
         auth_manager.is_authorized_dag.return_value = True
         mock_get_auth_manager.return_value = auth_manager
+        fastapi_request = Mock()
+        fastapi_request.path_params.return_value = {}
 
-        mock_request = Mock()
-        mock_request.path_params.return_value = {"dag_id": "test"}
-
-        requires_access_dag("GET", DagAccessEntity.CODE)(mock_request, Mock())
+        requires_access_dag("GET", DagAccessEntity.CODE)(fastapi_request, 
Mock())
 
         auth_manager.is_authorized_dag.assert_called_once()
 
@@ -101,11 +100,13 @@ class TestFastApiSecurity:
         auth_manager = Mock()
         auth_manager.is_authorized_dag.return_value = False
         mock_get_auth_manager.return_value = auth_manager
+        fastapi_request = Mock()
+        fastapi_request.path_params.return_value = {}
 
         mock_request = Mock()
         mock_request.path_params.return_value = {}
 
         with pytest.raises(HTTPException, match="Forbidden"):
-            requires_access_dag("GET", DagAccessEntity.CODE)(mock_request, 
Mock())
+            requires_access_dag("GET", DagAccessEntity.CODE)(fastapi_request, 
Mock())
 
         auth_manager.is_authorized_dag.assert_called_once()


Reply via email to