vincbeck commented on code in PR #55306:
URL: https://github.com/apache/airflow/pull/55306#discussion_r2330654604


##########
airflow-core/src/airflow/api_fastapi/core_api/security.py:
##########
@@ -255,21 +255,28 @@ def inner(
         request: BulkBody[PoolBody],
         user: GetUserDep,
     ) -> None:
+        existing_pool_names = [
+            cast("str", entity) if action.action == BulkAction.DELETE else 
cast("PoolBody", entity).pool
+            for action in request.actions
+            for entity in action.entities
+            if action.action != BulkAction.CREATE
+        ]
+        teams = Pool.get_bulk_team_name(existing_pool_names)
+
         requests: list[IsAuthorizedPoolRequest] = []
         for action in request.actions:
-            requests.extend(
-                [
-                    {
-                        "method": 
MAP_BULK_ACTION_TO_AUTH_METHOD[action.action],
-                        "details": PoolDetails(
-                            name=cast("str", pool)
-                            if action.action == BulkAction.DELETE
-                            else cast("PoolBody", pool).pool
-                        ),
-                    }
-                    for pool in action.entities
-                ]
-            )
+            for pool in action.entities:
+                pool_name = (
+                    cast("str", pool) if action.action == BulkAction.DELETE 
else cast("PoolBody", pool).pool
+                )
+                req: IsAuthorizedPoolRequest = {
+                    "method": MAP_BULK_ACTION_TO_AUTH_METHOD[action.action],
+                    "details": PoolDetails(
+                        name=pool_name,
+                        team_name=teams.get(pool_name),
+                    ),
+                }
+                requests.append(req)

Review Comment:
   Fair point, I can add some



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/import_error.py:
##########
@@ -179,19 +179,21 @@ def get_import_errors(
         limit=limit,
         session=session,
     )
-    import_errors_result: Iterable[tuple[ParseImportError, Iterable[str]]] = 
groupby(
+    import_errors_result: Iterable[tuple[ParseImportError, Iterable]] = 
groupby(
         session.execute(import_errors_select), itemgetter(0)
     )
 
     import_errors = []
     for import_error, file_dag_ids in import_errors_result:
+        dag_ids = [dag_id for _, dag_id in file_dag_ids]
+        teams = DagModel.get_bulk_team_name(dag_ids, session=session)

Review Comment:
   Fair point for the renaming.
   
   Regarding the second point, this API has always been working on the dag file 
level. This API returns the error per dag file. I guess changing this behavior 
would be breaking and would have many consequences



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/import_error.py:
##########
@@ -179,19 +179,21 @@ def get_import_errors(
         limit=limit,
         session=session,
     )
-    import_errors_result: Iterable[tuple[ParseImportError, Iterable[str]]] = 
groupby(
+    import_errors_result: Iterable[tuple[ParseImportError, Iterable]] = 
groupby(
         session.execute(import_errors_select), itemgetter(0)
     )
 
     import_errors = []
     for import_error, file_dag_ids in import_errors_result:
+        dag_ids = [dag_id for _, dag_id in file_dag_ids]

Review Comment:
   Nop, see my comment above



##########
airflow-core/src/airflow/models/connection.py:
##########
@@ -598,3 +598,13 @@ def get_team_name(connection_id: str, session=NEW_SESSION) 
-> str | None:
             .where(Connection.conn_id == connection_id)
         )
         return session.scalar(stmt)
+
+    @staticmethod
+    @provide_session
+    def get_bulk_team_name(connection_ids: list[str], session=NEW_SESSION) -> 
dict[str, str | None]:

Review Comment:
   Agree



##########
airflow-core/tests/unit/api_fastapi/core_api/test_security.py:
##########
@@ -106,7 +106,7 @@ async def test_get_user_expired_token(self, 
mock_get_auth_manager):
 
     @pytest.mark.db_test
     @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
-    async def test_requires_access_dag_authorized(self, mock_get_auth_manager):

Review Comment:
   It is not related, this is something I noticed when working on the tests. 
The `async` is not necessary here, I am pretty sure it comes from copy paste 
from tests testing `resolve_user_from_token`, that is an async function



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/import_error.py:
##########
@@ -179,19 +179,21 @@ def get_import_errors(
         limit=limit,
         session=session,
     )
-    import_errors_result: Iterable[tuple[ParseImportError, Iterable[str]]] = 
groupby(
+    import_errors_result: Iterable[tuple[ParseImportError, Iterable]] = 
groupby(

Review Comment:
   Good catch! It should have called it out actually. When working on this 
issue I found out that bug. `file_dag_ids` (the second element of 
`import_errors_result`) is not a `str` but a `tuple`, so today we are passing a 
`tuple` as `str` as Dag ID to the method `batch_is_authorized_dag` (needless to 
say, garbage data). That is also why I added this test in 
`airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_import_error.py`:
   ```
   mock_batch_is_authorized_dag.assert_called_once_with(
               [
                   {
                       "method": "GET",
                       "details": DagDetails(id=permitted_dag_model.dag_id, 
team_name=team),
                   }
               ],
               user=mock.ANY,
           )
   ```
   
   To catch this bug



##########
airflow-core/src/airflow/api_fastapi/core_api/security.py:
##########
@@ -255,21 +255,28 @@ def inner(
         request: BulkBody[PoolBody],
         user: GetUserDep,
     ) -> None:
+        existing_pool_names = [
+            cast("str", entity) if action.action == BulkAction.DELETE else 
cast("PoolBody", entity).pool
+            for action in request.actions
+            for entity in action.entities
+            if action.action != BulkAction.CREATE
+        ]
+        teams = Pool.get_bulk_team_name(existing_pool_names)
+
         requests: list[IsAuthorizedPoolRequest] = []
         for action in request.actions:
-            requests.extend(
-                [
-                    {
-                        "method": 
MAP_BULK_ACTION_TO_AUTH_METHOD[action.action],
-                        "details": PoolDetails(
-                            name=cast("str", pool)
-                            if action.action == BulkAction.DELETE
-                            else cast("PoolBody", pool).pool
-                        ),
-                    }
-                    for pool in action.entities
-                ]
-            )
+            for pool in action.entities:
+                pool_name = (
+                    cast("str", pool) if action.action == BulkAction.DELETE 
else cast("PoolBody", pool).pool
+                )
+                req: IsAuthorizedPoolRequest = {
+                    "method": MAP_BULK_ACTION_TO_AUTH_METHOD[action.action],
+                    "details": PoolDetails(
+                        name=pool_name,
+                        team_name=teams.get(pool_name),

Review Comment:
   Indeed, `existing_pool_names` contains only pool names that are associated 
with one team. But the authorization request we are sending contain all 
entities (here pool). Look at line `267`: `for action in request.actions:`. We 
are looping over the entire list of entities. In case the pool is not 
associated to a team, then `team_name` is `False` in the authorization request. 
All we do here is passing information to the auth manager so that it can make 
authorization decision.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to